Tanzeer commited on
Commit
660142e
·
1 Parent(s): 776dd3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -40
app.py CHANGED
@@ -15,6 +15,19 @@ model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_loc
15
  model.to(device)
16
  model.eval()
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  st.title('Saliency Detection App')
19
  st.write('Upload an image for saliency detection:')
20
  uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
@@ -25,57 +38,33 @@ if uploaded_image:
25
 
26
  if st.button('Detect Saliency'):
27
  img = image.resize((384, 288))
28
- img = np.array(img) / 100.
29
  img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
30
  img = torch.from_numpy(img)
31
  img = img.type(torch.FloatTensor).to(device)
32
 
33
- pred_saliency = model(img)
34
 
35
- toPIL = transforms.ToPILImage()
36
- pic = toPIL(pred_saliency.squeeze())
37
 
38
- colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_OCEAN)
39
 
40
- original_img = np.array(image)
41
- colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
42
 
43
- alpha = 0.7
44
- blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)
 
 
 
45
 
46
- # Find all contours
47
- contours, _ = cv2.findContours(np.uint8(pred_saliency.squeeze().detach().numpy() * 255), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
48
- saliency_8bit = np.uint8(pred_saliency.squeeze().detach().numpy() * 255)
49
-
50
- # Apply dilation
51
- kernel = np.ones((5,5),np.uint8)
52
- dilated = cv2.dilate(saliency_8bit, kernel, iterations = 1)
53
-
54
- # Find contours on dilated image
55
- contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
56
-
57
- font = cv2.FONT_HERSHEY_SIMPLEX
58
- label = 1
59
- for contour in contours:
60
- # Get bounding box for contour
61
- x, y, w, h = cv2.boundingRect(contour)
62
-
63
- # Calculate center of bounding box
64
- center_x = x + w // 2
65
- center_y = y + h // 2
66
-
67
- # Find point on contour closest to center of bounding box
68
- distances = np.sqrt((contour[:,0,0] - center_x)**2 + (contour[:,0,1] - center_y)**2)
69
- min_index = np.argmin(distances)
70
- closest_point = tuple(contour[min_index][0])
71
-
72
- # Place label at closest point on contour
73
- cv2.putText(blended_img, str(label), closest_point, font, 1, (0, 0, 255), 3, cv2.LINE_AA)
74
-
75
- label += 1
76
 
 
 
77
 
78
- st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
79
 
80
  cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
81
  st.success('Saliency detection complete. Result saved as "example/result15.png".')
 
15
  model.to(device)
16
  model.eval()
17
 
18
+ def count_and_label_red_patches(heatmap, threshold=200):
19
+ red_mask = heatmap[:, :, 2] > threshold
20
+ _, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8)
21
+
22
+ num_red_patches = labels.max()
23
+
24
+ for i in range(1, num_red_patches + 1):
25
+ patch_mask = (labels == i)
26
+ patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2)
27
+ cv2.putText(heatmap, str(i), (patch_centroid_x, patch_centroid_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)
28
+
29
+ return heatmap, num_red_patches
30
+
31
  st.title('Saliency Detection App')
32
  st.write('Upload an image for saliency detection:')
33
  uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
38
 
39
  if st.button('Detect Saliency'):
40
  img = image.resize((384, 288))
41
+ img = np.array(img) / 255.
42
  img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
43
  img = torch.from_numpy(img)
44
  img = img.type(torch.FloatTensor).to(device)
45
 
46
+ pred_saliency = model(img).squeeze().detach().numpy()
47
 
48
+ heatmap = (pred_saliency * 255).astype(np.uint8)
49
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
50
 
51
+ heatmap = cv2.resize(heatmap, (image.width, image.height))
52
 
53
+ heatmap, num_red_patches = count_and_label_red_patches(heatmap)
 
54
 
55
+ enhanced_image = np.array(image)
56
+ b, g, r = cv2.split(enhanced_image)
57
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
58
+ b_enhanced = clahe.apply(b)
59
+ enhanced_image = cv2.merge((b_enhanced, g, r))
60
 
61
+ alpha = 0.7
62
+ blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ st.image(heatmap, caption='Enhanced Saliency Heatmap', use_column_width=True, channels='BGR')
65
+ st.image(enhanced_image, caption='Enhanced Blue Image', use_column_width=True, channels='BGR')
66
 
67
+ st.image(blended_img, caption=f'Blended Image with {num_red_patches} Red Patches', use_column_width=True, channels='BGR')
68
 
69
  cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
70
  st.success('Saliency detection complete. Result saved as "example/result15.png".')