Spaces:
Edmond98
/
Running on TPU v5e

Afrinetwork7 commited on
Commit
03f9951
1 Parent(s): 42573d3

Update whisper_jax/layers.py

Browse files
Files changed (1) hide show
  1. whisper_jax/layers.py +1 -66
whisper_jax/layers.py CHANGED
@@ -56,16 +56,6 @@ NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxi
56
  default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
57
 
58
 
59
- def nd_dense_init(scale, mode, distribution):
60
- """Initializer with in_axis, out_axis set at call time."""
61
-
62
- def init_fn(key, shape, dtype, in_axis, out_axis):
63
- fn = variance_scaling(scale, mode, distribution, in_axis, out_axis)
64
- return fn(key, shape, dtype)
65
-
66
- return init_fn
67
-
68
-
69
  def dot_product_attention(
70
  query: Array,
71
  key: Array,
@@ -78,11 +68,9 @@ def dot_product_attention(
78
  float32_logits: bool = False,
79
  ):
80
  """Computes dot-product attention given query, key, and value.
81
-
82
  This is the core function for applying attention based on
83
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
84
  query and key and combines the values using the attention weights.
85
-
86
  Args:
87
  query: queries for calculating attention with shape of `[batch, q_length,
88
  num_heads, qk_depth_per_head]`.
@@ -99,7 +87,6 @@ def dot_product_attention(
99
  dtype: the dtype of the computation (default: float32)
100
  float32_logits: bool, if True then compute logits in float32 to avoid
101
  numerical issues with bfloat16.
102
-
103
  Returns:
104
  Output of shape `[batch, length, num_heads, v_depth_per_head]`.
105
  """
@@ -145,7 +132,6 @@ dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None,
145
 
146
  class MultiHeadDotProductAttention(nn.Module):
147
  """Multi-head dot-product attention.
148
-
149
  Attributes:
150
  num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
151
  should be divisible by the number of heads.
@@ -176,22 +162,18 @@ class MultiHeadDotProductAttention(nn.Module):
176
  deterministic: bool = False,
177
  ) -> Array:
178
  """Applies multi-head dot product attention on the input data.
179
-
180
  Projects the inputs into multi-headed query, key, and value vectors,
181
  applies dot-product attention and project the results to an output vector.
182
-
183
  There are two modes: decoding and non-decoding (e.g., training). The mode is
184
  determined by `decode` argument. For decoding, this method is called twice,
185
  first to initialize the cache and then for an actual decoding process. The
186
  two calls are differentiated by the presence of 'cached_key' in the variable
187
  dict. In the cache initialization stage, the cache variables are initialized
188
  as zeros and will be filled in the subsequent decoding process.
189
-
190
  In the cache initialization call, `inputs_q` has a shape [batch, length,
191
  q_features] and `inputs_kv`: [batch, length, kv_features]. During the
192
  incremental decoding stage, query, key and value all have the shape [batch,
193
  1, qkv_features] corresponding to a single step.
194
-
195
  Args:
196
  inputs_q: input queries of shape `[batch, q_length, q_features]`.
197
  inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
@@ -199,7 +181,6 @@ class MultiHeadDotProductAttention(nn.Module):
199
  bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
200
  decode: Whether to prepare and use an autoregressive cache.
201
  deterministic: Disables dropout if set to True.
202
-
203
  Returns:
204
  output of shape `[batch, length, q_features]`.
205
  """
@@ -360,7 +341,6 @@ def _canonicalize_tuple(x):
360
  # ------------------------------------------------------------------------------
361
  class DenseGeneral(nn.Module):
362
  """A linear transformation (without bias) with flexible axes.
363
-
364
  Attributes:
365
  features: tuple with numbers of output features.
366
  axis: tuple with axes to apply the transformation on.
@@ -380,10 +360,8 @@ class DenseGeneral(nn.Module):
380
  @nn.compact
381
  def __call__(self, inputs: Array) -> Array:
382
  """Applies a linear transformation to the inputs along multiple dimensions.
383
-
384
  Args:
385
  inputs: The nd-array to be transformed.
386
-
387
  Returns:
388
  The transformed input.
389
  """
@@ -432,7 +410,6 @@ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Calla
432
 
433
  class MlpBlock(nn.Module):
434
  """Transformer MLP / feed-forward block.
435
-
436
  Attributes:
437
  intermediate_dim: Shared dimension of hidden layers.
438
  activations: Type of activations for each layer. Each element is either
@@ -482,7 +459,6 @@ class MlpBlock(nn.Module):
482
 
483
  class Embed(nn.Module):
484
  """A parameterized function from integers [0, n) to d-dimensional vectors.
485
-
486
  Attributes:
487
  num_embeddings: number of embeddings.
488
  features: number of feature dimensions for each embedding.
@@ -513,10 +489,8 @@ class Embed(nn.Module):
513
 
514
  def __call__(self, inputs: Array) -> Array:
515
  """Embeds the inputs along the last dimension.
516
-
517
  Args:
518
  inputs: input data, all dimensions are considered batch dimensions.
519
-
520
  Returns:
521
  Output which is embedded input data. The output shape follows the input,
522
  with an additional `features` dimension appended.
@@ -536,11 +510,9 @@ class Embed(nn.Module):
536
 
537
  def attend(self, query: Array) -> Array:
538
  """Attend over the embedding using a query array.
539
-
540
  Args:
541
  query: array with last dimension equal the feature depth `features` of the
542
  embedding.
543
-
544
  Returns:
545
  An array with final dim `num_embeddings` corresponding to the batched
546
  inner-product of the array of query vectors against each embedding.
@@ -553,7 +525,6 @@ class Embed(nn.Module):
553
 
554
  class RelativePositionBiases(nn.Module):
555
  """Adds T5-style relative positional embeddings to the attention logits.
556
-
557
  Attributes:
558
  num_buckets: Number of buckets to bucket distances between key and query
559
  positions into.
@@ -574,7 +545,6 @@ class RelativePositionBiases(nn.Module):
574
  @staticmethod
575
  def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
576
  """Translate relative position to a bucket number for relative attention.
577
-
578
  The relative position is defined as memory_position - query_position, i.e.
579
  the distance in tokens from the attending position to the attended-to
580
  position. If bidirectional=False, then positive relative positions are
@@ -585,13 +555,11 @@ class RelativePositionBiases(nn.Module):
585
  positions <=-max_distance map to the same bucket. This should allow for
586
  more graceful generalization to longer sequences than the model has been
587
  trained on.
588
-
589
  Args:
590
  relative_position: an int32 array
591
  bidirectional: a boolean - whether the attention is bidirectional
592
  num_buckets: an integer
593
  max_distance: an integer
594
-
595
  Returns:
596
  a Tensor with the same shape as relative_position, containing int32
597
  values in the range [0, num_buckets)
@@ -619,13 +587,11 @@ class RelativePositionBiases(nn.Module):
619
  @nn.compact
620
  def __call__(self, qlen, klen, bidirectional=True):
621
  """Produce relative position embedding attention biases.
622
-
623
  Args:
624
  qlen: attention query length.
625
  klen: attention key length.
626
  bidirectional: whether to allow positive memory-query relative position
627
  embeddings.
628
-
629
  Returns:
630
  output: `(1, len, q_len, k_len)` attention bias
631
  """
@@ -749,11 +715,9 @@ def make_attention_mask(
749
  dtype: DType = jnp.float32,
750
  ) -> Array:
751
  """Mask-making helper for attention weights.
752
-
753
  In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the
754
  attention weights will be `[batch, heads, len_q, len_kv]` and this
755
  function will produce `[batch, 1, len_q, len_kv]`.
756
-
757
  Args:
758
  query_input: a batched, flat input of query_length size
759
  key_input: a batched, flat input of key_length size
@@ -761,7 +725,6 @@ def make_attention_mask(
761
  extra_batch_dims: number of extra batch dims to add singleton axes for, none
762
  by default
763
  dtype: mask return dtype
764
-
765
  Returns:
766
  A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention.
767
  """
@@ -781,21 +744,17 @@ def make_attention_mask(
781
 
782
  def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32) -> Array:
783
  """Make a causal mask for self-attention.
784
-
785
  In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights
786
  will be `[batch, heads, len, len]` and this function will produce a
787
  causal mask of shape `[batch, 1, len, len]`.
788
-
789
  Note that a causal mask does not depend on the values of x; it only depends on
790
  the shape. If x has padding elements, they will not be treated in a special
791
  manner.
792
-
793
  Args:
794
  x: input array of shape `[batch, len]`
795
  extra_batch_dims: number of batch dims to add singleton axes for, none by
796
  default
797
  dtype: mask return dtype
798
-
799
  Returns:
800
  A `[batch, 1, len, len]` shaped causal mask for 1d attention.
801
  """
@@ -805,11 +764,9 @@ def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.flo
805
 
806
  def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
807
  """Combine attention masks.
808
-
809
  Args:
810
  *masks: set of attention mask arguments to combine, some can be None.
811
  dtype: final mask dtype
812
-
813
  Returns:
814
  Combined mask, reduced by logical and, returns None if no masks given.
815
  """
@@ -827,10 +784,8 @@ def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
827
 
828
  def combine_biases(*masks: Optional[Array]):
829
  """Combine attention biases.
830
-
831
  Args:
832
  *masks: set of attention bias arguments to combine, some can be None.
833
-
834
  Returns:
835
  Combined mask, reduced by summation, returns None if no masks given.
836
  """
@@ -853,40 +808,30 @@ def make_decoder_mask(
853
  decoder_segment_ids: Optional[Array] = None,
854
  ) -> Array:
855
  """Compute the self-attention mask for a decoder.
856
-
857
  Decoder mask is formed by combining a causal mask, a padding mask and an
858
  optional packing mask. If decoder_causal_attention is passed, it makes the
859
  masking non-causal for positions that have value of 1.
860
-
861
  A prefix LM is applied to a dataset which has a notion of "inputs" and
862
  "targets", e.g., a machine translation task. The inputs and targets are
863
  concatenated to form a new target. `decoder_target_tokens` is the concatenated
864
  decoder output tokens.
865
-
866
  The "inputs" portion of the concatenated sequence can attend to other "inputs"
867
  tokens even for those at a later time steps. In order to control this
868
  behavior, `decoder_causal_attention` is necessary. This is a binary mask with
869
  a value of 1 indicating that the position belonged to "inputs" portion of the
870
  original dataset.
871
-
872
  Example:
873
-
874
  Suppose we have a dataset with two examples.
875
-
876
  ds = [{"inputs": [6, 7], "targets": [8]},
877
  {"inputs": [3, 4], "targets": [5]}]
878
-
879
  After the data preprocessing with packing, the two examples are packed into
880
  one example with the following three fields (some fields are skipped for
881
  simplicity).
882
-
883
  decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]]
884
  decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
885
  decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]
886
-
887
  where each array has [batch, length] shape with batch size being 1. Then,
888
  this function computes the following mask.
889
-
890
  mask = [[[[1, 1, 0, 0, 0, 0, 0],
891
  [1, 1, 0, 0, 0, 0, 0],
892
  [1, 1, 1, 0, 0, 0, 0],
@@ -894,14 +839,11 @@ def make_decoder_mask(
894
  [0, 0, 0, 1, 1, 0, 0],
895
  [0, 0, 0, 1, 1, 1, 0],
896
  [0, 0, 0, 0, 0, 0, 0]]]]
897
-
898
  mask[b, 1, :, :] represents the mask for the example `b` in the batch.
899
  Because mask is for a self-attention layer, the mask's shape is a square of
900
  shape [query length, key length].
901
-
902
  mask[b, 1, i, j] = 1 means that the query token at position i can attend to
903
  the key token at position j.
904
-
905
  Args:
906
  decoder_target_tokens: decoder output tokens. [batch, length]
907
  dtype: dtype of the output mask.
@@ -910,7 +852,6 @@ def make_decoder_mask(
910
  bidirectionally. [batch, length]
911
  decoder_segment_ids: decoder segmentation info for packed examples. [batch,
912
  length]
913
-
914
  Returns:
915
  the combined decoder mask.
916
  """
@@ -976,7 +917,6 @@ def _conv_dimension_numbers(input_shape):
976
 
977
  class _Conv(nn.Module):
978
  """Convolution Module wrapping `lax.conv_general_dilated[_local]`.
979
-
980
  Attributes:
981
  features: number of convolution filters.
982
  kernel_size: shape of the convolutional kernel. For 1D convolution,
@@ -1032,19 +972,16 @@ class _Conv(nn.Module):
1032
  @property
1033
  def shared_weights(self) -> bool: # type: ignore
1034
  """Defines whether weights are shared or not between different pixels.
1035
-
1036
  Returns:
1037
  `True` to use shared weights in convolution (regular convolution).
1038
  `False` to use different weights at different pixels, a.k.a.
1039
  "locally connected layer", "unshared convolution", or "local convolution".
1040
-
1041
  """
1042
  ...
1043
 
1044
  @nn.compact
1045
  def __call__(self, inputs: Array) -> Array:
1046
  """Applies a (potentially unshared) convolution to the inputs.
1047
-
1048
  Args:
1049
  inputs: input data with dimensions (*batch_dims, spatial_dims...,
1050
  features). This is the channels-last convention, i.e. NHWC for a 2d
@@ -1057,7 +994,6 @@ class _Conv(nn.Module):
1057
  better performance than this default flattening approach. If the input
1058
  lacks a batch dimension it will be added for the convolution and removed
1059
  n return, an allowance made to enable writing single-example code.
1060
-
1061
  Returns:
1062
  The convolved data.
1063
  """
@@ -1214,7 +1150,6 @@ class _Conv(nn.Module):
1214
 
1215
  class Conv(_Conv):
1216
  """Convolution Module wrapping `lax.conv_general_dilated`.
1217
-
1218
  Attributes:
1219
  features: number of convolution filters.
1220
  kernel_size: shape of the convolutional kernel. For 1D convolution,
@@ -1252,4 +1187,4 @@ class Conv(_Conv):
1252
 
1253
  @property
1254
  def shared_weights(self) -> bool:
1255
- return True
 
56
  default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
57
 
58
 
 
 
 
 
 
 
 
 
 
 
59
  def dot_product_attention(
60
  query: Array,
61
  key: Array,
 
68
  float32_logits: bool = False,
69
  ):
70
  """Computes dot-product attention given query, key, and value.
 
71
  This is the core function for applying attention based on
72
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
73
  query and key and combines the values using the attention weights.
 
74
  Args:
75
  query: queries for calculating attention with shape of `[batch, q_length,
76
  num_heads, qk_depth_per_head]`.
 
87
  dtype: the dtype of the computation (default: float32)
88
  float32_logits: bool, if True then compute logits in float32 to avoid
89
  numerical issues with bfloat16.
 
90
  Returns:
91
  Output of shape `[batch, length, num_heads, v_depth_per_head]`.
92
  """
 
132
 
133
  class MultiHeadDotProductAttention(nn.Module):
134
  """Multi-head dot-product attention.
 
135
  Attributes:
136
  num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
137
  should be divisible by the number of heads.
 
162
  deterministic: bool = False,
163
  ) -> Array:
164
  """Applies multi-head dot product attention on the input data.
 
165
  Projects the inputs into multi-headed query, key, and value vectors,
166
  applies dot-product attention and project the results to an output vector.
 
167
  There are two modes: decoding and non-decoding (e.g., training). The mode is
168
  determined by `decode` argument. For decoding, this method is called twice,
169
  first to initialize the cache and then for an actual decoding process. The
170
  two calls are differentiated by the presence of 'cached_key' in the variable
171
  dict. In the cache initialization stage, the cache variables are initialized
172
  as zeros and will be filled in the subsequent decoding process.
 
173
  In the cache initialization call, `inputs_q` has a shape [batch, length,
174
  q_features] and `inputs_kv`: [batch, length, kv_features]. During the
175
  incremental decoding stage, query, key and value all have the shape [batch,
176
  1, qkv_features] corresponding to a single step.
 
177
  Args:
178
  inputs_q: input queries of shape `[batch, q_length, q_features]`.
179
  inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
 
181
  bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
182
  decode: Whether to prepare and use an autoregressive cache.
183
  deterministic: Disables dropout if set to True.
 
184
  Returns:
185
  output of shape `[batch, length, q_features]`.
186
  """
 
341
  # ------------------------------------------------------------------------------
342
  class DenseGeneral(nn.Module):
343
  """A linear transformation (without bias) with flexible axes.
 
344
  Attributes:
345
  features: tuple with numbers of output features.
346
  axis: tuple with axes to apply the transformation on.
 
360
  @nn.compact
361
  def __call__(self, inputs: Array) -> Array:
362
  """Applies a linear transformation to the inputs along multiple dimensions.
 
363
  Args:
364
  inputs: The nd-array to be transformed.
 
365
  Returns:
366
  The transformed input.
367
  """
 
410
 
411
  class MlpBlock(nn.Module):
412
  """Transformer MLP / feed-forward block.
 
413
  Attributes:
414
  intermediate_dim: Shared dimension of hidden layers.
415
  activations: Type of activations for each layer. Each element is either
 
459
 
460
  class Embed(nn.Module):
461
  """A parameterized function from integers [0, n) to d-dimensional vectors.
 
462
  Attributes:
463
  num_embeddings: number of embeddings.
464
  features: number of feature dimensions for each embedding.
 
489
 
490
  def __call__(self, inputs: Array) -> Array:
491
  """Embeds the inputs along the last dimension.
 
492
  Args:
493
  inputs: input data, all dimensions are considered batch dimensions.
 
494
  Returns:
495
  Output which is embedded input data. The output shape follows the input,
496
  with an additional `features` dimension appended.
 
510
 
511
  def attend(self, query: Array) -> Array:
512
  """Attend over the embedding using a query array.
 
513
  Args:
514
  query: array with last dimension equal the feature depth `features` of the
515
  embedding.
 
516
  Returns:
517
  An array with final dim `num_embeddings` corresponding to the batched
518
  inner-product of the array of query vectors against each embedding.
 
525
 
526
  class RelativePositionBiases(nn.Module):
527
  """Adds T5-style relative positional embeddings to the attention logits.
 
528
  Attributes:
529
  num_buckets: Number of buckets to bucket distances between key and query
530
  positions into.
 
545
  @staticmethod
546
  def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
547
  """Translate relative position to a bucket number for relative attention.
 
548
  The relative position is defined as memory_position - query_position, i.e.
549
  the distance in tokens from the attending position to the attended-to
550
  position. If bidirectional=False, then positive relative positions are
 
555
  positions <=-max_distance map to the same bucket. This should allow for
556
  more graceful generalization to longer sequences than the model has been
557
  trained on.
 
558
  Args:
559
  relative_position: an int32 array
560
  bidirectional: a boolean - whether the attention is bidirectional
561
  num_buckets: an integer
562
  max_distance: an integer
 
563
  Returns:
564
  a Tensor with the same shape as relative_position, containing int32
565
  values in the range [0, num_buckets)
 
587
  @nn.compact
588
  def __call__(self, qlen, klen, bidirectional=True):
589
  """Produce relative position embedding attention biases.
 
590
  Args:
591
  qlen: attention query length.
592
  klen: attention key length.
593
  bidirectional: whether to allow positive memory-query relative position
594
  embeddings.
 
595
  Returns:
596
  output: `(1, len, q_len, k_len)` attention bias
597
  """
 
715
  dtype: DType = jnp.float32,
716
  ) -> Array:
717
  """Mask-making helper for attention weights.
 
718
  In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the
719
  attention weights will be `[batch, heads, len_q, len_kv]` and this
720
  function will produce `[batch, 1, len_q, len_kv]`.
 
721
  Args:
722
  query_input: a batched, flat input of query_length size
723
  key_input: a batched, flat input of key_length size
 
725
  extra_batch_dims: number of extra batch dims to add singleton axes for, none
726
  by default
727
  dtype: mask return dtype
 
728
  Returns:
729
  A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention.
730
  """
 
744
 
745
  def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32) -> Array:
746
  """Make a causal mask for self-attention.
 
747
  In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights
748
  will be `[batch, heads, len, len]` and this function will produce a
749
  causal mask of shape `[batch, 1, len, len]`.
 
750
  Note that a causal mask does not depend on the values of x; it only depends on
751
  the shape. If x has padding elements, they will not be treated in a special
752
  manner.
 
753
  Args:
754
  x: input array of shape `[batch, len]`
755
  extra_batch_dims: number of batch dims to add singleton axes for, none by
756
  default
757
  dtype: mask return dtype
 
758
  Returns:
759
  A `[batch, 1, len, len]` shaped causal mask for 1d attention.
760
  """
 
764
 
765
  def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
766
  """Combine attention masks.
 
767
  Args:
768
  *masks: set of attention mask arguments to combine, some can be None.
769
  dtype: final mask dtype
 
770
  Returns:
771
  Combined mask, reduced by logical and, returns None if no masks given.
772
  """
 
