Update app.py
Browse files
app.py
CHANGED
@@ -9,8 +9,6 @@ from tqdm import tqdm
|
|
9 |
import torch.nn as nn
|
10 |
from utils.data_process import preprocess_img, postprocess_img
|
11 |
|
12 |
-
|
13 |
-
|
14 |
# Load the model and set the device
|
15 |
device = torch.device('cpu')
|
16 |
model = TranSalNet()
|
@@ -29,6 +27,10 @@ if uploaded_image:
|
|
29 |
|
30 |
# Check if the user clicks a button
|
31 |
if st.button('Detect Saliency'):
|
|
|
|
|
|
|
|
|
32 |
# Preprocess the image
|
33 |
img = image.resize((384, 288))
|
34 |
img = np.array(img) / 100.
|
@@ -48,54 +50,31 @@ if uploaded_image:
|
|
48 |
|
49 |
# Ensure the colorized image has the same dimensions as the original image
|
50 |
original_img = np.array(image)
|
51 |
-
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0
|
52 |
-
|
53 |
-
# Compute intensity values from the colorized image
|
54 |
-
intensity_map = cv2.cvtColor(colorized_img, cv2.COLOR_BGR2GRAY)
|
55 |
-
|
56 |
-
# Threshold the intensity map to create a binary mask
|
57 |
-
_, binary_map = cv2.threshold(intensity_map, 255, 0, cv2.THRESH_BINARY)
|
58 |
-
|
59 |
-
# Find contours in the binary map
|
60 |
-
contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
61 |
-
|
62 |
-
# Sort the contours by area in descending order
|
63 |
-
contours = sorted(contours, key=cv2.contourArea, reverse=True)
|
64 |
|
65 |
# Create an empty label map for ranking based on area
|
66 |
-
label_map = np.zeros_like(
|
67 |
|
68 |
-
#
|
|
|
|
|
69 |
for i, contour in enumerate(contours):
|
70 |
M = cv2.moments(contour)
|
71 |
if M["m00"] == 0:
|
72 |
continue
|
73 |
center_x = int(M["m10"] / M["m00"])
|
74 |
center_y = int(M["m01"] / M["m00"])
|
75 |
-
cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, 255, 2, cv2.LINE_AA)
|
76 |
|
77 |
# Blend the colorized image with the original image
|
78 |
alpha = 0.7 # Adjust the alpha value to control blending strength
|
79 |
-
blended_img = cv2.addWeighted(
|
80 |
-
|
81 |
-
# Overlay the labels on the blended image
|
82 |
-
font = cv2.FONT_HERSHEY_SIMPLEX
|
83 |
-
for i in range(1, len(contours) + 1):
|
84 |
-
mask = (label_map == i).astype(np.uint8)
|
85 |
-
x, y, w, h = cv2.boundingRect(contours[i-1])
|
86 |
-
org = (x, y)
|
87 |
-
color = (255, 0, 0) # Blue color
|
88 |
-
thickness = 3
|
89 |
-
cv2.putText(blended_img, str(i), org, font, 1, color, thickness, cv2.LINE_AA)
|
90 |
-
|
91 |
-
|
92 |
|
93 |
# Display the final result
|
94 |
-
st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
|
95 |
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
|
96 |
|
97 |
# Save the final result
|
98 |
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
|
99 |
-
st.success('Saliency detection complete. Result saved as "example/result15.png"
|
100 |
|
101 |
st.write('Finished, check the result at: example/result15.png')
|
|
|
9 |
import torch.nn as nn
|
10 |
from utils.data_process import preprocess_img, postprocess_img
|
11 |
|
|
|
|
|
12 |
# Load the model and set the device
|
13 |
device = torch.device('cpu')
|
14 |
model = TranSalNet()
|
|
|
27 |
|
28 |
# Check if the user clicks a button
|
29 |
if st.button('Detect Saliency'):
|
30 |
+
# Create a blue background image
|
31 |
+
blue_background = np.zeros_like(np.array(image))
|
32 |
+
blue_background[:] = (0, 0, 255) # Set the background to blue (in BGR format)
|
33 |
+
|
34 |
# Preprocess the image
|
35 |
img = image.resize((384, 288))
|
36 |
img = np.array(img) / 100.
|
|
|
50 |
|
51 |
# Ensure the colorized image has the same dimensions as the original image
|
52 |
original_img = np.array(image)
|
53 |
+
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
# Create an empty label map for ranking based on area
|
56 |
+
label_map = np.zeros_like(colorized_img)
|
57 |
|
58 |
+
# Overlay the labels on the blended image
|
59 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
60 |
+
contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
61 |
for i, contour in enumerate(contours):
|
62 |
M = cv2.moments(contour)
|
63 |
if M["m00"] == 0:
|
64 |
continue
|
65 |
center_x = int(M["m10"] / M["m00"])
|
66 |
center_y = int(M["m01"] / M["m00"])
|
67 |
+
cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
|
68 |
|
69 |
# Blend the colorized image with the original image
|
70 |
alpha = 0.7 # Adjust the alpha value to control blending strength
|
71 |
+
blended_img = cv2.addWeighted(blue_background, 1 - alpha, colorized_img, alpha, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
# Display the final result
|
|
|
74 |
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
|
75 |
|
76 |
# Save the final result
|
77 |
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
|
78 |
+
st.success('Saliency detection complete. Result saved as "example/result15.png"')
|
79 |
|
80 |
st.write('Finished, check the result at: example/result15.png')
|