fashxp commited on
Commit
264e02e
·
1 Parent(s): 7c4332a

cleanup and text classification

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Fine Tuning Service
3
  emoji: 🦀
4
  colorFrom: green
5
  colorTo: yellow
@@ -8,5 +8,37 @@ pinned: false
8
  license: other
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
1
  ---
2
+ title: Fine-Tuning Service
3
  emoji: 🦀
4
  colorFrom: green
5
  colorTo: yellow
 
8
  license: other
9
  ---
10
 
11
+ # Pimcore Fine-Tuning Service
12
+
13
+ This app provides endpoints to showcase fine-tuning of models for image and text classification tasks.
14
+ It is possible to execute one training at a time and to get status information via the `/get_training_status` endpoint.
15
+ Via the `/stop_training` endpoint stopping the currently running training is possible. After the training, the fine-tuned model is uploaded to huggingface hub.
16
+
17
+ ## Neccesary Environment Variables
18
+ - `AUTHENTICATION_TOKEN`: Secret that is necessary to authorize calling the apps endpoints.
19
+ - `HUGGINGFACE_TOKEN`: Huggingface token to be used for accessing the huggingface hub and uploading the models.
20
+ - `HUGGINGFACE_ORGANIZATION`: Organization to be used for uploading the fine-tuned models.
21
+
22
+ Further details for parameters of the endpoints see: https://your-domain/docs
23
+
24
+
25
+ ## Image Classification
26
+
27
+ Use `/training/image_classification` for fine tuning a model for image classification tasks.
28
+
29
+ #### Important Parameters
30
+ - `training_data_zip`: The ZIP file containing the training data, with a folder per class which contains images belonging to that class.
31
+ - `project_name`: The name of the project. Will also be used as name of resulting model that will be created after fine tuning and as the name of the repository at huggingface.
32
+ - `source_model_name`: The source model to be used as basis for fine tuning.
33
+
34
+
35
+ ## Text Classification
36
+ Use `/training/text_classification` for fine tuning a model for text classification tasks.
37
+
38
+ #### Important Parameters
39
+ - `training_data_csv`: The CSV file containing the training data, necessary columns `value` (text data) and `target` (classification).
40
+ - `project_name`: The name of the project. Will also be used as name of resulting model that will be created after fine tuning and as the name of the repository at huggingface.
41
+ - `source_model_name`: The source model to be used as basis for fine tuning.
42
+
43
+
44
 
src/image_classification/image_classification_parameters.py CHANGED
@@ -4,14 +4,16 @@ from fastapi import Form
4
 
5
 
6
  class ImageClassificationTrainingParameters(BaseModel):
 
7
  epochs: int
8
  learning_rate: float
9
 
10
 
11
  def map_image_classification_training_parameters(
12
- epocs: Annotated[int, Form(...)] = 3,
13
- learning_rate: Annotated[float, Form(...)] = 5e-5
14
  ) -> ImageClassificationTrainingParameters:
 