784
 
785
  def combine_biases(*masks: Optional[Array]):
786
  """Combine attention biases.
 
787
  Args:
788
  *masks: set of attention bias arguments to combine, some can be None.
 
789
  Returns:
790
  Combined mask, reduced by summation, returns None if no masks given.
791
  """
 
808
  decoder_segment_ids: Optional[Array] = None,
809
  ) -> Array:
810
  """Compute the self-attention mask for a decoder.
 
811
  Decoder mask is formed by combining a causal mask, a padding mask and an
812
  optional packing mask. If decoder_causal_attention is passed, it makes the
813
  masking non-causal for positions that have value of 1.
 
814
  A prefix LM is applied to a dataset which has a notion of "inputs" and
815
  "targets", e.g., a machine translation task. The inputs and targets are
816
  concatenated to form a new target. `decoder_target_tokens` is the concatenated
817
  decoder output tokens.
 
818
  The "inputs" portion of the concatenated sequence can attend to other "inputs"
819
  tokens even for those at a later time steps. In order to control this
820
  behavior, `decoder_causal_attention` is necessary. This is a binary mask with
821
  a value of 1 indicating that the position belonged to "inputs" portion of the
822
  original dataset.
 
823
  Example:
 
824
  Suppose we have a dataset with two examples.
 
825
  ds = [{"inputs": [6, 7], "targets": [8]},
826
  {"inputs": [3, 4], "targets": [5]}]
 
827
  After the data preprocessing with packing, the two examples are packed into
828
  one example with the following three fields (some fields are skipped for
829
  simplicity).
 
830
  decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]]
831
  decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
832
  decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]
 
833
  where each array has [batch, length] shape with batch size being 1. Then,
834
  this function computes the following mask.
 
835
  mask = [[[[1, 1, 0, 0, 0, 0, 0],
836
  [1, 1, 0, 0, 0, 0, 0],
837
  [1, 1, 1, 0, 0, 0, 0],
 
839
  [0, 0, 0, 1, 1, 0, 0],
840
  [0, 0, 0, 1, 1, 1, 0],
841
  [0, 0, 0, 0, 0, 0, 0]]]]
 
842
  mask[b, 1, :, :] represents the mask for the example `b` in the batch.
843
  Because mask is for a self-attention layer, the mask's shape is a square of
844
  shape [query length, key length].
 
845
  mask[b, 1, i, j] = 1 means that the query token at position i can attend to
846
  the key token at position j.
 
847
  Args:
848
  decoder_target_tokens: decoder output tokens. [batch, length]
849
  dtype: dtype of the output mask.
 
852
  bidirectionally. [batch, length]
853
  decoder_segment_ids: decoder segmentation info for packed examples. [batch,
854
  length]
 
855
  Returns:
856
  the combined decoder mask.
857
  """
 
