File size: 5,393 Bytes
97f6a6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195f2d0
 
 
 
97f6a6d
 
 
6a5c0e0
97f6a6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b2afbd
97f6a6d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import onnxruntime as ort
import numpy as np
import pandas as pd
import cv2
import time
import datetime
from scipy.special import softmax
import json

# Load ONNX model
session = ort.InferenceSession("ViT_PreProcessing-ops11-preprocessing-int-dynam_graph.onnx?sequence=7")

snake_df = pd.read_csv("snake.csv")

with open('scientific2vietnamese.json', 'r') as f:
    scientific2vietnamese = json.load(f)

url = 'https://raw.githubusercontent.com/picekl/PlosNTD-GlobalSnakeID/refs/heads/main/metadata/Species_Relevance.csv'
species_relevance_df = pd.read_csv(url, sep=';')
vietnam_species = species_relevance_df[species_relevance_df['vietnam']==1]['Unnamed: 0'].to_list()

url = 'https://raw.githubusercontent.com/picekl/PlosNTD-GlobalSnakeID/refs/heads/main/metadata/SnakeCLEF2021_train_metadata_PROD.csv'
df = pd.read_csv(url, sep=',')
class_id_to_binomial = dict(zip(df['class_id'], df['binomial']))

model_snake_list = df['binomial'].unique().tolist()
overlap_list = list(set(vietnam_species) & set(model_snake_list))

# Preprocessing function to prepare image for ONNX model
def preprocess_image(image):
    HEIGHT = 384
    WIDTH = 384

    # Convert PIL Image to numpy array
    input_data = np.array(image)

    # Resize the image
    input_data = cv2.resize(input_data, (WIDTH, HEIGHT))

    # Normalize the image (scale pixel values to [-1, 1])
    input_data = input_data / 255.0
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    input_data = (input_data - mean) / std

    # Scale the image back to [0, 255] and convert to int32
    input_data = (input_data * 255).astype(np.int32)
    return input_data

# Inference function to get top 5 predictions
def predict(image):
    # Preprocess the image
    input_data = preprocess_image(image)

    # Get input and output node names
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    start_time = time.time()
    result = session.run([output_name], {input_name: input_data})
    end_time = time.time()

    # Calculate the total execution time
    execution_time = end_time - start_time
    print(f"Execution on {datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=7)))} in {execution_time:.4f} seconds")

    k = 20
    top_k_index = np.argsort(result[0][0])[-k:][::-1]

    top_classes, top_probabilities = [], []
    # Also convert the logit into %
    top_k_prob = softmax(result[0][0])

    more_info = ""

    for i in top_k_index:
      if (top_k_prob[i] > 0.01): # Only display those with percentage > 1%
        species = class_id_to_binomial[i]
        if species in vietnam_species:
          top_probabilities.append(top_k_prob[i])
          
          if species in scientific2vietnamese:
            top_classes.append(scientific2vietnamese[species])
            url = f"https://nhandangdongvat.vercel.app/snakes/{species.replace(' ', '%20')}"
            more_info += f"- ℹ️ Tìm hiểu thêm về [{scientific2vietnamese[species]}]({url}) \n"
          else:
            top_classes.append(species)
    
    # Format the result
    return {class_name: float(prob) for class_name, prob in zip(top_classes, top_probabilities)}, more_info

# Example images
examples = [
    ["example_snake00001.png"], 
    ["example_snake00002.png"],
    ["example_snake00003.png"],
    ["example_snake00004.png"],
    ["example_snake00005.png"]
]

# Supported snakes (in the accordion)
supported_snakes = """
| Tên khoa học         | Tên tiếng Việt           |
|----------------------|--------------------------|
"""

for snake in overlap_list:
    if snake in scientific2vietnamese.keys():
        supported_snakes += f"|{snake}|{scientific2vietnamese[snake]}|\n"
    else:
        supported_snakes += f"|{snake}||\n"

# Gradio interface
description = """
Bạn có biết, ở Việt Nam có trên 250 loài rắn, trong đó chỉ có 12% rắn trên cạn là có độc và nguy hiểm với người. Khi môi trường sống của rắn bị thu hẹp thì việc đụng độ giữa người và rắn sẽ diễn ra thường xuyên hơn. Do vậy, mình phát triển ứng giúp này để giúp các bạn phân biệt được bạn rắn mà bạn bắt gặp có độc hay không bằng trí tuệ nhân tạo, từ đó có được biện pháp phòng tránh tốt nhất. Happy snaking 🐍.
"""

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Ứng dụng Nhận diện Rắn 🐍")
    gr.Markdown(description)  # Add description
    
    # Upload image section
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Tải lên hình ảnh của bạn")
        with gr.Column():
            output = gr.Label(num_top_classes=5, label="Kết quả")
            more_info = gr.Markdown()

    # Automatically trigger prediction on image upload
    image_input.change(predict, inputs=image_input, outputs=[output, more_info])
    
    # Example images
    gr.Examples(examples=examples, inputs=image_input, outputs=output, fn=predict, label="Ảnh chạy thử", cache_examples=False)
    
    # Accordion for supported snake species
    with gr.Accordion(f"Danh sách {len(overlap_list)} loài rắn ở Việt Nam được hỗ trợ", open=False):
        gr.Markdown(supported_snakes)

# Launch the interface
demo.launch(debug=True)