# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sampling module for top_k, top_p and greedy decoding.""" import abc from typing import Any, Callable, Dict, Optional import numpy as np import tensorflow as tf, tf_keras from official.nlp.modeling.ops import decoding_module def greedy(log_probs): """Returns the top ids and scores based on greedy decoding.""" log_probs, ids = tf.math.top_k(log_probs, k=1) return log_probs, ids def sample_logits_with_temperature(logits, temperature): """Applies a sampling temperature. Temperature skews the distribution towards high probability tokens and lowers the mass in tail distribution. Args: logits: Input logits for next token. temperature: Tensor for specifying the sampling temperature. Returns: Logits with applied temperature. """ return logits / temperature def sample_top_k(logits, top_k): """Chooses top_k logits and sets the others to negative infinity. Args: logits: Input logits for next token. top_k: Tensor to specify the top_k values. Returns: Logits with top_k filtering applied. """ top_k = tf.clip_by_value( top_k, clip_value_min=1, clip_value_max=tf.shape(logits)[-1]) top_k_logits = tf.math.top_k(logits, k=top_k) indices_to_remove = logits < tf.expand_dims(top_k_logits[0][..., -1], -1) top_k_logits = set_tensor_by_indices_to_value(logits, indices_to_remove, np.NINF) return top_k_logits def sample_top_p(logits, top_p): """Chooses most probable logits with cumulative probabilities upto top_p. Sets the remaining logits to negative infinity. Args: logits: Input logits for next token. top_p: Float tensor with a value >=0 and < 1.0 Returns: Logits with top_p filtering applied. """ sorted_indices = tf.argsort(logits, direction="DESCENDING") # Flatten logits as tf.gather on TPU needs axis to be compile time constant. logits_shape = decoding_module.shape_list(logits) range_for_gather = tf.expand_dims(tf.range(0, logits_shape[0]), axis=1) range_for_gather = tf.tile(range_for_gather * logits_shape[1], [1, logits_shape[1]]) + sorted_indices flattened_logits = tf.reshape(logits, [-1]) flattened_sorted_indices = tf.reshape(range_for_gather, [-1]) sorted_logits = tf.reshape( tf.gather(flattened_logits, flattened_sorted_indices), [logits_shape[0], logits_shape[1]]) cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) # Remove tokens with cumulative probability above the threshold. sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep the first token above threshold. sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1) sorted_indices_to_remove = tf.concat([ tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:] ], -1) # Scatter sorted indices to original indexes. indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices) top_p_logits = set_tensor_by_indices_to_value(logits, indices_to_remove, np.NINF) return top_p_logits def scatter_values_on_batch_indices(values, batch_indices): """Scatter `values` into a tensor using `batch_indices`. Args: values: tensor of shape [batch_size, vocab_size] containing the values to scatter batch_indices: tensor of shape [batch_size, vocab_size] containing the indices to insert (should be a permutation in range(0, n)) Returns: Tensor of shape [batch_size, vocab_size] with values inserted at batch_indices """ tensor_shape = decoding_module.shape_list(batch_indices) broad_casted_batch_dims = tf.reshape( tf.broadcast_to( tf.expand_dims(tf.range(tensor_shape[0]), axis=-1), tensor_shape), [1, -1]) pair_indices = tf.transpose( tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0)) return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), tensor_shape) def set_tensor_by_indices_to_value(input_tensor, indices, value): """Where indices is True, set the value in input_tensor to value. Args: input_tensor: float (batch_size, dim) indices: bool (batch_size, dim) value: float scalar Returns: output_tensor: same shape as input_tensor. """ value_tensor = tf.zeros_like(input_tensor) + value output_tensor = tf.where(indices, value_tensor, input_tensor) return output_tensor class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): """Implementation for sampling strategies (go/decoding-tf-nlp).""" def __init__(self, symbols_to_logits_fn, vocab_size: int, max_decode_length: int, eos_id: int, padded_decode: bool, length_normalization_fn: Optional[Callable[[int, tf.DType], float]] = None, top_k=0, top_p=1.0, sample_temperature=0.0, enable_greedy: bool = True, dtype: tf.DType = tf.float32, decoding_name: Optional[str] = None, extra_cache_output: bool = False): """Initialize sampling module.""" self.symbols_to_logits_fn = symbols_to_logits_fn self.length_normalization_fn = length_normalization_fn self.eos_id = eos_id self.padded_decode = padded_decode self.dtype = tf.as_dtype(dtype) self.vocab_size = tf.convert_to_tensor(vocab_size, dtype=tf.int32) self.max_decode_length = max_decode_length self.top_k = tf.convert_to_tensor(top_k, dtype=tf.int32) self.top_p = tf.convert_to_tensor(top_p, dtype=tf.float32) self.sample_temperature = tf.convert_to_tensor( sample_temperature, dtype=tf.float32) self.enable_greedy = enable_greedy self.decoding_name = decoding_name self.extra_cache_output = extra_cache_output super(SamplingModule, self).__init__( length_normalization_fn=length_normalization_fn, dtype=dtype, decoding_name=decoding_name, extra_cache_output=extra_cache_output) def _grow_alive_seq(self, state: Dict[str, Any], batch_size: int) -> decoding_module.InternalState: """Grow alive sequences by one token. This function will implement the decoding strategies like top_p, top_k and greedy for the choosing the next logit. Args: state: A dictionary with the current loop state. batch_size: The given batch size Returns: Tuple of (Top sequences [batch, curr_index + 1] or [batch, max_decode_length + 1], Scores of returned sequences [batch, 1], New ids [batch, 1], New alive cache) """ i = state[decoding_module.StateKeys.CUR_INDEX] alive_seq = state[decoding_module.StateKeys.ALIVE_SEQ] alive_log_probs = state[decoding_module.StateKeys.ALIVE_LOG_PROBS] alive_cache = state[decoding_module.StateKeys.ALIVE_CACHE] if self.padded_decode: ids = tf.slice(alive_seq, [0, i], [batch_size, 1]) else: ids = alive_seq new_logits, new_cache = self.symbols_to_logits_fn(ids, i, alive_cache) candidate_log_probs = decoding_module.log_prob_from_logits( new_logits) original_log_probs = candidate_log_probs + alive_log_probs topk_log_probs, topk_ids = None, None if self.enable_greedy: topk_log_probs, topk_ids = greedy(original_log_probs) else: temperature_fn = sample_logits_with_temperature sampled_logits = tf.cond( self.sample_temperature > 0.0, lambda: temperature_fn(new_logits, self.sample_temperature), lambda: new_logits) sampled_logits = tf.cond( self.top_k > 0, lambda: sample_top_k(sampled_logits, self.top_k), lambda: sampled_logits) sampled_logits = tf.cond( self.top_p < 1, lambda: sample_top_p(sampled_logits, self.top_p), lambda: sampled_logits) topk_ids = tf.random.categorical( sampled_logits, dtype=tf.int32, num_samples=1) topk_log_probs = tf.gather( original_log_probs, topk_ids, axis=1, batch_dims=1) if self.padded_decode: topk_seq = tf.transpose(alive_seq, perm=[1, 0]) topk_seq = tf.tensor_scatter_nd_update( topk_seq, [[i + 1]], tf.expand_dims(tf.squeeze(topk_ids, -1), 0)) topk_seq = tf.transpose(topk_seq, perm=[1, 0]) else: topk_seq = tf.concat([alive_seq, topk_ids], axis=-1) return topk_seq, topk_log_probs, topk_ids, new_cache def _create_initial_state( self, initial_ids: tf.Tensor, initial_cache: Dict[str, tf.Tensor], batch_size: int, initial_log_probs: Optional[tf.Tensor] = None ) -> decoding_module.InitialState: """Return initial state dictionary and its shape invariants.""" for key, value in initial_cache.items(): for inner_value in tf.nest.flatten(value): if inner_value.dtype != self.dtype: raise TypeError( "initial_cache element for key '%s' has dtype %s that does not " "match sampling_module's dtype of %s. Value: %s" % (key, value.dtype.name, self.dtype.name, inner_value)) # Current loop index (starts at 0) cur_index = tf.constant(0) # Alive sequence with shape [batch_size, 1] alive_seq = initial_ids alive_seq = tf.expand_dims(alive_seq, axis=-1) if self.padded_decode: alive_seq = tf.tile(alive_seq, [1, self.max_decode_length + 1]) # Initial log probabilities with shape [batch_size, 1]. if initial_log_probs is None: initial_log_probs = tf.constant([[0.]], dtype=self.dtype) alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1]) else: alive_log_probs = initial_log_probs alive_cache = initial_cache # Initialize tensor storing finished sequences [batch_size, 1, 1]. finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32) # Set scores of the initial finished seqs to negative infinity. finished_scores = tf.zeros([batch_size, 1], dtype=self.dtype) # Initialize finished flags with all False values. finished_flags = tf.zeros([batch_size, 1], tf.bool) # Create state dictionary and state shapes. state = { decoding_module.StateKeys.CUR_INDEX: cur_index, decoding_module.StateKeys.ALIVE_SEQ: alive_seq, decoding_module.StateKeys.ALIVE_LOG_PROBS: alive_log_probs, decoding_module.StateKeys.ALIVE_CACHE: alive_cache, decoding_module.StateKeys.FINISHED_SEQ: finished_seq, decoding_module.StateKeys.FINISHED_SCORES: finished_scores, decoding_module.StateKeys.FINISHED_FLAGS: finished_flags } if self.padded_decode: state_shape_invariants = { decoding_module.StateKeys.CUR_INDEX: tf.TensorShape([]), decoding_module.StateKeys.ALIVE_SEQ: tf.TensorShape([batch_size, self.max_decode_length + 1]), decoding_module.StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([batch_size, 1]), decoding_module.StateKeys.ALIVE_CACHE: tf.nest.map_structure(lambda state: state.get_shape(), alive_cache), decoding_module.StateKeys.FINISHED_SEQ: tf.TensorShape([batch_size, self.max_decode_length + 1]), decoding_module.StateKeys.FINISHED_SCORES: tf.TensorShape([batch_size, 1]), decoding_module.StateKeys.FINISHED_FLAGS: tf.TensorShape([batch_size, 1]) } else: state_shape_invariants = { decoding_module.StateKeys.CUR_INDEX: tf.TensorShape([]), decoding_module.StateKeys.ALIVE_SEQ: tf.TensorShape([None, None]), decoding_module.StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([None, 1]), decoding_module.StateKeys.ALIVE_CACHE: tf.nest.map_structure(decoding_module.get_shape_keep_last_dim, alive_cache), decoding_module.StateKeys.FINISHED_SEQ: tf.TensorShape([None, None]), decoding_module.StateKeys.FINISHED_SCORES: tf.TensorShape([None, 1]), decoding_module.StateKeys.FINISHED_FLAGS: tf.TensorShape([None, 1]) } if self.extra_cache_output: state.update( {decoding_module.StateKeys.INITIAL_OUTPUT_CACHE: alive_cache}) if self.padded_decode: state_shape_invariants.update({ decoding_module.StateKeys.INITIAL_OUTPUT_CACHE: tf.nest.map_structure(lambda state: state.get_shape(), alive_cache) }) else: state_shape_invariants.update({ decoding_module.StateKeys.INITIAL_OUTPUT_CACHE: tf.nest.map_structure(decoding_module.get_shape_keep_last_dim, alive_cache), }) return state, state_shape_invariants def _get_new_alive_state(self, new_seq: tf.Tensor, new_log_probs: tf.Tensor, new_finished_flags: tf.Tensor, new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]: """Gather the sequences that are still alive. This function resets the sequences in the alive_state that are finished. Args: new_seq: New sequences generated by growing the current alive sequences int32 tensor with shape [batch_size, cur_index + 1] new_log_probs: Log probabilities of new sequences float32 tensor with shape [batch_size, 1] new_finished_flags: A boolean Tensor indicates which sequences are live inside the beam. new_cache: Dict of cached values for each sequence. Returns: Dictionary with alive keys. """ new_seq = tf.multiply( new_seq, tf.cast(tf.logical_not(new_finished_flags), new_seq.dtype)) return { decoding_module.StateKeys.ALIVE_SEQ: new_seq, decoding_module.StateKeys.ALIVE_LOG_PROBS: new_log_probs, decoding_module.StateKeys.ALIVE_CACHE: new_cache } def _get_new_finished_state(self, state: Dict[str, Any], new_seq: tf.Tensor, new_log_probs: tf.Tensor, new_finished_flags: tf.Tensor, batch_size: int) -> Dict[str, tf.Tensor]: """Combine new and old finished sequences. Args: state: A dictionary with the current loop state. new_seq: New sequences generated by growing the current alive sequences int32 tensor [batch, curr_index + 1] or [batch, max_decode_length + 1]. new_log_probs: Log probabilities of new sequences float32 tensor with shape [batch, 1]. new_finished_flags: A boolean Tensor indicates which sequences are live. batch_size: The given batch size. Returns: Dictionary with finished keys from StateKeys. """ i = state[decoding_module.StateKeys.CUR_INDEX] finished_seq = state[decoding_module.StateKeys.FINISHED_SEQ] finished_scores = state[decoding_module.StateKeys.FINISHED_SCORES] finished_flags = state[decoding_module.StateKeys.FINISHED_FLAGS] if not self.padded_decode: finished_seq = tf.concat( [finished_seq, tf.zeros([batch_size, 1], tf.int32)], axis=-1) new_scores = new_log_probs if self.length_normalization_fn is not None: length_norm = self.length_normalization_fn(i + 1, self.dtype) new_scores = new_log_probs / length_norm new_seq = tf.multiply( new_seq, tf.cast(tf.logical_not(finished_flags), new_seq.dtype)) new_scores = tf.multiply( new_scores, tf.cast(tf.logical_not(finished_flags), new_scores.dtype)) finished_seq += tf.multiply(new_seq, tf.cast(new_finished_flags, new_seq.dtype)) finished_scores += tf.multiply( new_scores, tf.cast(new_finished_flags, new_scores.dtype)) new_finished_flags = tf.logical_or(new_finished_flags, finished_flags) return { decoding_module.StateKeys.FINISHED_SEQ: finished_seq, decoding_module.StateKeys.FINISHED_SCORES: finished_scores, decoding_module.StateKeys.FINISHED_FLAGS: new_finished_flags } def _process_finished_state( self, finished_state: Dict[str, Any]) -> decoding_module.Output: """Process the alive/finished state to return final sequences and scores.""" alive_seq = finished_state[decoding_module.StateKeys.ALIVE_SEQ] alive_log_probs = finished_state[decoding_module.StateKeys.ALIVE_LOG_PROBS] finished_seq = finished_state[decoding_module.StateKeys.FINISHED_SEQ] finished_scores = finished_state[decoding_module.StateKeys.FINISHED_SCORES] finished_flags = finished_state[decoding_module.StateKeys.FINISHED_FLAGS] finished_cond = tf.reduce_any(finished_flags, 1, name="finished_cond") if self.length_normalization_fn is not None: length_norm = self.length_normalization_fn(self.max_decode_length + 1, self.dtype) alive_log_probs = alive_log_probs / length_norm seq_cond = decoding_module.expand_to_same_rank(finished_cond, finished_seq) score_cond = decoding_module.expand_to_same_rank(finished_cond, finished_scores) finished_seq = tf.where(seq_cond, finished_seq, alive_seq) finished_scores = tf.where(score_cond, finished_scores, alive_log_probs) if self.extra_cache_output: return finished_seq, finished_scores, finished_state[ decoding_module.StateKeys.INITIAL_OUTPUT_CACHE] return finished_seq, finished_scores def _continue_search(self, state) -> tf.Tensor: i = state[decoding_module.StateKeys.CUR_INDEX] # Have we reached max decoding length? not_at_end = tf.less(i, self.max_decode_length) # Have all sampled sequences reached an EOS? all_has_eos = tf.reduce_all( state[decoding_module.StateKeys.FINISHED_FLAGS], axis=None, name="search_finish_cond") return tf.logical_and(not_at_end, tf.logical_not(all_has_eos)) def _finished_flags(self, topk_ids, state) -> tf.Tensor: new_finished_flags = tf.equal(topk_ids, self.eos_id) new_finished_flags = tf.logical_or( new_finished_flags, state[decoding_module.StateKeys.FINISHED_FLAGS]) return new_finished_flags