#!/usr/bin/env python3
#
# @File: rebinarize-masks.py
# @Date: 2026-06-05
#
# Re-fetch ImgIX masks for the existing dataset at threshold=200 instead of 127.
#
# Problem: The original prepare-training-data.py binarized at threshold=127.
# Window pixels have ImgIX alpha ~120-160, so:
#   - alpha 120-127 → 0 (transparent)   ← correct
#   - alpha 128-160 → 255 (opaque/body) ← WRONG — model learns to treat glass as body
#
# Fix: threshold=200 puts ALL window pixels (120-160) below threshold → labeled transparent.
# Model learns to predict 0 for glass → post-processing fills holes with dark tinted glass.
#
# Strategy:
#   1. SCP each dataset/im/ image to a temp dir on the media server
#   2. Call ImgIX on it and capture RGBA PNG (preserving soft alpha)
#   3. Binarize at threshold=200: alpha>=200 → 255 (body), else 0 (glass/bg)
#   4. Save new GT mask to dataset/gt/
#   5. Remove temp files from server
#
# Usage:
#   cd /path/to/bionicbobby
#   python3 tools/internal/rebinarize-masks.py --dataset dataset
#   python3 tools/internal/rebinarize-masks.py --dataset dataset --dry-run
#
# Install:
#   pip install requests Pillow numpy

import argparse
import subprocess
import sys
import time
from io import BytesIO
from pathlib import Path

import numpy as np
import requests
from PIL import Image

SSH_SERVER   = "root@165.227.32.132"
MEDIA_PATH   = "/var/www/media.liftkit.click/public_html"
IMGIX_DOMAIN = "liftkit-dev6.imgix.net"
REMOTE_TMP   = f"{MEDIA_PATH}/tmp-rebinarize"
SSH_OPTS     = ["-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10"]

IMGIX_PARAMS = "bg-remove=true&bg-remove-add-shadow=true&fm=png"
HEADERS      = {"User-Agent": "Mozilla/5.0"}

# New threshold: 200 puts all window pixels (alpha ~120-160) in the transparent bucket
ALPHA_THRESHOLD = 200


def upload_image(local_path: Path, remote_name: str) -> bool:
    try:
        result = subprocess.run(
            ["scp"] + SSH_OPTS + [str(local_path), f"{SSH_SERVER}:{REMOTE_TMP}/{remote_name}"],
            capture_output=True, text=True, timeout=60
        )
        return result.returncode == 0
    except Exception:
        return False


def fetch_imgix_mask(remote_name: str):
    """
    Fetch the RGBA PNG from ImgIX and binarize at ALPHA_THRESHOLD.
    Returns uint8 mask (255=body, 0=glass+background), or None on failure.
    """
    url = f"https://{IMGIX_DOMAIN}/tmp-rebinarize/{remote_name}?{IMGIX_PARAMS}"
    try:
        resp = requests.get(url, timeout=60, headers=HEADERS)
        if resp.status_code != 200:
            return None
        img   = Image.open(BytesIO(resp.content)).convert("RGBA")
        alpha = np.array(img.split()[3], dtype=np.uint8)
        mask  = (alpha >= ALPHA_THRESHOLD).astype(np.uint8) * 255
        if mask.sum() == 0:
            return None
        return mask
    except Exception as e:
        print(f"    ImgIX error: {e}")
        return None


def delete_remote(remote_name: str) -> None:
    try:
        subprocess.run(
            ["ssh"] + SSH_OPTS + [SSH_SERVER, f"rm -f {REMOTE_TMP}/{remote_name}"],
            capture_output=True, timeout=15
        )
    except Exception:
        pass


def setup_remote_tmp() -> bool:
    try:
        result = subprocess.run(
            ["ssh"] + SSH_OPTS + [SSH_SERVER, f"mkdir -p {REMOTE_TMP}"],
            capture_output=True, text=True, timeout=15
        )
        return result.returncode == 0
    except Exception:
        return False


def main():
    parser = argparse.ArgumentParser(
        description="Re-fetch GT masks at threshold=200 (windows become transparent)"
    )
    parser.add_argument("--dataset",   default="dataset", help="Dataset root (default: dataset/)")
    parser.add_argument("--dry-run",   action="store_true", help="Print what would happen, don't call ImgIX")
    parser.add_argument("--overwrite", action="store_true", default=True,
                        help="Overwrite existing GT masks (default: True)")
    parser.add_argument("--no-overwrite", dest="overwrite", action="store_false",
                        help="Skip images whose GT mask already exists")
    args = parser.parse_args()

    im_dir = Path(args.dataset) / "im"
    gt_dir = Path(args.dataset) / "gt"
    gt_dir.mkdir(parents=True, exist_ok=True)

    images = sorted(im_dir.glob("*.jpg"))
    if not images:
        print(f"No images found in {im_dir}")
        sys.exit(1)

    print(f"Dataset  : {args.dataset}/")
    print(f"Images   : {len(images)}")
    print(f"Threshold: {ALPHA_THRESHOLD} (was 127 — windows 120-160 now become transparent)")
    print(f"Dry run  : {args.dry_run}\n")

    if args.dry_run:
        for img in images:
            gt = gt_dir / f"{img.stem}.png"
            status = "overwrite" if gt.exists() else "new"
            print(f"  {img.name}  → {gt.name}  [{status}]")
        return

    # Create temp dir on server
    if not setup_remote_tmp():
        print("ERROR: Could not create temp dir on media server")
        sys.exit(1)

    ok = skip = fail = 0

    for i, img_path in enumerate(images, 1):
        gt_path = gt_dir / f"{img_path.stem}.png"

        if not args.overwrite and gt_path.exists():
            skip += 1
            continue

        print(f"[{i:3d}/{len(images)}] {img_path.name}", end=" ", flush=True)

        # Upload to server
        remote_name = f"rb_{img_path.name}"
        if not upload_image(img_path, remote_name):
            print("→ FAIL (upload)")
            fail += 1
            continue

        # ImgIX needs a moment to see the new file
        time.sleep(0.5)

        # Fetch new mask at threshold=200
        mask = fetch_imgix_mask(remote_name)

        # Clean up server immediately
        delete_remote(remote_name)

        if mask is None:
            print("→ FAIL (ImgIX)")
            fail += 1
            continue

        # Count how many pixels changed
        old_mask = np.array(Image.open(gt_path).convert("L")) if gt_path.exists() else None
        if old_mask is not None and old_mask.shape == mask.shape:
            changed = int(np.sum(old_mask != mask))
            frac    = changed / mask.size
            print(f"→ OK  (changed {frac:.1%} of pixels)")
        else:
            print(f"→ OK  (new)")

        Image.fromarray(mask, mode="L").save(gt_path)
        ok += 1

        time.sleep(0.2)

    # Clean up remote temp dir
    subprocess.run(
        ["ssh"] + SSH_OPTS + [SSH_SERVER, f"rmdir {REMOTE_TMP} 2>/dev/null; true"],
        capture_output=True, timeout=15
    )

    print(f"\n{'='*55}")
    print(f"Done: {ok} masks updated, {skip} skipped, {fail} failed")
    print(f"GT masks → {gt_dir}/")
    print(f"\nNext: retrain from scratch with corrected masks:")
    print(f"  python3 tools/internal/finetune-birefnet.py --dataset {args.dataset} --epochs 60")
    print("=" * 55)


if __name__ == "__main__":
    main()
