|
import streamlit as st |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from torchvision import transforms, utils, models |
|
from PIL import Image |
|
from TranSalNet_Res import TranSalNet |
|
from tqdm import tqdm |
|
import torch.nn as nn |
|
from utils.data_process import preprocess_img, postprocess_img |
|
|
|
|
|
|
|
|
|
device = torch.device('cpu') |
|
model = TranSalNet() |
|
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu'))) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
st.title('Saliency Detection App') |
|
st.write('Upload an image for saliency detection:') |
|
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_image: |
|
image = Image.open(uploaded_image) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
|
|
|
if st.button('Detect Saliency'): |
|
|
|
img = image.resize((384, 288)) |
|
img = np.array(img) / 255. |
|
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0) |
|
img = torch.from_numpy(img) |
|
img = img.type(torch.FloatTensor).to(device) |
|
|
|
|
|
pred_saliency = model(img) |
|
|
|
|
|
toPIL = transforms.ToPILImage() |
|
pic = toPIL(pred_saliency.squeeze()) |
|
|
|
|
|
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET) |
|
|
|
|
|
original_img = np.array(image) |
|
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0])) |
|
|
|
|
|
intensity_map = cv2.cvtColor(colorized_img, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
_, binary_map = cv2.threshold(intensity_map, 0, 255, cv2.THRESH_BINARY) |
|
|
|
|
|
contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
contours = sorted(contours, key=cv2.contourArea, reverse=True) |
|
|
|
|
|
label_map = np.zeros_like(intensity_map) |
|
|
|
|
|
for i, contour in enumerate(contours): |
|
M = cv2.moments(contour) |
|
if M["m00"] == 0: |
|
continue |
|
center_x = int(M["m10"] / M["m00"]) |
|
center_y = int(M["m01"] / M["m00"]) |
|
cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, 255, 2, cv2.LINE_AA) |
|
|
|
|
|
alpha = 0.4 |
|
blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0) |
|
|
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
for i in range(1, len(contours) + 1): |
|
mask = (label_map == i).astype(np.uint8) |
|
x, y, w, h = cv2.boundingRect(contours[i-1]) |
|
org = (x, y) |
|
color = (0, 0, 255) |
|
thickness = 2 |
|
cv2.putText(blended_img, str(i), org, font, 1, color, thickness, cv2.LINE_AA) |
|
|
|
|
|
st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True) |
|
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True) |
|
|
|
|
|
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200]) |
|
st.success('Saliency detection complete. Result saved as "example/result15.png".') |
|
|
|
st.write('Finished, check the result at: example/result15.png') |
|
|