import gradio as gr
from ultralytics import YOLO
import tempfile
import os
import cv2
import numpy as np
import torch
import atexit
import uuid

# Load the YOLOv8 pose estimation model once at the start
model = YOLO("yolov8n-pose.pt")  

# Define the skeleton connections based on COCO keypoints
COCO_KEYPOINTS = [
    "nose", "left_eye", "right_eye", "left_ear", "right_ear",
    "left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
    "left_wrist", "right_wrist", "left_hip", "right_hip",
    "left_knee", "right_knee", "left_ankle", "right_ankle"
]

# Define the skeleton as pairs of keypoints indices
SKELETON_CONNECTIONS = [
    (0, 1), (0, 2),       # Nose to eyes
    (1, 3), (2, 4),       # Eyes to ears
    (0, 5), (0, 6),       # Nose to shoulders
    (5, 6),               # Shoulders to each other
    (5, 7), (6, 8),       # Shoulders to elbows
    (7, 9), (8, 10),      # Elbows to wrists
    (5, 11), (6, 12),     # Shoulders to hips
    (11, 12),             # Hips to each other
    (11, 13), (12, 14),   # Hips to knees
    (13, 15), (14, 16)    # Knees to ankles
]

def calculate_torso_angle(keypoints, frame_height):
    """
    Calculate the angle of the torso with respect to the vertical axis.
    
    Args:
        keypoints (numpy.ndarray): Array of shape (17, 3) representing COCO keypoints.
        frame_height (int): Height of the video frame in pixels.
    
    Returns:
        float: Angle in degrees. Returns None if keypoints are not detected properly.
    """
    try:
        # COCO keypoint indices
        LEFT_SHOULDER = 5
        RIGHT_SHOULDER = 6
        LEFT_HIP = 11
        RIGHT_HIP = 12

        # Extract shoulder and hip coordinates
        left_shoulder = keypoints[LEFT_SHOULDER][:2]
        right_shoulder = keypoints[RIGHT_SHOULDER][:2]
        left_hip = keypoints[LEFT_HIP][:2]
        right_hip = keypoints[RIGHT_HIP][:2]

        # Check visibility (visibility > 0.3)
        if (keypoints[LEFT_SHOULDER][2] < 0.3 or keypoints[RIGHT_SHOULDER][2] < 0.3 or
            keypoints[LEFT_HIP][2] < 0.3 or keypoints[RIGHT_HIP][2] < 0.3):
            return None

        # Calculate mid points
        mid_shoulder = (left_shoulder + right_shoulder) / 2
        mid_hip = (left_hip + right_hip) / 2

        # Calculate the vector of the torso
        vector = mid_hip - mid_shoulder

        # Calculate angle with respect to the vertical axis
        angle_rad = np.arctan2(vector[0], vector[1])
        angle_deg = np.degrees(angle_rad)

        return angle_deg
    except Exception as e:
        print(f"Error calculating torso angle: {e}")
        return None

def draw_skeleton(frame, keypoints, show_labels=True):
    """
    Draws the skeleton on the frame based on keypoints.
    
    Args:
        frame (numpy.ndarray): The current video frame.
        keypoints (numpy.ndarray): Array of shape (17, 3) representing COCO keypoints.
        show_labels (bool): Whether to display keypoint indices.
    
    Returns:
        numpy.ndarray: Annotated frame with skeleton.
    """
    for connection in SKELETON_CONNECTIONS:
        start_idx, end_idx = connection
        x_start, y_start, conf_start = keypoints[start_idx]
        x_end, y_end, conf_end = keypoints[end_idx]
        
        # Only draw if both keypoints have sufficient confidence
        if conf_start > 0.5 and conf_end > 0.5:
            start_point = (int(x_start), int(y_start))
            end_point = (int(x_end), int(y_end))
            cv2.line(frame, start_point, end_point, (255, 0, 0), 2)  # Blue lines

    if show_labels:
        # Draw keypoints indices
        for idx, (x, y, conf) in enumerate(keypoints):
            if conf > 0.5:
                cv2.putText(frame, f"{idx}", (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 0, 0), 1)  # Blue labels

    return frame

