VisAt / app.py
Tanzeer's picture
Upload 15 files
8395863
raw
history blame
1.91 kB
import streamlit as st
import torch
import cv2
from PIL import Image
import numpy as np
from torchvision import transforms
from TranSalNet_Res import TranSalNet # Make sure TranSalNet is accessible from your Streamlit app
# Load the model and set the device
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.eval() # Set the model to evaluation mode
device = torch.device('cpu')
model.to(device)
# Define Streamlit app
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Check if the user clicks a button
if st.button('Detect Saliency'):
# Preprocess the image
img = image.resize((384, 288))
img = np.array(img) / 255.
img = np.transpose(img, (2, 0, 1))
img = torch.from_numpy(img).unsqueeze(0).float()
img = img.to(device)
# Get saliency prediction
with torch.no_grad():
pred_saliency = model(img)
# Convert the result back to a PIL image
toPIL = transforms.ToPILImage()
pic = toPIL(pred_saliency.squeeze())
# Colorize the grayscale prediction
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)
# Ensure the colorized image has the same dimensions as the original image
original_img = np.array(image)
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
# You can add more post-processing here if needed
# Display the final result
st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
st.write('Finished!')