#!/usr/bin/env python3
#
# @File: test-rembg-comparison.py
# @Date: 2026-05-19
#
# Background removal quality test: rembg (free, self-hosted) vs ImgIX vs remove.bg.
# Sources: outdoor car photos from DuckDuckGo + real dealer photos from media server.
#
# Install:
#   pip install "rembg[cpu]" Pillow requests duckduckgo-search
#
# Basic run — internet photos only:
#   python test-rembg-comparison.py --output /tmp/cars
#
# With real dealer photos from media server:
#   python test-rembg-comparison.py --output /tmp/cars \
#     --ssh-server root@165.227.32.132 \
#     --media-path /var/www/media.liftkit.click/public_html
#
# Full comparison with ImgIX:
#   python test-rembg-comparison.py --output /tmp/cars \
#     --ssh-server root@165.227.32.132 \
#     --media-path /var/www/media.liftkit.click/public_html \
#     --imgix-domain liftkit.imgix.net
#
# Optional: also compare remove.bg (free key at remove.bg/api):
#   ... --removebg-key=YOUR_KEY

import sys
import time
import argparse
import subprocess
import requests
from pathlib import Path
from urllib.parse import urlparse

try:
    from rembg import remove, new_session
    from PIL import Image
except ImportError:
    print("ERROR: Run:  pip install 'rembg[cpu]' Pillow requests duckduckgo-search")
    sys.exit(1)

try:
    from duckduckgo_search import DDGS
except ImportError:
    print("ERROR: Run:  pip install duckduckgo-search")
    sys.exit(1)

# Real-world outdoor shots — cars WITH backgrounds (the actual use case)
SEARCH_TERMS = [
    'used car outdoor dealership lot for sale',
    'SUV truck parked outdoor lot',
    'sedan coupe outdoor parking lot dealership',
    'used car outdoor photo natural background',
]
HEADERS = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
IMGIX_TEST_DIR = 'bgtest'
SSH_OPTS = ['-o', 'StrictHostKeyChecking=no', '-o', 'BatchMode=yes']


