fashxp commited on
Commit
fef773e
·
1 Parent(s): 5c263d5

additional tasks

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -0
  2. src/main.py +223 -192
requirements.txt CHANGED
@@ -5,4 +5,5 @@ transformers
5
  sentencepiece
6
  sacremoses
7
  torch
 
8
  # Optional dependencies for specific features
 
5
  sentencepiece
6
  sacremoses
7
  torch
8
+ pillow
9
  # Optional dependencies for specific features
src/main.py CHANGED
@@ -11,26 +11,15 @@
11
  import os
12
  import torch
13
 
14
- #from .training_status import Status
15
- #from .environment_variable_checker import EnvironmentVariableChecker
16
-
17
- #from .training_manager import TrainingManager
18
- #from .image_classification.image_classification_trainer import ImageClassificationTrainer
19
- #from .image_classification.image_classification_parameters import ImageClassificationParameters, map_image_classification_training_parameters, ImageClassificationTrainingParameters
20
- #from .text_classification.text_classification_trainer import TextClassificationTrainer
21
- #from .text_classification.text_classification_parameters import TextClassificationParameters, map_text_classification_training_parameters, TextClassificationTrainingParameters
22
-
23
-
24
- from fastapi import FastAPI, Depends, HTTPException, UploadFile, Form, File, status
25
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
26
  from pydantic import BaseModel
27
  from typing import Annotated
28
-
29
 
30
  import logging
31
- from pathlib import Path
32
- import tempfile
33
  import sys
 
34
 
35
 
36
  from transformers import pipeline
@@ -41,9 +30,6 @@ app = FastAPI(
41
  version="1.0.0"
42
  )
43
 
44
- #environmentVariableChecker = EnvironmentVariableChecker()
45
- #environmentVariableChecker.validate_environment_variables()
46
-
47
  logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s')
48
  logger = logging.getLogger(__name__)
49
  logger.setLevel(logging.DEBUG)
@@ -65,7 +51,6 @@ class StreamToLogger(object):
65
  sys.stdout = StreamToLogger(logger, logging.INFO)
66
  sys.stderr = StreamToLogger(logger, logging.ERROR)
67
 
68
- #classification_trainer: TrainingManager = TrainingManager()
69
 
70
 
71
  class ResponseModel(BaseModel):
@@ -74,51 +59,6 @@ class ResponseModel(BaseModel):
74
  success: bool = True
75
 
76
 
77
- # ===========================================
78
- # Security Check
79
- # ===========================================
80
-
81
- # security = HTTPBearer()
82
- # def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
83
- # """Verify the token provided by the user."""
84
-
85
- # token = environmentVariableChecker.get_authentication_token()
86
-
87
- # if credentials.credentials != token:
88
- # raise HTTPException(
89
- # status_code=status.HTTP_401_UNAUTHORIZED,
90
- # detail="Invalid token",
91
- # headers={"WWW-Authenticate": "Bearer"},
92
- # )
93
- # return {"token": credentials.credentials}
94
-
95
-
96
- # ===========================================
97
- # Training Status Endpoints
98
- # ===========================================
99
-
100
- # @app.get("/get_training_status")
101
- # async def get_task_status(token_data: dict = Depends(verify_token)):
102
- # """ Get the status of the currently running training (if any). """
103
- # status = classification_trainer.get_task_status()
104
- # return {
105
- # "project": status.get_project_name(),
106
- # "progress": status.get_progress(),
107
- # "task": status.get_task(),
108
- # "status": status.get_status().value
109
- # }
110
-
111
- # @app.put("/stop_training")
112
- # async def stop_task(token_data: dict = Depends(verify_token)):
113
- # """ Stop the currently running training (if any). """
114
- # try:
115
- # status = classification_trainer.get_task_status()
116
- # classification_trainer.stop_task()
117
- # return ResponseModel(message=f"Training stopped for `{ status.get_project_name() }`")
118
- # except Exception as e:
119
- # raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
120
-
121
-
122
  @app.get("/gpu_check")
123
  async def gpu_check():
