|
from datasets_.util import _get_dataset_config_names, _load_dataset |
|
from langcodes import Language, standardize_tag |
|
|
|
slug_mgsm = "juletxara/mgsm" |
|
tags_mgsm = { |
|
standardize_tag(a, macro=True): a for a in _get_dataset_config_names(slug_mgsm) |
|
} |
|
slug_afrimgsm = "masakhane/afrimgsm" |
|
tags_afrimgsm = { |
|
standardize_tag(a, macro=True): a for a in _get_dataset_config_names(slug_afrimgsm) |
|
} |
|
slug_gsm8kx = "Eurolingua/gsm8kx" |
|
tags_gsm8kx = { |
|
standardize_tag(a, macro=True): a |
|
for a in _get_dataset_config_names(slug_gsm8kx, trust_remote_code=True) |
|
} |
|
|
|
def parse_number(i): |
|
if isinstance(i, int): |
|
return i |
|
try: |
|
return int(i.replace(",", "").replace(".", "")) |
|
except ValueError: |
|
return None |
|
|
|
def load_mgsm(language_bcp_47, nr): |
|
if language_bcp_47 in tags_mgsm.keys(): |
|
ds = _load_dataset(slug_mgsm, subset=tags_mgsm[language_bcp_47], split="test") |
|
return slug_mgsm, ds[nr] |
|
elif language_bcp_47 in tags_afrimgsm.keys(): |
|
ds = _load_dataset( |
|
slug_afrimgsm, subset=tags_afrimgsm[language_bcp_47], split="test" |
|
) |
|
return slug_afrimgsm, ds[nr] |
|
elif language_bcp_47 in tags_gsm8kx.keys(): |
|
row = _load_dataset( |
|
slug_gsm8kx, |
|
subset=tags_gsm8kx[language_bcp_47], |
|
split="test", |
|
trust_remote_code=True, |
|
)[nr] |
|
row["answer_number"] = row["answer"].split("####")[1].strip() |
|
return slug_gsm8kx, row |
|
else: |
|
return None, None |
|
|