Spaces:
Sleeping
Sleeping
Commit
·
be6042a
1
Parent(s):
84fd7cf
dismantle
Browse files
main.py
CHANGED
@@ -8,18 +8,12 @@ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
|
8 |
from PIL import Image
|
9 |
import zipfile
|
10 |
|
11 |
-
import os
|
12 |
-
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
|
13 |
-
|
14 |
-
from transformers import pipeline
|
15 |
-
|
16 |
app = Flask(__name__)
|
17 |
CORS(app)
|
18 |
|
19 |
cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
print(cudaOrNah)
|
21 |
|
22 |
-
|
23 |
# Global model setup
|
24 |
# running out of memory adjusted
|
25 |
# checkpoint = "sam_vit_h_4b8939.pth"
|
@@ -33,8 +27,6 @@ mask_generator = SamAutomaticMaskGenerator(
|
|
33 |
min_mask_region_area=0.0015 # Adjust this value as needed
|
34 |
)
|
35 |
print('Setup SAM model')
|
36 |
-
rembg_pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
|
37 |
-
print('Setup rembg model')
|
38 |
|
39 |
@app.route('/')
|
40 |
def hello():
|
@@ -69,8 +61,7 @@ def get_masks():
|
|
69 |
if cudaOrNah == "cuda":
|
70 |
torch.cuda.empty_cache()
|
71 |
|
72 |
-
|
73 |
-
masks = mask_generator.generate(noBg)
|
74 |
|
75 |
if cudaOrNah == "cuda":
|
76 |
torch.cuda.empty_cache()
|
@@ -82,7 +73,7 @@ def get_masks():
|
|
82 |
segmentation[10, -10] or segmentation[-10, -10])
|
83 |
return val
|
84 |
|
85 |
-
|
86 |
|
87 |
for i in range(0, len(masks) - 1)[::-1]:
|
88 |
large_mask = masks[i]['segmentation']
|
|
|
8 |
from PIL import Image
|
9 |
import zipfile
|
10 |
|
|
|
|
|
|
|
|
|
|
|
11 |
app = Flask(__name__)
|
12 |
CORS(app)
|
13 |
|
14 |
cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
print(cudaOrNah)
|
16 |
|
|
|
17 |
# Global model setup
|
18 |
# running out of memory adjusted
|
19 |
# checkpoint = "sam_vit_h_4b8939.pth"
|
|
|
27 |
min_mask_region_area=0.0015 # Adjust this value as needed
|
28 |
)
|
29 |
print('Setup SAM model')
|
|
|
|
|
30 |
|
31 |
@app.route('/')
|
32 |
def hello():
|
|
|
61 |
if cudaOrNah == "cuda":
|
62 |
torch.cuda.empty_cache()
|
63 |
|
64 |
+
masks = mask_generator.generate(image)
|
|
|
65 |
|
66 |
if cudaOrNah == "cuda":
|
67 |
torch.cuda.empty_cache()
|
|
|
73 |
segmentation[10, -10] or segmentation[-10, -10])
|
74 |
return val
|
75 |
|
76 |
+
masks = [mask for mask in masks if not is_background(mask['segmentation'])]
|
77 |
|
78 |
for i in range(0, len(masks) - 1)[::-1]:
|
79 |
large_mask = masks[i]['segmentation']
|