File size: 1,907 Bytes
8395863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import streamlit as st
import torch
import cv2
from PIL import Image
import numpy as np
from torchvision import transforms
from TranSalNet_Res import TranSalNet  # Make sure TranSalNet is accessible from your Streamlit app

# Load the model and set the device
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.eval()  # Set the model to evaluation mode
device = torch.device('cpu')
model.to(device)

# Define Streamlit app
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_image:
    image = Image.open(uploaded_image)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    # Check if the user clicks a button
    if st.button('Detect Saliency'):
        # Preprocess the image
        img = image.resize((384, 288))
        img = np.array(img) / 255.
        img = np.transpose(img, (2, 0, 1))
        img = torch.from_numpy(img).unsqueeze(0).float()
        img = img.to(device)

        # Get saliency prediction
        with torch.no_grad():
            pred_saliency = model(img)

        # Convert the result back to a PIL image
        toPIL = transforms.ToPILImage()
        pic = toPIL(pred_saliency.squeeze())

        # Colorize the grayscale prediction
        colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)

        # Ensure the colorized image has the same dimensions as the original image
        original_img = np.array(image)
        colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))

        # You can add more post-processing here if needed

        # Display the final result
        st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)

st.write('Finished!')