YaArtemNosenko commited on
Commit
01591d1
·
verified ·
1 Parent(s): f7b0413

[ADD] Add IP adapter and ControlNet

Browse files
Files changed (1) hide show
  1. app.py +150 -19
app.py CHANGED
@@ -25,7 +25,13 @@ else:
25
  # Cache to avoid re-initializing pipelines repeatedly
26
  model_cache = {}
27
 
28
- def load_pipeline(model_id: str):
 
 
 
 
 
 
29
  """
30
  Loads or retrieves a cached DiffusionPipeline.
31
 
@@ -34,11 +40,52 @@ def load_pipeline(model_id: str):
34
  """
35
  if model_id in model_cache:
36
  return model_cache[model_id]
37
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  if model_id == "YaArtemNosenko/dino_stickers":
39
  # Use the specified base model for your LoRA adapter.
40
  base_model = "CompVis/stable-diffusion-v1-4"
41
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
42
  # Load the LoRA weights
43
  pipe.unet = PeftModel.from_pretrained(
44
  pipe.unet,
@@ -52,9 +99,21 @@ def load_pipeline(model_id: str):
52
  subfolder="text_encoder",
53
  torch_dtype=torch_dtype
54
  )
 
 
55
  else:
56
- pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
57
-
 
 
 
 
 
 
 
 
 
 
58
  pipe.to(device)
59
  model_cache[model_id] = pipe
60
  return pipe
@@ -72,17 +131,36 @@ def infer(
72
  height,
73
  guidance_scale,
74
  num_inference_steps,
75
- lora_scale, # New parameter for adjusting LoRA scale
 
 
 
 
 
 
 
76
  progress=gr.Progress(track_tqdm=True),
77
  ):
78
  # Load the pipeline for the chosen model
