alexnasa commited on
Commit
0c8ae54
Β·
verified Β·
1 Parent(s): 1814d0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -4
app.py CHANGED
@@ -17,8 +17,220 @@ os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results"
17
 
18
  def sh(cmd): subprocess.check_call(cmd, shell=True)
19
 
20
-
21
  sh("pip install -e .")
22
- sh("cd src/pixel3dmm/preprocessing/facer && pip install -e .")
23
- sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh")
24
- sh("python app_photo.py")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def sh(cmd): subprocess.check_call(cmd, shell=True)
19
 
 
20
  sh("pip install -e .")
21
+ sh("cd src/pixel3dmm/preprocessing/facer && pip install -e . && cd ../../../..")
22
+ sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh && cd ../../../../../..")
23
+
24
+
25
+ def install_cuda_toolkit():
26
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
27
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
28
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
29
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
30
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
31
+
32
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
33
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
34
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
35
+ os.environ["CUDA_HOME"],
36
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
37
+ )
38
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
39
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
40
+ print("==> finished installation")
41
+
42
+ install_cuda_toolkit()
43
+
44
+ import os
45
+ import torch
46
+ import numpy as np
47
+ import trimesh
48
+ from pytorch3d.io import load_obj
49
+
50
+ from pixel3dmm import env_paths
51
+ from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer
52
+ from pixel3dmm.tracking.flame.FLAME import FLAME
53
+ from pixel3dmm.tracking.tracker import Tracker
54
+ from omegaconf import OmegaConf
55
+
56
+
57
+ DEVICE = "cuda"
58
+
59
+ base_conf = OmegaConf.load(f'{env_paths.CODE_BASE}/configs/tracking.yaml')
60
+
61
+ _mesh_file = env_paths.head_template
62
+ flame_model = FLAME(base_conf).to(DEVICE)
63
+
64
+ _obj_faces = load_obj(_mesh_file)[1]
65
+
66
+ diff_renderer = NVDRenderer(
67
+ image_size=base_conf.size,
68
+ obj_filename=_mesh_file,
69
+ no_sh=False,
70
+ white_bg=True
71
+ ).to(DEVICE)
72
+
73
+
74
+ # Utility to select first image from a folder
75
+ def first_image_from_dir(directory):
76
+ patterns = ["*.jpg", "*.png", "*.jpeg"]
77
+ files = []
78
+ for p in patterns:
79
+ files.extend(glob.glob(os.path.join(directory, p)))
80
+ if not files:
81
+ return None
82
+ return sorted(files)[0]
83
+
84
+ # Function to reset the UI and state
85
+ def reset_all():
86
+ return (
87
+ None, # crop_img
88
+ None, # normals_img
89
+ None, # uv_img
90
+ None, # track_img
91
+ "Awaiting new image upload...", # status
92
+ {}, # state
93
+ gr.update(interactive=True), # preprocess_btn
94
+ gr.update(interactive=False), # normals_btn
95
+ gr.update(interactive=False), # uv_map_btn
96
+ gr.update(interactive=False) # track_btn
97
+ )
98
+
99
+ # Step 1: Preprocess the input image (Save and Crop)
100
+ # @spaces.GPU()
101
+ def preprocess_image(image_array, state):
102
+ if image_array is None:
103
+ return "❌ Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=False)
104
+
105
+ session_id = str(uuid.uuid4())
106
+ base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
107
+ os.makedirs(base_dir, exist_ok=True)
108
+ state.update({"session_id": session_id, "base_dir": base_dir})
109
+
110
+ img = Image.fromarray(image_array)
111
+ saved_image_path = os.path.join(base_dir, f"{session_id}.png")
112
+ img.save(saved_image_path)
113
+ state["image_path"] = saved_image_path
114
+
115
+ try:
116
+ p = subprocess.run([
117
+ "python", "scripts/run_preprocessing.py", "--video_or_images_path", saved_image_path
118
+ ], check=True, capture_output=True, text=True)
119
+ except subprocess.CalledProcessError as e:
120
+ err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
121
+ shutil.rmtree(base_dir)
122
+ return err, None, {}, gr.update(interactive=True), gr.update(interactive=False)
123
+
124
+ crop_dir = os.path.join(base_dir, "cropped")
125
+ image = first_image_from_dir(crop_dir)
126
+ return "βœ… Step 1 complete. Ready for Normals.", image, state, gr.update(interactive=False), gr.update(interactive=True)
127
+
128
+ # Step 2: Normals inference β†’ normals image
129
+ @spaces.GPU()
130
+ def step2_normals(state):
131
+ session_id = state.get("session_id")
132
+ if not session_id:
133
+ return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False)
134
+
135
+ try:
136
+ p = subprocess.run([
137
+ "python", "scripts/network_inference.py", "model.prediction_type=normals", f"video_name={session_id}"
138
+ ], check=True, capture_output=True, text=True)
139
+ except subprocess.CalledProcessError as e:
140
+ err = f"❌ Normal map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
141
+ return err, None, state, gr.update(interactive=True), gr.update(interactive=False)
142
+
143
+ normals_dir = os.path.join(state["base_dir"], "p3dmm", "normals")
144
+ image = first_image_from_dir(normals_dir)
145
+ return "βœ… Step 2 complete. Ready for UV Map.", image, state, gr.update(interactive=False), gr.update(interactive=True)
146
+
147
+ # Step 3: UV map inference β†’ uv map image
148
+ @spaces.GPU()
149
+ def step3_uv_map(state):
150
+ session_id = state.get("session_id")
151
+ if not session_id:
152
+ return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False)
153
+
154
+ try:
155
+ p = subprocess.run([
156
+ "python", "scripts/network_inference.py", "model.prediction_type=uv_map", f"video_name={session_id}"
157
+ ], check=True, capture_output=True, text=True)
158
+ except subprocess.CalledProcessError as e:
159
+ err = f"❌ UV map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
160
+ return err, None, state, gr.update(interactive=True), gr.update(interactive=False)
161
+
162
+ uv_dir = os.path.join(state["base_dir"], "p3dmm", "uv_map")
163
+ image = first_image_from_dir(uv_dir)
164
+ return "βœ… Step 3 complete. Ready for Tracking.", image, state, gr.update(interactive=False), gr.update(interactive=True)
165
+
166
+ # Step 4: Tracking β†’ final tracking image
167
+ @spaces.GPU()
168
+ def step4_track(state):
169
+ session_id = state.get("session_id")
170
+ base_conf.video_name = f'{session_id}'
171
+ tracker = Tracker(base_conf, flame_model, diff_renderer)
172
+ tracker.run()
173
+
174
+ tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
175
+ image = first_image_from_dir(tracking_dir)
176
+
177
+ return "βœ… Pipeline complete!", image, state, gr.update(interactive=False)
178
+
179
+ # Build Gradio UI
180
+ demo = gr.Blocks()
181
+
182
+ with demo:
183
+ gr.Markdown("## Image Processing Pipeline")
184
+ gr.Markdown("Upload an image, then click the buttons in order. Uploading a new image will reset the process.")
185
+ with gr.Row():
186
+ with gr.Column():
187
+ image_in = gr.Image(label="Upload Image", type="numpy", height=512)
188
+ status = gr.Textbox(label="Status", lines=2, interactive=False, value="Upload an image to start.")
189
+ state = gr.State({})
190
+ with gr.Column():
191
+ with gr.Row():
192
+ crop_img = gr.Image(label="Preprocessed", height=256)
193
+ normals_img = gr.Image(label="Normals", height=256)
194
+ with gr.Row():
195
+ uv_img = gr.Image(label="UV Map", height=256)
196
+ track_img = gr.Image(label="Tracking", height=256)
197
+
198
+ with gr.Row():
199
+ preprocess_btn = gr.Button("Step 1: Preprocess", interactive=True)
200
+ normals_btn = gr.Button("Step 2: Normals", interactive=False)
201
+ uv_map_btn = gr.Button("Step 3: UV Map", interactive=False)
202
+ track_btn = gr.Button("Step 4: Track", interactive=False)
203
+
204
+ # Define component list for reset
205
+ outputs_for_reset = [crop_img, normals_img, uv_img, track_img, status, state, preprocess_btn, normals_btn, uv_map_btn, track_btn]
206
+
207
+ # Pipeline execution logic
208
+ preprocess_btn.click(
209
+ fn=preprocess_image,
210
+ inputs=[image_in, state],
211
+ outputs=[status, crop_img, state, preprocess_btn, normals_btn]
212
+ )
213
+ normals_btn.click(
214
+ fn=step2_normals,
215
+ inputs=[state],
216
+ outputs=[status, normals_img, state, normals_btn, uv_map_btn]
217
+ )
218
+ uv_map_btn.click(
219
+ fn=step3_uv_map,
220
+ inputs=[state],
221
+ outputs=[status, uv_img, state, uv_map_btn, track_btn]
222
+ )
223
+ track_btn.click(
224
+ fn=step4_track,
225
+ inputs=[state],
226
+ outputs=[status, track_img, state, track_btn]
227
+ )
228
+
229
+ # Event to reset everything when a new image is uploaded
230
+ image_in.upload(fn=reset_all, inputs=None, outputs=outputs_for_reset)
231
+
232
+ # ------------------------------------------------------------------
233
+ # START THE GRADIO SERVER
234
+ # ------------------------------------------------------------------
235
+ demo.queue()
236
+ demo.launch(share=True, ssr_mode=False)