#
# @File: extract_frames.py
# @Author: @osamabinishrat
# @Date: Thu Jun 26 2025
#
# Copyright (c) 2025 FlexDealer
#

import os
import sys
import cv2
import argparse
import json
import io
import numpy as np
from PIL import Image
from datetime import datetime
import tensorflow as tf
from keras.models import load_model
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.applications.resnet50 import preprocess_input
import tensorflow_hub as hub
import absl.logging
import logging
import random
import string
import time
import requests
from PIL import ImageOps
from io import BytesIO

# === Silence unnecessary TensorFlow logs for cleaner output ===
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
absl.logging.set_verbosity(absl.logging.ERROR)
logging.getLogger('tensorflow').setLevel(logging.ERROR)
sys.stderr = open(os.devnull, 'w')

# === Constants and model paths ===
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CLASS_MODEL_PATH = os.path.join(BASE_DIR, "../models/liftkit_custom_trained_model.keras")
CLASS_NAMES_FILE = os.path.join(BASE_DIR, "class_names.txt")

# Load classification model
class_model = load_model(CLASS_MODEL_PATH)

# Load class names
with open(CLASS_NAMES_FILE) as f:
    class_names = [line.strip() for line in f.readlines()]

# === Import custom training utilities used in pose model ===
from training_tools import (
    np_get_angle_from_sin_cos,
    tf_mean_absolute_angle_error_sin_cos_output,
    tf_rmse_angle_sin_cos_output,
    tf_r2_angle_score_sin_cos_output,
    tf_median_absolute_angle_error_sin_cos_output,
    tf_acc_pi_6_sin_cos_output,
    horizontal_flip_pose_sin_cos_output,
)

IMG_SIZE = (224, 224)

# Angle ranges mapped to viewpoint labels
LABEL_RANGES = {
    "eff": (0, 1), "efd": (20, 30), "esd": (70, 80), "erd": (150, 160),
    "erf": (180, 185), "erp": (190, 210), "esp": (260, 280), "efp": (330, 340),
}
LABELS = list(LABEL_RANGES.keys())

def get_label(angle_deg):
    """
    Determine the viewpoint label based on predicted angle.
    """
    for label, (min_deg, max_deg) in LABEL_RANGES.items():
        if min_deg <= angle_deg < max_deg:
            return label
    return None

def load_pose_model(model_path):
    """
    Load a pose estimation model with required custom objects.
    """
    return load_model(model_path, custom_objects={
        "KerasLayer": hub.KerasLayer,
        "tf_mean_absolute_angle_error_sin_cos_output": tf_mean_absolute_angle_error_sin_cos_output,
        "tf_rmse_angle_sin_cos_output": tf_rmse_angle_sin_cos_output,
        "tf_r2_angle_score_sin_cos_output": tf_r2_angle_score_sin_cos_output,
        "tf_median_absolute_angle_error_sin_cos_output": tf_median_absolute_angle_error_sin_cos_output,
        "tf_acc_pi_6_sin_cos_output": tf_acc_pi_6_sin_cos_output,
        "horizontal_flip_pose_sin_cos_output": horizontal_flip_pose_sin_cos_output,
    })

