File size: 3,246 Bytes
8395863
 
 
8271835
fc1d3e9
8271835
36945ed
fc1d3e9
 
 
 
8395863
fc1d3e9
36945ed
8395863
 
fc1d3e9
8395863
 
 
 
 
 
 
 
 
 
 
 
a001c81
d870d2d
a001c81
d870d2d
8395863
 
629ed6b
8271835
 
 
8395863
 
8271835
8395863
 
 
 
 
 
 
 
 
 
d870d2d
8271835
 
d870d2d
8271835
d870d2d
 
 
8271835
 
 
 
 
 
d870d2d
8271835
a001c81
5b880e7
d870d2d
621f2db
8395863
8271835
 
 
83d61e1
d870d2d
8395863
8271835
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, utils, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img

# Load the model and set the device
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()

# Define Streamlit app
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_image:
    image = Image.open(uploaded_image)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    # Check if the user clicks a button
    if st.button('Detect Saliency'):
        # Create a blue background image with the same dimensions as the original image
        blue_background = np.zeros_like(np.array(image))
        blue_background[:] = (255, 0, 0)  # Set the background to blue (in BGR format)

        # Preprocess the image
        img = image.resize((384, 288))
        img = np.array(img) / 100.
        img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
        img = torch.from_numpy(img)
        img = img.type(torch.FloatTensor).to(device)

        # Get saliency prediction
        pred_saliency = model(img)

        # Convert the result back to a PIL image
        toPIL = transforms.ToPILImage()
        pic = toPIL(pred_saliency.squeeze())

        # Colorize the grayscale prediction
        colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)

        # Ensure the colorized image has the same dimensions as the original image
        original_img = np.array(image)
        colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0))

        # Create an empty label map for ranking based on area
        label_map = np.zeros_like(colorized_img)

        # Overlay the labels on the blended image
        font = cv2.FONT_HERSHEY_SIMPLEX
        contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for i, contour in enumerate(contours):
            M = cv2.moments(contour)
            if M["m00"] == 0:
                continue
            center_x = int(M["m10"] / M["m00"])
            center_y = int(M["m01"] / M["m00"])
            cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)

        # Blend the colorized image with the blue background
        alpha = 0.7  # Adjust the alpha value to control blending strength
        blended_img = cv2.addWeighted(blue_background, 1 - alpha, colorized_img, alpha, 0)

        # Display the final result
        st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)

        # Save the final result
        cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
        st.success('Saliency detection complete. Result saved as "example/result15.png"')

st.write('Finished, check the result at: example/result15.png')