=
adding the best model to hugging face
b63fd37
from fake_face_detection.metrics.compute_metrics import compute_metrics
from fake_face_detection.data.collator import fake_face_collator
from transformers import Trainer, TrainingArguments, set_seed
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from typing import *
import numpy as np
import json
import os
def train(epochs: int, output_dir: str, config: dict, model: nn.Module, trainer, get_datasets: Callable, log_dir: str = "fake_face_logs", metric = 'accuracy', seed: int = 0):
print("------------------------- Beginning of training")
set_seed(seed)
# initialize the model
model = model()
# reformat the config integer type
for key, value in config.items():
if isinstance(value, np.int32): config[key] = int(value)
pretty = json.dumps(config, indent = 4)
print(f"Current Config: \n {pretty}")
print(f"Checkpoints in {output_dir}")
# recuperate the dataset
train_dataset, test_dataset = get_datasets(config['h_flip_p'], config['v_flip_p'], config['gray_scale_p'], config['rotation'])
# initialize the arguments of the training
training_args = TrainingArguments(output_dir,
per_device_train_batch_size=config['batch_size'],
evaluation_strategy='epoch',
save_strategy='epoch',
logging_strategy='epoch',
num_train_epochs=epochs,
fp16=True,
save_total_limit=2,
push_to_hub=False,
logging_dir=os.path.join(log_dir, os.path.basename(output_dir)),
load_best_model_at_end=True,
learning_rate=config['lr']
)
# train the model
trainer_ = trainer(
model = model,
args = training_args,
data_collator = fake_face_collator,
compute_metrics = compute_metrics,
train_dataset = train_dataset,
eval_dataset = test_dataset
)
# train the model
trainer_.train()
# evaluate the model and recuperate metrics
metrics = trainer_.evaluate(test_dataset)
# add metrics and config to the hyperparameter panel of tensorboard
with SummaryWriter(os.path.join(log_dir, 'hparams')) as logger:
logger.add_hparams(
config, metrics
)
print(metrics)
print("------------------------- End of training")
# recuperate the metric to evaluate
return metrics[f'eval_{metric}']