File size: 2,358 Bytes
59b7eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
 
import pytorch_lightning as pl
import hydra
import torch
import random
import time
from os.path import join, basename, exists
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.strategies import DDPStrategy, FSDPStrategy
from torch.utils.data import DataLoader
from data_module import DataModule
from lightning_module import CodecLightningModule
from pytorch_lightning.loggers import TensorBoardLogger  # Changed from WandbLogger
from omegaconf import OmegaConf
 
seed = 1024
seed_everything(seed)
 
@hydra.main(config_path='config', config_name='default', version_base=None)
def train(cfg):
    checkpoint_callback = ModelCheckpoint(
        dirpath=cfg.log_dir,
        save_top_k=5, 
        save_last=True,
        every_n_train_steps=5000, 
        monitor='mel_loss', 
        mode='min'
    )
    lr_monitor = LearningRateMonitor(logging_interval='step')
    callbacks = [checkpoint_callback, lr_monitor]
    
    datamodule = DataModule(cfg)
    lightning_module = CodecLightningModule(cfg)
    
    log_dir_name = os.path.basename(os.path.normpath(cfg.log_dir))
    

    tensorboard_logger = TensorBoardLogger(
        save_dir=cfg.log_dir,
        name="",  
        version="",  
        log_graph=False,
        default_hp_metric=True
    )
    
    ckpt_path = None
    last_ckpt = os.path.join(cfg.log_dir, 'last.ckpt')
    if os.path.exists(last_ckpt):
        ckpt_path = last_ckpt
        print(f"Resuming from checkpoint: {ckpt_path}")
    else:
        print("No checkpoint found, starting training from scratch.")
    
    trainer = pl.Trainer(
        **cfg.train.trainer,
        strategy=DDPStrategy(find_unused_parameters=True),
        callbacks=callbacks,
        logger=tensorboard_logger, 
        profiler="simple",
        limit_train_batches=1.0 if not cfg.debug else 0.001
    )
    
    torch.backends.cudnn.benchmark = True  
    
    trainer.fit(lightning_module, datamodule=datamodule, ckpt_path=ckpt_path)
    
    print(f'Training ends, best score: {checkpoint_callback.best_model_score}, ckpt path: {checkpoint_callback.best_model_path}')

if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn', force=True)
    train()