Commit
·
1427339
1
Parent(s):
7245849
Upload files for inference
Browse files- constants.py +197 -0
- lemmas/ed.npy +3 -0
- lemmas/er.npy +3 -0
- lemmas/est.npy +3 -0
- lemmas/ing.npy +3 -0
- lemmas/lemmas.npy +3 -0
- lemmas/s.npy +3 -0
- model.py +433 -0
- saved_models/b_model.h5 +3 -0
- tokens.py +534 -0
constants.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
def parseArgv(argv):
|
3 |
+
data = {
|
4 |
+
'--model-type': 'b',
|
5 |
+
'--vocab-size': 4096,
|
6 |
+
'--ngram-n': 4,
|
7 |
+
'--transformer-n': 32,
|
8 |
+
'--kaggle': False,
|
9 |
+
'--rhyme-size': 4,
|
10 |
+
'--meter-size': 3,
|
11 |
+
}
|
12 |
+
startParse = 1
|
13 |
+
if len(argv) > 1 and argv[1] in ['n','t','b']:
|
14 |
+
data['--model-type'] = argv[1]
|
15 |
+
startParse = 2
|
16 |
+
for i in range(startParse, len(argv)):
|
17 |
+
if argv[i] in data:
|
18 |
+
if argv[i] == '--kaggle':
|
19 |
+
data[argv[i]] = True
|
20 |
+
elif argv[i] == '--model-type':
|
21 |
+
data[argv[i]] = argv[i+1]
|
22 |
+
else:
|
23 |
+
data[argv[i]] = int(argv[i+1])
|
24 |
+
return data
|
25 |
+
sysArgs = parseArgv(sys.argv)
|
26 |
+
|
27 |
+
VOCAB_SIZE = sysArgs['--vocab-size']
|
28 |
+
NGRAM_N = sysArgs['--ngram-n']
|
29 |
+
TRANSFORMER_N = sysArgs['--transformer-n']
|
30 |
+
MODEL_TYPE = sysArgs['--model-type'] # n: ngram, t: transformer, b: bard
|
31 |
+
KAGGLE = sysArgs['--kaggle']
|
32 |
+
TOKEN_SKIP = 2 if MODEL_TYPE == 'n' else TRANSFORMER_N-1
|
33 |
+
RHYME_STACK_SIZE = sysArgs['--rhyme-size']
|
34 |
+
METER_STACK_SIZE = sysArgs['--meter-size']
|
35 |
+
VOWEL_TYPES = 14
|
36 |
+
CONSONANT_TYPES = 10
|
37 |
+
|
38 |
+
TITLE = " <TITLE> "
|
39 |
+
NEWLINE = " <NEWLINE> "
|
40 |
+
|
41 |
+
BRITISH_OUR = ['neighbor','color','flavor','splendor','labor','favor','fervor','savior','vapor','endeavor','parlor',
|
42 |
+
'clamor','harbor','splendor','behavior','rumor','humor','savor','valor','armor','honor','odor']
|
43 |
+
|
44 |
+
est_set = set(['for','t','n','liv','b','di','j','r','p','v','w','b','gu',
|
45 |
+
'l','eld','pr','inter','sever','hug','earn','smil',
|
46 |
+
'qu','ch','bl','conqu','pri'])
|
47 |
+
ed_set = set(['he','you','they','we','will','mov','w','wretch','fe','wav','gre',
|
48 |
+
'till','far','fell','de','b','f','l','re','hopp','ne','br',
|
49 |
+
'mann','bann','bl','pleas','mark','m','sh','se','spe','ble',
|
50 |
+
'lov','ste','rous','arm','bar','di','unmov','asham','cre'])
|
51 |
+
d_set = set(['be', 'she','we','see','re','fe','rowe','fee','le','seale','dee','ne',
|
52 |
+
'reveale','traine','warme','coole','saile','sweate','mowe','cooke',
|
53 |
+
'gree','warne','aire','seate','ree','temp','doome','helpe','feare',
|
54 |
+
'neare','designe','adde','parte','repeate','gaine','parke','mourne',
|
55 |
+
'backe','cleane','raine','charme','climbe','wee','fle','barbe','roote',
|
56 |
+
'waite','fixe','hee','ende','wounde','pointe','earne','cree','matte',
|
57 |
+
'kisse','haire','marke','neede','summe','farme','poure','owne','showe',
|
58 |
+
'crowne','entere','evene','turne','crouche','laye','jade','recorde',
|
59 |
+
'flowe','looke','nee','calle','learne','spe','ble','fille','washe',
|
60 |
+
'boxe','talke','returne','sacre','dreame','pulle','seeme','calle',
|
61 |
+
'prie','forme','ruine','lighte','appeare','adorne','aske','locke',
|
62 |
+
'crosse','misse','arme','towe','shoute','heade','burne','faile','bowe',
|
63 |
+
'rolle','walke','heape','obtaine'])
|
64 |
+
c_ed_set = set(['ad','cares','jag','pis','kis','mat','er','mis','cal','pas','fil','wo'])
|
65 |
+
y_ed_set = set(['drapery','city','weary'])
|
66 |
+
s_set = set(['','a','i','it','his','her',"'",'their','one','will','your','our','down','pant','wa',
|
67 |
+
'god','well','other','saw','good','new','ye','leave','right','wood',
|
68 |
+
'ha','thi','hi','jesu','riche','specie','alway','ala','grasse','glorie',
|
69 |
+
'goe','doe','mas','pis','mi','pi','selve','wherea','prie','masse',
|
70 |
+
'beautie','jame','misse','san','la','lo','politic','u','ga','bu','tos',
|
71 |
+
'len'])
|
72 |
+
st_set = set(['be','we','ne','re','tempe','le','mode', 'fore','le','que','riche','cre','pe',
|
73 |
+
'harde','sweete','cleane','je','te','che','highe','earne','deepe','meane','prie',
|
74 |
+
'olde'])
|
75 |
+
c_est_set = set(['ful','smal'])
|
76 |
+
er_set = set(['with','she','h','quak','curr','hopp','minist','eth','thund','whisp','whit',
|
77 |
+
'fev','rememb','inn','rend','de','beak','wand','port','heath','clos','should',
|
78 |
+
'wrapp','cap','cow','lett','moth','chart','prop','danc','dinn','slumb','tend',
|
79 |
+
'sever','ladd','falt','eld','aft','hind','flatt','murd','show','flow','sob',
|
80 |
+
'pray','s','numb','pond','ev','und','wint','shiv','ang','fin','hov','teach',
|
81 |
+
'clov','ov','oth','riv','barb','post','nev','discov','wat','draw','wait',
|
82 |
+
'suff','deliv','quiv','silv','cov','shelt','los','m','slipp','batt','plast',
|
83 |
+
'bitt','p','be','pe','ti','pi','ve','se','us','ton','min','sew','lit','tig',
|
84 |
+
'lat','inn','out','off','ent','low','pow','less','wond','mann','care','lov',
|
85 |
+
'rath','form','summ','bett','found','quart','tap','pap','record','shudd','pitch',
|
86 |
+
'shatt','tatt','rid','butt','mis','bould','bord','glimm','answ','wav','walk',
|
87 |
+
'glitt','gath','stick','care','temp','fish','corn','flick','dress','feath','met',
|
88 |
+
'broth','both','lock','tow','conqu','che','encount','head','alt','mutt','san'])
|
89 |
+
c_er_set = set(['of','in','but','up','man','let','shut','sum','slip','din','flit',
|
90 |
+
'mat','bat','bit','lad','ban','bet','ad','flat','pe','ful','smal','up',
|
91 |
+
'pis','kis','slip','lat','cop','begin','shud','washe','shat','tat','lit',
|
92 |
+
'glim','lay','lad','cal','glit','pas','fil','ham','sup','pep','rub','chat',
|
93 |
+
'skip','alte','flut','mut','scat','dip','stag','wo'])
|
94 |
+
r_set = set(['he',"'re",'rule','cottage','quake','cove','clove','warble','prime','lowe',
|
95 |
+
'cape','tempe','late','e','rive','dee','eve','wave','me','rathe','meter',
|
96 |
+
'anothe','mothe','mowe','sweate','saile','leade','hithe','warme','coole',
|
97 |
+
'reaveale','traine','chee','manne','shee','uppe','withe','designe','neare',
|
98 |
+
'barbe','darke','banne','pete','faste','soone','oute','rende','parke',
|
99 |
+
'keepe','lee','rooste','cleane','sweete','bothe','harde','sleepe','poste',
|
100 |
+
'loude','climbe','flowe','drawe','waite','highe','lathe','summe','fathe',
|
101 |
+
'cove','farme','lose','showe','deepe','longe','hove','teache','pe','rule',
|
102 |
+
'freeze','compute','consume','recorde','fille','washe','boxe','talke',
|
103 |
+
'spide','meane','outside','inside','laye','lighte','reade','ladde',
|
104 |
+
'eage','forme','coppe','answe','aske','dinne','wave','glitte','feve',
|
105 |
+
'butte','gathe','pape','broke','matte','time','locke','olde','towe','inne',
|
106 |
+
'shoute','heade','cunne','burne','singe','mutte','rolle','dippe','walke'])
|
107 |
+
ing_set = set(['','us','s','st','n','wan','din','k','heav','w','morn','cloth','br','wav',
|
108 |
+
'even','cl','noth','charm','th','spr','bl','p','r','d','tempt','m','s','z',
|
109 |
+
'ch','mean','exact','bless','train','lov','str','build','pleas','slid','light',
|
110 |
+
'stock','feel','bo','gap'])
|
111 |
+
c_ing_set = set(['er','wed','ad','ear','begin','pis','kis','er','mis','cal','pas','fil'])
|
112 |
+
e_ing_set = set(['the','we','bee','bore','lute','ne','re','please','displease','tide','clothe','ke',
|
113 |
+
'neare','wounde','che','feare','doome','helpe','designe','evene','dye',
|
114 |
+
'adde','parte','repeate','gaine','parke','mourne','backe','cleane','charme',
|
115 |
+
'climbe','waite','fixe','raine','ende','wounde','pointe','earne','neede',
|
116 |
+
'summe','poure','owne','crowne','entere','turne','crouche','ble','laye',
|
117 |
+
'recorde','flowe','calle','morne','learne','fille','washe','boxe','talke',
|
118 |
+
'kisse','returne','dreame','pulle','seeme','matte','forme','meane','ruine',
|
119 |
+
'lighte','reade','appeare','adorne','stocke','aske','locke','calle','crosse',
|
120 |
+
'misse','towe','shoute','feele','heade','burne','singe','faile','bowe',
|
121 |
+
'rolle','walke','heape','obtaine'])
|
122 |
+
y_s_set = set(['ry'])
|
123 |
+
y_er_set = set(['by'])
|
124 |
+
y_est_set = set(['pry'])
|
125 |
+
|
126 |
+
BANNED_TOKENS = ['1','2','3','y','e','l','maud','olaf','lorenzo','de','oscar',
|
127 |
+
'r','d','f','p','agnes','eulalie','kate','niam','thel','asius',
|
128 |
+
'saadi','\\\\','juanna','johnson','dudù','moore','xanthus',
|
129 |
+
'arjun','pandav','draupadi','bhishma','karna','pandu','bhima',
|
130 |
+
'duryodhan','drona','abhimanyu','yudhishthir','agamemnon','narad',
|
131 |
+
'antilochus','diomed','helen','ulysses','achilles','nestor',
|
132 |
+
'menelaus','patroclus','hector','aeneas','laertes','priam',
|
133 |
+
'penelope','eumaeus','telemachus','euryclea','sarpedon','peleus',
|
134 |
+
'polydamas','glaucus','antenor','idomeneus','rishi','boreas',
|
135 |
+
'phaeacian','savitri','kuru','diana','panchala','ida','ithaca',
|
136 |
+
'matsya','pritha','salya','kripa','hastina','sisupala','vidura',
|
137 |
+
'dhrita','rashtra','jayadratha','lamia','medon','highth','haydée',
|
138 |
+
'haidée', 'edward','ithacus',
|
139 |
+
'lenore','à','negro','juan','harold','etc','allan','adeline',
|
140 |
+
'+++++++++++++','c','j','h','4','5','6','7','8','9','10',
|
141 |
+
'11','12','*','x','b','/','k','g','ii','s','u','da','el',
|
142 |
+
'le','que','~','000','m','thu','thir','13','14','15','16','17',
|
143 |
+
'18','19','20','30','th','bu','ri','w','v','al','iv','wi',
|
144 |
+
'la','las','t','ma','ha','mee','ne','em','ry','di','st',
|
145 |
+
'yr','ful','iii','bo','faire','tos','ai','en','et','sug',
|
146 |
+
'ga','wel','hee','hon','n','wan','ut','te','ad','hym','na']
|
147 |
+
PUNCT = set(['.', ',', '!', '?', ':', ';', '-'])
|
148 |
+
VOWELS = set(['a','e','i','o','u'])
|
149 |
+
SOMETIMES_VOWELS = VOWELS.union(['y','w'])
|
150 |
+
|
151 |
+
DEFINED_RHYMES = {
|
152 |
+
"'ll": [4,1], "=er": [13,0], "the": [4,-1], 'a': [4,-1], 'we': [8,-1], 'ye': [8,-1], 'e': [8,-1],
|
153 |
+
'zimbabwe': [7,-1], 'one': [4,2], 'two': [11,-1], 'oh': [10,-1], 'ah': [12,-1], 'i': [9,-1],
|
154 |
+
'you': [11,-1], 'own': [10,2], 'know': [10,-1], 'do': [11,-1], 'upon': [3,2], 'whereon': [3,2],
|
155 |
+
'world': [13,4], 'learn': [13,2], 'earn': [13,2], 'yearn': [13,2], 'of': [4,5], 'service': [4,6],
|
156 |
+
'practice': [4,6], 'police': [8,6], 'through': [11,-1], 'tough': [4,5], 'enough': [4,5],
|
157 |
+
'thorough': [10,-1], 'dough': [10,-1], 'rough': [4,5], 'cough': [3,5], 'snow': [10,-1],
|
158 |
+
'w': [11,-1], 'walk': [3,7], 'talk': [3,7], 'son':[4,2], 'iron': [13,2], 'anon': [3,2],
|
159 |
+
'full': [11,1], 'pull': [11,1], 'bull': [11,1], 'put': [11,1], 'push': [11,6], 'book': [11,7],
|
160 |
+
'won': [4,2], 'what': [4,4], 'who': [11,-1], 'whose': [11,6], 'where': [7,0], 'there': [7,0],
|
161 |
+
'their': [7,0], 'theirs': [7,6], 'bear': [7,0], 'wear': [7,0], 'show': [10,-1], 'tow': [10,-1],
|
162 |
+
'sow': [10,-1], 'brow': [5,-1], 'prow': [5,-1], 'allow': [5,-1], 'laugh': [0,5],
|
163 |
+
'elbow': [10,-1], 'window': [10,-1], 'rainbow': [10,-1], 'shadow': [10,-1], 'ancient': [1,4],
|
164 |
+
'meant': [1,4], 'dreamt': [1,4], 'learnt': [13,4], 'hymn': [2,2], 'could': [11,4], 'should': [11,4],
|
165 |
+
'to': [11,-1], 'was': [4,6], 'were': [13,0], 'love': [4,5], 'eye': [9,-1], 'bury': [8,-1],
|
166 |
+
'your': [11,0], 'heart': [12,4], 'some': [4,2], 'come': [4,2], 'from': [4,2], 'become': [4,2],
|
167 |
+
'would': [11,4], 'pour': [10,0],'figure': [13,0], 'author': [4,0], 'sure': [11,0], 'rhythm': [4,2],
|
168 |
+
'every': [8,-1], 'very': [8,-1], 'many': [8,-1], 'any': [8,-1], 'busy': [8,-1], 'easy': [8,-1],
|
169 |
+
'happy': [8,-1], 'live': [2,5], 'into': [11,-1], 'soul': [10,2], 'only': [8,-1], 'earth': [13,10],
|
170 |
+
'though': [10,-1], 'thought': [3,4], 'bought': [3,4], 'brought': [3,4], 'ought': [3,4],
|
171 |
+
'said': [1,4], 'dead': [1,4], 'word': [13,4], 'heard': [13,4], 'death': [1,10], 'head': [1,4],
|
172 |
+
'once': [4,6], 'great': [7,4], 'young': [4,2], 'among': [4,2], 'yon': [3,2], 'wh': [-1,-1],
|
173 |
+
'door': [10,0], 'find': [9,4], 'mind': [9,4], 'kind': [9,4], 'behind': [9,4], 'blind': [9,4],
|
174 |
+
'wild': [9,4], 'give': [2,5], 'beauty': [8,-1], 'duty': [8,-1], 'move': [11,5], 'above': [4,5],
|
175 |
+
'prove': [11,5], 'have': [0,5], 'whom': [11,2], 'warm': [10,2], 'done': [4,2], 'gone': [3,2],
|
176 |
+
'behind': [9,4], 'none': [4,2], 'most': [10,4], 'ghost': [10,4], 'host': [10,4], 'post': [10,4],
|
177 |
+
'travel': [4,1], 'broad': [3,4],'veil': [7,1],'tread': [1,4], 'bread': [1,4], 'ocean': [4,2],
|
178 |
+
'truth': [11,10], 'human': [4,2], 'woman': [4,2], 'unto': [11,-1], 'worm': [13,4], 'blood': [4,4],
|
179 |
+
'instead': [1,4], 'spread': [1,4], 'ahead': [1,4], 'breadth': [1,10], 'breath': [1,10],
|
180 |
+
'valley': [8,-1], 'key': [8,-1], 'journey': [8,-1], 'honey': [8,-1], 'money': [8,-1],
|
181 |
+
'chimney': [8,-1], 'monkey': [8,-1], 'donkey': [8,-1], 'alley': [8,-1], 'trolley': [8,-1],
|
182 |
+
'galley': [8,-1], 'silly': [8,-1], 'lily': [8,-1], 'barley': [8,-1], 'quiet': [4,4],
|
183 |
+
'else': [1,1], 'christian': [4,2], 'shadow': [10,-1], 'meadow': [10,-1], 'mow': [10,-1],
|
184 |
+
'bestow': [10,-1], 'widow': [10,-1], 'friend': [1,4], 'source': [10,6], 'course': [10,6],
|
185 |
+
'lyre': [9,0], 'curse': [13,6], 'rehearse': [13,6], 'are': [12,0], 'genuine': [2,2],
|
186 |
+
'fly': [9,-1], 'july': [9,-1], 'reply': [9,-1], 'butterfly': [9,-1], 'ply': [9,-1],
|
187 |
+
'supply': [9,-1], 'folk': [10,7], 'welcome': [4,2], 'wash': [3,6], 'child': [9,4],
|
188 |
+
'deaf': [1,4], 'league': [8,7], 'plague': [7,7], 'vague': [7,7], 'overhead': [1,4]
|
189 |
+
}
|
190 |
+
DEFINED_METERS = {
|
191 |
+
"'re": 0, "'ve": 0, 'shakespeare': 2, 'every': 2, 'leaves': 1, 'evening': 2,
|
192 |
+
'tongue': 1, 'lovely': 2, 'quiet': 2, 'people': 2, 'something': 2,
|
193 |
+
'beautiful': 3, 'lyre': 1, 'hymn': 1, 'forego': 2, 'therefore': 2,
|
194 |
+
'somewhere': 2
|
195 |
+
}
|
196 |
+
for word in BRITISH_OUR:
|
197 |
+
DEFINED_RHYMES[word] = [4,0]
|
lemmas/ed.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aca3dce224618296c3e20468ecbcc85dbe80643577560ad38428f023fbdede1c
|
3 |
+
size 5616
|
lemmas/er.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a1100564b80079fef9223fce5dcabbdef5a76acc75d82bcc2bceb038a8b60380
|
3 |
+
size 7715
|
lemmas/est.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e22c89fd4f66d10760be51319efc98ae4f28b2c78ee08ffc4b2acaa6e251589
|
3 |
+
size 2641
|
lemmas/ing.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2917172bddf0147b993e4aab8383a0b95c2df0fba32dbc94821079ce1cbbe5cf
|
3 |
+
size 15719
|
lemmas/lemmas.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4278d8ab3ff2cc6e234d9257ca360e03cd0b3185d085067e9701e2d8b57033ed
|
3 |
+
size 213120
|
lemmas/s.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:195dc16152a9d6017d298dcba3dda551dda7ee39224fdb784f58b37f38901e5f
|
3 |
+
size 4658
|
model.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow import keras
|
5 |
+
from keras.layers import Dense, Flatten, Dropout, Embedding,\
|
6 |
+
Add, MultiHeadAttention, LayerNormalization, Input, Softmax
|
7 |
+
import sys
|
8 |
+
|
9 |
+
from constants import *
|
10 |
+
from tokens import pretty_tokens, rhymeMeterFromTokens
|
11 |
+
|
12 |
+
EPOCHS = 10
|
13 |
+
WARMUP_STEPS = 800
|
14 |
+
EMBED_DIM = 512
|
15 |
+
TRANSFORMER_LAYERS = 8
|
16 |
+
TRANSFORMER_DFF = 1024
|
17 |
+
RHYME_METER_DFF = 64
|
18 |
+
TRANSFORMER_HEADS = 4
|
19 |
+
VAL_SPLIT = 0.2
|
20 |
+
BATCH_SIZE = 256
|
21 |
+
SAVE_AT_END = False
|
22 |
+
VERBOSE = False
|
23 |
+
TRAINING = True
|
24 |
+
|
25 |
+
if '--epochs' in sys.argv:
|
26 |
+
EPOCHS = int(sys.argv[sys.argv.index('--epochs')+1])
|
27 |
+
if '--warmup-steps' in sys.argv:
|
28 |
+
WARMUP_STEPS = int(sys.argv[sys.argv.index('--warmup-steps')+1])
|
29 |
+
if '--embed-dim' in sys.argv:
|
30 |
+
EMBED_DIM = int(sys.argv[sys.argv.index('--embed-dim')+1])
|
31 |
+
if '--transformer-layers' in sys.argv:
|
32 |
+
TRANSFORMER_LAYERS = int(sys.argv[sys.argv.index('--transformer-layers')+1])
|
33 |
+
if '--transformer-dff' in sys.argv:
|
34 |
+
TRANSFORMER_DFF = int(sys.argv[sys.argv.index('--transformer-dff')+1])
|
35 |
+
if '--rhyme-meter-dff' in sys.argv:
|
36 |
+
RHYME_METER_DFF = int(sys.argv[sys.argv.index('--rhyme-meter-dff')+1])
|
37 |
+
if '--transformer-heads' in sys.argv:
|
38 |
+
TRANSFORMER_HEADS = int(sys.argv[sys.argv.index('--transformer-heads')+1])
|
39 |
+
if '--val-split' in sys.argv:
|
40 |
+
VAL_SPLIT = float(sys.argv[sys.argv.index('--val-split')+1])
|
41 |
+
if '--batch-size' in sys.argv:
|
42 |
+
BATCH_SIZE = int(sys.argv[sys.argv.index('--batch-size')+1])
|
43 |
+
if '--save-at-end' in sys.argv:
|
44 |
+
SAVE_AT_END = True
|
45 |
+
if '--verbose' in sys.argv:
|
46 |
+
VERBOSE = True
|
47 |
+
if '--load' in sys.argv:
|
48 |
+
TRAINING = False
|
49 |
+
|
50 |
+
N = NGRAM_N if MODEL_TYPE == 'n' else TRANSFORMER_N
|
51 |
+
VOCAB = list(np.load('lemmas/lemmas.npy'))
|
52 |
+
TEST_PROMPT = '<title> stop =ing by woods on a snowy evening <newline> '+\
|
53 |
+
'whose woods these are i think i know <newline> '+\
|
54 |
+
'his house is in the village though <newline> he'
|
55 |
+
|
56 |
+
def sampleVocab(dist, temperature):
|
57 |
+
temperature = 1e-8 if temperature == 0 else temperature
|
58 |
+
dist = np.power(dist, temperature)
|
59 |
+
dist /= np.sum(dist)
|
60 |
+
sample = np.random.choice(np.arange(VOCAB_SIZE), p=dist)
|
61 |
+
return sample
|
62 |
+
|
63 |
+
def genTokens(model, tokens, temperature=0.7, prompt=None):
|
64 |
+
res = [model.vocab.index(TITLE.lower()[1:-1])]
|
65 |
+
if prompt is not None:
|
66 |
+
res = [model.vocab.index(x) for x in prompt.split(' ') if x in model.vocab]
|
67 |
+
for _ in range(tokens):
|
68 |
+
pred = model.generate(res, temperature)
|
69 |
+
assert pred is not None
|
70 |
+
res.append(pred)
|
71 |
+
res = list(map(lambda token: model.vocab[token], res))
|
72 |
+
return res
|
73 |
+
|
74 |
+
class LinearModel(keras.Model):
|
75 |
+
def __init__(self):
|
76 |
+
super(LinearModel, self).__init__()
|
77 |
+
self.vocab = VOCAB
|
78 |
+
self.seq = keras.Sequential([
|
79 |
+
Input(shape=(NGRAM_N-1, VOCAB_SIZE)),
|
80 |
+
Flatten(),
|
81 |
+
Dense(1024, activation='relu'),
|
82 |
+
Dense(1024, activation='relu'),
|
83 |
+
Dense(2048, activation='relu'),
|
84 |
+
Dropout(0.2),
|
85 |
+
Dense(VOCAB_SIZE, activation='softmax')
|
86 |
+
])
|
87 |
+
|
88 |
+
def call(self, input):
|
89 |
+
x = tf.one_hot(input, VOCAB_SIZE)
|
90 |
+
x = self.seq(x)
|
91 |
+
return x
|
92 |
+
|
93 |
+
def generate(self, fullContext, temperature=0.7):
|
94 |
+
context = fullContext[-(N-1):]
|
95 |
+
while len(context) > NGRAM_N-1:
|
96 |
+
context.pop(0)
|
97 |
+
while len(context) < NGRAM_N-1:
|
98 |
+
context.append(-1)
|
99 |
+
context = np.asarray([context])
|
100 |
+
pred = self.call(context)[0]
|
101 |
+
pred = sampleVocab(pred, temperature)
|
102 |
+
return pred
|
103 |
+
|
104 |
+
|
105 |
+
def positional_encoding(length, depth):
|
106 |
+
depth = depth / 2
|
107 |
+
positions = np.arange(length)[:, np.newaxis]
|
108 |
+
depths = np.arange(depth)[np.newaxis, :]/depth
|
109 |
+
angle_rates = 1 / (10000**depths)
|
110 |
+
angle_rads = positions * angle_rates
|
111 |
+
pos_encoding = np.concatenate(
|
112 |
+
[np.sin(angle_rads), np.cos(angle_rads)],
|
113 |
+
axis=-1)
|
114 |
+
return tf.cast(pos_encoding, dtype=tf.float32)
|
115 |
+
|
116 |
+
class InputEmbedding(keras.layers.Layer):
|
117 |
+
def __init__(self):
|
118 |
+
super().__init__()
|
119 |
+
self.embed = Embedding(input_dim=VOCAB_SIZE+1, output_dim=EMBED_DIM)
|
120 |
+
self.pos = positional_encoding(length=TRANSFORMER_N, depth=EMBED_DIM)
|
121 |
+
self.add = Add()
|
122 |
+
self.dropout = Dropout(0.1)
|
123 |
+
def call(self, input):
|
124 |
+
length = tf.shape(input)[1]
|
125 |
+
x = self.embed(input)
|
126 |
+
x *= tf.math.sqrt(tf.cast(EMBED_DIM, tf.float32))
|
127 |
+
x = self.add([x, self.pos[tf.newaxis, :length, :]])
|
128 |
+
x = self.dropout(x)
|
129 |
+
return x
|
130 |
+
|
131 |
+
class AttentionBlock(keras.layers.Layer):
|
132 |
+
def __init__(self, **kwargs):
|
133 |
+
super().__init__()
|
134 |
+
self.mha = MultiHeadAttention(**kwargs)
|
135 |
+
self.dropout = Dropout(0.1)
|
136 |
+
self.norm = LayerNormalization()
|
137 |
+
self.add = Add()
|
138 |
+
def call(self, input):
|
139 |
+
x = self.mha(query=input, value=input, key=input, use_causal_mask=True)
|
140 |
+
x = self.dropout(x)
|
141 |
+
x = self.add([input, x])
|
142 |
+
x = self.norm(x)
|
143 |
+
return x
|
144 |
+
|
145 |
+
class FeedForward(keras.layers.Layer):
|
146 |
+
def __init__(self, dff):
|
147 |
+
super().__init__()
|
148 |
+
self.seq = keras.Sequential([
|
149 |
+
Dense(dff, activation='relu'),
|
150 |
+
Dense(EMBED_DIM),
|
151 |
+
Dropout(0.1)
|
152 |
+
])
|
153 |
+
self.add = Add()
|
154 |
+
self.norm = LayerNormalization()
|
155 |
+
def call(self, input):
|
156 |
+
x = self.add([input, self.seq(input)])
|
157 |
+
x = self.norm(x)
|
158 |
+
return x
|
159 |
+
|
160 |
+
class Decoder(keras.layers.Layer):
|
161 |
+
def __init__(self, *, num_layers, num_heads, dff):
|
162 |
+
super(Decoder, self).__init__()
|
163 |
+
attention = []
|
164 |
+
for _ in range(num_layers):
|
165 |
+
attention.append(AttentionBlock(num_heads=num_heads, key_dim=EMBED_DIM, dropout=0.1))
|
166 |
+
self.attn_seq = keras.Sequential(attention)
|
167 |
+
self.ffn = FeedForward(dff)
|
168 |
+
def call(self, input):
|
169 |
+
x = self.attn_seq(input)
|
170 |
+
x = self.ffn(x)
|
171 |
+
return x
|
172 |
+
|
173 |
+
class TransformerModel(keras.Model):
|
174 |
+
def __init__(self, *, num_layers=TRANSFORMER_LAYERS, num_heads=TRANSFORMER_HEADS, dff=TRANSFORMER_DFF):
|
175 |
+
super(TransformerModel, self).__init__()
|
176 |
+
self.vocab = VOCAB
|
177 |
+
self.embed = InputEmbedding()
|
178 |
+
self.decoder = Decoder(num_layers=num_layers, num_heads=num_heads, dff=dff)
|
179 |
+
self.out = Dense(VOCAB_SIZE, activation='softmax')
|
180 |
+
|
181 |
+
def call(self, input):
|
182 |
+
x = self.embed(input) # context x embedding
|
183 |
+
x = self.decoder(x) # context x embedding
|
184 |
+
x = self.out(x) # context x vocab size
|
185 |
+
try:
|
186 |
+
del x._keras_mask
|
187 |
+
except AttributeError:
|
188 |
+
pass
|
189 |
+
|
190 |
+
return x
|
191 |
+
|
192 |
+
def generate(self, fullContext, temperature=0.7):
|
193 |
+
context = fullContext[-N:]
|
194 |
+
lastToken = len(context)-1
|
195 |
+
while len(context) > TRANSFORMER_N:
|
196 |
+
context.pop(0)
|
197 |
+
while len(context) < TRANSFORMER_N:
|
198 |
+
context.append(-1)
|
199 |
+
context = np.asarray([context])+1
|
200 |
+
pred = self.call(context)[0]
|
201 |
+
pred = pred[lastToken]
|
202 |
+
pred = sampleVocab(pred, temperature)
|
203 |
+
return pred
|
204 |
+
|
205 |
+
|
206 |
+
def rhyme_meter_encoding(input):
|
207 |
+
vowels = input[:,:,:RHYME_STACK_SIZE-1]
|
208 |
+
consonants = input[:,:,RHYME_STACK_SIZE-1:(RHYME_STACK_SIZE-1)*2]
|
209 |
+
rhyme_match = input[:,:,(RHYME_STACK_SIZE-1)*2:(RHYME_STACK_SIZE-1)*3]
|
210 |
+
vowels = tf.cast(vowels, tf.int8)
|
211 |
+
consonants = tf.cast(consonants, tf.int8)
|
212 |
+
vowels = tf.one_hot(vowels, depth=VOWEL_TYPES)
|
213 |
+
consonants = tf.one_hot(consonants, depth=CONSONANT_TYPES)
|
214 |
+
vowels = tf.reshape(vowels, shape=(tf.shape(vowels)[0], tf.shape(vowels)[1], -1))
|
215 |
+
consonants = tf.reshape(consonants, shape=(tf.shape(consonants)[0], tf.shape(consonants)[1], -1))
|
216 |
+
meter = input[:,:,-METER_STACK_SIZE:]
|
217 |
+
vowels = tf.cast(vowels, tf.float32)
|
218 |
+
consonants = tf.cast(consonants, tf.float32)
|
219 |
+
rhyme_match = tf.cast(rhyme_match, tf.float32)
|
220 |
+
meter = tf.cast(meter, tf.float32)
|
221 |
+
rhyme = tf.concat([vowels, consonants, rhyme_match], axis=2)
|
222 |
+
return rhyme, meter
|
223 |
+
|
224 |
+
class RhymeMeterLayer(keras.layers.Layer):
|
225 |
+
def __init__(self):
|
226 |
+
super().__init__()
|
227 |
+
self.dense_r1 = Dense(RHYME_METER_DFF, activation='relu')
|
228 |
+
self.dense_m1 = Dense(RHYME_METER_DFF//2, activation='relu')
|
229 |
+
self.dense_r2 = Dense(RHYME_METER_DFF, activation='relu')
|
230 |
+
# self.dense_m2 = Dense(RHYME_METER_DFF//2, activation='relu')
|
231 |
+
self.dense_3 = Dense(RHYME_METER_DFF*2, activation='relu')
|
232 |
+
self.dense_final = Dense(VOCAB_SIZE)
|
233 |
+
def call(self, input):
|
234 |
+
rhyme, meter = rhyme_meter_encoding(input)
|
235 |
+
rhyme = self.dense_r1(rhyme)
|
236 |
+
rhyme = self.dense_r2(rhyme)
|
237 |
+
meter = self.dense_m1(meter)
|
238 |
+
# meter = self.dense_m2(meter)
|
239 |
+
x = tf.concat([rhyme, meter], axis=2)
|
240 |
+
x = self.dense_3(x)
|
241 |
+
x = self.dense_final(x)
|
242 |
+
return x
|
243 |
+
|
244 |
+
class BardModel(keras.Model):
|
245 |
+
def __init__(self, *, num_layers=TRANSFORMER_LAYERS, num_heads=TRANSFORMER_HEADS, dff=TRANSFORMER_DFF):
|
246 |
+
super(BardModel, self).__init__()
|
247 |
+
self.vocab = VOCAB
|
248 |
+
self.tl = VOCAB.index(TITLE.lower()[1:-1])
|
249 |
+
self.rhyme_types = max(VOWEL_TYPES, CONSONANT_TYPES)
|
250 |
+
self.embed = InputEmbedding()
|
251 |
+
self.decoder = Decoder(num_layers=num_layers, num_heads=num_heads, dff=dff)
|
252 |
+
self.transformer_pred = Dense(VOCAB_SIZE)
|
253 |
+
self.rhyme_meter_pred = RhymeMeterLayer()
|
254 |
+
self.add = Add()
|
255 |
+
self.softmax = Softmax()
|
256 |
+
|
257 |
+
def call(self, input):
|
258 |
+
x = self.embed(input[0])
|
259 |
+
x = self.decoder(x)
|
260 |
+
x = self.transformer_pred(x)
|
261 |
+
try:
|
262 |
+
del x._keras_mask
|
263 |
+
except AttributeError:
|
264 |
+
pass
|
265 |
+
|
266 |
+
rhyme_meter_x = self.rhyme_meter_pred(input[1])
|
267 |
+
x = self.add([x, rhyme_meter_x])
|
268 |
+
x = self.softmax(x)
|
269 |
+
return x
|
270 |
+
|
271 |
+
def generate(self, fullContext, temperature=0.7):
|
272 |
+
context = fullContext[-N:]
|
273 |
+
lastToken = len(context)-1
|
274 |
+
while len(context) > TRANSFORMER_N:
|
275 |
+
context.pop(0)
|
276 |
+
while len(context) < TRANSFORMER_N:
|
277 |
+
context.append(-1)
|
278 |
+
context = np.asarray([context])+1
|
279 |
+
rm = rhymeMeterFromTokens(fullContext, len(fullContext), self.tl, self.vocab)
|
280 |
+
rm = np.asarray([rm])
|
281 |
+
pred = self.call([context, rm])[0]
|
282 |
+
pred = pred[lastToken]
|
283 |
+
pred = sampleVocab(pred, temperature)
|
284 |
+
return pred
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
289 |
+
def __init__(self, d_model, warmup_steps=WARMUP_STEPS):
|
290 |
+
super().__init__()
|
291 |
+
|
292 |
+
self.d_model = d_model
|
293 |
+
self.d_model = tf.cast(self.d_model, tf.float32)
|
294 |
+
|
295 |
+
self.warmup_steps = warmup_steps
|
296 |
+
|
297 |
+
def __call__(self, step):
|
298 |
+
step = tf.cast(step, dtype=tf.float32)
|
299 |
+
arg1 = tf.math.rsqrt(step)
|
300 |
+
arg2 = step * (self.warmup_steps ** -1.5)
|
301 |
+
|
302 |
+
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
|
303 |
+
|
304 |
+
|
305 |
+
def sparse_loss(y_true, y_pred):
|
306 |
+
loss_obj = keras.losses.SparseCategoricalCrossentropy(ignore_class=-1, reduction='none')
|
307 |
+
loss = loss_obj(y_true, y_pred)
|
308 |
+
return loss
|
309 |
+
def sparse_perplexity(y_true, y_pred):
|
310 |
+
return tf.math.exp(tf.math.reduce_mean(sparse_loss(y_true, y_pred)))
|
311 |
+
|
312 |
+
if __name__ == '__main__':
|
313 |
+
fname = {'n': 'inputs/ngram_train.npz',
|
314 |
+
't': 'inputs/transformer_train.npz',
|
315 |
+
'b': 'inputs/bard_train.npz'
|
316 |
+
}[MODEL_TYPE]
|
317 |
+
print("Loading data from", fname)
|
318 |
+
loaded = np.load(fname)
|
319 |
+
train_x = loaded['x']
|
320 |
+
train_y = loaded['y']
|
321 |
+
if MODEL_TYPE == 'b':
|
322 |
+
train_x = [tf.convert_to_tensor(train_x), tf.convert_to_tensor(loaded['rm'])] # rhyme and syllables
|
323 |
+
if MODEL_TYPE == 'n':
|
324 |
+
train_x = tf.convert_to_tensor(train_x, tf.int32)
|
325 |
+
del loaded
|
326 |
+
|
327 |
+
if TRAINING and VERBOSE:
|
328 |
+
if MODEL_TYPE != 'b':
|
329 |
+
print("X:", train_x[10:14])
|
330 |
+
else:
|
331 |
+
print("X:", train_x[0][10:14])
|
332 |
+
print("RM:", train_x[1][10:14][1])
|
333 |
+
print("Y:", train_y[10:14])
|
334 |
+
if MODEL_TYPE != 'b':
|
335 |
+
print("X shape:", train_x.shape)
|
336 |
+
print("Y shape:", train_y.shape)
|
337 |
+
|
338 |
+
print("Initializing model")
|
339 |
+
models = {'n': LinearModel, 't': TransformerModel, 'b': BardModel}
|
340 |
+
model = models[MODEL_TYPE]()
|
341 |
+
if MODEL_TYPE != 'b':
|
342 |
+
res = model(train_x[:1])
|
343 |
+
else:
|
344 |
+
x0 = train_x[0][:1]
|
345 |
+
x1 = train_x[1][:1]
|
346 |
+
res = model([x0, x1])
|
347 |
+
if VERBOSE:
|
348 |
+
print(model)
|
349 |
+
print(res)
|
350 |
+
print(model.summary())
|
351 |
+
|
352 |
+
if TRAINING:
|
353 |
+
print("Compiling model")
|
354 |
+
learning_rate = CustomSchedule(EMBED_DIM)
|
355 |
+
model.compile(optimizer=keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9),
|
356 |
+
loss=sparse_loss, metrics=[sparse_perplexity])
|
357 |
+
|
358 |
+
print("Generating sample from baseline")
|
359 |
+
print(pretty_tokens(genTokens(model, 25)))
|
360 |
+
|
361 |
+
print("Training model")
|
362 |
+
min_perplexity = None
|
363 |
+
if not os.path.exists('saved_models'):
|
364 |
+
os.mkdir('saved_models')
|
365 |
+
class TrainCallback(keras.callbacks.Callback):
|
366 |
+
def on_epoch_end(self, epoch, logs=None):
|
367 |
+
global min_perplexity
|
368 |
+
perplexity = logs['val_sparse_perplexity'] if VAL_SPLIT > 0 else logs['sparse_perplexity']
|
369 |
+
print("\rGenerating sample from model in training: "+
|
370 |
+
"epoch "+str(epoch+1)+", perplexity "+str(round(perplexity, 2)), end='')
|
371 |
+
print(pretty_tokens(genTokens(model, 75)))
|
372 |
+
if (min_perplexity is None or perplexity <= min_perplexity) and not SAVE_AT_END:
|
373 |
+
min_perplexity = perplexity
|
374 |
+
print("Saving model weights")
|
375 |
+
model.save_weights('saved_models/'+MODEL_TYPE+'_model.h5') # no such file or directory right now
|
376 |
+
|
377 |
+
model.fit(train_x, train_y,
|
378 |
+
batch_size=BATCH_SIZE, validation_split=VAL_SPLIT, epochs=EPOCHS,
|
379 |
+
callbacks=[TrainCallback()])
|
380 |
+
|
381 |
+
if SAVE_AT_END:
|
382 |
+
print("Saving final model weights")
|
383 |
+
model.save_weights('saved_models/'+MODEL_TYPE+'_model.h5')
|
384 |
+
|
385 |
+
print("Generating samples from final model")
|
386 |
+
if VERBOSE:
|
387 |
+
for i in range(10):
|
388 |
+
print(pretty_tokens(genTokens(model, 100)))
|
389 |
+
print(pretty_tokens(genTokens(model, 150, prompt=TEST_PROMPT)))
|
390 |
+
print(pretty_tokens(genTokens(model, 500)))
|
391 |
+
print(pretty_tokens(genTokens(model, 500)))
|
392 |
+
|
393 |
+
else:
|
394 |
+
del train_x
|
395 |
+
del train_y
|
396 |
+
print("Loading weights")
|
397 |
+
model.load_weights('saved_models/'+MODEL_TYPE+'_model.h5')
|
398 |
+
|
399 |
+
while True:
|
400 |
+
temp = 0.7
|
401 |
+
print("Commands:\ng: generate sample with 250 tokens\nl: generate sample with custom length\np: generate sample with prompt\nt: set temperature\nq: quit")
|
402 |
+
cmd = input("Enter command: ")
|
403 |
+
try:
|
404 |
+
if cmd == 'g':
|
405 |
+
print("Generating sample...")
|
406 |
+
print(pretty_tokens(genTokens(model, 250, temperature=temp)))
|
407 |
+
if cmd == 'l':
|
408 |
+
length = int(input("Enter length: "))
|
409 |
+
print("Generating sample...")
|
410 |
+
print(pretty_tokens(genTokens(model, length, temperature=temp)))
|
411 |
+
if cmd == 'p':
|
412 |
+
prompt = ""
|
413 |
+
print("Enter prompt as tokens separated by spaces and newlines.")
|
414 |
+
print("Example: <title> stop =ing by woods on a snowy evening\nwhose woods these are i think i know")
|
415 |
+
print("All tokens not in the vocabulary will be ignored.")
|
416 |
+
while not prompt.endswith('\n\n\n'):
|
417 |
+
prompt += input("")+'\n'
|
418 |
+
while prompt.startswith(' ') or prompt.startswith('\n'):
|
419 |
+
prompt = prompt[1:]
|
420 |
+
while prompt.endswith(' ') or prompt.endswith('\n'):
|
421 |
+
prompt = prompt[:-1]
|
422 |
+
prompt = prompt.replace('\n', NEWLINE.lower())
|
423 |
+
length = int(input("Enter length: "))
|
424 |
+
print("Generating sample...")
|
425 |
+
print(pretty_tokens(genTokens(model, length, temperature=temp, prompt=prompt)))
|
426 |
+
if cmd == 't':
|
427 |
+
print("Current temperature:", temp)
|
428 |
+
temp = float(input("New temperature: "))
|
429 |
+
print("Temperature set to", temp)
|
430 |
+
if cmd == 'q':
|
431 |
+
sys.exit(0)
|
432 |
+
except Exception as e:
|
433 |
+
print("Error:", e)
|
saved_models/b_model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f03eab3fd3dd13a08aadbe6a03e7b7e0c10ed7d38534484bc58e126634879a6f
|
3 |
+
size 157786768
|
tokens.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from constants import *
|
5 |
+
if __name__ == '__main__':
|
6 |
+
from threading import Thread
|
7 |
+
|
8 |
+
N_THREADS = 32
|
9 |
+
if '--n_threads' in sys.argv:
|
10 |
+
N_THREADS = int(sys.argv[sys.argv.index('--n_threads')+1])
|
11 |
+
|
12 |
+
if __name__ == '__main__':
|
13 |
+
if not os.path.exists('lemmas'):
|
14 |
+
os.mkdir('lemmas')
|
15 |
+
|
16 |
+
file = open("inputs/join.txt" if not KAGGLE else "inputs/join-kaggle.txt", "r")
|
17 |
+
text = file.read()
|
18 |
+
file.close()
|
19 |
+
|
20 |
+
tokens = text.split(" ")
|
21 |
+
tokens = [x for x in tokens if x != '']
|
22 |
+
print("Total number of tokens:", len(tokens))
|
23 |
+
print("Counting tokens")
|
24 |
+
counts = {}
|
25 |
+
for token in tokens:
|
26 |
+
if not token in counts:
|
27 |
+
counts[token] = 0
|
28 |
+
counts[token] += 1
|
29 |
+
words = list(counts.keys())
|
30 |
+
words.sort(reverse=True, key=lambda word: counts[word])
|
31 |
+
|
32 |
+
for token in BANNED_TOKENS:
|
33 |
+
if token in words:
|
34 |
+
words.remove(token)
|
35 |
+
words.append(token)
|
36 |
+
counts['<unk>'] = 0
|
37 |
+
for word in words:
|
38 |
+
if word in words[:VOCAB_SIZE]:
|
39 |
+
continue
|
40 |
+
counts['<unk>'] += counts[word]
|
41 |
+
words = list(counts.keys())
|
42 |
+
words.sort(reverse=True, key=lambda word: counts[word])
|
43 |
+
for token in BANNED_TOKENS:
|
44 |
+
if token in words:
|
45 |
+
words.remove(token)
|
46 |
+
words.append(token)
|
47 |
+
|
48 |
+
vocab = set(words[:VOCAB_SIZE])
|
49 |
+
else:
|
50 |
+
print("Loading vocab")
|
51 |
+
vocab = set(list(np.load('lemmas/lemmas.npy')))
|
52 |
+
|
53 |
+
def pretty_tokens(tokens, mask=True):
|
54 |
+
s_dict = np.load('lemmas/s.npy', allow_pickle=True).item()
|
55 |
+
ed_dict = np.load('lemmas/ed.npy', allow_pickle=True).item()
|
56 |
+
er_dict = np.load('lemmas/er.npy', allow_pickle=True).item()
|
57 |
+
est_dict = np.load('lemmas/est.npy', allow_pickle=True).item()
|
58 |
+
ing_dict = np.load('lemmas/ing.npy', allow_pickle=True).item()
|
59 |
+
dicts = {'=s': s_dict, '=ed': ed_dict, '=er': er_dict, '=est': est_dict, '=ing': ing_dict}
|
60 |
+
res = []
|
61 |
+
i = 0
|
62 |
+
def includeSpace(this):
|
63 |
+
nonlocal res
|
64 |
+
quote = set(["'", '"'])
|
65 |
+
nospace = set(['\n','-'])
|
66 |
+
prev = res[len(res)-1] if len(res) > 0 else None
|
67 |
+
prev2 = res[len(res)-2] if len(res) > 1 else None
|
68 |
+
space = not prev in nospace\
|
69 |
+
and not this in PUNCT and not this == '\n'\
|
70 |
+
and not (this.startswith("'") and this != "'")
|
71 |
+
if prev in quote and not prev2 in PUNCT:
|
72 |
+
space = False
|
73 |
+
elif this in quote and prev in PUNCT:
|
74 |
+
space = False
|
75 |
+
return space
|
76 |
+
while i < len(tokens):
|
77 |
+
this = tokens[i]
|
78 |
+
if this == NEWLINE.lower()[1:-1]:
|
79 |
+
this = '\n'
|
80 |
+
elif this == TITLE.lower()[1:-1]:
|
81 |
+
this = '\n ༄༅༅ '
|
82 |
+
elif mask and not this in vocab:
|
83 |
+
this = " <unk>"
|
84 |
+
if not includeSpace(this):
|
85 |
+
this = "<unk>"
|
86 |
+
res.append(this)
|
87 |
+
i += 1
|
88 |
+
continue
|
89 |
+
if i+1 < len(tokens):
|
90 |
+
next = tokens[i+1]
|
91 |
+
while next.startswith('='):
|
92 |
+
if next == "=nt":
|
93 |
+
if tokens[i].endswith('n'):
|
94 |
+
this = this[:-1]
|
95 |
+
if tokens[i] == 'will':
|
96 |
+
this = 'wo'
|
97 |
+
elif tokens[i] == 'shall':
|
98 |
+
this = 'sha'
|
99 |
+
this = this+"n't"
|
100 |
+
else:
|
101 |
+
if tokens[i] in dicts[next]:
|
102 |
+
this = dicts[next][this]
|
103 |
+
else:
|
104 |
+
if next[1] == 'e' or next[1] == 'i':
|
105 |
+
if this.endswith('e'):
|
106 |
+
this = this[:-1]
|
107 |
+
elif this.endswith('c'):
|
108 |
+
this = this+'k'
|
109 |
+
if this.endswith('y') and next[1] == 'e' and len(this) > 2 and not this[-2] in VOWELS:
|
110 |
+
this = this[:-1]+'i'
|
111 |
+
if next[1] == 's':
|
112 |
+
if this.endswith('s') or this.endswith('sh') or this.endswith('x') or this.endswith('ch'):
|
113 |
+
this = this+'e'
|
114 |
+
if this.endswith('y') and len(this) > 2 and not this[-2] in VOWELS:
|
115 |
+
this = this[:-1]+'ie'
|
116 |
+
|
117 |
+
this = this+next[1:]
|
118 |
+
i += 1
|
119 |
+
next = tokens[i+1] if i+1 < len(tokens) else ''
|
120 |
+
if this.startswith('='):
|
121 |
+
this = this[1:]
|
122 |
+
elif includeSpace(this):
|
123 |
+
this = " "+this
|
124 |
+
res.append(this)
|
125 |
+
i += 1
|
126 |
+
res = ''.join(res)
|
127 |
+
res = res[1:] if res.startswith(' ') else res
|
128 |
+
return res
|
129 |
+
|
130 |
+
def getRhyme(line):
|
131 |
+
# rhyme format:
|
132 |
+
# final vowel (short AEIO, schwa, long AEIOU, OW, OI, A/schwa before R; total 14)
|
133 |
+
# final consonant (R, L, N/M/NG, P/B, T/D, F/V, S/SH/Z/ZH, K/G, CH/J, TH; total 10)
|
134 |
+
if line is None or len(line) == 0:
|
135 |
+
return [-1, -1]
|
136 |
+
nl = NEWLINE.lower()[1:-1]
|
137 |
+
tl = TITLE.lower()[1:-1]
|
138 |
+
if line[0] == tl:
|
139 |
+
return [-1, -1]
|
140 |
+
while line[-1] == nl or line[-1] in PUNCT or line[-1] == '"' or line[-1] == "'" or line[-1] is None:
|
141 |
+
line = line[:-1]
|
142 |
+
if len(line) == 0:
|
143 |
+
return [-1, -1]
|
144 |
+
word = line[-1]+''
|
145 |
+
long_vowel = False
|
146 |
+
vowel_type = None
|
147 |
+
vowel_map = {'a': 0, 'e': 1, 'i': 2, 'o': 3, 'u': 4, 'ow': 5, 'ou': 5, 'oi': 6, 'oy': 6,
|
148 |
+
'ay': 7, 'ai': 7, 'au': 3, 'aw': 3, 'ea': 8, 'ee': 8, 'eu': 11, 'ew': 11,
|
149 |
+
'oa': 10, 'oo': 11, 'y': 9, 'ey': 7, 'ei': 9}
|
150 |
+
|
151 |
+
# vowel type format:
|
152 |
+
# 0: A, 1: E, 2: I, 3: O, 4: U, 5: OW, 6: OI
|
153 |
+
# short U is schwa
|
154 |
+
# OW, OI are always long
|
155 |
+
# before R: short E/I become schwa, schwa/short A get their own vowel type, short O becomes long O
|
156 |
+
consonant_type = -1
|
157 |
+
cons_map = {'r': 0, 'l': 1, 'n': 2, 'm': 2, 'ng': 2,
|
158 |
+
'p': 3, 'b': 3, 't': 4, 'd': 4, 'f': 5,
|
159 |
+
'v': 5, 's': 6, 'sh': 6, 'z': 6, 'zh': 6,
|
160 |
+
'th': 9, 'k': 7, 'ch': 8, 'j': 8}
|
161 |
+
# consonant type format:
|
162 |
+
# 0: R, 1: L, 2: N/M/NG, 3: P/B, 4: T/D, 5: F/V, 6: S/SH/Z/ZH/TH, 7: K/G, 8: CH/J
|
163 |
+
# total 9 consonant types
|
164 |
+
|
165 |
+
# full vowel type list: (L=long, S=short, R=before R)
|
166 |
+
# 0: AS (bat), 1: ES (bet), 2: IS (bit), 3: OS (bot), 4: US/schwa (but)
|
167 |
+
# 5: OW (bout), 6: OI (boil)
|
168 |
+
# 7: AL (bait), 8: EL (beat), 9: IL (bite), 10: OL (boat), 11: UL (boot)
|
169 |
+
# 12: AR (bar), 13: schwa_R (butter, bird, burn)
|
170 |
+
# total 14 vowel types
|
171 |
+
def getVowel(type, isLong, beforeR):
|
172 |
+
if beforeR and not isLong:
|
173 |
+
if type == 0:
|
174 |
+
return 12
|
175 |
+
if type ==1 or type == 2 or type == 4:
|
176 |
+
return 13
|
177 |
+
if type == 3:
|
178 |
+
return 10
|
179 |
+
if isLong and 0 <= type <= 4:
|
180 |
+
return type+7
|
181 |
+
return type
|
182 |
+
|
183 |
+
lock_consonant = -1
|
184 |
+
if len(line) > 1:
|
185 |
+
if word == '=ed':
|
186 |
+
if line[-2].endswith('t') or line[-2].endswith('d'):
|
187 |
+
return [4, 4]
|
188 |
+
lock_consonant = 4
|
189 |
+
word = line[-2]
|
190 |
+
if word == '=s' or word == "'s":
|
191 |
+
if line[-2].endswith('s') or line[-2].endswith('z') or line[-2].endswith('ch') or line[-2].endswith('sh') or line[-2].endswith('x'):
|
192 |
+
return [4, 6]
|
193 |
+
lock_consonant = 6
|
194 |
+
word = line[-2]
|
195 |
+
elif word == "'re":
|
196 |
+
lock_consonant = 0
|
197 |
+
word = line[-2]
|
198 |
+
elif word == "'ve":
|
199 |
+
lock_consonant = 5
|
200 |
+
word = line[-2]
|
201 |
+
elif word == "'ll":
|
202 |
+
lock_consonant = 1
|
203 |
+
word = line[-2]
|
204 |
+
elif word == "'d":
|
205 |
+
lock_consonant = 4
|
206 |
+
word = line[-2]
|
207 |
+
elif word == "'m":
|
208 |
+
lock_consonant = 2
|
209 |
+
word = line[-2]
|
210 |
+
elif word == "=nt'":
|
211 |
+
lock_consonant = 4
|
212 |
+
word = line[-2]
|
213 |
+
if word in DEFINED_RHYMES:
|
214 |
+
vowel_type = DEFINED_RHYMES[word][0]
|
215 |
+
consonant_type = DEFINED_RHYMES[word][1] if lock_consonant == -1 else lock_consonant
|
216 |
+
return [vowel_type, consonant_type]
|
217 |
+
|
218 |
+
if word.endswith('o'):
|
219 |
+
return [10, lock_consonant]
|
220 |
+
if word.endswith('bble') or word.endswith('ggle'):
|
221 |
+
return [4, 1 if lock_consonant == -1 else lock_consonant]
|
222 |
+
if word.endswith('old'):
|
223 |
+
return [10, 1 if lock_consonant == -1 else lock_consonant]
|
224 |
+
if word.endswith('ance'):
|
225 |
+
return [0, 6 if lock_consonant == -1 else lock_consonant]
|
226 |
+
if word.endswith('ense') or word.endswith('ence'):
|
227 |
+
return [1, 6 if lock_consonant == -1 else lock_consonant]
|
228 |
+
if word.endswith('ince'):
|
229 |
+
return [2, 6 if lock_consonant == -1 else lock_consonant]
|
230 |
+
if word.endswith('ture') or word.endswith('sure'):
|
231 |
+
return [13, 0 if lock_consonant == -1 else lock_consonant]
|
232 |
+
if word.endswith('all'):
|
233 |
+
return [3, 1 if lock_consonant == -1 else lock_consonant]
|
234 |
+
if word.endswith('row') or word.endswith('low'):
|
235 |
+
return [10, lock_consonant]
|
236 |
+
if word.endswith('le') and len(word) >= 3 and not word[-3] in VOWELS:
|
237 |
+
return [4, 1 if lock_consonant == -1 else lock_consonant]
|
238 |
+
if word.endswith('on') and len(word) > 3 and not word.endswith('oon'):
|
239 |
+
return [4, 2 if lock_consonant == -1 else lock_consonant]
|
240 |
+
if word.endswith('al') and len(word) > 3 and not word.endswith('eal'):
|
241 |
+
return [4, 1 if lock_consonant == -1 else lock_consonant]
|
242 |
+
if word.endswith('ous'):
|
243 |
+
return [4, 6 if lock_consonant == -1 else lock_consonant]
|
244 |
+
if word.endswith('ly'):
|
245 |
+
return [8, -1 if lock_consonant == -1 else lock_consonant]
|
246 |
+
if word.endswith('ward'):
|
247 |
+
return [13, 4 if lock_consonant == -1 else lock_consonant]
|
248 |
+
|
249 |
+
if word.endswith('e'):
|
250 |
+
long_vowel = True
|
251 |
+
word = word[:-1]
|
252 |
+
if lock_consonant == -1:
|
253 |
+
if word[-2:] in cons_map:
|
254 |
+
consonant_type = cons_map[word[-2:]]
|
255 |
+
elif word[-1:] in cons_map:
|
256 |
+
consonant_type = cons_map[word[-1:]]
|
257 |
+
elif word[-1] == 'c' and long_vowel:
|
258 |
+
consonant_type = cons_map['s']
|
259 |
+
elif word[-1] == 'g' and long_vowel:
|
260 |
+
consonant_type = cons_map['j']
|
261 |
+
else:
|
262 |
+
consonant_type = lock_consonant
|
263 |
+
|
264 |
+
lock_r = False
|
265 |
+
if not word[-1] in SOMETIMES_VOWELS:
|
266 |
+
while not word[-1] in SOMETIMES_VOWELS:
|
267 |
+
if word.endswith('igh'):
|
268 |
+
return [9, consonant_type]
|
269 |
+
if word[-1] == 'r':
|
270 |
+
lock_r = True
|
271 |
+
elif lock_r:
|
272 |
+
lock_r = False
|
273 |
+
word = word[:-1]
|
274 |
+
if word == '':
|
275 |
+
return [8, lock_consonant]
|
276 |
+
if word[-2:] in vowel_map:
|
277 |
+
vowel_type = vowel_map[word[-2:]]
|
278 |
+
elif word[-1:] in vowel_map:
|
279 |
+
vowel_type = vowel_map[word[-1:]]
|
280 |
+
|
281 |
+
vowel_type = getVowel(vowel_type, long_vowel, consonant_type == 0 or lock_r)
|
282 |
+
return [vowel_type, consonant_type]
|
283 |
+
def pretty_rhyme(rhyme):
|
284 |
+
v_map = ['bat', 'bet', 'bit', 'bot', 'but', 'pout', 'boil', 'bait', 'beat', 'bite', 'boat', 'boot', 'bar', 'sir']
|
285 |
+
c_map = ['R', 'L', 'N/M/NG', 'P/B', 'T/D', 'F/V', 'S/SH/Z/ZH', 'K/G', 'CH/J', 'TH']
|
286 |
+
return "Rhyme is " +\
|
287 |
+
(v_map[rhyme[0]] if rhyme[0] != -1 else '--') + ' ' + (c_map[rhyme[1]] if rhyme[1] != -1 else 'ø')
|
288 |
+
|
289 |
+
|
290 |
+
def getMeter(line):
|
291 |
+
if line is None:
|
292 |
+
return 0
|
293 |
+
res = 0
|
294 |
+
nl = NEWLINE.lower()[1:-1]
|
295 |
+
tl = TITLE.lower()[1:-1]
|
296 |
+
for i in range(len(line)):
|
297 |
+
word = line[i]
|
298 |
+
if word == nl or word == tl or word is None:
|
299 |
+
continue
|
300 |
+
if word in DEFINED_METERS:
|
301 |
+
res += DEFINED_METERS[word]
|
302 |
+
continue
|
303 |
+
if word == '=ed' and i > 0:
|
304 |
+
if line[i-1].endswith('t') or line[i-1].endswith('d') or line[i-1].endswith('te') or line[i-1].endswith('de'):
|
305 |
+
res += 1
|
306 |
+
continue
|
307 |
+
if word == '=s' and i > 0:
|
308 |
+
if line[i-1].endswith('s') or line[i-1].endswith('z') or line[i-1].endswith('ch') or line[i-1].endswith('sh') or line[i-1].endswith('x'):
|
309 |
+
res += 1
|
310 |
+
continue
|
311 |
+
if word.endswith('le') and len(word) >= 3 and not word[-3] in VOWELS:
|
312 |
+
res += 1 # to account for the dropped e
|
313 |
+
removed_e = False
|
314 |
+
if word.endswith('e'):
|
315 |
+
word = word[:-1]
|
316 |
+
removed_e = True
|
317 |
+
if word.endswith('y') and len(word) > 2 and not word[-2] in VOWELS:
|
318 |
+
word = word[:-1]+'i'
|
319 |
+
word = word.replace('ea','i').replace('ee','i')
|
320 |
+
word = word.replace('ai','i').replace('au','o')
|
321 |
+
word = word.replace('eu','u')
|
322 |
+
word = word.replace('ei','i').replace('ie','i')
|
323 |
+
word = word.replace('oa','o').replace('ou','o')
|
324 |
+
word = word.replace('oi','o').replace('oo','u')
|
325 |
+
if word.endswith('tion') or word.endswith('sion') or word.endswith('tian'):
|
326 |
+
word = word[:-4]+'shun'
|
327 |
+
this_count = 0
|
328 |
+
for vowel in VOWELS:
|
329 |
+
this_count += word.count(vowel)
|
330 |
+
if removed_e and this_count == 0:
|
331 |
+
this_count = 1
|
332 |
+
res += this_count
|
333 |
+
return res
|
334 |
+
|
335 |
+
def lastLine(tokens, endl):
|
336 |
+
res = []
|
337 |
+
nl = NEWLINE.lower()[1:-1]
|
338 |
+
i = endl-1
|
339 |
+
while i > 0:
|
340 |
+
if tokens[i] == nl:
|
341 |
+
break
|
342 |
+
i -= 1
|
343 |
+
res = tokens[i:endl]
|
344 |
+
if len(res) == 0:
|
345 |
+
res = tokens[:endl]
|
346 |
+
return res
|
347 |
+
def processRhymeStack(rhyme_stack):
|
348 |
+
prev = rhyme_stack[:-1].flatten(order='F')
|
349 |
+
lastRhyme = rhyme_stack[-1]
|
350 |
+
res = np.zeros(RHYME_STACK_SIZE-1)
|
351 |
+
if lastRhyme[0] != -1:
|
352 |
+
for i in range(RHYME_STACK_SIZE-1):
|
353 |
+
if rhyme_stack[i][0] == lastRhyme[0]:
|
354 |
+
res[i] = 1
|
355 |
+
if rhyme_stack[i][1] == lastRhyme[1]:
|
356 |
+
res[i] = 2
|
357 |
+
res = np.concatenate([prev, res])
|
358 |
+
return res
|
359 |
+
def processRhymeMeter(split):
|
360 |
+
in_title = False
|
361 |
+
meter = []
|
362 |
+
rhymes = []
|
363 |
+
meter_stack = np.zeros(METER_STACK_SIZE, np.int8)
|
364 |
+
rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1
|
365 |
+
tl = TITLE.lower()[1:-1]
|
366 |
+
nl = NEWLINE.lower()[1:-1]
|
367 |
+
for i in range(len(split)):
|
368 |
+
line = lastLine(split, i)
|
369 |
+
if split[i] == tl:
|
370 |
+
in_title = True
|
371 |
+
meter_stack = np.zeros(METER_STACK_SIZE, np.int8)
|
372 |
+
rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1
|
373 |
+
meter.append(meter_stack.copy())
|
374 |
+
rhymes.append(processRhymeStack(rhyme_stack))
|
375 |
+
continue
|
376 |
+
elif in_title and split[i] == nl:
|
377 |
+
in_title = False
|
378 |
+
meter_stack = np.zeros(METER_STACK_SIZE, np.int8)
|
379 |
+
meter_stack[-1] = getMeter(line)
|
380 |
+
meter.append(meter_stack.copy())
|
381 |
+
rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1
|
382 |
+
rhyme_stack[-1] = np.array(getRhyme(line), np.int8)
|
383 |
+
rhymes.append(processRhymeStack(rhyme_stack))
|
384 |
+
meter_stack = np.zeros(METER_STACK_SIZE, np.int8)
|
385 |
+
rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1
|
386 |
+
continue
|
387 |
+
if not in_title and split[i] == nl:
|
388 |
+
rhymes.append(processRhymeStack(rhyme_stack))
|
389 |
+
meter.append(meter_stack.copy())
|
390 |
+
if split[i-1] != nl:
|
391 |
+
rhyme_stack = np.roll(rhyme_stack, -1, axis=0)
|
392 |
+
rhyme_stack[-1] = np.array(getRhyme(line), np.int8)
|
393 |
+
meter_stack = np.roll(meter_stack, -1, axis=0)
|
394 |
+
meter_stack[-1] = getMeter(line)
|
395 |
+
else:
|
396 |
+
meter_stack[-1] = getMeter(line)
|
397 |
+
rhyme_stack[-1] = np.array(getRhyme(line), np.int8)
|
398 |
+
rhymes.append(processRhymeStack(rhyme_stack))
|
399 |
+
meter.append(meter_stack.copy())
|
400 |
+
return [rhymes, meter]
|
401 |
+
|
402 |
+
def rhymeMeterFromTokens(tokens, endl, tl, vocab=None):
|
403 |
+
# used as input for model
|
404 |
+
res = []
|
405 |
+
start = endl-1
|
406 |
+
if len(tokens) >= endl:
|
407 |
+
while start > 0 and tokens[start] != tl:
|
408 |
+
start -= 1
|
409 |
+
lines = tokens[start:endl]
|
410 |
+
while len(lines) < TRANSFORMER_N:
|
411 |
+
lines.append(None)
|
412 |
+
input_lines = lines if vocab is None else [(vocab[x] if (x is not None and 0 <= x < VOCAB_SIZE) else None) for x in lines]
|
413 |
+
rhymes, meter = processRhymeMeter(input_lines)
|
414 |
+
rhymes = rhymes[-TRANSFORMER_N:] # context x RHYME_STACK_SIZE x 2
|
415 |
+
meter = meter[-TRANSFORMER_N:] # context x METER_STACK_SIZE
|
416 |
+
rhymes = np.array(rhymes)
|
417 |
+
meter = np.array(meter)
|
418 |
+
res = np.concatenate([rhymes, meter], axis=1) # context x (RHYME_STACK_SIZE*2 + METER_STACK_SIZE)
|
419 |
+
return res
|
420 |
+
|
421 |
+
if __name__ == '__main__':
|
422 |
+
N = NGRAM_N if MODEL_TYPE == 'n' else TRANSFORMER_N+1
|
423 |
+
for i in range(N-1):
|
424 |
+
tokens.append(None)
|
425 |
+
words.remove('<unk>')
|
426 |
+
print({word: counts[word] for word in words[:VOCAB_SIZE]})
|
427 |
+
title_token = words.index(TITLE.lower()[1:-1])
|
428 |
+
newline_token = words.index(NEWLINE.lower()[1:-1])
|
429 |
+
|
430 |
+
print("Splitting poems with masked dividers")
|
431 |
+
mask_list = [-1]*N
|
432 |
+
splits = []
|
433 |
+
chunk_size = len(tokens)//N_THREADS
|
434 |
+
for i in range(N_THREADS):
|
435 |
+
splits.append(
|
436 |
+
tokens[i*chunk_size : (i+1)*chunk_size if i < N_THREADS-1 else len(tokens)])
|
437 |
+
|
438 |
+
|
439 |
+
results = [None] * N_THREADS
|
440 |
+
threads = []
|
441 |
+
|
442 |
+
def add_dividers(thread_index, split):
|
443 |
+
i = 1
|
444 |
+
while i < len(split):
|
445 |
+
if split[i] == title_token:
|
446 |
+
split = split[:i] + mask_list + split[i:]
|
447 |
+
i += N+5
|
448 |
+
i += 1
|
449 |
+
results[thread_index] = split
|
450 |
+
return split
|
451 |
+
for i in range(N_THREADS):
|
452 |
+
t = Thread(target=add_dividers, args=(i, splits[i],))
|
453 |
+
threads.append(t)
|
454 |
+
t.start()
|
455 |
+
tokens = []
|
456 |
+
for i in range(N_THREADS):
|
457 |
+
threads[i].join()
|
458 |
+
tokens += results[i]
|
459 |
+
|
460 |
+
if MODEL_TYPE == 'b':
|
461 |
+
print("Computing rhyme and meter information")
|
462 |
+
split_token_marks = []
|
463 |
+
split_size = len(tokens)//N_THREADS
|
464 |
+
for i in range(N_THREADS+1):
|
465 |
+
split_token_marks.append(split_size*i)
|
466 |
+
for i in range(1, N_THREADS):
|
467 |
+
while tokens[split_token_marks[i]] != TITLE.lower()[1:-1]:
|
468 |
+
split_token_marks[i] += 1
|
469 |
+
if split_token_marks[i] >= len(tokens):
|
470 |
+
break
|
471 |
+
meter_data = []
|
472 |
+
rhymes_data = []
|
473 |
+
split_token_marks[-1] = len(tokens)
|
474 |
+
split_tokens = [tokens[split_token_marks[i]:split_token_marks[i+1]] for i in range(N_THREADS)]
|
475 |
+
rhyme_meter_res = [None] * N_THREADS
|
476 |
+
threads = []
|
477 |
+
def rhymeMeterThread(thread_index, split):
|
478 |
+
rhyme_meter_res[thread_index] = processRhymeMeter(split)
|
479 |
+
for i in range(N_THREADS):
|
480 |
+
t = Thread(target=rhymeMeterThread, args=(i, split_tokens[i]))
|
481 |
+
threads.append(t)
|
482 |
+
t.start()
|
483 |
+
for i in range(N_THREADS):
|
484 |
+
threads[i].join()
|
485 |
+
rhymes_data += rhyme_meter_res[i][0]
|
486 |
+
meter_data += rhyme_meter_res[i][1]
|
487 |
+
|
488 |
+
print("Converting rhyme and meter information")
|
489 |
+
rhymes_data = np.asarray(rhymes_data)
|
490 |
+
meter_data = np.asarray(meter_data)
|
491 |
+
rhyme_meter_data = np.concatenate([rhymes_data, meter_data], axis=1)
|
492 |
+
|
493 |
+
print("Masking unknown tokens")
|
494 |
+
tokens = [(words.index(x) if x in vocab else -1) for x in tokens]
|
495 |
+
|
496 |
+
print("Creating sets of ngrams")
|
497 |
+
ngrams = []
|
498 |
+
rm_ngrams = []
|
499 |
+
for i in range(0, len(tokens)-N, TOKEN_SKIP):
|
500 |
+
ngrams.append(tokens[i:i+N])
|
501 |
+
if MODEL_TYPE == 'b':
|
502 |
+
rm_ngrams.append(rhyme_meter_data[i:i+N-1,:])
|
503 |
+
train_x = []
|
504 |
+
train_y = []
|
505 |
+
train_rm = []
|
506 |
+
for i in range(len(ngrams)):
|
507 |
+
sample = ngrams[i][:N]
|
508 |
+
train_x.append(sample[:N-1])
|
509 |
+
if MODEL_TYPE == 'b':
|
510 |
+
sample_rm = rm_ngrams[i]
|
511 |
+
train_rm.append(sample_rm)
|
512 |
+
if MODEL_TYPE != 'n':
|
513 |
+
train_y.append(sample[1:])
|
514 |
+
else:
|
515 |
+
train_y.append(sample[N-1])
|
516 |
+
print("Converting arrays")
|
517 |
+
train_x = np.asarray(train_x)
|
518 |
+
train_y = np.asarray(train_y)
|
519 |
+
if MODEL_TYPE == 'b':
|
520 |
+
train_rm = np.asarray(train_rm, np.int8)
|
521 |
+
if MODEL_TYPE != 'n':
|
522 |
+
train_x += 1 # x in [0, VOCAB_SIZE] since 0 is for <unk>
|
523 |
+
# y in [-1, VOCAB_SIZE-1] with VOCAB_SIZE tokens, one for each vocabulary item, and -1 for <unk>
|
524 |
+
|
525 |
+
print("Saving data")
|
526 |
+
fname = {'n': 'inputs/ngram_train.npz',
|
527 |
+
't': 'inputs/transformer_train.npz',
|
528 |
+
'b': 'inputs/bard_train.npz'
|
529 |
+
}[MODEL_TYPE]
|
530 |
+
if MODEL_TYPE != 'b':
|
531 |
+
np.savez_compressed(fname, x=train_x, y=train_y)
|
532 |
+
else:
|
533 |
+
np.savez_compressed(fname, x=train_x, rm=train_rm, y=train_y)
|
534 |
+
np.save('lemmas/lemmas.npy', words[:VOCAB_SIZE])
|