alexnasa commited on
Commit
616acbb
Β·
1 Parent(s): ae38969

tracker optimised for zero gpu

Browse files
Files changed (2) hide show
  1. app.py +35 -14
  2. src/pixel3dmm/tracking/tracker.py +6 -7
app.py CHANGED
@@ -17,11 +17,12 @@ os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results"
17
 
18
  def sh(cmd): subprocess.check_call(cmd, shell=True)
19
 
20
- # only do this once per VM restart
21
  sh("pip install -e .")
22
  sh("cd src/pixel3dmm/preprocessing/facer && pip install -e .")
23
  sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh")
24
 
 
25
  def install_cuda_toolkit():
26
  CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
27
  CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
@@ -39,9 +40,36 @@ def install_cuda_toolkit():
39
  os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
40
  print("==> finished installation")
41
 
42
-
43
  install_cuda_toolkit()
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # Utility to select first image from a folder
46
  def first_image_from_dir(directory):
47
  patterns = ["*.jpg", "*.png", "*.jpeg"]
@@ -68,7 +96,7 @@ def reset_all():
68
  )
69
 
70
  # Step 1: Preprocess the input image (Save and Crop)
71
- @spaces.GPU()
72
  def preprocess_image(image_array, state):
73
  if image_array is None:
74
  return "❌ Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=False)
@@ -138,20 +166,13 @@ def step3_uv_map(state):
138
  @spaces.GPU()
139
  def step4_track(state):
140
  session_id = state.get("session_id")
141
- if not session_id:
142
- return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False)
143
-
144
- script = os.path.join(os.environ["PIXEL3DMM_CODE_BASE"], "scripts", "track.py")
145
- try:
146
- p = subprocess.run([
147
- "python", script, f"video_name={session_id}"
148
- ], check=True, capture_output=True, text=True)
149
- except subprocess.CalledProcessError as e:
150
- err = f"❌ Tracking failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
151
- return err, None, state, gr.update(interactive=True)
152
 
153
  tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
154
  image = first_image_from_dir(tracking_dir)
 
155
  return "βœ… Pipeline complete!", image, state, gr.update(interactive=False)
156
 
157
  # Build Gradio UI
 
17
 
18
  def sh(cmd): subprocess.check_call(cmd, shell=True)
19
 
20
+
21
  sh("pip install -e .")
22
  sh("cd src/pixel3dmm/preprocessing/facer && pip install -e .")
23
  sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh")
24
 
25
+
26
  def install_cuda_toolkit():
27
  CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
28
  CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
 
40
  os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
41
  print("==> finished installation")
42
 
 
43
  install_cuda_toolkit()
44
 
45
+ import os
46
+ import torch
47
+ import numpy as np
48
+ import trimesh
49
+ from pytorch3d.io import load_obj
50
+ from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer
51
+ from pixel3dmm.tracking.flame.FLAME import FLAME
52
+ from pixel3dmm import env_paths
53
+ from omegaconf import OmegaConf
54
+ from pixel3dmm.tracking.tracker import Tracker
55
+
56
+ DEVICE = "cuda"
57
+
58
+ base_conf = OmegaConf.load(f'{env_paths.CODE_BASE}/configs/tracking.yaml')
59
+
60
+ _mesh_file = env_paths.head_template
61
+ flame_model = FLAME(base_conf).to(DEVICE)
62
+
63
+ _obj_faces = load_obj(_mesh_file)[1]
64
+
65
+ diff_renderer = NVDRenderer(
66
+ image_size=base_conf.size,
67
+ obj_filename=_mesh_file,
68
+ no_sh=False,
69
+ white_bg=True
70
+ ).to(DEVICE)
71
+
72
+
73
  # Utility to select first image from a folder
74
  def first_image_from_dir(directory):
75
  patterns = ["*.jpg", "*.png", "*.jpeg"]
 
96
  )
97
 
98
  # Step 1: Preprocess the input image (Save and Crop)
99
+ # @spaces.GPU()
100
  def preprocess_image(image_array, state):
101
  if image_array is None:
102
  return "❌ Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=False)
 
166
  @spaces.GPU()
167
  def step4_track(state):
168
  session_id = state.get("session_id")
169
+ base_conf.video_name = f'{session_id}'
170
+ tracker = Tracker(base_conf, flame_model, diff_renderer)
171
+ tracker.run()
 
 
 
 
 
 
 
 
172
 
173
  tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
174
  image = first_image_from_dir(tracking_dir)
175
+
176
  return "βœ… Pipeline complete!", image, state, gr.update(interactive=False)
177
 
178
  # Build Gradio UI
src/pixel3dmm/tracking/tracker.py CHANGED
@@ -128,9 +128,13 @@ if COMPILE:
128
 
129
 
130
  class Tracker(object):
131
- def __init__(self, config,
132
  device='cuda:0',
133
  ):
 
 
 
 
134
  self.config = config
135
  self.device = device
136
  self.actor_name = self.config.video_name
@@ -240,7 +244,6 @@ class Tracker(object):
240
  def setup_renderer(self):
241
  mesh_file = f'{env_paths.head_template}'
242
  self.config.image_size = self.get_image_size()
243
- self.flame = FLAME(self.config).to(self.device)
244
  self.flame.vertex_face_mask = self.vertex_face_mask
245
 
246
 
@@ -251,11 +254,7 @@ class Tracker(object):
251
  self.actual_smooth = torch.compile(self.actual_smooth)
252
 
253
 
254
- self.diff_renderer = NVDRenderer(self.config.size,
255
- obj_filename=mesh_file,
256
- no_sh=self.no_sh,
257
- white_bg= True,
258
- ).to(self.device)
259
 
260
 
261
  self.faces = load_obj(mesh_file)[1]
 
128
 
129
 
130
  class Tracker(object):
131
+ def __init__(self, config, flame_module, renderer,
132
  device='cuda:0',
133
  ):
134
+ self.config = config
135
+ self.flame = flame_module
136
+ self.diff_renderer = renderer
137
+
138
  self.config = config
139
  self.device = device
140
  self.actor_name = self.config.video_name
 
244
  def setup_renderer(self):
245
  mesh_file = f'{env_paths.head_template}'
246
  self.config.image_size = self.get_image_size()
 
247
  self.flame.vertex_face_mask = self.vertex_face_mask
248
 
249
 
 
254
  self.actual_smooth = torch.compile(self.actual_smooth)
255
 
256
 
257
+ self.renderer = self.diff_renderer # already global
 
 
 
 
258
 
259
 
260
  self.faces = load_obj(mesh_file)[1]