# ------------------------------------------------------------------- # Pimcore # # This source file is available under two different licenses: # - GNU General Public License version 3 (GPLv3) # - Pimcore Commercial License (PCL) # Full copyright and license information is available in # LICENSE.md which is distributed with this source code. # # @copyright Copyright (c) Pimcore GmbH (http://www.pimcore.org) # @license http://www.pimcore.org/license GPLv3 and PCL # ------------------------------------------------------------------- from pydantic import BaseModel from typing import Annotated from fastapi import Form class TextClassificationTrainingParameters(BaseModel): """ Provides specific training parameters for the text classification fine tuning.""" epochs: int learning_rate: float def map_text_classification_training_parameters( epocs: Annotated[int, Form(description="Epochs executed during training.")] = 3, learning_rate: Annotated[float, Form(description="Learning rate for training.")] = 5e-5 ) -> TextClassificationTrainingParameters: """ Maps the parameters to the TextClassificationTrainingParameters class. """ return TextClassificationTrainingParameters( epochs=epocs, learning_rate=learning_rate ) class TextClassificationParameters: """ Provides all parameters for the text classification fine tuning. """ __training_csv_file_path: str __training_csv_limiter: str __project_name: str __source_model_name: str __training_parameters: TextClassificationTrainingParameters def __init__(self, training_csv_file_path: str, project_name: str, source_model_name: str, training_parameters: TextClassificationTrainingParameters, training_csv_limiter: str = ';' ): self.__training_csv_file_path = training_csv_file_path self.__project_name = project_name self.__source_model_name = source_model_name self.__training_parameters = training_parameters self.__training_csv_limiter = training_csv_limiter def get_training_csv_file_path(self) -> str: return self.__training_csv_file_path def get_training_csv_limiter(self) -> str: return self.__training_csv_limiter def get_project_name(self) -> str: return self.__project_name def get_result_model_name(self) -> str: return self.__project_name def get_source_model_name(self) -> str: return self.__source_model_name def get_training_parameters(self) -> TextClassificationTrainingParameters: return self.__training_parameters