alexnasa commited on
Commit
c65fc1e
Β·
verified Β·
1 Parent(s): a55f645

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -31
app.py CHANGED
@@ -98,16 +98,14 @@ def reset_all():
98
 
99
  # Step 1: Preprocess the input image (Save and Crop)
100
  @spaces.GPU()
101
- def preprocess_image(image_array, state):
102
  if image_array is None:
103
- return "❌ Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=True)
104
 
105
- base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
106
  os.makedirs(base_dir, exist_ok=True)
107
- state.update({"base_dir": base_dir})
108
 
109
  img = Image.fromarray(image_array)
110
- saved_image_path = os.path.join(base_dir, f"{session_id}.png")
111
  img.save(saved_image_path)
112
 
113
  try:
@@ -121,11 +119,11 @@ def preprocess_image(image_array, state):
121
 
122
  crop_dir = os.path.join(base_dir, "cropped")
123
  image = first_image_from_dir(crop_dir)
124
- return "βœ… Step 1 complete. Ready for Normals.", image, state, gr.update(interactive=True), gr.update(interactive=True)
125
 
126
  # Step 2: Normals inference β†’ normals image
127
  @spaces.GPU()
128
- def step2_normals(state):
129
 
130
  base_conf = OmegaConf.load("configs/base.yaml")
131
 
@@ -136,18 +134,17 @@ def step2_normals(state):
136
  model = model.eval().to(DEVICE)
137
  _model_cache["normals_model"] = model
138
 
139
- session_id = state.get("session_id")
140
  base_conf.video_name = f'{session_id}'
141
  normals_n_uvs(base_conf, _model_cache["normals_model"])
142
 
143
- normals_dir = os.path.join(state["base_dir"], "p3dmm", "normals")
144
  image = first_image_from_dir(normals_dir)
145
 
146
- return "βœ… Step 2 complete. Ready for UV Map.", image, state, gr.update(interactive=True), gr.update(interactive=True)
147
 
148
  # Step 3: UV map inference β†’ uv map image
149
  @spaces.GPU()
150
- def step3_uv_map(state):
151
 
152
  base_conf = OmegaConf.load("configs/base.yaml")
153
 
@@ -158,19 +155,18 @@ def step3_uv_map(state):
158
  model = model.eval().to(DEVICE)
159
  _model_cache["uv_model"] = model
160
 
161
- session_id = state.get("session_id")
162
  base_conf.video_name = f'{session_id}'
163
  base_conf.model.prediction_type = "uv_map"
164
  normals_n_uvs(base_conf, _model_cache["uv_model"])
165
 
166
- uv_dir = os.path.join(state["base_dir"], "p3dmm", "uv_map")
167
  image = first_image_from_dir(uv_dir)
168
 
169
- return "βœ… Step 3 complete. Ready for Tracking.", image, state, gr.update(interactive=True), gr.update(interactive=True)
170
 
171
  # Step 4: Tracking β†’ final tracking image
172
  @spaces.GPU()
173
- def step4_track(state):
174
 
175
  tracking_conf = OmegaConf.load("configs/tracking.yaml")
176
 
@@ -203,7 +199,6 @@ def step4_track(state):
203
 
204
  flame_model = _model_cache["flame_model"]
205
  diff_renderer = _model_cache["diff_renderer"]
206
- session_id = state.get("session_id")
207
  tracking_conf.video_name = f'{session_id}'
208
  tracker = Tracker(tracking_conf, flame_model, diff_renderer)
209
  tracker.run()
@@ -212,31 +207,30 @@ def step4_track(state):
212
  tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
213
  image = first_image_from_dir(tracking_dir)
214
 
215
- return "βœ… Pipeline complete!", image, state, gr.update(interactive=True)
216
 
217
  # New: run all steps sequentially
218
  @spaces.GPU()
219
- def run_pipeline(image_array, state, request: gr.Request):
220
 
221
  session_id = request.session_hash
222
- state.update({"session_id": session_id, "base_dir": base_dir})
223
 
224
  # Step 1
225
- status1, crop_img, state, _, _ = preprocess_image(image_array, state)
226
  if "❌" in status1:
227
- return status1, None, None, None, None, None, {}
228
  # Step 2
229
- status2, normals_img, state, _, _ = step2_normals(state)
230
  # Step 3
231
- status3, uv_img, state, _, _ = step3_uv_map(state)
232
  # Step 4
233
- status4, track_img, state, _ = step4_track(state)
234
  # Locate mesh (.ply)
235
- mesh_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], state.get("session_id"), "mesh")
236
  mesh_file = first_file_from_dir(mesh_dir, "ply")
237
 
238
  final_status = "\n".join([status1, status2, status3, status4])
239
- return final_status, crop_img, normals_img, uv_img, track_img, mesh_file, state
240
 
241
  # Cleanup on unload
242
  def cleanup(request: gr.Request):
@@ -275,7 +269,7 @@ with gr.Blocks(css=css) as demo:
275
  with gr.Column():
276
  image_in = gr.Image(label="Upload Image", type="numpy", height=512)
277
  status = gr.Textbox(label="Status", lines=6, interactive=True, value="Upload an image to start.")
278
- state = gr.State({})
279
  with gr.Column():
280
  with gr.Row():
281
  crop_img = gr.Image(label="Preprocessed", height=256)
@@ -297,11 +291,11 @@ with gr.Blocks(css=css) as demo:
297
 
