alexnasa commited on
Commit
d568786
Β·
verified Β·
1 Parent(s): e2cec11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -42
app.py CHANGED
@@ -44,7 +44,6 @@ def install_cuda_toolkit():
44
 
45
  install_cuda_toolkit()
46
 
47
-
48
  # Utility to select first image from a folder
49
  def first_image_from_dir(directory):
50
  patterns = ["*.jpg", "*.png", "*.jpeg"]
@@ -55,14 +54,27 @@ def first_image_from_dir(directory):
55
  return None
56
  return sorted(files)[0]
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Step 1: Preprocess the input image (Save and Crop)
59
  @spaces.GPU()
60
  def preprocess_image(image_array, state):
61
- # Check if an image was uploaded
62
  if image_array is None:
63
- return "❌ Please upload an image first.", None, state
64
 
65
- # Step 1a: Save the uploaded image
66
  session_id = str(uuid.uuid4())
67
  base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
68
  os.makedirs(base_dir, exist_ok=True)
@@ -73,96 +85,87 @@ def preprocess_image(image_array, state):
73
  img.save(saved_image_path)
74
  state["image_path"] = saved_image_path
75
 
76
- # Step 1b: Run the preprocessing script
77
  try:
78
  p = subprocess.run([
79
- "python", "scripts/run_preprocessing.py",
80
- "--video_or_images_path", saved_image_path
81
  ], check=True, capture_output=True, text=True)
82
  except subprocess.CalledProcessError as e:
83
  err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
84
- # Clean up created directory on failure
85
  shutil.rmtree(base_dir)
86
- return err, None, state
87
 
88
  crop_dir = os.path.join(base_dir, "cropped")
89
  image = first_image_from_dir(crop_dir)
90
- return "βœ… Preprocessing complete", image, state
91
-
92
 
93
  # Step 2: Normals inference β†’ normals image
94
  @spaces.GPU()
95
  def step2_normals(state):
96
  session_id = state.get("session_id")
97
  if not session_id:
98
- return "❌ Please preprocess an image first.", None, state
99
 
100
  try:
101
- # Execute the network inference for normals
102
  p = subprocess.run([
103
- "python", "scripts/network_inference.py",
104
- "model.prediction_type=normals", f"video_name={session_id}"
105
  ], check=True, capture_output=True, text=True)
106
  except subprocess.CalledProcessError as e:
107
  err = f"❌ Normal map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
108
- return err, None, state
109
 
110
  normals_dir = os.path.join(state["base_dir"], "p3dmm", "normals")
111
  image = first_image_from_dir(normals_dir)
112
- return "βœ… Step 2: Normals inference complete", image, state
113
 
114
  # Step 3: UV map inference β†’ uv map image
115
  @spaces.GPU()
116
  def step3_uv_map(state):
117
  session_id = state.get("session_id")
118
  if not session_id:
119
- return "❌ Please preprocess an image first.", None, state
120
 
121
  try:
122
- # Execute the network inference for UV map
123
  p = subprocess.run([
124
- "python", "scripts/network_inference.py",
125
- "model.prediction_type=uv_map", f"video_name={session_id}"
126
  ], check=True, capture_output=True, text=True)
127
  except subprocess.CalledProcessError as e:
128
  err = f"❌ UV map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
129
- return err, None, state
130
 
131
  uv_dir = os.path.join(state["base_dir"], "p3dmm", "uv_map")
132
  image = first_image_from_dir(uv_dir)
133
- return "βœ… Step 3: UV map inference complete", image, state
134
 
135
  # Step 4: Tracking β†’ final tracking image
136
  @spaces.GPU()
137
  def step4_track(state):
138
  session_id = state.get("session_id")
139
  if not session_id:
140
- return "❌ Please preprocess an image first.", None, state
141
 
142
  script = os.path.join(os.environ["PIXEL3DMM_CODE_BASE"], "scripts", "track.py")
143
  try:
144
- # Execute the tracking script
145
  p = subprocess.run([
146
- "python", script,
147
- f"video_name={session_id}"
148
  ], check=True, capture_output=True, text=True)
149
  except subprocess.CalledProcessError as e:
150
  err = f"❌ Tracking failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
151
- return err, None, state
152
 
153
  tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
154
  image = first_image_from_dir(tracking_dir)
155
- return "βœ… Step 4: Tracking complete", image, state
156
 
