import torch from PIL import Image from datasets import load_dataset from torchvision import transforms import random import os import numpy as np Image.MAX_IMAGE_PIXELS = None def make_train_dataset(args, tokenizer, accelerator=None): if args.train_data_dir is not None: print("load_data") dataset = load_dataset('json', data_files=args.train_data_dir) column_names = dataset["train"].column_names # 6. Get the column names for input/target. if args.caption_column is None: caption_column = column_names[0] print(f"caption column defaulting to {caption_column}") else: caption_column = args.caption_column if caption_column not in column_names: raise ValueError( f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) if args.source_column is None: source_column = column_names[1] print(f"source column defaulting to {source_column}") else: source_column = args.source_column if source_column not in column_names: raise ValueError( f"`--source_column` value '{args.source_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) if args.target_column is None: target_column = column_names[1] print(f"target column defaulting to {target_column}") else: target_column = args.target_column if target_column not in column_names: raise ValueError( f"`--target_column` value '{args.target_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) h = args.height w = args.width train_transforms = transforms.Compose( [ transforms.Resize((h, w), interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) tokenizer_clip = tokenizer[0] tokenizer_t5 = tokenizer[1] def tokenize_prompt_clip_t5(examples): captions = [] for caption in examples[caption_column]: if isinstance(caption, str): captions.append(caption) elif isinstance(caption, list): captions.append(random.choice(caption)) else: raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) text_inputs = tokenizer_clip( captions, padding="max_length", max_length=77, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids_1 = text_inputs.input_ids text_inputs = tokenizer_t5( captions, padding="max_length", max_length=512, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids_2 = text_inputs.input_ids return text_input_ids_1, text_input_ids_2 def preprocess_train(examples): _examples = {} source_images = [Image.open(image).convert("RGB") for image in examples[source_column]] target_images = [Image.open(image).convert("RGB") for image in examples[target_column]] _examples["cond_pixel_values"] = [train_transforms(source) for source in source_images] _examples["pixel_values"] = [train_transforms(image) for image in target_images] _examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(examples) return _examples if accelerator is not None: with accelerator.main_process_first(): train_dataset = dataset["train"].with_transform(preprocess_train) else: train_dataset = dataset["train"].with_transform(preprocess_train) return train_dataset def collate_fn(examples): cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples]) cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() target_pixel_values = torch.stack([example["pixel_values"] for example in examples]) target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float() token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples]) token_ids_t5 = torch.stack([torch.tensor(example["token_ids_t5"]) for example in examples]) return { "cond_pixel_values": cond_pixel_values, "pixel_values": target_pixel_values, "text_ids_1": token_ids_clip, "text_ids_2": token_ids_t5, }