124
  """ Check if a GPU is available """
@@ -133,39 +73,73 @@ async def gpu_check():
133
  return {'success': True, 'gpu': gpu}
134
 
135
 
136
- from fastapi import Body
137
  from typing import Optional
138
 
 
 
 
 
 
 
139
  class TranslationRequest(BaseModel):
140
  inputs: str
141
  parameters: Optional[dict] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  @app.post(
144
  "/translation/{model_name:path}/",
145
- )
146
- async def translation(
147
- model_name: str,
148
- body: TranslationRequest = Body(
149
- ...,
150
- example={
151
- "inputs": "I am a car",
152
- "parameters": {
153
- "repetition_penalty": 1.6,
154
  }
155
  }
 
 
 
 
 
 
 
 
156
  )
157
- ):
158
  """
159
  Execute translation tasks.
160
 
161
- Args:
162
- model_name (str): The HuggingFace model name to use for translation.
163
- body (TranslationRequest): The request payload containing translation parameters.
164
-
165
  Returns:
166
  list: The translation result(s) as returned by the pipeline.
167
  """
168
 
 
 
169
  try:
170
  pipe = pipeline("translation", model=model_name)
171
  except Exception as e:
@@ -176,7 +150,96 @@ async def translation(
176
  )
177
 
178
  try:
179
- result = pipe(body.inputs, **(body.parameters or {}))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  except Exception as e:
181
  logger.error(f"Inference failed for model '{model_name}': {str(e)}")
