File size: 10,795 Bytes
0dce0bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
import numpy as np
import torch
import torch.nn.functional as F
from sat.generation.sampling_strategies.base_strategy import top_k_logits
from sat.mpu.initialize import get_model_parallel_world_size, get_model_parallel_src_rank, get_model_parallel_group
class AdvancedBaseStrategy:
def __init__(self, batch_size, invalid_slices=[], temperature=1., no_repeat_ngram_size = 0, top_k=200, eps=1e-4, top_p=0.0, min_gen_length=1, end_tokens=None):
self.batch_size = batch_size
self.invalid_slices = invalid_slices
self.temperature = temperature
self.topk = top_k
self.top_p = top_p
self.eps = eps
self.min_gen_length = min_gen_length
self.ngram=no_repeat_ngram_size
if end_tokens is None:
end_tokens = []
self.end_tokens = end_tokens
self.length_generated = 0
self.cached_beam_ngram_bans = [{} for _ in range(self.batch_size)]
self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
self._init_cache()
@property
def is_done(self) -> bool:
return self._is_done.all()
def _init_cache(self):
self.length_generated = 0
self.cached_beam_ngram_bans = [[{}] for _ in range(self.batch_size)]
self._is_done = np.zeros(self.batch_size, dtype=bool)
def forward(self, logits, tokens, mems, is_first = False, temperature=None):
# print(is_first)
batch_size, num_beam, seq_len = tokens.shape
seq_len = tokens.shape[-1]
if temperature is None:
temperature = self.temperature
logits = logits / temperature
if self.min_gen_length > self.length_generated:
for end_token in self.end_tokens:
logits[..., end_token] = -65504
for invalid_slice in self.invalid_slices:
logits[..., invalid_slice] = -65504
if self.ngram > 0 and seq_len > self.ngram:
for batch_idx in range(batch_size):
for i in range(num_beam):
ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
logits[batch_idx, i, banned_index] = -65504
logits = logits.view(-1, logits.size(-1))
logits = top_k_logits(logits, self.topk, self.top_p)
probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
pred = torch.multinomial(probs, num_samples=1)
for i in range(self.batch_size):
if i >= batch_size:
self._is_done[i] = True
elif self._is_done[i]:
pred[i] = -1
elif pred[i].item() in self.end_tokens:
self._is_done[i] = True
if self.ngram > 0:
for batch_idx in range(batch_size):
bans_continue = []
for i in range(num_beam):
bans = self.cached_beam_ngram_bans[batch_idx][i].copy()
ngram_prefix = tuple(tokens[batch_idx, i, -(self.ngram - 1):].tolist())
bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (pred[batch_idx],)
bans_continue.append(bans)
self.cached_beam_ngram_bans[batch_idx] = bans_continue
tokens = torch.cat((tokens, pred.view(tokens.shape[:-1] + (1,))), dim=-1)
self.length_generated += 1
return tokens, mems
def finalize(self, tokens, mems):
self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
self._init_cache()
return tokens, mems
class BeamSearchStrategy:
def __init__(
self,
batch_size,
num_beams,
length_penalty=1.0,
consider_end=False,
end_tokens=[],
invalid_slices=[],
no_repeat_ngram_size=0,
min_gen_length=0,
deterministic=False,
):
self.batch_size = batch_size
self.num_beams = num_beams
self.length_penalty = length_penalty
self.end_tokens = end_tokens
self.ngram = no_repeat_ngram_size
self.min_gen_length = min_gen_length
self.invalid_slices = invalid_slices
self.consider_end = consider_end
self.deterministic = deterministic
self._init_cache()
def _init_cache(self):
self.end_beams = [[] for _ in range(self.batch_size)] # list of LongTensors
self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)] # list of LongTensors
self.cached_beam_scores = 0 # [batch_size]
self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
self.length_generated = 0
self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
def _add_end_beams(self, score, beam, batch_idx):
score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT
for i in range(len(self.end_beams[batch_idx]), -1, -1):
if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
break
self.end_beams[batch_idx].insert(i, beam)
self.end_beams_penalized_scores[batch_idx].insert(i, score)
self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
@property
def is_done(self) -> bool:
return self._is_done.all()
def forward(self, logits, tokens, mems):
batch_size, num_beams, vocab_size = logits.shape
seq_len = tokens.shape[-1]
logits = logits.float()
for invalid_slice in self.invalid_slices:
logits[..., invalid_slice] = -65504
if self.min_gen_length > self.length_generated:
for end_token in self.end_tokens:
logits[..., end_token] = -65504
if self.ngram > 0 and seq_len > self.ngram:
for batch_idx in range(batch_size):
for i in range(num_beams):
ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
logits[batch_idx, i, banned_index] = -65504
next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
prev_scores = self.cached_beam_scores
if isinstance(prev_scores, torch.Tensor):
prev_scores = prev_scores[..., None].expand_as(next_token_scores)
next_token_scores = next_token_scores + prev_scores
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = F.softmax(next_token_scores, dim=-1)
if num_beams < self.num_beams: # First token
probs = probs[..., :vocab_size]
if self.deterministic:
next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices # [2*nb]
else:
next_tokens = torch.multinomial(
probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
) # [2*nb]
next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
next_tokens = next_tokens % vocab_size
# select out end beams or continue beams
beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
for batch_idx in range(batch_size):
beam_continue = []
scores_continue = []
bans_continue = []
mems_contiue = []
for i in range(len(next_tokens[batch_idx])):
beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
elif len(beam_continue) < self.num_beams:
beam_continue.append(beam)
mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
# update caches
scores_continue.append(next_token_scores[batch_idx, i])
if self.ngram > 0:
bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
# TODO ngram=1
ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
bans_continue.append(bans)
else:
break
beam_continue_batch.append(torch.stack(beam_continue))
mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
score_continue_batch.append(scores_continue)
self.cached_beam_ngram_bans[batch_idx] = bans_continue
tokens = torch.stack(beam_continue_batch)
mems = torch.stack(mems_continue_batch, dim=1)
self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
self.length_generated += 1
for batch_idx in range(self.batch_size):
if batch_idx >= batch_size:
self._is_done[batch_idx] = True
elif (
len(self.end_beams[batch_idx]) == self.num_beams
and self.end_beams_penalized_scores[batch_idx][-1]
>= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
): # We're done if none of current tokens will better than the worst in end_beams
self._is_done[batch_idx] = True
return tokens, mems
def finalize(self, tokens, mems):
if self.consider_end:
batch_size, num_beams = tokens.shape[:2]
for batch_idx in range(batch_size):
if not self._is_done[batch_idx]:
for i in range(num_beams):
self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
mems = None
ret = self.end_beams[:batch_size]
else:
ret = tokens
self._init_cache()
return ret, mems |