import cv2
import io
import numpy as np
import torch
from PIL import Image
from fastapi import FastAPI, UploadFile, File, HTTPException
from gradio import Gradio, Image as GImage
from starlette.responses import StreamingResponse
from TranSalNet_Res import TranSalNet
from utils.data_process import preprocess_img, postprocess_img

app = FastAPI()

device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()


def count_and_label_red_patches(heatmap, threshold=200):
    red_mask = heatmap[:, :, 2] > threshold
    contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    contours = sorted(contours, key=cv2.contourArea, reverse=True)

    original_image = np.array(image)

    centroid_list = []

    for i, contour in enumerate(contours, start=1):
        M = cv2.moments(contour)
        if M["m00"] != 0:
            cX = int(M["m10"] / M["m00"])
            cY = int(M["m01"] / M["m00"])
        else:
            cX, cY = 0, 0

        radius = 20
        circle_color = (0, 0, 0)
        cv2.circle(original_image, (cX, cY), radius, circle_color, -1)

        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1
        font_color = (255, 255, 255)
        line_type = cv2.LINE_AA
        cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type)

        centroid_list.append((cX, cY))

    for i in range(len(centroid_list) - 1):
        start_point = centroid_list[i]
        end_point = centroid_list[i + 1]
        line_color = (0, 0, 0)
        cv2.line(original_image, start_point, end_point, line_color, 2)

    return original_image, len(contours)


def process_image(image: Image.Image) -> np.ndarray:
    img = image.resize((384, 288))
    img = np.array(img)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    img = np.array(img) / 255.
    img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
    img = torch.from_numpy(img)
    img = img.type(torch.FloatTensor).to(device)

    pred_saliency = model(img).squeeze().detach().numpy()

    heatmap = (pred_saliency * 255).astype(np.uint8)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    heatmap = cv2.resize(heatmap, (image.width, image.height))

    enhanced_image = np.array(image)
    b, g, r = cv2.split(enhanced_image)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    b_enhanced = clahe.apply(b)
    enhanced_image = cv2.merge((b_enhanced, g, r))

    alpha = 0.7
    blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)

    original_image, num_red_patches = count_and_label_red_patches(heatmap)

    # Save processed image (optional)
    cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])

    return blended_img


def gr_process_image(input_image):
    image = Image.fromarray(input_image)
    processed_image = process_image(image)
    return processed_image


iface = Gradio.Interface(
    fn=gr_process_image,
    inputs=GImage(),
    outputs=GImage("numpy")
)

iface.launch(share=True)