nikhiljais commited on
Commit
6d39f1a
·
2 Parent(s): 7a01e86 a2245ca

Merge master into main and resolve conflicts

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. app.py +133 -0
  3. best_model.pth +3 -0
  4. imagenet_classes.json +1000 -0
  5. requirements.txt +7 -0
  6. train_resnet50_local.py +637 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ best_model.pth filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ import json
7
+ import os
8
+ from train_resnet50 import ResNet, Bottleneck
9
+
10
+ # Load ImageNet class labels
11
+ try:
12
+ with open("imagenet_classes.json", "r") as f:
13
+ class_labels = json.load(f)
14
+ print(f"Loaded {len(class_labels)} class labels")
15
+ except FileNotFoundError:
16
+ print("Warning: imagenet_classes.json not found, creating simplified labels")
17
+ # Fallback to a simplified version
18
+ class_labels = {str(i): f"class_{i}" for i in range(1000)}
19
+ except json.JSONDecodeError:
20
+ print("Warning: Error parsing imagenet_classes.json, using simplified labels")
21
+ class_labels = {str(i): f"class_{i}" for i in range(1000)}
22
+ except Exception as e:
23
+ print(f"Warning: Unexpected error loading class labels: {e}")
24
+ class_labels = {str(i): f"class_{i}" for i in range(1000)}
25
+
26
+
27
+ def create_model():
28
+ model = ResNet(Bottleneck, [3, 4, 6, 3])
29
+ return model
30
+
31
+
32
+ def load_model(model_path):
33
+ model = create_model()
34
+ try:
35
+ checkpoint = torch.load(model_path, map_location="cpu")
36
+
37
+ # Handle DataParallel/DDP state dict
38
+ state_dict = checkpoint["model_state_dict"]
39
+ new_state_dict = {}
40
+ for k, v in state_dict.items():
41
+ name = k.replace("module.", "") if k.startswith("module.") else k
42
+ new_state_dict[name] = v
43
+
44
+ model.load_state_dict(new_state_dict)
45
+ model.eval()
46
+ print("Model loaded successfully!")
47
+ return model
48
+ except Exception as e:
49
+ print(f"Error loading model: {e}")
50
+ print("Loading pretrained ResNet50 as fallback...")
51
+ model = torch.hub.load("pytorch/vision:v0.10.0", "resnet50", pretrained=True)
52
+ model.eval()
53
+ return model
54
+
55
+
56
+ # Preprocessing transform
57
+ transform = transforms.Compose(
58
+ [
59
+ transforms.Resize(256),
60
+ transforms.CenterCrop(224),
61
+ transforms.ToTensor(),
62
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
63
+ ]
64
+ )
65
+
66
+ # Global variable for model
67
+ global_model = None
68
+
69
+
70
+ def predict(image):
71
+ global global_model
72
+
73
+ # Load model only once
74
+ if global_model is None:
75
+ try:
76
+ global_model = load_model("best_model.pth")
77
+ except Exception as e:
78
+ print(f"Error loading model: {e}")
79
+ return None
80
+
81
+ # Preprocess image
82
+ if image is None:
83
+ return None
84
+
85
+ try:
86
+ image = Image.fromarray(image)
87
+ image = transform(image).unsqueeze(0)
88
+
89
+ # Make prediction
90
+ with torch.no_grad():
91
+ outputs = global_model(image)
92
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
93
+
94
+ # Get top 5 predictions
95
+ top5_prob, top5_catid = torch.topk(probabilities, 5)
96
+
97
+ # Create results dictionary
98
+ results = []
99
+ for i in range(5):
100
+ class_idx = top5_catid[i].item()
101
+ # Use list indexing instead of dictionary get()
102
+ class_label = (
103
+ class_labels[class_idx]
104
+ if class_idx < len(class_labels)
105
+ else f"class_{class_idx}"
106
+ )
107
+ results.append(
108
+ {
109
+ "label": class_label,
110
+ "class_id": class_idx,
111
+ "confidence": float(top5_prob[i].item()),
112
+ }
113
+ )
114
+
115
+ return results
116
+ except Exception as e:
117
+ print(f"Error during prediction: {e}")
118
+ print(f"Class indices: {[idx.item() for idx in top5_catid]}") # Debug info
119
+ return None
120
+
121
+
122
+ # Create Gradio interface
123
+ iface = gr.Interface(
124
+ fn=predict,
125
+ inputs=gr.Image(),
126
+ outputs=gr.JSON(),
127
+ title="ResNet50 ImageNet Classifier",
128
+ description="Upload an image to get top-5 predictions from our trained ResNet50 model.",
129
+ )
130
+
131
+ # Launch the app
132
+ if __name__ == "__main__":
133
+ iface.launch(share=True)
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccb415799634943df1fa09c3e9a6cb9d2c51152db86ba59ed55fd108cb564ea7
3
+ size 204827231
imagenet_classes.json ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ["tench",
2
+ "goldfish",
3
+ "great white shark",
4
+ "tiger shark",
5
+ "hammerhead shark",
6
+ "electric ray",
7
+ "stingray",
8
+ "cock",
9
+ "hen",
10
+ "ostrich",
11
+ "brambling",
12
+ "goldfinch",
13
+ "house finch",
14
+ "junco",
15
+ "indigo bunting",
16
+ "American robin",
17
+ "bulbul",
18
+ "jay",
19
+ "magpie",
20
+ "chickadee",
21
+ "American dipper",
22
+ "kite",
23
+ "bald eagle",
24
+ "vulture",
25
+ "great grey owl",
26
+ "fire salamander",
27
+ "smooth newt",
28
+ "newt",
29
+ "spotted salamander",
30
+ "axolotl",
31
+ "American bullfrog",
32
+ "tree frog",
33
+ "tailed frog",
34
+ "loggerhead sea turtle",
35
+ "leatherback sea turtle",
36
+ "mud turtle",
37
+ "terrapin",
38
+ "box turtle",
39
+ "banded gecko",
40
+ "green iguana",
41
+ "Carolina anole",
42
+ "desert grassland whiptail lizard",
43
+ "agama",
44
+ "frilled-necked lizard",
45
+ "alligator lizard",
46
+ "Gila monster",
47
+ "European green lizard",
48
+ "chameleon",
49
+ "Komodo dragon",
50
+ "Nile crocodile",
51
+ "American alligator",
52
+ "triceratops",
53
+ "worm snake",
54
+ "ring-necked snake",
55
+ "eastern hog-nosed snake",
56
+ "smooth green snake",
57
+ "kingsnake",
58
+ "garter snake",
59
+ "water snake",
60
+ "vine snake",
61
+ "night snake",
62
+ "boa constrictor",
63
+ "African rock python",
64
+ "Indian cobra",
65
+ "green mamba",
66
+ "sea snake",
67
+ "Saharan horned viper",
68
+ "eastern diamondback rattlesnake",
69
+ "sidewinder",
70
+ "trilobite",
71
+ "harvestman",
72
+ "scorpion",
73
+ "yellow garden spider",
74
+ "barn spider",
75
+ "European garden spider",
76
+ "southern black widow",
77
+ "tarantula",
78
+ "wolf spider",
79
+ "tick",
80
+ "centipede",
81
+ "black grouse",
82
+ "ptarmigan",
83
+ "ruffed grouse",
84
+ "prairie grouse",
85
+ "peacock",
86
+ "quail",
87
+ "partridge",
88
+ "grey parrot",
89
+ "macaw",
90
+ "sulphur-crested cockatoo",
91
+ "lorikeet",
92
+ "coucal",
93
+ "bee eater",
94
+ "hornbill",
95
+ "hummingbird",
96
+ "jacamar",
97
+ "toucan",
98
+ "duck",
99
+ "red-breasted merganser",
100
+ "goose",
101
+ "black swan",
102
+ "tusker",
103
+ "echidna",
104
+ "platypus",
105
+ "wallaby",
106
+ "koala",
107
+ "wombat",
108
+ "jellyfish",
109
+ "sea anemone",
110
+ "brain coral",
111
+ "flatworm",
112
+ "nematode",
113
+ "conch",
114
+ "snail",
115
+ "slug",
116
+ "sea slug",
117
+ "chiton",
118
+ "chambered nautilus",
119
+ "Dungeness crab",
120
+ "rock crab",
121
+ "fiddler crab",
122
+ "red king crab",
123
+ "American lobster",
124
+ "spiny lobster",
125
+ "crayfish",
126
+ "hermit crab",
127
+ "isopod",
128
+ "white stork",
129
+ "black stork",
130
+ "spoonbill",
131
+ "flamingo",
132
+ "little blue heron",
133
+ "great egret",
134
+ "bittern",
135
+ "crane (bird)",
136
+ "limpkin",
137
+ "common gallinule",
138
+ "American coot",
139
+ "bustard",
140
+ "ruddy turnstone",
141
+ "dunlin",
142
+ "common redshank",
143
+ "dowitcher",
144
+ "oystercatcher",
145
+ "pelican",
146
+ "king penguin",
147
+ "albatross",
148
+ "grey whale",
149
+ "killer whale",
150
+ "dugong",
151
+ "sea lion",
152
+ "Chihuahua",
153
+ "Japanese Chin",
154
+ "Maltese",
155
+ "Pekingese",
156
+ "Shih Tzu",
157
+ "King Charles Spaniel",
158
+ "Papillon",
159
+ "toy terrier",
160
+ "Rhodesian Ridgeback",
161
+ "Afghan Hound",
162
+ "Basset Hound",
163
+ "Beagle",
164
+ "Bloodhound",
165
+ "Bluetick Coonhound",
166
+ "Black and Tan Coonhound",
167
+ "Treeing Walker Coonhound",
168
+ "English foxhound",
169
+ "Redbone Coonhound",
170
+ "borzoi",
171
+ "Irish Wolfhound",
172
+ "Italian Greyhound",
173
+ "Whippet",
174
+ "Ibizan Hound",
175
+ "Norwegian Elkhound",
176
+ "Otterhound",
177
+ "Saluki",
178
+ "Scottish Deerhound",
179
+ "Weimaraner",
180
+ "Staffordshire Bull Terrier",
181
+ "American Staffordshire Terrier",
182
+ "Bedlington Terrier",
183
+ "Border Terrier",
184
+ "Kerry Blue Terrier",
185
+ "Irish Terrier",
186
+ "Norfolk Terrier",
187
+ "Norwich Terrier",
188
+ "Yorkshire Terrier",
189
+ "Wire Fox Terrier",
190
+ "Lakeland Terrier",
191
+ "Sealyham Terrier",
192
+ "Airedale Terrier",
193
+ "Cairn Terrier",
194
+ "Australian Terrier",
195
+ "Dandie Dinmont Terrier",
196
+ "Boston Terrier",
197
+ "Miniature Schnauzer",
198
+ "Giant Schnauzer",
199
+ "Standard Schnauzer",
200
+ "Scottish Terrier",
201
+ "Tibetan Terrier",
202
+ "Australian Silky Terrier",
203
+ "Soft-coated Wheaten Terrier",
204
+ "West Highland White Terrier",
205
+ "Lhasa Apso",
206
+ "Flat-Coated Retriever",
207
+ "Curly-coated Retriever",
208
+ "Golden Retriever",
209
+ "Labrador Retriever",
210
+ "Chesapeake Bay Retriever",
211
+ "German Shorthaired Pointer",
212
+ "Vizsla",
213
+ "English Setter",
214
+ "Irish Setter",
215
+ "Gordon Setter",
216
+ "Brittany Spaniel",
217
+ "Clumber Spaniel",
218
+ "English Springer Spaniel",
219
+ "Welsh Springer Spaniel",
220
+ "Cocker Spaniels",
221
+ "Sussex Spaniel",
222
+ "Irish Water Spaniel",
223
+ "Kuvasz",
224
+ "Schipperke",
225
+ "Groenendael",
226
+ "Malinois",
227
+ "Briard",
228
+ "Australian Kelpie",
229
+ "Komondor",
230
+ "Old English Sheepdog",
231
+ "Shetland Sheepdog",
232
+ "collie",
233
+ "Border Collie",
234
+ "Bouvier des Flandres",
235
+ "Rottweiler",
236
+ "German Shepherd Dog",
237
+ "Dobermann",
238
+ "Miniature Pinscher",
239
+ "Greater Swiss Mountain Dog",
240
+ "Bernese Mountain Dog",
241
+ "Appenzeller Sennenhund",
242
+ "Entlebucher Sennenhund",
243
+ "Boxer",
244
+ "Bullmastiff",
245
+ "Tibetan Mastiff",
246
+ "French Bulldog",
247
+ "Great Dane",
248
+ "St. Bernard",
249
+ "husky",
250
+ "Alaskan Malamute",
251
+ "Siberian Husky",
252
+ "Dalmatian",
253
+ "Affenpinscher",
254
+ "Basenji",
255
+ "pug",
256
+ "Leonberger",
257
+ "Newfoundland",
258
+ "Pyrenean Mountain Dog",
259
+ "Samoyed",
260
+ "Pomeranian",
261
+ "Chow Chow",
262
+ "Keeshond",
263
+ "Griffon Bruxellois",
264
+ "Pembroke Welsh Corgi",
265
+ "Cardigan Welsh Corgi",
266
+ "Toy Poodle",
267
+ "Miniature Poodle",
268
+ "Standard Poodle",
269
+ "Mexican hairless dog",
270
+ "grey wolf",
271
+ "Alaskan tundra wolf",
272
+ "red wolf",
273
+ "coyote",
274
+ "dingo",
275
+ "dhole",
276
+ "African wild dog",
277
+ "hyena",
278
+ "red fox",
279
+ "kit fox",
280
+ "Arctic fox",
281
+ "grey fox",
282
+ "tabby cat",
283
+ "tiger cat",
284
+ "Persian cat",
285
+ "Siamese cat",
286
+ "Egyptian Mau",
287
+ "cougar",
288
+ "lynx",
289
+ "leopard",
290
+ "snow leopard",
291
+ "jaguar",
292
+ "lion",
293
+ "tiger",
294
+ "cheetah",
295
+ "brown bear",
296
+ "American black bear",
297
+ "polar bear",
298
+ "sloth bear",
299
+ "mongoose",
300
+ "meerkat",
301
+ "tiger beetle",
302
+ "ladybug",
303
+ "ground beetle",
304
+ "longhorn beetle",
305
+ "leaf beetle",
306
+ "dung beetle",
307
+ "rhinoceros beetle",
308
+ "weevil",
309
+ "fly",
310
+ "bee",
311
+ "ant",
312
+ "grasshopper",
313
+ "cricket",
314
+ "stick insect",
315
+ "cockroach",
316
+ "mantis",
317
+ "cicada",
318
+ "leafhopper",
319
+ "lacewing",
320
+ "dragonfly",
321
+ "damselfly",
322
+ "red admiral",
323
+ "ringlet",
324
+ "monarch butterfly",
325
+ "small white",
326
+ "sulphur butterfly",
327
+ "gossamer-winged butterfly",
328
+ "starfish",
329
+ "sea urchin",
330
+ "sea cucumber",
331
+ "cottontail rabbit",
332
+ "hare",
333
+ "Angora rabbit",
334
+ "hamster",
335
+ "porcupine",
336
+ "fox squirrel",
337
+ "marmot",
338
+ "beaver",
339
+ "guinea pig",
340
+ "common sorrel",
341
+ "zebra",
342
+ "pig",
343
+ "wild boar",
344
+ "warthog",
345
+ "hippopotamus",
346
+ "ox",
347
+ "water buffalo",
348
+ "bison",
349
+ "ram",
350
+ "bighorn sheep",
351
+ "Alpine ibex",
352
+ "hartebeest",
353
+ "impala",
354
+ "gazelle",
355
+ "dromedary",
356
+ "llama",
357
+ "weasel",
358
+ "mink",
359
+ "European polecat",
360
+ "black-footed ferret",
361
+ "otter",
362
+ "skunk",
363
+ "badger",
364
+ "armadillo",
365
+ "three-toed sloth",
366
+ "orangutan",
367
+ "gorilla",
368
+ "chimpanzee",
369
+ "gibbon",
370
+ "siamang",
371
+ "guenon",
372
+ "patas monkey",
373
+ "baboon",
374
+ "macaque",
375
+ "langur",
376
+ "black-and-white colobus",
377
+ "proboscis monkey",
378
+ "marmoset",
379
+ "white-headed capuchin",
380
+ "howler monkey",
381
+ "titi",
382
+ "Geoffroy's spider monkey",
383
+ "common squirrel monkey",
384
+ "ring-tailed lemur",
385
+ "indri",
386
+ "Asian elephant",
387
+ "African bush elephant",
388
+ "red panda",
389
+ "giant panda",
390
+ "snoek",
391
+ "eel",
392
+ "coho salmon",
393
+ "rock beauty",
394
+ "clownfish",
395
+ "sturgeon",
396
+ "garfish",
397
+ "lionfish",
398
+ "pufferfish",
399
+ "abacus",
400
+ "abaya",
401
+ "academic gown",
402
+ "accordion",
403
+ "acoustic guitar",
404
+ "aircraft carrier",
405
+ "airliner",
406
+ "airship",
407
+ "altar",
408
+ "ambulance",
409
+ "amphibious vehicle",
410
+ "analog clock",
411
+ "apiary",
412
+ "apron",
413
+ "waste container",
414
+ "assault rifle",
415
+ "backpack",
416
+ "bakery",
417
+ "balance beam",
418
+ "balloon",
419
+ "ballpoint pen",
420
+ "Band-Aid",
421
+ "banjo",
422
+ "baluster",
423
+ "barbell",
424
+ "barber chair",
425
+ "barbershop",
426
+ "barn",
427
+ "barometer",
428
+ "barrel",
429
+ "wheelbarrow",
430
+ "baseball",
431
+ "basketball",
432
+ "bassinet",
433
+ "bassoon",
434
+ "swimming cap",
435
+ "bath towel",
436
+ "bathtub",
437
+ "station wagon",
438
+ "lighthouse",
439
+ "beaker",
440
+ "military cap",
441
+ "beer bottle",
442
+ "beer glass",
443
+ "bell-cot",
444
+ "bib",
445
+ "tandem bicycle",
446
+ "bikini",
447
+ "ring binder",
448
+ "binoculars",
449
+ "birdhouse",
450
+ "boathouse",
451
+ "bobsleigh",
452
+ "bolo tie",
453
+ "poke bonnet",
454
+ "bookcase",
455
+ "bookstore",
456
+ "bottle cap",
457
+ "bow",
458
+ "bow tie",
459
+ "brass",
460
+ "bra",
461
+ "breakwater",
462
+ "breastplate",
463
+ "broom",
464
+ "bucket",
465
+ "buckle",
466
+ "bulletproof vest",
467
+ "high-speed train",
468
+ "butcher shop",
469
+ "taxicab",
470
+ "cauldron",
471
+ "candle",
472
+ "cannon",
473
+ "canoe",
474
+ "can opener",
475
+ "cardigan",
476
+ "car mirror",
477
+ "carousel",
478
+ "tool kit",
479
+ "carton",
480
+ "car wheel",
481
+ "automated teller machine",
482
+ "cassette",
483
+ "cassette player",
484
+ "castle",
485
+ "catamaran",
486
+ "CD player",
487
+ "cello",
488
+ "mobile phone",
489
+ "chain",
490
+ "chain-link fence",
491
+ "chain mail",
492
+ "chainsaw",
493
+ "chest",
494
+ "chiffonier",
495
+ "chime",
496
+ "china cabinet",
497
+ "Christmas stocking",
498
+ "church",
499
+ "movie theater",
500
+ "cleaver",
501
+ "cliff dwelling",
502
+ "cloak",
503
+ "clogs",
504
+ "cocktail shaker",
505
+ "coffee mug",
506
+ "coffeemaker",
507
+ "coil",
508
+ "combination lock",
509
+ "computer keyboard",
510
+ "confectionery store",
511
+ "container ship",
512
+ "convertible",
513
+ "corkscrew",
514
+ "cornet",
515
+ "cowboy boot",
516
+ "cowboy hat",
517
+ "cradle",
518
+ "crane (machine)",
519
+ "crash helmet",
520
+ "crate",
521
+ "infant bed",
522
+ "Crock Pot",
523
+ "croquet ball",
524
+ "crutch",
525
+ "cuirass",
526
+ "dam",
527
+ "desk",
528
+ "desktop computer",
529
+ "rotary dial telephone",
530
+ "diaper",
531
+ "digital clock",
532
+ "digital watch",
533
+ "dining table",
534
+ "dishcloth",
535
+ "dishwasher",
536
+ "disc brake",
537
+ "dock",
538
+ "dog sled",
539
+ "dome",
540
+ "doormat",
541
+ "drilling rig",
542
+ "drum",
543
+ "drumstick",
544
+ "dumbbell",
545
+ "Dutch oven",
546
+ "electric fan",
547
+ "electric guitar",
548
+ "electric locomotive",
549
+ "entertainment center",
550
+ "envelope",
551
+ "espresso machine",
552
+ "face powder",
553
+ "feather boa",
554
+ "filing cabinet",
555
+ "fireboat",
556
+ "fire engine",
557
+ "fire screen sheet",
558
+ "flagpole",
559
+ "flute",
560
+ "folding chair",
561
+ "football helmet",
562
+ "forklift",
563
+ "fountain",
564
+ "fountain pen",
565
+ "four-poster bed",
566
+ "freight car",
567
+ "French horn",
568
+ "frying pan",
569
+ "fur coat",
570
+ "garbage truck",
571
+ "gas mask",
572
+ "gas pump",
573
+ "goblet",
574
+ "go-kart",
575
+ "golf ball",
576
+ "golf cart",
577
+ "gondola",
578
+ "gong",
579
+ "gown",
580
+ "grand piano",
581
+ "greenhouse",
582
+ "grille",
583
+ "grocery store",
584
+ "guillotine",
585
+ "barrette",
586
+ "hair spray",
587
+ "half-track",
588
+ "hammer",
589
+ "hamper",
590
+ "hair dryer",
591
+ "hand-held computer",
592
+ "handkerchief",
593
+ "hard disk drive",
594
+ "harmonica",
595
+ "harp",
596
+ "harvester",
597
+ "hatchet",
598
+ "holster",
599
+ "home theater",
600
+ "honeycomb",
601
+ "hook",
602
+ "hoop skirt",
603
+ "horizontal bar",
604
+ "horse-drawn vehicle",
605
+ "hourglass",
606
+ "iPod",
607
+ "clothes iron",
608
+ "jack-o'-lantern",
609
+ "jeans",
610
+ "jeep",
611
+ "T-shirt",
612
+ "jigsaw puzzle",
613
+ "pulled rickshaw",
614
+ "joystick",
615
+ "kimono",
616
+ "knee pad",
617
+ "knot",
618
+ "lab coat",
619
+ "ladle",
620
+ "lampshade",
621
+ "laptop computer",
622
+ "lawn mower",
623
+ "lens cap",
624
+ "paper knife",
625
+ "library",
626
+ "lifeboat",
627
+ "lighter",
628
+ "limousine",
629
+ "ocean liner",
630
+ "lipstick",
631
+ "slip-on shoe",
632
+ "lotion",
633
+ "speaker",
634
+ "loupe",
635
+ "sawmill",
636
+ "magnetic compass",
637
+ "mail bag",
638
+ "mailbox",
639
+ "tights",
640
+ "tank suit",
641
+ "manhole cover",
642
+ "maraca",
643
+ "marimba",
644
+ "mask",
645
+ "match",
646
+ "maypole",
647
+ "maze",
648
+ "measuring cup",
649
+ "medicine chest",
650
+ "megalith",
651
+ "microphone",
652
+ "microwave oven",
653
+ "military uniform",
654
+ "milk can",
655
+ "minibus",
656
+ "miniskirt",
657
+ "minivan",
658
+ "missile",
659
+ "mitten",
660
+ "mixing bowl",
661
+ "mobile home",
662
+ "Model T",
663
+ "modem",
664
+ "monastery",
665
+ "monitor",
666
+ "moped",
667
+ "mortar",
668
+ "square academic cap",
669
+ "mosque",
670
+ "mosquito net",
671
+ "scooter",
672
+ "mountain bike",
673
+ "tent",
674
+ "computer mouse",
675
+ "mousetrap",
676
+ "moving van",
677
+ "muzzle",
678
+ "nail",
679
+ "neck brace",
680
+ "necklace",
681
+ "nipple",
682
+ "notebook computer",
683
+ "obelisk",
684
+ "oboe",
685
+ "ocarina",
686
+ "odometer",
687
+ "oil filter",
688
+ "organ",
689
+ "oscilloscope",
690
+ "overskirt",
691
+ "bullock cart",
692
+ "oxygen mask",
693
+ "packet",
694
+ "paddle",
695
+ "paddle wheel",
696
+ "padlock",
697
+ "paintbrush",
698
+ "pajamas",
699
+ "palace",
700
+ "pan flute",
701
+ "paper towel",
702
+ "parachute",
703
+ "parallel bars",
704
+ "park bench",
705
+ "parking meter",
706
+ "passenger car",
707
+ "patio",
708
+ "payphone",
709
+ "pedestal",
710
+ "pencil case",
711
+ "pencil sharpener",
712
+ "perfume",
713
+ "Petri dish",
714
+ "photocopier",
715
+ "plectrum",
716
+ "Pickelhaube",
717
+ "picket fence",
718
+ "pickup truck",
719
+ "pier",
720
+ "piggy bank",
721
+ "pill bottle",
722
+ "pillow",
723
+ "ping-pong ball",
724
+ "pinwheel",
725
+ "pirate ship",
726
+ "pitcher",
727
+ "hand plane",
728
+ "planetarium",
729
+ "plastic bag",
730
+ "plate rack",
731
+ "plow",
732
+ "plunger",
733
+ "Polaroid camera",
734
+ "pole",
735
+ "police van",
736
+ "poncho",
737
+ "billiard table",
738
+ "soda bottle",
739
+ "pot",
740
+ "potter's wheel",
741
+ "power drill",
742
+ "prayer rug",
743
+ "printer",
744
+ "prison",
745
+ "projectile",
746
+ "projector",
747
+ "hockey puck",
748
+ "punching bag",
749
+ "purse",
750
+ "quill",
751
+ "quilt",
752
+ "race car",
753
+ "racket",
754
+ "radiator",
755
+ "radio",
756
+ "radio telescope",
757
+ "rain barrel",
758
+ "recreational vehicle",
759
+ "reel",
760
+ "reflex camera",
761
+ "refrigerator",
762
+ "remote control",
763
+ "restaurant",
764
+ "revolver",
765
+ "rifle",
766
+ "rocking chair",
767
+ "rotisserie",
768
+ "eraser",
769
+ "rugby ball",
770
+ "ruler",
771
+ "running shoe",
772
+ "safe",
773
+ "safety pin",
774
+ "salt shaker",
775
+ "sandal",
776
+ "sarong",
777
+ "saxophone",
778
+ "scabbard",
779
+ "weighing scale",
780
+ "school bus",
781
+ "schooner",
782
+ "scoreboard",
783
+ "CRT screen",
784
+ "screw",
785
+ "screwdriver",
786
+ "seat belt",
787
+ "sewing machine",
788
+ "shield",
789
+ "shoe store",
790
+ "shoji",
791
+ "shopping basket",
792
+ "shopping cart",
793
+ "shovel",
794
+ "shower cap",
795
+ "shower curtain",
796
+ "ski",
797
+ "ski mask",
798
+ "sleeping bag",
799
+ "slide rule",
800
+ "sliding door",
801
+ "slot machine",
802
+ "snorkel",
803
+ "snowmobile",
804
+ "snowplow",
805
+ "soap dispenser",
806
+ "soccer ball",
807
+ "sock",
808
+ "solar thermal collector",
809
+ "sombrero",
810
+ "soup bowl",
811
+ "space bar",
812
+ "space heater",
813
+ "space shuttle",
814
+ "spatula",
815
+ "motorboat",
816
+ "spider web",
817
+ "spindle",
818
+ "sports car",
819
+ "spotlight",
820
+ "stage",
821
+ "steam locomotive",
822
+ "through arch bridge",
823
+ "steel drum",
824
+ "stethoscope",
825
+ "scarf",
826
+ "stone wall",
827
+ "stopwatch",
828
+ "stove",
829
+ "strainer",
830
+ "tram",
831
+ "stretcher",
832
+ "couch",
833
+ "stupa",
834
+ "submarine",
835
+ "suit",
836
+ "sundial",
837
+ "sunglass",
838
+ "sunglasses",
839
+ "sunscreen",
840
+ "suspension bridge",
841
+ "mop",
842
+ "sweatshirt",
843
+ "swimsuit",
844
+ "swing",
845
+ "switch",
846
+ "syringe",
847
+ "table lamp",
848
+ "tank",
849
+ "tape player",
850
+ "teapot",
851
+ "teddy bear",
852
+ "television",
853
+ "tennis ball",
854
+ "thatched roof",
855
+ "front curtain",
856
+ "thimble",
857
+ "threshing machine",
858
+ "throne",
859
+ "tile roof",
860
+ "toaster",
861
+ "tobacco shop",
862
+ "toilet seat",
863
+ "torch",
864
+ "totem pole",
865
+ "tow truck",
866
+ "toy store",
867
+ "tractor",
868
+ "semi-trailer truck",
869
+ "tray",
870
+ "trench coat",
871
+ "tricycle",
872
+ "trimaran",
873
+ "tripod",
874
+ "triumphal arch",
875
+ "trolleybus",
876
+ "trombone",
877
+ "tub",
878
+ "turnstile",
879
+ "typewriter keyboard",
880
+ "umbrella",
881
+ "unicycle",
882
+ "upright piano",
883
+ "vacuum cleaner",
884
+ "vase",
885
+ "vault",
886
+ "velvet",
887
+ "vending machine",
888
+ "vestment",
889
+ "viaduct",
890
+ "violin",
891
+ "volleyball",
892
+ "waffle iron",
893
+ "wall clock",
894
+ "wallet",
895
+ "wardrobe",
896
+ "military aircraft",
897
+ "sink",
898
+ "washing machine",
899
+ "water bottle",
900
+ "water jug",
901
+ "water tower",
902
+ "whiskey jug",
903
+ "whistle",
904
+ "wig",
905
+ "window screen",
906
+ "window shade",
907
+ "Windsor tie",
908
+ "wine bottle",
909
+ "wing",
910
+ "wok",
911
+ "wooden spoon",
912
+ "wool",
913
+ "split-rail fence",
914
+ "shipwreck",
915
+ "yawl",
916
+ "yurt",
917
+ "website",
918
+ "comic book",
919
+ "crossword",
920
+ "traffic sign",
921
+ "traffic light",
922
+ "dust jacket",
923
+ "menu",
924
+ "plate",
925
+ "guacamole",
926
+ "consomme",
927
+ "hot pot",
928
+ "trifle",
929
+ "ice cream",
930
+ "ice pop",
931
+ "baguette",
932
+ "bagel",
933
+ "pretzel",
934
+ "cheeseburger",
935
+ "hot dog",
936
+ "mashed potato",
937
+ "cabbage",
938
+ "broccoli",
939
+ "cauliflower",
940
+ "zucchini",
941
+ "spaghetti squash",
942
+ "acorn squash",
943
+ "butternut squash",
944
+ "cucumber",
945
+ "artichoke",
946
+ "bell pepper",
947
+ "cardoon",
948
+ "mushroom",
949
+ "Granny Smith",
950
+ "strawberry",
951
+ "orange",
952
+ "lemon",
953
+ "fig",
954
+ "pineapple",
955
+ "banana",
956
+ "jackfruit",
957
+ "custard apple",
958
+ "pomegranate",
959
+ "hay",
960
+ "carbonara",
961
+ "chocolate syrup",
962
+ "dough",
963
+ "meatloaf",
964
+ "pizza",
965
+ "pot pie",
966
+ "burrito",
967
+ "red wine",
968
+ "espresso",
969
+ "cup",
970
+ "eggnog",
971
+ "alp",
972
+ "bubble",
973
+ "cliff",
974
+ "coral reef",
975
+ "geyser",
976
+ "lakeshore",
977
+ "promontory",
978
+ "shoal",
979
+ "seashore",
980
+ "valley",
981
+ "volcano",
982
+ "baseball player",
983
+ "bridegroom",
984
+ "scuba diver",
985
+ "rapeseed",
986
+ "daisy",
987
+ "yellow lady's slipper",
988
+ "corn",
989
+ "acorn",
990
+ "rose hip",
991
+ "horse chestnut seed",
992
+ "coral fungus",
993
+ "agaric",
994
+ "gyromitra",
995
+ "stinkhorn mushroom",
996
+ "earth star",
997
+ "hen-of-the-woods",
998
+ "bolete",
999
+ "ear of corn",
1000
+ "toilet paper"]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow
5
+ tqdm
6
+ numpy
7
+ tensorboard
train_resnet50_local.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision
5
+ from torch.utils.data import DataLoader, Subset
6
+ from torchvision import datasets, transforms
7
+ import torch.nn.functional as F
8
+ import os
9
+ from tqdm import tqdm
10
+ import random
11
+ import numpy as np
12
+ from torch.utils.tensorboard import SummaryWriter
13
+ import json
14
+ from datetime import timedelta, datetime
15
+ import logging
16
+ import torch.distributed as dist
17
+ import torch.multiprocessing as mp
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+ import socket
20
+ import argparse
21
+ import math
22
+
23
+
24
+ # Set random seeds for reproducibility
25
+ def set_seed(seed=42):
26
+ random.seed(seed)
27
+ np.random.seed(seed)
28
+ torch.manual_seed(seed)
29
+ torch.cuda.manual_seed_all(seed)
30
+
31
+
32
+ # Training configuration
33
+ class Config:
34
+ num_epochs = 150
35
+ batch_size = 512
36
+ learning_rate = 0.1
37
+ momentum = 0.9
38
+ weight_decay = 1e-4
39
+ num_workers = 16
40
+ subset_size = None
41
+ print_freq = 100
42
+
43
+ # Add gradient accumulation steps if needed
44
+ accum_iter = 1 # Can be increased if memory allows
45
+
46
+ # Add mixed precision training parameters
47
+ use_amp = True # Enable automatic mixed precision
48
+
49
+
50
+ class AverageMeter(object):
51
+ """Computes and stores the average and current value"""
52
+
53
+ def __init__(self):
54
+ self.reset()
55
+
56
+ def reset(self):
57
+ self.val = 0
58
+ self.avg = 0
59
+ self.sum = 0
60
+ self.count = 0
61
+
62
+ def update(self, val, n=1):
63
+ self.val = val
64
+ self.sum += val * n
65
+ self.count += n
66
+ self.avg = self.sum / self.count
67
+
68
+
69
+ def get_data_loaders(subset_size=None, distributed=False, world_size=None, rank=None):
70
+ # ImageNet normalization values
71
+ normalize = transforms.Normalize(
72
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
73
+ )
74
+
75
+ # Modified data augmentation for training
76
+ train_transform = transforms.Compose(
77
+ [
78
+ transforms.RandomResizedCrop(224), # Removed interpolation and antialias
79
+ transforms.RandomHorizontalFlip(),
80
+ transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
81
+ transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
82
+ transforms.ToTensor(),
83
+ normalize,
84
+ transforms.RandomErasing(p=0.5), # Moved after ToTensor
85
+ ]
86
+ )
87
+
88
+ # Modified transform for validation
89
+ val_transform = transforms.Compose(
90
+ [
91
+ transforms.Resize(256), # Removed antialias
92
+ transforms.CenterCrop(224),
93
+ transforms.ToTensor(),
94
+ normalize,
95
+ ]
96
+ )
97
+
98
+ training_folder_name = "ILSVRC/Data/CLS-LOC/train"
99
+ val_folder_name = "ILSVRC/Data/CLS-LOC/val"
100
+
101
+ train_dataset = torchvision.datasets.ImageFolder(
102
+ root=training_folder_name, transform=train_transform
103
+ )
104
+
105
+ val_dataset = torchvision.datasets.ImageFolder(
106
+ root=val_folder_name, transform=val_transform
107
+ )
108
+
109
+ # Create subset for initial testing
110
+ if subset_size:
111
+ train_indices = torch.randperm(len(train_dataset))[:subset_size]
112
+ val_indices = torch.randperm(len(val_dataset))[: subset_size // 10]
113
+ train_dataset = Subset(train_dataset, train_indices)
114
+ val_dataset = Subset(val_dataset, val_indices)
115
+
116
+ # Create samplers for distributed training
117
+ train_sampler = None
118
+ val_sampler = None
119
+
120
+ if distributed:
121
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
122
+ train_dataset, num_replicas=world_size, rank=rank
123
+ )
124
+ val_sampler = torch.utils.data.distributed.DistributedSampler(
125
+ val_dataset, num_replicas=world_size, rank=rank
126
+ )
127
+
128
+ # Create data loaders
129
+ train_loader = DataLoader(
130
+ train_dataset,
131
+ batch_size=Config.batch_size,
132
+ shuffle=(train_sampler is None),
133
+ num_workers=Config.num_workers,
134
+ pin_memory=True,
135
+ sampler=train_sampler,
136
+ persistent_workers=True,
137
+ prefetch_factor=2,
138
+ )
139
+
140
+ val_loader = DataLoader(
141
+ val_dataset,
142
+ batch_size=Config.batch_size,
143
+ shuffle=False,
144
+ num_workers=Config.num_workers,
145
+ pin_memory=True,
146
+ sampler=val_sampler,
147
+ persistent_workers=True,
148
+ prefetch_factor=2,
149
+ )
150
+
151
+ return train_loader, val_loader, train_sampler
152
+
153
+
154
+ def train_epoch(model, train_loader, criterion, optimizer, epoch, device):
155
+ epoch_start_time = datetime.now()
156
+ model.train()
157
+ running_loss = 0.0
158
+ correct = 0
159
+ total = 0
160
+
161
+ # Create GradScaler for mixed precision training
162
+ scaler = torch.cuda.amp.GradScaler(enabled=Config.use_amp)
163
+
164
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
165
+ optimizer.zero_grad()
166
+
167
+ for i, data in enumerate(pbar):
168
+ try:
169
+ images, targets = data
170
+ images, targets = images.to(device), targets.to(device)
171
+
172
+ # Mixed precision training
173
+ with torch.cuda.amp.autocast(enabled=Config.use_amp):
174
+ outputs = model(images)
175
+ loss = criterion(outputs, targets)
176
+ loss = (
177
+ loss / Config.accum_iter
178
+ ) # Normalize loss for gradient accumulation
179
+
180
+ # Backward pass with gradient scaling
181
+ scaler.scale(loss).backward()
182
+
183
+ # Gradient accumulation
184
+ if ((i + 1) % Config.accum_iter == 0) or (i + 1 == len(train_loader)):
185
+ scaler.unscale_(optimizer)
186
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
187
+ scaler.step(optimizer)
188
+ scaler.update()
189
+ optimizer.zero_grad()
190
+
191
+ running_loss += loss.item() * Config.accum_iter
192
+ _, predicted = outputs.max(1)
193
+ total += targets.size(0)
194
+ correct += predicted.eq(targets).sum().item()
195
+
196
+ if i % Config.print_freq == 0:
197
+ accuracy = 100.0 * correct / total
198
+ pbar.set_postfix(
199
+ {
200
+ "loss": running_loss / (i + 1),
201
+ "acc": f"{accuracy:.2f}%",
202
+ "lr": optimizer.param_groups[0]["lr"],
203
+ }
204
+ )
205
+ except Exception as e:
206
+ print(f"Error in batch {i}: {str(e)}")
207
+ continue
208
+
209
+ # Calculate epoch time and return metrics
210
+ epoch_time = datetime.now() - epoch_start_time
211
+ epoch_metrics = {
212
+ "time": epoch_time,
213
+ "loss": running_loss / len(train_loader),
214
+ "accuracy": 100.0 * correct / total,
215
+ }
216
+ return epoch_metrics
217
+
218
+
219
+ def validate(model, val_loader, criterion, device):
220
+ model.eval()
221
+ top1 = AverageMeter()
222
+ top5 = AverageMeter()
223
+ losses = AverageMeter()
224
+
225
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=Config.use_amp):
226
+ for images, targets in tqdm(val_loader, desc="Validating"):
227
+ images, targets = images.to(device), targets.to(device)
228
+ output = model(images)
229
+ loss = criterion(output, targets)
230
+
231
+ # Compute top-1 and top-5 accuracy
232
+ maxk = max((1, 5))
233
+ batch_size = targets.size(0)
234
+
235
+ _, pred = output.topk(maxk, 1, True, True)
236
+ pred = pred.t()
237
+ correct = pred.eq(targets.view(1, -1).expand_as(pred))
238
+
239
+ # Top-1 accuracy
240
+ top1_acc = correct[0].float().sum() * 100.0 / batch_size
241
+ top1.update(top1_acc.item(), batch_size)
242
+
243
+ # Top-5 accuracy
244
+ top5_acc = correct[:5].float().sum() * 100.0 / batch_size
245
+ top5.update(top5_acc.item(), batch_size)
246
+
247
+ losses.update(loss.item(), batch_size)
248
+
249
+ return top1.avg, top5.avg, losses.avg
250
+
251
+
252
+ # Add ResNet building blocks
253
+ class BasicBlock(nn.Module):
254
+ expansion = 1
255
+
256
+ def __init__(self, in_channels, out_channels, stride=1):
257
+ super().__init__()
258
+ self.conv1 = nn.Conv2d(
259
+ in_channels,
260
+ out_channels,
261
+ kernel_size=3,
262
+ stride=stride,
263
+ padding=1,
264
+ bias=False,
265
+ )
266
+ self.bn1 = nn.BatchNorm2d(out_channels)
267
+ self.conv2 = nn.Conv2d(
268
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
269
+ )
270
+ self.bn2 = nn.BatchNorm2d(out_channels)
271
+
272
+ self.shortcut = nn.Sequential()
273
+ if stride != 1 or in_channels != out_channels:
274
+ self.shortcut = nn.Sequential(
275
+ nn.Conv2d(
276
+ in_channels, out_channels, kernel_size=1, stride=stride, bias=False
277
+ ),
278
+ nn.BatchNorm2d(out_channels),
279
+ )
280
+
281
+ def forward(self, x):
282
+ out = F.relu(self.bn1(self.conv1(x)))
283
+ out = self.bn2(self.conv2(out))
284
+ out += self.shortcut(x)
285
+ out = F.relu(out)
286
+ return out
287
+
288
+
289
+ class Bottleneck(nn.Module):
290
+ expansion = 4
291
+
292
+ def __init__(self, in_channels, out_channels, stride=1):
293
+ super().__init__()
294
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
295
+ self.bn1 = nn.BatchNorm2d(out_channels)
296
+ self.conv2 = nn.Conv2d(
297
+ out_channels,
298
+ out_channels,
299
+ kernel_size=3,
300
+ stride=stride,
301
+ padding=1,
302
+ bias=False,
303
+ )
304
+ self.bn2 = nn.BatchNorm2d(out_channels)
305
+ self.conv3 = nn.Conv2d(
306
+ out_channels, out_channels * self.expansion, kernel_size=1, bias=False
307
+ )
308
+ self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
309
+
310
+ self.shortcut = nn.Sequential()
311
+ if stride != 1 or in_channels != out_channels * self.expansion:
312
+ self.shortcut = nn.Sequential(
313
+ nn.Conv2d(
314
+ in_channels,
315
+ out_channels * self.expansion,
316
+ kernel_size=1,
317
+ stride=stride,
318
+ bias=False,
319
+ ),
320
+ nn.BatchNorm2d(out_channels * self.expansion),
321
+ )
322
+
323
+ def forward(self, x):
324
+ out = F.relu(self.bn1(self.conv1(x)))
325
+ out = F.relu(self.bn2(self.conv2(out)))
326
+ out = self.bn3(self.conv3(out))
327
+ out += self.shortcut(x)
328
+ out = F.relu(out)
329
+ return out
330
+
331
+
332
+ class ResNet(nn.Module):
333
+ def __init__(self, block, num_blocks, num_classes=1000):
334
+ super().__init__()
335
+ self.in_channels = 64
336
+
337
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
338
+ self.bn1 = nn.BatchNorm2d(64)
339
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
340
+
341
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
342
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
343
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
344
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
345
+
346
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
347
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
348
+
349
+ # Initialize weights
350
+ for m in self.modules():
351
+ if isinstance(m, nn.Conv2d):
352
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
353
+ elif isinstance(m, nn.BatchNorm2d):
354
+ nn.init.constant_(m.weight, 1)
355
+ nn.init.constant_(m.bias, 0)
356
+
357
+ def _make_layer(self, block, out_channels, num_blocks, stride):
358
+ strides = [stride] + [1] * (num_blocks - 1)
359
+ layers = []
360
+ for stride in strides:
361
+ layers.append(block(self.in_channels, out_channels, stride))
362
+ self.in_channels = out_channels * block.expansion
363
+ return nn.Sequential(*layers)
364
+
365
+ def forward(self, x):
366
+ out = F.relu(self.bn1(self.conv1(x)))
367
+ out = self.maxpool(out)
368
+
369
+ out = self.layer1(out)
370
+ out = self.layer2(out)
371
+ out = self.layer3(out)
372
+ out = self.layer4(out)
373
+
374
+ out = self.avgpool(out)
375
+ out = torch.flatten(out, 1)
376
+ out = self.fc(out)
377
+ return out
378
+
379
+
380
+ # Replace the model creation in main() with this:
381
+ def create_resnet50():
382
+ return ResNet(Bottleneck, [3, 4, 6, 3])
383
+
384
+
385
+ # Add logging setup function
386
+ def setup_logging(log_dir):
387
+ # Create local log directory
388
+ os.makedirs(log_dir, exist_ok=True)
389
+ logging.basicConfig(
390
+ level=logging.INFO,
391
+ format="%(asctime)s - %(levelname)s - %(message)s",
392
+ handlers=[
393
+ logging.FileHandler(os.path.join(log_dir, "training.log")),
394
+ logging.StreamHandler(),
395
+ ],
396
+ )
397
+ return logging.getLogger(__name__)
398
+
399
+
400
+ # Add distributed training setup
401
+ def setup_distributed():
402
+ parser = argparse.ArgumentParser()
403
+ parser.add_argument("--nodes", type=int, default=1)
404
+ args = parser.parse_args()
405
+
406
+ if "LOCAL_RANK" not in os.environ:
407
+ os.environ["LOCAL_RANK"] = "-1"
408
+
409
+ args.local_rank = int(os.environ["LOCAL_RANK"])
410
+
411
+ if "WORLD_SIZE" in os.environ:
412
+ args.world_size = int(os.environ["WORLD_SIZE"])
413
+ else:
414
+ args.world_size = args.nodes
415
+
416
+ return args
417
+
418
+
419
+ # Add this function to get dataset statistics
420
+ def get_dataset_stats(train_loader, val_loader):
421
+ stats = {
422
+ "num_train_samples": len(train_loader.dataset),
423
+ "num_val_samples": len(val_loader.dataset),
424
+ "num_classes": len(train_loader.dataset.dataset.classes)
425
+ if hasattr(train_loader.dataset, "dataset")
426
+ else len(train_loader.dataset.classes),
427
+ "batch_size": train_loader.batch_size,
428
+ "num_train_batches": len(train_loader),
429
+ "num_val_batches": len(val_loader),
430
+ "device_count": torch.cuda.device_count(),
431
+ "max_epochs": Config.num_epochs,
432
+ "learning_rate": Config.learning_rate,
433
+ "weight_decay": Config.weight_decay,
434
+ "num_workers": Config.num_workers,
435
+ }
436
+
437
+ # Get class distribution
438
+ if hasattr(train_loader.dataset, "dataset"):
439
+ # For subset dataset
440
+ classes = train_loader.dataset.dataset.classes
441
+ class_to_idx = train_loader.dataset.dataset.class_to_idx
442
+ else:
443
+ # For full dataset
444
+ classes = train_loader.dataset.classes
445
+ class_to_idx = train_loader.dataset.class_to_idx
446
+
447
+ stats["classes"] = classes
448
+ stats["class_to_idx"] = class_to_idx
449
+
450
+ return stats
451
+
452
+
453
+ # Modify the main function to support distributed training
454
+ def main():
455
+ start_time = datetime.now()
456
+ args = setup_distributed()
457
+
458
+ # Setup distributed training
459
+ if args.local_rank != -1:
460
+ torch.cuda.set_device(args.local_rank)
461
+ dist.init_process_group(
462
+ backend="nccl",
463
+ init_method="env://", # Use environment variables for initialization
464
+ )
465
+
466
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
467
+
468
+ # Setup logging and tensorboard
469
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
470
+ log_dir = f"runs/resnet50_{timestamp}"
471
+ if args.local_rank <= 0: # Only create directories for master process
472
+ os.makedirs(log_dir, exist_ok=True)
473
+ writer = SummaryWriter(log_dir)
474
+ logger = setup_logging(log_dir)
475
+ logger.info(f"Starting training on {socket.gethostname()}")
476
+ logger.info(f"Available GPUs: {torch.cuda.device_count()}")
477
+ logger.info(f"Training started at: {start_time}")
478
+
479
+ set_seed()
480
+
481
+ # Create model
482
+ model = create_resnet50()
483
+ if args.local_rank != -1:
484
+ model = DDP(model.to(device), device_ids=[args.local_rank])
485
+ else:
486
+ model = torch.nn.DataParallel(model).to(device)
487
+
488
+ # Rest of your training setup
489
+ criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
490
+ optimizer = optim.SGD(
491
+ model.parameters(),
492
+ lr=Config.learning_rate,
493
+ momentum=Config.momentum,
494
+ weight_decay=Config.weight_decay,
495
+ nesterov=True,
496
+ )
497
+
498
+ # Cosine annealing with warmup
499
+ warmup_epochs = 5
500
+
501
+ def warmup_lr_scheduler(epoch):
502
+ if epoch < warmup_epochs:
503
+ return epoch / warmup_epochs
504
+ return 0.5 * (
505
+ 1
506
+ + math.cos(
507
+ math.pi * (epoch - warmup_epochs) / (Config.num_epochs - warmup_epochs)
508
+ )
509
+ )
510
+
511
+ scheduler = optim.lr_scheduler.LambdaLR(optimizer, warmup_lr_scheduler)
512
+
513
+ # Get data loaders with distributed sampler
514
+ train_loader, val_loader, train_sampler = get_data_loaders(
515
+ subset_size=Config.subset_size,
516
+ distributed=(args.local_rank != -1),
517
+ world_size=dist.get_world_size() if args.local_rank != -1 else None,
518
+ rank=args.local_rank if args.local_rank != -1 else None,
519
+ )
520
+
521
+ # Log dataset statistics
522
+ if args.local_rank <= 0:
523
+ dataset_stats = get_dataset_stats(train_loader, val_loader)
524
+ logger.info("Dataset Statistics:")
525
+ logger.info(f"Training samples: {dataset_stats['num_train_samples']}")
526
+ logger.info(f"Validation samples: {dataset_stats['num_val_samples']}")
527
+ logger.info(f"Number of classes: {dataset_stats['num_classes']}")
528
+ logger.info(f"Batch size: {dataset_stats['batch_size']}")
529
+ logger.info(f"Training batches per epoch: {dataset_stats['num_train_batches']}")
530
+ logger.info(f"Validation batches per epoch: {dataset_stats['num_val_batches']}")
531
+
532
+ best_acc = 0
533
+ # Training loop
534
+ total_training_time = timedelta()
535
+ # Training loop
536
+ for epoch in range(Config.num_epochs):
537
+ if args.local_rank <= 0:
538
+ logger.info(f"Starting epoch {epoch}")
539
+
540
+ if train_sampler is not None:
541
+ train_sampler.set_epoch(epoch)
542
+
543
+ # Train for one epoch and get metrics
544
+ train_metrics = train_epoch(
545
+ model, train_loader, criterion, optimizer, epoch, device
546
+ )
547
+ total_training_time += train_metrics["time"]
548
+
549
+ # train_epoch(model, train_loader, criterion, optimizer, epoch, device)
550
+
551
+ if args.local_rank <= 0: # Only validate on master process
552
+ # Log training metrics
553
+ logger.info(
554
+ f"Epoch {epoch} completed in {train_metrics['time']}, "
555
+ f"Training Loss: {train_metrics['loss']:.4f}, "
556
+ f"Training Accuracy: {train_metrics['accuracy']:.2f}%"
557
+ )
558
+
559
+ top1_acc, top5_acc, val_loss = validate(
560
+ model, val_loader, criterion, device
561
+ )
562
+
563
+ # Log validation metrics
564
+ logger.info(
565
+ f"Validation metrics - "
566
+ f"Top1 Acc: {top1_acc:.2f}%, "
567
+ f"Top5 Acc: {top5_acc:.2f}%, "
568
+ f"Val Loss: {val_loss:.4f}"
569
+ )
570
+
571
+ # Log cumulative time
572
+ logger.info(f"Total training time so far: {total_training_time}")
573
+
574
+ # Log to tensorboard
575
+ writer.add_scalar("Training/Loss", train_metrics["loss"], epoch)
576
+ writer.add_scalar("Training/Accuracy", train_metrics["accuracy"], epoch)
577
+ writer.add_scalar(
578
+ "Training/Time", train_metrics["time"].total_seconds(), epoch
579
+ )
580
+ writer.add_scalar("Validation/Top1_Accuracy", top1_acc, epoch)
581
+ writer.add_scalar("Validation/Top5_Accuracy", top5_acc, epoch)
582
+ writer.add_scalar("Validation/Loss", val_loss, epoch)
583
+
584
+ is_best = top1_acc > best_acc
585
+ best_acc = max(top1_acc, best_acc)
586
+
587
+ # Save checkpoint
588
+ torch.save(
589
+ {
590
+ "epoch": epoch,
591
+ "model_state_dict": model.state_dict(),
592
+ "optimizer_state_dict": optimizer.state_dict(),
593
+ "scheduler_state_dict": scheduler.state_dict(),
594
+ "best_acc": best_acc,
595
+ "top1_accuracy": top1_acc,
596
+ "top5_accuracy": top5_acc,
597
+ },
598
+ os.path.join(log_dir, "best_model.pth"),
599
+ )
600
+
601
+ if top1_acc >= 70.0:
602
+ logger.info(
603
+ f"\nTarget accuracy of 70% achieved! Current accuracy: {top1_acc:.2f}%"
604
+ )
605
+ torch.save(
606
+ {
607
+ "epoch": epoch,
608
+ "model_state_dict": model.state_dict(),
609
+ "optimizer_state_dict": optimizer.state_dict(),
610
+ "best_acc": best_acc,
611
+ "top1_accuracy": top1_acc,
612
+ "top5_accuracy": top5_acc,
613
+ },
614
+ os.path.join(log_dir, "target_achieved_model.pth"),
615
+ )
616
+ # break
617
+
618
+ # Save metrics after each epoch
619
+ # with open(os.path.join(log_dir, "metrics.json"), "w") as f:
620
+ # json.dump(train_metrics, f, indent=4)
621
+
622
+ scheduler.step()
623
+
624
+ if args.local_rank <= 0:
625
+ end_time = datetime.now()
626
+ training_time = end_time - start_time
627
+ writer.close()
628
+ logger.info("\nTraining completed!")
629
+ logger.info(f"Total training time: {training_time}")
630
+ logger.info(f"Best Top-1 Accuracy: {train_metrics['best_top1_acc']:.2f}%")
631
+ logger.info(
632
+ f"Target accuracy of 70% {'achieved' if train_metrics['best_top1_acc'] >= 70.0 else 'not achieved'}"
633
+ )
634
+
635
+
636
+ if __name__ == "__main__":
637
+ main()