# ------------------------------------------------------------------- # Pimcore # # This source file is available under two different licenses: # - GNU General Public License version 3 (GPLv3) # - Pimcore Commercial License (PCL) # Full copyright and license information is available in # LICENSE.md which is distributed with this source code. # # @copyright Copyright (c) Pimcore GmbH (http://www.pimcore.org) # @license http://www.pimcore.org/license GPLv3 and PCL # ------------------------------------------------------------------- from pydantic import BaseModel from typing import Annotated from fastapi import Form class ImageClassificationTrainingParameters(BaseModel): """ Provides specific training parameters for the image classification fine tuning.""" epochs: int learning_rate: float def map_image_classification_training_parameters( epocs: Annotated[int, Form(description="Epochs executed during training.")] = 3, learning_rate: Annotated[float, Form(description="Learning rate for training.")] = 5e-5 ) -> ImageClassificationTrainingParameters: """ Maps the parameters to the ImageClassificationTrainingParameters class. """ return ImageClassificationTrainingParameters( epochs=epocs, learning_rate=learning_rate ) class ImageClassificationParameters: """ Provides all parameters for the image classification fine tuning. """ __training_files_path: str __training_zip_file_path: str __project_name: str __source_model_name: str __training_parameters: ImageClassificationTrainingParameters def __init__(self, training_files_path: str, training_zip_file_path: str, project_name: str, source_model_name: str, training_parameters: ImageClassificationTrainingParameters ): self.__training_files_path = training_files_path self.__training_zip_file_path = training_zip_file_path self.__project_name = project_name self.__source_model_name = source_model_name self.__training_parameters = training_parameters def get_training_files_path(self) -> str: return self.__training_files_path def get_training_zip_file(self) -> str: return self.__training_zip_file_path def get_result_model_name(self) -> str: return self.__project_name def get_project_name(self) -> str: return self.__project_name def get_source_model_name(self) -> str: return self.__source_model_name def get_training_parameters(self) -> ImageClassificationTrainingParameters: return self.__training_parameters