Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import bitarray | |
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer | |
def decode(self, token_ids, **kwargs): | |
filtered_tokens = self.convert_ids_to_tokens(token_ids) | |
text = self.convert_tokens_to_string(filtered_tokens) | |
return text | |
GPT2Tokenizer.decode = decode | |
def _convert_token_to_id(self, token): | |
return self.encoder.get(token, 0) | |
GPT2Tokenizer._convert_token_to_id = _convert_token_to_id | |
def limit_past(past): | |
past = list(past) | |
for i in range(len(past)): | |
past[i] = past[i][:, :, :, -1022:] | |
return past | |
def kl(q, logq, logp): | |
res = q*(logq-logp)/0.69315 | |
res[q==0] = 0 | |
return res.sum().item() # in bits | |
def entropy(q, logq): | |
res = q*logq/0.69315 | |
res[q==0] = 0 | |
return -res.sum().item() # in bits | |
# e.g. [0, 1, 1, 1] looks like 1110=14 | |
def bits2int(bits): | |
res = 0 | |
for i, bit in enumerate(bits): | |
res += bit*(2**i) | |
return res | |
def int2bits(inp, num_bits): | |
if num_bits == 0: | |
return [] | |
strlist = ('{0:0%db}'%num_bits).format(inp) | |
return [int(strval) for strval in reversed(strlist)] | |
def is_sent_finish(token_idx, enc): | |
token = enc.decoder[token_idx] | |
return '.' in token or '!' in token or '?' in token | |
def num_same_from_beg(bits1, bits2): | |
assert len(bits1) == len(bits2) | |
for i in range(len(bits1)): | |
if bits1[i] != bits2[i]: | |
break | |
return i | |
def encode_context(raw_text, enc): | |
context_tokens = [enc.encoder['<|endoftext|>']] + enc.encode(raw_text) | |
return context_tokens | |
# Use gpt2-medium for 345M param model | |
# Use gpt2-large for 774M param model | |
def get_model(seed=1234, model_name='gpt2'): | |
np.random.seed(seed) | |
torch.random.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
enc = GPT2Tokenizer.from_pretrained(model_name) | |
enc.unk_token = None | |
enc.bos_token = None | |
enc.eos_token = None | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
model.to(device) | |
model.eval() | |
#model.double() | |
return enc, model | |
enc32_itoc = ['\0', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '.', ',', "'", '!', ' '] | |
enc32_ctoi = {k: v for v, k in enumerate(enc32_itoc)} | |
def enc32(text): | |
bits = [] | |
for c in text: | |
bits.extend(int2bits(enc32_ctoi[c], 5)) | |
return bits | |
def dec32(bits): | |
text = '' | |
for i in range(0, len(bits), 5): | |
c = enc32_itoc[bits2int(bits[i:i+5])] | |
if c == '\0': | |
break | |
text += c | |
return text | |
# message should be bit string | |
# encoded should be text string | |
def expansion_ratio(message, encoded): | |
message_bits = len(message) | |
encoded_ba = bitarray.bitarray() | |
encoded_ba.frombytes(encoded.encode('utf-8')) | |
encoded_bits = len(encoded_ba.tolist()) | |
return encoded_bits/message_bits | |
#@title | |
import torch | |
import math | |
import random | |
def bin_sort(l, token_indices, total, entropy, device): | |
#compute entropy for upper bound on the number of bins we need | |
bucket_size = total | |
num_bins = 2**int(entropy+1) | |
bucket_size = total / num_bins | |
bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins | |
value_in_bins = [0] * num_bins | |
space_left_after = [total - i*bucket_size for i in range(0,num_bins)] | |
token_bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins | |
# Figuring out what the search order should be | |
step_size = num_bins/4 | |
search_order = [] | |
priorities = [0]*num_bins | |
priority = 0 | |
search_order.append(int(num_bins/2)) | |
search_order.append(0) | |
priorities[int(num_bins/2)] = 0 | |
priorities[0] = 0 | |
while(step_size>=1): | |
priority += 1 | |
for x in range(num_bins-int(step_size), -1, -int(step_size*2)): | |
search_order.append(x) | |
priorities[x] = priority | |
step_size = step_size/2 | |
# Adding the actual elements | |
for (item, token_index) in zip(l.tolist(), token_indices.tolist()): | |
found_single_bucket_fit = False | |
single_bucket_index = -1 | |
single_bucket_value = bucket_size | |
found_multi_bucket_bumpless_fit = False | |
multi_bucket_bumpless_index = -1 | |
multi_bucket_bumpless_value = total | |
found_multi_bucket_bumping_fit = False | |
multi_bucket_bumping_index = -1 | |
multi_bucket_bumping_value = total | |
for i in search_order: # for index in search_order | |
if(item > space_left_after[i]): | |
continue | |
if(value_in_bins[i] >= bucket_size): | |
continue | |
# Priority of choices | |
# 1. Can i place this thing in an empty bucket all on its own? | |
# 2. Can i plan this somewhere where is doesnt have to bump anything else around? | |
# 2a. Minimize the wasted space. Aka use the smallest space (of equal priority) that accomplishes this goal | |
# 3. If not (1) and (2), then put it in the space the bumps stuff the least. | |
if(value_in_bins[i] + item > bucket_size): #Would overflow. | |
space_before_next_block = bucket_size - value_in_bins[i] | |
for j in range(i+1, len(bins)): | |
if(value_in_bins[j] > 0): # We have found a bucket with something in it. This is how much space we have here. | |
space_before_next_block = space_before_next_block + (bucket_size - value_in_bins[i]) | |
break | |
else: # This was a empty bucket | |
space_before_next_block = space_before_next_block + bucket_size | |
if((not found_multi_bucket_bumpless_fit) or (found_multi_bucket_bumpless_fit and priorities[i] <= priorities[multi_bucket_bumpless_index])): #This could potentially be a match | |
# If this is a valid space to put this without bumping and it is a better fit than previous spaces | |
if(space_before_next_block > item and space_before_next_block < multi_bucket_bumpless_value): | |
# set this to be the pointer! we can fit stuff here | |
found_multi_bucket_bumpless_fit = True | |
multi_bucket_bumpless_index = i | |
multi_bucket_bumpless_value = space_before_next_block | |
# Find the overflow that will bump the least | |
if ( item - space_before_next_block < multi_bucket_bumping_value): | |
found_multi_bucket_bumping_fit = True | |
multi_bucket_bumping_index = i | |
multi_bucket_bumping_value = item - space_before_next_block | |
if(value_in_bins[i] + item <= bucket_size): #Would fit | |
if(single_bucket_value > value_in_bins[i]): | |
found_single_bucket_fit = True | |
single_bucket_value = value_in_bins[i] | |
single_bucket_index = i | |
if (single_bucket_index == multi_bucket_bumpless_index == multi_bucket_bumping_index == -1): | |
bins[0] = torch.cat( (torch.tensor([item], device=device), bins[0]), 0) | |
token_bins[0] = torch.cat( (torch.tensor([token_index], device=device), token_bins[0]), 0) | |
continue | |
if found_single_bucket_fit: | |
# We found somewhere we can actually fit! | |
bins[single_bucket_index] = torch.cat( (bins[single_bucket_index], torch.tensor([item], device=device)), 0) | |
token_bins[single_bucket_index] = torch.cat( (token_bins[single_bucket_index], torch.tensor([token_index], device=device)), 0) | |
value_in_bins[single_bucket_index] += item | |
for i in range(0, single_bucket_index+1): | |
space_left_after[i] -= item | |
elif found_multi_bucket_bumpless_fit: | |
# Found somewhere we can put this without upsetting the force | |
part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumpless_index] | |
part_overflow = item - part_in_bucket | |
bins[multi_bucket_bumpless_index] = torch.cat( (bins[multi_bucket_bumpless_index], torch.tensor([item], device=device)), 0) | |
token_bins[multi_bucket_bumpless_index] = torch.cat( (token_bins[multi_bucket_bumpless_index], torch.tensor([token_index], device=device)), 0) | |
value_in_bins[multi_bucket_bumpless_index] = bucket_size | |
# Fill this bucket and continue overflowing | |
j = multi_bucket_bumpless_index + 1 | |
for i in range(0, j): | |
space_left_after[i] -= item | |
while(part_overflow > 0): | |
new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size | |
value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled | |
space_left_after[j] -= part_overflow | |
part_overflow = new_part_overflow | |
j+=1 | |
else: | |
part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumping_index] | |
part_overflow = item - part_in_bucket | |
bins[multi_bucket_bumping_index] = torch.cat( (bins[multi_bucket_bumping_index], torch.tensor([item], device=device)), 0) | |
token_bins[multi_bucket_bumping_index] = torch.cat( (token_bins[multi_bucket_bumping_index], torch.tensor([token_index], device=device)), 0) | |
value_in_bins[multi_bucket_bumping_index] = bucket_size | |
# Fill this bucket and continue overflowing | |
j = multi_bucket_bumping_index + 1 | |
for i in range(0, j): | |
space_left_after[i] -= item | |
while(part_overflow > 0): | |
new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size | |
value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled | |
space_left_after[j] -= part_overflow | |
part_overflow = new_part_overflow | |
j+=1 | |
sorted_tensor = torch.cat(bins, 0) | |
sorted_tokens = torch.cat(token_bins, 0) | |
return sorted_tensor, sorted_tokens | |
def compute_ev(t, precision): | |
expected_bits = [] | |
cum_probs = t.cumsum(0) | |
for selection in range(0, len(cum_probs)): | |
# Calculate new range as ints | |
new_int_bottom = cum_probs[selection-1] if selection > 0 else 0 | |
new_int_top = cum_probs[selection] | |
# Convert range to bits | |
new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision))) | |
new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive | |
# Consume most significant bits which are now fixed and update interval | |
num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc) | |
expected_bits.append(t[selection] * num_bits_encoded) | |
return(float(sum(expected_bits).item())/(2**precision)) | |
def visualize_bins(values_in_bins, bucket_size): | |
out_str = "[" | |
for b in values_in_bins: | |
out_str = out_str + " " + str(round(100*b/bucket_size,2)) + " |" | |
out_str = out_str + "]" | |
print(out_str) | |
def visualize_distribution(l): | |
total = sum(l) | |
out_str = "[" | |
for b in l: | |
out_str = out_str + " " + str(round(100*b/total,2)) + " |" | |
out_str = out_str + "]" | |
print(out_str) | |
def compute_entropy(lists): | |
total = sum(lists) | |
entropy = -1*sum([ (x/total) * math.log2(x/total) for x in lists]) | |
return entropy |