|
import os |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import argparse |
|
from diffusers import DDPMScheduler |
|
|
|
from pipeline_sdxl_ipadapter import StableDiffusionXLControlNeXtPipeline |
|
from transformers import CLIPVisionModelWithProjection |
|
from transformers import CLIPTokenizer |
|
import onnxruntime as ort |
|
from configs import * |
|
|
|
def log_validation( |
|
vae, |
|
scheduler, |
|
text_encoder, |
|
tokenizer, |
|
unet, |
|
controlnet, |
|
args, |
|
device, |
|
image_proj, |
|
text_encoder2, |
|
tokenizer2, |
|
image_encoder |
|
): |
|
if len(args.validation_image) == len(args.validation_prompt): |
|
validation_images = args.validation_image |
|
validation_prompts = args.validation_prompt |
|
elif len(args.validation_image) == 1: |
|
validation_images = args.validation_image * len(args.validation_prompt) |
|
validation_prompts = args.validation_prompt |
|
elif len(args.validation_prompt) == 1: |
|
validation_images = args.validation_image |
|
validation_prompts = args.validation_prompt * len(args.validation_image) |
|
else: |
|
raise ValueError( |
|
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" |
|
) |
|
|
|
if args.negative_prompt is not None: |
|
negative_prompts = args.negative_prompt |
|
assert len(validation_prompts) == len(validation_prompts) |
|
else: |
|
negative_prompts = None |
|
|
|
inference_ctx = torch.autocast(device) |
|
|
|
pipeline = StableDiffusionXLControlNeXtPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
text_encoder_2=text_encoder2, |
|
tokenizer=tokenizer, |
|
tokenizer_2=tokenizer2, |
|
unet=unet, |
|
controlnext=controlnet, |
|
scheduler=scheduler, |
|
image_encoder=image_encoder, |
|
device=device, |
|
image_proj=image_proj |
|
) |
|
|
|
image_logs = [] |
|
pil_image = args.pil_image |
|
|
|
if args.pil_image is not None: |
|
pil_image = Image.open(pil_image).convert("RGB") |
|
|
|
for i, (validation_prompt, validation_image) in enumerate(zip(validation_prompts, validation_images)): |
|
validation_image = Image.open(validation_image).convert("RGB") |
|
|
|
images = [] |
|
negative_prompt = negative_prompts[i] if negative_prompts is not None else None |
|
|
|
for _ in range(args.num_validation_images): |
|
|
|
with inference_ctx: |
|
|
|
image = pipeline( |
|
prompt=validation_prompt, |
|
controlnet_image=validation_image, |
|
num_inference_steps=args.num_inference_steps, |
|
guidance_rescale = args.guidance_scale, |
|
negative_prompt=negative_prompt, |
|
ip_adapter_image=pil_image, |
|
control_scale=args.controlnext_scale, |
|
width = args.width, |
|
height=args.height, |
|
)[0] |
|
|
|
images.append(image) |
|
|
|
image_logs.append( |
|
{"validation_image": validation_image.resize((args.width,args.height)), |
|
"ip_adapter_image": pil_image.resize((args.width,args.height)), |
|
"images": images, "validation_prompt": validation_prompt} |
|
) |
|
|
|
save_dir_path = args.output_dir |
|
|
|
if not os.path.exists(save_dir_path): |
|
os.makedirs(save_dir_path) |
|
for i, log in enumerate(image_logs): |
|
images = log["images"] |
|
validation_prompt = log["validation_prompt"] |
|
ip_adapter_image = log["ip_adapter_image"] |
|
validation_image = log["validation_image"] |
|
|
|
formatted_images = [] |
|
formatted_images.append(np.asarray(validation_image)) |
|
formatted_images.append(np.asarray(ip_adapter_image)) |
|
|
|
for image in images: |
|
formatted_images.append(np.asarray(image)) |
|
|
|
for idx, img in enumerate(formatted_images): |
|
print(f"Image {idx} shape: {img.shape}") |
|
|
|
formatted_images = np.concatenate(formatted_images, 1) |
|
|
|
file_path = os.path.join(save_dir_path, "image_{}.png".format(i)) |
|
formatted_images = cv2.cvtColor(formatted_images, cv2.COLOR_BGR2RGB) |
|
print("Save images to:", file_path) |
|
cv2.imwrite(file_path, formatted_images) |
|
|
|
return image_logs |
|
|
|
def parse_args(input_args=None): |
|
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") |
|
|
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default=None, |
|
help="The output directory where the inference result will be written.", |
|
) |
|
parser.add_argument( |
|
"--pil_image", |
|
type=str, |
|
default=None, |
|
help="IP Adapter image path.", |
|
) |
|
|
|
parser.add_argument( |
|
"--validation_prompt", |
|
type=str, |
|
default=None, |
|
nargs="+", |
|
help=( |
|
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." |
|
" Provide either a matching number of `--validation_image`s, a single `--validation_image`" |
|
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." |
|
), |
|
) |
|
parser.add_argument( |
|
"--negative_prompt", |
|
type=str, |
|
default=None, |
|
nargs="+", |
|
help=( |
|
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." |
|
" Provide either a matching number of `--validation_image`s, a single `--validation_image`" |
|
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." |
|
), |
|
) |
|
parser.add_argument( |
|
"--validation_image", |
|
type=str, |
|
default=None, |
|
nargs="+", |
|
help=( |
|
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" |
|
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" |
|
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single" |
|
" `--validation_image` that will be used with all `--validation_prompt`s." |
|
), |
|
) |
|
parser.add_argument( |
|
"--num_validation_images", |
|
type=int, |
|
default=1, |
|
help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair.", |
|
) |
|
|
|
parser.add_argument( |
|
"--num_inference_steps", |
|
type=int, |
|
default=30, |
|
help="Number of steps for inference.", |
|
) |
|
|
|
parser.add_argument( |
|
"--controlnext_scale", |
|
type=float, |
|
default=2.5, |
|
help="ControlNext scale.", |
|
) |
|
|
|
parser.add_argument( |
|
"--guidance_scale", |
|
type=float, |
|
default=7.5, |
|
help="Guidance scale.", |
|
) |
|
|
|
parser.add_argument( |
|
"--height", |
|
type=int, |
|
default=1024, |
|
help="The height of output image.", |
|
) |
|
|
|
parser.add_argument( |
|
"--width", |
|
type=int, |
|
default=1024, |
|
help="The width of output image.", |
|
) |
|
|
|
if input_args is not None: |
|
args = parser.parse_args(input_args) |
|
else: |
|
args = parser.parse_args() |
|
|
|
|
|
if args.validation_prompt is not None and args.validation_image is None: |
|
raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") |
|
|
|
if args.validation_prompt is None and args.validation_image is not None: |
|
raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") |
|
|
|
if ( |
|
args.validation_image is not None |
|
and args.validation_prompt is not None |
|
and len(args.validation_image) != 1 |
|
and len(args.validation_prompt) != 1 |
|
and len(args.validation_image) != len(args.validation_prompt) |
|
): |
|
raise ValueError( |
|
"Must provide either 1 `--validation_image`, 1 `--validation_prompt`," |
|
" or the same number of `--validation_prompt`s and `--validation_image`s" |
|
) |
|
|
|
return args |
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
|
|
device = 'cuda:0' |
|
|
|
vae_session = ort.InferenceSession(VAE_ONNX_PATH, providers=providers, sess_options=session_options) |
|
|
|
unet_session = ort.InferenceSession(UNET_ONNX_PATH, providers=providers, sess_options=session_options, provider_options=provider_options_1) |
|
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) |
|
tokenizer2 = CLIPTokenizer.from_pretrained(TOKENIZER_PATH2) |
|
text_encoder_session = ort.InferenceSession(TEXT_ENCODER_PATH, providers=providers, sess_options=session_options) |
|
text_encoder_session2 = ort.InferenceSession(TEXT_ENCODER_PATH2, providers=providers, sess_options=session_options) |
|
scheduler = DDPMScheduler.from_pretrained(SCHEDULER_PATH) |
|
|
|
controlnet = ort.InferenceSession(CONTROLNEXT_ONNX_PATH, providers=providers, sess_options=session_options) |
|
|
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder = 'sdxl_models/image_encoder').to(device, dtype=torch.float32) |
|
image_proj = ort.InferenceSession(PROJ_ONNX_PATH, providers=providers, sess_options=session_options) |
|
|
|
log_validation( |
|
vae=vae_session, |
|
scheduler=scheduler, |
|
text_encoder=text_encoder_session, |
|
tokenizer=tokenizer, |
|
unet=unet_session, |
|
controlnet=controlnet, |
|
image_encoder = image_encoder, |
|
args=args, |
|
device=device, |
|
image_proj = image_proj, |
|
text_encoder2 = text_encoder_session2, |
|
tokenizer2 = tokenizer2 |
|
) |