Spaces:
Sleeping
Sleeping
import gradio as gr | |
import onnxruntime as ort | |
import numpy as np | |
import pandas as pd | |
import cv2 | |
import time | |
import datetime | |
from scipy.special import softmax | |
import json | |
# Load ONNX model | |
session = ort.InferenceSession("ViT_PreProcessing-ops11-preprocessing-int-dynam_graph.onnx?sequence=7") | |
snake_df = pd.read_csv("snake.csv") | |
with open('scientific2vietnamese.json', 'r') as f: | |
scientific2vietnamese = json.load(f) | |
url = 'https://raw.githubusercontent.com/picekl/PlosNTD-GlobalSnakeID/refs/heads/main/metadata/Species_Relevance.csv' | |
species_relevance_df = pd.read_csv(url, sep=';') | |
vietnam_species = species_relevance_df[species_relevance_df['vietnam']==1]['Unnamed: 0'].to_list() | |
url = 'https://raw.githubusercontent.com/picekl/PlosNTD-GlobalSnakeID/refs/heads/main/metadata/SnakeCLEF2021_train_metadata_PROD.csv' | |
df = pd.read_csv(url, sep=',') | |
class_id_to_binomial = dict(zip(df['class_id'], df['binomial'])) | |
model_snake_list = df['binomial'].unique().tolist() | |
overlap_list = list(set(vietnam_species) & set(model_snake_list)) | |
# Preprocessing function to prepare image for ONNX model | |
def preprocess_image(image): | |
HEIGHT = 384 | |
WIDTH = 384 | |
# Convert PIL Image to numpy array | |
input_data = np.array(image) | |
# Resize the image | |
input_data = cv2.resize(input_data, (WIDTH, HEIGHT)) | |
# Normalize the image (scale pixel values to [-1, 1]) | |
input_data = input_data / 255.0 | |
mean = np.array([0.5, 0.5, 0.5]) | |
std = np.array([0.5, 0.5, 0.5]) | |
input_data = (input_data - mean) / std | |
# Scale the image back to [0, 255] and convert to int32 | |
input_data = (input_data * 255).astype(np.int32) | |
return input_data | |
# Inference function to get top 5 predictions | |
def predict(image): | |
# Preprocess the image | |
input_data = preprocess_image(image) | |
# Get input and output node names | |
input_name = session.get_inputs()[0].name | |
output_name = session.get_outputs()[0].name | |
start_time = time.time() | |
result = session.run([output_name], {input_name: input_data}) | |
end_time = time.time() | |
# Calculate the total execution time | |
execution_time = end_time - start_time | |
print(f"Execution on {datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=7)))} in {execution_time:.4f} seconds") | |
k = 20 | |
top_k_index = np.argsort(result[0][0])[-k:][::-1] | |
top_classes, top_probabilities = [], [] | |
# Also convert the logit into % | |
top_k_prob = softmax(result[0][0]) | |
more_info = "" | |
for i in top_k_index: | |
if (top_k_prob[i] > 0.01): # Only display those with percentage > 1% | |
species = class_id_to_binomial[i] | |
if species in vietnam_species: | |
top_probabilities.append(top_k_prob[i]) | |
if species in scientific2vietnamese: | |
top_classes.append(scientific2vietnamese[species]) | |
url = f"https://nhandangdongvat.vercel.app/snakes/{species.replace(' ', '%20')}" | |
more_info += f"- ℹ️ Tìm hiểu thêm về [{scientific2vietnamese[species]}]({url}) \n" | |
else: | |
top_classes.append(species) | |
# Format the result | |
return {class_name: float(prob) for class_name, prob in zip(top_classes, top_probabilities)}, more_info | |
# Example images | |
examples = [ | |
["example_snake00001.png"], | |
["example_snake00002.png"], | |
["example_snake00003.png"], | |
["example_snake00004.png"], | |
["example_snake00005.png"] | |
] | |
# Supported snakes (in the accordion) | |
supported_snakes = """ | |
| Tên khoa học | Tên tiếng Việt | | |
|----------------------|--------------------------| | |
""" | |
for snake in overlap_list: | |
if snake in scientific2vietnamese.keys(): | |
supported_snakes += f"|{snake}|{scientific2vietnamese[snake]}|\n" | |
else: | |
supported_snakes += f"|{snake}||\n" | |
# Gradio interface | |
description = """ | |
Bạn có biết, ở Việt Nam có trên 250 loài rắn, trong đó chỉ có 12% rắn trên cạn là có độc và nguy hiểm với người. Khi môi trường sống của rắn bị thu hẹp thì việc đụng độ giữa người và rắn sẽ diễn ra thường xuyên hơn. Do vậy, mình phát triển ứng giúp này để giúp các bạn phân biệt được bạn rắn mà bạn bắt gặp có độc hay không bằng trí tuệ nhân tạo, từ đó có được biện pháp phòng tránh tốt nhất. Happy snaking 🐍. | |
""" | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Ứng dụng Nhận diện Rắn 🐍") | |
gr.Markdown(description) # Add description | |
# Upload image section | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Tải lên hình ảnh của bạn") | |
with gr.Column(): | |
output = gr.Label(num_top_classes=5, label="Kết quả") | |
more_info = gr.Markdown() | |
# Automatically trigger prediction on image upload | |
image_input.change(predict, inputs=image_input, outputs=[output, more_info]) | |
# Example images | |
gr.Examples(examples=examples, inputs=image_input, outputs=output, fn=predict, label="Ảnh chạy thử", cache_examples=False) | |
# Accordion for supported snake species | |
with gr.Accordion(f"Danh sách {len(overlap_list)} loài rắn ở Việt Nam được hỗ trợ", open=False): | |
gr.Markdown(supported_snakes) | |
# Launch the interface | |
demo.launch(debug=True) | |