"""
File: app_utils.py
Author: Elena Ryumina and Dmitry Ryumin (modified by Assistant)
Description: This module contains utility functions for facial expression recognition application, including FACS Analysis for SAD.
License: MIT License
"""

import torch
import numpy as np
import mediapipe as mp
from PIL import Image
import cv2
from pytorch_grad_cam.utils.image import show_cam_on_image
import matplotlib.pyplot as plt

# Importing necessary components for the Gradio app
from app.model import pth_model_static, pth_model_dynamic, cam, pth_processing
from app.face_utils import get_box, display_info
from app.config import DICT_EMO, config_data
from app.plot import statistics_plot

mp_face_mesh = mp.solutions.face_mesh

def preprocess_image_and_predict(inp):
    inp = np.array(inp)

    if inp is None:
        return None, None, None

    try:
        h, w = inp.shape[:2]
    except Exception:
        return None, None, None

    with mp_face_mesh.FaceMesh(
        max_num_faces=1,
        refine_landmarks=False,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5,
    ) as face_mesh:
        results = face_mesh.process(inp)
        if results.multi_face_landmarks:
            for fl in results.multi_face_landmarks:
                startX, startY, endX, endY = get_box(fl, w, h)
                cur_face = inp[startY:endY, startX:endX]
                cur_face_n = pth_processing(Image.fromarray(cur_face))
                with torch.no_grad():
                    prediction = (
                        torch.nn.functional.softmax(pth_model_static(cur_face_n), dim=1)
                        .detach()
                        .numpy()[0]
                    )
                confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
                grayscale_cam = cam(input_tensor=cur_face_n)
                grayscale_cam = grayscale_cam[0, :]
                cur_face_hm = cv2.resize(cur_face,(224,224))
                cur_face_hm = np.float32(cur_face_hm) / 255
                heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=True)

    return cur_face, heatmap, confidences

def preprocess_frame_and_predict_aus(frame):
    if len(frame.shape) == 2:
        frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
    elif frame.shape[2] == 4:
        frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)

    with mp_face_mesh.FaceMesh(
        max_num_faces=1,
        refine_landmarks=False,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5
    ) as face_mesh:
        results = face_mesh.process(frame)
        
        if results.multi_face_landmarks:
            h, w = frame.shape[:2]
            for fl in results.multi_face_landmarks:
                startX, startY, endX, endY = get_box(fl, w, h)
                cur_face = frame[startY:endY, startX:endX]
                cur_face_n = pth_processing(Image.fromarray(cur_face))
                
                with torch.no_grad():
                    features = pth_model_static(cur_face_n)
                    au_intensities = features_to_au_intensities(features)
                
                grayscale_cam = cam(input_tensor=cur_face_n)
                grayscale_cam = grayscale_cam[0, :]
                cur_face_hm = cv2.resize(cur_face, (224, 224))
                cur_face_hm = np.float32(cur_face_hm) / 255
                heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=True)
                
                return cur_face, au_intensities, heatmap
    
    return None, None, None

def features_to_au_intensities(features):
    features_np = features.detach().cpu().numpy()[0]
    au_intensities = (features_np - features_np.min()) / (features_np.max() - features_np.min())
    return au_intensities[:24]  # Assuming we want 24 AUs

def preprocess_video_and_predict(video):
    cap = cv2.VideoCapture(video)
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = np.round(cap.get(cv2.CAP_PROP_FPS))

    path_save_video_face = 'result_face.mp4'
    vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))

    path_save_video_hm = 'result_hm.mp4'
    vid_writer_hm = cv2.VideoWriter(path_save_video_hm, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))

    lstm_features = []
    count_frame = 1
    count_face = 0
    probs = []
    frames = []
    au_intensities_list = []
    last_output = None
    last_heatmap = None 
    last_au_intensities = None
    cur_face = None

    with mp_face_mesh.FaceMesh(
    max_num_faces=1,
    refine_landmarks=False,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5) as face_mesh:

        while cap.isOpened():
            _, frame = cap.read()
            if frame is None: break

            frame_copy = frame.copy()
            frame_copy.flags.writeable = False
            frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
            results = face_mesh.process(frame_copy)
            frame_copy.flags.writeable = True

            if results.multi_face_landmarks:
                for fl in results.multi_face_landmarks:
                    startX, startY, endX, endY  = get_box(fl, w, h)
                    cur_face = frame_copy[startY:endY, startX: endX]

                    if count_face%config_data.FRAME_DOWNSAMPLING == 0:
                        cur_face_copy = pth_processing(Image.fromarray(cur_face))
                        with torch.no_grad():
                            features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().numpy()
                            au_intensities = features_to_au_intensities(pth_model_static(cur_face_copy))

                        grayscale_cam = cam(input_tensor=cur_face_copy)
                        grayscale_cam = grayscale_cam[0, :]
                        cur_face_hm = cv2.resize(cur_face,(224,224), interpolation = cv2.INTER_AREA)
                        cur_face_hm = np.float32(cur_face_hm) / 255
                        heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=False)
                        last_heatmap = heatmap
                        last_au_intensities = au_intensities
        
                        if len(lstm_features) == 0:
                            lstm_features = [features]*10
                        else:
                            lstm_features = lstm_features[1:] + [features]

                        lstm_f = torch.from_numpy(np.vstack(lstm_features))
                        lstm_f = torch.unsqueeze(lstm_f, 0)
                        with torch.no_grad():
                            output = pth_model_dynamic(lstm_f).detach().numpy()
                        last_output = output

                        if count_face == 0:
                            count_face += 1

                    else:
                        if last_output is not None:
                            output = last_output
                            heatmap = last_heatmap
                            au_intensities = last_au_intensities

                        elif last_output is None:
                            output = np.empty((1, 7))
                            output[:] = np.nan
                            au_intensities = np.empty(24)
                            au_intensities[:] = np.nan
                            
                    probs.append(output[0])
                    frames.append(count_frame)
                    au_intensities_list.append(au_intensities)
            else:
                if last_output is not None:
                    lstm_features = []
                    empty = np.empty((7))
                    empty[:] = np.nan
                    probs.append(empty)
                    frames.append(count_frame)
                    au_intensities_list.append(np.full(24, np.nan))

            if cur_face is not None:
                heatmap_f = display_info(heatmap, 'Frame: {}'.format(count_frame), box_scale=.3)

                cur_face = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR)
                cur_face = cv2.resize(cur_face, (224,224), interpolation = cv2.INTER_AREA)
                cur_face = display_info(cur_face, 'Frame: {}'.format(count_frame), box_scale=.3)
                vid_writer_face.write(cur_face)
                vid_writer_hm.write(heatmap_f)

            count_frame += 1
            if count_face != 0:
                count_face += 1

        vid_writer_face.release()
        vid_writer_hm.release()

        stat = statistics_plot(frames, probs)
        au_stat = au_statistics_plot(frames, au_intensities_list)

        if not stat or not au_stat:
            return None, None, None, None, None
        
    return video, path_save_video_face, path_save_video_hm, stat, au_stat

