#!/usr/bin/env python3
#
# @File: prepare-training-data.py
# @Date: 2026-06-01
#
# Build a BiRefNet fine-tuning dataset using ImgIX automotive masks as ground truth.
#
# Strategy:
#   1. SSH into media server → find one large JPEG per unique vehicle
#   2. SCP images locally
#   3. Call ImgIX with automotive bg-remove (which handles windows correctly)
#   4. FILTER: keep only images where background is visible through windows
#      (connected-component analysis on mask — interior transparent holes = windows)
#   5. Save qualifying pairs to dataset/im/ and dataset/gt/
#
# Why the filter matters:
#   If a photo shows no background through the windows (tinted glass, dark interior,
#   direct sunlight making glass opaque), BiRefNet and ImgIX produce nearly identical
#   masks. There is nothing for the model to learn from those images.
#   We only keep images where ImgIX actually punches transparent holes through the
#   glass — those are the teaching examples.
#
# Install:
#   pip install requests Pillow numpy scipy
#
# Usage:
#   python3 tools/internal/prepare-training-data.py --test          ← verify connectivity first
#   python3 tools/internal/prepare-training-data.py --count 100
#   python3 tools/internal/prepare-training-data.py --count 100 --out /path/to/dataset

import sys
import time
import argparse
import subprocess
import requests
from pathlib import Path
from io import BytesIO
from PIL import Image
import numpy as np

SSH_SERVER   = "root@165.227.32.132"
MEDIA_PATH   = "/var/www/media.liftkit.click/public_html"
IMGIX_DOMAIN = "liftkit-dev6.imgix.net"
SSH_OPTS     = ["-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10"]

# bg-remove-add-shadow triggers ImgIX's automotive model (handles windshield/windows)
IMGIX_PARAMS = "bg-remove=true&bg-remove-add-shadow=true&fm=png"
HEADERS      = {"User-Agent": "Mozilla/5.0"}


# ──────────────────────────────────────────────────────
# SSH / download helpers
# ──────────────────────────────────────────────────────
def find_remote_images(count: int) -> list:
    """Find one large JPEG per unique vehicle on the media server (shuffled for diversity)."""
    find_cmd = (
        f'find {MEDIA_PATH} -name "*-large.jpg" -printf "%h\\n" | '
        f'sort -u | '
        f'while read dir; do ls "$dir"/*-large.jpg 2>/dev/null | head -1; done | '
        f'shuf | head -{count}'
    )
    result = subprocess.run(
        ["ssh"] + SSH_OPTS + [SSH_SERVER, find_cmd],
        capture_output=True, text=True, timeout=30
    )
    if result.returncode != 0 or not result.stdout.strip():
        print(f"  SSH error: {result.stderr.strip()}")
        sys.exit(1)
    paths = [p.strip() for p in result.stdout.strip().splitlines() if p.strip()]
    print(f"  Found {len(paths)} vehicle images on server")
    return paths


def download_image(remote_path: str, dest: Path) -> bool:
    try:
        result = subprocess.run(
            ["scp"] + SSH_OPTS + [f"{SSH_SERVER}:{remote_path}", str(dest)],
            capture_output=True, text=True, timeout=60
        )
        return result.returncode == 0 and dest.exists() and dest.stat().st_size > 5_000
    except subprocess.TimeoutExpired:
        dest.unlink(missing_ok=True)
        return False


# ──────────────────────────────────────────────────────
# ImgIX mask fetching
# ──────────────────────────────────────────────────────
def fetch_imgix_mask(remote_path: str):
    """
    Call ImgIX with automotive bg-remove.
    Returns binary uint8 numpy array (255=car body, 0=windows+background), or None on failure.
    """
    rel = remote_path.replace(MEDIA_PATH.rstrip("/") + "/", "")
    url = f"https://{IMGIX_DOMAIN}/{rel}?{IMGIX_PARAMS}"

    try:
        resp = requests.get(url, timeout=30, headers=HEADERS)
        if resp.status_code != 200:
            print(f"    ImgIX HTTP {resp.status_code}")
            return None
        img   = Image.open(BytesIO(resp.content)).convert("RGBA")
        alpha = np.array(img.split()[3])
        # Save soft alpha directly — body≈255, windows≈120-160, background=0.
        # The model learns to predict these exact values, not a binary mask.
        # Binary thresholding destroys the semi-transparent window signal.
        mask  = alpha.astype(np.uint8)
        if mask.max() == 0:
            print("    Empty mask")
            return None
        return mask
    except Exception as e:
        print(f"    ImgIX error: {e}")
        return None


