|
import pytorch_lightning as pl |
|
|
|
from util.preprocessor import Preprocessor |
|
from model.bilstm import BiLSTM |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint |
|
|
|
if __name__ == "__main__": |
|
pl.seed_everything(42) |
|
|
|
module = Preprocessor(batch_size=64) |
|
num_classes, input_size = module.get_feature_size() |
|
|
|
model = BiLSTM(lr=1e-3, num_classes=num_classes, input_size=input_size) |
|
|
|
checkpoint_callback = ModelCheckpoint(dirpath='./checkpoints/bilstm_result', monitor='val_loss') |
|
logger = TensorBoardLogger('log', name='bilstm_result') |
|
early_stop_callback = EarlyStopping(monitor='val_loss', min_delta=0.00, check_on_train_epoch_end=1, patience=10) |
|
|
|
trainer = pl.Trainer( |
|
accelerator='gpu', |
|
max_epochs=100, |
|
default_root_dir='./checkpoints/bilstm_result', |
|
callbacks = [checkpoint_callback, early_stop_callback], |
|
deterministic=True, |
|
logger=logger) |
|
|
|
trainer.fit(model=model, datamodule=module) |
|
trainer.test(model=model, datamodule=module, ckpt_path='best') |