# ------------------------------------------------------------------- # Pimcore # # This source file is available under two different licenses: # - GNU General Public License version 3 (GPLv3) # - Pimcore Commercial License (PCL) # Full copyright and license information is available in # LICENSE.md which is distributed with this source code. # # @copyright Copyright (c) Pimcore GmbH (http://www.pimcore.org) # @license http://www.pimcore.org/license GPLv3 and PCL # ------------------------------------------------------------------- import logging from ..progress_callback import ProgressCallback from ..abstract_trainer import AbstractTrainer from ..environment_variable_checker import EnvironmentVariableChecker from .image_classification_parameters import ImageClassificationParameters import zipfile import os import shutil from datasets import load_dataset from transformers import AutoImageProcessor, DefaultDataCollator, AutoModelForImageClassification, TrainingArguments, Trainer, TrainerState, TrainerControl from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor from huggingface_hub import HfFolder import evaluate import numpy as np logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) class ImageClassificationTrainer(AbstractTrainer): def start_training(self, parameters: ImageClassificationParameters): logger.info('Start Training...') try: task = 'Extract training data' self.get_status().update_status(0, task, parameters.get_project_name()) logger.info(task) self.__extract_training_data(parameters) if(self.get_status().is_training_aborted()): return task = 'Prepare Data set' self.get_status().update_status(10, task) logger.info(task) images = self.__prepare_data_set(parameters) if(self.get_status().is_training_aborted()): return task = 'Start training model' self.get_status().update_status(20, task) logger.info(task) self.__train_model(images, parameters) if(self.get_status().is_training_aborted()): return self.get_status().update_status(100, "Training completed") except Exception as e: logger.error(e) self.get_status().finalize_abort_training(str(e)) raise RuntimeError(f"An error occurred: {str(e)}") finally: # Cleanup after processing logger.info('Cleaning up training files after training') shutil.rmtree(parameters.get_training_files_path()) if(self.get_status().is_training_aborted()): self.get_status().finalize_abort_training("Training aborted") def __extract_training_data(self, parameters: ImageClassificationParameters): training_file = parameters.get_training_zip_file() # Check if it is a valid ZIP file if not zipfile.is_zipfile(training_file): raise RuntimeError("Uploaded file is not a valid zip file") # Extract the ZIP file with zipfile.ZipFile(training_file, 'r') as zip_ref: zip_ref.extractall(parameters.get_training_files_path()) os.remove(training_file) logger.info(os.listdir(parameters.get_training_files_path())) def __prepare_data_set(self, parameters: ImageClassificationParameters) -> dict: dataset = load_dataset("imagefolder", data_dir=parameters.get_training_files_path()) images = dataset["train"] images = images.train_test_split(test_size=0.2) logger.info(images) logger.info(images["train"][10]) # Preprocess the images image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name()) # Apply some image transformations to the images to make the model more robust against overfitting. normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) size = ( image_processor.size["shortest_edge"] if "shortest_edge" in image_processor.size else (image_processor.size["height"], image_processor.size["width"]) ) _transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize]) def transforms(examples): examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]] del examples["image"] return examples images = images.with_transform(transforms) return images def __train_model(self, images: dict, parameters: ImageClassificationParameters): environment_variable_checker = EnvironmentVariableChecker() HfFolder.save_token(environment_variable_checker.get_huggingface_token()) image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name()) data_collator = DefaultDataCollator() progressCallback = ProgressCallback(self.get_status(), 21, 89) # Evaluate and metrics accuracy = evaluate.load("accuracy") def compute_metrics(eval_pred): predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) return accuracy.compute(predictions=predictions, references=labels) # get label maps labels = images["train"].features["label"].names label2id, id2label = dict(), dict() for i, label in enumerate(labels): label2id[label] = str(i) id2label[str(i)] = label logger.info(id2label) # train the model model = AutoModelForImageClassification.from_pretrained( parameters.get_source_model_name(), num_labels=len(labels), id2label=id2label, label2id=label2id, ) target_model_id = environment_variable_checker.get_huggingface_organization() + '/' + parameters.get_result_model_name() training_args = TrainingArguments( output_dir=parameters.get_result_model_name(), hub_model_id=target_model_id, remove_unused_columns=False, eval_strategy="epoch", save_strategy="epoch", learning_rate=parameters.get_training_parameters().learning_rate, per_device_train_batch_size=16, gradient_accumulation_steps=4, per_device_eval_batch_size=16, num_train_epochs=parameters.get_training_parameters().epochs, warmup_ratio=0.1, logging_steps=10, load_best_model_at_end=True, metric_for_best_model="accuracy", push_to_hub=False, hub_private_repo=True, ) trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=images["train"], eval_dataset=images["test"], tokenizer=image_processor, compute_metrics=compute_metrics, callbacks=[progressCallback] ) if(self.get_status().is_training_aborted()): return trainer.train() if(self.get_status().is_training_aborted()): return logger.info(f"Model trained, start uploading") self.get_status().update_status(90, f"Uploading model to Hugging Face") trainer.push_to_hub()