Spaces:
Sleeping
Sleeping
File size: 6,625 Bytes
229a3ba 1c251e8 229a3ba 1c251e8 229a3ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import torch
import torch.nn.functional as F
from huffman import HuffmanCoding
from utils import kl, entropy, is_sent_finish, limit_past
def encode_huffman(model, enc, message, context, bits_per_word, finish_sent=False, device='cpu'):
length = len(message)
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
prev = context
output = context
past = None
total_num = 0
total_num_for_stats = 0
total_log_probs = 0
total_kl = 0 # in bits
total_num_sents = 0
with torch.no_grad():
i = 0
sent_finish = False
while i < length or (finish_sent and not sent_finish):
logits, past = model(prev.unsqueeze(0), past=past)
past = limit_past(past)
logits[0, -1, -1] = -1e10 # endoftext can't happen
logits[0, -1, 628] = -1e10 # 2 newlines can't happen
logits, indices = logits[0, -1, :].sort(descending=True)
# Get the top 2**bits options
indices = indices[:2**bits_per_word]
log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word]
probs = torch.exp(log_probs)
if i >= length:
selection = 0
sent_finish = is_sent_finish(indices[0].item(), enc)
else:
probs_array = probs.cpu().numpy()
coding = HuffmanCoding()
coding.make_heap_from_array(probs_array)
coding.merge_nodes()
root = coding.make_codes()
#print(message[i:i+10])
while root.token is None:
if i >= length or message[i] == 0:
root = root.left
else:
root = root.right
i += 1
selection = root.token
logq = torch.tensor([-len(coding.codes[idx]) for idx in range(len(probs_array))], dtype=torch.float, device=device) # in bits
logq = logq*0.69315 # in nats
q = torch.exp(logq)
total_kl += kl(q, logq, log_probs)
total_log_probs += log_probs[selection].item()
total_num_for_stats += 1
total_num += 1
prev = indices[selection].view(1)
output = torch.cat((output, prev))
avg_NLL = -total_log_probs/total_num_for_stats
avg_KL = total_kl/total_num_for_stats
words_per_bit = total_num_for_stats/i
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
def decode_huffman(model, enc, text, context, bits_per_word, device='cpu'):
# inp is a list of token indices
# context is a list of token indices
inp = enc.encode(text)
i = 0
while i < len(inp):
if inp[i] == 628:
inp[i] = 198
inp[i+1:i+1] = [198]
i += 2
else:
i += 1
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
prev = context
past = None
message = []
with torch.no_grad():
i = 0
while i < len(inp):
if past and past[0].shape[3] >= 1023:
raise RuntimeError
logits, past = model(prev.unsqueeze(0), past=past)
past = limit_past(past)
logits[0, -1, -1] = -1e10 # endoftext can't happen
logits[0, -1, 628] = -1e10 # 2 newlines can't happen
logits, indices = logits[0, -1, :].sort(descending=True)
# Get the top 2**bits options
indices = indices[:2**bits_per_word]
log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word]
probs = torch.exp(log_probs)
if inp[i] not in indices:
true_token_text = enc.decoder[inp[i]]
for rank_idx in range(2**bits_per_word):
prop_token_text = enc.decoder[indices[rank_idx].item()]
# common case that is not caught
if inp[i] == 128 and indices[rank_idx] == 198:
rank = rank_idx
inp[i] = indices[rank_idx].item()
break
# Is there a more likely prefix token that could be the actual token generated?
if len(prop_token_text) <= len(true_token_text) and \
prop_token_text == true_token_text[:len(prop_token_text)]:
rank = rank_idx
suffix = true_token_text[len(prop_token_text):]
suffix_tokens = enc.encode(suffix) # a list
inp[i] = indices[rank_idx].item()
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
break
# Is there a more likely longer token that could be the actual token generated?
elif len(prop_token_text) > len(true_token_text) and \
true_token_text == prop_token_text[:len(true_token_text)]:
whole_text = true_token_text
num_extra = 1
while len(whole_text) < len(prop_token_text):
whole_text += enc.decoder[inp[i+num_extra]]
num_extra += 1
if prop_token_text == whole_text[:len(prop_token_text)]:
rank = rank_idx
inp[i] = indices[rank_idx].item()
for j in range(1, num_extra):
del inp[i+j]
if len(whole_text) > len(prop_token_text):
suffix = whole_text[len(prop_token_text):]
suffix_tokens = enc.encode(suffix) # a list
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
break
else:
print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
rank = 0
else:
rank = (indices == inp[i]).nonzero().item()
probs_array = probs.cpu().numpy()
coding = HuffmanCoding()
coding.make_heap_from_array(probs_array)
coding.merge_nodes()
coding.make_codes()
tokens_t = map(int, coding.codes[rank])
message.extend(tokens_t)
prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
i += 1
return message
|