Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,921 Bytes
0a82683 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import torch
from PIL import Image
from datasets import load_dataset
from torchvision import transforms
import random
import os
import numpy as np
Image.MAX_IMAGE_PIXELS = None
def make_train_dataset(args, tokenizer, accelerator=None):
if args.train_data_dir is not None:
print("load_data")
dataset = load_dataset('json', data_files=args.train_data_dir)
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if args.caption_column is None:
caption_column = column_names[0]
print(f"caption column defaulting to {caption_column}")
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if args.source_column is None:
source_column = column_names[1]
print(f"source column defaulting to {source_column}")
else:
source_column = args.source_column
if source_column not in column_names:
raise ValueError(
f"`--source_column` value '{args.source_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if args.target_column is None:
target_column = column_names[1]
print(f"target column defaulting to {target_column}")
else:
target_column = args.target_column
if target_column not in column_names:
raise ValueError(
f"`--target_column` value '{args.target_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
h = args.height
w = args.width
train_transforms = transforms.Compose(
[
transforms.Resize((h, w), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
tokenizer_clip = tokenizer[0]
tokenizer_t5 = tokenizer[1]
def tokenize_prompt_clip_t5(examples):
captions = []
for caption in examples[caption_column]:
if isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, list):
captions.append(random.choice(caption))
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
text_inputs = tokenizer_clip(
captions,
padding="max_length",
max_length=77,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids_1 = text_inputs.input_ids
text_inputs = tokenizer_t5(
captions,
padding="max_length",
max_length=512,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids_2 = text_inputs.input_ids
return text_input_ids_1, text_input_ids_2
def preprocess_train(examples):
_examples = {}
source_images = [Image.open(image).convert("RGB") for image in examples[source_column]]
target_images = [Image.open(image).convert("RGB") for image in examples[target_column]]
_examples["cond_pixel_values"] = [train_transforms(source) for source in source_images]
_examples["pixel_values"] = [train_transforms(image) for image in target_images]
_examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(examples)
return _examples
if accelerator is not None:
with accelerator.main_process_first():
train_dataset = dataset["train"].with_transform(preprocess_train)
else:
train_dataset = dataset["train"].with_transform(preprocess_train)
return train_dataset
def collate_fn(examples):
cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples])
token_ids_t5 = torch.stack([torch.tensor(example["token_ids_t5"]) for example in examples])
return {
"cond_pixel_values": cond_pixel_values,
"pixel_values": target_pixel_values,
"text_ids_1": token_ids_clip,
"text_ids_2": token_ids_t5,
}
|