tidalove commited on
Commit
ac5ef0a
·
verified ·
1 Parent(s): 9748030

added model loading from ckpt

Browse files
Files changed (1) hide show
  1. tools/demo_api.py +6 -0
tools/demo_api.py CHANGED
@@ -133,6 +133,12 @@ def build_predictor(
133
  model.eval()
134
  logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
135
 
 
 
 
 
 
 
136
  predictor = Predictor(
137
  model, exp, COCO_CLASSES,
138
  None, decoder=None,
 
133
  model.eval()
134
  logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
135
 
136
+ logger.info("loading checkpoint")
137
+ ckpt = torch.load(ckpt_path, map_location="cpu")
138
+ # load the model state dict
139
+ model.load_state_dict(ckpt["model"])
140
+ logger.info("loaded checkpoint done.")
141
+
142
  predictor = Predictor(
143
  model, exp, COCO_CLASSES,
144
  None, decoder=None,