fine-tuning-service / src /text_classification /text_classification_trainer.py
fashxp's picture
license
8a35bc0
# -------------------------------------------------------------------
# Pimcore
#
# This source file is available under two different licenses:
# - GNU General Public License version 3 (GPLv3)
# - Pimcore Commercial License (PCL)
# Full copyright and license information is available in
# LICENSE.md which is distributed with this source code.
#
# @copyright Copyright (c) Pimcore GmbH (http://www.pimcore.org)
# @license http://www.pimcore.org/license GPLv3 and PCL
# -------------------------------------------------------------------
import logging
from ..progress_callback import ProgressCallback
from ..abstract_trainer import AbstractTrainer
from ..environment_variable_checker import EnvironmentVariableChecker
from .text_classification_parameters import TextClassificationParameters
import shutil
import os
from datasets import load_dataset
from transformers import DataCollatorWithPadding, AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from huggingface_hub import HfFolder
import evaluate
import numpy as np
from typing import Tuple
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class TextClassificationTrainer(AbstractTrainer):
def start_training(self, parameters: TextClassificationParameters):
logger.info('Start Training...')
try:
task = 'Load and prepare training data'
self.get_status().update_status(0, task, parameters.get_project_name())
logger.info(task)
tokenized_dataset, labels, label2id, id2label = self.__prepare_training_data(parameters)
if(self.get_status().is_training_aborted()):
return
task = 'Start training model'
self.get_status().update_status(10, task)
logger.info(task)
self.__train_model(tokenized_dataset, labels, label2id, id2label, parameters)
if(self.get_status().is_training_aborted()):
return
self.get_status().update_status(100, "Training completed")
except Exception as e:
logger.error(e)
self.get_status().finalize_abort_training(str(e))
raise RuntimeError(f"An error occurred: {str(e)}")
finally:
# Cleanup after processing
logger.info('Cleaning up training files after training')
shutil.rmtree(os.path.dirname(parameters.get_training_csv_file_path()))
if(self.get_status().is_training_aborted()):
self.get_status().finalize_abort_training("Training aborted")
def __prepare_training_data(self, parameters: TextClassificationParameters) -> Tuple[dict, dict, dict, dict]:
dataset = load_dataset('csv', data_files=parameters.get_training_csv_file_path(), delimiter=parameters.get_training_csv_limiter())
dataset = dataset["train"]
# Extract the labels
#labels = tokenized_dataset['train'].unique('target')
labels = dataset.unique('target')
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = i
id2label[i] = label
logger.info(id2label)
dataset = dataset.train_test_split(test_size=0.2)
logger.info(dataset)
logger.info(dataset["train"][10])
# Tokenize the value column
tokenizer = AutoTokenizer.from_pretrained(parameters.get_source_model_name())
def preprocess_function(examples):
return tokenizer(examples["value"], truncation=True, padding='max_length')
tokenized_dataset = dataset.map(preprocess_function, batched=True)
# Rename the Target column to labels and remove unnecessary columns
tokenized_dataset = tokenized_dataset.rename_column('target', 'labels')
# Columns to keep
columns_to_keep = ['input_ids', 'labels', 'attention_mask']
all_columns = tokenized_dataset["train"].column_names
columns_to_remove = [col for col in all_columns if col not in columns_to_keep]
tokenized_dataset = tokenized_dataset.remove_columns(columns_to_remove)
# Map labels to numeric ids
def map_labels(example):
example['labels'] = label2id[example['labels']]
return example
tokenized_dataset = tokenized_dataset.map(map_labels)
logger.info(tokenized_dataset)
logger.info(tokenized_dataset["train"][10])
return tokenized_dataset, labels, label2id, id2label
def __train_model(self, tokenized_dataset: dict, labels: dict, label2id: dict, id2label: dict, parameters: TextClassificationParameters):
tokenizer = AutoTokenizer.from_pretrained(parameters.get_source_model_name())
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
environment_variable_checker = EnvironmentVariableChecker()
HfFolder.save_token(environment_variable_checker.get_huggingface_token())
progressCallback = ProgressCallback(self.get_status(), 11, 89)
# Evaluate and metrics
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
# train the model
model = AutoModelForSequenceClassification.from_pretrained(
parameters.get_source_model_name(),
num_labels=len(labels),
id2label=id2label,
label2id=label2id
)
target_model_id = environment_variable_checker.get_huggingface_organization() + '/' + parameters.get_result_model_name()
training_args = TrainingArguments(
output_dir=parameters.get_result_model_name(),
hub_model_id=target_model_id,
learning_rate=parameters.get_training_parameters().learning_rate,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=parameters.get_training_parameters().epochs,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=False,
remove_unused_columns=False,
hub_private_repo=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[progressCallback]
)
if(self.get_status().is_training_aborted()):
return
trainer.train()
if(self.get_status().is_training_aborted()):
return
logger.info(f"Model trained, start uploading")
self.get_status().update_status(90, f"Uploading model to Hugging Face")
trainer.push_to_hub()