doublelotus commited on
Commit
9d22666
·
1 Parent(s): dce6539

mask test 2

Browse files
Files changed (1) hide show
  1. main.py +27 -22
main.py CHANGED
@@ -14,8 +14,8 @@ CORS(app)
14
  cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
15
  print(cudaOrNah)
16
 
17
- # Global model setup
18
- # running out of memory adjusted
19
  # checkpoint = "sam_vit_h_4b8939.pth"
20
  # model_type = "vit_h"
21
  checkpoint = "sam_vit_l_0b3195.pth"
@@ -30,7 +30,7 @@ print('Setup SAM model')
30
 
31
  @app.route('/')
32
  def hello():
33
- return {"hei": "Shredded to peices"}
34
 
35
  @app.route('/health', methods=['GET'])
36
  def health_check():
@@ -44,7 +44,7 @@ def get_masks():
44
  # Get the image file from the request
45
  if 'image' not in request.files:
46
  return jsonify({"error": "No image file provided"}), 400
47
-
48
  image_file = request.files['image']
49
  if image_file.filename == '':
50
  return jsonify({"error": "No image file provided"}), 400
@@ -57,17 +57,36 @@ def get_masks():
57
 
58
  if image is None:
59
  raise ValueError("Image not found or unable to read.")
60
-
61
  if cudaOrNah == "cuda":
62
  torch.cuda.empty_cache()
63
-
64
  masks = mask_generator.generate(image)
65
 
66
  if cudaOrNah == "cuda":
67
  torch.cuda.empty_cache()
68
 
69
- masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
 
 
 
71
  def is_background(segmentation):
72
  val = (segmentation[10, 10] or segmentation[-10, 10] or
73
  segmentation[10, -10] or segmentation[-10, -10])
@@ -75,20 +94,6 @@ def get_masks():
75
 
76
  masks = [mask for mask in masks if not is_background(mask['segmentation'])]
77
 
78
- for i in range(0, len(masks) - 1)[::-1]:
79
- large_mask = masks[i]['segmentation']
80
- for j in range(i+1, len(masks)):
81
- not_small_mask = np.logical_not(masks[j]['segmentation'])
82
- masks[i]['segmentation'] = np.logical_and(large_mask, not_small_mask)
83
- masks[i]['area'] = masks[i]['segmentation'].sum()
84
- large_mask = masks[i]['segmentation']
85
-
86
- def sum_under_threshold(segmentation, threshold):
87
- return segmentation.sum() / segmentation.size < 0.0015
88
-
89
- masks = [mask for mask in masks if not sum_under_threshold(mask['segmentation'], 100)]
90
- masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
91
-
92
  # Create a zip file in memory
93
  zip_buffer = io.BytesIO()
94
  with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
@@ -101,7 +106,7 @@ def get_masks():
101
  zip_file.writestr(f'mask_{idx+1}.png', mask_io.read())
102
 
103
  zip_buffer.seek(0)
104
-
105
  return send_file(zip_buffer, mimetype='application/zip', as_attachment=True, download_name='masks.zip')
106
  except Exception as e:
107
  # Log the error message if needed
 
14
  cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
15
  print(cudaOrNah)
16
 
17
+ # Global model setup
18
+ # Adjusted due to memory constraints
19
  # checkpoint = "sam_vit_h_4b8939.pth"
20
  # model_type = "vit_h"
21
  checkpoint = "sam_vit_l_0b3195.pth"
 
30
 
31
  @app.route('/')
32
  def hello():
33
+ return {"hei": "Shredded to pieces"}
34
 
35
  @app.route('/health', methods=['GET'])
36
  def health_check():
 
44
  # Get the image file from the request
45
  if 'image' not in request.files:
46
  return jsonify({"error": "No image file provided"}), 400
47
+
48
  image_file = request.files['image']
49
  if image_file.filename == '':
50
  return jsonify({"error": "No image file provided"}), 400
 
57
 
58
  if image is None:
59
  raise ValueError("Image not found or unable to read.")
60
+
61
  if cudaOrNah == "cuda":
62
  torch.cuda.empty_cache()
63
+
64
  masks = mask_generator.generate(image)
65
 
66
  if cudaOrNah == "cuda":
67
  torch.cuda.empty_cache()
68
 
69
+ # Sort masks by area in descending order
70
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
71
+
72
+ # Initialize a cumulative mask to keep track of covered areas
73
+ cumulative_mask = np.zeros_like(masks[0]['segmentation'], dtype=bool)
74
+
75
+ # Process masks to remove overlaps
76
+ for mask in masks:
77
+ # Subtract areas already covered
78
+ mask['segmentation'] = np.logical_and(
79
+ mask['segmentation'], np.logical_not(cumulative_mask)
80
+ )
81
+ # Update the cumulative mask
82
+ cumulative_mask = np.logical_or(cumulative_mask, mask['segmentation'])
83
+ # Update the area
84
+ mask['area'] = mask['segmentation'].sum()
85
 
86
+ # Remove masks with zero area
87
+ masks = [mask for mask in masks if mask['area'] > 0]
88
+
89
+ # (Optional) Remove background masks if needed
90
  def is_background(segmentation):
91
  val = (segmentation[10, 10] or segmentation[-10, 10] or
92
  segmentation[10, -10] or segmentation[-10, -10])
 
94
 
95
  masks = [mask for mask in masks if not is_background(mask['segmentation'])]
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  # Create a zip file in memory
98
  zip_buffer = io.BytesIO()
99
  with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
 
106
  zip_file.writestr(f'mask_{idx+1}.png', mask_io.read())
107
 
108
  zip_buffer.seek(0)
109
+
110
  return send_file(zip_buffer, mimetype='application/zip', as_attachment=True, download_name='masks.zip')
111
  except Exception as e:
112
  # Log the error message if needed