Spaces:
Sleeping
Sleeping
# ------------------------------------------------------------------- | |
# Pimcore | |
# | |
# This source file is available under two different licenses: | |
# - GNU General Public License version 3 (GPLv3) | |
# - Pimcore Commercial License (PCL) | |
# Full copyright and license information is available in | |
# LICENSE.md which is distributed with this source code. | |
# | |
# @copyright Copyright (c) Pimcore GmbH (http://www.pimcore.org) | |
# @license http://www.pimcore.org/license GPLv3 and PCL | |
# ------------------------------------------------------------------- | |
import logging | |
from ..progress_callback import ProgressCallback | |
from ..abstract_trainer import AbstractTrainer | |
from ..environment_variable_checker import EnvironmentVariableChecker | |
from .text_classification_parameters import TextClassificationParameters | |
import shutil | |
import os | |
from datasets import load_dataset | |
from transformers import DataCollatorWithPadding, AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer | |
from huggingface_hub import HfFolder | |
import evaluate | |
import numpy as np | |
from typing import Tuple | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
class TextClassificationTrainer(AbstractTrainer): | |
def start_training(self, parameters: TextClassificationParameters): | |
logger.info('Start Training...') | |
try: | |
task = 'Load and prepare training data' | |
self.get_status().update_status(0, task, parameters.get_project_name()) | |
logger.info(task) | |
tokenized_dataset, labels, label2id, id2label = self.__prepare_training_data(parameters) | |
if(self.get_status().is_training_aborted()): | |
return | |
task = 'Start training model' | |
self.get_status().update_status(10, task) | |
logger.info(task) | |
self.__train_model(tokenized_dataset, labels, label2id, id2label, parameters) | |
if(self.get_status().is_training_aborted()): | |
return | |
self.get_status().update_status(100, "Training completed") | |
except Exception as e: | |
logger.error(e) | |
self.get_status().finalize_abort_training(str(e)) | |
raise RuntimeError(f"An error occurred: {str(e)}") | |
finally: | |
# Cleanup after processing | |
logger.info('Cleaning up training files after training') | |
shutil.rmtree(os.path.dirname(parameters.get_training_csv_file_path())) | |
if(self.get_status().is_training_aborted()): | |
self.get_status().finalize_abort_training("Training aborted") | |
def __prepare_training_data(self, parameters: TextClassificationParameters) -> Tuple[dict, dict, dict, dict]: | |
dataset = load_dataset('csv', data_files=parameters.get_training_csv_file_path(), delimiter=parameters.get_training_csv_limiter()) | |
dataset = dataset["train"] | |
# Extract the labels | |
#labels = tokenized_dataset['train'].unique('target') | |
labels = dataset.unique('target') | |
label2id, id2label = dict(), dict() | |
for i, label in enumerate(labels): | |
label2id[label] = i | |
id2label[i] = label | |
logger.info(id2label) | |
dataset = dataset.train_test_split(test_size=0.2) | |
logger.info(dataset) | |
logger.info(dataset["train"][10]) | |
# Tokenize the value column | |
tokenizer = AutoTokenizer.from_pretrained(parameters.get_source_model_name()) | |
def preprocess_function(examples): | |
return tokenizer(examples["value"], truncation=True, padding='max_length') | |
tokenized_dataset = dataset.map(preprocess_function, batched=True) | |
# Rename the Target column to labels and remove unnecessary columns | |
tokenized_dataset = tokenized_dataset.rename_column('target', 'labels') | |
# Columns to keep | |
columns_to_keep = ['input_ids', 'labels', 'attention_mask'] | |
all_columns = tokenized_dataset["train"].column_names | |
columns_to_remove = [col for col in all_columns if col not in columns_to_keep] | |
tokenized_dataset = tokenized_dataset.remove_columns(columns_to_remove) | |
# Map labels to numeric ids | |
def map_labels(example): | |
example['labels'] = label2id[example['labels']] | |
return example | |
tokenized_dataset = tokenized_dataset.map(map_labels) | |
logger.info(tokenized_dataset) | |
logger.info(tokenized_dataset["train"][10]) | |
return tokenized_dataset, labels, label2id, id2label | |
def __train_model(self, tokenized_dataset: dict, labels: dict, label2id: dict, id2label: dict, parameters: TextClassificationParameters): | |
tokenizer = AutoTokenizer.from_pretrained(parameters.get_source_model_name()) | |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
environment_variable_checker = EnvironmentVariableChecker() | |
HfFolder.save_token(environment_variable_checker.get_huggingface_token()) | |
progressCallback = ProgressCallback(self.get_status(), 11, 89) | |
# Evaluate and metrics | |
accuracy = evaluate.load("accuracy") | |
def compute_metrics(eval_pred): | |
predictions, labels = eval_pred | |
predictions = np.argmax(predictions, axis=1) | |
return accuracy.compute(predictions=predictions, references=labels) | |
# train the model | |
model = AutoModelForSequenceClassification.from_pretrained( | |
parameters.get_source_model_name(), | |
num_labels=len(labels), | |
id2label=id2label, | |
label2id=label2id | |
) | |
target_model_id = environment_variable_checker.get_huggingface_organization() + '/' + parameters.get_result_model_name() | |
training_args = TrainingArguments( | |
output_dir=parameters.get_result_model_name(), | |
hub_model_id=target_model_id, | |
learning_rate=parameters.get_training_parameters().learning_rate, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=16, | |
num_train_epochs=parameters.get_training_parameters().epochs, | |
weight_decay=0.01, | |
eval_strategy="epoch", | |
save_strategy="epoch", | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy", | |
push_to_hub=False, | |
remove_unused_columns=False, | |
hub_private_repo=True, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_dataset["train"], | |
eval_dataset=tokenized_dataset["test"], | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
compute_metrics=compute_metrics, | |
callbacks=[progressCallback] | |
) | |
if(self.get_status().is_training_aborted()): | |
return | |
trainer.train() | |
if(self.get_status().is_training_aborted()): | |
return | |
logger.info(f"Model trained, start uploading") | |
self.get_status().update_status(90, f"Uploading model to Hugging Face") | |
trainer.push_to_hub() |