File size: 21,928 Bytes
bcb8ccf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
from typing import List

import datasets
import evaluate
import numpy as np

import pytest


@pytest.fixture(scope="session")
def syntaxgym_dataset():
    return datasets.load_dataset("syntaxgym", "subordination_src-src")


@pytest.fixture(scope="session")
def syntaxgym_metric():
    return evaluate.load("./syntaxgym.py")


@pytest.fixture(scope="session")
def model_ref():
    # return "hf-internal-testing/tiny-random-gpt_neo"
    return "gpt2"


# Reference region surprisals computed with syntaxgym-core.
# See notebook in https://colab.research.google.com/drive/1qziyPcu65jffizSPi-ZGHKR0x7BaHFMS#scrollTo=RgtnScy6LLKi .
GPT2_SUBORDINATION_SRC_REFERENCE = \
[{('no-sub_matrix', 1): 13.151199615123803,
  ('no-sub_matrix', 2): 38.503222716703526,
  ('no-sub_matrix', 3): 27.623861034812286,
  ('no-sub_matrix', 4): 48.831672846038224,
  ('no-sub_matrix', 5): 38.08533699286694,
  ('no-sub_no-matrix', 1): 13.151199615123803,
  ('no-sub_no-matrix', 2): 38.503222716703526,
  ('no-sub_no-matrix', 3): 27.623861034812286,
  ('no-sub_no-matrix', 4): 48.831687980511504,
  ('no-sub_no-matrix', 5): 1.8096143510772873,
  ('sub_matrix', 1): 14.905592916748805,
  ('sub_matrix', 2): 39.06304309956175,
  ('sub_matrix', 3): 26.862648365854433,
  ('sub_matrix', 4): 50.56554401687938,
  ('sub_matrix', 5): 26.532245572980194,
  ('sub_no-matrix', 1): 14.905592916748805,
  ('sub_no-matrix', 2): 39.06304309956175,
  ('sub_no-matrix', 3): 26.862648365854433,
  ('sub_no-matrix', 4): 50.56553438585093,
  ('sub_no-matrix', 5): 7.470089829866611},
 {('no-sub_matrix', 1): 10.116093820255577,
  ('no-sub_matrix', 2): 20.96513246705127,
  ('no-sub_matrix', 3): 20.02959138986416,
  ('no-sub_matrix', 4): 23.779661397107446,
  ('no-sub_matrix', 5): 33.2560281692696,
  ('no-sub_no-matrix', 1): 10.116093820255577,
  ('no-sub_no-matrix', 2): 20.96513246705127,
  ('no-sub_no-matrix', 3): 20.02959138986416,
  ('no-sub_no-matrix', 4): 23.779661397107446,
  ('no-sub_no-matrix', 5): 1.9449125865631063,
  ('sub_matrix', 1): 13.545157521732826,
  ('sub_matrix', 2): 24.96048395897244,
  ('sub_matrix', 3): 18.609464944317324,
  ('sub_matrix', 4): 23.057566440062317,
  ('sub_matrix', 5): 26.424454285669032,
  ('sub_no-matrix', 1): 13.545157521732826,
  ('sub_no-matrix', 2): 24.96048395897244,
  ('sub_no-matrix', 3): 18.609464944317324,
  ('sub_no-matrix', 4): 23.057566440062317,
  ('sub_no-matrix', 5): 2.807467838359704},
 {('no-sub_matrix', 1): 11.992867568477442,
  ('no-sub_matrix', 2): 45.813114232935774,
  ('no-sub_matrix', 3): 24.57554828372551,
  ('no-sub_matrix', 4): 45.334025774062916,
  ('no-sub_matrix', 5): 26.208189541862073,
  ('no-sub_no-matrix', 1): 11.992867568477442,
  ('no-sub_no-matrix', 2): 45.813114232935774,
  ('no-sub_no-matrix', 3): 24.57554828372551,
  ('no-sub_no-matrix', 4): 45.33402766587207,
  ('no-sub_no-matrix', 5): 1.8284485151385752,
  ('sub_matrix', 1): 14.219887768799735,
  ('sub_matrix', 2): 46.25055434117979,
  ('sub_matrix', 3): 23.054221678472672,
  ('sub_matrix', 4): 47.08503858470256,
  ('sub_matrix', 5): 22.154772321452022,
  ('sub_no-matrix', 1): 14.219887768799735,
  ('sub_no-matrix', 2): 46.25055434117979,
  ('sub_no-matrix', 3): 23.054221678472672,
  ('sub_no-matrix', 4): 47.08503858470256,
  ('sub_no-matrix', 5): 3.0655133594366757},
 {('no-sub_matrix', 1): 10.55002943802296,
  ('no-sub_matrix', 2): 52.419810137608856,
  ('no-sub_matrix', 3): 23.30710475332303,
  ('no-sub_matrix', 4): 37.957905964008944,
  ('no-sub_matrix', 5): 29.259648135104936,
  ('no-sub_no-matrix', 1): 10.55002943802296,
  ('no-sub_no-matrix', 2): 52.419810137608856,
  ('no-sub_no-matrix', 3): 23.30710475332303,
  ('no-sub_no-matrix', 4): 37.957905964008944,
  ('no-sub_no-matrix', 5): 1.9632913405649093,
  ('sub_matrix', 1): 15.289384584900025,
  ('sub_matrix', 2): 53.93652737134243,
  ('sub_matrix', 3): 19.43915835312633,
  ('sub_matrix', 4): 36.459591551099386,
  ('sub_matrix', 5): 22.185742699245417,
  ('sub_no-matrix', 1): 15.289384584900025,
  ('sub_no-matrix', 2): 53.93652737134243,
  ('sub_no-matrix', 3): 19.43915835312633,
  ('sub_no-matrix', 4): 36.4595598203003,
  ('sub_no-matrix', 5): 5.707732355645454},
 {('no-sub_matrix', 1): 23.543723213902986,
  ('no-sub_matrix', 2): 31.967972102825854,
  ('no-sub_matrix', 3): 29.159572978411727,
  ('no-sub_matrix', 4): 36.61365345925747,
  ('no-sub_matrix', 5): 44.576591305970545,
  ('no-sub_no-matrix', 1): 23.543723213902986,
  ('no-sub_no-matrix', 2): 31.967972102825854,
  ('no-sub_no-matrix', 3): 29.159572978411727,
  ('no-sub_no-matrix', 4): 36.61365345925747,
  ('no-sub_no-matrix', 5): 3.2813457388593714,
  ('sub_matrix', 1): 27.118410129310597,
  ('sub_matrix', 2): 33.909617362987866,
  ('sub_matrix', 3): 28.791166362258743,
  ('sub_matrix', 4): 37.24960609010374,
  ('sub_matrix', 5): 31.660933798006262,
  ('sub_no-matrix', 1): 27.118410129310597,
  ('sub_no-matrix', 2): 33.909617362987866,
  ('sub_no-matrix', 3): 28.791166362258743,
  ('sub_no-matrix', 4): 37.24960609010374,
  ('sub_no-matrix', 5): 7.3613541428239015},
 {('no-sub_matrix', 1): 14.22171869610082,
  ('no-sub_matrix', 2): 30.270423022911977,
  ('no-sub_matrix', 3): 25.973276891204705,
  ('no-sub_matrix', 4): 28.43856735947716,
  ('no-sub_matrix', 5): 57.39887418731055,
  ('no-sub_no-matrix', 1): 14.22171869610082,
  ('no-sub_no-matrix', 2): 30.270423022911977,
  ('no-sub_no-matrix', 3): 25.973276891204705,
  ('no-sub_no-matrix', 4): 28.43856735947716,
  ('no-sub_no-matrix', 5): 1.7127059109344136,
  ('sub_matrix', 1): 16.39289784951447,
  ('sub_matrix', 2): 31.5671111565765,
  ('sub_matrix', 3): 24.54307828171008,
  ('sub_matrix', 4): 29.249645624130757,
  ('sub_matrix', 5): 53.59155769093577,
  ('sub_no-matrix', 1): 16.39289784951447,
  ('sub_no-matrix', 2): 31.5671111565765,
  ('sub_no-matrix', 3): 24.54307828171008,
  ('sub_no-matrix', 4): 29.249645624130757,
  ('sub_no-matrix', 5): 7.225276653947023},
 {('no-sub_matrix', 1): 13.729688714733188,
  ('no-sub_matrix', 2): 36.018118127225165,
  ('no-sub_matrix', 3): 28.232055923783275,
  ('no-sub_matrix', 4): 44.44634394296659,
  ('no-sub_matrix', 5): 38.277975147059344,
  ('no-sub_no-matrix', 1): 13.729688714733188,
  ('no-sub_no-matrix', 2): 36.018118127225165,
  ('no-sub_no-matrix', 3): 28.232055923783275,
  ('no-sub_no-matrix', 4): 44.44634394296659,
  ('no-sub_no-matrix', 5): 3.0318996942908414,
  ('sub_matrix', 1): 16.93528744674245,
  ('sub_matrix', 2): 36.545024814326574,
  ('sub_matrix', 3): 26.279603445823692,
  ('sub_matrix', 4): 46.501226364074995,
  ('sub_matrix', 5): 32.155418057793035,
  ('sub_no-matrix', 1): 16.93528744674245,
  ('sub_no-matrix', 2): 36.545024814326574,
  ('sub_no-matrix', 3): 26.279603445823692,
  ('sub_no-matrix', 4): 46.501226364074995,
  ('sub_no-matrix', 5): 4.4581122618864155},
 {('no-sub_matrix', 1): 15.598113737151568,
  ('no-sub_matrix', 2): 56.12543415244172,
  ('no-sub_matrix', 3): 29.755667770007285,
  ('no-sub_matrix', 4): 51.689282097269995,
  ('no-sub_matrix', 5): 45.575230324010775,
  ('no-sub_no-matrix', 1): 15.598113737151568,
  ('no-sub_no-matrix', 2): 56.12543415244172,
  ('no-sub_no-matrix', 3): 29.755667770007285,
  ('no-sub_no-matrix', 4): 51.68928424705313,
  ('no-sub_no-matrix', 5): 1.235207173694806,
  ('sub_matrix', 1): 18.909088991066888,
  ('sub_matrix', 2): 57.753410746636746,
  ('sub_matrix', 3): 28.677667873674363,
  ('sub_matrix', 4): 51.99410775929489,
  ('sub_matrix', 5): 35.754144966112236,
  ('sub_no-matrix', 1): 18.909088991066888,
  ('sub_no-matrix', 2): 57.753410746636746,
  ('sub_no-matrix', 3): 28.677667873674363,
  ('sub_no-matrix', 4): 51.9941480032352,
  ('sub_no-matrix', 5): 5.033266273930268},
 {('no-sub_matrix', 1): 14.859413855165633,
  ('no-sub_matrix', 2): 34.54519231993284,
  ('no-sub_matrix', 3): 24.26528519671309,
  ('no-sub_matrix', 4): 35.42343514121054,
  ('no-sub_matrix', 5): 55.85308623165151,
  ('no-sub_no-matrix', 1): 14.859413855165633,
  ('no-sub_no-matrix', 2): 34.54519231993284,
  ('no-sub_no-matrix', 3): 24.26528519671309,
  ('no-sub_no-matrix', 4): 35.42343514121054,
  ('no-sub_no-matrix', 5): 2.3309861205259734,
  ('sub_matrix', 1): 17.053809634549854,
  ('sub_matrix', 2): 33.66637542056656,
  ('sub_matrix', 3): 23.26181234829638,
  ('sub_matrix', 4): 35.61438567264568,
  ('sub_matrix', 5): 48.48551986050014,
  ('sub_no-matrix', 1): 17.053809634549854,
  ('sub_no-matrix', 2): 33.66637542056656,
  ('sub_no-matrix', 3): 23.26181234829638,
  ('sub_no-matrix', 4): 35.61438704850689,
  ('sub_no-matrix', 5): 2.969309360231736},
 {('no-sub_matrix', 1): 13.708973748402064,
  ('no-sub_matrix', 2): 31.147590264691182,
  ('no-sub_matrix', 3): 30.495597241955565,
  ('no-sub_matrix', 4): 34.65164493728535,
  ('no-sub_matrix', 5): 35.87510990950117,
  ('no-sub_no-matrix', 1): 13.708973748402064,
  ('no-sub_no-matrix', 2): 31.147590264691182,
  ('no-sub_no-matrix', 3): 30.495597241955565,
  ('no-sub_no-matrix', 4): 34.65164493728535,
  ('no-sub_no-matrix', 5): 3.232032121481573,
  ('sub_matrix', 1): 17.681722076468287,
  ('sub_matrix', 2): 33.77225997922327,
  ('sub_matrix', 3): 29.435808932487806,
  ('sub_matrix', 4): 34.354368969668016,
  ('sub_matrix', 5): 20.802733205442486,
  ('sub_no-matrix', 1): 17.681722076468287,
  ('sub_no-matrix', 2): 33.77225997922327,
  ('sub_no-matrix', 3): 29.435808932487806,
  ('sub_no-matrix', 4): 34.354368969668016,
  ('sub_no-matrix', 5): 3.7902066303710424},
 {('no-sub_matrix', 1): 15.72185319065555,
  ('no-sub_matrix', 2): 45.25539814380218,
  ('no-sub_matrix', 3): 24.94273362957689,
  ('no-sub_matrix', 4): 40.81704901026569,
  ('no-sub_matrix', 5): 42.898794519499596,
  ('no-sub_no-matrix', 1): 15.72185319065555,
  ('no-sub_no-matrix', 2): 45.25539814380218,
  ('no-sub_no-matrix', 3): 24.94273362957689,
  ('no-sub_no-matrix', 4): 40.81704901026569,
  ('no-sub_no-matrix', 5): 2.6826901255924644,
  ('sub_matrix', 1): 17.565795106862403,
  ('sub_matrix', 2): 46.9371803702329,
  ('sub_matrix', 3): 23.887805807796486,
  ('sub_matrix', 4): 39.058599411828766,
  ('sub_matrix', 5): 32.234453544910295,
  ('sub_no-matrix', 1): 17.565795106862403,
  ('sub_no-matrix', 2): 46.9371803702329,
  ('sub_no-matrix', 3): 23.887805807796486,
  ('sub_no-matrix', 4): 39.058599411828766,
  ('sub_no-matrix', 5): 4.214674259243127},
 {('no-sub_matrix', 1): 13.910878628792588,
  ('no-sub_matrix', 2): 33.45626834359109,
  ('no-sub_matrix', 3): 16.127584513594687,
  ('no-sub_matrix', 4): 32.59623120264939,
  ('no-sub_matrix', 5): 29.87568851789407,
  ('no-sub_no-matrix', 1): 13.910878628792588,
  ('no-sub_no-matrix', 2): 33.45626834359109,
  ('no-sub_no-matrix', 3): 16.127584513594687,
  ('no-sub_no-matrix', 4): 32.59623120264939,
  ('no-sub_no-matrix', 5): 2.3891779982892625,
  ('sub_matrix', 1): 17.18981661053988,
  ('sub_matrix', 2): 36.38883326650068,
  ('sub_matrix', 3): 13.081088737716442,
  ('sub_matrix', 4): 33.419732612590224,
  ('sub_matrix', 5): 22.665485632721676,
  ('sub_no-matrix', 1): 17.18981661053988,
  ('sub_no-matrix', 2): 36.38883326650068,
  ('sub_no-matrix', 3): 13.081088737716442,
  ('sub_no-matrix', 4): 33.419732612590224,
  ('sub_no-matrix', 5): 6.155199912348024},
 {('no-sub_matrix', 1): 18.196771699177763,
  ('no-sub_matrix', 2): 35.624058750852136,
  ('no-sub_matrix', 3): 23.746554392851053,
  ('no-sub_matrix', 4): 29.44669921790574,
  ('no-sub_matrix', 5): 39.72412918901379,
  ('no-sub_no-matrix', 1): 18.196771699177763,
  ('no-sub_no-matrix', 2): 35.624058750852136,
  ('no-sub_no-matrix', 3): 23.746554392851053,
  ('no-sub_no-matrix', 4): 29.44669921790574,
  ('no-sub_no-matrix', 5): 2.870123353843486,
  ('sub_matrix', 1): 20.38619930823735,
  ('sub_matrix', 2): 36.29781144853154,
  ('sub_matrix', 3): 22.13637404741934,
  ('sub_matrix', 4): 29.68729899086184,
  ('sub_matrix', 5): 36.993790238103884,
  ('sub_no-matrix', 1): 20.38619930823735,
  ('sub_no-matrix', 2): 36.29781144853154,
  ('sub_no-matrix', 3): 22.13637404741934,
  ('sub_no-matrix', 4): 29.68729899086184,
  ('sub_no-matrix', 5): 7.650303570399713},
 {('no-sub_matrix', 1): 11.992867568477442,
  ('no-sub_matrix', 2): 26.44083030170154,
  ('no-sub_matrix', 3): 27.574921221726136,
  ('no-sub_matrix', 4): 28.94213565689118,
  ('no-sub_matrix', 5): 46.973469397495556,
  ('no-sub_no-matrix', 1): 11.992867568477442,
  ('no-sub_no-matrix', 2): 26.44083030170154,
  ('no-sub_no-matrix', 3): 27.574921221726136,
  ('no-sub_no-matrix', 4): 28.94213565689118,
  ('no-sub_no-matrix', 5): 3.354326576753004,
  ('sub_matrix', 1): 14.434047100994839,
  ('sub_matrix', 2): 26.76571524620116,
  ('sub_matrix', 3): 25.83488399989926,
  ('sub_matrix', 4): 30.263621195061678,
  ('sub_matrix', 5): 36.822532494114455,
  ('sub_no-matrix', 1): 14.434047100994839,
  ('sub_no-matrix', 2): 26.76571524620116,
  ('sub_no-matrix', 3): 25.83488399989926,
  ('sub_no-matrix', 4): 30.263621195061678,
  ('sub_no-matrix', 5): 6.748976893757906},
 {('no-sub_matrix', 1): 16.27614914680276,
  ('no-sub_matrix', 2): 41.35282905624703,
  ('no-sub_matrix', 3): 25.173115913245226,
  ('no-sub_matrix', 4): 52.876981987369014,
  ('no-sub_matrix', 5): 49.49767321075167,
  ('no-sub_no-matrix', 1): 16.27614914680276,
  ('no-sub_no-matrix', 2): 41.35282905624703,
  ('no-sub_no-matrix', 3): 25.173115913245226,
  ('no-sub_no-matrix', 4): 52.876981987369014,
  ('no-sub_no-matrix', 5): 1.5962803636236758,
  ('sub_matrix', 1): 18.735912436641787,
  ('sub_matrix', 2): 43.36213985849511,
  ('sub_matrix', 3): 24.582800598631913,
  ('sub_matrix', 4): 53.1616607417586,
  ('sub_matrix', 5): 41.2664433745972,
  ('sub_no-matrix', 1): 18.735912436641787,
  ('sub_no-matrix', 2): 43.36213985849511,
  ('sub_no-matrix', 3): 24.582800598631913,
  ('sub_no-matrix', 4): 53.16165799003619,
  ('sub_no-matrix', 5): 6.4917878462822305},
 {('no-sub_matrix', 1): 14.036280122634507,
  ('no-sub_matrix', 2): 53.72802368862095,
  ('no-sub_matrix', 3): 18.940766131564004,
  ('no-sub_matrix', 4): 40.74964840745327,
  ('no-sub_matrix', 5): 39.57008490907742,
  ('no-sub_no-matrix', 1): 14.036280122634507,
  ('no-sub_no-matrix', 2): 53.72802368862095,
  ('no-sub_no-matrix', 3): 18.940766131564004,
  ('no-sub_no-matrix', 4): 40.74964840745327,
  ('no-sub_no-matrix', 5): 2.1275557540222967,
  ('sub_matrix', 1): 19.641722357026286,
  ('sub_matrix', 2): 52.709120728751486,
  ('sub_matrix', 3): 17.976257844509426,
  ('sub_matrix', 4): 42.51851542500959,
  ('sub_matrix', 5): 28.25018664655579,
  ('sub_no-matrix', 1): 19.641722357026286,
  ('sub_no-matrix', 2): 52.709120728751486,
  ('sub_no-matrix', 3): 17.976257844509426,
  ('sub_no-matrix', 4): 42.51851267328718,
  ('sub_no-matrix', 5): 5.409622788119386},
 {('no-sub_matrix', 1): 16.961927903326398,
  ('no-sub_matrix', 2): 38.5455951142925,
  ('no-sub_matrix', 3): 25.122316709729276,
  ('no-sub_matrix', 4): 35.90131439006518,
  ('no-sub_matrix', 5): 41.65886977570029,
  ('no-sub_no-matrix', 1): 16.961927903326398,
  ('no-sub_no-matrix', 2): 38.5455951142925,
  ('no-sub_no-matrix', 3): 25.122316709729276,
  ('no-sub_no-matrix', 4): 35.90131439006518,
  ('no-sub_no-matrix', 5): 3.2679255886472447,
  ('sub_matrix', 1): 20.247934372024154,
  ('sub_matrix', 2): 40.408716019775625,
  ('sub_matrix', 3): 23.782735071043668,
  ('sub_matrix', 4): 37.00513584758997,
  ('sub_matrix', 5): 29.22700479607527,
  ('sub_no-matrix', 1): 20.247934372024154,
  ('sub_no-matrix', 2): 40.408716019775625,
  ('sub_no-matrix', 3): 23.782735071043668,
  ('sub_no-matrix', 4): 37.00513584758997,
  ('sub_no-matrix', 5): 4.780011845541033},
 {('no-sub_matrix', 1): 12.109815771064152,
  ('no-sub_matrix', 2): 38.32406752938649,
  ('no-sub_matrix', 3): 25.987801084044044,
  ('no-sub_matrix', 4): 40.40950903177875,
  ('no-sub_matrix', 5): 52.86522525335603,
  ('no-sub_no-matrix', 1): 12.109815771064152,
  ('no-sub_no-matrix', 2): 38.32406752938649,
  ('no-sub_no-matrix', 3): 25.987801084044044,
  ('no-sub_no-matrix', 4): 40.40950903177875,
  ('no-sub_no-matrix', 5): 3.61917194787979,
  ('sub_matrix', 1): 15.130341564722832,
  ('sub_matrix', 2): 37.89719334728088,
  ('sub_matrix', 3): 24.65681032273433,
  ('sub_matrix', 4): 40.731610867030774,
  ('sub_matrix', 5): 37.566910985257906,
  ('sub_no-matrix', 1): 15.130341564722832,
  ('sub_no-matrix', 2): 37.89719334728088,
  ('sub_no-matrix', 3): 24.65681032273433,
  ('sub_no-matrix', 4): 40.731610867030774,
  ('sub_no-matrix', 5): 9.39736249989602},
 {('no-sub_matrix', 1): 16.25058564557851,
  ('no-sub_matrix', 2): 37.20405682898803,
  ('no-sub_matrix', 3): 30.5107090995129,
  ('no-sub_matrix', 4): 44.537084655292894,
  ('no-sub_matrix', 5): 46.50046620075818,
  ('no-sub_no-matrix', 1): 16.25058564557851,
  ('no-sub_no-matrix', 2): 37.20405682898803,
  ('no-sub_no-matrix', 3): 30.5107090995129,
  ('no-sub_no-matrix', 4): 44.537084655292894,
  ('no-sub_no-matrix', 5): 1.8752506698658238,
  ('sub_matrix', 1): 18.440281483079957,
  ('sub_matrix', 2): 38.54769605435544,
  ('sub_matrix', 3): 30.510800250317864,
  ('sub_matrix', 4): 44.99740645329493,
  ('sub_matrix', 5): 39.55738177603457,
  ('sub_no-matrix', 1): 18.440281483079957,
  ('sub_no-matrix', 2): 38.54769605435544,
  ('sub_no-matrix', 3): 30.510800250317864,
  ('sub_no-matrix', 4): 44.99740645329493,
  ('sub_no-matrix', 5): 2.6233048602148386},
 {('no-sub_matrix', 1): 16.324447378609865,
  ('no-sub_matrix', 2): 30.87308462806543,
  ('no-sub_matrix', 3): 22.765564836381643,
  ('no-sub_matrix', 4): 38.337445027901204,
  ('no-sub_matrix', 5): 40.98815076599078,
  ('no-sub_no-matrix', 1): 16.324447378609865,
  ('no-sub_no-matrix', 2): 30.87308462806543,
  ('no-sub_no-matrix', 3): 22.765564836381643,
  ('no-sub_no-matrix', 4): 38.337445027901204,
  ('no-sub_no-matrix', 5): 1.4796406979126138,
  ('sub_matrix', 1): 17.9623592385626,
  ('sub_matrix', 2): 32.36568198294609,
  ('sub_matrix', 3): 22.438215466486483,
  ('sub_matrix', 4): 40.900713840387546,
  ('sub_matrix', 5): 33.396627340011634,
  ('sub_no-matrix', 1): 17.9623592385626,
  ('sub_no-matrix', 2): 32.36568198294609,
  ('sub_no-matrix', 3): 22.438215466486483,
  ('sub_no-matrix', 4): 40.900713840387546,
  ('sub_no-matrix', 5): 6.609518913895668},
 {('no-sub_matrix', 1): 14.033258731424148,
  ('no-sub_matrix', 2): 28.37206528002418,
  ('no-sub_matrix', 3): 27.043658386061033,
  ('no-sub_matrix', 4): 36.167049513436204,
  ('no-sub_matrix', 5): 52.280797076864395,
  ('no-sub_no-matrix', 1): 14.033258731424148,
  ('no-sub_no-matrix', 2): 28.37206528002418,
  ('no-sub_no-matrix', 3): 27.043658386061033,
  ('no-sub_no-matrix', 4): 36.167049513436204,
  ('no-sub_no-matrix', 5): 1.9358795417918389,
  ('sub_matrix', 1): 16.606623097498794,
  ('sub_matrix', 2): 29.98729916366884,
  ('sub_matrix', 3): 24.737985875967603,
  ('sub_matrix', 4): 34.93154214402433,
  ('sub_matrix', 5): 42.35241303296243,
  ('sub_no-matrix', 1): 16.606623097498794,
  ('sub_no-matrix', 2): 29.98729916366884,
  ('sub_no-matrix', 3): 24.737985875967603,
  ('sub_no-matrix', 4): 34.931551775052775,
  ('sub_no-matrix', 5): 7.151971456773863},
 {('no-sub_matrix', 1): 10.482293039084738,
  ('no-sub_matrix', 2): 52.67861788579445,
  ('no-sub_matrix', 3): 21.665543335527666,
  ('no-sub_matrix', 4): 23.53727708917033,
  ('no-sub_matrix', 5): 32.2645584918966,
  ('no-sub_no-matrix', 1): 10.482293039084738,
  ('no-sub_no-matrix', 2): 52.67861788579445,
  ('no-sub_no-matrix', 3): 21.665543335527666,
  ('no-sub_no-matrix', 4): 23.53727708917033,
  ('no-sub_no-matrix', 5): 2.5207572809328243,
  ('sub_matrix', 1): 11.523882918360123,
  ('sub_matrix', 2): 57.336257883871156,
  ('sub_matrix', 3): 21.647716645835132,
  ('sub_matrix', 4): 23.491483569694733,
  ('sub_matrix', 5): 24.264706351480406,
  ('sub_no-matrix', 1): 11.523882918360123,
  ('sub_no-matrix', 2): 57.336257883871156,
  ('sub_no-matrix', 3): 21.647716645835132,
  ('sub_no-matrix', 4): 23.491462243846026,
  ('sub_no-matrix', 5): 9.714244661694366},
 {('no-sub_matrix', 1): 11.992867568477442,
  ('no-sub_matrix', 2): 28.861638231250264,
  ('no-sub_matrix', 3): 24.222607873884137,
  ('no-sub_matrix', 4): 41.28280460012173,
  ('no-sub_matrix', 5): 56.6084264455065,
  ('no-sub_no-matrix', 1): 11.992867568477442,
  ('no-sub_no-matrix', 2): 28.861638231250264,
  ('no-sub_no-matrix', 3): 24.222607873884137,
  ('no-sub_no-matrix', 4): 41.28280460012173,
  ('no-sub_no-matrix', 5): 2.4980576348107437,
  ('sub_matrix', 1): 14.531057698832324,
  ('sub_matrix', 2): 31.280393934821902,
  ('sub_matrix', 3): 20.756528260470358,
  ('sub_matrix', 4): 42.15937712589425,
  ('sub_matrix', 5): 52.45767194621365,
  ('sub_no-matrix', 1): 14.531057698832324,
  ('sub_no-matrix', 2): 31.280393934821902,
  ('sub_no-matrix', 3): 20.756528260470358,
  ('sub_no-matrix', 4): 42.15937712589425,
  ('sub_no-matrix', 5): 4.819862633503057}]


def test_gpt_subordination_region_totals():
    """
    Check region-level surprisals against the original syntaxgym-core
    implementation, using the same underlying `gpt2` model.
    """
    reference = ...  # TODO

    # TODO work out references
    dataset = datasets.load_dataset("cpllab/syntaxgym", "subordination_src-src")
    metric = evaluate.load("./syntaxgym.py")
    result = metric.compute(suite=dataset["test"], model_id="gpt2")

    from pprint import pprint
    pprint(result["region_totals"][0])
    pprint(GPT2_SUBORDINATION_SRC_REFERENCE[0])

    keys = result["region_totals"][0].keys()
    assert set(keys) == set(GPT2_SUBORDINATION_SRC_REFERENCE[0].keys())

    result_ndarray = np.concatenate([np.array([region_totals[key] for key in keys])
                                     for region_totals in result["region_totals"]])
    reference_ndarray = np.concatenate([np.array([region_totals[key] for key in keys])
                                        for region_totals in GPT2_SUBORDINATION_SRC_REFERENCE])
    pprint(sorted(zip(keys, np.abs(result_ndarray - reference_ndarray)),
                  key=lambda x: -x[1]))
    np.testing.assert_allclose(result_ndarray, reference_ndarray, atol=1e-3)