182
  raise HTTPException(
@@ -187,117 +250,85 @@ async def translation(
187
  return result
188
 
189
 
190
- # ===========================================
191
- # Fine-Tuning Image Classification
192
- # ===========================================
193
-
194
- # @app.post(
195
- # "/training/image_classification",
196
- # response_model=ResponseModel
197
- # )
198
- # async def image_classification(
199
- # training_params: Annotated[ImageClassificationTrainingParameters, Depends(map_image_classification_training_parameters)],
200
- # training_data_zip: Annotated[UploadFile, File(description="The ZIP file containing the training data, with a folder per class which contains images belonging to that class.")],
201
- # token_data: dict = Depends(verify_token),
202
- # project_name: str = Form(description="The name of the project. Will also be used as name of resulting model that will be created after fine tuning and as the name of the repository at huggingface."),
203
- # source_model_name: str = Form('google/vit-base-patch16-224-in21k', description="The source model to be used as basis for fine tuning."),
204
- # ):
205
- # """
206
- # Start fine tuning an image classification model with the provided data.
207
- # """
208
-
209
- # # check if training is running, if so then exit
210
- # status = classification_trainer.get_task_status()
211
- # if status.get_status() == Status.IN_PROGRESS or status.get_status() == Status.CANCELLING:
212
- # raise HTTPException(status_code=405, detail="Training is already in progress.")
213
-
214
- # # Ensure the uploaded file is a ZIP file
215
- # if not training_data_zip.filename.endswith(".zip"):
216
- # raise HTTPException(status_code=422, detail="Uploaded file is not a zip file.")
217
-
218
- # try:
219
- # # Create a temporary directory to extract the contents
220
- # tmp_path = os.path.join(tempfile.gettempdir(), 'training_data')
221
- # path = Path(tmp_path)
222
- # path.mkdir(parents=True, exist_ok=True)
223
-
224
- # contents = await training_data_zip.read()
225
- # zip_path = os.path.join(tmp_path, 'image_classification_data.zip')
226
- # with open(zip_path, 'wb') as temp_file:
227
- # temp_file.write(contents)
228
-
229
- # # prepare parameters
230
- # parameters = ImageClassificationParameters(
231
- # training_files_path=tmp_path,
232
- # training_zip_file_path=zip_path,
233
- # project_name=project_name,
234
- # source_model_name=source_model_name,
235
- # training_parameters=training_params
236
- # )
237
-
238
- # # start training
239
- # await classification_trainer.start_training(ImageClassificationTrainer(), parameters)
240
-
241
- # return ResponseModel(message="Training started.")
242
-
243
- # except Exception as e:
244
- # raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
245
-
246
-
247
-
248
-
249
- # ===========================================
250
- # Fine-Tuning Text Classification
251
- # ===========================================
252
-
253
- # @app.post(
254
- # "/training/text_classification",
255
- # response_model=ResponseModel
256
- # )
257
- # async def text_classificaiton(
258
- # training_params: Annotated[TextClassificationTrainingParameters, Depends(map_text_classification_training_parameters)],
259
- # training_data_csv: Annotated[UploadFile, File(description="The CSV file containing the training data, necessary columns `value` (text data) and `target` (classification).")],
260
- # token_data: dict = Depends(verify_token),
261
- # project_name: str = Form(description="The name of the project. Will also be used as name of resulting model that will be created after fine tuning and as the name of the repository at huggingface."),
262
- # training_csv_limiter: str = Form(';', description="The delimiter used in the CSV file."),
263
- # source_model_name: str = Form('distilbert/distilbert-base-uncased'),
264
- # ):
265
- # """Start fine tuning an text classification model with the provided data."""
266
-
267
- # # check if training is running, if so then exit
268
- # status = classification_trainer.get_task_status()
269
- # if status.get_status() == Status.IN_PROGRESS or status.get_status() == Status.CANCELLING:
270
- # raise HTTPException(status_code=405, detail="Training is already in progress")
271
-
272
- # # Ensure the uploaded file is a CSV file
273
- # if not training_data_csv.filename.endswith(".csv"):
274
- # raise HTTPException(status_code=422, detail="Uploaded file is not a csv file.")
275
-
276
- # try:
277
- # # Create a temporary directory to extract the contents
278
- # tmp_path = os.path.join(tempfile.gettempdir(), 'training_data')
279
- # path = Path(tmp_path)
280
- # path.mkdir(parents=True, exist_ok=True)
281
-
282
- # contents = await training_data_csv.read()
283
- # csv_path = os.path.join(tmp_path, 'data.csv')
284
- # with open(csv_path, 'wb') as temp_file:
285
- # temp_file.write(contents)
286
-
287
- # # prepare parameters
288
- # parameters = TextClassificationParameters(
289
- # training_csv_file_path=csv_path,
290
- # training_csv_limiter=training_csv_limiter,
291
- # project_name=project_name,
292
- # source_model_name=source_model_name,
293
- # training_parameters=training_params
294
- # )
295
-
296
- # # start training
297
- # await classification_trainer.start_training(TextClassificationTrainer(), parameters)
298
-
299
- # return ResponseModel(message="Training started.")
300
-
301
- # except Exception as e:
302
- # raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  import os
12
  import torch
13
 
14
+ from fastapi import FastAPI, Path, Depends, HTTPException, UploadFile, Form, File, status, Request
 
 
 
 
 
 
 
 
 
 
15
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
16
  from pydantic import BaseModel
17
  from typing import Annotated
18
+ import json
19
 
20
  import logging
 
 
21
  import sys
22
+ import base64
23
 
24
 
25
  from transformers import pipeline
 
30
  version="1.0.0"
31
  )
32
 
 
 
 
33
  logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s')
34
  logger = logging.getLogger(__name__)
35
  logger.setLevel(logging.DEBUG)
 
51
  sys.stdout = StreamToLogger(logger, logging.INFO)
52
  sys.stderr = StreamToLogger(logger, logging.ERROR)
53
 
 
54
 
55
 
56
  class ResponseModel(BaseModel):
 
59
  success: bool = True
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  @app.get("/gpu_check")
63
  async def gpu_check():
64
  """ Check if a GPU is available """
 
73
  return {'success': True, 'gpu': gpu}
74
 
75
 
 
76
  from typing import Optional
77
 
78
+
79
+
80
+ # =========================
81
+ # Translation Task
82
+ # =========================
83
+
84
  class TranslationRequest(BaseModel):
85
  inputs: str
86
  parameters: Optional[dict] = None
87
+ options: Optional[dict] = None
88
+
89
+ async def get_translation_request(
90
+ request: Request
91
+ ) -> TranslationRequest:
92
+ content_type = request.headers.get("content-type", "")
93
+ if content_type.startswith("application/json"):
94
+ data = await request.json()
95
+ return TranslationRequest(**data)
96
+ if content_type.startswith("application/x-www-form-urlencoded"):
97
+ raw = await request.body()
98
+ try:
99
+ data = json.loads(raw)
100
+ return TranslationRequest(**data)
101
+ except Exception:
102
+ try:
103
+ data = json.loads(raw.decode("utf-8"))
104
+ return TranslationRequest(**data)
105
+ except Exception:
106
+ raise HTTPException(status_code=400, detail="Invalid request body")
107
+ raise HTTPException(status_code=400, detail="Unsupported content type")
108
+
109
+
110
 
111
  @app.post(
112
  "/translation/{model_name:path}/",
113
+ openapi_extra={
114
+ "requestBody": {
115
+ "content": {
116
+ "application/json": {
117
+ "example": {
118
+ "inputs": "Hello, world! foo bar",
119
+ "parameters": {"repetition_penalty": 1.6}
120
+ }
121
+ }
122
  }
123
  }
124
+ }
125
+ )
126
+ async def translate(
127
+ request: Request,
128
+ model_name: str = Path(
129
+ ...,
130
+ description="The name of the translation model (e.g. Helsinki-NLP/opus-mt-en-de)",
131
+ example="Helsinki-NLP/opus-mt-en-de"
132
  )
133
+ ):
134
  """
135
  Execute translation tasks.
136
 
 
 
 
 
137
  Returns:
138
  list: The translation result(s) as returned by the pipeline.
139
  """
140
 
141
+ translationRequest: TranslationRequest = await get_translation_request(request)
142
+
143
  try:
144
  pipe = pipeline("translation", model=model_name)
145
  except Exception as e:
 
150
  )
