File size: 951 Bytes
3f8d76d
 
9d5ed04
410b92f
c015c4c
aef2f7d
c015c4c
 
 
 
d5a6d18
3f8d76d
 
c015c4c
c6e4955
 
c015c4c
d5a6d18
 
698a370
c015c4c
c6e4955
 
 
3f8d76d
c6e4955
 
 
 
 
 
 
 
 
3f8d76d
c6e4955
0ee5810
aef2f7d
c6e4955
aef2f7d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import yaml

from model import Summarization
import pandas as pd


def train_model():
    """
    Train the model
    """
    with open("model_params.yml") as f:
        params = yaml.safe_load(f)

    # Load the data
    train_df = pd.read_csv("data/processed/train.csv")
    eval_df = pd.read_csv("data/processed/validation.csv")

    train_df = train_df.sample(random_state=1)
    eval_df = eval_df.sample(random_state=1)

    model = Summarization()
    model.from_pretrained(
        model_type=params["model_type"], model_name=params["model_name"]
    )

    model.train(
        train_df=train_df,
        eval_df=eval_df,
        batch_size=params["batch_size"],
        max_epochs=params["epochs"],
        use_gpu=params["use_gpu"],
        learning_rate=float(params["learning_rate"]),
        num_workers=int(params["num_workers"]),
    )

    model.save_model(model_dir=params["model_dir"])


if __name__ == "__main__":
    train_model()