fine-tuning-service / src /training_status.py
fashxp's picture
license
8a35bc0
# -------------------------------------------------------------------
# Pimcore
#
# This source file is available under two different licenses:
# - GNU General Public License version 3 (GPLv3)
# - Pimcore Commercial License (PCL)
# Full copyright and license information is available in
# LICENSE.md which is distributed with this source code.
#
# @copyright Copyright (c) Pimcore GmbH (http://www.pimcore.org)
# @license http://www.pimcore.org/license GPLv3 and PCL
# -------------------------------------------------------------------
import logging
from enum import Enum
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class Status(Enum):
NOT_STARTED = "NOT_STARTED"
IN_PROGRESS = "IN_PROGRESS"
CANCELLING = "CANCELLING"
CANCELLED = "CANCELLED"
COMPLETED = "COMPLETED"
class TrainingStatus:
__status: Status = Status.NOT_STARTED
__project_name: str = None
__task: str = None
__progress: int = 0
def update_status(self, progress: int, task: str, project_name: str = None):
if progress < 0 or progress > 100:
raise ValueError("Progress must be between 0 and 100")
if progress > 0:
self.__status = Status.IN_PROGRESS
if progress == 100:
self.__status = Status.COMPLETED
self.__progress = progress
if task is not None:
self.__task = task
if project_name is not None:
self.__project_name = project_name
def abort_training(self, task: str):
self.__task = task
self.__status = Status.CANCELLING
def finalize_abort_training(self, task: str):
self.__status = Status.CANCELLED
self.__progress = 0
self.__task = task
def is_training_aborted(self) -> bool:
return (self.__status == Status.CANCELLING)
def get_status(self) -> str:
return self.__status
def get_progress(self) -> int:
return self.__progress
def get_task(self) -> str:
return self.__task
def get_project_name(self) -> str:
return self.__project_name