Xinyoumeng233hu's picture
Update utils.py
c864132
import torch
import numpy as np
import bitarray
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer
def decode(self, token_ids, **kwargs):
filtered_tokens = self.convert_ids_to_tokens(token_ids)
text = self.convert_tokens_to_string(filtered_tokens)
return text
GPT2Tokenizer.decode = decode
def _convert_token_to_id(self, token):
return self.encoder.get(token, 0)
GPT2Tokenizer._convert_token_to_id = _convert_token_to_id
def limit_past(past):
past = list(past)
for i in range(len(past)):
past[i] = past[i][:, :, :, -1022:]
return past
def kl(q, logq, logp):
res = q*(logq-logp)/0.69315
res[q==0] = 0
return res.sum().item() # in bits
def entropy(q, logq):
res = q*logq/0.69315
res[q==0] = 0
return -res.sum().item() # in bits
# e.g. [0, 1, 1, 1] looks like 1110=14
def bits2int(bits):
res = 0
for i, bit in enumerate(bits):
res += bit*(2**i)
return res
def int2bits(inp, num_bits):
if num_bits == 0:
return []
strlist = ('{0:0%db}'%num_bits).format(inp)
return [int(strval) for strval in reversed(strlist)]
def is_sent_finish(token_idx, enc):
token = enc.decoder[token_idx]
return '.' in token or '!' in token or '?' in token
def num_same_from_beg(bits1, bits2):
assert len(bits1) == len(bits2)
for i in range(len(bits1)):
if bits1[i] != bits2[i]:
break
return i
def encode_context(raw_text, enc):
context_tokens = [enc.encoder['<|endoftext|>']] + enc.encode(raw_text)
return context_tokens
# Use gpt2-medium for 345M param model
# Use gpt2-large for 774M param model
def get_model(seed=1234, model_name='gpt2'):
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device("cpu")
enc = GPT2Tokenizer.from_pretrained(model_name)
enc.unk_token = None
enc.bos_token = None
enc.eos_token = None
model = GPT2LMHeadModel.from_pretrained(model_name)
model.to(device)
model.eval()
#model.double()
return enc, model
enc32_itoc = ['\0', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '.', ',', "'", '!', ' ']
enc32_ctoi = {k: v for v, k in enumerate(enc32_itoc)}
def enc32(text):
bits = []
for c in text:
bits.extend(int2bits(enc32_ctoi[c], 5))
return bits
def dec32(bits):
text = ''
for i in range(0, len(bits), 5):
c = enc32_itoc[bits2int(bits[i:i+5])]
if c == '\0':
break
text += c
return text
# message should be bit string
# encoded should be text string
def expansion_ratio(message, encoded):
message_bits = len(message)
encoded_ba = bitarray.bitarray()
encoded_ba.frombytes(encoded.encode('utf-8'))
encoded_bits = len(encoded_ba.tolist())
return encoded_bits/message_bits
#@title
import torch
import math
import random
def bin_sort(l, token_indices, total, entropy, device):
#compute entropy for upper bound on the number of bins we need
bucket_size = total
num_bins = 2**int(entropy+1)
bucket_size = total / num_bins
bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins
value_in_bins = [0] * num_bins
space_left_after = [total - i*bucket_size for i in range(0,num_bins)]
token_bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins
# Figuring out what the search order should be
step_size = num_bins/4
search_order = []
priorities = [0]*num_bins
priority = 0
search_order.append(int(num_bins/2))
search_order.append(0)
priorities[int(num_bins/2)] = 0
priorities[0] = 0
while(step_size>=1):
priority += 1
for x in range(num_bins-int(step_size), -1, -int(step_size*2)):
search_order.append(x)
priorities[x] = priority
step_size = step_size/2
# Adding the actual elements
for (item, token_index) in zip(l.tolist(), token_indices.tolist()):
found_single_bucket_fit = False
single_bucket_index = -1
single_bucket_value = bucket_size
found_multi_bucket_bumpless_fit = False
multi_bucket_bumpless_index = -1
multi_bucket_bumpless_value = total
found_multi_bucket_bumping_fit = False
multi_bucket_bumping_index = -1
multi_bucket_bumping_value = total
for i in search_order: # for index in search_order
if(item > space_left_after[i]):
continue
if(value_in_bins[i] >= bucket_size):
continue
# Priority of choices
# 1. Can i place this thing in an empty bucket all on its own?
# 2. Can i plan this somewhere where is doesnt have to bump anything else around?
# 2a. Minimize the wasted space. Aka use the smallest space (of equal priority) that accomplishes this goal
# 3. If not (1) and (2), then put it in the space the bumps stuff the least.
if(value_in_bins[i] + item > bucket_size): #Would overflow.
space_before_next_block = bucket_size - value_in_bins[i]
for j in range(i+1, len(bins)):
if(value_in_bins[j] > 0): # We have found a bucket with something in it. This is how much space we have here.
space_before_next_block = space_before_next_block + (bucket_size - value_in_bins[i])
break
else: # This was a empty bucket
space_before_next_block = space_before_next_block + bucket_size
if((not found_multi_bucket_bumpless_fit) or (found_multi_bucket_bumpless_fit and priorities[i] <= priorities[multi_bucket_bumpless_index])): #This could potentially be a match
# If this is a valid space to put this without bumping and it is a better fit than previous spaces
if(space_before_next_block > item and space_before_next_block < multi_bucket_bumpless_value):
# set this to be the pointer! we can fit stuff here
found_multi_bucket_bumpless_fit = True
multi_bucket_bumpless_index = i
multi_bucket_bumpless_value = space_before_next_block
# Find the overflow that will bump the least
if ( item - space_before_next_block < multi_bucket_bumping_value):
found_multi_bucket_bumping_fit = True
multi_bucket_bumping_index = i
multi_bucket_bumping_value = item - space_before_next_block
if(value_in_bins[i] + item <= bucket_size): #Would fit
if(single_bucket_value > value_in_bins[i]):
found_single_bucket_fit = True
single_bucket_value = value_in_bins[i]
single_bucket_index = i
if (single_bucket_index == multi_bucket_bumpless_index == multi_bucket_bumping_index == -1):
bins[0] = torch.cat( (torch.tensor([item], device=device), bins[0]), 0)
token_bins[0] = torch.cat( (torch.tensor([token_index], device=device), token_bins[0]), 0)
continue
if found_single_bucket_fit:
# We found somewhere we can actually fit!
bins[single_bucket_index] = torch.cat( (bins[single_bucket_index], torch.tensor([item], device=device)), 0)
token_bins[single_bucket_index] = torch.cat( (token_bins[single_bucket_index], torch.tensor([token_index], device=device)), 0)
value_in_bins[single_bucket_index] += item
for i in range(0, single_bucket_index+1):
space_left_after[i] -= item
elif found_multi_bucket_bumpless_fit:
# Found somewhere we can put this without upsetting the force
part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumpless_index]
part_overflow = item - part_in_bucket
bins[multi_bucket_bumpless_index] = torch.cat( (bins[multi_bucket_bumpless_index], torch.tensor([item], device=device)), 0)
token_bins[multi_bucket_bumpless_index] = torch.cat( (token_bins[multi_bucket_bumpless_index], torch.tensor([token_index], device=device)), 0)
value_in_bins[multi_bucket_bumpless_index] = bucket_size
# Fill this bucket and continue overflowing
j = multi_bucket_bumpless_index + 1
for i in range(0, j):
space_left_after[i] -= item
while(part_overflow > 0):
new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size
value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled
space_left_after[j] -= part_overflow
part_overflow = new_part_overflow
j+=1
else:
part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumping_index]
part_overflow = item - part_in_bucket
bins[multi_bucket_bumping_index] = torch.cat( (bins[multi_bucket_bumping_index], torch.tensor([item], device=device)), 0)
token_bins[multi_bucket_bumping_index] = torch.cat( (token_bins[multi_bucket_bumping_index], torch.tensor([token_index], device=device)), 0)
value_in_bins[multi_bucket_bumping_index] = bucket_size
# Fill this bucket and continue overflowing
j = multi_bucket_bumping_index + 1
for i in range(0, j):
space_left_after[i] -= item
while(part_overflow > 0):
new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size
value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled
space_left_after[j] -= part_overflow
part_overflow = new_part_overflow
j+=1
sorted_tensor = torch.cat(bins, 0)
sorted_tokens = torch.cat(token_bins, 0)
return sorted_tensor, sorted_tokens
def compute_ev(t, precision):
expected_bits = []
cum_probs = t.cumsum(0)
for selection in range(0, len(cum_probs)):
# Calculate new range as ints
new_int_bottom = cum_probs[selection-1] if selection > 0 else 0
new_int_top = cum_probs[selection]
# Convert range to bits
new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
# Consume most significant bits which are now fixed and update interval
num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
expected_bits.append(t[selection] * num_bits_encoded)
return(float(sum(expected_bits).item())/(2**precision))
def visualize_bins(values_in_bins, bucket_size):
out_str = "["
for b in values_in_bins:
out_str = out_str + " " + str(round(100*b/bucket_size,2)) + " |"
out_str = out_str + "]"
print(out_str)
def visualize_distribution(l):
total = sum(l)
out_str = "["
for b in l:
out_str = out_str + " " + str(round(100*b/total,2)) + " |"
out_str = out_str + "]"
print(out_str)
def compute_entropy(lists):
total = sum(lists)
entropy = -1*sum([ (x/total) * math.log2(x/total) for x in lists])
return entropy