SteganographywithGPT-2 / huffman_baseline.py
Xinyoumeng233hu's picture
Upload 8 files
229a3ba
raw
history blame
6.63 kB
import torch
import torch.nn.functional as F
from huffman import HuffmanCoding
from utils import kl, entropy, is_sent_finish, limit_past
def encode_huffman(model, enc, message, context, bits_per_word, finish_sent=False, device='cuda'):
length = len(message)
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
prev = context
output = context
past = None
total_num = 0
total_num_for_stats = 0
total_log_probs = 0
total_kl = 0 # in bits
total_num_sents = 0
with torch.no_grad():
i = 0
sent_finish = False
while i < length or (finish_sent and not sent_finish):
logits, past = model(prev.unsqueeze(0), past=past)
past = limit_past(past)
logits[0, -1, -1] = -1e10 # endoftext can't happen
logits[0, -1, 628] = -1e10 # 2 newlines can't happen
logits, indices = logits[0, -1, :].sort(descending=True)
# Get the top 2**bits options
indices = indices[:2**bits_per_word]
log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word]
probs = torch.exp(log_probs)
if i >= length:
selection = 0
sent_finish = is_sent_finish(indices[0].item(), enc)
else:
probs_array = probs.cpu().numpy()
coding = HuffmanCoding()
coding.make_heap_from_array(probs_array)
coding.merge_nodes()
root = coding.make_codes()
#print(message[i:i+10])
while root.token is None:
if i >= length or message[i] == 0:
root = root.left
else:
root = root.right
i += 1
selection = root.token
logq = torch.tensor([-len(coding.codes[idx]) for idx in range(len(probs_array))], dtype=torch.float, device=device) # in bits
logq = logq*0.69315 # in nats
q = torch.exp(logq)
total_kl += kl(q, logq, log_probs)
total_log_probs += log_probs[selection].item()
total_num_for_stats += 1
total_num += 1
prev = indices[selection].view(1)
output = torch.cat((output, prev))
avg_NLL = -total_log_probs/total_num_for_stats
avg_KL = total_kl/total_num_for_stats
words_per_bit = total_num_for_stats/i
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
def decode_huffman(model, enc, text, context, bits_per_word, device='cuda'):
# inp is a list of token indices
# context is a list of token indices
inp = enc.encode(text)
i = 0
while i < len(inp):
if inp[i] == 628:
inp[i] = 198
inp[i+1:i+1] = [198]
i += 2
else:
i += 1
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
prev = context
past = None
message = []
with torch.no_grad():
i = 0
while i < len(inp):
if past and past[0].shape[3] >= 1023:
raise RuntimeError
logits, past = model(prev.unsqueeze(0), past=past)
past = limit_past(past)
logits[0, -1, -1] = -1e10 # endoftext can't happen
logits[0, -1, 628] = -1e10 # 2 newlines can't happen
logits, indices = logits[0, -1, :].sort(descending=True)
# Get the top 2**bits options
indices = indices[:2**bits_per_word]
log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word]
probs = torch.exp(log_probs)
if inp[i] not in indices:
true_token_text = enc.decoder[inp[i]]
for rank_idx in range(2**bits_per_word):
prop_token_text = enc.decoder[indices[rank_idx].item()]
# common case that is not caught
if inp[i] == 128 and indices[rank_idx] == 198:
rank = rank_idx
inp[i] = indices[rank_idx].item()
break
# Is there a more likely prefix token that could be the actual token generated?
if len(prop_token_text) <= len(true_token_text) and \
prop_token_text == true_token_text[:len(prop_token_text)]:
rank = rank_idx
suffix = true_token_text[len(prop_token_text):]
suffix_tokens = enc.encode(suffix) # a list
inp[i] = indices[rank_idx].item()
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
break
# Is there a more likely longer token that could be the actual token generated?
elif len(prop_token_text) > len(true_token_text) and \
true_token_text == prop_token_text[:len(true_token_text)]:
whole_text = true_token_text
num_extra = 1
while len(whole_text) < len(prop_token_text):
whole_text += enc.decoder[inp[i+num_extra]]
num_extra += 1
if prop_token_text == whole_text[:len(prop_token_text)]:
rank = rank_idx
inp[i] = indices[rank_idx].item()
for j in range(1, num_extra):
del inp[i+j]
if len(whole_text) > len(prop_token_text):
suffix = whole_text[len(prop_token_text):]
suffix_tokens = enc.encode(suffix) # a list
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
break
else:
print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
rank = 0
else:
rank = (indices == inp[i]).nonzero().item()
probs_array = probs.cpu().numpy()
coding = HuffmanCoding()
coding.make_heap_from_array(probs_array)
coding.merge_nodes()
coding.make_codes()
tokens_t = map(int, coding.codes[rank])
message.extend(tokens_t)
prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
i += 1
return message