File size: 3,829 Bytes
8395863
 
8271835
e4b253e
8271835
36945ed
fc1d3e9
 
e4b253e
 
 
fc1d3e9
36945ed
8395863
 
fc1d3e9
8395863
660142e
 
cb374b8
 
 
 
 
3a6ce4c
cb374b8
3489552
cb374b8
 
 
 
 
 
 
 
 
 
 
 
 
 
3a6ce4c
 
 
 
cb374b8
3489552
 
 
 
 
 
 
 
 
3a6ce4c
cb374b8
660142e
e4b253e
 
 
 
 
 
 
 
3489552
e4b253e
8395863
e4b253e
 
8395863
e4b253e
8395863
e4b253e
 
 
 
 
8395863
e4b253e
 
8395863
e4b253e
776dd3c
e4b253e
 
8271835
e4b253e
3a6ce4c
e4b253e
 
 
 
 
 
 
621f2db
e4b253e
 
 
8271835
e4b253e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import cv2
import numpy as np
import torch
from fastapi import FastAPI, UploadFile, File
from PIL import Image
from TranSalNet_Res import TranSalNet
from utils.data_process import preprocess_img, postprocess_img


app = FastAPI()

device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()

def count_and_label_red_patches(heatmap, threshold=200):
    red_mask = heatmap[:, :, 2] > threshold
    contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Sort the contours based on their areas in descending order
    contours = sorted(contours, key=cv2.contourArea, reverse=True)
    
    original_image = np.array(image)
    
    centroid_list = []  # List to store the centroids of the contours in order
    
    for i, contour in enumerate(contours, start=1):
        # Compute the centroid of the current contour
        M = cv2.moments(contour)
        if M["m00"] != 0:
            cX = int(M["m10"] / M["m00"])
            cY = int(M["m01"] / M["m00"])
        else:
            cX, cY = 0, 0
        
        radius = 20  # Adjust the circle radius to fit the numbers
        circle_color = (0, 0, 0)  # Blue color
        cv2.circle(original_image, (cX, cY), radius, circle_color, -1)  # Draw blue circle

        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1
        font_color = (255, 255, 255)
        line_type = cv2.LINE_AA
        cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type)
        
        centroid_list.append((cX, cY))  # Add the centroid to the list

    # Connect the red spots in the desired order
    for i in range(len(centroid_list) - 1):
        start_point = centroid_list[i]
        end_point = centroid_list[i + 1]
        line_color = (0, 0, 0)  # Red color
        cv2.line(original_image, start_point, end_point, line_color, 2)

    return original_image, len(contours)

def process_image(image: Image.Image) -> np.ndarray:
    img = image.resize((384, 288))
    img = np.array(img)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)  # Convert to BGR color space
    img = np.array(img) / 255.
    img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
    img = torch.from_numpy(img)
    img = img.type(torch.FloatTensor).to(device)

    pred_saliency = model(img).squeeze().detach().numpy()

    heatmap = (pred_saliency * 255).astype(np.uint8)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # Use a blue colormap (JET)

    heatmap = cv2.resize(heatmap, (image.width, image.height))

    enhanced_image = np.array(image)
    b, g, r = cv2.split(enhanced_image)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    b_enhanced = clahe.apply(b)
    enhanced_image = cv2.merge((b_enhanced, g, r))

    alpha = 0.7
    blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)

    original_image, num_red_patches = count_and_label_red_patches(heatmap)

    # Save processed image (optional)
    cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])

    return blended_img

@app.post("/process_image")
async def process_uploaded_image(file: UploadFile = File(...)):
    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents))
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Error opening image: {str(e)}")

    try:
        processed_image = process_image(image)
        return StreamingResponse(io.BytesIO(cv2.imencode('.png', processed_image)[1].tobytes()), media_type="image/png")

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")