Tanzeer commited on
Commit
776dd3c
·
1 Parent(s): 9ba4ed8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -39
app.py CHANGED
@@ -2,21 +2,19 @@ import streamlit as st
2
  import cv2
3
  import numpy as np
4
  import torch
5
- from torchvision import transforms, utils, models
6
  from PIL import Image
7
  from TranSalNet_Res import TranSalNet
8
  from tqdm import tqdm
9
  import torch.nn as nn
10
  from utils.data_process import preprocess_img, postprocess_img
11
 
12
- # Load the model and set the device
13
  device = torch.device('cpu')
14
  model = TranSalNet()
15
  model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
16
  model.to(device)
17
  model.eval()
18
 
19
- # Define Streamlit app
20
  st.title('Saliency Detection App')
21
  st.write('Upload an image for saliency detection:')
22
  uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
@@ -25,60 +23,59 @@ if uploaded_image:
25
  image = Image.open(uploaded_image)
26
  st.image(image, caption='Uploaded Image', use_column_width=True)
27
 
28
- # Check if the user clicks a button
29
  if st.button('Detect Saliency'):
30
- # Create a blue background image with the same dimensions as the original image
31
- blue_background = np.zeros_like(np.array(image))
32
- blue_background[:] = (255, 0, 0) # Set the background to blue (in BGR format)
33
-
34
- # Preprocess the image
35
  img = image.resize((384, 288))
36
- img = np.array(img) / 255.
37
  img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
38
  img = torch.from_numpy(img)
39
  img = img.type(torch.FloatTensor).to(device)
40
 
41
- # Get saliency prediction
42
  pred_saliency = model(img)
43
 
44
- # Convert the result back to a PIL image
45
  toPIL = transforms.ToPILImage()
46
  pic = toPIL(pred_saliency.squeeze())
47
 
48
- # Colorize the grayscale prediction
49
- colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)
50
 
51
- # Ensure the colorized image has the same dimensions as the original image
52
  original_img = np.array(image)
53
  colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
54
 
55
- # Create an empty label map for ranking based on area
56
- label_map = np.zeros_like(colorized_img)
57
-
58
- intensity_map = cv2.cvtColor(colorized_img, cv2.COLOR_BGR2GRAY)
59
-
60
- _, binary_map = cv2.threshold(intensity_map, 255, 0, cv2.THRESH_BINARY)
61
-
62
- # Overlay the labels on the blended image
 
 
 
 
 
 
63
  font = cv2.FONT_HERSHEY_SIMPLEX
64
- contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
65
- for i, contour in enumerate(contours):
66
- M = cv2.moments(contour)
67
- if M["m00"] == 0:
68
- continue
69
- center_x = int(M["m10"] / M["m00"])
70
- center_y = int(M["m01"] / M["m00"])
71
- cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Blend the colorized image with the blue background
74
- alpha = 0.3 # Adjust the alpha value to control blending strength
75
- blended_img = cv2.addWeighted(blue_background, 1 - alpha, colorized_img, alpha, 0)
76
 
77
- # Display the final result
78
  st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
79
 
80
- # Save the final result
81
  cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
82
- st.success('Saliency detection complete. Result saved as "example/result15.png"')
83
-
84
- st.write('Finished, check the result at: example/result15.png')
 
2
  import cv2
3
  import numpy as np
4
  import torch
5
+ from torchvision import transforms, models
6
  from PIL import Image
7
  from TranSalNet_Res import TranSalNet
8
  from tqdm import tqdm
9
  import torch.nn as nn
10
  from utils.data_process import preprocess_img, postprocess_img
11
 
 
12
  device = torch.device('cpu')
13
  model = TranSalNet()
14
  model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
15
  model.to(device)
16
  model.eval()
17
 
 
18
  st.title('Saliency Detection App')
19
  st.write('Upload an image for saliency detection:')
20
  uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
23
  image = Image.open(uploaded_image)
24
  st.image(image, caption='Uploaded Image', use_column_width=True)
25
 
 
26
  if st.button('Detect Saliency'):
 
 
 
 
 
27
  img = image.resize((384, 288))
28
+ img = np.array(img) / 100.
29
  img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
30
  img = torch.from_numpy(img)
31
  img = img.type(torch.FloatTensor).to(device)
32
 
 
33
  pred_saliency = model(img)
34
 
 
35
  toPIL = transforms.ToPILImage()
36
  pic = toPIL(pred_saliency.squeeze())
37
 
38
+ colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_OCEAN)
 
39
 
 
40
  original_img = np.array(image)
41
  colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
42
 
43
+ alpha = 0.7
44
+ blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)
45
+
46
+ # Find all contours
47
+ contours, _ = cv2.findContours(np.uint8(pred_saliency.squeeze().detach().numpy() * 255), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
48
+ saliency_8bit = np.uint8(pred_saliency.squeeze().detach().numpy() * 255)
49
+
50
+ # Apply dilation
51
+ kernel = np.ones((5,5),np.uint8)
52
+ dilated = cv2.dilate(saliency_8bit, kernel, iterations = 1)
53
+
54
+ # Find contours on dilated image
55
+ contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
56
+
57
  font = cv2.FONT_HERSHEY_SIMPLEX
58
+ label = 1
59
+ for contour in contours:
60
+ # Get bounding box for contour
61
+ x, y, w, h = cv2.boundingRect(contour)
62
+
63
+ # Calculate center of bounding box
64
+ center_x = x + w // 2
65
+ center_y = y + h // 2
66
+
67
+ # Find point on contour closest to center of bounding box
68
+ distances = np.sqrt((contour[:,0,0] - center_x)**2 + (contour[:,0,1] - center_y)**2)
69
+ min_index = np.argmin(distances)
70
+ closest_point = tuple(contour[min_index][0])
71
+
72
+ # Place label at closest point on contour
73
+ cv2.putText(blended_img, str(label), closest_point, font, 1, (0, 0, 255), 3, cv2.LINE_AA)
74
+
75
+ label += 1
76
 
 
 
 
77
 
 
78
  st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
79
 
 
80
  cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
81
+ st.success('Saliency detection complete. Result saved as "example/result15.png".')