|
import streamlit as st |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from torchvision import transforms, models |
|
from PIL import Image |
|
from TranSalNet_Res import TranSalNet |
|
from tqdm import tqdm |
|
import torch.nn as nn |
|
from utils.data_process import preprocess_img, postprocess_img |
|
|
|
device = torch.device('cpu') |
|
model = TranSalNet() |
|
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu'))) |
|
model.to(device) |
|
model.eval() |
|
|
|
def count_and_label_red_patches(heatmap, threshold=200): |
|
red_mask = heatmap[:, :, 2] > threshold |
|
_, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8) |
|
|
|
num_red_patches = labels.max() |
|
|
|
|
|
for i in range(1, num_red_patches + 1): |
|
patch_mask = (labels == i) |
|
patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2) |
|
cv2.putText(heatmap, str(i), (patch_centroid_x, patch_centroid_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA) |
|
|
|
return heatmap, num_red_patches |
|
|
|
st.title('Saliency Detection App') |
|
st.write('Upload an image for saliency detection:') |
|
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_image: |
|
image = Image.open(uploaded_image) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
|
if st.button('Detect Saliency'): |
|
img = image.resize((384, 288)) |
|
img = np.array(img) / 255. |
|
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0) |
|
img = torch.from_numpy(img) |
|
img = img.type(torch.FloatTensor).to(device) |
|
|
|
pred_saliency = model(img).squeeze().detach().numpy() |
|
|
|
|
|
heatmap = (pred_saliency * 255).astype(np.uint8) |
|
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) |
|
|
|
|
|
heatmap = cv2.resize(heatmap, (image.width, image.height)) |
|
|
|
|
|
heatmap, num_red_patches = count_and_label_red_patches(heatmap) |
|
|
|
|
|
enhanced_image = np.array(image) |
|
b, g, r = cv2.split(enhanced_image) |
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
|
b_enhanced = clahe.apply(b) |
|
enhanced_image = cv2.merge((b_enhanced, g, r)) |
|
|
|
alpha = 0.7 |
|
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0) |
|
|
|
|
|
st.image(heatmap, caption='Enhanced Saliency Heatmap', use_column_width=True, channels='BGR') |
|
st.image(enhanced_image, caption='Enhanced Blue Image', use_column_width=True, channels='BGR') |
|
|
|
|
|
st.image(blended_img, caption=f'Blended Image with {num_red_patches} Red Patches', use_column_width=True, channels='BGR') |
|
|
|
|
|
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200]) |
|
st.success('Saliency detection complete. Result saved as "example/result15.png".') |
|
|