File size: 1,501 Bytes
549360a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from datasets_.util import _get_dataset_config_names, _load_dataset
from langcodes import Language, standardize_tag

slug_mgsm = "juletxara/mgsm"
tags_mgsm = {
    standardize_tag(a, macro=True): a for a in _get_dataset_config_names(slug_mgsm)
}
slug_afrimgsm = "masakhane/afrimgsm"
tags_afrimgsm = {
    standardize_tag(a, macro=True): a for a in _get_dataset_config_names(slug_afrimgsm)
}
slug_gsm8kx = "Eurolingua/gsm8kx"
tags_gsm8kx = {
    standardize_tag(a, macro=True): a
    for a in _get_dataset_config_names(slug_gsm8kx, trust_remote_code=True)
}

def parse_number(i):
    if isinstance(i, int):
        return i
    try:
        return int(i.replace(",", "").replace(".", ""))
    except ValueError:
        return None

def load_mgsm(language_bcp_47, nr):
    if language_bcp_47 in tags_mgsm.keys():
        ds = _load_dataset(slug_mgsm, subset=tags_mgsm[language_bcp_47], split="test")
        return slug_mgsm, ds[nr]
    elif language_bcp_47 in tags_afrimgsm.keys():
        ds = _load_dataset(
            slug_afrimgsm, subset=tags_afrimgsm[language_bcp_47], split="test"
        )
        return slug_afrimgsm, ds[nr]
    elif language_bcp_47 in tags_gsm8kx.keys():
        row = _load_dataset(
            slug_gsm8kx,
            subset=tags_gsm8kx[language_bcp_47],
            split="test",
            trust_remote_code=True,
        )[nr]
        row["answer_number"] = row["answer"].split("####")[1].strip()
        return slug_gsm8kx, row
    else:
        return None, None