File size: 2,441 Bytes
550665c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import dataclasses
import re
import copy
import yaml
from pathlib import Path
from dataclasses import dataclass, field
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union, Dict

from transformers.hf_argparser import DataClass, HfArgumentParser as OriginalHfArgumentParser

DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)

def lambda_field(default, **kwargs):
    return field(default_factory=lambda: copy.copy(default))

class HfArgumentParser(OriginalHfArgumentParser):
    def parse_yaml_file(self, yaml_file: str) -> Tuple[DataClass, ...]:
        """

        Parse a YAML file and return a tuple of dataclass instances.



        Args:

            yaml_file (str): Path to the YAML file.



        Returns:

            Tuple[DataClass, ...]: A tuple of dataclass instances.

        """
        # Create a custom YAML loader that allows parsing of floats with exponents
        loader = yaml.SafeLoader
        loader.add_implicit_resolver(
            u'tag:yaml.org,2002:float',
            re.compile(u'''^(?:

                 [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?

                |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)

                |\\.[0-9_]+(?:[eE][-+][0-9]+)?

                |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*

                |[-+]?\\.(?:inf|Inf|INF)

                |\\.(?:nan|NaN|NAN))$''', re.X),
            list(u'-+0123456789.')
        )

        # Load the YAML data from the file
        data = yaml.load(Path(yaml_file).read_text(), Loader=loader)

        # Create a list to store the dataclass instances
        outputs = []

        # Iterate over each dataclass type
        for dtype in self.dataclass_types:
            # Get the names of the fields that are initialized
            keys = {f.name for f in dataclasses.fields(dtype) if f.init}
            # Get the name of the argument from the dataclass's mro
            arg_name = dtype.__mro__[-2].__name__
            # Create a dictionary of the inputs for the dataclass
            inputs = {k: v for k, v in data[arg_name].items() if k in keys}
            # Create an instance of the dataclass with the inputs
            obj = dtype(**inputs)
            # Add the instance to the list
            outputs.append(obj)

        # Return the list of dataclass instances as a tuple
        return (*outputs,)