Update predictor.py
Browse files- predictor.py +10 -1
predictor.py
CHANGED
@@ -76,7 +76,7 @@ def get_points_on_a_grid(
|
|
76 |
class CoTrackerOnlinePredictor(torch.nn.Module):
|
77 |
def __init__(
|
78 |
self,
|
79 |
-
checkpoint=
|
80 |
offline=False,
|
81 |
v2=False,
|
82 |
window_len=16,
|
@@ -85,6 +85,15 @@ class CoTrackerOnlinePredictor(torch.nn.Module):
|
|
85 |
self.support_grid_size = 6
|
86 |
model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").model
|
87 |
# build_cotracker(checkpoint, v2=v2, offline=False, window_len=window_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
self.interp_shape = model.model_resolution
|
89 |
self.step = model.window_len // 2
|
90 |
self.model = model
|
|
|
76 |
class CoTrackerOnlinePredictor(torch.nn.Module):
|
77 |
def __init__(
|
78 |
self,
|
79 |
+
checkpoint=None,
|
80 |
offline=False,
|
81 |
v2=False,
|
82 |
window_len=16,
|
|
|
85 |
self.support_grid_size = 6
|
86 |
model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").model
|
87 |
# build_cotracker(checkpoint, v2=v2, offline=False, window_len=window_len)
|
88 |
+
if checkpoint is not None:
|
89 |
+
with open(checkpoint, "rb") as f:
|
90 |
+
state_dict = torch.load(f, map_location="cpu")
|
91 |
+
if "model" in state_dict:
|
92 |
+
state_dict = state_dict["model"]
|
93 |
+
model.model.load_state_dict(state_dict)
|
94 |
+
print('LOAD STATE DICT')
|
95 |
+
|
96 |
+
|
97 |
self.interp_shape = model.model_resolution
|
98 |
self.step = model.window_len // 2
|
99 |
self.model = model
|