kmfoda commited on
Commit
2467ab2
·
1 Parent(s): 34e3f6e

Add app.py

Browse files
Files changed (3) hide show
  1. app.py +33 -0
  2. evaluate.py +94 -0
  3. results.json +318 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import pandas as pd
4
+
5
+ with open('results.json', 'r') as file:
6
+ results = json.load(file)
7
+
8
+ models = [key for key in results.keys()]
9
+ demo = gr.Blocks()
10
+
11
+ df = pd.DataFrame.from_dict(results[models[0]], orient = "index").reset_index()
12
+ df.columns = ["Step", "Loss"]
13
+ df["Step"] = pd.to_numeric(df["Step"])
14
+
15
+ def return_results(model_name):
16
+ print(model_name)
17
+ df = pd.DataFrame.from_dict(results[model_name], orient = "index").reset_index()
18
+ df.columns = ["Step", "Loss"]
19
+ df["Step"] = pd.to_numeric(df["Step"])
20
+ return df
21
+
22
+ with demo:
23
+ with gr.Row():
24
+ title = gr.Markdown(value=f"""# <p style="text-align: center;"> Subnet 38 Model Convergence</p>""")
25
+ with gr.Row():
26
+ dropdown_1 = gr.Dropdown(choices = models, value = models[0])
27
+ button_1 = gr.Button("Submit")
28
+ with gr.Row():
29
+ chart = gr.LinePlot(df, "Step", "Loss")
30
+
31
+ button_1.click(return_results, dropdown_1, chart)
32
+
33
+ demo.launch(debug=True, server_name="0.0.0.0", server_port=7860)
evaluate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from distributed_training.data.dataset import DataLoader
4
+ import random
5
+ from huggingface_hub import list_repo_refs
6
+ import matplotlib.pyplot as plt
7
+ import json
8
+
9
+ device = "cuda"
10
+ test_indices_length = 10
11
+
12
+ models = ["distributed/optimized-gpt2-250m", "distributed/gpt2-250m"]
13
+
14
+ with open('./results.json', 'r') as file:
15
+ results = json.load(file)
16
+
17
+ for model_name in models:
18
+
19
+ if (model_name not in results.keys()) or (model_name == "distributed/optimized-gpt2-250m"):
20
+ results[model_name] = {}
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
23
+
24
+ refs = list_repo_refs(model_name, repo_type="model")
25
+ global_epoch = max([int(tag.name) for tag in refs.tags]) if refs.tags else None
26
+
27
+ for epoch in range(0, global_epoch, 5):
28
+ # for epoch in [global_epoch]:
29
+
30
+ if str(epoch) in results[model_name].keys():
31
+ continue
32
+
33
+ model = AutoModelForCausalLM.from_pretrained(model_name, revision=str(epoch), trust_remote_code=True)
34
+ model = model.to(device)
35
+
36
+ search_start = random.choice(
37
+ range(
38
+ DataLoader.max_pages
39
+ - test_indices_length
40
+ + 1
41
+ )
42
+ )
43
+ group = [
44
+ i
45
+ for i in range(
46
+ search_start, search_start + test_indices_length
47
+ )
48
+ ]
49
+
50
+ dataloader = DataLoader(
51
+ batch_size=1,
52
+ sequence_length=1024,
53
+ rows=group,
54
+ )
55
+
56
+ total_loss = 0
57
+ index = 0
58
+ # Train data for one epoch
59
+ for index, batch in enumerate(dataloader):
60
+ inputs = batch[0].to(device)
61
+ labels = batch[1].to(device)
62
+
63
+ if (len(inputs[0]) != len(labels[0])):
64
+ breakpoint()
65
+
66
+ if "optimized" in model_name:
67
+ outputs = model(input_ids=inputs, labels=labels)
68
+ loss = outputs[1]
69
+ else:
70
+ outputs = model(input_ids=inputs, labels=inputs)
71
+ loss = outputs.loss
72
+
73
+ # Accumulate Total Loss
74
+ total_loss += loss.detach().item()
75
+
76
+ # Backward Pass
77
+ model.zero_grad()
78
+
79
+ average_loss = total_loss / (index+1)
80
+ results[model_name][str(epoch)] = [average_loss]
81
+ print(f"Epoch: {epoch} Average Loss: {average_loss:.2f}")
82
+
83
+ # breakpoint()
84
+ with open("./results.json", "w") as outfile:
85
+ json.dump(results, outfile, indent = 4)
86
+
87
+ for model_name in models:
88
+
89
+ plt.plot(results[model_name].keys(), results[model_name].values())
90
+ plt.title(f"{model_name} Convergence Over Time")
91
+ plt.xlabel("Steps")
92
+ plt.ylabel("Loss")
93
+ plt.xticks(fontsize=3.5)
94
+ plt.savefig(f"{model_name.split('/')[1]}_results.png")
results.json ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "distributed/optimized-gpt2-250m": {
3
+ "0": [
4
+ 11.042416954040528
5
+ ],
6
+ "5": [
7
+ 9.064676761627197
8
+ ],
9
+ "10": [
10
+ 8.353436279296876
11
+ ],
12
+ "15": [
13
+ 8.157295894622802
14
+ ],
15
+ "20": [
16
+ 7.744552771250407
17
+ ],
18
+ "25": [
19
+ 7.923193550109863
20
+ ],
21
+ "30": [
22
+ 7.360100865364075
23
+ ],
24
+ "35": [
25
+ 7.582625230153401
26
+ ],
27
+ "40": [
28
+ 7.635447263717651
29
+ ],
30
+ "45": [
31
+ 7.298124694824219
32
+ ],
33
+ "50": [
34
+ 7.584524154663086
35
+ ],
36
+ "55": [
37
+ 7.3763152122497555
38
+ ],
39
+ "60": [
40
+ 7.288678407669067
41
+ ],
42
+ "65": [
43
+ 7.490873456001282
44
+ ],
45
+ "70": [
46
+ 6.960979843139649
47
+ ],
48
+ "75": [
49
+ 7.144528865814209
50
+ ],
51
+ "80": [
52
+ 7.195922565460205
53
+ ],
54
+ "85": [
55
+ 7.632096767425537
56
+ ],
57
+ "90": [
58
+ 7.1985063552856445
59
+ ],
60
+ "95": [
61
+ 6.93459119796753
62
+ ],
63
+ "100": [
64
+ 6.701247930526733
65
+ ],
66
+ "105": [
67
+ 7.049336791038513
68
+ ],
69
+ "110": [
70
+ 6.837615370750427
71
+ ],
72
+ "115": [
73
+ 7.020212531089783
74
+ ],
75
+ "120": [
76
+ 6.697751712799072
77
+ ],
78
+ "125": [
79
+ 6.588788318634033
80
+ ],
81
+ "130": [
82
+ 6.7763800621032715
83
+ ],
84
+ "135": [
85
+ 6.9689741134643555
86
+ ],
87
+ "140": [
88
+ 6.709237098693848
89
+ ],
90
+ "145": [
91
+ 7.035352826118469
92
+ ],
93
+ "150": [
94
+ 6.6759562492370605
95
+ ],
96
+ "155": [
97
+ 6.7904438972473145
98
+ ],
99
+ "160": [
100
+ 6.934930443763733
101
+ ],
102
+ "165": [
103
+ 6.596151669820149
104
+ ],
105
+ "170": [
106
+ 6.548283481597901
107
+ ],
108
+ "175": [
109
+ 6.447548770904541
110
+ ],
111
+ "180": [
112
+ 6.536311149597168
113
+ ],
114
+ "185": [
115
+ 6.70653502146403
116
+ ],
117
+ "190": [
118
+ 6.557690461476644
119
+ ],
120
+ "195": [
121
+ 6.67773175239563
122
+ ],
123
+ "200": [
124
+ 6.467767238616943
125
+ ],
126
+ "205": [
127
+ 6.4236222267150875
128
+ ],
129
+ "210": [
130
+ 6.6386902809143065
131
+ ],
132
+ "215": [
133
+ 6.141726970672607
134
+ ],
135
+ "220": [
136
+ 6.378688907623291
137
+ ],
138
+ "225": [
139
+ 6.42099928855896
140
+ ],
141
+ "230": [
142
+ 6.738618612289429
143
+ ],
144
+ "235": [
145
+ 6.558012008666992
146
+ ],
147
+ "240": [
148
+ 6.777796030044556
149
+ ],
150
+ "245": [
151
+ 6.396033000946045
152
+ ],
153
+ "250": [
154
+ 6.102731609344483
155
+ ],
156
+ "255": [
157
+ 6.540631294250488
158
+ ]
159
+ },
160
+ "distributed/gpt2-250m": {
161
+ "0": [
162
+ 10.942681312561035
163
+ ],
164
+ "5": [
165
+ 9.673693656921387
166
+ ],
167
+ "10": [
168
+ 9.623630285263062
169
+ ],
170
+ "15": [
171
+ 9.381710529327393
172
+ ],
173
+ "20": [
174
+ 9.240305423736572
175
+ ],
176
+ "25": [
177
+ 9.34835402170817
178
+ ],
179
+ "30": [
180
+ 9.45114345550537
181
+ ],
182
+ "35": [
183
+ 9.190510940551757
184
+ ],
185
+ "40": [
186
+ 8.936849594116211
187
+ ],
188
+ "45": [
189
+ 8.903728485107422
190
+ ],
191
+ "50": [
192
+ 8.871788597106933
193
+ ],
194
+ "55": [
195
+ 8.653409957885742
196
+ ],
197
+ "60": [
198
+ 8.565237998962402
199
+ ],
200
+ "65": [
201
+ 8.616942405700684
202
+ ],
203
+ "70": [
204
+ 8.725053310394287
205
+ ],
206
+ "75": [
207
+ 8.058599853515625
208
+ ],
209
+ "80": [
210
+ 8.40323429107666
211
+ ],
212
+ "85": [
213
+ 8.251930522918702
214
+ ],
215
+ "90": [
216
+ 8.315114784240723
217
+ ],
218
+ "95": [
219
+ 8.024084663391113
220
+ ],
221
+ "100": [
222
+ 8.095765829086304
223
+ ],
224
+ "105": [
225
+ 8.223698139190674
226
+ ],
227
+ "110": [
228
+ 7.960695743560791
229
+ ],
230
+ "115": [
231
+ 7.827797985076904
232
+ ],
233
+ "120": [
234
+ 8.389174143473307
235
+ ],
236
+ "125": [
237
+ 7.795609354972839
238
+ ],
239
+ "130": [
240
+ 8.024239349365235
241
+ ],
242
+ "135": [
243
+ 7.622925678888957
244
+ ],
245
+ "140": [
246
+ 7.671920299530029
247
+ ],
248
+ "145": [
249
+ 7.719462108612061
250
+ ],
251
+ "150": [
252
+ 7.654707551002502
253
+ ],
254
+ "155": [
255
+ 7.858335399627686
256
+ ],
257
+ "160": [
258
+ 7.582762241363525
259
+ ],
260
+ "165": [
261
+ 7.7280534108479815
262
+ ],
263
+ "170": [
264
+ 7.398298358917236
265
+ ],
266
+ "175": [
267
+ 7.448758959770203
268
+ ],
269
+ "180": [
270
+ 7.248022079467773
271
+ ],
272
+ "185": [
273
+ 7.408734480539958
274
+ ],
275
+ "190": [
276
+ 7.431381821632385
277
+ ],
278
+ "195": [
279
+ 7.13822078704834
280
+ ],
281
+ "200": [
282
+ 7.499457120895386
283
+ ],
284
+ "205": [
285
+ 7.281359386444092
286
+ ],
287
+ "210": [
288
+ 7.49737777709961
289
+ ],
290
+ "215": [
291
+ 7.441878795623779
292
+ ],
293
+ "220": [
294
+ 7.27855650583903
295
+ ],
296
+ "225": [
297
+ 7.162156343460083
298
+ ],
299
+ "230": [
300
+ 7.732161164283752
301
+ ],
302
+ "235": [
303
+ 6.726261933644612
304
+ ],
305
+ "240": [
306
+ 6.9339855194091795
307
+ ],
308
+ "245": [
309
+ 7.31608259677887
310
+ ],
311
+ "250": [
312
+ 7.316546440124512
313
+ ],
314
+ "255": [
315
+ 7.263134765625
316
+ ]
317
+ }
318
+ }