udayzee05 commited on
Commit
3b4610d
·
1 Parent(s): 434eb56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -31
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import cv2
 
2
  import numpy as np
3
  import torch
4
- from fastapi import FastAPI, UploadFile, File
5
  from PIL import Image
 
 
 
6
  from TranSalNet_Res import TranSalNet
7
  from utils.data_process import preprocess_img, postprocess_img
8
 
9
-
10
  app = FastAPI()
11
 
12
  device = torch.device('cpu')
@@ -15,51 +17,50 @@ model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_loc
15
  model.to(device)
16
  model.eval()
17
 
 
18
  def count_and_label_red_patches(heatmap, threshold=200):
19
  red_mask = heatmap[:, :, 2] > threshold
20
  contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
21
-
22
- # Sort the contours based on their areas in descending order
23
  contours = sorted(contours, key=cv2.contourArea, reverse=True)
24
-
25
  original_image = np.array(image)
26
-
27
- centroid_list = [] # List to store the centroids of the contours in order
28
-
29
  for i, contour in enumerate(contours, start=1):
30
- # Compute the centroid of the current contour
31
  M = cv2.moments(contour)
32
  if M["m00"] != 0:
33
  cX = int(M["m10"] / M["m00"])
34
  cY = int(M["m01"] / M["m00"])
35
  else:
36
  cX, cY = 0, 0
37
-
38
- radius = 20 # Adjust the circle radius to fit the numbers
39
- circle_color = (0, 0, 0) # Blue color
40
- cv2.circle(original_image, (cX, cY), radius, circle_color, -1) # Draw blue circle
41
 
42
  font = cv2.FONT_HERSHEY_SIMPLEX
43
  font_scale = 1
44
  font_color = (255, 255, 255)
45
  line_type = cv2.LINE_AA
46
  cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type)
47
-
48
- centroid_list.append((cX, cY)) # Add the centroid to the list
49
 
50
- # Connect the red spots in the desired order
 
51
  for i in range(len(centroid_list) - 1):
52
  start_point = centroid_list[i]
53
  end_point = centroid_list[i + 1]
54
- line_color = (0, 0, 0) # Red color
55
  cv2.line(original_image, start_point, end_point, line_color, 2)
56
 
57
  return original_image, len(contours)
58
 
 
59
  def process_image(image: Image.Image) -> np.ndarray:
60
  img = image.resize((384, 288))
61
  img = np.array(img)
62
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Convert to BGR color space
63
  img = np.array(img) / 255.
64
  img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
65
  img = torch.from_numpy(img)
@@ -68,7 +69,7 @@ def process_image(image: Image.Image) -> np.ndarray:
68
  pred_saliency = model(img).squeeze().detach().numpy()
69
 
70
  heatmap = (pred_saliency * 255).astype(np.uint8)
71
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Use a blue colormap (JET)
72
 
73
  heatmap = cv2.resize(heatmap, (image.width, image.height))
74
 
@@ -88,17 +89,17 @@ def process_image(image: Image.Image) -> np.ndarray:
88
 
89
  return blended_img
90
 
91
- @app.post("/process_image")
92
- async def process_uploaded_image(file: UploadFile = File(...)):
93
- try:
94
- contents = await file.read()
95
- image = Image.open(io.BytesIO(contents))
96
- except Exception as e:
97
- raise HTTPException(status_code=400, detail=f"Error opening image: {str(e)}")
98
 
99
- try:
100
- processed_image = process_image(image)
101
- return StreamingResponse(io.BytesIO(cv2.imencode('.png', processed_image)[1].tobytes()), media_type="image/png")
 
 
 
 
 
 
 
 
102
 
103
- except Exception as e:
104
- raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
 
1
  import cv2
2
+ import io
3
  import numpy as np
4
  import torch
 
5
  from PIL import Image
6
+ from fastapi import FastAPI, UploadFile, File, HTTPException
7
+ from gradio import Gradio, Image as GImage
8
+ from starlette.responses import StreamingResponse
9
  from TranSalNet_Res import TranSalNet
10
  from utils.data_process import preprocess_img, postprocess_img
11
 
 
12
  app = FastAPI()
13
 
14
  device = torch.device('cpu')
 
17
  model.to(device)
18
  model.eval()
19
 
20
+
21
  def count_and_label_red_patches(heatmap, threshold=200):
22
  red_mask = heatmap[:, :, 2] > threshold
23
  contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
24
+
 
25
  contours = sorted(contours, key=cv2.contourArea, reverse=True)
26
+
27
  original_image = np.array(image)
28
+
29
+ centroid_list = []
30
+
31
  for i, contour in enumerate(contours, start=1):
 
32
  M = cv2.moments(contour)
33
  if M["m00"] != 0:
34
  cX = int(M["m10"] / M["m00"])
35
  cY = int(M["m01"] / M["m00"])
36
  else:
37
  cX, cY = 0, 0
38
+
39
+ radius = 20
40
+ circle_color = (0, 0, 0)
41
+ cv2.circle(original_image, (cX, cY), radius, circle_color, -1)
42
 
43
  font = cv2.FONT_HERSHEY_SIMPLEX
44
  font_scale = 1
45
  font_color = (255, 255, 255)
46
  line_type = cv2.LINE_AA
47
  cv2.putText(original_image, str(i), (cX - 10, cY + 10), font, font_scale, font_color, 2, line_type)
 
 
48
 
49
+ centroid_list.append((cX, cY))
50
+
51
  for i in range(len(centroid_list) - 1):
52
  start_point = centroid_list[i]
53
  end_point = centroid_list[i + 1]
54
+ line_color = (0, 0, 0)
55
  cv2.line(original_image, start_point, end_point, line_color, 2)
56
 
57
  return original_image, len(contours)
58
 
59
+
60
  def process_image(image: Image.Image) -> np.ndarray:
61
  img = image.resize((384, 288))
62
  img = np.array(img)
63
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
64
  img = np.array(img) / 255.
65
  img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
66
  img = torch.from_numpy(img)
 
69
  pred_saliency = model(img).squeeze().detach().numpy()
70
 
71
  heatmap = (pred_saliency * 255).astype(np.uint8)
72
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
73
 
74
  heatmap = cv2.resize(heatmap, (image.width, image.height))
75
 
 
89
 
90
  return blended_img
91
 
 
 
 
 
 
 
 
92
 
93
+ def gr_process_image(input_image):
94
+ image = Image.fromarray(input_image)
95
+ processed_image = process_image(image)
96
+ return processed_image
97
+
98
+
99
+ iface = Gradio.Interface(
100
+ fn=gr_process_image,
101
+ inputs=GImage(),
102
+ outputs=GImage("numpy")
103
+ )
104
 
105
+ iface.launch(share=True)