guangkaixu commited on
Commit
bf2146f
1 Parent(s): 843312c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -77,13 +77,13 @@ def process_image(
77
  processing_res=processing_res,
78
  batch_size=1 if processing_res == 0 else 0,
79
  show_progress_bar=False,
 
80
  )
81
 
82
  depth_pred = pipe_out.pred_np
83
  depth_colored = pipe_out.pred_colored
84
 
85
  np.save(path_out_fp32, depth_pred)
86
-
87
  depth_colored.save(path_out_vis)
88
 
89
  if mode == 'depth':
@@ -98,10 +98,10 @@ def process_image(
98
  [path_out_16bit, path_out_fp32, path_out_vis],
99
  )
100
 
101
- def run_demo_server(pipe):
102
- process_pipe_depth = spaces.GPU(functools.partial(process_image, pipe, mode='depth'))
103
- process_pipe_normal = spaces.GPU(functools.partial(process_image, pipe, mode='normal'))
104
- process_pipe_dis = spaces.GPU(functools.partial(process_image, pipe, mode='dis'))
105
  gradio_theme = gr.themes.Default()
106
 
107
  with gr.Blocks(
@@ -336,7 +336,7 @@ def run_demo_server(pipe):
336
  preprocess=False,
337
  queue=False,
338
  ).success(
339
- fn=process_pipe_normal,
340
  inputs=[
341
  image_input,
342
  image_processing_res,
@@ -383,17 +383,28 @@ def main():
383
  unet_depth_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_depth_v1").to(dtype)
384
  empty_text_embed = torch.from_numpy(np.load("./empty_text_embed.npy")).to(device, dtype)[None] # [1, 77, 1024]
385
 
386
- pipe = GenPerceptPipeline(vae=vae,
387
- unet=unet_depth_v1,
388
- empty_text_embed=empty_text_embed)
 
 
 
 
 
 
389
  try:
390
  import xformers
391
- pipe.enable_xformers_memory_efficient_attention()
 
 
392
  except:
393
  pass # run without xformers
394
 
395
- pipe = pipe.to(device)
396
- run_demo_server(pipe)
 
 
 
397
 
398
 
399
  if __name__ == "__main__":
 
77
  processing_res=processing_res,
78
  batch_size=1 if processing_res == 0 else 0,
79
  show_progress_bar=False,
80
+ mode=mode,
81
  )
82
 
83
  depth_pred = pipe_out.pred_np
84
  depth_colored = pipe_out.pred_colored
85
 
86
  np.save(path_out_fp32, depth_pred)
 
87
  depth_colored.save(path_out_vis)
88
 
89
  if mode == 'depth':
 
98
  [path_out_16bit, path_out_fp32, path_out_vis],
99
  )
100
 
101
+ def run_demo_server(pipe_depth, pipe_normal, pipe_dis):
102
+ process_pipe_depth = spaces.GPU(functools.partial(process_image, pipe_depth, mode='depth'))
103
+ process_pipe_normal = spaces.GPU(functools.partial(process_image, pipe_normal, mode='normal'))
104
+ process_pipe_dis = spaces.GPU(functools.partial(process_image, pipe_dis, mode='dis'))
105
  gradio_theme = gr.themes.Default()
106
 
107
  with gr.Blocks(
 
336
  preprocess=False,
337
  queue=False,
338
  ).success(
339
+ fn=process_pipe_image,
340
  inputs=[
341
  image_input,
342
  image_processing_res,
 
383
  unet_depth_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_depth_v1").to(dtype)
384
  empty_text_embed = torch.from_numpy(np.load("./empty_text_embed.npy")).to(device, dtype)[None] # [1, 77, 1024]
385
 
386
+ pipe_depth = GenPerceptPipeline(vae=vae,
387
+ unet=unet_depth_v1,
388
+ empty_text_embed=empty_text_embed)
389
+ pipe_normal = GenPerceptPipeline(vae=vae,
390
+ unet=unet_normal_v1,
391
+ empty_text_embed=empty_text_embed)
392
+ pipe_dis = GenPerceptPipeline(vae=vae,
393
+ unet=unet_dis_v1,
394
+ empty_text_embed=empty_text_embed)
395
  try:
396
  import xformers
397
+ pipe_depth.enable_xformers_memory_efficient_attention()
398
+ pipe_normal.enable_xformers_memory_efficient_attention()
399
+ pipe_dis.enable_xformers_memory_efficient_attention()
400
  except:
401
  pass # run without xformers
402
 
403
+ pipe_depth = pipe_depth.to(device)
404
+ pipe_normal = pipe_normal.to(device)
405
+ pipe_dis = pipe_dis.to(device)
406
+
407
+ run_demo_server(pipe_depth, pipe_normal, pipe_dis)
408
 
409
 
410
  if __name__ == "__main__":