nsfw / app.py
suko's picture
Update app.py
ca8afc0
raw
history blame
4.94 kB
import gradio as gr
import numpy as np
import json
import os
from PIL import Image
import onnxruntime as rt
class ONNXModel:
def __init__(self, dir_path) -> None:
"""Method to get name of model file. Assumes model is in the parent directory for script."""
model_dir = os.path.dirname(dir_path)
with open(os.path.join(model_dir, "signature.json"), "r") as f:
self.signature = json.load(f)
self.model_file = os.path.join(model_dir, self.signature.get("filename"))
if not os.path.isfile(self.model_file):
raise FileNotFoundError(f"Model file does not exist")
# get the signature for model inputs and outputs
self.signature_inputs = self.signature.get("inputs")
self.signature_outputs = self.signature.get("outputs")
self.session = None
if "Image" not in self.signature_inputs:
raise ValueError("ONNX model doesn't have 'Image' input! Check signature.json, and please report issue to Lobe.")
# Look for the version in signature file.
# If it's not found or the doesn't match expected, print a message
version = self.signature.get("export_model_version")
if version is None or version != EXPORT_MODEL_VERSION:
print(
f"There has been a change to the model format. Please use a model with a signature 'export_model_version' that matches {EXPORT_MODEL_VERSION}."
)
def load(self) -> None:
"""Load the model from path to model file"""
# Load ONNX model as session.
self.session = rt.InferenceSession(path_or_bytes=self.model_file)
def predict(self, image: Image.Image) -> dict:
"""
Predict with the ONNX session!
"""
# process image to be compatible with the model
img = self.process_image(image, self.signature_inputs.get("Image").get("shape"))
# run the model!
fetches = [(key, value.get("name")) for key, value in self.signature_outputs.items()]
# make the image a batch of 1
feed = {self.signature_inputs.get("Image").get("name"): [img]}
outputs = self.session.run(output_names=[name for (_, name) in fetches], input_feed=feed)
return self.process_output(fetches, outputs)
def process_image(self, image: Image.Image, input_shape: list) -> np.ndarray:
"""
Given a PIL Image, center square crop and resize to fit the expected model input, and convert from [0,255] to [0,1] values.
"""
width, height = image.size
# ensure image type is compatible with model and convert if not
if image.mode != "RGB":
image = image.convert("RGB")
# center crop image (you can substitute any other method to make a square image, such as just resizing or padding edges with 0)
if width != height:
square_size = min(width, height)
left = (width - square_size) / 2
top = (height - square_size) / 2
right = (width + square_size) / 2
bottom = (height + square_size) / 2
# Crop the center of the image
image = image.crop((left, top, right, bottom))
# now the image is square, resize it to be the right shape for the model input
input_width, input_height = input_shape[1:3]
if image.width != input_width or image.height != input_height:
image = image.resize((input_width, input_height))
# make 0-1 float instead of 0-255 int (that PIL Image loads by default)
image = np.asarray(image) / 255.0
# format input as model expects
return image.astype(np.float32)
def process_output(self, fetches: dict, outputs: dict) -> dict:
# un-batch since we ran an image with batch size of 1,
# convert to normal python types with tolist(), and convert any byte strings to normal strings with .decode()
out_keys = ["label", "confidence"]
results = {}
for i, (key, _) in enumerate(fetches):
val = outputs[i].tolist()[0]
if isinstance(val, bytes):
val = val.decode()
results[key] = val
confs = results["Confidences"]
labels = self.signature.get("classes").get("Label")
output = [dict(zip(out_keys, group)) for group in zip(labels, confs)]
sorted_output = {"predictions": sorted(output, key=lambda k: k["confidence"], reverse=True)}
return sorted_output
EXPORT_MODEL_VERSION=1
model = ONNXModel(dir_path="/model/model.onnx")
model.load()
def predict(image):
image = Image.fromarray(np.uint8(image), 'RGB')
prediction = model.predict(image)
for output in prediction["predictions"]:
output["confidence"] = round(output["confidence"], 2)
return prediction
inputs = gr.inputs.Image(type="pil")
outputs = gr.outputs.JSON()
gr.Interface(fn=predict, inputs=inputs, outputs=outputs).launch()