Stanlito commited on
Commit
52284cd
·
1 Parent(s): 691411c

Upload 7 files

Browse files
Files changed (7) hide show
  1. app.py +395 -0
  2. bird_classification_vit.pth +3 -0
  3. examples/1.jpg +0 -0
  4. examples/3.jpg +0 -0
  5. examples/5.jpg +0 -0
  6. model.py +24 -0
  7. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from model import create_effnetb2_model
7
+ from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
+
10
+ # Setup class names
11
+ class_names = ['AFRICAN CROWNED CRANE',
12
+ 'AFRICAN FIREFINCH',
13
+ 'ALBATROSS',
14
+ 'ALEXANDRINE PARAKEET',
15
+ 'AMERICAN AVOCET',
16
+ 'AMERICAN BITTERN',
17
+ 'AMERICAN COOT',
18
+ 'AMERICAN GOLDFINCH',
19
+ 'AMERICAN KESTREL',
20
+ 'AMERICAN PIPIT',
21
+ 'AMERICAN REDSTART',
22
+ 'ANHINGA',
23
+ 'ANNAS HUMMINGBIRD',
24
+ 'ANTBIRD',
25
+ 'ARARIPE MANAKIN',
26
+ 'ASIAN CRESTED IBIS',
27
+ 'BALD EAGLE',
28
+ 'BALD IBIS',
29
+ 'BALI STARLING',
30
+ 'BALTIMORE ORIOLE',
31
+ 'BANANAQUIT',
32
+ 'BANDED BROADBILL',
33
+ 'BANDED PITA',
34
+ 'BAR-TAILED GODWIT',
35
+ 'BARN OWL',
36
+ 'BARN SWALLOW',
37
+ 'BARRED PUFFBIRD',
38
+ 'BAY-BREASTED WARBLER',
39
+ 'BEARDED BARBET',
40
+ 'BEARDED BELLBIRD',
41
+ 'BEARDED REEDLING',
42
+ 'BELTED KINGFISHER',
43
+ 'BIRD OF PARADISE',
44
+ 'BLACK & YELLOW bROADBILL',
45
+ 'BLACK BAZA',
46
+ 'BLACK FRANCOLIN',
47
+ 'BLACK SKIMMER',
48
+ 'BLACK SWAN',
49
+ 'BLACK TAIL CRAKE',
50
+ 'BLACK THROATED BUSHTIT',
51
+ 'BLACK THROATED WARBLER',
52
+ 'BLACK VULTURE',
53
+ 'BLACK-CAPPED CHICKADEE',
54
+ 'BLACK-NECKED GREBE',
55
+ 'BLACK-THROATED SPARROW',
56
+ 'BLACKBURNIAM WARBLER',
57
+ 'BLONDE CRESTED WOODPECKER',
58
+ 'BLUE COAU',
59
+ 'BLUE GROUSE',
60
+ 'BLUE HERON',
61
+ 'BLUE THROATED TOUCANET',
62
+ 'BOBOLINK',
63
+ 'BORNEAN BRISTLEHEAD',
64
+ 'BORNEAN LEAFBIRD',
65
+ 'BORNEAN PHEASANT',
66
+ 'BRANDT CORMARANT',
67
+ 'BROWN CREPPER',
68
+ 'BROWN NOODY',
69
+ 'BROWN THRASHER',
70
+ 'BULWERS PHEASANT',
71
+ 'CACTUS WREN',
72
+ 'CALIFORNIA CONDOR',
73
+ 'CALIFORNIA GULL',
74
+ 'CALIFORNIA QUAIL',
75
+ 'CANARY',
76
+ 'CAPE GLOSSY STARLING',
77
+ 'CAPE MAY WARBLER',
78
+ 'CAPPED HERON',
79
+ 'CAPUCHINBIRD',
80
+ 'CARMINE BEE-EATER',
81
+ 'CASPIAN TERN',
82
+ 'CASSOWARY',
83
+ 'CEDAR WAXWING',
84
+ 'CERULEAN WARBLER',
85
+ 'CHARA DE COLLAR',
86
+ 'CHESTNET BELLIED EUPHONIA',
87
+ 'CHIPPING SPARROW',
88
+ 'CHUKAR PARTRIDGE',
89
+ 'CINNAMON TEAL',
90
+ 'CLARKS NUTCRACKER',
91
+ 'COCK OF THE ROCK',
92
+ 'COCKATOO',
93
+ 'COLLARED ARACARI',
94
+ 'COMMON FIRECREST',
95
+ 'COMMON GRACKLE',
96
+ 'COMMON HOUSE MARTIN',
97
+ 'COMMON LOON',
98
+ 'COMMON POORWILL',
99
+ 'COMMON STARLING',
100
+ 'COUCHS KINGBIRD',
101
+ 'CRESTED AUKLET',
102
+ 'CRESTED CARACARA',
103
+ 'CRESTED NUTHATCH',
104
+ 'CRIMSON SUNBIRD',
105
+ 'CROW',
106
+ 'CROWNED PIGEON',
107
+ 'CUBAN TODY',
108
+ 'CUBAN TROGON',
109
+ 'CURL CRESTED ARACURI',
110
+ 'D-ARNAUDS BARBET',
111
+ 'DARK EYED JUNCO',
112
+ 'DOUBLE BARRED FINCH',
113
+ 'DOUBLE BRESTED CORMARANT',
114
+ 'DOWNY WOODPECKER',
115
+ 'EASTERN BLUEBIRD',
116
+ 'EASTERN MEADOWLARK',
117
+ 'EASTERN ROSELLA',
118
+ 'EASTERN TOWEE',
119
+ 'ELEGANT TROGON',
120
+ 'ELLIOTS PHEASANT',
121
+ 'EMPEROR PENGUIN',
122
+ 'EMU',
123
+ 'ENGGANO MYNA',
124
+ 'EURASIAN GOLDEN ORIOLE',
125
+ 'EURASIAN MAGPIE',
126
+ 'EVENING GROSBEAK',
127
+ 'FAIRY BLUEBIRD',
128
+ 'FIRE TAILLED MYZORNIS',
129
+ 'FLAME TANAGER',
130
+ 'FLAMINGO',
131
+ 'FRIGATE',
132
+ 'GAMBELS QUAIL',
133
+ 'GANG GANG COCKATOO',
134
+ 'GILA WOODPECKER',
135
+ 'GILDED FLICKER',
136
+ 'GLOSSY IBIS',
137
+ 'GO AWAY BIRD',
138
+ 'GOLD WING WARBLER',
139
+ 'GOLDEN CHEEKED WARBLER',
140
+ 'GOLDEN CHLOROPHONIA',
141
+ 'GOLDEN EAGLE',
142
+ 'GOLDEN PHEASANT',
143
+ 'GOLDEN PIPIT',
144
+ 'GOULDIAN FINCH',
145
+ 'GRAY CATBIRD',
146
+ 'GRAY KINGBIRD',
147
+ 'GRAY PARTRIDGE',
148
+ 'GREAT GRAY OWL',
149
+ 'GREAT KISKADEE',
150
+ 'GREAT POTOO',
151
+ 'GREATOR SAGE GROUSE',
152
+ 'GREEN BROADBILL',
153
+ 'GREEN JAY',
154
+ 'GREEN MAGPIE',
155
+ 'GREY PLOVER',
156
+ 'GROVED BILLED ANI',
157
+ 'GUINEA TURACO',
158
+ 'GUINEAFOWL',
159
+ 'GYRFALCON',
160
+ 'HARLEQUIN DUCK',
161
+ 'HARPY EAGLE',
162
+ 'HAWAIIAN GOOSE',
163
+ 'HELMET VANGA',
164
+ 'HIMALAYAN MONAL',
165
+ 'HOATZIN',
166
+ 'HOODED MERGANSER',
167
+ 'HOOPOES',
168
+ 'HORNBILL',
169
+ 'HORNED GUAN',
170
+ 'HORNED LARK',
171
+ 'HORNED SUNGEM',
172
+ 'HOUSE FINCH',
173
+ 'HOUSE SPARROW',
174
+ 'HYACINTH MACAW',
175
+ 'IMPERIAL SHAQ',
176
+ 'INCA TERN',
177
+ 'INDIAN BUSTARD',
178
+ 'INDIAN PITTA',
179
+ 'INDIAN ROLLER',
180
+ 'INDIGO BUNTING',
181
+ 'IWI',
182
+ 'JABIRU',
183
+ 'JAVA SPARROW',
184
+ 'KAGU',
185
+ 'KAKAPO',
186
+ 'KILLDEAR',
187
+ 'KING VULTURE',
188
+ 'KIWI',
189
+ 'KOOKABURRA',
190
+ 'LARK BUNTING',
191
+ 'LAZULI BUNTING',
192
+ 'LILAC ROLLER',
193
+ 'LONG-EARED OWL',
194
+ 'MAGPIE GOOSE',
195
+ 'MALABAR HORNBILL',
196
+ 'MALACHITE KINGFISHER',
197
+ 'MALAGASY WHITE EYE',
198
+ 'MALEO',
199
+ 'MALLARD DUCK',
200
+ 'MANDRIN DUCK',
201
+ 'MANGROVE CUCKOO',
202
+ 'MARABOU STORK',
203
+ 'MASKED BOOBY',
204
+ 'MASKED LAPWING',
205
+ 'MIKADO PHEASANT',
206
+ 'MOURNING DOVE',
207
+ 'MYNA',
208
+ 'NICOBAR PIGEON',
209
+ 'NOISY FRIARBIRD',
210
+ 'NORTHERN CARDINAL',
211
+ 'NORTHERN FLICKER',
212
+ 'NORTHERN FULMAR',
213
+ 'NORTHERN GANNET',
214
+ 'NORTHERN GOSHAWK',
215
+ 'NORTHERN JACANA',
216
+ 'NORTHERN MOCKINGBIRD',
217
+ 'NORTHERN PARULA',
218
+ 'NORTHERN RED BISHOP',
219
+ 'NORTHERN SHOVELER',
220
+ 'OCELLATED TURKEY',
221
+ 'OKINAWA RAIL',
222
+ 'ORANGE BRESTED BUNTING',
223
+ 'ORIENTAL BAY OWL',
224
+ 'OSPREY',
225
+ 'OSTRICH',
226
+ 'OVENBIRD',
227
+ 'OYSTER CATCHER',
228
+ 'PAINTED BUNTIG',
229
+ 'PALILA',
230
+ 'PARADISE TANAGER',
231
+ 'PARAKETT AKULET',
232
+ 'PARUS MAJOR',
233
+ 'PATAGONIAN SIERRA FINCH',
234
+ 'PEACOCK',
235
+ 'PELICAN',
236
+ 'PEREGRINE FALCON',
237
+ 'PHILIPPINE EAGLE',
238
+ 'PINK ROBIN',
239
+ 'POMARINE JAEGER',
240
+ 'PUFFIN',
241
+ 'PURPLE FINCH',
242
+ 'PURPLE GALLINULE',
243
+ 'PURPLE MARTIN',
244
+ 'PURPLE SWAMPHEN',
245
+ 'PYGMY KINGFISHER',
246
+ 'QUETZAL',
247
+ 'RAINBOW LORIKEET',
248
+ 'RAZORBILL',
249
+ 'RED BEARDED BEE EATER',
250
+ 'RED BELLIED PITTA',
251
+ 'RED BROWED FINCH',
252
+ 'RED FACED CORMORANT',
253
+ 'RED FACED WARBLER',
254
+ 'RED FODY',
255
+ 'RED HEADED DUCK',
256
+ 'RED HEADED WOODPECKER',
257
+ 'RED HONEY CREEPER',
258
+ 'RED NAPED TROGON',
259
+ 'RED TAILED HAWK',
260
+ 'RED TAILED THRUSH',
261
+ 'RED WINGED BLACKBIRD',
262
+ 'RED WISKERED BULBUL',
263
+ 'REGENT BOWERBIRD',
264
+ 'RING-NECKED PHEASANT',
265
+ 'ROADRUNNER',
266
+ 'ROBIN',
267
+ 'ROCK DOVE',
268
+ 'ROSY FACED LOVEBIRD',
269
+ 'ROUGH LEG BUZZARD',
270
+ 'ROYAL FLYCATCHER',
271
+ 'RUBY THROATED HUMMINGBIRD',
272
+ 'RUDY KINGFISHER',
273
+ 'RUFOUS KINGFISHER',
274
+ 'RUFUOS MOTMOT',
275
+ 'SAMATRAN THRUSH',
276
+ 'SAND MARTIN',
277
+ 'SANDHILL CRANE',
278
+ 'SATYR TRAGOPAN',
279
+ 'SCARLET CROWNED FRUIT DOVE',
280
+ 'SCARLET IBIS',
281
+ 'SCARLET MACAW',
282
+ 'SCARLET TANAGER',
283
+ 'SHOEBILL',
284
+ 'SHORT BILLED DOWITCHER',
285
+ 'SMITHS LONGSPUR',
286
+ 'SNOWY EGRET',
287
+ 'SNOWY OWL',
288
+ 'SORA',
289
+ 'SPANGLED COTINGA',
290
+ 'SPLENDID WREN',
291
+ 'SPOON BILED SANDPIPER',
292
+ 'SPOONBILL',
293
+ 'SPOTTED CATBIRD',
294
+ 'SRI LANKA BLUE MAGPIE',
295
+ 'STEAMER DUCK',
296
+ 'STORK BILLED KINGFISHER',
297
+ 'STRAWBERRY FINCH',
298
+ 'STRIPED OWL',
299
+ 'STRIPPED MANAKIN',
300
+ 'STRIPPED SWALLOW',
301
+ 'SUPERB STARLING',
302
+ 'SWINHOES PHEASANT',
303
+ 'TAIWAN MAGPIE',
304
+ 'TAKAHE',
305
+ 'TASMANIAN HEN',
306
+ 'TEAL DUCK',
307
+ 'TIT MOUSE',
308
+ 'TOUCHAN',
309
+ 'TOWNSENDS WARBLER',
310
+ 'TREE SWALLOW',
311
+ 'TROPICAL KINGBIRD',
312
+ 'TRUMPTER SWAN',
313
+ 'TURKEY VULTURE',
314
+ 'TURQUOISE MOTMOT',
315
+ 'UMBRELLA BIRD',
316
+ 'VARIED THRUSH',
317
+ 'VENEZUELIAN TROUPIAL',
318
+ 'VERMILION FLYCATHER',
319
+ 'VICTORIA CROWNED PIGEON',
320
+ 'VIOLET GREEN SWALLOW',
321
+ 'VULTURINE GUINEAFOWL',
322
+ 'WALL CREAPER',
323
+ 'WATTLED CURASSOW',
324
+ 'WHIMBREL',
325
+ 'WHITE BROWED CRAKE',
326
+ 'WHITE CHEEKED TURACO',
327
+ 'WHITE NECKED RAVEN',
328
+ 'WHITE TAILED TROPIC',
329
+ 'WHITE THROATED BEE EATER',
330
+ 'WILD TURKEY',
331
+ 'WILSONS BIRD OF PARADISE',
332
+ 'WOOD DUCK',
333
+ 'YELLOW BELLIED FLOWERPECKER',
334
+ 'YELLOW CACIQUE',
335
+ 'YELLOW HEADED BLACKBIRD']
336
+
337
+ ### 2. Model and transforms perparation ###
338
+ vit, vit_transforms = create_transformer_model(
339
+ num_classes=325)
340
+
341
+ # Load save weights
342
+ vit.load_state_dict(
343
+ torch.load(
344
+ f="bird_classification_vit.pth",
345
+ map_location=torch.device("cpu") # load the model to the CPU
346
+ )
347
+ )
348
+
349
+ ### 3. Predict function ###
350
+
351
+ def predict(img) -> Tuple[Dict, float]:
352
+ # Start a timer
353
+ start_time = timer()
354
+
355
+ # Transform the input image for use with EffNetB2
356
+ img = vit_transforms(img).unsqueeze(0) # unsqueeze = add batch dimension on 0th index
357
+
358
+ # Put model into eval mode, make prediction
359
+ vit.eval()
360
+ with torch.inference_mode():
361
+ # Pass transformed image through the model and turn the prediction logits into probaiblities
362
+ pred_probs = torch.softmax(effnetb2(img), dim=1)
363
+
364
+ # Create a prediction label and prediction probability dictionary
365
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
366
+
367
+ # Calculate pred time
368
+ end_time = timer()
369
+ pred_time = round(end_time - start_time, 4)
370
+
371
+ # Return pred dict and pred time
372
+ return pred_labels_and_probs, pred_time
373
+
374
+ ### 4. Gradio app ###
375
+
376
+ # Create title, description and article
377
+ title = "Bird Species Classification"
378
+ description = "A [Transformer] computer vision model to classify Bird species.."
379
+ article = "Created by Chidera Stanley [Transformer using PyTorch]"
380
+
381
+ # Create example list
382
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
383
+
384
+ # Create the Gradio demo
385
+ app = gr.Interface(fn=predict, # maps inputs to outputs
386
+ inputs=gr.Image(type="pil"),
387
+ outputs=[gr.Label(num_top_classes=3, label="Predictions"),
388
+ gr.Number(label="Prediction time (s)")],
389
+ examples=example_list,
390
+ title=title,
391
+ description=description,
392
+ article=article)
393
+
394
+ # Launch the demo!
395
+ app.launch()
bird_classification_vit.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d052c78bae58442b8a357ffcad5ca504e1393934b66b88bae71672aff542287f
3
+ size 344256753
examples/1.jpg ADDED
examples/3.jpg ADDED
examples/5.jpg ADDED
model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ from torch import nn
5
+
6
+ def create_transformer_model(num_classes:int=325, # default output classes = 3 (pizza, steak, sushi)
7
+ seed:int=42):
8
+ # 1, 2, 3 Create EffNetB2 pretrained weights, transforms and model
9
+ weights = torchvision.models.ViT_B_16_Weights.DEFAULT
10
+ transforms = weights.transforms()
11
+ model = torchvision.models.vit_b_16(weights=weights)
12
+
13
+ # 4. Freeze all layers in the base model
14
+ for param in model.parameters():
15
+ param.requires_grad = False
16
+
17
+ # 5. Change classifier head with random seed for reproducibility
18
+ torch.manual_seed(seed)
19
+ model.classifier = nn.Sequential(
20
+ nn.Dropout(p=0.3, inplace=True),
21
+ nn.Linear(in_features=768, out_features=num_classes)
22
+ )
23
+
24
+ return model, transforms
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.1.4