File size: 834 Bytes
dbfc835
 
 
 
 
c7490aa
dbfc835
 
 
c7490aa
dbfc835
 
 
c7490aa
dbfc835
c7490aa
dbfc835
 
 
c7490aa
 
 
dbfc835
 
 
c7490aa
dbfc835
 
c7490aa
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import PIL
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
from utils.model_func import class_id_to_label, load_model, transform_image

model = None 
app = FastAPI()


# Create class of answer: only class name 
class ImageClass(BaseModel):
    prediction: str

# Load model at startup
@app.on_event("startup")
def startup_event():
    global model
    model = load_model()

@app.get('/')
def return_info():
    return 'Hello FastAPI'


@app.post('/classify')
def classify(file: UploadFile = File(...)):
    image = PIL.Image.open(file.file)
    adapted_image = transform_image(image)
    pred_index = model(adapted_image.unsqueeze(0)).detach().cpu().numpy().argmax()
    imagenet_class = class_id_to_label(pred_index)
    response = ImageClass(
        prediction=imagenet_class
    )

    return response