File size: 2,808 Bytes
783053f
 
 
b63fd37
783053f
 
 
 
 
 
 
b63fd37
783053f
 
 
b63fd37
 
 
 
 
783053f
 
 
 
 
 
 
 
 
b63fd37
 
783053f
 
 
 
 
 
b63fd37
 
 
783053f
 
 
 
 
 
b63fd37
783053f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

from fake_face_detection.metrics.compute_metrics import compute_metrics
from fake_face_detection.data.collator import fake_face_collator
from transformers import Trainer, TrainingArguments, set_seed
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from typing import *
import numpy as np
import json
import os

def train(epochs: int, output_dir: str, config: dict, model: nn.Module, trainer, get_datasets: Callable, log_dir: str = "fake_face_logs", metric = 'accuracy', seed: int = 0):
    
    print("------------------------- Beginning of training")
    
    set_seed(seed)
    
    # initialize the model
    model = model()
    
    # reformat the config integer type
    for key, value in config.items():
        
        if isinstance(value, np.int32): config[key] = int(value)
    
    pretty = json.dumps(config, indent = 4)
    
    print(f"Current Config: \n {pretty}")
    
    print(f"Checkpoints in {output_dir}")
    
    # recuperate the dataset
    train_dataset, test_dataset = get_datasets(config['h_flip_p'], config['v_flip_p'], config['gray_scale_p'], config['rotation'])
    
    # initialize the arguments of the training
    training_args = TrainingArguments(output_dir,
                                      per_device_train_batch_size=config['batch_size'],
                                      evaluation_strategy='epoch',
                                      save_strategy='epoch',
                                      logging_strategy='epoch',
                                      num_train_epochs=epochs,
                                      fp16=True,
                                      save_total_limit=2,
                                      push_to_hub=False,
                                      logging_dir=os.path.join(log_dir, os.path.basename(output_dir)),
                                      load_best_model_at_end=True,
                                      learning_rate=config['lr']
                                      )
    
    # train the model
    trainer_ = trainer(
        model = model,
        args = training_args,
        data_collator = fake_face_collator,
        compute_metrics = compute_metrics,
        train_dataset = train_dataset,
        eval_dataset = test_dataset
    )
    
    # train the model
    trainer_.train()
    
    # evaluate the model and recuperate metrics
    metrics = trainer_.evaluate(test_dataset)
    
    # add metrics and config to the hyperparameter panel of tensorboard
    with SummaryWriter(os.path.join(log_dir, 'hparams')) as logger:
        
        logger.add_hparams(
            config, metrics
        )
    
    print(metrics)
    
    print("------------------------- End of training")
    # recuperate the metric to evaluate
    return metrics[f'eval_{metric}']