import gradio as gr import onnxruntime as ort import numpy as np import pandas as pd import cv2 import time import datetime from scipy.special import softmax import json # Load ONNX model session = ort.InferenceSession("ViT_PreProcessing-ops11-preprocessing-int-dynam_graph.onnx?sequence=7") snake_df = pd.read_csv("snake.csv") with open('scientific2vietnamese.json', 'r') as f: scientific2vietnamese = json.load(f) url = 'https://raw.githubusercontent.com/picekl/PlosNTD-GlobalSnakeID/refs/heads/main/metadata/Species_Relevance.csv' species_relevance_df = pd.read_csv(url, sep=';') vietnam_species = species_relevance_df[species_relevance_df['vietnam']==1]['Unnamed: 0'].to_list() url = 'https://raw.githubusercontent.com/picekl/PlosNTD-GlobalSnakeID/refs/heads/main/metadata/SnakeCLEF2021_train_metadata_PROD.csv' df = pd.read_csv(url, sep=',') class_id_to_binomial = dict(zip(df['class_id'], df['binomial'])) model_snake_list = df['binomial'].unique().tolist() overlap_list = list(set(vietnam_species) & set(model_snake_list)) # Preprocessing function to prepare image for ONNX model def preprocess_image(image): HEIGHT = 384 WIDTH = 384 # Convert PIL Image to numpy array input_data = np.array(image) # Resize the image input_data = cv2.resize(input_data, (WIDTH, HEIGHT)) # Normalize the image (scale pixel values to [-1, 1]) input_data = input_data / 255.0 mean = np.array([0.5, 0.5, 0.5]) std = np.array([0.5, 0.5, 0.5]) input_data = (input_data - mean) / std # Scale the image back to [0, 255] and convert to int32 input_data = (input_data * 255).astype(np.int32) return input_data # Inference function to get top 5 predictions def predict(image): # Preprocess the image input_data = preprocess_image(image) # Get input and output node names input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name start_time = time.time() result = session.run([output_name], {input_name: input_data}) end_time = time.time() # Calculate the total execution time execution_time = end_time - start_time print(f"Execution on {datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=7)))} in {execution_time:.4f} seconds") k = 20 top_k_index = np.argsort(result[0][0])[-k:][::-1] top_classes, top_probabilities = [], [] # Also convert the logit into % top_k_prob = softmax(result[0][0]) more_info = "" for i in top_k_index: if (top_k_prob[i] > 0.01): # Only display those with percentage > 1% species = class_id_to_binomial[i] if species in vietnam_species: top_probabilities.append(top_k_prob[i]) if species in scientific2vietnamese: top_classes.append(scientific2vietnamese[species]) url = f"https://nhandangdongvat.vercel.app/snakes/{species.replace(' ', '%20')}" more_info += f"- ℹ️ Tìm hiểu thêm về [{scientific2vietnamese[species]}]({url}) \n" else: top_classes.append(species) # Format the result return {class_name: float(prob) for class_name, prob in zip(top_classes, top_probabilities)}, more_info # Example images examples = [ ["example_snake00001.png"], ["example_snake00002.png"], ["example_snake00003.png"], ["example_snake00004.png"], ["example_snake00005.png"] ] # Supported snakes (in the accordion) supported_snakes = """ | Tên khoa học | Tên tiếng Việt | |----------------------|--------------------------| """ for snake in overlap_list: if snake in scientific2vietnamese.keys(): supported_snakes += f"|{snake}|{scientific2vietnamese[snake]}|\n" else: supported_snakes += f"|{snake}||\n" # Gradio interface description = """ Bạn có biết, ở Việt Nam có trên 250 loài rắn, trong đó chỉ có 12% rắn trên cạn là có độc và nguy hiểm với người. Khi môi trường sống của rắn bị thu hẹp thì việc đụng độ giữa người và rắn sẽ diễn ra thường xuyên hơn. Do vậy, mình phát triển ứng giúp này để giúp các bạn phân biệt được bạn rắn mà bạn bắt gặp có độc hay không bằng trí tuệ nhân tạo, từ đó có được biện pháp phòng tránh tốt nhất. Happy snaking 🐍. """ # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Ứng dụng Nhận diện Rắn 🐍") gr.Markdown(description) # Add description # Upload image section with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Tải lên hình ảnh của bạn") with gr.Column(): output = gr.Label(num_top_classes=5, label="Kết quả") more_info = gr.Markdown() # Automatically trigger prediction on image upload image_input.change(predict, inputs=image_input, outputs=[output, more_info]) # Example images gr.Examples(examples=examples, inputs=image_input, outputs=output, fn=predict, label="Ảnh chạy thử", cache_examples=False) # Accordion for supported snake species with gr.Accordion(f"Danh sách {len(overlap_list)} loài rắn ở Việt Nam được hỗ trợ", open=False): gr.Markdown(supported_snakes) # Launch the interface demo.launch(debug=True)