298
  run_btn.click(
299
  fn=run_pipeline,
300
- inputs=[image_in, state],
301
- outputs=[status, crop_img, normals_img, uv_img, track_img, mesh_file, state]
302
  )
303
- examples.outputs = [status, crop_img, normals_img, uv_img, track_img, mesh_file, state]
304
- image_in.upload(fn=reset_all, inputs=None, outputs=[crop_img, normals_img, uv_img, track_img, mesh_file, status, state, run_btn])
305
 
306
  demo.unload(cleanup)
307
 
 
98
 
99
  # Step 1: Preprocess the input image (Save and Crop)
100
  @spaces.GPU()
101
+ def preprocess_image(image_array, session_id):
102
  if image_array is None:
103
+ return "❌ Please upload an image first.", None, gr.update(interactive=True), gr.update(interactive=True)
104
 
 
105
  os.makedirs(base_dir, exist_ok=True)
 
106
 
107
  img = Image.fromarray(image_array)
108
+ saved_image_path = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, f"{session_id}.png")
109
  img.save(saved_image_path)
110
 
111
  try:
 
119
 
120
  crop_dir = os.path.join(base_dir, "cropped")
121
  image = first_image_from_dir(crop_dir)
122
+ return "βœ… Step 1 complete. Ready for Normals.", image, gr.update(interactive=True), gr.update(interactive=True)
123
 
124
  # Step 2: Normals inference β†’ normals image
125
  @spaces.GPU()
126
+ def step2_normals(session_id):
127
 
128
  base_conf = OmegaConf.load("configs/base.yaml")
129
 
 
134
  model = model.eval().to(DEVICE)
135
  _model_cache["normals_model"] = model
136
 
 
137
  base_conf.video_name = f'{session_id}'
138
  normals_n_uvs(base_conf, _model_cache["normals_model"])
139
 
140
+ normals_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "normals")
141
  image = first_image_from_dir(normals_dir)
142
 
143
+ return "βœ… Step 2 complete. Ready for UV Map.", image, gr.update(interactive=True), gr.update(interactive=True)
144
 
145
  # Step 3: UV map inference β†’ uv map image
146
  @spaces.GPU()
147
+ def step3_uv_map(session_id):
148
 
149
  base_conf = OmegaConf.load("configs/base.yaml")
150
 
 
155
  model = model.eval().to(DEVICE)
156
  _model_cache["uv_model"] = model
157
 
 
158
  base_conf.video_name = f'{session_id}'
159
  base_conf.model.prediction_type = "uv_map"
160
  normals_n_uvs(base_conf, _model_cache["uv_model"])
161
 
162
+ uv_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "uv_map")
163
  image = first_image_from_dir(uv_dir)
164
 
165
+ return "βœ… Step 3 complete. Ready for Tracking.", image, gr.update(interactive=True), gr.update(interactive=True)
166
 
167
  # Step 4: Tracking β†’ final tracking image
168
  @spaces.GPU()
169
+ def step4_track(session_id):
170
 
171
  tracking_conf = OmegaConf.load("configs/tracking.yaml")
172
 
 
199
 
200
  flame_model = _model_cache["flame_model"]
201
  diff_renderer = _model_cache["diff_renderer"]
 
202
  tracking_conf.video_name = f'{session_id}'
203
  tracker = Tracker(tracking_conf, flame_model, diff_renderer)
204
  tracker.run()
 
207
  tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
208
  image = first_image_from_dir(tracking_dir)
209
 
210
+ return "βœ… Pipeline complete!", image, gr.update(interactive=True)
211
 
212
  # New: run all steps sequentially
213
  @spaces.GPU()
214
+ def run_pipeline(image_array, request: gr.Request):
215
 
216
  session_id = request.session_hash
 
217
 
218
  # Step 1
219
+ status1, crop_img, _, _ = preprocess_image(image_array)
220
  if "❌" in status1:
221
+ return status1, None, None, None, None, None
222
  # Step 2
223
+ status2, normals_img, _, _ = step2_normals(session_id)
224
  # Step 3
225
+ status3, uv_img, _, _ = step3_uv_map(session_id)
226
  # Step 4
227
+ status4, track_img, _ = step4_track(session_id)
228
  # Locate mesh (.ply)
229
+ mesh_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "mesh")
230
  mesh_file = first_file_from_dir(mesh_dir, "ply")
231
 
232
  final_status = "\n".join([status1, status2, status3, status4])
233
+ return final_status, crop_img, normals_img, uv_img, track_img, mesh_file
234
 
235
  # Cleanup on unload
236
  def cleanup(request: gr.Request):
 
269
  with gr.Column():
270
  image_in = gr.Image(label="Upload Image", type="numpy", height=512)
271
  status = gr.Textbox(label="Status", lines=6, interactive=True, value="Upload an image to start.")
272
+
273
  with gr.Column():
274
  with gr.Row():
275
  crop_img = gr.Image(label="Preprocessed", height=256)
 
291
 
292
  run_btn.click(
293
  fn=run_pipeline,
294
+ inputs=[image_in],
295
+ outputs=[status, crop_img, normals_img, uv_img, track_img, mesh_file]
296
  )
297
+ examples.outputs = [status, crop_img, normals_img, uv_img, track_img, mesh_file]
298
+ image_in.upload(fn=reset_all, inputs=None, outputs=[crop_img, normals_img, uv_img, track_img, mesh_file, status, run_btn])
299
 
300
  demo.unload(cleanup)
301