Spaces:
Sleeping
Sleeping
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
import onnxruntime as ort | |
def resize_and_crop(image, image_size): | |
# Resize the image such that the shortest side is image_size | |
original_size = image.size | |
ratio = float(image_size) / min(original_size) | |
new_size = tuple([int(x * ratio) for x in original_size]) | |
resized_image = image.resize(new_size, Image.LANCZOS) | |
# Calculate coordinates for center cropping | |
left = (resized_image.width - image_size) / 2 | |
top = (resized_image.height - image_size) / 2 | |
right = (resized_image.width + image_size) / 2 | |
bottom = (resized_image.height + image_size) / 2 | |
# Crop the image to image_size x image_size | |
cropped_image = resized_image.crop((left, top, right, bottom)) | |
array = np.array(cropped_image) | |
array = np.transpose(array, (2, 0, 1)) | |
array = np.expand_dims(array, axis=0) | |
array = (array/255).astype(np.float32) | |
return array | |
# Read class labels from a text file | |
def read_labels(file_path): | |
with open(file_path, 'r') as f: | |
labels = [line.strip() for line in f.readlines()] | |
return labels | |
# Load the class labels | |
class_labels = read_labels('vocab_formatted.txt') | |
taxon_included = read_labels('taxon_included.txt') | |
string_taxon_included = ', '.join(sorted(taxon_included)) | |
taxon_not_included = read_labels('taxon_not_included.txt') | |
string_taxon_not_included = ', '.join(sorted(taxon_not_included)) | |
# Load the ONNX model | |
onnx_model = ort.InferenceSession('convnext_tiny.onnx') | |
input_name = onnx_model.get_inputs()[0].name | |
output_name = onnx_model.get_outputs()[0].name | |
# Define the inference function | |
def classify_image(image): | |
input_array = resize_and_crop(image, 320) | |
outputs = onnx_model.run([output_name], {input_name: input_array})[0] | |
result = {taxon: prob for taxon, prob in zip(class_labels, outputs[0])} | |
return result | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=5), | |
title="Billedgenkendelse af Danske Ferskvandsfisk", | |
description=f"**Upload et billede af en fisk for at identificere arten**.\n\nArter inkluderet (dansk navn):\n*{string_taxon_included}*\n\nArter ikke inkluderet (endnu!):\n*{string_taxon_not_included}*\n\nLavet af: Kenneth Thorø Martinsen ([email protected])", | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch() |