File size: 7,803 Bytes
c6624a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6a5fdd
c6624a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6a5fdd
c6624a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import torch
import torch.nn.functional as F

import numpy as np
from utils import kl, entropy, is_sent_finish, limit_past, bits2int, int2bits

# number of bins is 2^block_size
# each bin contains vocab_size/2^block_size words
def get_bins(vocab_size, block_size):
    num_bins = 2**block_size
    words_per_bin = vocab_size/num_bins 

    vocab_ordering = np.arange(vocab_size)
    np.random.seed(block_size)
    np.random.shuffle(vocab_ordering)

    bin2words = [vocab_ordering[int(i*words_per_bin):int((i+1)*words_per_bin)] for i in range(num_bins)]
    bin2words = [np.array(words) for words in bin2words]
    words2bin_list = [{i: j for i in bin2words[j]} for j in range(num_bins)]
    words2bin = {}
    for d in words2bin_list:
        words2bin.update(d)

    return bin2words, words2bin

def encode_block(model, enc, message, context, block_size, bin2words, words2bin, finish_sent=False, device='cpu'):
    length = len(message)

    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
    
    prev = context
    output = context
    past = None
    
    total_num = 0
    total_num_for_stats = 0
    total_log_probs = 0
    total_kl = 0 # in bits
    total_num_sents = 0
    
    with torch.no_grad():
        i = 0
        sent_finish = False
        while i < length or (finish_sent and not sent_finish):
            logits, past = model(prev.unsqueeze(0), past=past)
            past = limit_past(past)
            logits[0, -1, -1] = -1e10 # endoftext can't happen
            logits[0, -1, 628] = -1e10 # 2 newlines can't happen
            logits = logits[0, -1, :]
            log_probs = F.log_softmax(logits, dim=-1)
            
            filtered_logits = logits.clone()
            filtered_logits[:] = -1e10 # first set all to 0

            if i >= length:
                _, indices = logits.sort(descending=True)
                sent_finish = is_sent_finish(indices[0].item(), enc)
            else:
                # First calculate logq
                logq = logits.clone()
                logq[:] = -1e10 # first set all to 0

                for bin_val in range(2**block_size):
                    filtered_logits = logits.clone()
                    filtered_logits[:] = -1e10 # first set all to 0
                    available_tokens = bin2words[bin_val]
                    filtered_logits[available_tokens] = logits[available_tokens]
                    filtered_logits, indices = filtered_logits.sort(descending=True)

                    logq[indices[0]] = -block_size # in bits

                logq = logq*0.69315 # in nats
                q = torch.exp(logq)

                # Then find the actual word for the right bin
                m_part = message[i:i+block_size]

                filtered_logits = logits.clone()
                filtered_logits[:] = -1e10 # first set all to 0
                available_tokens = bin2words[bits2int(m_part)]
                filtered_logits[available_tokens] = logits[available_tokens]
                filtered_logits, indices = filtered_logits.sort(descending=True)

                total_kl += kl(q, logq, log_probs)
                total_log_probs += log_probs[indices[0]].item()
                i += block_size
                total_num_for_stats += 1
                
            total_num += 1
            prev = indices[0].view(1)
            output = torch.cat((output, prev))

    avg_NLL = -total_log_probs/total_num_for_stats
    avg_KL = total_kl/total_num_for_stats
    words_per_bit = total_num_for_stats/i

    return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit

def decode_block(model, enc, text, context, block_size, bin2words, words2bin, device='cpu'):
    # inp is a list of token indices
    # context is a list of token indices
    inp = enc.encode(text)
    i = 0
    while i < len(inp):
        if inp[i] == 628:
            inp[i] = 198
            inp[i+1:i+1] = [198]
            i += 2
        else:
            i += 1

    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
    prev = context
    past = None

    message = []
    with torch.no_grad():
        i = 0
        while i < len(inp):
            if past and past[0].shape[3] >= 1023:
                raise RuntimeError
            bin_num = words2bin[inp[i]]

            logits, past = model(prev.unsqueeze(0), past=past)
            past = limit_past(past)
            logits[0, -1, -1] = -1e10 # endoftext can't happen
            logits[0, -1, 628] = -1e10 # 2 newlines can't happen

            logits = logits[0, -1, :]
            filtered_logits = logits.clone()
            filtered_logits[:] = -1e10 # first set all to 0

            available_tokens = bin2words[bin_num]
            filtered_logits[available_tokens] = logits[available_tokens]
            filtered_logits, indices = filtered_logits.sort(descending=True)
            
            rank = (indices == inp[i]).nonzero().item()

            # Handle errors that could happen because of BPE
            if rank > 0:
                true_token_text = enc.decoder[inp[i]]
                for bin_num in range(len(bin2words)):
                    filtered_logits = logits.clone()
                    filtered_logits[:] = -1e10 # first set all to 0

                    available_tokens = bin2words[bin_num]
                    filtered_logits[available_tokens] = logits[available_tokens]
                    filtered_logits, indices = filtered_logits.sort(descending=True)

                    prop_token_text = enc.decoder[indices[0].item()]
                    #print(true_token_text, prop_token_text)

                    # Is there a more likely prefix token that could be the actual token generated?
                    if len(prop_token_text) < len(true_token_text) and \
                            prop_token_text == true_token_text[:len(prop_token_text)]:
                        suffix = true_token_text[len(prop_token_text):]
                        suffix_tokens = enc.encode(suffix) # a list
                        inp[i] = indices[0].item()
                        inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
                        break

                    # Is there a more likely longer token that could be the actual token generated?
                    elif len(prop_token_text) > len(true_token_text) and \
                              true_token_text == prop_token_text[:len(true_token_text)]:
                        whole_text = true_token_text
                        num_extra = 1
                        while len(whole_text) < len(prop_token_text):
                            whole_text += enc.decoder[inp[i+num_extra]]
                            num_extra += 1
                        if prop_token_text == whole_text[:len(prop_token_text)]:
                            inp[i] = indices[0].item()
                            for j in range(1, num_extra):
                                del inp[i+j]
                        
                            if len(whole_text) > len(prop_token_text):
                                suffix = whole_text[len(prop_token_text):]
                                suffix_tokens = enc.encode(suffix) # a list
                                inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
                            break
                else:
                    print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))

            tokens_t = int2bits(bin_num, block_size)

            message.extend(tokens_t)
            prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
            i += 1

    return message

if __name__ == '__main__':
    np.random.seed(123)

    bin2words, words2bin = get_bins(50257, 5)
    print(words2bin[153])