Spaces:
Sleeping
Sleeping
Commit
·
229a3ba
1
Parent(s):
c6624a3
Upload 8 files
Browse files- demo.py +68 -0
- drgb.py +43 -0
- huffman.py +119 -0
- huffman_baseline.py +166 -0
- meteor.py +275 -0
- run_single.py +166 -0
- sample.py +55 -0
- utils.py +296 -0
demo.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from run_single import encode_message,decode_message
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
def encode_decode_message(mode, secret_message, chosen_context):
|
5 |
+
x= encode_message(mode, secret_message, chosen_context)
|
6 |
+
return x
|
7 |
+
|
8 |
+
def decode_encode_message(mode, stegomessage, chosen_context):
|
9 |
+
y = decode_message(mode, stegomessage, chosen_context)
|
10 |
+
return y
|
11 |
+
|
12 |
+
def clear1():
|
13 |
+
# en_input2.value = ""
|
14 |
+
en_method.value = ""
|
15 |
+
en_output.value = ""
|
16 |
+
en_ppl.value = ""
|
17 |
+
en_kl.value = ""
|
18 |
+
en_wordsbit.value = ""
|
19 |
+
# en_entropy.value = ""
|
20 |
+
return "","","","","","",""
|
21 |
+
|
22 |
+
def clear2():
|
23 |
+
# de_input2.value = ""
|
24 |
+
de_method.value = ""
|
25 |
+
de_output.value = ""
|
26 |
+
return "","",""
|
27 |
+
|
28 |
+
# Modify the demo block to add Textbox widgets for metrics
|
29 |
+
with gr.Blocks() as demo:
|
30 |
+
|
31 |
+
gr.Markdown("<center><h1>GPT2隐写系统</h1></center>")
|
32 |
+
gr.Markdown("使用gpt2模型进行隐写与提取.")
|
33 |
+
|
34 |
+
with gr.Tab("加密"):
|
35 |
+
en_input1 = gr.Textbox(label="上下文",placeholder="Input context")
|
36 |
+
en_input2 = gr.Textbox(label="秘密信息",placeholder="Message text to be concealed")
|
37 |
+
en_method = gr.Dropdown(label="嵌入算法", choices=["meteor", "arithmetic", "huffman", "bins"])
|
38 |
+
en_output = gr.Textbox(label="含密文本")
|
39 |
+
with gr.Row():
|
40 |
+
en_ppl = gr.Textbox(label="困惑度")
|
41 |
+
en_kl = gr.Textbox(label="KL散度")
|
42 |
+
en_wordsbit = gr.Textbox(label="每比特所携带的信息量")
|
43 |
+
# en_entropy = gr.Textbox(label="信道容量")
|
44 |
+
with gr.Row():
|
45 |
+
en_button1 = gr.Button("清除")
|
46 |
+
en_button2 = gr.Button("加密")
|
47 |
+
en_input1.value = "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist." #@param ["Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.", "The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.", "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist."] {allow-input: true}
|
48 |
+
# en_input1.value = "Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission."
|
49 |
+
en_input2.value = "In me the tiger sniffs the tiger."
|
50 |
+
# en_input2.value = "hello"
|
51 |
+
|
52 |
+
with gr.Tab("解密"):
|
53 |
+
de_input1 = gr.Textbox(label="上下文")
|
54 |
+
de_input2 = gr.Textbox(label="含密文本",placeholder="Output of encrypt")
|
55 |
+
de_method = gr.Dropdown(label="嵌入算法",choices=["meteor", "arithmetic", "huffman", "bins"])
|
56 |
+
de_output = gr.Textbox(label="恢复的秘密信息")
|
57 |
+
with gr.Row():
|
58 |
+
de_button1 = gr.Button("清除")
|
59 |
+
de_button2 = gr.Button("解密")
|
60 |
+
de_input1.value = "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist." #@param ["Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.", "The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.", "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist."] {allow-input: true}
|
61 |
+
# de_input1.value = "Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission."
|
62 |
+
|
63 |
+
en_button1.click(clear1, outputs=[en_output,en_ppl,en_kl,en_wordsbit])
|
64 |
+
en_button2.click(encode_decode_message, inputs=[en_method, en_input2, en_input1], outputs=[en_output,en_ppl,en_kl,en_wordsbit])
|
65 |
+
de_button1.click(clear2, outputs=[de_input2,de_output])
|
66 |
+
de_button2.click(decode_encode_message, inputs=[de_method, de_input2, de_input1], outputs=[de_output])
|
67 |
+
|
68 |
+
demo.launch()
|
drgb.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#@title Colab setup { run: "auto", display-mode: "form" }
|
2 |
+
#@markdown This downloads some prereqs. It might take a while! You only have to run this cell once.
|
3 |
+
# !pip install torch==1.13.1 pytorch-transformers==1.1.0 bitarray==1.0.1
|
4 |
+
import hashlib
|
5 |
+
import hmac
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class DRBG(object):
|
9 |
+
def __init__(self, key, seed):
|
10 |
+
self.key = key
|
11 |
+
self.val = b'\x01' * 64
|
12 |
+
self.reseed(seed)
|
13 |
+
|
14 |
+
self.byte_index = 0
|
15 |
+
self.bit_index = 0
|
16 |
+
|
17 |
+
def hmac(self, key, val):
|
18 |
+
return hmac.new(key, val, hashlib.sha512).digest()
|
19 |
+
|
20 |
+
def reseed(self, data=b''):
|
21 |
+
self.key = self.hmac(self.key, self.val + b'\x00' + data)
|
22 |
+
self.val = self.hmac(self.key, self.val)
|
23 |
+
|
24 |
+
if data:
|
25 |
+
self.key = self.hmac(self.key, self.val + b'\x01' + data)
|
26 |
+
self.val = self.hmac(self.key, self.val)
|
27 |
+
|
28 |
+
def generate_bits(self, n):
|
29 |
+
xs = np.zeros(n, dtype=bool)
|
30 |
+
for i in range(0,n):
|
31 |
+
xs[i] = (self.val[self.byte_index] >> (7 - self.bit_index)) & 1
|
32 |
+
|
33 |
+
self.bit_index += 1
|
34 |
+
if self.bit_index >= 8:
|
35 |
+
self.bit_index = 0
|
36 |
+
self.byte_index += 1
|
37 |
+
|
38 |
+
if self.byte_index >= 8:
|
39 |
+
self.byte_index = 0
|
40 |
+
self.val = self.hmac(self.key, self.val)
|
41 |
+
|
42 |
+
self.reseed()
|
43 |
+
return xs
|
huffman.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import heapq
|
2 |
+
from functools import total_ordering
|
3 |
+
|
4 |
+
"""
|
5 |
+
Code for Huffman Coding, compression and decompression.
|
6 |
+
Explanation at http://bhrigu.me/blog/2017/01/17/huffman-coding-python-implementation/
|
7 |
+
Adapted from https://github.com/bhrigu123/huffman-coding
|
8 |
+
"""
|
9 |
+
|
10 |
+
@total_ordering
|
11 |
+
class HeapNode:
|
12 |
+
def __init__(self, token, freq):
|
13 |
+
self.token = token
|
14 |
+
self.freq = freq
|
15 |
+
self.left = None
|
16 |
+
self.right = None
|
17 |
+
|
18 |
+
# defining comparators less_than and equals
|
19 |
+
def __lt__(self, other):
|
20 |
+
return self.freq < other.freq
|
21 |
+
|
22 |
+
def __eq__(self, other):
|
23 |
+
if(other == None):
|
24 |
+
return False
|
25 |
+
if(not isinstance(other, HeapNode)):
|
26 |
+
return False
|
27 |
+
return self.freq == other.freq
|
28 |
+
|
29 |
+
class HuffmanCoding:
|
30 |
+
def __init__(self):
|
31 |
+
self.heap = []
|
32 |
+
self.codes = {}
|
33 |
+
self.reverse_mapping = {}
|
34 |
+
|
35 |
+
# functions for compression:
|
36 |
+
|
37 |
+
def make_heap(self, frequency):
|
38 |
+
for key in frequency:
|
39 |
+
node = HeapNode(key, frequency[key])
|
40 |
+
heapq.heappush(self.heap, node)
|
41 |
+
|
42 |
+
def make_heap_from_array(self, freqs):
|
43 |
+
for index in range(len(freqs)):
|
44 |
+
node = HeapNode(index, freqs[index])
|
45 |
+
heapq.heappush(self.heap, node)
|
46 |
+
|
47 |
+
def merge_nodes(self):
|
48 |
+
while(len(self.heap)>1):
|
49 |
+
node1 = heapq.heappop(self.heap)
|
50 |
+
node2 = heapq.heappop(self.heap)
|
51 |
+
|
52 |
+
merged = HeapNode(None, node1.freq + node2.freq)
|
53 |
+
merged.left = node1
|
54 |
+
merged.right = node2
|
55 |
+
|
56 |
+
heapq.heappush(self.heap, merged)
|
57 |
+
|
58 |
+
|
59 |
+
def make_codes_helper(self, root, current_code):
|
60 |
+
if(root == None):
|
61 |
+
return
|
62 |
+
|
63 |
+
if(root.token != None):
|
64 |
+
self.codes[root.token] = current_code
|
65 |
+
self.reverse_mapping[current_code] = root.token
|
66 |
+
return
|
67 |
+
|
68 |
+
self.make_codes_helper(root.left, current_code + "0")
|
69 |
+
self.make_codes_helper(root.right, current_code + "1")
|
70 |
+
|
71 |
+
def make_codes(self):
|
72 |
+
root = heapq.heappop(self.heap)
|
73 |
+
current_code = ""
|
74 |
+
self.make_codes_helper(root, current_code)
|
75 |
+
return root
|
76 |
+
|
77 |
+
|
78 |
+
def get_encoded_tokens(self, token_list):
|
79 |
+
encoded_text = ""
|
80 |
+
for token in token_list:
|
81 |
+
encoded_text += self.codes[token]
|
82 |
+
return encoded_text
|
83 |
+
|
84 |
+
def decode_text(self, encoded_text):
|
85 |
+
current_code = ""
|
86 |
+
decoded_text = ""
|
87 |
+
|
88 |
+
for bit in encoded_text:
|
89 |
+
current_code += bit
|
90 |
+
if(current_code in self.reverse_mapping):
|
91 |
+
character = self.reverse_mapping[current_code]
|
92 |
+
decoded_text += character
|
93 |
+
current_code = ""
|
94 |
+
|
95 |
+
return decoded_text
|
96 |
+
|
97 |
+
|
98 |
+
def decompress(self, input_path):
|
99 |
+
filename, file_extension = os.path.splitext(self.path)
|
100 |
+
output_path = filename + "_decompressed" + ".txt"
|
101 |
+
|
102 |
+
with open(input_path, 'rb') as file, open(output_path, 'w') as output:
|
103 |
+
bit_string = ""
|
104 |
+
|
105 |
+
byte = file.read(1)
|
106 |
+
while(len(byte) > 0):
|
107 |
+
byte = ord(byte)
|
108 |
+
bits = bin(byte)[2:].rjust(8, '0')
|
109 |
+
bit_string += bits
|
110 |
+
byte = file.read(1)
|
111 |
+
|
112 |
+
encoded_text = self.remove_padding(bit_string)
|
113 |
+
|
114 |
+
decompressed_text = self.decode_text(encoded_text)
|
115 |
+
|
116 |
+
output.write(decompressed_text)
|
117 |
+
|
118 |
+
return output_path
|
119 |
+
|
huffman_baseline.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from huffman import HuffmanCoding
|
5 |
+
from utils import kl, entropy, is_sent_finish, limit_past
|
6 |
+
|
7 |
+
def encode_huffman(model, enc, message, context, bits_per_word, finish_sent=False, device='cuda'):
|
8 |
+
length = len(message)
|
9 |
+
|
10 |
+
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
11 |
+
|
12 |
+
prev = context
|
13 |
+
output = context
|
14 |
+
past = None
|
15 |
+
|
16 |
+
total_num = 0
|
17 |
+
total_num_for_stats = 0
|
18 |
+
total_log_probs = 0
|
19 |
+
total_kl = 0 # in bits
|
20 |
+
total_num_sents = 0
|
21 |
+
|
22 |
+
with torch.no_grad():
|
23 |
+
i = 0
|
24 |
+
sent_finish = False
|
25 |
+
while i < length or (finish_sent and not sent_finish):
|
26 |
+
logits, past = model(prev.unsqueeze(0), past=past)
|
27 |
+
past = limit_past(past)
|
28 |
+
logits[0, -1, -1] = -1e10 # endoftext can't happen
|
29 |
+
logits[0, -1, 628] = -1e10 # 2 newlines can't happen
|
30 |
+
logits, indices = logits[0, -1, :].sort(descending=True)
|
31 |
+
|
32 |
+
# Get the top 2**bits options
|
33 |
+
indices = indices[:2**bits_per_word]
|
34 |
+
log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word]
|
35 |
+
probs = torch.exp(log_probs)
|
36 |
+
|
37 |
+
if i >= length:
|
38 |
+
selection = 0
|
39 |
+
sent_finish = is_sent_finish(indices[0].item(), enc)
|
40 |
+
else:
|
41 |
+
probs_array = probs.cpu().numpy()
|
42 |
+
coding = HuffmanCoding()
|
43 |
+
coding.make_heap_from_array(probs_array)
|
44 |
+
coding.merge_nodes()
|
45 |
+
root = coding.make_codes()
|
46 |
+
|
47 |
+
#print(message[i:i+10])
|
48 |
+
while root.token is None:
|
49 |
+
if i >= length or message[i] == 0:
|
50 |
+
root = root.left
|
51 |
+
else:
|
52 |
+
root = root.right
|
53 |
+
i += 1
|
54 |
+
selection = root.token
|
55 |
+
|
56 |
+
logq = torch.tensor([-len(coding.codes[idx]) for idx in range(len(probs_array))], dtype=torch.float, device=device) # in bits
|
57 |
+
logq = logq*0.69315 # in nats
|
58 |
+
q = torch.exp(logq)
|
59 |
+
total_kl += kl(q, logq, log_probs)
|
60 |
+
total_log_probs += log_probs[selection].item()
|
61 |
+
total_num_for_stats += 1
|
62 |
+
|
63 |
+
total_num += 1
|
64 |
+
|
65 |
+
prev = indices[selection].view(1)
|
66 |
+
output = torch.cat((output, prev))
|
67 |
+
|
68 |
+
avg_NLL = -total_log_probs/total_num_for_stats
|
69 |
+
avg_KL = total_kl/total_num_for_stats
|
70 |
+
words_per_bit = total_num_for_stats/i
|
71 |
+
|
72 |
+
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
73 |
+
|
74 |
+
def decode_huffman(model, enc, text, context, bits_per_word, device='cuda'):
|
75 |
+
# inp is a list of token indices
|
76 |
+
# context is a list of token indices
|
77 |
+
inp = enc.encode(text)
|
78 |
+
i = 0
|
79 |
+
while i < len(inp):
|
80 |
+
if inp[i] == 628:
|
81 |
+
inp[i] = 198
|
82 |
+
inp[i+1:i+1] = [198]
|
83 |
+
i += 2
|
84 |
+
else:
|
85 |
+
i += 1
|
86 |
+
|
87 |
+
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
88 |
+
prev = context
|
89 |
+
past = None
|
90 |
+
|
91 |
+
message = []
|
92 |
+
with torch.no_grad():
|
93 |
+
i = 0
|
94 |
+
while i < len(inp):
|
95 |
+
if past and past[0].shape[3] >= 1023:
|
96 |
+
raise RuntimeError
|
97 |
+
|
98 |
+
logits, past = model(prev.unsqueeze(0), past=past)
|
99 |
+
past = limit_past(past)
|
100 |
+
logits[0, -1, -1] = -1e10 # endoftext can't happen
|
101 |
+
logits[0, -1, 628] = -1e10 # 2 newlines can't happen
|
102 |
+
logits, indices = logits[0, -1, :].sort(descending=True)
|
103 |
+
|
104 |
+
# Get the top 2**bits options
|
105 |
+
indices = indices[:2**bits_per_word]
|
106 |
+
log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word]
|
107 |
+
probs = torch.exp(log_probs)
|
108 |
+
|
109 |
+
if inp[i] not in indices:
|
110 |
+
true_token_text = enc.decoder[inp[i]]
|
111 |
+
for rank_idx in range(2**bits_per_word):
|
112 |
+
prop_token_text = enc.decoder[indices[rank_idx].item()]
|
113 |
+
# common case that is not caught
|
114 |
+
if inp[i] == 128 and indices[rank_idx] == 198:
|
115 |
+
rank = rank_idx
|
116 |
+
inp[i] = indices[rank_idx].item()
|
117 |
+
break
|
118 |
+
|
119 |
+
# Is there a more likely prefix token that could be the actual token generated?
|
120 |
+
if len(prop_token_text) <= len(true_token_text) and \
|
121 |
+
prop_token_text == true_token_text[:len(prop_token_text)]:
|
122 |
+
rank = rank_idx
|
123 |
+
suffix = true_token_text[len(prop_token_text):]
|
124 |
+
suffix_tokens = enc.encode(suffix) # a list
|
125 |
+
inp[i] = indices[rank_idx].item()
|
126 |
+
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
|
127 |
+
break
|
128 |
+
|
129 |
+
# Is there a more likely longer token that could be the actual token generated?
|
130 |
+
elif len(prop_token_text) > len(true_token_text) and \
|
131 |
+
true_token_text == prop_token_text[:len(true_token_text)]:
|
132 |
+
whole_text = true_token_text
|
133 |
+
num_extra = 1
|
134 |
+
while len(whole_text) < len(prop_token_text):
|
135 |
+
whole_text += enc.decoder[inp[i+num_extra]]
|
136 |
+
num_extra += 1
|
137 |
+
if prop_token_text == whole_text[:len(prop_token_text)]:
|
138 |
+
rank = rank_idx
|
139 |
+
inp[i] = indices[rank_idx].item()
|
140 |
+
for j in range(1, num_extra):
|
141 |
+
del inp[i+j]
|
142 |
+
|
143 |
+
if len(whole_text) > len(prop_token_text):
|
144 |
+
suffix = whole_text[len(prop_token_text):]
|
145 |
+
suffix_tokens = enc.encode(suffix) # a list
|
146 |
+
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
|
147 |
+
break
|
148 |
+
else:
|
149 |
+
print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
|
150 |
+
rank = 0
|
151 |
+
else:
|
152 |
+
rank = (indices == inp[i]).nonzero().item()
|
153 |
+
|
154 |
+
probs_array = probs.cpu().numpy()
|
155 |
+
coding = HuffmanCoding()
|
156 |
+
coding.make_heap_from_array(probs_array)
|
157 |
+
coding.merge_nodes()
|
158 |
+
coding.make_codes()
|
159 |
+
|
160 |
+
tokens_t = map(int, coding.codes[rank])
|
161 |
+
|
162 |
+
message.extend(tokens_t)
|
163 |
+
prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
|
164 |
+
i += 1
|
165 |
+
|
166 |
+
return message
|
meteor.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#@title
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from drgb import DRBG
|
8 |
+
from utils import bin_sort, bits2int, entropy, int2bits, is_sent_finish, kl, limit_past, num_same_from_beg
|
9 |
+
|
10 |
+
# Constants for HMAC-DRBG -- MUST CHANGE FOR SECURE IMPLEMENTATION
|
11 |
+
sample_key = b'0x01'*64
|
12 |
+
sample_seed_prefix = b'sample'
|
13 |
+
sample_nonce_counter = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
|
14 |
+
|
15 |
+
|
16 |
+
def encode_meteor(model, enc, message, context, finish_sent=False, device='cuda', temp=1.0, precision=16, topk=50000, is_sort=False, randomize_key=False, input_key=sample_key, input_nonce=sample_nonce_counter):
|
17 |
+
|
18 |
+
if randomize_key:
|
19 |
+
input_key = os.urandom(64)
|
20 |
+
mask_generator = DRBG(input_key, sample_seed_prefix + input_nonce)
|
21 |
+
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
22 |
+
|
23 |
+
|
24 |
+
max_val = 2**precision
|
25 |
+
threshold = 2**(-precision)
|
26 |
+
cur_interval = [0, max_val] # bottom inclusive, top exclusive
|
27 |
+
|
28 |
+
prev = context
|
29 |
+
output = context
|
30 |
+
past = None
|
31 |
+
|
32 |
+
total_num = 0
|
33 |
+
total_num_for_stats = 0
|
34 |
+
total_log_probs = 0
|
35 |
+
total_kl = 0 # in bits
|
36 |
+
total_entropy_ptau = 0
|
37 |
+
total_num_sents = 0
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
i = 0
|
41 |
+
sent_finish = False
|
42 |
+
while i < len(message) or (finish_sent and not sent_finish):
|
43 |
+
logits, past = model(prev.unsqueeze(0), past=past)
|
44 |
+
past = limit_past(past)
|
45 |
+
logits[0, -1, -1] = -1e20 # endoftext token can't happen
|
46 |
+
logits[0, -1, 628] = -1e20 # 2 newlines token can't happen
|
47 |
+
logits, indices = logits[0, -1, :].sort(descending=True)
|
48 |
+
logits = logits.double()
|
49 |
+
logits_temp = logits / temp
|
50 |
+
probs_temp = F.softmax(logits_temp, dim=0)
|
51 |
+
log_probs_temp = F.log_softmax(logits_temp, dim=0)
|
52 |
+
log_probs = F.log_softmax(logits, dim=0)
|
53 |
+
|
54 |
+
# conditions for having reached the end of the message
|
55 |
+
if i >= len(message):
|
56 |
+
selection = 0
|
57 |
+
sent_finish = is_sent_finish(indices[selection].item(), enc)
|
58 |
+
else:
|
59 |
+
# Cutoff low probabilities that would be rounded to 0
|
60 |
+
cur_int_range = cur_interval[1]-cur_interval[0]
|
61 |
+
cur_threshold = 1/cur_int_range
|
62 |
+
k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
|
63 |
+
probs_temp_int = probs_temp[:k] # Cutoff all but top k
|
64 |
+
old_indices = indices
|
65 |
+
indices = indices[:k]
|
66 |
+
|
67 |
+
# Rescale to correct range
|
68 |
+
probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range
|
69 |
+
|
70 |
+
entropy_in_this_distribution = entropy(probs_temp, log_probs_temp)
|
71 |
+
|
72 |
+
# Round probabilities to integers given precision
|
73 |
+
probs_temp_int = probs_temp_int.round().long()
|
74 |
+
|
75 |
+
if is_sort:
|
76 |
+
probs_temp_int, indices = bin_sort(probs_temp_int, indices, cur_int_range, entropy_in_this_distribution, device)
|
77 |
+
cum_probs = probs_temp_int.cumsum(0)
|
78 |
+
|
79 |
+
# Remove any elements from the bottom if rounding caused the total prob to be too large
|
80 |
+
overfill_index = (cum_probs > cur_int_range).nonzero()
|
81 |
+
if len(overfill_index) > 0:
|
82 |
+
cum_probs = cum_probs[:overfill_index[0]]
|
83 |
+
|
84 |
+
# Add any mass to the top if removing/rounding causes the total prob to be too small
|
85 |
+
cum_probs += cur_int_range-cum_probs[-1] # add
|
86 |
+
|
87 |
+
# Get out resulting probabilities
|
88 |
+
probs_final = cum_probs.clone()
|
89 |
+
probs_final[1:] = cum_probs[1:] - cum_probs[:-1]
|
90 |
+
|
91 |
+
# Convert to position in range
|
92 |
+
cum_probs += cur_interval[0]
|
93 |
+
|
94 |
+
# Apply the mask to the message
|
95 |
+
message_bits = message[i:i+precision]
|
96 |
+
if i+precision > len(message):
|
97 |
+
message_bits = message_bits + [0]*(i+precision-len(message))
|
98 |
+
|
99 |
+
mask_bits = mask_generator.generate_bits(precision)
|
100 |
+
for b in range(0, len(message_bits)):
|
101 |
+
message_bits[b] = message_bits[b] ^ mask_bits[b]
|
102 |
+
|
103 |
+
# Get selected index based on binary fraction from message bits
|
104 |
+
message_idx = bits2int(reversed(message_bits))
|
105 |
+
selection = (cum_probs > message_idx).nonzero()[0].item()
|
106 |
+
|
107 |
+
# Calculate new range as ints
|
108 |
+
new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]
|
109 |
+
new_int_top = cum_probs[selection]
|
110 |
+
|
111 |
+
# Convert range to bits
|
112 |
+
new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
|
113 |
+
new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
|
114 |
+
|
115 |
+
# Consume most significant bits which are now fixed and update interval
|
116 |
+
num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
|
117 |
+
i += num_bits_encoded
|
118 |
+
|
119 |
+
# Gather statistics
|
120 |
+
total_log_probs += log_probs[selection].item()
|
121 |
+
|
122 |
+
q = probs_final.double()/probs_final.sum()
|
123 |
+
logq = q.log()
|
124 |
+
total_kl += kl(q, logq, log_probs[:len(q)])
|
125 |
+
total_entropy_ptau += entropy_in_this_distribution
|
126 |
+
total_num_for_stats += 1
|
127 |
+
|
128 |
+
# Update history with new token
|
129 |
+
prev = indices[selection].view(1)
|
130 |
+
output = torch.cat((output, prev))
|
131 |
+
total_num += 1
|
132 |
+
|
133 |
+
# For text->bits->text
|
134 |
+
partial = enc.decode(output[len(context):].tolist())
|
135 |
+
if '<eos>' in partial:
|
136 |
+
break
|
137 |
+
|
138 |
+
avg_NLL = -total_log_probs/total_num_for_stats
|
139 |
+
avg_KL = total_kl/total_num_for_stats
|
140 |
+
# avg_Hq = total_entropy_ptau/total_num_for_stats
|
141 |
+
words_per_bit = total_num_for_stats/i
|
142 |
+
|
143 |
+
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
144 |
+
|
145 |
+
def decode_meteor(model, enc, text, context, device='cuda', temp=1.0, precision=16, topk=50000, is_sort=False, input_key=sample_key, input_nonce=sample_nonce_counter):
|
146 |
+
# inp is a list of token indices
|
147 |
+
# context is a list of token indices
|
148 |
+
inp = enc.encode(text)
|
149 |
+
|
150 |
+
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
151 |
+
mask_generator = DRBG(input_key, sample_seed_prefix + input_nonce)
|
152 |
+
|
153 |
+
max_val = 2**precision
|
154 |
+
threshold = 2**(-precision)
|
155 |
+
cur_interval = [0, max_val] # bottom inclusive, top exclusive
|
156 |
+
|
157 |
+
prev = context
|
158 |
+
past = None
|
159 |
+
message = []
|
160 |
+
with torch.no_grad():
|
161 |
+
i = 0
|
162 |
+
while i < len(inp):
|
163 |
+
logits, past = model(prev.unsqueeze(0), past=past)
|
164 |
+
past = limit_past(past)
|
165 |
+
logits[0, -1, -1] = -1e20 # endoftext can't happen
|
166 |
+
logits[0, -1, 628] = -1e20 # 2 newlines can't happen
|
167 |
+
logits, indices = logits[0, -1, :].sort(descending=True)
|
168 |
+
logits = logits.double()
|
169 |
+
logits_temp = logits / temp
|
170 |
+
log_probs_temp = F.log_softmax(logits_temp, dim=0)
|
171 |
+
probs_temp = F.softmax(logits_temp, dim=0)
|
172 |
+
|
173 |
+
# Cutoff low probabilities that would be rounded to 0
|
174 |
+
cur_int_range = cur_interval[1]-cur_interval[0]
|
175 |
+
cur_threshold = 1/cur_int_range
|
176 |
+
k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
|
177 |
+
probs_temp_int = probs_temp[:k] # Cutoff all but top k
|
178 |
+
|
179 |
+
# Rescale to correct range
|
180 |
+
probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range
|
181 |
+
entropy_in_this_distribution = entropy(probs_temp, log_probs_temp)
|
182 |
+
|
183 |
+
# Round probabilities to integers given precision
|
184 |
+
probs_temp_int = probs_temp_int.round().long()
|
185 |
+
if is_sort:
|
186 |
+
probs_temp_int, indices = bin_sort(probs_temp_int, indices, cur_int_range, entropy_in_this_distribution, device)
|
187 |
+
cum_probs = probs_temp_int.cumsum(0)
|
188 |
+
|
189 |
+
# Remove any elements from the bottom if rounding caused the total prob to be too large
|
190 |
+
overfill_index = (cum_probs > cur_int_range).nonzero()
|
191 |
+
if len(overfill_index) > 0:
|
192 |
+
cum_probs = cum_probs[:overfill_index[0]]
|
193 |
+
k = overfill_index[0].item()
|
194 |
+
|
195 |
+
# Add any mass to the top if removing/rounding causes the total prob to be too small
|
196 |
+
cum_probs += cur_int_range-cum_probs[-1] # add
|
197 |
+
|
198 |
+
# Covnert to position in range
|
199 |
+
cum_probs += cur_interval[0]
|
200 |
+
|
201 |
+
rank = (indices == inp[i]).nonzero().item()
|
202 |
+
|
203 |
+
# Handle most errors that could happen because of BPE with heuristic
|
204 |
+
if rank >= k:
|
205 |
+
true_token_text = enc.decoder[inp[i]]
|
206 |
+
for rank_idx in range(k):
|
207 |
+
prop_token_text = enc.decoder[indices[rank_idx].item()]
|
208 |
+
# common case that is not caught
|
209 |
+
if inp[i] == 128 and indices[rank_idx] == 198:
|
210 |
+
rank = rank_idx
|
211 |
+
inp[i] = indices[rank_idx].item()
|
212 |
+
break
|
213 |
+
|
214 |
+
# Is there a more likely prefix token that could be the actual token generated?
|
215 |
+
if len(prop_token_text) <= len(true_token_text) and \
|
216 |
+
prop_token_text == true_token_text[:len(prop_token_text)]:
|
217 |
+
rank = rank_idx
|
218 |
+
suffix = true_token_text[len(prop_token_text):]
|
219 |
+
suffix_tokens = enc.encode(suffix) # a list
|
220 |
+
inp[i] = indices[rank_idx].item()
|
221 |
+
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
|
222 |
+
break
|
223 |
+
|
224 |
+
# Is there a more likely longer token that could be the actual token generated?
|
225 |
+
elif len(prop_token_text) > len(true_token_text) and \
|
226 |
+
true_token_text == prop_token_text[:len(true_token_text)]:
|
227 |
+
whole_text = true_token_text
|
228 |
+
num_extra = 1
|
229 |
+
while len(whole_text) < len(prop_token_text):
|
230 |
+
whole_text += enc.decoder[inp[i+num_extra]]
|
231 |
+
num_extra += 1
|
232 |
+
if prop_token_text == whole_text[:len(prop_token_text)]:
|
233 |
+
rank = rank_idx
|
234 |
+
inp[i] = indices[rank_idx].item()
|
235 |
+
for j in range(1, num_extra):
|
236 |
+
del inp[i+j]
|
237 |
+
|
238 |
+
if len(whole_text) > len(prop_token_text):
|
239 |
+
suffix = whole_text[len(prop_token_text):]
|
240 |
+
suffix_tokens = enc.encode(suffix) # a list
|
241 |
+
inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
|
242 |
+
break
|
243 |
+
else:
|
244 |
+
print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
|
245 |
+
rank = 0
|
246 |
+
|
247 |
+
selection = rank
|
248 |
+
|
249 |
+
# Calculate new range as ints
|
250 |
+
new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]
|
251 |
+
new_int_top = cum_probs[selection]
|
252 |
+
|
253 |
+
# Convert range to bits
|
254 |
+
new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
|
255 |
+
new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
|
256 |
+
|
257 |
+
# Emit most significant bits which are now fixed and update interval
|
258 |
+
num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
|
259 |
+
if i == len(inp)-1:
|
260 |
+
new_bits = new_int_bottom_bits_inc
|
261 |
+
else:
|
262 |
+
new_bits = new_int_top_bits_inc[:num_bits_encoded]
|
263 |
+
|
264 |
+
# Get the mask and apply it to the recovered bits
|
265 |
+
mask_bits = mask_generator.generate_bits(precision)
|
266 |
+
for b in range(0, len(new_bits)):
|
267 |
+
new_bits[b] = new_bits[b] ^ mask_bits[b]
|
268 |
+
message += new_bits
|
269 |
+
|
270 |
+
# Update history with new token
|
271 |
+
prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
|
272 |
+
|
273 |
+
i += 1
|
274 |
+
|
275 |
+
return message
|
run_single.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import bitarray
|
3 |
+
import sys
|
4 |
+
import re
|
5 |
+
import math
|
6 |
+
from meteor import encode_meteor,decode_meteor
|
7 |
+
from utils import get_model, encode_context
|
8 |
+
from arithmetic import encode_arithmetic, decode_arithmetic
|
9 |
+
from block_baseline import get_bins, encode_block, decode_block
|
10 |
+
from huffman_baseline import encode_huffman, decode_huffman
|
11 |
+
from sample import sample
|
12 |
+
|
13 |
+
def encode_message(mode, message_str, context):
|
14 |
+
enc, model = get_model(model_name='gpt2')
|
15 |
+
## PARAMETERS
|
16 |
+
# message_str = input("input secret message:")
|
17 |
+
unicode_enc = False
|
18 |
+
# mode = 'meteor'
|
19 |
+
# mode = input("Please enter mode (arithmetic, huffman, bins, or sample): ")
|
20 |
+
block_size = 3 # for huffman and bins
|
21 |
+
temp = 0.9 # for arithmetic
|
22 |
+
precision = 26 # for arithmetic
|
23 |
+
sample_tokens = 100 # for sample
|
24 |
+
topk = 300
|
25 |
+
finish_sent=True # whether or not to force finish sent. If so, stats displayed will be for non-finished sentence
|
26 |
+
meteor_sort = False
|
27 |
+
meteor_random = False
|
28 |
+
|
29 |
+
key = b'0x01'*64
|
30 |
+
sample_seed_prefix = b'sample'
|
31 |
+
nonce = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
|
32 |
+
|
33 |
+
## VALIDATE PARAMETERS
|
34 |
+
if mode not in ['meteor', 'arithmetic', 'huffman', 'bins']:
|
35 |
+
raise NotImplementedError
|
36 |
+
|
37 |
+
if mode == 'bins':
|
38 |
+
bin2words, words2bin = get_bins(len(enc.encoder), block_size)
|
39 |
+
|
40 |
+
# context = \
|
41 |
+
# """Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.
|
42 |
+
# """
|
43 |
+
# context = "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist." #@param ["Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.", "The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.", "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist."] {allow-input: true}
|
44 |
+
context_tokens = encode_context(context, enc)
|
45 |
+
# ------------------------------------------------------------------------------------
|
46 |
+
# ------------------------------------------------------------------------------------
|
47 |
+
# First encode message to uniform bits, without any context
|
48 |
+
# (not essential this is arithmetic vs ascii, but it's more efficient when the message is natural language)
|
49 |
+
if unicode_enc:
|
50 |
+
ba = bitarray.bitarray()
|
51 |
+
ba.frombytes(message_str.encode('utf-8'))
|
52 |
+
message = ba.tolist()
|
53 |
+
else:
|
54 |
+
message_ctx = [enc.encoder['<|endoftext|>']]
|
55 |
+
message_str += '<eos>'
|
56 |
+
message = decode_arithmetic(model, enc, message_str, message_ctx, precision=40, topk=60000)
|
57 |
+
# Next encode bits into cover text, using arbitrary context
|
58 |
+
Hq = 0
|
59 |
+
if mode == 'arithmetic':
|
60 |
+
out, nll, kl, words_per_bit = encode_arithmetic(model, enc, message, context_tokens, temp=temp, finish_sent=finish_sent, precision=precision, topk=topk)
|
61 |
+
elif mode == 'huffman':
|
62 |
+
out, nll, kl, words_per_bit = encode_huffman(model, enc, message, context_tokens, block_size, finish_sent=finish_sent)
|
63 |
+
elif mode == 'bins':
|
64 |
+
out, nll, kl, words_per_bit = encode_block(model, enc, message, context_tokens, block_size, bin2words, words2bin, finish_sent=finish_sent)
|
65 |
+
elif mode == 'meteor':
|
66 |
+
out, nll, kl, words_per_bit = encode_meteor(model, enc, message, context_tokens, temp=temp, finish_sent=finish_sent,
|
67 |
+
precision=precision, topk=topk, is_sort=meteor_sort, randomize_key=meteor_random, input_key=key, input_nonce=nonce)
|
68 |
+
elif mode == 'sample':
|
69 |
+
out, nll, kl, Hq = sample(model, enc, sample_tokens, context_tokens, temperature=temp, topk=topk)
|
70 |
+
words_per_bit = 1
|
71 |
+
text = enc.decode(out)
|
72 |
+
|
73 |
+
# print(message)
|
74 |
+
# print(len(message))
|
75 |
+
print("="*40 + " Encoding " + "="*40)
|
76 |
+
print(text)
|
77 |
+
print('ppl: %0.2f, kl: %0.3f, words/bit: %0.2f, bits/word: %0.2f, entropy: %.2f' % (math.exp(nll), kl, words_per_bit, 1/words_per_bit, Hq/0.69315))
|
78 |
+
|
79 |
+
stats = {
|
80 |
+
"ppl": math.exp(nll),
|
81 |
+
"kl": kl,
|
82 |
+
"wordsbit": words_per_bit,
|
83 |
+
"entropy": Hq/0.69315
|
84 |
+
}
|
85 |
+
# return text, stats
|
86 |
+
return text,stats["ppl"], stats["kl"], stats["wordsbit"]
|
87 |
+
|
88 |
+
def decode_message(mode, text, context):
|
89 |
+
enc, model = get_model(model_name='gpt2')
|
90 |
+
## PARAMETERS
|
91 |
+
unicode_enc = False
|
92 |
+
# mode = 'meteor'
|
93 |
+
block_size = 3 # for huffman and bins
|
94 |
+
temp = 0.9 # for arithmetic
|
95 |
+
precision = 26 # for arithmetic
|
96 |
+
sample_tokens = 100 # for sample
|
97 |
+
topk = 300
|
98 |
+
finish_sent=True # whether or not to force finish sent. If so, stats displayed will be for non-finished sentence
|
99 |
+
meteor_sort = False
|
100 |
+
meteor_random = False
|
101 |
+
|
102 |
+
key = b'0x01'*64
|
103 |
+
sample_seed_prefix = b'sample'
|
104 |
+
nonce = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
|
105 |
+
|
106 |
+
## VALIDATE PARAMETERS
|
107 |
+
if mode not in ['meteor', 'arithmetic', 'huffman', 'bins', 'sample']:
|
108 |
+
raise NotImplementedError
|
109 |
+
if mode == 'bins':
|
110 |
+
bin2words, words2bin = get_bins(len(enc.encoder), block_size)
|
111 |
+
|
112 |
+
context_tokens = encode_context(context, enc)
|
113 |
+
|
114 |
+
if mode != 'sample':
|
115 |
+
if mode == 'arithmetic':
|
116 |
+
message_rec = decode_arithmetic(model, enc, text, context_tokens, temp=temp, precision=precision, topk=topk)
|
117 |
+
elif mode == 'huffman':
|
118 |
+
message_rec = decode_huffman(model, enc, text, context_tokens, block_size)
|
119 |
+
elif mode == 'bins':
|
120 |
+
message_rec = decode_block(model, enc, text, context_tokens, block_size, bin2words, words2bin)
|
121 |
+
elif mode == 'meteor':
|
122 |
+
message_rec = decode_meteor(model, enc, text, context_tokens, temp=temp,
|
123 |
+
precision=precision, topk=topk, is_sort=meteor_sort, input_key=key, input_nonce=nonce)
|
124 |
+
|
125 |
+
print("="*35 + " Recovered Message " + "="*35)
|
126 |
+
# print(message_rec)
|
127 |
+
# print("=" * 80)
|
128 |
+
# Finally map message bits back to original text
|
129 |
+
if unicode_enc:
|
130 |
+
message_rec = [bool(item) for item in message_rec]
|
131 |
+
ba = bitarray.bitarray(message_rec)
|
132 |
+
reconst = ba.tobytes().decode('utf-8', 'ignore')
|
133 |
+
else:
|
134 |
+
message_ctx = [enc.encoder['<|endoftext|>']]
|
135 |
+
reconst = encode_arithmetic(model, enc, message_rec, message_ctx, precision=40, topk=60000)
|
136 |
+
reconst = enc.decode(reconst[0])
|
137 |
+
print(reconst[:-5])
|
138 |
+
print("=" * 80)
|
139 |
+
return reconst[:-5]
|
140 |
+
|
141 |
+
# def main():
|
142 |
+
# chosen_context = "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist." #@param ["Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.", "The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.", "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist."] {allow-input: true}
|
143 |
+
# # #@title { run: "auto", display-mode: "form" }
|
144 |
+
# message_text = "generate text!" #@param {type:"string"}
|
145 |
+
# mode = input("Please enter mode (meteor, arithmetic, huffman, bins, or sample): ")
|
146 |
+
# #@title Run me!
|
147 |
+
# #@markdown Make sure to re-run this cell if you change the parameters above.
|
148 |
+
# x = encode_message(mode, message_text, chosen_context)
|
149 |
+
# # print(x[0])
|
150 |
+
# y = decode_message(mode, x[0], chosen_context)
|
151 |
+
|
152 |
+
# if __name__ == '__main__':
|
153 |
+
# main()
|
154 |
+
|
155 |
+
# chosen_context = "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist." #@param ["Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.", "The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.", "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist."] {allow-input: true}
|
156 |
+
# # chosen_context = "Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission."
|
157 |
+
# # chosen_context += "\n\n" # to add a little spacing
|
158 |
+
|
159 |
+
# # #@title { run: "auto", display-mode: "form" }
|
160 |
+
# message_text = "generate text!" #@param {type:"string"}
|
161 |
+
# mode = input("Please enter mode (arithmetic, huffman, bins, or sample): ")
|
162 |
+
# #@title Run me!
|
163 |
+
# #@markdown Make sure to re-run this cell if you change the parameters above.
|
164 |
+
# x = encode_message(mode, message_text, chosen_context)
|
165 |
+
# # print(x[0])
|
166 |
+
# y = decode_message(mode, x[0], chosen_context)
|
sample.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from utils import limit_past, kl, entropy
|
5 |
+
|
6 |
+
def sample(model, enc, length, context, temperature=1.0, device='cuda', topk=-1):
|
7 |
+
assert length > 0
|
8 |
+
|
9 |
+
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
10 |
+
|
11 |
+
prev = context
|
12 |
+
output = context
|
13 |
+
past = None
|
14 |
+
|
15 |
+
total_log_probs = 0
|
16 |
+
total_entropy_ptau = 0
|
17 |
+
total_num = 0
|
18 |
+
total_kl = 0 # in bits
|
19 |
+
|
20 |
+
with torch.no_grad():
|
21 |
+
while total_num < length:
|
22 |
+
if past and past[0].shape[3] >= 1023:
|
23 |
+
raise RuntimeError
|
24 |
+
|
25 |
+
logits, past = model(prev.unsqueeze(0), past=past)
|
26 |
+
past = limit_past(past)
|
27 |
+
logits[0, -1, -1] = -1e10 # endoftext can't happen
|
28 |
+
logits[0, -1, 628] = -1e10 # 2 newlines can't happen
|
29 |
+
logits, indices = logits[0, -1, :].sort(descending=True)
|
30 |
+
base_log_probs = F.log_softmax(logits, dim=-1)
|
31 |
+
|
32 |
+
if topk > 0:
|
33 |
+
logits = logits[:topk]
|
34 |
+
|
35 |
+
logits = logits / temperature
|
36 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
37 |
+
probs = torch.exp(log_probs)
|
38 |
+
|
39 |
+
total_kl += kl(probs, log_probs, base_log_probs[:topk])
|
40 |
+
|
41 |
+
selection = torch.multinomial(probs, num_samples=1).item()
|
42 |
+
log_prob_chosen = base_log_probs[selection]
|
43 |
+
total_log_probs += log_prob_chosen.item()
|
44 |
+
|
45 |
+
total_entropy_ptau += entropy(probs, log_probs)
|
46 |
+
|
47 |
+
prev = indices[selection].view(1)
|
48 |
+
output = torch.cat((output, prev))
|
49 |
+
total_num += 1
|
50 |
+
|
51 |
+
avg_NLL = -total_log_probs/total_num
|
52 |
+
avg_KL = total_kl/total_num
|
53 |
+
avg_Hq = total_entropy_ptau/total_num
|
54 |
+
|
55 |
+
return output[len(context):].tolist(), avg_NLL, avg_KL, avg_Hq
|
utils.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import bitarray
|
4 |
+
|
5 |
+
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer
|
6 |
+
|
7 |
+
def decode(self, token_ids, **kwargs):
|
8 |
+
filtered_tokens = self.convert_ids_to_tokens(token_ids)
|
9 |
+
text = self.convert_tokens_to_string(filtered_tokens)
|
10 |
+
return text
|
11 |
+
GPT2Tokenizer.decode = decode
|
12 |
+
|
13 |
+
def _convert_token_to_id(self, token):
|
14 |
+
return self.encoder.get(token, 0)
|
15 |
+
GPT2Tokenizer._convert_token_to_id = _convert_token_to_id
|
16 |
+
|
17 |
+
|
18 |
+
def limit_past(past):
|
19 |
+
past = list(past)
|
20 |
+
for i in range(len(past)):
|
21 |
+
past[i] = past[i][:, :, :, -1022:]
|
22 |
+
return past
|
23 |
+
|
24 |
+
def kl(q, logq, logp):
|
25 |
+
res = q*(logq-logp)/0.69315
|
26 |
+
res[q==0] = 0
|
27 |
+
return res.sum().item() # in bits
|
28 |
+
|
29 |
+
def entropy(q, logq):
|
30 |
+
res = q*logq/0.69315
|
31 |
+
res[q==0] = 0
|
32 |
+
return -res.sum().item() # in bits
|
33 |
+
|
34 |
+
# e.g. [0, 1, 1, 1] looks like 1110=14
|
35 |
+
def bits2int(bits):
|
36 |
+
res = 0
|
37 |
+
for i, bit in enumerate(bits):
|
38 |
+
res += bit*(2**i)
|
39 |
+
return res
|
40 |
+
|
41 |
+
def int2bits(inp, num_bits):
|
42 |
+
if num_bits == 0:
|
43 |
+
return []
|
44 |
+
strlist = ('{0:0%db}'%num_bits).format(inp)
|
45 |
+
return [int(strval) for strval in reversed(strlist)]
|
46 |
+
|
47 |
+
def is_sent_finish(token_idx, enc):
|
48 |
+
token = enc.decoder[token_idx]
|
49 |
+
return '.' in token or '!' in token or '?' in token
|
50 |
+
|
51 |
+
def num_same_from_beg(bits1, bits2):
|
52 |
+
assert len(bits1) == len(bits2)
|
53 |
+
for i in range(len(bits1)):
|
54 |
+
if bits1[i] != bits2[i]:
|
55 |
+
break
|
56 |
+
|
57 |
+
return i
|
58 |
+
|
59 |
+
def encode_context(raw_text, enc):
|
60 |
+
context_tokens = [enc.encoder['<|endoftext|>']] + enc.encode(raw_text)
|
61 |
+
return context_tokens
|
62 |
+
|
63 |
+
# Use gpt2-medium for 345M param model
|
64 |
+
# Use gpt2-large for 774M param model
|
65 |
+
def get_model(seed=1234, model_name='gpt2'):
|
66 |
+
np.random.seed(seed)
|
67 |
+
torch.random.manual_seed(seed)
|
68 |
+
torch.cuda.manual_seed(seed)
|
69 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
70 |
+
|
71 |
+
enc = GPT2Tokenizer.from_pretrained(model_name)
|
72 |
+
enc.unk_token = None
|
73 |
+
enc.bos_token = None
|
74 |
+
enc.eos_token = None
|
75 |
+
|
76 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
77 |
+
model.to(device)
|
78 |
+
model.eval()
|
79 |
+
#model.double()
|
80 |
+
|
81 |
+
return enc, model
|
82 |
+
|
83 |
+
enc32_itoc = ['\0', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '.', ',', "'", '!', ' ']
|
84 |
+
enc32_ctoi = {k: v for v, k in enumerate(enc32_itoc)}
|
85 |
+
def enc32(text):
|
86 |
+
bits = []
|
87 |
+
for c in text:
|
88 |
+
bits.extend(int2bits(enc32_ctoi[c], 5))
|
89 |
+
return bits
|
90 |
+
|
91 |
+
def dec32(bits):
|
92 |
+
text = ''
|
93 |
+
for i in range(0, len(bits), 5):
|
94 |
+
c = enc32_itoc[bits2int(bits[i:i+5])]
|
95 |
+
if c == '\0':
|
96 |
+
break
|
97 |
+
text += c
|
98 |
+
return text
|
99 |
+
|
100 |
+
# message should be bit string
|
101 |
+
# encoded should be text string
|
102 |
+
def expansion_ratio(message, encoded):
|
103 |
+
message_bits = len(message)
|
104 |
+
encoded_ba = bitarray.bitarray()
|
105 |
+
encoded_ba.frombytes(encoded.encode('utf-8'))
|
106 |
+
encoded_bits = len(encoded_ba.tolist())
|
107 |
+
return encoded_bits/message_bits
|
108 |
+
|
109 |
+
#@title
|
110 |
+
import torch
|
111 |
+
import math
|
112 |
+
import random
|
113 |
+
|
114 |
+
def bin_sort(l, token_indices, total, entropy, device):
|
115 |
+
#compute entropy for upper bound on the number of bins we need
|
116 |
+
|
117 |
+
bucket_size = total
|
118 |
+
num_bins = 2**int(entropy+1)
|
119 |
+
bucket_size = total / num_bins
|
120 |
+
|
121 |
+
bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins
|
122 |
+
value_in_bins = [0] * num_bins
|
123 |
+
space_left_after = [total - i*bucket_size for i in range(0,num_bins)]
|
124 |
+
|
125 |
+
|
126 |
+
token_bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins
|
127 |
+
|
128 |
+
# Figuring out what the search order should be
|
129 |
+
step_size = num_bins/4
|
130 |
+
search_order = []
|
131 |
+
priorities = [0]*num_bins
|
132 |
+
priority = 0
|
133 |
+
search_order.append(int(num_bins/2))
|
134 |
+
search_order.append(0)
|
135 |
+
priorities[int(num_bins/2)] = 0
|
136 |
+
priorities[0] = 0
|
137 |
+
while(step_size>=1):
|
138 |
+
priority += 1
|
139 |
+
for x in range(num_bins-int(step_size), -1, -int(step_size*2)):
|
140 |
+
search_order.append(x)
|
141 |
+
priorities[x] = priority
|
142 |
+
step_size = step_size/2
|
143 |
+
|
144 |
+
# Adding the actual elements
|
145 |
+
for (item, token_index) in zip(l.tolist(), token_indices.tolist()):
|
146 |
+
found_single_bucket_fit = False
|
147 |
+
single_bucket_index = -1
|
148 |
+
single_bucket_value = bucket_size
|
149 |
+
|
150 |
+
found_multi_bucket_bumpless_fit = False
|
151 |
+
multi_bucket_bumpless_index = -1
|
152 |
+
multi_bucket_bumpless_value = total
|
153 |
+
|
154 |
+
found_multi_bucket_bumping_fit = False
|
155 |
+
multi_bucket_bumping_index = -1
|
156 |
+
multi_bucket_bumping_value = total
|
157 |
+
|
158 |
+
for i in search_order: # for index in search_order
|
159 |
+
if(item > space_left_after[i]):
|
160 |
+
continue
|
161 |
+
if(value_in_bins[i] >= bucket_size):
|
162 |
+
continue
|
163 |
+
|
164 |
+
# Priority of choices
|
165 |
+
# 1. Can i place this thing in an empty bucket all on its own?
|
166 |
+
# 2. Can i plan this somewhere where is doesnt have to bump anything else around?
|
167 |
+
# 2a. Minimize the wasted space. Aka use the smallest space (of equal priority) that accomplishes this goal
|
168 |
+
# 3. If not (1) and (2), then put it in the space the bumps stuff the least.
|
169 |
+
|
170 |
+
if(value_in_bins[i] + item > bucket_size): #Would overflow.
|
171 |
+
|
172 |
+
space_before_next_block = bucket_size - value_in_bins[i]
|
173 |
+
for j in range(i+1, len(bins)):
|
174 |
+
if(value_in_bins[j] > 0): # We have found a bucket with something in it. This is how much space we have here.
|
175 |
+
space_before_next_block = space_before_next_block + (bucket_size - value_in_bins[i])
|
176 |
+
break
|
177 |
+
else: # This was a empty bucket
|
178 |
+
space_before_next_block = space_before_next_block + bucket_size
|
179 |
+
|
180 |
+
if((not found_multi_bucket_bumpless_fit) or (found_multi_bucket_bumpless_fit and priorities[i] <= priorities[multi_bucket_bumpless_index])): #This could potentially be a match
|
181 |
+
|
182 |
+
# If this is a valid space to put this without bumping and it is a better fit than previous spaces
|
183 |
+
if(space_before_next_block > item and space_before_next_block < multi_bucket_bumpless_value):
|
184 |
+
# set this to be the pointer! we can fit stuff here
|
185 |
+
found_multi_bucket_bumpless_fit = True
|
186 |
+
multi_bucket_bumpless_index = i
|
187 |
+
multi_bucket_bumpless_value = space_before_next_block
|
188 |
+
|
189 |
+
# Find the overflow that will bump the least
|
190 |
+
if ( item - space_before_next_block < multi_bucket_bumping_value):
|
191 |
+
found_multi_bucket_bumping_fit = True
|
192 |
+
multi_bucket_bumping_index = i
|
193 |
+
multi_bucket_bumping_value = item - space_before_next_block
|
194 |
+
|
195 |
+
if(value_in_bins[i] + item <= bucket_size): #Would fit
|
196 |
+
if(single_bucket_value > value_in_bins[i]):
|
197 |
+
found_single_bucket_fit = True
|
198 |
+
single_bucket_value = value_in_bins[i]
|
199 |
+
single_bucket_index = i
|
200 |
+
|
201 |
+
if (single_bucket_index == multi_bucket_bumpless_index == multi_bucket_bumping_index == -1):
|
202 |
+
bins[0] = torch.cat( (torch.tensor([item], device=device), bins[0]), 0)
|
203 |
+
token_bins[0] = torch.cat( (torch.tensor([token_index], device=device), token_bins[0]), 0)
|
204 |
+
continue
|
205 |
+
|
206 |
+
|
207 |
+
if found_single_bucket_fit:
|
208 |
+
# We found somewhere we can actually fit!
|
209 |
+
bins[single_bucket_index] = torch.cat( (bins[single_bucket_index], torch.tensor([item], device=device)), 0)
|
210 |
+
token_bins[single_bucket_index] = torch.cat( (token_bins[single_bucket_index], torch.tensor([token_index], device=device)), 0)
|
211 |
+
value_in_bins[single_bucket_index] += item
|
212 |
+
for i in range(0, single_bucket_index+1):
|
213 |
+
space_left_after[i] -= item
|
214 |
+
|
215 |
+
elif found_multi_bucket_bumpless_fit:
|
216 |
+
# Found somewhere we can put this without upsetting the force
|
217 |
+
part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumpless_index]
|
218 |
+
part_overflow = item - part_in_bucket
|
219 |
+
bins[multi_bucket_bumpless_index] = torch.cat( (bins[multi_bucket_bumpless_index], torch.tensor([item], device=device)), 0)
|
220 |
+
token_bins[multi_bucket_bumpless_index] = torch.cat( (token_bins[multi_bucket_bumpless_index], torch.tensor([token_index], device=device)), 0)
|
221 |
+
value_in_bins[multi_bucket_bumpless_index] = bucket_size
|
222 |
+
|
223 |
+
# Fill this bucket and continue overflowing
|
224 |
+
j = multi_bucket_bumpless_index + 1
|
225 |
+
for i in range(0, j):
|
226 |
+
space_left_after[i] -= item
|
227 |
+
|
228 |
+
while(part_overflow > 0):
|
229 |
+
new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size
|
230 |
+
value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled
|
231 |
+
space_left_after[j] -= part_overflow
|
232 |
+
part_overflow = new_part_overflow
|
233 |
+
j+=1
|
234 |
+
|
235 |
+
else:
|
236 |
+
part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumping_index]
|
237 |
+
part_overflow = item - part_in_bucket
|
238 |
+
bins[multi_bucket_bumping_index] = torch.cat( (bins[multi_bucket_bumping_index], torch.tensor([item], device=device)), 0)
|
239 |
+
token_bins[multi_bucket_bumping_index] = torch.cat( (token_bins[multi_bucket_bumping_index], torch.tensor([token_index], device=device)), 0)
|
240 |
+
value_in_bins[multi_bucket_bumping_index] = bucket_size
|
241 |
+
|
242 |
+
# Fill this bucket and continue overflowing
|
243 |
+
j = multi_bucket_bumping_index + 1
|
244 |
+
for i in range(0, j):
|
245 |
+
space_left_after[i] -= item
|
246 |
+
while(part_overflow > 0):
|
247 |
+
new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size
|
248 |
+
value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled
|
249 |
+
space_left_after[j] -= part_overflow
|
250 |
+
part_overflow = new_part_overflow
|
251 |
+
j+=1
|
252 |
+
|
253 |
+
sorted_tensor = torch.cat(bins, 0)
|
254 |
+
sorted_tokens = torch.cat(token_bins, 0)
|
255 |
+
|
256 |
+
return sorted_tensor, sorted_tokens
|
257 |
+
|
258 |
+
def compute_ev(t, precision):
|
259 |
+
expected_bits = []
|
260 |
+
cum_probs = t.cumsum(0)
|
261 |
+
|
262 |
+
for selection in range(0, len(cum_probs)):
|
263 |
+
|
264 |
+
# Calculate new range as ints
|
265 |
+
new_int_bottom = cum_probs[selection-1] if selection > 0 else 0
|
266 |
+
new_int_top = cum_probs[selection]
|
267 |
+
|
268 |
+
# Convert range to bits
|
269 |
+
new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
|
270 |
+
new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
|
271 |
+
|
272 |
+
# Consume most significant bits which are now fixed and update interval
|
273 |
+
num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
|
274 |
+
expected_bits.append(t[selection] * num_bits_encoded)
|
275 |
+
|
276 |
+
return(float(sum(expected_bits).item())/(2**precision))
|
277 |
+
|
278 |
+
def visualize_bins(values_in_bins, bucket_size):
|
279 |
+
out_str = "["
|
280 |
+
for b in values_in_bins:
|
281 |
+
out_str = out_str + " " + str(round(100*b/bucket_size,2)) + " |"
|
282 |
+
out_str = out_str + "]"
|
283 |
+
print(out_str)
|
284 |
+
|
285 |
+
def visualize_distribution(l):
|
286 |
+
total = sum(l)
|
287 |
+
out_str = "["
|
288 |
+
for b in l:
|
289 |
+
out_str = out_str + " " + str(round(100*b/total,2)) + " |"
|
290 |
+
out_str = out_str + "]"
|
291 |
+
print(out_str)
|
292 |
+
|
293 |
+
def compute_entropy(lists):
|
294 |
+
total = sum(lists)
|
295 |
+
entropy = -1*sum([ (x/total) * math.log2(x/total) for x in lists])
|
296 |
+
return entropy
|