import os
import cv2
import time
import torch
import gradio as gr
import numpy as np

# Make sure these are your local imports from your project.
from model import create_model
from config import NUM_CLASSES, DEVICE, CLASSES

# ----------------------------------------------------------------
# GLOBAL SETUP
# ----------------------------------------------------------------
# Create the model and load the best weights.
model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load("outputs/best_model_79.pth", map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE).eval()

# Create a colors array for each class index.
# (length matches len(CLASSES), including background if you wish).
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))

# COLORS = [
#     (255, 255, 0),  # Cyan - background
#     (50, 0, 255),  # Red - buffalo
#     (147, 20, 255),  # Pink - elephant
#     (0, 255, 0),  # Green - rhino
#     (238, 130, 238),  # Violet - zebra
# ]


# ----------------------------------------------------------------
# HELPER FUNCTIONS
# ----------------------------------------------------------------
def inference_on_image(orig_image: np.ndarray, resize_dim=None, threshold=0.25):
    """
    Runs inference on a single image (OpenCV BGR or NumPy array).
    - resize_dim: if not None, we resize to (resize_dim, resize_dim)
    - threshold: detection confidence threshold
    Returns: processed image with bounding boxes drawn.
    """
    image = orig_image.copy()
    # Optionally resize for inference.
    if resize_dim is not None:
        image = cv2.resize(image, (resize_dim, resize_dim))

    # Convert BGR to RGB, normalize [0..1]
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    # Move channels to front (C,H,W)
    image_tensor = torch.tensor(image_rgb.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0).to(DEVICE)
    start_time = time.time()
    # Inference
    with torch.no_grad():
        outputs = model(image_tensor)
    end_time = time.time()
    # Get the current fps.
    fps = 1 / (end_time - start_time)
    fps_text = f"FPS: {fps:.2f}"
    # Move outputs to CPU numpy
    outputs = [{k: v.cpu() for k, v in t.items()} for t in outputs]
    boxes = outputs[0]["boxes"].numpy()
    scores = outputs[0]["scores"].numpy()
    labels = outputs[0]["labels"].numpy().astype(int)

    # Filter out boxes with low confidence
    valid_idx = np.where(scores >= threshold)[0]
    boxes = boxes[valid_idx].astype(int)
    labels = labels[valid_idx]

    # If we resized for inference, rescale boxes back to orig_image size
    if resize_dim is not None:
        h_orig, w_orig = orig_image.shape[:2]
        h_new, w_new = resize_dim, resize_dim
        # scale boxes
        boxes[:, [0, 2]] = (boxes[:, [0, 2]] / w_new) * w_orig
        boxes[:, [1, 3]] = (boxes[:, [1, 3]] / h_new) * h_orig

    # Draw bounding boxes
    for box, label_idx in zip(boxes, labels):
        class_name = CLASSES[label_idx] if 0 <= label_idx < len(CLASSES) else str(label_idx)
        color = COLORS[label_idx % len(COLORS)][::-1]  # BGR color
        cv2.rectangle(orig_image, (box[0], box[1]), (box[2], box[3]), color, 5)
        cv2.putText(orig_image, class_name, (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 3)
        cv2.putText(
            orig_image,
            fps_text,
            (int((w_orig / 2) - 50), 30),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.8,
            (0, 255, 0),
            2,
            cv2.LINE_AA,
        )
    return orig_image, fps


def inference_on_frame(frame: np.ndarray, resize_dim=None, threshold=0.25):
    """
    Same as inference_on_image but for a single video frame.
    Returns the processed frame with bounding boxes.
    """
    return inference_on_image(frame, resize_dim, threshold)


# ----------------------------------------------------------------
# GRADIO FUNCTIONS
# ----------------------------------------------------------------


def img_inf(image_path, resize_dim, threshold):
    """
    Gradio function for image inference.
    :param image_path: File path from Gradio (uploaded image).
    :param model_name: Selected model from Radio (not used if only one model).
    Returns: A NumPy image array with bounding boxes.
    """
    if image_path is None:
        return None  # No image provided
    orig_image = cv2.imread(image_path)  # BGR
    if orig_image is None:
        return None  # Error reading image

    result_image, _ = inference_on_image(orig_image, resize_dim=resize_dim, threshold=threshold)
    # Return the image in RGB for Gradio's display
    result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
    return result_image_rgb


def vid_inf(video_path, resize_dim, threshold):
    """
    Gradio function for video inference.
    Processes each frame, draws bounding boxes, and writes to an output video.
    Returns: (last_processed_frame, output_video_file_path)
    """
    if video_path is None:
        return None, None  # No video provided

    # Prepare input capture
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None, None

    # Create an output file path
    os.makedirs("inference_outputs/videos", exist_ok=True)
    out_video_path = os.path.join("inference_outputs/videos", "video_output.mp4")
    # out_video_path = "video_output.mp4"

    # Get video properties
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # or 'XVID'

    # If FPS is 0 (some weird container), default to something
    if fps <= 0:
        fps = 20.0

    out_writer = cv2.VideoWriter(out_video_path, fourcc, fps, (width, height))

    frame_count = 0
    total_fps = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # Inference on frame
        processed_frame, frame_fps = inference_on_frame(frame, resize_dim=resize_dim, threshold=threshold)
        total_fps += frame_fps
        frame_count += 1

        # Write the processed frame
        out_writer.write(processed_frame)
        yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB), None

    avg_fps = total_fps / frame_count

    cap.release()
    out_writer.release()
    print(f"Average FPS: {avg_fps:.3f}")
    yield None, out_video_path


# ----------------------------------------------------------------
# BUILD THE GRADIO INTERFACES
# ----------------------------------------------------------------

# For demonstration, we define two possible model radio choices.
# You can ignore or expand this if you only use RetinaNet.
resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
inputs_image = gr.Image(type="filepath", label="Input Image")
outputs_image = gr.Image(type="numpy", label="Output Image")

interface_image = gr.Interface(
    fn=img_inf,
    inputs=[inputs_image, resize_dim, threshold],
    outputs=outputs_image,
    title="Image Inference",
    description="Upload your photo, select a model, and see the results!",
    examples=[["examples/buffalo.jpg"], ["examples/zebra.jpg"]],
    cache_examples=False,
)

resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
input_video = gr.Video(label="Input Video")

# Output is a pair: (last_processed_frame, output_video_path)
output_frame = gr.Image(type="numpy", label="Output (Last Processed Frame)")
output_video_file = gr.Video(format="mp4", label="Output Video")

interface_video = gr.Interface(
    fn=vid_inf,
    inputs=[input_video, resize_dim, threshold],
    outputs=[output_frame, output_video_file],
    title="Video Inference",
    description="Upload your video and see the processed output!",
    examples=[["examples/elephants.mp4"], ["examples/rhino.mp4"]],
    cache_examples=False,
)

# Combine them in a Tabbed Interface
demo = (
    gr.TabbedInterface(
        [interface_image, interface_video],
        tab_names=["Image", "Video"],
        title="FineTuning RetinaNet for Wildlife Animal Detection",
        theme="gstaff/xkcd",
    )
    .queue()
    .launch()
)