File size: 1,115 Bytes
74b29ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import pytorch_lightning as pl

from util.preprocessor import Preprocessor
from model.bilstm import BiLSTM
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

if __name__ == "__main__":
    pl.seed_everything(42)

    module = Preprocessor(batch_size=64)
    num_classes, input_size = module.get_feature_size()
    
    model = BiLSTM(lr=1e-3, num_classes=num_classes, input_size=input_size)

    checkpoint_callback = ModelCheckpoint(dirpath='./checkpoints/bilstm_result', monitor='val_loss')
    logger = TensorBoardLogger('log', name='bilstm_result')
    early_stop_callback = EarlyStopping(monitor='val_loss', min_delta=0.00, check_on_train_epoch_end=1, patience=10)

    trainer = pl.Trainer(
        accelerator='gpu',
        max_epochs=100,
        default_root_dir='./checkpoints/bilstm_result',
        callbacks = [checkpoint_callback, early_stop_callback],
        deterministic=True,
        logger=logger)
    
    trainer.fit(model=model, datamodule=module)
    trainer.test(model=model, datamodule=module, ckpt_path='best')