import cv2 import io import numpy as np import torch from PIL import Image from fastapi import FastAPI, UploadFile, File, HTTPException from gradio import Gradio, Image as GImage from starlette.responses import StreamingResponse from TranSalNet_Res import TranSalNet from utils.data_process import preprocess_img, postprocess_img app = FastAPI() device = torch.device('cpu') model = TranSalNet() model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu'))) model.to(device) model.eval() def count_and_label_red_patches(heatmap, threshold=200): red_mask = heatmap[:, :, 2] > threshold contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours = sorted(contours, key=cv2.contourArea, reverse=True) original_image = np.array(image) centroid_list = [] for i, contour in enumerate(contours, start=1): M = cv2.moments(contour) if M["m00"] != 0: cX = int(M["m10"] / M["m00"]) cY = int(M["m01"] / M["m00"]) else: cX, cY = 0, 0 radius = 20 circle_color = (0, 0, 0) cv2.circle(original_image, (cX, cY), radius, circle_color, -1) font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 1 font_color = (255, 255, 255) line_type = cv2.LINE_AA cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type) centroid_list.append((cX, cY)) for i in range(len(centroid_list) - 1): start_point = centroid_list[i] end_point = centroid_list[i + 1] line_color = (0, 0, 0) cv2.line(original_image, start_point, end_point, line_color, 2) return original_image, len(contours) def process_image(image: Image.Image) -> np.ndarray: img = image.resize((384, 288)) img = np.array(img) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) img = np.array(img) / 255. img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0) img = torch.from_numpy(img) img = img.type(torch.FloatTensor).to(device) pred_saliency = model(img).squeeze().detach().numpy() heatmap = (pred_saliency * 255).astype(np.uint8) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) heatmap = cv2.resize(heatmap, (image.width, image.height)) enhanced_image = np.array(image) b, g, r = cv2.split(enhanced_image) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) b_enhanced = clahe.apply(b) enhanced_image = cv2.merge((b_enhanced, g, r)) alpha = 0.7 blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0) original_image, num_red_patches = count_and_label_red_patches(heatmap) # Save processed image (optional) cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200]) return blended_img def gr_process_image(input_image): image = Image.fromarray(input_image) processed_image = process_image(image) return processed_image iface = Gradio.Interface( fn=gr_process_image, inputs=GImage(), outputs=GImage("numpy") ) iface.launch(share=True)