|
import streamlit as st |
|
import torch |
|
import cv2 |
|
from PIL import Image |
|
import numpy as np |
|
from torchvision import transforms |
|
from TranSalNet_Res import TranSalNet |
|
|
|
|
|
model = TranSalNet() |
|
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu'))) |
|
model.eval() |
|
device = torch.device('cpu') |
|
model.to(device) |
|
|
|
|
|
st.title('Saliency Detection App') |
|
st.write('Upload an image for saliency detection:') |
|
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_image: |
|
image = Image.open(uploaded_image) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
|
|
|
if st.button('Detect Saliency'): |
|
|
|
img = image.resize((384, 288)) |
|
img = np.array(img) / 255. |
|
img = np.transpose(img, (2, 0, 1)) |
|
img = torch.from_numpy(img).unsqueeze(0).float() |
|
img = img.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
pred_saliency = model(img) |
|
|
|
|
|
toPIL = transforms.ToPILImage() |
|
pic = toPIL(pred_saliency.squeeze()) |
|
|
|
|
|
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET) |
|
|
|
|
|
original_img = np.array(image) |
|
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0])) |
|
|
|
|
|
|
|
|
|
st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True) |
|
|
|
st.write('Finished!') |
|
|