Spaces:
Sleeping
Sleeping
initial commit
Browse files- .gitignore +7 -0
- Dockerfile +17 -0
- README.md +1 -0
- docker-compose.yaml +15 -0
- requirements.txt +13 -0
- src/abstract_trainer.py +20 -0
- src/environment_variable_checker.py +27 -0
- src/image_classification/image_classification_parameters.py +56 -0
- src/image_classification/image_classification_trainer.py +196 -0
- src/main.py +145 -0
- src/progress_callback.py +31 -0
- src/task_manager.py +72 -0
- src/training_manager.py +50 -0
- src/training_status.py +58 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# local config
|
2 |
+
docker-compose.override.yaml
|
3 |
+
|
4 |
+
# PhpStorm / IDEA
|
5 |
+
.idea
|
6 |
+
# NetBeans
|
7 |
+
nbproject
|
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
RUN useradd -m -u 1000 user
|
4 |
+
USER user
|
5 |
+
|
6 |
+
ENV HOME=/home/user \
|
7 |
+
PATH=/home/user/.local/bin:$PATH
|
8 |
+
|
9 |
+
WORKDIR $HOME/app
|
10 |
+
|
11 |
+
COPY --chown=user requirements.txt requirements.txt
|
12 |
+
|
13 |
+
RUN pip install --upgrade -r requirements.txt
|
14 |
+
|
15 |
+
COPY --chown=user . .
|
16 |
+
|
17 |
+
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -9,3 +9,4 @@ license: other
|
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
12 |
+
|
docker-compose.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
server:
|
3 |
+
build:
|
4 |
+
context: .
|
5 |
+
ports:
|
6 |
+
- 7860:7860
|
7 |
+
develop:
|
8 |
+
watch:
|
9 |
+
- action: rebuild
|
10 |
+
path: .
|
11 |
+
volumes:
|
12 |
+
- python-cache:/home/user/.cache
|
13 |
+
|
14 |
+
volumes:
|
15 |
+
python-cache:
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.111.*
|
2 |
+
requests==2.*
|
3 |
+
uvicorn[standard]==0.30.*
|
4 |
+
pandas
|
5 |
+
transformers
|
6 |
+
datasets
|
7 |
+
evaluate
|
8 |
+
accelerate
|
9 |
+
pillow
|
10 |
+
torchvision
|
11 |
+
scikit-learn
|
12 |
+
huggingface_hub
|
13 |
+
pydantic
|
src/abstract_trainer.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from .training_status import TrainingStatus
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
logger.setLevel(logging.DEBUG)
|
8 |
+
|
9 |
+
|
10 |
+
class AbstractTrainer(ABC):
|
11 |
+
|
12 |
+
__training_status: TrainingStatus = TrainingStatus();
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
async def start_training(self):
|
16 |
+
logger.info('start abstract trainer training')
|
17 |
+
pass
|
18 |
+
|
19 |
+
def get_status(self) -> TrainingStatus:
|
20 |
+
return self.__training_status
|
src/environment_variable_checker.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from fastapi import HTTPException, status
|
3 |
+
|
4 |
+
class EnvironmentVariableChecker:
|
5 |
+
|
6 |
+
def validate_environment_variables(self):
|
7 |
+
|
8 |
+
variables = ['AUTHENTICATION_TOKEN', 'HUGGINGFACE_TOKEN', 'HUGGINGFACE_ORGANIZATION']
|
9 |
+
|
10 |
+
for variable in variables:
|
11 |
+
if os.getenv(variable) is None:
|
12 |
+
raise HTTPException(
|
13 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
14 |
+
detail=f"Environment variable {variable} not set, please set the {variable} environment variable",
|
15 |
+
)
|
16 |
+
|
17 |
+
def get_authentication_token(self):
|
18 |
+
return os.getenv('AUTHENTICATION_TOKEN')
|
19 |
+
|
20 |
+
def get_huggingface_token(self):
|
21 |
+
return os.getenv('HUGGINGFACE_TOKEN');
|
22 |
+
|
23 |
+
def get_huggingface_organization(self):
|
24 |
+
return os.getenv('HUGGINGFACE_ORGANIZATION');
|
25 |
+
|
26 |
+
|
27 |
+
|
src/image_classification/image_classification_parameters.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
from typing import Annotated
|
3 |
+
from fastapi import Form
|
4 |
+
|
5 |
+
|
6 |
+
class ImageClassificationTrainingParameters(BaseModel):
|
7 |
+
epochs: int
|
8 |
+
learning_rate: float
|
9 |
+
|
10 |
+
|
11 |
+
def map_image_classification_training_parameters(
|
12 |
+
epocs: Annotated[int, Form(...)] = 3,
|
13 |
+
learning_rate: Annotated[float, Form(...)] = 5e-5
|
14 |
+
) -> ImageClassificationTrainingParameters:
|
15 |
+
return ImageClassificationTrainingParameters(
|
16 |
+
epochs=epocs,
|
17 |
+
learning_rate=learning_rate
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
class ImageClassificationParameters:
|
22 |
+
|
23 |
+
__training_files_path: str
|
24 |
+
__training_zip_file_path: str
|
25 |
+
__result_model_name: str
|
26 |
+
__source_model_name: str
|
27 |
+
__training_parameters: ImageClassificationTrainingParameters
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
training_files_path: str,
|
31 |
+
training_zip_file_path: str,
|
32 |
+
result_model_name: str,
|
33 |
+
source_model_name: str,
|
34 |
+
training_parameters: ImageClassificationTrainingParameters
|
35 |
+
):
|
36 |
+
self.__training_files_path = training_files_path
|
37 |
+
self.__training_zip_file_path = training_zip_file_path
|
38 |
+
self.__result_model_name = result_model_name
|
39 |
+
self.__source_model_name = source_model_name
|
40 |
+
self.__training_parameters = training_parameters
|
41 |
+
|
42 |
+
def get_training_files_path(self) -> str:
|
43 |
+
return self.__training_files_path
|
44 |
+
|
45 |
+
def get_training_zip_file(self) -> str:
|
46 |
+
return self.__training_zip_file_path
|
47 |
+
|
48 |
+
def get_result_model_name(self) -> str:
|
49 |
+
return self.__result_model_name
|
50 |
+
|
51 |
+
def get_source_model_name(self) -> str:
|
52 |
+
return self.__source_model_name
|
53 |
+
|
54 |
+
def get_training_parameters(self) -> ImageClassificationTrainingParameters:
|
55 |
+
return self.__training_parameters
|
56 |
+
|
src/image_classification/image_classification_trainer.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from ..progress_callback import ProgressCallback
|
4 |
+
from ..abstract_trainer import AbstractTrainer
|
5 |
+
from ..environment_variable_checker import EnvironmentVariableChecker
|
6 |
+
from .image_classification_parameters import ImageClassificationParameters
|
7 |
+
|
8 |
+
import zipfile
|
9 |
+
import os
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
from datasets import load_dataset
|
13 |
+
from transformers import AutoImageProcessor, DefaultDataCollator, AutoModelForImageClassification, TrainingArguments, Trainer, TrainerState, TrainerControl
|
14 |
+
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
|
15 |
+
from huggingface_hub import HfFolder
|
16 |
+
|
17 |
+
import evaluate
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
logger.setLevel(logging.DEBUG)
|
22 |
+
|
23 |
+
|
24 |
+
class ImageClassificationTrainer(AbstractTrainer):
|
25 |
+
|
26 |
+
def start_training(self, parameters: ImageClassificationParameters):
|
27 |
+
|
28 |
+
logger.info('Start Training...')
|
29 |
+
|
30 |
+
try:
|
31 |
+
task = 'Extract training data'
|
32 |
+
self.get_status().update_status(0, task)
|
33 |
+
logger.info(task)
|
34 |
+
|
35 |
+
self.__extract_training_data(parameters)
|
36 |
+
|
37 |
+
if(self.get_status().is_training_aborted()):
|
38 |
+
return
|
39 |
+
|
40 |
+
task = 'Prepare Data set'
|
41 |
+
self.get_status().update_status(10, task)
|
42 |
+
logger.info(task)
|
43 |
+
|
44 |
+
images = self.__prepare_data_set(parameters)
|
45 |
+
|
46 |
+
if(self.get_status().is_training_aborted()):
|
47 |
+
return
|
48 |
+
|
49 |
+
task = 'Start training model'
|
50 |
+
self.get_status().update_status(20, task)
|
51 |
+
logger.info(task)
|
52 |
+
|
53 |
+
self.__train_model(images, parameters)
|
54 |
+
|
55 |
+
self.get_status().update_status(100, "Training completed")
|
56 |
+
|
57 |
+
except Exception as e:
|
58 |
+
logger.error(e)
|
59 |
+
self.get_status().finalize_abort_training(str(e))
|
60 |
+
|
61 |
+
raise RuntimeError(f"An error occurred: {str(e)}")
|
62 |
+
|
63 |
+
finally:
|
64 |
+
# Cleanup after processing
|
65 |
+
logger.info('Cleaning up training files after training')
|
66 |
+
shutil.rmtree(parameters.get_training_files_path())
|
67 |
+
|
68 |
+
if(self.get_status().is_training_aborted()):
|
69 |
+
self.get_status().finalize_abort_training("Training aborted")
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def __extract_training_data(self, parameters: ImageClassificationParameters):
|
74 |
+
training_file = parameters.get_training_zip_file()
|
75 |
+
|
76 |
+
# Check if it is a valid ZIP file
|
77 |
+
if not zipfile.is_zipfile(training_file):
|
78 |
+
raise RuntimeError("Uploaded file is not a valid zip file")
|
79 |
+
|
80 |
+
# Extract the ZIP file
|
81 |
+
with zipfile.ZipFile(training_file, 'r') as zip_ref:
|
82 |
+
zip_ref.extractall(parameters.get_training_files_path())
|
83 |
+
|
84 |
+
os.remove(training_file)
|
85 |
+
logger.info(os.listdir(parameters.get_training_files_path()))
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
def __prepare_data_set(self, parameters: ImageClassificationParameters) -> dict:
|
90 |
+
|
91 |
+
dataset = load_dataset("imagefolder", data_dir=parameters.get_training_files_path())
|
92 |
+
|
93 |
+
images = dataset["train"]
|
94 |
+
images = images.train_test_split(test_size=0.2)
|
95 |
+
|
96 |
+
logger.info(images)
|
97 |
+
logger.info(images["train"][100])
|
98 |
+
|
99 |
+
|
100 |
+
# Preprocess the images
|
101 |
+
image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name())
|
102 |
+
|
103 |
+
# Apply some image transformations to the images to make the model more robust against overfitting.
|
104 |
+
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
105 |
+
size = (
|
106 |
+
image_processor.size["shortest_edge"]
|
107 |
+
if "shortest_edge" in image_processor.size
|
108 |
+
else (image_processor.size["height"], image_processor.size["width"])
|
109 |
+
)
|
110 |
+
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
|
111 |
+
|
112 |
+
def transforms(examples):
|
113 |
+
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
|
114 |
+
del examples["image"]
|
115 |
+
return examples
|
116 |
+
|
117 |
+
images = images.with_transform(transforms)
|
118 |
+
|
119 |
+
return images
|
120 |
+
|
121 |
+
|
122 |
+
def __train_model(self, images: dict, parameters: ImageClassificationParameters):
|
123 |
+
|
124 |
+
environment_variable_checker = EnvironmentVariableChecker()
|
125 |
+
HfFolder.save_token(environment_variable_checker.get_huggingface_token())
|
126 |
+
|
127 |
+
image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name())
|
128 |
+
data_collator = DefaultDataCollator()
|
129 |
+
progressCallback = ProgressCallback(self.get_status())
|
130 |
+
|
131 |
+
# Evaluate and metrics
|
132 |
+
accuracy = evaluate.load("accuracy")
|
133 |
+
def compute_metrics(eval_pred):
|
134 |
+
predictions, labels = eval_pred
|
135 |
+
predictions = np.argmax(predictions, axis=1)
|
136 |
+
return accuracy.compute(predictions=predictions, references=labels)
|
137 |
+
|
138 |
+
# get label maps
|
139 |
+
labels = images["train"].features["label"].names
|
140 |
+
label2id, id2label = dict(), dict()
|
141 |
+
for i, label in enumerate(labels):
|
142 |
+
label2id[label] = str(i)
|
143 |
+
id2label[str(i)] = label
|
144 |
+
logger.info(id2label)
|
145 |
+
|
146 |
+
# train the model
|
147 |
+
model = AutoModelForImageClassification.from_pretrained(
|
148 |
+
parameters.get_source_model_name(),
|
149 |
+
num_labels=len(labels),
|
150 |
+
id2label=id2label,
|
151 |
+
label2id=label2id,
|
152 |
+
)
|
153 |
+
|
154 |
+
target_model_id = environment_variable_checker.get_huggingface_organization() + '/' + parameters.get_result_model_name()
|
155 |
+
training_args = TrainingArguments(
|
156 |
+
output_dir=parameters.get_result_model_name(),
|
157 |
+
hub_model_id=target_model_id,
|
158 |
+
remove_unused_columns=False,
|
159 |
+
eval_strategy="epoch",
|
160 |
+
save_strategy="epoch",
|
161 |
+
learning_rate=parameters.get_training_parameters().learning_rate,
|
162 |
+
per_device_train_batch_size=16,
|
163 |
+
gradient_accumulation_steps=4,
|
164 |
+
per_device_eval_batch_size=16,
|
165 |
+
num_train_epochs=parameters.get_training_parameters().epochs,
|
166 |
+
warmup_ratio=0.1,
|
167 |
+
logging_steps=10,
|
168 |
+
load_best_model_at_end=True,
|
169 |
+
metric_for_best_model="accuracy",
|
170 |
+
push_to_hub=False,
|
171 |
+
hub_private_repo=True,
|
172 |
+
)
|
173 |
+
|
174 |
+
trainer = Trainer(
|
175 |
+
model=model,
|
176 |
+
args=training_args,
|
177 |
+
data_collator=data_collator,
|
178 |
+
train_dataset=images["train"],
|
179 |
+
eval_dataset=images["test"],
|
180 |
+
tokenizer=image_processor,
|
181 |
+
compute_metrics=compute_metrics,
|
182 |
+
callbacks=[progressCallback]
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
if(self.get_status().is_training_aborted()):
|
187 |
+
return
|
188 |
+
|
189 |
+
trainer.train()
|
190 |
+
|
191 |
+
if(self.get_status().is_training_aborted()):
|
192 |
+
return
|
193 |
+
|
194 |
+
logger.info(f"Model trained, start uploading")
|
195 |
+
self.get_status().update_status(90, f"Uploading model to Hugging Face")
|
196 |
+
trainer.push_to_hub()
|
src/main.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .training_status import Status
|
6 |
+
from .environment_variable_checker import EnvironmentVariableChecker
|
7 |
+
|
8 |
+
from .task_manager import TaskManager
|
9 |
+
from .training_manager import TrainingManager
|
10 |
+
from .image_classification.image_classification_trainer import ImageClassificationTrainer
|
11 |
+
from .image_classification.image_classification_parameters import ImageClassificationParameters, map_image_classification_training_parameters, ImageClassificationTrainingParameters
|
12 |
+
|
13 |
+
from fastapi import FastAPI, Header, Depends, HTTPException, BackgroundTasks, UploadFile, Form, File, status
|
14 |
+
from fastapi.responses import FileResponse
|
15 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
16 |
+
from pydantic import BaseModel
|
17 |
+
from typing import Optional, Annotated
|
18 |
+
|
19 |
+
|
20 |
+
import logging
|
21 |
+
import sys
|
22 |
+
|
23 |
+
import zipfile
|
24 |
+
import os
|
25 |
+
from pathlib import Path
|
26 |
+
import tempfile
|
27 |
+
import shutil
|
28 |
+
|
29 |
+
|
30 |
+
app = FastAPI()
|
31 |
+
|
32 |
+
environmentVariableChecker = EnvironmentVariableChecker()
|
33 |
+
environmentVariableChecker.validate_environment_variables()
|
34 |
+
|
35 |
+
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s')
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
logger.setLevel(logging.DEBUG)
|
38 |
+
|
39 |
+
classification_trainer: TrainingManager = TrainingManager(ImageClassificationTrainer())
|
40 |
+
|
41 |
+
security = HTTPBearer()
|
42 |
+
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
43 |
+
|
44 |
+
token = environmentVariableChecker.get_authentication_token()
|
45 |
+
|
46 |
+
if credentials.credentials != token:
|
47 |
+
raise HTTPException(
|
48 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
49 |
+
detail="Invalid token",
|
50 |
+
headers={"WWW-Authenticate": "Bearer"},
|
51 |
+
)
|
52 |
+
return {"token": credentials.credentials}
|
53 |
+
|
54 |
+
|
55 |
+
class ResponseModel(BaseModel):
|
56 |
+
message: str
|
57 |
+
success: bool = True
|
58 |
+
|
59 |
+
|
60 |
+
@app.post(
|
61 |
+
"/upload",
|
62 |
+
summary="Upload a zip file containing training data",
|
63 |
+
response_model=ResponseModel
|
64 |
+
)
|
65 |
+
async def upload_file(
|
66 |
+
training_params: Annotated[ImageClassificationTrainingParameters, Depends(map_image_classification_training_parameters)],
|
67 |
+
data_files_training: Annotated[UploadFile, File(...)],
|
68 |
+
token_data: dict = Depends(verify_token),
|
69 |
+
result_model_name: str = Form(...),
|
70 |
+
source_model_name: str = Form('google/vit-base-patch16-224-in21k'),
|
71 |
+
):
|
72 |
+
|
73 |
+
# check if training is running, if so then exit
|
74 |
+
status = classification_trainer.get_task_status()
|
75 |
+
if status.get_status() == Status.IN_PROGRESS or status.get_status() == Status.CANCELLING:
|
76 |
+
raise HTTPException(status_code=405, detail="Training is already in progress")
|
77 |
+
|
78 |
+
# Ensure the uploaded file is a ZIP file
|
79 |
+
if not data_files_training.filename.endswith(".zip"):
|
80 |
+
raise HTTPException(status_code=422, detail="Uploaded file is not a zip file")
|
81 |
+
|
82 |
+
try:
|
83 |
+
# Create a temporary directory to extract the contents
|
84 |
+
tmp_path = os.path.join(tempfile.gettempdir(), 'training_data')
|
85 |
+
path = Path(tmp_path)
|
86 |
+
path.mkdir(parents=True, exist_ok=True)
|
87 |
+
|
88 |
+
contents = await data_files_training.read()
|
89 |
+
zip_path = os.path.join(tmp_path, 'image_classification_data.zip')
|
90 |
+
with open(zip_path, 'wb') as temp_file:
|
91 |
+
temp_file.write(contents)
|
92 |
+
|
93 |
+
# prepare parameters
|
94 |
+
parameters = ImageClassificationParameters(
|
95 |
+
training_files_path=tmp_path,
|
96 |
+
training_zip_file_path=zip_path,
|
97 |
+
result_model_name=result_model_name,
|
98 |
+
source_model_name=source_model_name,
|
99 |
+
training_parameters=training_params
|
100 |
+
)
|
101 |
+
|
102 |
+
# start training
|
103 |
+
await classification_trainer.start_training(parameters)
|
104 |
+
|
105 |
+
# TODO add more return parameters and information
|
106 |
+
return ResponseModel(message="training started")
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
110 |
+
|
111 |
+
|
112 |
+
@app.get("/get_task_status")
|
113 |
+
async def get_task_status(token_data: dict = Depends(verify_token)):
|
114 |
+
status = classification_trainer.get_task_status()
|
115 |
+
return {
|
116 |
+
"progress": status.get_progress(),
|
117 |
+
"task": status.get_task(),
|
118 |
+
"status": status.get_status().value
|
119 |
+
}
|
120 |
+
|
121 |
+
|
122 |
+
@app.get("/stop_task")
|
123 |
+
async def stop_task(token_data: dict = Depends(verify_token)):
|
124 |
+
try:
|
125 |
+
classification_trainer.stop_task()
|
126 |
+
return {
|
127 |
+
"success": True
|
128 |
+
}
|
129 |
+
except Exception as e:
|
130 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
@app.get("/gpu_check")
|
135 |
+
async def gpu_check():
|
136 |
+
|
137 |
+
gpu = 'GPU not available'
|
138 |
+
if torch.cuda.is_available():
|
139 |
+
gpu = 'GPU is available'
|
140 |
+
print("GPU is available")
|
141 |
+
else:
|
142 |
+
print("GPU is not available")
|
143 |
+
|
144 |
+
return {'success': True, 'response': 'hello world 3', 'gpu': gpu}
|
145 |
+
|
src/progress_callback.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
3 |
+
|
4 |
+
from .training_status import TrainingStatus
|
5 |
+
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
logger.setLevel(logging.DEBUG)
|
9 |
+
|
10 |
+
class ProgressCallback(TrainerCallback):
|
11 |
+
|
12 |
+
__trainingStatus: TrainingStatus = None
|
13 |
+
|
14 |
+
def __init__(self, trainingStatus: TrainingStatus):
|
15 |
+
self.__trainingStatus = trainingStatus
|
16 |
+
|
17 |
+
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
18 |
+
logger.info(f"Completed step {state.global_step} of {state.max_steps}")
|
19 |
+
|
20 |
+
if self.__trainingStatus.is_training_aborted():
|
21 |
+
control.should_training_stop = True
|
22 |
+
logger.info("Training aborted")
|
23 |
+
return
|
24 |
+
|
25 |
+
startPercentage = 21
|
26 |
+
endPercentage = 89
|
27 |
+
scope = endPercentage - startPercentage
|
28 |
+
progress = startPercentage + (state.global_step / state.max_steps) * scope
|
29 |
+
|
30 |
+
self.__trainingStatus.update_status(progress, f"Training model, completed step {state.global_step} of {state.max_steps}")
|
31 |
+
|
src/task_manager.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
from fastapi import BackgroundTasks, HTTPException
|
4 |
+
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
logger.setLevel(logging.DEBUG)
|
9 |
+
|
10 |
+
|
11 |
+
class Worker:
|
12 |
+
def doing_work(self, task_manager):
|
13 |
+
task_manager.task_status["status"] = "Running"
|
14 |
+
for i in range(1, 101):
|
15 |
+
if task_manager.task_status["status"] == "Stopped":
|
16 |
+
break
|
17 |
+
asyncio.sleep(1) # Simulate a time-consuming task
|
18 |
+
task_manager.task_status["progress"] = i
|
19 |
+
logger.info('process ' + str(i) + '%' + ' done')
|
20 |
+
|
21 |
+
if task_manager.task_status["status"] != "Stopped":
|
22 |
+
task_manager.task_status["status"] = "Completed"
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
class TaskManager:
|
27 |
+
|
28 |
+
task_status = {"progress": 0, "status": "Not started"}
|
29 |
+
task = None
|
30 |
+
|
31 |
+
#def __init__(self):
|
32 |
+
|
33 |
+
worker = Worker()
|
34 |
+
|
35 |
+
async def doing_work(self):
|
36 |
+
loop = asyncio.get_running_loop()
|
37 |
+
with ThreadPoolExecutor() as pool:
|
38 |
+
await loop.run_in_executor(pool, self.worker.doing_work, self)
|
39 |
+
#self.worker.doing_work(self)
|
40 |
+
|
41 |
+
# self.task_status["status"] = "Running"
|
42 |
+
# for i in range(1, 101):
|
43 |
+
# if self.task_status["status"] == "Stopped":
|
44 |
+
# break
|
45 |
+
# await asyncio.sleep(1) # Simulate a time-consuming task
|
46 |
+
# self.task_status["progress"] = i
|
47 |
+
# logger.info('process ' + str(i) + '%' + ' done')
|
48 |
+
|
49 |
+
# if self.task_status["status"] != "Stopped":
|
50 |
+
# self.task_status["status"] = "Completed"
|
51 |
+
|
52 |
+
|
53 |
+
async def start_task(self):
|
54 |
+
if self.task is None or self.task.done():
|
55 |
+
self.task_status["progress"] = 0
|
56 |
+
self.task_status["status"] = "Not started"
|
57 |
+
self.task = asyncio.create_task(self.doing_work())
|
58 |
+
return {"message": "Task started"}
|
59 |
+
else:
|
60 |
+
raise HTTPException(status_code=409, detail="Task already running")
|
61 |
+
|
62 |
+
async def get_task_status(self):
|
63 |
+
return self.task_status
|
64 |
+
|
65 |
+
async def stop_task(self):
|
66 |
+
if self.task is not None and not self.task.done():
|
67 |
+
self.task_status["status"] = "Stopped"
|
68 |
+
self.task.cancel()
|
69 |
+
return {"message": "Task stopped"}
|
70 |
+
else:
|
71 |
+
raise HTTPException(status_code=409, detail="No task running")
|
72 |
+
|
src/training_manager.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import asyncio
|
3 |
+
from .abstract_trainer import AbstractTrainer
|
4 |
+
from .training_status import TrainingStatus
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
logger.setLevel(logging.DEBUG)
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class TrainingManager:
|
14 |
+
|
15 |
+
__training_task = None
|
16 |
+
__trainer: AbstractTrainer = None
|
17 |
+
|
18 |
+
task_status = {"progress": 0, "status": "Not started"}
|
19 |
+
|
20 |
+
def __init__(self, trainer: AbstractTrainer):
|
21 |
+
self.__trainer = trainer
|
22 |
+
|
23 |
+
async def __do_start_training(self, parameters):
|
24 |
+
logger.info('do start training')
|
25 |
+
|
26 |
+
loop = asyncio.get_running_loop()
|
27 |
+
with ThreadPoolExecutor() as pool:
|
28 |
+
await loop.run_in_executor(pool, self.__trainer.start_training, parameters)
|
29 |
+
|
30 |
+
logger.info('done')
|
31 |
+
|
32 |
+
async def start_training(self, parameters):
|
33 |
+
logger.info('start training')
|
34 |
+
|
35 |
+
if self.__training_task is None or self.__training_task.done():
|
36 |
+
self.__training_task = asyncio.create_task(self.__do_start_training(parameters))
|
37 |
+
else:
|
38 |
+
raise RuntimeError("Training already running")
|
39 |
+
|
40 |
+
def get_task_status(self) -> TrainingStatus:
|
41 |
+
return self.__trainer.get_status()
|
42 |
+
|
43 |
+
def stop_task(self):
|
44 |
+
if self.__training_task is not None and not self.__training_task.done():
|
45 |
+
self.__trainer.get_status().abort_training("Stopping training")
|
46 |
+
#self.__training_task.cancel()
|
47 |
+
|
48 |
+
else:
|
49 |
+
raise RuntimeError("No task running")
|
50 |
+
|
src/training_status.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
logger.setLevel(logging.DEBUG)
|
7 |
+
|
8 |
+
class Status(Enum):
|
9 |
+
NOT_STARTED = "NOT_STARTED"
|
10 |
+
IN_PROGRESS = "IN_PROGRESS"
|
11 |
+
CANCELLING = "CANCELLING"
|
12 |
+
CANCELLED = "CANCELLED"
|
13 |
+
COMPLETED = "COMPLETED"
|
14 |
+
|
15 |
+
class TrainingStatus:
|
16 |
+
|
17 |
+
__status: Status = Status.NOT_STARTED
|
18 |
+
__task: str = None
|
19 |
+
__progress: int = 0
|
20 |
+
|
21 |
+
def update_status(self, progress: int, task: str):
|
22 |
+
if progress < 0 or progress > 100:
|
23 |
+
raise ValueError("Progress must be between 0 and 100")
|
24 |
+
|
25 |
+
if progress > 0:
|
26 |
+
self.__status = Status.IN_PROGRESS
|
27 |
+
|
28 |
+
if progress == 100:
|
29 |
+
self.__status = Status.COMPLETED
|
30 |
+
|
31 |
+
self.__progress = progress
|
32 |
+
|
33 |
+
if task is not None:
|
34 |
+
self.__task = task
|
35 |
+
|
36 |
+
|
37 |
+
def abort_training(self, task: str):
|
38 |
+
self.__task = task
|
39 |
+
self.__status = Status.CANCELLING
|
40 |
+
|
41 |
+
def finalize_abort_training(self, task: str):
|
42 |
+
self.__status = Status.CANCELLED
|
43 |
+
self.__progress = 0
|
44 |
+
self.__task = task
|
45 |
+
|
46 |
+
def is_training_aborted(self) -> bool:
|
47 |
+
return (self.__status == Status.CANCELLING)
|
48 |
+
|
49 |
+
def get_status(self) -> str:
|
50 |
+
return self.__status
|
51 |
+
|
52 |
+
def get_progress(self) -> int:
|
53 |
+
return self.__progress
|
54 |
+
|
55 |
+
def get_task(self) -> str:
|
56 |
+
return self.__task
|
57 |
+
|
58 |
+
|