fashxp commited on
Commit
7c4332a
·
1 Parent(s): 489e3fa

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # local config
2
+ docker-compose.override.yaml
3
+
4
+ # PhpStorm / IDEA
5
+ .idea
6
+ # NetBeans
7
+ nbproject
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+
6
+ ENV HOME=/home/user \
7
+ PATH=/home/user/.local/bin:$PATH
8
+
9
+ WORKDIR $HOME/app
10
+
11
+ COPY --chown=user requirements.txt requirements.txt
12
+
13
+ RUN pip install --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . .
16
+
17
+ CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -9,3 +9,4 @@ license: other
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+
docker-compose.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ server:
3
+ build:
4
+ context: .
5
+ ports:
6
+ - 7860:7860
7
+ develop:
8
+ watch:
9
+ - action: rebuild
10
+ path: .
11
+ volumes:
12
+ - python-cache:/home/user/.cache
13
+
14
+ volumes:
15
+ python-cache:
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.111.*
2
+ requests==2.*
3
+ uvicorn[standard]==0.30.*
4
+ pandas
5
+ transformers
6
+ datasets
7
+ evaluate
8
+ accelerate
9
+ pillow
10
+ torchvision
11
+ scikit-learn
12
+ huggingface_hub
13
+ pydantic
src/abstract_trainer.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import logging
4
+ from .training_status import TrainingStatus
5
+
6
+ logger = logging.getLogger(__name__)
7
+ logger.setLevel(logging.DEBUG)
8
+
9
+
10
+ class AbstractTrainer(ABC):
11
+
12
+ __training_status: TrainingStatus = TrainingStatus();
13
+
14
+ @abstractmethod
15
+ async def start_training(self):
16
+ logger.info('start abstract trainer training')
17
+ pass
18
+
19
+ def get_status(self) -> TrainingStatus:
20
+ return self.__training_status
src/environment_variable_checker.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import HTTPException, status
3
+
4
+ class EnvironmentVariableChecker:
5
+
6
+ def validate_environment_variables(self):
7
+
8
+ variables = ['AUTHENTICATION_TOKEN', 'HUGGINGFACE_TOKEN', 'HUGGINGFACE_ORGANIZATION']
9
+
10
+ for variable in variables:
11
+ if os.getenv(variable) is None:
12
+ raise HTTPException(
13
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
14
+ detail=f"Environment variable {variable} not set, please set the {variable} environment variable",
15
+ )
16
+
17
+ def get_authentication_token(self):
18
+ return os.getenv('AUTHENTICATION_TOKEN')
19
+
20
+ def get_huggingface_token(self):
21
+ return os.getenv('HUGGINGFACE_TOKEN');
22
+
23
+ def get_huggingface_organization(self):
24
+ return os.getenv('HUGGINGFACE_ORGANIZATION');
25
+
26
+
27
+
src/image_classification/image_classification_parameters.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Annotated
3
+ from fastapi import Form
4
+
5
+
6
+ class ImageClassificationTrainingParameters(BaseModel):
7
+ epochs: int
8
+ learning_rate: float
9
+
10
+
11
+ def map_image_classification_training_parameters(
12
+ epocs: Annotated[int, Form(...)] = 3,
13
+ learning_rate: Annotated[float, Form(...)] = 5e-5
14
+ ) -> ImageClassificationTrainingParameters:
15
+ return ImageClassificationTrainingParameters(
16
+ epochs=epocs,
17
+ learning_rate=learning_rate
18
+ )
19
+
20
+
21
+ class ImageClassificationParameters:
22
+
23
+ __training_files_path: str
24
+ __training_zip_file_path: str
25
+ __result_model_name: str
26
+ __source_model_name: str
27
+ __training_parameters: ImageClassificationTrainingParameters
28
+
29
+ def __init__(self,
30
+ training_files_path: str,
31
+ training_zip_file_path: str,
32
+ result_model_name: str,
33
+ source_model_name: str,
34
+ training_parameters: ImageClassificationTrainingParameters
35
+ ):
36
+ self.__training_files_path = training_files_path
37
+ self.__training_zip_file_path = training_zip_file_path
38
+ self.__result_model_name = result_model_name
39
+ self.__source_model_name = source_model_name
40
+ self.__training_parameters = training_parameters
41
+
42
+ def get_training_files_path(self) -> str:
43
+ return self.__training_files_path
44
+
45
+ def get_training_zip_file(self) -> str:
46
+ return self.__training_zip_file_path
47
+
48
+ def get_result_model_name(self) -> str:
49
+ return self.__result_model_name
50
+
51
+ def get_source_model_name(self) -> str:
52
+ return self.__source_model_name
53
+
54
+ def get_training_parameters(self) -> ImageClassificationTrainingParameters:
55
+ return self.__training_parameters
56
+
src/image_classification/image_classification_trainer.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from ..progress_callback import ProgressCallback
4
+ from ..abstract_trainer import AbstractTrainer
5
+ from ..environment_variable_checker import EnvironmentVariableChecker
6
+ from .image_classification_parameters import ImageClassificationParameters
7
+
8
+ import zipfile
9
+ import os
10
+ import shutil
11
+
12
+ from datasets import load_dataset
13
+ from transformers import AutoImageProcessor, DefaultDataCollator, AutoModelForImageClassification, TrainingArguments, Trainer, TrainerState, TrainerControl
14
+ from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
15
+ from huggingface_hub import HfFolder
16
+
17
+ import evaluate
18
+ import numpy as np
19
+
20
+ logger = logging.getLogger(__name__)
21
+ logger.setLevel(logging.DEBUG)
22
+
23
+
24
+ class ImageClassificationTrainer(AbstractTrainer):
25
+
26
+ def start_training(self, parameters: ImageClassificationParameters):
27
+
28
+ logger.info('Start Training...')
29
+
30
+ try:
31
+ task = 'Extract training data'
32
+ self.get_status().update_status(0, task)
33
+ logger.info(task)
34
+
35
+ self.__extract_training_data(parameters)
36
+
37
+ if(self.get_status().is_training_aborted()):
38
+ return
39
+
40
+ task = 'Prepare Data set'
41
+ self.get_status().update_status(10, task)
42
+ logger.info(task)
43
+
44
+ images = self.__prepare_data_set(parameters)
45
+
46
+ if(self.get_status().is_training_aborted()):
47
+ return
48
+
49
+ task = 'Start training model'
50
+ self.get_status().update_status(20, task)
51
+ logger.info(task)
52
+
53
+ self.__train_model(images, parameters)
54
+
55
+ self.get_status().update_status(100, "Training completed")
56
+
57
+ except Exception as e:
58
+ logger.error(e)
59
+ self.get_status().finalize_abort_training(str(e))
60
+
61
+ raise RuntimeError(f"An error occurred: {str(e)}")
62
+
63
+ finally:
64
+ # Cleanup after processing
65
+ logger.info('Cleaning up training files after training')
66
+ shutil.rmtree(parameters.get_training_files_path())
67
+
68
+ if(self.get_status().is_training_aborted()):
69
+ self.get_status().finalize_abort_training("Training aborted")
70
+
71
+
72
+
73
+ def __extract_training_data(self, parameters: ImageClassificationParameters):
74
+ training_file = parameters.get_training_zip_file()
75
+
76
+ # Check if it is a valid ZIP file
77
+ if not zipfile.is_zipfile(training_file):
78
+ raise RuntimeError("Uploaded file is not a valid zip file")
79
+
80
+ # Extract the ZIP file
81
+ with zipfile.ZipFile(training_file, 'r') as zip_ref:
82
+ zip_ref.extractall(parameters.get_training_files_path())
83
+
84
+ os.remove(training_file)
85
+ logger.info(os.listdir(parameters.get_training_files_path()))
86
+
87
+
88
+
89
+ def __prepare_data_set(self, parameters: ImageClassificationParameters) -> dict:
90
+
91
+ dataset = load_dataset("imagefolder", data_dir=parameters.get_training_files_path())
92
+
93
+ images = dataset["train"]
94
+ images = images.train_test_split(test_size=0.2)
95
+
96
+ logger.info(images)
97
+ logger.info(images["train"][100])
98
+
99
+
100
+ # Preprocess the images
101
+ image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name())
102
+
103
+ # Apply some image transformations to the images to make the model more robust against overfitting.
104
+ normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
105
+ size = (
106
+ image_processor.size["shortest_edge"]
107
+ if "shortest_edge" in image_processor.size
108
+ else (image_processor.size["height"], image_processor.size["width"])
109
+ )
110
+ _transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
111
+
112
+ def transforms(examples):
113
+ examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
114
+ del examples["image"]
115
+ return examples
116
+
117
+ images = images.with_transform(transforms)
118
+
119
+ return images
120
+
121
+
122
+ def __train_model(self, images: dict, parameters: ImageClassificationParameters):
123
+
124
+ environment_variable_checker = EnvironmentVariableChecker()
125
+ HfFolder.save_token(environment_variable_checker.get_huggingface_token())
126
+
127
+ image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name())
128
+ data_collator = DefaultDataCollator()
129
+ progressCallback = ProgressCallback(self.get_status())
130
+
131
+ # Evaluate and metrics
132
+ accuracy = evaluate.load("accuracy")
133
+ def compute_metrics(eval_pred):
134
+ predictions, labels = eval_pred
135
+ predictions = np.argmax(predictions, axis=1)
136
+ return accuracy.compute(predictions=predictions, references=labels)
137
+
138
+ # get label maps
139
+ labels = images["train"].features["label"].names
140
+ label2id, id2label = dict(), dict()
141
+ for i, label in enumerate(labels):
142
+ label2id[label] = str(i)
143
+ id2label[str(i)] = label
144
+ logger.info(id2label)
145
+
146
+ # train the model
147
+ model = AutoModelForImageClassification.from_pretrained(
148
+ parameters.get_source_model_name(),
149
+ num_labels=len(labels),
150
+ id2label=id2label,
151
+ label2id=label2id,
152
+ )
153
+
154
+ target_model_id = environment_variable_checker.get_huggingface_organization() + '/' + parameters.get_result_model_name()
155
+ training_args = TrainingArguments(
156
+ output_dir=parameters.get_result_model_name(),
157
+ hub_model_id=target_model_id,
158
+ remove_unused_columns=False,
159
+ eval_strategy="epoch",
160
+ save_strategy="epoch",
161
+ learning_rate=parameters.get_training_parameters().learning_rate,
162
+ per_device_train_batch_size=16,
163
+ gradient_accumulation_steps=4,
164
+ per_device_eval_batch_size=16,
165
+ num_train_epochs=parameters.get_training_parameters().epochs,
166
+ warmup_ratio=0.1,
167
+ logging_steps=10,
168
+ load_best_model_at_end=True,
169
+ metric_for_best_model="accuracy",
170
+ push_to_hub=False,
171
+ hub_private_repo=True,
172
+ )
173
+
174
+ trainer = Trainer(
175
+ model=model,
176
+ args=training_args,
177
+ data_collator=data_collator,
178
+ train_dataset=images["train"],
179
+ eval_dataset=images["test"],
180
+ tokenizer=image_processor,
181
+ compute_metrics=compute_metrics,
182
+ callbacks=[progressCallback]
183
+ )
184
+
185
+
186
+ if(self.get_status().is_training_aborted()):
187
+ return
188
+
189
+ trainer.train()
190
+
191
+ if(self.get_status().is_training_aborted()):
192
+ return
193
+
194
+ logger.info(f"Model trained, start uploading")
195
+ self.get_status().update_status(90, f"Uploading model to Hugging Face")
196
+ trainer.push_to_hub()
src/main.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import torch
4
+
5
+ from .training_status import Status
6
+ from .environment_variable_checker import EnvironmentVariableChecker
7
+
8
+ from .task_manager import TaskManager
9
+ from .training_manager import TrainingManager
10
+ from .image_classification.image_classification_trainer import ImageClassificationTrainer
11
+ from .image_classification.image_classification_parameters import ImageClassificationParameters, map_image_classification_training_parameters, ImageClassificationTrainingParameters
12
+
13
+ from fastapi import FastAPI, Header, Depends, HTTPException, BackgroundTasks, UploadFile, Form, File, status
14
+ from fastapi.responses import FileResponse
15
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
16
+ from pydantic import BaseModel
17
+ from typing import Optional, Annotated
18
+
19
+
20
+ import logging
21
+ import sys
22
+
23
+ import zipfile
24
+ import os
25
+ from pathlib import Path
26
+ import tempfile
27
+ import shutil
28
+
29
+
30
+ app = FastAPI()
31
+
32
+ environmentVariableChecker = EnvironmentVariableChecker()
33
+ environmentVariableChecker.validate_environment_variables()
34
+
35
+ logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s')
36
+ logger = logging.getLogger(__name__)
37
+ logger.setLevel(logging.DEBUG)
38
+
39
+ classification_trainer: TrainingManager = TrainingManager(ImageClassificationTrainer())
40
+
41
+ security = HTTPBearer()
42
+ def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
43
+
44
+ token = environmentVariableChecker.get_authentication_token()
45
+
46
+ if credentials.credentials != token:
47
+ raise HTTPException(
48
+ status_code=status.HTTP_401_UNAUTHORIZED,
49
+ detail="Invalid token",
50
+ headers={"WWW-Authenticate": "Bearer"},
51
+ )
52
+ return {"token": credentials.credentials}
53
+
54
+
55
+ class ResponseModel(BaseModel):
56
+ message: str
57
+ success: bool = True
58
+
59
+
60
+ @app.post(
61
+ "/upload",
62
+ summary="Upload a zip file containing training data",
63
+ response_model=ResponseModel
64
+ )
65
+ async def upload_file(
66
+ training_params: Annotated[ImageClassificationTrainingParameters, Depends(map_image_classification_training_parameters)],
67
+ data_files_training: Annotated[UploadFile, File(...)],
68
+ token_data: dict = Depends(verify_token),
69
+ result_model_name: str = Form(...),
70
+ source_model_name: str = Form('google/vit-base-patch16-224-in21k'),
71
+ ):
72
+
73
+ # check if training is running, if so then exit
74
+ status = classification_trainer.get_task_status()
75
+ if status.get_status() == Status.IN_PROGRESS or status.get_status() == Status.CANCELLING:
76
+ raise HTTPException(status_code=405, detail="Training is already in progress")
77
+
78
+ # Ensure the uploaded file is a ZIP file
79
+ if not data_files_training.filename.endswith(".zip"):
80
+ raise HTTPException(status_code=422, detail="Uploaded file is not a zip file")
81
+
82
+ try:
83
+ # Create a temporary directory to extract the contents
84
+ tmp_path = os.path.join(tempfile.gettempdir(), 'training_data')
85
+ path = Path(tmp_path)
86
+ path.mkdir(parents=True, exist_ok=True)
87
+
88
+ contents = await data_files_training.read()
89
+ zip_path = os.path.join(tmp_path, 'image_classification_data.zip')
90
+ with open(zip_path, 'wb') as temp_file:
91
+ temp_file.write(contents)
92
+
93
+ # prepare parameters
94
+ parameters = ImageClassificationParameters(
95
+ training_files_path=tmp_path,
96
+ training_zip_file_path=zip_path,
97
+ result_model_name=result_model_name,
98
+ source_model_name=source_model_name,
99
+ training_parameters=training_params
100
+ )
101
+
102
+ # start training
103
+ await classification_trainer.start_training(parameters)
104
+
105
+ # TODO add more return parameters and information
106
+ return ResponseModel(message="training started")
107
+
108
+ except Exception as e:
109
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
110
+
111
+
112
+ @app.get("/get_task_status")
113
+ async def get_task_status(token_data: dict = Depends(verify_token)):
114
+ status = classification_trainer.get_task_status()
115
+ return {
116
+ "progress": status.get_progress(),
117
+ "task": status.get_task(),
118
+ "status": status.get_status().value
119
+ }
120
+
121
+
122
+ @app.get("/stop_task")
123
+ async def stop_task(token_data: dict = Depends(verify_token)):
124
+ try:
125
+ classification_trainer.stop_task()
126
+ return {
127
+ "success": True
128
+ }
129
+ except Exception as e:
130
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
131
+
132
+
133
+
134
+ @app.get("/gpu_check")
135
+ async def gpu_check():
136
+
137
+ gpu = 'GPU not available'
138
+ if torch.cuda.is_available():
139
+ gpu = 'GPU is available'
140
+ print("GPU is available")
141
+ else:
142
+ print("GPU is not available")
143
+
144
+ return {'success': True, 'response': 'hello world 3', 'gpu': gpu}
145
+
src/progress_callback.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
3
+
4
+ from .training_status import TrainingStatus
5
+
6
+
7
+ logger = logging.getLogger(__name__)
8
+ logger.setLevel(logging.DEBUG)
9
+
10
+ class ProgressCallback(TrainerCallback):
11
+
12
+ __trainingStatus: TrainingStatus = None
13
+
14
+ def __init__(self, trainingStatus: TrainingStatus):
15
+ self.__trainingStatus = trainingStatus
16
+
17
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
18
+ logger.info(f"Completed step {state.global_step} of {state.max_steps}")
19
+
20
+ if self.__trainingStatus.is_training_aborted():
21
+ control.should_training_stop = True
22
+ logger.info("Training aborted")
23
+ return
24
+
25
+ startPercentage = 21
26
+ endPercentage = 89
27
+ scope = endPercentage - startPercentage
28
+ progress = startPercentage + (state.global_step / state.max_steps) * scope
29
+
30
+ self.__trainingStatus.update_status(progress, f"Training model, completed step {state.global_step} of {state.max_steps}")
31
+
src/task_manager.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from fastapi import BackgroundTasks, HTTPException
4
+
5
+ from concurrent.futures import ThreadPoolExecutor
6
+
7
+ logger = logging.getLogger(__name__)
8
+ logger.setLevel(logging.DEBUG)
9
+
10
+
11
+ class Worker:
12
+ def doing_work(self, task_manager):
13
+ task_manager.task_status["status"] = "Running"
14
+ for i in range(1, 101):
15
+ if task_manager.task_status["status"] == "Stopped":
16
+ break
17
+ asyncio.sleep(1) # Simulate a time-consuming task
18
+ task_manager.task_status["progress"] = i
19
+ logger.info('process ' + str(i) + '%' + ' done')
20
+
21
+ if task_manager.task_status["status"] != "Stopped":
22
+ task_manager.task_status["status"] = "Completed"
23
+
24
+
25
+
26
+ class TaskManager:
27
+
28
+ task_status = {"progress": 0, "status": "Not started"}
29
+ task = None
30
+
31
+ #def __init__(self):
32
+
33
+ worker = Worker()
34
+
35
+ async def doing_work(self):
36
+ loop = asyncio.get_running_loop()
37
+ with ThreadPoolExecutor() as pool:
38
+ await loop.run_in_executor(pool, self.worker.doing_work, self)
39
+ #self.worker.doing_work(self)
40
+
41
+ # self.task_status["status"] = "Running"
42
+ # for i in range(1, 101):
43
+ # if self.task_status["status"] == "Stopped":
44
+ # break
45
+ # await asyncio.sleep(1) # Simulate a time-consuming task
46
+ # self.task_status["progress"] = i
47
+ # logger.info('process ' + str(i) + '%' + ' done')
48
+
49
+ # if self.task_status["status"] != "Stopped":
50
+ # self.task_status["status"] = "Completed"
51
+
52
+
53
+ async def start_task(self):
54
+ if self.task is None or self.task.done():
55
+ self.task_status["progress"] = 0
56
+ self.task_status["status"] = "Not started"
57
+ self.task = asyncio.create_task(self.doing_work())
58
+ return {"message": "Task started"}
59
+ else:
60
+ raise HTTPException(status_code=409, detail="Task already running")
61
+
62
+ async def get_task_status(self):
63
+ return self.task_status
64
+
65
+ async def stop_task(self):
66
+ if self.task is not None and not self.task.done():
67
+ self.task_status["status"] = "Stopped"
68
+ self.task.cancel()
69
+ return {"message": "Task stopped"}
70
+ else:
71
+ raise HTTPException(status_code=409, detail="No task running")
72
+
src/training_manager.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import asyncio
3
+ from .abstract_trainer import AbstractTrainer
4
+ from .training_status import TrainingStatus
5
+ from concurrent.futures import ThreadPoolExecutor
6
+
7
+
8
+ logger = logging.getLogger(__name__)
9
+ logger.setLevel(logging.DEBUG)
10
+
11
+
12
+
13
+ class TrainingManager:
14
+
15
+ __training_task = None
16
+ __trainer: AbstractTrainer = None
17
+
18
+ task_status = {"progress": 0, "status": "Not started"}
19
+
20
+ def __init__(self, trainer: AbstractTrainer):
21
+ self.__trainer = trainer
22
+
23
+ async def __do_start_training(self, parameters):
24
+ logger.info('do start training')
25
+
26
+ loop = asyncio.get_running_loop()
27
+ with ThreadPoolExecutor() as pool:
28
+ await loop.run_in_executor(pool, self.__trainer.start_training, parameters)
29
+
30
+ logger.info('done')
31
+
32
+ async def start_training(self, parameters):
33
+ logger.info('start training')
34
+
35
+ if self.__training_task is None or self.__training_task.done():
36
+ self.__training_task = asyncio.create_task(self.__do_start_training(parameters))
37
+ else:
38
+ raise RuntimeError("Training already running")
39
+
40
+ def get_task_status(self) -> TrainingStatus:
41
+ return self.__trainer.get_status()
42
+
43
+ def stop_task(self):
44
+ if self.__training_task is not None and not self.__training_task.done():
45
+ self.__trainer.get_status().abort_training("Stopping training")
46
+ #self.__training_task.cancel()
47
+
48
+ else:
49
+ raise RuntimeError("No task running")
50
+
src/training_status.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from enum import Enum
3
+
4
+
5
+ logger = logging.getLogger(__name__)
6
+ logger.setLevel(logging.DEBUG)
7
+
8
+ class Status(Enum):
9
+ NOT_STARTED = "NOT_STARTED"
10
+ IN_PROGRESS = "IN_PROGRESS"
11
+ CANCELLING = "CANCELLING"
12
+ CANCELLED = "CANCELLED"
13
+ COMPLETED = "COMPLETED"
14
+
15
+ class TrainingStatus:
16
+
17
+ __status: Status = Status.NOT_STARTED
18
+ __task: str = None
19
+ __progress: int = 0
20
+
21
+ def update_status(self, progress: int, task: str):
22
+ if progress < 0 or progress > 100:
23
+ raise ValueError("Progress must be between 0 and 100")
24
+
25
+ if progress > 0:
26
+ self.__status = Status.IN_PROGRESS
27
+
28
+ if progress == 100:
29
+ self.__status = Status.COMPLETED
30
+
31
+ self.__progress = progress
32
+
33
+ if task is not None:
34
+ self.__task = task
35
+
36
+
37
+ def abort_training(self, task: str):
38
+ self.__task = task
39
+ self.__status = Status.CANCELLING
40
+
41
+ def finalize_abort_training(self, task: str):
42
+ self.__status = Status.CANCELLED
43
+ self.__progress = 0
44
+ self.__task = task
45
+
46
+ def is_training_aborted(self) -> bool:
47
+ return (self.__status == Status.CANCELLING)
48
+
49
+ def get_status(self) -> str:
50
+ return self.__status
51
+
52
+ def get_progress(self) -> int:
53
+ return self.__progress
54
+
55
+ def get_task(self) -> str:
56
+ return self.__task
57
+
58
+