#!/usr/bin/env python3
#
# @File: fetch-imgix-masks.py
# @Date: 2026-06-01
#
# Step 3 of 3: Call ImgIX on the filtered source images to get ground-truth masks.
# Run this ONLY after filter-candidates.py has confirmed your source images
# show visible outdoor background through the windows.
#
# Each ImgIX call uses credits — this script is intentionally the last step.
#
# Output layout ready for finetune-birefnet.py:
#   dataset/
#     im/   ← original car JPEGs (copied from filtered source)
#     gt/   ← binary alpha masks (255=car body, 0=windows+background)
#
# Usage:
#   python3 tools/internal/fetch-imgix-masks.py --source candidates/filtered
#   python3 tools/internal/fetch-imgix-masks.py --source candidates/filtered --out dataset

import sys
import time
import argparse
import requests
from pathlib import Path
from io import BytesIO
import shutil

import numpy as np
from PIL import Image

SSH_SERVER   = "root@165.227.32.132"
MEDIA_PATH   = "/var/www/media.liftkit.click/public_html"
IMGIX_DOMAIN = "liftkit-dev6.imgix.net"
IMGIX_PARAMS = "bg-remove=true&bg-remove-add-shadow=true&fm=png"
HEADERS      = {"User-Agent": "Mozilla/5.0"}


def filename_to_remote_path(filename: str) -> str:
    """
    Each source image was downloaded as {idx:04d}.jpg but we need the
    original server path to construct the ImgIX URL.
    We store the mapping in candidates/source/paths.txt written by download-candidates.py.
    If not available, the image can't be processed.
    """
    # Fallback handled in main() — see path_map loading
    return ""


def fetch_mask(remote_rel_path: str) -> "np.ndarray | None":
    """
    Fetch bg-removed PNG from ImgIX and return binary alpha mask.
    remote_rel_path: path relative to media root, e.g. 'bc1172/1467313/file-large.jpg'
    """
    url = f"https://{IMGIX_DOMAIN}/{remote_rel_path}?{IMGIX_PARAMS}"
    try:
        resp = requests.get(url, timeout=45, headers=HEADERS)
        if resp.status_code != 200:
            print(f"    ImgIX HTTP {resp.status_code}")
            return None
        img = Image.open(BytesIO(resp.content))
        # ImgIX returns mode "P" (palette) PNGs with transparency in the palette.
        # PIL's convert("RGBA") handles this correctly only when called on the
        # original mode-P image — do NOT intermediate-convert to RGB first.
        rgba  = img.convert("RGBA")
        alpha = np.array(rgba.split()[3])
        # Save the raw alpha channel — ImgIX uses semi-transparent values on
        # windows (alpha 50-200), so we preserve the full 0-255 range.
        # BiRefNet training uses this as a soft mask (not binarised).
        if alpha.max() == 0:
            print("    Empty mask — skipping")
            return None
        return alpha
    except Exception as e:
        print(f"    Error: {e}")
        return None


def main():
    parser = argparse.ArgumentParser(description="Fetch ImgIX ground-truth masks for filtered images")
    parser.add_argument("--source", required=True, help="Folder of filtered source JPEGs")
    parser.add_argument("--out",    default="dataset", help="Output dataset directory (default: dataset/)")
    args = parser.parse_args()

    source_dir = Path(args.source)
    out_dir    = Path(args.out)
    im_dir     = out_dir / "im"
    gt_dir     = out_dir / "gt"
    im_dir.mkdir(parents=True, exist_ok=True)
    gt_dir.mkdir(parents=True, exist_ok=True)

    # Load path map written by download-candidates.py
    path_map_file = source_dir.parent / "source" / "paths.txt"
    if not path_map_file.exists():
        # Try sibling paths.txt
        path_map_file = source_dir / "paths.txt"
    if not path_map_file.exists():
        print(f"ERROR: paths.txt not found. Expected at {path_map_file}")
        print("Re-run download-candidates.py — it writes paths.txt automatically.")
        sys.exit(1)

    path_map = {}  # filename stem → relative server path
    for line in path_map_file.read_text().splitlines():
        if "\t" in line:
            stem, rel = line.split("\t", 1)
            path_map[stem] = rel.strip()

    images = sorted(source_dir.glob("*.jpg"))
    if not images:
        print(f"No JPEGs found in {source_dir}")
        sys.exit(1)

    print(f"Source     : {source_dir}/  ({len(images)} images)")
    print(f"Output     : {out_dir}/")
    print(f"ImgIX      : {IMGIX_DOMAIN}")
    print(f"Credits    : ~{len(images)} ImgIX bg-remove calls\n")

    ok = fail = skip = 0

    for img_path in images:
        stem    = img_path.stem
        im_dest = im_dir / f"{stem}.jpg"
        gt_dest = gt_dir / f"{stem}.png"

        # Resume
        if im_dest.exists() and gt_dest.exists():
            skip += 1
            continue

        rel_path = path_map.get(stem)
        if not rel_path:
            print(f"[{stem}] SKIP — no server path in paths.txt")
            fail += 1
            continue

        print(f"[{stem}]", end=" ", flush=True)

        # Fetch ImgIX mask
        time.sleep(0.2)
        mask = fetch_mask(rel_path)
        if mask is None:
            print("SKIP (ImgIX failed)")
            fail += 1
            continue

        # Copy source image + save mask
        shutil.copy2(img_path, im_dest)
        Image.fromarray(mask, mode="L").save(gt_dest)
        size_kb = img_path.stat().st_size // 1024
        print(f"OK  ({size_kb} KB)")
        ok += 1

    print(f"\n{'='*50}")
    print(f"Masks fetched : {ok}")
    print(f"Skipped       : {skip}  (already done)")
    print(f"Failed        : {fail}")
    print(f"Dataset       : {out_dir}/im/  +  {out_dir}/gt/")
    if ok + skip > 0:
        print(f"\nNext step:")
        print(f"  python3 tools/internal/finetune-birefnet.py --dataset {out_dir}")
    print("=" * 50)


if __name__ == "__main__":
    main()