157
  # Build Gradio UI
158
  demo = gr.Blocks()
159
 
160
  with demo:
161
  gr.Markdown("## Image Processing Pipeline")
 
162
  with gr.Row():
163
  with gr.Column():
164
  image_in = gr.Image(label="Upload Image", type="numpy", height=512)
165
- status = gr.Textbox(label="Status", lines=2, interactive=False)
166
  state = gr.State({})
167
  with gr.Column():
168
  with gr.Row():
@@ -173,21 +176,41 @@ with demo:
173
  track_img = gr.Image(label="Tracking", height=256)
174
 
175
  with gr.Row():
176
- preprocess_btn = gr.Button("Step 1: Preprocess")
177
- normals_btn = gr.Button("Step 2: Normals")
178
- uv_map_btn = gr.Button("Step 3: UV Map")
179
- track_btn = gr.Button("Step 4: Track")
180
-
181
- # Pipeline execution
182
- preprocess_btn.click(fn=preprocess_image, inputs=[image_in, state], outputs=[status, crop_img, state])
183
- normals_btn.click(fn=step2_normals, inputs=[state], outputs=[status, normals_img, state])
184
- uv_map_btn.click(fn=step3_uv_map, inputs=[state], outputs=[status, uv_img, state])
185
- track_btn.click(fn=step4_track, inputs=[state], outputs=[status, track_img, state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
 
 
187
 
188
  # ------------------------------------------------------------------
189
  # START THE GRADIO SERVER
190
  # ------------------------------------------------------------------
191
  demo.queue()
192
- demo.launch(share=True, ssr_mode=False)
193
-
 
44
 
45
  install_cuda_toolkit()
46
 
 
47
  # Utility to select first image from a folder
48
  def first_image_from_dir(directory):
49
  patterns = ["*.jpg", "*.png", "*.jpeg"]
 
54
  return None
55
  return sorted(files)[0]
56
 
57
+ # Function to reset the UI and state
58
+ def reset_all():
59
+ return (
60
+ None, # crop_img
61
+ None, # normals_img
62
+ None, # uv_img
63
+ None, # track_img
64
+ "Awaiting new image upload...", # status
65
+ {}, # state
66
+ gr.update(interactive=True), # preprocess_btn
67
+ gr.update(interactive=False), # normals_btn
68
+ gr.update(interactive=False), # uv_map_btn
69
+ gr.update(interactive=False) # track_btn
70
+ )
71
+
72
  # Step 1: Preprocess the input image (Save and Crop)
73
  @spaces.GPU()
74
  def preprocess_image(image_array, state):
 
75
  if image_array is None:
76
+ return "❌ Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=False)
77
 
 
78
  session_id = str(uuid.uuid4())
79
  base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
80
  os.makedirs(base_dir, exist_ok=True)
 
85
  img.save(saved_image_path)
86
  state["image_path"] = saved_image_path
87
 
 
88
  try:
89
  p = subprocess.run([
90
+ "python", "scripts/run_preprocessing.py", "--video_or_images_path", saved_image_path
 
91
  ], check=True, capture_output=True, text=True)
92
  except subprocess.CalledProcessError as e:
93
  err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
 
94
  shutil.rmtree(base_dir)
95
+ return err, None, {}, gr.update(interactive=True), gr.update(interactive=False)
96
 
97
  crop_dir = os.path.join(base_dir, "cropped")
98
  image = first_image_from_dir(crop_dir)
99
+ return "βœ… Step 1 complete. Ready for Normals.", image, state, gr.update(interactive=False), gr.update(interactive=True)
 
100
 
101
  # Step 2: Normals inference β†’ normals image
102
  @spaces.GPU()
103
  def step2_normals(state):
104
  session_id = state.get("session_id")
105
  if not session_id:
106
+ return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False)
107
 
108
  try:
 
109
  p = subprocess.run([
110
+ "python", "scripts/network_inference.py", "model.prediction_type=normals", f"video_name={session_id}"
 
111
  ], check=True, capture_output=True, text=True)
112
  except subprocess.CalledProcessError as e:
113
  err = f"❌ Normal map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
114
+ return err, None, state, gr.update(interactive=True), gr.update(interactive=False)
115
 
116
  normals_dir = os.path.join(state["base_dir"], "p3dmm", "normals")
117
  image = first_image_from_dir(normals_dir)
118
+ return "βœ… Step 2 complete. Ready for UV Map.", image, state, gr.update(interactive=False), gr.update(interactive=True)
119
 
120
  # Step 3: UV map inference β†’ uv map image
121
  @spaces.GPU()
122
  def step3_uv_map(state):
123
  session_id = state.get("session_id")
124
  if not session_id:
125
+ return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False)
126
 
