import tensorflow as tf import numpy as np from PIL import Image from typing import Dict # Load the ResNetV2 model model = tf.keras.models.load_model("resnetv2_model.h5") # Define the handler for the Inference API def predict(inputs: Dict) -> Dict: """ Handle inference requests. Args: inputs (Dict): A dictionary with a key 'image' containing the base64-encoded image. Returns: Dict: A dictionary containing the predicted class label. """ # Decode the image if "image" not in inputs: return {"error": "No image found in inputs"} # Preprocess the input image image = preprocess_image(inputs["image"]) # Perform inference prediction = model.predict(image) predicted_class = np.argmax(prediction, axis=1)[0] # Get the predicted class index # Return the predicted class return {"label": int(predicted_class)} def preprocess_image(image_base64: str) -> np.ndarray: """ Preprocess the input image for ResNetV2. Args: image_base64 (str): Base64-encoded image. Returns: np.ndarray: Preprocessed image ready for inference. """ from io import BytesIO import base64 # Decode the base64 image image_data = base64.b64decode(image_base64) image = Image.open(BytesIO(image_data)).convert("RGB") # Resize and normalize the image image = image.resize((224, 224)) image_array = np.array(image) / 255.0 # Normalize pixel values to [0, 1] image_array = np.expand_dims(image_array, axis=0) # Add batch dimension return image_array