import cv2
import numpy as np
import gradio as gr
from PIL import Image, ImageOps
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import os
import time
import io
import base64
import torch
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from functools import partial

class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)

        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.dropout2 = nn.Dropout(0.25)

        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.dropout3 = nn.Dropout(0.25)

        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.pool4 = nn.MaxPool2d(2, 2)
        self.dropout4 = nn.Dropout(0.25)

        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(64 * 5 * 5, 200)
        self.fc2 = nn.Linear(200, 150)
        self.fc3 = nn.Linear(150, 2)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = self.dropout1(x)

        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = self.dropout2(x)

        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        x = self.dropout3(x)

        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool4(x)
        x = self.dropout4(x)

        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x), dim=1)
        return x
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)

        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.dropout2 = nn.Dropout(0.25)

        self.conv3 = nn.Conv2d(32, 32, 3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.dropout3 = nn.Dropout(0.25)

        self.conv4 = nn.Conv2d(32, 32, 3, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2)
        self.dropout4 = nn.Dropout(0.25)

        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(32 * 5 * 5, 200)
        self.fc2 = nn.Linear(200, 150)
        self.fc3 = nn.Linear(150, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.dropout1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.dropout2(x)

        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = self.dropout3(x)

        x = F.relu(self.conv4(x))
        x = self.pool4(x)
        x = self.dropout4(x)

        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

model = None
model_path = "model3.pth"

model2 = None
model2_path = "model4.pth"

if os.path.exists(model_path):
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    new_state_dict = {}
    for key, value in state_dict.items():
        new_key = key.replace("module.", "")
        new_state_dict[new_key] = value

    model = Net()
    model.load_state_dict(new_state_dict)
    model.eval()

else:
    print("Model file not found at", model_path)


# def process_image(input_image):
#     image = Image.fromarray(input_image).convert("RGB")
#
#     start_time = time.time()
#     heatmap = scanmap(np.array(image), model)
#     elapsed_time = time.time() - start_time
#     heatmap_img = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)).convert('RGB')
#
#     heatmap_img = heatmap_img.resize(image.size)
#
#     return image, heatmap_img, int(elapsed_time)
#
#
# def scanmap(image_np, model):
#     image_np = image_np.astype(np.float32) / 255.0
#
#     window_size = (80, 80)
#     stride = 10
#
#     height, width, channels = image_np.shape
#
#     probabilities_map = []
#
#     for y in range(0, height - window_size[1] + 1, stride):
#         row_probabilities = []
#         for x in range(0, width - window_size[0] + 1, stride):
#             cropped_window = image_np[y:y + window_size[1], x:x + window_size[0]]
#             cropped_window_torch = transforms.ToTensor()(cropped_window).unsqueeze(0)
#
#             with torch.no_grad():
#                 probabilities = model(cropped_window_torch)
#
#             row_probabilities.append(probabilities[0, 1].item())
#
#         probabilities_map.append(row_probabilities)
#
#     probabilities_map = np.array(probabilities_map)
#     return probabilities_map
#
# def gradio_process_image(input_image):
#     original, heatmap, elapsed_time = process_image(input_image)
#     return original, heatmap, f"Elapsed Time (seconds): {elapsed_time}"
#
# inputs = gr.Image(label="Upload Image")
# outputs = [
#     gr.Image(label="Original Image"),
#     gr.Image(label="Heatmap"),
#     gr.Textbox(label="Elapsed Time")
# ]
#
# iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
# iface.launch()
def scanmap(image_np, model, threshold=0.5):
    image_np = image_np.astype(np.float32) / 255.0

    window_size = (80, 80)
    stride = 10

    height, width, channels = image_np.shape

    fig, ax = plt.subplots(1)
    ax.imshow(image_np)

    for y in range(0, height - window_size[1] + 1, stride):
        for x in range(0, width - window_size[0] + 1, stride):
            cropped_window = image_np[y:y + window_size[1], x:x + window_size[0]]
            cropped_window_torch = transforms.ToTensor()(cropped_window).unsqueeze(0)

            with torch.no_grad():
                probabilities = model(cropped_window_torch)

            # if probability is greater than threshold, draw a bounding box
            if probabilities[0, 1].item() > threshold:
                rect = patches.Rectangle((x, y), window_size[0], window_size[1], linewidth=1, edgecolor='r',
                                         facecolor='none')
                ax.add_patch(rect)

    # Convert matplotlib figure to PIL Image
    fig.canvas.draw()
    img_arr = np.array(fig.canvas.renderer.buffer_rgba())
    plt.close(fig)

    img = Image.fromarray(img_arr)

    return img


def process_image(input_image):
    image = Image.fromarray(input_image).convert("RGB")

    start_time = time.time()
    detected_ships_image = scanmap(np.array(image), model)
    elapsed_time = time.time() - start_time

    return image, detected_ships_image, int(elapsed_time)

def gradio_process_image(input_image):
    original, detected_ships_image, elapsed_time = process_image(input_image)
    return original, detected_ships_image, f"Elapsed Time (seconds): {elapsed_time}"

inputs = gr.Image(label="Upload Image")
outputs = [
    gr.Image(label="Original Image"),
    gr.Image(label="Heatmap"),
    gr.Textbox(label="Elapsed Time")
]

iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
iface.launch()