File size: 5,045 Bytes
df7f40f 3b11ca6 df7f40f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
import os
import numpy as np
import pandas as pd
import torch
import timm
from PIL import Image
from torchvision import transforms
import gradio
import warnings
warnings.filterwarnings("ignore")
model = timm.create_model('swinv2_cr_tiny_ns_224.sw_in1k', pretrained=True)
output_shape = 60
model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2, inplace=True),
torch.nn.Linear(in_features=1000,
out_features=output_shape,
bias=True).to('cpu'))
model.load_state_dict(torch.load('./swin_70_65.pth', map_location=torch.device('cpu')))
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class_names = [
'Ahaetulla_prasina', 'Bitis_arietans', 'Boa_constrictor', 'Boa_imperator',
'Bothriechis_schlegelii', 'Bothrops_asper', 'Bothrops_atrox', 'Bungarus_fasciatus',
'Chrysopelea_ornata', 'Coelognathus_radiatus', 'Corallus_hortulana', 'Coronella_austriaca',
'Crotaphopeltis_hotamboeia', 'Dendrelaphis_pictus', 'Dolichophis_caspius', 'Drymarchon_melanurus',
'Drymobius_margaritiferus', 'Elaphe_dione', 'Epicrates_cenchria', 'Erythrolamprus_poecilogyrus',
'Eunectes_murinus', 'Fowlea_flavipunctata', 'Gonyosoma_oxycephalum', 'Helicops_angulatus',
'Hierophis_viridiflavus', 'Imantodes_cenchoa', 'Indotyphlops_braminus', 'Laticauda_colubrina',
'Leptodeira_annulata', 'Leptodeira_ornata', 'Leptodeira_septentrionalis', 'Leptophis_ahaetulla',
'Leptophis_mexicanus', 'Lycodon_capucinus', 'Malayopython_reticulatus', 'Malpolon_insignitus',
'Mastigodryas_boddaerti', 'Natrix_helvetica', 'Natrix_maura', 'Natrix_natrix', 'Natrix_tessellata',
'Ninia_sebae', 'Ophiophagus_hannah', 'Oxybelis_aeneus', 'Oxybelis_fulgidus', 'Oxyrhopus_petolarius',
'Phrynonax_poecilonotus', 'Psammodynastes_pulverulentus', 'Ptyas_korros', 'Ptyas_mucosa',
'Python_bivittatus', 'Rhabdophis_tigrinus', 'Sibon_nebulatus', 'Spilotes_pullatus',
'Tantilla_melanocephala', 'Trimeresurus_albolabris', 'Vipera_ammodytes', 'Vipera_aspis',
'Vipera_berus', 'Zamenis_longissimus'
]
def predict(image):
if image is None:
return "No image provided."
try:
input_tensor = preprocess(image)
except Exception as e:
return f"Error in preprocessing: {str(e)}"
input_batch = input_tensor.unsqueeze(0).to('cpu')
try:
with torch.no_grad():
output = model(input_batch)
except Exception as e:
return f"Error in model inference: {str(e)}"
probabilities = torch.nn.functional.softmax(output, dim=1)
percentages = probabilities[0].cpu().numpy() * 100
top_n = 5
combined = list(zip(class_names, percentages))
sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
top_predictions = sorted_combined[:top_n]
# Generate HTML for progress bars with numbers above
html_content = "<div style='font-family: Arial, sans-serif;'>"
for class_label, confidence in top_predictions:
html_content += f"""
<div style='margin-bottom: 10px; position: relative;'>
<div style='display: flex; align-items: center;'>
<strong style='flex: 1;'>{class_label}</strong>
<span style='flex-shrink: 0; color: black; margin-left: 10px;'>
{confidence:.2f}%
</span>
</div>
<div style='background-color: #f3f3f3; border-radius: 5px; width: 100%; height: 20px; margin-top: 5px;'>
<div style='background-color: #4CAF50; height: 100%; width: {confidence:.2f}%; border-radius: 5px;'></div>
</div>
</div>
"""
html_content += "</div>"
return html_content
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type='pil'),
outputs=gr.HTML(),
title="Snake Species Classification (SnakeCLEF2024)",
description = """
<div style='font-family: Arial, sans-serif; line-height: 1.6;'>
<p>Datasets and classes are referenced from: <a href="https://www.imageclef.org/node/319" target="_blank">ImageCLEF</a> and for more details: <a href="https://github.com/Tanaanan/SnakeCLEF2024_MLCS" target="_blank">GitHub repository</a></p>
<p style='font-size: smaller;'>This project is part of the course 'Machine Learning Systems (01418262).'</p>
<p style='font-size: smaller;'>Developed by Tanaanan, Narakorn, Department of Computer Science, Kasetsart University.</p>
</div>
""",
examples=['./sample_imgs/Bitis_arietans.png',
'./sample_imgs/Boa_imperator.png',
'./sample_imgs/Coelognathus_radiatus.png',
'./sample_imgs/Leptodeira_septentrionalis.png',
'./sample_imgs/Natrix_tessellata.png',
'./sample_imgs/Psammodynastes_pulverulentus.png',
'./sample_imgs/Ptyas_mucosa.png',
'./sample_imgs/Vipera_ammodytes.png',
'./sample_imgs/Zamenis_longissimus.png']
,
live=True
)
if __name__ == "__main__":
interface.launch()
|