Xinyoumeng233hu commited on
Commit
229a3ba
·
1 Parent(s): c6624a3

Upload 8 files

Browse files
Files changed (8) hide show
  1. demo.py +68 -0
  2. drgb.py +43 -0
  3. huffman.py +119 -0
  4. huffman_baseline.py +166 -0
  5. meteor.py +275 -0
  6. run_single.py +166 -0
  7. sample.py +55 -0
  8. utils.py +296 -0
demo.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from run_single import encode_message,decode_message
2
+ import gradio as gr
3
+
4
+ def encode_decode_message(mode, secret_message, chosen_context):
5
+ x= encode_message(mode, secret_message, chosen_context)
6
+ return x
7
+
8
+ def decode_encode_message(mode, stegomessage, chosen_context):
9
+ y = decode_message(mode, stegomessage, chosen_context)
10
+ return y
11
+
12
+ def clear1():
13
+ # en_input2.value = ""
14
+ en_method.value = ""
15
+ en_output.value = ""
16
+ en_ppl.value = ""
17
+ en_kl.value = ""
18
+ en_wordsbit.value = ""
19
+ # en_entropy.value = ""
20
+ return "","","","","","",""
21
+
22
+ def clear2():
23
+ # de_input2.value = ""
24
+ de_method.value = ""
25
+ de_output.value = ""
26
+ return "","",""
27
+
28
+ # Modify the demo block to add Textbox widgets for metrics
29
+ with gr.Blocks() as demo:
30
+
31
+ gr.Markdown("<center><h1>GPT2隐写系统</h1></center>")
32
+ gr.Markdown("使用gpt2模型进行隐写与提取.")
33
+
34
+ with gr.Tab("加密"):
35
+ en_input1 = gr.Textbox(label="上下文",placeholder="Input context")
36
+ en_input2 = gr.Textbox(label="秘密信息",placeholder="Message text to be concealed")
37
+ en_method = gr.Dropdown(label="嵌入算法", choices=["meteor", "arithmetic", "huffman", "bins"])
38
+ en_output = gr.Textbox(label="含密文本")
39
+ with gr.Row():
40
+ en_ppl = gr.Textbox(label="困惑度")
41
+ en_kl = gr.Textbox(label="KL散度")
42
+ en_wordsbit = gr.Textbox(label="每比特所携带的信息量")
43
+ # en_entropy = gr.Textbox(label="信道容量")
44
+ with gr.Row():
45
+ en_button1 = gr.Button("清除")
46
+ en_button2 = gr.Button("加密")
47
+ en_input1.value = "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist." #@param ["Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.", "The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.", "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist."] {allow-input: true}
48
+ # en_input1.value = "Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission."
49
+ en_input2.value = "In me the tiger sniffs the tiger."
50
+ # en_input2.value = "hello"
51
+
52
+ with gr.Tab("解密"):
53
+ de_input1 = gr.Textbox(label="上下文")
54
+ de_input2 = gr.Textbox(label="含密文本",placeholder="Output of encrypt")
55
+ de_method = gr.Dropdown(label="嵌入算法",choices=["meteor", "arithmetic", "huffman", "bins"])
56
+ de_output = gr.Textbox(label="恢复的秘密信息")
57
+ with gr.Row():
58
+ de_button1 = gr.Button("清除")
59
+ de_button2 = gr.Button("解密")
60
+ de_input1.value = "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist." #@param ["Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.", "The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.", "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist."] {allow-input: true}
61
+ # de_input1.value = "Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission."
62
+
63
+ en_button1.click(clear1, outputs=[en_output,en_ppl,en_kl,en_wordsbit])
64
+ en_button2.click(encode_decode_message, inputs=[en_method, en_input2, en_input1], outputs=[en_output,en_ppl,en_kl,en_wordsbit])
65
+ de_button1.click(clear2, outputs=[de_input2,de_output])
66
+ de_button2.click(decode_encode_message, inputs=[de_method, de_input2, de_input1], outputs=[de_output])
67
+
68
+ demo.launch()
drgb.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Colab setup { run: "auto", display-mode: "form" }
2
+ #@markdown This downloads some prereqs. It might take a while! You only have to run this cell once.
3
+ # !pip install torch==1.13.1 pytorch-transformers==1.1.0 bitarray==1.0.1
4
+ import hashlib
5
+ import hmac
6
+ import numpy as np
7
+
8
+ class DRBG(object):
9
+ def __init__(self, key, seed):
10
+ self.key = key
11
+ self.val = b'\x01' * 64
12
+ self.reseed(seed)
13
+
14
+ self.byte_index = 0
15
+ self.bit_index = 0
16
+
17
+ def hmac(self, key, val):
18
+ return hmac.new(key, val, hashlib.sha512).digest()
19
+
20
+ def reseed(self, data=b''):
21
+ self.key = self.hmac(self.key, self.val + b'\x00' + data)
22
+ self.val = self.hmac(self.key, self.val)
23
+
24
+ if data:
25
+ self.key = self.hmac(self.key, self.val + b'\x01' + data)
26
+ self.val = self.hmac(self.key, self.val)
27
+
28
+ def generate_bits(self, n):
29
+ xs = np.zeros(n, dtype=bool)
30
+ for i in range(0,n):
31
+ xs[i] = (self.val[self.byte_index] >> (7 - self.bit_index)) & 1
32
+
33
+ self.bit_index += 1
34
+ if self.bit_index >= 8:
35
+ self.bit_index = 0
36
+ self.byte_index += 1
37
+
38
+ if self.byte_index >= 8:
39
+ self.byte_index = 0
40
+ self.val = self.hmac(self.key, self.val)
41
+
42
+ self.reseed()
43
+ return xs
huffman.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import heapq
2
+ from functools import total_ordering
3
+
4
+ """
5
+ Code for Huffman Coding, compression and decompression.
6
+ Explanation at http://bhrigu.me/blog/2017/01/17/huffman-coding-python-implementation/
7
+ Adapted from https://github.com/bhrigu123/huffman-coding
8
+ """
9
+
10
+ @total_ordering
11
+ class HeapNode:
12
+ def __init__(self, token, freq):
13
+ self.token = token
14
+ self.freq = freq
15
+ self.left = None
16
+ self.right = None
17
+
18
+ # defining comparators less_than and equals
19
+ def __lt__(self, other):
20
+ return self.freq < other.freq
21
+
22
+ def __eq__(self, other):
23
+ if(other == None):
24
+ return False
25
+ if(not isinstance(other, HeapNode)):
26
+ return False
27
+ return self.freq == other.freq
28
+
29
+ class HuffmanCoding:
30
+ def __init__(self):
31
+ self.heap = []
32
+ self.codes = {}
33
+ self.reverse_mapping = {}
34
+
35
+ # functions for compression:
36
+
37
+ def make_heap(self, frequency):
38
+ for key in frequency:
39
+ node = HeapNode(key, frequency[key])
40
+ heapq.heappush(self.heap, node)
41
+
42
+ def make_heap_from_array(self, freqs):
43
+ for index in range(len(freqs)):
44
+ node = HeapNode(index, freqs[index])
45
+ heapq.heappush(self.heap, node)
46
+
47
+ def merge_nodes(self):
48
+ while(len(self.heap)>1):
49
+ node1 = heapq.heappop(self.heap)
50
+ node2 = heapq.heappop(self.heap)
51
+
52
+ merged = HeapNode(None, node1.freq + node2.freq)
53
+ merged.left = node1
54
+ merged.right = node2
55
+
56
+ heapq.heappush(self.heap, merged)
57
+
58
+
59
+ def make_codes_helper(self, root, current_code):
60
+ if(root == None):
61
+ return
62
+
63
+ if(root.token != None):
64
+ self.codes[root.token] = current_code
65
+ self.reverse_mapping[current_code] = root.token
66
+ return
67
+
68
+ self.make_codes_helper(root.left, current_code + "0")
69
+ self.make_codes_helper(root.right, current_code + "1")
70
+
71
+ def make_codes(self):
72
+ root = heapq.heappop(self.heap)
73
+ current_code = ""
74
+ self.make_codes_helper(root, current_code)
75
+ return root
76
+
77
+
78
+ def get_encoded_tokens(self, token_list):
79
+ encoded_text = ""
80
+ for token in token_list:
81
+ encoded_text += self.codes[token]
82
+ return encoded_text
83
+
84
+ def decode_text(self, encoded_text):
85
+ current_code = ""
86
+ decoded_text = ""
87
+
88
+ for bit in encoded_text:
89
+ current_code += bit
90
+ if(current_code in self.reverse_mapping):
91
+ character = self.reverse_mapping[current_code]
92
+ decoded_text += character
93
+ current_code = ""
94
+
95
+ return decoded_text
96
+
97
+
98
+ def decompress(self, input_path):
99
+ filename, file_extension = os.path.splitext(self.path)
100
+ output_path = filename + "_decompressed" + ".txt"
101
+
102
+ with open(input_path, 'rb') as file, open(output_path, 'w') as output:
103
+ bit_string = ""
104
+
105
+ byte = file.read(1)
106
+ while(len(byte) > 0):
107
+ byte = ord(byte)
108
+ bits = bin(byte)[2:].rjust(8, '0')
109
+ bit_string += bits
110
+ byte = file.read(1)
111
+
112
+ encoded_text = self.remove_padding(bit_string)
113
+
114
+ decompressed_text = self.decode_text(encoded_text)
115
+
116
+ output.write(decompressed_text)
117
+
118
+ return output_path
119
+
huffman_baseline.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from huffman import HuffmanCoding
5
+ from utils import kl, entropy, is_sent_finish, limit_past
6
+
7
+ def encode_huffman(model, enc, message, context, bits_per_word, finish_sent=False, device='cuda'):
8
+ length = len(message)
9
+
10
+ context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
11
+
12
+ prev = context
13
+ output = context
14
+ past = None
15
+
16
+ total_num = 0
17
+ total_num_for_stats = 0
18
+ total_log_probs = 0
19
+ total_kl = 0 # in bits
20
+ total_num_sents = 0
21
+
22
+ with torch.no_grad():
23
+ i = 0
24
+ sent_finish = False
25
+ while i < length or (finish_sent and not sent_finish):
26
+ logits, past = model(prev.unsqueeze(0), past=past)
27
+ past = limit_past(past)
28
+ logits[0, -1, -1] = -1e10 # endoftext can't happen
29
+ logits[0, -1, 628] = -1e10 # 2 newlines can't happen
30
+ logits, indices = logits[0, -1, :].sort(descending=True)
31
+
32
+ # Get the top 2**bits options
33
+ indices = indices[:2**bits_per_word]
34
+ log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word]
35
+ probs = torch.exp(log_probs)
36
+
37
+ if i >= length:
38
+ selection = 0
39
+ sent_finish = is_sent_finish(indices[0].item(), enc)
40
+ else:
41
+ probs_array = probs.cpu().numpy()
42
+ coding = HuffmanCoding()
43
+ coding.make_heap_from_array(probs_array)
44
+ coding.merge_nodes()
45
+ root = coding.make_codes()
46
+
47
+ #print(message[i:i+10])
48
+ while root.token is None:
49
+ if i >= length or message[i] == 0:
50
+ root = root.left
51
+ else:
52
+ root = root.right
53
+ i += 1
54
+ selection = root.token
55
+
56
+ logq = torch.tensor([-len(coding.codes[idx]) for idx in range(len(probs_array))], dtype=torch.float, device=device) # in bits
57
+ logq = logq*0.69315 # in nats
58
+ q = torch.exp(logq)
59
+ total_kl += kl(q, logq, log_probs)
60
+ total_log_probs += log_probs[selection].item()
61
+ total_num_for_stats += 1
62
+
63
+ total_num += 1
64
+
65
+ prev = indices[selection].view(1)
66
+ output = torch.cat((output, prev))
67
+
68
+ avg_NLL = -total_log_probs/total_num_for_stats
69
+ avg_KL = total_kl/total_num_for_stats
70
+ words_per_bit = total_num_for_stats/i
71
+
72
+ return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
73
+
74
+ def decode_huffman(model, enc, text, context, bits_per_word, device='cuda'):
75
+ # inp is a list of token indices
76
+ # context is a list of token indices
77
+ inp = enc.encode(text)
78
+ i = 0
79
+ while i < len(inp):
80
+ if inp[i] == 628:
81
+ inp[i] = 198
82
+ inp[i+1:i+1] = [198]
83
+ i += 2
84
+ else:
85
+ i += 1
86
+
87
+ context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
88
+ prev = context
89
+ past = None
90
+
91
+ message = []
92
+ with torch.no_grad():
93
+ i = 0
94
+ while i < len(inp):
95
+ if past and past[0].shape[3] >= 1023:
96
+ raise RuntimeError
97
+
98
+ logits, past = model(prev.unsqueeze(0), past=past)
99
+ past = limit_past(past)
100
+ logits[0, -1, -1] = -1e10 # endoftext can't happen
101
+ logits[0, -1, 628] = -1e10 # 2 newlines can't happen
102
+ logits, indices = logits[0, -1, :].sort(descending=True)
103
+
104
+ # Get the top 2**bits options
105
+ indices = indices[:2**bits_per_word]
106
+ log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word]
107
+ probs = torch.exp(log_probs)
108
+
109
+ if inp[i] not in indices:
110
+ true_token_text = enc.decoder[inp[i]]
111
+ for rank_idx in range(2**bits_per_word):
112
+ prop_token_text = enc.decoder[indices[rank_idx].item()]
113
+ # common case that is not caught
114
+ if inp[i] == 128 and indices[rank_idx] == 198:
115
+ rank = rank_idx
116
+ inp[i] = indices[rank_idx].item()
117
+ break
118
+
119
+ # Is there a more likely prefix token that could be the actual token generated?
120
+ if len(prop_token_text) <= len(true_token_text) and \
121
+ prop_token_text == true_token_text[:len(prop_token_text)]:
122
+ rank = rank_idx
123
+ suffix = true_token_text[len(prop_token_text):]
124
+ suffix_tokens = enc.encode(suffix) # a list
125
+ inp[i] = indices[rank_idx].item()
126
+ inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
127
+ break
128
+
129
+ # Is there a more likely longer token that could be the actual token generated?
130
+ elif len(prop_token_text) > len(true_token_text) and \
131
+ true_token_text == prop_token_text[:len(true_token_text)]:
132
+ whole_text = true_token_text
133
+ num_extra = 1
134
+ while len(whole_text) < len(prop_token_text):
135
+ whole_text += enc.decoder[inp[i+num_extra]]
136
+ num_extra += 1
137
+ if prop_token_text == whole_text[:len(prop_token_text)]:
138
+ rank = rank_idx
139
+ inp[i] = indices[rank_idx].item()
140
+ for j in range(1, num_extra):
141
+ del inp[i+j]
142
+
143
+ if len(whole_text) > len(prop_token_text):
144
+ suffix = whole_text[len(prop_token_text):]
145
+ suffix_tokens = enc.encode(suffix) # a list
146
+ inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
147
+ break
148
+ else:
149
+ print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
150
+ rank = 0
151
+ else:
152
+ rank = (indices == inp[i]).nonzero().item()
153
+
154
+ probs_array = probs.cpu().numpy()
155
+ coding = HuffmanCoding()
156
+ coding.make_heap_from_array(probs_array)
157
+ coding.merge_nodes()
158
+ coding.make_codes()
159
+
160
+ tokens_t = map(int, coding.codes[rank])
161
+
162
+ message.extend(tokens_t)
163
+ prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
164
+ i += 1
165
+
166
+ return message
meteor.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ import os
6
+
7
+ from drgb import DRBG
8
+ from utils import bin_sort, bits2int, entropy, int2bits, is_sent_finish, kl, limit_past, num_same_from_beg
9
+
10
+ # Constants for HMAC-DRBG -- MUST CHANGE FOR SECURE IMPLEMENTATION
11
+ sample_key = b'0x01'*64
12
+ sample_seed_prefix = b'sample'
13
+ sample_nonce_counter = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
14
+
15
+
16
+ def encode_meteor(model, enc, message, context, finish_sent=False, device='cuda', temp=1.0, precision=16, topk=50000, is_sort=False, randomize_key=False, input_key=sample_key, input_nonce=sample_nonce_counter):
17
+
18
+ if randomize_key:
19
+ input_key = os.urandom(64)
20
+ mask_generator = DRBG(input_key, sample_seed_prefix + input_nonce)
21
+ context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
22
+
23
+
24
+ max_val = 2**precision
25
+ threshold = 2**(-precision)
26
+ cur_interval = [0, max_val] # bottom inclusive, top exclusive
27
+
28
+ prev = context
29
+ output = context
30
+ past = None
31
+
32
+ total_num = 0
33
+ total_num_for_stats = 0
34
+ total_log_probs = 0
35
+ total_kl = 0 # in bits
36
+ total_entropy_ptau = 0
37
+ total_num_sents = 0
38
+
39
+ with torch.no_grad():
40
+ i = 0
41
+ sent_finish = False
42
+ while i < len(message) or (finish_sent and not sent_finish):
43
+ logits, past = model(prev.unsqueeze(0), past=past)
44
+ past = limit_past(past)
45
+ logits[0, -1, -1] = -1e20 # endoftext token can't happen
46
+ logits[0, -1, 628] = -1e20 # 2 newlines token can't happen
47
+ logits, indices = logits[0, -1, :].sort(descending=True)
48
+ logits = logits.double()
49
+ logits_temp = logits / temp
50
+ probs_temp = F.softmax(logits_temp, dim=0)
51
+ log_probs_temp = F.log_softmax(logits_temp, dim=0)
52
+ log_probs = F.log_softmax(logits, dim=0)
53
+
54
+ # conditions for having reached the end of the message
55
+ if i >= len(message):
56
+ selection = 0
57
+ sent_finish = is_sent_finish(indices[selection].item(), enc)
58
+ else:
59
+ # Cutoff low probabilities that would be rounded to 0
60
+ cur_int_range = cur_interval[1]-cur_interval[0]
61
+ cur_threshold = 1/cur_int_range
62
+ k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
63
+ probs_temp_int = probs_temp[:k] # Cutoff all but top k
64
+ old_indices = indices
65
+ indices = indices[:k]
66
+
67
+ # Rescale to correct range
68
+ probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range
69
+
70
+ entropy_in_this_distribution = entropy(probs_temp, log_probs_temp)
71
+
72
+ # Round probabilities to integers given precision
73
+ probs_temp_int = probs_temp_int.round().long()
74
+
75
+ if is_sort:
76
+ probs_temp_int, indices = bin_sort(probs_temp_int, indices, cur_int_range, entropy_in_this_distribution, device)
77
+ cum_probs = probs_temp_int.cumsum(0)
78
+
79
+ # Remove any elements from the bottom if rounding caused the total prob to be too large
80
+ overfill_index = (cum_probs > cur_int_range).nonzero()
81
+ if len(overfill_index) > 0:
82
+ cum_probs = cum_probs[:overfill_index[0]]
83
+
84
+ # Add any mass to the top if removing/rounding causes the total prob to be too small
85
+ cum_probs += cur_int_range-cum_probs[-1] # add
86
+
87
+ # Get out resulting probabilities
88
+ probs_final = cum_probs.clone()
89
+ probs_final[1:] = cum_probs[1:] - cum_probs[:-1]
90
+
91
+ # Convert to position in range
92
+ cum_probs += cur_interval[0]
93
+
94
+ # Apply the mask to the message
95
+ message_bits = message[i:i+precision]
96
+ if i+precision > len(message):
97
+ message_bits = message_bits + [0]*(i+precision-len(message))
98
+
99
+ mask_bits = mask_generator.generate_bits(precision)
100
+ for b in range(0, len(message_bits)):
101
+ message_bits[b] = message_bits[b] ^ mask_bits[b]
102
+
103
+ # Get selected index based on binary fraction from message bits
104
+ message_idx = bits2int(reversed(message_bits))
105
+ selection = (cum_probs > message_idx).nonzero()[0].item()
106
+
107
+ # Calculate new range as ints
108
+ new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]
109
+ new_int_top = cum_probs[selection]
110
+
111
+ # Convert range to bits
112
+ new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
113
+ new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
114
+
115
+ # Consume most significant bits which are now fixed and update interval
116
+ num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
117
+ i += num_bits_encoded
118
+
119
+ # Gather statistics
120
+ total_log_probs += log_probs[selection].item()
121
+
122
+ q = probs_final.double()/probs_final.sum()
123
+ logq = q.log()
124
+ total_kl += kl(q, logq, log_probs[:len(q)])
125
+ total_entropy_ptau += entropy_in_this_distribution
126
+ total_num_for_stats += 1
127
+
128
+ # Update history with new token
129
+ prev = indices[selection].view(1)
130
+ output = torch.cat((output, prev))
131
+ total_num += 1
132
+
133
+ # For text->bits->text
134
+ partial = enc.decode(output[len(context):].tolist())
135
+ if '<eos>' in partial:
136
+ break
137
+
138
+ avg_NLL = -total_log_probs/total_num_for_stats
139
+ avg_KL = total_kl/total_num_for_stats
140
+ # avg_Hq = total_entropy_ptau/total_num_for_stats
141
+ words_per_bit = total_num_for_stats/i
142
+
143
+ return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
144
+
145
+ def decode_meteor(model, enc, text, context, device='cuda', temp=1.0, precision=16, topk=50000, is_sort=False, input_key=sample_key, input_nonce=sample_nonce_counter):
146
+ # inp is a list of token indices
147
+ # context is a list of token indices
148
+ inp = enc.encode(text)
149
+
150
+ context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
151
+ mask_generator = DRBG(input_key, sample_seed_prefix + input_nonce)
152
+
153
+ max_val = 2**precision
154
+ threshold = 2**(-precision)
155
+ cur_interval = [0, max_val] # bottom inclusive, top exclusive
156
+
157
+ prev = context
158
+ past = None
159
+ message = []
160
+ with torch.no_grad():
161
+ i = 0
162
+ while i < len(inp):
163
+ logits, past = model(prev.unsqueeze(0), past=past)
164
+ past = limit_past(past)
165
+ logits[0, -1, -1] = -1e20 # endoftext can't happen
166
+ logits[0, -1, 628] = -1e20 # 2 newlines can't happen
167
+ logits, indices = logits[0, -1, :].sort(descending=True)
168
+ logits = logits.double()
169
+ logits_temp = logits / temp
170
+ log_probs_temp = F.log_softmax(logits_temp, dim=0)
171
+ probs_temp = F.softmax(logits_temp, dim=0)
172
+
173
+ # Cutoff low probabilities that would be rounded to 0
174
+ cur_int_range = cur_interval[1]-cur_interval[0]
175
+ cur_threshold = 1/cur_int_range
176
+ k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
177
+ probs_temp_int = probs_temp[:k] # Cutoff all but top k
178
+
179
+ # Rescale to correct range
180
+ probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range
181
+ entropy_in_this_distribution = entropy(probs_temp, log_probs_temp)
182
+
183
+ # Round probabilities to integers given precision
184
+ probs_temp_int = probs_temp_int.round().long()
185
+ if is_sort:
186
+ probs_temp_int, indices = bin_sort(probs_temp_int, indices, cur_int_range, entropy_in_this_distribution, device)
187
+ cum_probs = probs_temp_int.cumsum(0)
188
+
189
+ # Remove any elements from the bottom if rounding caused the total prob to be too large
190
+ overfill_index = (cum_probs > cur_int_range).nonzero()
191
+ if len(overfill_index) > 0:
192
+ cum_probs = cum_probs[:overfill_index[0]]
193
+ k = overfill_index[0].item()
194
+
195
+ # Add any mass to the top if removing/rounding causes the total prob to be too small
196
+ cum_probs += cur_int_range-cum_probs[-1] # add
197
+
198
+ # Covnert to position in range
199
+ cum_probs += cur_interval[0]
200
+
201
+ rank = (indices == inp[i]).nonzero().item()
202
+
203
+ # Handle most errors that could happen because of BPE with heuristic
204
+ if rank >= k:
205
+ true_token_text = enc.decoder[inp[i]]
206
+ for rank_idx in range(k):
207
+ prop_token_text = enc.decoder[indices[rank_idx].item()]
208
+ # common case that is not caught
209
+ if inp[i] == 128 and indices[rank_idx] == 198:
210
+ rank = rank_idx
211
+ inp[i] = indices[rank_idx].item()
212
+ break
213
+
214
+ # Is there a more likely prefix token that could be the actual token generated?
215
+ if len(prop_token_text) <= len(true_token_text) and \
216
+ prop_token_text == true_token_text[:len(prop_token_text)]:
217
+ rank = rank_idx
218
+ suffix = true_token_text[len(prop_token_text):]
219
+ suffix_tokens = enc.encode(suffix) # a list
220
+ inp[i] = indices[rank_idx].item()
221
+ inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
222
+ break
223
+
224
+ # Is there a more likely longer token that could be the actual token generated?
225
+ elif len(prop_token_text) > len(true_token_text) and \
226
+ true_token_text == prop_token_text[:len(true_token_text)]:
227
+ whole_text = true_token_text
228
+ num_extra = 1
229
+ while len(whole_text) < len(prop_token_text):
230
+ whole_text += enc.decoder[inp[i+num_extra]]
231
+ num_extra += 1
232
+ if prop_token_text == whole_text[:len(prop_token_text)]:
233
+ rank = rank_idx
234
+ inp[i] = indices[rank_idx].item()
235
+ for j in range(1, num_extra):
236
+ del inp[i+j]
237
+
238
+ if len(whole_text) > len(prop_token_text):
239
+ suffix = whole_text[len(prop_token_text):]
240
+ suffix_tokens = enc.encode(suffix) # a list
241
+ inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
242
+ break
243
+ else:
244
+ print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
245
+ rank = 0
246
+
247
+ selection = rank
248
+
249
+ # Calculate new range as ints
250
+ new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]
251
+ new_int_top = cum_probs[selection]
252
+
253
+ # Convert range to bits
254
+ new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
255
+ new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
256
+
257
+ # Emit most significant bits which are now fixed and update interval
258
+ num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
259
+ if i == len(inp)-1:
260
+ new_bits = new_int_bottom_bits_inc
261
+ else:
262
+ new_bits = new_int_top_bits_inc[:num_bits_encoded]
263
+
264
+ # Get the mask and apply it to the recovered bits
265
+ mask_bits = mask_generator.generate_bits(precision)
266
+ for b in range(0, len(new_bits)):
267
+ new_bits[b] = new_bits[b] ^ mask_bits[b]
268
+ message += new_bits
269
+
270
+ # Update history with new token
271
+ prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
272
+
273
+ i += 1
274
+
275
+ return message
run_single.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import bitarray
3
+ import sys
4
+ import re
5
+ import math
6
+ from meteor import encode_meteor,decode_meteor
7
+ from utils import get_model, encode_context
8
+ from arithmetic import encode_arithmetic, decode_arithmetic
9
+ from block_baseline import get_bins, encode_block, decode_block
10
+ from huffman_baseline import encode_huffman, decode_huffman
11
+ from sample import sample
12
+
13
+ def encode_message(mode, message_str, context):
14
+ enc, model = get_model(model_name='gpt2')
15
+ ## PARAMETERS
16
+ # message_str = input("input secret message:")
17
+ unicode_enc = False
18
+ # mode = 'meteor'
19
+ # mode = input("Please enter mode (arithmetic, huffman, bins, or sample): ")
20
+ block_size = 3 # for huffman and bins
21
+ temp = 0.9 # for arithmetic
22
+ precision = 26 # for arithmetic
23
+ sample_tokens = 100 # for sample
24
+ topk = 300
25
+ finish_sent=True # whether or not to force finish sent. If so, stats displayed will be for non-finished sentence
26
+ meteor_sort = False
27
+ meteor_random = False
28
+
29
+ key = b'0x01'*64
30
+ sample_seed_prefix = b'sample'
31
+ nonce = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
32
+
33
+ ## VALIDATE PARAMETERS
34
+ if mode not in ['meteor', 'arithmetic', 'huffman', 'bins']:
35
+ raise NotImplementedError
36
+
37
+ if mode == 'bins':
38
+ bin2words, words2bin = get_bins(len(enc.encoder), block_size)
39
+
40
+ # context = \
41
+ # """Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.
42
+ # """
43
+ # context = "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist." #@param ["Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.", "The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.", "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist."] {allow-input: true}
44
+ context_tokens = encode_context(context, enc)
45
+ # ------------------------------------------------------------------------------------
46
+ # ------------------------------------------------------------------------------------
47
+ # First encode message to uniform bits, without any context
48
+ # (not essential this is arithmetic vs ascii, but it's more efficient when the message is natural language)
49
+ if unicode_enc:
50
+ ba = bitarray.bitarray()
51
+ ba.frombytes(message_str.encode('utf-8'))
52
+ message = ba.tolist()
53
+ else:
54
+ message_ctx = [enc.encoder['<|endoftext|>']]
55
+ message_str += '<eos>'
56
+ message = decode_arithmetic(model, enc, message_str, message_ctx, precision=40, topk=60000)
57
+ # Next encode bits into cover text, using arbitrary context
58
+ Hq = 0
59
+ if mode == 'arithmetic':
60
+ out, nll, kl, words_per_bit = encode_arithmetic(model, enc, message, context_tokens, temp=temp, finish_sent=finish_sent, precision=precision, topk=topk)
61
+ elif mode == 'huffman':
62
+ out, nll, kl, words_per_bit = encode_huffman(model, enc, message, context_tokens, block_size, finish_sent=finish_sent)
63
+ elif mode == 'bins':
64
+ out, nll, kl, words_per_bit = encode_block(model, enc, message, context_tokens, block_size, bin2words, words2bin, finish_sent=finish_sent)
65
+ elif mode == 'meteor':
66
+ out, nll, kl, words_per_bit = encode_meteor(model, enc, message, context_tokens, temp=temp, finish_sent=finish_sent,
67
+ precision=precision, topk=topk, is_sort=meteor_sort, randomize_key=meteor_random, input_key=key, input_nonce=nonce)
68
+ elif mode == 'sample':
69
+ out, nll, kl, Hq = sample(model, enc, sample_tokens, context_tokens, temperature=temp, topk=topk)
70
+ words_per_bit = 1
71
+ text = enc.decode(out)
72
+
73
+ # print(message)
74
+ # print(len(message))
75
+ print("="*40 + " Encoding " + "="*40)
76
+ print(text)
77
+ print('ppl: %0.2f, kl: %0.3f, words/bit: %0.2f, bits/word: %0.2f, entropy: %.2f' % (math.exp(nll), kl, words_per_bit, 1/words_per_bit, Hq/0.69315))
78
+
79
+ stats = {
80
+ "ppl": math.exp(nll),
81
+ "kl": kl,
82
+ "wordsbit": words_per_bit,
83
+ "entropy": Hq/0.69315
84
+ }
85
+ # return text, stats
86
+ return text,stats["ppl"], stats["kl"], stats["wordsbit"]
87
+
88
+ def decode_message(mode, text, context):
89
+ enc, model = get_model(model_name='gpt2')
90
+ ## PARAMETERS
91
+ unicode_enc = False
92
+ # mode = 'meteor'
93
+ block_size = 3 # for huffman and bins
94
+ temp = 0.9 # for arithmetic
95
+ precision = 26 # for arithmetic
96
+ sample_tokens = 100 # for sample
97
+ topk = 300
98
+ finish_sent=True # whether or not to force finish sent. If so, stats displayed will be for non-finished sentence
99
+ meteor_sort = False
100
+ meteor_random = False
101
+
102
+ key = b'0x01'*64
103
+ sample_seed_prefix = b'sample'
104
+ nonce = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
105
+
106
+ ## VALIDATE PARAMETERS
107
+ if mode not in ['meteor', 'arithmetic', 'huffman', 'bins', 'sample']:
108
+ raise NotImplementedError
109
+ if mode == 'bins':
110
+ bin2words, words2bin = get_bins(len(enc.encoder), block_size)
111
+
112
+ context_tokens = encode_context(context, enc)
113
+
114
+ if mode != 'sample':
115
+ if mode == 'arithmetic':
116
+ message_rec = decode_arithmetic(model, enc, text, context_tokens, temp=temp, precision=precision, topk=topk)
117
+ elif mode == 'huffman':
118
+ message_rec = decode_huffman(model, enc, text, context_tokens, block_size)
119
+ elif mode == 'bins':
120
+ message_rec = decode_block(model, enc, text, context_tokens, block_size, bin2words, words2bin)
121
+ elif mode == 'meteor':
122
+ message_rec = decode_meteor(model, enc, text, context_tokens, temp=temp,
123
+ precision=precision, topk=topk, is_sort=meteor_sort, input_key=key, input_nonce=nonce)
124
+
125
+ print("="*35 + " Recovered Message " + "="*35)
126
+ # print(message_rec)
127
+ # print("=" * 80)
128
+ # Finally map message bits back to original text
129
+ if unicode_enc:
130
+ message_rec = [bool(item) for item in message_rec]
131
+ ba = bitarray.bitarray(message_rec)
132
+ reconst = ba.tobytes().decode('utf-8', 'ignore')
133
+ else:
134
+ message_ctx = [enc.encoder['<|endoftext|>']]
135
+ reconst = encode_arithmetic(model, enc, message_rec, message_ctx, precision=40, topk=60000)
136
+ reconst = enc.decode(reconst[0])
137
+ print(reconst[:-5])
138
+ print("=" * 80)
139
+ return reconst[:-5]
140
+
141
+ # def main():
142
+ # chosen_context = "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist." #@param ["Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.", "The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.", "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist."] {allow-input: true}
143
+ # # #@title { run: "auto", display-mode: "form" }
144
+ # message_text = "generate text!" #@param {type:"string"}
145
+ # mode = input("Please enter mode (meteor, arithmetic, huffman, bins, or sample): ")
146
+ # #@title Run me!
147
+ # #@markdown Make sure to re-run this cell if you change the parameters above.
148
+ # x = encode_message(mode, message_text, chosen_context)
149
+ # # print(x[0])
150
+ # y = decode_message(mode, x[0], chosen_context)
151
+
152
+ # if __name__ == '__main__':
153
+ # main()
154
+
155
+ # chosen_context = "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist." #@param ["Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.", "The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.", "Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist."] {allow-input: true}
156
+ # # chosen_context = "Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission."
157
+ # # chosen_context += "\n\n" # to add a little spacing
158
+
159
+ # # #@title { run: "auto", display-mode: "form" }
160
+ # message_text = "generate text!" #@param {type:"string"}
161
+ # mode = input("Please enter mode (arithmetic, huffman, bins, or sample): ")
162
+ # #@title Run me!
163
+ # #@markdown Make sure to re-run this cell if you change the parameters above.
164
+ # x = encode_message(mode, message_text, chosen_context)
165
+ # # print(x[0])
166
+ # y = decode_message(mode, x[0], chosen_context)
sample.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from utils import limit_past, kl, entropy
5
+
6
+ def sample(model, enc, length, context, temperature=1.0, device='cuda', topk=-1):
7
+ assert length > 0
8
+
9
+ context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
10
+
11
+ prev = context
12
+ output = context
13
+ past = None
14
+
15
+ total_log_probs = 0
16
+ total_entropy_ptau = 0
17
+ total_num = 0
18
+ total_kl = 0 # in bits
19
+
20
+ with torch.no_grad():
21
+ while total_num < length:
22
+ if past and past[0].shape[3] >= 1023:
23
+ raise RuntimeError
24
+
25
+ logits, past = model(prev.unsqueeze(0), past=past)
26
+ past = limit_past(past)
27
+ logits[0, -1, -1] = -1e10 # endoftext can't happen
28
+ logits[0, -1, 628] = -1e10 # 2 newlines can't happen
29
+ logits, indices = logits[0, -1, :].sort(descending=True)
30
+ base_log_probs = F.log_softmax(logits, dim=-1)
31
+
32
+ if topk > 0:
33
+ logits = logits[:topk]
34
+
35
+ logits = logits / temperature
36
+ log_probs = F.log_softmax(logits, dim=-1)
37
+ probs = torch.exp(log_probs)
38
+
39
+ total_kl += kl(probs, log_probs, base_log_probs[:topk])
40
+
41
+ selection = torch.multinomial(probs, num_samples=1).item()
42
+ log_prob_chosen = base_log_probs[selection]
43
+ total_log_probs += log_prob_chosen.item()
44
+
45
+ total_entropy_ptau += entropy(probs, log_probs)
46
+
47
+ prev = indices[selection].view(1)
48
+ output = torch.cat((output, prev))
49
+ total_num += 1
50
+
51
+ avg_NLL = -total_log_probs/total_num
52
+ avg_KL = total_kl/total_num
53
+ avg_Hq = total_entropy_ptau/total_num
54
+
55
+ return output[len(context):].tolist(), avg_NLL, avg_KL, avg_Hq
utils.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import bitarray
4
+
5
+ from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer
6
+
7
+ def decode(self, token_ids, **kwargs):
8
+ filtered_tokens = self.convert_ids_to_tokens(token_ids)
9
+ text = self.convert_tokens_to_string(filtered_tokens)
10
+ return text
11
+ GPT2Tokenizer.decode = decode
12
+
13
+ def _convert_token_to_id(self, token):
14
+ return self.encoder.get(token, 0)
15
+ GPT2Tokenizer._convert_token_to_id = _convert_token_to_id
16
+
17
+
18
+ def limit_past(past):
19
+ past = list(past)
20
+ for i in range(len(past)):
21
+ past[i] = past[i][:, :, :, -1022:]
22
+ return past
23
+
24
+ def kl(q, logq, logp):
25
+ res = q*(logq-logp)/0.69315
26
+ res[q==0] = 0
27
+ return res.sum().item() # in bits
28
+
29
+ def entropy(q, logq):
30
+ res = q*logq/0.69315
31
+ res[q==0] = 0
32
+ return -res.sum().item() # in bits
33
+
34
+ # e.g. [0, 1, 1, 1] looks like 1110=14
35
+ def bits2int(bits):
36
+ res = 0
37
+ for i, bit in enumerate(bits):
38
+ res += bit*(2**i)
39
+ return res
40
+
41
+ def int2bits(inp, num_bits):
42
+ if num_bits == 0:
43
+ return []
44
+ strlist = ('{0:0%db}'%num_bits).format(inp)
45
+ return [int(strval) for strval in reversed(strlist)]
46
+
47
+ def is_sent_finish(token_idx, enc):
48
+ token = enc.decoder[token_idx]
49
+ return '.' in token or '!' in token or '?' in token
50
+
51
+ def num_same_from_beg(bits1, bits2):
52
+ assert len(bits1) == len(bits2)
53
+ for i in range(len(bits1)):
54
+ if bits1[i] != bits2[i]:
55
+ break
56
+
57
+ return i
58
+
59
+ def encode_context(raw_text, enc):
60
+ context_tokens = [enc.encoder['<|endoftext|>']] + enc.encode(raw_text)
61
+ return context_tokens
62
+
63
+ # Use gpt2-medium for 345M param model
64
+ # Use gpt2-large for 774M param model
65
+ def get_model(seed=1234, model_name='gpt2'):
66
+ np.random.seed(seed)
67
+ torch.random.manual_seed(seed)
68
+ torch.cuda.manual_seed(seed)
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+
71
+ enc = GPT2Tokenizer.from_pretrained(model_name)
72
+ enc.unk_token = None
73
+ enc.bos_token = None
74
+ enc.eos_token = None
75
+
76
+ model = GPT2LMHeadModel.from_pretrained(model_name)
77
+ model.to(device)
78
+ model.eval()
79
+ #model.double()
80
+
81
+ return enc, model
82
+
83
+ enc32_itoc = ['\0', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '.', ',', "'", '!', ' ']
84
+ enc32_ctoi = {k: v for v, k in enumerate(enc32_itoc)}
85
+ def enc32(text):
86
+ bits = []
87
+ for c in text:
88
+ bits.extend(int2bits(enc32_ctoi[c], 5))
89
+ return bits
90
+
91
+ def dec32(bits):
92
+ text = ''
93
+ for i in range(0, len(bits), 5):
94
+ c = enc32_itoc[bits2int(bits[i:i+5])]
95
+ if c == '\0':
96
+ break
97
+ text += c
98
+ return text
99
+
100
+ # message should be bit string
101
+ # encoded should be text string
102
+ def expansion_ratio(message, encoded):
103
+ message_bits = len(message)
104
+ encoded_ba = bitarray.bitarray()
105
+ encoded_ba.frombytes(encoded.encode('utf-8'))
106
+ encoded_bits = len(encoded_ba.tolist())
107
+ return encoded_bits/message_bits
108
+
109
+ #@title
110
+ import torch
111
+ import math
112
+ import random
113
+
114
+ def bin_sort(l, token_indices, total, entropy, device):
115
+ #compute entropy for upper bound on the number of bins we need
116
+
117
+ bucket_size = total
118
+ num_bins = 2**int(entropy+1)
119
+ bucket_size = total / num_bins
120
+
121
+ bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins
122
+ value_in_bins = [0] * num_bins
123
+ space_left_after = [total - i*bucket_size for i in range(0,num_bins)]
124
+
125
+
126
+ token_bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins
127
+
128
+ # Figuring out what the search order should be
129
+ step_size = num_bins/4
130
+ search_order = []
131
+ priorities = [0]*num_bins
132
+ priority = 0
133
+ search_order.append(int(num_bins/2))
134
+ search_order.append(0)
135
+ priorities[int(num_bins/2)] = 0
136
+ priorities[0] = 0
137
+ while(step_size>=1):
138
+ priority += 1
139
+ for x in range(num_bins-int(step_size), -1, -int(step_size*2)):
140
+ search_order.append(x)
141
+ priorities[x] = priority
142
+ step_size = step_size/2
143
+
144
+ # Adding the actual elements
145
+ for (item, token_index) in zip(l.tolist(), token_indices.tolist()):
146
+ found_single_bucket_fit = False
147
+ single_bucket_index = -1
148
+ single_bucket_value = bucket_size
149
+
150
+ found_multi_bucket_bumpless_fit = False
151
+ multi_bucket_bumpless_index = -1
152
+ multi_bucket_bumpless_value = total
153
+
154
+ found_multi_bucket_bumping_fit = False
155
+ multi_bucket_bumping_index = -1
156
+ multi_bucket_bumping_value = total
157
+
158
+ for i in search_order: # for index in search_order
159
+ if(item > space_left_after[i]):
160
+ continue
161
+ if(value_in_bins[i] >= bucket_size):
162
+ continue
163
+
164
+ # Priority of choices
165
+ # 1. Can i place this thing in an empty bucket all on its own?
166
+ # 2. Can i plan this somewhere where is doesnt have to bump anything else around?
167
+ # 2a. Minimize the wasted space. Aka use the smallest space (of equal priority) that accomplishes this goal
168
+ # 3. If not (1) and (2), then put it in the space the bumps stuff the least.
169
+
170
+ if(value_in_bins[i] + item > bucket_size): #Would overflow.
171
+
172
+ space_before_next_block = bucket_size - value_in_bins[i]
173
+ for j in range(i+1, len(bins)):
174
+ if(value_in_bins[j] > 0): # We have found a bucket with something in it. This is how much space we have here.
175
+ space_before_next_block = space_before_next_block + (bucket_size - value_in_bins[i])
176
+ break
177
+ else: # This was a empty bucket
178
+ space_before_next_block = space_before_next_block + bucket_size
179
+
180
+ if((not found_multi_bucket_bumpless_fit) or (found_multi_bucket_bumpless_fit and priorities[i] <= priorities[multi_bucket_bumpless_index])): #This could potentially be a match
181
+
182
+ # If this is a valid space to put this without bumping and it is a better fit than previous spaces
183
+ if(space_before_next_block > item and space_before_next_block < multi_bucket_bumpless_value):
184
+ # set this to be the pointer! we can fit stuff here
185
+ found_multi_bucket_bumpless_fit = True
186
+ multi_bucket_bumpless_index = i
187
+ multi_bucket_bumpless_value = space_before_next_block
188
+
189
+ # Find the overflow that will bump the least
190
+ if ( item - space_before_next_block < multi_bucket_bumping_value):
191
+ found_multi_bucket_bumping_fit = True
192
+ multi_bucket_bumping_index = i
193
+ multi_bucket_bumping_value = item - space_before_next_block
194
+
195
+ if(value_in_bins[i] + item <= bucket_size): #Would fit
196
+ if(single_bucket_value > value_in_bins[i]):
197
+ found_single_bucket_fit = True
198
+ single_bucket_value = value_in_bins[i]
199
+ single_bucket_index = i
200
+
201
+ if (single_bucket_index == multi_bucket_bumpless_index == multi_bucket_bumping_index == -1):
202
+ bins[0] = torch.cat( (torch.tensor([item], device=device), bins[0]), 0)
203
+ token_bins[0] = torch.cat( (torch.tensor([token_index], device=device), token_bins[0]), 0)
204
+ continue
205
+
206
+
207
+ if found_single_bucket_fit:
208
+ # We found somewhere we can actually fit!
209
+ bins[single_bucket_index] = torch.cat( (bins[single_bucket_index], torch.tensor([item], device=device)), 0)
210
+ token_bins[single_bucket_index] = torch.cat( (token_bins[single_bucket_index], torch.tensor([token_index], device=device)), 0)
211
+ value_in_bins[single_bucket_index] += item
212
+ for i in range(0, single_bucket_index+1):
213
+ space_left_after[i] -= item
214
+
215
+ elif found_multi_bucket_bumpless_fit:
216
+ # Found somewhere we can put this without upsetting the force
217
+ part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumpless_index]
218
+ part_overflow = item - part_in_bucket
219
+ bins[multi_bucket_bumpless_index] = torch.cat( (bins[multi_bucket_bumpless_index], torch.tensor([item], device=device)), 0)
220
+ token_bins[multi_bucket_bumpless_index] = torch.cat( (token_bins[multi_bucket_bumpless_index], torch.tensor([token_index], device=device)), 0)
221
+ value_in_bins[multi_bucket_bumpless_index] = bucket_size
222
+
223
+ # Fill this bucket and continue overflowing
224
+ j = multi_bucket_bumpless_index + 1
225
+ for i in range(0, j):
226
+ space_left_after[i] -= item
227
+
228
+ while(part_overflow > 0):
229
+ new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size
230
+ value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled
231
+ space_left_after[j] -= part_overflow
232
+ part_overflow = new_part_overflow
233
+ j+=1
234
+
235
+ else:
236
+ part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumping_index]
237
+ part_overflow = item - part_in_bucket
238
+ bins[multi_bucket_bumping_index] = torch.cat( (bins[multi_bucket_bumping_index], torch.tensor([item], device=device)), 0)
239
+ token_bins[multi_bucket_bumping_index] = torch.cat( (token_bins[multi_bucket_bumping_index], torch.tensor([token_index], device=device)), 0)
240
+ value_in_bins[multi_bucket_bumping_index] = bucket_size
241
+
242
+ # Fill this bucket and continue overflowing
243
+ j = multi_bucket_bumping_index + 1
244
+ for i in range(0, j):
245
+ space_left_after[i] -= item
246
+ while(part_overflow > 0):
247
+ new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size
248
+ value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled
249
+ space_left_after[j] -= part_overflow
250
+ part_overflow = new_part_overflow
251
+ j+=1
252
+
253
+ sorted_tensor = torch.cat(bins, 0)
254
+ sorted_tokens = torch.cat(token_bins, 0)
255
+
256
+ return sorted_tensor, sorted_tokens
257
+
258
+ def compute_ev(t, precision):
259
+ expected_bits = []
260
+ cum_probs = t.cumsum(0)
261
+
262
+ for selection in range(0, len(cum_probs)):
263
+
264
+ # Calculate new range as ints
265
+ new_int_bottom = cum_probs[selection-1] if selection > 0 else 0
266
+ new_int_top = cum_probs[selection]
267
+
268
+ # Convert range to bits
269
+ new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
270
+ new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
271
+
272
+ # Consume most significant bits which are now fixed and update interval
273
+ num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
274
+ expected_bits.append(t[selection] * num_bits_encoded)
275
+
276
+ return(float(sum(expected_bits).item())/(2**precision))
277
+
278
+ def visualize_bins(values_in_bins, bucket_size):
279
+ out_str = "["
280
+ for b in values_in_bins:
281
+ out_str = out_str + " " + str(round(100*b/bucket_size,2)) + " |"
282
+ out_str = out_str + "]"
283
+ print(out_str)
284
+
285
+ def visualize_distribution(l):
286
+ total = sum(l)
287
+ out_str = "["
288
+ for b in l:
289
+ out_str = out_str + " " + str(round(100*b/total,2)) + " |"
290
+ out_str = out_str + "]"
291
+ print(out_str)
292
+
293
+ def compute_entropy(lists):
294
+ total = sum(lists)
295
+ entropy = -1*sum([ (x/total) * math.log2(x/total) for x in lists])
296
+ return entropy