import os import torch import cv2 import numpy as np from PIL import Image import argparse from diffusers import DDPMScheduler from pipeline_sdxl_ipadapter import StableDiffusionXLControlNeXtPipeline from transformers import CLIPVisionModelWithProjection from transformers import CLIPTokenizer import onnxruntime as ort from configs import * def log_validation( vae, scheduler, text_encoder, tokenizer, unet, controlnet, args, device, image_proj, text_encoder2, tokenizer2, image_encoder ): if len(args.validation_image) == len(args.validation_prompt): validation_images = args.validation_image validation_prompts = args.validation_prompt elif len(args.validation_image) == 1: validation_images = args.validation_image * len(args.validation_prompt) validation_prompts = args.validation_prompt elif len(args.validation_prompt) == 1: validation_images = args.validation_image validation_prompts = args.validation_prompt * len(args.validation_image) else: raise ValueError( "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" ) if args.negative_prompt is not None: negative_prompts = args.negative_prompt assert len(validation_prompts) == len(validation_prompts) else: negative_prompts = None inference_ctx = torch.autocast(device) pipeline = StableDiffusionXLControlNeXtPipeline( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder2, tokenizer=tokenizer, tokenizer_2=tokenizer2, unet=unet, controlnext=controlnet, scheduler=scheduler, image_encoder=image_encoder, device=device, image_proj=image_proj ) image_logs = [] pil_image = args.pil_image if args.pil_image is not None: pil_image = Image.open(pil_image).convert("RGB") for i, (validation_prompt, validation_image) in enumerate(zip(validation_prompts, validation_images)): validation_image = Image.open(validation_image).convert("RGB") images = [] negative_prompt = negative_prompts[i] if negative_prompts is not None else None for _ in range(args.num_validation_images): with inference_ctx: image = pipeline( prompt=validation_prompt, controlnet_image=validation_image, num_inference_steps=args.num_inference_steps, guidance_rescale = args.guidance_scale, negative_prompt=negative_prompt, ip_adapter_image=pil_image, control_scale=args.controlnext_scale, width = args.width, height=args.height, )[0] images.append(image) image_logs.append( {"validation_image": validation_image.resize((args.width,args.height)), "ip_adapter_image": pil_image.resize((args.width,args.height)), "images": images, "validation_prompt": validation_prompt} ) save_dir_path = args.output_dir if not os.path.exists(save_dir_path): os.makedirs(save_dir_path) for i, log in enumerate(image_logs): images = log["images"] validation_prompt = log["validation_prompt"] ip_adapter_image = log["ip_adapter_image"] validation_image = log["validation_image"] formatted_images = [] formatted_images.append(np.asarray(validation_image)) formatted_images.append(np.asarray(ip_adapter_image)) for image in images: formatted_images.append(np.asarray(image)) for idx, img in enumerate(formatted_images): print(f"Image {idx} shape: {img.shape}") formatted_images = np.concatenate(formatted_images, 1) file_path = os.path.join(save_dir_path, "image_{}.png".format(i)) formatted_images = cv2.cvtColor(formatted_images, cv2.COLOR_BGR2RGB) print("Save images to:", file_path) cv2.imwrite(file_path, formatted_images) return image_logs def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") parser.add_argument( "--output_dir", type=str, default=None, help="The output directory where the inference result will be written.", ) parser.add_argument( "--pil_image", type=str, default=None, help="IP Adapter image path.", ) parser.add_argument( "--validation_prompt", type=str, default=None, nargs="+", help=( "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." " Provide either a matching number of `--validation_image`s, a single `--validation_image`" " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." ), ) parser.add_argument( "--negative_prompt", type=str, default=None, nargs="+", help=( "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." " Provide either a matching number of `--validation_image`s, a single `--validation_image`" " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." ), ) parser.add_argument( "--validation_image", type=str, default=None, nargs="+", help=( "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" " `--validation_image` that will be used with all `--validation_prompt`s." ), ) parser.add_argument( "--num_validation_images", type=int, default=1, help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair.", ) parser.add_argument( "--num_inference_steps", type=int, default=30, help="Number of steps for inference.", ) parser.add_argument( "--controlnext_scale", type=float, default=2.5, help="ControlNext scale.", ) parser.add_argument( "--guidance_scale", type=float, default=7.5, help="Guidance scale.", ) parser.add_argument( "--height", type=int, default=1024, help="The height of output image.", ) parser.add_argument( "--width", type=int, default=1024, help="The width of output image.", ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() if args.validation_prompt is not None and args.validation_image is None: raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") if args.validation_prompt is None and args.validation_image is not None: raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") if ( args.validation_image is not None and args.validation_prompt is not None and len(args.validation_image) != 1 and len(args.validation_prompt) != 1 and len(args.validation_image) != len(args.validation_prompt) ): raise ValueError( "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," " or the same number of `--validation_prompt`s and `--validation_image`s" ) return args if __name__ == "__main__": args = parse_args() device = 'cuda:0' vae_session = ort.InferenceSession(VAE_ONNX_PATH, providers=providers, sess_options=session_options) unet_session = ort.InferenceSession(UNET_ONNX_PATH, providers=providers, sess_options=session_options, provider_options=provider_options_1) tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) tokenizer2 = CLIPTokenizer.from_pretrained(TOKENIZER_PATH2) text_encoder_session = ort.InferenceSession(TEXT_ENCODER_PATH, providers=providers, sess_options=session_options) text_encoder_session2 = ort.InferenceSession(TEXT_ENCODER_PATH2, providers=providers, sess_options=session_options) scheduler = DDPMScheduler.from_pretrained(SCHEDULER_PATH) controlnet = ort.InferenceSession(CONTROLNEXT_ONNX_PATH, providers=providers, sess_options=session_options) image_encoder = ort.InferenceSession(IMAGE_ENCODER_ONNX_PATH, providers=providers, provider_options=provider_options_1) image_proj = ort.InferenceSession(PROJ_ONNX_PATH, providers=providers, sess_options=session_options) log_validation( vae=vae_session, scheduler=scheduler, text_encoder=text_encoder_session, tokenizer=tokenizer, unet=unet_session, controlnet=controlnet, image_encoder = image_encoder, args=args, device=device, image_proj = image_proj, text_encoder2 = text_encoder_session2, tokenizer2 = tokenizer2 )