Seokju Cho commited on
Commit
205164d
·
1 Parent(s): 5d50721
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -11,7 +11,7 @@ import cv2
11
  import matplotlib
12
  import torch
13
 
14
- from locotrack_pytorch.models.locotrack_model import load_model
15
  from viz_utils import paint_point_track
16
 
17
 
@@ -132,10 +132,12 @@ def extract_feature(video_input, model_size="small"):
132
  with torch.no_grad():
133
  feature = model.get_feature_grids(video_input)
134
 
135
- feature.lowres = (x.cpu() for x in feature.lowres)
136
- feature.hires = (x.cpu() for x in feature.hires)
137
- feature.highest = (x.cpu() for x in feature.highest)
138
-
 
 
139
  return feature
140
 
141
 
 
11
  import matplotlib
12
  import torch
13
 
14
+ from locotrack_pytorch.models.locotrack_model import load_model, FeatureGrids
15
  from viz_utils import paint_point_track
16
 
17
 
 
132
  with torch.no_grad():
133
  feature = model.get_feature_grids(video_input)
134
 
135
+ feature = FeatureGrids(
136
+ lowres=(feature.lowres[0].cpu(),),
137
+ hires=(feature.hires[0].cpu(),),
138
+ highest=(feature.highest[0].cpu(),),
139
+ resolutions=feature.resolutions,
140
+ )
141
  return feature
142
 
143