151
 
152
  try:
153
+ result = pipe(translationRequest.inputs, **(translationRequest.parameters or {}))
154
+ except Exception as e:
155
+ logger.error(f"Inference failed for model '{model_name}': {str(e)}")
156
+ raise HTTPException(
157
+ status_code=500,
158
+ detail=f"Inference failed: {str(e)}"
159
+ )
160
+
161
+ return result
162
+
163
+
164
+ # =========================
165
+ # Zero-Shot Image Classification Task
166
+ # =========================
167
+
168
+
169
+ class ZeroShotImageClassificationRequest(BaseModel):
170
+ inputs: str
171
+ parameters: Optional[dict] = None
172
+
173
+ async def get_zero_shot_image_classification_request(
174
+ request: Request
175
+ ) -> ZeroShotImageClassificationRequest:
176
+ content_type = request.headers.get("content-type", "")
177
+ if content_type.startswith("application/json"):
178
+ data = await request.json()
179
+ return ZeroShotImageClassificationRequest(**data)
180
+ if content_type.startswith("application/x-www-form-urlencoded"):
181
+ raw = await request.body()
182
+ try:
183
+ data = json.loads(raw)
184
+ return ZeroShotImageClassificationRequest(**data)
185
+ except Exception:
186
+ try:
187
+ data = json.loads(raw.decode("utf-8"))
188
+ return ZeroShotImageClassificationRequest(**data)
189
+ except Exception:
190
+ raise HTTPException(status_code=400, detail="Invalid request body")
191
+ raise HTTPException(status_code=400, detail="Unsupported content type")
192
+
193
+
194
+
195
+ @app.post(
196
+ "/zero-shot-image-classification/{model_name:path}/",
197
+ openapi_extra={
198
+ "requestBody": {
199
+ "content": {
200
+ "application/json": {
201
+ "example": {
202
+ "inputs": "base64_encoded_image_string",
203
+ "parameters": {"candidate_labels": "green, yellow, blue, white, silver"}
204
+ }
205
+ }
206
+ }
207
+ }
208
+ }
209
+ )
210
+ async def zero_shot_image_classification(
211
+ request: Request,
212
+ model_name: str = Path(
213
+ ...,
214
+ description="The name of the zero-shot classification model (e.g., openai/clip-vit-large-patch14-336)",
215
+ example="openai/clip-vit-large-patch14-336"
216
+ )
217
+ ):
218
+ """
219
+ Execute zero-shot image classification tasks.
220
+
221
+ Returns:
222
+ list: The classification result(s) as returned by the pipeline.
223
+ """
224
+
225
+ zeroShotRequest: ZeroShotImageClassificationRequest = await get_zero_shot_image_classification_request(request)
226
+
227
+ try:
228
+ pipe = pipeline("zero-shot-image-classification", model=model_name)
229
+ except Exception as e:
230
+ logger.error(f"Failed to load model '{model_name}': {str(e)}")
231
+ raise HTTPException(
232
+ status_code=404,
233
+ detail=f"Model '{model_name}' could not be loaded: {str(e)}"
234
+ )
235
+
236
+ try:
237
+ candidate_labels = []
238
+ if zeroShotRequest.parameters:
239
+ candidate_labels = zeroShotRequest.parameters.get('candidate_labels', [])
240
+ if isinstance(candidate_labels, str):
241
+ candidate_labels = [label.strip() for label in candidate_labels.split(',')]
242
+ result = pipe(zeroShotRequest.inputs, candidate_labels=candidate_labels)
243
  except Exception as e:
