Spaces:
Build error
Build error
Upload 7 files
Browse files- app.py +395 -0
- bird_classification_vit.pth +3 -0
- examples/1.jpg +0 -0
- examples/3.jpg +0 -0
- examples/5.jpg +0 -0
- model.py +24 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 1. Imports and class names setup ###
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from model import create_effnetb2_model
|
7 |
+
from timeit import default_timer as timer
|
8 |
+
from typing import Tuple, Dict
|
9 |
+
|
10 |
+
# Setup class names
|
11 |
+
class_names = ['AFRICAN CROWNED CRANE',
|
12 |
+
'AFRICAN FIREFINCH',
|
13 |
+
'ALBATROSS',
|
14 |
+
'ALEXANDRINE PARAKEET',
|
15 |
+
'AMERICAN AVOCET',
|
16 |
+
'AMERICAN BITTERN',
|
17 |
+
'AMERICAN COOT',
|
18 |
+
'AMERICAN GOLDFINCH',
|
19 |
+
'AMERICAN KESTREL',
|
20 |
+
'AMERICAN PIPIT',
|
21 |
+
'AMERICAN REDSTART',
|
22 |
+
'ANHINGA',
|
23 |
+
'ANNAS HUMMINGBIRD',
|
24 |
+
'ANTBIRD',
|
25 |
+
'ARARIPE MANAKIN',
|
26 |
+
'ASIAN CRESTED IBIS',
|
27 |
+
'BALD EAGLE',
|
28 |
+
'BALD IBIS',
|
29 |
+
'BALI STARLING',
|
30 |
+
'BALTIMORE ORIOLE',
|
31 |
+
'BANANAQUIT',
|
32 |
+
'BANDED BROADBILL',
|
33 |
+
'BANDED PITA',
|
34 |
+
'BAR-TAILED GODWIT',
|
35 |
+
'BARN OWL',
|
36 |
+
'BARN SWALLOW',
|
37 |
+
'BARRED PUFFBIRD',
|
38 |
+
'BAY-BREASTED WARBLER',
|
39 |
+
'BEARDED BARBET',
|
40 |
+
'BEARDED BELLBIRD',
|
41 |
+
'BEARDED REEDLING',
|
42 |
+
'BELTED KINGFISHER',
|
43 |
+
'BIRD OF PARADISE',
|
44 |
+
'BLACK & YELLOW bROADBILL',
|
45 |
+
'BLACK BAZA',
|
46 |
+
'BLACK FRANCOLIN',
|
47 |
+
'BLACK SKIMMER',
|
48 |
+
'BLACK SWAN',
|
49 |
+
'BLACK TAIL CRAKE',
|
50 |
+
'BLACK THROATED BUSHTIT',
|
51 |
+
'BLACK THROATED WARBLER',
|
52 |
+
'BLACK VULTURE',
|
53 |
+
'BLACK-CAPPED CHICKADEE',
|
54 |
+
'BLACK-NECKED GREBE',
|
55 |
+
'BLACK-THROATED SPARROW',
|
56 |
+
'BLACKBURNIAM WARBLER',
|
57 |
+
'BLONDE CRESTED WOODPECKER',
|
58 |
+
'BLUE COAU',
|
59 |
+
'BLUE GROUSE',
|
60 |
+
'BLUE HERON',
|
61 |
+
'BLUE THROATED TOUCANET',
|
62 |
+
'BOBOLINK',
|
63 |
+
'BORNEAN BRISTLEHEAD',
|
64 |
+
'BORNEAN LEAFBIRD',
|
65 |
+
'BORNEAN PHEASANT',
|
66 |
+
'BRANDT CORMARANT',
|
67 |
+
'BROWN CREPPER',
|
68 |
+
'BROWN NOODY',
|
69 |
+
'BROWN THRASHER',
|
70 |
+
'BULWERS PHEASANT',
|
71 |
+
'CACTUS WREN',
|
72 |
+
'CALIFORNIA CONDOR',
|
73 |
+
'CALIFORNIA GULL',
|
74 |
+
'CALIFORNIA QUAIL',
|
75 |
+
'CANARY',
|
76 |
+
'CAPE GLOSSY STARLING',
|
77 |
+
'CAPE MAY WARBLER',
|
78 |
+
'CAPPED HERON',
|
79 |
+
'CAPUCHINBIRD',
|
80 |
+
'CARMINE BEE-EATER',
|
81 |
+
'CASPIAN TERN',
|
82 |
+
'CASSOWARY',
|
83 |
+
'CEDAR WAXWING',
|
84 |
+
'CERULEAN WARBLER',
|
85 |
+
'CHARA DE COLLAR',
|
86 |
+
'CHESTNET BELLIED EUPHONIA',
|
87 |
+
'CHIPPING SPARROW',
|
88 |
+
'CHUKAR PARTRIDGE',
|
89 |
+
'CINNAMON TEAL',
|
90 |
+
'CLARKS NUTCRACKER',
|
91 |
+
'COCK OF THE ROCK',
|
92 |
+
'COCKATOO',
|
93 |
+
'COLLARED ARACARI',
|
94 |
+
'COMMON FIRECREST',
|
95 |
+
'COMMON GRACKLE',
|
96 |
+
'COMMON HOUSE MARTIN',
|
97 |
+
'COMMON LOON',
|
98 |
+
'COMMON POORWILL',
|
99 |
+
'COMMON STARLING',
|
100 |
+
'COUCHS KINGBIRD',
|
101 |
+
'CRESTED AUKLET',
|
102 |
+
'CRESTED CARACARA',
|
103 |
+
'CRESTED NUTHATCH',
|
104 |
+
'CRIMSON SUNBIRD',
|
105 |
+
'CROW',
|
106 |
+
'CROWNED PIGEON',
|
107 |
+
'CUBAN TODY',
|
108 |
+
'CUBAN TROGON',
|
109 |
+
'CURL CRESTED ARACURI',
|
110 |
+
'D-ARNAUDS BARBET',
|
111 |
+
'DARK EYED JUNCO',
|
112 |
+
'DOUBLE BARRED FINCH',
|
113 |
+
'DOUBLE BRESTED CORMARANT',
|
114 |
+
'DOWNY WOODPECKER',
|
115 |
+
'EASTERN BLUEBIRD',
|
116 |
+
'EASTERN MEADOWLARK',
|
117 |
+
'EASTERN ROSELLA',
|
118 |
+
'EASTERN TOWEE',
|
119 |
+
'ELEGANT TROGON',
|
120 |
+
'ELLIOTS PHEASANT',
|
121 |
+
'EMPEROR PENGUIN',
|
122 |
+
'EMU',
|
123 |
+
'ENGGANO MYNA',
|
124 |
+
'EURASIAN GOLDEN ORIOLE',
|
125 |
+
'EURASIAN MAGPIE',
|
126 |
+
'EVENING GROSBEAK',
|
127 |
+
'FAIRY BLUEBIRD',
|
128 |
+
'FIRE TAILLED MYZORNIS',
|
129 |
+
'FLAME TANAGER',
|
130 |
+
'FLAMINGO',
|
131 |
+
'FRIGATE',
|
132 |
+
'GAMBELS QUAIL',
|
133 |
+
'GANG GANG COCKATOO',
|
134 |
+
'GILA WOODPECKER',
|
135 |
+
'GILDED FLICKER',
|
136 |
+
'GLOSSY IBIS',
|
137 |
+
'GO AWAY BIRD',
|
138 |
+
'GOLD WING WARBLER',
|
139 |
+
'GOLDEN CHEEKED WARBLER',
|
140 |
+
'GOLDEN CHLOROPHONIA',
|
141 |
+
'GOLDEN EAGLE',
|
142 |
+
'GOLDEN PHEASANT',
|
143 |
+
'GOLDEN PIPIT',
|
144 |
+
'GOULDIAN FINCH',
|
145 |
+
'GRAY CATBIRD',
|
146 |
+
'GRAY KINGBIRD',
|
147 |
+
'GRAY PARTRIDGE',
|
148 |
+
'GREAT GRAY OWL',
|
149 |
+
'GREAT KISKADEE',
|
150 |
+
'GREAT POTOO',
|
151 |
+
'GREATOR SAGE GROUSE',
|
152 |
+
'GREEN BROADBILL',
|
153 |
+
'GREEN JAY',
|
154 |
+
'GREEN MAGPIE',
|
155 |
+
'GREY PLOVER',
|
156 |
+
'GROVED BILLED ANI',
|
157 |
+
'GUINEA TURACO',
|
158 |
+
'GUINEAFOWL',
|
159 |
+
'GYRFALCON',
|
160 |
+
'HARLEQUIN DUCK',
|
161 |
+
'HARPY EAGLE',
|
162 |
+
'HAWAIIAN GOOSE',
|
163 |
+
'HELMET VANGA',
|
164 |
+
'HIMALAYAN MONAL',
|
165 |
+
'HOATZIN',
|
166 |
+
'HOODED MERGANSER',
|
167 |
+
'HOOPOES',
|
168 |
+
'HORNBILL',
|
169 |
+
'HORNED GUAN',
|
170 |
+
'HORNED LARK',
|
171 |
+
'HORNED SUNGEM',
|
172 |
+
'HOUSE FINCH',
|
173 |
+
'HOUSE SPARROW',
|
174 |
+
'HYACINTH MACAW',
|
175 |
+
'IMPERIAL SHAQ',
|
176 |
+
'INCA TERN',
|
177 |
+
'INDIAN BUSTARD',
|
178 |
+
'INDIAN PITTA',
|
179 |
+
'INDIAN ROLLER',
|
180 |
+
'INDIGO BUNTING',
|
181 |
+
'IWI',
|
182 |
+
'JABIRU',
|
183 |
+
'JAVA SPARROW',
|
184 |
+
'KAGU',
|
185 |
+
'KAKAPO',
|
186 |
+
'KILLDEAR',
|
187 |
+
'KING VULTURE',
|
188 |
+
'KIWI',
|
189 |
+
'KOOKABURRA',
|
190 |
+
'LARK BUNTING',
|
191 |
+
'LAZULI BUNTING',
|
192 |
+
'LILAC ROLLER',
|
193 |
+
'LONG-EARED OWL',
|
194 |
+
'MAGPIE GOOSE',
|
195 |
+
'MALABAR HORNBILL',
|
196 |
+
'MALACHITE KINGFISHER',
|
197 |
+
'MALAGASY WHITE EYE',
|
198 |
+
'MALEO',
|
199 |
+
'MALLARD DUCK',
|
200 |
+
'MANDRIN DUCK',
|
201 |
+
'MANGROVE CUCKOO',
|
202 |
+
'MARABOU STORK',
|
203 |
+
'MASKED BOOBY',
|
204 |
+
'MASKED LAPWING',
|
205 |
+
'MIKADO PHEASANT',
|
206 |
+
'MOURNING DOVE',
|
207 |
+
'MYNA',
|
208 |
+
'NICOBAR PIGEON',
|
209 |
+
'NOISY FRIARBIRD',
|
210 |
+
'NORTHERN CARDINAL',
|
211 |
+
'NORTHERN FLICKER',
|
212 |
+
'NORTHERN FULMAR',
|
213 |
+
'NORTHERN GANNET',
|
214 |
+
'NORTHERN GOSHAWK',
|
215 |
+
'NORTHERN JACANA',
|
216 |
+
'NORTHERN MOCKINGBIRD',
|
217 |
+
'NORTHERN PARULA',
|
218 |
+
'NORTHERN RED BISHOP',
|
219 |
+
'NORTHERN SHOVELER',
|
220 |
+
'OCELLATED TURKEY',
|
221 |
+
'OKINAWA RAIL',
|
222 |
+
'ORANGE BRESTED BUNTING',
|
223 |
+
'ORIENTAL BAY OWL',
|
224 |
+
'OSPREY',
|
225 |
+
'OSTRICH',
|
226 |
+
'OVENBIRD',
|
227 |
+
'OYSTER CATCHER',
|
228 |
+
'PAINTED BUNTIG',
|
229 |
+
'PALILA',
|
230 |
+
'PARADISE TANAGER',
|
231 |
+
'PARAKETT AKULET',
|
232 |
+
'PARUS MAJOR',
|
233 |
+
'PATAGONIAN SIERRA FINCH',
|
234 |
+
'PEACOCK',
|
235 |
+
'PELICAN',
|
236 |
+
'PEREGRINE FALCON',
|
237 |
+
'PHILIPPINE EAGLE',
|
238 |
+
'PINK ROBIN',
|
239 |
+
'POMARINE JAEGER',
|
240 |
+
'PUFFIN',
|
241 |
+
'PURPLE FINCH',
|
242 |
+
'PURPLE GALLINULE',
|
243 |
+
'PURPLE MARTIN',
|
244 |
+
'PURPLE SWAMPHEN',
|
245 |
+
'PYGMY KINGFISHER',
|
246 |
+
'QUETZAL',
|
247 |
+
'RAINBOW LORIKEET',
|
248 |
+
'RAZORBILL',
|
249 |
+
'RED BEARDED BEE EATER',
|
250 |
+
'RED BELLIED PITTA',
|
251 |
+
'RED BROWED FINCH',
|
252 |
+
'RED FACED CORMORANT',
|
253 |
+
'RED FACED WARBLER',
|
254 |
+
'RED FODY',
|
255 |
+
'RED HEADED DUCK',
|
256 |
+
'RED HEADED WOODPECKER',
|
257 |
+
'RED HONEY CREEPER',
|
258 |
+
'RED NAPED TROGON',
|
259 |
+
'RED TAILED HAWK',
|
260 |
+
'RED TAILED THRUSH',
|
261 |
+
'RED WINGED BLACKBIRD',
|
262 |
+
'RED WISKERED BULBUL',
|
263 |
+
'REGENT BOWERBIRD',
|
264 |
+
'RING-NECKED PHEASANT',
|
265 |
+
'ROADRUNNER',
|
266 |
+
'ROBIN',
|
267 |
+
'ROCK DOVE',
|
268 |
+
'ROSY FACED LOVEBIRD',
|
269 |
+
'ROUGH LEG BUZZARD',
|
270 |
+
'ROYAL FLYCATCHER',
|
271 |
+
'RUBY THROATED HUMMINGBIRD',
|
272 |
+
'RUDY KINGFISHER',
|
273 |
+
'RUFOUS KINGFISHER',
|
274 |
+
'RUFUOS MOTMOT',
|
275 |
+
'SAMATRAN THRUSH',
|
276 |
+
'SAND MARTIN',
|
277 |
+
'SANDHILL CRANE',
|
278 |
+
'SATYR TRAGOPAN',
|
279 |
+
'SCARLET CROWNED FRUIT DOVE',
|
280 |
+
'SCARLET IBIS',
|
281 |
+
'SCARLET MACAW',
|
282 |
+
'SCARLET TANAGER',
|
283 |
+
'SHOEBILL',
|
284 |
+
'SHORT BILLED DOWITCHER',
|
285 |
+
'SMITHS LONGSPUR',
|
286 |
+
'SNOWY EGRET',
|
287 |
+
'SNOWY OWL',
|
288 |
+
'SORA',
|
289 |
+
'SPANGLED COTINGA',
|
290 |
+
'SPLENDID WREN',
|
291 |
+
'SPOON BILED SANDPIPER',
|
292 |
+
'SPOONBILL',
|
293 |
+
'SPOTTED CATBIRD',
|
294 |
+
'SRI LANKA BLUE MAGPIE',
|
295 |
+
'STEAMER DUCK',
|
296 |
+
'STORK BILLED KINGFISHER',
|
297 |
+
'STRAWBERRY FINCH',
|
298 |
+
'STRIPED OWL',
|
299 |
+
'STRIPPED MANAKIN',
|
300 |
+
'STRIPPED SWALLOW',
|
301 |
+
'SUPERB STARLING',
|
302 |
+
'SWINHOES PHEASANT',
|
303 |
+
'TAIWAN MAGPIE',
|
304 |
+
'TAKAHE',
|
305 |
+
'TASMANIAN HEN',
|
306 |
+
'TEAL DUCK',
|
307 |
+
'TIT MOUSE',
|
308 |
+
'TOUCHAN',
|
309 |
+
'TOWNSENDS WARBLER',
|
310 |
+
'TREE SWALLOW',
|
311 |
+
'TROPICAL KINGBIRD',
|
312 |
+
'TRUMPTER SWAN',
|
313 |
+
'TURKEY VULTURE',
|
314 |
+
'TURQUOISE MOTMOT',
|
315 |
+
'UMBRELLA BIRD',
|
316 |
+
'VARIED THRUSH',
|
317 |
+
'VENEZUELIAN TROUPIAL',
|
318 |
+
'VERMILION FLYCATHER',
|
319 |
+
'VICTORIA CROWNED PIGEON',
|
320 |
+
'VIOLET GREEN SWALLOW',
|
321 |
+
'VULTURINE GUINEAFOWL',
|
322 |
+
'WALL CREAPER',
|
323 |
+
'WATTLED CURASSOW',
|
324 |
+
'WHIMBREL',
|
325 |
+
'WHITE BROWED CRAKE',
|
326 |
+
'WHITE CHEEKED TURACO',
|
327 |
+
'WHITE NECKED RAVEN',
|
328 |
+
'WHITE TAILED TROPIC',
|
329 |
+
'WHITE THROATED BEE EATER',
|
330 |
+
'WILD TURKEY',
|
331 |
+
'WILSONS BIRD OF PARADISE',
|
332 |
+
'WOOD DUCK',
|
333 |
+
'YELLOW BELLIED FLOWERPECKER',
|
334 |
+
'YELLOW CACIQUE',
|
335 |
+
'YELLOW HEADED BLACKBIRD']
|
336 |
+
|
337 |
+
### 2. Model and transforms perparation ###
|
338 |
+
vit, vit_transforms = create_transformer_model(
|
339 |
+
num_classes=325)
|
340 |
+
|
341 |
+
# Load save weights
|
342 |
+
vit.load_state_dict(
|
343 |
+
torch.load(
|
344 |
+
f="bird_classification_vit.pth",
|
345 |
+
map_location=torch.device("cpu") # load the model to the CPU
|
346 |
+
)
|
347 |
+
)
|
348 |
+
|
349 |
+
### 3. Predict function ###
|
350 |
+
|
351 |
+
def predict(img) -> Tuple[Dict, float]:
|
352 |
+
# Start a timer
|
353 |
+
start_time = timer()
|
354 |
+
|
355 |
+
# Transform the input image for use with EffNetB2
|
356 |
+
img = vit_transforms(img).unsqueeze(0) # unsqueeze = add batch dimension on 0th index
|
357 |
+
|
358 |
+
# Put model into eval mode, make prediction
|
359 |
+
vit.eval()
|
360 |
+
with torch.inference_mode():
|
361 |
+
# Pass transformed image through the model and turn the prediction logits into probaiblities
|
362 |
+
pred_probs = torch.softmax(effnetb2(img), dim=1)
|
363 |
+
|
364 |
+
# Create a prediction label and prediction probability dictionary
|
365 |
+
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
|
366 |
+
|
367 |
+
# Calculate pred time
|
368 |
+
end_time = timer()
|
369 |
+
pred_time = round(end_time - start_time, 4)
|
370 |
+
|
371 |
+
# Return pred dict and pred time
|
372 |
+
return pred_labels_and_probs, pred_time
|
373 |
+
|
374 |
+
### 4. Gradio app ###
|
375 |
+
|
376 |
+
# Create title, description and article
|
377 |
+
title = "Bird Species Classification"
|
378 |
+
description = "A [Transformer] computer vision model to classify Bird species.."
|
379 |
+
article = "Created by Chidera Stanley [Transformer using PyTorch]"
|
380 |
+
|
381 |
+
# Create example list
|
382 |
+
example_list = [["examples/" + example] for example in os.listdir("examples")]
|
383 |
+
|
384 |
+
# Create the Gradio demo
|
385 |
+
app = gr.Interface(fn=predict, # maps inputs to outputs
|
386 |
+
inputs=gr.Image(type="pil"),
|
387 |
+
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
|
388 |
+
gr.Number(label="Prediction time (s)")],
|
389 |
+
examples=example_list,
|
390 |
+
title=title,
|
391 |
+
description=description,
|
392 |
+
article=article)
|
393 |
+
|
394 |
+
# Launch the demo!
|
395 |
+
app.launch()
|
bird_classification_vit.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d052c78bae58442b8a357ffcad5ca504e1393934b66b88bae71672aff542287f
|
3 |
+
size 344256753
|
examples/1.jpg
ADDED
examples/3.jpg
ADDED
examples/5.jpg
ADDED
model.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
def create_transformer_model(num_classes:int=325, # default output classes = 3 (pizza, steak, sushi)
|
7 |
+
seed:int=42):
|
8 |
+
# 1, 2, 3 Create EffNetB2 pretrained weights, transforms and model
|
9 |
+
weights = torchvision.models.ViT_B_16_Weights.DEFAULT
|
10 |
+
transforms = weights.transforms()
|
11 |
+
model = torchvision.models.vit_b_16(weights=weights)
|
12 |
+
|
13 |
+
# 4. Freeze all layers in the base model
|
14 |
+
for param in model.parameters():
|
15 |
+
param.requires_grad = False
|
16 |
+
|
17 |
+
# 5. Change classifier head with random seed for reproducibility
|
18 |
+
torch.manual_seed(seed)
|
19 |
+
model.classifier = nn.Sequential(
|
20 |
+
nn.Dropout(p=0.3, inplace=True),
|
21 |
+
nn.Linear(in_features=768, out_features=num_classes)
|
22 |
+
)
|
23 |
+
|
24 |
+
return model, transforms
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.12.0
|
2 |
+
torchvision==0.13.0
|
3 |
+
gradio==3.1.4
|