fine-tuning-service / src /image_classification /image_classification_trainer.py
fashxp's picture
license
8a35bc0
# -------------------------------------------------------------------
# Pimcore
#
# This source file is available under two different licenses:
# - GNU General Public License version 3 (GPLv3)
# - Pimcore Commercial License (PCL)
# Full copyright and license information is available in
# LICENSE.md which is distributed with this source code.
#
# @copyright Copyright (c) Pimcore GmbH (http://www.pimcore.org)
# @license http://www.pimcore.org/license GPLv3 and PCL
# -------------------------------------------------------------------
import logging
from ..progress_callback import ProgressCallback
from ..abstract_trainer import AbstractTrainer
from ..environment_variable_checker import EnvironmentVariableChecker
from .image_classification_parameters import ImageClassificationParameters
import zipfile
import os
import shutil
from datasets import load_dataset
from transformers import AutoImageProcessor, DefaultDataCollator, AutoModelForImageClassification, TrainingArguments, Trainer, TrainerState, TrainerControl
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from huggingface_hub import HfFolder
import evaluate
import numpy as np
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class ImageClassificationTrainer(AbstractTrainer):
def start_training(self, parameters: ImageClassificationParameters):
logger.info('Start Training...')
try:
task = 'Extract training data'
self.get_status().update_status(0, task, parameters.get_project_name())
logger.info(task)
self.__extract_training_data(parameters)
if(self.get_status().is_training_aborted()):
return
task = 'Prepare Data set'
self.get_status().update_status(10, task)
logger.info(task)
images = self.__prepare_data_set(parameters)
if(self.get_status().is_training_aborted()):
return
task = 'Start training model'
self.get_status().update_status(20, task)
logger.info(task)
self.__train_model(images, parameters)
if(self.get_status().is_training_aborted()):
return
self.get_status().update_status(100, "Training completed")
except Exception as e:
logger.error(e)
self.get_status().finalize_abort_training(str(e))
raise RuntimeError(f"An error occurred: {str(e)}")
finally:
# Cleanup after processing
logger.info('Cleaning up training files after training')
shutil.rmtree(parameters.get_training_files_path())
if(self.get_status().is_training_aborted()):
self.get_status().finalize_abort_training("Training aborted")
def __extract_training_data(self, parameters: ImageClassificationParameters):
training_file = parameters.get_training_zip_file()
# Check if it is a valid ZIP file
if not zipfile.is_zipfile(training_file):
raise RuntimeError("Uploaded file is not a valid zip file")
# Extract the ZIP file
with zipfile.ZipFile(training_file, 'r') as zip_ref:
zip_ref.extractall(parameters.get_training_files_path())
os.remove(training_file)
logger.info(os.listdir(parameters.get_training_files_path()))
def __prepare_data_set(self, parameters: ImageClassificationParameters) -> dict:
dataset = load_dataset("imagefolder", data_dir=parameters.get_training_files_path())
images = dataset["train"]
images = images.train_test_split(test_size=0.2)
logger.info(images)
logger.info(images["train"][10])
# Preprocess the images
image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name())
# Apply some image transformations to the images to make the model more robust against overfitting.
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
image_processor.size["shortest_edge"]
if "shortest_edge" in image_processor.size
else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
images = images.with_transform(transforms)
return images
def __train_model(self, images: dict, parameters: ImageClassificationParameters):
environment_variable_checker = EnvironmentVariableChecker()
HfFolder.save_token(environment_variable_checker.get_huggingface_token())
image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name())
data_collator = DefaultDataCollator()
progressCallback = ProgressCallback(self.get_status(), 21, 89)
# Evaluate and metrics
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
# get label maps
labels = images["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
logger.info(id2label)
# train the model
model = AutoModelForImageClassification.from_pretrained(
parameters.get_source_model_name(),
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
target_model_id = environment_variable_checker.get_huggingface_organization() + '/' + parameters.get_result_model_name()
training_args = TrainingArguments(
output_dir=parameters.get_result_model_name(),
hub_model_id=target_model_id,
remove_unused_columns=False,
eval_strategy="epoch",
save_strategy="epoch",
learning_rate=parameters.get_training_parameters().learning_rate,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
per_device_eval_batch_size=16,
num_train_epochs=parameters.get_training_parameters().epochs,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=False,
hub_private_repo=True,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=images["train"],
eval_dataset=images["test"],
tokenizer=image_processor,
compute_metrics=compute_metrics,
callbacks=[progressCallback]
)
if(self.get_status().is_training_aborted()):
return
trainer.train()
if(self.get_status().is_training_aborted()):
return
logger.info(f"Model trained, start uploading")
self.get_status().update_status(90, f"Uploading model to Hugging Face")
trainer.push_to_hub()