244
  logger.error(f"Inference failed for model '{model_name}': {str(e)}")
245
  raise HTTPException(
 
250
  return result
251
 
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ # =========================
255
+ # Image to Text Task
256
+ # =========================
257
+
258
+
259
+ async def get_encoded_image(
260
+ request: Request
261
+ ) -> str:
262
+ content_type = request.headers.get("content-type", "")
263
+ if content_type.startswith("multipart/form-data"):
264
+ form = await request.form()
265
+ image = form.get("image")
266
+ if image:
267
+ image_bytes = await image.read()
268
+ return base64.b64encode(image_bytes).decode("utf-8")
269
+ if content_type.startswith("image/"):
270
+ image_bytes = await request.body()
271
+ return base64.b64encode(image_bytes).decode("utf-8")
272
+
273
+ raise HTTPException(status_code=400, detail="Unsupported content type")
274
+
275
+
276
+
277
+ @app.post(
278
+ "/image-to-text/{model_name:path}/",
279
+ openapi_extra={
280
+ "requestBody": {
281
+ "content": {
282
+ "multipart/form-data": {
283
+ "schema": {
284
+ "type": "object",
285
+ "properties": {
286
+ "image": {
287
+ "type": "string",
288
+ "format": "binary",
289
+ "description": "Image file to upload"
290
+ }
291
+ },
292
+ "required": ["image"]
293
+ }
294
+ }
295
+ }
296
+ }
297
+ }
298
+ )
299
+ async def image_to_text(
300
+ request: Request,
301
+ model_name: str = Path(
302
+ ...,
303
+ description="The name of the image-to-text (e.g., Salesforce/blip-image-captioning-base)",
304
+ example="Salesforce/blip-image-captioning-base"
305
+ )
306
+ ):
307
+ """
308
+ Execute image-to-text tasks.
309
+
310
+ Returns:
311
+ list: The generated text as returned by the pipeline.
312
+ """
313
+
314
+ encoded_image = await get_encoded_image(request)
315
+
316
+ try:
317
+ pipe = pipeline("image-to-text", model=model_name, use_fast=True)
318
+ except Exception as e:
319
+ logger.error(f"Failed to load model '{model_name}': {str(e)}")
320
+ raise HTTPException(
321
+ status_code=404,
322
+ detail=f"Model '{model_name}' could not be loaded: {str(e)}"
323
+ )
324
+
325
+ try:
326
+ result = pipe(encoded_image)
327
+ except Exception as e:
328
+ logger.error(f"Inference failed for model '{model_name}': {str(e)}")
329
+ raise HTTPException(
330
+ status_code=500,
331
+ detail=f"Inference failed: {str(e)}"
332
+ )
333
+
334
+ return result