Bethie's picture
Run pipeline
27898b7 verified
raw
history blame
9.76 kB
import os
import torch
import cv2
import numpy as np
from PIL import Image
import argparse
from diffusers import DDPMScheduler
from pipeline_sdxl_ipadapter import StableDiffusionXLControlNeXtPipeline
from transformers import CLIPVisionModelWithProjection
from transformers import CLIPTokenizer
import onnxruntime as ort
from configs import *
def log_validation(
vae,
scheduler,
text_encoder,
tokenizer,
unet,
controlnet,
args,
device,
image_proj,
text_encoder2,
tokenizer2,
image_encoder
):
if len(args.validation_image) == len(args.validation_prompt):
validation_images = args.validation_image
validation_prompts = args.validation_prompt
elif len(args.validation_image) == 1:
validation_images = args.validation_image * len(args.validation_prompt)
validation_prompts = args.validation_prompt
elif len(args.validation_prompt) == 1:
validation_images = args.validation_image
validation_prompts = args.validation_prompt * len(args.validation_image)
else:
raise ValueError(
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
)
if args.negative_prompt is not None:
negative_prompts = args.negative_prompt
assert len(validation_prompts) == len(validation_prompts)
else:
negative_prompts = None
inference_ctx = torch.autocast(device)
pipeline = StableDiffusionXLControlNeXtPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder2,
tokenizer=tokenizer,
tokenizer_2=tokenizer2,
unet=unet,
controlnext=controlnet,
scheduler=scheduler,
image_encoder=image_encoder,
device=device,
image_proj=image_proj
)
image_logs = []
pil_image = args.pil_image
if args.pil_image is not None:
pil_image = Image.open(pil_image).convert("RGB")
for i, (validation_prompt, validation_image) in enumerate(zip(validation_prompts, validation_images)):
validation_image = Image.open(validation_image).convert("RGB")
images = []
negative_prompt = negative_prompts[i] if negative_prompts is not None else None
for _ in range(args.num_validation_images):
with inference_ctx:
image = pipeline(
prompt=validation_prompt,
controlnet_image=validation_image,
num_inference_steps=args.num_inference_steps,
guidance_rescale = args.guidance_scale,
negative_prompt=negative_prompt,
ip_adapter_image=pil_image,
control_scale=args.controlnext_scale,
width = args.width,
height=args.height,
)[0]
images.append(image)
image_logs.append(
{"validation_image": validation_image.resize((args.width,args.height)),
"ip_adapter_image": pil_image.resize((args.width,args.height)),
"images": images, "validation_prompt": validation_prompt}
)
save_dir_path = args.output_dir
if not os.path.exists(save_dir_path):
os.makedirs(save_dir_path)
for i, log in enumerate(image_logs):
images = log["images"]
validation_prompt = log["validation_prompt"]
ip_adapter_image = log["ip_adapter_image"]
validation_image = log["validation_image"]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
formatted_images.append(np.asarray(ip_adapter_image))
for image in images:
formatted_images.append(np.asarray(image))
for idx, img in enumerate(formatted_images):
print(f"Image {idx} shape: {img.shape}")
formatted_images = np.concatenate(formatted_images, 1)
file_path = os.path.join(save_dir_path, "image_{}.png".format(i))
formatted_images = cv2.cvtColor(formatted_images, cv2.COLOR_BGR2RGB)
print("Save images to:", file_path)
cv2.imwrite(file_path, formatted_images)
return image_logs
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the inference result will be written.",
)
parser.add_argument(
"--pil_image",
type=str,
default=None,
help="IP Adapter image path.",
)
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
nargs="+",
help=(
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
),
)
parser.add_argument(
"--negative_prompt",
type=str,
default=None,
nargs="+",
help=(
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
),
)
parser.add_argument(
"--validation_image",
type=str,
default=None,
nargs="+",
help=(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
),
)
parser.add_argument(
"--num_validation_images",
type=int,
default=1,
help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair.",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=30,
help="Number of steps for inference.",
)
parser.add_argument(
"--controlnext_scale",
type=float,
default=2.5,
help="ControlNext scale.",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="Guidance scale.",
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="The height of output image.",
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="The width of output image.",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
if args.validation_prompt is not None and args.validation_image is None:
raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
if args.validation_prompt is None and args.validation_image is not None:
raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
if (
args.validation_image is not None
and args.validation_prompt is not None
and len(args.validation_image) != 1
and len(args.validation_prompt) != 1
and len(args.validation_image) != len(args.validation_prompt)
):
raise ValueError(
"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
" or the same number of `--validation_prompt`s and `--validation_image`s"
)
return args
if __name__ == "__main__":
args = parse_args()
device = 'cuda:0'
vae_session = ort.InferenceSession(VAE_ONNX_PATH, providers=providers, sess_options=session_options)
unet_session = ort.InferenceSession(UNET_ONNX_PATH, providers=providers, sess_options=session_options, provider_options=provider_options_1)
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
tokenizer2 = CLIPTokenizer.from_pretrained(TOKENIZER_PATH2)
text_encoder_session = ort.InferenceSession(TEXT_ENCODER_PATH, providers=providers, sess_options=session_options)
text_encoder_session2 = ort.InferenceSession(TEXT_ENCODER_PATH2, providers=providers, sess_options=session_options)
scheduler = DDPMScheduler.from_pretrained(SCHEDULER_PATH)
controlnet = ort.InferenceSession(CONTROLNEXT_ONNX_PATH, providers=providers, sess_options=session_options)
#image_encoder = ort.InferenceSession(IMAGE_ENCODER_ONNX_PATH, providers=providers, provider_options=provider_options_0)
image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder = 'sdxl_models/image_encoder').to(device, dtype=torch.float32)
image_proj = ort.InferenceSession(PROJ_ONNX_PATH, providers=providers, sess_options=session_options)
log_validation(
vae=vae_session,
scheduler=scheduler,
text_encoder=text_encoder_session,
tokenizer=tokenizer,
unet=unet_session,
controlnet=controlnet,
image_encoder = image_encoder,
args=args,
device=device,
image_proj = image_proj,
text_encoder2 = text_encoder_session2,
tokenizer2 = tokenizer2
)