127
  try:
 
128
  p = subprocess.run([
129
+ "python", "scripts/network_inference.py", "model.prediction_type=uv_map", f"video_name={session_id}"
 
130
  ], check=True, capture_output=True, text=True)
131
  except subprocess.CalledProcessError as e:
132
  err = f"❌ UV map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
133
+ return err, None, state, gr.update(interactive=True), gr.update(interactive=False)
134
 
135
  uv_dir = os.path.join(state["base_dir"], "p3dmm", "uv_map")
136
  image = first_image_from_dir(uv_dir)
137
+ return "βœ… Step 3 complete. Ready for Tracking.", image, state, gr.update(interactive=False), gr.update(interactive=True)
138
 
139
  # Step 4: Tracking β†’ final tracking image
140
  @spaces.GPU()
141
  def step4_track(state):
142
  session_id = state.get("session_id")
143
  if not session_id:
144
+ return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False)
145
 
146
  script = os.path.join(os.environ["PIXEL3DMM_CODE_BASE"], "scripts", "track.py")
147
  try:
 
148
  p = subprocess.run([
149
+ "python", script, f"video_name={session_id}"
 
150
  ], check=True, capture_output=True, text=True)
151
  except subprocess.CalledProcessError as e:
152
  err = f"❌ Tracking failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
153
+ return err, None, state, gr.update(interactive=True)
154
 
155
  tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
156
  image = first_image_from_dir(tracking_dir)
157
+ return "βœ… Pipeline complete!", image, state, gr.update(interactive=False)
158
 
159
  # Build Gradio UI
160
  demo = gr.Blocks()
161
 
162
  with demo:
163
  gr.Markdown("## Image Processing Pipeline")
164
+ gr.Markdown("Upload an image, then click the buttons in order. Uploading a new image will reset the process.")
165
  with gr.Row():
166
  with gr.Column():
167
  image_in = gr.Image(label="Upload Image", type="numpy", height=512)
168
+ status = gr.Textbox(label="Status", lines=2, interactive=False, value="Upload an image to start.")
169
  state = gr.State({})
170
  with gr.Column():
171
  with gr.Row():
 
176
  track_img = gr.Image(label="Tracking", height=256)
177
 
178
  with gr.Row():
179
+ preprocess_btn = gr.Button("Step 1: Preprocess", interactive=True)
180
+ normals_btn = gr.Button("Step 2: Normals", interactive=False)
181
+ uv_map_btn = gr.Button("Step 3: UV Map", interactive=False)
182
+ track_btn = gr.Button("Step 4: Track", interactive=False)
183
+
184
+ # Define component list for reset
185
+ outputs_for_reset = [crop_img, normals_img, uv_img, track_img, status, state, preprocess_btn, normals_btn, uv_map_btn, track_btn]
186
+
187
+ # Pipeline execution logic
188
+ preprocess_btn.click(
189
+ fn=preprocess_image,
190
+ inputs=[image_in, state],
191
+ outputs=[status, crop_img, state, preprocess_btn, normals_btn]
192
+ )
193
+ normals_btn.click(
194
+ fn=step2_normals,
195
+ inputs=[state],
196
+ outputs=[status, normals_img, state, normals_btn, uv_map_btn]
197
+ )
198
+ uv_map_btn.click(
199
+ fn=step3_uv_map,
200
+ inputs=[state],
201
+ outputs=[status, uv_img, state, uv_map_btn, track_btn]
202
+ )
203
+ track_btn.click(
204
+ fn=step4_track,
205
+ inputs=[state],
206
+ outputs=[status, track_img, state, track_btn]
207
+ )
208
 
209
+ # Event to reset everything when a new image is uploaded
210
+ image_in.upload(fn=reset_all, inputs=None, outputs=outputs_for_reset)
211
 
212
  # ------------------------------------------------------------------
213
  # START THE GRADIO SERVER
214
  # ------------------------------------------------------------------
215
  demo.queue()
216
+ demo.launch(share=True, ssr_mode=False)