Spaces:
Sleeping
Sleeping
#@title | |
import torch | |
import torch.nn.functional as F | |
import os | |
from drgb import DRBG | |
from utils import bin_sort, bits2int, entropy, int2bits, is_sent_finish, kl, limit_past, num_same_from_beg | |
# Constants for HMAC-DRBG -- MUST CHANGE FOR SECURE IMPLEMENTATION | |
sample_key = b'0x01'*64 | |
sample_seed_prefix = b'sample' | |
sample_nonce_counter = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' | |
def encode_meteor(model, enc, message, context, finish_sent=False, device='cpu', temp=1.0, precision=16, topk=50000, is_sort=False, randomize_key=False, input_key=sample_key, input_nonce=sample_nonce_counter): | |
if randomize_key: | |
input_key = os.urandom(64) | |
mask_generator = DRBG(input_key, sample_seed_prefix + input_nonce) | |
context = torch.tensor(context[-1022:], device=device, dtype=torch.long) | |
max_val = 2**precision | |
threshold = 2**(-precision) | |
cur_interval = [0, max_val] # bottom inclusive, top exclusive | |
prev = context | |
output = context | |
past = None | |
total_num = 0 | |
total_num_for_stats = 0 | |
total_log_probs = 0 | |
total_kl = 0 # in bits | |
total_entropy_ptau = 0 | |
total_num_sents = 0 | |
with torch.no_grad(): | |
i = 0 | |
sent_finish = False | |
while i < len(message) or (finish_sent and not sent_finish): | |
logits, past = model(prev.unsqueeze(0), past=past) | |
past = limit_past(past) | |
logits[0, -1, -1] = -1e20 # endoftext token can't happen | |
logits[0, -1, 628] = -1e20 # 2 newlines token can't happen | |
logits, indices = logits[0, -1, :].sort(descending=True) | |
logits = logits.double() | |
logits_temp = logits / temp | |
probs_temp = F.softmax(logits_temp, dim=0) | |
log_probs_temp = F.log_softmax(logits_temp, dim=0) | |
log_probs = F.log_softmax(logits, dim=0) | |
# conditions for having reached the end of the message | |
if i >= len(message): | |
selection = 0 | |
sent_finish = is_sent_finish(indices[selection].item(), enc) | |
else: | |
# Cutoff low probabilities that would be rounded to 0 | |
cur_int_range = cur_interval[1]-cur_interval[0] | |
cur_threshold = 1/cur_int_range | |
k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk) | |
probs_temp_int = probs_temp[:k] # Cutoff all but top k | |
old_indices = indices | |
indices = indices[:k] | |
# Rescale to correct range | |
probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range | |
entropy_in_this_distribution = entropy(probs_temp, log_probs_temp) | |
# Round probabilities to integers given precision | |
probs_temp_int = probs_temp_int.round().long() | |
if is_sort: | |
probs_temp_int, indices = bin_sort(probs_temp_int, indices, cur_int_range, entropy_in_this_distribution, device) | |
cum_probs = probs_temp_int.cumsum(0) | |
# Remove any elements from the bottom if rounding caused the total prob to be too large | |
overfill_index = (cum_probs > cur_int_range).nonzero() | |
if len(overfill_index) > 0: | |
cum_probs = cum_probs[:overfill_index[0]] | |
# Add any mass to the top if removing/rounding causes the total prob to be too small | |
cum_probs += cur_int_range-cum_probs[-1] # add | |
# Get out resulting probabilities | |
probs_final = cum_probs.clone() | |
probs_final[1:] = cum_probs[1:] - cum_probs[:-1] | |
# Convert to position in range | |
cum_probs += cur_interval[0] | |
# Apply the mask to the message | |
message_bits = message[i:i+precision] | |
if i+precision > len(message): | |
message_bits = message_bits + [0]*(i+precision-len(message)) | |
mask_bits = mask_generator.generate_bits(precision) | |
for b in range(0, len(message_bits)): | |
message_bits[b] = message_bits[b] ^ mask_bits[b] | |
# Get selected index based on binary fraction from message bits | |
message_idx = bits2int(reversed(message_bits)) | |
selection = (cum_probs > message_idx).nonzero()[0].item() | |
# Calculate new range as ints | |
new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0] | |
new_int_top = cum_probs[selection] | |
# Convert range to bits | |
new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision))) | |
new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive | |
# Consume most significant bits which are now fixed and update interval | |
num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc) | |
i += num_bits_encoded | |
# Gather statistics | |
total_log_probs += log_probs[selection].item() | |
q = probs_final.double()/probs_final.sum() | |
logq = q.log() | |
total_kl += kl(q, logq, log_probs[:len(q)]) | |
total_entropy_ptau += entropy_in_this_distribution | |
total_num_for_stats += 1 | |
# Update history with new token | |
prev = indices[selection].view(1) | |
output = torch.cat((output, prev)) | |
total_num += 1 | |
# For text->bits->text | |
partial = enc.decode(output[len(context):].tolist()) | |
if '<eos>' in partial: | |
break | |
avg_NLL = -total_log_probs/total_num_for_stats | |
avg_KL = total_kl/total_num_for_stats | |
# avg_Hq = total_entropy_ptau/total_num_for_stats | |
words_per_bit = total_num_for_stats/i | |
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit | |
def decode_meteor(model, enc, text, context, device='cpu', temp=1.0, precision=16, topk=50000, is_sort=False, input_key=sample_key, input_nonce=sample_nonce_counter): | |
# inp is a list of token indices | |
# context is a list of token indices | |
inp = enc.encode(text) | |
context = torch.tensor(context[-1022:], device=device, dtype=torch.long) | |
mask_generator = DRBG(input_key, sample_seed_prefix + input_nonce) | |
max_val = 2**precision | |
threshold = 2**(-precision) | |
cur_interval = [0, max_val] # bottom inclusive, top exclusive | |
prev = context | |
past = None | |
message = [] | |
with torch.no_grad(): | |
i = 0 | |
while i < len(inp): | |
logits, past = model(prev.unsqueeze(0), past=past) | |
past = limit_past(past) | |
logits[0, -1, -1] = -1e20 # endoftext can't happen | |
logits[0, -1, 628] = -1e20 # 2 newlines can't happen | |
logits, indices = logits[0, -1, :].sort(descending=True) | |
logits = logits.double() | |
logits_temp = logits / temp | |
log_probs_temp = F.log_softmax(logits_temp, dim=0) | |
probs_temp = F.softmax(logits_temp, dim=0) | |
# Cutoff low probabilities that would be rounded to 0 | |
cur_int_range = cur_interval[1]-cur_interval[0] | |
cur_threshold = 1/cur_int_range | |
k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk) | |
probs_temp_int = probs_temp[:k] # Cutoff all but top k | |
# Rescale to correct range | |
probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range | |
entropy_in_this_distribution = entropy(probs_temp, log_probs_temp) | |
# Round probabilities to integers given precision | |
probs_temp_int = probs_temp_int.round().long() | |
if is_sort: | |
probs_temp_int, indices = bin_sort(probs_temp_int, indices, cur_int_range, entropy_in_this_distribution, device) | |
cum_probs = probs_temp_int.cumsum(0) | |
# Remove any elements from the bottom if rounding caused the total prob to be too large | |
overfill_index = (cum_probs > cur_int_range).nonzero() | |
if len(overfill_index) > 0: | |
cum_probs = cum_probs[:overfill_index[0]] | |
k = overfill_index[0].item() | |
# Add any mass to the top if removing/rounding causes the total prob to be too small | |
cum_probs += cur_int_range-cum_probs[-1] # add | |
# Covnert to position in range | |
cum_probs += cur_interval[0] | |
rank = (indices == inp[i]).nonzero().item() | |
# Handle most errors that could happen because of BPE with heuristic | |
if rank >= k: | |
true_token_text = enc.decoder[inp[i]] | |
for rank_idx in range(k): | |
prop_token_text = enc.decoder[indices[rank_idx].item()] | |
# common case that is not caught | |
if inp[i] == 128 and indices[rank_idx] == 198: | |
rank = rank_idx | |
inp[i] = indices[rank_idx].item() | |
break | |
# Is there a more likely prefix token that could be the actual token generated? | |
if len(prop_token_text) <= len(true_token_text) and \ | |
prop_token_text == true_token_text[:len(prop_token_text)]: | |
rank = rank_idx | |
suffix = true_token_text[len(prop_token_text):] | |
suffix_tokens = enc.encode(suffix) # a list | |
inp[i] = indices[rank_idx].item() | |
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list | |
break | |
# Is there a more likely longer token that could be the actual token generated? | |
elif len(prop_token_text) > len(true_token_text) and \ | |
true_token_text == prop_token_text[:len(true_token_text)]: | |
whole_text = true_token_text | |
num_extra = 1 | |
while len(whole_text) < len(prop_token_text): | |
whole_text += enc.decoder[inp[i+num_extra]] | |
num_extra += 1 | |
if prop_token_text == whole_text[:len(prop_token_text)]: | |
rank = rank_idx | |
inp[i] = indices[rank_idx].item() | |
for j in range(1, num_extra): | |
del inp[i+j] | |
if len(whole_text) > len(prop_token_text): | |
suffix = whole_text[len(prop_token_text):] | |
suffix_tokens = enc.encode(suffix) # a list | |
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list | |
break | |
else: | |
print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text)) | |
rank = 0 | |
selection = rank | |
# Calculate new range as ints | |
new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0] | |
new_int_top = cum_probs[selection] | |
# Convert range to bits | |
new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision))) | |
new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive | |
# Emit most significant bits which are now fixed and update interval | |
num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc) | |
if i == len(inp)-1: | |
new_bits = new_int_bottom_bits_inc | |
else: | |
new_bits = new_int_top_bits_inc[:num_bits_encoded] | |
# Get the mask and apply it to the recovered bits | |
mask_bits = mask_generator.generate_bits(precision) | |
for b in range(0, len(new_bits)): | |
new_bits[b] = new_bits[b] ^ mask_bits[b] | |
message += new_bits | |
# Update history with new token | |
prev = torch.tensor([inp[i]], device=device, dtype=torch.long) | |
i += 1 | |
return message | |