Spaces:
Sleeping
Sleeping
Commit
·
a0ed808
0
Parent(s):
initial commit
Browse files- .gitattributes +31 -0
- README.md +13 -0
- app.py +25 -0
- bot/interface.py +48 -0
- bot/simctgdialogue.py +177 -0
- bot/utlis.py +174 -0
- requirements.txt +18 -0
.gitattributes
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Winnie
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.1.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from bot.interface import Chatbot
|
5 |
+
|
6 |
+
bot = Chatbot()
|
7 |
+
|
8 |
+
def greet(input_txt, history = []):
|
9 |
+
global bot
|
10 |
+
|
11 |
+
if bot is None:
|
12 |
+
bot = Chatbot()
|
13 |
+
|
14 |
+
history.append(input_txt)
|
15 |
+
response = bot.chat(history)
|
16 |
+
history.append(response)
|
17 |
+
|
18 |
+
return response, history
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
gr.Interface(fn=greet,
|
22 |
+
# title="使用中文和脑子瓦特了的Vicky聊天",
|
23 |
+
inputs=["text", "state"],
|
24 |
+
outputs=["text", "state"]
|
25 |
+
).launch()
|
bot/interface.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from random import choice
|
3 |
+
from random import randint
|
4 |
+
from random import uniform
|
5 |
+
|
6 |
+
from bot.simctgdialogue import SimCTGDialogue
|
7 |
+
|
8 |
+
class Chatbot():
|
9 |
+
def __init__(self):
|
10 |
+
self.model = SimCTGDialogue("cambridgeltl/simctg_lccc_dialogue", [])
|
11 |
+
self.tokenizer = self.model.tokenizer
|
12 |
+
self.model.eval()
|
13 |
+
|
14 |
+
def __contrastive_search(self, context_list):
|
15 |
+
print("__contrastive_search")
|
16 |
+
print(context_list)
|
17 |
+
beam_width, alpha, decoding_len = randint(1, 8), uniform(0.10, 0.40), 64
|
18 |
+
return self.model.contrastive_search(context_list, beam_width, alpha, decoding_len)
|
19 |
+
|
20 |
+
def __diverse_contrastive_search(self, context_list):
|
21 |
+
print("__diverse_contrastive_search")
|
22 |
+
print(context_list)
|
23 |
+
sample_step, nucleus_p = 1, uniform(0.10, 0.40)
|
24 |
+
beam_width, alpha, decoding_len = randint(1, 5), uniform(0.10, 0.40), 64
|
25 |
+
return self.model.diverse_contrastive_search(context_list, sample_step, nucleus_p, beam_width, alpha, decoding_len)
|
26 |
+
|
27 |
+
def __greedy_search(self, context_list):
|
28 |
+
print("__greedy_search")
|
29 |
+
print(context_list)
|
30 |
+
decoding_len = 64
|
31 |
+
return self.model.greedy_search(context_list, decoding_len)
|
32 |
+
|
33 |
+
def __beam_search(self, context_list):
|
34 |
+
print("__beam_search")
|
35 |
+
print(context_list)
|
36 |
+
beam_width, decoding_len = randint(1, 9), 64
|
37 |
+
return self.model.beam_search(context_list, beam_width, decoding_len)
|
38 |
+
|
39 |
+
def chat(self, prefix = []):
|
40 |
+
methods_for_sort_dialogue = [self.__contrastive_search, self.__greedy_search]
|
41 |
+
methods_for_long_dialogue = [self.__beam_search, self.__diverse_contrastive_search, self.__greedy_search, self.__contrastive_search]
|
42 |
+
|
43 |
+
if ( len(prefix) < 4 ):
|
44 |
+
response = choice(methods_for_sort_dialogue)(prefix)
|
45 |
+
else:
|
46 |
+
response = choice(methods_for_long_dialogue)(prefix)
|
47 |
+
|
48 |
+
return response
|
bot/simctgdialogue.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
class SimCTGDialogue(nn.Module):
|
7 |
+
def __init__(self, model_name, additional_special_tokens):
|
8 |
+
super(SimCTGDialogue, self).__init__()
|
9 |
+
from transformers import AutoTokenizer, GPT2LMHeadModel
|
10 |
+
eos_token = '[SEP]'
|
11 |
+
pad_token = '[PAD]'
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, additional_special_tokens=additional_special_tokens)
|
13 |
+
self.vocab_size = len(self.tokenizer)
|
14 |
+
self.model = GPT2LMHeadModel.from_pretrained(model_name)
|
15 |
+
self.embed_dim = self.model.config.hidden_size
|
16 |
+
if pad_token in self.tokenizer.vocab:
|
17 |
+
print ('PAD token exists.')
|
18 |
+
else:
|
19 |
+
print ('Add PAD token to the tokenizer.')
|
20 |
+
print ('Original vocabulary size is {}'.format(len(self.tokenizer)))
|
21 |
+
self.tokenizer.add_tokens([pad_token])
|
22 |
+
print ('Vocabulary size after extension is {}'.format(len(self.tokenizer)))
|
23 |
+
assert len(self.tokenizer.convert_tokens_to_ids([pad_token])) == 1
|
24 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
25 |
+
self.pad_token_id = self.tokenizer.convert_tokens_to_ids([pad_token])[0]
|
26 |
+
self.vocab_size = len(self.tokenizer)
|
27 |
+
if 'e' in eos_token:
|
28 |
+
self.eos_token = self.tokenizer.eos_token
|
29 |
+
else:
|
30 |
+
self.eos_token = eos_token
|
31 |
+
print (self.eos_token)
|
32 |
+
|
33 |
+
def parse_dialogue_context(self, context_list, cuda_available=False, device=0):
|
34 |
+
# context_list: a list of utterances in the dialogue session
|
35 |
+
uttr_num = len(context_list)
|
36 |
+
context_text = self.eos_token.join(context_list).strip(self.eos_token) + self.eos_token
|
37 |
+
#print (context_text)
|
38 |
+
tokens = self.tokenizer.tokenize(context_text)
|
39 |
+
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
40 |
+
input_ids = input_ids
|
41 |
+
input_ids = torch.LongTensor(input_ids).view(1,-1)
|
42 |
+
if cuda_available:
|
43 |
+
input_ids = input_ids.cuda(device)
|
44 |
+
return input_ids, uttr_num
|
45 |
+
|
46 |
+
def extract_response(self, output_ids, uttr_num):
|
47 |
+
output_text = self.tokenizer.decode(output_ids)
|
48 |
+
# extract response
|
49 |
+
item_list = output_text.split(self.eos_token)
|
50 |
+
response = item_list[uttr_num].strip()
|
51 |
+
if self.eos_token == '<|endoftext|>': # English GPT
|
52 |
+
response = ' '.join(response.split())
|
53 |
+
else:
|
54 |
+
response = ''.join(response.split())
|
55 |
+
return response
|
56 |
+
|
57 |
+
def contrastive_search(self, context_list, beam_width, alpha, decoding_len,
|
58 |
+
cuda_available=False, device=0):
|
59 |
+
input_ids, uttr_num = self.parse_dialogue_context(context_list,
|
60 |
+
cuda_available=cuda_available, device=device)
|
61 |
+
output = self.fast_contrastive_generation(input_ids, beam_width, alpha, decoding_len)
|
62 |
+
return self.extract_response(output, uttr_num)
|
63 |
+
|
64 |
+
def diverse_contrastive_search(self, context_list, sample_step, nucleus_p,
|
65 |
+
beam_width, alpha, decoding_len, cuda_available=False, device=0):
|
66 |
+
input_ids, uttr_num = self.parse_dialogue_context(context_list,
|
67 |
+
cuda_available=cuda_available, device=device)
|
68 |
+
output = self.diverse_contrastive_generation(input_ids, sample_step, nucleus_p,
|
69 |
+
beam_width, alpha, decoding_len)
|
70 |
+
return self.extract_response(output, uttr_num)
|
71 |
+
|
72 |
+
def greedy_search(self, context_list, decoding_len, cuda_available=False, device=0):
|
73 |
+
input_ids, uttr_num = self.parse_dialogue_context(context_list,
|
74 |
+
cuda_available=cuda_available, device=device)
|
75 |
+
output = self.greedy_generation(input_ids, decoding_len)
|
76 |
+
return self.extract_response(output, uttr_num)
|
77 |
+
|
78 |
+
def beam_search(self, context_list, beam_width, decoding_len,
|
79 |
+
cuda_available=False, device=0):
|
80 |
+
input_ids, uttr_num = self.parse_dialogue_context(context_list,
|
81 |
+
cuda_available=cuda_available, device=device)
|
82 |
+
output = self.beam_generation(input_ids, beam_width, decoding_len)
|
83 |
+
return self.extract_response(output, uttr_num)
|
84 |
+
|
85 |
+
def nucleus_sampling(self, context_list, nucleus_p, decoding_len,
|
86 |
+
cuda_available=False, device=0):
|
87 |
+
input_ids, uttr_num = self.parse_dialogue_context(context_list,
|
88 |
+
cuda_available=cuda_available, device=device)
|
89 |
+
output = self.nucleus_generation(input_ids, nucleus_p, decoding_len)
|
90 |
+
return self.extract_response(output, uttr_num)
|
91 |
+
|
92 |
+
def fast_contrastive_generation(self, input_ids, beam_width, alpha, decoding_len):
|
93 |
+
'''
|
94 |
+
input_ids: prefix input; 1 x prefix_len
|
95 |
+
decoding_len: how many tokens to generate
|
96 |
+
beam_width: size of candidate pool during decoding
|
97 |
+
alpha: regulates importance of model confidence and degeneration penalty
|
98 |
+
'''
|
99 |
+
self.model.eval()
|
100 |
+
from bot.utlis import ContrastiveDecodingOneStepFast
|
101 |
+
# sanity check
|
102 |
+
assert alpha >= 0. and alpha <= 1.0
|
103 |
+
|
104 |
+
# fast mode
|
105 |
+
batch_size, seqlen = input_ids.size()
|
106 |
+
#generated = [[] for _ in range(batch_size)]
|
107 |
+
generated = [item for item in input_ids.tolist()]
|
108 |
+
past_key_values = None
|
109 |
+
last_hidden_states = None
|
110 |
+
logits = None
|
111 |
+
for step in range(decoding_len):
|
112 |
+
input_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast(
|
113 |
+
self.model,
|
114 |
+
input_ids,
|
115 |
+
beam_width,
|
116 |
+
alpha,
|
117 |
+
past_key_values,
|
118 |
+
last_hidden_states,
|
119 |
+
self.tokenizer,
|
120 |
+
logits,
|
121 |
+
first_step=step == 0,
|
122 |
+
)
|
123 |
+
tokens = input_ids.squeeze(dim=-1).tolist()
|
124 |
+
for idx, t in enumerate(tokens):
|
125 |
+
generated[idx].append(t)
|
126 |
+
return generated[0]
|
127 |
+
|
128 |
+
def diverse_contrastive_generation(self, input_ids, sample_step, nucleus_p, beam_width, alpha, decoding_len):
|
129 |
+
'''
|
130 |
+
sample_step:
|
131 |
+
number of steps to decode with nucleus sampling,
|
132 |
+
for the remaining steps we use contrastive search
|
133 |
+
decoding_len:
|
134 |
+
the total number of generated tokens
|
135 |
+
beam_width:
|
136 |
+
size of candidate pool during decoding
|
137 |
+
alpha:
|
138 |
+
regulates importance of model confidence and degeneration penalty
|
139 |
+
|
140 |
+
'''
|
141 |
+
contrastive_step = decoding_len - sample_step
|
142 |
+
_, prefix_len = input_ids.size()
|
143 |
+
# first do sample
|
144 |
+
input_ids = self.model.generate(
|
145 |
+
input_ids,
|
146 |
+
do_sample=True,
|
147 |
+
max_length=prefix_len+sample_step,
|
148 |
+
top_p=nucleus_p,
|
149 |
+
top_k=0)
|
150 |
+
# then do contrastive search
|
151 |
+
output = self.fast_contrastive_generation(input_ids, beam_width, alpha, contrastive_step)
|
152 |
+
return output
|
153 |
+
|
154 |
+
def greedy_generation(self, input_ids, decoding_len):
|
155 |
+
_, prefix_len = input_ids.size()
|
156 |
+
output = self.model.generate(
|
157 |
+
input_ids,
|
158 |
+
max_length=prefix_len+decoding_len)
|
159 |
+
return output[0]
|
160 |
+
|
161 |
+
def beam_generation(self, input_ids, beam_width, decoding_len):
|
162 |
+
_, prefix_len = input_ids.size()
|
163 |
+
output = self.model.generate(
|
164 |
+
input_ids,
|
165 |
+
max_length=prefix_len+decoding_len,
|
166 |
+
num_beams=beam_width)
|
167 |
+
return output[0]
|
168 |
+
|
169 |
+
def nucleus_generation(self, input_ids, nucleus_p, decoding_len):
|
170 |
+
_, prefix_len = input_ids.size()
|
171 |
+
output = self.model.generate(
|
172 |
+
input_ids,
|
173 |
+
do_sample=True,
|
174 |
+
max_length=prefix_len+decoding_len,
|
175 |
+
top_p=nucleus_p,
|
176 |
+
top_k=0)
|
177 |
+
return output[0]
|
bot/utlis.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
def ranking(context_hidden, next_hidden, next_top_k_ids, next_top_k_probs, alpha):
|
7 |
+
'''
|
8 |
+
context_hidden: beam_width x context_len x embed_dim
|
9 |
+
next_hidden: beam_width x 1 x embed_dim
|
10 |
+
next_top_k_ids: beam_width x 1
|
11 |
+
'''
|
12 |
+
beam_width, context_len, embed_dim = context_hidden.size()
|
13 |
+
assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])
|
14 |
+
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
|
15 |
+
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
|
16 |
+
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)
|
17 |
+
assert cosine_matrix.size() == torch.Size([beam_width, context_len])
|
18 |
+
scores, _ = torch.max(cosine_matrix, dim = -1)
|
19 |
+
assert scores.size() == torch.Size([beam_width])
|
20 |
+
next_top_k_probs = next_top_k_probs.view(-1)
|
21 |
+
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
|
22 |
+
_, selected_idx = torch.topk(scores, k = 1)
|
23 |
+
assert selected_idx.size() == torch.Size([1])
|
24 |
+
selected_idx = selected_idx.unsqueeze(0)
|
25 |
+
assert selected_idx.size() == torch.Size([1,1])
|
26 |
+
next_id = torch.gather(next_top_k_ids, dim = 0, index=selected_idx)
|
27 |
+
assert next_id.size() == torch.Size([1,1])
|
28 |
+
return next_id
|
29 |
+
|
30 |
+
def ContrastiveDecodingOneStep(model, input_ids, beam_width, alpha):
|
31 |
+
'''
|
32 |
+
model: the generation model, e.g., gpt2
|
33 |
+
input_ids: 1 x seqlen
|
34 |
+
'''
|
35 |
+
prev_hidden_states, logits = model.compute_logits_and_hidden_states(input_ids)
|
36 |
+
_, seqlen, embed_dim = prev_hidden_states.size()
|
37 |
+
_, _, vocab_size = logits.size()
|
38 |
+
p = random.uniform(0, 1)
|
39 |
+
|
40 |
+
logit_for_next_step = logits[:,-1,:]
|
41 |
+
assert logit_for_next_step.size() == torch.Size([1, vocab_size])
|
42 |
+
|
43 |
+
next_probs = F.softmax(logit_for_next_step, dim = -1)
|
44 |
+
assert next_probs.size() == logit_for_next_step.size()
|
45 |
+
|
46 |
+
_, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width)
|
47 |
+
assert top_k_ids.size() == torch.Size([1, beam_width])
|
48 |
+
|
49 |
+
top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids)
|
50 |
+
|
51 |
+
assert top_k_probs.size() == top_k_ids.size()
|
52 |
+
# compute new hidden
|
53 |
+
expanded_context = [input_ids for _ in range(beam_width)]
|
54 |
+
expanded_context = torch.cat(expanded_context, dim = 0)
|
55 |
+
assert expanded_context.size() == torch.Size([beam_width, seqlen])
|
56 |
+
top_k_ids = top_k_ids.view(beam_width, 1)
|
57 |
+
next_input_ids = torch.cat([expanded_context, top_k_ids], dim = -1)
|
58 |
+
assert next_input_ids.size() == torch.Size([beam_width, seqlen+1])
|
59 |
+
new_hidden_states, next_logits = model.compute_logits_and_hidden_states(next_input_ids)
|
60 |
+
assert new_hidden_states.size() == torch.Size([beam_width, seqlen+1, embed_dim])
|
61 |
+
context_hidden = new_hidden_states[:,:seqlen,:]
|
62 |
+
assert context_hidden.size() == torch.Size([beam_width, seqlen, embed_dim])
|
63 |
+
next_hidden = new_hidden_states[:,seqlen:,:]
|
64 |
+
assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])
|
65 |
+
|
66 |
+
next_id = ranking(context_hidden, next_hidden, top_k_ids, top_k_probs, alpha)
|
67 |
+
|
68 |
+
next_input_ids = torch.cat([input_ids, next_id], dim = -1)
|
69 |
+
assert next_input_ids.size() == torch.Size([1, seqlen+1])
|
70 |
+
return next_input_ids
|
71 |
+
|
72 |
+
# ========== batch version ========= #
|
73 |
+
def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width):
|
74 |
+
'''
|
75 |
+
context_hidden: bsz*beam x seqlen x embed_dim
|
76 |
+
next_hidden: bsz*beam x 1 x embed_dim
|
77 |
+
next_top_k_probs: bsz x beam
|
78 |
+
'''
|
79 |
+
_, context_len, embed_dim = context_hidden.size()
|
80 |
+
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
|
81 |
+
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
|
82 |
+
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) # [B*K, S]
|
83 |
+
scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
|
84 |
+
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
|
85 |
+
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
|
86 |
+
scores = torch.stack(torch.split(scores, beam_width)) # [B, K]
|
87 |
+
selected_idx = scores.max(dim=-1)[1] # [B]
|
88 |
+
return selected_idx
|
89 |
+
|
90 |
+
def ContrastiveDecodingOneStepFast(
|
91 |
+
model,
|
92 |
+
ids,
|
93 |
+
beam_width,
|
94 |
+
alpha,
|
95 |
+
past_key_values,
|
96 |
+
last_hidden_states,
|
97 |
+
vocab,
|
98 |
+
logit_for_next_step,
|
99 |
+
first_step=False,
|
100 |
+
):
|
101 |
+
# input_ids: [B, S]
|
102 |
+
if first_step:
|
103 |
+
output = model(
|
104 |
+
input_ids=ids,
|
105 |
+
past_key_values=past_key_values,
|
106 |
+
use_cache=True,
|
107 |
+
output_hidden_states=True
|
108 |
+
)
|
109 |
+
past_key_values = output.past_key_values
|
110 |
+
last_hidden_states = output.hidden_states[-1] # [B, S, E]
|
111 |
+
logit_for_next_step = output.logits[:, -1, :] # [B, V]
|
112 |
+
bsz, seqlen, embed_dim = last_hidden_states.size()
|
113 |
+
p = random.uniform(0, 1)
|
114 |
+
|
115 |
+
next_probs = F.softmax(logit_for_next_step, dim=-1)
|
116 |
+
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K]
|
117 |
+
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) # [B, K]
|
118 |
+
# compute new hidden
|
119 |
+
past_key_values = enlarge_past_key_values(past_key_values, beam_width)
|
120 |
+
output = model(
|
121 |
+
input_ids=top_k_ids.view(-1, 1),
|
122 |
+
attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),
|
123 |
+
past_key_values=past_key_values,
|
124 |
+
output_hidden_states=True,
|
125 |
+
use_cache=True,
|
126 |
+
)
|
127 |
+
past_key_values = output.past_key_values
|
128 |
+
logits = output.logits[:, -1, :] # [B*K, V]
|
129 |
+
next_hidden = output.hidden_states[-1] # [B*K, 1, E]
|
130 |
+
context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) # [B*K, S, E]
|
131 |
+
|
132 |
+
selected_idx = ranking_fast(
|
133 |
+
context_hidden,
|
134 |
+
next_hidden,
|
135 |
+
top_k_probs, # [B, K]
|
136 |
+
alpha,
|
137 |
+
beam_width,
|
138 |
+
) # [B]
|
139 |
+
# prepare for the next step
|
140 |
+
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) # [B, 1]
|
141 |
+
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) # [B, K, E]
|
142 |
+
next_hidden = next_hidden[range(bsz), selected_idx, :] # [B, E]
|
143 |
+
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) # [B, S, E]
|
144 |
+
past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
|
145 |
+
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] # [B, V]
|
146 |
+
# next_id: [B, 1]
|
147 |
+
return next_id, past_key_values, last_hidden_states, logits
|
148 |
+
|
149 |
+
def enlarge_past_key_values(past_key_values, beam_width):
|
150 |
+
# from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
|
151 |
+
new_key_values = []
|
152 |
+
for layer in past_key_values:
|
153 |
+
items = []
|
154 |
+
for item in layer:
|
155 |
+
# item is the key and value matrix
|
156 |
+
bsz, num_head, seq_len, esz = item.size()
|
157 |
+
item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz) # [bsz*beam, num_head, seq_len, esz]
|
158 |
+
items.append(item)
|
159 |
+
new_key_values.append(items)
|
160 |
+
return new_key_values
|
161 |
+
|
162 |
+
def select_past_key_values(past_key_values, beam_width, selected_idx):
|
163 |
+
'''select_idx: [B]'''
|
164 |
+
new_key_values = []
|
165 |
+
for layer in past_key_values:
|
166 |
+
items = []
|
167 |
+
for item in layer:
|
168 |
+
bsz_and_beam, num_head, seq_len, esz = item.size()
|
169 |
+
bsz = int(bsz_and_beam//beam_width)
|
170 |
+
item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz]
|
171 |
+
item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz]
|
172 |
+
items.append(item)
|
173 |
+
new_key_values.append(items)
|
174 |
+
return new_key_values
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py
|
2 |
+
pytest
|
3 |
+
sacrebleu==1.4.10
|
4 |
+
six
|
5 |
+
wheel
|
6 |
+
progressbar
|
7 |
+
sklearn
|
8 |
+
torch==1.6.0
|
9 |
+
torchvision==0.7.0
|
10 |
+
transformers==4.7.0
|
11 |
+
pyyaml
|
12 |
+
nltk
|
13 |
+
sentencepiece
|
14 |
+
spacy
|
15 |
+
gdown
|
16 |
+
seaborn
|
17 |
+
matplotlib
|
18 |
+
pandas
|