multimodalart HF staff commited on
Commit
61bc6a3
1 Parent(s): 5e64d98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -59
app.py CHANGED
@@ -1,79 +1,47 @@
1
  import gradio as gr
2
- from gradio_client import Client
3
- from diffusers import AutoencoderKL, StableDiffusionXLPipeline
4
  import torch
5
- import concurrent.futures
6
- import spaces
7
-
8
- client_lightning = Client("AP123/SDXL-Lightning")
9
- client_hyper = Client("ByteDance/Hyper-SDXL-1Step-T2I")
10
-
11
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
12
 
13
  ### SDXL Turbo ####
14
- pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
15
- vae=vae,
16
- torch_dtype=torch.float16,
17
- variant="fp16"
18
- )
19
  pipe_turbo.to("cuda")
20
 
 
 
 
 
21
 
22
- def get_lighting_result(prompt):
23
- result_lighting = client_lightning.predict(
24
- prompt, # Your prompt
25
- "1-Step", # Number of inference steps
26
- api_name="/generate_image"
27
- )
28
- return result_lighting
29
 
30
- def get_hyper_result(prompt):
31
- result_hyper = client_hyper.predict(
32
- num_images=1,
33
- height=1024,
34
- width=1024,
35
- prompt=prompt,
36
- seed=3413,
37
- api_name="/process_image"
38
- )
39
- return result_hyper
40
 
41
- @spaces.GPU
42
- def get_turbo_result(prompt):
43
- image_turbo = pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
44
- return image_turbo
 
45
 
46
  def run_comparison(prompt):
47
- with concurrent.futures.ThreadPoolExecutor() as executor:
48
- # Submit tasks to the executor
49
- future_lighting = executor.submit(get_lighting_result, prompt)
50
- future_hyper = executor.submit(get_hyper_result, prompt)
51
- future_turbo = executor.submit(get_turbo_result, prompt)
52
-
53
- # Wait for all futures to complete
54
- results = concurrent.futures.wait(
55
- [future_lighting, future_hyper, future_turbo],
56
- return_when=concurrent.futures.ALL_COMPLETED
57
- )
58
-
59
- # Extract results from futures
60
- result_lighting = future_lighting.result()
61
- result_hyper = future_hyper.result()
62
- image_turbo = future_turbo.result()
63
- print(result_lighting)
64
- print(result_hyper)
65
- return image_turbo, result_lighting, result_hyper
66
 
67
- css = '''
68
- .gradio-container{max-width: 768px !important}
69
- '''
70
- with gr.Blocks(css=css) as demo:
71
  prompt = gr.Textbox(label="Prompt")
72
  run = gr.Button("Run")
73
  with gr.Row():
74
  image_turbo = gr.Image(label="SDXL Turbo")
75
  image_lightning = gr.Image(label="SDXL Lightning")
76
- image_hyper = gr.Image("Hyper SDXL")
77
 
78
  run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])
79
  demo.launch()
 
1
  import gradio as gr
2
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler
 
3
  import torch
4
+ from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
 
 
 
 
 
6
 
7
  ### SDXL Turbo ####
8
+ pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
 
 
 
 
9
  pipe_turbo.to("cuda")
10
 
11
+ ### SDXL Lightning ###
12
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
13
+ repo = "ByteDance/SDXL-Lightning"
14
+ ckpt = "sdxl_lightning_1step_unet_x0.safetensors"
15
 
16
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
17
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
18
+ pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
19
+ pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
20
+ pipe_lightning.to("cuda")
 
 
21
 
22
+ ### Hyper SDXL ###
23
+ repo_name = "ByteDance/Hyper-SD"
24
+ ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
 
 
 
 
 
 
 
25
 
26
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
27
+ unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name), device="cuda"))
28
+ pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
29
+ pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
30
+ pipe_hyper.to("cuda")
31
 
32
  def run_comparison(prompt):
33
+ image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
34
+ image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
35
+ image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
36
+ return image_turbo, image_lightning, image_hyper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ with gr.Blocks() as demo:
 
 
 
39
  prompt = gr.Textbox(label="Prompt")
40
  run = gr.Button("Run")
41
  with gr.Row():
42
  image_turbo = gr.Image(label="SDXL Turbo")
43
  image_lightning = gr.Image(label="SDXL Lightning")
44
+ image_hyper = gr.Image(label="Hyper SDXL")
45
 
46
  run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])
47
  demo.launch()