Tanzeer commited on
Commit
3a6ce4c
·
1 Parent(s): c0f6a4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -12
app.py CHANGED
@@ -5,7 +5,6 @@ import torch
5
  from torchvision import transforms, models
6
  from PIL import Image
7
  from TranSalNet_Res import TranSalNet
8
- from tqdm import tqdm
9
  import torch.nn as nn
10
  from utils.data_process import preprocess_img, postprocess_img
11
 
@@ -18,15 +17,31 @@ model.eval()
18
  def count_and_label_red_patches(heatmap, threshold=200):
19
  red_mask = heatmap[:, :, 2] > threshold
20
  _, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8)
21
-
22
  num_red_patches = labels.max()
23
 
 
 
24
  for i in range(1, num_red_patches + 1):
25
  patch_mask = (labels == i)
26
  patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2)
27
- cv2.putText(heatmap, str(i), (patch_centroid_x, patch_centroid_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)
28
-
29
- return heatmap, num_red_patches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  st.title('Saliency Detection App')
32
  st.write('Upload an image for saliency detection:')
@@ -38,6 +53,8 @@ if uploaded_image:
38
 
39
  if st.button('Detect Saliency'):
40
  img = image.resize((384, 288))
 
 
41
  img = np.array(img) / 255.
42
  img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
43
  img = torch.from_numpy(img)
@@ -50,22 +67,21 @@ if uploaded_image:
50
 
51
  heatmap = cv2.resize(heatmap, (image.width, image.height))
52
 
53
- heatmap, num_red_patches = count_and_label_red_patches(heatmap)
54
-
55
  enhanced_image = np.array(image)
56
  b, g, r = cv2.split(enhanced_image)
57
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
58
  b_enhanced = clahe.apply(b)
59
  enhanced_image = cv2.merge((b_enhanced, g, r))
60
 
61
- alpha = 0.7
62
  blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
63
 
64
- st.image(heatmap, caption='Enhanced Saliency Heatmap', use_column_width=True, channels='BGR')
65
- st.image(enhanced_image, caption='Enhanced Blue Image', use_column_width=True, channels='BGR')
 
66
 
67
- st.image(blended_img, caption=f'Blended Image with {num_red_patches} Red Patches', use_column_width=True, channels='BGR')
68
 
69
- # Create a dir with name example to save
70
  cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
71
  st.success('Saliency detection complete. Result saved as "example/result15.png".')
 
5
  from torchvision import transforms, models
6
  from PIL import Image
7
  from TranSalNet_Res import TranSalNet
 
8
  import torch.nn as nn
9
  from utils.data_process import preprocess_img, postprocess_img
10
 
 
17
  def count_and_label_red_patches(heatmap, threshold=200):
18
  red_mask = heatmap[:, :, 2] > threshold
19
  _, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8)
 
20
  num_red_patches = labels.max()
21
 
22
+ original_image = np.array(image)
23
+
24
  for i in range(1, num_red_patches + 1):
25
  patch_mask = (labels == i)
26
  patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2)
27
+ radius = 20 # Adjust the following variable to manage the circle image
28
+ circle_color = (0, 0, 0) # The circle is black adjust the following to change the color
29
+ cv2.circle(original_image, (patch_centroid_x, patch_centroid_y), radius, circle_color, -1) # Draw the circle
30
+
31
+ # Lines code
32
+ for j in range(i + 1, num_red_patches + 1):
33
+ patch_mask_j = (labels == j)
34
+ patch_centroid_x_j, patch_centroid_y_j = int(stats[j, cv2.CC_STAT_LEFT] + stats[j, cv2.CC_STAT_WIDTH] / 2), int(stats[j, cv2.CC_STAT_TOP] + stats[j, cv2.CC_STAT_HEIGHT] / 2)
35
+ line_color = (0, 0, 0) # Ajdust the following to manage the line color
36
+ cv2.line(original_image, (patch_centroid_x, patch_centroid_y), (patch_centroid_x_j, patch_centroid_y_j), line_color, 2) # Line
37
+
38
+ font = cv2.FONT_HERSHEY_SIMPLEX
39
+ font_scale = 1
40
+ font_color = (255, 255, 255)
41
+ line_type = cv2.LINE_AA
42
+ cv2.putText(original_image, str(i), (patch_centroid_x - 10, patch_centroid_y + 10), font, font_scale, font_color, 2, line_type)
43
+
44
+ return original_image, num_red_patches
45
 
46
  st.title('Saliency Detection App')
47
  st.write('Upload an image for saliency detection:')
 
53
 
54
  if st.button('Detect Saliency'):
55
  img = image.resize((384, 288))
56
+ img = np.array(img)
57
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Convert to BGR color space
58
  img = np.array(img) / 255.
59
  img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
60
  img = torch.from_numpy(img)
 
67
 
68
  heatmap = cv2.resize(heatmap, (image.width, image.height))
69
 
 
 
70
  enhanced_image = np.array(image)
71
  b, g, r = cv2.split(enhanced_image)
72
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
73
  b_enhanced = clahe.apply(b)
74
  enhanced_image = cv2.merge((b_enhanced, g, r))
75
 
76
+ alpha = 0.7
77
  blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
78
 
79
+ original_image, num_red_patches = count_and_label_red_patches(heatmap)
80
+
81
+ st.image(original_image, caption=f'Image with {num_red_patches} Red Patches', use_column_width=True, channels='RGB')
82
 
83
+ st.image(blended_img, caption='Blended Image', use_column_width=True, channels='BGR')
84
 
85
+ # Create a dir with the name example to save
86
  cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
87
  st.success('Saliency detection complete. Result saved as "example/result15.png".')