Spaces:
Sleeping
Sleeping
# Copyright 2022 Google. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Class for T5 relative position biases. | |
Adapted from flaxformer.components.relative_position_biases.py | |
""" | |
from typing import Any, Callable, Optional | |
from flax import linen as nn | |
import gin | |
from jax import lax | |
import jax.numpy as jnp | |
from transformer import position | |
import numpy as np | |
Array = Any | |
class T5RelativePositionBiases(nn.Module): | |
"""Adds T5-style relative positional embeddings to the attention logits. | |
Attributes: | |
num_buckets: Number of buckets to bucket distances between key and query | |
positions into. | |
max_distance: Maximum distance before everything is lumped into the last | |
distance bucket. | |
num_heads: Number of heads in the attention layer. Each head will get a | |
different relative position weighting. | |
dtype: Type of arrays through this module. | |
embedding_init: initializer for relative embedding table. | |
""" | |
num_buckets: int | |
max_distance: int | |
num_heads: int | |
dtype: Any | |
embedding_init: Callable[..., Array] = nn.linear.default_embed_init | |
def _relative_position_bucket(relative_position, | |
bidirectional=True, | |
num_buckets=32, | |
max_distance=128): | |
"""Translate relative position to a bucket number for relative attention. | |
The relative position is defined as memory_position - query_position, i.e. | |
the distance in tokens from the attending position to the attended-to | |
position. If bidirectional=False, then positive relative positions are | |
invalid. | |
We use smaller buckets for small absolute relative_position and larger | |
buckets for larger absolute relative_positions. All relative | |
positions >=max_distance map to the same bucket. All relative | |
positions <=-max_distance map to the same bucket. This should allow for | |
more graceful generalization to longer sequences than the model has been | |
trained on. | |
Args: | |
relative_position: an int32 array | |
bidirectional: a boolean - whether the attention is bidirectional | |
num_buckets: an integer | |
max_distance: an integer | |
Returns: | |
a Tensor with the same shape as relative_position, containing int32 | |
values in the range [0, num_buckets) | |
""" | |
ret = 0 | |
n = -relative_position | |
if bidirectional: | |
num_buckets //= 2 | |
ret += (n < 0).astype(np.int32) * num_buckets | |
n = np.abs(n) | |
else: | |
n = np.maximum(n, 0) | |
# now n is in the range [0, inf) | |
max_exact = num_buckets // 2 | |
is_small = (n < max_exact) | |
val_if_large = max_exact + ( | |
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) / | |
np.log(max_distance / max_exact) * | |
(num_buckets - max_exact)).astype(np.int32) | |
val_if_large = np.minimum(val_if_large, num_buckets - 1) | |
ret += np.where(is_small, n, val_if_large) | |
return ret | |
def __call__(self, num_queries, num_keys, offset: Optional[int]=None, | |
bidirectional=True): | |
"""Produce relative position embedding attention biases. | |
Args: | |
num_queries: Number of queries. | |
num_keys: Number of keys. | |
offset: Offset of the first query with respect to the first key. | |
(See position.relative_positions() for more info.) | |
bidirectional: whether to allow positive memory-query relative position | |
embeddings. | |
Returns: | |
output: `(1, num_heads, num_queries, num_keys)` attention bias | |
""" | |
# Find the distance between each query and each key. | |
# This is where this implementation differs from the T5 implementation; | |
# this version lines the /last/ N queries up with the /last/ N keys, | |
# (which is appropriate for XL/sliding window) while the T5 version lines | |
# up the /first/ N queries with the first N keys, in cases where the | |
# number of keys and queries differ. | |
relative_position = position.relative_positions_np( | |
num_queries=num_queries, num_keys=num_keys, offset=offset) | |
rp_bucket = self._relative_position_bucket( | |
relative_position, | |
bidirectional=bidirectional, | |
num_buckets=self.num_buckets, | |
max_distance=self.max_distance) | |
relative_attention_bias = self.param('rel_embedding', self.embedding_init, | |
(self.num_heads, self.num_buckets), | |
jnp.float32) | |
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) | |
# Instead of using a slow gather, we create a leading-dimension one-hot | |
# array from rp_bucket and use it to perform the gather-equivalent via a | |
# contraction, i.e.: | |
# (num_head, num_buckets) x (num_buckets one-hot, num_queries, num_keys). | |
# This is equivalent to relative_attention_bias[:, rp_bucket] | |
bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) | |
rp_bucket_one_hot = jnp.array( | |
rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) | |
# --> shape (num_queries, num_keys, num_heads) | |
values = lax.dot_general( | |
relative_attention_bias, | |
rp_bucket_one_hot, | |
( | |
((1,), (0,)), # rhs, lhs contracting dims | |
((), ()))) # no batched dims | |
# Add a singleton batch dimension. | |
# --> shape (1, num_heads, num_queries, num_keys) | |
out = values[jnp.newaxis, ...] | |
return out | |