|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Base class for Decoding Strategies (beam_search, top_k, top_p and greedy).""" |
|
|
|
import abc |
|
from typing import Any, Callable, Dict, Optional, Tuple |
|
|
|
import tensorflow as tf, tf_keras |
|
|
|
from tensorflow.python.framework import dtypes |
|
from official.modeling import tf_utils |
|
|
|
Output = Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor]] |
|
InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict] |
|
InitialState = Tuple[Dict[str, Any], Dict[str, Any]] |
|
|
|
|
|
class StateKeys: |
|
"""Keys to dictionary storing the state of Decoding loop.""" |
|
|
|
|
|
CUR_INDEX = "CUR_INDEX" |
|
|
|
|
|
|
|
|
|
|
|
|
|
ALIVE_SEQ = "ALIVE_SEQ" |
|
|
|
ALIVE_LOG_PROBS = "ALIVE_LOG_PROBS" |
|
|
|
|
|
|
|
ALIVE_CACHE = "ALIVE_CACHE" |
|
|
|
|
|
|
|
INITIAL_OUTPUT_CACHE = "INITIAL_OUTPUT_CACHE" |
|
|
|
|
|
|
|
|
|
FINISHED_SEQ = "FINISHED_SEQ" |
|
|
|
|
|
FINISHED_SCORES = "FINISHED_SCORES" |
|
|
|
|
|
|
|
FINISHED_FLAGS = "FINISHED_FLAGS" |
|
|
|
|
|
def log_prob_from_logits(logits): |
|
return logits - tf.reduce_logsumexp(logits, axis=-1, keepdims=True) |
|
|
|
|
|
def shape_list(tensor): |
|
"""Return a list of the tensor's shape, and ensure no None values in list.""" |
|
return tf_utils.get_shape_list(tensor) |
|
|
|
|
|
def get_shape_keep_last_dim(tensor): |
|
shape_list_obj = shape_list(tensor) |
|
for i in range(len(shape_list_obj) - 1): |
|
shape_list_obj[i] = None |
|
|
|
if isinstance(shape_list_obj[-1], tf.Tensor): |
|
shape_list_obj[-1] = None |
|
return tf.TensorShape(shape_list_obj) |
|
|
|
|
|
def expand_to_same_rank(tensor, target): |
|
"""Expands a given tensor to target's rank to be broadcastable. |
|
|
|
Args: |
|
tensor: input tensor to tile. Shape: [b, d1, ..., da] |
|
target: target tensor. Shape: [b, d1, ..., da, ..., dn] |
|
|
|
Returns: |
|
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target |
|
|
|
Raises: |
|
ValueError, if the shape rank of rank tensor/target is None. |
|
""" |
|
if tensor.shape.rank is None: |
|
raise ValueError("Expect rank for tensor shape, but got None.") |
|
if target.shape.rank is None: |
|
raise ValueError("Expect rank for target shape, but got None.") |
|
|
|
with tf.name_scope("expand_rank"): |
|
diff_rank = target.shape.rank - tensor.shape.rank |
|
for _ in range(diff_rank): |
|
tensor = tf.expand_dims(tensor, -1) |
|
return tensor |
|
|
|
|
|
class DecodingModule(tf.Module, metaclass=abc.ABCMeta): |
|
"""A base class for the API required for decoding (go/decoding-tf-nlp).""" |
|
|
|
def __init__(self, |
|
length_normalization_fn: Callable[[int, tf.DType], float], |
|
dtype: tf.DType = tf.float32, |
|
decoding_name: Optional[str] = None, |
|
extra_cache_output: bool = False): |
|
"""Initialize the Decoding Module. |
|
|
|
Args: |
|
length_normalization_fn: Closure for returning length normalization |
|
parameter. Function accepts input as length, dtype and returns float. |
|
dtype: A tensorflow data type used for score computation. The default is |
|
tf.float32. |
|
decoding_name: an optional name for the decoding loop tensors. |
|
extra_cache_output: If true, the first cache will be in the states. |
|
""" |
|
self.length_normalization_fn = length_normalization_fn |
|
self.dtype = tf.as_dtype(dtype) |
|
self.decoding_name = decoding_name |
|
|
|
def generate(self, |
|
initial_ids: tf.Tensor, |
|
initial_cache: Dict[str, tf.Tensor], |
|
initial_log_probs: Optional[tf.Tensor] = None) -> Output: |
|
"""Implements the decoding strategy (beam_search or sampling). |
|
|
|
Args: |
|
initial_ids: initial ids to pass into the symbols_to_logits_fn. int tensor |
|
with shape [batch_size, 1] |
|
initial_cache: dictionary for caching model outputs from previous step. |
|
initial_log_probs: Optionally initial log probs if there is a prefix |
|
sequence we want to start to decode from. |
|
|
|
Returns: |
|
Tuple of tensors representing |
|
finished_sequence: shape [batch, max_seq_length] |
|
finished_scores: [batch] |
|
first_cache: The cache after init token |
|
""" |
|
batch_size = ( |
|
initial_ids.shape.as_list()[0] |
|
if self.padded_decode else tf.shape(initial_ids)[0]) |
|
|
|
state, state_shapes = self._create_initial_state(initial_ids, initial_cache, |
|
batch_size, |
|
initial_log_probs) |
|
|
|
def _generate_step(state): |
|
topk_seq, topk_log_probs, topk_ids, new_cache = self._grow_alive_seq( |
|
state, batch_size) |
|
new_finished_flags = self._finished_flags(topk_ids, state) |
|
alive_state = self._get_new_alive_state(topk_seq, |
|
topk_log_probs, |
|
new_finished_flags, |
|
new_cache) |
|
finished_state = self._get_new_finished_state(state, |
|
topk_seq, |
|
topk_log_probs, |
|
new_finished_flags, |
|
batch_size) |
|
new_state = { |
|
StateKeys.CUR_INDEX: state[StateKeys.CUR_INDEX] + 1 |
|
} |
|
new_state.update(alive_state) |
|
new_state.update(finished_state) |
|
if self.extra_cache_output: |
|
i = state[StateKeys.CUR_INDEX] |
|
old_cache = state[StateKeys.INITIAL_OUTPUT_CACHE] |
|
|
|
def update_with_cache(new_state, cache): |
|
"""Updates new_state with cache.""" |
|
new_state.update({StateKeys.INITIAL_OUTPUT_CACHE: cache}) |
|
|
|
tf.cond( |
|
tf.equal(i, 0), lambda: update_with_cache(new_state, new_cache), |
|
lambda: update_with_cache(new_state, old_cache)) |
|
return [new_state] |
|
|
|
finished_state = tf.nest.map_structure( |
|
tf.stop_gradient, |
|
tf.while_loop( |
|
self._continue_search, |
|
_generate_step, |
|
loop_vars=[state], |
|
shape_invariants=[state_shapes], |
|
parallel_iterations=1, |
|
name=self.decoding_name)) |
|
final_state = self._process_finished_state(finished_state[0]) |
|
return final_state |
|
|
|
@abc.abstractmethod |
|
def _create_initial_state( |
|
self, |
|
initial_ids: tf.Tensor, |
|
initial_cache: Dict[str, tf.Tensor], |
|
batch_size: int, |
|
initial_log_probs: Optional[tf.Tensor] = None) -> InitialState: |
|
"""Return initial state dictionary and its shape invariants.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def _grow_alive_seq(self, |
|
state: Dict[str, Any], |
|
batch_size: int) -> InternalState: |
|
"""Grow alive sequences by one token. |
|
|
|
Args: |
|
state: A dictionary with the current loop state. |
|
batch_size: The given batch size |
|
|
|
Returns: |
|
Tuple of |
|
(Top sequences, |
|
Scores of returned sequences, |
|
New ids, |
|
New alive cache) |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def _get_new_alive_state( |
|
self, |
|
new_seq: tf.Tensor, |
|
new_log_probs: tf.Tensor, |
|
new_finished_flags: tf.Tensor, |
|
new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]: |
|
"""Gather the sequences that are still alive. |
|
|
|
Args: |
|
new_seq: New sequences generated by growing the current alive sequences |
|
int32 tensor with shape |
|
new_log_probs: Log probabilities of new sequences float32 tensor with |
|
shape |
|
new_finished_flags: A boolean Tensor indicates which sequences are live. |
|
new_cache: Dict of cached values for each sequence. |
|
|
|
Returns: |
|
Dictionary with alive keys from StateKeys. |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def _get_new_finished_state(self, |
|
state: Dict[str, Any], |
|
new_seq: tf.Tensor, |
|
new_log_probs: tf.Tensor, |
|
new_finished_flags: tf.Tensor, |
|
batch_size: int) -> Dict[str, tf.Tensor]: |
|
"""Combine new and old finished sequences. |
|
|
|
Args: |
|
state: A dictionary with the current loop state. |
|
new_seq: New sequences generated by growing the current alive sequences |
|
int32 tensor. |
|
new_log_probs: Log probabilities of new sequences float32 tensor with |
|
shape. |
|
new_finished_flags: A boolean Tensor indicates which sequences are live. |
|
batch_size: The given batch size. |
|
|
|
Returns: |
|
Dictionary with finished keys from StateKeys. |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def _process_finished_state(self, finished_state: Dict[str, Any]) -> Output: |
|
"""Process the alive/finished state to return final sequences and scores.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def _continue_search(self, state: Dict[str, Any]) -> tf.Tensor: |
|
"""Returns a bool tensor if the decoding loop should continue.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def _finished_flags(self, |
|
topk_ids: tf.Tensor, |
|
state: Dict[str, Any]) -> tf.Tensor: |
|
"""Calculate the finished flags.""" |
|
pass |
|
|
|
def inf(self): |
|
"""Returns a value close to infinity, but is still finite in `dtype`. |
|
|
|
This is useful to get a very large value that is still zero when multiplied |
|
by zero. The floating-point "Inf" value is NaN when multiplied by zero. |
|
|
|
Returns: |
|
A very large value. |
|
""" |
|
if self.dtype == dtypes.float32 or self.dtype == dtypes.bfloat16: |
|
return 1e7 |
|
elif self.dtype == dtypes.float16: |
|
return dtypes.float16.max |
|
else: |
|
raise AssertionError("Invalid dtype: %s" % self.dtype) |
|
|