def detect_fall(video_path, angle_threshold=30, consecutive_frames=3, frame_sampling_rate=1, confidence_threshold=0.3, show_labels=True):
    """
    Detects falls in the uploaded video using pose estimation.
    
    Args:
        video_path (str): The path to the input video file uploaded by the user.
        angle_threshold (float): Angle threshold to classify a fall (in degrees).
        consecutive_frames (int): Number of consecutive frames to confirm a fall.
        frame_sampling_rate (int): Process every nth frame.
        confidence_threshold (float): Minimum confidence required for keypoint detection.
        show_labels (bool): Whether to display keypoint indices.
    
    Returns:
        tuple: (annotated_video_path, notification_message)
    """
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise ValueError("Unable to open the video file.")

        # Video properties
        fps = cap.get(cv2.CAP_PROP_FPS)
        width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')

        # Create a unique temporary file for the annotated video
        unique_id = uuid.uuid4().hex
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", prefix=f"annotated_{unique_id}_") as tmp:
            annotated_video_path = tmp.name

        out = cv2.VideoWriter(annotated_video_path, fourcc, fps, (width, height))

        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        current_frame = 0
        consecutive_fall_frames = 0
        total_falls = 0
        fall_frames = []  # To store frames where falls were detected

        while True:
            ret, frame = cap.read()
            if not ret:
                break  # End of video

            current_frame += 1

            # Implement frame sampling
            if current_frame % frame_sampling_rate != 0:
                out.write(frame)
                continue

            print(f"Processing frame {current_frame}/{frame_count}")

            # Run pose estimation
            results = model.predict(source=frame, conf=confidence_threshold, save=False, stream=False)

            # Iterate through detected persons
            for result in results:
                if not hasattr(result, 'keypoints') or result.keypoints is None:
                    continue
                for keypoints in result.keypoints.data:
                    # keypoints should be a tensor of shape (17,3)
                    if keypoints is None or not hasattr(keypoints, 'cpu'):
                        continue
                    # Convert to NumPy array
                    if isinstance(keypoints, torch.Tensor):
                        kpts = keypoints.cpu().numpy()
                    elif isinstance(keypoints, np.ndarray):
                        kpts = keypoints
                    else:
                        print(f"Unexpected keypoints data type: {type(keypoints)}")
                        continue

                    if kpts.size == 0 or kpts.shape[0] < 17:
                        print(f"Insufficient keypoints for processing in frame {current_frame}")
                        continue

                    angle = calculate_torso_angle(kpts, height)
                    if angle is None:
                        continue

                    # Determine if it's a fall
                    if abs(angle) > angle_threshold:
                        consecutive_fall_frames += 1
                        label = "Fall Detected!"
                        color = (0, 0, 255)  # Red
                    else:
                        if consecutive_fall_frames >= consecutive_frames:
                            total_falls += 1
                            fall_frames.append(current_frame)
                        consecutive_fall_frames = 0
                        label = "Normal"
                        color = (0, 255, 0)  # Green

                    # If fall persists over consecutive frames, mark as fall
                    if consecutive_fall_frames >= consecutive_frames:
                        cv2.putText(frame, label, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)

                    # Draw keypoints and skeleton
                    frame = draw_skeleton(frame, kpts, show_labels=show_labels)

            # Write annotated frame
            out.write(frame)

        # Release resources
        cap.release()
        out.release()

        # Final check for falls that persisted until the end of the video
        if consecutive_fall_frames >= consecutive_frames:
            total_falls += 1
            fall_frames.append(current_frame)

        # Generate notification message
        if total_falls > 0:
            if total_falls == 1:
                notification = f"A fall was detected at frame {fall_frames[0]}."
            else:
                frames = ', '.join(map(str, fall_frames))
                notification = f"{total_falls} falls were detected at frames: {frames}."
        else:
            notification = "No falls were detected in the video."

        # Check if annotated video was created
        if not os.path.exists(annotated_video_path):
            raise FileNotFoundError("Annotated video was not found. Please check the model and processing steps.")

        return annotated_video_path, notification

    except Exception as e:
        # Clean up in case of an error
        print(f"Error during fall detection: {e}")
        return None, f"An error occurred during fall detection: {e}"

def create_gradio_interface():
    # Define the Gradio interface with adjustable parameters
    iface = gr.Interface(
        fn=detect_fall,
        inputs=[
            gr.Video(label="Upload Video"),
            gr.Slider(
                label="Angle Threshold (degrees)",
                minimum=0,
                maximum=90,
                step=1,
                value=30,
                interactive=True,
                info="Adjust the torso angle threshold to classify a fall. Lower values increase sensitivity."
            ),
            gr.Slider(
                label="Consecutive Frames to Confirm Fall",
                minimum=1,
                maximum=10,
                step=1,
                value=3,
                interactive=True,
                info="Number of consecutive frames exceeding the angle threshold required to confirm a fall."
            ),
            gr.Slider(
                label="Frame Sampling Rate",
                minimum=1,
                maximum=10,
                step=1,
                value=1,
                interactive=True,
                info="Process every nth frame to speed up detection. Higher values reduce processing time."
            ),
            gr.Slider(
                label="Confidence Threshold",
                minimum=0.0,
                maximum=1.0,
                step=0.05,
                value=0.3,  # Changed default value to 0.3
                interactive=True,
                info="Minimum confidence required for keypoint detection. Higher values reduce false positives."
            ),
            gr.Checkbox(
                label="Show Keypoint Labels",
                value=True,
                interactive=True,
                info="Toggle the display of keypoint indices on the video."
            )
        ],
        outputs=[
            gr.Video(label="Annotated Video"),
            gr.Textbox(label="Fall Detection Notification")
        ],
        title="Fall Detection App 🚨",
        description=(
            "Upload a video of a person falling, and the app will detect and annotate the fall "
            "using pose estimation. Adjust the angle threshold, consecutive frames, frame sampling rate, "
            "and confidence threshold to fine-tune detection sensitivity and performance. "
            "The annotated video will display keypoints, skeleton lines, and indicate when a fall is detected."
        ),
        examples=[
            ["demo/person falling.mp4", 30, 3, 1, 0.3, True]
        ],  # Added example video with corresponding parameter values
        flagging_mode="never",  # Updated parameter name
    )
    return iface

# Create the Gradio interface
iface = create_gradio_interface()

# Ensure temporary directories are cleaned up on exit
def cleanup_temp_dirs():
    temp_dir = tempfile.gettempdir()
    # Implement additional cleanup logic if necessary

atexit.register(cleanup_temp_dirs)

# Launch the app
if __name__ == "__main__":
    iface.launch()