Upload app.py and requirements files
Browse files- app.py +109 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
import timm
|
7 |
+
from PIL import Image
|
8 |
+
from torchvision import transforms
|
9 |
+
import gradio
|
10 |
+
import warnings
|
11 |
+
warnings.filterwarnings("ignore")
|
12 |
+
|
13 |
+
model = timm.create_model('swinv2_cr_tiny_ns_224.sw_in1k', pretrained=True)
|
14 |
+
output_shape = 60
|
15 |
+
model.classifier = torch.nn.Sequential(
|
16 |
+
torch.nn.Dropout(p=0.2, inplace=True),
|
17 |
+
torch.nn.Linear(in_features=1000,
|
18 |
+
out_features=output_shape,
|
19 |
+
bias=True).to('cpu'))
|
20 |
+
model.load_state_dict(torch.load('./swin_70_65.pth', map_location=torch.device('cpu')))
|
21 |
+
|
22 |
+
preprocess = transforms.Compose([
|
23 |
+
transforms.Resize((224, 224)),
|
24 |
+
transforms.ToTensor(),
|
25 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
26 |
+
])
|
27 |
+
|
28 |
+
class_names = [
|
29 |
+
'Ahaetulla_prasina', 'Bitis_arietans', 'Boa_constrictor', 'Boa_imperator',
|
30 |
+
'Bothriechis_schlegelii', 'Bothrops_asper', 'Bothrops_atrox', 'Bungarus_fasciatus',
|
31 |
+
'Chrysopelea_ornata', 'Coelognathus_radiatus', 'Corallus_hortulana', 'Coronella_austriaca',
|
32 |
+
'Crotaphopeltis_hotamboeia', 'Dendrelaphis_pictus', 'Dolichophis_caspius', 'Drymarchon_melanurus',
|
33 |
+
'Drymobius_margaritiferus', 'Elaphe_dione', 'Epicrates_cenchria', 'Erythrolamprus_poecilogyrus',
|
34 |
+
'Eunectes_murinus', 'Fowlea_flavipunctata', 'Gonyosoma_oxycephalum', 'Helicops_angulatus',
|
35 |
+
'Hierophis_viridiflavus', 'Imantodes_cenchoa', 'Indotyphlops_braminus', 'Laticauda_colubrina',
|
36 |
+
'Leptodeira_annulata', 'Leptodeira_ornata', 'Leptodeira_septentrionalis', 'Leptophis_ahaetulla',
|
37 |
+
'Leptophis_mexicanus', 'Lycodon_capucinus', 'Malayopython_reticulatus', 'Malpolon_insignitus',
|
38 |
+
'Mastigodryas_boddaerti', 'Natrix_helvetica', 'Natrix_maura', 'Natrix_natrix', 'Natrix_tessellata',
|
39 |
+
'Ninia_sebae', 'Ophiophagus_hannah', 'Oxybelis_aeneus', 'Oxybelis_fulgidus', 'Oxyrhopus_petolarius',
|
40 |
+
'Phrynonax_poecilonotus', 'Psammodynastes_pulverulentus', 'Ptyas_korros', 'Ptyas_mucosa',
|
41 |
+
'Python_bivittatus', 'Rhabdophis_tigrinus', 'Sibon_nebulatus', 'Spilotes_pullatus',
|
42 |
+
'Tantilla_melanocephala', 'Trimeresurus_albolabris', 'Vipera_ammodytes', 'Vipera_aspis',
|
43 |
+
'Vipera_berus', 'Zamenis_longissimus'
|
44 |
+
]
|
45 |
+
|
46 |
+
|
47 |
+
def predict(image):
|
48 |
+
if image is None:
|
49 |
+
return "No image provided."
|
50 |
+
|
51 |
+
try:
|
52 |
+
input_tensor = preprocess(image)
|
53 |
+
except Exception as e:
|
54 |
+
return f"Error in preprocessing: {str(e)}"
|
55 |
+
|
56 |
+
input_batch = input_tensor.unsqueeze(0).to('cpu')
|
57 |
+
|
58 |
+
try:
|
59 |
+
with torch.no_grad():
|
60 |
+
output = model(input_batch)
|
61 |
+
except Exception as e:
|
62 |
+
return f"Error in model inference: {str(e)}"
|
63 |
+
|
64 |
+
probabilities = torch.nn.functional.softmax(output, dim=1)
|
65 |
+
percentages = probabilities[0].cpu().numpy() * 100
|
66 |
+
|
67 |
+
top_n = 5
|
68 |
+
combined = list(zip(class_names, percentages))
|
69 |
+
sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
|
70 |
+
top_predictions = sorted_combined[:top_n]
|
71 |
+
|
72 |
+
# Generate HTML for progress bars with numbers above
|
73 |
+
html_content = "<div style='font-family: Arial, sans-serif;'>"
|
74 |
+
for class_label, confidence in top_predictions:
|
75 |
+
html_content += f"""
|
76 |
+
<div style='margin-bottom: 10px; position: relative;'>
|
77 |
+
<div style='display: flex; align-items: center;'>
|
78 |
+
<strong style='flex: 1;'>{class_label}</strong>
|
79 |
+
<span style='flex-shrink: 0; color: black; margin-left: 10px;'>
|
80 |
+
{confidence:.2f}%
|
81 |
+
</span>
|
82 |
+
</div>
|
83 |
+
<div style='background-color: #f3f3f3; border-radius: 5px; width: 100%; height: 20px; margin-top: 5px;'>
|
84 |
+
<div style='background-color: #4CAF50; height: 100%; width: {confidence:.2f}%; border-radius: 5px;'></div>
|
85 |
+
</div>
|
86 |
+
</div>
|
87 |
+
"""
|
88 |
+
html_content += "</div>"
|
89 |
+
|
90 |
+
return html_content
|
91 |
+
|
92 |
+
|
93 |
+
interface = gr.Interface(
|
94 |
+
fn=predict,
|
95 |
+
inputs=gr.Image(type='pil'),
|
96 |
+
outputs=gr.HTML(),
|
97 |
+
title="Snake Species Classification (SnakeCLEF2024)",
|
98 |
+
description = """
|
99 |
+
<div style='font-family: Arial, sans-serif; line-height: 1.6;'>
|
100 |
+
<p>Datasets and classes are referenced from: <a href="https://www.imageclef.org/node/319" target="_blank">ImageCLEF</a> and for more details: <a href="https://github.com/Tanaanan/SnakeCLEF2024_MLCS" target="_blank">GitHub repository</a></p>
|
101 |
+
<p style='font-size: smaller;'>This project is part of the course 'Machine Learning Systems (01418262).'</p>
|
102 |
+
<p style='font-size: smaller;'>Developed by Tanaanan, Narakorn, Department of Computer Science, Kasetsart University.</p>
|
103 |
+
</div>
|
104 |
+
""",
|
105 |
+
live=True
|
106 |
+
)
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
interface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
timm
|
3 |
+
gradio
|
4 |
+
os
|
5 |
+
Pillow
|
6 |
+
torchvision
|
7 |
+
|