English
naveensp commited on
Commit
064158a
·
verified ·
1 Parent(s): 7ba9065

Upload beam_search.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. beam_search.py +1078 -0
beam_search.py ADDED
@@ -0,0 +1,1078 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a self-contained and flexible beam search implementation adapted from
3
+ AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py
4
+ """
5
+
6
+ import copy
7
+ import warnings
8
+ from abc import abstractmethod
9
+ from inspect import signature
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
11
+
12
+ import torch
13
+
14
+ __all__ = [
15
+ "Sampler",
16
+ "DeterministicSampler",
17
+ "MultinomialSampler",
18
+ "TopKSampler",
19
+ "TopPSampler",
20
+ "GumbelSampler",
21
+ "FinalSequenceScorer",
22
+ "SequenceLogProbabilityScorer",
23
+ "LengthNormalizedSequenceLogProbabilityScorer",
24
+ "Constraint",
25
+ "RepeatedNGramBlockingConstraint",
26
+ "BeamSearch",
27
+ ]
28
+
29
+ StateType = Dict[str, torch.Tensor]
30
+ StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
31
+ StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
32
+
33
+ StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
34
+ """
35
+ The type of step function that can be passed to [`BeamSearch.search`](#search).
36
+
37
+ This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep)
38
+ or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep).
39
+ """
40
+
41
+ ConstraintStateType = List[List[Dict[str, Any]]]
42
+
43
+
44
+ class Sampler:
45
+ """
46
+ An abstract class that can be used to sample candidates (either nodes or beams)
47
+ within `BeamSearch`.
48
+
49
+ A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`.
50
+
51
+ `init_state()` takes three arguments:
52
+
53
+ - a tensor of starting log probs with shape `(batch_size,, num_classes)`,
54
+ - the batch size, an int,
55
+ - and the number of classes, also an int.
56
+
57
+ It returns a state dictionary with any state tensors needed for subsequent
58
+ calls to `sample_nodes()` and `sample_beams()`.
59
+
60
+ By default this method just returns an empty dictionary.
61
+
62
+ Both `sample_nodes()` and `sample_beams()` should take three arguments:
63
+
64
+ - tensor of normalized log probabilities with shape `(batch_size, num_examples)`,
65
+ - an integer representing the number of samples to take for each example in the batch,
66
+ - and a state dictionary which could contain any tensors needed for the `Sampler` to keep
67
+ track of state.
68
+
69
+ For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`,
70
+ `num_examples = beam_size * per_node_beam_size`.
71
+
72
+ The return value should be a tuple containing:
73
+
74
+ - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`,
75
+ - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`,
76
+ - and the updated state dictionary.
77
+
78
+ A default implementation of `sample_beams` is provided, which just deterministically
79
+ picks the `k` examples with highest log probability.
80
+ """
81
+
82
+ def init_state(
83
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
84
+ ) -> StateType:
85
+ del start_class_log_probabilities, batch_size, num_classes
86
+ return {}
87
+
88
+ @abstractmethod
89
+ def sample_nodes(
90
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
91
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
92
+ raise NotImplementedError
93
+
94
+ def sample_beams(
95
+ self, log_probs: torch.Tensor, beam_size: int, state: StateType
96
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
97
+ del state
98
+ selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
99
+ return selected_log_probs, selected_indices, {}
100
+
101
+
102
+ class DeterministicSampler(Sampler):
103
+ """
104
+ A `Sampler` that just deterministically returns the `k` nodes or beams with highest
105
+ log probability.
106
+ """
107
+
108
+ def sample_nodes(
109
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
111
+ del state
112
+ selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1)
113
+ return selected_log_probs, selected_indices, {}
114
+
115
+
116
+ class MultinomialSampler(Sampler):
117
+ """
118
+ A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled
119
+ in the default, non-deterministic way.
120
+
121
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
122
+ above 1.0 produces a flatter probability distribution.
123
+ :param with_replacement: Whether to sample with replacement.
124
+
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ temperature: float = 1.0,
130
+ with_replacement: bool = False,
131
+ ) -> None:
132
+ self.temperature = temperature
133
+ self.with_replacement = with_replacement
134
+
135
+ def sample_nodes(
136
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
137
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
138
+ if self.temperature != 1.0:
139
+ _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1)
140
+ else:
141
+ _probabilities = log_probs.exp()
142
+
143
+ selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement)
144
+
145
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
146
+
147
+
148
+ class TopKSampler(Sampler):
149
+ """
150
+ A `Sampler` which redistributes the probability mass function for nodes among the
151
+ top `k` choices, then samples from that subset after re-normalizing the probabilities.
152
+
153
+ Beams are sampled in the default, deterministic way.
154
+
155
+ :param k: The number of top choices to be selected from.
156
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
157
+ above 1.0 produces a flatter probability distribution.
158
+ :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ k: int = 1,
164
+ temperature: float = 1.0,
165
+ with_replacement: bool = False,
166
+ ):
167
+ self.k = k
168
+ self.temperature = temperature or 1.0
169
+ self.with_replacement = with_replacement
170
+
171
+ def sample_nodes(
172
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
173
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
174
+ if not per_node_beam_size <= self.k <= log_probs.size()[1]:
175
+ raise ValueError(
176
+ "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
177
+ )
178
+
179
+ # shape (both): (batch_size, k)
180
+ top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)
181
+
182
+ # Apply temperature if necessary.
183
+ # shape: (batch_size, k)
184
+ if self.temperature != 1.0:
185
+ top_k_log_probs = top_k_log_probs / self.temperature
186
+
187
+ # Re-normalize the subset.
188
+ # shape: (batch_size, k)
189
+ normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)
190
+
191
+ # Sample from the re-normalized subset.
192
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
193
+ # shape: (batch_size, per_node_beam_size)
194
+ sampled_indices = torch.multinomial(
195
+ normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
196
+ )
197
+
198
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
199
+ # shape: (batch_size, per_node_beam_size)
200
+ indices = top_k_indices.gather(-1, sampled_indices)
201
+
202
+ return log_probs.gather(1, indices), indices, state
203
+
204
+
205
+ class TopPSampler(Sampler):
206
+ """
207
+ A `Sampler` which redistributes the probability mass function for nodes among
208
+ the top choices with a cumulative probability of at least `p`, then samples from that subset
209
+ after re-normalizing the probabilities.
210
+
211
+ Beams are sampled in the default, deterministic way.
212
+
213
+ :param p:
214
+ The cumulative probability cutoff threshold. A higher value of `p` will result in more possible
215
+ examples to sample from. If `with_replacement` is `False` and the number of possible samples is
216
+ insufficient to sample without replacement from when calling `sample_nodes`, then the top
217
+ `per_node_beam_size` examples will be chosen.
218
+ :param temperature:
219
+ A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
220
+ above 1.0 produces a flatter probability distribution.
221
+ :param with_replacement:
222
+ If set to `True`, samples will be selected with replacement from the top choices.
223
+
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ p: float = 0.9,
229
+ temperature: float = 1.0,
230
+ with_replacement: bool = False,
231
+ ):
232
+ if p < 0.0 or p > 1.0:
233
+ raise ValueError("p must be a positive float no greater than 1.0")
234
+ self.p = p
235
+ self.temperature = temperature or 1.0
236
+ self.with_replacement = with_replacement
237
+
238
+ def sample_nodes(
239
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
240
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
241
+ if not per_node_beam_size <= log_probs.size()[1]:
242
+ raise ValueError("per_node_beam_size cannot be greater than vocabulary size")
243
+
244
+ # First apply temperature coefficient:
245
+ if self.temperature != 1.0:
246
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
247
+ else:
248
+ _log_probs = log_probs
249
+
250
+ # Sort the probabilities in descending order to then find cumulative sum
251
+ log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)
252
+
253
+ # shape: (batch_size, num_classes)
254
+ probabilities_descending = log_probs_descending.exp()
255
+ probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)
256
+
257
+ # Create a mask for filtering out probabilities that don't make the top `p`.
258
+ # shape: (batch_size, num_classes)
259
+ exclusion_mask = probabilities_summed >= self.p
260
+
261
+ # We want to include the first index where probabilities_summed >= p, so we shift over one.
262
+ exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
263
+ exclusion_mask[..., 0] = False
264
+
265
+ # Make sure there's at least `per_node_beam_size` options to be selected.
266
+ if not self.with_replacement:
267
+ exclusion_mask[..., :per_node_beam_size] = False
268
+
269
+ log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min
270
+
271
+ # Now re-normalized the included log probs.
272
+ # shape: (batch_size, num_classes)
273
+ filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)
274
+
275
+ # Sample from the re-normalized subset.
276
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
277
+ # shape: (batch_size, per_node_beam_size)
278
+ sampled_indices = torch.multinomial(
279
+ filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
280
+ )
281
+
282
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
283
+ # shape: (batch_size, per_node_beam_size)
284
+ selected_indices = sorting_indices.gather(-1, sampled_indices)
285
+
286
+ # Return (selected log probabilities, selected classes)
287
+ # shape: (len(log_probs),1) , (len(log_probs), 1)
288
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
289
+
290
+
291
+ class GumbelSampler(Sampler):
292
+ """
293
+ A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See
294
+ [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
295
+ Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010]
296
+ (https://api.semanticscholar.org/CorpusID:76662039).
297
+
298
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
299
+ above 1.0 produces a flatter probability distribution.
300
+ """
301
+
302
+ def __init__(self, temperature: float = 1.0):
303
+ self.temperature = temperature
304
+
305
+ def init_state(
306
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
307
+ ) -> StateType:
308
+ # shape: (batch_size, num_classes)
309
+ zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes))
310
+
311
+ # shape: (batch_size, num_classes)
312
+ G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros)
313
+
314
+ return {"G_phi_S": G_phi_S}
315
+
316
+ def sample_nodes(
317
+ self,
318
+ log_probs: torch.Tensor,
319
+ per_node_beam_size: int,
320
+ state: StateType,
321
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
322
+ # First apply temperature coefficient:
323
+ # shape: (batch_size * beam_size, num_classes)
324
+ if self.temperature != 1.0:
325
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
326
+ else:
327
+ _log_probs = log_probs
328
+
329
+ # shape: (group_size,)
330
+ phi_S = state["phi_S"]
331
+
332
+ # shape: (group_size, num_classes)
333
+ phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs)
334
+
335
+ # shape: (group_size, num_classes)
336
+ phi_S_new = phi_S + _log_probs
337
+
338
+ # shape: (group_size, 1)
339
+ G_phi_S = state["G_phi_S"].unsqueeze(-1)
340
+
341
+ # shape: (group_size, num_classes)
342
+ G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S)
343
+
344
+ # Replace NaNs with very negative number.
345
+ # shape: (group_size, num_classes)
346
+ # G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min
347
+
348
+ # shape (both): (group_size, per_node_beam_size)
349
+ top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1)
350
+
351
+ # shape: (group_size, per_node_beam_size)
352
+ top_log_probs = log_probs.gather(1, top_indices)
353
+
354
+ return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new}
355
+
356
+ def sample_beams(
357
+ self,
358
+ log_probs: torch.Tensor,
359
+ beam_size: int,
360
+ state: StateType,
361
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
362
+ """
363
+ Returns the beams with the highest perturbed log probabilities.
364
+ """
365
+ # shape (log_probs): (batch_size, beam_size * per_node_beam_size)
366
+
367
+ batch_size = log_probs.size()[0]
368
+
369
+ # shape: (batch_size * beam_size, per_node_beam_size)
370
+ G_phi_S = state["G_phi_S"]
371
+
372
+ # shape: (batch_size, beam_size * per_node_beam_size)
373
+ G_phi_S = G_phi_S.reshape_as(log_probs)
374
+
375
+ # shape (both): (batch_size, beam_size)
376
+ G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
377
+
378
+ # shape: (batch_size, beam_size)
379
+ selected_log_probs = log_probs.gather(1, selected_indices)
380
+
381
+ # Now sort the selected beams by their true log prob.
382
+ # shape (all): (batch_size, beam_size)
383
+ selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
384
+ selected_indices = selected_indices.gather(1, sort_indices)
385
+ G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
386
+
387
+ # shape: (batch_size * beam_size,)
388
+ G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
389
+
390
+ # shape: (batch_size * beam_size,)
391
+ phi_S = selected_log_probs.reshape(batch_size * beam_size)
392
+
393
+ return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S}
394
+
395
+ def gumbel(self, phi) -> torch.Tensor:
396
+ """
397
+ Sample `Gumbel(phi)`.
398
+
399
+ `phi` should have shape `(batch_size, num_classes)`.
400
+ """
401
+ return -torch.log(-torch.log(torch.rand_like(phi))) + phi
402
+
403
+ def gumbel_with_max(self, phi, T) -> torch.Tensor:
404
+ """
405
+ Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`.
406
+
407
+ `phi` should have shape `(batch_size, num_classes)` and `T` should have
408
+ shape `(batch_size, 1)`.
409
+ """
410
+ # Shape: (batch_size, num_classes)
411
+ G_phi = self.gumbel(phi)
412
+
413
+ # Now we find the maximum from these samples.
414
+ # Shape: (batch_size, )
415
+ Z, _ = G_phi.max(dim=-1)
416
+
417
+ # Shape: (batch_size, num_classes)
418
+ v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1)))
419
+
420
+ # Shape: (batch_size, num_classes)
421
+ return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))
422
+
423
+
424
+ class FinalSequenceScorer:
425
+ """
426
+ An abstract class that can be used to score the final generated sequences found
427
+ by beam search. Given the predicted sequences and the corresponding log probabilities of
428
+ those sequences, the class calculates and returns the final score of the sequences.
429
+
430
+ The default implementation scores the sequences using the sum of the log probabilities of
431
+ the sequence, which is passed as input.
432
+ """
433
+
434
+ @abstractmethod
435
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
436
+ """
437
+ Score the final predictions found by beam search.
438
+ Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`.
439
+
440
+ :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.
441
+ :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum
442
+ of the log probabilities per token, with shape `(batch_size, beam_size)`.
443
+ :param end_index: The index of the end symbol.
444
+
445
+ """
446
+ raise NotImplementedError
447
+
448
+
449
+ class SequenceLogProbabilityScorer(FinalSequenceScorer):
450
+ """
451
+ A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
452
+ across the sequence's tokens.
453
+ """
454
+
455
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
456
+ del predictions, end_index
457
+ # The sum of the sequence log probabilities is the input parameter, so just
458
+ # return it.
459
+ return log_probabilities
460
+
461
+
462
+ class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
463
+ """
464
+ A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the
465
+ tokens in the sequence. It optionally includes a length penalty which promotes
466
+ or demotes sequences based on their lengths. The final score for a sequence will
467
+ be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
468
+ here includes the end token.
469
+
470
+ :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used.
471
+ A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
472
+ """
473
+
474
+ def __init__(self, length_penalty: float = 1.0):
475
+ super().__init__()
476
+ self.length_penalty = length_penalty
477
+
478
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
479
+ # shape: (batch_size, beam_size)
480
+ lengths = (predictions != end_index).long().sum(dim=2)
481
+
482
+ # If the sequence ended during beam search, the `log_probabilities` will include
483
+ # the transition to the end token. Therefore, in such situations, `lengths` is
484
+ # actually off by 1. This corrects for that.
485
+ # shape: (batch_size, beam_size)
486
+ is_end_token = predictions[:, :, -1] == end_index
487
+ lengths += is_end_token.long()
488
+
489
+ # shape: (batch_size, beam_size)
490
+ average_log_probs = log_probabilities / (lengths**self.length_penalty)
491
+ return average_log_probs
492
+
493
+
494
+ class Constraint:
495
+ """
496
+ An abstract class that can be used to enforce constraints on the output predictions
497
+ by manipulating the class log probabilities during beam search.
498
+
499
+ A `Constraint` just has three methods that need to be implemented by subclasses:
500
+ `init_state()`, `apply()` and `_update_state()`.
501
+
502
+ `init_state()` takes one argument:
503
+
504
+ - the batch size, an int
505
+
506
+ It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
507
+ calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`.
508
+ Each inner list should be of length 1.
509
+
510
+ `apply()` takes two arguments:
511
+
512
+ - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size`
513
+ and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1.
514
+ - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
515
+ log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.
516
+
517
+ The `apply()` method should return new `class_log_probabilities` that enforce the constraint
518
+ for this step of beam search. For instance, it may prevent a specific class from being selected by setting
519
+ the corresponding log probability to a negligible value such as `float("-inf")` or
520
+ `torch.finfo(class_log_probabilities.dtype).min`.
521
+
522
+ `_update_state()` takes two arguments:
523
+
524
+ - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the
525
+ copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be
526
+ directly edited in-place without affecting the others.
527
+ - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last
528
+ step of beam search.
529
+
530
+ The `_update_state()` function should return a new constraint state, a nested list of dictionaries of
531
+ length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`.
532
+
533
+ """
534
+
535
+ @abstractmethod
536
+ def init_state(
537
+ self,
538
+ batch_size: int,
539
+ ) -> ConstraintStateType:
540
+ raise NotImplementedError
541
+
542
+ @abstractmethod
543
+ def apply(
544
+ self,
545
+ state: ConstraintStateType,
546
+ class_log_probabilities: torch.Tensor,
547
+ ) -> torch.Tensor:
548
+ raise NotImplementedError
549
+
550
+ @staticmethod
551
+ def _copy_state(
552
+ state: ConstraintStateType,
553
+ batch_size: int,
554
+ beam_size: int,
555
+ last_backpointer: Optional[torch.Tensor] = None,
556
+ ) -> ConstraintStateType:
557
+ """
558
+ Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this
559
+ is not appropriate for your constraint, you will need to implement the copying yourself.
560
+ """
561
+ new_state = []
562
+ for i in range(batch_size):
563
+ batch_state = []
564
+ for j in range(beam_size):
565
+ if last_backpointer is None:
566
+ # This is the first prediction, so the backpointer is 0
567
+ backpointer = 0
568
+ else:
569
+ backpointer = last_backpointer[i, j].item()
570
+ batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore
571
+ new_state.append(batch_state)
572
+ return new_state
573
+
574
+ def update_state(
575
+ self,
576
+ state: ConstraintStateType,
577
+ last_prediction: torch.Tensor,
578
+ last_backpointer: Optional[torch.Tensor] = None,
579
+ ) -> ConstraintStateType:
580
+ batch_size, beam_size = last_prediction.size()
581
+ new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
582
+ return self._update_state(new_state, last_prediction)
583
+
584
+ @abstractmethod
585
+ def _update_state(
586
+ self,
587
+ state: ConstraintStateType,
588
+ last_prediction: torch.Tensor,
589
+ ) -> ConstraintStateType:
590
+ raise NotImplementedError
591
+
592
+
593
+ class RepeatedNGramBlockingConstraint(Constraint):
594
+ def __init__(self, ngram_size: int, **kwargs) -> None:
595
+ super().__init__(**kwargs)
596
+ self.ngram_size = ngram_size
597
+
598
+ def init_state(
599
+ self,
600
+ batch_size: int,
601
+ ) -> ConstraintStateType:
602
+ return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]
603
+
604
+ def apply(
605
+ self,
606
+ state: ConstraintStateType,
607
+ class_log_probabilities: torch.Tensor,
608
+ ) -> torch.Tensor:
609
+ for i, batch in enumerate(state):
610
+ for j, beam in enumerate(batch):
611
+ current_prefix = tuple(beam["current_prefix"])
612
+ seen_ngrams = beam["seen_ngrams"]
613
+ try:
614
+ disallowed_indices = seen_ngrams[current_prefix]
615
+ class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
616
+ class_log_probabilities.dtype
617
+ ).min
618
+ except KeyError:
619
+ # We have not seen this prefix before, so there is no index
620
+ # that needs to be blocked
621
+ pass
622
+ return class_log_probabilities
623
+
624
+ def _update_state(
625
+ self,
626
+ state: ConstraintStateType,
627
+ last_prediction: torch.Tensor,
628
+ ) -> ConstraintStateType:
629
+ for i, batch in enumerate(state):
630
+ for j, beam in enumerate(batch):
631
+ prediction = last_prediction[i, j].item()
632
+ prefix = beam["current_prefix"]
633
+ seen_ngrams = beam["seen_ngrams"]
634
+
635
+ if len(prefix) == self.ngram_size - 1:
636
+ # This is a new ngram that we have to remember
637
+ if tuple(prefix) not in seen_ngrams:
638
+ seen_ngrams[tuple(prefix)] = []
639
+ seen_ngrams[tuple(prefix)].append(prediction)
640
+
641
+ # Create the new prefix, removing the oldest index if the prefix
642
+ # is too long
643
+ prefix.append(prediction)
644
+ if len(prefix) == self.ngram_size:
645
+ prefix.pop(0)
646
+ return state
647
+
648
+
649
+ class BeamSearch:
650
+ """
651
+ Implements the beam search algorithm for decoding the most likely sequences.
652
+
653
+ :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID.
654
+
655
+ :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length
656
+ of the predicted sequences.
657
+
658
+ :param beam_size: The width of the beam used.
659
+
660
+ :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search.
661
+ If not given, this just defaults to `beam_size`. Setting this parameter
662
+ to a number smaller than `beam_size` may give better results, as it can introduce
663
+ more diversity into the search. See
664
+ [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017]
665
+ (https://api.semanticscholar.org/CorpusID:2229477).
666
+
667
+ :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams.
668
+ If not specified, `DeterministicSampler` will be used, which just takes the
669
+ `per_node_beam_size` most likely nodes and the `beam_size` most likely beams.
670
+
671
+ Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
672
+ [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
673
+
674
+ :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of
675
+ the predicted sequences. This does not include the start or end tokens. If `None`,
676
+ no minimum is enforced.
677
+
678
+ :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences.
679
+ The output from this module is what is returned by the `search` method. If not
680
+ specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
681
+ by the sum of the token log probabilities.
682
+
683
+ :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not
684
+ provided, no constraints will be enforced.
685
+
686
+ """
687
+
688
+ def __init__(
689
+ self,
690
+ end_index: int,
691
+ *,
692
+ max_steps: int = 50,
693
+ beam_size: int = 10,
694
+ per_node_beam_size: Optional[int] = None,
695
+ sampler: Optional[Sampler] = None,
696
+ min_steps: Optional[int] = None,
697
+ final_sequence_scorer: Optional[FinalSequenceScorer] = None,
698
+ constraints: Optional[List[Constraint]] = None,
699
+ ) -> None:
700
+ if not max_steps > 0:
701
+ raise ValueError("max_steps must be positive")
702
+ if not beam_size > 0:
703
+ raise ValueError("beam_size must be positive")
704
+ if per_node_beam_size is not None and not per_node_beam_size > 0:
705
+ raise ValueError("per_node_beam_size must be positive")
706
+ if min_steps is not None:
707
+ if not min_steps >= 0:
708
+ raise ValueError("min_steps must be non-negative")
709
+ if not min_steps <= max_steps:
710
+ raise ValueError("min_steps must be less than or equal to max_steps")
711
+
712
+ self._end_index = end_index
713
+ self.max_steps = max_steps
714
+ self.beam_size = beam_size
715
+ self.per_node_beam_size = per_node_beam_size or beam_size
716
+ self.sampler = sampler or DeterministicSampler()
717
+ self.min_steps = min_steps or 0
718
+ self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
719
+ self.constraints = constraints or []
720
+
721
+ @staticmethod
722
+ def _reconstruct_sequences(predictions, backpointers):
723
+ # Reconstruct the sequences.
724
+ # shape: [(batch_size, beam_size, 1)]
725
+ reconstructed_predictions = [predictions[-1].unsqueeze(2)]
726
+
727
+ if not backpointers:
728
+ return reconstructed_predictions
729
+
730
+ # shape: (batch_size, beam_size)
731
+ cur_backpointers = backpointers[-1]
732
+
733
+ for timestep in range(len(predictions) - 2, 0, -1):
734
+ # shape: (batch_size, beam_size, 1)
735
+ cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
736
+
737
+ reconstructed_predictions.append(cur_preds)
738
+
739
+ # shape: (batch_size, beam_size)
740
+ cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
741
+
742
+ # shape: (batch_size, beam_size, 1)
743
+ final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
744
+
745
+ reconstructed_predictions.append(final_preds)
746
+
747
+ return reconstructed_predictions
748
+
749
+ def search(
750
+ self,
751
+ start_predictions: torch.Tensor,
752
+ start_state: StateType,
753
+ step: StepFunctionType,
754
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
755
+ """
756
+ Given a starting state and a step function, apply beam search to find the
757
+ most likely target sequences.
758
+
759
+ Returns a tuple of `(predictions, final_scores)`, where `predictions`
760
+ has shape `(batch_size, beam_size, max_steps)` and `final_scores`
761
+ has shape `(batch_size, beam_size)`.
762
+
763
+ .. note::
764
+ If your step function returns `-inf` for some log probabilities
765
+ (like if you're using a masked log-softmax) then some of the "best"
766
+ sequences returned may also have `-inf` log probability. Specifically
767
+ this happens when the beam size is smaller than the number of actions
768
+ with finite log probability (non-zero probability) returned by the step function.
769
+ Therefore if you're using a mask you may want to check the results from `search`
770
+ and potentially discard sequences with non-finite log probability.
771
+
772
+ :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`.
773
+ Usually the initial predictions are just the index of the "start" token
774
+ in the target vocabulary.
775
+
776
+ :param start_state: The initial state passed to the `step` function. Each value of the state dict
777
+ should be a tensor of shape `(batch_size, *)`, where `*` means any other
778
+ number of dimensions.
779
+
780
+ :param step: A function that is responsible for computing the next most likely tokens,
781
+ given the current state and the predictions from the last time step.
782
+ The function should accept two or three arguments:
783
+
784
+ - a tensor of shape `(group_size,)` or representing the index of the predicted
785
+ tokens from the last time step,
786
+ - the current state, a `StateType`, and
787
+ - optionally, the timestep, an `int`.
788
+
789
+ The `group_size` will be `batch_size * beam_size`, except in the initial
790
+ step, for which it will just be `batch_size`.
791
+
792
+ The function is expected to return a tuple, where the first element
793
+ is a tensor of shape `(group_size, vocab_size)` containing
794
+ the log probabilities of the tokens for the next step, and the second
795
+ element is the updated state. The tensor in the state should have shape
796
+ `(group_size, *)`, where `*` means any other number of dimensions.
797
+
798
+ """
799
+ step_signature = signature(step)
800
+ if len(step_signature.parameters) < 3:
801
+ # If the step function we're given does not take the time step argument, wrap it
802
+ # in one that does.
803
+ old_step = cast(StepFunctionTypeNoTimestep, step)
804
+
805
+ def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
806
+ del time_step
807
+ return old_step(last_predictions, state)
808
+
809
+ return self._search(start_predictions, start_state, new_step)
810
+ else:
811
+ return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))
812
+
813
+ def _search(
814
+ self,
815
+ start_predictions: torch.Tensor,
816
+ start_state: StateType,
817
+ step: StepFunctionTypeWithTimestep,
818
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
819
+ batch_size = start_predictions.size()[0]
820
+
821
+ # List of (batch_size, beam_size) tensors. One for each time step. Does not
822
+ # include the start symbols, which are implicit.
823
+ predictions: List[torch.Tensor] = []
824
+
825
+ # List of (batch_size, beam_size) tensors. One for each time step. None for
826
+ # the first. Stores the index n for the parent prediction, i.e.
827
+ # predictions[t-1][i][n], that it came from.
828
+ backpointers: List[torch.Tensor] = []
829
+
830
+ constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]
831
+
832
+ # Calculate the first timestep. This is done outside the main loop
833
+ # because we are going from a single decoder input (the output from the
834
+ # encoder) to the top `beam_size` decoder outputs. On the other hand,
835
+ # within the main loop we are going from the `beam_size` elements of the
836
+ # beam to `beam_size`^2 candidates from which we will select the top
837
+ # `beam_size` elements for the next iteration.
838
+ # shape: (batch_size, num_classes)
839
+ start_class_log_probabilities, state = step(start_predictions, start_state, 0)
840
+
841
+ num_classes = start_class_log_probabilities.size()[1]
842
+
843
+ # Make sure `per_node_beam_size` is not larger than `num_classes`.
844
+ if self.per_node_beam_size > num_classes:
845
+ raise ValueError(
846
+ f"Vocab size ({num_classes:d}) too small "
847
+ f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
848
+ f"Please decrease beam_size or per_node_beam_size."
849
+ )
850
+
851
+ sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)
852
+
853
+ # Apply all constraints.
854
+ if self.constraints:
855
+ # shape: (batch_size, 1, num_classes)
856
+ expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
857
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
858
+ expanded_start_class_log_probabilities = constraint.apply(
859
+ constraint_state, expanded_start_class_log_probabilities
860
+ )
861
+ start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)
862
+
863
+ # Prevent selecting the end symbol if there is any min_steps constraint
864
+ if self.min_steps >= 1:
865
+ start_class_log_probabilities[:, self._end_index] = torch.finfo(
866
+ start_class_log_probabilities.dtype
867
+ ).min
868
+
869
+ # Get the initial predicted classed and their log probabilities.
870
+ # shape: (batch_size, beam_size), (batch_size, beam_size)
871
+ (
872
+ start_top_log_probabilities,
873
+ start_predicted_classes,
874
+ sampler_state,
875
+ ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)
876
+
877
+ if self.beam_size == 1 and (start_predicted_classes == self._end_index).all():
878
+ warnings.warn(
879
+ "Empty sequences predicted. You may want to increase the beam size or ensure "
880
+ "your step function is working properly.",
881
+ RuntimeWarning,
882
+ )
883
+ return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
884
+
885
+ # The log probabilities for the last time step.
886
+ # shape: (batch_size, beam_size)
887
+ last_log_probabilities = start_top_log_probabilities
888
+
889
+ # shape: [(batch_size, beam_size)]
890
+ predictions.append(start_predicted_classes)
891
+
892
+ # Log probability tensor that mandates that the end token is selected.
893
+ # shape: (batch_size * beam_size, num_classes)
894
+ log_probs_after_end = start_class_log_probabilities.new_full(
895
+ (batch_size * self.beam_size, num_classes),
896
+ torch.finfo(start_class_log_probabilities.dtype).min,
897
+ )
898
+ log_probs_after_end[:, self._end_index] = 0.0
899
+
900
+ # Set the same state for each element in the beam.
901
+ self._update_initial_state(state, batch_size)
902
+
903
+ for i, constraint in enumerate(self.constraints):
904
+ constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)
905
+
906
+ for timestep in range(self.max_steps - 1):
907
+ # shape: (batch_size * beam_size,)
908
+ last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
909
+
910
+ # If every predicted token from the last step is `self._end_index`,
911
+ # then we can stop early.
912
+ if (last_predictions == self._end_index).all():
913
+ break
914
+ # Take a step. This get the predicted log probs of the next classes
915
+ # and updates the state.
916
+ # shape: (batch_size * beam_size, num_classes)
917
+ class_log_probabilities, state = step(last_predictions, state, timestep + 1)
918
+
919
+ # Apply all constraints.
920
+ if self.constraints:
921
+ # shape: (batch_size, beam_size, num_classes)
922
+ reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
923
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
924
+ reshaped_class_log_probabilities = constraint.apply(
925
+ constraint_state, reshaped_class_log_probabilities
926
+ )
927
+ # shape: (batch_size * beam_size, num_classes)
928
+ class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)
929
+
930
+ # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
931
+ # of the sequence (because `timestep` is 0-indexed and we generated the first token
932
+ # before the for loop). Here we block the end index if the search is not allowed to
933
+ # terminate on this iteration.
934
+ if timestep + 2 <= self.min_steps:
935
+ class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min
936
+
937
+ # shape: (batch_size * beam_size, num_classes)
938
+ last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
939
+ batch_size * self.beam_size, num_classes
940
+ )
941
+
942
+ # Here we are finding any beams where we predicted the end token in
943
+ # the previous timestep and replacing the distribution with a
944
+ # one-hot distribution, forcing the beam to predict the end token
945
+ # this timestep as well.
946
+ # shape: (batch_size * beam_size, num_classes)
947
+ cleaned_log_probabilities = torch.where(
948
+ last_predictions_expanded == self._end_index,
949
+ log_probs_after_end,
950
+ class_log_probabilities,
951
+ )
952
+
953
+ # shape (both): (batch_size * beam_size, per_node_beam_size)
954
+ top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
955
+ cleaned_log_probabilities, self.per_node_beam_size, sampler_state
956
+ )
957
+
958
+ # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
959
+ # so that we can add them to the current log probs for this timestep.
960
+ # This lets us maintain the log probability of each element on the beam.
961
+ # shape: (batch_size * beam_size, per_node_beam_size)
962
+ expanded_last_log_probabilities = (
963
+ last_log_probabilities.unsqueeze(2)
964
+ .expand(batch_size, self.beam_size, self.per_node_beam_size)
965
+ .reshape(batch_size * self.beam_size, self.per_node_beam_size)
966
+ )
967
+
968
+ # shape: (batch_size * beam_size, per_node_beam_size)
969
+ summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
970
+
971
+ # shape: (batch_size, beam_size * per_node_beam_size)
972
+ reshaped_summed = summed_top_log_probabilities.reshape(
973
+ batch_size, self.beam_size * self.per_node_beam_size
974
+ )
975
+
976
+ # shape: (batch_size, beam_size * per_node_beam_size)
977
+ reshaped_predicted_classes = predicted_classes.reshape(
978
+ batch_size, self.beam_size * self.per_node_beam_size
979
+ )
980
+
981
+ # Keep only the top `beam_size` beam indices.
982
+ # shape (both): (batch_size, beam_size)
983
+ (
984
+ restricted_beam_log_probs,
985
+ restricted_beam_indices,
986
+ sampler_state,
987
+ ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)
988
+
989
+ # Use the beam indices to extract the corresponding classes.
990
+ # shape: (batch_size, beam_size)
991
+ restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
992
+
993
+ predictions.append(restricted_predicted_classes)
994
+
995
+ # shape: (batch_size, beam_size)
996
+ last_log_probabilities = restricted_beam_log_probs
997
+
998
+ # The beam indices come from a `beam_size * per_node_beam_size` dimension where the
999
+ # indices with a common ancestor are grouped together. Hence
1000
+ # dividing by per_node_beam_size gives the ancestor. (Note that this is integer
1001
+ # division as the tensor is a LongTensor.)
1002
+ # shape: (batch_size, beam_size)
1003
+ backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
1004
+ backpointers.append(backpointer)
1005
+
1006
+ # Keep only the pieces of the state tensors corresponding to the
1007
+ # ancestors created this iteration.
1008
+ self._update_state(state, backpointer)
1009
+
1010
+ for i, constraint in enumerate(self.constraints):
1011
+ constraint_states[i] = constraint.update_state(
1012
+ constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
1013
+ )
1014
+
1015
+ # Warn about "-inf" log probabilities if not using any constraints (negligible
1016
+ # log probabilities are expected when using constraints).
1017
+ if not self.constraints and (
1018
+ not torch.isfinite(last_log_probabilities).all()
1019
+ or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
1020
+ ):
1021
+ warnings.warn(
1022
+ "Negligible log probabilities encountered ('-inf' or equivalent). "
1023
+ "Some final sequences may not make sense. "
1024
+ "This can happen when the beam size is larger than the number of valid (non-zero "
1025
+ "probability) transitions that the step function produces.",
1026
+ RuntimeWarning,
1027
+ )
1028
+
1029
+ reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)
1030
+
1031
+ # shape: (batch_size, beam_size, max_steps)
1032
+ all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
1033
+
1034
+ # Calculate the final sequence scores
1035
+ # shape: (batch_size, beam_size)
1036
+ final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)
1037
+
1038
+ # Sort the sequences based on the final scores so the best scoring
1039
+ # sequence is at index 0
1040
+ sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
1041
+ sorted_all_predictions = torch.gather(
1042
+ all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
1043
+ )
1044
+
1045
+ return sorted_all_predictions, sorted_final_scores
1046
+
1047
+ def _update_initial_state(self, state: StateType, batch_size: int):
1048
+ """
1049
+ Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
1050
+ """
1051
+ for key, state_tensor in state.items():
1052
+ if state_tensor is None:
1053
+ continue
1054
+ # shape: (batch_size * beam_size, *)
1055
+ _, *last_dims = state_tensor.size()
1056
+ state[key] = (
1057
+ state_tensor.unsqueeze(1)
1058
+ .expand(batch_size, self.beam_size, *last_dims)
1059
+ .reshape(batch_size * self.beam_size, *last_dims)
1060
+ )
1061
+
1062
+ def _update_state(self, state: StateType, backpointer: torch.Tensor):
1063
+ batch_size = backpointer.size()[0]
1064
+
1065
+ for key, state_tensor in state.items():
1066
+ if state_tensor is None:
1067
+ continue
1068
+ _, *last_dims = state_tensor.size()
1069
+ # shape: (batch_size, beam_size, *)
1070
+ expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
1071
+ batch_size, self.beam_size, *last_dims
1072
+ )
1073
+ # shape: (batch_size * beam_size, *)
1074
+ state[key] = (
1075
+ state_tensor.reshape(batch_size, self.beam_size, *last_dims)
1076
+ .gather(1, expanded_backpointer)
1077
+ .reshape(batch_size * self.beam_size, *last_dims)
1078
+ )