|
|
|
import os |
|
import pdb |
|
|
|
if "_CUDA_VISIBLE_DEVICES" in os.environ: |
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] |
|
import argparse |
|
import logging |
|
from pathlib import Path |
|
|
|
import torch, platform |
|
from pytorch_lightning import seed_everything |
|
from pytorch_lightning import Trainer |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
from pytorch_lightning.strategies import DDPStrategy |
|
from AR.data.data_module import Text2SemanticDataModule |
|
from AR.models.t2s_lightning_module import Text2SemanticLightningModule |
|
from AR.utils.io import load_yaml_config |
|
|
|
logging.getLogger("numba").setLevel(logging.WARNING) |
|
logging.getLogger("matplotlib").setLevel(logging.WARNING) |
|
torch.set_float32_matmul_precision("high") |
|
from AR.utils import get_newest_ckpt |
|
|
|
from collections import OrderedDict |
|
from time import time as ttime |
|
import shutil |
|
def my_save(fea,path): |
|
dir=os.path.dirname(path) |
|
name=os.path.basename(path) |
|
tmp_path="%s.pth"%(ttime()) |
|
torch.save(fea,tmp_path) |
|
shutil.move(tmp_path,"%s/%s"%(dir,name)) |
|
|
|
|
|
class my_model_ckpt(ModelCheckpoint): |
|
def __init__( |
|
self, |
|
config, |
|
if_save_latest, |
|
if_save_every_weights, |
|
half_weights_save_dir, |
|
exp_name, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.if_save_latest = if_save_latest |
|
self.if_save_every_weights = if_save_every_weights |
|
self.half_weights_save_dir = half_weights_save_dir |
|
self.exp_name = exp_name |
|
self.config = config |
|
|
|
def on_train_epoch_end(self, trainer, pl_module): |
|
|
|
if self._should_save_on_train_epoch_end(trainer): |
|
monitor_candidates = self._monitor_candidates(trainer) |
|
if ( |
|
self._every_n_epochs >= 1 |
|
and (trainer.current_epoch + 1) % self._every_n_epochs == 0 |
|
): |
|
if ( |
|
self.if_save_latest == True |
|
): |
|
to_clean = list(os.listdir(self.dirpath)) |
|
self._save_topk_checkpoint(trainer, monitor_candidates) |
|
if self.if_save_latest == True: |
|
for name in to_clean: |
|
try: |
|
os.remove("%s/%s" % (self.dirpath, name)) |
|
except: |
|
pass |
|
if self.if_save_every_weights == True: |
|
to_save_od = OrderedDict() |
|
to_save_od["weight"] = OrderedDict() |
|
dictt = trainer.strategy._lightning_module.state_dict() |
|
for key in dictt: |
|
to_save_od["weight"][key] = dictt[key].half() |
|
to_save_od["config"] = self.config |
|
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1) |
|
|
|
my_save( |
|
to_save_od, |
|
"%s/%s-e%s.ckpt" |
|
% ( |
|
self.half_weights_save_dir, |
|
self.exp_name, |
|
trainer.current_epoch + 1, |
|
), |
|
) |
|
self._save_last_checkpoint(trainer, monitor_candidates) |
|
|
|
|
|
def main(args): |
|
config = load_yaml_config(args.config_file) |
|
|
|
output_dir = Path(config["output_dir"]) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
ckpt_dir = output_dir / "ckpt" |
|
ckpt_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
seed_everything(config["train"]["seed"], workers=True) |
|
ckpt_callback: ModelCheckpoint = my_model_ckpt( |
|
config=config, |
|
if_save_latest=config["train"]["if_save_latest"], |
|
if_save_every_weights=config["train"]["if_save_every_weights"], |
|
half_weights_save_dir=config["train"]["half_weights_save_dir"], |
|
exp_name=config["train"]["exp_name"], |
|
save_top_k=-1, |
|
monitor="top_3_acc", |
|
mode="max", |
|
save_on_train_epoch_end=True, |
|
every_n_epochs=config["train"]["save_every_n_epoch"], |
|
dirpath=ckpt_dir, |
|
) |
|
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir) |
|
os.environ["MASTER_ADDR"]="localhost" |
|
trainer: Trainer = Trainer( |
|
max_epochs=config["train"]["epochs"], |
|
accelerator="gpu" if torch.cuda.is_available() else "cpu", |
|
|
|
|
|
limit_val_batches=0, |
|
devices=-1 if torch.cuda.is_available() else 1, |
|
benchmark=False, |
|
fast_dev_run=False, |
|
strategy = DDPStrategy( |
|
process_group_backend="nccl" if platform.system() != "Windows" else "gloo" |
|
) if torch.cuda.is_available() else "auto", |
|
precision=config["train"]["precision"], |
|
logger=logger, |
|
num_sanity_val_steps=0, |
|
callbacks=[ckpt_callback], |
|
) |
|
|
|
model: Text2SemanticLightningModule = Text2SemanticLightningModule( |
|
config, output_dir |
|
) |
|
|
|
data_module: Text2SemanticDataModule = Text2SemanticDataModule( |
|
config, |
|
train_semantic_path=config["train_semantic_path"], |
|
train_phoneme_path=config["train_phoneme_path"], |
|
|
|
|
|
) |
|
|
|
try: |
|
|
|
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir)) |
|
ckpt_path = ckpt_dir / newest_ckpt_name |
|
except Exception: |
|
ckpt_path = None |
|
print("ckpt_path:", ckpt_path) |
|
trainer.fit(model, data_module, ckpt_path=ckpt_path) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-c", |
|
"--config_file", |
|
type=str, |
|
default="configs/s1longer.yaml", |
|
help="path of config file", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
logging.info(str(args)) |
|
main(args) |
|
|