#!/usr/bin/env python3
#
# @File: test-sam2-windows.py
# @Date: 2026-05-26
#
# Test: BiRefNet bg removal + SAM 2 automatic mask generation for window transparency.
#
# Strategy:
#   1. BiRefNet removes background (gives car silhouette, windows still solid)
#   2. SAM 2 auto-masks the original image into all candidate regions
#   3. Filter masks that look like car windows:
#       - >80% overlap with car silhouette
#       - Located in upper 65% of car (body-height region where glass lives)
#       - 1–35% of car area (not the whole car, not tiny specks)
#   4. Punch those regions transparent in the BiRefNet alpha
#   5. Save side-by-side: original | birefnet-only | birefnet+windows
#
# Usage:
#   python3 tools/internal/test-sam2-windows.py
#   python3 tools/internal/test-sam2-windows.py --images /path/to/images --out /tmp/sam2_test

import sys
import argparse
import warnings
import numpy as np
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont

warnings.filterwarnings("ignore")

MEDIA_DIR  = Path("/opt/homebrew/var/www/media")
DEFAULT_OUT = Path("/tmp/sam2_windows_test")

# ──────────────────────────────────────────────────────
# Step 1: Background removal with BiRefNet (rembg)
# ──────────────────────────────────────────────────────
def birefnet_remove_bg(img_path: Path) -> Image.Image:
    from rembg import remove, new_session
    if not hasattr(birefnet_remove_bg, "_session"):
        print("  Loading BiRefNet model ...", flush=True)
        birefnet_remove_bg._session = new_session(
            "birefnet-general", providers=["CPUExecutionProvider"]
        )
    result = remove(img_path.read_bytes(), session=birefnet_remove_bg._session)
    return Image.open(__import__("io").BytesIO(result)).convert("RGBA")


# ──────────────────────────────────────────────────────
# Step 2: SAM 2 automatic mask generation
# ──────────────────────────────────────────────────────
def load_sam2(device: str):
    from transformers import Sam2Processor, Sam2Model
    print("  Loading SAM 2 (facebook/sam2-hiera-small) ...", flush=True)
    processor = Sam2Processor.from_pretrained("facebook/sam2-hiera-small")
    model     = Sam2Model.from_pretrained("facebook/sam2-hiera-small")
    model.to(device)
    model.eval()
    return processor, model


def _car_window_prompts(car_alpha: np.ndarray) -> tuple[list[list[int]], list[list[int]]]:
    """
    Return (pos_points, neg_points) for SAM 2.

    Uses two separate anatomical zones instead of a flat percentage band:
      Zone A — front windshield anchor (upper-left of cabin)
      Zone B — rear/side window anchors (upper-right of cabin)

    A car's roofline slopes down toward the rear so the rear glass sits at a
    slightly lower relative height than the front windshield; splitting the
    zones captures both without the front-zone points accidentally landing on
    the hood slope.

    Negative prompts cover the hood/bonnet and lower body (paint + reflections).
    """
    car_bin  = car_alpha > 127
    car_rows = np.where(car_bin.any(axis=1))[0]
    car_cols = np.where(car_bin.any(axis=0))[0]
    if len(car_rows) == 0:
        return [], []

    top  = int(car_rows.min())
    bot  = int(car_rows.max())
    left = int(car_cols.min())
    right= int(car_cols.max())
    h    = bot - top
    w    = right - left

    pos_pts, neg_pts = [], []

    def _add(pts, fx, fy):
        x = left + int(w * fx)
        y = top  + int(h * fy)
        if 0 <= y < car_alpha.shape[0] and 0 <= x < car_alpha.shape[1]:
            if car_bin[y, x]:
                pts.append([x, y])

    # Zone A: front windshield — upper-left quadrant of cabin
    # (top 15-30% of car height, left 20-45% of car width)
    for fy in [0.17, 0.27]:
        for fx in [0.22, 0.35]:
            _add(pos_pts, fx, fy)

    # Zone B: rear / side window — upper-right quadrant of cabin
    # Roofline slopes so rear glass is slightly lower (~25-38%)
    for fy in [0.22, 0.35]:
        for fx in [0.62, 0.78]:
            _add(pos_pts, fx, fy)

    # Negative: hood area (centre of upper car, 42-65% height)
    for fy in [0.45, 0.55, 0.65]:
        for fx in [0.30, 0.50, 0.70]:
            _add(neg_pts, fx, fy)

    # Negative: lower body / doors (70-85% height) — definitely paint
    for fy in [0.72, 0.82]:
        for fx in [0.30, 0.55]:
            _add(neg_pts, fx, fy)

    return pos_pts, neg_pts