# ─────────────────────────────────────────────
# Step 1a: Download outdoor car images from DuckDuckGo
# ─────────────────────────────────────────────
def fetch_web_images(source_dir: Path, count: int) -> list[Path]:
    source_dir.mkdir(parents=True, exist_ok=True)
    downloaded = []
    per_term = max(1, count // len(SEARCH_TERMS) + 1)

    print(f"\n── Downloading {count} outdoor car images from DuckDuckGo ──\n")

    with DDGS() as ddgs:
        for term in SEARCH_TERMS:
            if len(downloaded) >= count:
                break
            print(f"  Searching: \"{term}\"")
            try:
                results = list(ddgs.images(term, max_results=per_term * 3, type_image='photo'))
            except Exception as e:
                print(f"  Search failed: {e}")
                continue

            for r in results:
                if len(downloaded) >= count:
                    break
                url = r.get('image', '')
                if not url:
                    continue
                ext = Path(urlparse(url).path).suffix.lower()
                if ext not in ('.jpg', '.jpeg', '.png', '.webp'):
                    ext = '.jpg'
                idx = len(downloaded) + 1
                dest = source_dir / f"web_{idx:02d}{ext}"
                try:
                    resp = requests.get(url, timeout=15, headers=HEADERS)
                    if resp.status_code == 200 and len(resp.content) > 10_000:
                        dest.write_bytes(resp.content)
                        print(f"  [web_{idx:02d}] {dest.name}  ({len(resp.content)//1024} KB)")
                        downloaded.append(dest)
                except Exception:
                    pass

    print(f"\n  Downloaded {len(downloaded)} web images → {source_dir}\n")
    return downloaded


# ─────────────────────────────────────────────
# Step 1b: Fetch real dealer photos from media server
# ─────────────────────────────────────────────
def fetch_dealer_images(source_dir: Path, count: int, ssh_server: str, media_path: str, ssh_key: str = '') -> list[Path]:
    source_dir.mkdir(parents=True, exist_ok=True)

    print(f"── Fetching {count} real dealer photos from {ssh_server} ──\n")

    ssh_cmd = ['ssh'] + SSH_OPTS
    if ssh_key:
        ssh_cmd += ['-i', ssh_key]

    # Find one large JPG per unique vehicle directory
    find_cmd = (
        f'find {media_path} -name "*-large.jpg" -printf "%h\\n" | '
        f'sort -u | '
        f'while read dir; do ls "$dir"/*-large.jpg 2>/dev/null | head -1; done | '
        f'head -{count}'
    )
    result = subprocess.run(ssh_cmd + [ssh_server, find_cmd], capture_output=True, text=True)
    if result.returncode != 0 or not result.stdout.strip():
        print(f"  ERROR finding photos on server: {result.stderr.strip()}")
        return []

    remote_paths = [p.strip() for p in result.stdout.strip().splitlines() if p.strip()][:count]
    print(f"  Found {len(remote_paths)} vehicle photos on server")

    downloaded = []
    for idx, remote_path in enumerate(remote_paths, 1):
        dest = source_dir / f"dealer_{idx:02d}.jpg"
        scp_cmd = ['scp'] + SSH_OPTS
        if ssh_key:
            scp_cmd += ['-i', ssh_key]
        scp_cmd += [f"{ssh_server}:{remote_path}", str(dest)]
        result = subprocess.run(scp_cmd, capture_output=True, text=True)
        if result.returncode == 0 and dest.exists():
            print(f"  [dealer_{idx:02d}] {Path(remote_path).name}  ({dest.stat().st_size//1024} KB)")
            downloaded.append(dest)
        else:
            print(f"  [dealer_{idx:02d}] FAILED: {result.stderr.strip()}")

    print(f"\n  Fetched {len(downloaded)} dealer images → {source_dir}\n")
    return downloaded


# ─────────────────────────────────────────────
# Step 2: Upload to media server + fetch ImgIX output
# ─────────────────────────────────────────────
def run_imgix(images: list[Path], output_dir: Path, ssh_server: str, media_path: str, imgix_domain: str, ssh_key: str = ''):
    output_dir.mkdir(parents=True, exist_ok=True)
    remote_dir = f"{media_path}/{IMGIX_TEST_DIR}"

    print(f"── ImgIX background removal ──\n")
    print(f"  Uploading {len(images)} images to {ssh_server}:{remote_dir} ...")

    ssh_cmd = ['ssh'] + SSH_OPTS
    if ssh_key:
        ssh_cmd += ['-i', ssh_key]
    ssh_cmd_base = ssh_cmd + [ssh_server]

    result = subprocess.run(ssh_cmd_base + [f'mkdir -p {remote_dir}'], capture_output=True, text=True)
    if result.returncode != 0:
        print(f"  ERROR creating remote dir: {result.stderr.strip()}")
        return

    scp_cmd = ['scp'] + SSH_OPTS
    if ssh_key:
        scp_cmd += ['-i', ssh_key]
    scp_cmd += [str(p) for p in images] + [f"{ssh_server}:{remote_dir}/"]
    result = subprocess.run(scp_cmd, capture_output=True, text=True)
    if result.returncode != 0:
        print(f"  ERROR uploading: {result.stderr.strip()}")
        return
    print(f"  Upload done.\n")

    time.sleep(2)

    ok = fail = 0
    for i, img_path in enumerate(images, 1):
        out = output_dir / (img_path.stem + '.png')
        imgix_url = (
            f"https://{imgix_domain}/{IMGIX_TEST_DIR}/{img_path.name}"
            f"?bg-remove=true&bg-remove-fg-type=car&bg-remove-add-shadow=true&fm=png"
        )
        print(f"  [{i:02d}/{len(images)}] {img_path.name}", end=' ... ', flush=True)
        try:
            resp = requests.get(imgix_url, timeout=30, headers=HEADERS)
            if resp.status_code == 200:
                out.write_bytes(resp.content)
                print(f"OK  {out.stat().st_size//1024} KB")
                ok += 1
            else:
                print(f"FAILED HTTP {resp.status_code}  ({imgix_url})")
                fail += 1
        except Exception as e:
            print(f"FAILED: {e}")
            fail += 1

    subprocess.run(ssh_cmd_base + [f'rm -rf {remote_dir}'], capture_output=True)
    print(f"\n  ImgIX: {ok} OK, {fail} failed\n")


# ─────────────────────────────────────────────
# Step 3: rembg
# ─────────────────────────────────────────────
def run_rembg(images: list[Path], output_dir: Path, model: str):
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"── rembg '{model}' ──\n")
    print(f"  Loading model ...")
    # Force CPU to avoid CoreML crashes on Apple Silicon with BiRefNet
    session = new_session(model, providers=['CPUExecutionProvider'])
    print(f"  Ready.\n")

    ok = fail = 0
    for i, img_path in enumerate(images, 1):
        out = output_dir / (img_path.stem + '.png')
        print(f"  [{i:02d}/{len(images)}] {img_path.name}", end=' ... ', flush=True)
        try:
            start = time.time()
            out.write_bytes(remove(img_path.read_bytes(), session=session))
            print(f"OK  {out.stat().st_size//1024} KB  ({time.time()-start:.1f}s)")
            ok += 1
        except Exception as e:
            print(f"FAILED: {e}")
            fail += 1

    print(f"\n  rembg: {ok} OK, {fail} failed\n")


