alexnasa commited on
Commit
6b0fd85
Β·
verified Β·
1 Parent(s): d3706c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -44
app.py CHANGED
@@ -32,11 +32,9 @@ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.in
32
 
33
  from pixel3dmm import env_paths
34
 
35
-
36
  sh("cd src/pixel3dmm/preprocessing/facer && pip install -e . && cd ../../../..")
37
  sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh && cd ../../../../../..")
38
 
39
-
40
  def install_cuda_toolkit():
41
  CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
42
  CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
@@ -57,12 +55,10 @@ def install_cuda_toolkit():
57
  install_cuda_toolkit()
58
 
59
  from omegaconf import OmegaConf
 
60
 
61
  DEVICE = "cuda"
62
 
63
- # 1. Prepare config at import time (no CUDA calls)
64
- base_conf = OmegaConf.load("configs/tracking.yaml")
65
-
66
  # 2. Empty cache for our heavy objects
67
  _model_cache = {}
68
 
@@ -87,16 +83,16 @@ def reset_all():
87
  "Awaiting new image upload...", # status
88
  {}, # state
89
  gr.update(interactive=True), # preprocess_btn
90
- gr.update(interactive=False), # normals_btn
91
- gr.update(interactive=False), # uv_map_btn
92
- gr.update(interactive=False) # track_btn
93
  )
94
 
95
  # Step 1: Preprocess the input image (Save and Crop)
96
  @spaces.GPU()
97
  def preprocess_image(image_array, state):
98
  if image_array is None:
99
- return "❌ Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=False)
100
 
101
  session_id = str(uuid.uuid4())
102
  base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
@@ -115,54 +111,63 @@ def preprocess_image(image_array, state):
115
  except subprocess.CalledProcessError as e:
116
  err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
117
  shutil.rmtree(base_dir)
118
- return err, None, {}, gr.update(interactive=True), gr.update(interactive=False)
119
 
120
  crop_dir = os.path.join(base_dir, "cropped")
121
  image = first_image_from_dir(crop_dir)
122
- return "βœ… Step 1 complete. Ready for Normals.", image, state, gr.update(interactive=False), gr.update(interactive=True)
123
 
124
  # Step 2: Normals inference β†’ normals image
125
  @spaces.GPU()
126
  def step2_normals(state):
127
- session_id = state.get("session_id")
128
- if not session_id:
129
- return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False)
130
 
131
- try:
132
- p = subprocess.run([
133
- "python", "scripts/network_inference.py", "model.prediction_type=normals", f"video_name={session_id}"
134
- ], check=True, capture_output=True, text=True)
135
- except subprocess.CalledProcessError as e:
136
- err = f"❌ Normal map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
137
- return err, None, state, gr.update(interactive=True), gr.update(interactive=False)
 
 
 
 
 
138
 
139
  normals_dir = os.path.join(state["base_dir"], "p3dmm", "normals")
140
  image = first_image_from_dir(normals_dir)
141
- return "βœ… Step 2 complete. Ready for UV Map.", image, state, gr.update(interactive=False), gr.update(interactive=True)
 
142
 
143
  # Step 3: UV map inference β†’ uv map image
144
  @spaces.GPU()
145
  def step3_uv_map(state):
146
- session_id = state.get("session_id")
147
- if not session_id:
148
- return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False)
149
 
150
- try:
151
- p = subprocess.run([
152
- "python", "scripts/network_inference.py", "model.prediction_type=uv_map", f"video_name={session_id}"
153
- ], check=True, capture_output=True, text=True)
154
- except subprocess.CalledProcessError as e:
155
- err = f"❌ UV map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
156
- return err, None, state, gr.update(interactive=True), gr.update(interactive=False)
 
 
 
 
 
 
157
 
158
  uv_dir = os.path.join(state["base_dir"], "p3dmm", "uv_map")
159
  image = first_image_from_dir(uv_dir)
160
- return "βœ… Step 3 complete. Ready for Tracking.", image, state, gr.update(interactive=False), gr.update(interactive=True)
 
161
 
162
  # Step 4: Tracking β†’ final tracking image
163
  @spaces.GPU()
164
  def step4_track(state):
165
 
 
 
166
  # Lazy init + caching of FLAME model on GPU
167
  if "flame_model" not in _model_cache:
