guangkaixu commited on
Commit
2dc1d37
1 Parent(s): e2c3a00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -6
app.py CHANGED
@@ -60,6 +60,7 @@ def process_image_check(path_input):
60
  def process_image(
61
  pipe,
62
  path_input,
 
63
  processing_res=default_image_processing_res,
64
  ):
65
  name_base, name_ext = os.path.splitext(os.path.basename(path_input))
@@ -67,7 +68,6 @@ def process_image(
67
 
68
  path_output_dir = tempfile.mkdtemp()
69
  path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy")
70
- path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png")
71
  path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png")
72
 
73
  input_image = Image.open(path_input)
@@ -81,11 +81,17 @@ def process_image(
81
 
82
  depth_pred = pipe_out.pred_np
83
  depth_colored = pipe_out.pred_colored
84
- depth_16bit = (depth_pred * 65535.0).astype(np.uint16)
85
 
86
  np.save(path_out_fp32, depth_pred)
87
- Image.fromarray(depth_16bit).save(path_out_16bit, mode="I;16")
88
  depth_colored.save(path_out_vis)
 
 
 
 
 
 
 
89
 
90
  return (
91
  [path_out_16bit, path_out_vis],
@@ -93,7 +99,9 @@ def process_image(
93
  )
94
 
95
  def run_demo_server(pipe):
96
- process_pipe_image = spaces.GPU(functools.partial(process_image, pipe))
 
 
97
  gradio_theme = gr.themes.Default()
98
 
99
  with gr.Blocks(
@@ -164,7 +172,7 @@ def run_demo_server(pipe):
164
  )
165
 
166
  with gr.Tabs(elem_classes=["tabs"]):
167
- with gr.Tab("Depth Estimation"):
168
  with gr.Row():
169
  with gr.Column():
170
  image_input = gr.Image(
@@ -206,7 +214,7 @@ def run_demo_server(pipe):
206
  filenames.extend(["line_%d.jpg" %(i+1) for i in range(6)])
207
  filenames.extend(["real_%d.jpg" %(i+1) for i in range(24)])
208
  Examples(
209
- fn=process_pipe_image,
210
  examples=[
211
  os.path.join("images", "depth", name)
212
  for name in filenames
@@ -216,6 +224,109 @@ def run_demo_server(pipe):
216
  cache_examples=True,
217
  directory_name="examples_image",
218
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  ### Image tab
221
  image_submit_btn.click(
 
60
  def process_image(
61
  pipe,
62
  path_input,
63
+ mode='depth',
64
  processing_res=default_image_processing_res,
65
  ):
66
  name_base, name_ext = os.path.splitext(os.path.basename(path_input))
 
68
 
69
  path_output_dir = tempfile.mkdtemp()
70
  path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy")
 
71
  path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png")
72
 
73
  input_image = Image.open(path_input)
 
81
 
82
  depth_pred = pipe_out.pred_np
83
  depth_colored = pipe_out.pred_colored
 
84
 
85
  np.save(path_out_fp32, depth_pred)
86
+
87
  depth_colored.save(path_out_vis)
88
+
89
+ if mode == 'depth':
90
+ path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png")
91
+ depth_16bit = (depth_pred * 65535.0).astype(np.uint16)
92
+ Image.fromarray(depth_16bit).save(path_out_16bit, mode="I;16")
93
+ else:
94
+ path_out_16bit = None
95
 
96
  return (
97
  [path_out_16bit, path_out_vis],
 
99
  )
100
 
101
  def run_demo_server(pipe):
102
+ process_pipe_depth = spaces.GPU(functools.partial(process_image, pipe, mode='depth'))
103
+ process_pipe_normal = spaces.GPU(functools.partial(process_image, pipe, mode='normal'))
104
+ process_pipe_dis = spaces.GPU(functools.partial(process_image, pipe, mode='dis'))
105
  gradio_theme = gr.themes.Default()
106
 
107
  with gr.Blocks(
 
172
  )
173
 
174
  with gr.Tabs(elem_classes=["tabs"]):
175
+ with gr.Tab("Depth"):
176
  with gr.Row():
177
  with gr.Column():
178
  image_input = gr.Image(
 
214
  filenames.extend(["line_%d.jpg" %(i+1) for i in range(6)])
215
  filenames.extend(["real_%d.jpg" %(i+1) for i in range(24)])
216
  Examples(
217
+ fn=process_pipe_depth,
218
  examples=[
219
  os.path.join("images", "depth", name)
220
  for name in filenames
 
224
  cache_examples=True,
225
  directory_name="examples_image",
226
  )
227
+
228
+ with gr.Tab("Normal"):
229
+ with gr.Row():
230
+ with gr.Column():
231
+ image_input = gr.Image(
232
+ label="Input Image",
233
+ type="filepath",
234
+ )
235
+ with gr.Row():
236
+ image_submit_btn = gr.Button(
237
+ value="Estimate Normal", variant="primary"
238
+ )
239
+ image_reset_btn = gr.Button(value="Reset")
240
+ with gr.Accordion("Advanced options", open=False):
241
+ image_processing_res = gr.Radio(
242
+ [
243
+ ("Native", 0),
244
+ ("Recommended", 768),
245
+ ],
246
+ label="Processing resolution",
247
+ value=default_image_processing_res,
248
+ )
249
+ with gr.Column():
250
+ image_output_slider = ImageSlider(
251
+ label="Predicted surface normal",
252
+ type="filepath",
253
+ show_download_button=True,
254
+ show_share_button=True,
255
+ interactive=False,
256
+ elem_classes="slider",
257
+ position=0.25,
258
+ )
259
+ image_output_files = gr.Files(
260
+ label="Normal outputs",
261
+ elem_id="download",
262
+ interactive=False,
263
+ )
264
+
265
+ filenames = []
266
+ filenames.extend(["%d.jpg" %(i+1) for i in range(10)])
267
+ Examples(
268
+ fn=process_pipe_normal,
269
+ examples=[
270
+ os.path.join("images", "normal", name)
271
+ for name in filenames
272
+ ],
273
+ inputs=[image_input],
274
+ outputs=[image_output_slider, image_output_files],
275
+ cache_examples=True,
276
+ directory_name="examples_image",
277
+ )
278
+
279
+ with gr.Tab("Dichotomous Segmentation"):
280
+ with gr.Row():
281
+ with gr.Column():
282
+ image_input = gr.Image(
283
+ label="Input Image",
284
+ type="filepath",
285
+ )
286
+ with gr.Row():
287
+ image_submit_btn = gr.Button(
288
+ value="Estimate Segmentation.", variant="primary"
289
+ )
290
+ image_reset_btn = gr.Button(value="Reset")
291
+ with gr.Accordion("Advanced options", open=False):
292
+ image_processing_res = gr.Radio(
293
+ [
294
+ ("Native", 0),
295
+ ("Recommended", 768),
296
+ ],
297
+ label="Processing resolution",
298
+ value=default_image_processing_res,
299
+ )
300
+ with gr.Column():
301
+ image_output_slider = ImageSlider(
302
+ label="Predicted dichotomous image segmentation",
303
+ type="filepath",
304
+ show_download_button=True,
305
+ show_share_button=True,
306
+ interactive=False,
307
+ elem_classes="slider",
308
+ position=0.25,
309
+ )
310
+ image_output_files = gr.Files(
311
+ label="DIS outputs",
312
+ elem_id="download",
313
+ interactive=False,
314
+ )
315
+
316
+ filenames = []
317
+ filenames.extend(["%d.jpg" %(i+1) for i in range(10)])
318
+ Examples(
319
+ fn=process_pipe_dis,
320
+ examples=[
321
+ os.path.join("images", "dis", name)
322
+ for name in filenames
323
+ ],
324
+ inputs=[image_input],
325
+ outputs=[image_output_slider, image_output_files],
326
+ cache_examples=True,
327
+ directory_name="examples_image",
328
+ )
329
+
330
 
331
  ### Image tab
332
  image_submit_btn.click(