|
import os |
|
import sys |
|
import logging |
|
import datetime |
|
import os.path as osp |
|
|
|
from tqdm.auto import tqdm |
|
from omegaconf import OmegaConf |
|
|
|
import torch |
|
import swanlab |
|
import diffusers |
|
import transformers |
|
from torch.utils.tensorboard import SummaryWriter |
|
from diffusers.optimization import get_scheduler |
|
|
|
from mld.config import parse_args |
|
from mld.data.get_data import get_dataset |
|
from mld.models.modeltype.mld import MLD |
|
from mld.utils.utils import print_table, set_seed, move_batch_to_device |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
def main(): |
|
cfg = parse_args() |
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
set_seed(cfg.SEED_VALUE) |
|
|
|
name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) |
|
cfg.output_dir = osp.join(cfg.FOLDER, name_time_str) |
|
os.makedirs(cfg.output_dir, exist_ok=False) |
|
os.makedirs(f"{cfg.output_dir}/checkpoints", exist_ok=False) |
|
if cfg.TRAIN.model_ema: |
|
os.makedirs(f"{cfg.output_dir}/checkpoints_ema", exist_ok=False) |
|
|
|
if cfg.vis == "tb": |
|
writer = SummaryWriter(cfg.output_dir) |
|
elif cfg.vis == "swanlab": |
|
writer = swanlab.init(project="MotionLCM", |
|
experiment_name=os.path.normpath(cfg.output_dir).replace(os.path.sep, "-"), |
|
suffix=None, config=dict(**cfg), logdir=cfg.output_dir) |
|
else: |
|
raise ValueError(f"Invalid vis method: {cfg.vis}") |
|
|
|
stream_handler = logging.StreamHandler(sys.stdout) |
|
file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log')) |
|
handlers = [file_handler, stream_handler] |
|
logging.basicConfig(level=logging.INFO, |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
handlers=handlers) |
|
logger = logging.getLogger(__name__) |
|
|
|
OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml')) |
|
|
|
transformers.utils.logging.set_verbosity_warning() |
|
diffusers.utils.logging.set_verbosity_info() |
|
|
|
dataset = get_dataset(cfg) |
|
train_dataloader = dataset.train_dataloader() |
|
val_dataloader = dataset.val_dataloader() |
|
|
|
model = MLD(cfg, dataset) |
|
|
|
assert cfg.TRAIN.PRETRAINED, "cfg.TRAIN.PRETRAINED must not be None." |
|
logger.info(f"Loading pre-trained model: {cfg.TRAIN.PRETRAINED}") |
|
state_dict = torch.load(cfg.TRAIN.PRETRAINED, map_location="cpu")["state_dict"] |
|
logger.info(model.load_state_dict(state_dict, strict=False)) |
|
|
|
model.vae.requires_grad_(False) |
|
model.text_encoder.requires_grad_(False) |
|
model.vae.eval() |
|
model.text_encoder.eval() |
|
model.to(device) |
|
|
|
logger.info("learning_rate: {}".format(cfg.TRAIN.learning_rate)) |
|
optimizer = torch.optim.AdamW( |
|
model.denoiser.parameters(), |
|
lr=cfg.TRAIN.learning_rate, |
|
betas=(cfg.TRAIN.adam_beta1, cfg.TRAIN.adam_beta2), |
|
weight_decay=cfg.TRAIN.adam_weight_decay, |
|
eps=cfg.TRAIN.adam_epsilon) |
|
|
|
if cfg.TRAIN.max_train_steps == -1: |
|
assert cfg.TRAIN.max_train_epochs != -1 |
|
cfg.TRAIN.max_train_steps = cfg.TRAIN.max_train_epochs * len(train_dataloader) |
|
|
|
if cfg.TRAIN.checkpointing_steps == -1: |
|
assert cfg.TRAIN.checkpointing_epochs != -1 |
|
cfg.TRAIN.checkpointing_steps = cfg.TRAIN.checkpointing_epochs * len(train_dataloader) |
|
|
|
if cfg.TRAIN.validation_steps == -1: |
|
assert cfg.TRAIN.validation_epochs != -1 |
|
cfg.TRAIN.validation_steps = cfg.TRAIN.validation_epochs * len(train_dataloader) |
|
|
|
lr_scheduler = get_scheduler( |
|
cfg.TRAIN.lr_scheduler, |
|
optimizer=optimizer, |
|
num_warmup_steps=cfg.TRAIN.lr_warmup_steps, |
|
num_training_steps=cfg.TRAIN.max_train_steps) |
|
|
|
|
|
model_ema = None |
|
if cfg.TRAIN.model_ema: |
|
alpha = 1.0 - cfg.TRAIN.model_ema_decay |
|
logger.info(f'EMA alpha: {alpha}') |
|
model_ema = torch.optim.swa_utils.AveragedModel(model, device, lambda p0, p1, _: (1 - alpha) * p0 + alpha * p1) |
|
|
|
|
|
logger.info("***** Running training *****") |
|
logging.info(f" Num examples = {len(train_dataloader.dataset)}") |
|
logging.info(f" Num Epochs = {cfg.TRAIN.max_train_epochs}") |
|
logging.info(f" Instantaneous batch size per device = {cfg.TRAIN.BATCH_SIZE}") |
|
logging.info(f" Total optimization steps = {cfg.TRAIN.max_train_steps}") |
|
|
|
global_step = 0 |
|
|
|
@torch.no_grad() |
|
def validation(target_model: MLD, ema: bool = False) -> tuple: |
|
target_model.denoiser.eval() |
|
val_loss_list = [] |
|
for val_batch in tqdm(val_dataloader): |
|
val_batch = move_batch_to_device(val_batch, device) |
|
val_loss_dict = target_model.allsplit_step(split='val', batch=val_batch) |
|
val_loss_list.append(val_loss_dict) |
|
metrics = target_model.allsplit_epoch_end() |
|
metrics[f"Val/loss"] = sum([d['loss'] for d in val_loss_list]).item() / len(val_dataloader) |
|
metrics[f"Val/diff_loss"] = sum([d['diff_loss'] for d in val_loss_list]).item() / len(val_dataloader) |
|
metrics[f"Val/router_loss"] = sum([d['router_loss'] for d in val_loss_list]).item() / len(val_dataloader) |
|
max_val_rp1 = metrics['Metrics/R_precision_top_1'] |
|
min_val_fid = metrics['Metrics/FID'] |
|
print_table(f'Validation@Step-{global_step}', metrics) |
|
for mk, mv in metrics.items(): |
|
mk = mk + '_EMA' if ema else mk |
|
if cfg.vis == "tb": |
|
writer.add_scalar(mk, mv, global_step=global_step) |
|
elif cfg.vis == "swanlab": |
|
writer.log({mk: mv}, step=global_step) |
|
target_model.denoiser.train() |
|
return max_val_rp1, min_val_fid |
|
|
|
max_rp1, min_fid = validation(model) |
|
if cfg.TRAIN.model_ema: |
|
validation(model_ema.module, ema=True) |
|
|
|
progress_bar = tqdm(range(0, cfg.TRAIN.max_train_steps), desc="Steps") |
|
while True: |
|
for step, batch in enumerate(train_dataloader): |
|
batch = move_batch_to_device(batch, device) |
|
loss_dict = model.allsplit_step('train', batch) |
|
|
|
diff_loss = loss_dict['diff_loss'] |
|
router_loss = loss_dict['router_loss'] |
|
loss = loss_dict['loss'] |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.denoiser.parameters(), cfg.TRAIN.max_grad_norm) |
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
progress_bar.update(1) |
|
global_step += 1 |
|
|
|
if cfg.TRAIN.model_ema and global_step % cfg.TRAIN.model_ema_steps == 0: |
|
model_ema.update_parameters(model) |
|
|
|
if global_step % cfg.TRAIN.checkpointing_steps == 0: |
|
save_path = os.path.join(cfg.output_dir, 'checkpoints', f"checkpoint-{global_step}.ckpt") |
|
ckpt = dict(state_dict=model.state_dict()) |
|
model.on_save_checkpoint(ckpt) |
|
torch.save(ckpt, save_path) |
|
logger.info(f"Saved state to {save_path}") |
|
|
|
if cfg.TRAIN.model_ema: |
|
save_path = os.path.join(cfg.output_dir, 'checkpoints_ema', f"checkpoint-{global_step}.ckpt") |
|
ckpt = dict(state_dict=model_ema.module.state_dict()) |
|
model_ema.module.on_save_checkpoint(ckpt) |
|
torch.save(ckpt, save_path) |
|
logger.info(f"Saved EMA state to {save_path}") |
|
|
|
if global_step % cfg.TRAIN.validation_steps == 0: |
|
cur_rp1, cur_fid = validation(model) |
|
if cfg.TRAIN.model_ema: |
|
validation(model_ema.module, ema=True) |
|
|
|
if cur_rp1 > max_rp1: |
|
max_rp1 = cur_rp1 |
|
save_path = os.path.join(cfg.output_dir, 'checkpoints', |
|
f"checkpoint-{global_step}-rp1-{round(cur_rp1, 3)}.ckpt") |
|
ckpt = dict(state_dict=model.state_dict()) |
|
model.on_save_checkpoint(ckpt) |
|
torch.save(ckpt, save_path) |
|
logger.info(f"Saved state to {save_path} with rp1:{round(cur_rp1, 3)}") |
|
|
|
if cur_fid < min_fid: |
|
min_fid = cur_fid |
|
save_path = os.path.join(cfg.output_dir, 'checkpoints', |
|
f"checkpoint-{global_step}-fid-{round(cur_fid, 3)}.ckpt") |
|
ckpt = dict(state_dict=model.state_dict()) |
|
model.on_save_checkpoint(ckpt) |
|
torch.save(ckpt, save_path) |
|
logger.info(f"Saved state to {save_path} with fid:{round(cur_fid, 3)}") |
|
|
|
logs = {"loss": loss.item(), |
|
"diff_loss": diff_loss.item(), |
|
"router_loss": router_loss.item(), |
|
"lr": lr_scheduler.get_last_lr()[0]} |
|
progress_bar.set_postfix(**logs) |
|
for k, v in logs.items(): |
|
if cfg.vis == "tb": |
|
writer.add_scalar(f"Train/{k}", v, global_step=global_step) |
|
elif cfg.vis == "swanlab": |
|
writer.log({f"Train/{k}": v}, step=global_step) |
|
|
|
if global_step >= cfg.TRAIN.max_train_steps: |
|
save_path = os.path.join(cfg.output_dir, 'checkpoints', f"checkpoint-last.ckpt") |
|
ckpt = dict(state_dict=model.state_dict()) |
|
model.on_save_checkpoint(ckpt) |
|
torch.save(ckpt, save_path) |
|
|
|
if cfg.TRAIN.model_ema: |
|
save_path = os.path.join(cfg.output_dir, 'checkpoints_ema', f"checkpoint-last.ckpt") |
|
ckpt = dict(state_dict=model_ema.module.state_dict()) |
|
model_ema.module.on_save_checkpoint(ckpt) |
|
torch.save(ckpt, save_path) |
|
|
|
exit(0) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|