Xinyoumeng233hu commited on
Commit
c6624a3
·
1 Parent(s): 0833281

Create block_baseline.py

Browse files
Files changed (1) hide show
  1. block_baseline.py +195 -0
block_baseline.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import numpy as np
5
+ from utils import kl, entropy, is_sent_finish, limit_past, bits2int, int2bits
6
+
7
+ # number of bins is 2^block_size
8
+ # each bin contains vocab_size/2^block_size words
9
+ def get_bins(vocab_size, block_size):
10
+ num_bins = 2**block_size
11
+ words_per_bin = vocab_size/num_bins
12
+
13
+ vocab_ordering = np.arange(vocab_size)
14
+ np.random.seed(block_size)
15
+ np.random.shuffle(vocab_ordering)
16
+
17
+ bin2words = [vocab_ordering[int(i*words_per_bin):int((i+1)*words_per_bin)] for i in range(num_bins)]
18
+ bin2words = [np.array(words) for words in bin2words]
19
+ words2bin_list = [{i: j for i in bin2words[j]} for j in range(num_bins)]
20
+ words2bin = {}
21
+ for d in words2bin_list:
22
+ words2bin.update(d)
23
+
24
+ return bin2words, words2bin
25
+
26
+ def encode_block(model, enc, message, context, block_size, bin2words, words2bin, finish_sent=False, device='cuda'):
27
+ length = len(message)
28
+
29
+ context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
30
+
31
+ prev = context
32
+ output = context
33
+ past = None
34
+
35
+ total_num = 0
36
+ total_num_for_stats = 0
37
+ total_log_probs = 0
38
+ total_kl = 0 # in bits
39
+ total_num_sents = 0
40
+
41
+ with torch.no_grad():
42
+ i = 0
43
+ sent_finish = False
44
+ while i < length or (finish_sent and not sent_finish):
45
+ logits, past = model(prev.unsqueeze(0), past=past)
46
+ past = limit_past(past)
47
+ logits[0, -1, -1] = -1e10 # endoftext can't happen
48
+ logits[0, -1, 628] = -1e10 # 2 newlines can't happen
49
+ logits = logits[0, -1, :]
50
+ log_probs = F.log_softmax(logits, dim=-1)
51
+
52
+ filtered_logits = logits.clone()
53
+ filtered_logits[:] = -1e10 # first set all to 0
54
+
55
+ if i >= length:
56
+ _, indices = logits.sort(descending=True)
57
+ sent_finish = is_sent_finish(indices[0].item(), enc)
58
+ else:
59
+ # First calculate logq
60
+ logq = logits.clone()
61
+ logq[:] = -1e10 # first set all to 0
62
+
63
+ for bin_val in range(2**block_size):
64
+ filtered_logits = logits.clone()
65
+ filtered_logits[:] = -1e10 # first set all to 0
66
+ available_tokens = bin2words[bin_val]
67
+ filtered_logits[available_tokens] = logits[available_tokens]
68
+ filtered_logits, indices = filtered_logits.sort(descending=True)
69
+
70
+ logq[indices[0]] = -block_size # in bits
71
+
72
+ logq = logq*0.69315 # in nats
73
+ q = torch.exp(logq)
74
+
75
+ # Then find the actual word for the right bin
76
+ m_part = message[i:i+block_size]
77
+
78
+ filtered_logits = logits.clone()
79
+ filtered_logits[:] = -1e10 # first set all to 0
80
+ available_tokens = bin2words[bits2int(m_part)]
81
+ filtered_logits[available_tokens] = logits[available_tokens]
82
+ filtered_logits, indices = filtered_logits.sort(descending=True)
83
+
84
+ total_kl += kl(q, logq, log_probs)
85
+ total_log_probs += log_probs[indices[0]].item()
86
+ i += block_size
87
+ total_num_for_stats += 1
88
+
89
+ total_num += 1
90
+ prev = indices[0].view(1)
91
+ output = torch.cat((output, prev))
92
+
93
+ avg_NLL = -total_log_probs/total_num_for_stats
94
+ avg_KL = total_kl/total_num_for_stats
95
+ words_per_bit = total_num_for_stats/i
96
+
97
+ return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
98
+
99
+ def decode_block(model, enc, text, context, block_size, bin2words, words2bin, device='cuda'):
100
+ # inp is a list of token indices
101
+ # context is a list of token indices
102
+ inp = enc.encode(text)
103
+ i = 0
104
+ while i < len(inp):
105
+ if inp[i] == 628:
106
+ inp[i] = 198
107
+ inp[i+1:i+1] = [198]
108
+ i += 2
109
+ else:
110
+ i += 1
111
+
112
+ context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
113
+ prev = context
114
+ past = None
115
+
116
+ message = []
117
+ with torch.no_grad():
118
+ i = 0
119
+ while i < len(inp):
120
+ if past and past[0].shape[3] >= 1023:
121
+ raise RuntimeError
122
+ bin_num = words2bin[inp[i]]
123
+
124
+ logits, past = model(prev.unsqueeze(0), past=past)
125
+ past = limit_past(past)
126
+ logits[0, -1, -1] = -1e10 # endoftext can't happen
127
+ logits[0, -1, 628] = -1e10 # 2 newlines can't happen
128
+
129
+ logits = logits[0, -1, :]
130
+ filtered_logits = logits.clone()
131
+ filtered_logits[:] = -1e10 # first set all to 0
132
+
133
+ available_tokens = bin2words[bin_num]
134
+ filtered_logits[available_tokens] = logits[available_tokens]
135
+ filtered_logits, indices = filtered_logits.sort(descending=True)
136
+
137
+ rank = (indices == inp[i]).nonzero().item()
138
+
139
+ # Handle errors that could happen because of BPE
140
+ if rank > 0:
141
+ true_token_text = enc.decoder[inp[i]]
142
+ for bin_num in range(len(bin2words)):
143
+ filtered_logits = logits.clone()
144
+ filtered_logits[:] = -1e10 # first set all to 0
145
+
146
+ available_tokens = bin2words[bin_num]
147
+ filtered_logits[available_tokens] = logits[available_tokens]
148
+ filtered_logits, indices = filtered_logits.sort(descending=True)
149
+
150
+ prop_token_text = enc.decoder[indices[0].item()]
151
+ #print(true_token_text, prop_token_text)
152
+
153
+ # Is there a more likely prefix token that could be the actual token generated?
154
+ if len(prop_token_text) < len(true_token_text) and \
155
+ prop_token_text == true_token_text[:len(prop_token_text)]:
156
+ suffix = true_token_text[len(prop_token_text):]
157
+ suffix_tokens = enc.encode(suffix) # a list
158
+ inp[i] = indices[0].item()
159
+ inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
160
+ break
161
+
162
+ # Is there a more likely longer token that could be the actual token generated?
163
+ elif len(prop_token_text) > len(true_token_text) and \
164
+ true_token_text == prop_token_text[:len(true_token_text)]:
165
+ whole_text = true_token_text
166
+ num_extra = 1
167
+ while len(whole_text) < len(prop_token_text):
168
+ whole_text += enc.decoder[inp[i+num_extra]]
169
+ num_extra += 1
170
+ if prop_token_text == whole_text[:len(prop_token_text)]:
171
+ inp[i] = indices[0].item()
172
+ for j in range(1, num_extra):
173
+ del inp[i+j]
174
+
175
+ if len(whole_text) > len(prop_token_text):
176
+ suffix = whole_text[len(prop_token_text):]
177
+ suffix_tokens = enc.encode(suffix) # a list
178
+ inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
179
+ break
180
+ else:
181
+ print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
182
+
183
+ tokens_t = int2bits(bin_num, block_size)
184
+
185
+ message.extend(tokens_t)
186
+ prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
187
+ i += 1
188
+
189
+ return message
190
+
191
+ if __name__ == '__main__':
192
+ np.random.seed(123)
193
+
194
+ bin2words, words2bin = get_bins(50257, 5)
195
+ print(words2bin[153])