semantic-segmentation / pipeline.py
merve's picture
merve HF staff
Update pipeline.py
6cb57f7
raw
history blame
2.48 kB
import json
from typing import Any, Dict, List
import tensorflow as tf
from tensorflow import keras
from app.pipelines import Pipeline
from huggingface_hub import from_pretrained_keras, hf_hub_download
from PIL import Image
import base64
MODEL_FILENAME = "saved_model.pb"
CONFIG_FILENAME = "config.json"
class PreTrainedPipeline(Pipeline):
def __init__(self, model_id: str):
# Reload Keras SavedModel
self.model = keras.models.load_model('./model.h5')
# Number of labels
self.num_labels = self.model.output_shape[1]
# Config is required to know the mapping to label.
#config_file = hf_hub_download(model_id, filename=CONFIG_FILENAME)
#with open(config_file) as config:
# config = json.load(config)
self.num_labels = 3
self.id2label = config.get(
"id2label", {str(i): f"LABEL_{i}" for i in range(self.num_labels)}
)
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
"""
Args:
inputs (:obj:`PIL.Image`):
The raw image representation as PIL.
No transformation made whatsoever from the input. Make all necessary transformations here.
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX" (str), mask: "base64 encoding of the mask" (str), "score": float}
It is preferred if the returned list is in decreasing `score` order
"""
# Resize image to expected size
expected_input_size = self.model.input_shape
with Image.open(inputs) as im:
inputs = np.array(im)
if expected_input_size[-1] == 1:
inputs = inputs.convert("L")
target_size = (expected_input_size[1], expected_input_size[2])
img = tf.image.resize(inputs, target_size)
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = img_array[tf.newaxis, ...]
predictions = self.model.predict(img_array)
self.single_output_unit = (
self.model.output_shape[1] == 1
) # if there are two classes
labels = []
for i in enumerate(predictions):
labels.append({
"label": str(i[0]),
"mask": base64.b64encode(i[1]),
"score": 1.0,
})
return sorted(labels)