import torch import numpy as np import bitarray from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer def decode(self, token_ids, **kwargs): filtered_tokens = self.convert_ids_to_tokens(token_ids) text = self.convert_tokens_to_string(filtered_tokens) return text GPT2Tokenizer.decode = decode def _convert_token_to_id(self, token): return self.encoder.get(token, 0) GPT2Tokenizer._convert_token_to_id = _convert_token_to_id def limit_past(past): past = list(past) for i in range(len(past)): past[i] = past[i][:, :, :, -1022:] return past def kl(q, logq, logp): res = q*(logq-logp)/0.69315 res[q==0] = 0 return res.sum().item() # in bits def entropy(q, logq): res = q*logq/0.69315 res[q==0] = 0 return -res.sum().item() # in bits # e.g. [0, 1, 1, 1] looks like 1110=14 def bits2int(bits): res = 0 for i, bit in enumerate(bits): res += bit*(2**i) return res def int2bits(inp, num_bits): if num_bits == 0: return [] strlist = ('{0:0%db}'%num_bits).format(inp) return [int(strval) for strval in reversed(strlist)] def is_sent_finish(token_idx, enc): token = enc.decoder[token_idx] return '.' in token or '!' in token or '?' in token def num_same_from_beg(bits1, bits2): assert len(bits1) == len(bits2) for i in range(len(bits1)): if bits1[i] != bits2[i]: break return i def encode_context(raw_text, enc): context_tokens = [enc.encoder['<|endoftext|>']] + enc.encode(raw_text) return context_tokens # Use gpt2-medium for 345M param model # Use gpt2-large for 774M param model def get_model(seed=1234, model_name='gpt2'): np.random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) device = torch.device("cpu") enc = GPT2Tokenizer.from_pretrained(model_name) enc.unk_token = None enc.bos_token = None enc.eos_token = None model = GPT2LMHeadModel.from_pretrained(model_name) model.to(device) model.eval() #model.double() return enc, model enc32_itoc = ['\0', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '.', ',', "'", '!', ' '] enc32_ctoi = {k: v for v, k in enumerate(enc32_itoc)} def enc32(text): bits = [] for c in text: bits.extend(int2bits(enc32_ctoi[c], 5)) return bits def dec32(bits): text = '' for i in range(0, len(bits), 5): c = enc32_itoc[bits2int(bits[i:i+5])] if c == '\0': break text += c return text # message should be bit string # encoded should be text string def expansion_ratio(message, encoded): message_bits = len(message) encoded_ba = bitarray.bitarray() encoded_ba.frombytes(encoded.encode('utf-8')) encoded_bits = len(encoded_ba.tolist()) return encoded_bits/message_bits #@title import torch import math import random def bin_sort(l, token_indices, total, entropy, device): #compute entropy for upper bound on the number of bins we need bucket_size = total num_bins = 2**int(entropy+1) bucket_size = total / num_bins bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins value_in_bins = [0] * num_bins space_left_after = [total - i*bucket_size for i in range(0,num_bins)] token_bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins # Figuring out what the search order should be step_size = num_bins/4 search_order = [] priorities = [0]*num_bins priority = 0 search_order.append(int(num_bins/2)) search_order.append(0) priorities[int(num_bins/2)] = 0 priorities[0] = 0 while(step_size>=1): priority += 1 for x in range(num_bins-int(step_size), -1, -int(step_size*2)): search_order.append(x) priorities[x] = priority step_size = step_size/2 # Adding the actual elements for (item, token_index) in zip(l.tolist(), token_indices.tolist()): found_single_bucket_fit = False single_bucket_index = -1 single_bucket_value = bucket_size found_multi_bucket_bumpless_fit = False multi_bucket_bumpless_index = -1 multi_bucket_bumpless_value = total found_multi_bucket_bumping_fit = False multi_bucket_bumping_index = -1 multi_bucket_bumping_value = total for i in search_order: # for index in search_order if(item > space_left_after[i]): continue if(value_in_bins[i] >= bucket_size): continue # Priority of choices # 1. Can i place this thing in an empty bucket all on its own? # 2. Can i plan this somewhere where is doesnt have to bump anything else around? # 2a. Minimize the wasted space. Aka use the smallest space (of equal priority) that accomplishes this goal # 3. If not (1) and (2), then put it in the space the bumps stuff the least. if(value_in_bins[i] + item > bucket_size): #Would overflow. space_before_next_block = bucket_size - value_in_bins[i] for j in range(i+1, len(bins)): if(value_in_bins[j] > 0): # We have found a bucket with something in it. This is how much space we have here. space_before_next_block = space_before_next_block + (bucket_size - value_in_bins[i]) break else: # This was a empty bucket space_before_next_block = space_before_next_block + bucket_size if((not found_multi_bucket_bumpless_fit) or (found_multi_bucket_bumpless_fit and priorities[i] <= priorities[multi_bucket_bumpless_index])): #This could potentially be a match # If this is a valid space to put this without bumping and it is a better fit than previous spaces if(space_before_next_block > item and space_before_next_block < multi_bucket_bumpless_value): # set this to be the pointer! we can fit stuff here found_multi_bucket_bumpless_fit = True multi_bucket_bumpless_index = i multi_bucket_bumpless_value = space_before_next_block # Find the overflow that will bump the least if ( item - space_before_next_block < multi_bucket_bumping_value): found_multi_bucket_bumping_fit = True multi_bucket_bumping_index = i multi_bucket_bumping_value = item - space_before_next_block if(value_in_bins[i] + item <= bucket_size): #Would fit if(single_bucket_value > value_in_bins[i]): found_single_bucket_fit = True single_bucket_value = value_in_bins[i] single_bucket_index = i if (single_bucket_index == multi_bucket_bumpless_index == multi_bucket_bumping_index == -1): bins[0] = torch.cat( (torch.tensor([item], device=device), bins[0]), 0) token_bins[0] = torch.cat( (torch.tensor([token_index], device=device), token_bins[0]), 0) continue if found_single_bucket_fit: # We found somewhere we can actually fit! bins[single_bucket_index] = torch.cat( (bins[single_bucket_index], torch.tensor([item], device=device)), 0) token_bins[single_bucket_index] = torch.cat( (token_bins[single_bucket_index], torch.tensor([token_index], device=device)), 0) value_in_bins[single_bucket_index] += item for i in range(0, single_bucket_index+1): space_left_after[i] -= item elif found_multi_bucket_bumpless_fit: # Found somewhere we can put this without upsetting the force part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumpless_index] part_overflow = item - part_in_bucket bins[multi_bucket_bumpless_index] = torch.cat( (bins[multi_bucket_bumpless_index], torch.tensor([item], device=device)), 0) token_bins[multi_bucket_bumpless_index] = torch.cat( (token_bins[multi_bucket_bumpless_index], torch.tensor([token_index], device=device)), 0) value_in_bins[multi_bucket_bumpless_index] = bucket_size # Fill this bucket and continue overflowing j = multi_bucket_bumpless_index + 1 for i in range(0, j): space_left_after[i] -= item while(part_overflow > 0): new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled space_left_after[j] -= part_overflow part_overflow = new_part_overflow j+=1 else: part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumping_index] part_overflow = item - part_in_bucket bins[multi_bucket_bumping_index] = torch.cat( (bins[multi_bucket_bumping_index], torch.tensor([item], device=device)), 0) token_bins[multi_bucket_bumping_index] = torch.cat( (token_bins[multi_bucket_bumping_index], torch.tensor([token_index], device=device)), 0) value_in_bins[multi_bucket_bumping_index] = bucket_size # Fill this bucket and continue overflowing j = multi_bucket_bumping_index + 1 for i in range(0, j): space_left_after[i] -= item while(part_overflow > 0): new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled space_left_after[j] -= part_overflow part_overflow = new_part_overflow j+=1 sorted_tensor = torch.cat(bins, 0) sorted_tokens = torch.cat(token_bins, 0) return sorted_tensor, sorted_tokens def compute_ev(t, precision): expected_bits = [] cum_probs = t.cumsum(0) for selection in range(0, len(cum_probs)): # Calculate new range as ints new_int_bottom = cum_probs[selection-1] if selection > 0 else 0 new_int_top = cum_probs[selection] # Convert range to bits new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision))) new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive # Consume most significant bits which are now fixed and update interval num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc) expected_bits.append(t[selection] * num_bits_encoded) return(float(sum(expected_bits).item())/(2**precision)) def visualize_bins(values_in_bins, bucket_size): out_str = "[" for b in values_in_bins: out_str = out_str + " " + str(round(100*b/bucket_size,2)) + " |" out_str = out_str + "]" print(out_str) def visualize_distribution(l): total = sum(l) out_str = "[" for b in l: out_str = out_str + " " + str(round(100*b/total,2)) + " |" out_str = out_str + "]" print(out_str) def compute_entropy(lists): total = sum(lists) entropy = -1*sum([ (x/total) * math.log2(x/total) for x in lists]) return entropy