File size: 825 Bytes
3f8d76d
 
96c4203
410b92f
c015c4c
aef2f7d
c015c4c
 
 
 
3f8d76d
 
 
c015c4c
3f8d76d
 
c015c4c
 
3f8d76d
 
 
 
 
 
 
 
0ee5810
aef2f7d
0ee5810
aef2f7d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import yaml

from src.models.model import Summarization
import pandas as pd


def train_model():
    """
    Train the model
    """
    with open("params.yml") as f:
        params = yaml.safe_load(f)

    # Load the data
    train_df = pd.read_csv('data/processed/train.csv')
    eval_df = pd.read_csv('data/processed/validation.csv')

    model = Summarization()
    model.from_pretrained(model_type=params['model_type'], model_name=params['model_name'])

    model.train(train_df=train_df, eval_df=eval_df,
                batch_size=params['batch_size'], max_epochs=params['max_epoch'],
                use_gpu=params['use_gpu'], learning_rate=params['learning_rate'],
                num_workers=params['num_workers'])

    model.save_model(model_dir=params['model_dir'])


if __name__ == '__main__':
    train_model()