VisAt / app.py
Tanzeer's picture
Update app.py
fc1d3e9
raw
history blame
3.94 kB
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, utils, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img
# Load the model and set the device
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
# Define Streamlit app
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Check if the user clicks a button
if st.button('Detect Saliency'):
# Preprocess the image
img = image.resize((384, 288))
img = np.array(img) / 255.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
# Get saliency prediction
pred_saliency = model(img)
# Convert the result back to a PIL image
toPIL = transforms.ToPILImage()
pic = toPIL(pred_saliency.squeeze())
# Colorize the grayscale prediction
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)
# Ensure the colorized image has the same dimensions as the original image
original_img = np.array(image)
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
# Compute intensity values from the colorized image
intensity_map = cv2.cvtColor(colorized_img, cv2.COLOR_BGR2GRAY)
# Threshold the intensity map to create a binary mask
_, binary_map = cv2.threshold(intensity_map, 0, 255, cv2.THRESH_BINARY)
# Find contours in the binary map
contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Sort the contours by area in descending order
contours = sorted(contours, key=cv2.contourArea, reverse=True)
# Create an empty label map for ranking based on area
label_map = np.zeros_like(intensity_map)
# Rank and label each region based on area
for i, contour in enumerate(contours):
M = cv2.moments(contour)
if M["m00"] == 0:
continue
center_x = int(M["m10"] / M["m00"])
center_y = int(M["m01"] / M["m00"])
cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, 255, 2, cv2.LINE_AA)
# Blend the colorized image with the original image
alpha = 0.4 # Adjust the alpha value to control blending strength
blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)
# Overlay the labels on the blended image
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(1, len(contours) + 1):
mask = (label_map == i).astype(np.uint8)
x, y, w, h = cv2.boundingRect(contours[i-1])
org = (x, y)
color = (0, 0, 255) # Red color
thickness = 2
cv2.putText(blended_img, str(i), org, font, 1, color, thickness, cv2.LINE_AA)
# Display the final result
st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
# Save the final result
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
st.success('Saliency detection complete. Result saved as "example/result15.png".')
st.write('Finished, check the result at: example/result15.png')