Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,6 @@ import torch
|
|
5 |
from torchvision import transforms, models
|
6 |
from PIL import Image
|
7 |
from TranSalNet_Res import TranSalNet
|
8 |
-
from tqdm import tqdm
|
9 |
import torch.nn as nn
|
10 |
from utils.data_process import preprocess_img, postprocess_img
|
11 |
|
@@ -18,15 +17,31 @@ model.eval()
|
|
18 |
def count_and_label_red_patches(heatmap, threshold=200):
|
19 |
red_mask = heatmap[:, :, 2] > threshold
|
20 |
_, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8)
|
21 |
-
|
22 |
num_red_patches = labels.max()
|
23 |
|
|
|
|
|
24 |
for i in range(1, num_red_patches + 1):
|
25 |
patch_mask = (labels == i)
|
26 |
patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2)
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
st.title('Saliency Detection App')
|
32 |
st.write('Upload an image for saliency detection:')
|
@@ -38,6 +53,8 @@ if uploaded_image:
|
|
38 |
|
39 |
if st.button('Detect Saliency'):
|
40 |
img = image.resize((384, 288))
|
|
|
|
|
41 |
img = np.array(img) / 255.
|
42 |
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
43 |
img = torch.from_numpy(img)
|
@@ -50,22 +67,21 @@ if uploaded_image:
|
|
50 |
|
51 |
heatmap = cv2.resize(heatmap, (image.width, image.height))
|
52 |
|
53 |
-
heatmap, num_red_patches = count_and_label_red_patches(heatmap)
|
54 |
-
|
55 |
enhanced_image = np.array(image)
|
56 |
b, g, r = cv2.split(enhanced_image)
|
57 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
58 |
b_enhanced = clahe.apply(b)
|
59 |
enhanced_image = cv2.merge((b_enhanced, g, r))
|
60 |
|
61 |
-
alpha = 0.7
|
62 |
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
|
63 |
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
-
st.image(blended_img, caption=
|
68 |
|
69 |
-
# Create a dir with name example to save
|
70 |
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
|
71 |
st.success('Saliency detection complete. Result saved as "example/result15.png".')
|
|
|
5 |
from torchvision import transforms, models
|
6 |
from PIL import Image
|
7 |
from TranSalNet_Res import TranSalNet
|
|
|
8 |
import torch.nn as nn
|
9 |
from utils.data_process import preprocess_img, postprocess_img
|
10 |
|
|
|
17 |
def count_and_label_red_patches(heatmap, threshold=200):
|
18 |
red_mask = heatmap[:, :, 2] > threshold
|
19 |
_, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8)
|
|
|
20 |
num_red_patches = labels.max()
|
21 |
|
22 |
+
original_image = np.array(image)
|
23 |
+
|
24 |
for i in range(1, num_red_patches + 1):
|
25 |
patch_mask = (labels == i)
|
26 |
patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2)
|
27 |
+
radius = 20 # Adjust the following variable to manage the circle image
|
28 |
+
circle_color = (0, 0, 0) # The circle is black adjust the following to change the color
|
29 |
+
cv2.circle(original_image, (patch_centroid_x, patch_centroid_y), radius, circle_color, -1) # Draw the circle
|
30 |
+
|
31 |
+
# Lines code
|
32 |
+
for j in range(i + 1, num_red_patches + 1):
|
33 |
+
patch_mask_j = (labels == j)
|
34 |
+
patch_centroid_x_j, patch_centroid_y_j = int(stats[j, cv2.CC_STAT_LEFT] + stats[j, cv2.CC_STAT_WIDTH] / 2), int(stats[j, cv2.CC_STAT_TOP] + stats[j, cv2.CC_STAT_HEIGHT] / 2)
|
35 |
+
line_color = (0, 0, 0) # Ajdust the following to manage the line color
|
36 |
+
cv2.line(original_image, (patch_centroid_x, patch_centroid_y), (patch_centroid_x_j, patch_centroid_y_j), line_color, 2) # Line
|
37 |
+
|
38 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
39 |
+
font_scale = 1
|
40 |
+
font_color = (255, 255, 255)
|
41 |
+
line_type = cv2.LINE_AA
|
42 |
+
cv2.putText(original_image, str(i), (patch_centroid_x - 10, patch_centroid_y + 10), font, font_scale, font_color, 2, line_type)
|
43 |
+
|
44 |
+
return original_image, num_red_patches
|
45 |
|
46 |
st.title('Saliency Detection App')
|
47 |
st.write('Upload an image for saliency detection:')
|
|
|
53 |
|
54 |
if st.button('Detect Saliency'):
|
55 |
img = image.resize((384, 288))
|
56 |
+
img = np.array(img)
|
57 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Convert to BGR color space
|
58 |
img = np.array(img) / 255.
|
59 |
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
60 |
img = torch.from_numpy(img)
|
|
|
67 |
|
68 |
heatmap = cv2.resize(heatmap, (image.width, image.height))
|
69 |
|
|
|
|
|
70 |
enhanced_image = np.array(image)
|
71 |
b, g, r = cv2.split(enhanced_image)
|
72 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
73 |
b_enhanced = clahe.apply(b)
|
74 |
enhanced_image = cv2.merge((b_enhanced, g, r))
|
75 |
|
76 |
+
alpha = 0.7
|
77 |
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
|
78 |
|
79 |
+
original_image, num_red_patches = count_and_label_red_patches(heatmap)
|
80 |
+
|
81 |
+
st.image(original_image, caption=f'Image with {num_red_patches} Red Patches', use_column_width=True, channels='RGB')
|
82 |
|
83 |
+
st.image(blended_img, caption='Blended Image', use_column_width=True, channels='BGR')
|
84 |
|
85 |
+
# Create a dir with the name example to save
|
86 |
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
|
87 |
st.success('Saliency detection complete. Result saved as "example/result15.png".')
|