Spaces:
Sleeping
Sleeping
from typing import List | |
import datasets | |
import evaluate | |
import numpy as np | |
import pytest | |
def syntaxgym_dataset(): | |
return datasets.load_dataset("syntaxgym", "subordination_src-src") | |
def syntaxgym_metric(): | |
# TODO work out reference | |
return evaluate.load("./syntaxgym.py") | |
def model_ref(): | |
# return "hf-internal-testing/tiny-random-gpt_neo" | |
return "gpt2" | |
# Reference region surprisals computed with syntaxgym-core. | |
# See notebook in https://colab.research.google.com/drive/1qziyPcu65jffizSPi-ZGHKR0x7BaHFMS#scrollTo=RgtnScy6LLKi . | |
GPT2_SUBORDINATION_SRC_REFERENCE = \ | |
[{('no-sub_matrix', 1): 13.151199615123803, | |
('no-sub_matrix', 2): 38.503222716703526, | |
('no-sub_matrix', 3): 27.623861034812286, | |
('no-sub_matrix', 4): 48.831672846038224, | |
('no-sub_matrix', 5): 38.08533699286694, | |
('no-sub_no-matrix', 1): 13.151199615123803, | |
('no-sub_no-matrix', 2): 38.503222716703526, | |
('no-sub_no-matrix', 3): 27.623861034812286, | |
('no-sub_no-matrix', 4): 48.831687980511504, | |
('no-sub_no-matrix', 5): 1.8096143510772873, | |
('sub_matrix', 1): 14.905592916748805, | |
('sub_matrix', 2): 39.06304309956175, | |
('sub_matrix', 3): 26.862648365854433, | |
('sub_matrix', 4): 50.56554401687938, | |
('sub_matrix', 5): 26.532245572980194, | |
('sub_no-matrix', 1): 14.905592916748805, | |
('sub_no-matrix', 2): 39.06304309956175, | |
('sub_no-matrix', 3): 26.862648365854433, | |
('sub_no-matrix', 4): 50.56553438585093, | |
('sub_no-matrix', 5): 7.470089829866611}, | |
{('no-sub_matrix', 1): 10.116093820255577, | |
('no-sub_matrix', 2): 20.96513246705127, | |
('no-sub_matrix', 3): 20.02959138986416, | |
('no-sub_matrix', 4): 23.779661397107446, | |
('no-sub_matrix', 5): 33.2560281692696, | |
('no-sub_no-matrix', 1): 10.116093820255577, | |
('no-sub_no-matrix', 2): 20.96513246705127, | |
('no-sub_no-matrix', 3): 20.02959138986416, | |
('no-sub_no-matrix', 4): 23.779661397107446, | |
('no-sub_no-matrix', 5): 1.9449125865631063, | |
('sub_matrix', 1): 13.545157521732826, | |
('sub_matrix', 2): 24.96048395897244, | |
('sub_matrix', 3): 18.609464944317324, | |
('sub_matrix', 4): 23.057566440062317, | |
('sub_matrix', 5): 26.424454285669032, | |
('sub_no-matrix', 1): 13.545157521732826, | |
('sub_no-matrix', 2): 24.96048395897244, | |
('sub_no-matrix', 3): 18.609464944317324, | |
('sub_no-matrix', 4): 23.057566440062317, | |
('sub_no-matrix', 5): 2.807467838359704}, | |
{('no-sub_matrix', 1): 11.992867568477442, | |
('no-sub_matrix', 2): 45.813114232935774, | |
('no-sub_matrix', 3): 24.57554828372551, | |
('no-sub_matrix', 4): 45.334025774062916, | |
('no-sub_matrix', 5): 26.208189541862073, | |
('no-sub_no-matrix', 1): 11.992867568477442, | |
('no-sub_no-matrix', 2): 45.813114232935774, | |
('no-sub_no-matrix', 3): 24.57554828372551, | |
('no-sub_no-matrix', 4): 45.33402766587207, | |
('no-sub_no-matrix', 5): 1.8284485151385752, | |
('sub_matrix', 1): 14.219887768799735, | |
('sub_matrix', 2): 46.25055434117979, | |
('sub_matrix', 3): 23.054221678472672, | |
('sub_matrix', 4): 47.08503858470256, | |
('sub_matrix', 5): 22.154772321452022, | |
('sub_no-matrix', 1): 14.219887768799735, | |
('sub_no-matrix', 2): 46.25055434117979, | |
('sub_no-matrix', 3): 23.054221678472672, | |
('sub_no-matrix', 4): 47.08503858470256, | |
('sub_no-matrix', 5): 3.0655133594366757}, | |
{('no-sub_matrix', 1): 10.55002943802296, | |
('no-sub_matrix', 2): 52.419810137608856, | |
('no-sub_matrix', 3): 23.30710475332303, | |
('no-sub_matrix', 4): 37.957905964008944, | |
('no-sub_matrix', 5): 29.259648135104936, | |
('no-sub_no-matrix', 1): 10.55002943802296, | |
('no-sub_no-matrix', 2): 52.419810137608856, | |
('no-sub_no-matrix', 3): 23.30710475332303, | |
('no-sub_no-matrix', 4): 37.957905964008944, | |
('no-sub_no-matrix', 5): 1.9632913405649093, | |
('sub_matrix', 1): 15.289384584900025, | |
('sub_matrix', 2): 53.93652737134243, | |
('sub_matrix', 3): 19.43915835312633, | |
('sub_matrix', 4): 36.459591551099386, | |
('sub_matrix', 5): 22.185742699245417, | |
('sub_no-matrix', 1): 15.289384584900025, | |
('sub_no-matrix', 2): 53.93652737134243, | |
('sub_no-matrix', 3): 19.43915835312633, | |
('sub_no-matrix', 4): 36.4595598203003, | |
('sub_no-matrix', 5): 5.707732355645454}, | |
{('no-sub_matrix', 1): 23.543723213902986, | |
('no-sub_matrix', 2): 31.967972102825854, | |
('no-sub_matrix', 3): 29.159572978411727, | |
('no-sub_matrix', 4): 36.61365345925747, | |
('no-sub_matrix', 5): 44.576591305970545, | |
('no-sub_no-matrix', 1): 23.543723213902986, | |
('no-sub_no-matrix', 2): 31.967972102825854, | |
('no-sub_no-matrix', 3): 29.159572978411727, | |
('no-sub_no-matrix', 4): 36.61365345925747, | |
('no-sub_no-matrix', 5): 3.2813457388593714, | |
('sub_matrix', 1): 27.118410129310597, | |
('sub_matrix', 2): 33.909617362987866, | |
('sub_matrix', 3): 28.791166362258743, | |
('sub_matrix', 4): 37.24960609010374, | |
('sub_matrix', 5): 31.660933798006262, | |
('sub_no-matrix', 1): 27.118410129310597, | |
('sub_no-matrix', 2): 33.909617362987866, | |
('sub_no-matrix', 3): 28.791166362258743, | |
('sub_no-matrix', 4): 37.24960609010374, | |
('sub_no-matrix', 5): 7.3613541428239015}, | |
{('no-sub_matrix', 1): 14.22171869610082, | |
('no-sub_matrix', 2): 30.270423022911977, | |
('no-sub_matrix', 3): 25.973276891204705, | |
('no-sub_matrix', 4): 28.43856735947716, | |
('no-sub_matrix', 5): 57.39887418731055, | |
('no-sub_no-matrix', 1): 14.22171869610082, | |
('no-sub_no-matrix', 2): 30.270423022911977, | |
('no-sub_no-matrix', 3): 25.973276891204705, | |
('no-sub_no-matrix', 4): 28.43856735947716, | |
('no-sub_no-matrix', 5): 1.7127059109344136, | |
('sub_matrix', 1): 16.39289784951447, | |
('sub_matrix', 2): 31.5671111565765, | |
('sub_matrix', 3): 24.54307828171008, | |
('sub_matrix', 4): 29.249645624130757, | |
('sub_matrix', 5): 53.59155769093577, | |
('sub_no-matrix', 1): 16.39289784951447, | |
('sub_no-matrix', 2): 31.5671111565765, | |
('sub_no-matrix', 3): 24.54307828171008, | |
('sub_no-matrix', 4): 29.249645624130757, | |
('sub_no-matrix', 5): 7.225276653947023}, | |
{('no-sub_matrix', 1): 13.729688714733188, | |
('no-sub_matrix', 2): 36.018118127225165, | |
('no-sub_matrix', 3): 28.232055923783275, | |
('no-sub_matrix', 4): 44.44634394296659, | |
('no-sub_matrix', 5): 38.277975147059344, | |
('no-sub_no-matrix', 1): 13.729688714733188, | |
('no-sub_no-matrix', 2): 36.018118127225165, | |
('no-sub_no-matrix', 3): 28.232055923783275, | |
('no-sub_no-matrix', 4): 44.44634394296659, | |
('no-sub_no-matrix', 5): 3.0318996942908414, | |
('sub_matrix', 1): 16.93528744674245, | |
('sub_matrix', 2): 36.545024814326574, | |
('sub_matrix', 3): 26.279603445823692, | |
('sub_matrix', 4): 46.501226364074995, | |
('sub_matrix', 5): 32.155418057793035, | |
('sub_no-matrix', 1): 16.93528744674245, | |
('sub_no-matrix', 2): 36.545024814326574, | |
('sub_no-matrix', 3): 26.279603445823692, | |
('sub_no-matrix', 4): 46.501226364074995, | |
('sub_no-matrix', 5): 4.4581122618864155}, | |
{('no-sub_matrix', 1): 15.598113737151568, | |
('no-sub_matrix', 2): 56.12543415244172, | |
('no-sub_matrix', 3): 29.755667770007285, | |
('no-sub_matrix', 4): 51.689282097269995, | |
('no-sub_matrix', 5): 45.575230324010775, | |
('no-sub_no-matrix', 1): 15.598113737151568, | |
('no-sub_no-matrix', 2): 56.12543415244172, | |
('no-sub_no-matrix', 3): 29.755667770007285, | |
('no-sub_no-matrix', 4): 51.68928424705313, | |
('no-sub_no-matrix', 5): 1.235207173694806, | |
('sub_matrix', 1): 18.909088991066888, | |
('sub_matrix', 2): 57.753410746636746, | |
('sub_matrix', 3): 28.677667873674363, | |
('sub_matrix', 4): 51.99410775929489, | |
('sub_matrix', 5): 35.754144966112236, | |
('sub_no-matrix', 1): 18.909088991066888, | |
('sub_no-matrix', 2): 57.753410746636746, | |
('sub_no-matrix', 3): 28.677667873674363, | |
('sub_no-matrix', 4): 51.9941480032352, | |
('sub_no-matrix', 5): 5.033266273930268}, | |
{('no-sub_matrix', 1): 14.859413855165633, | |
('no-sub_matrix', 2): 34.54519231993284, | |
('no-sub_matrix', 3): 24.26528519671309, | |
('no-sub_matrix', 4): 35.42343514121054, | |
('no-sub_matrix', 5): 55.85308623165151, | |
('no-sub_no-matrix', 1): 14.859413855165633, | |
('no-sub_no-matrix', 2): 34.54519231993284, | |
('no-sub_no-matrix', 3): 24.26528519671309, | |
('no-sub_no-matrix', 4): 35.42343514121054, | |
('no-sub_no-matrix', 5): 2.3309861205259734, | |
('sub_matrix', 1): 17.053809634549854, | |
('sub_matrix', 2): 33.66637542056656, | |
('sub_matrix', 3): 23.26181234829638, | |
('sub_matrix', 4): 35.61438567264568, | |
('sub_matrix', 5): 48.48551986050014, | |
('sub_no-matrix', 1): 17.053809634549854, | |
('sub_no-matrix', 2): 33.66637542056656, | |
('sub_no-matrix', 3): 23.26181234829638, | |
('sub_no-matrix', 4): 35.61438704850689, | |
('sub_no-matrix', 5): 2.969309360231736}, | |
{('no-sub_matrix', 1): 13.708973748402064, | |
('no-sub_matrix', 2): 31.147590264691182, | |
('no-sub_matrix', 3): 30.495597241955565, | |
('no-sub_matrix', 4): 34.65164493728535, | |
('no-sub_matrix', 5): 35.87510990950117, | |
('no-sub_no-matrix', 1): 13.708973748402064, | |
('no-sub_no-matrix', 2): 31.147590264691182, | |
('no-sub_no-matrix', 3): 30.495597241955565, | |
('no-sub_no-matrix', 4): 34.65164493728535, | |
('no-sub_no-matrix', 5): 3.232032121481573, | |
('sub_matrix', 1): 17.681722076468287, | |
('sub_matrix', 2): 33.77225997922327, | |
('sub_matrix', 3): 29.435808932487806, | |
('sub_matrix', 4): 34.354368969668016, | |
('sub_matrix', 5): 20.802733205442486, | |
('sub_no-matrix', 1): 17.681722076468287, | |
('sub_no-matrix', 2): 33.77225997922327, | |
('sub_no-matrix', 3): 29.435808932487806, | |
('sub_no-matrix', 4): 34.354368969668016, | |
('sub_no-matrix', 5): 3.7902066303710424}, | |
{('no-sub_matrix', 1): 15.72185319065555, | |
('no-sub_matrix', 2): 45.25539814380218, | |
('no-sub_matrix', 3): 24.94273362957689, | |
('no-sub_matrix', 4): 40.81704901026569, | |
('no-sub_matrix', 5): 42.898794519499596, | |
('no-sub_no-matrix', 1): 15.72185319065555, | |
('no-sub_no-matrix', 2): 45.25539814380218, | |
('no-sub_no-matrix', 3): 24.94273362957689, | |
('no-sub_no-matrix', 4): 40.81704901026569, | |
('no-sub_no-matrix', 5): 2.6826901255924644, | |
('sub_matrix', 1): 17.565795106862403, | |
('sub_matrix', 2): 46.9371803702329, | |
('sub_matrix', 3): 23.887805807796486, | |
('sub_matrix', 4): 39.058599411828766, | |
('sub_matrix', 5): 32.234453544910295, | |
('sub_no-matrix', 1): 17.565795106862403, | |
('sub_no-matrix', 2): 46.9371803702329, | |
('sub_no-matrix', 3): 23.887805807796486, | |
('sub_no-matrix', 4): 39.058599411828766, | |
('sub_no-matrix', 5): 4.214674259243127}, | |
{('no-sub_matrix', 1): 13.910878628792588, | |
('no-sub_matrix', 2): 33.45626834359109, | |
('no-sub_matrix', 3): 16.127584513594687, | |
('no-sub_matrix', 4): 32.59623120264939, | |
('no-sub_matrix', 5): 29.87568851789407, | |
('no-sub_no-matrix', 1): 13.910878628792588, | |
('no-sub_no-matrix', 2): 33.45626834359109, | |
('no-sub_no-matrix', 3): 16.127584513594687, | |
('no-sub_no-matrix', 4): 32.59623120264939, | |
('no-sub_no-matrix', 5): 2.3891779982892625, | |
('sub_matrix', 1): 17.18981661053988, | |
('sub_matrix', 2): 36.38883326650068, | |
('sub_matrix', 3): 13.081088737716442, | |
('sub_matrix', 4): 33.419732612590224, | |
('sub_matrix', 5): 22.665485632721676, | |
('sub_no-matrix', 1): 17.18981661053988, | |
('sub_no-matrix', 2): 36.38883326650068, | |
('sub_no-matrix', 3): 13.081088737716442, | |
('sub_no-matrix', 4): 33.419732612590224, | |
('sub_no-matrix', 5): 6.155199912348024}, | |
{('no-sub_matrix', 1): 18.196771699177763, | |
('no-sub_matrix', 2): 35.624058750852136, | |
('no-sub_matrix', 3): 23.746554392851053, | |
('no-sub_matrix', 4): 29.44669921790574, | |
('no-sub_matrix', 5): 39.72412918901379, | |
('no-sub_no-matrix', 1): 18.196771699177763, | |
('no-sub_no-matrix', 2): 35.624058750852136, | |
('no-sub_no-matrix', 3): 23.746554392851053, | |
('no-sub_no-matrix', 4): 29.44669921790574, | |
('no-sub_no-matrix', 5): 2.870123353843486, | |
('sub_matrix', 1): 20.38619930823735, | |
('sub_matrix', 2): 36.29781144853154, | |
('sub_matrix', 3): 22.13637404741934, | |
('sub_matrix', 4): 29.68729899086184, | |
('sub_matrix', 5): 36.993790238103884, | |
('sub_no-matrix', 1): 20.38619930823735, | |
('sub_no-matrix', 2): 36.29781144853154, | |
('sub_no-matrix', 3): 22.13637404741934, | |
('sub_no-matrix', 4): 29.68729899086184, | |
('sub_no-matrix', 5): 7.650303570399713}, | |
{('no-sub_matrix', 1): 11.992867568477442, | |
('no-sub_matrix', 2): 26.44083030170154, | |
('no-sub_matrix', 3): 27.574921221726136, | |
('no-sub_matrix', 4): 28.94213565689118, | |
('no-sub_matrix', 5): 46.973469397495556, | |
('no-sub_no-matrix', 1): 11.992867568477442, | |
('no-sub_no-matrix', 2): 26.44083030170154, | |
('no-sub_no-matrix', 3): 27.574921221726136, | |
('no-sub_no-matrix', 4): 28.94213565689118, | |
('no-sub_no-matrix', 5): 3.354326576753004, | |
('sub_matrix', 1): 14.434047100994839, | |
('sub_matrix', 2): 26.76571524620116, | |
('sub_matrix', 3): 25.83488399989926, | |
('sub_matrix', 4): 30.263621195061678, | |
('sub_matrix', 5): 36.822532494114455, | |
('sub_no-matrix', 1): 14.434047100994839, | |
('sub_no-matrix', 2): 26.76571524620116, | |
('sub_no-matrix', 3): 25.83488399989926, | |
('sub_no-matrix', 4): 30.263621195061678, | |
('sub_no-matrix', 5): 6.748976893757906}, | |
{('no-sub_matrix', 1): 16.27614914680276, | |
('no-sub_matrix', 2): 41.35282905624703, | |
('no-sub_matrix', 3): 25.173115913245226, | |
('no-sub_matrix', 4): 52.876981987369014, | |
('no-sub_matrix', 5): 49.49767321075167, | |
('no-sub_no-matrix', 1): 16.27614914680276, | |
('no-sub_no-matrix', 2): 41.35282905624703, | |
('no-sub_no-matrix', 3): 25.173115913245226, | |
('no-sub_no-matrix', 4): 52.876981987369014, | |
('no-sub_no-matrix', 5): 1.5962803636236758, | |
('sub_matrix', 1): 18.735912436641787, | |
('sub_matrix', 2): 43.36213985849511, | |
('sub_matrix', 3): 24.582800598631913, | |
('sub_matrix', 4): 53.1616607417586, | |
('sub_matrix', 5): 41.2664433745972, | |
('sub_no-matrix', 1): 18.735912436641787, | |
('sub_no-matrix', 2): 43.36213985849511, | |
('sub_no-matrix', 3): 24.582800598631913, | |
('sub_no-matrix', 4): 53.16165799003619, | |
('sub_no-matrix', 5): 6.4917878462822305}, | |
{('no-sub_matrix', 1): 14.036280122634507, | |
('no-sub_matrix', 2): 53.72802368862095, | |
('no-sub_matrix', 3): 18.940766131564004, | |
('no-sub_matrix', 4): 40.74964840745327, | |
('no-sub_matrix', 5): 39.57008490907742, | |
('no-sub_no-matrix', 1): 14.036280122634507, | |
('no-sub_no-matrix', 2): 53.72802368862095, | |
('no-sub_no-matrix', 3): 18.940766131564004, | |
('no-sub_no-matrix', 4): 40.74964840745327, | |
('no-sub_no-matrix', 5): 2.1275557540222967, | |
('sub_matrix', 1): 19.641722357026286, | |
('sub_matrix', 2): 52.709120728751486, | |
('sub_matrix', 3): 17.976257844509426, | |
('sub_matrix', 4): 42.51851542500959, | |
('sub_matrix', 5): 28.25018664655579, | |
('sub_no-matrix', 1): 19.641722357026286, | |
('sub_no-matrix', 2): 52.709120728751486, | |
('sub_no-matrix', 3): 17.976257844509426, | |
('sub_no-matrix', 4): 42.51851267328718, | |
('sub_no-matrix', 5): 5.409622788119386}, | |
{('no-sub_matrix', 1): 16.961927903326398, | |
('no-sub_matrix', 2): 38.5455951142925, | |
('no-sub_matrix', 3): 25.122316709729276, | |
('no-sub_matrix', 4): 35.90131439006518, | |
('no-sub_matrix', 5): 41.65886977570029, | |
('no-sub_no-matrix', 1): 16.961927903326398, | |
('no-sub_no-matrix', 2): 38.5455951142925, | |
('no-sub_no-matrix', 3): 25.122316709729276, | |
('no-sub_no-matrix', 4): 35.90131439006518, | |
('no-sub_no-matrix', 5): 3.2679255886472447, | |
('sub_matrix', 1): 20.247934372024154, | |
('sub_matrix', 2): 40.408716019775625, | |
('sub_matrix', 3): 23.782735071043668, | |
('sub_matrix', 4): 37.00513584758997, | |
('sub_matrix', 5): 29.22700479607527, | |
('sub_no-matrix', 1): 20.247934372024154, | |
('sub_no-matrix', 2): 40.408716019775625, | |
('sub_no-matrix', 3): 23.782735071043668, | |
('sub_no-matrix', 4): 37.00513584758997, | |
('sub_no-matrix', 5): 4.780011845541033}, | |
{('no-sub_matrix', 1): 12.109815771064152, | |
('no-sub_matrix', 2): 38.32406752938649, | |
('no-sub_matrix', 3): 25.987801084044044, | |
('no-sub_matrix', 4): 40.40950903177875, | |
('no-sub_matrix', 5): 52.86522525335603, | |
('no-sub_no-matrix', 1): 12.109815771064152, | |
('no-sub_no-matrix', 2): 38.32406752938649, | |
('no-sub_no-matrix', 3): 25.987801084044044, | |
('no-sub_no-matrix', 4): 40.40950903177875, | |
('no-sub_no-matrix', 5): 3.61917194787979, | |
('sub_matrix', 1): 15.130341564722832, | |
('sub_matrix', 2): 37.89719334728088, | |
('sub_matrix', 3): 24.65681032273433, | |
('sub_matrix', 4): 40.731610867030774, | |
('sub_matrix', 5): 37.566910985257906, | |
('sub_no-matrix', 1): 15.130341564722832, | |
('sub_no-matrix', 2): 37.89719334728088, | |
('sub_no-matrix', 3): 24.65681032273433, | |
('sub_no-matrix', 4): 40.731610867030774, | |
('sub_no-matrix', 5): 9.39736249989602}, | |
{('no-sub_matrix', 1): 16.25058564557851, | |
('no-sub_matrix', 2): 37.20405682898803, | |
('no-sub_matrix', 3): 30.5107090995129, | |
('no-sub_matrix', 4): 44.537084655292894, | |
('no-sub_matrix', 5): 46.50046620075818, | |
('no-sub_no-matrix', 1): 16.25058564557851, | |
('no-sub_no-matrix', 2): 37.20405682898803, | |
('no-sub_no-matrix', 3): 30.5107090995129, | |
('no-sub_no-matrix', 4): 44.537084655292894, | |
('no-sub_no-matrix', 5): 1.8752506698658238, | |
('sub_matrix', 1): 18.440281483079957, | |
('sub_matrix', 2): 38.54769605435544, | |
('sub_matrix', 3): 30.510800250317864, | |
('sub_matrix', 4): 44.99740645329493, | |
('sub_matrix', 5): 39.55738177603457, | |
('sub_no-matrix', 1): 18.440281483079957, | |
('sub_no-matrix', 2): 38.54769605435544, | |
('sub_no-matrix', 3): 30.510800250317864, | |
('sub_no-matrix', 4): 44.99740645329493, | |
('sub_no-matrix', 5): 2.6233048602148386}, | |
{('no-sub_matrix', 1): 16.324447378609865, | |
('no-sub_matrix', 2): 30.87308462806543, | |
('no-sub_matrix', 3): 22.765564836381643, | |
('no-sub_matrix', 4): 38.337445027901204, | |
('no-sub_matrix', 5): 40.98815076599078, | |
('no-sub_no-matrix', 1): 16.324447378609865, | |
('no-sub_no-matrix', 2): 30.87308462806543, | |
('no-sub_no-matrix', 3): 22.765564836381643, | |
('no-sub_no-matrix', 4): 38.337445027901204, | |
('no-sub_no-matrix', 5): 1.4796406979126138, | |
('sub_matrix', 1): 17.9623592385626, | |
('sub_matrix', 2): 32.36568198294609, | |
('sub_matrix', 3): 22.438215466486483, | |
('sub_matrix', 4): 40.900713840387546, | |
('sub_matrix', 5): 33.396627340011634, | |
('sub_no-matrix', 1): 17.9623592385626, | |
('sub_no-matrix', 2): 32.36568198294609, | |
('sub_no-matrix', 3): 22.438215466486483, | |
('sub_no-matrix', 4): 40.900713840387546, | |
('sub_no-matrix', 5): 6.609518913895668}, | |
{('no-sub_matrix', 1): 14.033258731424148, | |
('no-sub_matrix', 2): 28.37206528002418, | |
('no-sub_matrix', 3): 27.043658386061033, | |
('no-sub_matrix', 4): 36.167049513436204, | |
('no-sub_matrix', 5): 52.280797076864395, | |
('no-sub_no-matrix', 1): 14.033258731424148, | |
('no-sub_no-matrix', 2): 28.37206528002418, | |
('no-sub_no-matrix', 3): 27.043658386061033, | |
('no-sub_no-matrix', 4): 36.167049513436204, | |
('no-sub_no-matrix', 5): 1.9358795417918389, | |
('sub_matrix', 1): 16.606623097498794, | |
('sub_matrix', 2): 29.98729916366884, | |
('sub_matrix', 3): 24.737985875967603, | |
('sub_matrix', 4): 34.93154214402433, | |
('sub_matrix', 5): 42.35241303296243, | |
('sub_no-matrix', 1): 16.606623097498794, | |
('sub_no-matrix', 2): 29.98729916366884, | |
('sub_no-matrix', 3): 24.737985875967603, | |
('sub_no-matrix', 4): 34.931551775052775, | |
('sub_no-matrix', 5): 7.151971456773863}, | |
{('no-sub_matrix', 1): 10.482293039084738, | |
('no-sub_matrix', 2): 52.67861788579445, | |
('no-sub_matrix', 3): 21.665543335527666, | |
('no-sub_matrix', 4): 23.53727708917033, | |
('no-sub_matrix', 5): 32.2645584918966, | |
('no-sub_no-matrix', 1): 10.482293039084738, | |
('no-sub_no-matrix', 2): 52.67861788579445, | |
('no-sub_no-matrix', 3): 21.665543335527666, | |
('no-sub_no-matrix', 4): 23.53727708917033, | |
('no-sub_no-matrix', 5): 2.5207572809328243, | |
('sub_matrix', 1): 11.523882918360123, | |
('sub_matrix', 2): 57.336257883871156, | |
('sub_matrix', 3): 21.647716645835132, | |
('sub_matrix', 4): 23.491483569694733, | |
('sub_matrix', 5): 24.264706351480406, | |
('sub_no-matrix', 1): 11.523882918360123, | |
('sub_no-matrix', 2): 57.336257883871156, | |
('sub_no-matrix', 3): 21.647716645835132, | |
('sub_no-matrix', 4): 23.491462243846026, | |
('sub_no-matrix', 5): 9.714244661694366}, | |
{('no-sub_matrix', 1): 11.992867568477442, | |
('no-sub_matrix', 2): 28.861638231250264, | |
('no-sub_matrix', 3): 24.222607873884137, | |
('no-sub_matrix', 4): 41.28280460012173, | |
('no-sub_matrix', 5): 56.6084264455065, | |
('no-sub_no-matrix', 1): 11.992867568477442, | |
('no-sub_no-matrix', 2): 28.861638231250264, | |
('no-sub_no-matrix', 3): 24.222607873884137, | |
('no-sub_no-matrix', 4): 41.28280460012173, | |
('no-sub_no-matrix', 5): 2.4980576348107437, | |
('sub_matrix', 1): 14.531057698832324, | |
('sub_matrix', 2): 31.280393934821902, | |
('sub_matrix', 3): 20.756528260470358, | |
('sub_matrix', 4): 42.15937712589425, | |
('sub_matrix', 5): 52.45767194621365, | |
('sub_no-matrix', 1): 14.531057698832324, | |
('sub_no-matrix', 2): 31.280393934821902, | |
('sub_no-matrix', 3): 20.756528260470358, | |
('sub_no-matrix', 4): 42.15937712589425, | |
('sub_no-matrix', 5): 4.819862633503057}] | |
def test_gpt_subordination_region_totals(syntaxgym_metric): | |
""" | |
Check region-level surprisals against the original syntaxgym-core | |
implementation, using the same underlying `gpt2` model. | |
""" | |
suite_name = "subordination_src-src" | |
dataset = datasets.load_dataset("cpllab/syntaxgym", suite_name) | |
result = syntaxgym_metric.compute(dataset=dataset["test"], model_id="gpt2") | |
region_totals = result[suite_name].region_totals | |
from pprint import pprint | |
pprint(region_totals[0]) | |
pprint(GPT2_SUBORDINATION_SRC_REFERENCE[0]) | |
keys = region_totals[0].keys() | |
assert set(keys) == set(GPT2_SUBORDINATION_SRC_REFERENCE[0].keys()) | |
result_ndarray = np.concatenate([np.array([region_totals_i[key] for key in keys]) | |
for region_totals_i in region_totals]) | |
reference_ndarray = np.concatenate([np.array([region_totals_i[key] for key in keys]) | |
for region_totals_i in GPT2_SUBORDINATION_SRC_REFERENCE]) | |
pprint(sorted(zip(keys, np.abs(result_ndarray - reference_ndarray)), | |
key=lambda x: -x[1])) | |
np.testing.assert_allclose(result_ndarray, reference_ndarray, atol=1e-3) |