Ulrixon commited on
Commit
8751ffd
·
verified ·
1 Parent(s): 453f7da

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +52 -0
inference.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import Dict
5
+
6
+ # Load the ResNetV2 model
7
+ model = tf.keras.models.load_model("resnetv2_model.h5")
8
+
9
+ # Define the handler for the Inference API
10
+ def predict(inputs: Dict) -> Dict:
11
+ """
12
+ Handle inference requests.
13
+ Args:
14
+ inputs (Dict): A dictionary with a key 'image' containing the base64-encoded image.
15
+ Returns:
16
+ Dict: A dictionary containing the predicted class label.
17
+ """
18
+ # Decode the image
19
+ if "image" not in inputs:
20
+ return {"error": "No image found in inputs"}
21
+
22
+ # Preprocess the input image
23
+ image = preprocess_image(inputs["image"])
24
+
25
+ # Perform inference
26
+ prediction = model.predict(image)
27
+ predicted_class = np.argmax(prediction, axis=1)[0] # Get the predicted class index
28
+
29
+ # Return the predicted class
30
+ return {"label": int(predicted_class)}
31
+
32
+ def preprocess_image(image_base64: str) -> np.ndarray:
33
+ """
34
+ Preprocess the input image for ResNetV2.
35
+ Args:
36
+ image_base64 (str): Base64-encoded image.
37
+ Returns:
38
+ np.ndarray: Preprocessed image ready for inference.
39
+ """
40
+ from io import BytesIO
41
+ import base64
42
+
43
+ # Decode the base64 image
44
+ image_data = base64.b64decode(image_base64)
45
+ image = Image.open(BytesIO(image_data)).convert("RGB")
46
+
47
+ # Resize and normalize the image
48
+ image = image.resize((224, 224))
49
+ image_array = np.array(image) / 255.0 # Normalize pixel values to [0, 1]
50
+ image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
51
+
52
+ return image_array