File size: 3,459 Bytes
8395863
 
 
8271835
776dd3c
8271835
36945ed
fc1d3e9
 
 
 
 
36945ed
8395863
 
fc1d3e9
8395863
660142e
 
 
 
 
 
adda212
660142e
 
 
 
 
 
 
8395863
 
 
 
 
 
 
 
 
 
660142e
8271835
 
 
8395863
660142e
8395863
adda212
660142e
adda212
8395863
adda212
660142e
8395863
adda212
660142e
8271835
adda212
660142e
 
 
 
 
776dd3c
adda212
660142e
8271835
adda212
660142e
 
621f2db
adda212
660142e
8271835
adda212
83d61e1
776dd3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img

device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()

def count_and_label_red_patches(heatmap, threshold=200):
    red_mask = heatmap[:, :, 2] > threshold
    _, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8)

    num_red_patches = labels.max()
    
    # Iterate through the labeled patches and put sequential numbers on top
    for i in range(1, num_red_patches + 1):
        patch_mask = (labels == i)
        patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2)
        cv2.putText(heatmap, str(i), (patch_centroid_x, patch_centroid_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)
    
    return heatmap, num_red_patches

st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_image:
    image = Image.open(uploaded_image)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    if st.button('Detect Saliency'):
        img = image.resize((384, 288))
        img = np.array(img) / 255.
        img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
        img = torch.from_numpy(img)
        img = img.type(torch.FloatTensor).to(device)

        pred_saliency = model(img).squeeze().detach().numpy()

        # Convert the saliency map to a heatmap with a blue color map
        heatmap = (pred_saliency * 255).astype(np.uint8)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # Use a blue colormap (JET)

        # Resize the heatmap to match the image dimensions
        heatmap = cv2.resize(heatmap, (image.width, image.height))

        # Overlay red patch labels on the heatmap
        heatmap, num_red_patches = count_and_label_red_patches(heatmap)

        # Convert the image to a NumPy array and enhance the blue channel using CLAHE
        enhanced_image = np.array(image)
        b, g, r = cv2.split(enhanced_image)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        b_enhanced = clahe.apply(b)
        enhanced_image = cv2.merge((b_enhanced, g, r))

        alpha = 0.7  # Adjust alpha to control the blending
        blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)

        # Display the images in your Streamlit app
        st.image(heatmap, caption='Enhanced Saliency Heatmap', use_column_width=True, channels='BGR')
        st.image(enhanced_image, caption='Enhanced Blue Image', use_column_width=True, channels='BGR')

        # Overlay the red patch count on the blended image
        st.image(blended_img, caption=f'Blended Image with {num_red_patches} Red Patches', use_column_width=True, channels='BGR')

        # Save the result image
        cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
        st.success('Saliency detection complete. Result saved as "example/result15.png".')