168
  import os
@@ -175,7 +180,7 @@ def step4_track(state):
175
  from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer
176
  from pixel3dmm.tracking.tracker import Tracker
177
 
178
- flame = FLAME(base_conf) # CPU instantiation
179
  flame = flame.to(DEVICE) # CUDA init happens here
180
  _model_cache["flame_model"] = flame
181
 
@@ -184,24 +189,23 @@ def step4_track(state):
184
  _obj_faces = load_obj(_mesh_file)[1]
185
 
186
  _model_cache["diff_renderer"] = NVDRenderer(
187
- image_size=base_conf.size,
188
  obj_filename=_mesh_file,
189
  no_sh=False,
190
  white_bg=True
191
  ).to(DEVICE)
192
-
193
-
194
  flame_model = _model_cache["flame_model"]
195
  diff_renderer = _model_cache["diff_renderer"]
196
  session_id = state.get("session_id")
197
- base_conf.video_name = f'{session_id}'
198
- tracker = Tracker(base_conf, flame_model, diff_renderer)
199
  tracker.run()
200
 
201
  tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
202
  image = first_image_from_dir(tracking_dir)
203
 
204
- return "βœ… Pipeline complete!", image, state, gr.update(interactive=False)
205
 
206
  # Build Gradio UI
207
  demo = gr.Blocks()
@@ -212,7 +216,7 @@ with demo:
212
  with gr.Row():
213
  with gr.Column():
214
  image_in = gr.Image(label="Upload Image", type="numpy", height=512)
215
- status = gr.Textbox(label="Status", lines=2, interactive=False, value="Upload an image to start.")
216
  state = gr.State({})
217
  with gr.Column():
218
  with gr.Row():
@@ -224,9 +228,9 @@ with demo:
224
 
225
  with gr.Row():
226
  preprocess_btn = gr.Button("Step 1: Preprocess", interactive=True)
227
- normals_btn = gr.Button("Step 2: Normals", interactive=False)
228
- uv_map_btn = gr.Button("Step 3: UV Map", interactive=False)
229
- track_btn = gr.Button("Step 4: Track", interactive=False)
230
 
231
  # Define component list for reset
232
  outputs_for_reset = [crop_img, normals_img, uv_img, track_img, status, state, preprocess_btn, normals_btn, uv_map_btn, track_btn]
 
32
 
33
  from pixel3dmm import env_paths
34
 
 
35
  sh("cd src/pixel3dmm/preprocessing/facer && pip install -e . && cd ../../../..")
36
  sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh && cd ../../../../../..")
37
 
 
38
  def install_cuda_toolkit():
39
  CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
40
  CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
 
55
  install_cuda_toolkit()
56
 
57
  from omegaconf import OmegaConf
58
+ from pixel3dmm.network_inference import normals_n_uvs
59
 
60
  DEVICE = "cuda"
61
 
 
 
 
62
  # 2. Empty cache for our heavy objects
63
  _model_cache = {}
64
 
 
83
  "Awaiting new image upload...", # status
84
  {}, # state
85
  gr.update(interactive=True), # preprocess_btn
86
+ gr.update(interactive=True), # normals_btn
87
+ gr.update(interactive=True), # uv_map_btn
88
+ gr.update(interactive=True) # track_btn
89
  )
90
 
91
  # Step 1: Preprocess the input image (Save and Crop)
92
  @spaces.GPU()
93
  def preprocess_image(image_array, state):
94
  if image_array is None:
95
+ return "❌ Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=True)
96
 
97
  session_id = str(uuid.uuid4())
98
  base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
 
111
  except subprocess.CalledProcessError as e:
112
  err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
113
  shutil.rmtree(base_dir)
114
+ return err, None, {}, gr.update(interactive=True), gr.update(interactive=True)
115
 
116
  crop_dir = os.path.join(base_dir, "cropped")
117
  image = first_image_from_dir(crop_dir)
118
+ return "βœ… Step 1 complete. Ready for Normals.", image, state, gr.update(interactive=True), gr.update(interactive=True)
119
 
120
  # Step 2: Normals inference β†’ normals image
121
  @spaces.GPU()
122
  def step2_normals(state):
 
 
 
123
 
