Spaces:
Build error
Build error
from fake_face_detection.metrics.compute_metrics import compute_metrics | |
from fake_face_detection.data.collator import fake_face_collator | |
from transformers import Trainer, TrainingArguments, set_seed | |
from torch.utils.tensorboard import SummaryWriter | |
from torch import nn | |
from typing import * | |
import numpy as np | |
import json | |
import os | |
def train(epochs: int, output_dir: str, config: dict, model: nn.Module, trainer, get_datasets: Callable, log_dir: str = "fake_face_logs", metric = 'accuracy', seed: int = 0): | |
print("------------------------- Beginning of training") | |
set_seed(seed) | |
# initialize the model | |
model = model() | |
# reformat the config integer type | |
for key, value in config.items(): | |
if isinstance(value, np.int32): config[key] = int(value) | |
pretty = json.dumps(config, indent = 4) | |
print(f"Current Config: \n {pretty}") | |
print(f"Checkpoints in {output_dir}") | |
# recuperate the dataset | |
train_dataset, test_dataset = get_datasets(config['h_flip_p'], config['v_flip_p'], config['gray_scale_p'], config['rotation']) | |
# initialize the arguments of the training | |
training_args = TrainingArguments(output_dir, | |
per_device_train_batch_size=config['batch_size'], | |
evaluation_strategy='epoch', | |
save_strategy='epoch', | |
logging_strategy='epoch', | |
num_train_epochs=epochs, | |
fp16=True, | |
save_total_limit=2, | |
push_to_hub=False, | |
logging_dir=os.path.join(log_dir, os.path.basename(output_dir)), | |
load_best_model_at_end=True, | |
learning_rate=config['lr'] | |
) | |
# train the model | |
trainer_ = trainer( | |
model = model, | |
args = training_args, | |
data_collator = fake_face_collator, | |
compute_metrics = compute_metrics, | |
train_dataset = train_dataset, | |
eval_dataset = test_dataset | |
) | |
# train the model | |
trainer_.train() | |
# evaluate the model and recuperate metrics | |
metrics = trainer_.evaluate(test_dataset) | |
# add metrics and config to the hyperparameter panel of tensorboard | |
with SummaryWriter(os.path.join(log_dir, 'hparams')) as logger: | |
logger.add_hparams( | |
config, metrics | |
) | |
print(metrics) | |
print("------------------------- End of training") | |
# recuperate the metric to evaluate | |
return metrics[f'eval_{metric}'] | |