import os import numpy as np import torch import logging from pathlib import Path from pytorch_lightning import LightningModule from os.path import join as pjoin from collections import OrderedDict from mGPT.metrics import BaseMetrics from mGPT.config import get_obj_from_str class BaseModel(LightningModule): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # self.configure_metrics() # Ablation self.test_step_outputs = [] self.times = [] self.rep_i = 0 def training_step(self, batch, batch_idx): return self.allsplit_step("train", batch, batch_idx) def validation_step(self, batch, batch_idx): return self.allsplit_step("val", batch, batch_idx) def test_step(self, batch, batch_idx): outputs = self.allsplit_step("test", batch, batch_idx) self.test_step_outputs.append(outputs) return outputs def predict_step(self, batch, batch_idx): return self.forward(batch) def on_train_epoch_end(self): # Log steps and losses dico = self.step_log_dict() # Log losses dico.update(self.loss_log_dict('train')) # Write to log only if not sanity check if not self.trainer.sanity_checking: self.log_dict(dico, sync_dist=True, rank_zero_only=True) def on_validation_epoch_end(self): # Log steps and losses dico = self.step_log_dict() # Log losses dico.update(self.loss_log_dict('train')) dico.update(self.loss_log_dict('val')) # Log metrics dico.update(self.metrics_log_dict()) # Write to log only if not sanity check if not self.trainer.sanity_checking: self.log_dict(dico, sync_dist=True, rank_zero_only=True) def on_test_epoch_end(self): # Log metrics dico = self.metrics_log_dict() # Write to log only if not sanity check if not self.trainer.sanity_checking: self.log_dict(dico, sync_dist=True, rank_zero_only=True) self.save_npy(self.test_step_outputs) self.rep_i = self.rep_i + 1 # Free up the memory self.test_step_outputs.clear() def preprocess_state_dict(self, state_dict): new_state_dict = OrderedDict() # metric_state_dict = self.metrics.state_dict() loss_state_dict = self._losses.state_dict() # for k, v in metric_state_dict.items(): # new_state_dict['metrics.' + k] = v for k, v in loss_state_dict.items(): new_state_dict['_losses.' + k] = v for k, v in state_dict.items(): if '_losses' not in k and 'Metrics' not in k: new_state_dict[k] = v return new_state_dict def load_state_dict(self, state_dict, strict=True): new_state_dict = self.preprocess_state_dict(state_dict) super().load_state_dict(new_state_dict, strict) def step_log_dict(self): return { "epoch": float(self.trainer.current_epoch), "step": float(self.trainer.current_epoch) } def loss_log_dict(self, split: str): losses = self._losses['losses_' + split] loss_dict = losses.compute(split) return loss_dict def metrics_log_dict(self): # For TM2TMetrics MM if self.trainer.datamodule.is_mm and "TM2TMetrics" in self.hparams.metrics_dict: metrics_dicts = ['MMMetrics'] else: metrics_dicts = self.hparams.metrics_dict # Compute all metrics metrics_log_dict = {} for metric in metrics_dicts: metrics_dict = getattr( self.metrics, metric).compute(sanity_flag=self.trainer.sanity_checking) metrics_log_dict.update({ f"Metrics/{metric}": value.item() for metric, value in metrics_dict.items() }) return metrics_log_dict def configure_optimizers(self): # Optimizer optim_target = self.hparams.cfg.TRAIN.OPTIM.target if len(optim_target.split('.')) == 1: optim_target = 'torch.optim.' + optim_target optimizer = get_obj_from_str(optim_target)( params=self.parameters(), **self.hparams.cfg.TRAIN.OPTIM.params) # Scheduler scheduler_target = self.hparams.cfg.TRAIN.LR_SCHEDULER.target if len(scheduler_target.split('.')) == 1: scheduler_target = 'torch.optim.lr_scheduler.' + scheduler_target lr_scheduler = get_obj_from_str(scheduler_target)( optimizer=optimizer, **self.hparams.cfg.TRAIN.LR_SCHEDULER.params) return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler} def configure_metrics(self): self.metrics = BaseMetrics(datamodule=self.datamodule, **self.hparams) def save_npy(self, outputs): cfg = self.hparams.cfg output_dir = Path( os.path.join( cfg.FOLDER, str(cfg.model.target.split('.')[-2].lower()), str(cfg.NAME), "samples_" + cfg.TIME, )) if cfg.TEST.SAVE_PREDICTIONS: lengths = [i[1] for i in outputs] outputs = [i[0] for i in outputs] if cfg.TEST.DATASETS[0].lower() in ["humanml3d", "kit"]: keyids = self.trainer.datamodule.test_dataset.name_list for i in range(len(outputs)): for bid in range( min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])): keyid = keyids[i * cfg.TEST.BATCH_SIZE + bid] data = self.trainer.datamodule.test_dataset.data_dict[ keyid] motion = torch.tensor(data['motion'], device=outputs[i].device) motion = self.datamodule.normalize(motion) length = data['length'] text_list = data['text'] gen_joints = outputs[i][bid][:lengths[i][bid]].cpu( ).numpy() if cfg.TEST.REPLICATION_TIMES > 1: name = f"{keyid}.npy" else: name = f"{keyid}.npy" # save predictions results npypath = output_dir / name np.save(npypath, gen_joints) npypath = output_dir / f"{keyid}_gt.npy" joints = self.feats2joints(motion).cpu().numpy() np.save(npypath, joints) with open(output_dir / f"{keyid}.txt", "a") as f: for text in text_list: f.write(f"{text['caption']}\n") elif cfg.TEST.DATASETS[0].lower() in ["humanact12", "uestc"]: keyids = range(len(self.trainer.datamodule.test_dataset)) for i in range(len(outputs)): for bid in range( min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])): keyid = keyids[i * cfg.TEST.BATCH_SIZE + bid] gen_joints = outputs[i][bid].cpu() gen_joints = gen_joints.permute(2, 0, 1)[:lengths[i][bid], ...].numpy() if cfg.TEST.REPLICATION_TIMES > 1: name = f"{keyid}_{self.rep_i}" else: name = f"{keyid}.npy" # save predictions results npypath = output_dir / name np.save(npypath, gen_joints)