|
import numpy as np |
|
import os |
|
import sys |
|
from constants import * |
|
if __name__ == '__main__': |
|
from threading import Thread |
|
|
|
N_THREADS = 32 |
|
if '--n_threads' in sys.argv: |
|
N_THREADS = int(sys.argv[sys.argv.index('--n_threads')+1]) |
|
|
|
if __name__ == '__main__': |
|
if not os.path.exists('lemmas'): |
|
os.mkdir('lemmas') |
|
|
|
file = open("inputs/join.txt" if not KAGGLE else "inputs/join-kaggle.txt", "r") |
|
text = file.read() |
|
file.close() |
|
|
|
tokens = text.split(" ") |
|
tokens = [x for x in tokens if x != ''] |
|
print("Total number of tokens:", len(tokens)) |
|
print("Counting tokens") |
|
counts = {} |
|
for token in tokens: |
|
if not token in counts: |
|
counts[token] = 0 |
|
counts[token] += 1 |
|
words = list(counts.keys()) |
|
words.sort(reverse=True, key=lambda word: counts[word]) |
|
|
|
for token in BANNED_TOKENS: |
|
if token in words: |
|
words.remove(token) |
|
words.append(token) |
|
counts['<unk>'] = 0 |
|
for word in words: |
|
if word in words[:VOCAB_SIZE]: |
|
continue |
|
counts['<unk>'] += counts[word] |
|
words = list(counts.keys()) |
|
words.sort(reverse=True, key=lambda word: counts[word]) |
|
for token in BANNED_TOKENS: |
|
if token in words: |
|
words.remove(token) |
|
words.append(token) |
|
|
|
vocab = set(words[:VOCAB_SIZE]) |
|
else: |
|
print("Loading vocab") |
|
vocab = set(list(np.load('lemmas/lemmas.npy'))) |
|
|
|
def pretty_tokens(tokens, mask=True): |
|
s_dict = np.load('lemmas/s.npy', allow_pickle=True).item() |
|
ed_dict = np.load('lemmas/ed.npy', allow_pickle=True).item() |
|
er_dict = np.load('lemmas/er.npy', allow_pickle=True).item() |
|
est_dict = np.load('lemmas/est.npy', allow_pickle=True).item() |
|
ing_dict = np.load('lemmas/ing.npy', allow_pickle=True).item() |
|
dicts = {'=s': s_dict, '=ed': ed_dict, '=er': er_dict, '=est': est_dict, '=ing': ing_dict} |
|
res = [] |
|
i = 0 |
|
def includeSpace(this): |
|
nonlocal res |
|
quote = set(["'", '"']) |
|
nospace = set(['\n','-']) |
|
prev = res[len(res)-1] if len(res) > 0 else None |
|
prev2 = res[len(res)-2] if len(res) > 1 else None |
|
space = not prev in nospace\ |
|
and not this in PUNCT and not this == '\n'\ |
|
and not (this.startswith("'") and this != "'") |
|
if prev in quote and not prev2 in PUNCT: |
|
space = False |
|
elif this in quote and prev in PUNCT: |
|
space = False |
|
return space |
|
while i < len(tokens): |
|
this = tokens[i] |
|
if this == NEWLINE.lower()[1:-1]: |
|
this = '\n' |
|
elif this == TITLE.lower()[1:-1]: |
|
this = '\n ༄༅༅ ' |
|
elif mask and not this in vocab: |
|
this = " <unk>" |
|
if not includeSpace(this): |
|
this = "<unk>" |
|
res.append(this) |
|
i += 1 |
|
continue |
|
if i+1 < len(tokens): |
|
next = tokens[i+1] |
|
while next.startswith('='): |
|
if next == "=nt": |
|
if tokens[i].endswith('n'): |
|
this = this[:-1] |
|
if tokens[i] == 'will': |
|
this = 'wo' |
|
elif tokens[i] == 'shall': |
|
this = 'sha' |
|
this = this+"n't" |
|
else: |
|
if tokens[i] in dicts[next]: |
|
this = dicts[next][this] |
|
else: |
|
if next[1] == 'e' or next[1] == 'i': |
|
if this.endswith('e'): |
|
this = this[:-1] |
|
elif this.endswith('c'): |
|
this = this+'k' |
|
if this.endswith('y') and next[1] == 'e' and len(this) > 2 and not this[-2] in VOWELS: |
|
this = this[:-1]+'i' |
|
if next[1] == 's': |
|
if this.endswith('s') or this.endswith('sh') or this.endswith('x') or this.endswith('ch'): |
|
this = this+'e' |
|
if this.endswith('y') and len(this) > 2 and not this[-2] in VOWELS: |
|
this = this[:-1]+'ie' |
|
|
|
this = this+next[1:] |
|
i += 1 |
|
next = tokens[i+1] if i+1 < len(tokens) else '' |
|
if this.startswith('='): |
|
this = this[1:] |
|
elif includeSpace(this): |
|
this = " "+this |
|
res.append(this) |
|
i += 1 |
|
res = ''.join(res) |
|
res = res[1:] if res.startswith(' ') else res |
|
return res |
|
|
|
def getRhyme(line): |
|
|
|
|
|
|
|
if line is None or len(line) == 0: |
|
return [-1, -1] |
|
nl = NEWLINE.lower()[1:-1] |
|
tl = TITLE.lower()[1:-1] |
|
if line[0] == tl: |
|
return [-1, -1] |
|
while line[-1] == nl or line[-1] in PUNCT or line[-1] == '"' or line[-1] == "'" or line[-1] is None: |
|
line = line[:-1] |
|
if len(line) == 0: |
|
return [-1, -1] |
|
word = line[-1]+'' |
|
long_vowel = False |
|
vowel_type = None |
|
vowel_map = {'a': 0, 'e': 1, 'i': 2, 'o': 3, 'u': 4, 'ow': 5, 'ou': 5, 'oi': 6, 'oy': 6, |
|
'ay': 7, 'ai': 7, 'au': 3, 'aw': 3, 'ea': 8, 'ee': 8, 'eu': 11, 'ew': 11, |
|
'oa': 10, 'oo': 11, 'y': 9, 'ey': 7, 'ei': 9} |
|
|
|
|
|
|
|
|
|
|
|
|
|
consonant_type = -1 |
|
cons_map = {'r': 0, 'l': 1, 'n': 2, 'm': 2, 'ng': 2, |
|
'p': 3, 'b': 3, 't': 4, 'd': 4, 'f': 5, |
|
'v': 5, 's': 6, 'sh': 6, 'z': 6, 'zh': 6, |
|
'th': 9, 'k': 7, 'ch': 8, 'j': 8} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def getVowel(type, isLong, beforeR): |
|
if beforeR and not isLong: |
|
if type == 0: |
|
return 12 |
|
if type ==1 or type == 2 or type == 4: |
|
return 13 |
|
if type == 3: |
|
return 10 |
|
if isLong and 0 <= type <= 4: |
|
return type+7 |
|
return type |
|
|
|
lock_consonant = -1 |
|
if len(line) > 1: |
|
if word == '=ed': |
|
if line[-2].endswith('t') or line[-2].endswith('d'): |
|
return [4, 4] |
|
lock_consonant = 4 |
|
word = line[-2] |
|
if word == '=s' or word == "'s": |
|
if line[-2].endswith('s') or line[-2].endswith('z') or line[-2].endswith('ch') or line[-2].endswith('sh') or line[-2].endswith('x'): |
|
return [4, 6] |
|
lock_consonant = 6 |
|
word = line[-2] |
|
elif word == "'re": |
|
lock_consonant = 0 |
|
word = line[-2] |
|
elif word == "'ve": |
|
lock_consonant = 5 |
|
word = line[-2] |
|
elif word == "'ll": |
|
lock_consonant = 1 |
|
word = line[-2] |
|
elif word == "'d": |
|
lock_consonant = 4 |
|
word = line[-2] |
|
elif word == "'m": |
|
lock_consonant = 2 |
|
word = line[-2] |
|
elif word == "=nt'": |
|
lock_consonant = 4 |
|
word = line[-2] |
|
if word in DEFINED_RHYMES: |
|
vowel_type = DEFINED_RHYMES[word][0] |
|
consonant_type = DEFINED_RHYMES[word][1] if lock_consonant == -1 else lock_consonant |
|
return [vowel_type, consonant_type] |
|
|
|
if word.endswith('o'): |
|
return [10, lock_consonant] |
|
if word.endswith('bble') or word.endswith('ggle'): |
|
return [4, 1 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('old'): |
|
return [10, 1 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('ance'): |
|
return [0, 6 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('ense') or word.endswith('ence'): |
|
return [1, 6 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('ince'): |
|
return [2, 6 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('ture') or word.endswith('sure'): |
|
return [13, 0 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('all'): |
|
return [3, 1 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('row') or word.endswith('low'): |
|
return [10, lock_consonant] |
|
if word.endswith('le') and len(word) >= 3 and not word[-3] in VOWELS: |
|
return [4, 1 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('on') and len(word) > 3 and not word.endswith('oon'): |
|
return [4, 2 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('al') and len(word) > 3 and not word.endswith('eal'): |
|
return [4, 1 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('ous'): |
|
return [4, 6 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('ly'): |
|
return [8, -1 if lock_consonant == -1 else lock_consonant] |
|
if word.endswith('ward'): |
|
return [13, 4 if lock_consonant == -1 else lock_consonant] |
|
|
|
if word.endswith('e'): |
|
long_vowel = True |
|
word = word[:-1] |
|
if lock_consonant == -1: |
|
if word[-2:] in cons_map: |
|
consonant_type = cons_map[word[-2:]] |
|
elif word[-1:] in cons_map: |
|
consonant_type = cons_map[word[-1:]] |
|
elif word[-1] == 'c' and long_vowel: |
|
consonant_type = cons_map['s'] |
|
elif word[-1] == 'g' and long_vowel: |
|
consonant_type = cons_map['j'] |
|
else: |
|
consonant_type = lock_consonant |
|
|
|
lock_r = False |
|
if not word[-1] in SOMETIMES_VOWELS: |
|
while not word[-1] in SOMETIMES_VOWELS: |
|
if word.endswith('igh'): |
|
return [9, consonant_type] |
|
if word[-1] == 'r': |
|
lock_r = True |
|
elif lock_r: |
|
lock_r = False |
|
word = word[:-1] |
|
if word == '': |
|
return [8, lock_consonant] |
|
if word[-2:] in vowel_map: |
|
vowel_type = vowel_map[word[-2:]] |
|
elif word[-1:] in vowel_map: |
|
vowel_type = vowel_map[word[-1:]] |
|
|
|
vowel_type = getVowel(vowel_type, long_vowel, consonant_type == 0 or lock_r) |
|
return [vowel_type, consonant_type] |
|
def pretty_rhyme(rhyme): |
|
v_map = ['bat', 'bet', 'bit', 'bot', 'but', 'pout', 'boil', 'bait', 'beat', 'bite', 'boat', 'boot', 'bar', 'sir'] |
|
c_map = ['R', 'L', 'N/M/NG', 'P/B', 'T/D', 'F/V', 'S/SH/Z/ZH', 'K/G', 'CH/J', 'TH'] |
|
return "Rhyme is " +\ |
|
(v_map[rhyme[0]] if rhyme[0] != -1 else '--') + ' ' + (c_map[rhyme[1]] if rhyme[1] != -1 else 'ø') |
|
|
|
|
|
def getMeter(line): |
|
if line is None: |
|
return 0 |
|
res = 0 |
|
nl = NEWLINE.lower()[1:-1] |
|
tl = TITLE.lower()[1:-1] |
|
for i in range(len(line)): |
|
word = line[i] |
|
if word == nl or word == tl or word is None: |
|
continue |
|
if word in DEFINED_METERS: |
|
res += DEFINED_METERS[word] |
|
continue |
|
if word == '=ed' and i > 0: |
|
if line[i-1].endswith('t') or line[i-1].endswith('d') or line[i-1].endswith('te') or line[i-1].endswith('de'): |
|
res += 1 |
|
continue |
|
if word == '=s' and i > 0: |
|
if line[i-1].endswith('s') or line[i-1].endswith('z') or line[i-1].endswith('ch') or line[i-1].endswith('sh') or line[i-1].endswith('x'): |
|
res += 1 |
|
continue |
|
if word.endswith('le') and len(word) >= 3 and not word[-3] in VOWELS: |
|
res += 1 |
|
removed_e = False |
|
if word.endswith('e'): |
|
word = word[:-1] |
|
removed_e = True |
|
if word.endswith('y') and len(word) > 2 and not word[-2] in VOWELS: |
|
word = word[:-1]+'i' |
|
word = word.replace('ea','i').replace('ee','i') |
|
word = word.replace('ai','i').replace('au','o') |
|
word = word.replace('eu','u') |
|
word = word.replace('ei','i').replace('ie','i') |
|
word = word.replace('oa','o').replace('ou','o') |
|
word = word.replace('oi','o').replace('oo','u') |
|
if word.endswith('tion') or word.endswith('sion') or word.endswith('tian'): |
|
word = word[:-4]+'shun' |
|
this_count = 0 |
|
for vowel in VOWELS: |
|
this_count += word.count(vowel) |
|
if removed_e and this_count == 0: |
|
this_count = 1 |
|
res += this_count |
|
return res |
|
|
|
def lastLine(tokens, endl): |
|
res = [] |
|
nl = NEWLINE.lower()[1:-1] |
|
i = endl-1 |
|
while i > 0: |
|
if tokens[i] == nl: |
|
break |
|
i -= 1 |
|
res = tokens[i:endl] |
|
if len(res) == 0: |
|
res = tokens[:endl] |
|
return res |
|
def processRhymeStack(rhyme_stack): |
|
prev = rhyme_stack[:-1].flatten(order='F') |
|
lastRhyme = rhyme_stack[-1] |
|
res = np.zeros(RHYME_STACK_SIZE-1) |
|
if lastRhyme[0] != -1: |
|
for i in range(RHYME_STACK_SIZE-1): |
|
if rhyme_stack[i][0] == lastRhyme[0]: |
|
res[i] = 1 |
|
if rhyme_stack[i][1] == lastRhyme[1]: |
|
res[i] = 2 |
|
res = np.concatenate([prev, res]) |
|
return res |
|
def processRhymeMeter(split): |
|
in_title = False |
|
meter = [] |
|
rhymes = [] |
|
meter_stack = np.zeros(METER_STACK_SIZE, np.int8) |
|
rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1 |
|
tl = TITLE.lower()[1:-1] |
|
nl = NEWLINE.lower()[1:-1] |
|
for i in range(len(split)): |
|
line = lastLine(split, i) |
|
if split[i] == tl: |
|
in_title = True |
|
meter_stack = np.zeros(METER_STACK_SIZE, np.int8) |
|
rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1 |
|
meter.append(meter_stack.copy()) |
|
rhymes.append(processRhymeStack(rhyme_stack)) |
|
continue |
|
elif in_title and split[i] == nl: |
|
in_title = False |
|
meter_stack = np.zeros(METER_STACK_SIZE, np.int8) |
|
meter_stack[-1] = getMeter(line) |
|
meter.append(meter_stack.copy()) |
|
rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1 |
|
rhyme_stack[-1] = np.array(getRhyme(line), np.int8) |
|
rhymes.append(processRhymeStack(rhyme_stack)) |
|
meter_stack = np.zeros(METER_STACK_SIZE, np.int8) |
|
rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1 |
|
continue |
|
if not in_title and split[i] == nl: |
|
rhymes.append(processRhymeStack(rhyme_stack)) |
|
meter.append(meter_stack.copy()) |
|
if split[i-1] != nl: |
|
rhyme_stack = np.roll(rhyme_stack, -1, axis=0) |
|
rhyme_stack[-1] = np.array(getRhyme(line), np.int8) |
|
meter_stack = np.roll(meter_stack, -1, axis=0) |
|
meter_stack[-1] = getMeter(line) |
|
else: |
|
meter_stack[-1] = getMeter(line) |
|
rhyme_stack[-1] = np.array(getRhyme(line), np.int8) |
|
rhymes.append(processRhymeStack(rhyme_stack)) |
|
meter.append(meter_stack.copy()) |
|
return [rhymes, meter] |
|
|
|
def rhymeMeterFromTokens(tokens, endl, tl, vocab=None): |
|
|
|
res = [] |
|
start = endl-1 |
|
if len(tokens) >= endl: |
|
while start > 0 and tokens[start] != tl: |
|
start -= 1 |
|
lines = tokens[start:endl] |
|
while len(lines) < TRANSFORMER_N: |
|
lines.append(None) |
|
input_lines = lines if vocab is None else [(vocab[x] if (x is not None and 0 <= x < VOCAB_SIZE) else None) for x in lines] |
|
rhymes, meter = processRhymeMeter(input_lines) |
|
rhymes = rhymes[-TRANSFORMER_N:] |
|
meter = meter[-TRANSFORMER_N:] |
|
rhymes = np.array(rhymes) |
|
meter = np.array(meter) |
|
res = np.concatenate([rhymes, meter], axis=1) |
|
return res |
|
|
|
if __name__ == '__main__': |
|
N = NGRAM_N if MODEL_TYPE == 'n' else TRANSFORMER_N+1 |
|
for i in range(N-1): |
|
tokens.append(None) |
|
words.remove('<unk>') |
|
print({word: counts[word] for word in words[:VOCAB_SIZE]}) |
|
title_token = words.index(TITLE.lower()[1:-1]) |
|
newline_token = words.index(NEWLINE.lower()[1:-1]) |
|
|
|
print("Splitting poems with masked dividers") |
|
mask_list = [-1]*N |
|
splits = [] |
|
chunk_size = len(tokens)//N_THREADS |
|
for i in range(N_THREADS): |
|
splits.append( |
|
tokens[i*chunk_size : (i+1)*chunk_size if i < N_THREADS-1 else len(tokens)]) |
|
|
|
|
|
results = [None] * N_THREADS |
|
threads = [] |
|
|
|
def add_dividers(thread_index, split): |
|
i = 1 |
|
while i < len(split): |
|
if split[i] == title_token: |
|
split = split[:i] + mask_list + split[i:] |
|
i += N+5 |
|
i += 1 |
|
results[thread_index] = split |
|
return split |
|
for i in range(N_THREADS): |
|
t = Thread(target=add_dividers, args=(i, splits[i],)) |
|
threads.append(t) |
|
t.start() |
|
tokens = [] |
|
for i in range(N_THREADS): |
|
threads[i].join() |
|
tokens += results[i] |
|
|
|
if MODEL_TYPE == 'b': |
|
print("Computing rhyme and meter information") |
|
split_token_marks = [] |
|
split_size = len(tokens)//N_THREADS |
|
for i in range(N_THREADS+1): |
|
split_token_marks.append(split_size*i) |
|
for i in range(1, N_THREADS): |
|
while tokens[split_token_marks[i]] != TITLE.lower()[1:-1]: |
|
split_token_marks[i] += 1 |
|
if split_token_marks[i] >= len(tokens): |
|
break |
|
meter_data = [] |
|
rhymes_data = [] |
|
split_token_marks[-1] = len(tokens) |
|
split_tokens = [tokens[split_token_marks[i]:split_token_marks[i+1]] for i in range(N_THREADS)] |
|
rhyme_meter_res = [None] * N_THREADS |
|
threads = [] |
|
def rhymeMeterThread(thread_index, split): |
|
rhyme_meter_res[thread_index] = processRhymeMeter(split) |
|
for i in range(N_THREADS): |
|
t = Thread(target=rhymeMeterThread, args=(i, split_tokens[i])) |
|
threads.append(t) |
|
t.start() |
|
for i in range(N_THREADS): |
|
threads[i].join() |
|
rhymes_data += rhyme_meter_res[i][0] |
|
meter_data += rhyme_meter_res[i][1] |
|
|
|
print("Converting rhyme and meter information") |
|
rhymes_data = np.asarray(rhymes_data) |
|
meter_data = np.asarray(meter_data) |
|
rhyme_meter_data = np.concatenate([rhymes_data, meter_data], axis=1) |
|
|
|
print("Masking unknown tokens") |
|
tokens = [(words.index(x) if x in vocab else -1) for x in tokens] |
|
|
|
print("Creating sets of ngrams") |
|
ngrams = [] |
|
rm_ngrams = [] |
|
for i in range(0, len(tokens)-N, TOKEN_SKIP): |
|
ngrams.append(tokens[i:i+N]) |
|
if MODEL_TYPE == 'b': |
|
rm_ngrams.append(rhyme_meter_data[i:i+N-1,:]) |
|
train_x = [] |
|
train_y = [] |
|
train_rm = [] |
|
for i in range(len(ngrams)): |
|
sample = ngrams[i][:N] |
|
train_x.append(sample[:N-1]) |
|
if MODEL_TYPE == 'b': |
|
sample_rm = rm_ngrams[i] |
|
train_rm.append(sample_rm) |
|
if MODEL_TYPE != 'n': |
|
train_y.append(sample[1:]) |
|
else: |
|
train_y.append(sample[N-1]) |
|
print("Converting arrays") |
|
train_x = np.asarray(train_x) |
|
train_y = np.asarray(train_y) |
|
if MODEL_TYPE == 'b': |
|
train_rm = np.asarray(train_rm, np.int8) |
|
if MODEL_TYPE != 'n': |
|
train_x += 1 |
|
|
|
|
|
print("Saving data") |
|
fname = {'n': 'inputs/ngram_train.npz', |
|
't': 'inputs/transformer_train.npz', |
|
'b': 'inputs/bard_train.npz' |
|
}[MODEL_TYPE] |
|
if MODEL_TYPE != 'b': |
|
np.savez_compressed(fname, x=train_x, y=train_y) |
|
else: |
|
np.savez_compressed(fname, x=train_x, rm=train_rm, y=train_y) |
|
np.save('lemmas/lemmas.npy', words[:VOCAB_SIZE]) |
|
|