File size: 3,328 Bytes
8395863 8271835 fc1d3e9 8271835 36945ed fc1d3e9 8395863 fc1d3e9 36945ed 8395863 fc1d3e9 8395863 a001c81 d870d2d a001c81 d870d2d 8395863 1faf377 8271835 8395863 8271835 8395863 2ed70b2 8271835 d870d2d 8271835 8d1d4cf d870d2d 8271835 d870d2d 8271835 a001c81 37727c6 d870d2d 621f2db 8395863 8271835 83d61e1 d870d2d 8395863 8271835 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, utils, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img
# Load the model and set the device
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
# Define Streamlit app
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Check if the user clicks a button
if st.button('Detect Saliency'):
# Create a blue background image with the same dimensions as the original image
blue_background = np.zeros_like(np.array(image))
blue_background[:] = (255, 0, 0) # Set the background to blue (in BGR format)
# Preprocess the image
img = image.resize((384, 288))
img = np.array(img) / 255.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
# Get saliency prediction
pred_saliency = model(img)
# Convert the result back to a PIL image
toPIL = transforms.ToPILImage()
pic = toPIL(pred_saliency.squeeze())
# Colorize the grayscale prediction
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)
# Ensure the colorized image has the same dimensions as the original image
original_img = np.array(image)
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
# Create an empty label map for ranking based on area
label_map = np.zeros_like(colorized_img)
_, binary_map = cv2.threshold(intensity_map, 255, 0, cv2.THRESH_BINARY)
# Overlay the labels on the blended image
font = cv2.FONT_HERSHEY_SIMPLEX
contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for i, contour in enumerate(contours):
M = cv2.moments(contour)
if M["m00"] == 0:
continue
center_x = int(M["m10"] / M["m00"])
center_y = int(M["m01"] / M["m00"])
cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
# Blend the colorized image with the blue background
alpha = 0.3 # Adjust the alpha value to control blending strength
blended_img = cv2.addWeighted(blue_background, 1 - alpha, colorized_img, alpha, 0)
# Display the final result
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
# Save the final result
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
st.success('Saliency detection complete. Result saved as "example/result15.png"')
st.write('Finished, check the result at: example/result15.png')
|