File size: 481 Bytes
96c4203
 
c015c4c
 
 
 
 
 
 
0ee5810
c015c4c
 
 
 
0ee5810
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from src.models.model import Summarization
from src.data.make_dataset import make_dataset

def train_model():
    """
    Train the model
    """
    # Load the data
    train_df = make_dataset(split = 'train')
    eval_df = make_dataset(split = 'val')

    model = Summarization()
    model.from_pretrained('t5-base')
    model.train(train_df=train_df, eval_df=eval_df, batch_size=4, max_epochs=3, use_gpu=True)
    model.save_model()

if __name__ == '__main__':
    train_model()