#!/usr/bin/env python3
"""
Generate a visual review grid for all dataset pairs.
Left  = original input (im/)
Right = composite on white using GT mask (gt/)

For server_* pairs the composite reveals if the GT mask matches the input car.
"""
import sys
import math
from pathlib import Path
from PIL import Image
import numpy as np

DATASET_DIR = Path("dataset")
IM_DIR      = DATASET_DIR / "im"
GT_DIR      = DATASET_DIR / "gt"
OUT_DIR     = DATASET_DIR / "review"
THUMB_W     = 320   # thumbnail width per image
THUMB_H     = 240   # thumbnail height per image
COLS        = 6     # pairs per row (each pair = 2 thumbs side-by-side)
PAIRS_PER_PAGE = 24  # pairs per output image

def composite_on_white(im_path: Path, gt_path: Path) -> Image.Image:
    orig = Image.open(im_path).convert("RGB")
    alpha = Image.open(gt_path).convert("L")
    alpha = alpha.resize(orig.size, Image.LANCZOS)
    white = Image.new("RGB", orig.size, (255, 255, 255))
    white.paste(orig, mask=alpha)
    return white

def make_thumb(img: Image.Image, w: int, h: int) -> Image.Image:
    img = img.copy()
    img.thumbnail((w, h), Image.LANCZOS)
    canvas = Image.new("RGB", (w, h), (200, 200, 200))
    x = (w - img.width) // 2
    y = (h - img.height) // 2
    canvas.paste(img, (x, y))
    return canvas

def add_label(img: Image.Image, text: str) -> Image.Image:
    from PIL import ImageDraw, ImageFont
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16)
    except Exception:
        font = ImageFont.load_default()
    draw.rectangle([(0, 0), (img.width, 22)], fill=(40, 40, 40))
    draw.text((4, 3), text, fill=(255, 255, 0), font=font)
    return img

def main():
    pairs = sorted(IM_DIR.glob("*.jpg"))
    print(f"Found {len(pairs)} pairs in {IM_DIR}/")

    OUT_DIR.mkdir(parents=True, exist_ok=True)

    page_idx = 0
    pair_idx = 0
    page_pairs = []

    def flush_page(page_pairs, page_idx):
        n = len(page_pairs)
        rows = math.ceil(n / COLS)
        page_w = COLS * (THUMB_W * 2 + 4) + 4  # 2 thumbs per pair + gap
        page_h = rows * (THUMB_H + 24 + 4) + 4  # +24 for label row, +4 gap
        canvas = Image.new("RGB", (page_w, page_h), (60, 60, 60))

        for i, (stem, left, right) in enumerate(page_pairs):
            col = i % COLS
            row = i // COLS
            x = 4 + col * (THUMB_W * 2 + 4)
            y = 4 + row * (THUMB_H + 24 + 4)

            left_t  = make_thumb(left,  THUMB_W, THUMB_H)
            right_t = make_thumb(right, THUMB_W, THUMB_H)
            left_t  = add_label(left_t,  f"{stem} orig")
            right_t = add_label(right_t, f"{stem} comp")

            canvas.paste(left_t,  (x, y))
            canvas.paste(right_t, (x + THUMB_W + 2, y))

        out_path = OUT_DIR / f"review_page{page_idx+1:02d}.jpg"
        canvas.save(out_path, "JPEG", quality=88)
        print(f"  → {out_path}  ({n} pairs)")

    for im_path in pairs:
        stem = im_path.stem
        gt_path = GT_DIR / f"{stem}.png"
        if not gt_path.exists():
            print(f"  WARN: no GT for {stem}, skipping")
            continue

        orig  = Image.open(im_path).convert("RGB")
        comp  = composite_on_white(im_path, gt_path)
        page_pairs.append((stem, orig, comp))

        if len(page_pairs) == PAIRS_PER_PAGE:
            flush_page(page_pairs, page_idx)
            page_idx += 1
            page_pairs = []

    if page_pairs:
        flush_page(page_pairs, page_idx)

    print(f"\nDone. Review images saved to {OUT_DIR}/")
    print("Open them and look for any pair where the composite car doesn't match the original.")

if __name__ == "__main__":
    main()
