sjtu-deepvision commited on
Commit
7c4abde
·
verified ·
1 Parent(s): a390bd6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -6,6 +6,9 @@ import tempfile
6
  import numpy as np
7
  import torch as torch
8
  torch.backends.cuda.matmul.allow_tf32 = True
 
 
 
9
 
10
 
11
  from diffusers import (
@@ -83,6 +86,12 @@ if __name__ == "__main__":
83
  t_start=0,
84
  ).to(device)
85
 
 
 
 
 
 
 
86
  # Cache example images in memory
87
  example_images_dir = "files/image"
88
  example_images = []
@@ -93,7 +102,7 @@ if __name__ == "__main__":
93
 
94
  # Create a Gradio interface
95
  interface = gr.Interface(
96
- fn=lambda image: process_image(pipe, vae_2, image),
97
  inputs=gr.Image(type="pil"),
98
  outputs=gr.Image(type="pil"),
99
  title="Dereflection Any Image",
 
6
  import numpy as np
7
  import torch as torch
8
  torch.backends.cuda.matmul.allow_tf32 = True
9
+ import spaces
10
+ import functools
11
+
12
 
13
 
14
  from diffusers import (
 
86
  t_start=0,
87
  ).to(device)
88
 
89
+ try:
90
+ import xformers
91
+ pipe.enable_xformers_memory_efficient_attention()
92
+ except:
93
+ pass # run without xformers
94
+
95
  # Cache example images in memory
96
  example_images_dir = "files/image"
97
  example_images = []
 
102
 
103
  # Create a Gradio interface
104
  interface = gr.Interface(
105
+ fn=spaces.GPU(functools.partial(process_image, pipe, vae_2)),
106
  inputs=gr.Image(type="pil"),
107
  outputs=gr.Image(type="pil"),
108
  title="Dereflection Any Image",