xcurvnubaim
fix: fix img reshape
fba92a9
raw
history blame
899 Bytes
import numpy as np
from fastapi import FastAPI, File, UploadFile
import tensorflow as tf
from PIL import Image
from io import BytesIO
app = FastAPI()
labels = []
model = tf.keras.models.load_model('./models.h5')
with open("labels.txt") as f:
for line in f:
labels.append(line.replace('\n', ''))
def classify_image(img):
# Resize the input image to the expected shape (224, 224)
img_array = np.asarray(img.resize((224, 224)))[..., :3]
img_array = img_array.reshape((-1, 224, 224, 3))
img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
prediction = model.predict(img_array).flatten()
confidences = {labels[i]: float(prediction[i]) for i in range(90)}
return confidences
@app.post("/predict")
async def predict(file: bytes = File(...)):
img = Image.open(BytesIO(file))
confidences = classify_image(img)
return confidences