Spaces:
Sleeping
Sleeping
Update to include gpt2-2b
Browse files- evaluate.py +60 -56
- results.json +260 -0
evaluate.py
CHANGED
@@ -1,22 +1,29 @@
|
|
1 |
import json
|
|
|
2 |
import random
|
|
|
3 |
|
4 |
import torch
|
5 |
-
import time
|
6 |
-
import os
|
7 |
from distributed_training.data.dataset import DataLoader
|
8 |
-
from huggingface_hub import list_repo_refs
|
9 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
from huggingface_hub import create_tag, list_repo_refs, scan_cache_dir
|
|
|
11 |
|
12 |
device = "cuda"
|
13 |
test_indices_length = 1000
|
14 |
AUTOMATE = True
|
15 |
|
16 |
-
models = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
if os.path.exists("results.json"):
|
19 |
-
with open(
|
20 |
results = json.load(file)
|
21 |
else:
|
22 |
results = {}
|
@@ -24,39 +31,34 @@ else:
|
|
24 |
while True:
|
25 |
for model_name in [models[0]]:
|
26 |
|
27 |
-
if
|
28 |
results[model_name] = {}
|
29 |
|
30 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
31 |
|
32 |
refs = list_repo_refs(model_name, repo_type="model")
|
33 |
global_epoch = max([int(tag.name) for tag in refs.tags]) if refs.tags else None
|
34 |
|
35 |
-
if global_epoch in results[model_name][
|
36 |
print(f"Results for epoch {global_epoch} already calcualted")
|
37 |
-
time.sleep(30*60)
|
38 |
|
39 |
-
for epoch in range(0,global_epoch, 1):
|
40 |
|
41 |
-
if str(epoch) in results[model_name][
|
42 |
continue
|
43 |
|
44 |
-
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
45 |
model = model.to(device)
|
46 |
|
47 |
search_start = random.choice(
|
48 |
-
range(
|
49 |
-
DataLoader.max_pages
|
50 |
-
- test_indices_length
|
51 |
-
+ 1
|
52 |
-
)
|
53 |
)
|
54 |
-
group = [
|
55 |
-
i
|
56 |
-
for i in range(
|
57 |
-
search_start, search_start + test_indices_length
|
58 |
-
)
|
59 |
-
]
|
60 |
|
61 |
dataloader = DataLoader(
|
62 |
batch_size=1,
|
@@ -71,7 +73,7 @@ while True:
|
|
71 |
inputs = batch[0].to(device)
|
72 |
labels = batch[1].to(device)
|
73 |
|
74 |
-
if
|
75 |
breakpoint()
|
76 |
if "optimized" in model_name:
|
77 |
outputs = model(input_ids=inputs, labels=labels)
|
@@ -86,37 +88,39 @@ while True:
|
|
86 |
# Backward Pass
|
87 |
model.zero_grad()
|
88 |
|
89 |
-
average_loss = total_loss / (index+1)
|
90 |
-
results[model_name][
|
91 |
print(f"Epoch: {epoch} Average Loss: {average_loss:.2f}")
|
92 |
|
93 |
with open("results.json", "w") as outfile:
|
94 |
-
json.dump(results, outfile, indent
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
1 |
import json
|
2 |
+
import os
|
3 |
import random
|
4 |
+
import time
|
5 |
|
6 |
import torch
|
|
|
|
|
7 |
from distributed_training.data.dataset import DataLoader
|
|
|
|
|
8 |
from huggingface_hub import create_tag, list_repo_refs, scan_cache_dir
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
|
11 |
device = "cuda"
|
12 |
test_indices_length = 1000
|
13 |
AUTOMATE = True
|
14 |
|
15 |
+
models = [
|
16 |
+
"distributed/optimized-gpt2-2b",
|
17 |
+
"distributed/optimized-gpt2-1b",
|
18 |
+
"distributed/optimized-gpt2-500m",
|
19 |
+
"distributed/optimized-gpt2-250m",
|
20 |
+
"distributed/optimized-gpt2-250m-v0.1.3",
|
21 |
+
"distributed/optimized-gpt2-250m-v0.1.1",
|
22 |
+
"distributed/gpt2-94m",
|
23 |
+
]
|
24 |
|
25 |
if os.path.exists("results.json"):
|
26 |
+
with open("results.json", "r") as file:
|
27 |
results = json.load(file)
|
28 |
else:
|
29 |
results = {}
|
|
|
31 |
while True:
|
32 |
for model_name in [models[0]]:
|
33 |
|
34 |
+
if model_name not in results.keys():
|
35 |
results[model_name] = {}
|
36 |
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
38 |
+
"distributed/optimized-gpt2-250m", trust_remote_code=True
|
39 |
+
)
|
40 |
|
41 |
refs = list_repo_refs(model_name, repo_type="model")
|
42 |
global_epoch = max([int(tag.name) for tag in refs.tags]) if refs.tags else None
|
43 |
|
44 |
+
if global_epoch in results[model_name]["main-net"].keys():
|
45 |
print(f"Results for epoch {global_epoch} already calcualted")
|
46 |
+
time.sleep(30 * 60)
|
47 |
|
48 |
+
for epoch in range(0, global_epoch, 1):
|
49 |
|
50 |
+
if str(epoch) in results[model_name]["main-net"].keys():
|
51 |
continue
|
52 |
|
53 |
+
model = AutoModelForCausalLM.from_pretrained(
|
54 |
+
model_name, revision=str(epoch), trust_remote_code=True
|
55 |
+
)
|
56 |
model = model.to(device)
|
57 |
|
58 |
search_start = random.choice(
|
59 |
+
range(DataLoader.max_pages - test_indices_length + 1)
|
|
|
|
|
|
|
|
|
60 |
)
|
61 |
+
group = [i for i in range(search_start, search_start + test_indices_length)]
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
dataloader = DataLoader(
|
64 |
batch_size=1,
|
|
|
73 |
inputs = batch[0].to(device)
|
74 |
labels = batch[1].to(device)
|
75 |
|
76 |
+
if len(inputs[0]) != len(labels[0]):
|
77 |
breakpoint()
|
78 |
if "optimized" in model_name:
|
79 |
outputs = model(input_ids=inputs, labels=labels)
|
|
|
88 |
# Backward Pass
|
89 |
model.zero_grad()
|
90 |
|
91 |
+
average_loss = total_loss / (index + 1)
|
92 |
+
results[model_name]["main-net"][str(epoch)] = [average_loss]
|
93 |
print(f"Epoch: {epoch} Average Loss: {average_loss:.2f}")
|
94 |
|
95 |
with open("results.json", "w") as outfile:
|
96 |
+
json.dump(results, outfile, indent=4)
|
97 |
+
|
98 |
+
current_revision = model.config._commit_hash
|
99 |
+
keep_recent = 1
|
100 |
+
try:
|
101 |
+
cache_info = scan_cache_dir()
|
102 |
+
for repo in cache_info.repos:
|
103 |
+
if repo.repo_id == model_name:
|
104 |
+
revisions = sorted(
|
105 |
+
repo.revisions, key=lambda r: r.last_modified, reverse=True
|
106 |
+
)
|
107 |
+
current_index = next(
|
108 |
+
(
|
109 |
+
i
|
110 |
+
for i, r in enumerate(revisions)
|
111 |
+
if r.commit_hash == current_revision
|
112 |
+
),
|
113 |
+
None,
|
114 |
+
)
|
115 |
+
if current_index is not None:
|
116 |
+
for revision in revisions[
|
117 |
+
max(current_index + 1, keep_recent) :
|
118 |
+
]:
|
119 |
+
cache_info.delete_revisions(
|
120 |
+
revision.commit_hash
|
121 |
+
).execute()
|
122 |
+
break
|
123 |
+
except:
|
124 |
+
print(
|
125 |
+
"Failed to delete previous model version from cache. This might lead to 100% disk space utlisation in the future."
|
126 |
+
)
|
results.json
CHANGED
@@ -1,4 +1,264 @@
|
|
1 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
"distributed/optimized-gpt2-1b": {
|
3 |
"main-net": {
|
4 |
"0": [
|
|
|
1 |
{
|
2 |
+
"distributed/optimized-gpt2-2b": {
|
3 |
+
"main-net": {
|
4 |
+
"0": [
|
5 |
+
24.161225109466358
|
6 |
+
],
|
7 |
+
"1": [
|
8 |
+
10.691863035173064
|
9 |
+
],
|
10 |
+
"2": [
|
11 |
+
10.022417357756433
|
12 |
+
],
|
13 |
+
"3": [
|
14 |
+
9.55640449560465
|
15 |
+
],
|
16 |
+
"4": [
|
17 |
+
9.58835850097239
|
18 |
+
],
|
19 |
+
"5": [
|
20 |
+
9.462455684855833
|
21 |
+
],
|
22 |
+
"6": [
|
23 |
+
9.277088693210057
|
24 |
+
],
|
25 |
+
"7": [
|
26 |
+
9.301309664550528
|
27 |
+
],
|
28 |
+
"8": [
|
29 |
+
9.130553578411487
|
30 |
+
],
|
31 |
+
"9": [
|
32 |
+
9.198034809787515
|
33 |
+
],
|
34 |
+
"10": [
|
35 |
+
9.150826927009486
|
36 |
+
],
|
37 |
+
"11": [
|
38 |
+
9.101872412292888
|
39 |
+
],
|
40 |
+
"12": [
|
41 |
+
9.037806239881014
|
42 |
+
],
|
43 |
+
"13": [
|
44 |
+
8.92622016663717
|
45 |
+
],
|
46 |
+
"14": [
|
47 |
+
8.890519196425027
|
48 |
+
],
|
49 |
+
"15": [
|
50 |
+
8.851365919739123
|
51 |
+
],
|
52 |
+
"16": [
|
53 |
+
8.86057827515688
|
54 |
+
],
|
55 |
+
"17": [
|
56 |
+
8.770904886188791
|
57 |
+
],
|
58 |
+
"18": [
|
59 |
+
8.787317971086813
|
60 |
+
],
|
61 |
+
"19": [
|
62 |
+
8.762063222648823
|
63 |
+
],
|
64 |
+
"20": [
|
65 |
+
8.691791485353338
|
66 |
+
],
|
67 |
+
"21": [
|
68 |
+
8.612739718558705
|
69 |
+
],
|
70 |
+
"22": [
|
71 |
+
8.662117434136661
|
72 |
+
],
|
73 |
+
"23": [
|
74 |
+
8.569304224873378
|
75 |
+
],
|
76 |
+
"24": [
|
77 |
+
8.508418809899077
|
78 |
+
],
|
79 |
+
"25": [
|
80 |
+
8.416297421540703
|
81 |
+
],
|
82 |
+
"26": [
|
83 |
+
8.395312497974823
|
84 |
+
],
|
85 |
+
"27": [
|
86 |
+
8.361652030098822
|
87 |
+
],
|
88 |
+
"28": [
|
89 |
+
8.309751656976077
|
90 |
+
],
|
91 |
+
"29": [
|
92 |
+
8.271234605559991
|
93 |
+
],
|
94 |
+
"30": [
|
95 |
+
8.302275388588832
|
96 |
+
],
|
97 |
+
"31": [
|
98 |
+
8.172970907627247
|
99 |
+
],
|
100 |
+
"32": [
|
101 |
+
8.112572425603867
|
102 |
+
],
|
103 |
+
"33": [
|
104 |
+
0.0
|
105 |
+
],
|
106 |
+
"34": [
|
107 |
+
8.067226740030142
|
108 |
+
],
|
109 |
+
"35": [
|
110 |
+
8.015923105084333
|
111 |
+
],
|
112 |
+
"36": [
|
113 |
+
8.000407182927034
|
114 |
+
],
|
115 |
+
"37": [
|
116 |
+
7.897538427511851
|
117 |
+
],
|
118 |
+
"38": [
|
119 |
+
7.8652859703003175
|
120 |
+
],
|
121 |
+
"39": [
|
122 |
+
7.817014323654024
|
123 |
+
],
|
124 |
+
"40": [
|
125 |
+
7.807054872649335
|
126 |
+
],
|
127 |
+
"41": [
|
128 |
+
7.827541650510302
|
129 |
+
],
|
130 |
+
"42": [
|
131 |
+
7.689037536915112
|
132 |
+
],
|
133 |
+
"43": [
|
134 |
+
7.757941870595895
|
135 |
+
],
|
136 |
+
"44": [
|
137 |
+
7.804858555885658
|
138 |
+
],
|
139 |
+
"45": [
|
140 |
+
7.6064825819472395
|
141 |
+
],
|
142 |
+
"46": [
|
143 |
+
7.611153989357136
|
144 |
+
],
|
145 |
+
"47": [
|
146 |
+
7.59192221113976
|
147 |
+
],
|
148 |
+
"48": [
|
149 |
+
7.578028715109523
|
150 |
+
],
|
151 |
+
"49": [
|
152 |
+
7.535055926722339
|
153 |
+
],
|
154 |
+
"50": [
|
155 |
+
7.4285404285591445
|
156 |
+
],
|
157 |
+
"51": [
|
158 |
+
7.508890847739933
|
159 |
+
],
|
160 |
+
"52": [
|
161 |
+
7.594857940802703
|
162 |
+
],
|
163 |
+
"53": [
|
164 |
+
7.512502627618094
|
165 |
+
],
|
166 |
+
"54": [
|
167 |
+
7.506787989576394
|
168 |
+
],
|
169 |
+
"55": [
|
170 |
+
7.501947044107324
|
171 |
+
],
|
172 |
+
"56": [
|
173 |
+
7.429504378702631
|
174 |
+
],
|
175 |
+
"57": [
|
176 |
+
7.372085496371972
|
177 |
+
],
|
178 |
+
"58": [
|
179 |
+
7.408436578101554
|
180 |
+
],
|
181 |
+
"59": [
|
182 |
+
7.408653273726955
|
183 |
+
],
|
184 |
+
"60": [
|
185 |
+
7.3867659356859
|
186 |
+
],
|
187 |
+
"61": [
|
188 |
+
7.328268373037534
|
189 |
+
],
|
190 |
+
"62": [
|
191 |
+
7.374929182813982
|
192 |
+
],
|
193 |
+
"63": [
|
194 |
+
7.309664613777591
|
195 |
+
],
|
196 |
+
"64": [
|
197 |
+
7.282248006827795
|
198 |
+
],
|
199 |
+
"65": [
|
200 |
+
7.386888501138398
|
201 |
+
],
|
202 |
+
"66": [
|
203 |
+
7.2420131648637325
|
204 |
+
],
|
205 |
+
"67": [
|
206 |
+
7.3391031794848
|
207 |
+
],
|
208 |
+
"68": [
|
209 |
+
7.266478459521978
|
210 |
+
],
|
211 |
+
"69": [
|
212 |
+
7.2372944774106145
|
213 |
+
],
|
214 |
+
"70": [
|
215 |
+
7.293267532594487
|
216 |
+
],
|
217 |
+
"71": [
|
218 |
+
7.174058415324812
|
219 |
+
],
|
220 |
+
"72": [
|
221 |
+
7.300561785442671
|
222 |
+
],
|
223 |
+
"73": [
|
224 |
+
7.2531355329462
|
225 |
+
],
|
226 |
+
"74": [
|
227 |
+
7.176742718436501
|
228 |
+
],
|
229 |
+
"75": [
|
230 |
+
7.150713069236231
|
231 |
+
],
|
232 |
+
"76": [
|
233 |
+
7.181416940589538
|
234 |
+
],
|
235 |
+
"77": [
|
236 |
+
7.206587009709836
|
237 |
+
],
|
238 |
+
"78": [
|
239 |
+
7.08934546457475
|
240 |
+
],
|
241 |
+
"79": [
|
242 |
+
7.042178735546037
|
243 |
+
],
|
244 |
+
"80": [
|
245 |
+
7.034408964761874
|
246 |
+
],
|
247 |
+
"81": [
|
248 |
+
7.1328608132201765
|
249 |
+
],
|
250 |
+
"82": [
|
251 |
+
7.020384328287156
|
252 |
+
],
|
253 |
+
"83": [
|
254 |
+
6.989416580784077
|
255 |
+
],
|
256 |
+
"84": [
|
257 |
+
7.075146196260734
|
258 |
+
]
|
259 |
+
},
|
260 |
+
"baseline": {}
|
261 |
+
},
|
262 |
"distributed/optimized-gpt2-1b": {
|
263 |
"main-net": {
|
264 |
"0": [
|