import os import pytorch_lightning as pl import hydra import torch import random import time from os.path import join, basename, exists from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.strategies import DDPStrategy, FSDPStrategy from torch.utils.data import DataLoader from data_module import DataModule from lightning_module import CodecLightningModule from pytorch_lightning.loggers import TensorBoardLogger # Changed from WandbLogger from omegaconf import OmegaConf seed = 1024 seed_everything(seed) @hydra.main(config_path='config', config_name='default', version_base=None) def train(cfg): checkpoint_callback = ModelCheckpoint( dirpath=cfg.log_dir, save_top_k=5, save_last=True, every_n_train_steps=5000, monitor='mel_loss', mode='min' ) lr_monitor = LearningRateMonitor(logging_interval='step') callbacks = [checkpoint_callback, lr_monitor] datamodule = DataModule(cfg) lightning_module = CodecLightningModule(cfg) log_dir_name = os.path.basename(os.path.normpath(cfg.log_dir)) tensorboard_logger = TensorBoardLogger( save_dir=cfg.log_dir, name="", version="", log_graph=False, default_hp_metric=True ) ckpt_path = None last_ckpt = os.path.join(cfg.log_dir, 'last.ckpt') if os.path.exists(last_ckpt): ckpt_path = last_ckpt print(f"Resuming from checkpoint: {ckpt_path}") else: print("No checkpoint found, starting training from scratch.") trainer = pl.Trainer( **cfg.train.trainer, strategy=DDPStrategy(find_unused_parameters=True), callbacks=callbacks, logger=tensorboard_logger, profiler="simple", limit_train_batches=1.0 if not cfg.debug else 0.001 ) torch.backends.cudnn.benchmark = True trainer.fit(lightning_module, datamodule=datamodule, ckpt_path=ckpt_path) print(f'Training ends, best score: {checkpoint_callback.best_model_score}, ckpt path: {checkpoint_callback.best_model_path}') if __name__ == '__main__': torch.multiprocessing.set_start_method('spawn', force=True) train()