|
import tensorflow as tf |
|
import numpy as np |
|
from PIL import Image |
|
from typing import Dict |
|
|
|
|
|
model = tf.keras.models.load_model("resnetv2_model.h5") |
|
|
|
|
|
def predict(inputs: Dict) -> Dict: |
|
""" |
|
Handle inference requests. |
|
Args: |
|
inputs (Dict): A dictionary with a key 'image' containing the base64-encoded image. |
|
Returns: |
|
Dict: A dictionary containing the predicted class label. |
|
""" |
|
|
|
if "image" not in inputs: |
|
return {"error": "No image found in inputs"} |
|
|
|
|
|
image = preprocess_image(inputs["image"]) |
|
|
|
|
|
prediction = model.predict(image) |
|
predicted_class = np.argmax(prediction, axis=1)[0] |
|
|
|
|
|
return {"label": int(predicted_class)} |
|
|
|
def preprocess_image(image_base64: str) -> np.ndarray: |
|
""" |
|
Preprocess the input image for ResNetV2. |
|
Args: |
|
image_base64 (str): Base64-encoded image. |
|
Returns: |
|
np.ndarray: Preprocessed image ready for inference. |
|
""" |
|
from io import BytesIO |
|
import base64 |
|
|
|
|
|
image_data = base64.b64decode(image_base64) |
|
image = Image.open(BytesIO(image_data)).convert("RGB") |
|
|
|
|
|
image = image.resize((224, 224)) |
|
image_array = np.array(image) / 255.0 |
|
image_array = np.expand_dims(image_array, axis=0) |
|
|
|
return image_array |