Eugeoter commited on
Commit
067bc73
·
1 Parent(s): 189acad
Files changed (2) hide show
  1. app.py +9 -3
  2. utils/tools.py +2 -2
app.py CHANGED
@@ -13,6 +13,13 @@ CONTROLNET_FILENAME = "ControlAny-SDXL/anime_canny/controlnet.safetensors"
13
  CACHE_DIR = None
14
 
15
 
 
 
 
 
 
 
 
16
  def ui():
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model_file = hf_hub_download(
@@ -41,10 +48,9 @@ def ui():
41
  device=device,
42
  hf_cache_dir=CACHE_DIR,
43
  use_safetensors=True,
44
- enable_xformers_memory_efficient_attention=True,
45
  )
46
- with spaces.GPU:
47
- pipeline = pipeline.to(dtype=torch.float16)
48
 
49
  preprocessors = ['canny']
50
  schedulers = ['Euler A', 'UniPC', 'Euler', 'DDIM', 'DDPM']
 
13
  CACHE_DIR = None
14
 
15
 
16
+ @spaces.GPU
17
+ def optimize_pipeline(pipeline):
18
+ pipeline.to(dtype=torch.float16)
19
+ pipeline.enable_xformers_memory_efficient_attention()
20
+ return pipeline
21
+
22
+
23
  def ui():
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  model_file = hf_hub_download(
 
48
  device=device,
49
  hf_cache_dir=CACHE_DIR,
50
  use_safetensors=True,
51
+ enable_xformers_memory_efficient_attention=torch.cuda.is_available(),
52
  )
53
+ pipeline = optimize_pipeline(pipeline)
 
54
 
55
  preprocessors = ['canny']
56
  schedulers = ['Euler A', 'UniPC', 'Euler', 'DDIM', 'DDPM']
utils/tools.py CHANGED
@@ -112,11 +112,11 @@ def get_pipeline(
112
 
113
  pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
114
  pipeline.set_progress_bar_config()
115
- pipeline = pipeline.to(device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
116
 
117
  if lora_path is not None:
118
  pipeline.load_lora_weights(lora_path)
119
- if enable_xformers_memory_efficient_attention and torch.cuda.is_available():
120
  pipeline.enable_xformers_memory_efficient_attention()
121
 
122
  return pipeline
 
112
 
113
  pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
114
  pipeline.set_progress_bar_config()
115
+ pipeline = pipeline.to(device, dtype=torch.float16)
116
 
117
  if lora_path is not None:
118
  pipeline.load_lora_weights(lora_path)
119
+ if enable_xformers_memory_efficient_attention:
120
  pipeline.enable_xformers_memory_efficient_attention()
121
 
122
  return pipeline