haritsahm commited on
Commit
e661c00
·
1 Parent(s): f665217

fix load checkpoint from path

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. utils/utils.py +4 -3
app.py CHANGED
@@ -10,7 +10,7 @@ from PIL import Image
10
  from streamlit_drawable_canvas import st_canvas
11
  from utils import utils
12
 
13
- PREDICTOR_MODEL, AUTOMASK_MODEL = utils.get_model('vit_b')
14
 
15
 
16
  def process_box(predictor_model, show_mask, radius_width):
 
10
  from streamlit_drawable_canvas import st_canvas
11
  from utils import utils
12
 
13
+ PREDICTOR_MODEL, AUTOMASK_MODEL = utils.get_model('checkpoint/sam_vit_b_01ec64.pth')
14
 
15
 
16
  def process_box(predictor_model, show_mask, radius_width):
utils/utils.py CHANGED
@@ -15,10 +15,11 @@ def get_color():
15
 
16
 
17
  @st.cache_resource
18
- def get_model(model):
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- build_sam = sam_model_registry[model]
21
- model = build_sam(checkpoint=get_checkpoint_path(model)).to(device)
 
22
  if torch.cuda.is_available():
23
  torch.cuda.empty_cache()
24
  predictor = SamPredictor(model)
 
15
 
16
 
17
  @st.cache_resource
18
+ def get_model(model, checkpoint='checkpoint/sam_vit_b_01ec64.pth'):
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ build_sam = sam_model_registry['vit_b']
21
+ model = build_sam(checkpoint=checkpoint)
22
+ model = model.to(device)
23
  if torch.cuda.is_available():
24
  torch.cuda.empty_cache()
25
  predictor = SamPredictor(model)