panda1835's picture
Update app.py
6a5c0e0 verified
import gradio as gr
import onnxruntime as ort
import numpy as np
import pandas as pd
import cv2
import time
import datetime
from scipy.special import softmax
import json
# Load ONNX model
session = ort.InferenceSession("ViT_PreProcessing-ops11-preprocessing-int-dynam_graph.onnx?sequence=7")
snake_df = pd.read_csv("snake.csv")
with open('scientific2vietnamese.json', 'r') as f:
scientific2vietnamese = json.load(f)
url = 'https://raw.githubusercontent.com/picekl/PlosNTD-GlobalSnakeID/refs/heads/main/metadata/Species_Relevance.csv'
species_relevance_df = pd.read_csv(url, sep=';')
vietnam_species = species_relevance_df[species_relevance_df['vietnam']==1]['Unnamed: 0'].to_list()
url = 'https://raw.githubusercontent.com/picekl/PlosNTD-GlobalSnakeID/refs/heads/main/metadata/SnakeCLEF2021_train_metadata_PROD.csv'
df = pd.read_csv(url, sep=',')
class_id_to_binomial = dict(zip(df['class_id'], df['binomial']))
model_snake_list = df['binomial'].unique().tolist()
overlap_list = list(set(vietnam_species) & set(model_snake_list))
# Preprocessing function to prepare image for ONNX model
def preprocess_image(image):
HEIGHT = 384
WIDTH = 384
# Convert PIL Image to numpy array
input_data = np.array(image)
# Resize the image
input_data = cv2.resize(input_data, (WIDTH, HEIGHT))
# Normalize the image (scale pixel values to [-1, 1])
input_data = input_data / 255.0
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])
input_data = (input_data - mean) / std
# Scale the image back to [0, 255] and convert to int32
input_data = (input_data * 255).astype(np.int32)
return input_data
# Inference function to get top 5 predictions
def predict(image):
# Preprocess the image
input_data = preprocess_image(image)
# Get input and output node names
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
start_time = time.time()
result = session.run([output_name], {input_name: input_data})
end_time = time.time()
# Calculate the total execution time
execution_time = end_time - start_time
print(f"Execution on {datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=7)))} in {execution_time:.4f} seconds")
k = 20
top_k_index = np.argsort(result[0][0])[-k:][::-1]
top_classes, top_probabilities = [], []
# Also convert the logit into %
top_k_prob = softmax(result[0][0])
more_info = ""
for i in top_k_index:
if (top_k_prob[i] > 0.01): # Only display those with percentage > 1%
species = class_id_to_binomial[i]
if species in vietnam_species:
top_probabilities.append(top_k_prob[i])
if species in scientific2vietnamese:
top_classes.append(scientific2vietnamese[species])
url = f"https://nhandangdongvat.vercel.app/snakes/{species.replace(' ', '%20')}"
more_info += f"- ℹ️ Tìm hiểu thêm về [{scientific2vietnamese[species]}]({url}) \n"
else:
top_classes.append(species)
# Format the result
return {class_name: float(prob) for class_name, prob in zip(top_classes, top_probabilities)}, more_info
# Example images
examples = [
["example_snake00001.png"],
["example_snake00002.png"],
["example_snake00003.png"],
["example_snake00004.png"],
["example_snake00005.png"]
]
# Supported snakes (in the accordion)
supported_snakes = """
| Tên khoa học | Tên tiếng Việt |
|----------------------|--------------------------|
"""
for snake in overlap_list:
if snake in scientific2vietnamese.keys():
supported_snakes += f"|{snake}|{scientific2vietnamese[snake]}|\n"
else:
supported_snakes += f"|{snake}||\n"
# Gradio interface
description = """
Bạn có biết, ở Việt Nam có trên 250 loài rắn, trong đó chỉ có 12% rắn trên cạn là có độc và nguy hiểm với người. Khi môi trường sống của rắn bị thu hẹp thì việc đụng độ giữa người và rắn sẽ diễn ra thường xuyên hơn. Do vậy, mình phát triển ứng giúp này để giúp các bạn phân biệt được bạn rắn mà bạn bắt gặp có độc hay không bằng trí tuệ nhân tạo, từ đó có được biện pháp phòng tránh tốt nhất. Happy snaking 🐍.
"""
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Ứng dụng Nhận diện Rắn 🐍")
gr.Markdown(description) # Add description
# Upload image section
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Tải lên hình ảnh của bạn")
with gr.Column():
output = gr.Label(num_top_classes=5, label="Kết quả")
more_info = gr.Markdown()
# Automatically trigger prediction on image upload
image_input.change(predict, inputs=image_input, outputs=[output, more_info])
# Example images
gr.Examples(examples=examples, inputs=image_input, outputs=output, fn=predict, label="Ảnh chạy thử", cache_examples=False)
# Accordion for supported snake species
with gr.Accordion(f"Danh sách {len(overlap_list)} loài rắn ở Việt Nam được hỗ trợ", open=False):
gr.Markdown(supported_snakes)
# Launch the interface
demo.launch(debug=True)