File size: 3,983 Bytes
8395863 8271835 776dd3c 8271835 36945ed fc1d3e9 36945ed 8395863 fc1d3e9 8395863 660142e cb374b8 3a6ce4c cb374b8 3a6ce4c cb374b8 3a6ce4c cb374b8 660142e 8395863 3a6ce4c 660142e 8271835 8395863 660142e 8395863 660142e adda212 8395863 660142e 8395863 660142e 776dd3c 3a6ce4c 660142e 8271835 3a6ce4c 621f2db 3a6ce4c 8271835 cb374b8 83d61e1 776dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, models
from PIL import Image
from TranSalNet_Res import TranSalNet
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
def count_and_label_red_patches(heatmap, threshold=200):
red_mask = heatmap[:, :, 2] > threshold
contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Sort the contours based on their areas in descending order
contours = sorted(contours, key=cv2.contourArea, reverse=True)
original_image = np.array(image)
# Find the centroid of the red spot with the highest area
M_largest = cv2.moments(contours[0])
if M_largest["m00"] != 0:
cX_largest = int(M_largest["m10"] / M_largest["m00"])
cY_largest = int(M_largest["m01"] / M_largest["m00"])
else:
cX_largest, cY_largest = 0, 0
for i, contour in enumerate(contours, start=1):
# Compute the centroid of the current contour
M = cv2.moments(contour)
if M["m00"] != 0:
cX = int(M["m10"] / M["m00"])
cY = int(M["m01"] / M["m00"])
else:
cX, cY = 0, 0
radius = 20 # Adjust the circle radius to fit the numbers
circle_color = (0, 0, 0) # Blue color
cv2.circle(original_image, (cX, cY), radius, circle_color, -1) # Draw blue circle
# Connect the current red spot to the red spot with the highest area
line_color = (0, 0, 0) # Red color
cv2.line(original_image, (cX, cY), (cX_largest, cY_largest), line_color, 2)
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_color = (255, 255, 255)
line_type = cv2.LINE_AA
cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type)
return original_image, len(contours)
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Detect Saliency'):
img = image.resize((384, 288))
img = np.array(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Convert to BGR color space
img = np.array(img) / 255.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
pred_saliency = model(img).squeeze().detach().numpy()
heatmap = (pred_saliency * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Use a blue colormap (JET)
heatmap = cv2.resize(heatmap, (image.width, image.height))
enhanced_image = np.array(image)
b, g, r = cv2.split(enhanced_image)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
b_enhanced = clahe.apply(b)
enhanced_image = cv2.merge((b_enhanced, g, r))
alpha = 0.7
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
original_image, num_red_patches = count_and_label_red_patches(heatmap)
st.image(original_image, caption=f'Image with {num_red_patches} Red Patches', use_column_width=True, channels='RGB')
st.image(blended_img, caption='Blended Image', use_column_width=True, channels='BGR')
# Create a dir with name example to save
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
st.success('Saliency detection complete. Result saved as "example/result15.png".')
|