File size: 1,907 Bytes
8395863 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import streamlit as st
import torch
import cv2
from PIL import Image
import numpy as np
from torchvision import transforms
from TranSalNet_Res import TranSalNet # Make sure TranSalNet is accessible from your Streamlit app
# Load the model and set the device
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.eval() # Set the model to evaluation mode
device = torch.device('cpu')
model.to(device)
# Define Streamlit app
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Check if the user clicks a button
if st.button('Detect Saliency'):
# Preprocess the image
img = image.resize((384, 288))
img = np.array(img) / 255.
img = np.transpose(img, (2, 0, 1))
img = torch.from_numpy(img).unsqueeze(0).float()
img = img.to(device)
# Get saliency prediction
with torch.no_grad():
pred_saliency = model(img)
# Convert the result back to a PIL image
toPIL = transforms.ToPILImage()
pic = toPIL(pred_saliency.squeeze())
# Colorize the grayscale prediction
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)
# Ensure the colorized image has the same dimensions as the original image
original_img = np.array(image)
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
# You can add more post-processing here if needed
# Display the final result
st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
st.write('Finished!')
|