|
import cv2 |
|
import io |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
from gradio import Gradio, Image as GImage |
|
from starlette.responses import StreamingResponse |
|
from TranSalNet_Res import TranSalNet |
|
from utils.data_process import preprocess_img, postprocess_img |
|
|
|
app = FastAPI() |
|
|
|
device = torch.device('cpu') |
|
model = TranSalNet() |
|
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu'))) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
def count_and_label_red_patches(heatmap, threshold=200): |
|
red_mask = heatmap[:, :, 2] > threshold |
|
contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
contours = sorted(contours, key=cv2.contourArea, reverse=True) |
|
|
|
original_image = np.array(image) |
|
|
|
centroid_list = [] |
|
|
|
for i, contour in enumerate(contours, start=1): |
|
M = cv2.moments(contour) |
|
if M["m00"] != 0: |
|
cX = int(M["m10"] / M["m00"]) |
|
cY = int(M["m01"] / M["m00"]) |
|
else: |
|
cX, cY = 0, 0 |
|
|
|
radius = 20 |
|
circle_color = (0, 0, 0) |
|
cv2.circle(original_image, (cX, cY), radius, circle_color, -1) |
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
font_scale = 1 |
|
font_color = (255, 255, 255) |
|
line_type = cv2.LINE_AA |
|
cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type) |
|
|
|
centroid_list.append((cX, cY)) |
|
|
|
for i in range(len(centroid_list) - 1): |
|
start_point = centroid_list[i] |
|
end_point = centroid_list[i + 1] |
|
line_color = (0, 0, 0) |
|
cv2.line(original_image, start_point, end_point, line_color, 2) |
|
|
|
return original_image, len(contours) |
|
|
|
|
|
def process_image(image: Image.Image) -> np.ndarray: |
|
img = image.resize((384, 288)) |
|
img = np.array(img) |
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
img = np.array(img) / 255. |
|
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0) |
|
img = torch.from_numpy(img) |
|
img = img.type(torch.FloatTensor).to(device) |
|
|
|
pred_saliency = model(img).squeeze().detach().numpy() |
|
|
|
heatmap = (pred_saliency * 255).astype(np.uint8) |
|
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) |
|
|
|
heatmap = cv2.resize(heatmap, (image.width, image.height)) |
|
|
|
enhanced_image = np.array(image) |
|
b, g, r = cv2.split(enhanced_image) |
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
|
b_enhanced = clahe.apply(b) |
|
enhanced_image = cv2.merge((b_enhanced, g, r)) |
|
|
|
alpha = 0.7 |
|
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0) |
|
|
|
original_image, num_red_patches = count_and_label_red_patches(heatmap) |
|
|
|
|
|
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200]) |
|
|
|
return blended_img |
|
|
|
|
|
def gr_process_image(input_image): |
|
image = Image.fromarray(input_image) |
|
processed_image = process_image(image) |
|
return processed_image |
|
|
|
|
|
iface = Gradio.Interface( |
|
fn=gr_process_image, |
|
inputs=GImage(), |
|
outputs=GImage("numpy") |
|
) |
|
|
|
iface.launch(share=True) |
|
|