|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Flax IndicTrans model.""" |
|
|
|
|
|
import math |
|
from typing import List, Optional, Tuple, Union, Callable |
|
from functools import partial |
|
|
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze |
|
from flax.linen import combine_masks, make_causal_mask |
|
from flax.linen.attention import dot_product_attention_weights |
|
from flax.traverse_util import flatten_dict, unflatten_dict |
|
from jax import lax |
|
from jax.random import PRNGKey |
|
|
|
from transformers.modeling_flax_outputs import ( |
|
FlaxBaseModelOutput, |
|
FlaxBaseModelOutputWithPastAndCrossAttentions, |
|
FlaxCausalLMOutputWithCrossAttentions, |
|
FlaxSeq2SeqLMOutput, |
|
FlaxSeq2SeqModelOutput, |
|
) |
|
from transformers.modeling_flax_utils import ( |
|
ACT2FN, |
|
FlaxPreTrainedModel, |
|
append_call_sample_docstring, |
|
append_replace_return_docstrings, |
|
overwrite_call_docstring, |
|
) |
|
from .configuration_indictrans import IndicTransConfig |
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "IndicTransConfig" |
|
|
|
INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""] |
|
|
|
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = jnp.zeros_like(input_ids) |
|
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) |
|
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) |
|
|
|
if pad_token_id is None: |
|
raise ValueError("self.model.config.pad_token_id has to be defined.") |
|
|
|
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) |
|
|
|
return shifted_input_ids |
|
|
|
|
|
class FlaxIndicTransSinusoidalPositionalEmbedding(nn.Module): |
|
"""This module produces sinusoidal positional embeddings of any length.""" |
|
num_positions: int |
|
embedding_dim: int |
|
padding_idx: Optional[int] = None |
|
|
|
|
|
|
|
offset: int = 2 |
|
|
|
def setup(self) -> None: |
|
self.weights = self._make_weights(self.num_positions + self.offset, self.embedding_dim, padding_idx=self.padding_idx) |
|
|
|
def _make_weights( |
|
self, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
existing_weights: Optional[jnp.array] = None, |
|
padding_idx: Optional[int] = None |
|
): |
|
emb_weights = self._get_embedding(num_embeddings, embedding_dim, padding_idx) |
|
|
|
if existing_weights is not None: |
|
|
|
emb_weights = emb_weights.astype(existing_weights.dtype) |
|
|
|
return emb_weights |
|
|
|
def _get_embedding( |
|
self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None |
|
): |
|
""" |
|
Build sinusoidal embeddings. |
|
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of |
|
"Attention Is All You Need". |
|
""" |
|
half_dim = embedding_dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32)) |
|
emb = jnp.arange(num_embeddings, dtype=jnp.float32).reshape(-1, 1) * emb.reshape(1, -1) |
|
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1).reshape(num_embeddings, -1) |
|
|
|
if embedding_dim % 2 == 1: |
|
|
|
emb = jnp.concatenate([emb, jnp.zeros((num_embeddings, 1), dtype=emb.dtype)], axis=1) |
|
|
|
if padding_idx is not None: |
|
emb = emb.at[padding_idx].set(0) |
|
|
|
return emb |
|
|
|
def __call__( |
|
self, |
|
input_ids: jnp.array = None, |
|
inputs_embeds: jnp.array = None, |
|
past_key_values_length: int = 0 |
|
): |
|
if input_ids is not None: |
|
bsz, seq_len = input_ids.shape |
|
position_ids = self._create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) |
|
else: |
|
bsz, seq_len = inputs_embeds.shape[:-1] |
|
position_ids = self._create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) |
|
|
|
|
|
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length |
|
if max_pos > self.weights.shape[0]: |
|
self.weights = self.make_weights(max_pos + self.offset, self.embedding_dim, self.weights, self.padding_idx) |
|
|
|
return self.weights[position_ids.ravel()].reshape(bsz, seq_len, -1) |
|
|
|
def _create_position_ids_from_input_ids( |
|
self, input_ids, padding_idx, past_key_values_length=0 |
|
): |
|
""" |
|
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols |
|
are ignored. This is a JAX conversion of the PyTorch function. |
|
""" |
|
mask = (input_ids != padding_idx) |
|
incremental_indices = (jnp.cumsum(mask, axis=1) + past_key_values_length) * mask |
|
return incremental_indices + padding_idx |
|
|
|
def _create_position_ids_from_inputs_embeds( |
|
self, inputs_embeds, past_key_values_length |
|
): |
|
""" |
|
Generate sequential position ids from input embeddings. |
|
Args: |
|
inputs_embeds: jnp.array (JAX array) |
|
past_key_values_length: int |
|
Returns: |
|
jnp.array: Position IDs corresponding to the inputs. |
|
""" |
|
input_shape = inputs_embeds.shape[:-1] |
|
sequence_length = input_shape[1] |
|
|
|
position_ids = jnp.arange(self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=jnp.int64) |
|
return jnp.expand_dims(position_ids, axis=0).repeat(input_shape[0], axis=0) + past_key_values_length |
|
|
|
|
|
class FlaxIndicTransAttention(nn.Module): |
|
config: IndicTransConfig |
|
embed_dim: int |
|
num_heads: int |
|
dropout: float = 0.0 |
|
causal: bool = False |
|
bias: bool = True |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
|
|
self.head_dim = self.embed_dim // self.num_heads |
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
dense = partial( |
|
nn.Dense, |
|
self.embed_dim, |
|
use_bias=self.bias, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() |
|
self.out_proj = dense() |
|
|
|
self.dropout_layer = nn.Dropout(rate=self.dropout) |
|
|
|
if self.causal: |
|
self.causal_mask = make_causal_mask( |
|
jnp.ones((1, self.config.max_source_positions), dtype="bool"), dtype="bool" |
|
) |
|
|
|
def _split_heads(self, hidden_states): |
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) |
|
|
|
def _merge_heads(self, hidden_states): |
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) |
|
|
|
@nn.compact |
|
def _concatenate_to_cache(self, key, value, query, attention_mask): |
|
""" |
|
This function takes projected key, value states from a single input token and concatenates the states to cached |
|
states from previous steps. This function is slighly adapted from the official Flax repository: |
|
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 |
|
""" |
|
|
|
is_initialized = self.has_variable("cache", "cached_key") |
|
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) |
|
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) |
|
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) |
|
|
|
if is_initialized: |
|
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape |
|
|
|
cur_index = cache_index.value |
|
indices = (0,) * len(batch_dims) + (cur_index, 0, 0) |
|
key = lax.dynamic_update_slice(cached_key.value, key, indices) |
|
value = lax.dynamic_update_slice(cached_value.value, value, indices) |
|
cached_key.value = key |
|
cached_value.value = value |
|
num_updated_cache_vectors = query.shape[1] |
|
cache_index.value = cache_index.value + num_updated_cache_vectors |
|
|
|
pad_mask = jnp.broadcast_to( |
|
jnp.arange(max_length) < cur_index + num_updated_cache_vectors, |
|
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), |
|
) |
|
attention_mask = combine_masks(pad_mask, attention_mask) |
|
return key, value, attention_mask |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
key_value_states: Optional[jnp.ndarray] = None, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
init_cache: bool = False, |
|
deterministic: bool = True, |
|
) -> Tuple[jnp.ndarray]: |
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
|
|
|
|
is_cross_attention = key_value_states is not None |
|
batch_size = hidden_states.shape[0] |
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
|
if is_cross_attention: |
|
|
|
key_states = self.k_proj(key_value_states) |
|
value_states = self.v_proj(key_value_states) |
|
else: |
|
|
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = self._split_heads(query_states) |
|
key_states = self._split_heads(key_states) |
|
value_states = self._split_heads(value_states) |
|
|
|
|
|
if self.causal: |
|
query_length, key_length = query_states.shape[1], key_states.shape[1] |
|
if self.has_variable("cache", "cached_key"): |
|
mask_shift = self.variables["cache"]["cache_index"] |
|
max_decoder_length = self.variables["cache"]["cached_key"].shape[1] |
|
causal_mask = lax.dynamic_slice( |
|
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) |
|
) |
|
else: |
|
causal_mask = self.causal_mask[:, :, :query_length, :key_length] |
|
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) |
|
|
|
|
|
if attention_mask is not None and self.causal: |
|
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) |
|
attention_mask = combine_masks(attention_mask, causal_mask) |
|
elif self.causal: |
|
attention_mask = causal_mask |
|
elif attention_mask is not None: |
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) |
|
|
|
|
|
|
|
if self.causal and (self.has_variable("cache", "cached_key") or init_cache): |
|
key_states, value_states, attention_mask = self._concatenate_to_cache( |
|
key_states, value_states, query_states, attention_mask |
|
) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
attention_bias = lax.select( |
|
attention_mask > 0, |
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), |
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), |
|
) |
|
else: |
|
attention_bias = None |
|
|
|
dropout_rng = None |
|
if not deterministic and self.dropout > 0.0: |
|
dropout_rng = self.make_rng("dropout") |
|
|
|
attn_weights = dot_product_attention_weights( |
|
query_states, |
|
key_states, |
|
bias=attention_bias, |
|
dropout_rng=dropout_rng, |
|
dropout_rate=self.dropout, |
|
broadcast_dropout=True, |
|
deterministic=deterministic, |
|
dtype=self.dtype, |
|
precision="high", |
|
) |
|
|
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) |
|
attn_output = self._merge_heads(attn_output) |
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
class FlaxIndicTransEncoderLayer(nn.Module): |
|
config: IndicTransConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.embed_dim = self.config.encoder_embed_dim |
|
self.self_attn = FlaxIndicTransAttention( |
|
config=self.config, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.config.encoder_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
dtype=self.dtype, |
|
) |
|
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
self.activation_fn = ACT2FN[self.config.activation_function] |
|
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) |
|
self.fc1 = nn.Dense( |
|
self.config.encoder_ffn_dim, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
self.fc2 = nn.Dense( |
|
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) |
|
) |
|
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
self.normalize_before = self.config.encoder_normalize_before |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
attention_mask: jnp.ndarray, |
|
output_attentions: bool = True, |
|
deterministic: bool = True, |
|
) -> Tuple[jnp.ndarray]: |
|
residual = hidden_states |
|
if self.normalize_before: |
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
hidden_states, attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask |
|
|
|
) |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
if not self.normalize_before: |
|
hidden_states = self.attn_layer_norm(hidden_states) |
|
|
|
residual = hidden_states |
|
if self.normalize_before: |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
hidden_states = self.activation_fn(self.fc1(hidden_states)) |
|
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = self.fc2(hidden_states) |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
if not self.normalize_before: |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
class FlaxIndicTransEncoderLayerCollection(nn.Module): |
|
config: IndicTransConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.layers = [ |
|
FlaxIndicTransEncoderLayer(self.config, name=str(i), dtype=self.dtype) |
|
for i in range(self.config.encoder_layers) |
|
] |
|
self.layerdrop = self.config.encoder_layerdrop |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
all_attentions = () if output_attentions else None |
|
all_hidden_states = () if output_hidden_states else None |
|
|
|
for encoder_layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
dropout_probability = jax.random.normal(jax.random.PRNGKey(0), []) |
|
if not deterministic and (dropout_probability < self.layerdrop): |
|
layer_outputs = (None, None) |
|
else: |
|
layer_outputs = encoder_layer( |
|
hidden_states, |
|
attention_mask, |
|
output_attentions, |
|
deterministic, |
|
) |
|
hidden_states = layer_outputs[0] |
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
outputs = (hidden_states, all_hidden_states, all_attentions) |
|
|
|
if not return_dict: |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutput( |
|
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions |
|
) |
|
|
|
|
|
class FlaxIndicTransDecoderLayer(nn.Module): |
|
config: IndicTransConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.embed_dim = self.config.decoder_embed_dim |
|
self.self_attn = FlaxIndicTransAttention( |
|
config=self.config, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.config.decoder_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
causal=True, |
|
dtype=self.dtype, |
|
) |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
self.activation_fn = ACT2FN[self.config.activation_function] |
|
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) |
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
self.encoder_attn = FlaxIndicTransAttention( |
|
config=self.config, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.config.decoder_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
causal=False, |
|
dtype=self.dtype, |
|
) |
|
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
self.fc1 = nn.Dense( |
|
self.config.decoder_ffn_dim, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
self.fc2 = nn.Dense( |
|
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) |
|
) |
|
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
self.normalize_before = self.config.decoder_normalize_before |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
attention_mask: jnp.ndarray, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
init_cache: bool = False, |
|
output_attentions: bool = True, |
|
deterministic: bool = True, |
|
) -> Tuple[jnp.ndarray]: |
|
residual = hidden_states |
|
if self.normalize_before: |
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights = self.self_attn( |
|
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache |
|
) |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
if not self.normalize_before: |
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
|
|
|
cross_attn_weights = None |
|
if encoder_hidden_states is not None: |
|
residual = hidden_states |
|
if self.normalize_before: |
|
hidden_states = self.encoder_attn_layer_norm(hidden_states) |
|
|
|
hidden_states, cross_attn_weights = self.encoder_attn( |
|
hidden_states=hidden_states, |
|
key_value_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
|
|
) |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
if not self.normalize_before: |
|
hidden_states = self.encoder_attn_layer_norm(hidden_states) |
|
|
|
|
|
residual = hidden_states |
|
if self.normalize_before: |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
hidden_states = self.activation_fn(self.fc1(hidden_states)) |
|
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = self.fc2(hidden_states) |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
if not self.normalize_before: |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights, cross_attn_weights) |
|
|
|
return outputs |
|
|
|
|
|
class FlaxIndicTransDecoderLayerCollection(nn.Module): |
|
config: IndicTransConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.layers = [ |
|
FlaxIndicTransDecoderLayer(self.config, name=str(i), dtype=self.dtype) |
|
for i in range(self.config.decoder_layers) |
|
] |
|
self.layerdrop = self.config.decoder_layerdrop |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
deterministic: bool = True, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
|
|
|
for decoder_layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
dropout_probability = jax.random.normal(jax.random.PRNGKey(0), []) |
|
if not deterministic and (dropout_probability < self.layerdrop): |
|
layer_outputs = (None, None, None) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
init_cache=init_cache, |
|
output_attentions=output_attentions, |
|
deterministic=deterministic, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
if encoder_hidden_states is not None: |
|
all_cross_attentions += (layer_outputs[2],) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] |
|
|
|
if not return_dict: |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
cross_attentions=all_cross_attentions, |
|
) |
|
|
|
class FlaxIndicTransEncoder(nn.Module): |
|
config: IndicTransConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
|
|
embed_dim = self.config.encoder_embed_dim |
|
self.padding_idx = self.config.pad_token_id |
|
self.max_source_positions = self.config.max_source_positions |
|
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 |
|
|
|
self.embed_tokens = nn.Embed( |
|
self.config.encoder_vocab_size, |
|
embed_dim, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
self.embed_positions = FlaxIndicTransSinusoidalPositionalEmbedding( |
|
self.config.max_source_positions, |
|
embed_dim, |
|
self.padding_idx, |
|
) |
|
self.layers = FlaxIndicTransEncoderLayerCollection(self.config, self.dtype) |
|
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.encoder_normalize_before else None |
|
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.layernorm_embedding else None |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
): |
|
input_shape = input_ids.shape |
|
input_ids = input_ids.reshape(-1, input_shape[-1]) |
|
|
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale |
|
|
|
embed_pos = self.embed_positions(input_ids, inputs_embeds) |
|
|
|
hidden_states = inputs_embeds + embed_pos |
|
if self.layernorm_embedding is not None: |
|
hidden_states = self.layernorm_embedding(hidden_states) |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
|
|
outputs = self.layers( |
|
hidden_states, |
|
attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
last_hidden_states = outputs[0] |
|
|
|
if self.layer_norm is not None: |
|
last_hidden_states = self.layer_norm(last_hidden_states) |
|
|
|
|
|
hidden_states = None |
|
if output_hidden_states: |
|
hidden_states = outputs[1] |
|
hidden_states = hidden_states[:-1] + (last_hidden_states,) |
|
|
|
if not return_dict: |
|
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutput( |
|
last_hidden_state=last_hidden_states, |
|
hidden_states=hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class FlaxIndicTransDecoder(nn.Module): |
|
config: IndicTransConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
|
|
embed_dim = self.config.encoder_embed_dim |
|
self.padding_idx = self.config.pad_token_id |
|
self.max_target_positions = self.config.max_target_positions |
|
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 |
|
|
|
self.embed_tokens = nn.Embed( |
|
self.config.decoder_vocab_size, |
|
embed_dim, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
self.embed_positions = FlaxIndicTransSinusoidalPositionalEmbedding( |
|
self.config.max_target_positions, |
|
embed_dim, |
|
self.padding_idx, |
|
) |
|
|
|
self.layers = FlaxIndicTransDecoderLayerCollection(self.config, self.dtype) |
|
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.decoder_normalize_before else None |
|
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.layernorm_embedding else None |
|
|
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
): |
|
|
|
input_shape = input_ids.shape |
|
input_ids = input_ids.reshape(-1, input_shape[-1]) |
|
|
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale |
|
|
|
|
|
positions = self.embed_positions(input_ids, inputs_embeds) |
|
|
|
hidden_states = inputs_embeds + positions |
|
|
|
if self.layernorm_embedding is not None: |
|
hidden_states = self.layernorm_embedding(hidden_states) |
|
|
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
|
|
outputs = self.layers( |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
deterministic=deterministic, |
|
init_cache=init_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
last_hidden_states = outputs[0] |
|
|
|
if self.layer_norm is not None: |
|
last_hidden_states = self.layer_norm(last_hidden_states) |
|
|
|
|
|
hidden_states = None |
|
if output_hidden_states: |
|
hidden_states = outputs[1] |
|
hidden_states = hidden_states[:-1] + (last_hidden_states,) |
|
|
|
if not return_dict: |
|
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=last_hidden_states, |
|
hidden_states=hidden_states, |
|
attentions=outputs.attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
) |
|
|
|
|
|
class FlaxIndicTransModule(nn.Module): |
|
config: IndicTransConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.encoder = FlaxIndicTransEncoder(self.config, dtype=self.dtype) |
|
self.decoder = FlaxIndicTransDecoder(self.config, dtype=self.dtype) |
|
|
|
def _get_encoder_module(self): |
|
return self.encoder |
|
|
|
def _get_decoder_module(self): |
|
return self.decoder |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
position_ids, |
|
decoder_position_ids, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
): |
|
encoder_outputs = self.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
) |
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
position_ids=decoder_position_ids, |
|
encoder_hidden_states=encoder_outputs[0], |
|
encoder_attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
) |
|
|
|
if not return_dict: |
|
return decoder_outputs + encoder_outputs |
|
|
|
return FlaxSeq2SeqModelOutput( |
|
last_hidden_state=decoder_outputs.last_hidden_state, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class FlaxIndicTransPreTrainedModel(FlaxPreTrainedModel): |
|
config_class = IndicTransConfig |
|
base_model_prefix: str = "model" |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: IndicTransConfig, |
|
input_shape: Tuple[int] = (1, 1), |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
_do_init: bool = True, |
|
**kwargs, |
|
): |
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: |
|
|
|
input_ids = jnp.zeros(input_shape, dtype="i4") |
|
|
|
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) |
|
attention_mask = jnp.ones_like(input_ids) |
|
decoder_input_ids = input_ids |
|
decoder_attention_mask = jnp.ones_like(input_ids) |
|
|
|
batch_size, sequence_length = input_ids.shape |
|
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) |
|
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) |
|
|
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
random_params = self.module.init( |
|
rngs, |
|
input_ids, |
|
attention_mask, |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
position_ids, |
|
decoder_position_ids, |
|
)["params"] |
|
|
|
if params is not None: |
|
random_params = flatten_dict(unfreeze(random_params)) |
|
params = flatten_dict(unfreeze(params)) |
|
for missing_key in self._missing_keys: |
|
params[missing_key] = random_params[missing_key] |
|
self._missing_keys = set() |
|
return freeze(unflatten_dict(params)) |
|
else: |
|
return random_params |
|
|
|
def init_cache(self, batch_size, max_length, encoder_outputs): |
|
|
|
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") |
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape |
|
) |
|
|
|
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): |
|
decoder_module = module._get_decoder_module() |
|
return decoder_module( |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
**kwargs, |
|
) |
|
|
|
init_variables = self.module.init( |
|
jax.random.PRNGKey(0), |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
decoder_position_ids=decoder_position_ids, |
|
encoder_hidden_states=encoder_outputs[0], |
|
init_cache=True, |
|
method=_decoder_forward, |
|
) |
|
return unfreeze(init_variables["cache"]) |
|
|
|
def encode( |
|
self, |
|
input_ids: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
position_ids: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(input_ids) |
|
if position_ids is None: |
|
batch_size, sequence_length = input_ids.shape |
|
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): |
|
encode_module = module._get_encoder_module() |
|
return encode_module(input_ids, attention_mask, position_ids, **kwargs) |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
input_ids=jnp.array(input_ids, dtype="i4"), |
|
attention_mask=jnp.array(attention_mask, dtype="i4"), |
|
position_ids=jnp.array(position_ids, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
method=_encoder_forward, |
|
) |
|
|
|
def decode( |
|
self, |
|
decoder_input_ids, |
|
encoder_outputs, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_position_ids: Optional[jnp.ndarray] = None, |
|
past_key_values: dict = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
if encoder_attention_mask is None: |
|
batch_size, sequence_length = encoder_hidden_states.shape[:2] |
|
encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
batch_size, sequence_length = decoder_input_ids.shape |
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
if decoder_position_ids is None: |
|
if past_key_values is not None: |
|
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") |
|
|
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
|
|
|
|
|
|
if past_key_values: |
|
inputs["cache"] = past_key_values |
|
mutable = ["cache"] |
|
else: |
|
mutable = False |
|
|
|
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): |
|
decoder_module = module._get_decoder_module() |
|
return decoder_module( |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
**kwargs, |
|
) |
|
|
|
outputs = self.module.apply( |
|
inputs, |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
mutable=mutable, |
|
method=_decoder_forward, |
|
) |
|
|
|
|
|
if past_key_values is not None and return_dict: |
|
outputs, past = outputs |
|
outputs["past_key_values"] = unfreeze(past["cache"]) |
|
return outputs |
|
elif past_key_values is not None and not return_dict: |
|
outputs, past = outputs |
|
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
|
|
|
return outputs |
|
|
|
def __call__( |
|
self, |
|
input_ids: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_input_ids: Optional[jnp.ndarray] = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
position_ids: Optional[jnp.ndarray] = None, |
|
decoder_position_ids: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(input_ids) |
|
if position_ids is None: |
|
batch_size, sequence_length = input_ids.shape |
|
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) |
|
|
|
|
|
if decoder_input_ids is None: |
|
decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id, self.config.decoder_start_token_id) |
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
if decoder_position_ids is None: |
|
batch_size, sequence_length = decoder_input_ids.shape |
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
|
|
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
input_ids=jnp.array(input_ids, dtype="i4"), |
|
attention_mask=jnp.array(attention_mask, dtype="i4"), |
|
position_ids=jnp.array(position_ids, dtype="i4"), |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
) |
|
|
|
|
|
class FlaxIndicTransModel(FlaxIndicTransPreTrainedModel): |
|
config: IndicTransConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
module_class = FlaxIndicTransModule |
|
|
|
|
|
class FlaxIndicTransForConditionalGenerationModule(nn.Module): |
|
config: IndicTransConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros |
|
|
|
def setup(self): |
|
self.model = FlaxIndicTransModule(config=self.config, dtype=self.dtype) |
|
|
|
self.lm_head = nn.Dense( |
|
self.config.decoder_vocab_size, |
|
use_bias=False, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
def _get_encoder_module(self): |
|
return self.model.encoder |
|
|
|
def _get_decoder_module(self): |
|
return self.model.decoder |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
position_ids, |
|
decoder_position_ids, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
): |
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
position_ids=position_ids, |
|
decoder_position_ids=decoder_position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
|
|
if self.config.share_decoder_input_output_embed: |
|
shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"] |
|
lm_logits = jax.lax.stop_gradient(self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)) |
|
else: |
|
lm_logits = jax.lax.stop_gradient(self.lm_head(hidden_states)) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + outputs[1:] |
|
return output |
|
|
|
return FlaxSeq2SeqLMOutput( |
|
logits=lm_logits, |
|
decoder_hidden_states=outputs.decoder_hidden_states, |
|
decoder_attentions=outputs.decoder_attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
|
encoder_hidden_states=outputs.encoder_hidden_states, |
|
encoder_attentions=outputs.encoder_attentions, |
|
) |
|
|
|
|
|
class FlaxIndicTransForConditionalGeneration(FlaxIndicTransPreTrainedModel): |
|
module_class = FlaxIndicTransForConditionalGenerationModule |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def decode( |
|
self, |
|
decoder_input_ids, |
|
encoder_outputs, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_position_ids: Optional[jnp.ndarray] = None, |
|
past_key_values: dict = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
if encoder_attention_mask is None: |
|
batch_size, sequence_length = encoder_hidden_states.shape[:2] |
|
encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
batch_size, sequence_length = decoder_input_ids.shape |
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
if decoder_position_ids is None: |
|
if past_key_values is not None: |
|
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") |
|
|
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
|
|
|
|
|
|
if past_key_values: |
|
inputs["cache"] = past_key_values |
|
mutable = ["cache"] |
|
else: |
|
mutable = False |
|
|
|
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): |
|
decoder_module = module._get_decoder_module() |
|
outputs = decoder_module( |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
**kwargs, |
|
) |
|
hidden_states = outputs[0] |
|
|
|
if self.config.share_decoder_input_output_embed: |
|
shared_embedding = module.model.variables["params"]["decoder"]["embed_tokens"]["embedding"] |
|
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) |
|
else: |
|
lm_logits = module.lm_head(hidden_states) |
|
|
|
return lm_logits, outputs |
|
|
|
outputs = self.module.apply( |
|
inputs, |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
mutable=mutable, |
|
method=_decoder_forward, |
|
) |
|
|
|
if past_key_values is None: |
|
lm_logits, decoder_outputs = outputs |
|
else: |
|
(lm_logits, decoder_outputs), past = outputs |
|
|
|
if return_dict: |
|
outputs = FlaxCausalLMOutputWithCrossAttentions( |
|
logits=lm_logits, |
|
hidden_states=decoder_outputs.hidden_states, |
|
attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
) |
|
else: |
|
outputs = (lm_logits,) + decoder_outputs[1:] |
|
|
|
|
|
if past_key_values is not None and return_dict: |
|
outputs["past_key_values"] = unfreeze(past["cache"]) |
|
return outputs |
|
elif past_key_values is not None and not return_dict: |
|
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
|
|
|
return outputs |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
decoder_input_ids, |
|
max_length, |
|
attention_mask: Optional[jax.Array] = None, |
|
decoder_attention_mask: Optional[jax.Array] = None, |
|
encoder_outputs=None, |
|
**kwargs, |
|
): |
|
|
|
batch_size, seq_length = decoder_input_ids.shape |
|
|
|
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) |
|
|
|
|
|
|
|
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") |
|
if decoder_attention_mask is not None: |
|
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 |
|
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) |
|
else: |
|
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) |
|
|
|
return { |
|
"past_key_values": past_key_values, |
|
"encoder_outputs": encoder_outputs, |
|
"encoder_attention_mask": attention_mask, |
|
"decoder_attention_mask": extended_attention_mask, |
|
"decoder_position_ids": position_ids, |
|
} |
|
|
|
def update_inputs_for_generation(self, model_outputs, model_kwargs): |
|
model_kwargs["past_key_values"] = model_outputs.past_key_values |
|
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 |
|
return model_kwargs |
|
|
|
|