15
  return ImageClassificationTrainingParameters(
16
  epochs=epocs,
17
  learning_rate=learning_rate
@@ -19,23 +21,24 @@ def map_image_classification_training_parameters(
19
 
20
 
21
  class ImageClassificationParameters:
 
22
 
23
  __training_files_path: str
24
  __training_zip_file_path: str
25
- __result_model_name: str
26
  __source_model_name: str
27
  __training_parameters: ImageClassificationTrainingParameters
28
 
29
  def __init__(self,
30
  training_files_path: str,
31
  training_zip_file_path: str,
32
- result_model_name: str,
33
  source_model_name: str,
34
  training_parameters: ImageClassificationTrainingParameters
35
  ):
36
  self.__training_files_path = training_files_path
37
  self.__training_zip_file_path = training_zip_file_path
38
- self.__result_model_name = result_model_name
39
  self.__source_model_name = source_model_name
40
  self.__training_parameters = training_parameters
41
 
@@ -46,7 +49,10 @@ class ImageClassificationParameters:
46
  return self.__training_zip_file_path
47
 
48
  def get_result_model_name(self) -> str:
49
- return self.__result_model_name
 
 
 
50
 
51
  def get_source_model_name(self) -> str:
52
  return self.__source_model_name
 
4
 
5
 
6
  class ImageClassificationTrainingParameters(BaseModel):
7
+ """ Provides specific training parameters for the image classification fine tuning."""
8
  epochs: int
9
  learning_rate: float
10
 
11
 
12
  def map_image_classification_training_parameters(
13
+ epocs: Annotated[int, Form(description="Epochs executed during training.")] = 3,
14
+ learning_rate: Annotated[float, Form(description="Learning rate for training.")] = 5e-5
15
  ) -> ImageClassificationTrainingParameters:
16
+ """ Maps the parameters to the ImageClassificationTrainingParameters class. """
17
  return ImageClassificationTrainingParameters(
18
  epochs=epocs,
19
  learning_rate=learning_rate
 
21
 
22
 
23
  class ImageClassificationParameters:
24
+ """ Provides all parameters for the image classification fine tuning. """
25
 
26
  __training_files_path: str
27
  __training_zip_file_path: str
28
+ __project_name: str
29
  __source_model_name: str
30
  __training_parameters: ImageClassificationTrainingParameters
31
 
32
  def __init__(self,
33
  training_files_path: str,
34
  training_zip_file_path: str,
35
+ project_name: str,
36
  source_model_name: str,
37
  training_parameters: ImageClassificationTrainingParameters
38
  ):
39
  self.__training_files_path = training_files_path
40
  self.__training_zip_file_path = training_zip_file_path
41
+ self.__project_name = project_name
42
  self.__source_model_name = source_model_name
43
  self.__training_parameters = training_parameters
44
 
 
49
  return self.__training_zip_file_path
50
 
51
  def get_result_model_name(self) -> str:
52
+ return self.__project_name
53
+
54
+ def get_project_name(self) -> str:
55
+ return self.__project_name
56
 
57
  def get_source_model_name(self) -> str:
58
  return self.__source_model_name
src/image_classification/image_classification_trainer.py CHANGED
@@ -29,7 +29,7 @@ class ImageClassificationTrainer(AbstractTrainer):
29
 
30
  try:
31
  task = 'Extract training data'
32
- self.get_status().update_status(0, task)
33
  logger.info(task)
34
 
35
  self.__extract_training_data(parameters)
@@ -63,7 +63,7 @@ class ImageClassificationTrainer(AbstractTrainer):
63
  finally:
64
  # Cleanup after processing
65
  logger.info('Cleaning up training files after training')
66
- shutil.rmtree(parameters.get_training_files_path())
67
 
68
  if(self.get_status().is_training_aborted()):
69
  self.get_status().finalize_abort_training("Training aborted")
@@ -94,7 +94,7 @@ class ImageClassificationTrainer(AbstractTrainer):
94
  images = images.train_test_split(test_size=0.2)
95
 
96
  logger.info(images)
97
- logger.info(images["train"][100])
98
 
99
 
100
  # Preprocess the images
@@ -126,7 +126,7 @@ class ImageClassificationTrainer(AbstractTrainer):
126
 
127
  image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name())
128
  data_collator = DefaultDataCollator()
129
- progressCallback = ProgressCallback(self.get_status())
130
 
131
  # Evaluate and metrics
132
  accuracy = evaluate.load("accuracy")
 
29
 
30
  try:
31
  task = 'Extract training data'
32
+ self.get_status().update_status(0, task, parameters.get_project_name())
33
  logger.info(task)
34
 
35
  self.__extract_training_data(parameters)
 
63
  finally:
64
  # Cleanup after processing
65
  logger.info('Cleaning up training files after training')
66
+ shutil.rmtree(parameters.get_training_files_path())
67
 
68
  if(self.get_status().is_training_aborted()):
69
  self.get_status().finalize_abort_training("Training aborted")
 
94
  images = images.train_test_split(test_size=0.2)
95
 
96
  logger.info(images)
97
+ logger.info(images["train"][10])
98
 
99
 
100
  # Preprocess the images
 
126
 
127
  image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name())
128
  data_collator = DefaultDataCollator()
129
+ progressCallback = ProgressCallback(self.get_status(), 21, 89)
130
 
131
  # Evaluate and metrics
132
  accuracy = evaluate.load("accuracy")
src/main.py CHANGED
@@ -1,30 +1,26 @@
1
  import os
2
- import requests
3
  import torch
4
 
5
  from .training_status import Status
6
  from .environment_variable_checker import EnvironmentVariableChecker
7
 
8
- from .task_manager import TaskManager
9
  from .training_manager import TrainingManager
10
  from .image_classification.image_classification_trainer import ImageClassificationTrainer
11
  from .image_classification.image_classification_parameters import ImageClassificationParameters, map_image_classification_training_parameters, ImageClassificationTrainingParameters
 
 
12
 
13
- from fastapi import FastAPI, Header, Depends, HTTPException, BackgroundTasks, UploadFile, Form, File, status
14
- from fastapi.responses import FileResponse
15
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
16
  from pydantic import BaseModel
17
- from typing import Optional, Annotated
18
 
19
 
20
  import logging
21
- import sys
22
-
23
- import zipfile
24
  import os
25
  from pathlib import Path
26
  import tempfile
27
- import shutil
28
 
29
 
30
  app = FastAPI()
@@ -36,10 +32,22 @@ logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s')
36
  logger = logging.getLogger(__name__)
37
  logger.setLevel(logging.DEBUG)
38
 
39
- classification_trainer: TrainingManager = TrainingManager(ImageClassificationTrainer())
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  security = HTTPBearer()
42
  def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
 
43
 
44
  token = environmentVariableChecker.get_authentication_token()
45
 
@@ -52,32 +60,73 @@ def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
52
  return {"token": credentials.credentials}
53
 
54
 
55
- class ResponseModel(BaseModel):
56
- message: str
57
- success: bool = True
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @app.post(
61
- "/upload",
62
- summary="Upload a zip file containing training data",
63
  response_model=ResponseModel
64
  )
65
- async def upload_file(
66
  training_params: Annotated[ImageClassificationTrainingParameters, Depends(map_image_classification_training_parameters)],
67
- data_files_training: Annotated[UploadFile, File(...)],
68
  token_data: dict = Depends(verify_token),
69
- result_model_name: str = Form(...),
70
- source_model_name: str = Form('google/vit-base-patch16-224-in21k'),
71
  ):
 
 
 
72
 
73
  # check if training is running, if so then exit
74
  status = classification_trainer.get_task_status()
75
  if status.get_status() == Status.IN_PROGRESS or status.get_status() == Status.CANCELLING:
76
- raise HTTPException(status_code=405, detail="Training is already in progress")
77
 
78
  # Ensure the uploaded file is a ZIP file
79
- if not data_files_training.filename.endswith(".zip"):
80
- raise HTTPException(status_code=422, detail="Uploaded file is not a zip file")
81
 
82
  try:
83
  # Create a temporary directory to extract the contents
@@ -85,7 +134,7 @@ async def upload_file(
85
  path = Path(tmp_path)
86
  path.mkdir(parents=True, exist_ok=True)
87
 
88
- contents = await data_files_training.read()
89
  zip_path = os.path.join(tmp_path, 'image_classification_data.zip')
90
  with open(zip_path, 'wb') as temp_file:
91
  temp_file.write(contents)
@@ -94,52 +143,74 @@ async def upload_file(
94
  parameters = ImageClassificationParameters(
95
  training_files_path=tmp_path,
96
  training_zip_file_path=zip_path,
97
- result_model_name=result_model_name,
98
  source_model_name=source_model_name,
99
  training_parameters=training_params
100
  )
101
 
102
  # start training
103
- await classification_trainer.start_training(parameters)
104
 
105
- # TODO add more return parameters and information
106
- return ResponseModel(message="training started")
107
 
108
  except Exception as e:
109
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
110
 
111
 
112
- @app.get("/get_task_status")
113
- async def get_task_status(token_data: dict = Depends(verify_token)):
114
- status = classification_trainer.get_task_status()
115
- return {
116
- "progress": status.get_progress(),
117
- "task": status.get_task(),
118
- "status": status.get_status().value
119
- }
120
 
121
 
122
- @app.get("/stop_task")
123
- async def stop_task(token_data: dict = Depends(verify_token)):
124
- try:
125
- classification_trainer.stop_task()
126
- return {
127
- "success": True
128
- }
129
- except Exception as e:
130
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
 
 
 
 
133
 
134
- @app.get("/gpu_check")
135
- async def gpu_check():
 
136
 
137
- gpu = 'GPU not available'
138
- if torch.cuda.is_available():
139
- gpu = 'GPU is available'
140
- print("GPU is available")
141
- else:
142
- print("GPU is not available")
 
 
 
 
143
 
144
- return {'success': True, 'response': 'hello world 3', 'gpu': gpu}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
 
1
  import os
 
2
  import torch
3
 
4
  from .training_status import Status
5
  from .environment_variable_checker import EnvironmentVariableChecker
6
 
 
7
  from .training_manager import TrainingManager
8
  from .image_classification.image_classification_trainer import ImageClassificationTrainer
9
  from .image_classification.image_classification_parameters import ImageClassificationParameters, map_image_classification_training_parameters, ImageClassificationTrainingParameters
10
+ from .text_classification.text_classification_trainer import TextClassificationTrainer
11
+ from .text_classification.text_classification_parameters import TextClassificationParameters, map_text_classification_training_parameters, TextClassificationTrainingParameters
12
 
13
+
14
+ from fastapi import FastAPI, Depends, HTTPException, UploadFile, Form, File, status
15
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
16
  from pydantic import BaseModel
17
+ from typing import Annotated
18
 
19
 
20
  import logging
 
 
 
21
  import os
22
  from pathlib import Path
23
  import tempfile
 
24
 
25
 
26
  app = FastAPI()
 
32
  logger = logging.getLogger(__name__)
33
  logger.setLevel(logging.DEBUG)
34
 
35
+ classification_trainer: TrainingManager = TrainingManager()
36
+
37
+
38
+ class ResponseModel(BaseModel):
39
+ """ Default pesponse model for endpoints. """
40
+ message: str
41
+ success: bool = True
42
+
43
+
44
+ # ===========================================
45
+ # Security Check
46
+ # ===========================================
47
 
48
  security = HTTPBearer()
49
  def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
50
+ """Verify the token provided by the user."""
51
 
52
  token = environmentVariableChecker.get_authentication_token()
53
 
 
60
  return {"token": credentials.credentials}
61
 
62
 
63
+ # ===========================================
64
+ # Training Status Endpoints
65
+ # ===========================================
66
 
67
+ @app.get("/get_training_status")
68
+ async def get_task_status(token_data: dict = Depends(verify_token)):
69
+ """ Get the status of the currently running training (if any). """
70
+ status = classification_trainer.get_task_status()
71
+ return {
72
+ "project": status.get_project_name(),
73
+ "progress": status.get_progress(),
74
+ "task": status.get_task(),
75
+ "status": status.get_status().value
76
+ }
77
+
78
+ @app.get("/stop_training")
79
+ async def stop_task(token_data: dict = Depends(verify_token)):
80
+ """ Stop the currently running training (if any). """
81
+ try:
82
+ status = classification_trainer.get_task_status()
83
+ classification_trainer.stop_task()
84
+ return ResponseModel(message=f"Training stopped for `{ status.get_project_name() }`")
85
+ except Exception as e:
86
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
87
+
88
+
89
+ @app.get("/gpu_check")
90
+ async def gpu_check():
91
+ """ Check if a GPU is available """
92
+
93
+ gpu = 'GPU not available'
94
+ if torch.cuda.is_available():
95
+ gpu = 'GPU is available'
96
+ print("GPU is available")
97
+ else:
98
+ print("GPU is not available")
99
+
100
+ return {'success': True, 'gpu': gpu}
101
+
102
+
103
+ # ===========================================
104
+ # Fine-Tuning Image Classification
105
+ # ===========================================
106
 
107
  @app.post(
108
+ "/training/image_classification",
 
109
  response_model=ResponseModel
110
  )
111
+ async def image_classification(
112
  training_params: Annotated[ImageClassificationTrainingParameters, Depends(map_image_classification_training_parameters)],
113
+ training_data_zip: Annotated[UploadFile, File(description="The ZIP file containing the training data, with a folder per class which contains images belonging to that class.")],
114
  token_data: dict = Depends(verify_token),
115
+ project_name: str = Form(description="The name of the project. Will also be used as name of resulting model that will be created after fine tuning and as the name of the repository at huggingface."),
116
+ source_model_name: str = Form('google/vit-base-patch16-224-in21k', description="The source model to be used as basis for fine tuning."),
117
  ):
118
+ """
119
+ Start fine tuning an image classification model with the provided data.
120
+ """
121
 
122
  # check if training is running, if so then exit
123
  status = classification_trainer.get_task_status()
124
  if status.get_status() == Status.IN_PROGRESS or status.get_status() == Status.CANCELLING:
125
+ raise HTTPException(status_code=405, detail="Training is already in progress.")
126
 
127
  # Ensure the uploaded file is a ZIP file
128
+ if not training_data_zip.filename.endswith(".zip"):
129
+ raise HTTPException(status_code=422, detail="Uploaded file is not a zip file.")
130
 
131
  try:
132
  # Create a temporary directory to extract the contents
 
134
  path = Path(tmp_path)
135
  path.mkdir(parents=True, exist_ok=True)
136
 
137
+ contents = await training_data_zip.read()
138
  zip_path = os.path.join(tmp_path, 'image_classification_data.zip')
139
  with open(zip_path, 'wb') as temp_file:
140
  temp_file.write(contents)
 
143
  parameters = ImageClassificationParameters(
144
  training_files_path=tmp_path,
145
  training_zip_file_path=zip_path,
146
+ project_name=project_name,
147
  source_model_name=source_model_name,
148
  training_parameters=training_params
149
  )
150
 
151
  # start training
152
+ await classification_trainer.start_training(ImageClassificationTrainer(), parameters)
153
 
154
+ return ResponseModel(message="Training started.")
 
155
 
156
  except Exception as e:
157
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
158
 
159
 
 
 
 
 
 
 
 
 
160
 
161
 
162
+ # ===========================================
163
+ # Fine-Tuning Text Classification
164
+ # ===========================================
 
 
 
 
 
 
165
 
166
+ @app.post(
167
+ "/training/text_classification",
168
+ response_model=ResponseModel
169
+ )
170
+ async def text_classificaiton(
171
+ training_params: Annotated[TextClassificationTrainingParameters, Depends(map_text_classification_training_parameters)],
172
+ training_data_csv: Annotated[UploadFile, File(description="The CSV file containing the training data, necessary columns `value` (text data) and `target` (classification).")],
173
+ token_data: dict = Depends(verify_token),
174
+ project_name: str = Form(description="The name of the project. Will also be used as name of resulting model that will be created after fine tuning and as the name of the repository at huggingface."),
175
+ training_csv_limiter: str = Form(';', description="The delimiter used in the CSV file."),
176
+ source_model_name: str = Form('distilbert/distilbert-base-uncased'),
177
+ ):
178
+ """Start fine tuning an text classification model with the provided data."""
179
 
180
+ # check if training is running, if so then exit
181
+ status = classification_trainer.get_task_status()
182
+ if status.get_status() == Status.IN_PROGRESS or status.get_status() == Status.CANCELLING:
183
+ raise HTTPException(status_code=405, detail="Training is already in progress")
184
 
185
+ # Ensure the uploaded file is a CSV file
186
+ if not training_data_csv.filename.endswith(".csv"):
187
+ raise HTTPException(status_code=422, detail="Uploaded file is not a csv file.")
188
 
189
+ try:
190
+ # Create a temporary directory to extract the contents
191
+ tmp_path = os.path.join(tempfile.gettempdir(), 'training_data')
192
+ path = Path(tmp_path)
193
+ path.mkdir(parents=True, exist_ok=True)
194
+
195
+ contents = await training_data_csv.read()
196
+ csv_path = os.path.join(tmp_path, 'data.csv')
197
+ with open(csv_path, 'wb') as temp_file:
198
+ temp_file.write(contents)
199
 
200
+ # prepare parameters
201
+ parameters = TextClassificationParameters(
202
+ training_csv_file_path=csv_path,
203
+ training_csv_limiter=training_csv_limiter,
204
+ project_name=project_name,
205
+ source_model_name=source_model_name,
206
+ training_parameters=training_params
207
+ )
208
+
209
+ # start training
210
+ await classification_trainer.start_training(TextClassificationTrainer(), parameters)
211
+
212
+ return ResponseModel(message="Training started.")
213
+
214
+ except Exception as e:
215
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
216
 
src/progress_callback.py CHANGED
@@ -10,9 +10,13 @@ logger.setLevel(logging.DEBUG)
10
  class ProgressCallback(TrainerCallback):
11
 
12
  __trainingStatus: TrainingStatus = None
 
 
13
 
14
- def __init__(self, trainingStatus: TrainingStatus):
15
  self.__trainingStatus = trainingStatus
 
 
16
 
17
  def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
18
  logger.info(f"Completed step {state.global_step} of {state.max_steps}")
@@ -22,10 +26,8 @@ class ProgressCallback(TrainerCallback):
22
  logger.info("Training aborted")
23
  return
24
 
25
- startPercentage = 21
26
- endPercentage = 89
27
- scope = endPercentage - startPercentage
28
- progress = startPercentage + (state.global_step / state.max_steps) * scope
29
 
30
  self.__trainingStatus.update_status(progress, f"Training model, completed step {state.global_step} of {state.max_steps}")
31
 
 
10
  class ProgressCallback(TrainerCallback):
11
 
12
  __trainingStatus: TrainingStatus = None
13
+ __startPercentage: int = None
14
+ __endPercentage: int = None
15
 
16
+ def __init__(self, trainingStatus: TrainingStatus, startPercentage: int, endPercentage: int):
17
  self.__trainingStatus = trainingStatus
18
+ self.__startPercentage = startPercentage
19
+ self.__endPercentage = endPercentage
20
 
21
  def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
22
  logger.info(f"Completed step {state.global_step} of {state.max_steps}")
 
26
  logger.info("Training aborted")
27
  return
28
 
29
+ scope = self.__endPercentage - self.__startPercentage
30
+ progress = round(self.__startPercentage + (state.global_step / state.max_steps) * scope, 2)
 
 
31
 
32
  self.__trainingStatus.update_status(progress, f"Training model, completed step {state.global_step} of {state.max_steps}")
33
 
src/task_manager.py DELETED
@@ -1,72 +0,0 @@
1
- import asyncio
2
- import logging
3
- from fastapi import BackgroundTasks, HTTPException
4
-
5
- from concurrent.futures import ThreadPoolExecutor
6
-
7
- logger = logging.getLogger(__name__)
8
- logger.setLevel(logging.DEBUG)
9
-
10
-
11
- class Worker:
12
- def doing_work(self, task_manager):
13
- task_manager.task_status["status"] = "Running"
14
- for i in range(1, 101):
15
- if task_manager.task_status["status"] == "Stopped":
16
- break
17
- asyncio.sleep(1) # Simulate a time-consuming task
18
- task_manager.task_status["progress"] = i
19
- logger.info('process ' + str(i) + '%' + ' done')
20
-
21
- if task_manager.task_status["status"] != "Stopped":
22
- task_manager.task_status["status"] = "Completed"
23
-
24
-
25
-
26
- class TaskManager:
27
-
28
- task_status = {"progress": 0, "status": "Not started"}
29
- task = None
30
-
31
- #def __init__(self):
32
-
33
- worker = Worker()
34
-
35
- async def doing_work(self):
36
- loop = asyncio.get_running_loop()
37
- with ThreadPoolExecutor() as pool:
38
- await loop.run_in_executor(pool, self.worker.doing_work, self)
39
- #self.worker.doing_work(self)
40
-
41
- # self.task_status["status"] = "Running"
42
- # for i in range(1, 101):
43
- # if self.task_status["status"] == "Stopped":
44
- # break
45
- # await asyncio.sleep(1) # Simulate a time-consuming task
46
- # self.task_status["progress"] = i
47
- # logger.info('process ' + str(i) + '%' + ' done')
48
-
49
- # if self.task_status["status"] != "Stopped":
50
- # self.task_status["status"] = "Completed"
51
-
52
-
53
- async def start_task(self):
54
- if self.task is None or self.task.done():
55
- self.task_status["progress"] = 0
56
- self.task_status["status"] = "Not started"
57
- self.task = asyncio.create_task(self.doing_work())
58
- return {"message": "Task started"}
59
- else:
60
- raise HTTPException(status_code=409, detail="Task already running")
61
-
62
- async def get_task_status(self):
63
- return self.task_status
64
-
65
- async def stop_task(self):
66
- if self.task is not None and not self.task.done():
67
- self.task_status["status"] = "Stopped"
68
- self.task.cancel()
69
- return {"message": "Task stopped"}
70
- else:
71
- raise HTTPException(status_code=409, detail="No task running")
72
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/text_classification/text_classification_parameters.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Annotated
3
+ from fastapi import Form
4
+
5
+
6
+ class TextClassificationTrainingParameters(BaseModel):
7
+ """ Provides specific training parameters for the text classification fine tuning."""
8
+ epochs: int
9
+ learning_rate: float
10
+
11
+
12
+ def map_text_classification_training_parameters(
13
+ epocs: Annotated[int, Form(description="Epochs executed during training.")] = 3,
14
+ learning_rate: Annotated[float, Form(description="Learning rate for training.")] = 5e-5
15
+ ) -> TextClassificationTrainingParameters:
16
+ """ Maps the parameters to the TextClassificationTrainingParameters class. """
17
+ return TextClassificationTrainingParameters(
18
+ epochs=epocs,
19
+ learning_rate=learning_rate
20
+ )
21
+
22
+
23
+ class TextClassificationParameters:
24
+ """ Provides all parameters for the text classification fine tuning. """
25
+
26
+ __training_csv_file_path: str
27
+ __training_csv_limiter: str
28
+ __project_name: str
29
+ __source_model_name: str
30
+ __training_parameters: TextClassificationTrainingParameters
31
+
32
+ def __init__(self,
33
+ training_csv_file_path: str,
34
+ project_name: str,
35
+ source_model_name: str,
36
+ training_parameters: TextClassificationTrainingParameters,
37
+ training_csv_limiter: str = ';'
38
+ ):
39
+ self.__training_csv_file_path = training_csv_file_path
40
+ self.__project_name = project_name
41
+ self.__source_model_name = source_model_name
42
+ self.__training_parameters = training_parameters
43
+ self.__training_csv_limiter = training_csv_limiter
44
+
45
+ def get_training_csv_file_path(self) -> str:
46
+ return self.__training_csv_file_path
47
+
48
+ def get_training_csv_limiter(self) -> str:
49
+ return self.__training_csv_limiter
50
+
51
+ def get_project_name(self) -> str:
52
+ return self.__project_name
53
+
54
+ def get_result_model_name(self) -> str:
55
+ return self.__project_name
56
+
57
+ def get_source_model_name(self) -> str:
58
+ return self.__source_model_name
59
+
60
+ def get_training_parameters(self) -> TextClassificationTrainingParameters:
61
+ return self.__training_parameters
62
+
63
+
src/text_classification/text_classification_trainer.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from ..progress_callback import ProgressCallback
4
+ from ..abstract_trainer import AbstractTrainer
5
+ from ..environment_variable_checker import EnvironmentVariableChecker
6
+ from .text_classification_parameters import TextClassificationParameters
7
+
8
+ import shutil
9
+ import os
10
+
11
+ from datasets import load_dataset
12
+ from transformers import DataCollatorWithPadding, AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
13
+ from huggingface_hub import HfFolder
14
+
15
+ import evaluate
16
+ import numpy as np
17
+
18
+ from typing import Tuple
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logger.setLevel(logging.DEBUG)
23
+
24
+
25
+ class TextClassificationTrainer(AbstractTrainer):
26
+
27
+ def start_training(self, parameters: TextClassificationParameters):
28
+
29
+ logger.info('Start Training...')
30
+
31
+ try:
32
+ task = 'Load and prepare training data'
33
+ self.get_status().update_status(0, task, parameters.get_project_name())
34
+ logger.info(task)
35
+
36
+ tokenized_dataset, labels, label2id, id2label = self.__prepare_training_data(parameters)
37
+
38
+ if(self.get_status().is_training_aborted()):
39
+ return
40
+
41
+ task = 'Start training model'
42
+ self.get_status().update_status(10, task)
43
+ logger.info(task)
44
+
45
+ self.__train_model(tokenized_dataset, labels, label2id, id2label, parameters)
46
+
47
+ self.get_status().update_status(100, "Training completed")
48
+
49
+ except Exception as e:
50
+ logger.error(e)
51
+ self.get_status().finalize_abort_training(str(e))
52
+
53
+ raise RuntimeError(f"An error occurred: {str(e)}")
54
+
55
+ finally:
56
+ # Cleanup after processing
57
+ logger.info('Cleaning up training files after training')
58
+ shutil.rmtree(os.path.dirname(parameters.get_training_csv_file_path()))
59
+
60
+ if(self.get_status().is_training_aborted()):
61
+ self.get_status().finalize_abort_training("Training aborted")
62
+
63
+
64
+ def __prepare_training_data(self, parameters: TextClassificationParameters) -> Tuple[dict, dict, dict, dict]:
65
+
66
+ dataset = load_dataset('csv', data_files=parameters.get_training_csv_file_path(), delimiter=parameters.get_training_csv_limiter())
67
+
68
+ dataset = dataset["train"]
69
+ dataset = dataset.train_test_split(test_size=0.2)
70
+
71
+ logger.info(dataset)
72
+ logger.info(dataset["train"][10])
73
+
74
+ # Tokenize the value column
75
+ tokenizer = AutoTokenizer.from_pretrained(parameters.get_source_model_name())
76
+
77
+ def preprocess_function(examples):
78
+ return tokenizer(examples["value"], truncation=True, padding='max_length')
79
+
80
+ tokenized_dataset = dataset.map(preprocess_function, batched=True)
81
+
82
+ # Extract the labels
83
+ labels = tokenized_dataset['train'].unique('target')
84
+ label2id, id2label = dict(), dict()
85
+ for i, label in enumerate(labels):
86
+ label2id[label] = i
87
+ id2label[i] = label
88
+
89
+ logger.info(id2label)
90
+
91
+ # Rename the Target column to labels and remove unnecessary columns
92
+ tokenized_dataset = tokenized_dataset.rename_column('target', 'labels')
93
+
94
+ # Columns to keep
95
+ columns_to_keep = ['input_ids', 'labels', 'attention_mask']
96
+ all_columns = tokenized_dataset["train"].column_names
97
+ columns_to_remove = [col for col in all_columns if col not in columns_to_keep]
98
+ tokenized_dataset = tokenized_dataset.remove_columns(columns_to_remove)
99
+
100
+ # Map labels to numeric ids
101
+ def map_labels(example):
102
+ example['labels'] = label2id[example['labels']]
103
+ return example
104
+ tokenized_dataset = tokenized_dataset.map(map_labels)
105
+
106
+ logger.info(tokenized_dataset)
107
+ logger.info(tokenized_dataset["train"][10])
108
+
109
+ return tokenized_dataset, labels, label2id, id2label
110
+
111
+
112
+ def __train_model(self, tokenized_dataset: dict, labels: dict, label2id: dict, id2label: dict, parameters: TextClassificationParameters):
113
+ tokenizer = AutoTokenizer.from_pretrained(parameters.get_source_model_name())
114
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
115
+
116
+ environment_variable_checker = EnvironmentVariableChecker()
117
+ HfFolder.save_token(environment_variable_checker.get_huggingface_token())
118
+
119
+ progressCallback = ProgressCallback(self.get_status(), 11, 89)
120
+
121
+ # Evaluate and metrics
122
+ accuracy = evaluate.load("accuracy")
123
+ def compute_metrics(eval_pred):
124
+ predictions, labels = eval_pred
125
+ predictions = np.argmax(predictions, axis=1)
126
+ return accuracy.compute(predictions=predictions, references=labels)
127
+
128
+
129
+ # train the model
130
+ model = AutoModelForSequenceClassification.from_pretrained(
131
+ parameters.get_source_model_name(),
132
+ num_labels=len(labels),
133
+ id2label=id2label,
134
+ label2id=label2id
135
+ )
136
+
137
+ target_model_id = environment_variable_checker.get_huggingface_organization() + '/' + parameters.get_result_model_name()
138
+ training_args = TrainingArguments(
139
+ output_dir=parameters.get_result_model_name(),
140
+ hub_model_id=target_model_id,
141
+ learning_rate=parameters.get_training_parameters().learning_rate,
142
+ per_device_train_batch_size=16,
143
+ per_device_eval_batch_size=16,
144
+ num_train_epochs=parameters.get_training_parameters().epochs,
145
+ weight_decay=0.01,
146
+ eval_strategy="epoch",
147
+ save_strategy="epoch",
148
+ load_best_model_at_end=True,
149
+ metric_for_best_model="accuracy",
150
+ push_to_hub=False,
151
+ remove_unused_columns=False,
152
+ hub_private_repo=True,
153
+ )
154
+
155
+ trainer = Trainer(
156
+ model=model,
157
+ args=training_args,
158
+ train_dataset=tokenized_dataset["train"],
159
+ eval_dataset=tokenized_dataset["test"],
160
+ tokenizer=tokenizer,
161
+ data_collator=data_collator,
162
+ compute_metrics=compute_metrics,
163
+ callbacks=[progressCallback]
164
+ )
165
+
166
+
167
+ if(self.get_status().is_training_aborted()):
168
+ return
169
+
170
+ trainer.train()
171
+
172
+ if(self.get_status().is_training_aborted()):
173
+ return
174
+
175
+ logger.info(f"Model trained, start uploading")
176
+ self.get_status().update_status(90, f"Uploading model to Hugging Face")
177
+ trainer.push_to_hub()
src/training_manager.py CHANGED
@@ -15,11 +15,6 @@ class TrainingManager:
15
  __training_task = None
16
  __trainer: AbstractTrainer = None
17
 
18
- task_status = {"progress": 0, "status": "Not started"}
19
-
20
- def __init__(self, trainer: AbstractTrainer):
21
- self.__trainer = trainer
22
-
23
  async def __do_start_training(self, parameters):
24
  logger.info('do start training')
25
 
@@ -29,22 +24,28 @@ class TrainingManager:
29
 
30
  logger.info('done')
31
 
32
- async def start_training(self, parameters):
33
  logger.info('start training')
34
-
35
  if self.__training_task is None or self.__training_task.done():
 
36
  self.__training_task = asyncio.create_task(self.__do_start_training(parameters))
37
  else:
38
- raise RuntimeError("Training already running")
39
 
40
  def get_task_status(self) -> TrainingStatus:
 
 
 
 
41
  return self.__trainer.get_status()
42
 
43
  def stop_task(self):
44
- if self.__training_task is not None and not self.__training_task.done():
45
  self.__trainer.get_status().abort_training("Stopping training")
46
  #self.__training_task.cancel()
47
 
48
  else:
49
- raise RuntimeError("No task running")
50
-
 
 
15
  __training_task = None
16
  __trainer: AbstractTrainer = None
17
 
 
 
 
 
 
18
  async def __do_start_training(self, parameters):
19
  logger.info('do start training')
20
 
 
24
 
25
  logger.info('done')
26
 
27
+ async def start_training(self, trainer: AbstractTrainer, parameters):
28
  logger.info('start training')
29
+
30
  if self.__training_task is None or self.__training_task.done():
31
+ self.__trainer = trainer
32
  self.__training_task = asyncio.create_task(self.__do_start_training(parameters))
33
  else:
34
+ raise RuntimeError("Training already running.")
35
 
36
  def get_task_status(self) -> TrainingStatus:
37
+
38
+ if self.__trainer is None:
39
+ return TrainingStatus()
40
+
41
  return self.__trainer.get_status()
42
 
43
  def stop_task(self):
44
+ if self.__training_task is not None and not self.__training_task.done() and self.__trainer is not None:
45
  self.__trainer.get_status().abort_training("Stopping training")
46
  #self.__training_task.cancel()
47
 
48
  else:
49
+ raise RuntimeError("No task running.")
50
+
51
+
src/training_status.py CHANGED
@@ -15,10 +15,12 @@ class Status(Enum):
15
  class TrainingStatus:
16
 
17
  __status: Status = Status.NOT_STARTED
 
18
  __task: str = None
19
  __progress: int = 0
20
 
21
- def update_status(self, progress: int, task: str):
 
22
  if progress < 0 or progress > 100:
23
  raise ValueError("Progress must be between 0 and 100")
24
 
@@ -33,6 +35,9 @@ class TrainingStatus:
33
  if task is not None:
34
  self.__task = task
35
 
 
 
 
36
 
37
  def abort_training(self, task: str):
38
  self.__task = task
@@ -55,4 +60,6 @@ class TrainingStatus:
55
  def get_task(self) -> str:
56
  return self.__task
57
 
 
 
58
 
 
15
  class TrainingStatus:
16
 
17
  __status: Status = Status.NOT_STARTED
18
+ __project_name: str = None
19
  __task: str = None
20
  __progress: int = 0
21
 
22
+
23
+ def update_status(self, progress: int, task: str, project_name: str = None):
24
  if progress < 0 or progress > 100:
25
  raise ValueError("Progress must be between 0 and 100")
26
 
 
35
  if task is not None:
36
  self.__task = task
37
 
38
+ if project_name is not None:
39
+ self.__project_name = project_name
40
+
41
 
42
  def abort_training(self, task: str):
43
  self.__task = task
 
60
  def get_task(self) -> str:
61
  return self.__task
62
 
63
+ def get_project_name(self) -> str:
64
+ return self.__project_name
65