crackalamoo commited on
Commit
1427339
·
1 Parent(s): 7245849

Upload files for inference

Browse files
Files changed (10) hide show
  1. constants.py +197 -0
  2. lemmas/ed.npy +3 -0
  3. lemmas/er.npy +3 -0
  4. lemmas/est.npy +3 -0
  5. lemmas/ing.npy +3 -0
  6. lemmas/lemmas.npy +3 -0
  7. lemmas/s.npy +3 -0
  8. model.py +433 -0
  9. saved_models/b_model.h5 +3 -0
  10. tokens.py +534 -0
constants.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ def parseArgv(argv):
3
+ data = {
4
+ '--model-type': 'b',
5
+ '--vocab-size': 4096,
6
+ '--ngram-n': 4,
7
+ '--transformer-n': 32,
8
+ '--kaggle': False,
9
+ '--rhyme-size': 4,
10
+ '--meter-size': 3,
11
+ }
12
+ startParse = 1
13
+ if len(argv) > 1 and argv[1] in ['n','t','b']:
14
+ data['--model-type'] = argv[1]
15
+ startParse = 2
16
+ for i in range(startParse, len(argv)):
17
+ if argv[i] in data:
18
+ if argv[i] == '--kaggle':
19
+ data[argv[i]] = True
20
+ elif argv[i] == '--model-type':
21
+ data[argv[i]] = argv[i+1]
22
+ else:
23
+ data[argv[i]] = int(argv[i+1])
24
+ return data
25
+ sysArgs = parseArgv(sys.argv)
26
+
27
+ VOCAB_SIZE = sysArgs['--vocab-size']
28
+ NGRAM_N = sysArgs['--ngram-n']
29
+ TRANSFORMER_N = sysArgs['--transformer-n']
30
+ MODEL_TYPE = sysArgs['--model-type'] # n: ngram, t: transformer, b: bard
31
+ KAGGLE = sysArgs['--kaggle']
32
+ TOKEN_SKIP = 2 if MODEL_TYPE == 'n' else TRANSFORMER_N-1
33
+ RHYME_STACK_SIZE = sysArgs['--rhyme-size']
34
+ METER_STACK_SIZE = sysArgs['--meter-size']
35
+ VOWEL_TYPES = 14
36
+ CONSONANT_TYPES = 10
37
+
38
+ TITLE = " <TITLE> "
39
+ NEWLINE = " <NEWLINE> "
40
+
41
+ BRITISH_OUR = ['neighbor','color','flavor','splendor','labor','favor','fervor','savior','vapor','endeavor','parlor',
42
+ 'clamor','harbor','splendor','behavior','rumor','humor','savor','valor','armor','honor','odor']
43
+
44
+ est_set = set(['for','t','n','liv','b','di','j','r','p','v','w','b','gu',
45
+ 'l','eld','pr','inter','sever','hug','earn','smil',
46
+ 'qu','ch','bl','conqu','pri'])
47
+ ed_set = set(['he','you','they','we','will','mov','w','wretch','fe','wav','gre',
48
+ 'till','far','fell','de','b','f','l','re','hopp','ne','br',
49
+ 'mann','bann','bl','pleas','mark','m','sh','se','spe','ble',
50
+ 'lov','ste','rous','arm','bar','di','unmov','asham','cre'])
51
+ d_set = set(['be', 'she','we','see','re','fe','rowe','fee','le','seale','dee','ne',
52
+ 'reveale','traine','warme','coole','saile','sweate','mowe','cooke',
53
+ 'gree','warne','aire','seate','ree','temp','doome','helpe','feare',
54
+ 'neare','designe','adde','parte','repeate','gaine','parke','mourne',
55
+ 'backe','cleane','raine','charme','climbe','wee','fle','barbe','roote',
56
+ 'waite','fixe','hee','ende','wounde','pointe','earne','cree','matte',
57
+ 'kisse','haire','marke','neede','summe','farme','poure','owne','showe',
58
+ 'crowne','entere','evene','turne','crouche','laye','jade','recorde',
59
+ 'flowe','looke','nee','calle','learne','spe','ble','fille','washe',
60
+ 'boxe','talke','returne','sacre','dreame','pulle','seeme','calle',
61
+ 'prie','forme','ruine','lighte','appeare','adorne','aske','locke',
62
+ 'crosse','misse','arme','towe','shoute','heade','burne','faile','bowe',
63
+ 'rolle','walke','heape','obtaine'])
64
+ c_ed_set = set(['ad','cares','jag','pis','kis','mat','er','mis','cal','pas','fil','wo'])
65
+ y_ed_set = set(['drapery','city','weary'])
66
+ s_set = set(['','a','i','it','his','her',"'",'their','one','will','your','our','down','pant','wa',
67
+ 'god','well','other','saw','good','new','ye','leave','right','wood',
68
+ 'ha','thi','hi','jesu','riche','specie','alway','ala','grasse','glorie',
69
+ 'goe','doe','mas','pis','mi','pi','selve','wherea','prie','masse',
70
+ 'beautie','jame','misse','san','la','lo','politic','u','ga','bu','tos',
71
+ 'len'])
72
+ st_set = set(['be','we','ne','re','tempe','le','mode', 'fore','le','que','riche','cre','pe',
73
+ 'harde','sweete','cleane','je','te','che','highe','earne','deepe','meane','prie',
74
+ 'olde'])
75
+ c_est_set = set(['ful','smal'])
76
+ er_set = set(['with','she','h','quak','curr','hopp','minist','eth','thund','whisp','whit',
77
+ 'fev','rememb','inn','rend','de','beak','wand','port','heath','clos','should',
78
+ 'wrapp','cap','cow','lett','moth','chart','prop','danc','dinn','slumb','tend',
79
+ 'sever','ladd','falt','eld','aft','hind','flatt','murd','show','flow','sob',
80
+ 'pray','s','numb','pond','ev','und','wint','shiv','ang','fin','hov','teach',
81
+ 'clov','ov','oth','riv','barb','post','nev','discov','wat','draw','wait',
82
+ 'suff','deliv','quiv','silv','cov','shelt','los','m','slipp','batt','plast',
83
+ 'bitt','p','be','pe','ti','pi','ve','se','us','ton','min','sew','lit','tig',
84
+ 'lat','inn','out','off','ent','low','pow','less','wond','mann','care','lov',
85
+ 'rath','form','summ','bett','found','quart','tap','pap','record','shudd','pitch',
86
+ 'shatt','tatt','rid','butt','mis','bould','bord','glimm','answ','wav','walk',
87
+ 'glitt','gath','stick','care','temp','fish','corn','flick','dress','feath','met',
88
+ 'broth','both','lock','tow','conqu','che','encount','head','alt','mutt','san'])
89
+ c_er_set = set(['of','in','but','up','man','let','shut','sum','slip','din','flit',
90
+ 'mat','bat','bit','lad','ban','bet','ad','flat','pe','ful','smal','up',
91
+ 'pis','kis','slip','lat','cop','begin','shud','washe','shat','tat','lit',
92
+ 'glim','lay','lad','cal','glit','pas','fil','ham','sup','pep','rub','chat',
93
+ 'skip','alte','flut','mut','scat','dip','stag','wo'])
94
+ r_set = set(['he',"'re",'rule','cottage','quake','cove','clove','warble','prime','lowe',
95
+ 'cape','tempe','late','e','rive','dee','eve','wave','me','rathe','meter',
96
+ 'anothe','mothe','mowe','sweate','saile','leade','hithe','warme','coole',
97
+ 'reaveale','traine','chee','manne','shee','uppe','withe','designe','neare',
98
+ 'barbe','darke','banne','pete','faste','soone','oute','rende','parke',
99
+ 'keepe','lee','rooste','cleane','sweete','bothe','harde','sleepe','poste',
100
+ 'loude','climbe','flowe','drawe','waite','highe','lathe','summe','fathe',
101
+ 'cove','farme','lose','showe','deepe','longe','hove','teache','pe','rule',
102
+ 'freeze','compute','consume','recorde','fille','washe','boxe','talke',
103
+ 'spide','meane','outside','inside','laye','lighte','reade','ladde',
104
+ 'eage','forme','coppe','answe','aske','dinne','wave','glitte','feve',
105
+ 'butte','gathe','pape','broke','matte','time','locke','olde','towe','inne',
106
+ 'shoute','heade','cunne','burne','singe','mutte','rolle','dippe','walke'])
107
+ ing_set = set(['','us','s','st','n','wan','din','k','heav','w','morn','cloth','br','wav',
108
+ 'even','cl','noth','charm','th','spr','bl','p','r','d','tempt','m','s','z',
109
+ 'ch','mean','exact','bless','train','lov','str','build','pleas','slid','light',
110
+ 'stock','feel','bo','gap'])
111
+ c_ing_set = set(['er','wed','ad','ear','begin','pis','kis','er','mis','cal','pas','fil'])
112
+ e_ing_set = set(['the','we','bee','bore','lute','ne','re','please','displease','tide','clothe','ke',
113
+ 'neare','wounde','che','feare','doome','helpe','designe','evene','dye',
114
+ 'adde','parte','repeate','gaine','parke','mourne','backe','cleane','charme',
115
+ 'climbe','waite','fixe','raine','ende','wounde','pointe','earne','neede',
116
+ 'summe','poure','owne','crowne','entere','turne','crouche','ble','laye',
117
+ 'recorde','flowe','calle','morne','learne','fille','washe','boxe','talke',
118
+ 'kisse','returne','dreame','pulle','seeme','matte','forme','meane','ruine',
119
+ 'lighte','reade','appeare','adorne','stocke','aske','locke','calle','crosse',
120
+ 'misse','towe','shoute','feele','heade','burne','singe','faile','bowe',
121
+ 'rolle','walke','heape','obtaine'])
122
+ y_s_set = set(['ry'])
123
+ y_er_set = set(['by'])
124
+ y_est_set = set(['pry'])
125
+
126
+ BANNED_TOKENS = ['1','2','3','y','e','l','maud','olaf','lorenzo','de','oscar',
127
+ 'r','d','f','p','agnes','eulalie','kate','niam','thel','asius',
128
+ 'saadi','\\\\','juanna','johnson','dudù','moore','xanthus',
129
+ 'arjun','pandav','draupadi','bhishma','karna','pandu','bhima',
130
+ 'duryodhan','drona','abhimanyu','yudhishthir','agamemnon','narad',
131
+ 'antilochus','diomed','helen','ulysses','achilles','nestor',
132
+ 'menelaus','patroclus','hector','aeneas','laertes','priam',
133
+ 'penelope','eumaeus','telemachus','euryclea','sarpedon','peleus',
134
+ 'polydamas','glaucus','antenor','idomeneus','rishi','boreas',
135
+ 'phaeacian','savitri','kuru','diana','panchala','ida','ithaca',
136
+ 'matsya','pritha','salya','kripa','hastina','sisupala','vidura',
137
+ 'dhrita','rashtra','jayadratha','lamia','medon','highth','haydée',
138
+ 'haidée', 'edward','ithacus',
139
+ 'lenore','à','negro','juan','harold','etc','allan','adeline',
140
+ '+++++++++++++','c','j','h','4','5','6','7','8','9','10',
141
+ '11','12','*','x','b','/','k','g','ii','s','u','da','el',
142
+ 'le','que','~','000','m','thu','thir','13','14','15','16','17',
143
+ '18','19','20','30','th','bu','ri','w','v','al','iv','wi',
144
+ 'la','las','t','ma','ha','mee','ne','em','ry','di','st',
145
+ 'yr','ful','iii','bo','faire','tos','ai','en','et','sug',
146
+ 'ga','wel','hee','hon','n','wan','ut','te','ad','hym','na']
147
+ PUNCT = set(['.', ',', '!', '?', ':', ';', '-'])
148
+ VOWELS = set(['a','e','i','o','u'])
149
+ SOMETIMES_VOWELS = VOWELS.union(['y','w'])
150
+
151
+ DEFINED_RHYMES = {
152
+ "'ll": [4,1], "=er": [13,0], "the": [4,-1], 'a': [4,-1], 'we': [8,-1], 'ye': [8,-1], 'e': [8,-1],
153
+ 'zimbabwe': [7,-1], 'one': [4,2], 'two': [11,-1], 'oh': [10,-1], 'ah': [12,-1], 'i': [9,-1],
154
+ 'you': [11,-1], 'own': [10,2], 'know': [10,-1], 'do': [11,-1], 'upon': [3,2], 'whereon': [3,2],
155
+ 'world': [13,4], 'learn': [13,2], 'earn': [13,2], 'yearn': [13,2], 'of': [4,5], 'service': [4,6],
156
+ 'practice': [4,6], 'police': [8,6], 'through': [11,-1], 'tough': [4,5], 'enough': [4,5],
157
+ 'thorough': [10,-1], 'dough': [10,-1], 'rough': [4,5], 'cough': [3,5], 'snow': [10,-1],
158
+ 'w': [11,-1], 'walk': [3,7], 'talk': [3,7], 'son':[4,2], 'iron': [13,2], 'anon': [3,2],
159
+ 'full': [11,1], 'pull': [11,1], 'bull': [11,1], 'put': [11,1], 'push': [11,6], 'book': [11,7],
160
+ 'won': [4,2], 'what': [4,4], 'who': [11,-1], 'whose': [11,6], 'where': [7,0], 'there': [7,0],
161
+ 'their': [7,0], 'theirs': [7,6], 'bear': [7,0], 'wear': [7,0], 'show': [10,-1], 'tow': [10,-1],
162
+ 'sow': [10,-1], 'brow': [5,-1], 'prow': [5,-1], 'allow': [5,-1], 'laugh': [0,5],
163
+ 'elbow': [10,-1], 'window': [10,-1], 'rainbow': [10,-1], 'shadow': [10,-1], 'ancient': [1,4],
164
+ 'meant': [1,4], 'dreamt': [1,4], 'learnt': [13,4], 'hymn': [2,2], 'could': [11,4], 'should': [11,4],
165
+ 'to': [11,-1], 'was': [4,6], 'were': [13,0], 'love': [4,5], 'eye': [9,-1], 'bury': [8,-1],
166
+ 'your': [11,0], 'heart': [12,4], 'some': [4,2], 'come': [4,2], 'from': [4,2], 'become': [4,2],
167
+ 'would': [11,4], 'pour': [10,0],'figure': [13,0], 'author': [4,0], 'sure': [11,0], 'rhythm': [4,2],
168
+ 'every': [8,-1], 'very': [8,-1], 'many': [8,-1], 'any': [8,-1], 'busy': [8,-1], 'easy': [8,-1],
169
+ 'happy': [8,-1], 'live': [2,5], 'into': [11,-1], 'soul': [10,2], 'only': [8,-1], 'earth': [13,10],
170
+ 'though': [10,-1], 'thought': [3,4], 'bought': [3,4], 'brought': [3,4], 'ought': [3,4],
171
+ 'said': [1,4], 'dead': [1,4], 'word': [13,4], 'heard': [13,4], 'death': [1,10], 'head': [1,4],
172
+ 'once': [4,6], 'great': [7,4], 'young': [4,2], 'among': [4,2], 'yon': [3,2], 'wh': [-1,-1],
173
+ 'door': [10,0], 'find': [9,4], 'mind': [9,4], 'kind': [9,4], 'behind': [9,4], 'blind': [9,4],
174
+ 'wild': [9,4], 'give': [2,5], 'beauty': [8,-1], 'duty': [8,-1], 'move': [11,5], 'above': [4,5],
175
+ 'prove': [11,5], 'have': [0,5], 'whom': [11,2], 'warm': [10,2], 'done': [4,2], 'gone': [3,2],
176
+ 'behind': [9,4], 'none': [4,2], 'most': [10,4], 'ghost': [10,4], 'host': [10,4], 'post': [10,4],
177
+ 'travel': [4,1], 'broad': [3,4],'veil': [7,1],'tread': [1,4], 'bread': [1,4], 'ocean': [4,2],
178
+ 'truth': [11,10], 'human': [4,2], 'woman': [4,2], 'unto': [11,-1], 'worm': [13,4], 'blood': [4,4],
179
+ 'instead': [1,4], 'spread': [1,4], 'ahead': [1,4], 'breadth': [1,10], 'breath': [1,10],
180
+ 'valley': [8,-1], 'key': [8,-1], 'journey': [8,-1], 'honey': [8,-1], 'money': [8,-1],
181
+ 'chimney': [8,-1], 'monkey': [8,-1], 'donkey': [8,-1], 'alley': [8,-1], 'trolley': [8,-1],
182
+ 'galley': [8,-1], 'silly': [8,-1], 'lily': [8,-1], 'barley': [8,-1], 'quiet': [4,4],
183
+ 'else': [1,1], 'christian': [4,2], 'shadow': [10,-1], 'meadow': [10,-1], 'mow': [10,-1],
184
+ 'bestow': [10,-1], 'widow': [10,-1], 'friend': [1,4], 'source': [10,6], 'course': [10,6],
185
+ 'lyre': [9,0], 'curse': [13,6], 'rehearse': [13,6], 'are': [12,0], 'genuine': [2,2],
186
+ 'fly': [9,-1], 'july': [9,-1], 'reply': [9,-1], 'butterfly': [9,-1], 'ply': [9,-1],
187
+ 'supply': [9,-1], 'folk': [10,7], 'welcome': [4,2], 'wash': [3,6], 'child': [9,4],
188
+ 'deaf': [1,4], 'league': [8,7], 'plague': [7,7], 'vague': [7,7], 'overhead': [1,4]
189
+ }
190
+ DEFINED_METERS = {
191
+ "'re": 0, "'ve": 0, 'shakespeare': 2, 'every': 2, 'leaves': 1, 'evening': 2,
192
+ 'tongue': 1, 'lovely': 2, 'quiet': 2, 'people': 2, 'something': 2,
193
+ 'beautiful': 3, 'lyre': 1, 'hymn': 1, 'forego': 2, 'therefore': 2,
194
+ 'somewhere': 2
195
+ }
196
+ for word in BRITISH_OUR:
197
+ DEFINED_RHYMES[word] = [4,0]
lemmas/ed.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aca3dce224618296c3e20468ecbcc85dbe80643577560ad38428f023fbdede1c
3
+ size 5616
lemmas/er.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1100564b80079fef9223fce5dcabbdef5a76acc75d82bcc2bceb038a8b60380
3
+ size 7715
lemmas/est.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e22c89fd4f66d10760be51319efc98ae4f28b2c78ee08ffc4b2acaa6e251589
3
+ size 2641
lemmas/ing.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2917172bddf0147b993e4aab8383a0b95c2df0fba32dbc94821079ce1cbbe5cf
3
+ size 15719
lemmas/lemmas.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4278d8ab3ff2cc6e234d9257ca360e03cd0b3185d085067e9701e2d8b57033ed
3
+ size 213120
lemmas/s.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:195dc16152a9d6017d298dcba3dda551dda7ee39224fdb784f58b37f38901e5f
3
+ size 4658
model.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow import keras
5
+ from keras.layers import Dense, Flatten, Dropout, Embedding,\
6
+ Add, MultiHeadAttention, LayerNormalization, Input, Softmax
7
+ import sys
8
+
9
+ from constants import *
10
+ from tokens import pretty_tokens, rhymeMeterFromTokens
11
+
12
+ EPOCHS = 10
13
+ WARMUP_STEPS = 800
14
+ EMBED_DIM = 512
15
+ TRANSFORMER_LAYERS = 8
16
+ TRANSFORMER_DFF = 1024
17
+ RHYME_METER_DFF = 64
18
+ TRANSFORMER_HEADS = 4
19
+ VAL_SPLIT = 0.2
20
+ BATCH_SIZE = 256
21
+ SAVE_AT_END = False
22
+ VERBOSE = False
23
+ TRAINING = True
24
+
25
+ if '--epochs' in sys.argv:
26
+ EPOCHS = int(sys.argv[sys.argv.index('--epochs')+1])
27
+ if '--warmup-steps' in sys.argv:
28
+ WARMUP_STEPS = int(sys.argv[sys.argv.index('--warmup-steps')+1])
29
+ if '--embed-dim' in sys.argv:
30
+ EMBED_DIM = int(sys.argv[sys.argv.index('--embed-dim')+1])
31
+ if '--transformer-layers' in sys.argv:
32
+ TRANSFORMER_LAYERS = int(sys.argv[sys.argv.index('--transformer-layers')+1])
33
+ if '--transformer-dff' in sys.argv:
34
+ TRANSFORMER_DFF = int(sys.argv[sys.argv.index('--transformer-dff')+1])
35
+ if '--rhyme-meter-dff' in sys.argv:
36
+ RHYME_METER_DFF = int(sys.argv[sys.argv.index('--rhyme-meter-dff')+1])
37
+ if '--transformer-heads' in sys.argv:
38
+ TRANSFORMER_HEADS = int(sys.argv[sys.argv.index('--transformer-heads')+1])
39
+ if '--val-split' in sys.argv:
40
+ VAL_SPLIT = float(sys.argv[sys.argv.index('--val-split')+1])
41
+ if '--batch-size' in sys.argv:
42
+ BATCH_SIZE = int(sys.argv[sys.argv.index('--batch-size')+1])
43
+ if '--save-at-end' in sys.argv:
44
+ SAVE_AT_END = True
45
+ if '--verbose' in sys.argv:
46
+ VERBOSE = True
47
+ if '--load' in sys.argv:
48
+ TRAINING = False
49
+
50
+ N = NGRAM_N if MODEL_TYPE == 'n' else TRANSFORMER_N
51
+ VOCAB = list(np.load('lemmas/lemmas.npy'))
52
+ TEST_PROMPT = '<title> stop =ing by woods on a snowy evening <newline> '+\
53
+ 'whose woods these are i think i know <newline> '+\
54
+ 'his house is in the village though <newline> he'
55
+
56
+ def sampleVocab(dist, temperature):
57
+ temperature = 1e-8 if temperature == 0 else temperature
58
+ dist = np.power(dist, temperature)
59
+ dist /= np.sum(dist)
60
+ sample = np.random.choice(np.arange(VOCAB_SIZE), p=dist)
61
+ return sample
62
+
63
+ def genTokens(model, tokens, temperature=0.7, prompt=None):
64
+ res = [model.vocab.index(TITLE.lower()[1:-1])]
65
+ if prompt is not None:
66
+ res = [model.vocab.index(x) for x in prompt.split(' ') if x in model.vocab]
67
+ for _ in range(tokens):
68
+ pred = model.generate(res, temperature)
69
+ assert pred is not None
70
+ res.append(pred)
71
+ res = list(map(lambda token: model.vocab[token], res))
72
+ return res
73
+
74
+ class LinearModel(keras.Model):
75
+ def __init__(self):
76
+ super(LinearModel, self).__init__()
77
+ self.vocab = VOCAB
78
+ self.seq = keras.Sequential([
79
+ Input(shape=(NGRAM_N-1, VOCAB_SIZE)),
80
+ Flatten(),
81
+ Dense(1024, activation='relu'),
82
+ Dense(1024, activation='relu'),
83
+ Dense(2048, activation='relu'),
84
+ Dropout(0.2),
85
+ Dense(VOCAB_SIZE, activation='softmax')
86
+ ])
87
+
88
+ def call(self, input):
89
+ x = tf.one_hot(input, VOCAB_SIZE)
90
+ x = self.seq(x)
91
+ return x
92
+
93
+ def generate(self, fullContext, temperature=0.7):
94
+ context = fullContext[-(N-1):]
95
+ while len(context) > NGRAM_N-1:
96
+ context.pop(0)
97
+ while len(context) < NGRAM_N-1:
98
+ context.append(-1)
99
+ context = np.asarray([context])
100
+ pred = self.call(context)[0]
101
+ pred = sampleVocab(pred, temperature)
102
+ return pred
103
+
104
+
105
+ def positional_encoding(length, depth):
106
+ depth = depth / 2
107
+ positions = np.arange(length)[:, np.newaxis]
108
+ depths = np.arange(depth)[np.newaxis, :]/depth
109
+ angle_rates = 1 / (10000**depths)
110
+ angle_rads = positions * angle_rates
111
+ pos_encoding = np.concatenate(
112
+ [np.sin(angle_rads), np.cos(angle_rads)],
113
+ axis=-1)
114
+ return tf.cast(pos_encoding, dtype=tf.float32)
115
+
116
+ class InputEmbedding(keras.layers.Layer):
117
+ def __init__(self):
118
+ super().__init__()
119
+ self.embed = Embedding(input_dim=VOCAB_SIZE+1, output_dim=EMBED_DIM)
120
+ self.pos = positional_encoding(length=TRANSFORMER_N, depth=EMBED_DIM)
121
+ self.add = Add()
122
+ self.dropout = Dropout(0.1)
123
+ def call(self, input):
124
+ length = tf.shape(input)[1]
125
+ x = self.embed(input)
126
+ x *= tf.math.sqrt(tf.cast(EMBED_DIM, tf.float32))
127
+ x = self.add([x, self.pos[tf.newaxis, :length, :]])
128
+ x = self.dropout(x)
129
+ return x
130
+
131
+ class AttentionBlock(keras.layers.Layer):
132
+ def __init__(self, **kwargs):
133
+ super().__init__()
134
+ self.mha = MultiHeadAttention(**kwargs)
135
+ self.dropout = Dropout(0.1)
136
+ self.norm = LayerNormalization()
137
+ self.add = Add()
138
+ def call(self, input):
139
+ x = self.mha(query=input, value=input, key=input, use_causal_mask=True)
140
+ x = self.dropout(x)
141
+ x = self.add([input, x])
142
+ x = self.norm(x)
143
+ return x
144
+
145
+ class FeedForward(keras.layers.Layer):
146
+ def __init__(self, dff):
147
+ super().__init__()
148
+ self.seq = keras.Sequential([
149
+ Dense(dff, activation='relu'),
150
+ Dense(EMBED_DIM),
151
+ Dropout(0.1)
152
+ ])
153
+ self.add = Add()
154
+ self.norm = LayerNormalization()
155
+ def call(self, input):
156
+ x = self.add([input, self.seq(input)])
157
+ x = self.norm(x)
158
+ return x
159
+
160
+ class Decoder(keras.layers.Layer):
161
+ def __init__(self, *, num_layers, num_heads, dff):
162
+ super(Decoder, self).__init__()
163
+ attention = []
164
+ for _ in range(num_layers):
165
+ attention.append(AttentionBlock(num_heads=num_heads, key_dim=EMBED_DIM, dropout=0.1))
166
+ self.attn_seq = keras.Sequential(attention)
167
+ self.ffn = FeedForward(dff)
168
+ def call(self, input):
169
+ x = self.attn_seq(input)
170
+ x = self.ffn(x)
171
+ return x
172
+
173
+ class TransformerModel(keras.Model):
174
+ def __init__(self, *, num_layers=TRANSFORMER_LAYERS, num_heads=TRANSFORMER_HEADS, dff=TRANSFORMER_DFF):
175
+ super(TransformerModel, self).__init__()
176
+ self.vocab = VOCAB
177
+ self.embed = InputEmbedding()
178
+ self.decoder = Decoder(num_layers=num_layers, num_heads=num_heads, dff=dff)
179
+ self.out = Dense(VOCAB_SIZE, activation='softmax')
180
+
181
+ def call(self, input):
182
+ x = self.embed(input) # context x embedding
183
+ x = self.decoder(x) # context x embedding
184
+ x = self.out(x) # context x vocab size
185
+ try:
186
+ del x._keras_mask
187
+ except AttributeError:
188
+ pass
189
+
190
+ return x
191
+
192
+ def generate(self, fullContext, temperature=0.7):
193
+ context = fullContext[-N:]
194
+ lastToken = len(context)-1
195
+ while len(context) > TRANSFORMER_N:
196
+ context.pop(0)
197
+ while len(context) < TRANSFORMER_N:
198
+ context.append(-1)
199
+ context = np.asarray([context])+1
200
+ pred = self.call(context)[0]
201
+ pred = pred[lastToken]
202
+ pred = sampleVocab(pred, temperature)
203
+ return pred
204
+
205
+
206
+ def rhyme_meter_encoding(input):
207
+ vowels = input[:,:,:RHYME_STACK_SIZE-1]
208
+ consonants = input[:,:,RHYME_STACK_SIZE-1:(RHYME_STACK_SIZE-1)*2]
209
+ rhyme_match = input[:,:,(RHYME_STACK_SIZE-1)*2:(RHYME_STACK_SIZE-1)*3]
210
+ vowels = tf.cast(vowels, tf.int8)
211
+ consonants = tf.cast(consonants, tf.int8)
212
+ vowels = tf.one_hot(vowels, depth=VOWEL_TYPES)
213
+ consonants = tf.one_hot(consonants, depth=CONSONANT_TYPES)
214
+ vowels = tf.reshape(vowels, shape=(tf.shape(vowels)[0], tf.shape(vowels)[1], -1))
215
+ consonants = tf.reshape(consonants, shape=(tf.shape(consonants)[0], tf.shape(consonants)[1], -1))
216
+ meter = input[:,:,-METER_STACK_SIZE:]
217
+ vowels = tf.cast(vowels, tf.float32)
218
+ consonants = tf.cast(consonants, tf.float32)
219
+ rhyme_match = tf.cast(rhyme_match, tf.float32)
220
+ meter = tf.cast(meter, tf.float32)
221
+ rhyme = tf.concat([vowels, consonants, rhyme_match], axis=2)
222
+ return rhyme, meter
223
+
224
+ class RhymeMeterLayer(keras.layers.Layer):
225
+ def __init__(self):
226
+ super().__init__()
227
+ self.dense_r1 = Dense(RHYME_METER_DFF, activation='relu')
228
+ self.dense_m1 = Dense(RHYME_METER_DFF//2, activation='relu')
229
+ self.dense_r2 = Dense(RHYME_METER_DFF, activation='relu')
230
+ # self.dense_m2 = Dense(RHYME_METER_DFF//2, activation='relu')
231
+ self.dense_3 = Dense(RHYME_METER_DFF*2, activation='relu')
232
+ self.dense_final = Dense(VOCAB_SIZE)
233
+ def call(self, input):
234
+ rhyme, meter = rhyme_meter_encoding(input)
235
+ rhyme = self.dense_r1(rhyme)
236
+ rhyme = self.dense_r2(rhyme)
237
+ meter = self.dense_m1(meter)
238
+ # meter = self.dense_m2(meter)
239
+ x = tf.concat([rhyme, meter], axis=2)
240
+ x = self.dense_3(x)
241
+ x = self.dense_final(x)
242
+ return x
243
+
244
+ class BardModel(keras.Model):
245
+ def __init__(self, *, num_layers=TRANSFORMER_LAYERS, num_heads=TRANSFORMER_HEADS, dff=TRANSFORMER_DFF):
246
+ super(BardModel, self).__init__()
247
+ self.vocab = VOCAB
248
+ self.tl = VOCAB.index(TITLE.lower()[1:-1])
249
+ self.rhyme_types = max(VOWEL_TYPES, CONSONANT_TYPES)
250
+ self.embed = InputEmbedding()
251
+ self.decoder = Decoder(num_layers=num_layers, num_heads=num_heads, dff=dff)
252
+ self.transformer_pred = Dense(VOCAB_SIZE)
253
+ self.rhyme_meter_pred = RhymeMeterLayer()
254
+ self.add = Add()
255
+ self.softmax = Softmax()
256
+
257
+ def call(self, input):
258
+ x = self.embed(input[0])
259
+ x = self.decoder(x)
260
+ x = self.transformer_pred(x)
261
+ try:
262
+ del x._keras_mask
263
+ except AttributeError:
264
+ pass
265
+
266
+ rhyme_meter_x = self.rhyme_meter_pred(input[1])
267
+ x = self.add([x, rhyme_meter_x])
268
+ x = self.softmax(x)
269
+ return x
270
+
271
+ def generate(self, fullContext, temperature=0.7):
272
+ context = fullContext[-N:]
273
+ lastToken = len(context)-1
274
+ while len(context) > TRANSFORMER_N:
275
+ context.pop(0)
276
+ while len(context) < TRANSFORMER_N:
277
+ context.append(-1)
278
+ context = np.asarray([context])+1
279
+ rm = rhymeMeterFromTokens(fullContext, len(fullContext), self.tl, self.vocab)
280
+ rm = np.asarray([rm])
281
+ pred = self.call([context, rm])[0]
282
+ pred = pred[lastToken]
283
+ pred = sampleVocab(pred, temperature)
284
+ return pred
285
+
286
+
287
+
288
+ class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
289
+ def __init__(self, d_model, warmup_steps=WARMUP_STEPS):
290
+ super().__init__()
291
+
292
+ self.d_model = d_model
293
+ self.d_model = tf.cast(self.d_model, tf.float32)
294
+
295
+ self.warmup_steps = warmup_steps
296
+
297
+ def __call__(self, step):
298
+ step = tf.cast(step, dtype=tf.float32)
299
+ arg1 = tf.math.rsqrt(step)
300
+ arg2 = step * (self.warmup_steps ** -1.5)
301
+
302
+ return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
303
+
304
+
305
+ def sparse_loss(y_true, y_pred):
306
+ loss_obj = keras.losses.SparseCategoricalCrossentropy(ignore_class=-1, reduction='none')
307
+ loss = loss_obj(y_true, y_pred)
308
+ return loss
309
+ def sparse_perplexity(y_true, y_pred):
310
+ return tf.math.exp(tf.math.reduce_mean(sparse_loss(y_true, y_pred)))
311
+
312
+ if __name__ == '__main__':
313
+ fname = {'n': 'inputs/ngram_train.npz',
314
+ 't': 'inputs/transformer_train.npz',
315
+ 'b': 'inputs/bard_train.npz'
316
+ }[MODEL_TYPE]
317
+ print("Loading data from", fname)
318
+ loaded = np.load(fname)
319
+ train_x = loaded['x']
320
+ train_y = loaded['y']
321
+ if MODEL_TYPE == 'b':
322
+ train_x = [tf.convert_to_tensor(train_x), tf.convert_to_tensor(loaded['rm'])] # rhyme and syllables
323
+ if MODEL_TYPE == 'n':
324
+ train_x = tf.convert_to_tensor(train_x, tf.int32)
325
+ del loaded
326
+
327
+ if TRAINING and VERBOSE:
328
+ if MODEL_TYPE != 'b':
329
+ print("X:", train_x[10:14])
330
+ else:
331
+ print("X:", train_x[0][10:14])
332
+ print("RM:", train_x[1][10:14][1])
333
+ print("Y:", train_y[10:14])
334
+ if MODEL_TYPE != 'b':
335
+ print("X shape:", train_x.shape)
336
+ print("Y shape:", train_y.shape)
337
+
338
+ print("Initializing model")
339
+ models = {'n': LinearModel, 't': TransformerModel, 'b': BardModel}
340
+ model = models[MODEL_TYPE]()
341
+ if MODEL_TYPE != 'b':
342
+ res = model(train_x[:1])
343
+ else:
344
+ x0 = train_x[0][:1]
345
+ x1 = train_x[1][:1]
346
+ res = model([x0, x1])
347
+ if VERBOSE:
348
+ print(model)
349
+ print(res)
350
+ print(model.summary())
351
+
352
+ if TRAINING:
353
+ print("Compiling model")
354
+ learning_rate = CustomSchedule(EMBED_DIM)
355
+ model.compile(optimizer=keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9),
356
+ loss=sparse_loss, metrics=[sparse_perplexity])
357
+
358
+ print("Generating sample from baseline")
359
+ print(pretty_tokens(genTokens(model, 25)))
360
+
361
+ print("Training model")
362
+ min_perplexity = None
363
+ if not os.path.exists('saved_models'):
364
+ os.mkdir('saved_models')
365
+ class TrainCallback(keras.callbacks.Callback):
366
+ def on_epoch_end(self, epoch, logs=None):
367
+ global min_perplexity
368
+ perplexity = logs['val_sparse_perplexity'] if VAL_SPLIT > 0 else logs['sparse_perplexity']
369
+ print("\rGenerating sample from model in training: "+
370
+ "epoch "+str(epoch+1)+", perplexity "+str(round(perplexity, 2)), end='')
371
+ print(pretty_tokens(genTokens(model, 75)))
372
+ if (min_perplexity is None or perplexity <= min_perplexity) and not SAVE_AT_END:
373
+ min_perplexity = perplexity
374
+ print("Saving model weights")
375
+ model.save_weights('saved_models/'+MODEL_TYPE+'_model.h5') # no such file or directory right now
376
+
377
+ model.fit(train_x, train_y,
378
+ batch_size=BATCH_SIZE, validation_split=VAL_SPLIT, epochs=EPOCHS,
379
+ callbacks=[TrainCallback()])
380
+
381
+ if SAVE_AT_END:
382
+ print("Saving final model weights")
383
+ model.save_weights('saved_models/'+MODEL_TYPE+'_model.h5')
384
+
385
+ print("Generating samples from final model")
386
+ if VERBOSE:
387
+ for i in range(10):
388
+ print(pretty_tokens(genTokens(model, 100)))
389
+ print(pretty_tokens(genTokens(model, 150, prompt=TEST_PROMPT)))
390
+ print(pretty_tokens(genTokens(model, 500)))
391
+ print(pretty_tokens(genTokens(model, 500)))
392
+
393
+ else:
394
+ del train_x
395
+ del train_y
396
+ print("Loading weights")
397
+ model.load_weights('saved_models/'+MODEL_TYPE+'_model.h5')
398
+
399
+ while True:
400
+ temp = 0.7
401
+ print("Commands:\ng: generate sample with 250 tokens\nl: generate sample with custom length\np: generate sample with prompt\nt: set temperature\nq: quit")
402
+ cmd = input("Enter command: ")
403
+ try:
404
+ if cmd == 'g':
405
+ print("Generating sample...")
406
+ print(pretty_tokens(genTokens(model, 250, temperature=temp)))
407
+ if cmd == 'l':
408
+ length = int(input("Enter length: "))
409
+ print("Generating sample...")
410
+ print(pretty_tokens(genTokens(model, length, temperature=temp)))
411
+ if cmd == 'p':
412
+ prompt = ""
413
+ print("Enter prompt as tokens separated by spaces and newlines.")
414
+ print("Example: <title> stop =ing by woods on a snowy evening\nwhose woods these are i think i know")
415
+ print("All tokens not in the vocabulary will be ignored.")
416
+ while not prompt.endswith('\n\n\n'):
417
+ prompt += input("")+'\n'
418
+ while prompt.startswith(' ') or prompt.startswith('\n'):
419
+ prompt = prompt[1:]
420
+ while prompt.endswith(' ') or prompt.endswith('\n'):
421
+ prompt = prompt[:-1]
422
+ prompt = prompt.replace('\n', NEWLINE.lower())
423
+ length = int(input("Enter length: "))
424
+ print("Generating sample...")
425
+ print(pretty_tokens(genTokens(model, length, temperature=temp, prompt=prompt)))
426
+ if cmd == 't':
427
+ print("Current temperature:", temp)
428
+ temp = float(input("New temperature: "))
429
+ print("Temperature set to", temp)
430
+ if cmd == 'q':
431
+ sys.exit(0)
432
+ except Exception as e:
433
+ print("Error:", e)
saved_models/b_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f03eab3fd3dd13a08aadbe6a03e7b7e0c10ed7d38534484bc58e126634879a6f
3
+ size 157786768
tokens.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+ from constants import *
5
+ if __name__ == '__main__':
6
+ from threading import Thread
7
+
8
+ N_THREADS = 32
9
+ if '--n_threads' in sys.argv:
10
+ N_THREADS = int(sys.argv[sys.argv.index('--n_threads')+1])
11
+
12
+ if __name__ == '__main__':
13
+ if not os.path.exists('lemmas'):
14
+ os.mkdir('lemmas')
15
+
16
+ file = open("inputs/join.txt" if not KAGGLE else "inputs/join-kaggle.txt", "r")
17
+ text = file.read()
18
+ file.close()
19
+
20
+ tokens = text.split(" ")
21
+ tokens = [x for x in tokens if x != '']
22
+ print("Total number of tokens:", len(tokens))
23
+ print("Counting tokens")
24
+ counts = {}
25
+ for token in tokens:
26
+ if not token in counts:
27
+ counts[token] = 0
28
+ counts[token] += 1
29
+ words = list(counts.keys())
30
+ words.sort(reverse=True, key=lambda word: counts[word])
31
+
32
+ for token in BANNED_TOKENS:
33
+ if token in words:
34
+ words.remove(token)
35
+ words.append(token)
36
+ counts['<unk>'] = 0
37
+ for word in words:
38
+ if word in words[:VOCAB_SIZE]:
39
+ continue
40
+ counts['<unk>'] += counts[word]
41
+ words = list(counts.keys())
42
+ words.sort(reverse=True, key=lambda word: counts[word])
43
+ for token in BANNED_TOKENS:
44
+ if token in words:
45
+ words.remove(token)
46
+ words.append(token)
47
+
48
+ vocab = set(words[:VOCAB_SIZE])
49
+ else:
50
+ print("Loading vocab")
51
+ vocab = set(list(np.load('lemmas/lemmas.npy')))
52
+
53
+ def pretty_tokens(tokens, mask=True):
54
+ s_dict = np.load('lemmas/s.npy', allow_pickle=True).item()
55
+ ed_dict = np.load('lemmas/ed.npy', allow_pickle=True).item()
56
+ er_dict = np.load('lemmas/er.npy', allow_pickle=True).item()
57
+ est_dict = np.load('lemmas/est.npy', allow_pickle=True).item()
58
+ ing_dict = np.load('lemmas/ing.npy', allow_pickle=True).item()
59
+ dicts = {'=s': s_dict, '=ed': ed_dict, '=er': er_dict, '=est': est_dict, '=ing': ing_dict}
60
+ res = []
61
+ i = 0
62
+ def includeSpace(this):
63
+ nonlocal res
64
+ quote = set(["'", '"'])
65
+ nospace = set(['\n','-'])
66
+ prev = res[len(res)-1] if len(res) > 0 else None
67
+ prev2 = res[len(res)-2] if len(res) > 1 else None
68
+ space = not prev in nospace\
69
+ and not this in PUNCT and not this == '\n'\
70
+ and not (this.startswith("'") and this != "'")
71
+ if prev in quote and not prev2 in PUNCT:
72
+ space = False
73
+ elif this in quote and prev in PUNCT:
74
+ space = False
75
+ return space
76
+ while i < len(tokens):
77
+ this = tokens[i]
78
+ if this == NEWLINE.lower()[1:-1]:
79
+ this = '\n'
80
+ elif this == TITLE.lower()[1:-1]:
81
+ this = '\n ༄༅༅ '
82
+ elif mask and not this in vocab:
83
+ this = " <unk>"
84
+ if not includeSpace(this):
85
+ this = "<unk>"
86
+ res.append(this)
87
+ i += 1
88
+ continue
89
+ if i+1 < len(tokens):
90
+ next = tokens[i+1]
91
+ while next.startswith('='):
92
+ if next == "=nt":
93
+ if tokens[i].endswith('n'):
94
+ this = this[:-1]
95
+ if tokens[i] == 'will':
96
+ this = 'wo'
97
+ elif tokens[i] == 'shall':
98
+ this = 'sha'
99
+ this = this+"n't"
100
+ else:
101
+ if tokens[i] in dicts[next]:
102
+ this = dicts[next][this]
103
+ else:
104
+ if next[1] == 'e' or next[1] == 'i':
105
+ if this.endswith('e'):
106
+ this = this[:-1]
107
+ elif this.endswith('c'):
108
+ this = this+'k'
109
+ if this.endswith('y') and next[1] == 'e' and len(this) > 2 and not this[-2] in VOWELS:
110
+ this = this[:-1]+'i'
111
+ if next[1] == 's':
112
+ if this.endswith('s') or this.endswith('sh') or this.endswith('x') or this.endswith('ch'):
113
+ this = this+'e'
114
+ if this.endswith('y') and len(this) > 2 and not this[-2] in VOWELS:
115
+ this = this[:-1]+'ie'
116
+
117
+ this = this+next[1:]
118
+ i += 1
119
+ next = tokens[i+1] if i+1 < len(tokens) else ''
120
+ if this.startswith('='):
121
+ this = this[1:]
122
+ elif includeSpace(this):
123
+ this = " "+this
124
+ res.append(this)
125
+ i += 1
126
+ res = ''.join(res)
127
+ res = res[1:] if res.startswith(' ') else res
128
+ return res
129
+
130
+ def getRhyme(line):
131
+ # rhyme format:
132
+ # final vowel (short AEIO, schwa, long AEIOU, OW, OI, A/schwa before R; total 14)
133
+ # final consonant (R, L, N/M/NG, P/B, T/D, F/V, S/SH/Z/ZH, K/G, CH/J, TH; total 10)
134
+ if line is None or len(line) == 0:
135
+ return [-1, -1]
136
+ nl = NEWLINE.lower()[1:-1]
137
+ tl = TITLE.lower()[1:-1]
138
+ if line[0] == tl:
139
+ return [-1, -1]
140
+ while line[-1] == nl or line[-1] in PUNCT or line[-1] == '"' or line[-1] == "'" or line[-1] is None:
141
+ line = line[:-1]
142
+ if len(line) == 0:
143
+ return [-1, -1]
144
+ word = line[-1]+''
145
+ long_vowel = False
146
+ vowel_type = None
147
+ vowel_map = {'a': 0, 'e': 1, 'i': 2, 'o': 3, 'u': 4, 'ow': 5, 'ou': 5, 'oi': 6, 'oy': 6,
148
+ 'ay': 7, 'ai': 7, 'au': 3, 'aw': 3, 'ea': 8, 'ee': 8, 'eu': 11, 'ew': 11,
149
+ 'oa': 10, 'oo': 11, 'y': 9, 'ey': 7, 'ei': 9}
150
+
151
+ # vowel type format:
152
+ # 0: A, 1: E, 2: I, 3: O, 4: U, 5: OW, 6: OI
153
+ # short U is schwa
154
+ # OW, OI are always long
155
+ # before R: short E/I become schwa, schwa/short A get their own vowel type, short O becomes long O
156
+ consonant_type = -1
157
+ cons_map = {'r': 0, 'l': 1, 'n': 2, 'm': 2, 'ng': 2,
158
+ 'p': 3, 'b': 3, 't': 4, 'd': 4, 'f': 5,
159
+ 'v': 5, 's': 6, 'sh': 6, 'z': 6, 'zh': 6,
160
+ 'th': 9, 'k': 7, 'ch': 8, 'j': 8}
161
+ # consonant type format:
162
+ # 0: R, 1: L, 2: N/M/NG, 3: P/B, 4: T/D, 5: F/V, 6: S/SH/Z/ZH/TH, 7: K/G, 8: CH/J
163
+ # total 9 consonant types
164
+
165
+ # full vowel type list: (L=long, S=short, R=before R)
166
+ # 0: AS (bat), 1: ES (bet), 2: IS (bit), 3: OS (bot), 4: US/schwa (but)
167
+ # 5: OW (bout), 6: OI (boil)
168
+ # 7: AL (bait), 8: EL (beat), 9: IL (bite), 10: OL (boat), 11: UL (boot)
169
+ # 12: AR (bar), 13: schwa_R (butter, bird, burn)
170
+ # total 14 vowel types
171
+ def getVowel(type, isLong, beforeR):
172
+ if beforeR and not isLong:
173
+ if type == 0:
174
+ return 12
175
+ if type ==1 or type == 2 or type == 4:
176
+ return 13
177
+ if type == 3:
178
+ return 10
179
+ if isLong and 0 <= type <= 4:
180
+ return type+7
181
+ return type
182
+
183
+ lock_consonant = -1
184
+ if len(line) > 1:
185
+ if word == '=ed':
186
+ if line[-2].endswith('t') or line[-2].endswith('d'):
187
+ return [4, 4]
188
+ lock_consonant = 4
189
+ word = line[-2]
190
+ if word == '=s' or word == "'s":
191
+ if line[-2].endswith('s') or line[-2].endswith('z') or line[-2].endswith('ch') or line[-2].endswith('sh') or line[-2].endswith('x'):
192
+ return [4, 6]
193
+ lock_consonant = 6
194
+ word = line[-2]
195
+ elif word == "'re":
196
+ lock_consonant = 0
197
+ word = line[-2]
198
+ elif word == "'ve":
199
+ lock_consonant = 5
200
+ word = line[-2]
201
+ elif word == "'ll":
202
+ lock_consonant = 1
203
+ word = line[-2]
204
+ elif word == "'d":
205
+ lock_consonant = 4
206
+ word = line[-2]
207
+ elif word == "'m":
208
+ lock_consonant = 2
209
+ word = line[-2]
210
+ elif word == "=nt'":
211
+ lock_consonant = 4
212
+ word = line[-2]
213
+ if word in DEFINED_RHYMES:
214
+ vowel_type = DEFINED_RHYMES[word][0]
215
+ consonant_type = DEFINED_RHYMES[word][1] if lock_consonant == -1 else lock_consonant
216
+ return [vowel_type, consonant_type]
217
+
218
+ if word.endswith('o'):
219
+ return [10, lock_consonant]
220
+ if word.endswith('bble') or word.endswith('ggle'):
221
+ return [4, 1 if lock_consonant == -1 else lock_consonant]
222
+ if word.endswith('old'):
223
+ return [10, 1 if lock_consonant == -1 else lock_consonant]
224
+ if word.endswith('ance'):
225
+ return [0, 6 if lock_consonant == -1 else lock_consonant]
226
+ if word.endswith('ense') or word.endswith('ence'):
227
+ return [1, 6 if lock_consonant == -1 else lock_consonant]
228
+ if word.endswith('ince'):
229
+ return [2, 6 if lock_consonant == -1 else lock_consonant]
230
+ if word.endswith('ture') or word.endswith('sure'):
231
+ return [13, 0 if lock_consonant == -1 else lock_consonant]
232
+ if word.endswith('all'):
233
+ return [3, 1 if lock_consonant == -1 else lock_consonant]
234
+ if word.endswith('row') or word.endswith('low'):
235
+ return [10, lock_consonant]
236
+ if word.endswith('le') and len(word) >= 3 and not word[-3] in VOWELS:
237
+ return [4, 1 if lock_consonant == -1 else lock_consonant]
238
+ if word.endswith('on') and len(word) > 3 and not word.endswith('oon'):
239
+ return [4, 2 if lock_consonant == -1 else lock_consonant]
240
+ if word.endswith('al') and len(word) > 3 and not word.endswith('eal'):
241
+ return [4, 1 if lock_consonant == -1 else lock_consonant]
242
+ if word.endswith('ous'):
243
+ return [4, 6 if lock_consonant == -1 else lock_consonant]
244
+ if word.endswith('ly'):
245
+ return [8, -1 if lock_consonant == -1 else lock_consonant]
246
+ if word.endswith('ward'):
247
+ return [13, 4 if lock_consonant == -1 else lock_consonant]
248
+
249
+ if word.endswith('e'):
250
+ long_vowel = True
251
+ word = word[:-1]
252
+ if lock_consonant == -1:
253
+ if word[-2:] in cons_map:
254
+ consonant_type = cons_map[word[-2:]]
255
+ elif word[-1:] in cons_map:
256
+ consonant_type = cons_map[word[-1:]]
257
+ elif word[-1] == 'c' and long_vowel:
258
+ consonant_type = cons_map['s']
259
+ elif word[-1] == 'g' and long_vowel:
260
+ consonant_type = cons_map['j']
261
+ else:
262
+ consonant_type = lock_consonant
263
+
264
+ lock_r = False
265
+ if not word[-1] in SOMETIMES_VOWELS:
266
+ while not word[-1] in SOMETIMES_VOWELS:
267
+ if word.endswith('igh'):
268
+ return [9, consonant_type]
269
+ if word[-1] == 'r':
270
+ lock_r = True
271
+ elif lock_r:
272
+ lock_r = False
273
+ word = word[:-1]
274
+ if word == '':
275
+ return [8, lock_consonant]
276
+ if word[-2:] in vowel_map:
277
+ vowel_type = vowel_map[word[-2:]]
278
+ elif word[-1:] in vowel_map:
279
+ vowel_type = vowel_map[word[-1:]]
280
+
281
+ vowel_type = getVowel(vowel_type, long_vowel, consonant_type == 0 or lock_r)
282
+ return [vowel_type, consonant_type]
283
+ def pretty_rhyme(rhyme):
284
+ v_map = ['bat', 'bet', 'bit', 'bot', 'but', 'pout', 'boil', 'bait', 'beat', 'bite', 'boat', 'boot', 'bar', 'sir']
285
+ c_map = ['R', 'L', 'N/M/NG', 'P/B', 'T/D', 'F/V', 'S/SH/Z/ZH', 'K/G', 'CH/J', 'TH']
286
+ return "Rhyme is " +\
287
+ (v_map[rhyme[0]] if rhyme[0] != -1 else '--') + ' ' + (c_map[rhyme[1]] if rhyme[1] != -1 else 'ø')
288
+
289
+
290
+ def getMeter(line):
291
+ if line is None:
292
+ return 0
293
+ res = 0
294
+ nl = NEWLINE.lower()[1:-1]
295
+ tl = TITLE.lower()[1:-1]
296
+ for i in range(len(line)):
297
+ word = line[i]
298
+ if word == nl or word == tl or word is None:
299
+ continue
300
+ if word in DEFINED_METERS:
301
+ res += DEFINED_METERS[word]
302
+ continue
303
+ if word == '=ed' and i > 0:
304
+ if line[i-1].endswith('t') or line[i-1].endswith('d') or line[i-1].endswith('te') or line[i-1].endswith('de'):
305
+ res += 1
306
+ continue
307
+ if word == '=s' and i > 0:
308
+ if line[i-1].endswith('s') or line[i-1].endswith('z') or line[i-1].endswith('ch') or line[i-1].endswith('sh') or line[i-1].endswith('x'):
309
+ res += 1
310
+ continue
311
+ if word.endswith('le') and len(word) >= 3 and not word[-3] in VOWELS:
312
+ res += 1 # to account for the dropped e
313
+ removed_e = False
314
+ if word.endswith('e'):
315
+ word = word[:-1]
316
+ removed_e = True
317
+ if word.endswith('y') and len(word) > 2 and not word[-2] in VOWELS:
318
+ word = word[:-1]+'i'
319
+ word = word.replace('ea','i').replace('ee','i')
320
+ word = word.replace('ai','i').replace('au','o')
321
+ word = word.replace('eu','u')
322
+ word = word.replace('ei','i').replace('ie','i')
323
+ word = word.replace('oa','o').replace('ou','o')
324
+ word = word.replace('oi','o').replace('oo','u')
325
+ if word.endswith('tion') or word.endswith('sion') or word.endswith('tian'):
326
+ word = word[:-4]+'shun'
327
+ this_count = 0
328
+ for vowel in VOWELS:
329
+ this_count += word.count(vowel)
330
+ if removed_e and this_count == 0:
331
+ this_count = 1
332
+ res += this_count
333
+ return res
334
+
335
+ def lastLine(tokens, endl):
336
+ res = []
337
+ nl = NEWLINE.lower()[1:-1]
338
+ i = endl-1
339
+ while i > 0:
340
+ if tokens[i] == nl:
341
+ break
342
+ i -= 1
343
+ res = tokens[i:endl]
344
+ if len(res) == 0:
345
+ res = tokens[:endl]
346
+ return res
347
+ def processRhymeStack(rhyme_stack):
348
+ prev = rhyme_stack[:-1].flatten(order='F')
349
+ lastRhyme = rhyme_stack[-1]
350
+ res = np.zeros(RHYME_STACK_SIZE-1)
351
+ if lastRhyme[0] != -1:
352
+ for i in range(RHYME_STACK_SIZE-1):
353
+ if rhyme_stack[i][0] == lastRhyme[0]:
354
+ res[i] = 1
355
+ if rhyme_stack[i][1] == lastRhyme[1]:
356
+ res[i] = 2
357
+ res = np.concatenate([prev, res])
358
+ return res
359
+ def processRhymeMeter(split):
360
+ in_title = False
361
+ meter = []
362
+ rhymes = []
363
+ meter_stack = np.zeros(METER_STACK_SIZE, np.int8)
364
+ rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1
365
+ tl = TITLE.lower()[1:-1]
366
+ nl = NEWLINE.lower()[1:-1]
367
+ for i in range(len(split)):
368
+ line = lastLine(split, i)
369
+ if split[i] == tl:
370
+ in_title = True
371
+ meter_stack = np.zeros(METER_STACK_SIZE, np.int8)
372
+ rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1
373
+ meter.append(meter_stack.copy())
374
+ rhymes.append(processRhymeStack(rhyme_stack))
375
+ continue
376
+ elif in_title and split[i] == nl:
377
+ in_title = False
378
+ meter_stack = np.zeros(METER_STACK_SIZE, np.int8)
379
+ meter_stack[-1] = getMeter(line)
380
+ meter.append(meter_stack.copy())
381
+ rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1
382
+ rhyme_stack[-1] = np.array(getRhyme(line), np.int8)
383
+ rhymes.append(processRhymeStack(rhyme_stack))
384
+ meter_stack = np.zeros(METER_STACK_SIZE, np.int8)
385
+ rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1
386
+ continue
387
+ if not in_title and split[i] == nl:
388
+ rhymes.append(processRhymeStack(rhyme_stack))
389
+ meter.append(meter_stack.copy())
390
+ if split[i-1] != nl:
391
+ rhyme_stack = np.roll(rhyme_stack, -1, axis=0)
392
+ rhyme_stack[-1] = np.array(getRhyme(line), np.int8)
393
+ meter_stack = np.roll(meter_stack, -1, axis=0)
394
+ meter_stack[-1] = getMeter(line)
395
+ else:
396
+ meter_stack[-1] = getMeter(line)
397
+ rhyme_stack[-1] = np.array(getRhyme(line), np.int8)
398
+ rhymes.append(processRhymeStack(rhyme_stack))
399
+ meter.append(meter_stack.copy())
400
+ return [rhymes, meter]
401
+
402
+ def rhymeMeterFromTokens(tokens, endl, tl, vocab=None):
403
+ # used as input for model
404
+ res = []
405
+ start = endl-1
406
+ if len(tokens) >= endl:
407
+ while start > 0 and tokens[start] != tl:
408
+ start -= 1
409
+ lines = tokens[start:endl]
410
+ while len(lines) < TRANSFORMER_N:
411
+ lines.append(None)
412
+ input_lines = lines if vocab is None else [(vocab[x] if (x is not None and 0 <= x < VOCAB_SIZE) else None) for x in lines]
413
+ rhymes, meter = processRhymeMeter(input_lines)
414
+ rhymes = rhymes[-TRANSFORMER_N:] # context x RHYME_STACK_SIZE x 2
415
+ meter = meter[-TRANSFORMER_N:] # context x METER_STACK_SIZE
416
+ rhymes = np.array(rhymes)
417
+ meter = np.array(meter)
418
+ res = np.concatenate([rhymes, meter], axis=1) # context x (RHYME_STACK_SIZE*2 + METER_STACK_SIZE)
419
+ return res
420
+
421
+ if __name__ == '__main__':
422
+ N = NGRAM_N if MODEL_TYPE == 'n' else TRANSFORMER_N+1
423
+ for i in range(N-1):
424
+ tokens.append(None)
425
+ words.remove('<unk>')
426
+ print({word: counts[word] for word in words[:VOCAB_SIZE]})
427
+ title_token = words.index(TITLE.lower()[1:-1])
428
+ newline_token = words.index(NEWLINE.lower()[1:-1])
429
+
430
+ print("Splitting poems with masked dividers")
431
+ mask_list = [-1]*N
432
+ splits = []
433
+ chunk_size = len(tokens)//N_THREADS
434
+ for i in range(N_THREADS):
435
+ splits.append(
436
+ tokens[i*chunk_size : (i+1)*chunk_size if i < N_THREADS-1 else len(tokens)])
437
+
438
+
439
+ results = [None] * N_THREADS
440
+ threads = []
441
+
442
+ def add_dividers(thread_index, split):
443
+ i = 1
444
+ while i < len(split):
445
+ if split[i] == title_token:
446
+ split = split[:i] + mask_list + split[i:]
447
+ i += N+5
448
+ i += 1
449
+ results[thread_index] = split
450
+ return split
451
+ for i in range(N_THREADS):
452
+ t = Thread(target=add_dividers, args=(i, splits[i],))
453
+ threads.append(t)
454
+ t.start()
455
+ tokens = []
456
+ for i in range(N_THREADS):
457
+ threads[i].join()
458
+ tokens += results[i]
459
+
460
+ if MODEL_TYPE == 'b':
461
+ print("Computing rhyme and meter information")
462
+ split_token_marks = []
463
+ split_size = len(tokens)//N_THREADS
464
+ for i in range(N_THREADS+1):
465
+ split_token_marks.append(split_size*i)
466
+ for i in range(1, N_THREADS):
467
+ while tokens[split_token_marks[i]] != TITLE.lower()[1:-1]:
468
+ split_token_marks[i] += 1
469
+ if split_token_marks[i] >= len(tokens):
470
+ break
471
+ meter_data = []
472
+ rhymes_data = []
473
+ split_token_marks[-1] = len(tokens)
474
+ split_tokens = [tokens[split_token_marks[i]:split_token_marks[i+1]] for i in range(N_THREADS)]
475
+ rhyme_meter_res = [None] * N_THREADS
476
+ threads = []
477
+ def rhymeMeterThread(thread_index, split):
478
+ rhyme_meter_res[thread_index] = processRhymeMeter(split)
479
+ for i in range(N_THREADS):
480
+ t = Thread(target=rhymeMeterThread, args=(i, split_tokens[i]))
481
+ threads.append(t)
482
+ t.start()
483
+ for i in range(N_THREADS):
484
+ threads[i].join()
485
+ rhymes_data += rhyme_meter_res[i][0]
486
+ meter_data += rhyme_meter_res[i][1]
487
+
488
+ print("Converting rhyme and meter information")
489
+ rhymes_data = np.asarray(rhymes_data)
490
+ meter_data = np.asarray(meter_data)
491
+ rhyme_meter_data = np.concatenate([rhymes_data, meter_data], axis=1)
492
+
493
+ print("Masking unknown tokens")
494
+ tokens = [(words.index(x) if x in vocab else -1) for x in tokens]
495
+
496
+ print("Creating sets of ngrams")
497
+ ngrams = []
498
+ rm_ngrams = []
499
+ for i in range(0, len(tokens)-N, TOKEN_SKIP):
500
+ ngrams.append(tokens[i:i+N])
501
+ if MODEL_TYPE == 'b':
502
+ rm_ngrams.append(rhyme_meter_data[i:i+N-1,:])
503
+ train_x = []
504
+ train_y = []
505
+ train_rm = []
506
+ for i in range(len(ngrams)):
507
+ sample = ngrams[i][:N]
508
+ train_x.append(sample[:N-1])
509
+ if MODEL_TYPE == 'b':
510
+ sample_rm = rm_ngrams[i]
511
+ train_rm.append(sample_rm)
512
+ if MODEL_TYPE != 'n':
513
+ train_y.append(sample[1:])
514
+ else:
515
+ train_y.append(sample[N-1])
516
+ print("Converting arrays")
517
+ train_x = np.asarray(train_x)
518
+ train_y = np.asarray(train_y)
519
+ if MODEL_TYPE == 'b':
520
+ train_rm = np.asarray(train_rm, np.int8)
521
+ if MODEL_TYPE != 'n':
522
+ train_x += 1 # x in [0, VOCAB_SIZE] since 0 is for <unk>
523
+ # y in [-1, VOCAB_SIZE-1] with VOCAB_SIZE tokens, one for each vocabulary item, and -1 for <unk>
524
+
525
+ print("Saving data")
526
+ fname = {'n': 'inputs/ngram_train.npz',
527
+ 't': 'inputs/transformer_train.npz',
528
+ 'b': 'inputs/bard_train.npz'
529
+ }[MODEL_TYPE]
530
+ if MODEL_TYPE != 'b':
531
+ np.savez_compressed(fname, x=train_x, y=train_y)
532
+ else:
533
+ np.savez_compressed(fname, x=train_x, rm=train_rm, y=train_y)
534
+ np.save('lemmas/lemmas.npy', words[:VOCAB_SIZE])