#!/usr/bin/env python3
#
# @File: fetch-stock-masks.py
# @Date: 2026-06-04
#
# Upload manually curated stock car photos to the media server, fetch ImgIX
# bg-remove masks, and save pairs to dataset/im/ + dataset/gt/.
#
# Handles JPEG, PNG, AVIF, WebP — converts everything to JPEG before upload.
# Preserves soft alpha (0-255) from ImgIX — windows get ~120-160, not binarised.
# Adds s-prefix names (s001, s002 …) to keep stock images distinct from server images.
#
# Usage:
#   cd bionicbobby/
#   python3 tools/internal/fetch-stock-masks.py \
#       --source "/Users/muhammadalishahnawaz/Downloads/cars data set" \
#       --dataset dataset/
#
#   Add --dry-run to preview what will be processed without touching ImgIX.

import argparse
import subprocess
import sys
import time
from io import BytesIO
from pathlib import Path

import requests
from PIL import Image
import numpy as np

SSH_SERVER   = "root@165.227.32.132"
MEDIA_PATH   = "/var/www/media.liftkit.click/public_html"
REMOTE_DIR   = "training-stock"
IMGIX_DOMAIN = "liftkit-dev6.imgix.net"
IMGIX_PARAMS = "bg-remove=true&bg-remove-add-shadow=true&fm=png"
HEADERS      = {"User-Agent": "Mozilla/5.0"}
SSH_KEY      = str(Path.home() / ".ssh" / "id_ed25519")
SSH_OPTS     = [
    "-i", SSH_KEY,
    "-o", "StrictHostKeyChecking=no",
    "-o", "ConnectTimeout=15",
]

SUPPORTED_EXTS = {".jpg", ".jpeg", ".png", ".avif", ".webp"}


def ssh(cmd: str) -> str:
    r = subprocess.run(
        ["ssh"] + SSH_OPTS + [SSH_SERVER, cmd],
        capture_output=True, text=True, timeout=30,
    )
    if r.returncode != 0:
        raise RuntimeError(f"SSH error: {r.stderr.strip()}")
    return r.stdout.strip()


def scp_up(local: Path, remote: str) -> bool:
    r = subprocess.run(
        ["scp"] + SSH_OPTS + [str(local), f"{SSH_SERVER}:{remote}"],
        capture_output=True, text=True, timeout=90,
    )
    return r.returncode == 0


def fetch_mask(rel_path: str) -> "np.ndarray | None":
    url = f"https://{IMGIX_DOMAIN}/{rel_path}?{IMGIX_PARAMS}"
    try:
        resp = requests.get(url, timeout=60, headers=HEADERS)
        if resp.status_code != 200:
            print(f"  ImgIX HTTP {resp.status_code}")
            return None
        img = Image.open(BytesIO(resp.content))
        # ImgIX may return mode-P PNGs; convert("RGBA") handles palette transparency
        rgba  = img.convert("RGBA")
        alpha = np.array(rgba.split()[3])
        if alpha.max() == 0:
            print("  Empty mask — skipping")
            return None
        return alpha
    except Exception as e:
        print(f"  Error: {e}")
        return None


def next_stock_index(im_dir: Path) -> int:
    """Find the next s### index not already used in the dataset."""
    used = set()
    for p in im_dir.glob("s*.jpg"):
        stem = p.stem[1:]   # strip 's'
        if stem.isdigit():
            used.add(int(stem))
    idx = 1
    while idx in used:
        idx += 1
    return idx


def main():
    parser = argparse.ArgumentParser(
        description="Upload stock photos to media server and fetch ImgIX bg-remove masks"
    )
    parser.add_argument("--source",  required=True, help="Folder of manually curated car photos (any format)")
    parser.add_argument("--dataset", default="dataset", help="Dataset output dir (default: dataset/)")
    parser.add_argument("--dry-run", action="store_true", help="List images without uploading or calling ImgIX")
    args = parser.parse_args()

    source_dir = Path(args.source)
    im_dir     = Path(args.dataset) / "im"
    gt_dir     = Path(args.dataset) / "gt"

    if not source_dir.exists():
        print(f"ERROR: source folder not found: {source_dir}")
        sys.exit(1)

    images = sorted([p for p in source_dir.iterdir() if p.suffix.lower() in SUPPORTED_EXTS])

    print("=" * 60)
    print("  Stock Photo → Dataset Pipeline")
    print("=" * 60)
    print(f"Source     : {source_dir}/  ({len(images)} images)")
    print(f"Dataset    : {args.dataset}/")
    print(f"ImgIX      : {IMGIX_DOMAIN}")
    print(f"Credits    : ~{len(images)} bg-remove calls\n")

    if args.dry_run:
        for i, p in enumerate(images, 1):
            print(f"  [{i:02d}] {p.name}")
        return

    im_dir.mkdir(parents=True, exist_ok=True)
    gt_dir.mkdir(parents=True, exist_ok=True)

    remote_base = f"{MEDIA_PATH}/{REMOTE_DIR}"
    ssh(f"mkdir -p {remote_base}")

    idx  = next_stock_index(im_dir)
    ok   = fail = skip = 0

    for src_path in images:
        stem    = f"s{idx:03d}"
        im_dest = im_dir / f"{stem}.jpg"
        gt_dest = gt_dir / f"{stem}.png"

        if im_dest.exists() and gt_dest.exists():
            print(f"[{stem}] SKIP — already done")
            skip += 1
            idx += 1
            continue

        print(f"[{stem}] {src_path.name}", end=" ", flush=True)

        # Convert to JPEG (handles AVIF, PNG, WebP)
        tmp_jpg = Path(f"/tmp/{stem}_stock.jpg")
        try:
            img = Image.open(src_path).convert("RGB")
            img.save(tmp_jpg, "JPEG", quality=92)
        except Exception as e:
            print(f"→ convert failed: {e}")
            fail += 1
            idx += 1
            continue

        # Upload to media server
        remote_full = f"{remote_base}/{stem}.jpg"
        rel_path    = f"{REMOTE_DIR}/{stem}.jpg"
        print("→ uploading", end=" ", flush=True)
        if not scp_up(tmp_jpg, remote_full):
            print("FAILED")
            fail += 1
            idx += 1
            tmp_jpg.unlink(missing_ok=True)
            continue

        # Fetch ImgIX mask (soft alpha, preserves window semi-transparency)
        print("→ ImgIX", end=" ", flush=True)
        time.sleep(0.3)
        mask = fetch_mask(rel_path)
        if mask is None:
            print("FAILED")
            fail += 1
            idx += 1
            tmp_jpg.unlink(missing_ok=True)
            continue

        # Save pair
        tmp_jpg.rename(im_dest)
        Image.fromarray(mask, mode="L").save(gt_dest)
        size_kb = im_dest.stat().st_size // 1024
        print(f"→ OK  ({size_kb} KB)")
        ok += 1
        idx += 1

    total = len(list(im_dir.glob("*.jpg")))
    print(f"\n{'='*60}")
    print(f"OK         : {ok}")
    print(f"Skipped    : {skip}  (already done)")
    print(f"Failed     : {fail}")
    print(f"Dataset now: {total} pairs")
    print(f"{'='*60}")
    if ok > 0:
        print(f"\nNext step:")
        print(f"  python3 tools/internal/finetune-birefnet.py --dataset {args.dataset} --epochs 80 --size 512")


if __name__ == "__main__":
    main()
