Spaces:
Sleeping
Sleeping
Commit
·
0833281
1
Parent(s):
69c4e15
Create arithmetic.py
Browse files- arithmetic.py +260 -0
arithmetic.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from utils import limit_past, kl, entropy, bits2int, int2bits, is_sent_finish, num_same_from_beg
|
5 |
+
|
6 |
+
def encode_arithmetic(model, enc, message, context, finish_sent=False, device='cuda', temp=1.0, precision=16, topk=50000):
|
7 |
+
|
8 |
+
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
9 |
+
|
10 |
+
max_val = 2**precision
|
11 |
+
threshold = 2**(-precision)
|
12 |
+
cur_interval = [0, max_val] # bottom inclusive, top exclusive
|
13 |
+
|
14 |
+
prev = context
|
15 |
+
output = context
|
16 |
+
past = None
|
17 |
+
|
18 |
+
total_num = 0
|
19 |
+
total_num_for_stats = 0
|
20 |
+
total_log_probs = 0
|
21 |
+
total_kl = 0 # in bits
|
22 |
+
total_entropy_ptau = 0
|
23 |
+
total_num_sents = 0
|
24 |
+
|
25 |
+
with torch.no_grad():
|
26 |
+
i = 0
|
27 |
+
sent_finish = False
|
28 |
+
while i < len(message) or (finish_sent and not sent_finish):
|
29 |
+
logits, past = model(prev.unsqueeze(0), past=past)
|
30 |
+
past = limit_past(past)
|
31 |
+
logits[0, -1, -1] = -1e20 # endoftext token can't happen
|
32 |
+
logits[0, -1, 628] = -1e20 # 2 newlines token can't happen
|
33 |
+
logits, indices = logits[0, -1, :].sort(descending=True)
|
34 |
+
logits = logits.double()
|
35 |
+
logits_temp = logits / temp
|
36 |
+
probs_temp = F.softmax(logits_temp, dim=0)
|
37 |
+
log_probs_temp = F.log_softmax(logits_temp, dim=0)
|
38 |
+
log_probs = F.log_softmax(logits, dim=0)
|
39 |
+
|
40 |
+
# conditions for having reached the end of the message
|
41 |
+
if i >= len(message):
|
42 |
+
selection = 0
|
43 |
+
sent_finish = is_sent_finish(indices[selection].item(), enc)
|
44 |
+
else:
|
45 |
+
# Cutoff low probabilities that would be rounded to 0
|
46 |
+
cur_int_range = cur_interval[1]-cur_interval[0]
|
47 |
+
cur_threshold = 1/cur_int_range
|
48 |
+
k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
|
49 |
+
probs_temp_int = probs_temp[:k] # Cutoff all but top k
|
50 |
+
|
51 |
+
# Rescale to correct range
|
52 |
+
probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range
|
53 |
+
|
54 |
+
# Round probabilities to integers given precision
|
55 |
+
probs_temp_int = probs_temp_int.round().long()
|
56 |
+
cum_probs = probs_temp_int.cumsum(0)
|
57 |
+
|
58 |
+
# Remove any elements from the bottom if rounding caused the total prob to be too large
|
59 |
+
overfill_index = (cum_probs > cur_int_range).nonzero()
|
60 |
+
if len(overfill_index) > 0:
|
61 |
+
cum_probs = cum_probs[:overfill_index[0]]
|
62 |
+
|
63 |
+
# Add any mass to the top if removing/rounding causes the total prob to be too small
|
64 |
+
cum_probs += cur_int_range-cum_probs[-1] # add
|
65 |
+
|
66 |
+
# Get out resulting probabilities
|
67 |
+
probs_final = cum_probs.clone()
|
68 |
+
probs_final[1:] = cum_probs[1:] - cum_probs[:-1]
|
69 |
+
|
70 |
+
# Convert to position in range
|
71 |
+
cum_probs += cur_interval[0]
|
72 |
+
|
73 |
+
# Get selected index based on binary fraction from message bits
|
74 |
+
message_bits = message[i:i+precision]
|
75 |
+
if i+precision > len(message):
|
76 |
+
message_bits = message_bits + [0]*(i+precision-len(message))
|
77 |
+
message_idx = bits2int(reversed(message_bits))
|
78 |
+
selection = (cum_probs > message_idx).nonzero()[0].item()
|
79 |
+
|
80 |
+
# Calculate new range as ints
|
81 |
+
new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]
|
82 |
+
new_int_top = cum_probs[selection]
|
83 |
+
|
84 |
+
# Convert range to bits
|
85 |
+
new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
|
86 |
+
new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
|
87 |
+
|
88 |
+
# Consume most significant bits which are now fixed and update interval
|
89 |
+
num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
|
90 |
+
i += num_bits_encoded
|
91 |
+
|
92 |
+
new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0]*num_bits_encoded
|
93 |
+
new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1]*num_bits_encoded
|
94 |
+
|
95 |
+
cur_interval[0] = bits2int(reversed(new_int_bottom_bits))
|
96 |
+
cur_interval[1] = bits2int(reversed(new_int_top_bits))+1 # +1 here because upper bound is exclusive
|
97 |
+
|
98 |
+
# Gather statistics
|
99 |
+
total_log_probs += log_probs[selection].item()
|
100 |
+
|
101 |
+
q = probs_final.double()/probs_final.sum()
|
102 |
+
logq = q.log()
|
103 |
+
total_kl += kl(q, logq, log_probs[:len(q)])
|
104 |
+
total_entropy_ptau += entropy(probs_temp, log_probs_temp)
|
105 |
+
total_num_for_stats += 1
|
106 |
+
|
107 |
+
# Update history with new token
|
108 |
+
prev = indices[selection].view(1)
|
109 |
+
output = torch.cat((output, prev))
|
110 |
+
total_num += 1
|
111 |
+
#print(enc.decode(prev.tolist()), message_bits[:num_bits_encoded])
|
112 |
+
|
113 |
+
# For text->bits->text
|
114 |
+
partial = enc.decode(output[len(context):].tolist())
|
115 |
+
if '<eos>' in partial:
|
116 |
+
break
|
117 |
+
|
118 |
+
avg_NLL = -total_log_probs/total_num_for_stats
|
119 |
+
avg_KL = total_kl/total_num_for_stats
|
120 |
+
words_per_bit = total_num_for_stats/i
|
121 |
+
# avg_Hq = total_entropy_ptau/total_num_for_stats
|
122 |
+
|
123 |
+
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
124 |
+
|
125 |
+
def decode_arithmetic(model, enc, text, context, device='cuda', temp=1.0, precision=16, topk=50000):
|
126 |
+
# inp is a list of token indices
|
127 |
+
# context is a list of token indices
|
128 |
+
inp = enc.encode(text)
|
129 |
+
# common BPE error case: 128, 128 (2 newlines) is interpretted as 628 (2 newlines)
|
130 |
+
i = 0
|
131 |
+
while i < len(inp):
|
132 |
+
if inp[i] == 628:
|
133 |
+
inp[i] = 198
|
134 |
+
inp[i+1:i+1] = [198]
|
135 |
+
i += 2
|
136 |
+
else:
|
137 |
+
i += 1
|
138 |
+
|
139 |
+
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
140 |
+
|
141 |
+
max_val = 2**precision
|
142 |
+
threshold = 2**(-precision)
|
143 |
+
cur_interval = [0, max_val] # bottom inclusive, top exclusive
|
144 |
+
|
145 |
+
prev = context
|
146 |
+
past = None
|
147 |
+
message = []
|
148 |
+
with torch.no_grad():
|
149 |
+
i = 0
|
150 |
+
while i < len(inp):
|
151 |
+
logits, past = model(prev.unsqueeze(0), past=past)
|
152 |
+
past = limit_past(past)
|
153 |
+
logits[0, -1, -1] = -1e10 # endoftext can't happen
|
154 |
+
logits[0, -1, 628] = -1e10 # 2 newlines can't happen
|
155 |
+
logits, indices = logits[0, -1, :].sort(descending=True)
|
156 |
+
logits = logits.double()
|
157 |
+
logits_temp = logits / temp
|
158 |
+
probs_temp = F.softmax(logits_temp, dim=0)
|
159 |
+
|
160 |
+
# Cutoff low probabilities that would be rounded to 0
|
161 |
+
cur_int_range = cur_interval[1]-cur_interval[0]
|
162 |
+
cur_threshold = 1/cur_int_range
|
163 |
+
k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
|
164 |
+
probs_temp_int = probs_temp[:k] # Cutoff all but top k
|
165 |
+
|
166 |
+
# Rescale to correct range
|
167 |
+
probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range
|
168 |
+
|
169 |
+
# Round probabilities to integers given precision
|
170 |
+
probs_temp_int = probs_temp_int.round().long()
|
171 |
+
cum_probs = probs_temp_int.cumsum(0)
|
172 |
+
|
173 |
+
# Remove any elements from the bottom if rounding caused the total prob to be too large
|
174 |
+
overfill_index = (cum_probs > cur_int_range).nonzero()
|
175 |
+
if len(overfill_index) > 0:
|
176 |
+
cum_probs = cum_probs[:overfill_index[0]]
|
177 |
+
k = overfill_index[0].item()
|
178 |
+
|
179 |
+
# Add any mass to the top if removing/rounding causes the total prob to be too small
|
180 |
+
cum_probs += cur_int_range-cum_probs[-1] # add
|
181 |
+
|
182 |
+
# Covnert to position in range
|
183 |
+
cum_probs += cur_interval[0]
|
184 |
+
|
185 |
+
rank = (indices == inp[i]).nonzero().item()
|
186 |
+
|
187 |
+
# Handle most errors that could happen because of BPE with heuristic
|
188 |
+
if rank >= k:
|
189 |
+
true_token_text = enc.decoder[inp[i]]
|
190 |
+
for rank_idx in range(k):
|
191 |
+
prop_token_text = enc.decoder[indices[rank_idx].item()]
|
192 |
+
# common case that is not caught
|
193 |
+
if inp[i] == 128 and indices[rank_idx] == 198:
|
194 |
+
rank = rank_idx
|
195 |
+
inp[i] = indices[rank_idx].item()
|
196 |
+
break
|
197 |
+
|
198 |
+
# Is there a more likely prefix token that could be the actual token generated?
|
199 |
+
if len(prop_token_text) <= len(true_token_text) and \
|
200 |
+
prop_token_text == true_token_text[:len(prop_token_text)]:
|
201 |
+
rank = rank_idx
|
202 |
+
suffix = true_token_text[len(prop_token_text):]
|
203 |
+
suffix_tokens = enc.encode(suffix) # a list
|
204 |
+
inp[i] = indices[rank_idx].item()
|
205 |
+
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
|
206 |
+
break
|
207 |
+
|
208 |
+
# Is there a more likely longer token that could be the actual token generated?
|
209 |
+
elif len(prop_token_text) > len(true_token_text) and \
|
210 |
+
true_token_text == prop_token_text[:len(true_token_text)]:
|
211 |
+
whole_text = true_token_text
|
212 |
+
num_extra = 1
|
213 |
+
while len(whole_text) < len(prop_token_text):
|
214 |
+
whole_text += enc.decoder[inp[i+num_extra]]
|
215 |
+
num_extra += 1
|
216 |
+
if prop_token_text == whole_text[:len(prop_token_text)]:
|
217 |
+
rank = rank_idx
|
218 |
+
inp[i] = indices[rank_idx].item()
|
219 |
+
for j in range(1, num_extra):
|
220 |
+
del inp[i+j]
|
221 |
+
|
222 |
+
if len(whole_text) > len(prop_token_text):
|
223 |
+
suffix = whole_text[len(prop_token_text):]
|
224 |
+
suffix_tokens = enc.encode(suffix) # a list
|
225 |
+
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
|
226 |
+
break
|
227 |
+
else:
|
228 |
+
print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
|
229 |
+
rank = 0
|
230 |
+
|
231 |
+
selection = rank
|
232 |
+
|
233 |
+
# Calculate new range as ints
|
234 |
+
new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]
|
235 |
+
new_int_top = cum_probs[selection]
|
236 |
+
|
237 |
+
# Convert range to bits
|
238 |
+
new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
|
239 |
+
new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
|
240 |
+
|
241 |
+
# Emit most significant bits which are now fixed and update interval
|
242 |
+
num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
|
243 |
+
if i == len(inp)-1:
|
244 |
+
new_bits = new_int_bottom_bits_inc
|
245 |
+
else:
|
246 |
+
new_bits = new_int_top_bits_inc[:num_bits_encoded]
|
247 |
+
message += new_bits
|
248 |
+
|
249 |
+
new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0]*num_bits_encoded
|
250 |
+
new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1]*num_bits_encoded
|
251 |
+
|
252 |
+
cur_interval[0] = bits2int(reversed(new_int_bottom_bits))
|
253 |
+
cur_interval[1] = bits2int(reversed(new_int_top_bits))+1 # +1 here because upper bound is exclusive
|
254 |
+
|
255 |
+
# Update history with new token
|
256 |
+
prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
|
257 |
+
#print(enc.decode([inp[i]]), new_bits)
|
258 |
+
i += 1
|
259 |
+
|
260 |
+
return message
|