|
import streamlit as st |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from torchvision import transforms, models |
|
from PIL import Image |
|
from TranSalNet_Res import TranSalNet |
|
from tqdm import tqdm |
|
import torch.nn as nn |
|
from utils.data_process import preprocess_img, postprocess_img |
|
|
|
device = torch.device('cpu') |
|
model = TranSalNet() |
|
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu'))) |
|
model.to(device) |
|
model.eval() |
|
|
|
st.title('Saliency Detection App') |
|
st.write('Upload an image for saliency detection:') |
|
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_image: |
|
image = Image.open(uploaded_image) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
|
if st.button('Detect Saliency'): |
|
img = image.resize((384, 288)) |
|
img = np.array(img) / 100. |
|
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0) |
|
img = torch.from_numpy(img) |
|
img = img.type(torch.FloatTensor).to(device) |
|
|
|
pred_saliency = model(img) |
|
|
|
toPIL = transforms.ToPILImage() |
|
pic = toPIL(pred_saliency.squeeze()) |
|
|
|
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_OCEAN) |
|
|
|
original_img = np.array(image) |
|
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0])) |
|
|
|
alpha = 0.7 |
|
blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0) |
|
|
|
|
|
contours, _ = cv2.findContours(np.uint8(pred_saliency.squeeze().detach().numpy() * 255), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
saliency_8bit = np.uint8(pred_saliency.squeeze().detach().numpy() * 255) |
|
|
|
|
|
kernel = np.ones((5,5),np.uint8) |
|
dilated = cv2.dilate(saliency_8bit, kernel, iterations = 1) |
|
|
|
|
|
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
label = 1 |
|
for contour in contours: |
|
|
|
x, y, w, h = cv2.boundingRect(contour) |
|
|
|
|
|
center_x = x + w // 2 |
|
center_y = y + h // 2 |
|
|
|
|
|
distances = np.sqrt((contour[:,0,0] - center_x)**2 + (contour[:,0,1] - center_y)**2) |
|
min_index = np.argmin(distances) |
|
closest_point = tuple(contour[min_index][0]) |
|
|
|
|
|
cv2.putText(blended_img, str(label), closest_point, font, 1, (0, 0, 255), 3, cv2.LINE_AA) |
|
|
|
label += 1 |
|
|
|
|
|
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True) |
|
|
|
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200]) |
|
st.success('Saliency detection complete. Result saved as "example/result15.png".') |
|
|