#!/usr/bin/env python3
#
# @File: test-finetuned.py
# @Date: 2026-06-03
#
# Visually test a fine-tuned BiRefNet checkpoint.
# Composites each output mask over a red background so window transparency is obvious.
#
# Usage:
#   python3 tools/internal/test-finetuned.py \
#       --checkpoint models/birefnet-vehicle/birefnet-vehicle-best.pt \
#       --images dataset/im/m002.jpg dataset/im/0264.jpg \
#       --out /tmp/birefnet-test

import argparse
import sys
from pathlib import Path

import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np

MODEL_ID = "ZhengPeng7/BiRefNet"
IMG_SIZE  = 1024   # inference at full resolution


def load_model(checkpoint_path: Path, device: torch.device):
    from transformers import AutoModelForImageSegmentation
    print(f"Loading base model {MODEL_ID} ...")
    model = AutoModelForImageSegmentation.from_pretrained(MODEL_ID, trust_remote_code=True)
    print(f"Loading checkpoint {checkpoint_path} ...")
    state = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state)
    model.to(device)
    model.eval()
    return model


def infer(model, img_path: Path, device: torch.device) -> Image.Image:
    img = Image.open(img_path).convert("RGB")
    orig_w, orig_h = img.size

    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    inp = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(inp)

    # Find the largest 4-D tensor (final prediction)
    tensors = []
    def collect(x):
        if isinstance(x, torch.Tensor) and x.ndim == 4:
            tensors.append(x)
        elif isinstance(x, (list, tuple)):
            for item in x:
                if item is not None:
                    collect(item)
    collect(outputs)
    pred = max(tensors, key=lambda t: t.shape[-1] * t.shape[-2])

    # Resize back to original, convert to 0-255 alpha
    pred = F.interpolate(pred, size=(orig_h, orig_w), mode="bilinear", align_corners=False)
    alpha = (torch.sigmoid(pred).squeeze().cpu().numpy() * 255).astype(np.uint8)
    return Image.fromarray(alpha, mode="L"), img


def apply_shadow_mode(orig: Image.Image, alpha: Image.Image, window_grey: int = 50) -> Image.Image:
    """
    Replace glass/window areas with dark tinted glass.

      1. Holes   — completely transparent pixels inside the car silhouette
                   (definite glass — fill with dark glass at alpha 180).
      2. Fringe  — semi-transparent interior pixels (alpha 15–230)
                   (glass edges — dark glass, keep model alpha for soft transitions).

    NOTE: windshield pixels the model still predicts at alpha=255 are NOT caught
    here; the training run with unfrozen backbone is the fix for those.
    """
    from scipy import ndimage as ndi

    orig_np  = np.array(orig.convert("RGB"), dtype=np.uint8)
    alpha_np = np.array(alpha, dtype=np.uint8)

    silhouette   = alpha_np > 15
    filled       = ndi.binary_fill_holes(silhouette)

    holes        = filled & ~silhouette
    glass_fringe = filled & (alpha_np > 15) & (alpha_np < 230)
    win_mask     = holes | glass_fringe

    out_rgba = np.zeros((orig_np.shape[0], orig_np.shape[1], 4), dtype=np.uint8)

    body_mask = silhouette & ~win_mask
    out_rgba[body_mask, :3] = orig_np[body_mask]
    out_rgba[body_mask,  3] = alpha_np[body_mask]

    out_rgba[holes,        :3] = window_grey
    out_rgba[holes,         3] = 180
    out_rgba[glass_fringe, :3] = window_grey
    out_rgba[glass_fringe,  3] = alpha_np[glass_fringe]

    return Image.fromarray(out_rgba, "RGBA")


def composite_over_color(orig: Image.Image, alpha: Image.Image, bg_color=(255, 0, 0)) -> Image.Image:
    rgba = orig.copy().convert("RGBA")
    rgba.putalpha(alpha)
    bg = Image.new("RGBA", orig.size, bg_color + (255,))
    bg.paste(rgba, mask=alpha)
    return bg.convert("RGB")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", required=True)
    parser.add_argument("--images",     nargs="+", default=[])
    parser.add_argument("--out",        default="/tmp/birefnet-test")
    args = parser.parse_args()

    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
    print(f"Device: {device}")

    out_dir = Path(args.out)
    out_dir.mkdir(parents=True, exist_ok=True)

    model = load_model(Path(args.checkpoint), device)

    images = args.images
    if not images:
        # Default: run on all dataset images
        images = sorted(Path("dataset/im").glob("*.jpg"))

    for img_path in images:
        img_path = Path(img_path)
        print(f"  {img_path.name} ...", end=" ", flush=True)
        alpha, orig = infer(model, img_path, device)

        # Save raw alpha mask
        alpha.save(out_dir / f"{img_path.stem}_mask.png")

        # Shadow-mode RGBA — window RGB replaced with dark grey (no old-background ghosting)
        shadow_rgba = apply_shadow_mode(orig, alpha, window_grey=50)
        shadow_rgba.save(out_dir / f"{img_path.stem}_shadow.png")

        # Composite shadow PNG over a showroom-style background to preview
        comp = composite_over_color(orig, alpha)
        comp.save(out_dir / f"{img_path.stem}_red.jpg", quality=90)

        # Side-by-side: original | shadow composite over white
        white_bg = Image.new("RGBA", orig.size, (255, 255, 255, 255))
        white_bg.paste(shadow_rgba, mask=shadow_rgba.split()[3])
        side = Image.new("RGB", (orig.width * 2, orig.height))
        side.paste(orig, (0, 0))
        side.paste(white_bg.convert("RGB"), (orig.width, 0))
        side.save(out_dir / f"{img_path.stem}_compare.jpg", quality=85)

        print("done")

    print(f"\nResults saved to {out_dir}/")
    print("  *_mask.png   — raw alpha (white=car body, black=window/background)")
    print("  *_red.jpg    — car composited over red (windows should show red)")
    print("  *_compare.jpg— original | red composite side-by-side")


if __name__ == "__main__":
    main()
