Spaces:
Sleeping
Sleeping
# ------------------------------------------------------------------- | |
# Pimcore | |
# | |
# This source file is available under two different licenses: | |
# - GNU General Public License version 3 (GPLv3) | |
# - Pimcore Commercial License (PCL) | |
# Full copyright and license information is available in | |
# LICENSE.md which is distributed with this source code. | |
# | |
# @copyright Copyright (c) Pimcore GmbH (http://www.pimcore.org) | |
# @license http://www.pimcore.org/license GPLv3 and PCL | |
# ------------------------------------------------------------------- | |
from pydantic import BaseModel | |
from typing import Annotated | |
from fastapi import Form | |
class TextClassificationTrainingParameters(BaseModel): | |
""" Provides specific training parameters for the text classification fine tuning.""" | |
epochs: int | |
learning_rate: float | |
def map_text_classification_training_parameters( | |
epocs: Annotated[int, Form(description="Epochs executed during training.")] = 3, | |
learning_rate: Annotated[float, Form(description="Learning rate for training.")] = 5e-5 | |
) -> TextClassificationTrainingParameters: | |
""" Maps the parameters to the TextClassificationTrainingParameters class. """ | |
return TextClassificationTrainingParameters( | |
epochs=epocs, | |
learning_rate=learning_rate | |
) | |
class TextClassificationParameters: | |
""" Provides all parameters for the text classification fine tuning. """ | |
__training_csv_file_path: str | |
__training_csv_limiter: str | |
__project_name: str | |
__source_model_name: str | |
__training_parameters: TextClassificationTrainingParameters | |
def __init__(self, | |
training_csv_file_path: str, | |
project_name: str, | |
source_model_name: str, | |
training_parameters: TextClassificationTrainingParameters, | |
training_csv_limiter: str = ';' | |
): | |
self.__training_csv_file_path = training_csv_file_path | |
self.__project_name = project_name | |
self.__source_model_name = source_model_name | |
self.__training_parameters = training_parameters | |
self.__training_csv_limiter = training_csv_limiter | |
def get_training_csv_file_path(self) -> str: | |
return self.__training_csv_file_path | |
def get_training_csv_limiter(self) -> str: | |
return self.__training_csv_limiter | |
def get_project_name(self) -> str: | |
return self.__project_name | |
def get_result_model_name(self) -> str: | |
return self.__project_name | |
def get_source_model_name(self) -> str: | |
return self.__source_model_name | |
def get_training_parameters(self) -> TextClassificationTrainingParameters: | |
return self.__training_parameters | |