VisAt / app.py
udayzee05's picture
Update app.py
3b4610d
import cv2
import io
import numpy as np
import torch
from PIL import Image
from fastapi import FastAPI, UploadFile, File, HTTPException
from gradio import Gradio, Image as GImage
from starlette.responses import StreamingResponse
from TranSalNet_Res import TranSalNet
from utils.data_process import preprocess_img, postprocess_img
app = FastAPI()
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
def count_and_label_red_patches(heatmap, threshold=200):
red_mask = heatmap[:, :, 2] > threshold
contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = sorted(contours, key=cv2.contourArea, reverse=True)
original_image = np.array(image)
centroid_list = []
for i, contour in enumerate(contours, start=1):
M = cv2.moments(contour)
if M["m00"] != 0:
cX = int(M["m10"] / M["m00"])
cY = int(M["m01"] / M["m00"])
else:
cX, cY = 0, 0
radius = 20
circle_color = (0, 0, 0)
cv2.circle(original_image, (cX, cY), radius, circle_color, -1)
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_color = (255, 255, 255)
line_type = cv2.LINE_AA
cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type)
centroid_list.append((cX, cY))
for i in range(len(centroid_list) - 1):
start_point = centroid_list[i]
end_point = centroid_list[i + 1]
line_color = (0, 0, 0)
cv2.line(original_image, start_point, end_point, line_color, 2)
return original_image, len(contours)
def process_image(image: Image.Image) -> np.ndarray:
img = image.resize((384, 288))
img = np.array(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = np.array(img) / 255.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
pred_saliency = model(img).squeeze().detach().numpy()
heatmap = (pred_saliency * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
heatmap = cv2.resize(heatmap, (image.width, image.height))
enhanced_image = np.array(image)
b, g, r = cv2.split(enhanced_image)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
b_enhanced = clahe.apply(b)
enhanced_image = cv2.merge((b_enhanced, g, r))
alpha = 0.7
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
original_image, num_red_patches = count_and_label_red_patches(heatmap)
# Save processed image (optional)
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
return blended_img
def gr_process_image(input_image):
image = Image.fromarray(input_image)
processed_image = process_image(image)
return processed_image
iface = Gradio.Interface(
fn=gr_process_image,
inputs=GImage(),
outputs=GImage("numpy")
)
iface.launch(share=True)