Seokju Cho commited on
Commit
72bbdf9
·
1 Parent(s): c7a1328
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import sys
3
  sys.path.append(os.path.join(os.path.dirname(__file__), "locotrack_pytorch"))
4
  import uuid
 
5
 
6
  import gradio as gr
7
  import mediapy
@@ -117,7 +118,7 @@ def clear_all_fn(frame_num, video_preview):
117
  def choose_frame(frame_num, video_preview_array):
118
  return video_preview_array[int(frame_num)]
119
 
120
-
121
  def extract_feature(video_input, model_size="small"):
122
  device = "cuda" if torch.cuda.is_available() else "cpu"
123
  dtype = torch.bfloat16 if device == "cuda" else torch.float16
@@ -177,6 +178,7 @@ def preprocess_video_input(video_path, model_size):
177
  )
178
 
179
 
 
180
  def track(
181
  model_size,
182
  video_preview,
 
2
  import sys
3
  sys.path.append(os.path.join(os.path.dirname(__file__), "locotrack_pytorch"))
4
  import uuid
5
+ import spaces
6
 
7
  import gradio as gr
8
  import mediapy
 
118
  def choose_frame(frame_num, video_preview_array):
119
  return video_preview_array[int(frame_num)]
120
 
121
+ @spaces.GPU
122
  def extract_feature(video_input, model_size="small"):
123
  device = "cuda" if torch.cuda.is_available() else "cpu"
124
  dtype = torch.bfloat16 if device == "cuda" else torch.float16
 
178
  )
179
 
180
 
181
+ @spaces.GPU
182
  def track(
183
  model_size,
184
  video_preview,