Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Sampling module for top_k, top_p and greedy decoding.""" | |
import abc | |
from typing import Any, Callable, Dict, Optional | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
from official.nlp.modeling.ops import decoding_module | |
def greedy(log_probs): | |
"""Returns the top ids and scores based on greedy decoding.""" | |
log_probs, ids = tf.math.top_k(log_probs, k=1) | |
return log_probs, ids | |
def sample_logits_with_temperature(logits, temperature): | |
"""Applies a sampling temperature. | |
Temperature skews the distribution towards high probability | |
tokens and lowers the mass in tail distribution. | |
Args: | |
logits: Input logits for next token. | |
temperature: Tensor for specifying the sampling temperature. | |
Returns: | |
Logits with applied temperature. | |
""" | |
return logits / temperature | |
def sample_top_k(logits, top_k): | |
"""Chooses top_k logits and sets the others to negative infinity. | |
Args: | |
logits: Input logits for next token. | |
top_k: Tensor to specify the top_k values. | |
Returns: | |
Logits with top_k filtering applied. | |
""" | |
top_k = tf.clip_by_value( | |
top_k, clip_value_min=1, clip_value_max=tf.shape(logits)[-1]) | |
top_k_logits = tf.math.top_k(logits, k=top_k) | |
indices_to_remove = logits < tf.expand_dims(top_k_logits[0][..., -1], -1) | |
top_k_logits = set_tensor_by_indices_to_value(logits, indices_to_remove, | |
np.NINF) | |
return top_k_logits | |
def sample_top_p(logits, top_p): | |
"""Chooses most probable logits with cumulative probabilities upto top_p. | |
Sets the remaining logits to negative infinity. | |
Args: | |
logits: Input logits for next token. | |
top_p: Float tensor with a value >=0 and < 1.0 | |
Returns: | |
Logits with top_p filtering applied. | |
""" | |
sorted_indices = tf.argsort(logits, direction="DESCENDING") | |
# Flatten logits as tf.gather on TPU needs axis to be compile time constant. | |
logits_shape = decoding_module.shape_list(logits) | |
range_for_gather = tf.expand_dims(tf.range(0, logits_shape[0]), axis=1) | |
range_for_gather = tf.tile(range_for_gather * logits_shape[1], | |
[1, logits_shape[1]]) + sorted_indices | |
flattened_logits = tf.reshape(logits, [-1]) | |
flattened_sorted_indices = tf.reshape(range_for_gather, [-1]) | |
sorted_logits = tf.reshape( | |
tf.gather(flattened_logits, flattened_sorted_indices), | |
[logits_shape[0], logits_shape[1]]) | |
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) | |
# Remove tokens with cumulative probability above the threshold. | |
sorted_indices_to_remove = cumulative_probs > top_p | |
# Shift the indices to the right to keep the first token above threshold. | |
sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1) | |
sorted_indices_to_remove = tf.concat([ | |
tf.zeros_like(sorted_indices_to_remove[:, :1]), | |
sorted_indices_to_remove[:, 1:] | |
], -1) | |
# Scatter sorted indices to original indexes. | |
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, | |
sorted_indices) | |
top_p_logits = set_tensor_by_indices_to_value(logits, indices_to_remove, | |
np.NINF) | |
return top_p_logits | |
def scatter_values_on_batch_indices(values, batch_indices): | |
"""Scatter `values` into a tensor using `batch_indices`. | |
Args: | |
values: tensor of shape [batch_size, vocab_size] containing the values to | |
scatter | |
batch_indices: tensor of shape [batch_size, vocab_size] containing the | |
indices to insert (should be a permutation in range(0, n)) | |
Returns: | |
Tensor of shape [batch_size, vocab_size] with values inserted at | |
batch_indices | |
""" | |
tensor_shape = decoding_module.shape_list(batch_indices) | |
broad_casted_batch_dims = tf.reshape( | |
tf.broadcast_to( | |
tf.expand_dims(tf.range(tensor_shape[0]), axis=-1), tensor_shape), | |
[1, -1]) | |
pair_indices = tf.transpose( | |
tf.concat([broad_casted_batch_dims, | |
tf.reshape(batch_indices, [1, -1])], 0)) | |
return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), tensor_shape) | |
def set_tensor_by_indices_to_value(input_tensor, indices, value): | |
"""Where indices is True, set the value in input_tensor to value. | |
Args: | |
input_tensor: float (batch_size, dim) | |
indices: bool (batch_size, dim) | |
value: float scalar | |
Returns: | |
output_tensor: same shape as input_tensor. | |
""" | |
value_tensor = tf.zeros_like(input_tensor) + value | |
output_tensor = tf.where(indices, value_tensor, input_tensor) | |
return output_tensor | |
class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): | |
"""Implementation for sampling strategies (go/decoding-tf-nlp).""" | |
def __init__(self, | |
symbols_to_logits_fn, | |
vocab_size: int, | |
max_decode_length: int, | |
eos_id: int, | |
padded_decode: bool, | |
length_normalization_fn: Optional[Callable[[int, tf.DType], | |
float]] = None, | |
top_k=0, | |
top_p=1.0, | |
sample_temperature=0.0, | |
enable_greedy: bool = True, | |
dtype: tf.DType = tf.float32, | |
decoding_name: Optional[str] = None, | |
extra_cache_output: bool = False): | |
"""Initialize sampling module.""" | |
self.symbols_to_logits_fn = symbols_to_logits_fn | |
self.length_normalization_fn = length_normalization_fn | |
self.eos_id = eos_id | |
self.padded_decode = padded_decode | |
self.dtype = tf.as_dtype(dtype) | |
self.vocab_size = tf.convert_to_tensor(vocab_size, dtype=tf.int32) | |
self.max_decode_length = max_decode_length | |
self.top_k = tf.convert_to_tensor(top_k, dtype=tf.int32) | |
self.top_p = tf.convert_to_tensor(top_p, dtype=tf.float32) | |
self.sample_temperature = tf.convert_to_tensor( | |
sample_temperature, dtype=tf.float32) | |
self.enable_greedy = enable_greedy | |
self.decoding_name = decoding_name | |
self.extra_cache_output = extra_cache_output | |
super(SamplingModule, self).__init__( | |
length_normalization_fn=length_normalization_fn, | |
dtype=dtype, | |
decoding_name=decoding_name, | |
extra_cache_output=extra_cache_output) | |
def _grow_alive_seq(self, | |
state: Dict[str, Any], | |
batch_size: int) -> decoding_module.InternalState: | |
"""Grow alive sequences by one token. | |
This function will implement the decoding strategies like top_p, top_k | |
and greedy for the choosing the next logit. | |
Args: | |
state: A dictionary with the current loop state. | |
batch_size: The given batch size | |
Returns: | |
Tuple of | |
(Top sequences [batch, curr_index + 1] or [batch, max_decode_length + 1], | |
Scores of returned sequences [batch, 1], | |
New ids [batch, 1], | |
New alive cache) | |
""" | |
i = state[decoding_module.StateKeys.CUR_INDEX] | |
alive_seq = state[decoding_module.StateKeys.ALIVE_SEQ] | |
alive_log_probs = state[decoding_module.StateKeys.ALIVE_LOG_PROBS] | |
alive_cache = state[decoding_module.StateKeys.ALIVE_CACHE] | |
if self.padded_decode: | |
ids = tf.slice(alive_seq, [0, i], [batch_size, 1]) | |
else: | |
ids = alive_seq | |
new_logits, new_cache = self.symbols_to_logits_fn(ids, i, alive_cache) | |
candidate_log_probs = decoding_module.log_prob_from_logits( | |
new_logits) | |
original_log_probs = candidate_log_probs + alive_log_probs | |
topk_log_probs, topk_ids = None, None | |
if self.enable_greedy: | |
topk_log_probs, topk_ids = greedy(original_log_probs) | |
else: | |
temperature_fn = sample_logits_with_temperature | |
sampled_logits = tf.cond( | |
self.sample_temperature > 0.0, | |
lambda: temperature_fn(new_logits, self.sample_temperature), | |
lambda: new_logits) | |
sampled_logits = tf.cond( | |
self.top_k > 0, | |
lambda: sample_top_k(sampled_logits, self.top_k), | |
lambda: sampled_logits) | |
sampled_logits = tf.cond( | |
self.top_p < 1, | |
lambda: sample_top_p(sampled_logits, self.top_p), | |
lambda: sampled_logits) | |
topk_ids = tf.random.categorical( | |
sampled_logits, dtype=tf.int32, num_samples=1) | |
topk_log_probs = tf.gather( | |
original_log_probs, topk_ids, axis=1, batch_dims=1) | |
if self.padded_decode: | |
topk_seq = tf.transpose(alive_seq, perm=[1, 0]) | |
topk_seq = tf.tensor_scatter_nd_update( | |
topk_seq, [[i + 1]], tf.expand_dims(tf.squeeze(topk_ids, -1), 0)) | |
topk_seq = tf.transpose(topk_seq, perm=[1, 0]) | |
else: | |
topk_seq = tf.concat([alive_seq, topk_ids], axis=-1) | |
return topk_seq, topk_log_probs, topk_ids, new_cache | |
def _create_initial_state( | |
self, | |
initial_ids: tf.Tensor, | |
initial_cache: Dict[str, tf.Tensor], | |
batch_size: int, | |
initial_log_probs: Optional[tf.Tensor] = None | |
) -> decoding_module.InitialState: | |
"""Return initial state dictionary and its shape invariants.""" | |
for key, value in initial_cache.items(): | |
for inner_value in tf.nest.flatten(value): | |
if inner_value.dtype != self.dtype: | |
raise TypeError( | |
"initial_cache element for key '%s' has dtype %s that does not " | |
"match sampling_module's dtype of %s. Value: %s" % | |
(key, value.dtype.name, self.dtype.name, inner_value)) | |
# Current loop index (starts at 0) | |
cur_index = tf.constant(0) | |
# Alive sequence with shape [batch_size, 1] | |
alive_seq = initial_ids | |
alive_seq = tf.expand_dims(alive_seq, axis=-1) | |
if self.padded_decode: | |
alive_seq = tf.tile(alive_seq, [1, self.max_decode_length + 1]) | |
# Initial log probabilities with shape [batch_size, 1]. | |
if initial_log_probs is None: | |
initial_log_probs = tf.constant([[0.]], dtype=self.dtype) | |
alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1]) | |
else: | |
alive_log_probs = initial_log_probs | |
alive_cache = initial_cache | |
# Initialize tensor storing finished sequences [batch_size, 1, 1]. | |
finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32) | |
# Set scores of the initial finished seqs to negative infinity. | |
finished_scores = tf.zeros([batch_size, 1], dtype=self.dtype) | |
# Initialize finished flags with all False values. | |
finished_flags = tf.zeros([batch_size, 1], tf.bool) | |
# Create state dictionary and state shapes. | |
state = { | |
decoding_module.StateKeys.CUR_INDEX: cur_index, | |
decoding_module.StateKeys.ALIVE_SEQ: alive_seq, | |
decoding_module.StateKeys.ALIVE_LOG_PROBS: alive_log_probs, | |
decoding_module.StateKeys.ALIVE_CACHE: alive_cache, | |
decoding_module.StateKeys.FINISHED_SEQ: finished_seq, | |
decoding_module.StateKeys.FINISHED_SCORES: finished_scores, | |
decoding_module.StateKeys.FINISHED_FLAGS: finished_flags | |
} | |
if self.padded_decode: | |
state_shape_invariants = { | |
decoding_module.StateKeys.CUR_INDEX: | |
tf.TensorShape([]), | |
decoding_module.StateKeys.ALIVE_SEQ: | |
tf.TensorShape([batch_size, self.max_decode_length + 1]), | |
decoding_module.StateKeys.ALIVE_LOG_PROBS: | |
tf.TensorShape([batch_size, 1]), | |
decoding_module.StateKeys.ALIVE_CACHE: | |
tf.nest.map_structure(lambda state: state.get_shape(), | |
alive_cache), | |
decoding_module.StateKeys.FINISHED_SEQ: | |
tf.TensorShape([batch_size, self.max_decode_length + 1]), | |
decoding_module.StateKeys.FINISHED_SCORES: | |
tf.TensorShape([batch_size, 1]), | |
decoding_module.StateKeys.FINISHED_FLAGS: | |
tf.TensorShape([batch_size, 1]) | |
} | |
else: | |
state_shape_invariants = { | |
decoding_module.StateKeys.CUR_INDEX: | |
tf.TensorShape([]), | |
decoding_module.StateKeys.ALIVE_SEQ: | |
tf.TensorShape([None, None]), | |
decoding_module.StateKeys.ALIVE_LOG_PROBS: | |
tf.TensorShape([None, 1]), | |
decoding_module.StateKeys.ALIVE_CACHE: | |
tf.nest.map_structure(decoding_module.get_shape_keep_last_dim, | |
alive_cache), | |
decoding_module.StateKeys.FINISHED_SEQ: | |
tf.TensorShape([None, None]), | |
decoding_module.StateKeys.FINISHED_SCORES: | |
tf.TensorShape([None, 1]), | |
decoding_module.StateKeys.FINISHED_FLAGS: | |
tf.TensorShape([None, 1]) | |
} | |
if self.extra_cache_output: | |
state.update( | |
{decoding_module.StateKeys.INITIAL_OUTPUT_CACHE: alive_cache}) | |
if self.padded_decode: | |
state_shape_invariants.update({ | |
decoding_module.StateKeys.INITIAL_OUTPUT_CACHE: | |
tf.nest.map_structure(lambda state: state.get_shape(), | |
alive_cache) | |
}) | |
else: | |
state_shape_invariants.update({ | |
decoding_module.StateKeys.INITIAL_OUTPUT_CACHE: | |
tf.nest.map_structure(decoding_module.get_shape_keep_last_dim, | |
alive_cache), | |
}) | |
return state, state_shape_invariants | |
def _get_new_alive_state(self, new_seq: tf.Tensor, new_log_probs: tf.Tensor, | |
new_finished_flags: tf.Tensor, | |
new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]: | |
"""Gather the sequences that are still alive. | |
This function resets the sequences in the alive_state that are finished. | |
Args: | |
new_seq: New sequences generated by growing the current alive sequences | |
int32 tensor with shape [batch_size, cur_index + 1] | |
new_log_probs: Log probabilities of new sequences float32 tensor with | |
shape [batch_size, 1] | |
new_finished_flags: A boolean Tensor indicates which sequences are live | |
inside the beam. | |
new_cache: Dict of cached values for each sequence. | |
Returns: | |
Dictionary with alive keys. | |
""" | |
new_seq = tf.multiply( | |
new_seq, tf.cast(tf.logical_not(new_finished_flags), new_seq.dtype)) | |
return { | |
decoding_module.StateKeys.ALIVE_SEQ: new_seq, | |
decoding_module.StateKeys.ALIVE_LOG_PROBS: new_log_probs, | |
decoding_module.StateKeys.ALIVE_CACHE: new_cache | |
} | |
def _get_new_finished_state(self, state: Dict[str, Any], new_seq: tf.Tensor, | |
new_log_probs: tf.Tensor, | |
new_finished_flags: tf.Tensor, | |
batch_size: int) -> Dict[str, tf.Tensor]: | |
"""Combine new and old finished sequences. | |
Args: | |
state: A dictionary with the current loop state. | |
new_seq: New sequences generated by growing the current alive sequences | |
int32 tensor [batch, curr_index + 1] or [batch, max_decode_length + 1]. | |
new_log_probs: Log probabilities of new sequences float32 tensor with | |
shape [batch, 1]. | |
new_finished_flags: A boolean Tensor indicates which sequences are live. | |
batch_size: The given batch size. | |
Returns: | |
Dictionary with finished keys from StateKeys. | |
""" | |
i = state[decoding_module.StateKeys.CUR_INDEX] | |
finished_seq = state[decoding_module.StateKeys.FINISHED_SEQ] | |
finished_scores = state[decoding_module.StateKeys.FINISHED_SCORES] | |
finished_flags = state[decoding_module.StateKeys.FINISHED_FLAGS] | |
if not self.padded_decode: | |
finished_seq = tf.concat( | |
[finished_seq, tf.zeros([batch_size, 1], tf.int32)], axis=-1) | |
new_scores = new_log_probs | |
if self.length_normalization_fn is not None: | |
length_norm = self.length_normalization_fn(i + 1, self.dtype) | |
new_scores = new_log_probs / length_norm | |
new_seq = tf.multiply( | |
new_seq, tf.cast(tf.logical_not(finished_flags), new_seq.dtype)) | |
new_scores = tf.multiply( | |
new_scores, tf.cast(tf.logical_not(finished_flags), new_scores.dtype)) | |
finished_seq += tf.multiply(new_seq, | |
tf.cast(new_finished_flags, new_seq.dtype)) | |
finished_scores += tf.multiply( | |
new_scores, tf.cast(new_finished_flags, new_scores.dtype)) | |
new_finished_flags = tf.logical_or(new_finished_flags, finished_flags) | |
return { | |
decoding_module.StateKeys.FINISHED_SEQ: finished_seq, | |
decoding_module.StateKeys.FINISHED_SCORES: finished_scores, | |
decoding_module.StateKeys.FINISHED_FLAGS: new_finished_flags | |
} | |
def _process_finished_state( | |
self, finished_state: Dict[str, Any]) -> decoding_module.Output: | |
"""Process the alive/finished state to return final sequences and scores.""" | |
alive_seq = finished_state[decoding_module.StateKeys.ALIVE_SEQ] | |
alive_log_probs = finished_state[decoding_module.StateKeys.ALIVE_LOG_PROBS] | |
finished_seq = finished_state[decoding_module.StateKeys.FINISHED_SEQ] | |
finished_scores = finished_state[decoding_module.StateKeys.FINISHED_SCORES] | |
finished_flags = finished_state[decoding_module.StateKeys.FINISHED_FLAGS] | |
finished_cond = tf.reduce_any(finished_flags, 1, name="finished_cond") | |
if self.length_normalization_fn is not None: | |
length_norm = self.length_normalization_fn(self.max_decode_length + 1, | |
self.dtype) | |
alive_log_probs = alive_log_probs / length_norm | |
seq_cond = decoding_module.expand_to_same_rank(finished_cond, finished_seq) | |
score_cond = decoding_module.expand_to_same_rank(finished_cond, | |
finished_scores) | |
finished_seq = tf.where(seq_cond, finished_seq, alive_seq) | |
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs) | |
if self.extra_cache_output: | |
return finished_seq, finished_scores, finished_state[ | |
decoding_module.StateKeys.INITIAL_OUTPUT_CACHE] | |
return finished_seq, finished_scores | |
def _continue_search(self, state) -> tf.Tensor: | |
i = state[decoding_module.StateKeys.CUR_INDEX] | |
# Have we reached max decoding length? | |
not_at_end = tf.less(i, self.max_decode_length) | |
# Have all sampled sequences reached an EOS? | |
all_has_eos = tf.reduce_all( | |
state[decoding_module.StateKeys.FINISHED_FLAGS], | |
axis=None, | |
name="search_finish_cond") | |
return tf.logical_and(not_at_end, tf.logical_not(all_has_eos)) | |
def _finished_flags(self, topk_ids, state) -> tf.Tensor: | |
new_finished_flags = tf.equal(topk_ids, self.eos_id) | |
new_finished_flags = tf.logical_or( | |
new_finished_flags, state[decoding_module.StateKeys.FINISHED_FLAGS]) | |
return new_finished_flags | |