File size: 3,829 Bytes
8395863 8271835 e4b253e 8271835 36945ed fc1d3e9 e4b253e fc1d3e9 36945ed 8395863 fc1d3e9 8395863 660142e cb374b8 3a6ce4c cb374b8 3489552 cb374b8 3a6ce4c cb374b8 3489552 3a6ce4c cb374b8 660142e e4b253e 3489552 e4b253e 8395863 e4b253e 8395863 e4b253e 8395863 e4b253e 8395863 e4b253e 8395863 e4b253e 776dd3c e4b253e 8271835 e4b253e 3a6ce4c e4b253e 621f2db e4b253e 8271835 e4b253e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import cv2
import numpy as np
import torch
from fastapi import FastAPI, UploadFile, File
from PIL import Image
from TranSalNet_Res import TranSalNet
from utils.data_process import preprocess_img, postprocess_img
app = FastAPI()
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
def count_and_label_red_patches(heatmap, threshold=200):
red_mask = heatmap[:, :, 2] > threshold
contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Sort the contours based on their areas in descending order
contours = sorted(contours, key=cv2.contourArea, reverse=True)
original_image = np.array(image)
centroid_list = [] # List to store the centroids of the contours in order
for i, contour in enumerate(contours, start=1):
# Compute the centroid of the current contour
M = cv2.moments(contour)
if M["m00"] != 0:
cX = int(M["m10"] / M["m00"])
cY = int(M["m01"] / M["m00"])
else:
cX, cY = 0, 0
radius = 20 # Adjust the circle radius to fit the numbers
circle_color = (0, 0, 0) # Blue color
cv2.circle(original_image, (cX, cY), radius, circle_color, -1) # Draw blue circle
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_color = (255, 255, 255)
line_type = cv2.LINE_AA
cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type)
centroid_list.append((cX, cY)) # Add the centroid to the list
# Connect the red spots in the desired order
for i in range(len(centroid_list) - 1):
start_point = centroid_list[i]
end_point = centroid_list[i + 1]
line_color = (0, 0, 0) # Red color
cv2.line(original_image, start_point, end_point, line_color, 2)
return original_image, len(contours)
def process_image(image: Image.Image) -> np.ndarray:
img = image.resize((384, 288))
img = np.array(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Convert to BGR color space
img = np.array(img) / 255.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
pred_saliency = model(img).squeeze().detach().numpy()
heatmap = (pred_saliency * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Use a blue colormap (JET)
heatmap = cv2.resize(heatmap, (image.width, image.height))
enhanced_image = np.array(image)
b, g, r = cv2.split(enhanced_image)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
b_enhanced = clahe.apply(b)
enhanced_image = cv2.merge((b_enhanced, g, r))
alpha = 0.7
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
original_image, num_red_patches = count_and_label_red_patches(heatmap)
# Save processed image (optional)
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
return blended_img
@app.post("/process_image")
async def process_uploaded_image(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error opening image: {str(e)}")
try:
processed_image = process_image(image)
return StreamingResponse(io.BytesIO(cv2.imencode('.png', processed_image)[1].tobytes()), media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|