Bethie commited on
Commit
27898b7
·
verified ·
1 Parent(s): 0a58c97

Run pipeline

Browse files
Files changed (1) hide show
  1. code_inference/run_controlnext.py +280 -0
code_inference/run_controlnext.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ import argparse
7
+ from diffusers import DDPMScheduler
8
+
9
+ from pipeline_sdxl_ipadapter import StableDiffusionXLControlNeXtPipeline
10
+ from transformers import CLIPVisionModelWithProjection
11
+ from transformers import CLIPTokenizer
12
+ import onnxruntime as ort
13
+ from configs import *
14
+
15
+ def log_validation(
16
+ vae,
17
+ scheduler,
18
+ text_encoder,
19
+ tokenizer,
20
+ unet,
21
+ controlnet,
22
+ args,
23
+ device,
24
+ image_proj,
25
+ text_encoder2,
26
+ tokenizer2,
27
+ image_encoder
28
+ ):
29
+ if len(args.validation_image) == len(args.validation_prompt):
30
+ validation_images = args.validation_image
31
+ validation_prompts = args.validation_prompt
32
+ elif len(args.validation_image) == 1:
33
+ validation_images = args.validation_image * len(args.validation_prompt)
34
+ validation_prompts = args.validation_prompt
35
+ elif len(args.validation_prompt) == 1:
36
+ validation_images = args.validation_image
37
+ validation_prompts = args.validation_prompt * len(args.validation_image)
38
+ else:
39
+ raise ValueError(
40
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
41
+ )
42
+
43
+ if args.negative_prompt is not None:
44
+ negative_prompts = args.negative_prompt
45
+ assert len(validation_prompts) == len(validation_prompts)
46
+ else:
47
+ negative_prompts = None
48
+
49
+ inference_ctx = torch.autocast(device)
50
+
51
+ pipeline = StableDiffusionXLControlNeXtPipeline(
52
+ vae=vae,
53
+ text_encoder=text_encoder,
54
+ text_encoder_2=text_encoder2,
55
+ tokenizer=tokenizer,
56
+ tokenizer_2=tokenizer2,
57
+ unet=unet,
58
+ controlnext=controlnet,
59
+ scheduler=scheduler,
60
+ image_encoder=image_encoder,
61
+ device=device,
62
+ image_proj=image_proj
63
+ )
64
+
65
+ image_logs = []
66
+ pil_image = args.pil_image
67
+
68
+ if args.pil_image is not None:
69
+ pil_image = Image.open(pil_image).convert("RGB")
70
+
71
+ for i, (validation_prompt, validation_image) in enumerate(zip(validation_prompts, validation_images)):
72
+ validation_image = Image.open(validation_image).convert("RGB")
73
+
74
+ images = []
75
+ negative_prompt = negative_prompts[i] if negative_prompts is not None else None
76
+
77
+ for _ in range(args.num_validation_images):
78
+
79
+ with inference_ctx:
80
+
81
+ image = pipeline(
82
+ prompt=validation_prompt,
83
+ controlnet_image=validation_image,
84
+ num_inference_steps=args.num_inference_steps,
85
+ guidance_rescale = args.guidance_scale,
86
+ negative_prompt=negative_prompt,
87
+ ip_adapter_image=pil_image,
88
+ control_scale=args.controlnext_scale,
89
+ width = args.width,
90
+ height=args.height,
91
+ )[0]
92
+
93
+ images.append(image)
94
+
95
+ image_logs.append(
96
+ {"validation_image": validation_image.resize((args.width,args.height)),
97
+ "ip_adapter_image": pil_image.resize((args.width,args.height)),
98
+ "images": images, "validation_prompt": validation_prompt}
99
+ )
100
+
101
+ save_dir_path = args.output_dir
102
+
103
+ if not os.path.exists(save_dir_path):
104
+ os.makedirs(save_dir_path)
105
+ for i, log in enumerate(image_logs):
106
+ images = log["images"]
107
+ validation_prompt = log["validation_prompt"]
108
+ ip_adapter_image = log["ip_adapter_image"]
109
+ validation_image = log["validation_image"]
110
+
111
+ formatted_images = []
112
+ formatted_images.append(np.asarray(validation_image))
113
+ formatted_images.append(np.asarray(ip_adapter_image))
114
+
115
+ for image in images:
116
+ formatted_images.append(np.asarray(image))
117
+
118
+ for idx, img in enumerate(formatted_images):
119
+ print(f"Image {idx} shape: {img.shape}")
120
+
121
+ formatted_images = np.concatenate(formatted_images, 1)
122
+
123
+ file_path = os.path.join(save_dir_path, "image_{}.png".format(i))
124
+ formatted_images = cv2.cvtColor(formatted_images, cv2.COLOR_BGR2RGB)
125
+ print("Save images to:", file_path)
126
+ cv2.imwrite(file_path, formatted_images)
127
+
128
+ return image_logs
129
+
130
+ def parse_args(input_args=None):
131
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
132
+
133
+ parser.add_argument(
134
+ "--output_dir",
135
+ type=str,
136
+ default=None,
137
+ help="The output directory where the inference result will be written.",
138
+ )
139
+ parser.add_argument(
140
+ "--pil_image",
141
+ type=str,
142
+ default=None,
143
+ help="IP Adapter image path.",
144
+ )
145
+
146
+ parser.add_argument(
147
+ "--validation_prompt",
148
+ type=str,
149
+ default=None,
150
+ nargs="+",
151
+ help=(
152
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
153
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
154
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
155
+ ),
156
+ )
157
+ parser.add_argument(
158
+ "--negative_prompt",
159
+ type=str,
160
+ default=None,
161
+ nargs="+",
162
+ help=(
163
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
164
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
165
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
166
+ ),
167
+ )
168
+ parser.add_argument(
169
+ "--validation_image",
170
+ type=str,
171
+ default=None,
172
+ nargs="+",
173
+ help=(
174
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
175
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
176
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
177
+ " `--validation_image` that will be used with all `--validation_prompt`s."
178
+ ),
179
+ )
180
+ parser.add_argument(
181
+ "--num_validation_images",
182
+ type=int,
183
+ default=1,
184
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair.",
185
+ )
186
+
187
+ parser.add_argument(
188
+ "--num_inference_steps",
189
+ type=int,
190
+ default=30,
191
+ help="Number of steps for inference.",
192
+ )
193
+
194
+ parser.add_argument(
195
+ "--controlnext_scale",
196
+ type=float,
197
+ default=2.5,
198
+ help="ControlNext scale.",
199
+ )
200
+
201
+ parser.add_argument(
202
+ "--guidance_scale",
203
+ type=float,
204
+ default=7.5,
205
+ help="Guidance scale.",
206
+ )
207
+
208
+ parser.add_argument(
209
+ "--height",
210
+ type=int,
211
+ default=1024,
212
+ help="The height of output image.",
213
+ )
214
+
215
+ parser.add_argument(
216
+ "--width",
217
+ type=int,
218
+ default=1024,
219
+ help="The width of output image.",
220
+ )
221
+
222
+ if input_args is not None:
223
+ args = parser.parse_args(input_args)
224
+ else:
225
+ args = parser.parse_args()
226
+
227
+
228
+ if args.validation_prompt is not None and args.validation_image is None:
229
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
230
+
231
+ if args.validation_prompt is None and args.validation_image is not None:
232
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
233
+
234
+ if (
235
+ args.validation_image is not None
236
+ and args.validation_prompt is not None
237
+ and len(args.validation_image) != 1
238
+ and len(args.validation_prompt) != 1
239
+ and len(args.validation_image) != len(args.validation_prompt)
240
+ ):
241
+ raise ValueError(
242
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
243
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
244
+ )
245
+
246
+ return args
247
+
248
+ if __name__ == "__main__":
249
+ args = parse_args()
250
+
251
+ device = 'cuda:0'
252
+
253
+ vae_session = ort.InferenceSession(VAE_ONNX_PATH, providers=providers, sess_options=session_options)
254
+
255
+ unet_session = ort.InferenceSession(UNET_ONNX_PATH, providers=providers, sess_options=session_options, provider_options=provider_options_1)
256
+ tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
257
+ tokenizer2 = CLIPTokenizer.from_pretrained(TOKENIZER_PATH2)
258
+ text_encoder_session = ort.InferenceSession(TEXT_ENCODER_PATH, providers=providers, sess_options=session_options)
259
+ text_encoder_session2 = ort.InferenceSession(TEXT_ENCODER_PATH2, providers=providers, sess_options=session_options)
260
+ scheduler = DDPMScheduler.from_pretrained(SCHEDULER_PATH)
261
+
262
+ controlnet = ort.InferenceSession(CONTROLNEXT_ONNX_PATH, providers=providers, sess_options=session_options)
263
+ #image_encoder = ort.InferenceSession(IMAGE_ENCODER_ONNX_PATH, providers=providers, provider_options=provider_options_0)
264
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder = 'sdxl_models/image_encoder').to(device, dtype=torch.float32)
265
+ image_proj = ort.InferenceSession(PROJ_ONNX_PATH, providers=providers, sess_options=session_options)
266
+
267
+ log_validation(
268
+ vae=vae_session,
269
+ scheduler=scheduler,
270
+ text_encoder=text_encoder_session,
271
+ tokenizer=tokenizer,
272
+ unet=unet_session,
273
+ controlnet=controlnet,
274
+ image_encoder = image_encoder,
275
+ args=args,
276
+ device=device,
277
+ image_proj = image_proj,
278
+ text_encoder2 = text_encoder_session2,
279
+ tokenizer2 = tokenizer2
280
+ )