124
+ base_conf = OmegaConf.load("configs/base.yaml")
125
+
126
+ if "normals_model" not in _model_cache:
127
+ from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system
128
+
129
+ model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_N_PRED}", strict=False)
130
+ model = model.eval().to(DEVICE)
131
+ _model_cache["normals_model"] = model
132
+
133
+ session_id = state.get("session_id")
134
+ base_conf.video_name = f'{session_id}'
135
+ normals_n_uvs(base_conf, _model_cache["normals_model"])
136
 
137
  normals_dir = os.path.join(state["base_dir"], "p3dmm", "normals")
138
  image = first_image_from_dir(normals_dir)
139
+
140
+ return "βœ… Step 2 complete. Ready for UV Map.", image, state, gr.update(interactive=True), gr.update(interactive=True)
141
 
142
  # Step 3: UV map inference β†’ uv map image
143
  @spaces.GPU()
144
  def step3_uv_map(state):
 
 
 
145
 
146
+ base_conf = OmegaConf.load("configs/base.yaml")
147
+
148
+ if "uv_model" not in _model_cache:
149
+ from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system
150
+
151
+ model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_UV_PRED}", strict=False)
152
+ model = model.eval().to(DEVICE)
153
+ _model_cache["uv_model"] = model
154
+
155
+ session_id = state.get("session_id")
156
+ base_conf.video_name = f'{session_id}'
157
+ base_conf.model.prediction_type = "uv_map"
158
+ normals_n_uvs(base_conf, _model_cache["uv_model"])
159
 
160
  uv_dir = os.path.join(state["base_dir"], "p3dmm", "uv_map")
161
  image = first_image_from_dir(uv_dir)
162
+
163
+ return "βœ… Step 3 complete. Ready for Tracking.", image, state, gr.update(interactive=True), gr.update(interactive=True)
164
 
165
  # Step 4: Tracking β†’ final tracking image
166
  @spaces.GPU()
167
  def step4_track(state):
168
 
169
+ tracking_conf = OmegaConf.load("configs/tracking.yaml")
170
+
171
  # Lazy init + caching of FLAME model on GPU
172
  if "flame_model" not in _model_cache:
173
  import os
 
180
  from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer
181
  from pixel3dmm.tracking.tracker import Tracker
182
 
183
+ flame = FLAME(tracking_conf) # CPU instantiation
184
  flame = flame.to(DEVICE) # CUDA init happens here
185
  _model_cache["flame_model"] = flame
186
 
 
189
  _obj_faces = load_obj(_mesh_file)[1]
190
 
191
  _model_cache["diff_renderer"] = NVDRenderer(
192
+ image_size=tracking_conf.size,
193
  obj_filename=_mesh_file,
194
  no_sh=False,
195
  white_bg=True
196
  ).to(DEVICE)
197
+
 
198
  flame_model = _model_cache["flame_model"]
199
  diff_renderer = _model_cache["diff_renderer"]
200
  session_id = state.get("session_id")
201
+ tracking_conf.video_name = f'{session_id}'
202
+ tracker = Tracker(tracking_conf, flame_model, diff_renderer)
203
  tracker.run()
204
 
205
  tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
206
  image = first_image_from_dir(tracking_dir)
207
 
208
+ return "βœ… Pipeline complete!", image, state, gr.update(interactive=True)
209
 
210
  # Build Gradio UI
211
  demo = gr.Blocks()
 
216
  with gr.Row():
217
  with gr.Column():
218
  image_in = gr.Image(label="Upload Image", type="numpy", height=512)
219
+ status = gr.Textbox(label="Status", lines=2, interactive=True, value="Upload an image to start.")
220
  state = gr.State({})
221
  with gr.Column():
222
  with gr.Row():
 
228
 
229
  with gr.Row():
230
  preprocess_btn = gr.Button("Step 1: Preprocess", interactive=True)
231
+ normals_btn = gr.Button("Step 2: Normals", interactive=True)
232
+ uv_map_btn = gr.Button("Step 3: UV Map", interactive=True)
233
+ track_btn = gr.Button("Step 4: Track", interactive=True)
234
 
235
  # Define component list for reset
236
  outputs_for_reset = [crop_img, normals_img, uv_img, track_img, status, state, preprocess_btn, normals_btn, uv_map_btn, track_btn]