Update app.py
Browse files
app.py
CHANGED
@@ -2,21 +2,19 @@ import streamlit as st
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
-
from torchvision import transforms,
|
6 |
from PIL import Image
|
7 |
from TranSalNet_Res import TranSalNet
|
8 |
from tqdm import tqdm
|
9 |
import torch.nn as nn
|
10 |
from utils.data_process import preprocess_img, postprocess_img
|
11 |
|
12 |
-
# Load the model and set the device
|
13 |
device = torch.device('cpu')
|
14 |
model = TranSalNet()
|
15 |
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
|
16 |
model.to(device)
|
17 |
model.eval()
|
18 |
|
19 |
-
# Define Streamlit app
|
20 |
st.title('Saliency Detection App')
|
21 |
st.write('Upload an image for saliency detection:')
|
22 |
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
@@ -25,60 +23,59 @@ if uploaded_image:
|
|
25 |
image = Image.open(uploaded_image)
|
26 |
st.image(image, caption='Uploaded Image', use_column_width=True)
|
27 |
|
28 |
-
# Check if the user clicks a button
|
29 |
if st.button('Detect Saliency'):
|
30 |
-
# Create a blue background image with the same dimensions as the original image
|
31 |
-
blue_background = np.zeros_like(np.array(image))
|
32 |
-
blue_background[:] = (255, 0, 0) # Set the background to blue (in BGR format)
|
33 |
-
|
34 |
-
# Preprocess the image
|
35 |
img = image.resize((384, 288))
|
36 |
-
img = np.array(img) /
|
37 |
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
38 |
img = torch.from_numpy(img)
|
39 |
img = img.type(torch.FloatTensor).to(device)
|
40 |
|
41 |
-
# Get saliency prediction
|
42 |
pred_saliency = model(img)
|
43 |
|
44 |
-
# Convert the result back to a PIL image
|
45 |
toPIL = transforms.ToPILImage()
|
46 |
pic = toPIL(pred_saliency.squeeze())
|
47 |
|
48 |
-
|
49 |
-
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)
|
50 |
|
51 |
-
# Ensure the colorized image has the same dimensions as the original image
|
52 |
original_img = np.array(image)
|
53 |
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
64 |
-
|
65 |
-
for
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
# Blend the colorized image with the blue background
|
74 |
-
alpha = 0.3 # Adjust the alpha value to control blending strength
|
75 |
-
blended_img = cv2.addWeighted(blue_background, 1 - alpha, colorized_img, alpha, 0)
|
76 |
|
77 |
-
# Display the final result
|
78 |
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
|
79 |
|
80 |
-
# Save the final result
|
81 |
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
|
82 |
-
st.success('Saliency detection complete. Result saved as "example/result15.png"')
|
83 |
-
|
84 |
-
st.write('Finished, check the result at: example/result15.png')
|
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
+
from torchvision import transforms, models
|
6 |
from PIL import Image
|
7 |
from TranSalNet_Res import TranSalNet
|
8 |
from tqdm import tqdm
|
9 |
import torch.nn as nn
|
10 |
from utils.data_process import preprocess_img, postprocess_img
|
11 |
|
|
|
12 |
device = torch.device('cpu')
|
13 |
model = TranSalNet()
|
14 |
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
|
15 |
model.to(device)
|
16 |
model.eval()
|
17 |
|
|
|
18 |
st.title('Saliency Detection App')
|
19 |
st.write('Upload an image for saliency detection:')
|
20 |
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
|
|
23 |
image = Image.open(uploaded_image)
|
24 |
st.image(image, caption='Uploaded Image', use_column_width=True)
|
25 |
|
|
|
26 |
if st.button('Detect Saliency'):
|
|
|
|
|
|
|
|
|
|
|
27 |
img = image.resize((384, 288))
|
28 |
+
img = np.array(img) / 100.
|
29 |
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
30 |
img = torch.from_numpy(img)
|
31 |
img = img.type(torch.FloatTensor).to(device)
|
32 |
|
|
|
33 |
pred_saliency = model(img)
|
34 |
|
|
|
35 |
toPIL = transforms.ToPILImage()
|
36 |
pic = toPIL(pred_saliency.squeeze())
|
37 |
|
38 |
+
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_OCEAN)
|
|
|
39 |
|
|
|
40 |
original_img = np.array(image)
|
41 |
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
|
42 |
|
43 |
+
alpha = 0.7
|
44 |
+
blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)
|
45 |
+
|
46 |
+
# Find all contours
|
47 |
+
contours, _ = cv2.findContours(np.uint8(pred_saliency.squeeze().detach().numpy() * 255), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
48 |
+
saliency_8bit = np.uint8(pred_saliency.squeeze().detach().numpy() * 255)
|
49 |
+
|
50 |
+
# Apply dilation
|
51 |
+
kernel = np.ones((5,5),np.uint8)
|
52 |
+
dilated = cv2.dilate(saliency_8bit, kernel, iterations = 1)
|
53 |
+
|
54 |
+
# Find contours on dilated image
|
55 |
+
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
56 |
+
|
57 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
58 |
+
label = 1
|
59 |
+
for contour in contours:
|
60 |
+
# Get bounding box for contour
|
61 |
+
x, y, w, h = cv2.boundingRect(contour)
|
62 |
+
|
63 |
+
# Calculate center of bounding box
|
64 |
+
center_x = x + w // 2
|
65 |
+
center_y = y + h // 2
|
66 |
+
|
67 |
+
# Find point on contour closest to center of bounding box
|
68 |
+
distances = np.sqrt((contour[:,0,0] - center_x)**2 + (contour[:,0,1] - center_y)**2)
|
69 |
+
min_index = np.argmin(distances)
|
70 |
+
closest_point = tuple(contour[min_index][0])
|
71 |
+
|
72 |
+
# Place label at closest point on contour
|
73 |
+
cv2.putText(blended_img, str(label), closest_point, font, 1, (0, 0, 255), 3, cv2.LINE_AA)
|
74 |
+
|
75 |
+
label += 1
|
76 |
|
|
|
|
|
|
|
77 |
|
|
|
78 |
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
|
79 |
|
|
|
80 |
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
|
81 |
+
st.success('Saliency detection complete. Result saved as "example/result15.png".')
|
|
|
|