# ──────────────────────────────────────────────────────
# Window transparency filter  ← the key quality gate
# ──────────────────────────────────────────────────────
def measure_window_transparency(mask: np.ndarray, min_fraction: float = 0.005) -> float:
    """
    Detect interior transparent regions in the ImgIX mask — these are windows.

    The mask has two types of transparent pixels (value=0):
      a) Background — connected to the image border (outside the car)
      b) Windows    — NOT connected to the border, surrounded by solid car body

    We use connected-component analysis to separate (a) from (b).
    Returns the fraction of total image pixels that are interior windows.
    If > min_fraction, the image has useful visible-through-glass content.

    Uses scipy if available (fast); falls back to a pure-numpy BFS (slower but works).
    """
    transparent = mask < 200  # catches soft-alpha windows (120-160) and background (0)

    try:
        from scipy.ndimage import label
        labeled, _ = label(transparent)
        # Components that touch the image border = background
        border_ids = set(np.concatenate([
            labeled[0, :], labeled[-1, :], labeled[:, 0], labeled[:, -1]
        ]).tolist()) - {0}
        # Interior = transparent and NOT a background component
        interior = transparent & ~np.isin(labeled, list(border_ids))

    except ImportError:
        # Pure-numpy BFS fallback
        h, w = mask.shape
        visited = np.zeros((h, w), dtype=bool)
        from collections import deque
        queue = deque()

        def enqueue_border(y, x):
            if transparent[y, x] and not visited[y, x]:
                visited[y, x] = True
                queue.append((y, x))

        for x in range(w):
            enqueue_border(0, x)
            enqueue_border(h - 1, x)
        for y in range(h):
            enqueue_border(y, 0)
            enqueue_border(y, w - 1)

        while queue:
            y, x = queue.popleft()
            for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
                ny, nx = y + dy, x + dx
                if 0 <= ny < h and 0 <= nx < w and transparent[ny, nx] and not visited[ny, nx]:
                    visited[ny, nx] = True
                    queue.append((ny, nx))

        interior = transparent & ~visited

    interior_px = int(interior.sum())
    fraction    = interior_px / mask.size
    return fraction


def has_visible_windows(mask: np.ndarray, min_fraction: float = 0.005) -> tuple:
    """
    Returns (passes: bool, window_fraction: float).
    min_fraction=0.005 means windows must cover ≥0.5% of the image to count.
    """
    frac = measure_window_transparency(mask, min_fraction)
    return frac >= min_fraction, frac


# ──────────────────────────────────────────────────────
# Connectivity test
# ──────────────────────────────────────────────────────
def test_connectivity():
    print("Testing SSH ...")
    r = subprocess.run(["ssh"] + SSH_OPTS + [SSH_SERVER, "echo ok"],
                       capture_output=True, text=True, timeout=15)
    print(f"  SSH: {'OK' if r.returncode == 0 else 'FAILED — ' + r.stderr.strip()}")

    print("Testing ImgIX + window filter ...")
    find_cmd = f'find {MEDIA_PATH} -name "*-large.jpg" | shuf | head -5'
    r = subprocess.run(["ssh"] + SSH_OPTS + [SSH_SERVER, find_cmd],
                       capture_output=True, text=True, timeout=15)
    if r.returncode != 0 or not r.stdout.strip():
        print("  Could not find test images")
        return

    for remote_path in r.stdout.strip().splitlines():
        remote_path = remote_path.strip()
        mask = fetch_imgix_mask(remote_path)
        if mask is None:
            print(f"  ImgIX: FAILED for {Path(remote_path).name}")
            continue
        passes, frac = has_visible_windows(mask)
        status = "PASS" if passes else "filtered-out"
        print(f"  {Path(remote_path).name}: window_frac={frac:.4f} → {status}")


