File size: 3,935 Bytes
8395863 8271835 fc1d3e9 8271835 36945ed fc1d3e9 8395863 fc1d3e9 36945ed 8395863 fc1d3e9 8395863 8271835 8395863 8271835 8395863 8271835 e84f8ab 8271835 8395863 8271835 8395863 8271835 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, utils, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img
# Load the model and set the device
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
# Define Streamlit app
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Check if the user clicks a button
if st.button('Detect Saliency'):
# Preprocess the image
img = image.resize((384, 288))
img = np.array(img) / 255.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
# Get saliency prediction
pred_saliency = model(img)
# Convert the result back to a PIL image
toPIL = transforms.ToPILImage()
pic = toPIL(pred_saliency.squeeze())
# Colorize the grayscale prediction
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)
# Ensure the colorized image has the same dimensions as the original image
original_img = np.array(image)
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
# Compute intensity values from the colorized image
intensity_map = cv2.cvtColor(colorized_img, cv2.COLOR_BGR2GRAY)
# Threshold the intensity map to create a binary mask
_, binary_map = cv2.threshold(intensity_map, 0, 255, cv2.THRESH_BINARY)
# Find contours in the binary map
contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Sort the contours by area in descending order
contours = sorted(contours, key=cv2.contourArea, reverse=True)
# Create an empty label map for ranking based on area
label_map = np.zeros_like(intensity_map)
# Rank and label each region based on area
for i, contour in enumerate(contours):
M = cv2.moments(contour)
if M["m00"] == 0:
continue
center_x = int(M["m10"] / M["m00"])
center_y = int(M["m01"] / M["m00"])
cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, 255, 2, cv2.LINE_AA)
# Blend the colorized image with the original image
alpha = 0.4 # Adjust the alpha value to control blending strength
blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)
# Overlay the labels on the blended image
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(1, len(contours) + 1):
mask = (label_map == i).astype(np.uint8)
x, y, w, h = cv2.boundingRect(contours[i-1])
org = (x, y)
color = (0, 0, 255) # Red color
thickness = 2
cv2.putText(blended_img, str(i), org, font, 1, color, thickness, cv2.LINE_AA)
# Display the final result
st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
# Save the final result
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
st.success('Saliency detection complete. Result saved as "example/result15.png".')
st.write('Finished, check the result at: example/result15.png')
|