def init(video_path, output_dir, model_path, frame_every_n=15, log=False, imgixServer = "liftkit.imgix.net", dealer = "", vehicle_id = ""):
    """
    Extracts frames from a video, classifies their viewpoint using pose estimation,
    and saves the best image per viewpoint label.
    """
    os.makedirs(output_dir, exist_ok=True)
    model = load_pose_model(model_path)

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        sys.exit(1)

    saved_file_registry = {}
    frame_id = 0
    temp_output_dir = os.path.join(output_dir, "temp/")
    os.makedirs(temp_output_dir, exist_ok=True)

    # Loop through video frames
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Process only every Nth frame
        if frame_id % frame_every_n == 0:
            try:
                image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                pil_img_resized = Image.fromarray(image_rgb).resize((224, 224))
                input_tensor = tf.convert_to_tensor(np.array(pil_img_resized), dtype=tf.float32)
                input_tensor = tf.expand_dims(input_tensor, axis=0)

                prediction = model.predict(input_tensor, verbose=0)
                angle_rad = np_get_angle_from_sin_cos(prediction)[0]
                angle_deg = float(angle_rad / np.pi * 180)
                angle_deg = (angle_deg + 180) % 360 - 180  # Normalize to [-180, 180)
                angle_deg = (360 + round(angle_deg)) if angle_deg < 0 else round(angle_deg)

                label = get_label(angle_deg)

                if label:
                    filename = f"{round(angle_deg)}_{label}_{generate_random_string(4)}.jpg"
                    full_path = os.path.join(temp_output_dir, filename)

                    Image.fromarray(image_rgb).save(full_path)

                    if label not in saved_file_registry:
                        saved_file_registry[label] = []

                    saved_file_registry[label].append({
                        "path": full_path,
                        "filename": filename
                    })

            except Exception as e:
                if log:
                    print(f"⚠️ Error on frame {frame_id}: {e}")

        frame_id += 1

    cap.release()

    
    finalImages = {}

    # Classify best image from collected set for each label
    for label, image_list in saved_file_registry.items():
        best_img = find_best_shot_by_label(image_list, label)
        finalImages[label] = f"https://{imgixServer}/{dealer}/{vehicle_id}/temp/{best_img['filename']}?bg-remove=true&trim=auto&trim-md=0&bg-remove-fg-type=car"

    remove_background_and_save(finalImages,output_dir, temp_output_dir)


def remove_background_and_save(images, output_dir, temp_output_dir):
    label_map = []
   
    for label, url in images.items():
        response = requests.get(url)
        response.raise_for_status()

        vehicle_image = Image.open(BytesIO(response.content)).convert("RGBA")

        vehicle_width, vehicle_height = vehicle_image.size

        timestamp = datetime.now().strftime("%m%d%H")
        filebasename = f"{label}-{timestamp}"

        for size_label, (canvas_width, canvas_height) in {
            "large": (1600, 1200),
            "small": (460, 345)
        }.items():
            # Fit image to canvas
            fixed_height = 640
            aspect_ratio = vehicle_width / vehicle_height
            new_height = fixed_height if size_label == "large" else 184
            new_width = int(aspect_ratio * new_height)

            # Resize with proper resampling
            try:
                resample = Image.Resampling.LANCZOS
            except AttributeError:
                resample = Image.ANTIALIAS  # For Pillow < 10

            resized_vehicle = vehicle_image.resize((new_width, new_height), resample)

            # Create black canvas
            black_canvas = Image.new("RGBA", (canvas_width, canvas_height), (0, 0, 0, 255))

            # Center the image
            x = (canvas_width - new_width) // 2
            y = (canvas_height - new_height) // 2
            black_canvas.paste(resized_vehicle, (x, y), resized_vehicle)

            # Center the image
            x = (canvas_width - new_width) // 2
            y = (canvas_height - new_height) // 2
            black_canvas.paste(resized_vehicle, (x, y), resized_vehicle)

            # Convert to RGB and save
            final_image = black_canvas.convert("RGB")
            filename = f"{filebasename}-{size_label}.jpg"
            # filename = f"{filebasename}.jpg"
            output_path = os.path.join(output_dir, filename)
            final_image.save(output_path, format="JPEG", quality=100)

        label_map.append({"class": label, "filename": filebasename})

    time.sleep(5)
    # Clean up temporary images
    for fname in os.listdir(temp_output_dir):
        os.remove(os.path.join(temp_output_dir, fname))

    print(json.dumps({"status": "success", "classes": label_map}))
     
