Bethie's picture
Run pipeline
74eeb7f verified
import os
import torch
import cv2
import numpy as np
from PIL import Image
import argparse
from diffusers import DDPMScheduler
from pipeline_sdxl_ipadapter import StableDiffusionXLControlNeXtPipeline
from transformers import CLIPVisionModelWithProjection
from transformers import CLIPTokenizer
import onnxruntime as ort
from configs import *
def log_validation(
vae,
scheduler,
text_encoder,
tokenizer,
unet,
controlnet,
args,
device,
image_proj,
text_encoder2,
tokenizer2,
image_encoder
):
if len(args.validation_image) == len(args.validation_prompt):
validation_images = args.validation_image
validation_prompts = args.validation_prompt
elif len(args.validation_image) == 1:
validation_images = args.validation_image * len(args.validation_prompt)
validation_prompts = args.validation_prompt
elif len(args.validation_prompt) == 1:
validation_images = args.validation_image
validation_prompts = args.validation_prompt * len(args.validation_image)
else:
raise ValueError(
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
)
if args.negative_prompt is not None:
negative_prompts = args.negative_prompt
assert len(validation_prompts) == len(validation_prompts)
else:
negative_prompts = None
inference_ctx = torch.autocast(device)
pipeline = StableDiffusionXLControlNeXtPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder2,
tokenizer=tokenizer,
tokenizer_2=tokenizer2,
unet=unet,
controlnext=controlnet,
scheduler=scheduler,
image_encoder=image_encoder,
device=device,
image_proj=image_proj
)
image_logs = []
pil_image = args.pil_image
if args.pil_image is not None:
pil_image = Image.open(pil_image).convert("RGB")
for i, (validation_prompt, validation_image) in enumerate(zip(validation_prompts, validation_images)):
validation_image = Image.open(validation_image).convert("RGB")
images = []
negative_prompt = negative_prompts[i] if negative_prompts is not None else None
for _ in range(args.num_validation_images):
with inference_ctx:
image = pipeline(
prompt=validation_prompt,
controlnet_image=validation_image,
num_inference_steps=args.num_inference_steps,
guidance_rescale = args.guidance_scale,
negative_prompt=negative_prompt,
ip_adapter_image=pil_image,
control_scale=args.controlnext_scale,
width = args.width,
height=args.height,
)[0]
images.append(image)
image_logs.append(
{"validation_image": validation_image.resize((args.width,args.height)),
"ip_adapter_image": pil_image.resize((args.width,args.height)),
"images": images, "validation_prompt": validation_prompt}
)
save_dir_path = args.output_dir
if not os.path.exists(save_dir_path):
os.makedirs(save_dir_path)
for i, log in enumerate(image_logs):
images = log["images"]
validation_prompt = log["validation_prompt"]
ip_adapter_image = log["ip_adapter_image"]
validation_image = log["validation_image"]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
formatted_images.append(np.asarray(ip_adapter_image))
for image in images:
formatted_images.append(np.asarray(image))
for idx, img in enumerate(formatted_images):
print(f"Image {idx} shape: {img.shape}")
formatted_images = np.concatenate(formatted_images, 1)
file_path = os.path.join(save_dir_path, "image_{}.png".format(i))
formatted_images = cv2.cvtColor(formatted_images, cv2.COLOR_BGR2RGB)
print("Save images to:", file_path)
cv2.imwrite(file_path, formatted_images)
return image_logs
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the inference result will be written.",
)
parser.add_argument(
"--pil_image",
type=str,
default=None,
help="IP Adapter image path.",
)
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
nargs="+",
help=(
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
),
)
parser.add_argument(
"--negative_prompt",
type=str,
default=None,
nargs="+",
help=(
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
),
)
parser.add_argument(
"--validation_image",
type=str,
default=None,
nargs="+",
help=(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
),
)
parser.add_argument(
"--num_validation_images",
type=int,
default=1,
help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair.",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=30,
help="Number of steps for inference.",
)
parser.add_argument(
"--controlnext_scale",
type=float,
default=2.5,
help="ControlNext scale.",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="Guidance scale.",
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="The height of output image.",
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="The width of output image.",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
if args.validation_prompt is not None and args.validation_image is None:
raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
if args.validation_prompt is None and args.validation_image is not None:
raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
if (
args.validation_image is not None
and args.validation_prompt is not None
and len(args.validation_image) != 1
and len(args.validation_prompt) != 1
and len(args.validation_image) != len(args.validation_prompt)
):
raise ValueError(
"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
" or the same number of `--validation_prompt`s and `--validation_image`s"
)
return args
if __name__ == "__main__":
args = parse_args()
device = 'cuda:0'
vae_session = ort.InferenceSession(VAE_ONNX_PATH, providers=providers, sess_options=session_options)
unet_session = ort.InferenceSession(UNET_ONNX_PATH, providers=providers, sess_options=session_options, provider_options=provider_options_1)
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
tokenizer2 = CLIPTokenizer.from_pretrained(TOKENIZER_PATH2)
text_encoder_session = ort.InferenceSession(TEXT_ENCODER_PATH, providers=providers, sess_options=session_options)
text_encoder_session2 = ort.InferenceSession(TEXT_ENCODER_PATH2, providers=providers, sess_options=session_options)
scheduler = DDPMScheduler.from_pretrained(SCHEDULER_PATH)
controlnet = ort.InferenceSession(CONTROLNEXT_ONNX_PATH, providers=providers, sess_options=session_options)
image_encoder = ort.InferenceSession(IMAGE_ENCODER_ONNX_PATH, providers=providers, provider_options=provider_options_1)
image_proj = ort.InferenceSession(PROJ_ONNX_PATH, providers=providers, sess_options=session_options)
log_validation(
vae=vae_session,
scheduler=scheduler,
text_encoder=text_encoder_session,
tokenizer=tokenizer,
unet=unet_session,
controlnet=controlnet,
image_encoder = image_encoder,
args=args,
device=device,
image_proj = image_proj,
text_encoder2 = text_encoder_session2,
tokenizer2 = tokenizer2
)