Spaces:
Sleeping
Sleeping
File size: 7,556 Bytes
8a35bc0 7c4332a 264e02e 7c4332a ade1b4d 7c4332a 264e02e 7c4332a 264e02e 7c4332a 264e02e 7c4332a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
# -------------------------------------------------------------------
# Pimcore
#
# This source file is available under two different licenses:
# - GNU General Public License version 3 (GPLv3)
# - Pimcore Commercial License (PCL)
# Full copyright and license information is available in
# LICENSE.md which is distributed with this source code.
#
# @copyright Copyright (c) Pimcore GmbH (http://www.pimcore.org)
# @license http://www.pimcore.org/license GPLv3 and PCL
# -------------------------------------------------------------------
import logging
from ..progress_callback import ProgressCallback
from ..abstract_trainer import AbstractTrainer
from ..environment_variable_checker import EnvironmentVariableChecker
from .image_classification_parameters import ImageClassificationParameters
import zipfile
import os
import shutil
from datasets import load_dataset
from transformers import AutoImageProcessor, DefaultDataCollator, AutoModelForImageClassification, TrainingArguments, Trainer, TrainerState, TrainerControl
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from huggingface_hub import HfFolder
import evaluate
import numpy as np
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class ImageClassificationTrainer(AbstractTrainer):
def start_training(self, parameters: ImageClassificationParameters):
logger.info('Start Training...')
try:
task = 'Extract training data'
self.get_status().update_status(0, task, parameters.get_project_name())
logger.info(task)
self.__extract_training_data(parameters)
if(self.get_status().is_training_aborted()):
return
task = 'Prepare Data set'
self.get_status().update_status(10, task)
logger.info(task)
images = self.__prepare_data_set(parameters)
if(self.get_status().is_training_aborted()):
return
task = 'Start training model'
self.get_status().update_status(20, task)
logger.info(task)
self.__train_model(images, parameters)
if(self.get_status().is_training_aborted()):
return
self.get_status().update_status(100, "Training completed")
except Exception as e:
logger.error(e)
self.get_status().finalize_abort_training(str(e))
raise RuntimeError(f"An error occurred: {str(e)}")
finally:
# Cleanup after processing
logger.info('Cleaning up training files after training')
shutil.rmtree(parameters.get_training_files_path())
if(self.get_status().is_training_aborted()):
self.get_status().finalize_abort_training("Training aborted")
def __extract_training_data(self, parameters: ImageClassificationParameters):
training_file = parameters.get_training_zip_file()
# Check if it is a valid ZIP file
if not zipfile.is_zipfile(training_file):
raise RuntimeError("Uploaded file is not a valid zip file")
# Extract the ZIP file
with zipfile.ZipFile(training_file, 'r') as zip_ref:
zip_ref.extractall(parameters.get_training_files_path())
os.remove(training_file)
logger.info(os.listdir(parameters.get_training_files_path()))
def __prepare_data_set(self, parameters: ImageClassificationParameters) -> dict:
dataset = load_dataset("imagefolder", data_dir=parameters.get_training_files_path())
images = dataset["train"]
images = images.train_test_split(test_size=0.2)
logger.info(images)
logger.info(images["train"][10])
# Preprocess the images
image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name())
# Apply some image transformations to the images to make the model more robust against overfitting.
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
image_processor.size["shortest_edge"]
if "shortest_edge" in image_processor.size
else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
images = images.with_transform(transforms)
return images
def __train_model(self, images: dict, parameters: ImageClassificationParameters):
environment_variable_checker = EnvironmentVariableChecker()
HfFolder.save_token(environment_variable_checker.get_huggingface_token())
image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name())
data_collator = DefaultDataCollator()
progressCallback = ProgressCallback(self.get_status(), 21, 89)
# Evaluate and metrics
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
# get label maps
labels = images["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
logger.info(id2label)
# train the model
model = AutoModelForImageClassification.from_pretrained(
parameters.get_source_model_name(),
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
target_model_id = environment_variable_checker.get_huggingface_organization() + '/' + parameters.get_result_model_name()
training_args = TrainingArguments(
output_dir=parameters.get_result_model_name(),
hub_model_id=target_model_id,
remove_unused_columns=False,
eval_strategy="epoch",
save_strategy="epoch",
learning_rate=parameters.get_training_parameters().learning_rate,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
per_device_eval_batch_size=16,
num_train_epochs=parameters.get_training_parameters().epochs,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=False,
hub_private_repo=True,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=images["train"],
eval_dataset=images["test"],
tokenizer=image_processor,
compute_metrics=compute_metrics,
callbacks=[progressCallback]
)
if(self.get_status().is_training_aborted()):
return
trainer.train()
if(self.get_status().is_training_aborted()):
return
logger.info(f"Model trained, start uploading")
self.get_status().update_status(90, f"Uploading model to Hugging Face")
trainer.push_to_hub() |