# ──────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────
def main():
    parser = argparse.ArgumentParser(description="Prepare BiRefNet fine-tuning dataset")
    parser.add_argument("--count",      type=int,   default=100,   help="Target image+mask pairs (default: 100)")
    parser.add_argument("--out",        default="dataset",         help="Output directory (default: dataset/)")
    parser.add_argument("--min-window", type=float, default=0.005, help="Min window fraction to keep image (default: 0.005)")
    parser.add_argument("--test",       action="store_true",       help="Test SSH + ImgIX + filter on 5 images")
    args = parser.parse_args()

    if args.test:
        test_connectivity()
        return

    out_dir = Path(args.out)
    im_dir  = out_dir / "im"
    gt_dir  = out_dir / "gt"
    im_dir.mkdir(parents=True, exist_ok=True)
    gt_dir.mkdir(parents=True, exist_ok=True)

    print(f"Target     : {args.count} qualifying image+mask pairs")
    print(f"Output     : {out_dir}/")
    print(f"Server     : {SSH_SERVER}:{MEDIA_PATH}")
    print(f"ImgIX      : {IMGIX_DOMAIN}")
    print(f"Min windows: {args.min_window:.3f} (interior transparent fraction)\n")

    # Fetch 4x target to account for download failures + window filter rejections
    candidate_count = args.count * 4
    remote_paths = find_remote_images(candidate_count)
    print()

    ok = fail = filtered = skip = 0

    for remote_path in remote_paths:
        if ok >= args.count:
            break

        idx     = ok + 1
        stem    = f"{idx:04d}"
        im_path = im_dir / f"{stem}.jpg"
        gt_path = gt_dir / f"{stem}.png"

        # Resume support: skip already-completed pairs
        if im_path.exists() and gt_path.exists():
            ok += 1
            skip += 1
            continue

        name = Path(remote_path).name
        print(f"[{ok+1:3d}/{args.count}] {name}", end=" ", flush=True)

        # Step 1: download original JPEG
        tmp_path = im_dir / f"_tmp_{name}"
        if not download_image(remote_path, tmp_path):
            print("→ SKIP (download failed)")
            fail += 1
            continue

        # Step 2: get ImgIX mask
        time.sleep(0.25)
        mask = fetch_imgix_mask(remote_path)
        if mask is None:
            tmp_path.unlink(missing_ok=True)
            print("→ SKIP (ImgIX failed)")
            fail += 1
            continue

        # Step 3: window quality filter — discard images with no visible background through glass
        passes, frac = has_visible_windows(mask, min_fraction=args.min_window)
        if not passes:
            tmp_path.unlink(missing_ok=True)
            print(f"→ filtered (window_frac={frac:.4f} < {args.min_window})")
            filtered += 1
            continue

        # Step 4: save the pair
        tmp_path.rename(im_path)
        Image.fromarray(mask, mode="L").save(gt_path)
        size_kb = im_path.stat().st_size // 1024
        print(f"→ OK  ({size_kb} KB, window_frac={frac:.4f})")
        ok += 1

    print(f"\n{'='*55}")
    print(f"Dataset complete : {ok} qualifying pairs saved")
    print(f"  Filtered out   : {filtered}  (no visible background through windows)")
    print(f"  Failed         : {fail}  (download/ImgIX errors)")
    print(f"  Resumed        : {skip}  (already on disk)")
    print(f"  Images → {im_dir}/")
    print(f"  Masks  → {gt_dir}/")
    if ok > 0:
        print(f"\nNext step:")
        print(f"  python3 tools/internal/finetune-birefnet.py --dataset {out_dir}")
    print("=" * 55)


if __name__ == "__main__":
    main()
