taher30 commited on
Commit
8e4f839
1 Parent(s): c47cea1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -12,14 +12,13 @@ import gradio as gr
12
 
13
  # image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
14
  def get_masks(model_type, image):
15
- if model_type == 'vit_h':
16
  sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
17
-
18
- masks_h = mask_generator_h.generate(image)
19
- if model_type == 'vit_b':
20
  sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
21
 
22
- if model_type == 'vit_l':
23
  sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
24
 
25
  mask_generator = SamAutomaticMaskGenerator(sam)
 
12
 
13
  # image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
14
  def get_masks(model_type, image):
15
+ if model_type.all() == 'vit_h':
16
  sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
17
+
18
+ if model_type,all() == 'vit_b':
 
19
  sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
20
 
21
+ if model_type.all() == 'vit_l':
22
  sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
23
 
24
  mask_generator = SamAutomaticMaskGenerator(sam)