File size: 4,756 Bytes
7c4332a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import requests
import torch

from .training_status import Status
from .environment_variable_checker import EnvironmentVariableChecker

from .task_manager import TaskManager
from .training_manager import TrainingManager
from .image_classification.image_classification_trainer import ImageClassificationTrainer
from .image_classification.image_classification_parameters import ImageClassificationParameters, map_image_classification_training_parameters, ImageClassificationTrainingParameters 

from fastapi import FastAPI, Header, Depends, HTTPException, BackgroundTasks, UploadFile, Form, File, status
from fastapi.responses import FileResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from typing import Optional, Annotated


import logging
import sys

import zipfile
import os
from pathlib import Path
import tempfile
import shutil


app = FastAPI()

environmentVariableChecker = EnvironmentVariableChecker()
environmentVariableChecker.validate_environment_variables()

logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

classification_trainer: TrainingManager = TrainingManager(ImageClassificationTrainer())

security = HTTPBearer()
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):

    token = environmentVariableChecker.get_authentication_token()
    
    if credentials.credentials != token:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid token",
            headers={"WWW-Authenticate": "Bearer"},
        )
    return {"token": credentials.credentials}


class ResponseModel(BaseModel):
    message: str
    success: bool = True


@app.post(
    "/upload", 
    summary="Upload a zip file containing training data",
    response_model=ResponseModel
)
async def upload_file(
    training_params: Annotated[ImageClassificationTrainingParameters, Depends(map_image_classification_training_parameters)],
    data_files_training: Annotated[UploadFile, File(...)],
    token_data: dict = Depends(verify_token),
    result_model_name: str = Form(...),
    source_model_name: str = Form('google/vit-base-patch16-224-in21k'),
):

    # check if training is running, if so then exit
    status = classification_trainer.get_task_status()
    if status.get_status() == Status.IN_PROGRESS or status.get_status() == Status.CANCELLING:
        raise HTTPException(status_code=405, detail="Training is already in progress")

    # Ensure the uploaded file is a ZIP file
    if not data_files_training.filename.endswith(".zip"):
        raise HTTPException(status_code=422, detail="Uploaded file is not a zip file")

    try:
        # Create a temporary directory to extract the contents
        tmp_path = os.path.join(tempfile.gettempdir(), 'training_data')
        path = Path(tmp_path)
        path.mkdir(parents=True, exist_ok=True)

        contents = await data_files_training.read()
        zip_path = os.path.join(tmp_path, 'image_classification_data.zip')
        with open(zip_path, 'wb') as temp_file:
            temp_file.write(contents)

        # prepare parameters
        parameters = ImageClassificationParameters(
            training_files_path=tmp_path,
            training_zip_file_path=zip_path,
            result_model_name=result_model_name,
            source_model_name=source_model_name,
            training_parameters=training_params
        )

        # start training
        await classification_trainer.start_training(parameters)
        
        # TODO add more return parameters and information
        return ResponseModel(message="training started")
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")


@app.get("/get_task_status")
async def get_task_status(token_data: dict = Depends(verify_token)):
    status = classification_trainer.get_task_status()
    return {
        "progress": status.get_progress(),
        "task": status.get_task(),
        "status": status.get_status().value
    }


@app.get("/stop_task")
async def stop_task(token_data: dict = Depends(verify_token)):
    try: 
        classification_trainer.stop_task()
        return {
            "success": True
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")



@app.get("/gpu_check")
async def gpu_check():

    gpu = 'GPU not available'
    if torch.cuda.is_available():
        gpu = 'GPU is available'
        print("GPU is available")
    else:
        print("GPU is not available")

    return {'success': True, 'response': 'hello world 3', 'gpu': gpu}