917
 
918
  class _Conv(nn.Module):
919
  """Convolution Module wrapping `lax.conv_general_dilated[_local]`.
 
920
  Attributes:
921
  features: number of convolution filters.
922
  kernel_size: shape of the convolutional kernel. For 1D convolution,
 
972
  @property
973
  def shared_weights(self) -> bool: # type: ignore
974
  """Defines whether weights are shared or not between different pixels.
 
975
  Returns:
976
  `True` to use shared weights in convolution (regular convolution).
977
  `False` to use different weights at different pixels, a.k.a.
978
  "locally connected layer", "unshared convolution", or "local convolution".
 
979
  """
980
  ...
981
 
982
  @nn.compact
983
  def __call__(self, inputs: Array) -> Array:
984
  """Applies a (potentially unshared) convolution to the inputs.
 
985
  Args:
986
  inputs: input data with dimensions (*batch_dims, spatial_dims...,
987
  features). This is the channels-last convention, i.e. NHWC for a 2d
 
994
  better performance than this default flattening approach. If the input
995
  lacks a batch dimension it will be added for the convolution and removed
996
  n return, an allowance made to enable writing single-example code.
 
997
  Returns:
998
  The convolved data.
999
  """
 
1150
 
1151
  class Conv(_Conv):
1152
  """Convolution Module wrapping `lax.conv_general_dilated`.
 
1153
  Attributes:
1154
  features: number of convolution filters.
1155
  kernel_size: shape of the convolutional kernel. For 1D convolution,
 
1187
 
1188
  @property
1189
  def shared_weights(self) -> bool:
1190
+ return True