doublelotus commited on
Commit
be6042a
·
1 Parent(s): 84fd7cf
Files changed (1) hide show
  1. main.py +2 -11
main.py CHANGED
@@ -8,18 +8,12 @@ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
8
  from PIL import Image
9
  import zipfile
10
 
11
- import os
12
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
13
-
14
- from transformers import pipeline
15
-
16
  app = Flask(__name__)
17
  CORS(app)
18
 
19
  cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
20
  print(cudaOrNah)
21
 
22
-
23
  # Global model setup
24
  # running out of memory adjusted
25
  # checkpoint = "sam_vit_h_4b8939.pth"
@@ -33,8 +27,6 @@ mask_generator = SamAutomaticMaskGenerator(
33
  min_mask_region_area=0.0015 # Adjust this value as needed
34
  )
35
  print('Setup SAM model')
36
- rembg_pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
37
- print('Setup rembg model')
38
 
39
  @app.route('/')
40
  def hello():
@@ -69,8 +61,7 @@ def get_masks():
69
  if cudaOrNah == "cuda":
70
  torch.cuda.empty_cache()
71
 
72
- noBg = rembg_pipe(image)
73
- masks = mask_generator.generate(noBg)
74
 
75
  if cudaOrNah == "cuda":
76
  torch.cuda.empty_cache()
@@ -82,7 +73,7 @@ def get_masks():
82
  segmentation[10, -10] or segmentation[-10, -10])
83
  return val
84
 
85
- # masks = [mask for mask in masks if not is_background(mask['segmentation'])]
86
 
87
  for i in range(0, len(masks) - 1)[::-1]:
88
  large_mask = masks[i]['segmentation']
 
8
  from PIL import Image
9
  import zipfile
10
 
 
 
 
 
 
11
  app = Flask(__name__)
12
  CORS(app)
13
 
14
  cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
15
  print(cudaOrNah)
16
 
 
17
  # Global model setup
18
  # running out of memory adjusted
19
  # checkpoint = "sam_vit_h_4b8939.pth"
 
27
  min_mask_region_area=0.0015 # Adjust this value as needed
28
  )
29
  print('Setup SAM model')
 
 
30
 
31
  @app.route('/')
32
  def hello():
 
61
  if cudaOrNah == "cuda":
62
  torch.cuda.empty_cache()
63
 
64
+ masks = mask_generator.generate(image)
 
65
 
66
  if cudaOrNah == "cuda":
67
  torch.cuda.empty_cache()
 
73
  segmentation[10, -10] or segmentation[-10, -10])
74
  return val
75
 
76
+ masks = [mask for mask in masks if not is_background(mask['segmentation'])]
77
 
78
  for i in range(0, len(masks) - 1)[::-1]:
79
  large_mask = masks[i]['segmentation']