#!/usr/bin/env python3
#
# @File: finetune-birefnet.py
# @Date: 2026-06-01
#
# Fine-tune BiRefNet to correctly handle car window / windshield transparency.
#
# Fine-tuning strategy:
#   - Freeze bottom 2 stages of Swin Transformer backbone (low-level feature extraction)
#   - Train upper backbone stages + full decoder (learns glass-specific semantics)
#   - Batch size 1, gradient clipping — compatible with Mac M-series MPS
#   - Loss: L1 on sigmoid output vs soft alpha target (better for semi-transparent windows)
#   - Augmentation: horizontal flip + random crop (prevents overfitting on small dataset)
#
# Install:
#   pip install torch torchvision transformers Pillow numpy
#
# Usage:
#   python3 tools/internal/finetune-birefnet.py --dataset dataset/
#   python3 tools/internal/finetune-birefnet.py --dataset dataset/ --epochs 50
#   python3 tools/internal/finetune-birefnet.py --dataset dataset/ --size 512   ← if OOM at 1024

import os
import sys
import time
import argparse
import random
from pathlib import Path

# Force line-buffered stdout so training progress appears immediately when piped.
sys.stdout.reconfigure(line_buffering=True)

# BiRefNet decoder uses deformable convolution which is not implemented on MPS.
# Enable CPU fallback so the backward pass for that op runs on CPU while the
# rest of the forward/backward graph stays on MPS.
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np

MODEL_ID = "ZhengPeng7/BiRefNet-massive"


# ──────────────────────────────────────────────────────
# Dataset
# ──────────────────────────────────────────────────────
class VehicleDataset(Dataset):
    """
    Pairs of (original JPEG, binary alpha mask PNG).
    Mask convention: 255 = car body (foreground), 0 = windows + background (transparent).
    """

    def __init__(self, dataset_dir: Path, img_size: int = 1024, augment: bool = True):
        self.im_dir   = dataset_dir / "im"
        self.gt_dir   = dataset_dir / "gt"
        self.img_size = img_size
        self.augment  = augment

        stems = sorted(p.stem for p in self.im_dir.glob("*.jpg"))
        self.items = [s for s in stems if (self.gt_dir / f"{s}.png").exists()]

        if not self.items:
            raise FileNotFoundError(
                f"No image+mask pairs found in {dataset_dir}. "
                "Run prepare-training-data.py first."
            )
        print(f"  Dataset: {len(self.items)} pairs  (img_size={img_size})")

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx: int):
        stem = self.items[idx]
        img  = Image.open(self.im_dir / f"{stem}.jpg").convert("RGB")
        mask = Image.open(self.gt_dir / f"{stem}.png").convert("L")

        if self.augment:
            img, mask = self._augment(img, mask)

        img_t  = self._img_transform(img)
        mask_t = self._mask_transform(mask)
        return img_t, mask_t

    def _augment(self, img: Image.Image, mask: Image.Image):
        # Horizontal flip
        if random.random() > 0.5:
            img  = img.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)

        # Random crop (90–100% scale)
        scale = random.uniform(0.90, 1.00)
        w, h  = img.size
        cw, ch = max(1, int(w * scale)), max(1, int(h * scale))
        x0 = random.randint(0, w - cw)
        y0 = random.randint(0, h - ch)
        img  = img.crop((x0, y0, x0 + cw, y0 + ch))
        mask = mask.crop((x0, y0, x0 + cw, y0 + ch))

        # Colour jitter (image only)
        img = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.15)(img)
        return img, mask

    def _img_transform(self, img: Image.Image) -> torch.Tensor:
        return transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])(img)

    def _mask_transform(self, mask: Image.Image) -> torch.Tensor:
        # Keep soft alpha values (0–1): car body ≈ 1.0, semi-transparent windows ≈ 0.4–0.6,
        # background = 0.0. BCE handles soft targets natively and the model will learn to
        # output varying alpha for windows rather than treating them as hard foreground.
        return transforms.Compose([
            transforms.Resize(
                (self.img_size, self.img_size),
                interpolation=transforms.InterpolationMode.BILINEAR
            ),
            transforms.ToTensor(),
        ])(mask)  # (1, H, W) float 0–1, no binarisation