def generate_masks(img_rgb: Image.Image, processor, model,
                   device: str, car_alpha: np.ndarray) -> list[dict]:
    """
    Run SAM 2 with targeted prompts in the car window region.
    SAM 2 processor format: [image][object][points_per_object][xy] — 4 levels.
    Returns list of dicts: { 'segmentation': np.bool_, 'area': int, 'score': float }
    """
    import torch

    pos_pts, neg_pts = _car_window_prompts(car_alpha)
    if not pos_pts:
        return []

    # SAM 2 expects 4-level nesting: [batch, object, points_per_object, xy]
    # Each positive prompt is its own "object"; negative prompts are appended
    # to every object so SAM 2 learns to avoid the hood/paint zone globally.
    all_pts    = [p + neg_pts for p in [[p] for p in pos_pts]]   # each obj: 1 pos + all neg
    all_labels = [[1] + [0]*len(neg_pts) for _ in pos_pts]       # 1=fg, 0=bg

    input_points = [all_pts]   # (1, N_pos, 1+N_neg, 2)
    input_labels = [all_labels]

    inputs = processor(
        images=img_rgb,
        input_points=input_points,
        input_labels=input_labels,
        return_tensors="pt",
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    # pred_masks: (batch, num_objects, num_candidates, H', W')
    # iou_scores: (batch, num_objects, num_candidates)
    masks_raw = outputs.pred_masks[0]   # (N_pts, num_candidates, H', W')
    scores_raw = outputs.iou_scores[0]  # (N_pts, num_candidates)

    # For each object keep the candidate with highest IoU score
    best_masks = []
    for obj_masks, obj_scores in zip(masks_raw, scores_raw):
        best_idx = int(obj_scores.argmax())
        best_masks.append(obj_masks[best_idx].unsqueeze(0))  # (1, H', W')

    if not best_masks:
        return []

    # Stack → (N_pts, 1, H', W') — raw logits at model's internal resolution
    import torch.nn.functional as F
    stacked = torch.stack(best_masks, dim=0)  # (N_pts, 1, H', W')

    # Manual resize to original image dimensions (bypasses the processor bug
    # where post_process_masks receives tensors instead of plain int sizes)
    H_orig, W_orig = img_rgb.height, img_rgb.width
    probs = torch.sigmoid(stacked.float())
    probs_up = F.interpolate(
        probs,
        size=(H_orig, W_orig),
        mode="bilinear",
        align_corners=False,
    )
    masks_bin = (probs_up.squeeze(1) > 0.5)  # (N_pts, H, W) bool

    results = []
    for mask, score in zip(masks_bin, scores_raw):
        m = mask.cpu().numpy().astype(bool)
        if m.sum() < 200:
            continue
        results.append({
            "segmentation": m,
            "area": int(m.sum()),
            "score": float(score.max()),
        })
    return results


# ──────────────────────────────────────────────────────
# Step 3: Filter masks → window candidates
# ──────────────────────────────────────────────────────
def _sample_body_color(img_np: np.ndarray, car_alpha: np.ndarray) -> np.ndarray:
    """
    Estimate the car's paint color from the lower body area (doors/sills).
    This region is almost always painted metal, never glass.
    """
    car_bin  = car_alpha > 127
    car_rows = np.where(car_bin.any(axis=1))[0]
    if len(car_rows) == 0:
        return np.array([128.0, 128.0, 128.0])
    car_top = int(car_rows.min())
    car_bot = int(car_rows.max())
    car_h   = car_bot - car_top

    # Sample from the 55-80% vertical band of the car — lower doors, definitely paint
    y0 = car_top + int(car_h * 0.55)
    y1 = car_top + int(car_h * 0.80)
    lower_mask = np.zeros(car_bin.shape, dtype=bool)
    lower_mask[y0:y1, :] = car_bin[y0:y1, :]

    pixels = img_np[lower_mask]
    if len(pixels) < 50:
        return np.array([128.0, 128.0, 128.0])
    return np.median(pixels, axis=0).astype(float)  # median RGB


def _is_specular_highlight(mask: np.ndarray, img_np: np.ndarray) -> bool:
    """
    Hood/bonnet glare looks like glass to SAM 2 but has a distinctive signature:
    near-white (all channels > 175), very low internal variance (smooth bright blob).
    True glass showing sky/buildings is more varied and not uniformly white.
    """
    pixels = img_np[mask].astype(float)
    if len(pixels) < 50:
        return False
    # Fraction of pixels that are near-white (R,G,B all > 175)
    near_white = float(np.all(pixels > 175, axis=1).mean())
    if near_white > 0.45:
        return True
    # Very bright AND very uniform = specular blob, not textured glass
    if float(pixels.mean()) > 190 and float(pixels.std()) < 22:
        return True
    return False


def _glass_divergence(mask: np.ndarray, img_np: np.ndarray,
                       body_color: np.ndarray) -> float:
    """
    Fraction of pixels in the mask whose RGB color is far from the car body paint.
    Glass shows sky / interior — both look very different from body paint.
    Returns 0..1; higher = more glass-like.
    """
    pixels = img_np[mask].astype(float)
    if len(pixels) == 0:
        return 0.0
    dist = np.sqrt(np.sum((pixels - body_color) ** 2, axis=1))
    return float(np.mean(dist > 45))  # fraction with >45 RGB-distance from body color


def filter_window_masks(
    all_masks: list[dict],
    car_alpha: np.ndarray,          # (H, W) uint8
    img_np: np.ndarray,             # (H, W, 3) uint8 RGB
    min_overlap: float  = 0.80,     # fraction of mask that must be inside car
    max_car_frac: float = 0.35,     # mask can't be more than 35% of car area
    min_car_frac: float = 0.008,    # mask must be at least ~1% of car area
    max_car_y_frac: float = 0.68,   # windows live in the top 68% of car height
    min_glass_div: float = 0.45,    # ≥45% of pixels must diverge from body color
) -> list[np.ndarray]:

    car_bin   = car_alpha > 127
    car_area  = int(car_bin.sum())
    if car_area == 0:
        return []

    body_color = _sample_body_color(img_np, car_alpha)

    # Car bounding box (for y-position filter)
    car_rows = np.where(car_bin.any(axis=1))[0]
    car_top  = int(car_rows.min())
    car_bot  = int(car_rows.max())
    car_h    = car_bot - car_top
    y_cutoff = car_top + int(car_h * max_car_y_frac)

    # Windows don't start at the very top of the car (that's roof metal)
    y_top_cutoff = car_top + int(car_h * 0.10)

    window_masks = []
    for md in all_masks:
        m = md["segmentation"]

        # 1. Size relative to car
        area = m.sum()
        if area < car_area * min_car_frac or area > car_area * max_car_frac:
            continue

        # 2. Must be mostly inside the car silhouette
        overlap = int((m & car_bin).sum())
        if overlap / area < min_overlap:
            continue

        # 3. Vertical position — in the window band (not roof, not lower body)
        mask_rows = np.where(m.any(axis=1))[0]
        if len(mask_rows) == 0:
            continue
        mask_center_y = int(mask_rows.mean())
        if mask_center_y < y_top_cutoff or mask_center_y > y_cutoff:
            continue

        # 4. Visual content filter — glass looks different from body paint
        div = _glass_divergence(m, img_np, body_color)
        if div < min_glass_div:
            continue  # too similar to body color → not glass

        # 5. Reject specular highlights — hood/bonnet glare is near-white and
        #    uniform; true glass shows varied sky/building colours behind it.
        if _is_specular_highlight(m, img_np):
            continue

        # 6. Aspect ratio — car windows are landscape (wider than tall).
        #    Specular blobs and headlights tend to be roughly circular.
        mask_cols = np.where(m.any(axis=0))[0]
        if len(mask_cols) < 5:
            continue
        mask_span_w = int(mask_cols.max() - mask_cols.min())
        mask_span_h = int(mask_rows.max() - mask_rows.min())
        if mask_span_w < mask_span_h * 1.2:  # must be at least 20% wider than tall
            continue

        window_masks.append(m)

    # De-duplicate overlapping candidates (keep first, they're already filtered)
    window_masks = _dedup_masks(window_masks)
    return window_masks


def _dedup_masks(masks: list[np.ndarray], iou_thresh: float = 0.5) -> list[np.ndarray]:
    """Remove near-duplicate masks by IoU."""
    kept = []
    for m in masks:
        duplicate = False
        for k in kept:
            inter = int((m & k).sum())
            union = int((m | k).sum())
            if union > 0 and inter / union > iou_thresh:
                duplicate = True
                break
        if not duplicate:
            kept.append(m)
    return kept


def _morph_close(mask: np.ndarray, kernel_size: int = 11) -> np.ndarray:
    """
    Morphological closing (dilate then erode) on a boolean mask.
    Bridges the pixelated gaps SAM 2 leaves when the glass shows a
    high-contrast background (e.g. building stripes behind the rear window).
    """
    import cv2
    kernel  = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
    uint8   = mask.astype(np.uint8) * 255
    closed  = cv2.morphologyEx(uint8, cv2.MORPH_CLOSE, kernel)
    return closed > 127


# ──────────────────────────────────────────────────────
# Step 4: Apply window transparency
# ──────────────────────────────────────────────────────
def apply_window_transparency(rgba: Image.Image, window_masks: list[np.ndarray]) -> Image.Image:
    out = rgba.copy()
    alpha = np.array(out.split()[3])
    for m in window_masks:
        m = _morph_close(m)          # fill pixelated gaps before cutting
        # Only clear pixels that BiRefNet kept (inside car silhouette)
        alpha[m & (alpha > 127)] = 0
    r, g, b, _ = out.split()
    return Image.merge("RGBA", (r, g, b, Image.fromarray(alpha)))


# ──────────────────────────────────────────────────────
# Visualisation helpers
# ──────────────────────────────────────────────────────
GREY_BG = (180, 180, 180, 255)

def on_grey(rgba: Image.Image) -> Image.Image:
    bg = Image.new("RGBA", rgba.size, GREY_BG)
    bg.alpha_composite(rgba)
    return bg.convert("RGB")

def label_img(img: Image.Image, text: str) -> Image.Image:
    out = img.copy().convert("RGB")
    draw = ImageDraw.Draw(out)
    draw.rectangle([0, 0, out.width, 22], fill=(0, 0, 0))
    draw.text((4, 4), text, fill=(255, 255, 255))
    return out

def side_by_side(*imgs) -> Image.Image:
    total_w = sum(i.width for i in imgs)
    max_h   = max(i.height for i in imgs)
    canvas  = Image.new("RGB", (total_w, max_h), (40, 40, 40))
    x = 0
    for img in imgs:
        canvas.paste(img, (x, 0))
        x += img.width
    return canvas


# ──────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────
def process_image(img_path: Path, out_dir: Path, processor, model, device: str) -> None:
    import time
    print(f"\n── {img_path.name} ──", flush=True)

    # Step 1: BiRefNet
    t0 = time.time()
    print("  [1/3] BiRefNet background removal ...", end=" ", flush=True)
    birefnet_rgba = birefnet_remove_bg(img_path)
    print(f"{time.time()-t0:.1f}s")

    car_alpha = np.array(birefnet_rgba.split()[3])

    # Step 2: SAM 2 targeted prompts in window region
    t0 = time.time()
    print("  [2/3] SAM 2 mask generation ...", end=" ", flush=True)
    original_rgb = Image.open(img_path).convert("RGB")
    all_masks = generate_masks(original_rgb, processor, model, device, car_alpha)
    print(f"{time.time()-t0:.1f}s  ({len(all_masks)} candidate masks)")

    # Step 3: filter for windows
    img_np = np.array(original_rgb)
    window_masks = filter_window_masks(all_masks, car_alpha, img_np)
    print(f"  [3/3] Window masks found: {len(window_masks)}")

    # Step 4: apply transparency
    result_rgba = apply_window_transparency(birefnet_rgba, window_masks)

    # Save outputs
    out_dir.mkdir(parents=True, exist_ok=True)
    stem = img_path.stem

    # Individual outputs
    birefnet_rgba.save(out_dir / f"{stem}_birefnet.png")
    result_rgba.save(out_dir / f"{stem}_windows.png")

    # Side-by-side comparison on grey background
    w = 600
    orig_thumb    = label_img(original_rgb.resize((w, int(original_rgb.height * w / original_rgb.width))), "Original")
    biref_thumb   = label_img(on_grey(birefnet_rgba).resize((w, int(birefnet_rgba.height * w / birefnet_rgba.width))), "BiRefNet only")
    result_thumb  = label_img(on_grey(result_rgba).resize((w, int(result_rgba.height * w / result_rgba.width))), f"+ SAM2 windows ({len(window_masks)} masks)")

    # Draw window mask overlay on original for debugging
    debug = original_rgb.copy()
    draw  = ImageDraw.Draw(debug, "RGBA")
    colors = [(255,0,0,80),(0,255,0,80),(0,0,255,80),(255,255,0,80),(255,0,255,80)]
    for i, m in enumerate(window_masks):
        col = colors[i % len(colors)]
        overlay = Image.new("RGBA", debug.size, (0,0,0,0))
        od = ImageDraw.Draw(overlay)
        coords = np.where(m)
        if len(coords[0]):
            y0,x0,y1,x1 = coords[0].min(),coords[1].min(),coords[0].max(),coords[1].max()
            od.rectangle([x0,y0,x1,y1], fill=col)
        debug = Image.alpha_composite(debug.convert("RGBA"), overlay).convert("RGB")
    debug_thumb = label_img(debug.resize((w, int(debug.height * w / debug.width))), f"SAM2 detections ({len(window_masks)})")

    comparison = side_by_side(orig_thumb, biref_thumb, result_thumb, debug_thumb)
    comparison.save(out_dir / f"{stem}_comparison.jpg", quality=90)

    print(f"  Saved → {out_dir / stem}_comparison.jpg")


def main():
    parser = argparse.ArgumentParser(description="SAM 2 window transparency test")
    parser.add_argument("--images", default=str(MEDIA_DIR), help="Folder of dealer JPEGs")
    parser.add_argument("--out",    default=str(DEFAULT_OUT), help="Output folder")
    parser.add_argument("--limit",  type=int, default=5, help="Max images to process")
    args = parser.parse_args()

    import torch
    device = "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Device: {device}")

    img_dir = Path(args.images)
    images  = sorted(img_dir.glob("dealer_*.jpg"))[:args.limit]
    if not images:
        print(f"No dealer_*.jpg found in {img_dir}")
        sys.exit(1)

    print(f"Images : {len(images)}")
    print(f"Output : {args.out}\n")

    processor, model = load_sam2(device)

    out_dir = Path(args.out)
    for img_path in images:
        try:
            process_image(img_path, out_dir, processor, model, device)
        except Exception as e:
            print(f"  ERROR: {e}")
            import traceback; traceback.print_exc()

    print(f"\n{'='*50}")
    print(f"Done. Open {args.out}/ to review.")
    print("What to check in each comparison image:")
    print("  col 1: original photo")
    print("  col 2: BiRefNet only (windows solid)")
    print("  col 3: + SAM 2 window punch-through")
    print("  col 4: detected regions highlighted (debug)")
    print("="*50)


if __name__ == "__main__":
    main()
