YaArtemNosenko commited on
Commit
0837027
·
verified ·
1 Parent(s): 4179d35

[ADD] Add LoRa

Browse files
Files changed (1) hide show
  1. app.py +80 -19
app.py CHANGED
@@ -6,14 +6,60 @@ import random
6
  from diffusers import DiffusionPipeline
7
  import torch
8
 
 
 
 
 
 
 
 
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- # model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
  if torch.cuda.is_available():
13
  torch_dtype = torch.float16
14
  else:
15
  torch_dtype = torch.float32
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
  # pipe = pipe.to(device)
19
 
@@ -37,10 +83,14 @@ def infer(
37
  if randomize_seed:
38
  seed = random.randint(0, MAX_SEED)
39
 
40
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
41
- pipe = pipe.to(device)
42
 
43
  generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
44
 
45
  image = pipe(
46
  prompt=prompt,
@@ -73,13 +123,14 @@ with gr.Blocks(css=css) as demo:
73
  gr.Markdown(" # Text-to-Image Gradio Template")
74
 
75
  with gr.Row():
76
- model_repo_id = gr.Text(
77
- label="model_repo_id",
78
- show_label=False,
79
- max_lines=1,
80
- placeholder="Enter your model_repo_id",
81
- container=False,
82
  )
 
 
83
  prompt = gr.Text(
84
  label="Prompt",
85
  show_label=False,
@@ -97,7 +148,6 @@ with gr.Blocks(css=css) as demo:
97
  label="Negative prompt",
98
  max_lines=1,
99
  placeholder="Enter a negative prompt",
100
- visible=False,
101
  )
102
 
103
  seed = gr.Slider(
@@ -105,7 +155,7 @@ with gr.Blocks(css=css) as demo:
105
  minimum=0,
106
  maximum=MAX_SEED,
107
  step=1,
108
- value=0,
109
  )
110
 
111
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
@@ -116,7 +166,7 @@ with gr.Blocks(css=css) as demo:
116
  minimum=256,
117
  maximum=MAX_IMAGE_SIZE,
118
  step=32,
119
- value=1024, # Replace with defaults that work for your model
120
  )
121
 
122
  height = gr.Slider(
@@ -124,32 +174,42 @@ with gr.Blocks(css=css) as demo:
124
  minimum=256,
125
  maximum=MAX_IMAGE_SIZE,
126
  step=32,
127
- value=1024, # Replace with defaults that work for your model
128
  )
129
 
130
  with gr.Row():
131
  guidance_scale = gr.Slider(
132
  label="Guidance scale",
133
  minimum=0.0,
134
- maximum=10.0,
135
- step=0.1,
136
- value=0.0, # Replace with defaults that work for your model
137
  )
138
 
139
  num_inference_steps = gr.Slider(
140
  label="Number of inference steps",
141
  minimum=1,
142
- maximum=50,
143
  step=1,
144
- value=2, # Replace with defaults that work for your model
145
  )
146
 
 
 
 
 
 
 
 
 
 
 
147
  gr.Examples(examples=examples, inputs=[prompt])
148
  gr.on(
149
  triggers=[run_button.click, prompt.submit],
150
  fn=infer,
151
  inputs=[
152
- model_repo_id,
153
  prompt,
154
  negative_prompt,
155
  seed,
@@ -158,6 +218,7 @@ with gr.Blocks(css=css) as demo:
158
  height,
159
  guidance_scale,
160
  num_inference_steps,
 
161
  ],
162
  outputs=[result, seed],
163
  )
 
6
  from diffusers import DiffusionPipeline
7
  import torch
8
 
9
+
10
+ # Model list including your LoRA model
11
+ MODEL_LIST = [
12
+ "CompVis/stable-diffusion-v1-4",
13
+ "stabilityai/sdxl-turbo",
14
+ "runwayml/stable-diffusion-v1-5",
15
+ "stabilityai/stable-diffusion-2-1",
16
+ "YaArtemNosenko/dino_stickers",
17
+ ]
18
+
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
20
 
21
  if torch.cuda.is_available():
22
  torch_dtype = torch.float16
23
  else:
24
  torch_dtype = torch.float32
25
 
26
+
27
+ def load_pipeline(model_id: str):
28
+ """
29
+ Loads or retrieves a cached DiffusionPipeline.
30
+
31
+ If the chosen model is your LoRA adapter, then load the base model
32
+ (CompVis/stable-diffusion-v1-4) and apply the LoRA weights.
33
+ """
34
+ if model_id in model_cache:
35
+ return model_cache[model_id]
36
+
37
+ if model_id == "YaArtemNosenko/dino_stickers":
38
+ # Use the specified base model for your LoRA adapter.
39
+ base_model = "CompVis/stable-diffusion-v1-4"
40
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
41
+ # Load the LoRA weights
42
+ pipe.unet = PeftModel.from_pretrained(
43
+ pipe.unet,
44
+ model_id,
45
+ subfolder="unet",
46
+ torch_dtype=torch_dtype
47
+ )
48
+ pipe.text_encoder = PeftModel.from_pretrained(
49
+ pipe.text_encoder,
50
+ model_id,
51
+ subfolder="text_encoder",
52
+ torch_dtype=torch_dtype
53
+ )
54
+ else:
55
+ pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
56
+
57
+ pipe.to(device)
58
+ model_cache[model_id] = pipe
59
+
60
+ return pipe
61
+
62
+ # model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
63
  # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
64
  # pipe = pipe.to(device)
65
 
 
83
  if randomize_seed:
84
  seed = random.randint(0, MAX_SEED)
85
 
86
+ pipe = load_pipeline(model_id)
 
87
 
88
  generator = torch.Generator().manual_seed(seed)
89
+ if model_id == "YaArtemNosenko/dino_stickers":
90
+ if hasattr(pipe.unet, "set_lora_scale"):
91
+ pipe.unet.set_lora_scale(lora_scale)
92
+ else:
93
+ print("Warning: LoRA scale adjustment method not found on UNet.")
94
 
95
  image = pipe(
96
  prompt=prompt,
 
123
  gr.Markdown(" # Text-to-Image Gradio Template")
124
 
125
  with gr.Row():
126
+ # Dropdown to select the model from Hugging Face
127
+ model_id = gr.Dropdown(
128
+ label="Model",
129
+ choices=MODEL_LIST,
130
+ value=MODEL_LIST[0], # Default model
 
131
  )
132
+
133
+ with gr.Row():
134
  prompt = gr.Text(
135
  label="Prompt",
136
  show_label=False,
 
148
  label="Negative prompt",
149
  max_lines=1,
150
  placeholder="Enter a negative prompt",
 
151
  )
152
 
153
  seed = gr.Slider(
 
155
  minimum=0,
156
  maximum=MAX_SEED,
157
  step=1,
158
+ value=42, # Default seed
159
  )
160
 
161
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
166
  minimum=256,
167
  maximum=MAX_IMAGE_SIZE,
168
  step=32,
169
+ value=1024,
170
  )
171
 
172
  height = gr.Slider(
 
174
  minimum=256,
175
  maximum=MAX_IMAGE_SIZE,
176
  step=32,
177
+ value=1024,
178
  )
179
 
180
  with gr.Row():
181
  guidance_scale = gr.Slider(
182
  label="Guidance scale",
183
  minimum=0.0,
184
+ maximum=20.0,
185
+ step=0.5,
186
+ value=7.0,
187
  )
188
 
189
  num_inference_steps = gr.Slider(
190
  label="Number of inference steps",
191
  minimum=1,
192
+ maximum=100,
193
  step=1,
194
+ value=20,
195
  )
196
 
197
+ # New slider for LoRA scale.
198
+ lora_scale = gr.Slider(
199
+ label="LoRA Scale",
200
+ minimum=0.0,
201
+ maximum=2.0,
202
+ step=0.1,
203
+ value=1.0,
204
+ info="Adjust the influence of the LoRA weights",
205
+ )
206
+
207
  gr.Examples(examples=examples, inputs=[prompt])
208
  gr.on(
209
  triggers=[run_button.click, prompt.submit],
210
  fn=infer,
211
  inputs=[
212
+ model_id,
213
  prompt,
214
  negative_prompt,
215
  seed,
 
218
  height,
219
  guidance_scale,
220
  num_inference_steps,
221
+ lora_scale, # Pass the new slider value
222
  ],
223
  outputs=[result, seed],
224
  )