VisAt / app.py
Tanzeer's picture
Update app.py
8d1d4cf
raw
history blame
3.33 kB
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, utils, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img
# Load the model and set the device
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
# Define Streamlit app
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Check if the user clicks a button
if st.button('Detect Saliency'):
# Create a blue background image with the same dimensions as the original image
blue_background = np.zeros_like(np.array(image))
blue_background[:] = (255, 0, 0) # Set the background to blue (in BGR format)
# Preprocess the image
img = image.resize((384, 288))
img = np.array(img) / 255.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
# Get saliency prediction
pred_saliency = model(img)
# Convert the result back to a PIL image
toPIL = transforms.ToPILImage()
pic = toPIL(pred_saliency.squeeze())
# Colorize the grayscale prediction
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)
# Ensure the colorized image has the same dimensions as the original image
original_img = np.array(image)
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
# Create an empty label map for ranking based on area
label_map = np.zeros_like(colorized_img)
_, binary_map = cv2.threshold(intensity_map, 255, 0, cv2.THRESH_BINARY)
# Overlay the labels on the blended image
font = cv2.FONT_HERSHEY_SIMPLEX
contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for i, contour in enumerate(contours):
M = cv2.moments(contour)
if M["m00"] == 0:
continue
center_x = int(M["m10"] / M["m00"])
center_y = int(M["m01"] / M["m00"])
cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
# Blend the colorized image with the blue background
alpha = 0.3 # Adjust the alpha value to control blending strength
blended_img = cv2.addWeighted(blue_background, 1 - alpha, colorized_img, alpha, 0)
# Display the final result
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
# Save the final result
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
st.success('Saliency detection complete. Result saved as "example/result15.png"')
st.write('Finished, check the result at: example/result15.png')