#
# @File: training_tools.py
# @Author: @osamabinishrat
# @Date: Thu Jun 26 2025
#
# Copyright (c) 2025 FlexDealer
#

import numpy as np
import tensorflow as tf

def np_get_angle_from_sin_cos(y_sincos: np.ndarray) -> np.ndarray:
    return np.arctan2(y_sincos[:, 0], y_sincos[:, 1])

def tf_align_pred_angle(y_true_angle: tf.Tensor, y_pred_angle: tf.Tensor) -> tf.Tensor:
    positive_cond = tf.cast(y_true_angle - y_pred_angle > np.pi, tf.float32)
    negative_cond = tf.cast(y_true_angle - y_pred_angle < -np.pi, tf.float32)
    y_pred_angle += positive_cond * np.pi * 2
    y_pred_angle -= negative_cond * np.pi * 2
    return y_pred_angle

def tf_mean_absolute_angle_error(y_true_angle: tf.Tensor, y_pred_angle: tf.Tensor) -> tf.Tensor:
    y_pred_angle = tf_align_pred_angle(y_true_angle, y_pred_angle)
    return tf.reduce_mean(tf.abs(y_true_angle - y_pred_angle)) / np.pi * 180

def tf_mean_absolute_angle_error_sin_cos_output(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    y_true_angle = tf.atan2(y_true[:, 0], y_true[:, 1])
    y_pred_angle = tf.atan2(y_pred[:, 0], y_pred[:, 1])
    return tf_mean_absolute_angle_error(y_true_angle, y_pred_angle)

def tf_rmse_angle_sin_cos_output(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    y_true_angle = tf.atan2(y_true[:, 0], y_true[:, 1])
    y_pred_angle = tf.atan2(y_pred[:, 0], y_pred[:, 1])
    y_pred_angle = tf_align_pred_angle(y_true_angle, y_pred_angle)
    return tf.sqrt(tf.reduce_mean(tf.square(y_pred_angle - y_true_angle))) * 180 / np.pi

def tf_r2_angle_score_sin_cos_output(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    y_true_angle = tf.atan2(y_true[:, 0], y_true[:, 1])
    y_pred_angle = tf.atan2(y_pred[:, 0], y_pred[:, 1])
    y_pred_angle = tf_align_pred_angle(y_true_angle, y_pred_angle)
    y_true_mean = tf.reduce_mean(y_true_angle)
    return 1 - tf.reduce_sum(tf.square(y_true_angle - y_pred_angle)) / tf.reduce_sum(tf.square(y_true_angle - y_true_mean))

def tf_median_absolute_angle_error_sin_cos_output(y_true: tf.Tensor, y_pred: tf.Tensor):
    y_true_angle = tf.atan2(y_true[:, 0], y_true[:, 1])
    y_pred_angle = tf.atan2(y_pred[:, 0], y_pred[:, 1])
    y_pred_angle = tf_align_pred_angle(y_true_angle, y_pred_angle)
    return tf.experimental.numpy.percentile(tf.abs(y_true_angle - y_pred_angle), 50) / np.pi * 180

def tf_acc_pi_6_sin_cos_output(y_true: tf.Tensor, y_pred: tf.Tensor):
    y_true_angle = tf.atan2(y_true[:, 0], y_true[:, 1])
    y_pred_angle = tf.atan2(y_pred[:, 0], y_pred[:, 1])
    y_pred_angle = tf_align_pred_angle(y_true_angle, y_pred_angle)
    return tf.reduce_mean(tf.cast(tf.abs(y_true_angle - y_pred_angle) < (np.pi / 6), tf.float32))

def horizontal_flip_pose_sin_cos_output(pose):
    return [-pose[0], pose[1]]
