kmfoda commited on
Commit
3358017
·
1 Parent(s): a1bdc8e

Upload results

Browse files
Files changed (2) hide show
  1. evaluate.py +2 -2
  2. results.json +45 -0
evaluate.py CHANGED
@@ -8,7 +8,7 @@ from huggingface_hub import list_repo_refs
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  device = "cuda"
11
- test_indices_length = 10
12
 
13
  models = ["distributed/optimized-gpt2-250m", "distributed/optimized-gpt2-250m-v0.1.1", "distributed/gpt2-94m"]
14
 
@@ -28,7 +28,7 @@ for model_name in models:
28
  refs = list_repo_refs(model_name, repo_type="model")
29
  global_epoch = max([int(tag.name) for tag in refs.tags]) if refs.tags else None
30
 
31
- for epoch in range(0,global_epoch):
32
 
33
  if str(epoch) in results[model_name]['main-net'].keys():
34
  continue
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  device = "cuda"
11
+ test_indices_length = 1000
12
 
13
  models = ["distributed/optimized-gpt2-250m", "distributed/optimized-gpt2-250m-v0.1.1", "distributed/gpt2-94m"]
14
 
 
28
  refs = list_repo_refs(model_name, repo_type="model")
29
  global_epoch = max([int(tag.name) for tag in refs.tags]) if refs.tags else None
30
 
31
+ for epoch in range(0,global_epoch, 5):
32
 
33
  if str(epoch) in results[model_name]['main-net'].keys():
34
  continue
results.json CHANGED
@@ -771,6 +771,51 @@
771
  ],
772
  "1280": [
773
  5.533918690816667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  ]
775
  },
776
  "baseline": {
 
771
  ],
772
  "1280": [
773
  5.533918690816667
774
+ ],
775
+ "1285": [
776
+ 5.567983552905323
777
+ ],
778
+ "1290": [
779
+ 5.5455932653448885
780
+ ],
781
+ "1295": [
782
+ 5.618938806755789
783
+ ],
784
+ "1300": [
785
+ 5.57726935165956
786
+ ],
787
+ "1305": [
788
+ 5.504214384279199
789
+ ],
790
+ "1310": [
791
+ 5.591931197914315
792
+ ],
793
+ "1315": [
794
+ 5.649649393850551
795
+ ],
796
+ "1320": [
797
+ 5.618240403285898
798
+ ],
799
+ "1325": [
800
+ 5.617241533815735
801
+ ],
802
+ "1330": [
803
+ 5.576359107468154
804
+ ],
805
+ "1335": [
806
+ 5.539896753655762
807
+ ],
808
+ "1340": [
809
+ 5.518490235810309
810
+ ],
811
+ "1345": [
812
+ 5.588169578686511
813
+ ],
814
+ "1350": [
815
+ 5.50622625116791
816
+ ],
817
+ "1355": [
818
+ 5.502286915139202
819
  ]
820
  },
821
  "baseline": {