Spaces:
Sleeping
Sleeping
import os | |
import requests | |
import torch | |
from .training_status import Status | |
from .environment_variable_checker import EnvironmentVariableChecker | |
from .task_manager import TaskManager | |
from .training_manager import TrainingManager | |
from .image_classification.image_classification_trainer import ImageClassificationTrainer | |
from .image_classification.image_classification_parameters import ImageClassificationParameters, map_image_classification_training_parameters, ImageClassificationTrainingParameters | |
from fastapi import FastAPI, Header, Depends, HTTPException, BackgroundTasks, UploadFile, Form, File, status | |
from fastapi.responses import FileResponse | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel | |
from typing import Optional, Annotated | |
import logging | |
import sys | |
import zipfile | |
import os | |
from pathlib import Path | |
import tempfile | |
import shutil | |
app = FastAPI() | |
environmentVariableChecker = EnvironmentVariableChecker() | |
environmentVariableChecker.validate_environment_variables() | |
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s') | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
classification_trainer: TrainingManager = TrainingManager(ImageClassificationTrainer()) | |
security = HTTPBearer() | |
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
token = environmentVariableChecker.get_authentication_token() | |
if credentials.credentials != token: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid token", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
return {"token": credentials.credentials} | |
class ResponseModel(BaseModel): | |
message: str | |
success: bool = True | |
async def upload_file( | |
training_params: Annotated[ImageClassificationTrainingParameters, Depends(map_image_classification_training_parameters)], | |
data_files_training: Annotated[UploadFile, File(...)], | |
token_data: dict = Depends(verify_token), | |
result_model_name: str = Form(...), | |
source_model_name: str = Form('google/vit-base-patch16-224-in21k'), | |
): | |
# check if training is running, if so then exit | |
status = classification_trainer.get_task_status() | |
if status.get_status() == Status.IN_PROGRESS or status.get_status() == Status.CANCELLING: | |
raise HTTPException(status_code=405, detail="Training is already in progress") | |
# Ensure the uploaded file is a ZIP file | |
if not data_files_training.filename.endswith(".zip"): | |
raise HTTPException(status_code=422, detail="Uploaded file is not a zip file") | |
try: | |
# Create a temporary directory to extract the contents | |
tmp_path = os.path.join(tempfile.gettempdir(), 'training_data') | |
path = Path(tmp_path) | |
path.mkdir(parents=True, exist_ok=True) | |
contents = await data_files_training.read() | |
zip_path = os.path.join(tmp_path, 'image_classification_data.zip') | |
with open(zip_path, 'wb') as temp_file: | |
temp_file.write(contents) | |
# prepare parameters | |
parameters = ImageClassificationParameters( | |
training_files_path=tmp_path, | |
training_zip_file_path=zip_path, | |
result_model_name=result_model_name, | |
source_model_name=source_model_name, | |
training_parameters=training_params | |
) | |
# start training | |
await classification_trainer.start_training(parameters) | |
# TODO add more return parameters and information | |
return ResponseModel(message="training started") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
async def get_task_status(token_data: dict = Depends(verify_token)): | |
status = classification_trainer.get_task_status() | |
return { | |
"progress": status.get_progress(), | |
"task": status.get_task(), | |
"status": status.get_status().value | |
} | |
async def stop_task(token_data: dict = Depends(verify_token)): | |
try: | |
classification_trainer.stop_task() | |
return { | |
"success": True | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
async def gpu_check(): | |
gpu = 'GPU not available' | |
if torch.cuda.is_available(): | |
gpu = 'GPU is available' | |
print("GPU is available") | |
else: | |
print("GPU is not available") | |
return {'success': True, 'response': 'hello world 3', 'gpu': gpu} | |