Spaces:
Paused
Paused
import copy | |
import os | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
from pytorch_lightning import loggers as pl_loggers | |
from pytorch_lightning.callbacks import * | |
from torch.cuda import amp | |
from torch.optim.optimizer import Optimizer | |
from torch.utils.data.dataset import TensorDataset | |
from model.seq2seq import DiffusionPredictor | |
from config import * | |
from dist_utils import * | |
from renderer import * | |
# This part is modified from: https://github.com/phizaz/diffae/blob/master/experiment.py | |
class LitModel(pl.LightningModule): | |
def __init__(self, conf: TrainConfig): | |
super().__init__() | |
assert conf.train_mode != TrainMode.manipulate | |
if conf.seed is not None: | |
pl.seed_everything(conf.seed) | |
self.save_hyperparameters(conf.as_dict_jsonable()) | |
self.conf = conf | |
self.model = DiffusionPredictor(conf) | |
self.ema_model = copy.deepcopy(self.model) | |
self.ema_model.requires_grad_(False) | |
self.ema_model.eval() | |
self.sampler = conf.make_diffusion_conf().make_sampler() | |
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() | |
# this is shared for both model and latent | |
self.T_sampler = conf.make_T_sampler() | |
if conf.train_mode.use_latent_net(): | |
self.latent_sampler = conf.make_latent_diffusion_conf( | |
).make_sampler() | |
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf( | |
).make_sampler() | |
else: | |
self.latent_sampler = None | |
self.eval_latent_sampler = None | |
# initial variables for consistent sampling | |
self.register_buffer( | |
'x_T', | |
torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size)) | |
def render(self, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, step_T, control_flag): | |
if step_T is None: | |
sampler = self.eval_sampler | |
else: | |
sampler = self.conf._make_diffusion_conf(step_T).make_sampler() | |
pred_img = render_condition(self.conf, | |
self.ema_model, | |
sampler, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, control_flag) | |
return pred_img | |
def forward(self, noise=None, x_start=None, ema_model: bool = False): | |
with amp.autocast(False): | |
if not self.disable_ema: | |
model = self.ema_model | |
else: | |
model = self.model | |
gen = self.eval_sampler.sample(model=model, | |
noise=noise, | |
x_start=x_start) | |
return gen | |
def setup(self, stage=None) -> None: | |
""" | |
make datasets & seeding each worker separately | |
""" | |
############################################## | |
# NEED TO SET THE SEED SEPARATELY HERE | |
if self.conf.seed is not None: | |
seed = self.conf.seed * get_world_size() + self.global_rank | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
print('local seed:', seed) | |
############################################## | |
self.train_data = self.conf.make_dataset() | |
print('train data:', len(self.train_data)) | |
self.val_data = self.train_data | |
print('val data:', len(self.val_data)) | |
def _train_dataloader(self, drop_last=True): | |
""" | |
really make the dataloader | |
""" | |
# make sure to use the fraction of batch size | |
# the batch size is global! | |
conf = self.conf.clone() | |
conf.batch_size = self.batch_size | |
dataloader = conf.make_loader(self.train_data, | |
shuffle=True, | |
drop_last=drop_last) | |
return dataloader | |
def train_dataloader(self): | |
""" | |
return the dataloader, if diffusion mode => return image dataset | |
if latent mode => return the inferred latent dataset | |
""" | |
print('on train dataloader start ...') | |
if self.conf.train_mode.require_dataset_infer(): | |
if self.conds is None: | |
# usually we load self.conds from a file | |
# so we do not need to do this again! | |
self.conds = self.infer_whole_dataset() | |
# need to use float32! unless the mean & std will be off! | |
# (1, c) | |
self.conds_mean.data = self.conds.float().mean(dim=0, | |
keepdim=True) | |
self.conds_std.data = self.conds.float().std(dim=0, | |
keepdim=True) | |
print('mean:', self.conds_mean.mean(), 'std:', | |
self.conds_std.mean()) | |
# return the dataset with pre-calculated conds | |
conf = self.conf.clone() | |
conf.batch_size = self.batch_size | |
data = TensorDataset(self.conds) | |
return conf.make_loader(data, shuffle=True) | |
else: | |
return self._train_dataloader() | |
def batch_size(self): | |
""" | |
local batch size for each worker | |
""" | |
ws = get_world_size() | |
assert self.conf.batch_size % ws == 0 | |
return self.conf.batch_size // ws | |
def num_samples(self): | |
""" | |
(global) batch size * iterations | |
""" | |
# batch size here is global! | |
# global_step already takes into account the accum batches | |
return self.global_step * self.conf.batch_size_effective | |
def is_last_accum(self, batch_idx): | |
""" | |
is it the last gradient accumulation loop? | |
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not | |
""" | |
return (batch_idx + 1) % self.conf.accum_batches == 0 | |
def training_step(self, batch, batch_idx): | |
""" | |
given an input, calculate the loss function | |
no optimization at this stage. | |
""" | |
with amp.autocast(False): | |
motion_start = batch['motion_start'] # torch.Size([B, 512]) | |
motion_direction = batch['motion_direction'] # torch.Size([B, 125, 20]) | |
audio_feats = batch['audio_feats'].float() # torch.Size([B, 25, 250, 1024]) | |
face_location = batch['face_location'].float() # torch.Size([B, 125]) | |
face_scale = batch['face_scale'].float() # torch.Size([B, 125, 1]) | |
yaw_pitch_roll = batch['yaw_pitch_roll'].float() # torch.Size([B, 125, 3]) | |
motion_direction_start = batch['motion_direction_start'].float() # torch.Size([B, 20]) | |
# import pdb; pdb.set_trace() | |
if self.conf.train_mode == TrainMode.diffusion: | |
""" | |
main training mode!!! | |
""" | |
# with numpy seed we have the problem that the sample t's are related! | |
t, weight = self.T_sampler.sample(len(motion_start), motion_start.device) | |
losses = self.sampler.training_losses(model=self.model, | |
motion_direction_start=motion_direction_start, | |
motion_target=motion_direction, | |
motion_start=motion_start, | |
audio_feats=audio_feats, | |
face_location=face_location, | |
face_scale=face_scale, | |
yaw_pitch_roll=yaw_pitch_roll, | |
t=t) | |
else: | |
raise NotImplementedError() | |
loss = losses['loss'].mean() | |
# divide by accum batches to make the accumulated gradient exact! | |
for key in losses.keys(): | |
losses[key] = self.all_gather(losses[key]).mean() | |
if self.global_rank == 0: | |
self.logger.experiment.add_scalar('loss', losses['loss'], | |
self.num_samples) | |
for key in losses: | |
self.logger.experiment.add_scalar( | |
f'loss/{key}', losses[key], self.num_samples) | |
return {'loss': loss} | |
def on_train_batch_end(self, outputs, batch, batch_idx: int, | |
dataloader_idx: int) -> None: | |
""" | |
after each training step ... | |
""" | |
if self.is_last_accum(batch_idx): | |
if self.conf.train_mode == TrainMode.latent_diffusion: | |
# it trains only the latent hence change only the latent | |
ema(self.model.latent_net, self.ema_model.latent_net, | |
self.conf.ema_decay) | |
else: | |
ema(self.model, self.ema_model, self.conf.ema_decay) | |
def on_before_optimizer_step(self, optimizer: Optimizer, | |
optimizer_idx: int) -> None: | |
# fix the fp16 + clip grad norm problem with pytorch lightinng | |
# this is the currently correct way to do it | |
if self.conf.grad_clip > 0: | |
# from trainer.params_grads import grads_norm, iter_opt_params | |
params = [ | |
p for group in optimizer.param_groups for p in group['params'] | |
] | |
torch.nn.utils.clip_grad_norm_(params, | |
max_norm=self.conf.grad_clip) | |
def configure_optimizers(self): | |
out = {} | |
if self.conf.optimizer == OptimizerType.adam: | |
optim = torch.optim.Adam(self.model.parameters(), | |
lr=self.conf.lr, | |
weight_decay=self.conf.weight_decay) | |
elif self.conf.optimizer == OptimizerType.adamw: | |
optim = torch.optim.AdamW(self.model.parameters(), | |
lr=self.conf.lr, | |
weight_decay=self.conf.weight_decay) | |
else: | |
raise NotImplementedError() | |
out['optimizer'] = optim | |
if self.conf.warmup > 0: | |
sched = torch.optim.lr_scheduler.LambdaLR(optim, | |
lr_lambda=WarmupLR( | |
self.conf.warmup)) | |
out['lr_scheduler'] = { | |
'scheduler': sched, | |
'interval': 'step', | |
} | |
return out | |
def split_tensor(self, x): | |
""" | |
extract the tensor for a corresponding "worker" in the batch dimension | |
Args: | |
x: (n, c) | |
Returns: x: (n_local, c) | |
""" | |
n = len(x) | |
rank = self.global_rank | |
world_size = get_world_size() | |
# print(f'rank: {rank}/{world_size}') | |
per_rank = n // world_size | |
return x[rank * per_rank:(rank + 1) * per_rank] | |
def ema(source, target, decay): | |
source_dict = source.state_dict() | |
target_dict = target.state_dict() | |
for key in source_dict.keys(): | |
target_dict[key].data.copy_(target_dict[key].data * decay + | |
source_dict[key].data * (1 - decay)) | |
class WarmupLR: | |
def __init__(self, warmup) -> None: | |
self.warmup = warmup | |
def __call__(self, step): | |
return min(step, self.warmup) / self.warmup | |
def is_time(num_samples, every, step_size): | |
closest = (num_samples // every) * every | |
return num_samples - closest < step_size | |
def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'): | |
print('conf:', conf.name) | |
# assert not (conf.fp16 and conf.grad_clip > 0 | |
# ), 'pytorch lightning has bug with amp + gradient clipping' | |
model = LitModel(conf) | |
if not os.path.exists(conf.logdir): | |
os.makedirs(conf.logdir) | |
checkpoint = ModelCheckpoint(dirpath=f'{conf.logdir}', | |
save_last=True, | |
save_top_k=-1, | |
every_n_epochs=10) | |
checkpoint_path = f'{conf.logdir}/last.ckpt' | |
print('ckpt path:', checkpoint_path) | |
if os.path.exists(checkpoint_path): | |
resume = checkpoint_path | |
print('resume!') | |
else: | |
if conf.continue_from is not None: | |
# continue from a checkpoint | |
resume = conf.continue_from.pathcd | |
else: | |
resume = None | |
tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir, | |
name=None, | |
version='') | |
# from pytorch_lightning. | |
plugins = [] | |
if len(gpus) == 1 and nodes == 1: | |
accelerator = None | |
else: | |
accelerator = 'ddp' | |
from pytorch_lightning.plugins import DDPPlugin | |
# important for working with gradient checkpoint | |
plugins.append(DDPPlugin(find_unused_parameters=True)) | |
trainer = pl.Trainer( | |
max_steps=conf.total_samples // conf.batch_size_effective, | |
resume_from_checkpoint=resume, | |
gpus=gpus, | |
num_nodes=nodes, | |
accelerator=accelerator, | |
precision=16 if conf.fp16 else 32, | |
callbacks=[ | |
checkpoint, | |
LearningRateMonitor(), | |
], | |
# clip in the model instead | |
# gradient_clip_val=conf.grad_clip, | |
replace_sampler_ddp=True, | |
logger=tb_logger, | |
accumulate_grad_batches=conf.accum_batches, | |
plugins=plugins, | |
) | |
trainer.fit(model) | |