Tanzeer commited on
Commit
8271835
·
1 Parent(s): c6e2164

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -12
app.py CHANGED
@@ -1,15 +1,14 @@
1
  import streamlit as st
2
- import torch
3
  import cv2
4
- from PIL import Image
5
  import numpy as np
 
6
  from torchvision import transforms
7
- from TranSalNet_Res import TranSalNet # Make sure TranSalNet is accessible from your Streamlit app
8
 
9
  # Load the model and set the device
10
- model = TranSalNet()
11
  model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
12
- model.eval() # Set the model to evaluation mode
13
  device = torch.device('cpu')
14
  model.to(device)
15
 
@@ -27,13 +26,12 @@ if uploaded_image:
27
  # Preprocess the image
28
  img = image.resize((384, 288))
29
  img = np.array(img) / 255.
30
- img = np.transpose(img, (2, 0, 1))
31
- img = torch.from_numpy(img).unsqueeze(0).float()
32
- img = img.to(device)
33
 
34
  # Get saliency prediction
35
- with torch.no_grad():
36
- pred_saliency = model(img)
37
 
38
  # Convert the result back to a PIL image
39
  toPIL = transforms.ToPILImage()
@@ -46,9 +44,50 @@ if uploaded_image:
46
  original_img = np.array(image)
47
  colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
48
 
49
- # You can add more post-processing here if needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Display the final result
52
  st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
 
 
 
 
 
53
 
54
- st.write('Finished!')
 
1
  import streamlit as st
 
2
  import cv2
 
3
  import numpy as np
4
+ import torch
5
  from torchvision import transforms
6
+ from PIL import Image
7
 
8
  # Load the model and set the device
9
+ model = TranSalNet() # Assuming you have defined your model
10
  model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
11
+ model.eval()
12
  device = torch.device('cpu')
13
  model.to(device)
14
 
 
26
  # Preprocess the image
27
  img = image.resize((384, 288))
28
  img = np.array(img) / 255.
29
+ img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
30
+ img = torch.from_numpy(img)
31
+ img = img.type(torch.FloatTensor).to(device)
32
 
33
  # Get saliency prediction
34
+ pred_saliency = model(img)
 
35
 
36
  # Convert the result back to a PIL image
37
  toPIL = transforms.ToPILImage()
 
44
  original_img = np.array(image)
45
  colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
46
 
47
+ # Compute intensity values from the colorized image
48
+ intensity_map = cv2.cvtColor(colorized_img, cv2.COLOR_BGR2GRAY)
49
+
50
+ # Threshold the intensity map to create a binary mask
51
+ _, binary_map = cv2.threshold(intensity_map, 0, 255, cv2.THRESH_BINARY)
52
+
53
+ # Find contours in the binary map
54
+ contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
55
+
56
+ # Sort the contours by area in descending order
57
+ contours = sorted(contours, key=cv2.contourArea, reverse=True)
58
+
59
+ # Create an empty label map for ranking based on area
60
+ label_map = np.zeros_like(intensity_map)
61
+
62
+ # Rank and label each region based on area
63
+ for i, contour in enumerate(contours):
64
+ M = cv2.moments(contour)
65
+ if M["m00"] == 0:
66
+ continue
67
+ center_x = int(M["m10"] / M["m00"])
68
+ center_y = int(M["m01"] / M["m00"])
69
+ cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, 255, 2, cv2.LINE_AA)
70
+
71
+ # Blend the colorized image with the original image
72
+ alpha = 0.7 # Adjust the alpha value to control blending strength
73
+ blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)
74
+
75
+ # Overlay the labels on the blended image
76
+ font = cv2.FONT_HERSHEY_SIMPLEX
77
+ for i in range(1, len(contours) + 1):
78
+ mask = (label_map == i).astype(np.uint8)
79
+ x, y, w, h = cv2.boundingRect(contours[i-1])
80
+ org = (x, y)
81
+ color = (0, 0, 255) # Red color
82
+ thickness = 2
83
+ cv2.putText(blended_img, str(i), org, font, 1, color, thickness, cv2.LINE_AA)
84
 
85
  # Display the final result
86
  st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
87
+ st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
88
+
89
+ # Save the final result
90
+ cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
91
+ st.success('Saliency detection complete. Result saved as "example/result15.png".')
92
 
93
+ st.write('Finished, check the result at: example/result15.png')