#!/usr/bin/env python3
"""
Generate a large-thumbnail review for ONLY the server_* pairs (positions 48-72 in the 72-file staging).
These are the ones that use paths_index and can be mismatched.
Shows orig | composite side-by-side at higher resolution so mismatches are obvious.
"""
import math
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont

DATASET_DIR = Path("dataset")
STAGING_DIR = DATASET_DIR / "staging"
IM_DIR      = DATASET_DIR / "im"
GT_DIR      = DATASET_DIR / "gt"
OUT_DIR     = DATASET_DIR / "review"
THUMB_W     = 480
THUMB_H     = 360
COLS        = 4

def composite_on_white(im_path: Path, gt_path: Path) -> Image.Image:
    orig  = Image.open(im_path).convert("RGB")
    alpha = Image.open(gt_path).convert("L")
    alpha = alpha.resize(orig.size, Image.LANCZOS)
    white = Image.new("RGB", orig.size, (255, 255, 255))
    white.paste(orig, mask=alpha)
    return white

def make_thumb(img: Image.Image, w: int, h: int) -> Image.Image:
    img = img.copy()
    img.thumbnail((w, h), Image.LANCZOS)
    canvas = Image.new("RGB", (w, h), (200, 200, 200))
    x = (w - img.width) // 2
    y = (h - img.height) // 2
    canvas.paste(img, (x, y))
    return canvas

def add_label(img: Image.Image, text: str, warn=False) -> Image.Image:
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 18)
    except Exception:
        font = ImageFont.load_default()
    bg = (180, 0, 0) if warn else (40, 40, 40)
    draw.rectangle([(0, 0), (img.width, 26)], fill=bg)
    draw.text((4, 4), text, fill=(255, 255, 0), font=font)
    return img

def main():
    # Find all server_* staging files in sorted order (same order as main pipeline)
    all_staging = sorted(f for f in STAGING_DIR.iterdir()
                         if f.suffix.lower() in ('.jpg', '.jpeg', '.png', '.avif'))

    server_pairs = []
    for i, f in enumerate(all_staging, 1):
        if f.name.startswith("server_"):
            stem = f"{i:04d}"
            im_path = IM_DIR / f"{stem}.jpg"
            gt_path = GT_DIR / f"{stem}.png"
            if im_path.exists() and gt_path.exists():
                server_pairs.append((stem, f.name, im_path, gt_path))
            else:
                print(f"  SKIP (no pair): {i:04d} = {f.name}")

    print(f"Found {len(server_pairs)} server_* pairs to review")

    PREVIOUSLY_BAD = {"server_0144", "server_0467"}  # known mismatches from previous analysis

    rows = math.ceil(len(server_pairs) / COLS)
    pair_w = THUMB_W * 2 + 6
    page_w = COLS * pair_w + (COLS + 1) * 4
    page_h = rows * (THUMB_H + 30 + 4) + 4
    canvas = Image.new("RGB", (page_w, page_h), (60, 60, 60))

    for i, (stem, staging_name, im_path, gt_path) in enumerate(server_pairs):
        col = i % COLS
        row = i // COLS
        x = 4 + col * (pair_w + 4)
        y = 4 + row * (THUMB_H + 30 + 4)

        orig = Image.open(im_path).convert("RGB")
        comp = composite_on_white(im_path, gt_path)

        left_t  = make_thumb(orig, THUMB_W, THUMB_H)
        right_t = make_thumb(comp, THUMB_W, THUMB_H)

        server_key = staging_name.replace(".jpg", "")
        warn = server_key in PREVIOUSLY_BAD
        left_t  = add_label(left_t,  f"{stem} orig  ({staging_name})", warn=warn)
        right_t = add_label(right_t, f"{stem} comp", warn=warn)

        canvas.paste(left_t,  (x, y))
        canvas.paste(right_t, (x + THUMB_W + 6, y))

    out_path = OUT_DIR / "review_server_pairs.jpg"
    canvas.save(out_path, "JPEG", quality=90)
    print(f"  → {out_path}")
    print()
    print("RED label = previously-identified mismatch. Check if orig car ≠ comp car.")

if __name__ == "__main__":
    main()
