#!/usr/bin/env python3
#
# @File: process-staging.py
# @Date: 2026-06-08
#
# Process every image in dataset/staging/ through ImgIX to build the fine-tuning dataset.
#
# Staging image prefixes and how each is handled:
#   server_XXXX  — already on server; path looked up from candidates/source/paths.txt
#   interior_XXXX — same as server_* (from candidates/source)
#   cars_*       — local only; uploaded to training-stock/ on server first
#   dealer_*     — local only; uploaded to training-stock/ on server first
#
# Output:
#   dataset/im/XXXX.jpg    — source image
#   dataset/gt/XXXX.png    — soft alpha mask (body≈255, windows≈120-160, bg=0)
#   dataset/rgba/XXXX.png  — full RGBA result (source pixels + ImgIX alpha)
#
# Pairing is tracked via dataset/manifest.json: {staging_filename → stem}.
# This prevents position-shift bugs when new files are added to staging/.
#
# Usage:
#   python3 tools/internal/process-staging.py
#   python3 tools/internal/process-staging.py --min-window 0.003   # relax filter
#   python3 tools/internal/process-staging.py --dry-run            # show plan only

import sys
import time
import json
import argparse
import subprocess
import requests
from pathlib import Path
from io import BytesIO
from PIL import Image
import numpy as np

SSH_USER      = "root"
SSH_HOST      = "165.227.32.132"
SERVER_DOMAIN = "media.liftkit.click"
IMGIX_ENV     = "dev6"

SSH_SERVER    = f"{SSH_USER}@{SSH_HOST}"
MEDIA_PATH    = f"/var/www/{SERVER_DOMAIN}/public_html"
STOCK_PATH    = f"/var/www/{SERVER_DOMAIN}/public_html/training-stock"
IMGIX_DOMAIN  = f"liftkit-{IMGIX_ENV}.imgix.net"
SSH_OPTS      = ["-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10"]
IMGIX_PARAMS  = "bg-remove=true&bg-remove-add-shadow=true&fm=png"
HEADERS       = {"User-Agent": "Mozilla/5.0"}

STAGING_DIR   = Path("dataset/staging")
PATHS_TXT     = Path("candidates/source/paths.txt")


# ──────────────────────────────────────────────────────
# Manifest — maps staging_filename → assigned stem
# ──────────────────────────────────────────────────────
def load_manifest(manifest_path: Path) -> dict:
    if manifest_path.exists():
        return json.loads(manifest_path.read_text())
    return {"version": 2, "stems": {}, "next_stem": 1}


def save_manifest(manifest: dict, manifest_path: Path):
    manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True))


def assign_stem(staging_name: str, manifest: dict) -> str:
    """Return the existing stem for this file, or assign the next available one."""
    if staging_name in manifest["stems"]:
        return manifest["stems"][staging_name]
    stem = f"{manifest['next_stem']:04d}"
    manifest["stems"][staging_name] = stem
    manifest["next_stem"] += 1
    return stem


# ──────────────────────────────────────────────────────
# Path lookup
# ──────────────────────────────────────────────────────
def load_paths_index() -> dict:
    if not PATHS_TXT.exists():
        print(f"ERROR: {PATHS_TXT} not found — run from bionicbobby/ root")
        sys.exit(1)
    index = {}
    with open(PATHS_TXT) as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) == 2:
                index[parts[0]] = parts[1]
    return index


def resolve_server_path(staging_name: str, paths_index: dict):
    if staging_name.startswith("server_") or staging_name.startswith("interior_"):
        idx = staging_name.split("_", 1)[1].replace(".jpg", "")
        rel = paths_index.get(idx)
        if rel:
            return f"{MEDIA_PATH}/{rel}"
    return None


# ──────────────────────────────────────────────────────
# Upload helper (for cars_* and dealer_*)
# ──────────────────────────────────────────────────────
def upload_to_stock(local_path: Path, remote_name: str):
    remote_path = f"{STOCK_PATH}/{remote_name}"
    try:
        result = subprocess.run(
            ["scp"] + SSH_OPTS + [str(local_path), f"{SSH_SERVER}:{remote_path}"],
            capture_output=True, text=True, timeout=60
        )
        if result.returncode == 0:
            return remote_path
        print(f"    SCP upload failed: {result.stderr.strip()}")
        return None
    except subprocess.TimeoutExpired:
        print("    SCP upload timed out")
        return None


