nikkar commited on
Commit
b256705
·
verified ·
1 Parent(s): 79ddb78

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +10 -1
predictor.py CHANGED
@@ -76,7 +76,7 @@ def get_points_on_a_grid(
76
  class CoTrackerOnlinePredictor(torch.nn.Module):
77
  def __init__(
78
  self,
79
- checkpoint="./checkpoints/scaled_online.pth",
80
  offline=False,
81
  v2=False,
82
  window_len=16,
@@ -85,6 +85,15 @@ class CoTrackerOnlinePredictor(torch.nn.Module):
85
  self.support_grid_size = 6
86
  model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").model
87
  # build_cotracker(checkpoint, v2=v2, offline=False, window_len=window_len)
 
 
 
 
 
 
 
 
 
88
  self.interp_shape = model.model_resolution
89
  self.step = model.window_len // 2
90
  self.model = model
 
76
  class CoTrackerOnlinePredictor(torch.nn.Module):
77
  def __init__(
78
  self,
79
+ checkpoint=None,
80
  offline=False,
81
  v2=False,
82
  window_len=16,
 
85
  self.support_grid_size = 6
86
  model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").model
87
  # build_cotracker(checkpoint, v2=v2, offline=False, window_len=window_len)
88
+ if checkpoint is not None:
89
+ with open(checkpoint, "rb") as f:
90
+ state_dict = torch.load(f, map_location="cpu")
91
+ if "model" in state_dict:
92
+ state_dict = state_dict["model"]
93
+ model.model.load_state_dict(state_dict)
94
+ print('LOAD STATE DICT')
95
+
96
+
97
  self.interp_shape = model.model_resolution
98
  self.step = model.window_len // 2
99
  self.model = model