Spaces:
Sleeping
Sleeping
File size: 5,961 Bytes
15bcbe6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for T5 relative position biases.
Adapted from flaxformer.components.relative_position_biases.py
"""
from typing import Any, Callable, Optional
from flax import linen as nn
import gin
from jax import lax
import jax.numpy as jnp
from transformer import position
import numpy as np
Array = Any
@gin.configurable
class T5RelativePositionBiases(nn.Module):
"""Adds T5-style relative positional embeddings to the attention logits.
Attributes:
num_buckets: Number of buckets to bucket distances between key and query
positions into.
max_distance: Maximum distance before everything is lumped into the last
distance bucket.
num_heads: Number of heads in the attention layer. Each head will get a
different relative position weighting.
dtype: Type of arrays through this module.
embedding_init: initializer for relative embedding table.
"""
num_buckets: int
max_distance: int
num_heads: int
dtype: Any
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
@staticmethod
def _relative_position_bucket(relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128):
"""Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are
invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative
positions <=-max_distance map to the same bucket. This should allow for
more graceful generalization to longer sequences than the model has been
trained on.
Args:
relative_position: an int32 array
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += (n < 0).astype(np.int32) * num_buckets
n = np.abs(n)
else:
n = np.maximum(n, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = (n < max_exact)
val_if_large = max_exact + (
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) /
np.log(max_distance / max_exact) *
(num_buckets - max_exact)).astype(np.int32)
val_if_large = np.minimum(val_if_large, num_buckets - 1)
ret += np.where(is_small, n, val_if_large)
return ret
@nn.compact
def __call__(self, num_queries, num_keys, offset: Optional[int]=None,
bidirectional=True):
"""Produce relative position embedding attention biases.
Args:
num_queries: Number of queries.
num_keys: Number of keys.
offset: Offset of the first query with respect to the first key.
(See position.relative_positions() for more info.)
bidirectional: whether to allow positive memory-query relative position
embeddings.
Returns:
output: `(1, num_heads, num_queries, num_keys)` attention bias
"""
# Find the distance between each query and each key.
# This is where this implementation differs from the T5 implementation;
# this version lines the /last/ N queries up with the /last/ N keys,
# (which is appropriate for XL/sliding window) while the T5 version lines
# up the /first/ N queries with the first N keys, in cases where the
# number of keys and queries differ.
relative_position = position.relative_positions_np(
num_queries=num_queries, num_keys=num_keys, offset=offset)
rp_bucket = self._relative_position_bucket(
relative_position,
bidirectional=bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance)
relative_attention_bias = self.param('rel_embedding', self.embedding_init,
(self.num_heads, self.num_buckets),
jnp.float32)
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
# Instead of using a slow gather, we create a leading-dimension one-hot
# array from rp_bucket and use it to perform the gather-equivalent via a
# contraction, i.e.:
# (num_head, num_buckets) x (num_buckets one-hot, num_queries, num_keys).
# This is equivalent to relative_attention_bias[:, rp_bucket]
bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
rp_bucket_one_hot = jnp.array(
rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
# --> shape (num_queries, num_keys, num_heads)
values = lax.dot_general(
relative_attention_bias,
rp_bucket_one_hot,
(
((1,), (0,)), # rhs, lhs contracting dims
((), ()))) # no batched dims
# Add a singleton batch dimension.
# --> shape (1, num_heads, num_queries, num_keys)
out = values[jnp.newaxis, ...]
return out
|