# ──────────────────────────────────────────────────────
# ImgIX — fetch full RGBA
# ──────────────────────────────────────────────────────
def fetch_imgix_rgba(remote_path: str):
    """Fetch the full RGBA image from ImgIX. Returns PIL Image (RGBA) or None."""
    rel = remote_path.replace(MEDIA_PATH.rstrip("/") + "/", "")
    url = f"https://{IMGIX_DOMAIN}/{rel}?{IMGIX_PARAMS}"
    try:
        resp = requests.get(url, timeout=30, headers=HEADERS)
        if resp.status_code != 200:
            print(f"    ImgIX HTTP {resp.status_code}")
            return None
        img = Image.open(BytesIO(resp.content)).convert("RGBA")
        if np.array(img.split()[3]).max() == 0:
            print("    Empty mask")
            return None
        return img
    except Exception as e:
        print(f"    ImgIX error: {e}")
        return None


# ──────────────────────────────────────────────────────
# Window transparency filter
# ──────────────────────────────────────────────────────
def measure_window_transparency(alpha: np.ndarray) -> float:
    transparent = alpha < 200
    try:
        from scipy.ndimage import label
        labeled, _ = label(transparent)
        border_ids = set(np.concatenate([
            labeled[0, :], labeled[-1, :], labeled[:, 0], labeled[:, -1]
        ]).tolist()) - {0}
        interior = transparent & ~np.isin(labeled, list(border_ids))
    except ImportError:
        from collections import deque
        h, w = alpha.shape
        visited = np.zeros((h, w), dtype=bool)
        queue = deque()
        def enq(y, x):
            if transparent[y, x] and not visited[y, x]:
                visited[y, x] = True
                queue.append((y, x))
        for x in range(w): enq(0, x); enq(h-1, x)
        for y in range(h): enq(y, 0); enq(y, w-1)
        while queue:
            y, x = queue.popleft()
            for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
                ny, nx = y + dy, x + dx
                if 0 <= ny < h and 0 <= nx < w and transparent[ny, nx] and not visited[ny, nx]:
                    visited[ny, nx] = True
                    queue.append((ny, nx))
        interior = transparent & ~visited
    return int(interior.sum()) / alpha.size


