kmfoda commited on
Commit
121fc83
·
1 Parent(s): d5bfcd1

Update results

Browse files
Files changed (2) hide show
  1. evaluate.py +4 -4
  2. results.json +301 -4
evaluate.py CHANGED
@@ -8,7 +8,7 @@ from huggingface_hub import list_repo_refs
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  device = "cuda"
11
- test_indices_length = 1000
12
 
13
  models = ["distributed/optimized-gpt2-250m", "distributed/optimized-gpt2-250m-v0.1.1", "distributed/gpt2-94m"]
14
 
@@ -28,9 +28,9 @@ for model_name in models:
28
  refs = list_repo_refs(model_name, repo_type="model")
29
  global_epoch = max([int(tag.name) for tag in refs.tags]) if refs.tags else None
30
 
31
- for epoch in range(0,global_epoch, 5):
32
 
33
- if str(epoch) in results[model_name].keys():
34
  continue
35
 
36
  model = AutoModelForCausalLM.from_pretrained(model_name, revision=str(epoch), trust_remote_code=True)
@@ -80,7 +80,7 @@ for model_name in models:
80
  model.zero_grad()
81
 
82
  average_loss = total_loss / (index+1)
83
- results[model_name][str(epoch)] = [average_loss]
84
  print(f"Epoch: {epoch} Average Loss: {average_loss:.2f}")
85
 
86
  with open("results.json", "w") as outfile:
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  device = "cuda"
11
+ test_indices_length = 10
12
 
13
  models = ["distributed/optimized-gpt2-250m", "distributed/optimized-gpt2-250m-v0.1.1", "distributed/gpt2-94m"]
14
 
 
28
  refs = list_repo_refs(model_name, repo_type="model")
29
  global_epoch = max([int(tag.name) for tag in refs.tags]) if refs.tags else None
30
 
31
+ for epoch in range(0,global_epoch):
32
 
33
+ if str(epoch) in results[model_name]['main-net'].keys():
34
  continue
35
 
36
  model = AutoModelForCausalLM.from_pretrained(model_name, revision=str(epoch), trust_remote_code=True)
 
80
  model.zero_grad()
81
 
82
  average_loss = total_loss / (index+1)
83
+ results[model_name]['main-net'][str(epoch)] = [average_loss]
84
  print(f"Epoch: {epoch} Average Loss: {average_loss:.2f}")
85
 
86
  with open("results.json", "w") as outfile:
results.json CHANGED
@@ -3285,9 +3285,306 @@
3285
  ],
3286
  "1094": [
3287
  5.5514771938323975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3288
  ]
3289
  },
3290
- "baseline":{
3291
  "0": [
3292
  10.93135
3293
  ],
@@ -3795,7 +4092,7 @@
3795
  }
3796
  },
3797
  "distributed/optimized-gpt2-250m-v0.1.1": {
3798
- "main-net":{
3799
  "0": [
3800
  11.042416954040528
3801
  ],
@@ -4727,7 +5024,7 @@
4727
  6.409368515014648
4728
  ]
4729
  },
4730
- "baseline":{
4731
  "0": [
4732
  10.93135
4733
  ],
@@ -5235,7 +5532,7 @@
5235
  }
5236
  },
