File size: 3,193 Bytes
2869f1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import regex as re

PROGRAM_SPECIAL_TOKEN="<extra_id_124>"
UTTERANCES_SPECIAL_TOKEN="<extra_id_123>"
GT_PROGRAM_SPECIAL_TOKEN="<extra_id_122>"

def consistent(rx, spec):
    # spec is in the form of (string, '+'/'-') pairs
    for s, label in spec:
        if not label in ['+', '-']:
            return None
        try:
            if re.fullmatch(rx, s, timeout=1):
                if label == '-':
                    return False
            else:
                if label == '+':
                    return False
        except re.error:
            return None
        except TimeoutError:
            return None

    return True

def decode(c):
    if c < 3:
        return f"<{c}>"
    elif c < 258:
        return chr(c - 3)
    else:
        return f"<extra_id_{c - 259}>"
        
def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False):
    skipped_tokens = outputs
    if skip_special_tokens:
        skipped_tokens = [
            [[t for t in x if t >= 3] for x in beam]
            for beam in skipped_tokens
            ]
    
    if skip_position_token:
        skipped_tokens = [
            [[t for t in x if t <= 258] for x in beam] 
            for beam in skipped_tokens
            ]

    return [
        [''.join([decode(t) for t in x]) for x in beam]
        for beam in skipped_tokens
    ]

def get_preprocess_function(tokenizer):
    def preprocess_function(examples):
        model_inputs = tokenizer(
            [' ' if x is None else x for x in examples["context"]], 
            text_target=examples["target"], 
            truncation=True
        )
        return model_inputs
    
    return preprocess_function

def get_utterance_processing_functions(label_pos, idx, separator=' '):
    if label_pos == "suffix":
        if idx:
            def utterances_to_string(spec):
                return ''.join([f"<extra_id_{i}>{s}{label}" for i, (s, label) in enumerate(spec)])
        else:
            def utterances_to_string(spec):
                return separator.join([f"{s}{label}" for s, label in spec])
    else:
        if idx:
            def utterances_to_string(spec):
                return ''.join([f"<extra_id_{i}>{label}{s}" for i, (s, label) in enumerate(spec)])
        else:
            def utterances_to_string(spec):
                return separator.join([f"{label}{s}" for s, label in spec])
    
    if label_pos == "suffix":
        if idx:
            def string_to_utterances(string):
                string = re.sub(r'<extra_id_\d+>', ' ', string)
                return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0]
        else:
            def string_to_utterances(string):
                return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0]
    else:
        if idx:
            def string_to_utterances(string):
                string = re.sub(r'<extra_id_\d+>', '', string)
                return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0]
        else:
            def string_to_utterances(string):
                return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0]
    
    return utterances_to_string, string_to_utterances