# ─────────────────────────────────────────────
# Step 4: remove.bg (optional reference)
# ─────────────────────────────────────────────
def run_removebg(images: list[Path], output_dir: Path, api_key: str):
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"── remove.bg API ──\n")

    ok = fail = 0
    for i, img_path in enumerate(images, 1):
        out = output_dir / (img_path.stem + '.png')
        print(f"  [{i:02d}/{len(images)}] {img_path.name}", end=' ... ', flush=True)
        try:
            r = requests.post(
                'https://api.remove.bg/v1.0/removebg',
                files={'image_file': img_path.open('rb')},
                data={'size': 'auto', 'type': 'car'},
                headers={'X-Api-Key': api_key},
                timeout=60,
            )
            if r.status_code == 200:
                out.write_bytes(r.content)
                credits = r.headers.get('X-Credits-Charged', '?')
                print(f"OK  {out.stat().st_size//1024} KB  ({credits} credit)")
                ok += 1
            else:
                err = r.json().get('errors', [{}])[0].get('title', r.text[:60])
                print(f"FAILED HTTP {r.status_code}: {err}")
                fail += 1
        except Exception as e:
            print(f"FAILED: {e}")
            fail += 1

    print(f"\n  remove.bg: {ok} OK, {fail} failed\n")


# ─────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────
def main():
    parser = argparse.ArgumentParser(description='Background removal quality comparison')
    parser.add_argument('--output',        required=True,  help='Root folder for all results')
    parser.add_argument('--web-count',     type=int, default=15, help='Outdoor car images to download from web (default: 15)')
    parser.add_argument('--dealer-count',  type=int, default=15, help='Real dealer photos to fetch from server (default: 15)')
    parser.add_argument('--source',        default='',     help='Skip download — use an existing folder of images')
    parser.add_argument('--model',         default='isnet-general-use',
                        help='rembg model (default: isnet-general-use). Also try: birefnet-general, u2net')
    parser.add_argument('--skip-rembg',    action='store_true', help='Skip rembg step (use existing output)')
    # Media server (for dealer photos + ImgIX upload)
    parser.add_argument('--ssh-server',    default='',     help='SSH target, e.g. root@165.227.32.132')
    parser.add_argument('--media-path',    default='/var/www/media.liftkit.click/public_html',
                        help='Absolute path on media server')
    parser.add_argument('--imgix-domain',  default='liftkit.imgix.net', help='ImgIX domain')
    parser.add_argument('--ssh-key',       default='',     help='Path to SSH private key (optional)')
    # remove.bg
    parser.add_argument('--removebg-key',  default='',     help='remove.bg API key (optional)')
    args = parser.parse_args()

    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)

    run_imgix_step     = bool(args.ssh_server and args.imgix_domain)
    run_removebg_step  = bool(args.removebg_key)

    print("=" * 50)
    print("  Background Removal Comparison")
    print("=" * 50)
    if not args.source:
        print(f"  Web images  : {args.web_count} (outdoor/lot photos)")
        if args.ssh_server:
            print(f"  Dealer photos: {args.dealer_count} from {args.ssh_server}")
    print(f"  rembg model : {args.model}")
    print(f"  ImgIX       : {'enabled → ' + args.imgix_domain if run_imgix_step else 'skipped'}")
    print(f"  remove.bg   : {'enabled' if run_removebg_step else 'skipped'}")
    print()

    # Step 1: get source images
    if args.source:
        source_dir = Path(args.source)
        images = sorted(f for f in source_dir.iterdir() if f.suffix.lower() in ('.jpg', '.jpeg', '.png', '.webp'))
        print(f"Using {len(images)} existing images from {source_dir}\n")
    else:
        source_dir = output_dir / 'source'
        images = []

        # 1a: web outdoor photos
        web_images = fetch_web_images(source_dir, args.web_count)
        images.extend(web_images)

        # 1b: real dealer photos from media server
        if args.ssh_server:
            dealer_images = fetch_dealer_images(source_dir, args.dealer_count, args.ssh_server, args.media_path, args.ssh_key)
            images.extend(dealer_images)

    if not images:
        print("ERROR: No images to process.")
        sys.exit(1)

    print(f"Total images to process: {len(images)}\n")

    # Step 2: ImgIX
    if run_imgix_step:
        run_imgix(images, output_dir / 'imgix', args.ssh_server, args.media_path, args.imgix_domain, args.ssh_key)

    # Step 3: rembg
    if not args.skip_rembg:
        run_rembg(images, output_dir / 'rembg', args.model)

    # Step 4: remove.bg
    if run_removebg_step:
        run_removebg(images, output_dir / 'removebg', args.removebg_key)

    # Summary
    print("=" * 50)
    print("  Done! Compare these folders:")
    print(f"  Original  → {source_dir}")
    if run_imgix_step:
        print(f"  ImgIX     → {output_dir / 'imgix'}")
    if not args.skip_rembg:
        print(f"  rembg     → {output_dir / 'rembg'}")
    if run_removebg_step:
        print(f"  remove.bg → {output_dir / 'removebg'}")
    print()
    print("  What to check:")
    print("  - Wheels and tires (main failure area)")
    print("  - Window glass (should be transparent)")
    print("  - Mirrors and roof edges")
    print("  - Overall edge sharpness")
    print("  - web_* = internet outdoor photos")
    print("  - dealer_* = real dealer uploads")
    print("=" * 50)


if __name__ == '__main__':
    main()
