Tanaanan commited on
Commit
df7f40f
·
verified ·
1 Parent(s): 4dc8399

Upload app.py and requirements files

Browse files
Files changed (2) hide show
  1. app.py +109 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import timm
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ import gradio
10
+ import warnings
11
+ warnings.filterwarnings("ignore")
12
+
13
+ model = timm.create_model('swinv2_cr_tiny_ns_224.sw_in1k', pretrained=True)
14
+ output_shape = 60
15
+ model.classifier = torch.nn.Sequential(
16
+ torch.nn.Dropout(p=0.2, inplace=True),
17
+ torch.nn.Linear(in_features=1000,
18
+ out_features=output_shape,
19
+ bias=True).to('cpu'))
20
+ model.load_state_dict(torch.load('./swin_70_65.pth', map_location=torch.device('cpu')))
21
+
22
+ preprocess = transforms.Compose([
23
+ transforms.Resize((224, 224)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
26
+ ])
27
+
28
+ class_names = [
29
+ 'Ahaetulla_prasina', 'Bitis_arietans', 'Boa_constrictor', 'Boa_imperator',
30
+ 'Bothriechis_schlegelii', 'Bothrops_asper', 'Bothrops_atrox', 'Bungarus_fasciatus',
31
+ 'Chrysopelea_ornata', 'Coelognathus_radiatus', 'Corallus_hortulana', 'Coronella_austriaca',
32
+ 'Crotaphopeltis_hotamboeia', 'Dendrelaphis_pictus', 'Dolichophis_caspius', 'Drymarchon_melanurus',
33
+ 'Drymobius_margaritiferus', 'Elaphe_dione', 'Epicrates_cenchria', 'Erythrolamprus_poecilogyrus',
34
+ 'Eunectes_murinus', 'Fowlea_flavipunctata', 'Gonyosoma_oxycephalum', 'Helicops_angulatus',
35
+ 'Hierophis_viridiflavus', 'Imantodes_cenchoa', 'Indotyphlops_braminus', 'Laticauda_colubrina',
36
+ 'Leptodeira_annulata', 'Leptodeira_ornata', 'Leptodeira_septentrionalis', 'Leptophis_ahaetulla',
37
+ 'Leptophis_mexicanus', 'Lycodon_capucinus', 'Malayopython_reticulatus', 'Malpolon_insignitus',
38
+ 'Mastigodryas_boddaerti', 'Natrix_helvetica', 'Natrix_maura', 'Natrix_natrix', 'Natrix_tessellata',
39
+ 'Ninia_sebae', 'Ophiophagus_hannah', 'Oxybelis_aeneus', 'Oxybelis_fulgidus', 'Oxyrhopus_petolarius',
40
+ 'Phrynonax_poecilonotus', 'Psammodynastes_pulverulentus', 'Ptyas_korros', 'Ptyas_mucosa',
41
+ 'Python_bivittatus', 'Rhabdophis_tigrinus', 'Sibon_nebulatus', 'Spilotes_pullatus',
42
+ 'Tantilla_melanocephala', 'Trimeresurus_albolabris', 'Vipera_ammodytes', 'Vipera_aspis',
43
+ 'Vipera_berus', 'Zamenis_longissimus'
44
+ ]
45
+
46
+
47
+ def predict(image):
48
+ if image is None:
49
+ return "No image provided."
50
+
51
+ try:
52
+ input_tensor = preprocess(image)
53
+ except Exception as e:
54
+ return f"Error in preprocessing: {str(e)}"
55
+
56
+ input_batch = input_tensor.unsqueeze(0).to('cpu')
57
+
58
+ try:
59
+ with torch.no_grad():
60
+ output = model(input_batch)
61
+ except Exception as e:
62
+ return f"Error in model inference: {str(e)}"
63
+
64
+ probabilities = torch.nn.functional.softmax(output, dim=1)
65
+ percentages = probabilities[0].cpu().numpy() * 100
66
+
67
+ top_n = 5
68
+ combined = list(zip(class_names, percentages))
69
+ sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
70
+ top_predictions = sorted_combined[:top_n]
71
+
72
+ # Generate HTML for progress bars with numbers above
73
+ html_content = "<div style='font-family: Arial, sans-serif;'>"
74
+ for class_label, confidence in top_predictions:
75
+ html_content += f"""
76
+ <div style='margin-bottom: 10px; position: relative;'>
77
+ <div style='display: flex; align-items: center;'>
78
+ <strong style='flex: 1;'>{class_label}</strong>
79
+ <span style='flex-shrink: 0; color: black; margin-left: 10px;'>
80
+ {confidence:.2f}%
81
+ </span>
82
+ </div>
83
+ <div style='background-color: #f3f3f3; border-radius: 5px; width: 100%; height: 20px; margin-top: 5px;'>
84
+ <div style='background-color: #4CAF50; height: 100%; width: {confidence:.2f}%; border-radius: 5px;'></div>
85
+ </div>
86
+ </div>
87
+ """
88
+ html_content += "</div>"
89
+
90
+ return html_content
91
+
92
+
93
+ interface = gr.Interface(
94
+ fn=predict,
95
+ inputs=gr.Image(type='pil'),
96
+ outputs=gr.HTML(),
97
+ title="Snake Species Classification (SnakeCLEF2024)",
98
+ description = """
99
+ <div style='font-family: Arial, sans-serif; line-height: 1.6;'>
100
+ <p>Datasets and classes are referenced from: <a href="https://www.imageclef.org/node/319" target="_blank">ImageCLEF</a> and for more details: <a href="https://github.com/Tanaanan/SnakeCLEF2024_MLCS" target="_blank">GitHub repository</a></p>
101
+ <p style='font-size: smaller;'>This project is part of the course 'Machine Learning Systems (01418262).'</p>
102
+ <p style='font-size: smaller;'>Developed by Tanaanan, Narakorn, Department of Computer Science, Kasetsart University.</p>
103
+ </div>
104
+ """,
105
+ live=True
106
+ )
107
+
108
+ if __name__ == "__main__":
109
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ timm
3
+ gradio
4
+ os
5
+ Pillow
6
+ torchvision
7
+