lewiswu1209 commited on
Commit
a0ed808
·
0 Parent(s):

initial commit

Browse files
Files changed (7) hide show
  1. .gitattributes +31 -0
  2. README.md +13 -0
  3. app.py +25 -0
  4. bot/interface.py +48 -0
  5. bot/simctgdialogue.py +177 -0
  6. bot/utlis.py +174 -0
  7. requirements.txt +18 -0
.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Winnie
3
+ emoji: 🔥
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.1.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+
4
+ from bot.interface import Chatbot
5
+
6
+ bot = Chatbot()
7
+
8
+ def greet(input_txt, history = []):
9
+ global bot
10
+
11
+ if bot is None:
12
+ bot = Chatbot()
13
+
14
+ history.append(input_txt)
15
+ response = bot.chat(history)
16
+ history.append(response)
17
+
18
+ return response, history
19
+
20
+ if __name__ == "__main__":
21
+ gr.Interface(fn=greet,
22
+ # title="使用中文和脑子瓦特了的Vicky聊天",
23
+ inputs=["text", "state"],
24
+ outputs=["text", "state"]
25
+ ).launch()
bot/interface.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from random import choice
3
+ from random import randint
4
+ from random import uniform
5
+
6
+ from bot.simctgdialogue import SimCTGDialogue
7
+
8
+ class Chatbot():
9
+ def __init__(self):
10
+ self.model = SimCTGDialogue("cambridgeltl/simctg_lccc_dialogue", [])
11
+ self.tokenizer = self.model.tokenizer
12
+ self.model.eval()
13
+
14
+ def __contrastive_search(self, context_list):
15
+ print("__contrastive_search")
16
+ print(context_list)
17
+ beam_width, alpha, decoding_len = randint(1, 8), uniform(0.10, 0.40), 64
18
+ return self.model.contrastive_search(context_list, beam_width, alpha, decoding_len)
19
+
20
+ def __diverse_contrastive_search(self, context_list):
21
+ print("__diverse_contrastive_search")
22
+ print(context_list)
23
+ sample_step, nucleus_p = 1, uniform(0.10, 0.40)
24
+ beam_width, alpha, decoding_len = randint(1, 5), uniform(0.10, 0.40), 64
25
+ return self.model.diverse_contrastive_search(context_list, sample_step, nucleus_p, beam_width, alpha, decoding_len)
26
+
27
+ def __greedy_search(self, context_list):
28
+ print("__greedy_search")
29
+ print(context_list)
30
+ decoding_len = 64
31
+ return self.model.greedy_search(context_list, decoding_len)
32
+
33
+ def __beam_search(self, context_list):
34
+ print("__beam_search")
35
+ print(context_list)
36
+ beam_width, decoding_len = randint(1, 9), 64
37
+ return self.model.beam_search(context_list, beam_width, decoding_len)
38
+
39
+ def chat(self, prefix = []):
40
+ methods_for_sort_dialogue = [self.__contrastive_search, self.__greedy_search]
41
+ methods_for_long_dialogue = [self.__beam_search, self.__diverse_contrastive_search, self.__greedy_search, self.__contrastive_search]
42
+
43
+ if ( len(prefix) < 4 ):
44
+ response = choice(methods_for_sort_dialogue)(prefix)
45
+ else:
46
+ response = choice(methods_for_long_dialogue)(prefix)
47
+
48
+ return response
bot/simctgdialogue.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ from torch import nn
5
+
6
+ class SimCTGDialogue(nn.Module):
7
+ def __init__(self, model_name, additional_special_tokens):
8
+ super(SimCTGDialogue, self).__init__()
9
+ from transformers import AutoTokenizer, GPT2LMHeadModel
10
+ eos_token = '[SEP]'
11
+ pad_token = '[PAD]'
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, additional_special_tokens=additional_special_tokens)
13
+ self.vocab_size = len(self.tokenizer)
14
+ self.model = GPT2LMHeadModel.from_pretrained(model_name)
15
+ self.embed_dim = self.model.config.hidden_size
16
+ if pad_token in self.tokenizer.vocab:
17
+ print ('PAD token exists.')
18
+ else:
19
+ print ('Add PAD token to the tokenizer.')
20
+ print ('Original vocabulary size is {}'.format(len(self.tokenizer)))
21
+ self.tokenizer.add_tokens([pad_token])
22
+ print ('Vocabulary size after extension is {}'.format(len(self.tokenizer)))
23
+ assert len(self.tokenizer.convert_tokens_to_ids([pad_token])) == 1
24
+ self.model.resize_token_embeddings(len(self.tokenizer))
25
+ self.pad_token_id = self.tokenizer.convert_tokens_to_ids([pad_token])[0]
26
+ self.vocab_size = len(self.tokenizer)
27
+ if 'e' in eos_token:
28
+ self.eos_token = self.tokenizer.eos_token
29
+ else:
30
+ self.eos_token = eos_token
31
+ print (self.eos_token)
32
+
33
+ def parse_dialogue_context(self, context_list, cuda_available=False, device=0):
34
+ # context_list: a list of utterances in the dialogue session
35
+ uttr_num = len(context_list)
36
+ context_text = self.eos_token.join(context_list).strip(self.eos_token) + self.eos_token
37
+ #print (context_text)
38
+ tokens = self.tokenizer.tokenize(context_text)
39
+ input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
40
+ input_ids = input_ids
41
+ input_ids = torch.LongTensor(input_ids).view(1,-1)
42
+ if cuda_available:
43
+ input_ids = input_ids.cuda(device)
44
+ return input_ids, uttr_num
45
+
46
+ def extract_response(self, output_ids, uttr_num):
47
+ output_text = self.tokenizer.decode(output_ids)
48
+ # extract response
49
+ item_list = output_text.split(self.eos_token)
50
+ response = item_list[uttr_num].strip()
51
+ if self.eos_token == '<|endoftext|>': # English GPT
52
+ response = ' '.join(response.split())
53
+ else:
54
+ response = ''.join(response.split())
55
+ return response
56
+
57
+ def contrastive_search(self, context_list, beam_width, alpha, decoding_len,
58
+ cuda_available=False, device=0):
59
+ input_ids, uttr_num = self.parse_dialogue_context(context_list,
60
+ cuda_available=cuda_available, device=device)
61
+ output = self.fast_contrastive_generation(input_ids, beam_width, alpha, decoding_len)
62
+ return self.extract_response(output, uttr_num)
63
+
64
+ def diverse_contrastive_search(self, context_list, sample_step, nucleus_p,
65
+ beam_width, alpha, decoding_len, cuda_available=False, device=0):
66
+ input_ids, uttr_num = self.parse_dialogue_context(context_list,
67
+ cuda_available=cuda_available, device=device)
68
+ output = self.diverse_contrastive_generation(input_ids, sample_step, nucleus_p,
69
+ beam_width, alpha, decoding_len)
70
+ return self.extract_response(output, uttr_num)
71
+
72
+ def greedy_search(self, context_list, decoding_len, cuda_available=False, device=0):
73
+ input_ids, uttr_num = self.parse_dialogue_context(context_list,
74
+ cuda_available=cuda_available, device=device)
75
+ output = self.greedy_generation(input_ids, decoding_len)
76
+ return self.extract_response(output, uttr_num)
77
+
78
+ def beam_search(self, context_list, beam_width, decoding_len,
79
+ cuda_available=False, device=0):
80
+ input_ids, uttr_num = self.parse_dialogue_context(context_list,
81
+ cuda_available=cuda_available, device=device)
82
+ output = self.beam_generation(input_ids, beam_width, decoding_len)
83
+ return self.extract_response(output, uttr_num)
84
+
85
+ def nucleus_sampling(self, context_list, nucleus_p, decoding_len,
86
+ cuda_available=False, device=0):
87
+ input_ids, uttr_num = self.parse_dialogue_context(context_list,
88
+ cuda_available=cuda_available, device=device)
89
+ output = self.nucleus_generation(input_ids, nucleus_p, decoding_len)
90
+ return self.extract_response(output, uttr_num)
91
+
92
+ def fast_contrastive_generation(self, input_ids, beam_width, alpha, decoding_len):
93
+ '''
94
+ input_ids: prefix input; 1 x prefix_len
95
+ decoding_len: how many tokens to generate
96
+ beam_width: size of candidate pool during decoding
97
+ alpha: regulates importance of model confidence and degeneration penalty
98
+ '''
99
+ self.model.eval()
100
+ from bot.utlis import ContrastiveDecodingOneStepFast
101
+ # sanity check
102
+ assert alpha >= 0. and alpha <= 1.0
103
+
104
+ # fast mode
105
+ batch_size, seqlen = input_ids.size()
106
+ #generated = [[] for _ in range(batch_size)]
107
+ generated = [item for item in input_ids.tolist()]
108
+ past_key_values = None
109
+ last_hidden_states = None
110
+ logits = None
111
+ for step in range(decoding_len):
112
+ input_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast(
113
+ self.model,
114
+ input_ids,
115
+ beam_width,
116
+ alpha,
117
+ past_key_values,
118
+ last_hidden_states,
119
+ self.tokenizer,
120
+ logits,
121
+ first_step=step == 0,
122
+ )
123
+ tokens = input_ids.squeeze(dim=-1).tolist()
124
+ for idx, t in enumerate(tokens):
125
+ generated[idx].append(t)
126
+ return generated[0]
127
+
128
+ def diverse_contrastive_generation(self, input_ids, sample_step, nucleus_p, beam_width, alpha, decoding_len):
129
+ '''
130
+ sample_step:
131
+ number of steps to decode with nucleus sampling,
132
+ for the remaining steps we use contrastive search
133
+ decoding_len:
134
+ the total number of generated tokens
135
+ beam_width:
136
+ size of candidate pool during decoding
137
+ alpha:
138
+ regulates importance of model confidence and degeneration penalty
139
+
140
+ '''
141
+ contrastive_step = decoding_len - sample_step
142
+ _, prefix_len = input_ids.size()
143
+ # first do sample
144
+ input_ids = self.model.generate(
145
+ input_ids,
146
+ do_sample=True,
147
+ max_length=prefix_len+sample_step,
148
+ top_p=nucleus_p,
149
+ top_k=0)
150
+ # then do contrastive search
151
+ output = self.fast_contrastive_generation(input_ids, beam_width, alpha, contrastive_step)
152
+ return output
153
+
154
+ def greedy_generation(self, input_ids, decoding_len):
155
+ _, prefix_len = input_ids.size()
156
+ output = self.model.generate(
157
+ input_ids,
158
+ max_length=prefix_len+decoding_len)
159
+ return output[0]
160
+
161
+ def beam_generation(self, input_ids, beam_width, decoding_len):
162
+ _, prefix_len = input_ids.size()
163
+ output = self.model.generate(
164
+ input_ids,
165
+ max_length=prefix_len+decoding_len,
166
+ num_beams=beam_width)
167
+ return output[0]
168
+
169
+ def nucleus_generation(self, input_ids, nucleus_p, decoding_len):
170
+ _, prefix_len = input_ids.size()
171
+ output = self.model.generate(
172
+ input_ids,
173
+ do_sample=True,
174
+ max_length=prefix_len+decoding_len,
175
+ top_p=nucleus_p,
176
+ top_k=0)
177
+ return output[0]
bot/utlis.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import random
4
+ import torch.nn.functional as F
5
+
6
+ def ranking(context_hidden, next_hidden, next_top_k_ids, next_top_k_probs, alpha):
7
+ '''
8
+ context_hidden: beam_width x context_len x embed_dim
9
+ next_hidden: beam_width x 1 x embed_dim
10
+ next_top_k_ids: beam_width x 1
11
+ '''
12
+ beam_width, context_len, embed_dim = context_hidden.size()
13
+ assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])
14
+ norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
15
+ norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
16
+ cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)
17
+ assert cosine_matrix.size() == torch.Size([beam_width, context_len])
18
+ scores, _ = torch.max(cosine_matrix, dim = -1)
19
+ assert scores.size() == torch.Size([beam_width])
20
+ next_top_k_probs = next_top_k_probs.view(-1)
21
+ scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
22
+ _, selected_idx = torch.topk(scores, k = 1)
23
+ assert selected_idx.size() == torch.Size([1])
24
+ selected_idx = selected_idx.unsqueeze(0)
25
+ assert selected_idx.size() == torch.Size([1,1])
26
+ next_id = torch.gather(next_top_k_ids, dim = 0, index=selected_idx)
27
+ assert next_id.size() == torch.Size([1,1])
28
+ return next_id
29
+
30
+ def ContrastiveDecodingOneStep(model, input_ids, beam_width, alpha):
31
+ '''
32
+ model: the generation model, e.g., gpt2
33
+ input_ids: 1 x seqlen
34
+ '''
35
+ prev_hidden_states, logits = model.compute_logits_and_hidden_states(input_ids)
36
+ _, seqlen, embed_dim = prev_hidden_states.size()
37
+ _, _, vocab_size = logits.size()
38
+ p = random.uniform(0, 1)
39
+
40
+ logit_for_next_step = logits[:,-1,:]
41
+ assert logit_for_next_step.size() == torch.Size([1, vocab_size])
42
+
43
+ next_probs = F.softmax(logit_for_next_step, dim = -1)
44
+ assert next_probs.size() == logit_for_next_step.size()
45
+
46
+ _, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width)
47
+ assert top_k_ids.size() == torch.Size([1, beam_width])
48
+
49
+ top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids)
50
+
51
+ assert top_k_probs.size() == top_k_ids.size()
52
+ # compute new hidden
53
+ expanded_context = [input_ids for _ in range(beam_width)]
54
+ expanded_context = torch.cat(expanded_context, dim = 0)
55
+ assert expanded_context.size() == torch.Size([beam_width, seqlen])
56
+ top_k_ids = top_k_ids.view(beam_width, 1)
57
+ next_input_ids = torch.cat([expanded_context, top_k_ids], dim = -1)
58
+ assert next_input_ids.size() == torch.Size([beam_width, seqlen+1])
59
+ new_hidden_states, next_logits = model.compute_logits_and_hidden_states(next_input_ids)
60
+ assert new_hidden_states.size() == torch.Size([beam_width, seqlen+1, embed_dim])
61
+ context_hidden = new_hidden_states[:,:seqlen,:]
62
+ assert context_hidden.size() == torch.Size([beam_width, seqlen, embed_dim])
63
+ next_hidden = new_hidden_states[:,seqlen:,:]
64
+ assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])
65
+
66
+ next_id = ranking(context_hidden, next_hidden, top_k_ids, top_k_probs, alpha)
67
+
68
+ next_input_ids = torch.cat([input_ids, next_id], dim = -1)
69
+ assert next_input_ids.size() == torch.Size([1, seqlen+1])
70
+ return next_input_ids
71
+
72
+ # ========== batch version ========= #
73
+ def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width):
74
+ '''
75
+ context_hidden: bsz*beam x seqlen x embed_dim
76
+ next_hidden: bsz*beam x 1 x embed_dim
77
+ next_top_k_probs: bsz x beam
78
+ '''
79
+ _, context_len, embed_dim = context_hidden.size()
80
+ norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
81
+ norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
82
+ cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) # [B*K, S]
83
+ scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
84
+ next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
85
+ scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
86
+ scores = torch.stack(torch.split(scores, beam_width)) # [B, K]
87
+ selected_idx = scores.max(dim=-1)[1] # [B]
88
+ return selected_idx
89
+
90
+ def ContrastiveDecodingOneStepFast(
91
+ model,
92
+ ids,
93
+ beam_width,
94
+ alpha,
95
+ past_key_values,
96
+ last_hidden_states,
97
+ vocab,
98
+ logit_for_next_step,
99
+ first_step=False,
100
+ ):
101
+ # input_ids: [B, S]
102
+ if first_step:
103
+ output = model(
104
+ input_ids=ids,
105
+ past_key_values=past_key_values,
106
+ use_cache=True,
107
+ output_hidden_states=True
108
+ )
109
+ past_key_values = output.past_key_values
110
+ last_hidden_states = output.hidden_states[-1] # [B, S, E]
111
+ logit_for_next_step = output.logits[:, -1, :] # [B, V]
112
+ bsz, seqlen, embed_dim = last_hidden_states.size()
113
+ p = random.uniform(0, 1)
114
+
115
+ next_probs = F.softmax(logit_for_next_step, dim=-1)
116
+ _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K]
117
+ top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) # [B, K]
118
+ # compute new hidden
119
+ past_key_values = enlarge_past_key_values(past_key_values, beam_width)
120
+ output = model(
121
+ input_ids=top_k_ids.view(-1, 1),
122
+ attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),
123
+ past_key_values=past_key_values,
124
+ output_hidden_states=True,
125
+ use_cache=True,
126
+ )
127
+ past_key_values = output.past_key_values
128
+ logits = output.logits[:, -1, :] # [B*K, V]
129
+ next_hidden = output.hidden_states[-1] # [B*K, 1, E]
130
+ context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) # [B*K, S, E]
131
+
132
+ selected_idx = ranking_fast(
133
+ context_hidden,
134
+ next_hidden,
135
+ top_k_probs, # [B, K]
136
+ alpha,
137
+ beam_width,
138
+ ) # [B]
139
+ # prepare for the next step
140
+ next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) # [B, 1]
141
+ next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) # [B, K, E]
142
+ next_hidden = next_hidden[range(bsz), selected_idx, :] # [B, E]
143
+ last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) # [B, S, E]
144
+ past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
145
+ logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] # [B, V]
146
+ # next_id: [B, 1]
147
+ return next_id, past_key_values, last_hidden_states, logits
148
+
149
+ def enlarge_past_key_values(past_key_values, beam_width):
150
+ # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
151
+ new_key_values = []
152
+ for layer in past_key_values:
153
+ items = []
154
+ for item in layer:
155
+ # item is the key and value matrix
156
+ bsz, num_head, seq_len, esz = item.size()
157
+ item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz) # [bsz*beam, num_head, seq_len, esz]
158
+ items.append(item)
159
+ new_key_values.append(items)
160
+ return new_key_values
161
+
162
+ def select_past_key_values(past_key_values, beam_width, selected_idx):
163
+ '''select_idx: [B]'''
164
+ new_key_values = []
165
+ for layer in past_key_values:
166
+ items = []
167
+ for item in layer:
168
+ bsz_and_beam, num_head, seq_len, esz = item.size()
169
+ bsz = int(bsz_and_beam//beam_width)
170
+ item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz]
171
+ item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz]
172
+ items.append(item)
173
+ new_key_values.append(items)
174
+ return new_key_values
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py
2
+ pytest
3
+ sacrebleu==1.4.10
4
+ six
5
+ wheel
6
+ progressbar
7
+ sklearn
8
+ torch==1.6.0
9
+ torchvision==0.7.0
10
+ transformers==4.7.0
11
+ pyyaml
12
+ nltk
13
+ sentencepiece
14
+ spacy
15
+ gdown
16
+ seaborn
17
+ matplotlib
18
+ pandas