VisAt / app.py
Tanzeer's picture
Update app.py
776dd3c
raw
history blame
3.06 kB
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Detect Saliency'):
img = image.resize((384, 288))
img = np.array(img) / 100.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
pred_saliency = model(img)
toPIL = transforms.ToPILImage()
pic = toPIL(pred_saliency.squeeze())
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_OCEAN)
original_img = np.array(image)
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
alpha = 0.7
blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)
# Find all contours
contours, _ = cv2.findContours(np.uint8(pred_saliency.squeeze().detach().numpy() * 255), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
saliency_8bit = np.uint8(pred_saliency.squeeze().detach().numpy() * 255)
# Apply dilation
kernel = np.ones((5,5),np.uint8)
dilated = cv2.dilate(saliency_8bit, kernel, iterations = 1)
# Find contours on dilated image
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
font = cv2.FONT_HERSHEY_SIMPLEX
label = 1
for contour in contours:
# Get bounding box for contour
x, y, w, h = cv2.boundingRect(contour)
# Calculate center of bounding box
center_x = x + w // 2
center_y = y + h // 2
# Find point on contour closest to center of bounding box
distances = np.sqrt((contour[:,0,0] - center_x)**2 + (contour[:,0,1] - center_y)**2)
min_index = np.argmin(distances)
closest_point = tuple(contour[min_index][0])
# Place label at closest point on contour
cv2.putText(blended_img, str(label), closest_point, font, 1, (0, 0, 255), 3, cv2.LINE_AA)
label += 1
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
st.success('Saliency detection complete. Result saved as "example/result15.png".')