amazonaws-la commited on
Commit
42aaf02
·
verified ·
1 Parent(s): e3e1bfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -4
app.py CHANGED
@@ -10,7 +10,7 @@ import numpy as np
10
  import PIL.Image
11
  import spaces
12
  import torch
13
- from diffusers import AutoencoderKL, DiffusionPipeline
14
 
15
  DESCRIPTION = "# SDXL"
16
  if not torch.cuda.is_available():
@@ -23,6 +23,7 @@ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
23
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
24
  ENABLE_REFINER = os.getenv("ENABLE_REFINER", "1") == "1"
25
  ENABLE_USE_LORA = os.getenv("ENABLE_USE_LORA", "1") == "1"
 
26
 
27
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
 
@@ -48,15 +49,25 @@ def generate(
48
  guidance_scale_refiner: float = 5.0,
49
  num_inference_steps_base: int = 25,
50
  num_inference_steps_refiner: int = 25,
 
51
  use_lora: bool = False,
52
  apply_refiner: bool = False,
53
- model = 'stabilityai/stable-diffusion-xl-base-1.0',
54
- vaecall = 'madebyollin/sdxl-vae-fp16-fix',
55
  lora = 'amazonaws-la/juliette',
 
56
  ) -> PIL.Image.Image:
57
  if torch.cuda.is_available():
58
- pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
59
 
 
 
 
 
 
 
 
 
 
60
  if use_lora:
61
  pipe.load_lora_weights(lora)
62
  pipe.fuse_lora(lora_scale=0.7)
@@ -134,6 +145,13 @@ with gr.Blocks(css="style.css") as demo:
134
  model = gr.Text(label='Modelo')
135
  vaecall = gr.Text(label='VAE')
136
  lora = gr.Text(label='LoRA')
 
 
 
 
 
 
 
137
  with gr.Row():
138
  prompt = gr.Text(
139
  label="Prompt",
@@ -191,6 +209,7 @@ with gr.Blocks(css="style.css") as demo:
191
  step=32,
192
  value=1024,
193
  )
 
194
  use_lora = gr.Checkbox(label='Use Lora', value=False, visible=ENABLE_USE_LORA)
195
  apply_refiner = gr.Checkbox(label="Apply refiner", value=False, visible=ENABLE_REFINER)
196
  with gr.Row():
@@ -253,6 +272,13 @@ with gr.Blocks(css="style.css") as demo:
253
  queue=False,
254
  api_name=False,
255
  )
 
 
 
 
 
 
 
256
  use_lora.change(
257
  fn=lambda x: gr.update(visible=x),
258
  inputs=use_lora,
@@ -298,11 +324,13 @@ with gr.Blocks(css="style.css") as demo:
298
  guidance_scale_refiner,
299
  num_inference_steps_base,
300
  num_inference_steps_refiner,
 
301
  use_lora,
302
  apply_refiner,
303
  model,
304
  vaecall,
305
  lora,
 
306
  ],
307
  outputs=result,
308
  api_name="run",
 
10
  import PIL.Image
11
  import spaces
12
  import torch
13
+ from diffusers import DPMSolverMultistepScheduler, AutoencoderKL, DiffusionPipeline
14
 
15
  DESCRIPTION = "# SDXL"
16
  if not torch.cuda.is_available():
 
23
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
24
  ENABLE_REFINER = os.getenv("ENABLE_REFINER", "1") == "1"
25
  ENABLE_USE_LORA = os.getenv("ENABLE_USE_LORA", "1") == "1"
26
+ ENABLE_USE_VAE = os.getenv("ENABLE_USE_VAE", "1") == "1"
27
 
28
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
 
 
49
  guidance_scale_refiner: float = 5.0,
50
  num_inference_steps_base: int = 25,
51
  num_inference_steps_refiner: int = 25,
52
+ use_vae: bool = False,
53
  use_lora: bool = False,
54
  apply_refiner: bool = False,
55
+ model = 'SG161222/Realistic_Vision_V6.0_B1_noVAE',
56
+ vaecall = 'stabilityai/sd-vae-ft-mse',
57
  lora = 'amazonaws-la/juliette',
58
+ lora_scale: float = 0.7,
59
  ) -> PIL.Image.Image:
60
  if torch.cuda.is_available():
 
61
 
62
+ if not use_vae:
63
+ pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
64
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
65
+
66
+ if use_vae:
67
+ vae = AutoencoderKL.from_pretrained(vaecall, torch_dtype=torch.float16)
68
+ pipe = DiffusionPipeline.from_pretrained(model, vae=vae, torch_dtype=torch.float16)
69
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
70
+
71
  if use_lora:
72
  pipe.load_lora_weights(lora)
73
  pipe.fuse_lora(lora_scale=0.7)
 
145
  model = gr.Text(label='Modelo')
146
  vaecall = gr.Text(label='VAE')
147
  lora = gr.Text(label='LoRA')
148
+ lora_scale = gr.Slider(
149
+ label="Lora Scale",
150
+ minimum=0.01,
151
+ maximum=1,
152
+ step=0.01,
153
+ value=0.7,
154
+ )
155
  with gr.Row():
156
  prompt = gr.Text(
157
  label="Prompt",
 
209
  step=32,
210
  value=1024,
211
  )
212
+ use_vae = gr.Checkbox(label='Use VAE', value=False, visible=ENABLE_USE_VAE)
213
  use_lora = gr.Checkbox(label='Use Lora', value=False, visible=ENABLE_USE_LORA)
214
  apply_refiner = gr.Checkbox(label="Apply refiner", value=False, visible=ENABLE_REFINER)
215
  with gr.Row():
 
272
  queue=False,
273
  api_name=False,
274
  )
275
+ use_vae.change(
276
+ fn=lambda x: gr.update(visible=x),
277
+ inputs=use_vae,
278
+ outputs=vaecall,
279
+ queue=False,
280
+ api_name=False,
281
+ )
282
  use_lora.change(
283
  fn=lambda x: gr.update(visible=x),
284
  inputs=use_lora,
 
324
  guidance_scale_refiner,
325
  num_inference_steps_base,
326
  num_inference_steps_refiner,
327
+ use_vae,
328
  use_lora,
329
  apply_refiner,
330
  model,
331
  vaecall,
332
  lora,
333
+ lora_scale,
334
  ],
335
  outputs=result,
336
  api_name="run",