File size: 2,439 Bytes
7c196be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2abd98
 
7c196be
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import gradio as gr
from PIL import Image
import onnxruntime as ort

def resize_and_crop(image, image_size):
    
    # Resize the image such that the shortest side is image_size
    original_size = image.size
    ratio = float(image_size) / min(original_size)
    new_size = tuple([int(x * ratio) for x in original_size])
    resized_image = image.resize(new_size, Image.LANCZOS)
    
    # Calculate coordinates for center cropping
    left = (resized_image.width - image_size) / 2
    top = (resized_image.height - image_size) / 2
    right = (resized_image.width + image_size) / 2
    bottom = (resized_image.height + image_size) / 2
    
    # Crop the image to image_size x image_size
    cropped_image = resized_image.crop((left, top, right, bottom))

    array = np.array(cropped_image)
    array = np.transpose(array, (2, 0, 1))
    array = np.expand_dims(array, axis=0)
    array = (array/255).astype(np.float32)
    
    return array

# Read class labels from a text file
def read_labels(file_path):
    with open(file_path, 'r') as f:
        labels = [line.strip() for line in f.readlines()]
    return labels

# Load the class labels
class_labels = read_labels('vocab_formatted.txt')

taxon_included = read_labels('taxon_included.txt')
string_taxon_included = ', '.join(sorted(taxon_included))

taxon_not_included = read_labels('taxon_not_included.txt')
string_taxon_not_included = ', '.join(sorted(taxon_not_included))

# Load the ONNX model
onnx_model = ort.InferenceSession('convnext_tiny.onnx')

input_name = onnx_model.get_inputs()[0].name
output_name = onnx_model.get_outputs()[0].name

# Define the inference function
def classify_image(image):
    input_array = resize_and_crop(image, 320)
    outputs = onnx_model.run([output_name], {input_name: input_array})[0]
    result = {taxon: prob for taxon, prob in zip(class_labels, outputs[0])}
    return result

# Create the Gradio interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=5),
    title="Billedgenkendelse af Danske Ferskvandsfisk",
    description=f"**Upload et billede af en fisk for at identificere arten**.\n\nArter inkluderet (dansk navn):\n*{string_taxon_included}*\n\nArter ikke inkluderet (endnu!):\n*{string_taxon_not_included}*\n\nLavet af: Kenneth Thorø Martinsen ([email protected])",
)

# Launch the app
if __name__ == "__main__":
    iface.launch()