dnnsdunca commited on
Commit
a4d6124
1 Parent(s): be0843c

Create src/train.py

Browse files
Files changed (1) hide show
  1. src/train.py +41 -0
src/train.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Trainer, TrainingArguments
3
+ from model import get_model_and_tokenizer
4
+ from data_loader import get_dataloader
5
+ from utils import load_config, set_seed
6
+
7
+ def main():
8
+ config = load_config('configs/model_config.yaml')
9
+ set_seed(config['training']['seed'])
10
+
11
+ model, tokenizer = get_model_and_tokenizer(config)
12
+ train_dataloader = get_dataloader(config, tokenizer, 'train')
13
+ val_dataloader = get_dataloader(config, tokenizer, 'validation')
14
+
15
+ training_args = TrainingArguments(
16
+ output_dir="./results",
17
+ num_train_epochs=config['training']['num_epochs'],
18
+ per_device_train_batch_size=config['training']['batch_size'],
19
+ per_device_eval_batch_size=config['training']['batch_size'],
20
+ warmup_steps=500,
21
+ weight_decay=0.01,
22
+ logging_dir='./logs',
23
+ logging_steps=100,
24
+ evaluation_strategy="steps",
25
+ eval_steps=1000,
26
+ save_steps=config['training']['save_every'],
27
+ load_best_model_at_end=True,
28
+ )
29
+
30
+ trainer = Trainer(
31
+ model=model,
32
+ args=training_args,
33
+ train_dataset=train_dataloader.dataset,
34
+ eval_dataset=val_dataloader.dataset,
35
+ )
36
+
37
+ trainer.train()
38
+ trainer.save_model("./final_model")
39
+
40
+ if __name__ == "__main__":
41
+ main()