|
|
|
import argparse |
|
import itertools |
|
import math |
|
import os |
|
import random |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from torch.utils.data import Dataset |
|
|
|
import PIL |
|
from accelerate import Accelerator |
|
from accelerate.logging import get_logger |
|
from accelerate.utils import set_seed |
|
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
|
from diffusers.optimization import get_scheduler |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
|
from PIL import Image |
|
from torchvision import transforms |
|
from tqdm.auto import tqdm |
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
|
|
|
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2" |
|
|
|
|
|
urls = [ |
|
"https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg", |
|
"https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg", |
|
"https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg", |
|
"https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg", |
|
] |
|
|
|
|
|
what_to_teach = "object" |
|
|
|
placeholder_token = "<cat-toy>" |
|
|
|
initializer_token = "toy" |
|
|
|
def image_grid(imgs, rows, cols): |
|
assert len(imgs) == rows*cols |
|
|
|
w, h = imgs[0].size |
|
grid = Image.new('RGB', size=(cols*w, rows*h)) |
|
grid_w, grid_h = grid.size |
|
|
|
for i, img in enumerate(imgs): |
|
grid.paste(img, box=(i%cols*w, i//cols*h)) |
|
return grid |
|
|
|
|
|
imagenet_templates_small = [ |
|
"a photo of a {}", |
|
"a rendering of a {}", |
|
"a cropped photo of the {}", |
|
"the photo of a {}", |
|
"a photo of a clean {}", |
|
"a photo of a dirty {}", |
|
"a dark photo of the {}", |
|
"a photo of my {}", |
|
"a photo of the cool {}", |
|
"a close-up photo of a {}", |
|
"a bright photo of the {}", |
|
"a cropped photo of a {}", |
|
"a photo of the {}", |
|
"a good photo of the {}", |
|
"a photo of one {}", |
|
"a close-up photo of the {}", |
|
"a rendition of the {}", |
|
"a photo of the clean {}", |
|
"a rendition of a {}", |
|
"a photo of a nice {}", |
|
"a good photo of a {}", |
|
"a photo of the nice {}", |
|
"a photo of the small {}", |
|
"a photo of the weird {}", |
|
"a photo of the large {}", |
|
"a photo of a cool {}", |
|
"a photo of a small {}", |
|
] |
|
|
|
imagenet_style_templates_small = [ |
|
"a painting in the style of {}", |
|
"a rendering in the style of {}", |
|
"a cropped painting in the style of {}", |
|
"the painting in the style of {}", |
|
"a clean painting in the style of {}", |
|
"a dirty painting in the style of {}", |
|
"a dark painting in the style of {}", |
|
"a picture in the style of {}", |
|
"a cool painting in the style of {}", |
|
"a close-up painting in the style of {}", |
|
"a bright painting in the style of {}", |
|
"a cropped painting in the style of {}", |
|
"a good painting in the style of {}", |
|
"a close-up painting in the style of {}", |
|
"a rendition in the style of {}", |
|
"a nice painting in the style of {}", |
|
"a small painting in the style of {}", |
|
"a weird painting in the style of {}", |
|
"a large painting in the style of {}", |
|
] |
|
|
|
|
|
class TextualInversionDataset(Dataset): |
|
def __init__( |
|
self, |
|
data_root, |
|
tokenizer, |
|
learnable_property="object", |
|
size=512, |
|
repeats=100, |
|
interpolation="bicubic", |
|
flip_p=0.5, |
|
set="train", |
|
placeholder_token="*", |
|
center_crop=False, |
|
): |
|
|
|
self.data_root = data_root |
|
self.tokenizer = tokenizer |
|
self.learnable_property = learnable_property |
|
self.size = size |
|
self.placeholder_token = placeholder_token |
|
self.center_crop = center_crop |
|
self.flip_p = flip_p |
|
|
|
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] |
|
|
|
self.num_images = len(self.image_paths) |
|
self._length = self.num_images |
|
|
|
if set == "train": |
|
self._length = self.num_images * repeats |
|
|
|
self.interpolation = { |
|
"linear": PIL.Image.LINEAR, |
|
"bilinear": PIL.Image.BILINEAR, |
|
"bicubic": PIL.Image.BICUBIC, |
|
"lanczos": PIL.Image.LANCZOS, |
|
}[interpolation] |
|
|
|
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small |
|
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) |
|
|
|
def __len__(self): |
|
return self._length |
|
|
|
def __getitem__(self, i): |
|
example = {} |
|
image = Image.open(self.image_paths[i % self.num_images]) |
|
|
|
if not image.mode == "RGB": |
|
image = image.convert("RGB") |
|
|
|
placeholder_string = self.placeholder_token |
|
text = random.choice(self.templates).format(placeholder_string) |
|
|
|
example["input_ids"] = self.tokenizer( |
|
text, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
return_tensors="pt", |
|
).input_ids[0] |
|
|
|
|
|
img = np.array(image).astype(np.uint8) |
|
|
|
if self.center_crop: |
|
crop = min(img.shape[0], img.shape[1]) |
|
h, w, = ( |
|
img.shape[0], |
|
img.shape[1], |
|
) |
|
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] |
|
|
|
image = Image.fromarray(img) |
|
image = image.resize((self.size, self.size), resample=self.interpolation) |
|
|
|
image = self.flip_transform(image) |
|
image = np.array(image).astype(np.uint8) |
|
image = (image / 127.5 - 1.0).astype(np.float32) |
|
|
|
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) |
|
return example |
|
|
|
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder="tokenizer", |
|
) |
|
|
|
|
|
num_added_tokens = tokenizer.add_tokens(placeholder_token) |
|
if num_added_tokens == 0: |
|
raise ValueError( |
|
f"The tokenizer already contains the token {placeholder_token}. Please pass a different" |
|
" `placeholder_token` that is not already in the tokenizer." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False) |
|
|
|
if len(token_ids) > 1: |
|
raise ValueError("The initializer token must be a single token.") |
|
|
|
initializer_token_id = token_ids[0] |
|
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) |
|
|
|
|
|
|
|
|
|
|
|
|
|
text_encoder = CLIPTextModel.from_pretrained( |
|
pretrained_model_name_or_path, subfolder="text_encoder" |
|
) |
|
vae = AutoencoderKL.from_pretrained( |
|
pretrained_model_name_or_path, subfolder="vae" |
|
) |
|
unet = UNet2DConditionModel.from_pretrained( |
|
pretrained_model_name_or_path, subfolder="unet" |
|
) |
|
|
|
text_encoder.resize_token_embeddings(len(tokenizer)) |
|
|
|
token_embeds = text_encoder.get_input_embeddings().weight.data |
|
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] |
|
|
|
def freeze_params(params): |
|
for param in params: |
|
param.requires_grad = False |
|
|
|
|
|
freeze_params(vae.parameters()) |
|
freeze_params(unet.parameters()) |
|
|
|
params_to_freeze = itertools.chain( |
|
text_encoder.text_model.encoder.parameters(), |
|
text_encoder.text_model.final_layer_norm.parameters(), |
|
text_encoder.text_model.embeddings.position_embedding.parameters(), |
|
) |
|
freeze_params(params_to_freeze) |
|
|
|
train_dataset = TextualInversionDataset( |
|
data_root=save_path, |
|
tokenizer=tokenizer, |
|
size=vae.sample_size, |
|
placeholder_token=placeholder_token, |
|
repeats=100, |
|
learnable_property=what_to_teach, |
|
center_crop=False, |
|
set="train", |
|
) |
|
|
|
def create_dataloader(train_batch_size=1): |
|
return torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True) |
|
|
|
noise_scheduler = DDPMScheduler.from_config(pretrained_model_name_or_path, subfolder="scheduler") |
|
|
|
|
|
|