79
- pipe = load_pipeline(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  if randomize_seed:
82
  seed = random.randint(0, MAX_SEED)
83
 
84
- generator = torch.Generator(device=device).manual_seed(seed)
85
-
86
  # If using the LoRA model, update the LoRA scale if supported.
87
  if model_id == "YaArtemNosenko/dino_stickers":
88
  # This assumes your pipeline's unet has a method to update the LoRA scale.
@@ -90,17 +168,15 @@ def infer(
90
  pipe.unet.set_lora_scale(lora_scale)
91
  else:
92
  print("Warning: LoRA scale adjustment method not found on UNet.")
 
 
 
 
 
 
 
93
 
94
- image = pipe(
95
- prompt=prompt,
96
- negative_prompt=negative_prompt,
97
- guidance_scale=guidance_scale,
98
- num_inference_steps=num_inference_steps,
99
- width=width,
100
- height=height,
101
- generator=generator,
102
- ).images[0]
103
-
104
  return image, seed
105
 
106
  examples = [
@@ -201,6 +277,61 @@ with gr.Blocks(css=css) as demo:
201
  value=1.0,
202
  info="Adjust the influence of the LoRA weights",
203
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  gr.Examples(examples=examples, inputs=[prompt])
206
  gr.on(
 
25
  # Cache to avoid re-initializing pipelines repeatedly
26
  model_cache = {}
27
 
28
+ def load_pipeline(model_id,
29
+ lora_scale,
30
+ controlnet_checkbox,
31
+ controlnet_mode,
32
+ ip_adapter_checkbox,
33
+ ip_adapter_scale
34
+ ):
35
  """
36
  Loads or retrieves a cached DiffusionPipeline.
37
 
 
40
  """
41
  if model_id in model_cache:
42
  return model_cache[model_id]
43
+
44
+ if controlnet_checkbox:
45
+ if controlnet_mode == "depth_map":
46
+ controlnet = ControlNetModel.from_pretrained(
47
+ "lllyasviel/sd-controlnet-depth",
48
+ cache_dir="./models_cache",
49
+ torch_dtype=torch_dtype
50
+ )
51
+ elif controlnet_mode == "pose_estimation":
52
+ controlnet = ControlNetModel.from_pretrained(
53
+ "lllyasviel/sd-controlnet-openpose",
54
+ cache_dir="./models_cache",
55
+ torch_dtype=torch_dtype
56
+ )
57
+ elif controlnet_mode == "normal_map":
58
+ controlnet = ControlNetModel.from_pretrained(
59
+ "lllyasviel/sd-controlnet-normal",
60
+ cache_dir="./models_cache",
61
+ torch_dtype=torch_dtype
62
+ )
63
+ elif controlnet_mode == "scribbles":
64
+ controlnet = ControlNetModel.from_pretrained(
65
+ "lllyasviel/sd-controlnet-scribble",
66
+ cache_dir="./models_cache",
67
+ torch_dtype=torch_dtype
68
+ )
69
+ else:
70
+ controlnet = ControlNetModel.from_pretrained(
71
+ "lllyasviel/sd-controlnet-canny",
72
+ cache_dir="./models_cache",
73
+ torch_dtype=torch_dtype
74
+ )
75
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id,
76
+ controlnet=controlnet,
77
+ torch_dtype=torch_dtype,
78
+ safety_checker=None).to(device)
79
+ # params['image'] = controlnet_image
80
+ # params['controlnet_conditioning_scale'] = float(controlnet_strength)
81
+ else:
82
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
83
+ torch_dtype=torch_dtype,
84
+ safety_checker=None).to(device)
85
+
86
  if model_id == "YaArtemNosenko/dino_stickers":
87
  # Use the specified base model for your LoRA adapter.
88
  base_model = "CompVis/stable-diffusion-v1-4"
 
89
  # Load the LoRA weights
90
  pipe.unet = PeftModel.from_pretrained(
91
  pipe.unet,
 
99
  subfolder="text_encoder",
100
  torch_dtype=torch_dtype
101
  )
102
+ pipe.unet.load_state_dict({k: lora_scale * v for k, v in pipe.unet.state_dict().items()})
103
+ pipe.text_encoder.load_state_dict({k: lora_scale * v for k, v in pipe.text_encoder.state_dict().items()})
104
  else:
105
+ pipe = DiffusionPipeline.from_pretrained(model_id,
106
+ torch_dtype=torch_dtype
107
+ )
108
+
109
+ if ip_adapter_checkbox:
110
+ pipe.load_ip_adapter("h94/IP-Adapter",
111
+ subfolder="models",
112
+ weight_name="ip-adapter-plus_sd15.bin"
113
+ )
114
+ pipe.set_ip_adapter_scale(ip_adapter_scale)
115
+ # params['ip_adapter_image'] = ip_adapter_image
116
+
117
  pipe.to(device)
118
  model_cache[model_id] = pipe
119
  return pipe
 
131
  height,
132
  guidance_scale,
133
  num_inference_steps,
134
+ lora_scale, # New parameter for adjusting LoRA scale
135
+ controlnet_checkbox=False, # используем ли мы controlnet
136
+ controlnet_conditioning_scale=0.0, # вес для controlnet
137
+ controlnet_mode="edge_detection", # вариант controlnet
138
+ controlnet_image=None, # картинка для controlnet
139
+ ip_adapter_checkbox=False, # используется ли ip адаптера
140
+ ip_adapter_scale=0.0, # вес для ip адаптера
141
+ ip_adapter_image=None, # картинка для ip адаптера
142
  progress=gr.Progress(track_tqdm=True),
143
  ):
144
  # Load the pipeline for the chosen model
145
+ generator = torch.Generator(device=device).manual_seed(seed)
146
+ params = {'prompt': prompt,
147
+ 'negative_prompt': negative_prompt,
148
+ 'guidance_scale': guidance_scale,
149
+ 'num_inference_steps': num_inference_steps,
150
+ 'width': width,
151
+ 'height': height,
152
+ 'generator': generator
153
+ }
154
+ pipe = load_pipeline(lora_scale,
155
+ controlnet_checkbox,
156
+ controlnet_mode,
157
+ ip_adapter_checkbox,
158
+ ip_adapter_scale
159
+ )
160
 
161
  if randomize_seed:
162
  seed = random.randint(0, MAX_SEED)
163
 
 
 
164
  # If using the LoRA model, update the LoRA scale if supported.
165
  if model_id == "YaArtemNosenko/dino_stickers":
166
  # This assumes your pipeline's unet has a method to update the LoRA scale.
 
168
  pipe.unet.set_lora_scale(lora_scale)
169
  else:
170
  print("Warning: LoRA scale adjustment method not found on UNet.")
171
+ # если используем controlnet
172
+ if controlnet_checkbox:
173
+ params['image'] = controlnet_image
174
+ params['controlnet_conditioning_scale'] = float(controlnet_conditioning_scale)
175
+ # если используем IP адаптер
176
+ if ip_adapter_checkbox:
177
+ params['ip_adapter_image'] = ip_adapter_image
178
 
179
+ image = pipe(**params).images[0]
 
 
 
 
 
 
 
 
 
180
  return image, seed
181
 
182
  examples = [
 
277
  value=1.0,
278
  info="Adjust the influence of the LoRA weights",
279
  )
280
+ with gr.Row():
281
+ controlnet_checkbox = gr.Checkbox(
282
+ label="ControlNet",
283
+ value=False
284
+ )
285
+ with gr.Column(visible=False) as controlnet_params:
286
+ controlnet_conditioning_scale = gr.Slider(
287
+ label="ControlNet conditioning scale",
288
+ minimum=0.0,
289
+ maximum=1.0,
290
+ step=0.01,
291
+ value=1.0,
292
+ )
293
+ controlnet_mode = gr.Dropdown(
294
+ label="ControlNet mode",
295
+ choices=["edge_detection",
296
+ "depth_map",
297
+ "pose_estimation",
298
+ "normal_map",
299
+ "scribbles"],
300
+ value="edge_detection",
301
+ max_choices=1
302
+ )
303
+ controlnet_image = gr.Image(
304
+ label="ControlNet condition image",
305
+ type="pil",
306
+ format="png"
307
+ )
308
+ controlnet_checkbox.change(
309
+ fn=lambda x: gr.Row.update(visible=x),
310
+ inputs=controlnet_checkbox,
311
+ outputs=controlnet_params
312
+ )
313
+ with gr.Row():
314
+ ip_adapter_checkbox = gr.Checkbox(
315
+ label="IPAdapter",
316
+ value=False
317
+ )
318
+ with gr.Column(visible=False) as ip_adapter_params:
319
+ ip_adapter_scale = gr.Slider(
320
+ label="IPAdapter scale",
321
+ minimum=0.0,
322
+ maximum=1.0,
323
+ step=0.01,
324
+ value=1.0,
325
+ )
326
+ ip_adapter_image = gr.Image(
327
+ label="IPAdapter condition image",
328
+ type="pil"
329
+ )
330
+ ip_adapter_checkbox.change(
331
+ fn=lambda x: gr.Row.update(visible=x),
332
+ inputs=ip_adapter_checkbox,
333
+ outputs=ip_adapter_params
334
+ )
335
 
336
  gr.Examples(examples=examples, inputs=[prompt])
337
  gr.on(