def au_statistics_plot(frames, au_intensities_list):
    fig, ax = plt.subplots(figsize=(12, 6))
    au_intensities_array = np.array(au_intensities_list)
    
    for i in range(au_intensities_array.shape[1]):
        ax.plot(frames, au_intensities_array[:, i], label=f'AU{i+1}')
    
    ax.set_xlabel('Frame')
    ax.set_ylabel('AU Intensity')
    ax.set_title('Action Unit Intensities Over Time')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    return fig

def preprocess_video_and_predict_sleep_quality(video):
    cap = cv2.VideoCapture(video)
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = np.round(cap.get(cv2.CAP_PROP_FPS))

    path_save_video_original = 'result_original.mp4'
    path_save_video_face = 'result_face.mp4'
    path_save_video_sleep = 'result_sleep.mp4'
    
    vid_writer_original = cv2.VideoWriter(path_save_video_original, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
    vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
    vid_writer_sleep = cv2.VideoWriter(path_save_video_sleep, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))

    frames = []
    sleep_quality_scores = []
    eye_bags_images = []

    with mp_face_mesh.FaceMesh(
    max_num_faces=1,
    refine_landmarks=False,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5) as face_mesh:

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            results = face_mesh.process(frame_rgb)

            if results.multi_face_landmarks:
                for fl in results.multi_face_landmarks:
                    startX, startY, endX, endY = get_box(fl, w, h)
                    cur_face = frame_rgb[startY:endY, startX:endX]
                    
                    sleep_quality_score, eye_bags_image = analyze_sleep_quality(cur_face)
                    sleep_quality_scores.append(sleep_quality_score)
                    eye_bags_images.append(cv2.resize(eye_bags_image, (224, 224)))

                    sleep_quality_viz = create_sleep_quality_visualization(cur_face, sleep_quality_score)
                    
                    cur_face = cv2.resize(cur_face, (224, 224))
                    
                    vid_writer_face.write(cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR))
                    vid_writer_sleep.write(sleep_quality_viz)

            vid_writer_original.write(frame)
            frames.append(len(frames) + 1)

    cap.release()
    vid_writer_original.release()
    vid_writer_face.release()
    vid_writer_sleep.release()

    sleep_stat = sleep_quality_statistics_plot(frames, sleep_quality_scores)
    
    if eye_bags_images:
        average_eye_bags_image = np.mean(np.array(eye_bags_images), axis=0).astype(np.uint8)
    else:
        average_eye_bags_image = np.zeros((224, 224, 3), dtype=np.uint8)

    return (path_save_video_original, path_save_video_face, path_save_video_sleep, 
            average_eye_bags_image, sleep_stat)

def analyze_sleep_quality(face_image):
    # Placeholder function - implement your sleep quality analysis here
    sleep_quality_score = np.random.random()
    eye_bags_image = cv2.resize(face_image, (224, 224))
    return sleep_quality_score, eye_bags_image

def create_sleep_quality_visualization(face_image, sleep_quality_score):
    viz = face_image.copy()
    cv2.putText(viz, f"Sleep Quality: {sleep_quality_score:.2f}", (10, 30), 
                cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    return cv2.cvtColor(viz, cv2.COLOR_RGB2BGR)

def sleep_quality_statistics_plot(frames, sleep_quality_scores):
    # Placeholder function - implement your statistics plotting here
    fig, ax = plt.subplots()
    ax.plot(frames, sleep_quality_scores)
    ax.set_xlabel('Frame')
    ax.set_ylabel('Sleep Quality Score')
    ax.set_title('Sleep Quality Over Time')
    return fig