File size: 3,189 Bytes
8395863
3b4610d
8395863
8271835
 
3b4610d
 
 
36945ed
fc1d3e9
 
e4b253e
 
fc1d3e9
36945ed
8395863
 
fc1d3e9
8395863
3b4610d
660142e
 
cb374b8
3b4610d
cb374b8
3b4610d
3a6ce4c
3b4610d
 
 
cb374b8
 
 
 
 
 
 
3b4610d
 
 
 
cb374b8
3a6ce4c
 
 
 
cb374b8
3489552
3b4610d
 
3489552
 
 
3b4610d
3489552
3a6ce4c
cb374b8
660142e
3b4610d
e4b253e
 
 
3b4610d
e4b253e
 
 
 
3489552
e4b253e
8395863
e4b253e
3b4610d
8395863
e4b253e
8395863
e4b253e
 
 
 
 
8395863
e4b253e
 
8395863
e4b253e
776dd3c
e4b253e
 
8271835
e4b253e
3a6ce4c
621f2db
3b4610d
 
 
 
 
 
 
 
 
 
 
8271835
3b4610d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import cv2
import io
import numpy as np
import torch
from PIL import Image
from fastapi import FastAPI, UploadFile, File, HTTPException
from gradio import Gradio, Image as GImage
from starlette.responses import StreamingResponse
from TranSalNet_Res import TranSalNet
from utils.data_process import preprocess_img, postprocess_img

app = FastAPI()

device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()


def count_and_label_red_patches(heatmap, threshold=200):
    red_mask = heatmap[:, :, 2] > threshold
    contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    contours = sorted(contours, key=cv2.contourArea, reverse=True)

    original_image = np.array(image)

    centroid_list = []

    for i, contour in enumerate(contours, start=1):
        M = cv2.moments(contour)
        if M["m00"] != 0:
            cX = int(M["m10"] / M["m00"])
            cY = int(M["m01"] / M["m00"])
        else:
            cX, cY = 0, 0

        radius = 20
        circle_color = (0, 0, 0)
        cv2.circle(original_image, (cX, cY), radius, circle_color, -1)

        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1
        font_color = (255, 255, 255)
        line_type = cv2.LINE_AA
        cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type)

        centroid_list.append((cX, cY))

    for i in range(len(centroid_list) - 1):
        start_point = centroid_list[i]
        end_point = centroid_list[i + 1]
        line_color = (0, 0, 0)
        cv2.line(original_image, start_point, end_point, line_color, 2)

    return original_image, len(contours)


def process_image(image: Image.Image) -> np.ndarray:
    img = image.resize((384, 288))
    img = np.array(img)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    img = np.array(img) / 255.
    img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
    img = torch.from_numpy(img)
    img = img.type(torch.FloatTensor).to(device)

    pred_saliency = model(img).squeeze().detach().numpy()

    heatmap = (pred_saliency * 255).astype(np.uint8)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    heatmap = cv2.resize(heatmap, (image.width, image.height))

    enhanced_image = np.array(image)
    b, g, r = cv2.split(enhanced_image)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    b_enhanced = clahe.apply(b)
    enhanced_image = cv2.merge((b_enhanced, g, r))

    alpha = 0.7
    blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)

    original_image, num_red_patches = count_and_label_red_patches(heatmap)

    # Save processed image (optional)
    cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])

    return blended_img


def gr_process_image(input_image):
    image = Image.fromarray(input_image)
    processed_image = process_image(image)
    return processed_image


iface = Gradio.Interface(
    fn=gr_process_image,
    inputs=GImage(),
    outputs=GImage("numpy")
)

iface.launch(share=True)