alexnasa commited on
Commit
f4d3cad
Β·
verified Β·
1 Parent(s): 47be277

Create app_photo.py

Browse files
Files changed (1) hide show
  1. app_photo.py +192 -0
app_photo.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import trimesh
5
+ from pytorch3d.io import load_obj
6
+ from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer
7
+ from pixel3dmm.tracking.flame.FLAME import FLAME
8
+ from pixel3dmm.tracking.tracker import Tracker
9
+ from pixel3dmm import env_paths
10
+ from omegaconf import OmegaConf
11
+
12
+
13
+ DEVICE = "cuda"
14
+
15
+ base_conf = OmegaConf.load(f'{env_paths.CODE_BASE}/configs/tracking.yaml')
16
+
17
+ _mesh_file = env_paths.head_template
18
+ flame_model = FLAME(base_conf).to(DEVICE)
19
+
20
+ _obj_faces = load_obj(_mesh_file)[1]
21
+
22
+ diff_renderer = NVDRenderer(
23
+ image_size=base_conf.size,
24
+ obj_filename=_mesh_file,
25
+ no_sh=False,
26
+ white_bg=True
27
+ ).to(DEVICE)
28
+
29
+
30
+ # Utility to select first image from a folder
31
+ def first_image_from_dir(directory):
32
+ patterns = ["*.jpg", "*.png", "*.jpeg"]
33
+ files = []
34
+ for p in patterns:
35
+ files.extend(glob.glob(os.path.join(directory, p)))
36
+ if not files:
37
+ return None
38
+ return sorted(files)[0]
39
+
40
+ # Function to reset the UI and state
41
+ def reset_all():
42
+ return (
43
+ None, # crop_img
44
+ None, # normals_img
45
+ None, # uv_img
46
+ None, # track_img
47
+ "Awaiting new image upload...", # status
48
+ {}, # state
49
+ gr.update(interactive=True), # preprocess_btn
50
+ gr.update(interactive=False), # normals_btn
51
+ gr.update(interactive=False), # uv_map_btn
52
+ gr.update(interactive=False) # track_btn
53
+ )
54
+
55
+ # Step 1: Preprocess the input image (Save and Crop)
56
+ # @spaces.GPU()
57
+ def preprocess_image(image_array, state):
58
+ if image_array is None:
59
+ return "❌ Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=False)
60
+
61
+ session_id = str(uuid.uuid4())
62
+ base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
63
+ os.makedirs(base_dir, exist_ok=True)
64
+ state.update({"session_id": session_id, "base_dir": base_dir})
65
+
66
+ img = Image.fromarray(image_array)
67
+ saved_image_path = os.path.join(base_dir, f"{session_id}.png")
68
+ img.save(saved_image_path)
69
+ state["image_path"] = saved_image_path
70
+
71
+ try:
72
+ p = subprocess.run([
73
+ "python", "scripts/run_preprocessing.py", "--video_or_images_path", saved_image_path
74
+ ], check=True, capture_output=True, text=True)
75
+ except subprocess.CalledProcessError as e:
76
+ err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
77
+ shutil.rmtree(base_dir)
78
+ return err, None, {}, gr.update(interactive=True), gr.update(interactive=False)
79
+
80
+ crop_dir = os.path.join(base_dir, "cropped")
81
+ image = first_image_from_dir(crop_dir)
82
+ return "βœ… Step 1 complete. Ready for Normals.", image, state, gr.update(interactive=False), gr.update(interactive=True)
83
+
84
+ # Step 2: Normals inference β†’ normals image
85
+ @spaces.GPU()
86
+ def step2_normals(state):
87
+ session_id = state.get("session_id")
88
+ if not session_id:
89
+ return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False)
90
+
91
+ try:
92
+ p = subprocess.run([
93
+ "python", "scripts/network_inference.py", "model.prediction_type=normals", f"video_name={session_id}"
94
+ ], check=True, capture_output=True, text=True)
95
+ except subprocess.CalledProcessError as e:
96
+ err = f"❌ Normal map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
97
+ return err, None, state, gr.update(interactive=True), gr.update(interactive=False)
98
+
99
+ normals_dir = os.path.join(state["base_dir"], "p3dmm", "normals")
100
+ image = first_image_from_dir(normals_dir)
101
+ return "βœ… Step 2 complete. Ready for UV Map.", image, state, gr.update(interactive=False), gr.update(interactive=True)
102
+
103
+ # Step 3: UV map inference β†’ uv map image
104
+ @spaces.GPU()
105
+ def step3_uv_map(state):
106
+ session_id = state.get("session_id")
107
+ if not session_id:
108
+ return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False)
109
+
110
+ try:
111
+ p = subprocess.run([
112
+ "python", "scripts/network_inference.py", "model.prediction_type=uv_map", f"video_name={session_id}"
113
+ ], check=True, capture_output=True, text=True)
114
+ except subprocess.CalledProcessError as e:
115
+ err = f"❌ UV map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
116
+ return err, None, state, gr.update(interactive=True), gr.update(interactive=False)
117
+
118
+ uv_dir = os.path.join(state["base_dir"], "p3dmm", "uv_map")
119
+ image = first_image_from_dir(uv_dir)
120
+ return "βœ… Step 3 complete. Ready for Tracking.", image, state, gr.update(interactive=False), gr.update(interactive=True)
121
+
122
+ # Step 4: Tracking β†’ final tracking image
123
+ @spaces.GPU()
124
+ def step4_track(state):
125
+ session_id = state.get("session_id")
126
+ base_conf.video_name = f'{session_id}'
127
+ tracker = Tracker(base_conf, flame_model, diff_renderer)
128
+ tracker.run()
129
+
130
+ tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
131
+ image = first_image_from_dir(tracking_dir)
132
+
133
+ return "βœ… Pipeline complete!", image, state, gr.update(interactive=False)
134
+
135
+ # Build Gradio UI
136
+ demo = gr.Blocks()
137
+
138
+ with demo:
139
+ gr.Markdown("## Image Processing Pipeline")
140
+ gr.Markdown("Upload an image, then click the buttons in order. Uploading a new image will reset the process.")
141
+ with gr.Row():
142
+ with gr.Column():
143
+ image_in = gr.Image(label="Upload Image", type="numpy", height=512)
144
+ status = gr.Textbox(label="Status", lines=2, interactive=False, value="Upload an image to start.")
145
+ state = gr.State({})
146
+ with gr.Column():
147
+ with gr.Row():
148
+ crop_img = gr.Image(label="Preprocessed", height=256)
149
+ normals_img = gr.Image(label="Normals", height=256)
150
+ with gr.Row():
151
+ uv_img = gr.Image(label="UV Map", height=256)
152
+ track_img = gr.Image(label="Tracking", height=256)
153
+
154
+ with gr.Row():
155
+ preprocess_btn = gr.Button("Step 1: Preprocess", interactive=True)
156
+ normals_btn = gr.Button("Step 2: Normals", interactive=False)
157
+ uv_map_btn = gr.Button("Step 3: UV Map", interactive=False)
158
+ track_btn = gr.Button("Step 4: Track", interactive=False)
159
+
160
+ # Define component list for reset
161
+ outputs_for_reset = [crop_img, normals_img, uv_img, track_img, status, state, preprocess_btn, normals_btn, uv_map_btn, track_btn]
162
+
163
+ # Pipeline execution logic
164
+ preprocess_btn.click(
165
+ fn=preprocess_image,
166
+ inputs=[image_in, state],
167
+ outputs=[status, crop_img, state, preprocess_btn, normals_btn]
168
+ )
169
+ normals_btn.click(
170
+ fn=step2_normals,
171
+ inputs=[state],
172
+ outputs=[status, normals_img, state, normals_btn, uv_map_btn]
173
+ )
174
+ uv_map_btn.click(
175
+ fn=step3_uv_map,
176
+ inputs=[state],
177
+ outputs=[status, uv_img, state, uv_map_btn, track_btn]
178
+ )
179
+ track_btn.click(
180
+ fn=step4_track,
181
+ inputs=[state],
182
+ outputs=[status, track_img, state, track_btn]
183
+ )
184
+
185
+ # Event to reset everything when a new image is uploaded
186
+ image_in.upload(fn=reset_all, inputs=None, outputs=outputs_for_reset)
187
+
188
+ # ------------------------------------------------------------------
189
+ # START THE GRADIO SERVER
190
+ # ------------------------------------------------------------------
191
+ demo.queue()
192
+ demo.launch(share=True, ssr_mode=False)