Spaces:
Runtime error
Runtime error
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py | |
import argparse | |
import logging | |
import os | |
from pathlib import Path | |
import torch | |
from pytorch_lightning import seed_everything | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning.loggers import WandbLogger | |
from pytorch_lightning.strategies import DDPStrategy | |
from AR.data.data_module_librilight_6k import Text2SemanticDataModule | |
from AR.models.t2s_lightning_module import Text2SemanticLightningModule | |
from soundstorm.utils import get_newest_ckpt | |
from soundstorm.utils.io import load_yaml_config | |
logging.getLogger('numba').setLevel(logging.WARNING) | |
logging.getLogger('matplotlib').setLevel(logging.WARNING) | |
torch.set_float32_matmul_precision('high') | |
def main(args): | |
output_dir = Path(args.output_dir) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
ckpt_dir = output_dir / 'ckpt' | |
ckpt_dir.mkdir(parents=True, exist_ok=True) | |
config = load_yaml_config(args.config_file) | |
seed_everything(config["train"]["seed"], workers=True) | |
ckpt_callback: ModelCheckpoint = ModelCheckpoint( | |
save_top_k=-1, | |
save_on_train_epoch_end=False, | |
every_n_train_steps=config["train"]["every_n_train_steps"], | |
dirpath=ckpt_dir) | |
logger = WandbLogger( | |
project="AR_S1_LibriLight", | |
name=output_dir.stem, | |
save_dir=output_dir, | |
# resume the loss curve | |
resume=True, | |
# id='k19kvsq8' | |
) | |
trainer: Trainer = Trainer( | |
max_epochs=config["train"]["epochs"], | |
accelerator='gpu', | |
devices=-1, | |
benchmark=False, | |
fast_dev_run=False, | |
strategy=DDPStrategy(find_unused_parameters=True), | |
precision=config["train"]["precision"], | |
logger=logger, | |
callbacks=[ckpt_callback]) | |
model: Text2SemanticLightningModule = Text2SemanticLightningModule( | |
config, output_dir) | |
data_module: Text2SemanticDataModule = Text2SemanticDataModule( | |
config, | |
train_semantic_dirs=args.train_semantic_dirs, | |
train_phoneme_dirs=args.train_phoneme_dirs, | |
dev_semantic_dirs=args.dev_semantic_dirs, | |
dev_phoneme_dirs=args.dev_phoneme_dirs, | |
train_non_speech_dirs=args.train_non_speech_dirs, | |
dev_non_speech_dirs=args.dev_non_speech_dirs) | |
try: | |
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir)) | |
ckpt_path = ckpt_dir / newest_ckpt_name | |
except Exception: | |
ckpt_path = None | |
print("ckpt_path:", ckpt_path) | |
trainer.fit(model, data_module, ckpt_path=ckpt_path) | |
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--config_file', | |
type=str, | |
default='conf/default.yaml', | |
help='path of config file') | |
# args for dataset | |
parser.add_argument( | |
'--train_semantic_dirs', | |
type=list, | |
nargs='+', | |
default=["dump/small/train/"], | |
help='dirs of train semantic') | |
parser.add_argument( | |
'--train_phoneme_dirs', | |
type=list, | |
nargs='+', | |
default=["dump/small/train/"], | |
help='dirs of train phoneme') | |
parser.add_argument( | |
'--dev_semantic_dirs', | |
type=list, | |
nargs='+', | |
default=["dump/small/dev/"], | |
help='dirs of dev semantic') | |
parser.add_argument( | |
'--dev_phoneme_dirs', | |
type=list, | |
nargs='+', | |
default=["dump/small/dev/"], | |
help='dirs of dev phoneme') | |
parser.add_argument( | |
'--output_dir', | |
type=str, | |
default='exp/default', | |
help='directory to save the results') | |
parser.add_argument( | |
'--train_non_speech_dirs', | |
type=list, | |
nargs='+', | |
default=None, | |
help='dirs of train non_speech data') | |
parser.add_argument( | |
'--dev_non_speech_dirs', | |
type=list, | |
nargs='+', | |
default=None, | |
help='dirs of dev non_speech data') | |
args = parser.parse_args() | |
new_train_semantic_dirs = [] | |
new_train_phoneme_dirs = [] | |
new_dev_semantic_dirs = [] | |
new_dev_phoneme_dirs = [] | |
new_train_non_speech_dirs = [] | |
new_dev_non_speech_dirs = [] | |
# format dataset dirs | |
for item in args.train_semantic_dirs: | |
new_train_semantic_dirs.append(''.join(item)) | |
args.train_semantic_dirs = new_train_semantic_dirs | |
for item in args.train_phoneme_dirs: | |
new_train_phoneme_dirs.append(''.join(item)) | |
args.train_phoneme_dirs = new_train_phoneme_dirs | |
for item in args.dev_semantic_dirs: | |
new_dev_semantic_dirs.append(''.join(item)) | |
args.dev_semantic_dirs = new_dev_semantic_dirs | |
for item in args.dev_phoneme_dirs: | |
new_dev_phoneme_dirs.append(''.join(item)) | |
args.dev_phoneme_dirs = new_dev_phoneme_dirs | |
if args.train_non_speech_dirs is not None: | |
for item in args.train_non_speech_dirs: | |
new_train_non_speech_dirs.append(''.join(item)) | |
args.train_non_speech_dirs = new_train_non_speech_dirs | |
if args.dev_non_speech_dirs is not None: | |
for item in args.dev_non_speech_dirs: | |
new_dev_non_speech_dirs.append(''.join(item)) | |
args.dev_non_speech_dirs = new_dev_non_speech_dirs | |
logging.info(str(args)) | |
main(args) | |