Spaces:
Sleeping
Sleeping
Update ag4masses/alphageometry/transformer_layer.py
Browse files
ag4masses/alphageometry/transformer_layer.py
CHANGED
@@ -1,526 +1,526 @@
|
|
1 |
-
# Copyright 2023 DeepMind Technologies Limited
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
# ==============================================================================
|
15 |
-
|
16 |
-
"""A single transformer layer in inference mode.
|
17 |
-
|
18 |
-
Modified
|
19 |
-
https://github.com/google-research/meliad/blob/main/transformer/transformer_layer.py
|
20 |
-
To accommodate sequence packing + kv cache + relative position during test time.
|
21 |
-
"""
|
22 |
-
|
23 |
-
from typing import Callable, Mapping, NewType, Optional, Tuple
|
24 |
-
|
25 |
-
from absl import logging
|
26 |
-
import gin
|
27 |
-
import jax
|
28 |
-
import jax.numpy as jnp
|
29 |
-
from
|
30 |
-
from
|
31 |
-
from
|
32 |
-
from
|
33 |
-
|
34 |
-
Array = jnp.ndarray
|
35 |
-
DecoderState = NewType("DecoderState", Mapping[str, Array])
|
36 |
-
WindowState = Optional[Tuple[attention.KVITuple, Array]]
|
37 |
-
|
38 |
-
|
39 |
-
@jax.vmap
|
40 |
-
def update_slice_in_dim_1(array: Array, update: Array, idx: Array) -> Array:
|
41 |
-
"""Update a stored keys/values slice for different-lengthed seqs in batch."""
|
42 |
-
return jax.lax.dynamic_update_slice_in_dim(array, update, idx, axis=0)
|
43 |
-
|
44 |
-
|
45 |
-
def slice_in_dim_1(window_length: int) -> Callable[[Array, Array], Array]:
|
46 |
-
@jax.vmap
|
47 |
-
def fn(array: Array, idx: Array) -> Array:
|
48 |
-
return jax.lax.dynamic_slice_in_dim(array, idx, window_length, axis=0)
|
49 |
-
|
50 |
-
return fn
|
51 |
-
|
52 |
-
|
53 |
-
@gin.configurable
|
54 |
-
class TransformerLayerGenerate(transformer_layer.TransformerLayer):
|
55 |
-
"""Full transformer layer, with attention."""
|
56 |
-
|
57 |
-
def _next_decoder_state(
|
58 |
-
self, decoder_state: DecoderState, keys: Array, values: Array
|
59 |
-
) -> Tuple[DecoderState, Array, Array]:
|
60 |
-
"""Compute the next decoder state, and return keys,values to attend to.
|
61 |
-
|
62 |
-
The keys,values returned from this function are drawn from the prior
|
63 |
-
decoding state, and comprise a full window of local context.
|
64 |
-
|
65 |
-
Args:
|
66 |
-
decoder_state: The current decoder state, initially created using
|
67 |
-
init_decoder_state().
|
68 |
-
keys: The key for the current token, of shape (batch_size, 1, dim)
|
69 |
-
values: The value for the current token of shape (batch_size, 1, dim)
|
70 |
-
|
71 |
-
Returns:
|
72 |
-
(next_decoder_state,
|
73 |
-
window of keys of shape (batch_size, window_length, dim),
|
74 |
-
window of values of shape (batch_size, window_length, dim))
|
75 |
-
"""
|
76 |
-
|
77 |
-
assert keys.shape[1] == 1 # single-token autoregressive decoding.
|
78 |
-
|
79 |
-
# Unpack decoder_state
|
80 |
-
stored_keys = decoder_state["keys"]
|
81 |
-
stored_values = decoder_state["values"]
|
82 |
-
curr_index = decoder_state["current_index"]
|
83 |
-
|
84 |
-
# Slice to get window_length-sized chunk of previous keys,values.
|
85 |
-
out_decoder_state = {}
|
86 |
-
curr_win_index = curr_index - self.window_length
|
87 |
-
|
88 |
-
# out_keys = jax.lax.dynamic_slice_in_dim(
|
89 |
-
# stored_keys, curr_win_index, self.window_length, axis=1)
|
90 |
-
out_keys = slice_in_dim_1(self.window_length)(stored_keys, curr_win_index)
|
91 |
-
|
92 |
-
# out_values = jax.lax.dynamic_slice_in_dim(
|
93 |
-
# stored_values, curr_win_index, self.window_length, axis=1)
|
94 |
-
out_values = slice_in_dim_1(self.window_length)(
|
95 |
-
stored_values, curr_win_index
|
96 |
-
)
|
97 |
-
|
98 |
-
# Write current keys,values to stored keys, values.
|
99 |
-
# stored_keys = jax.lax.dynamic_update_slice_in_dim(
|
100 |
-
# stored_keys, keys, curr_index, axis=1)
|
101 |
-
stored_keys = update_slice_in_dim_1(stored_keys, keys, curr_index)
|
102 |
-
# stored_values = jax.lax.dynamic_update_slice_in_dim(
|
103 |
-
# stored_values, values, curr_index, axis=1)
|
104 |
-
stored_values = update_slice_in_dim_1(stored_values, values, curr_index)
|
105 |
-
curr_index = curr_index + 1
|
106 |
-
|
107 |
-
# Pack a new decoder_state object.
|
108 |
-
out_decoder_state["keys"] = stored_keys
|
109 |
-
out_decoder_state["values"] = stored_values
|
110 |
-
out_decoder_state["current_index"] = curr_index
|
111 |
-
out_decoder_state["relative_position_bias"] = decoder_state[
|
112 |
-
"relative_position_bias"
|
113 |
-
]
|
114 |
-
out_decoder_state["recurrent_kvq"] = decoder_state["recurrent_kvq"]
|
115 |
-
|
116 |
-
return (DecoderState(out_decoder_state), out_keys, out_values)
|
117 |
-
|
118 |
-
def __call__(
|
119 |
-
self,
|
120 |
-
xs: Array,
|
121 |
-
start_of_sequence: Array,
|
122 |
-
*,
|
123 |
-
importance: Optional[Array] = None,
|
124 |
-
cross_attention_kv: Optional[Tuple[Array, Array]] = None,
|
125 |
-
window_state: Optional[WindowState] = None,
|
126 |
-
decoder_state: Optional[DecoderState] = None,
|
127 |
-
):
|
128 |
-
"""Computes attention over a sequence of inputs.
|
129 |
-
|
130 |
-
Args:
|
131 |
-
xs: input sequence of shape (batch_size, sequence_length, num_hidden)
|
132 |
-
start_of_sequence: An input array of shape (batch_size) --- The following
|
133 |
-
must be passed by keyword only. ---
|
134 |
-
importance: Array of shape (batch_size, sequence_length). An importance
|
135 |
-
bias for attention.
|
136 |
-
cross_attention_kv: Keys and values from encoder for cross-attention.
|
137 |
-
window_state: State object which contains context from the prior window
|
138 |
-
when using a transformer-XL or sliding window. Initially created with
|
139 |
-
load_window_state().
|
140 |
-
decoder_state: State object for autoregressive decoding, initially created
|
141 |
-
with from init_decoder_state().
|
142 |
-
|
143 |
-
Returns:
|
144 |
-
(ys: outputs of shape (batch_size, sequence_length, num_hidden),
|
145 |
-
importance_score: importance score for the next layer,
|
146 |
-
next_window_state: state to pass to the next window,
|
147 |
-
next_decoder_state: next decoder state for autoregressive decoding,
|
148 |
-
viz_dict: dictionary of visualizations
|
149 |
-
)
|
150 |
-
"""
|
151 |
-
|
152 |
-
xs = jnp.asarray(xs, dtype=self.dtype)
|
153 |
-
logging.info("tlayer: recurrent = %r", self.recurrent_attention)
|
154 |
-
logging.info("tlayer: compute_importance = %r", self.compute_importance)
|
155 |
-
|
156 |
-
is_training = self.mode == "train"
|
157 |
-
|
158 |
-
# Compute keys, values and queries.
|
159 |
-
# ---------------------------------
|
160 |
-
logging.info("tlayer: compute keys,values,queries.")
|
161 |
-
(keys, values, queries, queries2) = self.tbase.kvq(xs)
|
162 |
-
attention_scale_factors = self.tbase.attention_scale_factors()
|
163 |
-
(_, sequence_length, num_heads, _) = queries.shape # (b, k, h, d)
|
164 |
-
|
165 |
-
# Get biases and masks that are shared across windows.
|
166 |
-
# ----------------------------------------------------
|
167 |
-
if decoder_state is not None:
|
168 |
-
logging.info("tlayer: using autoregressive decoder.")
|
169 |
-
# When decoding, prior keys,values are loaded from the decoder state.
|
170 |
-
# Other values are precomputed, and loaded from the decoder state.
|
171 |
-
# The decoder state will be updated with the current token.
|
172 |
-
assert window_state is None
|
173 |
-
|
174 |
-
prev_kvi = None
|
175 |
-
recurrent_state = None # Use precomputed recurrent_kvq.
|
176 |
-
cross_attention_kv = None
|
177 |
-
rel_position_bias = decoder_state["relative_position_bias"]
|
178 |
-
causal_mask = None
|
179 |
-
dropout_multiplier = None
|
180 |
-
|
181 |
-
# Reuse cached recurrent keys,values for each token.
|
182 |
-
cached_recurrent_kvq = decoder_state["recurrent_kvq"]
|
183 |
-
if cached_recurrent_kvq is not None:
|
184 |
-
assert cross_attention_kv is None
|
185 |
-
cross_attention_kv = (cached_recurrent_kvq[0], cached_recurrent_kvq[1])
|
186 |
-
del cached_recurrent_kvq
|
187 |
-
|
188 |
-
# Get a full window of keys,values and update decoder state.
|
189 |
-
(decoder_state, keys, values) = self._next_decoder_state(
|
190 |
-
decoder_state, keys, values
|
191 |
-
)
|
192 |
-
|
193 |
-
# Each query attends to window_length prior keys.
|
194 |
-
assert keys.shape[1] == self.window_length
|
195 |
-
kq_relative_offset = self.window_length
|
196 |
-
|
197 |
-
if not self.use_long_xl_architecture:
|
198 |
-
kqpos = position.relative_positions(
|
199 |
-
1, self.window_length, offset=0
|
200 |
-
) # 2D mask
|
201 |
-
current_idx = decoder_state["current_index"]
|
202 |
-
|
203 |
-
# add (batch, heads) dims for kqpos
|
204 |
-
kqpos = jnp.expand_dims(kqpos, axis=(0, 1))
|
205 |
-
kqpos = jnp.tile(kqpos, (1, self.num_heads, 1, 1))
|
206 |
-
|
207 |
-
# add (_, heads, _) dim for current_idx
|
208 |
-
current_idx = jnp.expand_dims(current_idx, axis=(1, 2, 3))
|
209 |
-
|
210 |
-
causal_mask = kqpos > self.window_length * 2 - current_idx
|
211 |
-
else:
|
212 |
-
logging.info("tlayer: windowed attention.")
|
213 |
-
# When training, attention is done using windows or chunks, and prior
|
214 |
-
# context (e.g. keys,values from the previous window) is stored in the
|
215 |
-
# window_state object.
|
216 |
-
(prev_kvi, recurrent_state) = (
|
217 |
-
window_state # pytype: disable=attribute-error
|
218 |
-
)
|
219 |
-
|
220 |
-
# Get the size of the sliding window for pos bias, dropout, & causal mask.
|
221 |
-
(num_queries, num_keys) = attention.sliding_attention_window_shape(
|
222 |
-
(keys, values, importance),
|
223 |
-
prev_kvi,
|
224 |
-
queries,
|
225 |
-
window_length=self.window_length,
|
226 |
-
)
|
227 |
-
kq_relative_offset = num_keys - num_queries
|
228 |
-
|
229 |
-
# Get the relative position bias.
|
230 |
-
# The bias doesn't depend on the query content, and so can be precomputed.
|
231 |
-
if self.relative_positions is not None:
|
232 |
-
rel_position_bias = self.relative_positions(
|
233 |
-
num_queries, num_keys, bidirectional=False
|
234 |
-
)
|
235 |
-
else:
|
236 |
-
rel_position_bias = None
|
237 |
-
|
238 |
-
# Get causal mask.
|
239 |
-
if self.use_causal_mask:
|
240 |
-
causal_mask = position.causal_mask(
|
241 |
-
num_queries, num_keys, window_length=self.window_length
|
242 |
-
)
|
243 |
-
else:
|
244 |
-
causal_mask = None
|
245 |
-
|
246 |
-
# Apply dropout to the attention matrix.
|
247 |
-
# The mask will be broadcast across batches and windows.
|
248 |
-
if self.attn_dropout_rate > 0.0 and is_training:
|
249 |
-
dropout_rng = self.make_rng("dropout")
|
250 |
-
attn_shape = (self.num_heads, num_queries, num_keys)
|
251 |
-
dropout_multiplier = nn_components.dropout_multiplier_mask(
|
252 |
-
dropout_rng, self.attn_dropout_rate, attn_shape, self.dtype
|
253 |
-
)
|
254 |
-
else:
|
255 |
-
dropout_multiplier = None
|
256 |
-
|
257 |
-
# Load and store values into external memory, if memory is not None.
|
258 |
-
# ------------------------------------------------------------------
|
259 |
-
(mode, _, update_memory) = self._get_cache_name_from_mode(self.mode)
|
260 |
-
external_kv = self._query_external_memory(
|
261 |
-
keys,
|
262 |
-
values,
|
263 |
-
queries,
|
264 |
-
start_of_sequence=start_of_sequence,
|
265 |
-
mode=mode,
|
266 |
-
update_memory=decoder_state is None and update_memory,
|
267 |
-
)
|
268 |
-
|
269 |
-
if (
|
270 |
-
self.memory is not None
|
271 |
-
and self.memory_combine_with_local == "TRAINABLE_WEIGHTED_MEAN"
|
272 |
-
):
|
273 |
-
external_memory_bias = jnp.asarray(self.memory_bias, dtype=self.dtype)
|
274 |
-
external_memory_bias = jnp.reshape(
|
275 |
-
external_memory_bias, (1, 1, num_heads, 1)
|
276 |
-
)
|
277 |
-
external_memory_bias = jax.nn.sigmoid(external_memory_bias)
|
278 |
-
else:
|
279 |
-
external_memory_bias = None
|
280 |
-
|
281 |
-
# Compute the number of windows.
|
282 |
-
# ------------------------------
|
283 |
-
if sequence_length < self.window_length:
|
284 |
-
num_windows = 1 # Happens with autoregressive decoding.
|
285 |
-
elif sequence_length == self.window_length:
|
286 |
-
num_windows = 1
|
287 |
-
if self.use_long_xl_architecture:
|
288 |
-
assert prev_kvi is not None
|
289 |
-
else:
|
290 |
-
if not self.use_long_xl_architecture:
|
291 |
-
raise ValueError("Can only use sliding window with Transformer XL.")
|
292 |
-
num_windows = sequence_length // self.window_length
|
293 |
-
if (num_windows * self.window_length) != sequence_length:
|
294 |
-
raise ValueError(
|
295 |
-
f"Window length {self.window_length} must be a "
|
296 |
-
+ f"multiple of sequence length {sequence_length}"
|
297 |
-
)
|
298 |
-
logging.info("tlayer: num_windows = %d.", num_windows)
|
299 |
-
|
300 |
-
# Define the function to do attention within a single window.
|
301 |
-
# ---------------------------------------------------------
|
302 |
-
def single_window_attention(
|
303 |
-
carry: tuple[Array, Array], inputs_w: tuple[Array, Array]
|
304 |
-
) -> tuple[tuple[Array, Array], tuple[Array, Array]]:
|
305 |
-
# This function uses the following variables from the outer scope.
|
306 |
-
# They are listed here for clarity.
|
307 |
-
nonlocal rel_position_bias
|
308 |
-
nonlocal causal_mask
|
309 |
-
nonlocal kq_relative_offset
|
310 |
-
nonlocal dropout_multiplier
|
311 |
-
nonlocal attention_scale_factors
|
312 |
-
nonlocal external_memory_bias
|
313 |
-
nonlocal cross_attention_kv # externally supplied.
|
314 |
-
|
315 |
-
# keys,values,queries over the whole sequence will be split into chunks.
|
316 |
-
# xs_w, kvqi_w, etc. are the chunk for the current window.
|
317 |
-
(prev_kvi_w, rec_state) = carry # carried from one window to the next.
|
318 |
-
(kvqi_w, external_kv_w) = inputs_w # inputs to the current window.
|
319 |
-
# (keys_curr_w, values_curr_w, _, _, importance_curr_w) = kvqi_w
|
320 |
-
|
321 |
-
# Concatenate keys,values from the previous window with the current
|
322 |
-
# window to implement sliding window attention.
|
323 |
-
(kvqi_w, next_kvi_w) = attention.concat_kvqi(kvqi_w, prev_kvi_w)
|
324 |
-
(keys_w, values_w, queries_w, queries2_w, importance_w) = kvqi_w
|
325 |
-
|
326 |
-
# Perform recurrent attention within the current window to get the next
|
327 |
-
# recurrent state, and set up cross attention.
|
328 |
-
if rec_state is not None:
|
329 |
-
logging.info("tlayer: recurrent attention.")
|
330 |
-
|
331 |
-
# NOTE -- recurrent states and input tokens are handled separately,
|
332 |
-
# because they have separate learned positional embeddings. Due to
|
333 |
-
# the way TransformerBase does cross-attention, this means that we use
|
334 |
-
# separate key,value layers for rec_state and tokens_w.
|
335 |
-
|
336 |
-
# Keys, values, queries from recurrent state.
|
337 |
-
logging.info("tlayer: recurrent kvq.")
|
338 |
-
rec_kvq = self.recurrent_tbase.kvq(rec_state)
|
339 |
-
r_scale_factors = self.recurrent_tbase.attention_scale_factors()
|
340 |
-
(r_keys, r_values, r_queries, r_queries2) = rec_kvq
|
341 |
-
|
342 |
-
# Joint attention over both recurrent states and input tokens.
|
343 |
-
logging.info("tlayer: recurrent self-attention.")
|
344 |
-
r_attn_ys = attention.simple_attention(
|
345 |
-
r_keys,
|
346 |
-
r_values,
|
347 |
-
r_queries,
|
348 |
-
None,
|
349 |
-
scale_factor=r_scale_factors[0],
|
350 |
-
dtype=self.dtype,
|
351 |
-
)
|
352 |
-
|
353 |
-
logging.info("tlayer: recurrent cross-attention.")
|
354 |
-
r_cross_attn_ys = attention.simple_attention(
|
355 |
-
keys_w,
|
356 |
-
values_w,
|
357 |
-
r_queries2,
|
358 |
-
importance_w,
|
359 |
-
scale_factor=r_scale_factors[1],
|
360 |
-
dtype=self.dtype,
|
361 |
-
)
|
362 |
-
|
363 |
-
# Recurrent post-attention FFN.
|
364 |
-
logging.info("tlayer: recurrent ffn.")
|
365 |
-
next_rec_state = self.recurrent_tbase.post_attn_ffn(
|
366 |
-
rec_state, r_attn_ys, r_cross_attn_ys
|
367 |
-
)
|
368 |
-
|
369 |
-
# Get keys and values for cross-attention from recurrent state.
|
370 |
-
assert cross_attention_kv is None
|
371 |
-
local_cross_attention_kv = (r_keys, r_values)
|
372 |
-
else:
|
373 |
-
# Get keys and values for cross-attention from external argument.
|
374 |
-
next_rec_state = None
|
375 |
-
local_cross_attention_kv = cross_attention_kv
|
376 |
-
|
377 |
-
# If using RoPE, keys and queries are rotated before self-attention.
|
378 |
-
if self.relative_position_type == "rotary":
|
379 |
-
logging.info(
|
380 |
-
"Using rotary position encodings (RoPE), offset = %d",
|
381 |
-
kq_relative_offset,
|
382 |
-
)
|
383 |
-
(keys_w, queries_w) = position.rotate_kq(
|
384 |
-
keys_w, queries_w, max_wavelength=10_000, offset=kq_relative_offset
|
385 |
-
)
|
386 |
-
|
387 |
-
# Self-attention over input tokens.
|
388 |
-
logging.info("tlayer: self-attention.")
|
389 |
-
attn_ys_w = attention.simple_attention(
|
390 |
-
keys_w,
|
391 |
-
values_w,
|
392 |
-
queries_w,
|
393 |
-
importance_w,
|
394 |
-
relative_position_bias=rel_position_bias,
|
395 |
-
scale_factor=attention_scale_factors[0],
|
396 |
-
causal_mask=causal_mask,
|
397 |
-
dropout_multiplier=dropout_multiplier,
|
398 |
-
dtype=self.dtype,
|
399 |
-
)
|
400 |
-
|
401 |
-
# Attention over external memory.
|
402 |
-
if external_kv_w is not None:
|
403 |
-
(external_keys_w, external_values_w) = external_kv_w
|
404 |
-
y_ext = attention.external_attention(
|
405 |
-
external_keys_w,
|
406 |
-
external_values_w,
|
407 |
-
queries_w,
|
408 |
-
scale_factor=attention_scale_factors[0],
|
409 |
-
)
|
410 |
-
if external_memory_bias is not None:
|
411 |
-
ebias = external_memory_bias
|
412 |
-
attn_ys_w = (attn_ys_w * (1 - ebias)) + (y_ext * ebias)
|
413 |
-
elif self.memory_combine_with_local == "ADD":
|
414 |
-
attn_ys_w += y_ext
|
415 |
-
elif self.memory_combine_with_local == "STOP_FORWARD":
|
416 |
-
attn_ys_w = y_ext + (attn_ys_w - jax.lax.stop_gradient(attn_ys_w))
|
417 |
-
else:
|
418 |
-
raise ValueError(
|
419 |
-
f"Unexpected setting: {self.memory_combine_with_local = }"
|
420 |
-
)
|
421 |
-
|
422 |
-
# Cross attention from input tokens to encoder or recurrent state.
|
423 |
-
if local_cross_attention_kv is not None:
|
424 |
-
logging.info("tlayer: cross-attention.")
|
425 |
-
(c_keys, c_values) = local_cross_attention_kv
|
426 |
-
|
427 |
-
# Cross-attention using queries2.
|
428 |
-
cross_attn_ys_w = attention.simple_attention(
|
429 |
-
c_keys,
|
430 |
-
c_values,
|
431 |
-
queries2_w,
|
432 |
-
None,
|
433 |
-
scale_factor=attention_scale_factors[1],
|
434 |
-
dtype=self.dtype,
|
435 |
-
)
|
436 |
-
else:
|
437 |
-
cross_attn_ys_w = None
|
438 |
-
|
439 |
-
# End function single_window_attention(...)
|
440 |
-
return ((next_kvi_w, next_rec_state), (attn_ys_w, cross_attn_ys_w))
|
441 |
-
|
442 |
-
# Initialize recurrent_tbase before calling jax.lax.scan.
|
443 |
-
# Otherwise flax will throw a tantrum.
|
444 |
-
if (
|
445 |
-
self.recurrent_attention
|
446 |
-
and 0 <= self.max_unrolled_windows
|
447 |
-
and self.max_unrolled_windows < num_windows
|
448 |
-
):
|
449 |
-
logging.info("tlayer: force initialization of recurrent_tbase.")
|
450 |
-
self.recurrent_tbase.force_init(recurrent_state)
|
451 |
-
|
452 |
-
# Perform sliding window attention over all keys,values,queries.
|
453 |
-
# --------------------------------------------------------------
|
454 |
-
initial_carry = (prev_kvi, recurrent_state) # window state.
|
455 |
-
kvqi = (keys, values, queries, queries2, importance)
|
456 |
-
attn_inputs = (kvqi, external_kv)
|
457 |
-
(next_carry, attn_outputs) = attention.split_and_scan(
|
458 |
-
single_window_attention,
|
459 |
-
initial_carry,
|
460 |
-
attn_inputs,
|
461 |
-
sections=num_windows,
|
462 |
-
axis=1,
|
463 |
-
max_unrolled_windows=self.max_unrolled_windows,
|
464 |
-
)
|
465 |
-
(attn_ys, cross_attn_ys) = attn_outputs
|
466 |
-
|
467 |
-
logging.info("tlayer: End windows.")
|
468 |
-
|
469 |
-
# Post-attention MLP, resnet, and FFN.
|
470 |
-
# ------------------------------------
|
471 |
-
logging.info("tlayer: final FFN.")
|
472 |
-
ys = self.tbase.post_attn_ffn(xs, attn_ys, cross_attn_ys)
|
473 |
-
|
474 |
-
# Compute importance scores for each token if requested.
|
475 |
-
if self.compute_importance:
|
476 |
-
(batch_size, sequence_length, _) = ys.shape
|
477 |
-
importance_score = self.importance_layer(ys)
|
478 |
-
importance_score = importance_score.reshape((batch_size, sequence_length))
|
479 |
-
else:
|
480 |
-
importance_score = None
|
481 |
-
|
482 |
-
next_window_state = next_carry if window_state is not None else None
|
483 |
-
viz_dict = {} # Visualizations, not currently enabled.
|
484 |
-
return (ys, importance_score, next_window_state, decoder_state, viz_dict)
|
485 |
-
|
486 |
-
def init_decoder_state_vanilla(
|
487 |
-
self, sequence_length: int, start_of_sequence: Array
|
488 |
-
) -> DecoderState:
|
489 |
-
"""Initialize decoder state for autoregressive generation.
|
490 |
-
|
491 |
-
Args:
|
492 |
-
sequence_length: The maximum length of the sequence to generate.
|
493 |
-
start_of_sequence: Array of boolean of shape (batch_size,) True if
|
494 |
-
starting a new sequence (with no prefix).
|
495 |
-
|
496 |
-
Returns:
|
497 |
-
A state object that can be passed to __call__.
|
498 |
-
"""
|
499 |
-
|
500 |
-
if not self.use_causal_mask:
|
501 |
-
raise ValueError("Generator must have been trained with a causal mask.")
|
502 |
-
|
503 |
-
# Get relative position bias.
|
504 |
-
rel_position_bias = self.relative_positions(
|
505 |
-
1, self.window_length, offset=self.window_length, bidirectional=False
|
506 |
-
)
|
507 |
-
rel_position_bias = jnp.tile(rel_position_bias, (self.batch_size, 1, 1, 1))
|
508 |
-
|
509 |
-
# Initialize autoregressive storage for (key, value) pairs.
|
510 |
-
# Include space for a prefix of window_length tokens.
|
511 |
-
num_keys = sequence_length + self.window_length
|
512 |
-
stored_shape = (self.batch_size, num_keys, self.num_heads, self.head_size)
|
513 |
-
stored_keys = jnp.zeros(stored_shape, dtype=self.dtype)
|
514 |
-
stored_values = jnp.zeros(stored_shape, dtype=self.dtype)
|
515 |
-
|
516 |
-
recurrent_kvq = None
|
517 |
-
current_index = jnp.array([self.window_length] * self.batch_size)
|
518 |
-
|
519 |
-
decoder_state_dict = {
|
520 |
-
"keys": stored_keys,
|
521 |
-
"values": stored_values,
|
522 |
-
"current_index": current_index,
|
523 |
-
"relative_position_bias": rel_position_bias,
|
524 |
-
"recurrent_kvq": recurrent_kvq,
|
525 |
-
}
|
526 |
-
return DecoderState(decoder_state_dict)
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""A single transformer layer in inference mode.
|
17 |
+
|
18 |
+
Modified
|
19 |
+
https://github.com/google-research/meliad/blob/main/transformer/transformer_layer.py
|
20 |
+
To accommodate sequence packing + kv cache + relative position during test time.
|
21 |
+
"""
|
22 |
+
|
23 |
+
from typing import Callable, Mapping, NewType, Optional, Tuple
|
24 |
+
|
25 |
+
from absl import logging
|
26 |
+
import gin
|
27 |
+
import jax
|
28 |
+
import jax.numpy as jnp
|
29 |
+
from aglib.meliad.transformer import attention
|
30 |
+
from aglib.meliad.transformer import nn_components
|
31 |
+
from aglib.meliad.transformer import position
|
32 |
+
from aglib.meliad.transformer import transformer_layer
|
33 |
+
|
34 |
+
Array = jnp.ndarray
|
35 |
+
DecoderState = NewType("DecoderState", Mapping[str, Array])
|
36 |
+
WindowState = Optional[Tuple[attention.KVITuple, Array]]
|
37 |
+
|
38 |
+
|
39 |
+
@jax.vmap
|
40 |
+
def update_slice_in_dim_1(array: Array, update: Array, idx: Array) -> Array:
|
41 |
+
"""Update a stored keys/values slice for different-lengthed seqs in batch."""
|
42 |
+
return jax.lax.dynamic_update_slice_in_dim(array, update, idx, axis=0)
|
43 |
+
|
44 |
+
|
45 |
+
def slice_in_dim_1(window_length: int) -> Callable[[Array, Array], Array]:
|
46 |
+
@jax.vmap
|
47 |
+
def fn(array: Array, idx: Array) -> Array:
|
48 |
+
return jax.lax.dynamic_slice_in_dim(array, idx, window_length, axis=0)
|
49 |
+
|
50 |
+
return fn
|
51 |
+
|
52 |
+
|
53 |
+
@gin.configurable
|
54 |
+
class TransformerLayerGenerate(transformer_layer.TransformerLayer):
|
55 |
+
"""Full transformer layer, with attention."""
|
56 |
+
|
57 |
+
def _next_decoder_state(
|
58 |
+
self, decoder_state: DecoderState, keys: Array, values: Array
|
59 |
+
) -> Tuple[DecoderState, Array, Array]:
|
60 |
+
"""Compute the next decoder state, and return keys,values to attend to.
|
61 |
+
|
62 |
+
The keys,values returned from this function are drawn from the prior
|
63 |
+
decoding state, and comprise a full window of local context.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
decoder_state: The current decoder state, initially created using
|
67 |
+
init_decoder_state().
|
68 |
+
keys: The key for the current token, of shape (batch_size, 1, dim)
|
69 |
+
values: The value for the current token of shape (batch_size, 1, dim)
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
(next_decoder_state,
|
73 |
+
window of keys of shape (batch_size, window_length, dim),
|
74 |
+
window of values of shape (batch_size, window_length, dim))
|
75 |
+
"""
|
76 |
+
|
77 |
+
assert keys.shape[1] == 1 # single-token autoregressive decoding.
|
78 |
+
|
79 |
+
# Unpack decoder_state
|
80 |
+
stored_keys = decoder_state["keys"]
|
81 |
+
stored_values = decoder_state["values"]
|
82 |
+
curr_index = decoder_state["current_index"]
|
83 |
+
|
84 |
+
# Slice to get window_length-sized chunk of previous keys,values.
|
85 |
+
out_decoder_state = {}
|
86 |
+
curr_win_index = curr_index - self.window_length
|
87 |
+
|
88 |
+
# out_keys = jax.lax.dynamic_slice_in_dim(
|
89 |
+
# stored_keys, curr_win_index, self.window_length, axis=1)
|
90 |
+
out_keys = slice_in_dim_1(self.window_length)(stored_keys, curr_win_index)
|
91 |
+
|
92 |
+
# out_values = jax.lax.dynamic_slice_in_dim(
|
93 |
+
# stored_values, curr_win_index, self.window_length, axis=1)
|
94 |
+
out_values = slice_in_dim_1(self.window_length)(
|
95 |
+
stored_values, curr_win_index
|
96 |
+
)
|
97 |
+
|
98 |
+
# Write current keys,values to stored keys, values.
|
99 |
+
# stored_keys = jax.lax.dynamic_update_slice_in_dim(
|
100 |
+
# stored_keys, keys, curr_index, axis=1)
|
101 |
+
stored_keys = update_slice_in_dim_1(stored_keys, keys, curr_index)
|
102 |
+
# stored_values = jax.lax.dynamic_update_slice_in_dim(
|
103 |
+
# stored_values, values, curr_index, axis=1)
|
104 |
+
stored_values = update_slice_in_dim_1(stored_values, values, curr_index)
|
105 |
+
curr_index = curr_index + 1
|
106 |
+
|
107 |
+
# Pack a new decoder_state object.
|
108 |
+
out_decoder_state["keys"] = stored_keys
|
109 |
+
out_decoder_state["values"] = stored_values
|
110 |
+
out_decoder_state["current_index"] = curr_index
|
111 |
+
out_decoder_state["relative_position_bias"] = decoder_state[
|
112 |
+
"relative_position_bias"
|
113 |
+
]
|
114 |
+
out_decoder_state["recurrent_kvq"] = decoder_state["recurrent_kvq"]
|
115 |
+
|
116 |
+
return (DecoderState(out_decoder_state), out_keys, out_values)
|
117 |
+
|
118 |
+
def __call__(
|
119 |
+
self,
|
120 |
+
xs: Array,
|
121 |
+
start_of_sequence: Array,
|
122 |
+
*,
|
123 |
+
importance: Optional[Array] = None,
|
124 |
+
cross_attention_kv: Optional[Tuple[Array, Array]] = None,
|
125 |
+
window_state: Optional[WindowState] = None,
|
126 |
+
decoder_state: Optional[DecoderState] = None,
|
127 |
+
):
|
128 |
+
"""Computes attention over a sequence of inputs.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
xs: input sequence of shape (batch_size, sequence_length, num_hidden)
|
132 |
+
start_of_sequence: An input array of shape (batch_size) --- The following
|
133 |
+
must be passed by keyword only. ---
|
134 |
+
importance: Array of shape (batch_size, sequence_length). An importance
|
135 |
+
bias for attention.
|
136 |
+
cross_attention_kv: Keys and values from encoder for cross-attention.
|
137 |
+
window_state: State object which contains context from the prior window
|
138 |
+
when using a transformer-XL or sliding window. Initially created with
|
139 |
+
load_window_state().
|
140 |
+
decoder_state: State object for autoregressive decoding, initially created
|
141 |
+
with from init_decoder_state().
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
(ys: outputs of shape (batch_size, sequence_length, num_hidden),
|
145 |
+
importance_score: importance score for the next layer,
|
146 |
+
next_window_state: state to pass to the next window,
|
147 |
+
next_decoder_state: next decoder state for autoregressive decoding,
|
148 |
+
viz_dict: dictionary of visualizations
|
149 |
+
)
|
150 |
+
"""
|
151 |
+
|
152 |
+
xs = jnp.asarray(xs, dtype=self.dtype)
|
153 |
+
logging.info("tlayer: recurrent = %r", self.recurrent_attention)
|
154 |
+
logging.info("tlayer: compute_importance = %r", self.compute_importance)
|
155 |
+
|
156 |
+
is_training = self.mode == "train"
|
157 |
+
|
158 |
+
# Compute keys, values and queries.
|
159 |
+
# ---------------------------------
|
160 |
+
logging.info("tlayer: compute keys,values,queries.")
|
161 |
+
(keys, values, queries, queries2) = self.tbase.kvq(xs)
|
162 |
+
attention_scale_factors = self.tbase.attention_scale_factors()
|
163 |
+
(_, sequence_length, num_heads, _) = queries.shape # (b, k, h, d)
|
164 |
+
|
165 |
+
# Get biases and masks that are shared across windows.
|
166 |
+
# ----------------------------------------------------
|
167 |
+
if decoder_state is not None:
|
168 |
+
logging.info("tlayer: using autoregressive decoder.")
|
169 |
+
# When decoding, prior keys,values are loaded from the decoder state.
|
170 |
+
# Other values are precomputed, and loaded from the decoder state.
|
171 |
+
# The decoder state will be updated with the current token.
|
172 |
+
assert window_state is None
|
173 |
+
|
174 |
+
prev_kvi = None
|
175 |
+
recurrent_state = None # Use precomputed recurrent_kvq.
|
176 |
+
cross_attention_kv = None
|
177 |
+
rel_position_bias = decoder_state["relative_position_bias"]
|
178 |
+
causal_mask = None
|
179 |
+
dropout_multiplier = None
|
180 |
+
|
181 |
+
# Reuse cached recurrent keys,values for each token.
|
182 |
+
cached_recurrent_kvq = decoder_state["recurrent_kvq"]
|
183 |
+
if cached_recurrent_kvq is not None:
|
184 |
+
assert cross_attention_kv is None
|
185 |
+
cross_attention_kv = (cached_recurrent_kvq[0], cached_recurrent_kvq[1])
|
186 |
+
del cached_recurrent_kvq
|
187 |
+
|
188 |
+
# Get a full window of keys,values and update decoder state.
|
189 |
+
(decoder_state, keys, values) = self._next_decoder_state(
|
190 |
+
decoder_state, keys, values
|
191 |
+
)
|
192 |
+
|
193 |
+
# Each query attends to window_length prior keys.
|
194 |
+
assert keys.shape[1] == self.window_length
|
195 |
+
kq_relative_offset = self.window_length
|
196 |
+
|
197 |
+
if not self.use_long_xl_architecture:
|
198 |
+
kqpos = position.relative_positions(
|
199 |
+
1, self.window_length, offset=0
|
200 |
+
) # 2D mask
|
201 |
+
current_idx = decoder_state["current_index"]
|
202 |
+
|
203 |
+
# add (batch, heads) dims for kqpos
|
204 |
+
kqpos = jnp.expand_dims(kqpos, axis=(0, 1))
|
205 |
+
kqpos = jnp.tile(kqpos, (1, self.num_heads, 1, 1))
|
206 |
+
|
207 |
+
# add (_, heads, _) dim for current_idx
|
208 |
+
current_idx = jnp.expand_dims(current_idx, axis=(1, 2, 3))
|
209 |
+
|
210 |
+
causal_mask = kqpos > self.window_length * 2 - current_idx
|
211 |
+
else:
|
212 |
+
logging.info("tlayer: windowed attention.")
|
213 |
+
# When training, attention is done using windows or chunks, and prior
|
214 |
+
# context (e.g. keys,values from the previous window) is stored in the
|
215 |
+
# window_state object.
|
216 |
+
(prev_kvi, recurrent_state) = (
|
217 |
+
window_state # pytype: disable=attribute-error
|
218 |
+
)
|
219 |
+
|
220 |
+
# Get the size of the sliding window for pos bias, dropout, & causal mask.
|
221 |
+
(num_queries, num_keys) = attention.sliding_attention_window_shape(
|
222 |
+
(keys, values, importance),
|
223 |
+
prev_kvi,
|
224 |
+
queries,
|
225 |
+
window_length=self.window_length,
|
226 |
+
)
|
227 |
+
kq_relative_offset = num_keys - num_queries
|
228 |
+
|
229 |
+
# Get the relative position bias.
|
230 |
+
# The bias doesn't depend on the query content, and so can be precomputed.
|
231 |
+
if self.relative_positions is not None:
|
232 |
+
rel_position_bias = self.relative_positions(
|
233 |
+
num_queries, num_keys, bidirectional=False
|
234 |
+
)
|
235 |
+
else:
|
236 |
+
rel_position_bias = None
|
237 |
+
|
238 |
+
# Get causal mask.
|
239 |
+
if self.use_causal_mask:
|
240 |
+
causal_mask = position.causal_mask(
|
241 |
+
num_queries, num_keys, window_length=self.window_length
|
242 |
+
)
|
243 |
+
else:
|
244 |
+
causal_mask = None
|
245 |
+
|
246 |
+
# Apply dropout to the attention matrix.
|
247 |
+
# The mask will be broadcast across batches and windows.
|
248 |
+
if self.attn_dropout_rate > 0.0 and is_training:
|
249 |
+
dropout_rng = self.make_rng("dropout")
|
250 |
+
attn_shape = (self.num_heads, num_queries, num_keys)
|
251 |
+
dropout_multiplier = nn_components.dropout_multiplier_mask(
|
252 |
+
dropout_rng, self.attn_dropout_rate, attn_shape, self.dtype
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
dropout_multiplier = None
|
256 |
+
|
257 |
+
# Load and store values into external memory, if memory is not None.
|
258 |
+
# ------------------------------------------------------------------
|
259 |
+
(mode, _, update_memory) = self._get_cache_name_from_mode(self.mode)
|
260 |
+
external_kv = self._query_external_memory(
|
261 |
+
keys,
|
262 |
+
values,
|
263 |
+
queries,
|
264 |
+
start_of_sequence=start_of_sequence,
|
265 |
+
mode=mode,
|
266 |
+
update_memory=decoder_state is None and update_memory,
|
267 |
+
)
|
268 |
+
|
269 |
+
if (
|
270 |
+
self.memory is not None
|
271 |
+
and self.memory_combine_with_local == "TRAINABLE_WEIGHTED_MEAN"
|
272 |
+
):
|
273 |
+
external_memory_bias = jnp.asarray(self.memory_bias, dtype=self.dtype)
|
274 |
+
external_memory_bias = jnp.reshape(
|
275 |
+
external_memory_bias, (1, 1, num_heads, 1)
|
276 |
+
)
|
277 |
+
external_memory_bias = jax.nn.sigmoid(external_memory_bias)
|
278 |
+
else:
|
279 |
+
external_memory_bias = None
|
280 |
+
|
281 |
+
# Compute the number of windows.
|
282 |
+
# ------------------------------
|
283 |
+
if sequence_length < self.window_length:
|
284 |
+
num_windows = 1 # Happens with autoregressive decoding.
|
285 |
+
elif sequence_length == self.window_length:
|
286 |
+
num_windows = 1
|
287 |
+
if self.use_long_xl_architecture:
|
288 |
+
assert prev_kvi is not None
|
289 |
+
else:
|
290 |
+
if not self.use_long_xl_architecture:
|
291 |
+
raise ValueError("Can only use sliding window with Transformer XL.")
|
292 |
+
num_windows = sequence_length // self.window_length
|
293 |
+
if (num_windows * self.window_length) != sequence_length:
|
294 |
+
raise ValueError(
|
295 |
+
f"Window length {self.window_length} must be a "
|
296 |
+
+ f"multiple of sequence length {sequence_length}"
|
297 |
+
)
|
298 |
+
logging.info("tlayer: num_windows = %d.", num_windows)
|
299 |
+
|
300 |
+
# Define the function to do attention within a single window.
|
301 |
+
# ---------------------------------------------------------
|
302 |
+
def single_window_attention(
|
303 |
+
carry: tuple[Array, Array], inputs_w: tuple[Array, Array]
|
304 |
+
) -> tuple[tuple[Array, Array], tuple[Array, Array]]:
|
305 |
+
# This function uses the following variables from the outer scope.
|
306 |
+
# They are listed here for clarity.
|
307 |
+
nonlocal rel_position_bias
|
308 |
+
nonlocal causal_mask
|
309 |
+
nonlocal kq_relative_offset
|
310 |
+
nonlocal dropout_multiplier
|
311 |
+
nonlocal attention_scale_factors
|
312 |
+
nonlocal external_memory_bias
|
313 |
+
nonlocal cross_attention_kv # externally supplied.
|
314 |
+
|
315 |
+
# keys,values,queries over the whole sequence will be split into chunks.
|
316 |
+
# xs_w, kvqi_w, etc. are the chunk for the current window.
|
317 |
+
(prev_kvi_w, rec_state) = carry # carried from one window to the next.
|
318 |
+
(kvqi_w, external_kv_w) = inputs_w # inputs to the current window.
|
319 |
+
# (keys_curr_w, values_curr_w, _, _, importance_curr_w) = kvqi_w
|
320 |
+
|
321 |
+
# Concatenate keys,values from the previous window with the current
|
322 |
+
# window to implement sliding window attention.
|
323 |
+
(kvqi_w, next_kvi_w) = attention.concat_kvqi(kvqi_w, prev_kvi_w)
|
324 |
+
(keys_w, values_w, queries_w, queries2_w, importance_w) = kvqi_w
|
325 |
+
|
326 |
+
# Perform recurrent attention within the current window to get the next
|
327 |
+
# recurrent state, and set up cross attention.
|
328 |
+
if rec_state is not None:
|
329 |
+
logging.info("tlayer: recurrent attention.")
|
330 |
+
|
331 |
+
# NOTE -- recurrent states and input tokens are handled separately,
|
332 |
+
# because they have separate learned positional embeddings. Due to
|
333 |
+
# the way TransformerBase does cross-attention, this means that we use
|
334 |
+
# separate key,value layers for rec_state and tokens_w.
|
335 |
+
|
336 |
+
# Keys, values, queries from recurrent state.
|
337 |
+
logging.info("tlayer: recurrent kvq.")
|
338 |
+
rec_kvq = self.recurrent_tbase.kvq(rec_state)
|
339 |
+
r_scale_factors = self.recurrent_tbase.attention_scale_factors()
|
340 |
+
(r_keys, r_values, r_queries, r_queries2) = rec_kvq
|
341 |
+
|
342 |
+
# Joint attention over both recurrent states and input tokens.
|
343 |
+
logging.info("tlayer: recurrent self-attention.")
|
344 |
+
r_attn_ys = attention.simple_attention(
|
345 |
+
r_keys,
|
346 |
+
r_values,
|
347 |
+
r_queries,
|
348 |
+
None,
|
349 |
+
scale_factor=r_scale_factors[0],
|
350 |
+
dtype=self.dtype,
|
351 |
+
)
|
352 |
+
|
353 |
+
logging.info("tlayer: recurrent cross-attention.")
|
354 |
+
r_cross_attn_ys = attention.simple_attention(
|
355 |
+
keys_w,
|
356 |
+
values_w,
|
357 |
+
r_queries2,
|
358 |
+
importance_w,
|
359 |
+
scale_factor=r_scale_factors[1],
|
360 |
+
dtype=self.dtype,
|
361 |
+
)
|
362 |
+
|
363 |
+
# Recurrent post-attention FFN.
|
364 |
+
logging.info("tlayer: recurrent ffn.")
|
365 |
+
next_rec_state = self.recurrent_tbase.post_attn_ffn(
|
366 |
+
rec_state, r_attn_ys, r_cross_attn_ys
|
367 |
+
)
|
368 |
+
|
369 |
+
# Get keys and values for cross-attention from recurrent state.
|
370 |
+
assert cross_attention_kv is None
|
371 |
+
local_cross_attention_kv = (r_keys, r_values)
|
372 |
+
else:
|
373 |
+
# Get keys and values for cross-attention from external argument.
|
374 |
+
next_rec_state = None
|
375 |
+
local_cross_attention_kv = cross_attention_kv
|
376 |
+
|
377 |
+
# If using RoPE, keys and queries are rotated before self-attention.
|
378 |
+
if self.relative_position_type == "rotary":
|
379 |
+
logging.info(
|
380 |
+
"Using rotary position encodings (RoPE), offset = %d",
|
381 |
+
kq_relative_offset,
|
382 |
+
)
|
383 |
+
(keys_w, queries_w) = position.rotate_kq(
|
384 |
+
keys_w, queries_w, max_wavelength=10_000, offset=kq_relative_offset
|
385 |
+
)
|
386 |
+
|
387 |
+
# Self-attention over input tokens.
|
388 |
+
logging.info("tlayer: self-attention.")
|
389 |
+
attn_ys_w = attention.simple_attention(
|
390 |
+
keys_w,
|
391 |
+
values_w,
|
392 |
+
queries_w,
|
393 |
+
importance_w,
|
394 |
+
relative_position_bias=rel_position_bias,
|
395 |
+
scale_factor=attention_scale_factors[0],
|
396 |
+
causal_mask=causal_mask,
|
397 |
+
dropout_multiplier=dropout_multiplier,
|
398 |
+
dtype=self.dtype,
|
399 |
+
)
|
400 |
+
|
401 |
+
# Attention over external memory.
|
402 |
+
if external_kv_w is not None:
|
403 |
+
(external_keys_w, external_values_w) = external_kv_w
|
404 |
+
y_ext = attention.external_attention(
|
405 |
+
external_keys_w,
|
406 |
+
external_values_w,
|
407 |
+
queries_w,
|
408 |
+
scale_factor=attention_scale_factors[0],
|
409 |
+
)
|
410 |
+
if external_memory_bias is not None:
|
411 |
+
ebias = external_memory_bias
|
412 |
+
attn_ys_w = (attn_ys_w * (1 - ebias)) + (y_ext * ebias)
|
413 |
+
elif self.memory_combine_with_local == "ADD":
|
414 |
+
attn_ys_w += y_ext
|
415 |
+
elif self.memory_combine_with_local == "STOP_FORWARD":
|
416 |
+
attn_ys_w = y_ext + (attn_ys_w - jax.lax.stop_gradient(attn_ys_w))
|
417 |
+
else:
|
418 |
+
raise ValueError(
|
419 |
+
f"Unexpected setting: {self.memory_combine_with_local = }"
|
420 |
+
)
|
421 |
+
|
422 |
+
# Cross attention from input tokens to encoder or recurrent state.
|
423 |
+
if local_cross_attention_kv is not None:
|
424 |
+
logging.info("tlayer: cross-attention.")
|
425 |
+
(c_keys, c_values) = local_cross_attention_kv
|
426 |
+
|
427 |
+
# Cross-attention using queries2.
|
428 |
+
cross_attn_ys_w = attention.simple_attention(
|
429 |
+
c_keys,
|
430 |
+
c_values,
|
431 |
+
queries2_w,
|
432 |
+
None,
|
433 |
+
scale_factor=attention_scale_factors[1],
|
434 |
+
dtype=self.dtype,
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
cross_attn_ys_w = None
|
438 |
+
|
439 |
+
# End function single_window_attention(...)
|
440 |
+
return ((next_kvi_w, next_rec_state), (attn_ys_w, cross_attn_ys_w))
|
441 |
+
|
442 |
+
# Initialize recurrent_tbase before calling jax.lax.scan.
|
443 |
+
# Otherwise flax will throw a tantrum.
|
444 |
+
if (
|
445 |
+
self.recurrent_attention
|
446 |
+
and 0 <= self.max_unrolled_windows
|
447 |
+
and self.max_unrolled_windows < num_windows
|
448 |
+
):
|
449 |
+
logging.info("tlayer: force initialization of recurrent_tbase.")
|
450 |
+
self.recurrent_tbase.force_init(recurrent_state)
|
451 |
+
|
452 |
+
# Perform sliding window attention over all keys,values,queries.
|
453 |
+
# --------------------------------------------------------------
|
454 |
+
initial_carry = (prev_kvi, recurrent_state) # window state.
|
455 |
+
kvqi = (keys, values, queries, queries2, importance)
|
456 |
+
attn_inputs = (kvqi, external_kv)
|
457 |
+
(next_carry, attn_outputs) = attention.split_and_scan(
|
458 |
+
single_window_attention,
|
459 |
+
initial_carry,
|
460 |
+
attn_inputs,
|
461 |
+
sections=num_windows,
|
462 |
+
axis=1,
|
463 |
+
max_unrolled_windows=self.max_unrolled_windows,
|
464 |
+
)
|
465 |
+
(attn_ys, cross_attn_ys) = attn_outputs
|
466 |
+
|
467 |
+
logging.info("tlayer: End windows.")
|
468 |
+
|
469 |
+
# Post-attention MLP, resnet, and FFN.
|
470 |
+
# ------------------------------------
|
471 |
+
logging.info("tlayer: final FFN.")
|
472 |
+
ys = self.tbase.post_attn_ffn(xs, attn_ys, cross_attn_ys)
|
473 |
+
|
474 |
+
# Compute importance scores for each token if requested.
|
475 |
+
if self.compute_importance:
|
476 |
+
(batch_size, sequence_length, _) = ys.shape
|
477 |
+
importance_score = self.importance_layer(ys)
|
478 |
+
importance_score = importance_score.reshape((batch_size, sequence_length))
|
479 |
+
else:
|
480 |
+
importance_score = None
|
481 |
+
|
482 |
+
next_window_state = next_carry if window_state is not None else None
|
483 |
+
viz_dict = {} # Visualizations, not currently enabled.
|
484 |
+
return (ys, importance_score, next_window_state, decoder_state, viz_dict)
|
485 |
+
|
486 |
+
def init_decoder_state_vanilla(
|
487 |
+
self, sequence_length: int, start_of_sequence: Array
|
488 |
+
) -> DecoderState:
|
489 |
+
"""Initialize decoder state for autoregressive generation.
|
490 |
+
|
491 |
+
Args:
|
492 |
+
sequence_length: The maximum length of the sequence to generate.
|
493 |
+
start_of_sequence: Array of boolean of shape (batch_size,) True if
|
494 |
+
starting a new sequence (with no prefix).
|
495 |
+
|
496 |
+
Returns:
|
497 |
+
A state object that can be passed to __call__.
|
498 |
+
"""
|
499 |
+
|
500 |
+
if not self.use_causal_mask:
|
501 |
+
raise ValueError("Generator must have been trained with a causal mask.")
|
502 |
+
|
503 |
+
# Get relative position bias.
|
504 |
+
rel_position_bias = self.relative_positions(
|
505 |
+
1, self.window_length, offset=self.window_length, bidirectional=False
|
506 |
+
)
|
507 |
+
rel_position_bias = jnp.tile(rel_position_bias, (self.batch_size, 1, 1, 1))
|
508 |
+
|
509 |
+
# Initialize autoregressive storage for (key, value) pairs.
|
510 |
+
# Include space for a prefix of window_length tokens.
|
511 |
+
num_keys = sequence_length + self.window_length
|
512 |
+
stored_shape = (self.batch_size, num_keys, self.num_heads, self.head_size)
|
513 |
+
stored_keys = jnp.zeros(stored_shape, dtype=self.dtype)
|
514 |
+
stored_values = jnp.zeros(stored_shape, dtype=self.dtype)
|
515 |
+
|
516 |
+
recurrent_kvq = None
|
517 |
+
current_index = jnp.array([self.window_length] * self.batch_size)
|
518 |
+
|
519 |
+
decoder_state_dict = {
|
520 |
+
"keys": stored_keys,
|
521 |
+
"values": stored_values,
|
522 |
+
"current_index": current_index,
|
523 |
+
"relative_position_bias": rel_position_bias,
|
524 |
+
"recurrent_kvq": recurrent_kvq,
|
525 |
+
}
|
526 |
+
return DecoderState(decoder_state_dict)
|