recoilme commited on
Commit
8f1e053
·
verified ·
1 Parent(s): 95f0546

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -11
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import torch
 
3
 
4
- from diffusers import DiffusionPipeline
5
  from diffusers import EulerDiscreteScheduler
6
 
7
  device = "cpu"
@@ -17,17 +18,24 @@ if mps_available:
17
  dtype = torch.float16
18
  #print(f"device: {device}, dtype: {dtype}")
19
 
20
-
21
- pipeline = DiffusionPipeline.from_pretrained("recoilme/ColorfulXL-Lightning",
22
- variant="fp16",
23
- torch_dtype=dtype,
24
- use_safetensors=True)
25
  pipeline.to(device)
26
  pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
 
 
 
 
 
27
 
 
 
 
28
 
29
- def generate(prompt, width, height, sample_steps):
30
- return pipeline(prompt=prompt, guidance_scale=0, negative_prompt="", width=width, height=height, num_inference_steps=sample_steps).images[0]
31
 
32
  with gr.Blocks() as interface:
33
  with gr.Column():
@@ -41,14 +49,20 @@ with gr.Blocks() as interface:
41
  with gr.Column():
42
  width = gr.Slider(label="Width", info="The width in pixels of the generated image.", value=576, minimum=512, maximum=1280, step=64, interactive=True)
43
  height = gr.Slider(label="Height", info="The height in pixels of the generated image.", value=832, minimum=512, maximum=1280, step=64, interactive=True)
 
 
 
 
 
 
 
44
  with gr.Column():
45
  sampling_steps = gr.Slider(label="Sampling Steps", info="The number of denoising steps.", value=5, minimum=3, maximum=10, step=1, interactive=True)
46
 
47
  with gr.Row():
48
  output = gr.Image()
49
 
50
- generate_button.click(fn=generate, inputs=[prompt, width, height, sampling_steps], outputs=[output])
51
 
52
  if __name__ == "__main__":
53
- interface.launch()
54
-
 
1
  import gradio as gr
2
  import torch
3
+ import random
4
 
5
+ from diffusers import StableDiffusionXLPipeline
6
  from diffusers import EulerDiscreteScheduler
7
 
8
  device = "cpu"
 
18
  dtype = torch.float16
19
  #print(f"device: {device}, dtype: {dtype}")
20
 
21
+ pipeline = StableDiffusionXLPipeline.from_pretrained("recoilme/ColorfulXL-Lightning",
22
+ variant="fp16",
23
+ torch_dtype=dtype,
24
+ use_safetensors=True)
 
25
  pipeline.to(device)
26
  pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
27
+ # Comes from
28
+ # https://wandb.ai/nasirk24/UNET-FreeU-SDXL/reports/FreeU-SDXL-Optimal-Parameters--Vmlldzo1NDg4NTUw
29
+ if device == "cuda":
30
+ pipeline.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
31
+
32
 
33
+ def generate(prompt, width, height, sample_steps, seed):
34
+ generator = torch.Generator(device=device).manual_seed(int(seed))
35
+ return pipeline(prompt=prompt, prompt_2=prompt, guidance_scale=0, generator=generator, negative_prompt=None, negative_prompt_2=None, width=width, height=height, num_inference_steps=sample_steps).images[0]
36
 
37
+ def random_seed():
38
+ return random.randint(0, 2**32 - 1)
39
 
40
  with gr.Blocks() as interface:
41
  with gr.Column():
 
49
  with gr.Column():
50
  width = gr.Slider(label="Width", info="The width in pixels of the generated image.", value=576, minimum=512, maximum=1280, step=64, interactive=True)
51
  height = gr.Slider(label="Height", info="The height in pixels of the generated image.", value=832, minimum=512, maximum=1280, step=64, interactive=True)
52
+ with gr.Row():
53
+ seed = gr.Number(label="Seed",
54
+ value=None,
55
+ scale=8,
56
+ info="Random seed for reproducibility.")
57
+ seed_button = gr.Button("🎲", scale=2, elem_id="seed_button")
58
+ seed_button.click(fn=random_seed, inputs=[], outputs=seed)
59
  with gr.Column():
60
  sampling_steps = gr.Slider(label="Sampling Steps", info="The number of denoising steps.", value=5, minimum=3, maximum=10, step=1, interactive=True)
61
 
62
  with gr.Row():
63
  output = gr.Image()
64
 
65
+ generate_button.click(fn=generate, inputs=[prompt, width, height, sampling_steps, seed], outputs=[output])
66
 
67
  if __name__ == "__main__":
68
+ interface.launch()