File size: 2,983 Bytes
8395863
 
 
8271835
776dd3c
8271835
36945ed
fc1d3e9
 
 
 
 
36945ed
8395863
 
fc1d3e9
8395863
660142e
 
 
 
 
c0f6a4b
660142e
 
 
 
 
 
 
8395863
 
 
 
 
 
 
 
 
 
660142e
8271835
 
 
8395863
660142e
8395863
660142e
adda212
8395863
660142e
8395863
660142e
8271835
660142e
 
 
 
 
776dd3c
c0f6a4b
660142e
8271835
660142e
 
621f2db
660142e
8271835
c0f6a4b
83d61e1
776dd3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img

device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()

def count_and_label_red_patches(heatmap, threshold=200):
    red_mask = heatmap[:, :, 2] > threshold
    _, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8)

    num_red_patches = labels.max()

    for i in range(1, num_red_patches + 1):
        patch_mask = (labels == i)
        patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2)
        cv2.putText(heatmap, str(i), (patch_centroid_x, patch_centroid_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)
    
    return heatmap, num_red_patches

st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_image:
    image = Image.open(uploaded_image)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    if st.button('Detect Saliency'):
        img = image.resize((384, 288))
        img = np.array(img) / 255.
        img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
        img = torch.from_numpy(img)
        img = img.type(torch.FloatTensor).to(device)

        pred_saliency = model(img).squeeze().detach().numpy()

        heatmap = (pred_saliency * 255).astype(np.uint8)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # Use a blue colormap (JET)

        heatmap = cv2.resize(heatmap, (image.width, image.height))

        heatmap, num_red_patches = count_and_label_red_patches(heatmap)

        enhanced_image = np.array(image)
        b, g, r = cv2.split(enhanced_image)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        b_enhanced = clahe.apply(b)
        enhanced_image = cv2.merge((b_enhanced, g, r))

        alpha = 0.7  
        blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)

        st.image(heatmap, caption='Enhanced Saliency Heatmap', use_column_width=True, channels='BGR')
        st.image(enhanced_image, caption='Enhanced Blue Image', use_column_width=True, channels='BGR')

        st.image(blended_img, caption=f'Blended Image with {num_red_patches} Red Patches', use_column_width=True, channels='BGR')

        # Create a dir with name example to save
        cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
        st.success('Saliency detection complete. Result saved as "example/result15.png".')