|
import itertools |
|
import math |
|
import re |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from accelerate.logging import get_logger |
|
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from einops import rearrange |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
from mixofshow.models.edlora import (LoRALinearLayer, revise_edlora_unet_attention_controller_forward, |
|
revise_edlora_unet_attention_forward) |
|
from mixofshow.pipelines.pipeline_edlora import bind_concept_prompt |
|
from mixofshow.utils.ptp_util import AttentionStore |
|
|
|
|
|
class EDLoRATrainer(nn.Module): |
|
def __init__( |
|
self, |
|
pretrained_path, |
|
new_concept_token, |
|
initializer_token, |
|
enable_edlora, |
|
finetune_cfg=None, |
|
noise_offset=None, |
|
attn_reg_weight=None, |
|
reg_full_identity=True, |
|
use_mask_loss=True, |
|
enable_xformers=False, |
|
gradient_checkpoint=False |
|
): |
|
super().__init__() |
|
|
|
|
|
self.vae = AutoencoderKL.from_pretrained(pretrained_path, subfolder='vae') |
|
self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_path, subfolder='tokenizer') |
|
self.text_encoder = CLIPTextModel.from_pretrained(pretrained_path, subfolder='text_encoder') |
|
self.unet = UNet2DConditionModel.from_pretrained(pretrained_path, subfolder='unet') |
|
|
|
if gradient_checkpoint: |
|
self.unet.enable_gradient_checkpointing() |
|
|
|
if enable_xformers: |
|
assert is_xformers_available(), 'need to install xformer first' |
|
|
|
|
|
self.scheduler = DDPMScheduler.from_pretrained(pretrained_path, subfolder='scheduler') |
|
|
|
|
|
self.enable_edlora = enable_edlora |
|
self.new_concept_cfg = self.init_new_concept(new_concept_token, initializer_token, enable_edlora=enable_edlora) |
|
|
|
self.attn_reg_weight = attn_reg_weight |
|
self.reg_full_identity = reg_full_identity |
|
if self.attn_reg_weight is not None: |
|
self.controller = AttentionStore(training=True) |
|
revise_edlora_unet_attention_controller_forward(self.unet, self.controller) |
|
else: |
|
revise_edlora_unet_attention_forward(self.unet) |
|
|
|
if finetune_cfg: |
|
self.set_finetune_cfg(finetune_cfg) |
|
|
|
self.noise_offset = noise_offset |
|
self.use_mask_loss = use_mask_loss |
|
|
|
def set_finetune_cfg(self, finetune_cfg): |
|
logger = get_logger('mixofshow', log_level='INFO') |
|
params_to_freeze = [self.vae.parameters(), self.text_encoder.parameters(), self.unet.parameters()] |
|
|
|
|
|
for params in itertools.chain(*params_to_freeze): |
|
params.requires_grad = False |
|
|
|
|
|
params_group_list = [] |
|
|
|
|
|
if finetune_cfg['text_embedding']['enable_tuning']: |
|
text_embedding_cfg = finetune_cfg['text_embedding'] |
|
|
|
params_list = [] |
|
for params in self.text_encoder.get_input_embeddings().parameters(): |
|
params.requires_grad = True |
|
params_list.append(params) |
|
|
|
params_group = {'params': params_list, 'lr': text_embedding_cfg['lr']} |
|
if 'weight_decay' in text_embedding_cfg: |
|
params_group.update({'weight_decay': text_embedding_cfg['weight_decay']}) |
|
params_group_list.append(params_group) |
|
logger.info(f"optimizing embedding using lr: {text_embedding_cfg['lr']}") |
|
|
|
|
|
if finetune_cfg['text_encoder']['enable_tuning'] and finetune_cfg['text_encoder'].get('lora_cfg'): |
|
text_encoder_cfg = finetune_cfg['text_encoder'] |
|
|
|
where = text_encoder_cfg['lora_cfg'].pop('where') |
|
assert where in ['CLIPEncoderLayer', 'CLIPAttention'] |
|
|
|
self.text_encoder_lora = nn.ModuleList() |
|
params_list = [] |
|
|
|
for name, module in self.text_encoder.named_modules(): |
|
if module.__class__.__name__ == where: |
|
for child_name, child_module in module.named_modules(): |
|
if child_module.__class__.__name__ == 'Linear': |
|
lora_module = LoRALinearLayer(name + '.' + child_name, child_module, **text_encoder_cfg['lora_cfg']) |
|
self.text_encoder_lora.append(lora_module) |
|
params_list.extend(list(lora_module.parameters())) |
|
|
|
params_group_list.append({'params': params_list, 'lr': text_encoder_cfg['lr']}) |
|
logger.info(f"optimizing text_encoder ({len(self.text_encoder_lora)} LoRAs), using lr: {text_encoder_cfg['lr']}") |
|
|
|
|
|
if finetune_cfg['unet']['enable_tuning'] and finetune_cfg['unet'].get('lora_cfg'): |
|
unet_cfg = finetune_cfg['unet'] |
|
|
|
where = unet_cfg['lora_cfg'].pop('where') |
|
assert where in ['Transformer2DModel', 'Attention'] |
|
|
|
self.unet_lora = nn.ModuleList() |
|
params_list = [] |
|
|
|
for name, module in self.unet.named_modules(): |
|
if module.__class__.__name__ == where: |
|
for child_name, child_module in module.named_modules(): |
|
if child_module.__class__.__name__ == 'Linear' or (child_module.__class__.__name__ == 'Conv2d' and child_module.kernel_size == (1, 1)): |
|
lora_module = LoRALinearLayer(name + '.' + child_name, child_module, **unet_cfg['lora_cfg']) |
|
self.unet_lora.append(lora_module) |
|
params_list.extend(list(lora_module.parameters())) |
|
|
|
params_group_list.append({'params': params_list, 'lr': unet_cfg['lr']}) |
|
logger.info(f"optimizing unet ({len(self.unet_lora)} LoRAs), using lr: {unet_cfg['lr']}") |
|
|
|
|
|
self.params_to_optimize_iterator = params_group_list |
|
|
|
def get_params_to_optimize(self): |
|
return self.params_to_optimize_iterator |
|
|
|
def init_new_concept(self, new_concept_tokens, initializer_tokens, enable_edlora=True): |
|
logger = get_logger('mixofshow', log_level='INFO') |
|
new_concept_cfg = {} |
|
new_concept_tokens = new_concept_tokens.split('+') |
|
|
|
if initializer_tokens is None: |
|
initializer_tokens = ['<rand-0.017>'] * len(new_concept_tokens) |
|
else: |
|
initializer_tokens = initializer_tokens.split('+') |
|
assert len(new_concept_tokens) == len(initializer_tokens), 'concept token should match init token.' |
|
|
|
for idx, (concept_name, init_token) in enumerate(zip(new_concept_tokens, initializer_tokens)): |
|
if enable_edlora: |
|
num_new_embedding = 16 |
|
else: |
|
num_new_embedding = 1 |
|
new_token_names = [f'<new{idx * num_new_embedding + layer_id}>' for layer_id in range(num_new_embedding)] |
|
|
|
num_added_tokens = self.tokenizer.add_tokens(new_token_names) |
|
assert num_added_tokens == len(new_token_names), 'some token is already in tokenizer' |
|
new_token_ids = [self.tokenizer.convert_tokens_to_ids(token_name) for token_name in new_token_names] |
|
|
|
|
|
self.text_encoder.resize_token_embeddings(len(self.tokenizer)) |
|
token_embeds = self.text_encoder.get_input_embeddings().weight.data |
|
|
|
if init_token.startswith('<rand'): |
|
sigma_val = float(re.findall(r'<rand-(.*)>', init_token)[0]) |
|
init_feature = torch.randn_like(token_embeds[0]) * sigma_val |
|
logger.info(f'{concept_name} ({min(new_token_ids)}-{max(new_token_ids)}) is random initialized by: {init_token}') |
|
else: |
|
|
|
init_token_ids = self.tokenizer.encode(init_token, add_special_tokens=False) |
|
|
|
|
|
if len(init_token_ids) > 1 or init_token_ids[0] == 40497: |
|
raise ValueError('The initializer token must be a single existing token.') |
|
init_feature = token_embeds[init_token_ids] |
|
logger.info(f'{concept_name} ({min(new_token_ids)}-{max(new_token_ids)}) is random initialized by existing token ({init_token}): {init_token_ids[0]}') |
|
|
|
for token_id in new_token_ids: |
|
token_embeds[token_id] = init_feature.clone() |
|
|
|
new_concept_cfg.update({ |
|
concept_name: { |
|
'concept_token_ids': new_token_ids, |
|
'concept_token_names': new_token_names |
|
} |
|
}) |
|
|
|
return new_concept_cfg |
|
|
|
def get_all_concept_token_ids(self): |
|
new_concept_token_ids = [] |
|
for _, new_token_cfg in self.new_concept_cfg.items(): |
|
new_concept_token_ids.extend(new_token_cfg['concept_token_ids']) |
|
return new_concept_token_ids |
|
|
|
def forward(self, images, prompts, masks, img_masks): |
|
latents = self.vae.encode(images).latent_dist.sample() |
|
latents = latents * 0.18215 |
|
|
|
|
|
noise = torch.randn_like(latents) |
|
if self.noise_offset is not None: |
|
noise += self.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) |
|
|
|
bsz = latents.shape[0] |
|
|
|
timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (bsz, ), device=latents.device) |
|
timesteps = timesteps.long() |
|
|
|
|
|
|
|
noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) |
|
|
|
if self.enable_edlora: |
|
prompts = bind_concept_prompt(prompts, new_concept_cfg=self.new_concept_cfg) |
|
|
|
|
|
text_input_ids = self.tokenizer( |
|
prompts, |
|
padding='max_length', |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors='pt').input_ids.to(latents.device) |
|
|
|
|
|
encoder_hidden_states = self.text_encoder(text_input_ids)[0] |
|
if self.enable_edlora: |
|
encoder_hidden_states = rearrange(encoder_hidden_states, '(b n) m c -> b n m c', b=latents.shape[0]) |
|
|
|
|
|
model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample |
|
|
|
|
|
if self.scheduler.config.prediction_type == 'epsilon': |
|
target = noise |
|
elif self.scheduler.config.prediction_type == 'v_prediction': |
|
target = self.scheduler.get_velocity(latents, noise, timesteps) |
|
else: |
|
raise ValueError(f'Unknown prediction type {self.scheduler.config.prediction_type}') |
|
|
|
if self.use_mask_loss: |
|
loss_mask = masks |
|
else: |
|
loss_mask = img_masks |
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction='none') |
|
loss = ((loss * loss_mask).sum([1, 2, 3]) / loss_mask.sum([1, 2, 3])).mean() |
|
|
|
if self.attn_reg_weight is not None: |
|
attention_maps = self.controller.get_average_attention() |
|
attention_loss = self.cal_attn_reg(attention_maps, masks, text_input_ids) |
|
if not torch.isnan(attention_loss): |
|
loss = loss + attention_loss |
|
self.controller.reset() |
|
|
|
return loss |
|
|
|
def cal_attn_reg(self, attention_maps, masks, text_input_ids): |
|
''' |
|
attention_maps: {down_cross:[], mid_cross:[], up_cross:[]} |
|
masks: torch.Size([1, 1, 64, 64]) |
|
text_input_ids: torch.Size([16, 77]) |
|
''' |
|
|
|
batch_size = masks.shape[0] |
|
text_input_ids = rearrange(text_input_ids, '(b l) n -> b l n', b=batch_size) |
|
|
|
|
|
|
|
new_token_pos = [] |
|
all_concept_token_ids = self.get_all_concept_token_ids() |
|
for text in text_input_ids: |
|
text = text[0] |
|
new_token_pos.append([idx for idx in range(len(text)) if text[idx] in all_concept_token_ids]) |
|
|
|
|
|
attention_groups = {'64': [], '32': [], '16': [], '8': []} |
|
for _, attention_list in attention_maps.items(): |
|
for attn in attention_list: |
|
res = int(math.sqrt(attn.shape[1])) |
|
cross_map = attn.reshape(batch_size, -1, res, res, attn.shape[-1]) |
|
attention_groups[str(res)].append(cross_map) |
|
|
|
for k, cross_map in attention_groups.items(): |
|
cross_map = torch.cat(cross_map, dim=-4) |
|
cross_map = cross_map.sum(-4) / cross_map.shape[-4] |
|
cross_map = torch.stack([batch_map[..., batch_pos] for batch_pos, batch_map in zip(new_token_pos, cross_map)]) |
|
attention_groups[k] = cross_map |
|
|
|
attn_reg_total = 0 |
|
|
|
for k, cross_map in attention_groups.items(): |
|
map_adjective, map_subject = cross_map[..., 0], cross_map[..., 1] |
|
|
|
map_subject = map_subject / map_subject.max() |
|
map_adjective = map_adjective / map_adjective.max() |
|
|
|
gt_mask = F.interpolate(masks, size=map_subject.shape[1:], mode='nearest').squeeze(1) |
|
|
|
if self.reg_full_identity: |
|
loss_subject = F.mse_loss(map_subject.float(), gt_mask.float(), reduction='mean') |
|
else: |
|
loss_subject = map_subject[gt_mask == 0].mean() |
|
|
|
loss_adjective = map_adjective[gt_mask == 0].mean() |
|
|
|
attn_reg_total += self.attn_reg_weight * (loss_subject + loss_adjective) |
|
return attn_reg_total |
|
|
|
def load_delta_state_dict(self, delta_state_dict): |
|
|
|
logger = get_logger('mixofshow', log_level='INFO') |
|
|
|
if 'new_concept_embedding' in delta_state_dict and len(delta_state_dict['new_concept_embedding']) != 0: |
|
new_concept_tokens = list(delta_state_dict['new_concept_embedding'].keys()) |
|
|
|
|
|
token_embeds = self.text_encoder.get_input_embeddings().weight.data |
|
if set(new_concept_tokens) != set(self.new_concept_cfg.keys()): |
|
logger.warning('Your checkpoint have different concept with your model, loading existing concepts') |
|
|
|
for concept_name, concept_cfg in self.new_concept_cfg.items(): |
|
logger.info(f'load: concept_{concept_name}') |
|
token_embeds[concept_cfg['concept_token_ids']] = token_embeds[ |
|
concept_cfg['concept_token_ids']].copy_(delta_state_dict['new_concept_embedding'][concept_name]) |
|
|
|
|
|
if 'text_encoder' in delta_state_dict and len(delta_state_dict['text_encoder']) != 0: |
|
load_keys = delta_state_dict['text_encoder'].keys() |
|
if hasattr(self, 'text_encoder_lora') and len(load_keys) == 2 * len(self.text_encoder_lora): |
|
logger.info('loading LoRA for text encoder:') |
|
for lora_module in self.text_encoder_lora: |
|
for name, param, in lora_module.named_parameters(): |
|
logger.info(f'load: {lora_module.name}.{name}') |
|
param.data.copy_(delta_state_dict['text_encoder'][f'{lora_module.name}.{name}']) |
|
else: |
|
for name, param, in self.text_encoder.named_parameters(): |
|
if name in load_keys and 'token_embedding' not in name: |
|
logger.info(f'load: {name}') |
|
param.data.copy_(delta_state_dict['text_encoder'][f'{name}']) |
|
|
|
|
|
if 'unet' in delta_state_dict and len(delta_state_dict['unet']) != 0: |
|
load_keys = delta_state_dict['unet'].keys() |
|
if hasattr(self, 'unet_lora') and len(load_keys) == 2 * len(self.unet_lora): |
|
logger.info('loading LoRA for unet:') |
|
for lora_module in self.unet_lora: |
|
for name, param, in lora_module.named_parameters(): |
|
logger.info(f'load: {lora_module.name}.{name}') |
|
param.data.copy_(delta_state_dict['unet'][f'{lora_module.name}.{name}']) |
|
else: |
|
for name, param, in self.unet.named_parameters(): |
|
if name in load_keys: |
|
logger.info(f'load: {name}') |
|
param.data.copy_(delta_state_dict['unet'][f'{name}']) |
|
|
|
def delta_state_dict(self): |
|
delta_dict = {'new_concept_embedding': {}, 'text_encoder': {}, 'unet': {}} |
|
|
|
|
|
for concept_name, concept_cfg in self.new_concept_cfg.items(): |
|
learned_embeds = self.text_encoder.get_input_embeddings().weight[concept_cfg['concept_token_ids']] |
|
delta_dict['new_concept_embedding'][concept_name] = learned_embeds.detach().cpu() |
|
|
|
|
|
for lora_module in self.text_encoder_lora: |
|
for name, param, in lora_module.named_parameters(): |
|
delta_dict['text_encoder'][f'{lora_module.name}.{name}'] = param.cpu().clone() |
|
|
|
|
|
for lora_module in self.unet_lora: |
|
for name, param, in lora_module.named_parameters(): |
|
delta_dict['unet'][f'{lora_module.name}.{name}'] = param.cpu().clone() |
|
|
|
return delta_dict |