Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Keras-based bigbird attention layer.""" | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
MAX_SEQ_LEN = 4096 | |
def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask): | |
"""Create 3D attention mask from a 2D tensor mask. | |
Args: | |
from_blocked_mask: 2D Tensor of shape [batch_size, | |
from_seq_length//from_block_size, from_block_size]. | |
to_blocked_mask: int32 Tensor of shape [batch_size, | |
to_seq_length//to_block_size, to_block_size]. | |
Returns: | |
float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, | |
from_block_size, 3*to_block_size]. | |
""" | |
exp_blocked_to_pad = tf.concat([ | |
to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, | |
3:-1] | |
], 2) | |
band_mask = tf.einsum("BLQ,BLK->BLQK", from_blocked_mask[:, 2:-2], | |
exp_blocked_to_pad) | |
band_mask = tf.expand_dims(band_mask, 1) | |
return band_mask | |
def bigbird_block_rand_mask(from_seq_length, | |
to_seq_length, | |
from_block_size, | |
to_block_size, | |
num_rand_blocks, | |
last_idx=-1): | |
"""Create adjacency list of random attention. | |
Args: | |
from_seq_length: int. length of from sequence. | |
to_seq_length: int. length of to sequence. | |
from_block_size: int. size of block in from sequence. | |
to_block_size: int. size of block in to sequence. | |
num_rand_blocks: int. Number of random chunks per row. | |
last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence, | |
if positive then num_rand_blocks blocks choosen only upto last_idx. | |
Returns: | |
adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks | |
""" | |
assert from_seq_length//from_block_size == to_seq_length//to_block_size, \ | |
"Error the number of blocks needs to be same!" | |
rand_attn = np.zeros( | |
(from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32) | |
middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32) | |
last = to_seq_length // to_block_size - 1 | |
if last_idx > (2 * to_block_size): | |
last = (last_idx // to_block_size) - 1 | |
r = num_rand_blocks # shorthand | |
for i in range(1, from_seq_length // from_block_size - 1): | |
start = i - 2 | |
end = i | |
if i == 1: | |
rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r] | |
elif i == 2: | |
rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r] | |
elif i == from_seq_length // from_block_size - 3: | |
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] | |
# Missing -3: should have been sliced till last-3 | |
elif i == from_seq_length // from_block_size - 2: | |
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] | |
# Missing -4: should have been sliced till last-4 | |
else: | |
if start > last: | |
start = last | |
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] | |
elif (end + 1) == last: | |
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] | |
else: | |
rand_attn[i - 1, :] = np.random.permutation( | |
np.concatenate((middle_seq[:start], middle_seq[end + 1:last])))[:r] | |
return rand_attn | |
def create_rand_mask_from_inputs(from_blocked_mask, to_blocked_mask, rand_attn, | |
num_attention_heads, num_rand_blocks, | |
batch_size, from_seq_length, from_block_size): | |
"""Create 3D attention mask from a 2D tensor mask. | |
Args: | |
from_blocked_mask: 2D Tensor of shape [batch_size, | |
from_seq_length//from_block_size, from_block_size]. | |
to_blocked_mask: int32 Tensor of shape [batch_size, | |
to_seq_length//to_block_size, to_block_size]. | |
rand_attn: [batch_size, num_attention_heads, | |
from_seq_length//from_block_size-2, num_rand_blocks] | |
num_attention_heads: int. Number of attention heads. | |
num_rand_blocks: int. Number of random chunks per row. | |
batch_size: int. Batch size for computation. | |
from_seq_length: int. length of from sequence. | |
from_block_size: int. size of block in from sequence. | |
Returns: | |
float Tensor of shape [batch_size, num_attention_heads, | |
from_seq_length//from_block_size-2, | |
from_block_size, num_rand_blocks*to_block_size]. | |
""" | |
num_windows = from_seq_length // from_block_size - 2 | |
rand_mask = tf.reshape( | |
tf.gather(to_blocked_mask, rand_attn, batch_dims=1), [ | |
batch_size, num_attention_heads, num_windows, | |
num_rand_blocks * from_block_size | |
]) | |
rand_mask = tf.einsum("BLQ,BHLK->BHLQK", from_blocked_mask[:, 1:-1], | |
rand_mask) | |
return rand_mask | |
def bigbird_block_sparse_attention( | |
query_layer, key_layer, value_layer, band_mask, from_mask, to_mask, | |
from_blocked_mask, to_blocked_mask, rand_attn, num_attention_heads, | |
num_rand_blocks, size_per_head, batch_size, from_seq_length, to_seq_length, | |
from_block_size, to_block_size): | |
"""BigBird attention sparse calculation using blocks in linear time. | |
Assumes from_seq_length//from_block_size == to_seq_length//to_block_size. | |
Args: | |
query_layer: float Tensor of shape [batch_size, num_attention_heads, | |
from_seq_length, size_per_head] | |
key_layer: float Tensor of shape [batch_size, num_attention_heads, | |
to_seq_length, size_per_head] | |
value_layer: float Tensor of shape [batch_size, num_attention_heads, | |
to_seq_length, size_per_head] | |
band_mask: (optional) int32 Tensor of shape [batch_size, 1, | |
from_seq_length//from_block_size-4, from_block_size, 3*to_block_size]. The | |
values should be 1 or 0. The attention scores will effectively be set to | |
-infinity for any positions in the mask that are 0, and will be unchanged | |
for positions that are 1. | |
from_mask: (optional) int32 Tensor of shape [batch_size, 1, from_seq_length, | |
1]. The values should be 1 or 0. The attention scores will effectively be | |
set to -infinity for any positions in the mask that are 0, and will be | |
unchanged for positions that are 1. | |
to_mask: (optional) int32 Tensor of shape [batch_size, 1, 1, to_seq_length]. | |
The values should be 1 or 0. The attention scores will effectively be set | |
to -infinity for any positions in the mask that are 0, and will be | |
unchanged for positions that are 1. | |
from_blocked_mask: (optional) int32 Tensor of shape [batch_size, | |
from_seq_length//from_block_size, from_block_size]. Same as from_mask, | |
just reshaped. | |
to_blocked_mask: (optional) int32 Tensor of shape [batch_size, | |
to_seq_length//to_block_size, to_block_size]. Same as to_mask, just | |
reshaped. | |
rand_attn: [batch_size, num_attention_heads, | |
from_seq_length//from_block_size-2, num_rand_blocks] | |
num_attention_heads: int. Number of attention heads. | |
num_rand_blocks: int. Number of random chunks per row. | |
size_per_head: int. Size of each attention head. | |
batch_size: int. Batch size for computation. | |
from_seq_length: int. length of from sequence. | |
to_seq_length: int. length of to sequence. | |
from_block_size: int. size of block in from sequence. | |
to_block_size: int. size of block in to sequence. | |
Returns: | |
float Tensor of shape [batch_size, from_seq_length, num_attention_heads, | |
size_per_head]. | |
""" | |
rand_attn = tf.expand_dims(rand_attn, 0) | |
rand_attn = tf.repeat(rand_attn, batch_size, 0) | |
rand_mask = create_rand_mask_from_inputs( | |
from_blocked_mask, | |
to_blocked_mask, | |
rand_attn, | |
num_attention_heads, | |
num_rand_blocks, | |
batch_size, | |
from_seq_length, | |
from_block_size, | |
) | |
# Define shorthands | |
h = num_attention_heads | |
r = num_rand_blocks | |
d = size_per_head | |
b = batch_size | |
m = from_seq_length | |
n = to_seq_length | |
wm = from_block_size | |
wn = to_block_size | |
dtype = query_layer.dtype | |
query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3]) | |
key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3]) | |
value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3]) | |
blocked_query_matrix = tf.reshape(query_layer, (b, h, m // wm, wm, -1)) | |
blocked_key_matrix = tf.reshape(key_layer, (b, h, n // wn, wn, -1)) | |
blocked_value_matrix = tf.reshape(value_layer, (b, h, n // wn, wn, -1)) | |
gathered_key = tf.reshape( | |
tf.gather(blocked_key_matrix, rand_attn, batch_dims=2, name="gather_key"), | |
(b, h, m // wm - 2, r * wn, -1)) # [b, h, n//wn-2, r, wn, -1] | |
gathered_value = tf.reshape( | |
tf.gather( | |
blocked_value_matrix, rand_attn, batch_dims=2, name="gather_value"), | |
(b, h, m // wm - 2, r * wn, -1)) # [b, h, n//wn-2, r, wn, -1] | |
first_product = tf.einsum( | |
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 0], | |
key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n] | |
first_product = tf.multiply(first_product, 1.0 / np.sqrt(d)) | |
first_product += (1.0 - tf.cast(to_mask, dtype=dtype)) * -10000.0 | |
first_attn_weights = tf.nn.softmax(first_product) # [b, h, wm, n] | |
first_context_layer = tf.einsum( | |
"BHQK,BHKD->BHQD", first_attn_weights, | |
value_layer) # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1] | |
first_context_layer = tf.expand_dims(first_context_layer, 2) | |
second_key_mat = tf.concat([ | |
blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, 1], | |
blocked_key_matrix[:, :, 2], blocked_key_matrix[:, :, | |
-1], gathered_key[:, :, 0] | |
], 2) # [b, h, (4+r)*wn, -1] | |
second_value_mat = tf.concat([ | |
blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, 1], | |
blocked_value_matrix[:, :, 2], blocked_value_matrix[:, :, -1], | |
gathered_value[:, :, 0] | |
], 2) # [b, h, (4+r)*wn, -1] | |
second_product = tf.einsum( | |
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 1], second_key_mat | |
) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn] | |
second_seq_pad = tf.concat([ | |
to_mask[:, :, :, :3 * wn], to_mask[:, :, :, -wn:], | |
tf.ones([b, 1, 1, r * wn], dtype=dtype) | |
], 3) | |
second_rand_pad = tf.concat([ | |
tf.ones([b, h, wm, 4 * wn], dtype=dtype), rand_mask[:, :, 0] | |
], 3) | |
second_product = tf.multiply(second_product, 1.0 / np.sqrt(d)) | |
second_product += (1.0 - | |
tf.minimum(second_seq_pad, second_rand_pad)) * -10000.0 | |
second_attn_weights = tf.nn.softmax(second_product) # [b , h, wm, (4+r)*wn] | |
second_context_layer = tf.einsum( | |
"BHQK,BHKD->BHQD", second_attn_weights, second_value_mat | |
) # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1] | |
second_context_layer = tf.expand_dims(second_context_layer, 2) | |
exp_blocked_key_matrix = tf.concat([ | |
blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2], | |
blocked_key_matrix[:, :, 3:-1] | |
], 3) # [b, h, m//wm-4, 3*wn, -1] | |
exp_blocked_value_matrix = tf.concat([ | |
blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2], | |
blocked_value_matrix[:, :, 3:-1] | |
], 3) # [b, h, m//wm-4, 3*wn, -1] | |
middle_query_matrix = blocked_query_matrix[:, :, 2:-2] | |
inner_band_product = tf.einsum( | |
"BHLQD,BHLKD->BHLQK", middle_query_matrix, exp_blocked_key_matrix | |
) # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, 3*wn, -1] | |
# ==> [b, h, m//wm-4, wm, 3*wn] | |
inner_band_product = tf.multiply(inner_band_product, 1.0 / np.sqrt(d)) | |
rand_band_product = tf.einsum( | |
"BHLQD,BHLKD->BHLQK", middle_query_matrix, | |
gathered_key[:, :, | |
1:-1]) # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, r*wn, -1] | |
# ==> [b, h, m//wm-4, wm, r*wn] | |
rand_band_product = tf.multiply(rand_band_product, 1.0 / np.sqrt(d)) | |
first_band_product = tf.einsum( | |
"BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, 0] | |
) # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn] | |
first_band_product = tf.multiply(first_band_product, 1.0 / np.sqrt(d)) | |
last_band_product = tf.einsum( | |
"BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, -1] | |
) # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn] | |
last_band_product = tf.multiply(last_band_product, 1.0 / np.sqrt(d)) | |
inner_band_product += (1.0 - band_mask) * -10000.0 | |
first_band_product += (1.0 - | |
tf.expand_dims(to_mask[:, :, :, :wn], 3)) * -10000.0 | |
last_band_product += (1.0 - | |
tf.expand_dims(to_mask[:, :, :, -wn:], 3)) * -10000.0 | |
rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * -10000.0 | |
band_product = tf.concat([ | |
first_band_product, inner_band_product, rand_band_product, | |
last_band_product | |
], -1) # [b, h, m//wm-4, wm, (5+r)*wn] | |
attn_weights = tf.nn.softmax(band_product) # [b, h, m//wm-4, wm, (5+r)*wn] | |
context_layer = tf.einsum( | |
"BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :, | |
wn:4 * wn], exp_blocked_value_matrix | |
) # [b, h, m//wm-4, wm, 3*wn] x [b, h, m//wm-4, 3*wn, -1] | |
# ==> [b, h, m//wm-4, wm, -1] | |
context_layer += tf.einsum( | |
"BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :, | |
4 * wn:-wn], gathered_value[:, :, 1:-1] | |
) # [b, h, m//wm-4, wm, r*wn] x [b, h, m//wm-4, r*wn, -1] | |
# ==> [b, h, m//wm-4, wm, -1] | |
context_layer += tf.einsum( | |
"BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :, :wn], | |
blocked_value_matrix[:, :, 0] | |
) # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1] | |
context_layer += tf.einsum( | |
"BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :, | |
-wn:], blocked_value_matrix[:, :, -1] | |
) # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1] | |
second_last_key_mat = tf.concat([ | |
blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, -3], | |
blocked_key_matrix[:, :, -2], blocked_key_matrix[:, :, -1], | |
gathered_key[:, :, -1] | |
], 2) # [b, h, (4+r)*wn, -1] | |
second_last_value_mat = tf.concat([ | |
blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, -3], | |
blocked_value_matrix[:, :, -2], blocked_value_matrix[:, :, -1], | |
gathered_value[:, :, -1] | |
], 2) # [b, h, (4+r)*wn, -1] | |
second_last_product = tf.einsum( | |
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -2], second_last_key_mat | |
) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn] | |
second_last_seq_pad = tf.concat([ | |
to_mask[:, :, :, :wn], to_mask[:, :, :, -3 * wn:], | |
tf.ones([b, 1, 1, r * wn], dtype=dtype) | |
], 3) | |
second_last_rand_pad = tf.concat( | |
[tf.ones([b, h, wm, 4 * wn], dtype=dtype), rand_mask[:, :, -1]], 3) | |
second_last_product = tf.multiply(second_last_product, 1.0 / np.sqrt(d)) | |
second_last_product += ( | |
1.0 - tf.minimum(second_last_seq_pad, second_last_rand_pad)) * -10000.0 | |
second_last_attn_weights = tf.nn.softmax( | |
second_last_product) # [b, h, wm, (4+r)*wn] | |
second_last_context_layer = tf.einsum( | |
"BHQK,BHKD->BHQD", second_last_attn_weights, second_last_value_mat | |
) # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1] | |
second_last_context_layer = tf.expand_dims(second_last_context_layer, 2) | |
last_product = tf.einsum( | |
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -1], | |
key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n] | |
last_product = tf.multiply(last_product, 1.0 / np.sqrt(d)) | |
last_product += (1.0 - to_mask) * -10000.0 | |
last_attn_weights = tf.nn.softmax(last_product) # [b, h, wm, n] | |
last_context_layer = tf.einsum( | |
"BHQK,BHKD->BHQD", last_attn_weights, | |
value_layer) # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1] | |
last_context_layer = tf.expand_dims(last_context_layer, 2) | |
context_layer = tf.concat([ | |
first_context_layer, second_context_layer, context_layer, | |
second_last_context_layer, last_context_layer | |
], 2) | |
context_layer = tf.reshape(context_layer, (b, h, m, -1)) * from_mask | |
context_layer = tf.transpose(context_layer, (0, 2, 1, 3)) | |
return context_layer | |
class BigBirdMasks(tf_keras.layers.Layer): | |
"""Creates bigbird attention masks.""" | |
def __init__(self, block_size, **kwargs): | |
super().__init__(**kwargs) | |
self._block_size = block_size | |
def call(self, inputs, mask): | |
encoder_shape = tf.shape(mask) | |
mask = tf.cast(mask, inputs.dtype) | |
batch_size, seq_length = encoder_shape[0], encoder_shape[1] | |
# reshape for blocking | |
blocked_encoder_mask = tf.reshape( | |
mask, (batch_size, seq_length // self._block_size, self._block_size)) | |
encoder_from_mask = tf.reshape(mask, (batch_size, 1, seq_length, 1)) | |
encoder_to_mask = tf.reshape(mask, (batch_size, 1, 1, seq_length)) | |
band_mask = create_band_mask_from_inputs(blocked_encoder_mask, | |
blocked_encoder_mask) | |
return [band_mask, encoder_from_mask, encoder_to_mask, blocked_encoder_mask] | |
class BigBirdAttention(tf_keras.layers.MultiHeadAttention): | |
"""BigBird, a sparse attention mechanism. | |
This layer follows the paper "Big Bird: Transformers for Longer Sequences" | |
(https://arxiv.org/abs/2007.14062). | |
It reduces this quadratic dependency of attention | |
computation to linear. | |
Arguments are the same as `MultiHeadAttention` layer. | |
""" | |
def __init__(self, | |
num_rand_blocks=3, | |
from_block_size=64, | |
to_block_size=64, | |
max_rand_mask_length=MAX_SEQ_LEN, | |
seed=None, | |
**kwargs): | |
super().__init__(**kwargs) | |
self._num_rand_blocks = num_rand_blocks | |
self._from_block_size = from_block_size | |
self._to_block_size = to_block_size | |
self._seed = seed | |
# Generates random attention. | |
np.random.seed(self._seed) | |
# pylint: disable=g-complex-comprehension | |
rand_attn = [ | |
bigbird_block_rand_mask( | |
max_rand_mask_length, | |
max_rand_mask_length, | |
from_block_size, | |
to_block_size, | |
num_rand_blocks, | |
last_idx=1024) for _ in range(self._num_heads) | |
] | |
# pylint: enable=g-complex-comprehension | |
rand_attn = np.stack(rand_attn, axis=0) | |
self.rand_attn = tf.constant(rand_attn, dtype=tf.int32) | |
def _compute_attention(self, query, key, value, attention_mask=None): | |
(band_mask, encoder_from_mask, encoder_to_mask, | |
blocked_encoder_mask) = attention_mask | |
query_shape = tf.shape(query) | |
from_seq_length = query_shape[1] | |
to_seq_length = tf.shape(key)[1] | |
rand_attn = self.rand_attn[:, :(from_seq_length // self._from_block_size - | |
2)] | |
return bigbird_block_sparse_attention( | |
query, | |
key, | |
value, | |
band_mask, | |
encoder_from_mask, | |
encoder_to_mask, | |
blocked_encoder_mask, | |
blocked_encoder_mask, | |
num_attention_heads=self._num_heads, | |
num_rand_blocks=self._num_rand_blocks, | |
size_per_head=self._key_dim, | |
batch_size=query_shape[0], | |
from_seq_length=from_seq_length, | |
to_seq_length=to_seq_length, | |
from_block_size=self._from_block_size, | |
to_block_size=self._to_block_size, | |
rand_attn=rand_attn) | |
def call(self, query, value, key=None, attention_mask=None, **kwargs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks | |
if not self._built_from_signature: | |
self._build_from_signature(query=query, value=value, key=key) | |
if key is None: | |
key = value | |
# N = `num_attention_heads` | |
# H = `size_per_head` | |
# `query` = [B, T, N ,H] | |
query = self._query_dense(query) | |
# `key` = [B, S, N, H] | |
key = self._key_dense(key) | |
# `value` = [B, S, N, H] | |
value = self._value_dense(value) | |
attention_output = self._compute_attention(query, key, value, | |
attention_mask) | |
attention_output.set_shape([None, None, self._num_heads, self._value_dim]) | |
attention_output = self._output_dense(attention_output) | |
return attention_output | |
def get_config(self): | |
config = { | |
"num_rand_blocks": self._num_rand_blocks, | |
"from_block_size": self._from_block_size, | |
"to_block_size": self._to_block_size, | |
"seed": self._seed | |
} | |
base_config = super().get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |