SICI-model-ResnetV2 / inference.py
Ulrixon's picture
Create inference.py
8751ffd verified
import tensorflow as tf
import numpy as np
from PIL import Image
from typing import Dict
# Load the ResNetV2 model
model = tf.keras.models.load_model("resnetv2_model.h5")
# Define the handler for the Inference API
def predict(inputs: Dict) -> Dict:
"""
Handle inference requests.
Args:
inputs (Dict): A dictionary with a key 'image' containing the base64-encoded image.
Returns:
Dict: A dictionary containing the predicted class label.
"""
# Decode the image
if "image" not in inputs:
return {"error": "No image found in inputs"}
# Preprocess the input image
image = preprocess_image(inputs["image"])
# Perform inference
prediction = model.predict(image)
predicted_class = np.argmax(prediction, axis=1)[0] # Get the predicted class index
# Return the predicted class
return {"label": int(predicted_class)}
def preprocess_image(image_base64: str) -> np.ndarray:
"""
Preprocess the input image for ResNetV2.
Args:
image_base64 (str): Base64-encoded image.
Returns:
np.ndarray: Preprocessed image ready for inference.
"""
from io import BytesIO
import base64
# Decode the base64 image
image_data = base64.b64decode(image_base64)
image = Image.open(BytesIO(image_data)).convert("RGB")
# Resize and normalize the image
image = image.resize((224, 224))
image_array = np.array(image) / 255.0 # Normalize pixel values to [0, 1]
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
return image_array