# ──────────────────────────────────────────────────────
# Loss
# ──────────────────────────────────────────────────────
def training_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Plain L1 loss on sigmoid output vs soft alpha GT (0–1).

    The GT masks already encode window transparency as fractional alpha values
    (body≈1.0, glass≈0.4–0.7, background≈0.0). A uniform L1 loss is sufficient
    — no window weighting needed. Window weighting was found to cause gradient
    instability on small datasets, making training collapse after a few epochs.
    """
    return torch.abs(torch.sigmoid(pred) - target).mean()


# ──────────────────────────────────────────────────────
# Model helpers
# ──────────────────────────────────────────────────────
def freeze_backbone_bottom(model: nn.Module, freeze_stages: int = 2) -> None:
    """
    Freeze the bottom N stages of the Swin Transformer backbone (model.bb).
    Stages 0-1 = low-level edges/textures (freeze always).
    Stages 2-3 = semantic features (unfreeze to learn glass vs body).
    freeze_stages=4 freezes the whole backbone; freeze_stages=2 leaves top 2 trainable.
    """
    backbone = getattr(model, "bb", None)
    if backbone is None:
        print("  WARNING: model.bb not found — all parameters trainable (may not converge)")
        return

    # Always freeze patch embedding + early layers
    for param in backbone.parameters():
        param.requires_grad = False

    # Unfreeze the top (freeze_stages .. end) Swin stages so semantic layers adapt
    layers = getattr(backbone, "layers", [])
    for stage_idx, stage in enumerate(layers):
        if stage_idx >= freeze_stages:
            for param in stage.parameters():
                param.requires_grad = True

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
    total     = sum(p.numel() for p in model.parameters()) / 1e6
    frozen    = total - trainable
    print(f"  Backbone stages 0-{freeze_stages-1} frozen, {len(layers)-freeze_stages} top stages trainable")
    print(f"  Trainable: {trainable:.1f}M / {total:.1f}M params  (frozen: {frozen:.1f}M)")


def get_model_prediction(outputs) -> torch.Tensor:
    """
    BiRefNet output structure differs between eval and train mode.
    Robustly find the prediction tensor with the largest spatial resolution.
    """
    tensors = []

    def collect(x):
        if isinstance(x, torch.Tensor) and x.ndim == 4:
            tensors.append(x)
        elif isinstance(x, (list, tuple)):
            for item in x:
                if item is not None:
                    collect(item)

    collect(outputs)
    if not tensors:
        raise RuntimeError(f"No 4-D tensor found in model outputs: {type(outputs)}")
    return max(tensors, key=lambda t: t.shape[-1] * t.shape[-2])


def compute_metrics(pred_logits: torch.Tensor, target: torch.Tensor) -> dict:
    """Compute IoU and accuracy on a batch."""
    p = (torch.sigmoid(pred_logits) > 0.5).float()
    t = (target > 0.5).float()
    inter = (p * t).sum()
    union = (p + t).clamp(max=1).sum()
    iou   = (inter / (union + 1e-6)).item()
    acc   = (p == t).float().mean().item()
    return {"iou": iou, "acc": acc}


# ──────────────────────────────────────────────────────
# Training
# ──────────────────────────────────────────────────────
def train(args):
    # Device selection
    if args.device:
        device = torch.device(args.device)
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"Device : {device}\n")

    out_dir = Path(args.out)
    out_dir.mkdir(parents=True, exist_ok=True)

    # ── Model ──
    print(f"Loading {MODEL_ID} ...")
    from transformers import AutoModelForImageSegmentation
    model = AutoModelForImageSegmentation.from_pretrained(MODEL_ID, trust_remote_code=True)

    if args.resume:
        resume_path = Path(args.resume)
        print(f"Resuming from checkpoint: {resume_path} ...")
        state = torch.load(resume_path, map_location="cpu")
        model.load_state_dict(state)

    model.to(device)

    print("Configuring trainable layers ...")
    freeze_backbone_bottom(model, freeze_stages=args.freeze_stages)

    # ── Dataset ──
    print("\nLoading dataset ...")
    dataset = VehicleDataset(Path(args.dataset), img_size=args.size, augment=True)

    val_n   = max(2, int(len(dataset) * 0.20))
    train_n = len(dataset) - val_n
    train_set, val_set = random_split(
        dataset, [train_n, val_n],
        generator=torch.Generator().manual_seed(42)
    )
    # Disable augmentation for val
    val_dataset = VehicleDataset(Path(args.dataset), img_size=args.size, augment=False)
    val_indices = val_set.indices
    val_set = torch.utils.data.Subset(val_dataset, val_indices)

    train_loader = DataLoader(train_set, batch_size=1, shuffle=True,  num_workers=0)
    val_loader   = DataLoader(val_set,   batch_size=1, shuffle=False, num_workers=0)
    print(f"  Train: {train_n}  Val: {val_n}\n")

    # ── Optimizer ──
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=1e-4)

    # ── Resume: restore scheduler state and best loss from existing log ──
    start_epoch   = 1
    best_val_loss = float("inf")
    log_path      = out_dir / "training_log.csv"

    if args.resume and log_path.exists():
        import csv
        with open(log_path) as f:
            rows = list(csv.DictReader(f))
        if rows:
            last_epoch    = int(rows[-1]["epoch"])
            start_epoch   = last_epoch + 1
            best_val_loss = min(float(r["val_loss"]) for r in rows)
            print(f"  Resuming from epoch {start_epoch}  (best val so far: {best_val_loss:.4f})\n")

    total_epochs = start_epoch - 1 + args.epochs

    # Cosine schedule spans the full training run so LR is correct at resume point
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=total_epochs, eta_min=1e-6
    )
    # Advance scheduler to where we left off (no-op for fresh runs)
    for _ in range(start_epoch - 1):
        scheduler.step()

    # ── Training loop ──
    print(f"Training epochs {start_epoch}–{total_epochs} (resolution={args.size}x{args.size}, lr={args.lr})\n")
    print(f"{'Epoch':>6}  {'Train Loss':>10}  {'Val Loss':>8}  {'Val IoU':>8}  {'Val Acc':>8}  {'Time':>6}")
    print("-" * 60)

    log_mode = "a" if args.resume and log_path.exists() else "w"
    with open(log_path, log_mode) as log:
        if log_mode == "w":
            log.write("epoch,train_loss,val_loss,val_iou,val_acc\n")

        for epoch in range(start_epoch, total_epochs + 1):
            t0 = time.time()

            # Train in eval mode: BiRefNet's train-mode output is a nested structure
            # of auxiliary heads; eval mode returns the clean final prediction.
            # Gradients still flow to trainable params since we don't use no_grad().
            # BN running stats also handle batch_size=1 correctly in eval mode.
            model.eval()
            train_loss = 0.0
            for imgs, masks in train_loader:
                imgs, masks = imgs.to(device), masks.to(device)
                optimizer.zero_grad()

                outputs = model(imgs)
                preds   = get_model_prediction(outputs)

                if preds.shape[-2:] != masks.shape[-2:]:
                    preds = F.interpolate(preds, size=masks.shape[-2:],
                                          mode="bilinear", align_corners=False)

                loss = training_loss(preds, masks)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item()

            train_loss /= len(train_loader)

            # Validate
            model.eval()
            val_loss = val_iou = val_acc = 0.0
            with torch.no_grad():
                for imgs, masks in val_loader:
                    imgs, masks = imgs.to(device), masks.to(device)
                    outputs = model(imgs)
                    preds   = get_model_prediction(outputs)
                    if preds.shape[-2:] != masks.shape[-2:]:
                        preds = F.interpolate(preds, size=masks.shape[-2:],
                                              mode="bilinear", align_corners=False)
                    val_loss += training_loss(preds, masks).item()
                    m = compute_metrics(preds, masks)
                    val_iou += m["iou"]
                    val_acc += m["acc"]

            n_val      = max(len(val_loader), 1)
            val_loss  /= n_val
            val_iou   /= n_val
            val_acc   /= n_val
            elapsed    = time.time() - t0

            marker = " ✓" if val_loss < best_val_loss else ""
            print(f"{epoch:6d}  {train_loss:10.4f}  {val_loss:8.4f}  {val_iou:8.3f}  {val_acc:8.3f}  {elapsed:5.0f}s{marker}")
            log.write(f"{epoch},{train_loss:.6f},{val_loss:.6f},{val_iou:.4f},{val_acc:.4f}\n")
            log.flush()

            scheduler.step()

            # Always save latest so it can be tested immediately
            torch.save(model.state_dict(), out_dir / "birefnet-vehicle-latest.pt")

            # Save best checkpoint
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), out_dir / "birefnet-vehicle-best.pt")

            # Periodic checkpoint every 10 epochs
            if epoch % 10 == 0:
                torch.save(model.state_dict(), out_dir / f"birefnet-vehicle-ep{epoch:03d}.pt")

    print("-" * 60)
    print(f"\nDone. Best val loss: {best_val_loss:.4f}")
    print(f"Best model  → {out_dir}/birefnet-vehicle-best.pt")
    print(f"Training log→ {log_path}")
    print(f"\nTo test the fine-tuned model:")
    print(f"  python3 tools/internal/test-finetuned.py --checkpoint {out_dir}/birefnet-vehicle-best.pt")


# ──────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────
def main():
    parser = argparse.ArgumentParser(description="Fine-tune BiRefNet for car window transparency")
    parser.add_argument("--dataset", required=True,       help="Dataset dir with im/ and gt/ folders")
    parser.add_argument("--epochs",  type=int,  default=60,    help="Training epochs (default: 60)")
    parser.add_argument("--lr",      type=float,default=1e-5,  help="Peak learning rate (default: 1e-5)")
    parser.add_argument("--size",    type=int,  default=512,   help="Input resolution (default: 512)")
    parser.add_argument("--out",           default="models/birefnet-vehicle-v3", help="Output directory for checkpoints")
    parser.add_argument("--device",        default=None,  help="Force device: cpu, mps, cuda (default: auto)")
    parser.add_argument("--resume",        default=None,  help="Path to checkpoint to resume from (also reads existing CSV log)")
    parser.add_argument("--freeze-stages", type=int, default=4, help="Freeze bottom N backbone stages (default: 4 = full backbone frozen)")
    args = parser.parse_args()

    print("=" * 60)
    print("  BiRefNet Fine-Tuning — Vehicle Window Transparency")
    print("=" * 60)
    print(f"Dataset       : {args.dataset}")
    print(f"Epochs        : {args.epochs}")
    print(f"LR            : {args.lr}")
    print(f"Size          : {args.size}x{args.size}")
    print(f"Freeze stages : {args.freeze_stages}")
    print(f"Resume        : {args.resume or 'no'}")
    print(f"Output        : {args.out}\n")

    train(args)


if __name__ == "__main__":
    main()
