gaur3009 commited on
Commit
a416f9f
ยท
verified ยท
1 Parent(s): 79b8531

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -27
app.py CHANGED
@@ -1,39 +1,41 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- from diffusers import DiffusionPipeline, DDPMScheduler, DPMSolverMultistepScheduler
5
  import torch
 
6
 
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
- model_repo_id = "stabilityai/sdxl-turbo"
9
 
10
- if torch.cuda.is_available():
11
- torch_dtype = torch.float16
12
- else:
13
- torch_dtype = torch.float32
14
 
15
- # ๐Ÿงฎ Use ODE-based DPM-SolverMultistepScheduler for faster inference
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
 
 
 
 
17
 
18
- # โœ… Replace default scheduler with an ODE-optimized scheduler
19
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
20
 
21
- # ๐Ÿš€ Optimization Flags (Math-Driven)
22
- try:
23
- pipe.enable_xformers_memory_efficient_attention()
24
- except Exception as e:
25
- print("xFormers not available, falling back to slicing.")
26
- pipe.enable_attention_slicing()
 
27
 
28
- pipe.enable_model_cpu_offload()
29
- pipe.enable_attention_slicing()
30
- pipe.enable_vae_tiling()
31
 
32
  pipe = pipe.to(device)
33
 
34
- MAX_SEED = np.iinfo(np.int32).max
35
- MAX_IMAGE_SIZE = 1024
36
-
37
  def infer(
38
  prompt,
39
  negative_prompt,
@@ -48,7 +50,7 @@ def infer(
48
  if randomize_seed:
49
  seed = random.randint(0, MAX_SEED)
50
 
51
- generator = torch.Generator().manual_seed(seed)
52
 
53
  image = pipe(
54
  prompt=prompt,
@@ -62,12 +64,14 @@ def infer(
62
 
63
  return image, seed
64
 
 
65
  examples = [
66
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
67
  "An astronaut riding a green horse",
68
  "A delicious ceviche cheesecake slice",
69
  ]
70
 
 
71
  css = """
72
  #col-container {
73
  margin: 0 auto;
@@ -75,9 +79,10 @@ css = """
75
  }
76
  """
77
 
 
78
  with gr.Blocks(css=css) as demo:
79
  with gr.Column(elem_id="col-container"):
80
- gr.Markdown(" # Text-to-Image Gradio Template")
81
 
82
  with gr.Row():
83
  prompt = gr.Text(
@@ -133,18 +138,19 @@ with gr.Blocks(css=css) as demo:
133
  minimum=0.0,
134
  maximum=10.0,
135
  step=0.1,
136
- value=0.0,
137
  )
138
 
139
  num_inference_steps = gr.Slider(
140
  label="Number of inference steps",
141
  minimum=1,
142
- maximum=50,
143
  step=1,
144
- value=2,
145
  )
146
 
147
  gr.Examples(examples=examples, inputs=[prompt])
 
148
  gr.on(
149
  triggers=[run_button.click, prompt.submit],
150
  fn=infer,
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
4
  import torch
5
+ from diffusers import DiffusionPipeline, EDMEulerScheduler
6
 
7
+ # ๐Ÿ–ฅ๏ธ Detect device
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
+ # ๐ŸŽฏ Model ID and config
11
+ model_repo_id = "stabilityai/sdxl-turbo"
12
+ MAX_SEED = np.iinfo(np.int32).max
13
+ MAX_IMAGE_SIZE = 1024
14
 
15
+ # ๐Ÿ” Load model with EDM + VPred scheduler
16
+ pipe = DiffusionPipeline.from_pretrained(
17
+ model_repo_id,
18
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
19
+ variant="fp16" if torch.cuda.is_available() else None,
20
+ )
21
 
22
+ # ๐Ÿ”„ Replace scheduler with EDM + V-prediction
23
+ pipe.scheduler = EDMEulerScheduler.from_config(pipe.scheduler.config)
24
 
25
+ # ๐Ÿง  Enable optimizations if GPU
26
+ if device == "cuda":
27
+ try:
28
+ pipe.enable_xformers_memory_efficient_attention()
29
+ except Exception as e:
30
+ print("โš ๏ธ xFormers not available, using attention slicing.")
31
+ pipe.enable_attention_slicing()
32
 
33
+ pipe.enable_model_cpu_offload()
34
+ pipe.enable_vae_tiling()
 
35
 
36
  pipe = pipe.to(device)
37
 
38
+ # ๐Ÿš€ Inference function
 
 
39
  def infer(
40
  prompt,
41
  negative_prompt,
 
50
  if randomize_seed:
51
  seed = random.randint(0, MAX_SEED)
52
 
53
+ generator = torch.Generator(device=device).manual_seed(seed)
54
 
55
  image = pipe(
56
  prompt=prompt,
 
64
 
65
  return image, seed
66
 
67
+ # ๐Ÿงช Prompt examples
68
  examples = [
69
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
70
  "An astronaut riding a green horse",
71
  "A delicious ceviche cheesecake slice",
72
  ]
73
 
74
+ # ๐ŸŽจ UI CSS
75
  css = """
76
  #col-container {
77
  margin: 0 auto;
 
79
  }
80
  """
81
 
82
+ # ๐Ÿงฑ Gradio Interface
83
  with gr.Blocks(css=css) as demo:
84
  with gr.Column(elem_id="col-container"):
85
+ gr.Markdown(" # Text-to-Image Gradio (EDM + VPred)")
86
 
87
  with gr.Row():
88
  prompt = gr.Text(
 
138
  minimum=0.0,
139
  maximum=10.0,
140
  step=0.1,
141
+ value=2.5, # Optimal for SDXL-Turbo
142
  )
143
 
144
  num_inference_steps = gr.Slider(
145
  label="Number of inference steps",
146
  minimum=1,
147
+ maximum=20,
148
  step=1,
149
+ value=4, # Low default for EDM
150
  )
151
 
152
  gr.Examples(examples=examples, inputs=[prompt])
153
+
154
  gr.on(
155
  triggers=[run_button.click, prompt.submit],
156
  fn=infer,