|
from typing import Optional, Tuple |
|
|
|
import jax |
|
from flax import linen as nn |
|
from flax.core import FrozenDict, unfreeze, freeze |
|
from flax.traverse_util import flatten_dict, unflatten_dict |
|
from jax import numpy as jnp |
|
from transformers import FlaxPreTrainedModel |
|
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput |
|
from transformers.modeling_flax_utils import ACT2FN |
|
|
|
from .configuration_retnet import RetNetConfig |
|
|
|
|
|
def rotate_every_two(tensor): |
|
rotate_half_tensor = jnp.stack( |
|
(-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1 |
|
) |
|
rotate_half_tensor = rotate_half_tensor.reshape( |
|
rotate_half_tensor.shape[:-2] + (-1,) |
|
) |
|
return rotate_half_tensor |
|
|
|
|
|
def theta_shift(x, sin, cos): |
|
return (x * cos) + (rotate_every_two(x) * sin) |
|
|
|
|
|
class FlaxRetNetRelPos(nn.Module): |
|
config: RetNetConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
angle = 1.0 / ( |
|
10000 |
|
** jnp.linspace( |
|
0, 1, self.config.hidden_size // self.config.num_rettention_heads // 2 |
|
) |
|
) |
|
self.angle = angle.repeat(2).flatten() |
|
self.decay = jnp.log( |
|
1 |
|
- 2 |
|
** (-5 - jnp.arange(self.config.num_rettention_heads, dtype=jnp.float32)) |
|
) |
|
self.recurrent_chunk_size = self.config.recurrent_chunk_size |
|
|
|
def __call__( |
|
self, |
|
slen: int, |
|
activate_recurrent: bool = False, |
|
chunkwise_recurrent: bool = False, |
|
): |
|
if activate_recurrent: |
|
sin = jnp.sin(self.angle * (slen - 1)) |
|
cos = jnp.cos(self.angle * (slen - 1)) |
|
retention_rel_pos = ((sin, cos), jnp.exp(self.decay)) |
|
elif chunkwise_recurrent: |
|
index = jnp.arange(slen) |
|
sin = jnp.sin(index[:, None] * self.angle[None, :]) |
|
cos = jnp.cos(index[:, None] * self.angle[None, :]) |
|
|
|
block_index = jnp.arange(self.recurrent_chunk_size) |
|
mask = jnp.tril( |
|
jnp.ones((self.recurrent_chunk_size, self.recurrent_chunk_size)) |
|
) |
|
mask = jnp.where( |
|
~mask.astype(jnp.bool_), |
|
float("inf"), |
|
block_index[:, None] - block_index[None, :], |
|
) |
|
mask = jnp.exp(mask * self.decay[:, None, None]) |
|
mask = jnp.nan_to_num(mask) |
|
scale = jnp.sqrt(mask.sum(axis=-1, keepdims=True)) |
|
mask = mask / scale |
|
|
|
cross_decay = jnp.exp(self.decay * self.recurrent_chunk_size) |
|
inner_decay = jnp.exp(self.decay[:, None] * (block_index + 1)) |
|
cross_decay = cross_decay[:, None, None] |
|
inner_decay = inner_decay[:, :, None] / (scale / scale[:, -1, None]) |
|
|
|
retention_rel_pos = ((sin, cos), (mask, cross_decay, inner_decay)) |
|
else: |
|
index = jnp.arange(slen) |
|
sin = jnp.sin(index[:, None] * self.angle[None, :]) |
|
cos = jnp.cos(index[:, None] * self.angle[None, :]) |
|
mask = jnp.tril(jnp.ones((slen, slen))) |
|
mask = jnp.where( |
|
~mask.astype(jnp.bool_), float("inf"), index[:, None] - index[None, :] |
|
) |
|
mask = jnp.exp(mask * self.decay[:, None, None]) |
|
mask = jnp.nan_to_num(mask) |
|
mask = mask / jnp.sqrt(mask.sum(axis=-1, keepdims=True)) |
|
retention_rel_pos = ((sin, cos), mask) |
|
|
|
return retention_rel_pos |
|
|
|
|
|
class FlaxRetNetFeedForward(nn.Module): |
|
config: RetNetConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.fc1 = nn.Dense( |
|
self.config.intermediate_size, |
|
kernel_init=nn.initializers.xavier_normal(), |
|
dtype=self.dtype, |
|
) |
|
self.fc2 = nn.Dense( |
|
self.config.hidden_size, |
|
kernel_init=nn.initializers.xavier_normal(), |
|
dtype=self.dtype, |
|
) |
|
self.activation_fn = ACT2FN[self.config.hidden_act] |
|
self.activation_dropout = nn.Dropout(rate=self.config.dropout) |
|
self.dropout = nn.Dropout(rate=self.config.dropout) |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
deterministic: bool = True, |
|
) -> jnp.ndarray: |
|
hidden_states = self.fc1(hidden_states) |
|
hidden_states = self.activation_fn(hidden_states) |
|
hidden_states = self.activation_dropout( |
|
hidden_states, deterministic=deterministic |
|
) |
|
hidden_states = self.fc2(hidden_states) |
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
|
|
return hidden_states |
|
|
|
|
|
class FlaxRetNetRetention(nn.Module): |
|
config: RetNetConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.factor = 2 |
|
self.embed_dim = self.config.hidden_size |
|
self.num_heads = self.config.num_rettention_heads |
|
self.head_dim = self.embed_dim * self.factor // self.num_heads |
|
self.key_dim = self.embed_dim // self.num_heads |
|
self.scaling = self.key_dim**-0.5 |
|
|
|
self.q_proj = nn.Dense( |
|
self.embed_dim, |
|
use_bias=True, |
|
kernel_init=jax.nn.initializers.xavier_normal(), |
|
dtype=self.dtype, |
|
) |
|
self.k_proj = nn.Dense( |
|
self.embed_dim, |
|
use_bias=True, |
|
kernel_init=jax.nn.initializers.xavier_normal(), |
|
dtype=self.dtype, |
|
) |
|
self.v_proj = nn.Dense( |
|
self.embed_dim * self.factor, |
|
use_bias=True, |
|
kernel_init=jax.nn.initializers.xavier_normal(), |
|
dtype=self.dtype, |
|
) |
|
self.g_proj = nn.Dense( |
|
self.embed_dim * self.factor, |
|
use_bias=True, |
|
kernel_init=nn.initializers.xavier_normal(), |
|
dtype=self.dtype, |
|
) |
|
|
|
self.out_proj = nn.Dense( |
|
self.embed_dim, |
|
use_bias=True, |
|
kernel_init=jax.nn.initializers.xavier_normal(), |
|
dtype=self.dtype, |
|
) |
|
|
|
self.group_norm = nn.LayerNorm(epsilon=1e-6, dtype=self.dtype) |
|
|
|
def parallel_forward(self, qr, kr, v, mask): |
|
bsz, tgt_len, embed_dim = v.shape |
|
|
|
vr = v.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose( |
|
(0, 2, 1, 3) |
|
) |
|
|
|
qk_mat = qr @ kr.transpose((0, 1, 3, 2)) |
|
qk_mat = qk_mat * mask |
|
qk_mat /= jnp.abs( |
|
jax.lax.stop_gradient(qk_mat).sum(axis=-1, keepdims=True) |
|
).clip(min=1) |
|
output = jnp.matmul(qk_mat, vr) |
|
output = output.transpose((0, 2, 1, 3)) |
|
|
|
return output |
|
|
|
def chunk_recurrent_forward(self, qr, kr, v, inner_mask): |
|
mask, cross_decay, inner_decay = inner_mask |
|
bsz, tgt_len, embed_dim = v.shape |
|
chunk_len = mask.shape[1] |
|
num_chunks = tgt_len // chunk_len |
|
|
|
assert tgt_len % chunk_len == 0 |
|
|
|
qr = qr.reshape( |
|
bsz, self.num_heads, num_chunks, chunk_len, self.key_dim |
|
).transpose((0, 2, 1, 3, 4)) |
|
kr = kr.reshape( |
|
bsz, self.num_heads, num_chunks, chunk_len, self.key_dim |
|
).transpose((0, 2, 1, 3, 4)) |
|
v = v.reshape( |
|
bsz, num_chunks, chunk_len, self.num_heads, self.head_dim |
|
).transpose((0, 1, 3, 2, 4)) |
|
|
|
kr_t = kr.transpose((0, 1, 2, 4, 3)) |
|
|
|
qk_mat = qr @ kr_t |
|
qk_mat = qk_mat |
|
inner_scale = jnp.abs( |
|
jax.lax.stop_gradient(qk_mat).sum(axis=-1, keepdims=True) |
|
).clip(min=1) |
|
qk_mat = qk_mat / inner_scale |
|
inner_output = jnp.matmul(qk_mat, v) |
|
|
|
kv = kr_t @ v |
|
kv = kv.reshape(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim) |
|
|
|
kv_recurrent = [] |
|
cross_scale = [] |
|
kv_state = jnp.zeros((bsz, self.num_heads, self.key_dim, self.head_dim)) |
|
kv_scale = jnp.ones((bsz, self.num_heads, 1, 1)) |
|
|
|
for i in range(num_chunks): |
|
kv_recurrent.append(kv_state / kv_scale) |
|
cross_scale.append(kv_scale) |
|
|
|
kv_state = kv_state * cross_decay + kv[:, i] |
|
kv_scale = ( |
|
jnp.abs(jax.lax.stop_gradient(kv_state).sum(axis=-2, keepdims=True)) |
|
.max(axis=-1, keepdims=True) |
|
.clip(min=1) |
|
) |
|
|
|
kv_recurrent = jnp.stack(kv_recurrent, axis=1) |
|
cross_scale = jnp.stack(cross_scale, axis=1) |
|
|
|
all_scale = jnp.maximum(inner_scale, cross_scale) |
|
align_inner_scale = all_scale / inner_scale |
|
align_cross_scale = all_scale / cross_scale |
|
|
|
cross_output = (qr * inner_decay) @ kv_recurrent |
|
output = inner_output / align_inner_scale + cross_output / align_cross_scale |
|
|
|
output = output.transpose((0, 2, 1, 3, 4)) |
|
return output |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
rel_pos: Optional[jnp.ndarray] = None, |
|
chunkwise_recurrent: bool = True, |
|
incremental_state=None, |
|
) -> jnp.ndarray: |
|
bsz, tgt_len, _ = hidden_states.shape |
|
(sin, cos), inner_mask = rel_pos |
|
|
|
q = self.q_proj(hidden_states) |
|
k = self.k_proj(hidden_states) |
|
v = self.v_proj(hidden_states) |
|
g = self.g_proj(hidden_states) |
|
|
|
k *= self.scaling |
|
q = q.reshape(bsz, tgt_len, self.num_heads, self.key_dim).transpose( |
|
(0, 2, 1, 3) |
|
) |
|
k = k.reshape(bsz, tgt_len, self.num_heads, self.key_dim).transpose( |
|
(0, 2, 1, 3) |
|
) |
|
|
|
qr = theta_shift(q, sin, cos) |
|
kr = theta_shift(k, sin, cos) |
|
|
|
if incremental_state is not None: |
|
raise NotImplementedError |
|
elif self.config.attention_type == "chunkwise_recurrent": |
|
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask=inner_mask) |
|
else: |
|
output = self.parallel_forward(qr, kr, v, inner_mask) |
|
|
|
output = self.group_norm(output) |
|
output = output.reshape(bsz, tgt_len, -1) |
|
|
|
output = nn.swish(g) * output |
|
output = self.out_proj(output) |
|
|
|
return output |
|
|
|
|
|
class FlaxRetNetLayer(nn.Module): |
|
config: RetNetConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.retention = FlaxRetNetRetention(self.config, dtype=self.dtype) |
|
self.retention_layer_norm = nn.LayerNorm( |
|
epsilon=self.config.layer_norm_eps, dtype=self.dtype |
|
) |
|
|
|
self.ffn = FlaxRetNetFeedForward(self.config, dtype=self.dtype) |
|
self.final_layer_norm = nn.LayerNorm( |
|
epsilon=self.config.layer_norm_eps, dtype=self.dtype |
|
) |
|
|
|
self.dropout_module = nn.Dropout(rate=self.config.dropout) |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
retention_rel_pos: Optional[tuple] = None, |
|
deterministic: bool = True, |
|
) -> jnp.ndarray: |
|
residual = hidden_states |
|
hidden_states = self.retention_layer_norm(hidden_states) |
|
hidden_states = self.retention(hidden_states, rel_pos=retention_rel_pos) |
|
hidden_states = self.dropout_module(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
hidden_states = self.ffn(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
|
|
return hidden_states |
|
|
|
|
|
class FlaxRetNetLayerCollection(nn.Module): |
|
config: RetNetConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.layers = [ |
|
FlaxRetNetLayer(self.config, dtype=self.dtype) |
|
for _ in range(self.config.num_hidden_layers) |
|
] |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
retention_rel_pos: tuple = None, |
|
deterministic: bool = True, |
|
output_retentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
) -> jnp.ndarray: |
|
all_hidden_states = () if output_hidden_states else None |
|
all_retentions = () if output_retentions else None |
|
|
|
for layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
layer_outputs = layer( |
|
hidden_states, |
|
retention_rel_pos=retention_rel_pos, |
|
deterministic=deterministic, |
|
) |
|
hidden_states = layer_outputs |
|
|
|
outputs = (hidden_states, all_hidden_states, all_retentions) |
|
return outputs |
|
|
|
|
|
class FlaxRetNetPretrainedModel(FlaxPreTrainedModel): |
|
config_class = RetNetConfig |
|
base_model_prefix = "transformer" |
|
main_input_name = "input_ids" |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: RetNetConfig, |
|
input_shape: Tuple = (1, 1), |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
_do_init: bool = True, |
|
**kwargs |
|
): |
|
module = self.module_class(config, dtype=dtype, **kwargs) |
|
super().__init__( |
|
config, |
|
module, |
|
input_shape=input_shape, |
|
seed=seed, |
|
dtype=dtype, |
|
_do_init=_do_init, |
|
) |
|
|
|
def init_weights( |
|
self, |
|
rng: jax.random.PRNGKey, |
|
input_shape: Tuple, |
|
params: FrozenDict = None, |
|
) -> FrozenDict: |
|
input_ids = jnp.zeros(input_shape, dtype="i4") |
|
attention_mask = jnp.ones_like(input_ids) |
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
module_init_outputs = self.module.init( |
|
rngs, input_ids, attention_mask, return_dict=False |
|
) |
|
|
|
random_params = module_init_outputs["params"] |
|
|
|
if params is not None: |
|
random_params = flatten_dict(unfreeze(random_params)) |
|
params = flatten_dict(unfreeze(params)) |
|
for missing_key in self._missing_keys: |
|
params[missing_key] = random_params[missing_key] |
|
self._missing_keys = [] |
|
return freeze(unflatten_dict(params)) |
|
else: |
|
return random_params |
|
|
|
def __call__( |
|
self, |
|
input_ids: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
params: dict = None, |
|
dropout_rng: jnp.ndarray = None, |
|
train: bool = False, |
|
output_retentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
output_retentions = ( |
|
output_retentions |
|
if output_retentions is not None |
|
else self.config.output_retentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.return_dict |
|
) |
|
|
|
batch_size, sequence_length = input_ids.shape |
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
outputs = self.module.apply( |
|
inputs, |
|
jnp.array(input_ids, dtype="i4"), |
|
jnp.array(attention_mask, dtype="i4"), |
|
not train, |
|
output_retentions, |
|
output_hidden_states, |
|
return_dict, |
|
rngs=rngs, |
|
) |
|
|
|
return outputs |
|
|
|
|
|
class FlaxRetNetModule(nn.Module): |
|
config: RetNetConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.embed_tokens = nn.Embed( |
|
self.config.vocab_size, |
|
self.config.hidden_size, |
|
embedding_init=jax.nn.initializers.xavier_normal(), |
|
dtype=self.dtype, |
|
) |
|
self.retnet_rel_pos = FlaxRetNetRelPos(self.config, dtype=self.dtype) |
|
|
|
self.layers = FlaxRetNetLayerCollection(self.config, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
input_ids: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
deterministic: bool = True, |
|
output_retentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
input_embeds = self.embed_tokens(input_ids) |
|
|
|
batch_size, sequence_length = input_embeds.shape[:2] |
|
retention_rel_pos = self.retnet_rel_pos( |
|
sequence_length, |
|
activate_recurrent=False, |
|
chunkwise_recurrent=self.config.attention_type == "chunkwise_recurrent", |
|
) |
|
|
|
outputs = self.layers( |
|
input_embeds, |
|
retention_rel_pos=retention_rel_pos, |
|
deterministic=deterministic, |
|
output_retentions=output_retentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if not return_dict: |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutput( |
|
last_hidden_state=outputs[0], |
|
hidden_states=outputs[1], |
|
attentions=outputs[-1], |
|
) |
|
|
|
|
|
class FlaxRetNetModel(FlaxRetNetPretrainedModel): |
|
module_class = FlaxRetNetModule |
|
|
|
|
|
class FlaxRetNetForCausalLMModule(nn.Module): |
|
config: RetNetConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.transformer = FlaxRetNetModule(self.config, dtype=self.dtype) |
|
|
|
self.lm_head = nn.Dense( |
|
self.config.vocab_size, |
|
use_bias=False, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__( |
|
self, |
|
input_ids: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
deterministic: bool = True, |
|
output_retentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
outputs = self.transformer( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
deterministic=deterministic, |
|
output_retentions=output_retentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = outputs[0] |
|
|
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
if not return_dict: |
|
return (lm_logits,) + outputs[1:] |
|
|
|
return FlaxCausalLMOutput( |
|
logits=lm_logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class FlaxRetNetForCausalLM(FlaxRetNetPretrainedModel): |
|
module_class = FlaxRetNetForCausalLMModule |
|
|