Spaces:
Sleeping
Sleeping
Commit
·
c6624a3
1
Parent(s):
0833281
Create block_baseline.py
Browse files- block_baseline.py +195 -0
block_baseline.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from utils import kl, entropy, is_sent_finish, limit_past, bits2int, int2bits
|
6 |
+
|
7 |
+
# number of bins is 2^block_size
|
8 |
+
# each bin contains vocab_size/2^block_size words
|
9 |
+
def get_bins(vocab_size, block_size):
|
10 |
+
num_bins = 2**block_size
|
11 |
+
words_per_bin = vocab_size/num_bins
|
12 |
+
|
13 |
+
vocab_ordering = np.arange(vocab_size)
|
14 |
+
np.random.seed(block_size)
|
15 |
+
np.random.shuffle(vocab_ordering)
|
16 |
+
|
17 |
+
bin2words = [vocab_ordering[int(i*words_per_bin):int((i+1)*words_per_bin)] for i in range(num_bins)]
|
18 |
+
bin2words = [np.array(words) for words in bin2words]
|
19 |
+
words2bin_list = [{i: j for i in bin2words[j]} for j in range(num_bins)]
|
20 |
+
words2bin = {}
|
21 |
+
for d in words2bin_list:
|
22 |
+
words2bin.update(d)
|
23 |
+
|
24 |
+
return bin2words, words2bin
|
25 |
+
|
26 |
+
def encode_block(model, enc, message, context, block_size, bin2words, words2bin, finish_sent=False, device='cuda'):
|
27 |
+
length = len(message)
|
28 |
+
|
29 |
+
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
30 |
+
|
31 |
+
prev = context
|
32 |
+
output = context
|
33 |
+
past = None
|
34 |
+
|
35 |
+
total_num = 0
|
36 |
+
total_num_for_stats = 0
|
37 |
+
total_log_probs = 0
|
38 |
+
total_kl = 0 # in bits
|
39 |
+
total_num_sents = 0
|
40 |
+
|
41 |
+
with torch.no_grad():
|
42 |
+
i = 0
|
43 |
+
sent_finish = False
|
44 |
+
while i < length or (finish_sent and not sent_finish):
|
45 |
+
logits, past = model(prev.unsqueeze(0), past=past)
|
46 |
+
past = limit_past(past)
|
47 |
+
logits[0, -1, -1] = -1e10 # endoftext can't happen
|
48 |
+
logits[0, -1, 628] = -1e10 # 2 newlines can't happen
|
49 |
+
logits = logits[0, -1, :]
|
50 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
51 |
+
|
52 |
+
filtered_logits = logits.clone()
|
53 |
+
filtered_logits[:] = -1e10 # first set all to 0
|
54 |
+
|
55 |
+
if i >= length:
|
56 |
+
_, indices = logits.sort(descending=True)
|
57 |
+
sent_finish = is_sent_finish(indices[0].item(), enc)
|
58 |
+
else:
|
59 |
+
# First calculate logq
|
60 |
+
logq = logits.clone()
|
61 |
+
logq[:] = -1e10 # first set all to 0
|
62 |
+
|
63 |
+
for bin_val in range(2**block_size):
|
64 |
+
filtered_logits = logits.clone()
|
65 |
+
filtered_logits[:] = -1e10 # first set all to 0
|
66 |
+
available_tokens = bin2words[bin_val]
|
67 |
+
filtered_logits[available_tokens] = logits[available_tokens]
|
68 |
+
filtered_logits, indices = filtered_logits.sort(descending=True)
|
69 |
+
|
70 |
+
logq[indices[0]] = -block_size # in bits
|
71 |
+
|
72 |
+
logq = logq*0.69315 # in nats
|
73 |
+
q = torch.exp(logq)
|
74 |
+
|
75 |
+
# Then find the actual word for the right bin
|
76 |
+
m_part = message[i:i+block_size]
|
77 |
+
|
78 |
+
filtered_logits = logits.clone()
|
79 |
+
filtered_logits[:] = -1e10 # first set all to 0
|
80 |
+
available_tokens = bin2words[bits2int(m_part)]
|
81 |
+
filtered_logits[available_tokens] = logits[available_tokens]
|
82 |
+
filtered_logits, indices = filtered_logits.sort(descending=True)
|
83 |
+
|
84 |
+
total_kl += kl(q, logq, log_probs)
|
85 |
+
total_log_probs += log_probs[indices[0]].item()
|
86 |
+
i += block_size
|
87 |
+
total_num_for_stats += 1
|
88 |
+
|
89 |
+
total_num += 1
|
90 |
+
prev = indices[0].view(1)
|
91 |
+
output = torch.cat((output, prev))
|
92 |
+
|
93 |
+
avg_NLL = -total_log_probs/total_num_for_stats
|
94 |
+
avg_KL = total_kl/total_num_for_stats
|
95 |
+
words_per_bit = total_num_for_stats/i
|
96 |
+
|
97 |
+
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
98 |
+
|
99 |
+
def decode_block(model, enc, text, context, block_size, bin2words, words2bin, device='cuda'):
|
100 |
+
# inp is a list of token indices
|
101 |
+
# context is a list of token indices
|
102 |
+
inp = enc.encode(text)
|
103 |
+
i = 0
|
104 |
+
while i < len(inp):
|
105 |
+
if inp[i] == 628:
|
106 |
+
inp[i] = 198
|
107 |
+
inp[i+1:i+1] = [198]
|
108 |
+
i += 2
|
109 |
+
else:
|
110 |
+
i += 1
|
111 |
+
|
112 |
+
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
113 |
+
prev = context
|
114 |
+
past = None
|
115 |
+
|
116 |
+
message = []
|
117 |
+
with torch.no_grad():
|
118 |
+
i = 0
|
119 |
+
while i < len(inp):
|
120 |
+
if past and past[0].shape[3] >= 1023:
|
121 |
+
raise RuntimeError
|
122 |
+
bin_num = words2bin[inp[i]]
|
123 |
+
|
124 |
+
logits, past = model(prev.unsqueeze(0), past=past)
|
125 |
+
past = limit_past(past)
|
126 |
+
logits[0, -1, -1] = -1e10 # endoftext can't happen
|
127 |
+
logits[0, -1, 628] = -1e10 # 2 newlines can't happen
|
128 |
+
|
129 |
+
logits = logits[0, -1, :]
|
130 |
+
filtered_logits = logits.clone()
|
131 |
+
filtered_logits[:] = -1e10 # first set all to 0
|
132 |
+
|
133 |
+
available_tokens = bin2words[bin_num]
|
134 |
+
filtered_logits[available_tokens] = logits[available_tokens]
|
135 |
+
filtered_logits, indices = filtered_logits.sort(descending=True)
|
136 |
+
|
137 |
+
rank = (indices == inp[i]).nonzero().item()
|
138 |
+
|
139 |
+
# Handle errors that could happen because of BPE
|
140 |
+
if rank > 0:
|
141 |
+
true_token_text = enc.decoder[inp[i]]
|
142 |
+
for bin_num in range(len(bin2words)):
|
143 |
+
filtered_logits = logits.clone()
|
144 |
+
filtered_logits[:] = -1e10 # first set all to 0
|
145 |
+
|
146 |
+
available_tokens = bin2words[bin_num]
|
147 |
+
filtered_logits[available_tokens] = logits[available_tokens]
|
148 |
+
filtered_logits, indices = filtered_logits.sort(descending=True)
|
149 |
+
|
150 |
+
prop_token_text = enc.decoder[indices[0].item()]
|
151 |
+
#print(true_token_text, prop_token_text)
|
152 |
+
|
153 |
+
# Is there a more likely prefix token that could be the actual token generated?
|
154 |
+
if len(prop_token_text) < len(true_token_text) and \
|
155 |
+
prop_token_text == true_token_text[:len(prop_token_text)]:
|
156 |
+
suffix = true_token_text[len(prop_token_text):]
|
157 |
+
suffix_tokens = enc.encode(suffix) # a list
|
158 |
+
inp[i] = indices[0].item()
|
159 |
+
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
|
160 |
+
break
|
161 |
+
|
162 |
+
# Is there a more likely longer token that could be the actual token generated?
|
163 |
+
elif len(prop_token_text) > len(true_token_text) and \
|
164 |
+
true_token_text == prop_token_text[:len(true_token_text)]:
|
165 |
+
whole_text = true_token_text
|
166 |
+
num_extra = 1
|
167 |
+
while len(whole_text) < len(prop_token_text):
|
168 |
+
whole_text += enc.decoder[inp[i+num_extra]]
|
169 |
+
num_extra += 1
|
170 |
+
if prop_token_text == whole_text[:len(prop_token_text)]:
|
171 |
+
inp[i] = indices[0].item()
|
172 |
+
for j in range(1, num_extra):
|
173 |
+
del inp[i+j]
|
174 |
+
|
175 |
+
if len(whole_text) > len(prop_token_text):
|
176 |
+
suffix = whole_text[len(prop_token_text):]
|
177 |
+
suffix_tokens = enc.encode(suffix) # a list
|
178 |
+
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
|
179 |
+
break
|
180 |
+
else:
|
181 |
+
print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
|
182 |
+
|
183 |
+
tokens_t = int2bits(bin_num, block_size)
|
184 |
+
|
185 |
+
message.extend(tokens_t)
|
186 |
+
prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
|
187 |
+
i += 1
|
188 |
+
|
189 |
+
return message
|
190 |
+
|
191 |
+
if __name__ == '__main__':
|
192 |
+
np.random.seed(123)
|
193 |
+
|
194 |
+
bin2words, words2bin = get_bins(50257, 5)
|
195 |
+
print(words2bin[153])
|