shipnet / app.py
Mehmet Batuhan Duman
Add Gradio app and requirements
b53b6d3
raw
history blame
1.4 kB
import cv2
import numpy as np
import gradio as gr
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import os
# Add your model classes (Net and Net2) here.
# Loading model
model = None
model2 = None
model2_path = "model4.pth"
if os.path.exists(model2_path):
state_dict = torch.load(model2_path, map_location=torch.device('cpu'))
new_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace("module.", "")
new_state_dict[new_key] = value
model = Net2()
model.load_state_dict(new_state_dict)
model.eval()
else:
print("Model file not found at", model2_path)
# Add the scanmap function here.
def process_image(image: Image.Image):
image_np = np.array(image)
start_time = time.time()
heatmap = scanmap(image_np, model)
elapsed_time = time.time() - start_time
heatmap_img = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)).convert('RGB')
heatmap_img = heatmap_img.resize(image.size)
return heatmap_img, elapsed_time
inputs = gr.inputs.Image(label="Upload Image")
outputs = [
gr.outputs.Image(label="Heatmap"),
gr.outputs.Textbox(label="Elapsed Time (seconds)")
]
iface = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="ShipNet Heatmap")
iface.launch()