File size: 3,935 Bytes
8395863
 
 
8271835
fc1d3e9
8271835
36945ed
fc1d3e9
 
 
 
 
8395863
 
fc1d3e9
36945ed
8395863
 
fc1d3e9
8395863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8271835
 
 
8395863
 
8271835
8395863
 
 
 
 
 
 
 
 
 
 
 
8271835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e84f8ab
8271835
 
 
 
 
 
 
 
 
 
 
8395863
 
 
8271835
 
 
 
 
8395863
8271835
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, utils, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img



# Load the model and set the device
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()

# Define Streamlit app
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_image:
    image = Image.open(uploaded_image)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    # Check if the user clicks a button
    if st.button('Detect Saliency'):
        # Preprocess the image
        img = image.resize((384, 288))
        img = np.array(img) / 255.
        img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
        img = torch.from_numpy(img)
        img = img.type(torch.FloatTensor).to(device)

        # Get saliency prediction
        pred_saliency = model(img)

        # Convert the result back to a PIL image
        toPIL = transforms.ToPILImage()
        pic = toPIL(pred_saliency.squeeze())

        # Colorize the grayscale prediction
        colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)

        # Ensure the colorized image has the same dimensions as the original image
        original_img = np.array(image)
        colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))

        # Compute intensity values from the colorized image
        intensity_map = cv2.cvtColor(colorized_img, cv2.COLOR_BGR2GRAY)

        # Threshold the intensity map to create a binary mask
        _, binary_map = cv2.threshold(intensity_map, 0, 255, cv2.THRESH_BINARY)

        # Find contours in the binary map
        contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Sort the contours by area in descending order
        contours = sorted(contours, key=cv2.contourArea, reverse=True)

        # Create an empty label map for ranking based on area
        label_map = np.zeros_like(intensity_map)

        # Rank and label each region based on area
        for i, contour in enumerate(contours):
            M = cv2.moments(contour)
            if M["m00"] == 0:
                continue
            center_x = int(M["m10"] / M["m00"])
            center_y = int(M["m01"] / M["m00"])
            cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, 255, 2, cv2.LINE_AA)

        # Blend the colorized image with the original image
        alpha = 0.4  # Adjust the alpha value to control blending strength
        blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)

        # Overlay the labels on the blended image
        font = cv2.FONT_HERSHEY_SIMPLEX
        for i in range(1, len(contours) + 1):
            mask = (label_map == i).astype(np.uint8)
            x, y, w, h = cv2.boundingRect(contours[i-1])
            org = (x, y)
            color = (0, 0, 255)  # Red color
            thickness = 2
            cv2.putText(blended_img, str(i), org, font, 1, color, thickness, cv2.LINE_AA)

        # Display the final result
        st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
        st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)

        # Save the final result
        cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
        st.success('Saliency detection complete. Result saved as "example/result15.png".')

st.write('Finished, check the result at: example/result15.png')