guangkaixu commited on
Commit
fc8fe55
1 Parent(s): 22160e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -2
app.py CHANGED
@@ -162,10 +162,45 @@ def process_dis(
162
  [path_out_fp32, path_out_vis],
163
  )
164
 
165
- def run_demo_server(pipe_depth, pipe_normal, pipe_dis):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  process_pipe_depth = spaces.GPU(functools.partial(process_depth, pipe_depth))
167
  process_pipe_normal = spaces.GPU(functools.partial(process_normal, pipe_normal))
168
  process_pipe_dis = spaces.GPU(functools.partial(process_dis, pipe_dis))
 
169
  gradio_theme = gr.themes.Default()
170
 
171
  with gr.Blocks(
@@ -409,6 +444,63 @@ def run_demo_server(pipe_depth, pipe_normal, pipe_dis):
409
  cache_examples=False,
410
  )
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  ### Image tab
414
  depth_image_submit_btn.click(
@@ -510,6 +602,38 @@ def run_demo_server(pipe_depth, pipe_normal, pipe_dis):
510
  queue=False,
511
  )
512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  ### Server launch
514
 
515
  demo.queue(
@@ -534,6 +658,7 @@ def main():
534
  use_safetensors=True).to(dtype)
535
  unet_normal_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_normal_v1", use_safetensors=True).to(dtype)
536
  unet_dis_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_dis_v1", use_safetensors=True).to(dtype)
 
537
 
538
  empty_text_embed = torch.from_numpy(np.load("./empty_text_embed.npy")).to(device, dtype)[None] # [1, 77, 1024]
539
 
@@ -546,19 +671,24 @@ def main():
546
  pipe_dis = GenPerceptPipeline(vae=vae,
547
  unet=unet_dis_v1,
548
  empty_text_embed=empty_text_embed)
 
 
 
549
  try:
550
  import xformers
551
  pipe_depth.enable_xformers_memory_efficient_attention()
552
  pipe_normal.enable_xformers_memory_efficient_attention()
553
  pipe_dis.enable_xformers_memory_efficient_attention()
 
554
  except:
555
  pass # run without xformers
556
 
557
  pipe_depth = pipe_depth.to(device)
558
  pipe_normal = pipe_normal.to(device)
559
  pipe_dis = pipe_dis.to(device)
 
560
 
561
- run_demo_server(pipe_depth, pipe_normal, pipe_dis)
562
 
563
 
564
  if __name__ == "__main__":
 
162
  [path_out_fp32, path_out_vis],
163
  )
164
 
165
+ def process_matting(
166
+ pipe,
167
+ path_input,
168
+ processing_res=default_image_processing_res,
169
+ ):
170
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
171
+ print(f"Processing image {name_base}{name_ext}")
172
+
173
+ path_output_dir = tempfile.mkdtemp()
174
+ path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_matting_fp32.npy")
175
+ path_out_vis = os.path.join(path_output_dir, f"{name_base}_matting_colored.png")
176
+
177
+ input_image = Image.open(path_input)
178
+
179
+ pipe_out = pipe(
180
+ input_image,
181
+ processing_res=processing_res,
182
+ batch_size=1 if processing_res == 0 else 0,
183
+ show_progress_bar=False,
184
+ mode='seg',
185
+ )
186
+
187
+ depth_pred = pipe_out.pred_np
188
+ depth_colored = pipe_out.pred_colored
189
+
190
+ np.save(path_out_fp32, depth_pred)
191
+ depth_colored.save(path_out_vis)
192
+
193
+ return (
194
+ [path_out_vis],
195
+ [path_out_fp32, path_out_vis],
196
+ )
197
+
198
+
199
+ def run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting):
200
  process_pipe_depth = spaces.GPU(functools.partial(process_depth, pipe_depth))
201
  process_pipe_normal = spaces.GPU(functools.partial(process_normal, pipe_normal))
202
  process_pipe_dis = spaces.GPU(functools.partial(process_dis, pipe_dis))
203
+ process_pipe_matting = spaces.GPU(functools.partial(process_matting, pipe_matting))
204
  gradio_theme = gr.themes.Default()
205
 
206
  with gr.Blocks(
 
444
  cache_examples=False,
445
  )
446
 
