File size: 3,189 Bytes
8395863 3b4610d 8395863 8271835 3b4610d 36945ed fc1d3e9 e4b253e fc1d3e9 36945ed 8395863 fc1d3e9 8395863 3b4610d 660142e cb374b8 3b4610d cb374b8 3b4610d 3a6ce4c 3b4610d cb374b8 3b4610d cb374b8 3a6ce4c cb374b8 3489552 3b4610d 3489552 3b4610d 3489552 3a6ce4c cb374b8 660142e 3b4610d e4b253e 3b4610d e4b253e 3489552 e4b253e 8395863 e4b253e 3b4610d 8395863 e4b253e 8395863 e4b253e 8395863 e4b253e 8395863 e4b253e 776dd3c e4b253e 8271835 e4b253e 3a6ce4c 621f2db 3b4610d 8271835 3b4610d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import cv2
import io
import numpy as np
import torch
from PIL import Image
from fastapi import FastAPI, UploadFile, File, HTTPException
from gradio import Gradio, Image as GImage
from starlette.responses import StreamingResponse
from TranSalNet_Res import TranSalNet
from utils.data_process import preprocess_img, postprocess_img
app = FastAPI()
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
def count_and_label_red_patches(heatmap, threshold=200):
red_mask = heatmap[:, :, 2] > threshold
contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = sorted(contours, key=cv2.contourArea, reverse=True)
original_image = np.array(image)
centroid_list = []
for i, contour in enumerate(contours, start=1):
M = cv2.moments(contour)
if M["m00"] != 0:
cX = int(M["m10"] / M["m00"])
cY = int(M["m01"] / M["m00"])
else:
cX, cY = 0, 0
radius = 20
circle_color = (0, 0, 0)
cv2.circle(original_image, (cX, cY), radius, circle_color, -1)
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_color = (255, 255, 255)
line_type = cv2.LINE_AA
cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type)
centroid_list.append((cX, cY))
for i in range(len(centroid_list) - 1):
start_point = centroid_list[i]
end_point = centroid_list[i + 1]
line_color = (0, 0, 0)
cv2.line(original_image, start_point, end_point, line_color, 2)
return original_image, len(contours)
def process_image(image: Image.Image) -> np.ndarray:
img = image.resize((384, 288))
img = np.array(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = np.array(img) / 255.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
pred_saliency = model(img).squeeze().detach().numpy()
heatmap = (pred_saliency * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
heatmap = cv2.resize(heatmap, (image.width, image.height))
enhanced_image = np.array(image)
b, g, r = cv2.split(enhanced_image)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
b_enhanced = clahe.apply(b)
enhanced_image = cv2.merge((b_enhanced, g, r))
alpha = 0.7
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
original_image, num_red_patches = count_and_label_red_patches(heatmap)
# Save processed image (optional)
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
return blended_img
def gr_process_image(input_image):
image = Image.fromarray(input_image)
processed_image = process_image(image)
return processed_image
iface = Gradio.Interface(
fn=gr_process_image,
inputs=GImage(),
outputs=GImage("numpy")
)
iface.launch(share=True)
|