def find_best_shot_by_label(images, currentClass):
    """
    Select the best image for a given label using a classification model.
    Returns the image path with the highest class confidence.
    """
    best_score = -1
    best_img = images[get_center_of_list(images)]  # Default fallback

    for image in images:
        img = Image.open(image['path']).resize(IMG_SIZE)
        x = keras_image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x)

        prediction = class_model.predict(x, verbose=0)[0]
        top_indices = prediction.argsort()[-3:][::-1]  # Top 3 predictions

        # Convert top indices to class names
        top_classes = [class_names[i] for i in top_indices]

        # Only proceed if current class is in top 3
        if currentClass not in top_classes:
            continue

        # Get confidence score for current class
        class_index = class_names.index(currentClass)
        score = float(prediction[class_index].item() if hasattr(prediction[class_index], 'item') else prediction[class_index])

        if score > best_score:
            best_score = score
            best_img = image

    return best_img

def get_center_of_list(arr):
    """
    Return the index of the center element from a list.
    """
    return len(arr) // 2

def generate_random_string(length):
    """
    Generate a random alphanumeric string of specified length.
    """
    characters = string.ascii_letters + string.digits
    random_string = ''.join(random.choice(characters) for _ in range(length))
    return random_string

def log(msg):
    with open("/tmp/python_debug.log", "a") as f:
        f.write(f"{datetime.now()} - {msg}\n")

def unit_test():
    images = {
        "eff": "https://liftkit-dev4.imgix.net/sd2001/1001282/temp/0_eff_f9ra.jpg?bg-remove=true&trim=auto&trim-md=0",
        "efd": "https://liftkit-dev4.imgix.net/sd2001/1001282/temp/29_efd_Bfpy.jpg?bg-remove=true&trim=auto&trim-md=0",
        "esd": "https://liftkit-dev4.imgix.net/sd2001/1001282/temp/78_esd_z6U4.jpg?bg-remove=true&trim=auto&trim-md=0",
        "erd": "https://liftkit-dev4.imgix.net/sd2001/1001282/temp/154_erd_75xL.jpg?bg-remove=true&trim=auto&trim-md=0",
        "erf": "https://liftkit-dev4.imgix.net/sd2001/1001282/temp/181_erf_gOUY.jpg?bg-remove=true&trim=auto&trim-md=0",
        "erp": "https://liftkit-dev4.imgix.net/sd2001/1001282/temp/200_erp_P9yj.jpg?bg-remove=true&trim=auto&trim-md=0",
        "esp": "https://liftkit-dev4.imgix.net/sd2001/1001282/temp/263_esp_wJWJ.jpg?bg-remove=true&trim=auto&trim-md=0",
        "efp": "https://liftkit-dev4.imgix.net/sd2001/1001282/temp/339_efp_Y2xz.jpg?bg-remove=true&trim=auto&trim-md=0",
    }
    temp_output_dir = os.path.join(output_dir, "temp/")
    remove_background_and_save(images, output_dir, temp_output_dir)

if __name__ == "__main__":
    # === CLI Argument Handling ===
    parser = argparse.ArgumentParser(description="Extract 1 best frame per view from car video using pose estimation")
    parser.add_argument("video_path", help="Path to the input video (e.g. `video.mp4`)")
    parser.add_argument("--output", required=True, help="Path to output directory (flat)")
    parser.add_argument("--frame_every_n", type=int, default=15, help="Process every Nth frame (default: 15)")
    parser.add_argument("--log", action="store_true", help="Verbose logging")
    parser.add_argument("--server", help="Required for imgix background removal", required=True)
    parser.add_argument("--dealer", help="Required for imgix background removal", required=True)
    parser.add_argument("--vehicle_id", help="Required for imgix background removal", required=True)

    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    args = parser.parse_args()
    output_dir = os.path.expanduser(args.output)
    model_path = os.path.join(BASE_DIR, "../models/efficientnetb0_approach2.h5")
    
    # unit_test()
    init(args.video_path, output_dir, model_path, args.frame_every_n, args.log, args.server, args.dealer, args.vehicle_id)