5237
  "distributed/gpt2-94m": {
5238
- "main-net":{
5239
  "0": [
5240
  10.942681312561035
5241
  ],
 
3285
  ],
3286
  "1094": [
3287
  5.5514771938323975
3288
+ ],
3289
+ "1095": [
3290
+ 5.654173533121745
3291
+ ],
3292
+ "1096": [
3293
+ 5.783674240112305
3294
+ ],
3295
+ "1097": [
3296
+ 5.732811212539673
3297
+ ],
3298
+ "1098": [
3299
+ 5.725842118263245
3300
+ ],
3301
+ "1099": [
3302
+ 6.016797780990601
3303
+ ],
3304
+ "1100": [
3305
+ 5.492693265279134
3306
+ ],
3307
+ "1101": [
3308
+ 5.746817111968994
3309
+ ],
3310
+ "1102": [
3311
+ 5.732641816139221
3312
+ ],
3313
+ "1103": [
3314
+ 5.6667522430419925
3315
+ ],
3316
+ "1104": [
3317
+ 6.042284965515137
3318
+ ],
3319
+ "1105": [
3320
+ 5.957233905792236
3321
+ ],
3322
+ "1106": [
3323
+ 6.250933527946472
3324
+ ],
3325
+ "1107": [
3326
+ 6.189672231674194
3327
+ ],
3328
+ "1108": [
3329
+ 5.723267146519253
3330
+ ],
3331
+ "1109": [
3332
+ 5.82790470123291
3333
+ ],
3334
+ "1110": [
3335
+ 5.603849093119304
3336
+ ],
3337
+ "1111": [
3338
+ 5.782259106636047
3339
+ ],
3340
+ "1112": [
3341
+ 5.6029471556345625
3342
+ ],
3343
+ "1113": [
3344
+ 5.546664714813232
3345
+ ],
3346
+ "1114": [
3347
+ 5.836468458175659
3348
+ ],
3349
+ "1115": [
3350
+ 5.762608170509338
3351
+ ],
3352
+ "1116": [
3353
+ 6.046388339996338
3354
+ ],
3355
+ "1117": [
3356
+ 6.0027131080627445
3357
+ ],
3358
+ "1118": [
3359
+ 6.388125038146972
3360
+ ],
3361
+ "1119": [
3362
+ 6.11626410484314
3363
+ ],
3364
+ "1120": [
3365
+ 6.112424373626709
3366
+ ],
3367
+ "1121": [
3368
+ 6.001961326599121
3369
+ ],
3370
+ "1122": [
3371
+ 5.683912754058838
3372
+ ],
3373
+ "1123": [
3374
+ 5.743198204040527
3375
+ ],
3376
+ "1124": [
3377
+ 5.722175757090251
3378
+ ],
3379
+ "1125": [
3380
+ 5.8825499534606935
3381
+ ],
3382
+ "1126": [
3383
+ 6.028543186187744
3384
+ ],
3385
+ "1127": [
3386
+ 5.9720460891723635
3387
+ ],
3388
+ "1128": [
3389
+ 6.058712800343831
3390
+ ],
3391
+ "1129": [
3392
+ 5.475493907928467
3393
+ ],
3394
+ "1130": [
3395
+ 5.970499634742737
3396
+ ],
3397
+ "1131": [
3398
+ 5.9493067264556885
3399
+ ],
3400
+ "1132": [
3401
+ 5.458620548248291
3402
+ ],
3403
+ "1133": [
3404
+ 5.992060820261638
3405
+ ],
3406
+ "1134": [
3407
+ 5.951226472854614
3408
+ ],
3409
+ "1135": [
3410
+ 5.877881646156311
3411
+ ],
3412
+ "1136": [
3413
+ 5.603206443786621
3414
+ ],
3415
+ "1137": [
3416
+ 5.8340943336486815
3417
+ ],
3418
+ "1138": [
3419
+ 5.788412570953369
3420
+ ],
3421
+ "1139": [
3422
+ 5.737103462219238
3423
+ ],
3424
+ "1140": [
3425
+ 5.636613210042317
3426
+ ],
3427
+ "1141": [
3428
+ 5.949309587478638
3429
+ ],
3430
+ "1142": [
3431
+ 5.854878067970276
3432
+ ],
3433
+ "1143": [
3434
+ 5.924749374389648
3435
+ ],
3436
+ "1144": [
3437
+ 6.321739387512207
3438
+ ],
3439
+ "1145": [
3440
+ 5.9811422030131025
3441
+ ],
3442
+ "1146": [
3443
+ 5.701364517211914
3444
+ ],
3445
+ "1147": [
3446
+ 5.503353691101074
3447
+ ],
3448
+ "1148": [
3449
+ 5.773120641708374
3450
+ ],
3451
+ "1149": [
3452
+ 6.042929470539093
3453
+ ],
3454
+ "1150": [
3455
+ 5.8076521555582685
3456
+ ],
3457
+ "1151": [
3458
+ 5.682760079701741
3459
+ ],
3460
+ "1152": [
3461
+ 5.757667303085327
3462
+ ],
3463
+ "1153": [
3464
+ 5.896499156951904
3465
+ ],
3466
+ "1154": [
3467
+ 6.025218367576599
3468
+ ],
3469
+ "1155": [
3470
+ 5.879011154174805
3471
+ ],
3472
+ "1156": [
3473
+ 5.868439674377441
3474
+ ],
3475
+ "1157": [
3476
+ 6.418252754211426
3477
+ ],
3478
+ "1158": [
3479
+ 6.2828675508499146
3480
+ ],
3481
+ "1159": [
3482
+ 6.36786642074585
3483
+ ],
3484
+ "1160": [
3485
+ 6.58310022354126
3486
+ ],
3487
+ "1161": [
3488
+ 6.19826873143514
3489
+ ],
3490
+ "1162": [
3491
+ 6.289691209793091
3492
+ ],
3493
+ "1163": [
3494
+ 5.9907801151275635
3495
+ ],
3496
+ "1164": [
3497
+ 6.041745066642761
3498
+ ],
3499
+ "1165": [
3500
+ 6.02010326385498
3501
+ ],
3502
+ "1166": [
3503
+ 5.7515941460927325
3504
+ ],
3505
+ "1167": [
3506
+ 5.48467755317688
3507
+ ],
3508
+ "1168": [
3509
+ 6.096215724945068
3510
+ ],
3511
+ "1169": [
3512
+ 5.959380865097046
3513
+ ],
3514
+ "1170": [
3515
+ 5.851028124491374
3516
+ ],
3517
+ "1171": [
3518
+ 5.8480740785598755
3519
+ ],
3520
+ "1172": [
3521
+ 5.9064167737960815
3522
+ ],
3523
+ "1173": [
3524
+ 5.956684430440267
3525
+ ],
3526
+ "1174": [
3527
+ 6.00377357006073
3528
+ ],
3529
+ "1175": [
3530
+ 6.077920118967692
3531
+ ],
3532
+ "1176": [
3533
+ 5.967975934346517
3534
+ ],
3535
+ "1177": [
3536
+ 6.253712558746338
3537
+ ],
3538
+ "1178": [
3539
+ 5.780354785919189
3540
+ ],
3541
+ "1179": [
3542
+ 5.4884843826293945
3543
+ ],
3544
+ "1180": [
3545
+ 5.482951482137044
3546
+ ],
3547
+ "1181": [
3548
+ 5.966793219248454
3549
+ ],
3550
+ "1182": [
3551
+ 5.51493239402771
3552
+ ],
3553
+ "1183": [
3554
+ 5.4840850830078125
3555
+ ],
3556
+ "1184": [
3557
+ 5.834247946739197
3558
+ ],
3559
+ "1185": [
3560
+ 5.770521521568298
3561
+ ],
3562
+ "1186": [
3563
+ 5.671548962593079
3564
+ ],
3565
+ "1187": [
3566
+ 5.491109371185303
3567
+ ],
3568
+ "1188": [
3569
+ 5.561888694763184
3570
+ ],
3571
+ "1189": [
3572
+ 5.711345076560974
3573
+ ],
3574
+ "1190": [
3575
+ 5.628474712371826
3576
+ ],
3577
+ "1191": [
3578
+ 5.514147567749023
3579
+ ],
3580
+ "1192": [
3581
+ 5.556583046913147
3582
+ ],
3583
+ "1193": [
3584
+ 5.653698126475017
3585
  ]
3586
  },
3587
+ "baseline": {
3588
  "0": [
3589
  10.93135
3590
  ],
 
4092
  }
4093
  },
4094
  "distributed/optimized-gpt2-250m-v0.1.1": {
4095
+ "main-net": {
4096
  "0": [
4097
  11.042416954040528
4098
  ],
 
5024
  6.409368515014648
5025
  ]
5026
  },
5027
+ "baseline": {
5028
  "0": [
5029
  10.93135
5030
  ],
 
5532
  }
5533
  },
5534
  "distributed/gpt2-94m": {
5535
+ "main-net": {
5536
  "0": [
5537
  10.942681312561035
5538
  ],