447
+ with gr.Tab("Matting"):
448
+ with gr.Row():
449
+ with gr.Column():
450
+ dis_image_input = gr.Image(
451
+ label="Input Image",
452
+ type="filepath",
453
+ )
454
+ with gr.Row():
455
+ matting_image_submit_btn = gr.Button(
456
+ value="Estimate Matting", variant="primary"
457
+ )
458
+ matting_image_reset_btn = gr.Button(value="Reset")
459
+ with gr.Accordion("Advanced options", open=False):
460
+ image_processing_res = gr.Radio(
461
+ [
462
+ ("Native", 0),
463
+ ("Recommended", 768),
464
+ ],
465
+ label="Processing resolution",
466
+ value=default_image_processing_res,
467
+ )
468
+ with gr.Column():
469
+ # dis_image_output_slider = ImageSlider(
470
+ # label="Predicted dichotomous image segmentation",
471
+ # type="filepath",
472
+ # show_download_button=True,
473
+ # show_share_button=True,
474
+ # interactive=False,
475
+ # elem_classes="slider",
476
+ # position=0.25,
477
+ # )
478
+ matting_image_output = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
479
+ matting_image_output_files = gr.Files(
480
+ label="Matting outputs",
481
+ elem_id="download",
482
+ interactive=False,
483
+ )
484
+
485
+ filenames = []
486
+ filenames.extend(["matting_%d.jpg" %(i+1) for i in range(10)])
487
+ # example_folder = "images"
488
+ # print('line 396', __file__)
489
+ example_folder = os.path.join(os.path.dirname(__file__), "matting_images")
490
+ # print(example_folder)
491
+ Examples(
492
+ fn=process_pipe_dis,
493
+ examples=[
494
+ os.path.join(example_folder, name)
495
+ for name in filenames
496
+ ],
497
+ inputs=[dis_image_input],
498
+ outputs=[dis_image_output, dis_image_output_files],
499
+ # cache_examples=True,
500
+ directory_name="images_cache",
501
+ cache_examples=False,
502
+ )
503
+
504
 
505
  ### Image tab
506
  depth_image_submit_btn.click(
 
602
  queue=False,
603
  )
604
 
605
+ matting_image_submit_btn.click(
606
+ fn=process_image_check,
607
+ inputs=matting_image_input,
608
+ outputs=None,
609
+ preprocess=False,
610
+ queue=False,
611
+ ).success(
612
+ fn=process_pipe_dis,
613
+ inputs=[
614
+ matting_image_input,
615
+ image_processing_res,
616
+ ],
617
+ outputs=[matting_image_output, matting_image_output_files],
618
+ concurrency_limit=1,
619
+ )
620
+
621
+ matting_image_reset_btn.click(
622
+ fn=lambda: (
623
+ None,
624
+ None,
625
+ None,
626
+ default_image_processing_res,
627
+ ),
628
+ inputs=[],
629
+ outputs=[
630
+ matting_image_input,
631
+ matting_image_output,
632
+ matting_image_output_files,
633
+ image_processing_res,
634
+ ],
635
+ queue=False,
636
+ )
637
  ### Server launch
638
 
639
  demo.queue(
 
658
  use_safetensors=True).to(dtype)
659
  unet_normal_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_normal_v1", use_safetensors=True).to(dtype)
660
  unet_dis_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_dis_v1", use_safetensors=True).to(dtype)
661
+ unet_matting_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/genpercept-matting', subfolder="unet", use_safetensors=True).to(dtype)
662
 
663
  empty_text_embed = torch.from_numpy(np.load("./empty_text_embed.npy")).to(device, dtype)[None] # [1, 77, 1024]
664
 
 
671
  pipe_dis = GenPerceptPipeline(vae=vae,
672
  unet=unet_dis_v1,
673
  empty_text_embed=empty_text_embed)
674
+ pipe_matting = GenPerceptPipeline(vae=vae,
675
+ unet=unet_matting_v1,
676
+ empty_text_embed=empty_text_embed)
677
  try:
678
  import xformers
679
  pipe_depth.enable_xformers_memory_efficient_attention()
680
  pipe_normal.enable_xformers_memory_efficient_attention()
681
  pipe_dis.enable_xformers_memory_efficient_attention()
682
+ pipe_matting.enable_xformers_memory_efficient_attention()
683
  except:
684
  pass # run without xformers
685
 
686
  pipe_depth = pipe_depth.to(device)
687
  pipe_normal = pipe_normal.to(device)
688
  pipe_dis = pipe_dis.to(device)
689
+ pipe_matting = pipe_matting.to(device)
690
 
691
+ run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting)
692
 
693
 
694
  if __name__ == "__main__":