Spaces:
Runtime error
Runtime error
import regex as re | |
PROGRAM_SPECIAL_TOKEN="<extra_id_124>" | |
UTTERANCES_SPECIAL_TOKEN="<extra_id_123>" | |
GT_PROGRAM_SPECIAL_TOKEN="<extra_id_122>" | |
def consistent(rx, spec): | |
# spec is in the form of (string, '+'/'-') pairs | |
for s, label in spec: | |
if not label in ['+', '-']: | |
return None | |
try: | |
if re.fullmatch(rx, s, timeout=1): | |
if label == '-': | |
return False | |
else: | |
if label == '+': | |
return False | |
except re.error: | |
return None | |
except TimeoutError: | |
return None | |
return True | |
def decode(c): | |
if c < 3: | |
return f"<{c}>" | |
elif c < 258: | |
return chr(c - 3) | |
else: | |
return f"<extra_id_{c - 259}>" | |
def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False): | |
skipped_tokens = outputs | |
if skip_special_tokens: | |
skipped_tokens = [ | |
[[t for t in x if t >= 3] for x in beam] | |
for beam in skipped_tokens | |
] | |
if skip_position_token: | |
skipped_tokens = [ | |
[[t for t in x if t <= 258] for x in beam] | |
for beam in skipped_tokens | |
] | |
return [ | |
[''.join([decode(t) for t in x]) for x in beam] | |
for beam in skipped_tokens | |
] | |
def get_preprocess_function(tokenizer): | |
def preprocess_function(examples): | |
model_inputs = tokenizer( | |
[' ' if x is None else x for x in examples["context"]], | |
text_target=examples["target"], | |
truncation=True | |
) | |
return model_inputs | |
return preprocess_function | |
def get_utterance_processing_functions(label_pos, idx, separator=' '): | |
if label_pos == "suffix": | |
if idx: | |
def utterances_to_string(spec): | |
return ''.join([f"<extra_id_{i}>{s}{label}" for i, (s, label) in enumerate(spec)]) | |
else: | |
def utterances_to_string(spec): | |
return separator.join([f"{s}{label}" for s, label in spec]) | |
else: | |
if idx: | |
def utterances_to_string(spec): | |
return ''.join([f"<extra_id_{i}>{label}{s}" for i, (s, label) in enumerate(spec)]) | |
else: | |
def utterances_to_string(spec): | |
return separator.join([f"{label}{s}" for s, label in spec]) | |
if label_pos == "suffix": | |
if idx: | |
def string_to_utterances(string): | |
string = re.sub(r'<extra_id_\d+>', ' ', string) | |
return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0] | |
else: | |
def string_to_utterances(string): | |
return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0] | |
else: | |
if idx: | |
def string_to_utterances(string): | |
string = re.sub(r'<extra_id_\d+>', '', string) | |
return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] | |
else: | |
def string_to_utterances(string): | |
return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] | |
return utterances_to_string, string_to_utterances |