import numpy as np import gradio as gr from PIL import Image import onnxruntime as ort def resize_and_crop(image, image_size): # Resize the image such that the shortest side is image_size original_size = image.size ratio = float(image_size) / min(original_size) new_size = tuple([int(x * ratio) for x in original_size]) resized_image = image.resize(new_size, Image.LANCZOS) # Calculate coordinates for center cropping left = (resized_image.width - image_size) / 2 top = (resized_image.height - image_size) / 2 right = (resized_image.width + image_size) / 2 bottom = (resized_image.height + image_size) / 2 # Crop the image to image_size x image_size cropped_image = resized_image.crop((left, top, right, bottom)) array = np.array(cropped_image) array = np.transpose(array, (2, 0, 1)) array = np.expand_dims(array, axis=0) array = (array/255).astype(np.float32) return array # Read class labels from a text file def read_labels(file_path): with open(file_path, 'r') as f: labels = [line.strip() for line in f.readlines()] return labels # Load the class labels class_labels = read_labels('vocab_formatted.txt') taxon_included = read_labels('taxon_included.txt') string_taxon_included = ', '.join(sorted(taxon_included)) taxon_not_included = read_labels('taxon_not_included.txt') string_taxon_not_included = ', '.join(sorted(taxon_not_included)) # Load the ONNX model onnx_model = ort.InferenceSession('convnext_tiny.onnx') input_name = onnx_model.get_inputs()[0].name output_name = onnx_model.get_outputs()[0].name # Define the inference function def classify_image(image): input_array = resize_and_crop(image, 320) outputs = onnx_model.run([output_name], {input_name: input_array})[0] result = {taxon: prob for taxon, prob in zip(class_labels, outputs[0])} return result # Create the Gradio interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), title="Image Classification for Freshwater Fish Species of Denmark", description=f"**Upload an image to classify it**.\n\nSpecies included (Danish common name):\n*{string_taxon_included}*\n\nSpecies not included (yet!):\n*{string_taxon_not_included}*", ) # Launch the app if __name__ == "__main__": iface.launch()