added model loading from ckpt
Browse files- tools/demo_api.py +6 -0
tools/demo_api.py
CHANGED
@@ -133,6 +133,12 @@ def build_predictor(
|
|
133 |
model.eval()
|
134 |
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
predictor = Predictor(
|
137 |
model, exp, COCO_CLASSES,
|
138 |
None, decoder=None,
|
|
|
133 |
model.eval()
|
134 |
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
|
135 |
|
136 |
+
logger.info("loading checkpoint")
|
137 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
138 |
+
# load the model state dict
|
139 |
+
model.load_state_dict(ckpt["model"])
|
140 |
+
logger.info("loaded checkpoint done.")
|
141 |
+
|
142 |
predictor = Predictor(
|
143 |
model, exp, COCO_CLASSES,
|
144 |
None, decoder=None,
|