File size: 3,459 Bytes
8395863 8271835 776dd3c 8271835 36945ed fc1d3e9 36945ed 8395863 fc1d3e9 8395863 660142e adda212 660142e 8395863 660142e 8271835 8395863 660142e 8395863 adda212 660142e adda212 8395863 adda212 660142e 8395863 adda212 660142e 8271835 adda212 660142e 776dd3c adda212 660142e 8271835 adda212 660142e 621f2db adda212 660142e 8271835 adda212 83d61e1 776dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
def count_and_label_red_patches(heatmap, threshold=200):
red_mask = heatmap[:, :, 2] > threshold
_, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8)
num_red_patches = labels.max()
# Iterate through the labeled patches and put sequential numbers on top
for i in range(1, num_red_patches + 1):
patch_mask = (labels == i)
patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2)
cv2.putText(heatmap, str(i), (patch_centroid_x, patch_centroid_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)
return heatmap, num_red_patches
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Detect Saliency'):
img = image.resize((384, 288))
img = np.array(img) / 255.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
pred_saliency = model(img).squeeze().detach().numpy()
# Convert the saliency map to a heatmap with a blue color map
heatmap = (pred_saliency * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Use a blue colormap (JET)
# Resize the heatmap to match the image dimensions
heatmap = cv2.resize(heatmap, (image.width, image.height))
# Overlay red patch labels on the heatmap
heatmap, num_red_patches = count_and_label_red_patches(heatmap)
# Convert the image to a NumPy array and enhance the blue channel using CLAHE
enhanced_image = np.array(image)
b, g, r = cv2.split(enhanced_image)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
b_enhanced = clahe.apply(b)
enhanced_image = cv2.merge((b_enhanced, g, r))
alpha = 0.7 # Adjust alpha to control the blending
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
# Display the images in your Streamlit app
st.image(heatmap, caption='Enhanced Saliency Heatmap', use_column_width=True, channels='BGR')
st.image(enhanced_image, caption='Enhanced Blue Image', use_column_width=True, channels='BGR')
# Overlay the red patch count on the blended image
st.image(blended_img, caption=f'Blended Image with {num_red_patches} Red Patches', use_column_width=True, channels='BGR')
# Save the result image
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
st.success('Saliency detection complete. Result saved as "example/result15.png".')
|