File size: 2,983 Bytes
8395863 8271835 776dd3c 8271835 36945ed fc1d3e9 36945ed 8395863 fc1d3e9 8395863 660142e c0f6a4b 660142e 8395863 660142e 8271835 8395863 660142e 8395863 660142e adda212 8395863 660142e 8395863 660142e 8271835 660142e 776dd3c c0f6a4b 660142e 8271835 660142e 621f2db 660142e 8271835 c0f6a4b 83d61e1 776dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
def count_and_label_red_patches(heatmap, threshold=200):
red_mask = heatmap[:, :, 2] > threshold
_, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8)
num_red_patches = labels.max()
for i in range(1, num_red_patches + 1):
patch_mask = (labels == i)
patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2)
cv2.putText(heatmap, str(i), (patch_centroid_x, patch_centroid_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)
return heatmap, num_red_patches
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Detect Saliency'):
img = image.resize((384, 288))
img = np.array(img) / 255.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
pred_saliency = model(img).squeeze().detach().numpy()
heatmap = (pred_saliency * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Use a blue colormap (JET)
heatmap = cv2.resize(heatmap, (image.width, image.height))
heatmap, num_red_patches = count_and_label_red_patches(heatmap)
enhanced_image = np.array(image)
b, g, r = cv2.split(enhanced_image)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
b_enhanced = clahe.apply(b)
enhanced_image = cv2.merge((b_enhanced, g, r))
alpha = 0.7
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
st.image(heatmap, caption='Enhanced Saliency Heatmap', use_column_width=True, channels='BGR')
st.image(enhanced_image, caption='Enhanced Blue Image', use_column_width=True, channels='BGR')
st.image(blended_img, caption=f'Blended Image with {num_red_patches} Red Patches', use_column_width=True, channels='BGR')
# Create a dir with name example to save
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
st.success('Saliency detection complete. Result saved as "example/result15.png".')
|