# ──────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────
def main():
    parser = argparse.ArgumentParser(description="Process staging → dataset/im + gt + rgba via ImgIX")
    parser.add_argument("--staging",    default="dataset/staging",  help="Staging directory")
    parser.add_argument("--out",        default="dataset",           help="Output root")
    parser.add_argument("--min-window", type=float, default=0.003,   help="Min interior window fraction")
    parser.add_argument("--dry-run",    action="store_true",         help="Print plan without fetching")
    args = parser.parse_args()

    staging_dir   = Path(args.staging)
    out_dir       = Path(args.out)
    im_dir        = out_dir / "im"
    gt_dir        = out_dir / "gt"
    rgba_dir      = out_dir / "rgba"
    manifest_path = out_dir / "manifest.json"

    im_dir.mkdir(parents=True, exist_ok=True)
    gt_dir.mkdir(parents=True, exist_ok=True)
    rgba_dir.mkdir(parents=True, exist_ok=True)

    staging_files = sorted(f for f in staging_dir.iterdir() if f.suffix.lower() in ('.jpg', '.jpeg', '.png', '.avif'))
    paths_index   = load_paths_index()
    manifest      = load_manifest(manifest_path)

    print(f"Staging     : {staging_dir}/ ({len(staging_files)} images)")
    print(f"Output      : {out_dir}/  (im/ + gt/ + rgba/)")
    print(f"Min windows : {args.min_window:.3f}")
    print(f"Manifest    : {len(manifest['stems'])} known, next_stem={manifest['next_stem']}")
    if args.dry_run:
        print("DRY RUN — no ImgIX calls\n")
    print()

    ok = filtered = failed = skipped = 0

    for i, staging_path in enumerate(staging_files, 1):
        name   = staging_path.name
        prefix = name.split("_")[0]

        # Assign stem from manifest (stable, not position-based)
        stem   = assign_stem(name, manifest)
        im_out = im_dir  / f"{stem}.jpg"
        gt_out = gt_dir  / f"{stem}.png"
        rgba_out = rgba_dir / f"{stem}.png"

        # Resume: skip if all three outputs exist
        if im_out.exists() and gt_out.exists() and rgba_out.exists():
            print(f"[{i:3d}/{len(staging_files)}] {name} → skip (stem {stem}, already done)")
            ok += 1
            skipped += 1
            save_manifest(manifest, manifest_path)  # persist any new stem assignments
            continue

        # If im+gt exist but rgba missing — generate rgba without re-fetching
        if im_out.exists() and gt_out.exists() and not rgba_out.exists():
            src   = Image.open(im_out).convert("RGB")
            alpha = Image.open(gt_out).convert("L")
            if alpha.size != src.size:
                alpha = alpha.resize(src.size, Image.LANCZOS)
            src.putalpha(alpha)
            src.save(rgba_out)
            print(f"[{i:3d}/{len(staging_files)}] {name} → rgba regenerated (stem {stem})")
            ok += 1
            skipped += 1
            save_manifest(manifest, manifest_path)
            continue

        print(f"[{i:3d}/{len(staging_files)}] {name} (stem {stem})", end=" ", flush=True)

        # ── Resolve server path ───────────────────────────────────────────
        server_path = resolve_server_path(name, paths_index)

        if server_path is None:
            remote_name = name[len(prefix) + 1:]
            if not remote_name.lower().endswith(('.jpg', '.jpeg')):
                remote_name = Path(remote_name).stem + ".jpg"

            if args.dry_run:
                print(f"→ would upload {remote_name} then fetch ImgIX")
                save_manifest(manifest, manifest_path)
                continue

            print(f"[upload→{remote_name}]", end=" ", flush=True)

            local_path = staging_path
            if staging_path.suffix.lower() in ('.avif', '.png'):
                tmp_jpg = im_dir / f"_tmp_{stem}.jpg"
                Image.open(staging_path).convert("RGB").save(tmp_jpg, "JPEG", quality=95)
                local_path = tmp_jpg

            server_path = upload_to_stock(local_path, remote_name)
            if local_path != staging_path:
                local_path.unlink(missing_ok=True)

            if server_path is None:
                print("→ SKIP (upload failed)")
                failed += 1
                save_manifest(manifest, manifest_path)
                continue
        else:
            if args.dry_run:
                print(f"→ server path: {server_path}")
                save_manifest(manifest, manifest_path)
                continue

        # ── Fetch full RGBA from ImgIX ────────────────────────────────────
        time.sleep(0.25)
        rgba_img = fetch_imgix_rgba(server_path)
        if rgba_img is None:
            print("→ SKIP (ImgIX failed)")
            failed += 1
            save_manifest(manifest, manifest_path)
            continue

        # ── Extract alpha and apply window filter ─────────────────────────
        alpha_arr = np.array(rgba_img.split()[3])
        frac = measure_window_transparency(alpha_arr)
        if frac < args.min_window:
            print(f"→ filtered (window_frac={frac:.4f})")
            filtered += 1
            save_manifest(manifest, manifest_path)
            continue

        # ── Save im/ (source RGB) ─────────────────────────────────────────
        if staging_path.suffix.lower() in ('.avif', '.png'):
            Image.open(staging_path).convert("RGB").save(im_out, "JPEG", quality=95)
        else:
            import shutil
            shutil.copy2(staging_path, im_out)

        # ── Save gt/ (alpha mask, L mode) ─────────────────────────────────
        Image.fromarray(alpha_arr, mode="L").save(gt_out)

        # ── Save rgba/ (full RGBA result) ─────────────────────────────────
        rgba_img.save(rgba_out)

        size_kb = im_out.stat().st_size // 1024
        print(f"→ OK  ({size_kb}KB, windows={frac:.2%}, stem={stem})")
        ok += 1

        save_manifest(manifest, manifest_path)

    save_manifest(manifest, manifest_path)

    print(f"\n{'='*55}")
    print(f"Dataset complete : {ok} qualifying pairs")
    print(f"  Filtered out   : {filtered}  (windows < {args.min_window:.3f})")
    print(f"  Failed         : {failed}  (upload/ImgIX errors)")
    print(f"  Resumed        : {skipped}  (already on disk)")
    print(f"  Images → {im_dir}/")
    print(f"  Masks  → {gt_dir}/")
    print(f"  RGBA   → {rgba_dir}/")
    if ok > 0 and not args.dry_run:
        print(f"\nNext step:")
        print(f"  PYTORCH_ENABLE_MPS_FALLBACK=1 python3 tools/internal/finetune-birefnet.py --dataset {out_dir}")
    print("=" * 55)


if __name__ == "__main__":
    main()
