Xinyoumeng233hu commited on
Commit
0833281
·
1 Parent(s): 69c4e15

Create arithmetic.py

Browse files
Files changed (1) hide show
  1. arithmetic.py +260 -0
arithmetic.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from utils import limit_past, kl, entropy, bits2int, int2bits, is_sent_finish, num_same_from_beg
5
+
6
+ def encode_arithmetic(model, enc, message, context, finish_sent=False, device='cuda', temp=1.0, precision=16, topk=50000):
7
+
8
+ context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
9
+
10
+ max_val = 2**precision
11
+ threshold = 2**(-precision)
12
+ cur_interval = [0, max_val] # bottom inclusive, top exclusive
13
+
14
+ prev = context
15
+ output = context
16
+ past = None
17
+
18
+ total_num = 0
19
+ total_num_for_stats = 0
20
+ total_log_probs = 0
21
+ total_kl = 0 # in bits
22
+ total_entropy_ptau = 0
23
+ total_num_sents = 0
24
+
25
+ with torch.no_grad():
26
+ i = 0
27
+ sent_finish = False
28
+ while i < len(message) or (finish_sent and not sent_finish):
29
+ logits, past = model(prev.unsqueeze(0), past=past)
30
+ past = limit_past(past)
31
+ logits[0, -1, -1] = -1e20 # endoftext token can't happen
32
+ logits[0, -1, 628] = -1e20 # 2 newlines token can't happen
33
+ logits, indices = logits[0, -1, :].sort(descending=True)
34
+ logits = logits.double()
35
+ logits_temp = logits / temp
36
+ probs_temp = F.softmax(logits_temp, dim=0)
37
+ log_probs_temp = F.log_softmax(logits_temp, dim=0)
38
+ log_probs = F.log_softmax(logits, dim=0)
39
+
40
+ # conditions for having reached the end of the message
41
+ if i >= len(message):
42
+ selection = 0
43
+ sent_finish = is_sent_finish(indices[selection].item(), enc)
44
+ else:
45
+ # Cutoff low probabilities that would be rounded to 0
46
+ cur_int_range = cur_interval[1]-cur_interval[0]
47
+ cur_threshold = 1/cur_int_range
48
+ k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
49
+ probs_temp_int = probs_temp[:k] # Cutoff all but top k
50
+
51
+ # Rescale to correct range
52
+ probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range
53
+
54
+ # Round probabilities to integers given precision
55
+ probs_temp_int = probs_temp_int.round().long()
56
+ cum_probs = probs_temp_int.cumsum(0)
57
+
58
+ # Remove any elements from the bottom if rounding caused the total prob to be too large
59
+ overfill_index = (cum_probs > cur_int_range).nonzero()
60
+ if len(overfill_index) > 0:
61
+ cum_probs = cum_probs[:overfill_index[0]]
62
+
63
+ # Add any mass to the top if removing/rounding causes the total prob to be too small
64
+ cum_probs += cur_int_range-cum_probs[-1] # add
65
+
66
+ # Get out resulting probabilities
67
+ probs_final = cum_probs.clone()
68
+ probs_final[1:] = cum_probs[1:] - cum_probs[:-1]
69
+
70
+ # Convert to position in range
71
+ cum_probs += cur_interval[0]
72
+
73
+ # Get selected index based on binary fraction from message bits
74
+ message_bits = message[i:i+precision]
75
+ if i+precision > len(message):
76
+ message_bits = message_bits + [0]*(i+precision-len(message))
77
+ message_idx = bits2int(reversed(message_bits))
78
+ selection = (cum_probs > message_idx).nonzero()[0].item()
79
+
80
+ # Calculate new range as ints
81
+ new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]
82
+ new_int_top = cum_probs[selection]
83
+
84
+ # Convert range to bits
85
+ new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
86
+ new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
87
+
88
+ # Consume most significant bits which are now fixed and update interval
89
+ num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
90
+ i += num_bits_encoded
91
+
92
+ new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0]*num_bits_encoded
93
+ new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1]*num_bits_encoded
94
+
95
+ cur_interval[0] = bits2int(reversed(new_int_bottom_bits))
96
+ cur_interval[1] = bits2int(reversed(new_int_top_bits))+1 # +1 here because upper bound is exclusive
97
+
98
+ # Gather statistics
99
+ total_log_probs += log_probs[selection].item()
100
+
101
+ q = probs_final.double()/probs_final.sum()
102
+ logq = q.log()
103
+ total_kl += kl(q, logq, log_probs[:len(q)])
104
+ total_entropy_ptau += entropy(probs_temp, log_probs_temp)
105
+ total_num_for_stats += 1
106
+
107
+ # Update history with new token
108
+ prev = indices[selection].view(1)
109
+ output = torch.cat((output, prev))
110
+ total_num += 1
111
+ #print(enc.decode(prev.tolist()), message_bits[:num_bits_encoded])
112
+
113
+ # For text->bits->text
114
+ partial = enc.decode(output[len(context):].tolist())
115
+ if '<eos>' in partial:
116
+ break
117
+
118
+ avg_NLL = -total_log_probs/total_num_for_stats
119
+ avg_KL = total_kl/total_num_for_stats
120
+ words_per_bit = total_num_for_stats/i
121
+ # avg_Hq = total_entropy_ptau/total_num_for_stats
122
+
123
+ return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
124
+
125
+ def decode_arithmetic(model, enc, text, context, device='cuda', temp=1.0, precision=16, topk=50000):
126
+ # inp is a list of token indices
127
+ # context is a list of token indices
128
+ inp = enc.encode(text)
129
+ # common BPE error case: 128, 128 (2 newlines) is interpretted as 628 (2 newlines)
130
+ i = 0
131
+ while i < len(inp):
132
+ if inp[i] == 628:
133
+ inp[i] = 198
134
+ inp[i+1:i+1] = [198]
135
+ i += 2
136
+ else:
137
+ i += 1
138
+
139
+ context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
140
+
141
+ max_val = 2**precision
142
+ threshold = 2**(-precision)
143
+ cur_interval = [0, max_val] # bottom inclusive, top exclusive
144
+
145
+ prev = context
146
+ past = None
147
+ message = []
148
+ with torch.no_grad():
149
+ i = 0
150
+ while i < len(inp):
151
+ logits, past = model(prev.unsqueeze(0), past=past)
152
+ past = limit_past(past)
153
+ logits[0, -1, -1] = -1e10 # endoftext can't happen
154
+ logits[0, -1, 628] = -1e10 # 2 newlines can't happen
155
+ logits, indices = logits[0, -1, :].sort(descending=True)
156
+ logits = logits.double()
157
+ logits_temp = logits / temp
158
+ probs_temp = F.softmax(logits_temp, dim=0)
159
+
160
+ # Cutoff low probabilities that would be rounded to 0
161
+ cur_int_range = cur_interval[1]-cur_interval[0]
162
+ cur_threshold = 1/cur_int_range
163
+ k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
164
+ probs_temp_int = probs_temp[:k] # Cutoff all but top k
165
+
166
+ # Rescale to correct range
167
+ probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range
168
+
169
+ # Round probabilities to integers given precision
170
+ probs_temp_int = probs_temp_int.round().long()
171
+ cum_probs = probs_temp_int.cumsum(0)
172
+
173
+ # Remove any elements from the bottom if rounding caused the total prob to be too large
174
+ overfill_index = (cum_probs > cur_int_range).nonzero()
175
+ if len(overfill_index) > 0:
176
+ cum_probs = cum_probs[:overfill_index[0]]
177
+ k = overfill_index[0].item()
178
+
179
+ # Add any mass to the top if removing/rounding causes the total prob to be too small
180
+ cum_probs += cur_int_range-cum_probs[-1] # add
181
+
182
+ # Covnert to position in range
183
+ cum_probs += cur_interval[0]
184
+
185
+ rank = (indices == inp[i]).nonzero().item()
186
+
187
+ # Handle most errors that could happen because of BPE with heuristic
188
+ if rank >= k:
189
+ true_token_text = enc.decoder[inp[i]]
190
+ for rank_idx in range(k):
191
+ prop_token_text = enc.decoder[indices[rank_idx].item()]
192
+ # common case that is not caught
193
+ if inp[i] == 128 and indices[rank_idx] == 198:
194
+ rank = rank_idx
195
+ inp[i] = indices[rank_idx].item()
196
+ break
197
+
198
+ # Is there a more likely prefix token that could be the actual token generated?
199
+ if len(prop_token_text) <= len(true_token_text) and \
200
+ prop_token_text == true_token_text[:len(prop_token_text)]:
201
+ rank = rank_idx
202
+ suffix = true_token_text[len(prop_token_text):]
203
+ suffix_tokens = enc.encode(suffix) # a list
204
+ inp[i] = indices[rank_idx].item()
205
+ inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
206
+ break
207
+
208
+ # Is there a more likely longer token that could be the actual token generated?
209
+ elif len(prop_token_text) > len(true_token_text) and \
210
+ true_token_text == prop_token_text[:len(true_token_text)]:
211
+ whole_text = true_token_text
212
+ num_extra = 1
213
+ while len(whole_text) < len(prop_token_text):
214
+ whole_text += enc.decoder[inp[i+num_extra]]
215
+ num_extra += 1
216
+ if prop_token_text == whole_text[:len(prop_token_text)]:
217
+ rank = rank_idx
218
+ inp[i] = indices[rank_idx].item()
219
+ for j in range(1, num_extra):
220
+ del inp[i+j]
221
+
222
+ if len(whole_text) > len(prop_token_text):
223
+ suffix = whole_text[len(prop_token_text):]
224
+ suffix_tokens = enc.encode(suffix) # a list
225
+ inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
226
+ break
227
+ else:
228
+ print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
229
+ rank = 0
230
+
231
+ selection = rank
232
+
233
+ # Calculate new range as ints
234
+ new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]
235
+ new_int_top = cum_probs[selection]
236
+
237
+ # Convert range to bits
238
+ new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
239
+ new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
240
+
241
+ # Emit most significant bits which are now fixed and update interval
242
+ num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
243
+ if i == len(inp)-1:
244
+ new_bits = new_int_bottom_bits_inc
245
+ else:
246
+ new_bits = new_int_top_bits_inc[:num_bits_encoded]
247
+ message += new_bits
248
+
249
+ new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0]*num_bits_encoded
250
+ new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1]*num_bits_encoded
251
+
252
+ cur_interval[0] = bits2int(reversed(new_int_bottom_bits))
253
+ cur_interval[1] = bits2int(reversed(new_int_top_bits))+1 # +1 here because upper bound is exclusive
254
+
255
+ # Update history with new token
256
+ prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
257
+ #print(enc.decode([inp[i]]), new_bits)
258
+ i += 1
259
+
260
+ return message