File size: 3,983 Bytes
8395863
 
 
8271835
776dd3c
8271835
36945ed
fc1d3e9
 
 
 
36945ed
8395863
 
fc1d3e9
8395863
660142e
 
cb374b8
 
 
 
 
3a6ce4c
cb374b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a6ce4c
 
 
 
 
cb374b8
3a6ce4c
cb374b8
660142e
8395863
 
 
 
 
 
 
 
 
 
3a6ce4c
 
660142e
8271835
 
 
8395863
660142e
8395863
660142e
adda212
8395863
660142e
8395863
660142e
 
 
 
 
776dd3c
3a6ce4c
660142e
8271835
3a6ce4c
 
 
621f2db
3a6ce4c
8271835
cb374b8
83d61e1
776dd3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, models
from PIL import Image
from TranSalNet_Res import TranSalNet
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img

device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()

def count_and_label_red_patches(heatmap, threshold=200):
    red_mask = heatmap[:, :, 2] > threshold
    contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Sort the contours based on their areas in descending order
    contours = sorted(contours, key=cv2.contourArea, reverse=True)
    
    original_image = np.array(image)
    
    # Find the centroid of the red spot with the highest area
    M_largest = cv2.moments(contours[0])
    if M_largest["m00"] != 0:
        cX_largest = int(M_largest["m10"] / M_largest["m00"])
        cY_largest = int(M_largest["m01"] / M_largest["m00"])
    else:
        cX_largest, cY_largest = 0, 0
    
    for i, contour in enumerate(contours, start=1):
        # Compute the centroid of the current contour
        M = cv2.moments(contour)
        if M["m00"] != 0:
            cX = int(M["m10"] / M["m00"])
            cY = int(M["m01"] / M["m00"])
        else:
            cX, cY = 0, 0
        
        radius = 20  # Adjust the circle radius to fit the numbers
        circle_color = (0, 0, 0)  # Blue color
        cv2.circle(original_image, (cX, cY), radius, circle_color, -1)  # Draw blue circle

        # Connect the current red spot to the red spot with the highest area
        line_color = (0, 0, 0)  # Red color
        cv2.line(original_image, (cX, cY), (cX_largest, cY_largest), line_color, 2)

        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1
        font_color = (255, 255, 255)
        line_type = cv2.LINE_AA
        cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type)

    return original_image, len(contours)

st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_image:
    image = Image.open(uploaded_image)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    if st.button('Detect Saliency'):
        img = image.resize((384, 288))
        img = np.array(img)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)  # Convert to BGR color space
        img = np.array(img) / 255.
        img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
        img = torch.from_numpy(img)
        img = img.type(torch.FloatTensor).to(device)

        pred_saliency = model(img).squeeze().detach().numpy()

        heatmap = (pred_saliency * 255).astype(np.uint8)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # Use a blue colormap (JET)

        heatmap = cv2.resize(heatmap, (image.width, image.height))

        enhanced_image = np.array(image)
        b, g, r = cv2.split(enhanced_image)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        b_enhanced = clahe.apply(b)
        enhanced_image = cv2.merge((b_enhanced, g, r))

        alpha = 0.7
        blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)

        original_image, num_red_patches = count_and_label_red_patches(heatmap)

        st.image(original_image, caption=f'Image with {num_red_patches} Red Patches', use_column_width=True, channels='RGB')

        st.image(blended_img, caption='Blended Image', use_column_width=True, channels='BGR')

        # Create a dir with name example to save
        cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
        st.success('Saliency detection complete. Result saved as "example/result15.png".')