VisAt / app.py
Tanzeer's picture
Update app.py
3a6ce4c
raw
history blame
3.96 kB
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, models
from PIL import Image
from TranSalNet_Res import TranSalNet
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
def count_and_label_red_patches(heatmap, threshold=200):
red_mask = heatmap[:, :, 2] > threshold
_, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8)
num_red_patches = labels.max()
original_image = np.array(image)
for i in range(1, num_red_patches + 1):
patch_mask = (labels == i)
patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2)
radius = 20 # Adjust the following variable to manage the circle image
circle_color = (0, 0, 0) # The circle is black adjust the following to change the color
cv2.circle(original_image, (patch_centroid_x, patch_centroid_y), radius, circle_color, -1) # Draw the circle
# Lines code
for j in range(i + 1, num_red_patches + 1):
patch_mask_j = (labels == j)
patch_centroid_x_j, patch_centroid_y_j = int(stats[j, cv2.CC_STAT_LEFT] + stats[j, cv2.CC_STAT_WIDTH] / 2), int(stats[j, cv2.CC_STAT_TOP] + stats[j, cv2.CC_STAT_HEIGHT] / 2)
line_color = (0, 0, 0) # Ajdust the following to manage the line color
cv2.line(original_image, (patch_centroid_x, patch_centroid_y), (patch_centroid_x_j, patch_centroid_y_j), line_color, 2) # Line
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_color = (255, 255, 255)
line_type = cv2.LINE_AA
cv2.putText(original_image, str(i), (patch_centroid_x - 10, patch_centroid_y + 10), font, font_scale, font_color, 2, line_type)
return original_image, num_red_patches
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Detect Saliency'):
img = image.resize((384, 288))
img = np.array(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Convert to BGR color space
img = np.array(img) / 255.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
pred_saliency = model(img).squeeze().detach().numpy()
heatmap = (pred_saliency * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Use a blue colormap (JET)
heatmap = cv2.resize(heatmap, (image.width, image.height))
enhanced_image = np.array(image)
b, g, r = cv2.split(enhanced_image)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
b_enhanced = clahe.apply(b)
enhanced_image = cv2.merge((b_enhanced, g, r))
alpha = 0.7
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
original_image, num_red_patches = count_and_label_red_patches(heatmap)
st.image(original_image, caption=f'Image with {num_red_patches} Red Patches', use_column_width=True, channels='RGB')
st.image(blended_img, caption='Blended Image', use_column_width=True, channels='BGR')
# Create a dir with the name example to save
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
st.success('Saliency detection complete. Result saved as "example/result15.png".')