Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
import gradio as gr | |
from PIL import Image, ImageOps | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import os | |
# Add your model classes (Net and Net2) here. | |
# Loading model | |
model = None | |
model2 = None | |
model2_path = "model4.pth" | |
if os.path.exists(model2_path): | |
state_dict = torch.load(model2_path, map_location=torch.device('cpu')) | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
new_key = key.replace("module.", "") | |
new_state_dict[new_key] = value | |
model = Net2() | |
model.load_state_dict(new_state_dict) | |
model.eval() | |
else: | |
print("Model file not found at", model2_path) | |
# Add the scanmap function here. | |
def process_image(image: Image.Image): | |
image_np = np.array(image) | |
start_time = time.time() | |
heatmap = scanmap(image_np, model) | |
elapsed_time = time.time() - start_time | |
heatmap_img = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)).convert('RGB') | |
heatmap_img = heatmap_img.resize(image.size) | |
return heatmap_img, elapsed_time | |
inputs = gr.inputs.Image(label="Upload Image") | |
outputs = [ | |
gr.outputs.Image(label="Heatmap"), | |
gr.outputs.Textbox(label="Elapsed Time (seconds)") | |
] | |
iface = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="ShipNet Heatmap") | |
iface.launch() | |