File size: 3,936 Bytes
8395863
 
 
8271835
fc1d3e9
8271835
36945ed
fc1d3e9
 
 
 
 
8395863
 
fc1d3e9
36945ed
8395863
 
fc1d3e9
8395863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8271835
 
 
8395863
 
8271835
8395863
 
 
 
 
 
 
 
 
 
 
 
8271835
 
 
 
631c591
8271835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b880e7
8271835
 
 
 
 
 
 
 
ec10188
 
8271835
8395863
 
 
8271835
 
 
83d61e1
8271835
8395863
8271835
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, utils, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img



# Load the model and set the device
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()

# Define Streamlit app
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_image:
    image = Image.open(uploaded_image)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    # Check if the user clicks a button
    if st.button('Detect Saliency'):
        # Preprocess the image
        img = image.resize((384, 288))
        img = np.array(img) / 255.
        img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
        img = torch.from_numpy(img)
        img = img.type(torch.FloatTensor).to(device)

        # Get saliency prediction
        pred_saliency = model(img)

        # Convert the result back to a PIL image
        toPIL = transforms.ToPILImage()
        pic = toPIL(pred_saliency.squeeze())

        # Colorize the grayscale prediction
        colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)

        # Ensure the colorized image has the same dimensions as the original image
        original_img = np.array(image)
        colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))

        # Compute intensity values from the colorized image
        intensity_map = cv2.cvtColor(colorized_img, cv2.COLOR_BGR2GRAY)

        # Threshold the intensity map to create a binary mask
        _, binary_map = cv2.threshold(intensity_map, 255, 0, cv2.THRESH_BINARY)

        # Find contours in the binary map
        contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Sort the contours by area in descending order
        contours = sorted(contours, key=cv2.contourArea, reverse=True)

        # Create an empty label map for ranking based on area
        label_map = np.zeros_like(intensity_map)

        # Rank and label each region based on area
        for i, contour in enumerate(contours):
            M = cv2.moments(contour)
            if M["m00"] == 0:
                continue
            center_x = int(M["m10"] / M["m00"])
            center_y = int(M["m01"] / M["m00"])
            cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, 255, 2, cv2.LINE_AA)

        # Blend the colorized image with the original image
        alpha = 0.7  # Adjust the alpha value to control blending strength
        blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)

        # Overlay the labels on the blended image
        font = cv2.FONT_HERSHEY_SIMPLEX
        for i in range(1, len(contours) + 1):
            mask = (label_map == i).astype(np.uint8)
            x, y, w, h = cv2.boundingRect(contours[i-1])
            org = (x, y)
            color = (255, 0, 0)  # Blue color
            thickness = 3
            cv2.putText(blended_img, str(i), org, font, 1, color, thickness, cv2.LINE_AA)

        # Display the final result
        st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
        st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)

        # Save the final result
        cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
        st.success('Saliency detection complete. Result saved as "example/result15.png".')

st.write('Finished, check the result at: example/result15.png')