HugoVoxx commited on
Commit
a08ef52
·
verified ·
1 Parent(s): 9334a63

Update ag4masses/alphageometry/transformer_layer.py

Browse files
ag4masses/alphageometry/transformer_layer.py CHANGED
@@ -1,526 +1,526 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """A single transformer layer in inference mode.
17
-
18
- Modified
19
- https://github.com/google-research/meliad/blob/main/transformer/transformer_layer.py
20
- To accommodate sequence packing + kv cache + relative position during test time.
21
- """
22
-
23
- from typing import Callable, Mapping, NewType, Optional, Tuple
24
-
25
- from absl import logging
26
- import gin
27
- import jax
28
- import jax.numpy as jnp
29
- from meliad_lib.meliad.transformer import attention
30
- from meliad_lib.meliad.transformer import nn_components
31
- from meliad_lib.meliad.transformer import position
32
- from meliad_lib.meliad.transformer import transformer_layer
33
-
34
- Array = jnp.ndarray
35
- DecoderState = NewType("DecoderState", Mapping[str, Array])
36
- WindowState = Optional[Tuple[attention.KVITuple, Array]]
37
-
38
-
39
- @jax.vmap
40
- def update_slice_in_dim_1(array: Array, update: Array, idx: Array) -> Array:
41
- """Update a stored keys/values slice for different-lengthed seqs in batch."""
42
- return jax.lax.dynamic_update_slice_in_dim(array, update, idx, axis=0)
43
-
44
-
45
- def slice_in_dim_1(window_length: int) -> Callable[[Array, Array], Array]:
46
- @jax.vmap
47
- def fn(array: Array, idx: Array) -> Array:
48
- return jax.lax.dynamic_slice_in_dim(array, idx, window_length, axis=0)
49
-
50
- return fn
51
-
52
-
53
- @gin.configurable
54
- class TransformerLayerGenerate(transformer_layer.TransformerLayer):
55
- """Full transformer layer, with attention."""
56
-
57
- def _next_decoder_state(
58
- self, decoder_state: DecoderState, keys: Array, values: Array
59
- ) -> Tuple[DecoderState, Array, Array]:
60
- """Compute the next decoder state, and return keys,values to attend to.
61
-
62
- The keys,values returned from this function are drawn from the prior
63
- decoding state, and comprise a full window of local context.
64
-
65
- Args:
66
- decoder_state: The current decoder state, initially created using
67
- init_decoder_state().
68
- keys: The key for the current token, of shape (batch_size, 1, dim)
69
- values: The value for the current token of shape (batch_size, 1, dim)
70
-
71
- Returns:
72
- (next_decoder_state,
73
- window of keys of shape (batch_size, window_length, dim),
74
- window of values of shape (batch_size, window_length, dim))
75
- """
76
-
77
- assert keys.shape[1] == 1 # single-token autoregressive decoding.
78
-
79
- # Unpack decoder_state
80
- stored_keys = decoder_state["keys"]
81
- stored_values = decoder_state["values"]
82
- curr_index = decoder_state["current_index"]
83
-
84
- # Slice to get window_length-sized chunk of previous keys,values.
85
- out_decoder_state = {}
86
- curr_win_index = curr_index - self.window_length
87
-
88
- # out_keys = jax.lax.dynamic_slice_in_dim(
89
- # stored_keys, curr_win_index, self.window_length, axis=1)
90
- out_keys = slice_in_dim_1(self.window_length)(stored_keys, curr_win_index)
91
-
92
- # out_values = jax.lax.dynamic_slice_in_dim(
93
- # stored_values, curr_win_index, self.window_length, axis=1)
94
- out_values = slice_in_dim_1(self.window_length)(
95
- stored_values, curr_win_index
96
- )
97
-
98
- # Write current keys,values to stored keys, values.
99
- # stored_keys = jax.lax.dynamic_update_slice_in_dim(
100
- # stored_keys, keys, curr_index, axis=1)
101
- stored_keys = update_slice_in_dim_1(stored_keys, keys, curr_index)
102
- # stored_values = jax.lax.dynamic_update_slice_in_dim(
103
- # stored_values, values, curr_index, axis=1)
104
- stored_values = update_slice_in_dim_1(stored_values, values, curr_index)
105
- curr_index = curr_index + 1
106
-
107
- # Pack a new decoder_state object.
108
- out_decoder_state["keys"] = stored_keys
109
- out_decoder_state["values"] = stored_values
110
- out_decoder_state["current_index"] = curr_index
111
- out_decoder_state["relative_position_bias"] = decoder_state[
112
- "relative_position_bias"
113
- ]
114
- out_decoder_state["recurrent_kvq"] = decoder_state["recurrent_kvq"]
115
-
116
- return (DecoderState(out_decoder_state), out_keys, out_values)
117
-
118
- def __call__(
119
- self,
120
- xs: Array,
121
- start_of_sequence: Array,
122
- *,
123
- importance: Optional[Array] = None,
124
- cross_attention_kv: Optional[Tuple[Array, Array]] = None,
125
- window_state: Optional[WindowState] = None,
126
- decoder_state: Optional[DecoderState] = None,
127
- ):
128
- """Computes attention over a sequence of inputs.
129
-
130
- Args:
131
- xs: input sequence of shape (batch_size, sequence_length, num_hidden)
132
- start_of_sequence: An input array of shape (batch_size) --- The following
133
- must be passed by keyword only. ---
134
- importance: Array of shape (batch_size, sequence_length). An importance
135
- bias for attention.
136
- cross_attention_kv: Keys and values from encoder for cross-attention.
137
- window_state: State object which contains context from the prior window
138
- when using a transformer-XL or sliding window. Initially created with
139
- load_window_state().
140
- decoder_state: State object for autoregressive decoding, initially created
141
- with from init_decoder_state().
142
-
143
- Returns:
144
- (ys: outputs of shape (batch_size, sequence_length, num_hidden),
145
- importance_score: importance score for the next layer,
146
- next_window_state: state to pass to the next window,
147
- next_decoder_state: next decoder state for autoregressive decoding,
148
- viz_dict: dictionary of visualizations
149
- )
150
- """
151
-
152
- xs = jnp.asarray(xs, dtype=self.dtype)
153
- logging.info("tlayer: recurrent = %r", self.recurrent_attention)
154
- logging.info("tlayer: compute_importance = %r", self.compute_importance)
155
-
156
- is_training = self.mode == "train"
157
-
158
- # Compute keys, values and queries.
159
- # ---------------------------------
160
- logging.info("tlayer: compute keys,values,queries.")
161
- (keys, values, queries, queries2) = self.tbase.kvq(xs)
162
- attention_scale_factors = self.tbase.attention_scale_factors()
163
- (_, sequence_length, num_heads, _) = queries.shape # (b, k, h, d)
164
-
165
- # Get biases and masks that are shared across windows.
166
- # ----------------------------------------------------
167
- if decoder_state is not None:
168
- logging.info("tlayer: using autoregressive decoder.")
169
- # When decoding, prior keys,values are loaded from the decoder state.
170
- # Other values are precomputed, and loaded from the decoder state.
171
- # The decoder state will be updated with the current token.
172
- assert window_state is None
173
-
174
- prev_kvi = None
175
- recurrent_state = None # Use precomputed recurrent_kvq.
176
- cross_attention_kv = None
177
- rel_position_bias = decoder_state["relative_position_bias"]
178
- causal_mask = None
179
- dropout_multiplier = None
180
-
181
- # Reuse cached recurrent keys,values for each token.
182
- cached_recurrent_kvq = decoder_state["recurrent_kvq"]
183
- if cached_recurrent_kvq is not None:
184
- assert cross_attention_kv is None
185
- cross_attention_kv = (cached_recurrent_kvq[0], cached_recurrent_kvq[1])
186
- del cached_recurrent_kvq
187
-
188
- # Get a full window of keys,values and update decoder state.
189
- (decoder_state, keys, values) = self._next_decoder_state(
190
- decoder_state, keys, values
191
- )
192
-
193
- # Each query attends to window_length prior keys.
194
- assert keys.shape[1] == self.window_length
195
- kq_relative_offset = self.window_length
196
-
197
- if not self.use_long_xl_architecture:
198
- kqpos = position.relative_positions(
199
- 1, self.window_length, offset=0
200
- ) # 2D mask
201
- current_idx = decoder_state["current_index"]
202
-
203
- # add (batch, heads) dims for kqpos
204
- kqpos = jnp.expand_dims(kqpos, axis=(0, 1))
205
- kqpos = jnp.tile(kqpos, (1, self.num_heads, 1, 1))
206
-
207
- # add (_, heads, _) dim for current_idx
208
- current_idx = jnp.expand_dims(current_idx, axis=(1, 2, 3))
209
-
210
- causal_mask = kqpos > self.window_length * 2 - current_idx
211
- else:
212
- logging.info("tlayer: windowed attention.")
213
- # When training, attention is done using windows or chunks, and prior
214
- # context (e.g. keys,values from the previous window) is stored in the
215
- # window_state object.
216
- (prev_kvi, recurrent_state) = (
217
- window_state # pytype: disable=attribute-error
218
- )
219
-
220
- # Get the size of the sliding window for pos bias, dropout, & causal mask.
221
- (num_queries, num_keys) = attention.sliding_attention_window_shape(
222
- (keys, values, importance),
223
- prev_kvi,
224
- queries,
225
- window_length=self.window_length,
226
- )
227
- kq_relative_offset = num_keys - num_queries
228
-
229
- # Get the relative position bias.
230
- # The bias doesn't depend on the query content, and so can be precomputed.
231
- if self.relative_positions is not None:
232
- rel_position_bias = self.relative_positions(
233
- num_queries, num_keys, bidirectional=False
234
- )
235
- else:
236
- rel_position_bias = None
237
-
238
- # Get causal mask.
239
- if self.use_causal_mask:
240
- causal_mask = position.causal_mask(
241
- num_queries, num_keys, window_length=self.window_length
242
- )
243
- else:
244
- causal_mask = None
245
-
246
- # Apply dropout to the attention matrix.
247
- # The mask will be broadcast across batches and windows.
248
- if self.attn_dropout_rate > 0.0 and is_training:
249
- dropout_rng = self.make_rng("dropout")
250
- attn_shape = (self.num_heads, num_queries, num_keys)
251
- dropout_multiplier = nn_components.dropout_multiplier_mask(
252
- dropout_rng, self.attn_dropout_rate, attn_shape, self.dtype
253
- )
254
- else:
255
- dropout_multiplier = None
256
-
257
- # Load and store values into external memory, if memory is not None.
258
- # ------------------------------------------------------------------
259
- (mode, _, update_memory) = self._get_cache_name_from_mode(self.mode)
260
- external_kv = self._query_external_memory(
261
- keys,
262
- values,
263
- queries,
264
- start_of_sequence=start_of_sequence,
265
- mode=mode,
266
- update_memory=decoder_state is None and update_memory,
267
- )
268
-
269
- if (
270
- self.memory is not None
271
- and self.memory_combine_with_local == "TRAINABLE_WEIGHTED_MEAN"
272
- ):
273
- external_memory_bias = jnp.asarray(self.memory_bias, dtype=self.dtype)
274
- external_memory_bias = jnp.reshape(
275
- external_memory_bias, (1, 1, num_heads, 1)
276
- )
277
- external_memory_bias = jax.nn.sigmoid(external_memory_bias)
278
- else:
279
- external_memory_bias = None
280
-
281
- # Compute the number of windows.
282
- # ------------------------------
283
- if sequence_length < self.window_length:
284
- num_windows = 1 # Happens with autoregressive decoding.
285
- elif sequence_length == self.window_length:
286
- num_windows = 1
287
- if self.use_long_xl_architecture:
288
- assert prev_kvi is not None
289
- else:
290
- if not self.use_long_xl_architecture:
291
- raise ValueError("Can only use sliding window with Transformer XL.")
292
- num_windows = sequence_length // self.window_length
293
- if (num_windows * self.window_length) != sequence_length:
294
- raise ValueError(
295
- f"Window length {self.window_length} must be a "
296
- + f"multiple of sequence length {sequence_length}"
297
- )
298
- logging.info("tlayer: num_windows = %d.", num_windows)
299
-
300
- # Define the function to do attention within a single window.
301
- # ---------------------------------------------------------
302
- def single_window_attention(
303
- carry: tuple[Array, Array], inputs_w: tuple[Array, Array]
304
- ) -> tuple[tuple[Array, Array], tuple[Array, Array]]:
305
- # This function uses the following variables from the outer scope.
306
- # They are listed here for clarity.
307
- nonlocal rel_position_bias
308
- nonlocal causal_mask
309
- nonlocal kq_relative_offset
310
- nonlocal dropout_multiplier
311
- nonlocal attention_scale_factors
312
- nonlocal external_memory_bias
313
- nonlocal cross_attention_kv # externally supplied.
314
-
315
- # keys,values,queries over the whole sequence will be split into chunks.
316
- # xs_w, kvqi_w, etc. are the chunk for the current window.
317
- (prev_kvi_w, rec_state) = carry # carried from one window to the next.
318
- (kvqi_w, external_kv_w) = inputs_w # inputs to the current window.
319
- # (keys_curr_w, values_curr_w, _, _, importance_curr_w) = kvqi_w
320
-
321
- # Concatenate keys,values from the previous window with the current
322
- # window to implement sliding window attention.
323
- (kvqi_w, next_kvi_w) = attention.concat_kvqi(kvqi_w, prev_kvi_w)
324
- (keys_w, values_w, queries_w, queries2_w, importance_w) = kvqi_w
325
-
326
- # Perform recurrent attention within the current window to get the next
327
- # recurrent state, and set up cross attention.
328
- if rec_state is not None:
329
- logging.info("tlayer: recurrent attention.")
330
-
331
- # NOTE -- recurrent states and input tokens are handled separately,
332
- # because they have separate learned positional embeddings. Due to
333
- # the way TransformerBase does cross-attention, this means that we use
334
- # separate key,value layers for rec_state and tokens_w.
335
-
336
- # Keys, values, queries from recurrent state.
337
- logging.info("tlayer: recurrent kvq.")
338
- rec_kvq = self.recurrent_tbase.kvq(rec_state)
339
- r_scale_factors = self.recurrent_tbase.attention_scale_factors()
340
- (r_keys, r_values, r_queries, r_queries2) = rec_kvq
341
-
342
- # Joint attention over both recurrent states and input tokens.
343
- logging.info("tlayer: recurrent self-attention.")
344
- r_attn_ys = attention.simple_attention(
345
- r_keys,
346
- r_values,
347
- r_queries,
348
- None,
349
- scale_factor=r_scale_factors[0],
350
- dtype=self.dtype,
351
- )
352
-
353
- logging.info("tlayer: recurrent cross-attention.")
354
- r_cross_attn_ys = attention.simple_attention(
355
- keys_w,
356
- values_w,
357
- r_queries2,
358
- importance_w,
359
- scale_factor=r_scale_factors[1],
360
- dtype=self.dtype,
361
- )
362
-
363
- # Recurrent post-attention FFN.
364
- logging.info("tlayer: recurrent ffn.")
365
- next_rec_state = self.recurrent_tbase.post_attn_ffn(
366
- rec_state, r_attn_ys, r_cross_attn_ys
367
- )
368
-
369
- # Get keys and values for cross-attention from recurrent state.
370
- assert cross_attention_kv is None
371
- local_cross_attention_kv = (r_keys, r_values)
372
- else:
373
- # Get keys and values for cross-attention from external argument.
374
- next_rec_state = None
375
- local_cross_attention_kv = cross_attention_kv
376
-
377
- # If using RoPE, keys and queries are rotated before self-attention.
378
- if self.relative_position_type == "rotary":
379
- logging.info(
380
- "Using rotary position encodings (RoPE), offset = %d",
381
- kq_relative_offset,
382
- )
383
- (keys_w, queries_w) = position.rotate_kq(
384
- keys_w, queries_w, max_wavelength=10_000, offset=kq_relative_offset
385
- )
386
-
387
- # Self-attention over input tokens.
388
- logging.info("tlayer: self-attention.")
389
- attn_ys_w = attention.simple_attention(
390
- keys_w,
391
- values_w,
392
- queries_w,
393
- importance_w,
394
- relative_position_bias=rel_position_bias,
395
- scale_factor=attention_scale_factors[0],
396
- causal_mask=causal_mask,
397
- dropout_multiplier=dropout_multiplier,
398
- dtype=self.dtype,
399
- )
400
-
401
- # Attention over external memory.
402
- if external_kv_w is not None:
403
- (external_keys_w, external_values_w) = external_kv_w
404
- y_ext = attention.external_attention(
405
- external_keys_w,
406
- external_values_w,
407
- queries_w,
408
- scale_factor=attention_scale_factors[0],
409
- )
410
- if external_memory_bias is not None:
411
- ebias = external_memory_bias
412
- attn_ys_w = (attn_ys_w * (1 - ebias)) + (y_ext * ebias)
413
- elif self.memory_combine_with_local == "ADD":
414
- attn_ys_w += y_ext
415
- elif self.memory_combine_with_local == "STOP_FORWARD":
416
- attn_ys_w = y_ext + (attn_ys_w - jax.lax.stop_gradient(attn_ys_w))
417
- else:
418
- raise ValueError(
419
- f"Unexpected setting: {self.memory_combine_with_local = }"
420
- )
421
-
422
- # Cross attention from input tokens to encoder or recurrent state.
423
- if local_cross_attention_kv is not None:
424
- logging.info("tlayer: cross-attention.")
425
- (c_keys, c_values) = local_cross_attention_kv
426
-
427
- # Cross-attention using queries2.
428
- cross_attn_ys_w = attention.simple_attention(
429
- c_keys,
430
- c_values,
431
- queries2_w,
432
- None,
433
- scale_factor=attention_scale_factors[1],
434
- dtype=self.dtype,
435
- )
436
- else:
437
- cross_attn_ys_w = None
438
-
439
- # End function single_window_attention(...)
440
- return ((next_kvi_w, next_rec_state), (attn_ys_w, cross_attn_ys_w))
441
-
442
- # Initialize recurrent_tbase before calling jax.lax.scan.
443
- # Otherwise flax will throw a tantrum.
444
- if (
445
- self.recurrent_attention
446
- and 0 <= self.max_unrolled_windows
447
- and self.max_unrolled_windows < num_windows
448
- ):
449
- logging.info("tlayer: force initialization of recurrent_tbase.")
450
- self.recurrent_tbase.force_init(recurrent_state)
451
-
452
- # Perform sliding window attention over all keys,values,queries.
453
- # --------------------------------------------------------------
454
- initial_carry = (prev_kvi, recurrent_state) # window state.
455
- kvqi = (keys, values, queries, queries2, importance)
456
- attn_inputs = (kvqi, external_kv)
457
- (next_carry, attn_outputs) = attention.split_and_scan(
458
- single_window_attention,
459
- initial_carry,
460
- attn_inputs,
461
- sections=num_windows,
462
- axis=1,
463
- max_unrolled_windows=self.max_unrolled_windows,
464
- )
465
- (attn_ys, cross_attn_ys) = attn_outputs
466
-
467
- logging.info("tlayer: End windows.")
468
-
469
- # Post-attention MLP, resnet, and FFN.
470
- # ------------------------------------
471
- logging.info("tlayer: final FFN.")
472
- ys = self.tbase.post_attn_ffn(xs, attn_ys, cross_attn_ys)
473
-
474
- # Compute importance scores for each token if requested.
475
- if self.compute_importance:
476
- (batch_size, sequence_length, _) = ys.shape
477
- importance_score = self.importance_layer(ys)
478
- importance_score = importance_score.reshape((batch_size, sequence_length))
479
- else:
480
- importance_score = None
481
-
482
- next_window_state = next_carry if window_state is not None else None
483
- viz_dict = {} # Visualizations, not currently enabled.
484
- return (ys, importance_score, next_window_state, decoder_state, viz_dict)
485
-
486
- def init_decoder_state_vanilla(
487
- self, sequence_length: int, start_of_sequence: Array
488
- ) -> DecoderState:
489
- """Initialize decoder state for autoregressive generation.
490
-
491
- Args:
492
- sequence_length: The maximum length of the sequence to generate.
493
- start_of_sequence: Array of boolean of shape (batch_size,) True if
494
- starting a new sequence (with no prefix).
495
-
496
- Returns:
497
- A state object that can be passed to __call__.
498
- """
499
-
500
- if not self.use_causal_mask:
501
- raise ValueError("Generator must have been trained with a causal mask.")
502
-
503
- # Get relative position bias.
504
- rel_position_bias = self.relative_positions(
505
- 1, self.window_length, offset=self.window_length, bidirectional=False
506
- )
507
- rel_position_bias = jnp.tile(rel_position_bias, (self.batch_size, 1, 1, 1))
508
-
509
- # Initialize autoregressive storage for (key, value) pairs.
510
- # Include space for a prefix of window_length tokens.
511
- num_keys = sequence_length + self.window_length
512
- stored_shape = (self.batch_size, num_keys, self.num_heads, self.head_size)
513
- stored_keys = jnp.zeros(stored_shape, dtype=self.dtype)
514
- stored_values = jnp.zeros(stored_shape, dtype=self.dtype)
515
-
516
- recurrent_kvq = None
517
- current_index = jnp.array([self.window_length] * self.batch_size)
518
-
519
- decoder_state_dict = {
520
- "keys": stored_keys,
521
- "values": stored_values,
522
- "current_index": current_index,
523
- "relative_position_bias": rel_position_bias,
524
- "recurrent_kvq": recurrent_kvq,
525
- }
526
- return DecoderState(decoder_state_dict)
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """A single transformer layer in inference mode.
17
+
18
+ Modified
19
+ https://github.com/google-research/meliad/blob/main/transformer/transformer_layer.py
20
+ To accommodate sequence packing + kv cache + relative position during test time.
21
+ """
22
+
23
+ from typing import Callable, Mapping, NewType, Optional, Tuple
24
+
25
+ from absl import logging
26
+ import gin
27
+ import jax
28
+ import jax.numpy as jnp
29
+ from aglib.meliad.transformer import attention
30
+ from aglib.meliad.transformer import nn_components
31
+ from aglib.meliad.transformer import position
32
+ from aglib.meliad.transformer import transformer_layer
33
+
34
+ Array = jnp.ndarray
35
+ DecoderState = NewType("DecoderState", Mapping[str, Array])
36
+ WindowState = Optional[Tuple[attention.KVITuple, Array]]
37
+
38
+
39
+ @jax.vmap
40
+ def update_slice_in_dim_1(array: Array, update: Array, idx: Array) -> Array:
41
+ """Update a stored keys/values slice for different-lengthed seqs in batch."""
42
+ return jax.lax.dynamic_update_slice_in_dim(array, update, idx, axis=0)
43
+
44
+
45
+ def slice_in_dim_1(window_length: int) -> Callable[[Array, Array], Array]:
46
+ @jax.vmap
47
+ def fn(array: Array, idx: Array) -> Array:
48
+ return jax.lax.dynamic_slice_in_dim(array, idx, window_length, axis=0)
49
+
50
+ return fn
51
+
52
+
53
+ @gin.configurable
54
+ class TransformerLayerGenerate(transformer_layer.TransformerLayer):
55
+ """Full transformer layer, with attention."""
56
+
57
+ def _next_decoder_state(
58
+ self, decoder_state: DecoderState, keys: Array, values: Array
59
+ ) -> Tuple[DecoderState, Array, Array]:
60
+ """Compute the next decoder state, and return keys,values to attend to.
61
+
62
+ The keys,values returned from this function are drawn from the prior
63
+ decoding state, and comprise a full window of local context.
64
+
65
+ Args:
66
+ decoder_state: The current decoder state, initially created using
67
+ init_decoder_state().
68
+ keys: The key for the current token, of shape (batch_size, 1, dim)
69
+ values: The value for the current token of shape (batch_size, 1, dim)
70
+
71
+ Returns:
72
+ (next_decoder_state,
73
+ window of keys of shape (batch_size, window_length, dim),
74
+ window of values of shape (batch_size, window_length, dim))
75
+ """
76
+
77
+ assert keys.shape[1] == 1 # single-token autoregressive decoding.
78
+
79
+ # Unpack decoder_state
80
+ stored_keys = decoder_state["keys"]
81
+ stored_values = decoder_state["values"]
82
+ curr_index = decoder_state["current_index"]
83
+
84
+ # Slice to get window_length-sized chunk of previous keys,values.
85
+ out_decoder_state = {}
86
+ curr_win_index = curr_index - self.window_length
87
+
88
+ # out_keys = jax.lax.dynamic_slice_in_dim(
89
+ # stored_keys, curr_win_index, self.window_length, axis=1)
90
+ out_keys = slice_in_dim_1(self.window_length)(stored_keys, curr_win_index)
91
+
92
+ # out_values = jax.lax.dynamic_slice_in_dim(
93
+ # stored_values, curr_win_index, self.window_length, axis=1)
94
+ out_values = slice_in_dim_1(self.window_length)(
95
+ stored_values, curr_win_index
96
+ )
97
+
98
+ # Write current keys,values to stored keys, values.
99
+ # stored_keys = jax.lax.dynamic_update_slice_in_dim(
100
+ # stored_keys, keys, curr_index, axis=1)
101
+ stored_keys = update_slice_in_dim_1(stored_keys, keys, curr_index)
102
+ # stored_values = jax.lax.dynamic_update_slice_in_dim(
103
+ # stored_values, values, curr_index, axis=1)
104
+ stored_values = update_slice_in_dim_1(stored_values, values, curr_index)
105
+ curr_index = curr_index + 1
106
+
107
+ # Pack a new decoder_state object.
108
+ out_decoder_state["keys"] = stored_keys
109
+ out_decoder_state["values"] = stored_values
110
+ out_decoder_state["current_index"] = curr_index
111
+ out_decoder_state["relative_position_bias"] = decoder_state[
112
+ "relative_position_bias"
113
+ ]
114
+ out_decoder_state["recurrent_kvq"] = decoder_state["recurrent_kvq"]
115
+
116
+ return (DecoderState(out_decoder_state), out_keys, out_values)
117
+
118
+ def __call__(
119
+ self,
120
+ xs: Array,
121
+ start_of_sequence: Array,
122
+ *,
123
+ importance: Optional[Array] = None,
124
+ cross_attention_kv: Optional[Tuple[Array, Array]] = None,
125
+ window_state: Optional[WindowState] = None,
126
+ decoder_state: Optional[DecoderState] = None,
127
+ ):
128
+ """Computes attention over a sequence of inputs.
129
+
130
+ Args:
131
+ xs: input sequence of shape (batch_size, sequence_length, num_hidden)
132
+ start_of_sequence: An input array of shape (batch_size) --- The following
133
+ must be passed by keyword only. ---
134
+ importance: Array of shape (batch_size, sequence_length). An importance
135
+ bias for attention.
136
+ cross_attention_kv: Keys and values from encoder for cross-attention.
137
+ window_state: State object which contains context from the prior window
138
+ when using a transformer-XL or sliding window. Initially created with
139
+ load_window_state().
140
+ decoder_state: State object for autoregressive decoding, initially created
141
+ with from init_decoder_state().
142
+
143
+ Returns:
144
+ (ys: outputs of shape (batch_size, sequence_length, num_hidden),
145
+ importance_score: importance score for the next layer,
146
+ next_window_state: state to pass to the next window,
147
+ next_decoder_state: next decoder state for autoregressive decoding,
148
+ viz_dict: dictionary of visualizations
149
+ )
150
+ """
151
+
152
+ xs = jnp.asarray(xs, dtype=self.dtype)
153
+ logging.info("tlayer: recurrent = %r", self.recurrent_attention)
154
+ logging.info("tlayer: compute_importance = %r", self.compute_importance)
155
+
156
+ is_training = self.mode == "train"
157
+
158
+ # Compute keys, values and queries.
159
+ # ---------------------------------
160
+ logging.info("tlayer: compute keys,values,queries.")
161
+ (keys, values, queries, queries2) = self.tbase.kvq(xs)
162
+ attention_scale_factors = self.tbase.attention_scale_factors()
163
+ (_, sequence_length, num_heads, _) = queries.shape # (b, k, h, d)
164
+
165
+ # Get biases and masks that are shared across windows.
166
+ # ----------------------------------------------------
167
+ if decoder_state is not None:
168
+ logging.info("tlayer: using autoregressive decoder.")
169
+ # When decoding, prior keys,values are loaded from the decoder state.
170
+ # Other values are precomputed, and loaded from the decoder state.
171
+ # The decoder state will be updated with the current token.
172
+ assert window_state is None
173
+
174
+ prev_kvi = None
175
+ recurrent_state = None # Use precomputed recurrent_kvq.
176
+ cross_attention_kv = None
177
+ rel_position_bias = decoder_state["relative_position_bias"]
178
+ causal_mask = None
179
+ dropout_multiplier = None
180
+
181
+ # Reuse cached recurrent keys,values for each token.
182
+ cached_recurrent_kvq = decoder_state["recurrent_kvq"]
183
+ if cached_recurrent_kvq is not None:
184
+ assert cross_attention_kv is None
185
+ cross_attention_kv = (cached_recurrent_kvq[0], cached_recurrent_kvq[1])
186
+ del cached_recurrent_kvq
187
+
188
+ # Get a full window of keys,values and update decoder state.
189
+ (decoder_state, keys, values) = self._next_decoder_state(
190
+ decoder_state, keys, values
191
+ )
192
+
193
+ # Each query attends to window_length prior keys.
194
+ assert keys.shape[1] == self.window_length
195
+ kq_relative_offset = self.window_length
196
+
197
+ if not self.use_long_xl_architecture:
198
+ kqpos = position.relative_positions(
199
+ 1, self.window_length, offset=0
200
+ ) # 2D mask
201
+ current_idx = decoder_state["current_index"]
202
+
203
+ # add (batch, heads) dims for kqpos
204
+ kqpos = jnp.expand_dims(kqpos, axis=(0, 1))
205
+ kqpos = jnp.tile(kqpos, (1, self.num_heads, 1, 1))
206
+
207
+ # add (_, heads, _) dim for current_idx
208
+ current_idx = jnp.expand_dims(current_idx, axis=(1, 2, 3))
209
+
210
+ causal_mask = kqpos > self.window_length * 2 - current_idx
211
+ else:
212
+ logging.info("tlayer: windowed attention.")
213
+ # When training, attention is done using windows or chunks, and prior
214
+ # context (e.g. keys,values from the previous window) is stored in the
215
+ # window_state object.
216
+ (prev_kvi, recurrent_state) = (
217
+ window_state # pytype: disable=attribute-error
218
+ )
219
+
220
+ # Get the size of the sliding window for pos bias, dropout, & causal mask.
221
+ (num_queries, num_keys) = attention.sliding_attention_window_shape(
222
+ (keys, values, importance),
223
+ prev_kvi,
224
+ queries,
225
+ window_length=self.window_length,
226
+ )
227
+ kq_relative_offset = num_keys - num_queries
228
+
229
+ # Get the relative position bias.
230
+ # The bias doesn't depend on the query content, and so can be precomputed.
231
+ if self.relative_positions is not None:
232
+ rel_position_bias = self.relative_positions(
233
+ num_queries, num_keys, bidirectional=False
234
+ )
235
+ else:
236
+ rel_position_bias = None
237
+
238
+ # Get causal mask.
239
+ if self.use_causal_mask:
240
+ causal_mask = position.causal_mask(
241
+ num_queries, num_keys, window_length=self.window_length
242
+ )
243
+ else:
244
+ causal_mask = None
245
+
246
+ # Apply dropout to the attention matrix.
247
+ # The mask will be broadcast across batches and windows.
248
+ if self.attn_dropout_rate > 0.0 and is_training:
249
+ dropout_rng = self.make_rng("dropout")
250
+ attn_shape = (self.num_heads, num_queries, num_keys)
251
+ dropout_multiplier = nn_components.dropout_multiplier_mask(
252
+ dropout_rng, self.attn_dropout_rate, attn_shape, self.dtype
253
+ )
254
+ else:
255
+ dropout_multiplier = None
256
+
257
+ # Load and store values into external memory, if memory is not None.
258
+ # ------------------------------------------------------------------
259
+ (mode, _, update_memory) = self._get_cache_name_from_mode(self.mode)
260
+ external_kv = self._query_external_memory(
261
+ keys,
262
+ values,
263
+ queries,
264
+ start_of_sequence=start_of_sequence,
265
+ mode=mode,
266
+ update_memory=decoder_state is None and update_memory,
267
+ )
268
+
269
+ if (
270
+ self.memory is not None
271
+ and self.memory_combine_with_local == "TRAINABLE_WEIGHTED_MEAN"
272
+ ):
273
+ external_memory_bias = jnp.asarray(self.memory_bias, dtype=self.dtype)
274
+ external_memory_bias = jnp.reshape(
275
+ external_memory_bias, (1, 1, num_heads, 1)
276
+ )
277
+ external_memory_bias = jax.nn.sigmoid(external_memory_bias)
278
+ else:
279
+ external_memory_bias = None
280
+
281
+ # Compute the number of windows.
282
+ # ------------------------------
283
+ if sequence_length < self.window_length:
284
+ num_windows = 1 # Happens with autoregressive decoding.
285
+ elif sequence_length == self.window_length:
286
+ num_windows = 1
287
+ if self.use_long_xl_architecture:
288
+ assert prev_kvi is not None
289
+ else:
290
+ if not self.use_long_xl_architecture:
291
+ raise ValueError("Can only use sliding window with Transformer XL.")
292
+ num_windows = sequence_length // self.window_length
293
+ if (num_windows * self.window_length) != sequence_length:
294
+ raise ValueError(
295
+ f"Window length {self.window_length} must be a "
296
+ + f"multiple of sequence length {sequence_length}"
297
+ )
298
+ logging.info("tlayer: num_windows = %d.", num_windows)
299
+
300
+ # Define the function to do attention within a single window.
301
+ # ---------------------------------------------------------
302
+ def single_window_attention(
303
+ carry: tuple[Array, Array], inputs_w: tuple[Array, Array]
304
+ ) -> tuple[tuple[Array, Array], tuple[Array, Array]]:
305
+ # This function uses the following variables from the outer scope.
306
+ # They are listed here for clarity.
307
+ nonlocal rel_position_bias
308
+ nonlocal causal_mask
309
+ nonlocal kq_relative_offset
310
+ nonlocal dropout_multiplier
311
+ nonlocal attention_scale_factors
312
+ nonlocal external_memory_bias
313
+ nonlocal cross_attention_kv # externally supplied.
314
+
315
+ # keys,values,queries over the whole sequence will be split into chunks.
316
+ # xs_w, kvqi_w, etc. are the chunk for the current window.
317
+ (prev_kvi_w, rec_state) = carry # carried from one window to the next.
318
+ (kvqi_w, external_kv_w) = inputs_w # inputs to the current window.
319
+ # (keys_curr_w, values_curr_w, _, _, importance_curr_w) = kvqi_w
320
+
321
+ # Concatenate keys,values from the previous window with the current
322
+ # window to implement sliding window attention.
323
+ (kvqi_w, next_kvi_w) = attention.concat_kvqi(kvqi_w, prev_kvi_w)
324
+ (keys_w, values_w, queries_w, queries2_w, importance_w) = kvqi_w
325
+
326
+ # Perform recurrent attention within the current window to get the next
327
+ # recurrent state, and set up cross attention.
328
+ if rec_state is not None:
329
+ logging.info("tlayer: recurrent attention.")
330
+
331
+ # NOTE -- recurrent states and input tokens are handled separately,
332
+ # because they have separate learned positional embeddings. Due to
333
+ # the way TransformerBase does cross-attention, this means that we use
334
+ # separate key,value layers for rec_state and tokens_w.
335
+
336
+ # Keys, values, queries from recurrent state.
337
+ logging.info("tlayer: recurrent kvq.")
338
+ rec_kvq = self.recurrent_tbase.kvq(rec_state)
339
+ r_scale_factors = self.recurrent_tbase.attention_scale_factors()
340
+ (r_keys, r_values, r_queries, r_queries2) = rec_kvq
341
+
342
+ # Joint attention over both recurrent states and input tokens.
343
+ logging.info("tlayer: recurrent self-attention.")
344
+ r_attn_ys = attention.simple_attention(
345
+ r_keys,
346
+ r_values,
347
+ r_queries,
348
+ None,
349
+ scale_factor=r_scale_factors[0],
350
+ dtype=self.dtype,
351
+ )
352
+
353
+ logging.info("tlayer: recurrent cross-attention.")
354
+ r_cross_attn_ys = attention.simple_attention(
355
+ keys_w,
356
+ values_w,
357
+ r_queries2,
358
+ importance_w,
359
+ scale_factor=r_scale_factors[1],
360
+ dtype=self.dtype,
361
+ )
362
+
363
+ # Recurrent post-attention FFN.
364
+ logging.info("tlayer: recurrent ffn.")
365
+ next_rec_state = self.recurrent_tbase.post_attn_ffn(
366
+ rec_state, r_attn_ys, r_cross_attn_ys
367
+ )
368
+
369
+ # Get keys and values for cross-attention from recurrent state.
370
+ assert cross_attention_kv is None
371
+ local_cross_attention_kv = (r_keys, r_values)
372
+ else:
373
+ # Get keys and values for cross-attention from external argument.
374
+ next_rec_state = None
375
+ local_cross_attention_kv = cross_attention_kv
376
+
377
+ # If using RoPE, keys and queries are rotated before self-attention.
378
+ if self.relative_position_type == "rotary":
379
+ logging.info(
380
+ "Using rotary position encodings (RoPE), offset = %d",
381
+ kq_relative_offset,
382
+ )
383
+ (keys_w, queries_w) = position.rotate_kq(
384
+ keys_w, queries_w, max_wavelength=10_000, offset=kq_relative_offset
385
+ )
386
+
387
+ # Self-attention over input tokens.
388
+ logging.info("tlayer: self-attention.")
389
+ attn_ys_w = attention.simple_attention(
390
+ keys_w,
391
+ values_w,
392
+ queries_w,
393
+ importance_w,
394
+ relative_position_bias=rel_position_bias,
395
+ scale_factor=attention_scale_factors[0],
396
+ causal_mask=causal_mask,
397
+ dropout_multiplier=dropout_multiplier,
398
+ dtype=self.dtype,
399
+ )
400
+
401
+ # Attention over external memory.
402
+ if external_kv_w is not None:
403
+ (external_keys_w, external_values_w) = external_kv_w
404
+ y_ext = attention.external_attention(
405
+ external_keys_w,
406
+ external_values_w,
407
+ queries_w,
408
+ scale_factor=attention_scale_factors[0],
409
+ )
410
+ if external_memory_bias is not None:
411
+ ebias = external_memory_bias
412
+ attn_ys_w = (attn_ys_w * (1 - ebias)) + (y_ext * ebias)
413
+ elif self.memory_combine_with_local == "ADD":
414
+ attn_ys_w += y_ext
415
+ elif self.memory_combine_with_local == "STOP_FORWARD":
416
+ attn_ys_w = y_ext + (attn_ys_w - jax.lax.stop_gradient(attn_ys_w))
417
+ else:
418
+ raise ValueError(
419
+ f"Unexpected setting: {self.memory_combine_with_local = }"
420
+ )
421
+
422
+ # Cross attention from input tokens to encoder or recurrent state.
423
+ if local_cross_attention_kv is not None:
424
+ logging.info("tlayer: cross-attention.")
425
+ (c_keys, c_values) = local_cross_attention_kv
426
+
427
+ # Cross-attention using queries2.
428
+ cross_attn_ys_w = attention.simple_attention(
429
+ c_keys,
430
+ c_values,
431
+ queries2_w,
432
+ None,
433
+ scale_factor=attention_scale_factors[1],
434
+ dtype=self.dtype,
435
+ )
436
+ else:
437
+ cross_attn_ys_w = None
438
+
439
+ # End function single_window_attention(...)
440
+ return ((next_kvi_w, next_rec_state), (attn_ys_w, cross_attn_ys_w))
441
+
442
+ # Initialize recurrent_tbase before calling jax.lax.scan.
443
+ # Otherwise flax will throw a tantrum.
444
+ if (
445
+ self.recurrent_attention
446
+ and 0 <= self.max_unrolled_windows
447
+ and self.max_unrolled_windows < num_windows
448
+ ):
449
+ logging.info("tlayer: force initialization of recurrent_tbase.")
450
+ self.recurrent_tbase.force_init(recurrent_state)
451
+
452
+ # Perform sliding window attention over all keys,values,queries.
453
+ # --------------------------------------------------------------
454
+ initial_carry = (prev_kvi, recurrent_state) # window state.
455
+ kvqi = (keys, values, queries, queries2, importance)
456
+ attn_inputs = (kvqi, external_kv)
457
+ (next_carry, attn_outputs) = attention.split_and_scan(
458
+ single_window_attention,
459
+ initial_carry,
460
+ attn_inputs,
461
+ sections=num_windows,
462
+ axis=1,
463
+ max_unrolled_windows=self.max_unrolled_windows,
464
+ )
465
+ (attn_ys, cross_attn_ys) = attn_outputs
466
+
467
+ logging.info("tlayer: End windows.")
468
+
469
+ # Post-attention MLP, resnet, and FFN.
470
+ # ------------------------------------
471
+ logging.info("tlayer: final FFN.")
472
+ ys = self.tbase.post_attn_ffn(xs, attn_ys, cross_attn_ys)
473
+
474
+ # Compute importance scores for each token if requested.
475
+ if self.compute_importance:
476
+ (batch_size, sequence_length, _) = ys.shape
477
+ importance_score = self.importance_layer(ys)
478
+ importance_score = importance_score.reshape((batch_size, sequence_length))
479
+ else:
480
+ importance_score = None
481
+
482
+ next_window_state = next_carry if window_state is not None else None
483
+ viz_dict = {} # Visualizations, not currently enabled.
484
+ return (ys, importance_score, next_window_state, decoder_state, viz_dict)
485
+
486
+ def init_decoder_state_vanilla(
487
+ self, sequence_length: int, start_of_sequence: Array
488
+ ) -> DecoderState:
489
+ """Initialize decoder state for autoregressive generation.
490
+
491
+ Args:
492
+ sequence_length: The maximum length of the sequence to generate.
493
+ start_of_sequence: Array of boolean of shape (batch_size,) True if
494
+ starting a new sequence (with no prefix).
495
+
496
+ Returns:
497
+ A state object that can be passed to __call__.
498
+ """
499
+
500
+ if not self.use_causal_mask:
501
+ raise ValueError("Generator must have been trained with a causal mask.")
502
+
503
+ # Get relative position bias.
504
+ rel_position_bias = self.relative_positions(
505
+ 1, self.window_length, offset=self.window_length, bidirectional=False
506
+ )
507
+ rel_position_bias = jnp.tile(rel_position_bias, (self.batch_size, 1, 1, 1))
508
+
509
+ # Initialize autoregressive storage for (key, value) pairs.
510
+ # Include space for a prefix of window_length tokens.
511
+ num_keys = sequence_length + self.window_length
512
+ stored_shape = (self.batch_size, num_keys, self.num_heads, self.head_size)
513
+ stored_keys = jnp.zeros(stored_shape, dtype=self.dtype)
514
+ stored_values = jnp.zeros(stored_shape, dtype=self.dtype)
515
+
516
+ recurrent_kvq = None
517
+ current_index = jnp.array([self.window_length] * self.batch_size)
518
+
519
+ decoder_state_dict = {
520
+ "keys": stored_keys,
521
+ "values": stored_values,
522
+ "current_index": current_index,
523
+ "relative_position_bias": rel_position_bias,
524
+ "recurrent_kvq": recurrent_kvq,
525
+ }
526
+ return DecoderState(decoder_state_dict)