Update app.py
Browse files
app.py
CHANGED
@@ -12,14 +12,13 @@ import gradio as gr
|
|
12 |
|
13 |
# image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
|
14 |
def get_masks(model_type, image):
|
15 |
-
if model_type == 'vit_h':
|
16 |
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
17 |
-
|
18 |
-
|
19 |
-
if model_type == 'vit_b':
|
20 |
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
21 |
|
22 |
-
if model_type == 'vit_l':
|
23 |
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
|
24 |
|
25 |
mask_generator = SamAutomaticMaskGenerator(sam)
|
|
|
12 |
|
13 |
# image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
|
14 |
def get_masks(model_type, image):
|
15 |
+
if model_type.all() == 'vit_h':
|
16 |
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
17 |
+
|
18 |
+
if model_type,all() == 'vit_b':
|
|
|
19 |
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
20 |
|
21 |
+
if model_type.all() == 'vit_l':
|
22 |
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
|
23 |
|
24 |
mask_generator = SamAutomaticMaskGenerator(sam)
|