File size: 7,902 Bytes
c3b1078 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import os
from pathlib import Path
import re
VOCAB_DIR = Path(__file__).resolve().parent
PAD = "@@PADDING@@"
UNK = "@@UNKNOWN@@"
START_TOKEN = "$START"
SEQ_DELIMETERS = {"tokens": " ", "labels": "SEPL|||SEPR", "operations": "SEPL__SEPR"}
def get_verb_form_dicts():
path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt")
encode, decode = {}, {}
with open(path_to_dict, encoding="utf-8") as f:
for line in f:
words, tags = line.split(":")
word1, word2 = words.split("_")
tag1, tag2 = tags.split("_")
decode_key = f"{word1}_{tag1}_{tag2.strip()}"
if decode_key not in decode:
encode[words] = tags
decode[decode_key] = word2
return encode, decode
ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts()
def get_target_sent_by_edits(source_tokens, edits):
target_tokens = source_tokens[:]
shift_idx = 0
for edit in edits:
start, end, label, _ = edit
target_pos = start + shift_idx
if start < 0:
continue
elif len(target_tokens) > target_pos:
source_token = target_tokens[target_pos]
else:
source_token = ""
if label == "":
del target_tokens[target_pos]
shift_idx -= 1
elif start == end:
word = label.replace("$APPEND_", "")
# Avoid appending same token twice
if (target_pos < len(target_tokens) and target_tokens[target_pos] == word) or (
target_pos > 0 and target_tokens[target_pos - 1] == word
):
continue
target_tokens[target_pos:target_pos] = [word]
shift_idx += 1
elif label.startswith("$TRANSFORM_"):
word = apply_reverse_transformation(source_token, label)
if word is None:
word = source_token
target_tokens[target_pos] = word
elif start == end - 1:
word = label.replace("$REPLACE_", "")
target_tokens[target_pos] = word
elif label.startswith("$MERGE_"):
target_tokens[target_pos + 1 : target_pos + 1] = [label]
shift_idx += 1
return replace_merge_transforms(target_tokens)
def replace_merge_transforms(tokens):
if all(not x.startswith("$MERGE_") for x in tokens):
return tokens
if tokens[0].startswith("$MERGE_"):
tokens = tokens[1:]
if tokens[-1].startswith("$MERGE_"):
tokens = tokens[:-1]
target_line = " ".join(tokens)
target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
target_line = target_line.replace(" $MERGE_SPACE ", "")
target_line = re.sub(r'([\.\,\?\:]\s+)+', r'\1', target_line)
return target_line.split()
def convert_using_case(token, smart_action):
if not smart_action.startswith("$TRANSFORM_CASE_"):
return token
if smart_action.endswith("LOWER"):
return token.lower()
elif smart_action.endswith("UPPER"):
return token.upper()
elif smart_action.endswith("CAPITAL"):
return token.capitalize()
elif smart_action.endswith("CAPITAL_1"):
return token[0] + token[1:].capitalize()
elif smart_action.endswith("UPPER_-1"):
return token[:-1].upper() + token[-1]
else:
return token
def convert_using_verb(token, smart_action):
key_word = "$TRANSFORM_VERB_"
if not smart_action.startswith(key_word):
raise Exception(f"Unknown action type {smart_action}")
encoding_part = f"{token}_{smart_action[len(key_word):]}"
decoded_target_word = decode_verb_form(encoding_part)
return decoded_target_word
def convert_using_split(token, smart_action):
key_word = "$TRANSFORM_SPLIT"
if not smart_action.startswith(key_word):
raise Exception(f"Unknown action type {smart_action}")
target_words = token.split("-")
return " ".join(target_words)
def convert_using_plural(token, smart_action):
if smart_action.endswith("PLURAL"):
return token + "s"
elif smart_action.endswith("SINGULAR"):
return token[:-1]
else:
raise Exception(f"Unknown action type {smart_action}")
def apply_reverse_transformation(source_token, transform):
if transform.startswith("$TRANSFORM"):
# deal with equal
if transform == "$KEEP":
return source_token
# deal with case
if transform.startswith("$TRANSFORM_CASE"):
return convert_using_case(source_token, transform)
# deal with verb
if transform.startswith("$TRANSFORM_VERB"):
return convert_using_verb(source_token, transform)
# deal with split
if transform.startswith("$TRANSFORM_SPLIT"):
return convert_using_split(source_token, transform)
# deal with single/plural
if transform.startswith("$TRANSFORM_AGREEMENT"):
return convert_using_plural(source_token, transform)
# raise exception if not find correct type
raise Exception(f"Unknown action type {transform}")
else:
return source_token
# def read_parallel_lines(fn1, fn2):
# lines1 = read_lines(fn1, skip_strip=True)
# lines2 = read_lines(fn2, skip_strip=True)
# assert len(lines1) == len(lines2)
# out_lines1, out_lines2 = [], []
# for line1, line2 in zip(lines1, lines2):
# if not line1.strip() or not line2.strip():
# continue
# else:
# out_lines1.append(line1)
# out_lines2.append(line2)
# return out_lines1, out_lines2
def read_parallel_lines(fn1, fn2):
with open(fn1, encoding='utf-8') as f1, open(fn2, encoding='utf-8') as f2:
for line1, line2 in zip(f1, f2):
line1 = line1.strip()
line2 = line2.strip()
yield line1, line2
def read_lines(fn, skip_strip=False):
if not os.path.exists(fn):
return []
with open(fn, 'r', encoding='utf-8') as f:
lines = f.readlines()
return [s.strip() for s in lines if s.strip() or skip_strip]
def write_lines(fn, lines, mode='w'):
if mode == 'w' and os.path.exists(fn):
os.remove(fn)
with open(fn, encoding='utf-8', mode=mode) as f:
f.writelines(['%s\n' % s for s in lines])
def decode_verb_form(original):
return DECODE_VERB_DICT.get(original)
def encode_verb_form(original_word, corrected_word):
decoding_request = original_word + "_" + corrected_word
decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip()
if original_word and decoding_response:
answer = decoding_response
else:
answer = None
return answer
def get_weights_name(transformer_name, lowercase):
if transformer_name == 'bert' and lowercase:
return 'bert-base-uncased'
if transformer_name == 'bert' and not lowercase:
return 'bert-base-cased'
if transformer_name == 'bert-large' and not lowercase:
return 'bert-large-cased'
if transformer_name == 'distilbert':
if not lowercase:
print('Warning! This model was trained only on uncased sentences.')
return 'distilbert-base-uncased'
if transformer_name == 'albert':
if not lowercase:
print('Warning! This model was trained only on uncased sentences.')
return 'albert-base-v1'
if lowercase:
print('Warning! This model was trained only on cased sentences.')
if transformer_name == 'roberta':
return 'roberta-base'
if transformer_name == 'roberta-large':
return 'roberta-large'
if transformer_name == 'gpt2':
return 'gpt2'
if transformer_name == 'transformerxl':
return 'transfo-xl-wt103'
if transformer_name == 'xlnet':
return 'xlnet-base-cased'
if transformer_name == 'xlnet-large':
return 'xlnet-large-cased'
return transformer_name |