doublelotus commited on
Commit
a4acab9
·
1 Parent(s): 9d22666

removing bg first

Browse files
Files changed (2) hide show
  1. main.py +29 -30
  2. requirements.txt +2 -1
main.py CHANGED
@@ -7,6 +7,7 @@ import cv2
7
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
8
  from PIL import Image
9
  import zipfile
 
10
 
11
  app = Flask(__name__)
12
  CORS(app)
@@ -14,8 +15,8 @@ CORS(app)
14
  cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
15
  print(cudaOrNah)
16
 
17
- # Global model setup
18
- # Adjusted due to memory constraints
19
  # checkpoint = "sam_vit_h_4b8939.pth"
20
  # model_type = "vit_h"
21
  checkpoint = "sam_vit_l_0b3195.pth"
@@ -27,10 +28,12 @@ mask_generator = SamAutomaticMaskGenerator(
27
  min_mask_region_area=0.0015 # Adjust this value as needed
28
  )
29
  print('Setup SAM model')
 
 
30
 
31
  @app.route('/')
32
  def hello():
33
- return {"hei": "Shredded to pieces"}
34
 
35
  @app.route('/health', methods=['GET'])
36
  def health_check():
@@ -44,7 +47,7 @@ def get_masks():
44
  # Get the image file from the request
45
  if 'image' not in request.files:
46
  return jsonify({"error": "No image file provided"}), 400
47
-
48
  image_file = request.files['image']
49
  if image_file.filename == '':
50
  return jsonify({"error": "No image file provided"}), 400
@@ -57,42 +60,38 @@ def get_masks():
57
 
58
  if image is None:
59
  raise ValueError("Image not found or unable to read.")
60
-
61
  if cudaOrNah == "cuda":
62
  torch.cuda.empty_cache()
63
-
64
- masks = mask_generator.generate(image)
 
65
 
66
  if cudaOrNah == "cuda":
67
  torch.cuda.empty_cache()
68
 
69
- # Sort masks by area in descending order
70
- masks = sorted(masks, key=lambda x: x['area'], reverse=True)
71
-
72
- # Initialize a cumulative mask to keep track of covered areas
73
- cumulative_mask = np.zeros_like(masks[0]['segmentation'], dtype=bool)
74
-
75
- # Process masks to remove overlaps
76
- for mask in masks:
77
- # Subtract areas already covered
78
- mask['segmentation'] = np.logical_and(
79
- mask['segmentation'], np.logical_not(cumulative_mask)
80
- )
81
- # Update the cumulative mask
82
- cumulative_mask = np.logical_or(cumulative_mask, mask['segmentation'])
83
- # Update the area
84
- mask['area'] = mask['segmentation'].sum()
85
 
86
- # Remove masks with zero area
87
- masks = [mask for mask in masks if mask['area'] > 0]
88
-
89
- # (Optional) Remove background masks if needed
90
  def is_background(segmentation):
91
  val = (segmentation[10, 10] or segmentation[-10, 10] or
92
  segmentation[10, -10] or segmentation[-10, -10])
93
  return val
94
 
95
- masks = [mask for mask in masks if not is_background(mask['segmentation'])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # Create a zip file in memory
98
  zip_buffer = io.BytesIO()
@@ -106,7 +105,7 @@ def get_masks():
106
  zip_file.writestr(f'mask_{idx+1}.png', mask_io.read())
107
 
108
  zip_buffer.seek(0)
109
-
110
  return send_file(zip_buffer, mimetype='application/zip', as_attachment=True, download_name='masks.zip')
111
  except Exception as e:
112
  # Log the error message if needed
@@ -115,4 +114,4 @@ def get_masks():
115
  return jsonify({"error": "Error processing the image", "details": str(e)}), 400
116
 
117
  if __name__ == '__main__':
118
- app.run(debug=True)
 
7
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
8
  from PIL import Image
9
  import zipfile
10
+ from transformers import pipeline
11
 
12
  app = Flask(__name__)
13
  CORS(app)
 
15
  cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
16
  print(cudaOrNah)
17
 
18
+ # Global model setup
19
+ # running out of memory adjusted
20
  # checkpoint = "sam_vit_h_4b8939.pth"
21
  # model_type = "vit_h"
22
  checkpoint = "sam_vit_l_0b3195.pth"
 
28
  min_mask_region_area=0.0015 # Adjust this value as needed
29
  )
30
  print('Setup SAM model')
31
+ rembg_pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
32
+ print('Setup rembg model')
33
 
34
  @app.route('/')
35
  def hello():
36
+ return {"hei": "Shredded to peices"}
37
 
38
  @app.route('/health', methods=['GET'])
39
  def health_check():
 
47
  # Get the image file from the request
48
  if 'image' not in request.files:
49
  return jsonify({"error": "No image file provided"}), 400
50
+
51
  image_file = request.files['image']
52
  if image_file.filename == '':
53
  return jsonify({"error": "No image file provided"}), 400
 
60
 
61
  if image is None:
62
  raise ValueError("Image not found or unable to read.")
63
+
64
  if cudaOrNah == "cuda":
65
  torch.cuda.empty_cache()
66
+
67
+ noBg = rembg_pipe(image)
68
+ masks = mask_generator.generate(noBg)
69
 
70
  if cudaOrNah == "cuda":
71
  torch.cuda.empty_cache()
72
 
73
+ masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
 
 
 
 
75
  def is_background(segmentation):
76
  val = (segmentation[10, 10] or segmentation[-10, 10] or
77
  segmentation[10, -10] or segmentation[-10, -10])
78
  return val
79
 
80
+ # masks = [mask for mask in masks if not is_background(mask['segmentation'])]
81
+
82
+ for i in range(0, len(masks) - 1)[::-1]:
83
+ large_mask = masks[i]['segmentation']
84
+ for j in range(i+1, len(masks)):
85
+ not_small_mask = np.logical_not(masks[j]['segmentation'])
86
+ masks[i]['segmentation'] = np.logical_and(large_mask, not_small_mask)
87
+ masks[i]['area'] = masks[i]['segmentation'].sum()
88
+ large_mask = masks[i]['segmentation']
89
+
90
+ def sum_under_threshold(segmentation, threshold):
91
+ return segmentation.sum() / segmentation.size < 0.0015
92
+
93
+ masks = [mask for mask in masks if not sum_under_threshold(mask['segmentation'], 100)]
94
+ masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
95
 
96
  # Create a zip file in memory
97
  zip_buffer = io.BytesIO()
 
105
  zip_file.writestr(f'mask_{idx+1}.png', mask_io.read())
106
 
107
  zip_buffer.seek(0)
108
+
109
  return send_file(zip_buffer, mimetype='application/zip', as_attachment=True, download_name='masks.zip')
110
  except Exception as e:
111
  # Log the error message if needed
 
114
  return jsonify({"error": "Error processing the image", "details": str(e)}), 400
115
 
116
  if __name__ == '__main__':
117
+ app.run(debug=True)
requirements.txt CHANGED
@@ -12,4 +12,5 @@ torchvision
12
  matplotlib # Required for image processing and mask visualization
13
  onnxruntime
14
  onnx
15
- pycocotools
 
 
12
  matplotlib # Required for image processing and mask visualization
13
  onnxruntime
14
  onnx
15
+ pycocotools
16
+ transformers