saujasv's picture
make barebones gradio interface
2869f1d
import regex as re
PROGRAM_SPECIAL_TOKEN="<extra_id_124>"
UTTERANCES_SPECIAL_TOKEN="<extra_id_123>"
GT_PROGRAM_SPECIAL_TOKEN="<extra_id_122>"
def consistent(rx, spec):
# spec is in the form of (string, '+'/'-') pairs
for s, label in spec:
if not label in ['+', '-']:
return None
try:
if re.fullmatch(rx, s, timeout=1):
if label == '-':
return False
else:
if label == '+':
return False
except re.error:
return None
except TimeoutError:
return None
return True
def decode(c):
if c < 3:
return f"<{c}>"
elif c < 258:
return chr(c - 3)
else:
return f"<extra_id_{c - 259}>"
def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False):
skipped_tokens = outputs
if skip_special_tokens:
skipped_tokens = [
[[t for t in x if t >= 3] for x in beam]
for beam in skipped_tokens
]
if skip_position_token:
skipped_tokens = [
[[t for t in x if t <= 258] for x in beam]
for beam in skipped_tokens
]
return [
[''.join([decode(t) for t in x]) for x in beam]
for beam in skipped_tokens
]
def get_preprocess_function(tokenizer):
def preprocess_function(examples):
model_inputs = tokenizer(
[' ' if x is None else x for x in examples["context"]],
text_target=examples["target"],
truncation=True
)
return model_inputs
return preprocess_function
def get_utterance_processing_functions(label_pos, idx, separator=' '):
if label_pos == "suffix":
if idx:
def utterances_to_string(spec):
return ''.join([f"<extra_id_{i}>{s}{label}" for i, (s, label) in enumerate(spec)])
else:
def utterances_to_string(spec):
return separator.join([f"{s}{label}" for s, label in spec])
else:
if idx:
def utterances_to_string(spec):
return ''.join([f"<extra_id_{i}>{label}{s}" for i, (s, label) in enumerate(spec)])
else:
def utterances_to_string(spec):
return separator.join([f"{label}{s}" for s, label in spec])
if label_pos == "suffix":
if idx:
def string_to_utterances(string):
string = re.sub(r'<extra_id_\d+>', ' ', string)
return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0]
else:
def string_to_utterances(string):
return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0]
else:
if idx:
def string_to_utterances(string):
string = re.sub(r'<extra_id_\d+>', '', string)
return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0]
else:
def string_to_utterances(string):
return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0]
return utterances_to_string, string_to_utterances