Spaces:
Sleeping
Sleeping
import functools | |
import tensorflow as tf | |
from tensorflow.keras import backend as K | |
from tensorflow.keras import layers | |
from ..layers import BlockImages, SwapAxes, UnblockImages | |
from .block_gating import BlockGmlpLayer | |
from .grid_gating import GridGmlpLayer | |
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") | |
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same") | |
ConvT_up = functools.partial( | |
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same" | |
) | |
Conv_down = functools.partial( | |
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same" | |
) | |
def ResidualSplitHeadMultiAxisGmlpLayer( | |
block_size, | |
grid_size, | |
block_gmlp_factor: int = 2, | |
grid_gmlp_factor: int = 2, | |
input_proj_factor: int = 2, | |
use_bias: bool = True, | |
dropout_rate: float = 0.0, | |
name: str = "residual_split_head_maxim", | |
): | |
"""The multi-axis gated MLP block.""" | |
def apply(x): | |
shortcut = x | |
n, h, w, num_channels = ( | |
K.int_shape(x)[0], | |
K.int_shape(x)[1], | |
K.int_shape(x)[2], | |
K.int_shape(x)[3], | |
) | |
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x) | |
x = layers.Dense( | |
int(num_channels) * input_proj_factor, | |
use_bias=use_bias, | |
name=f"{name}_in_project", | |
)(x) | |
x = tf.nn.gelu(x, approximate=True) | |
u, v = tf.split(x, 2, axis=-1) | |
# GridGMLPLayer | |
u = GridGmlpLayer( | |
grid_size=grid_size, | |
factor=grid_gmlp_factor, | |
use_bias=use_bias, | |
dropout_rate=dropout_rate, | |
name=f"{name}_GridGmlpLayer", | |
)(u) | |
# BlockGMLPLayer | |
v = BlockGmlpLayer( | |
block_size=block_size, | |
factor=block_gmlp_factor, | |
use_bias=use_bias, | |
dropout_rate=dropout_rate, | |
name=f"{name}_BlockGmlpLayer", | |
)(v) | |
x = tf.concat([u, v], axis=-1) | |
x = layers.Dense( | |
num_channels, | |
use_bias=use_bias, | |
name=f"{name}_out_project", | |
)(x) | |
x = layers.Dropout(dropout_rate)(x) | |
x = x + shortcut | |
return x | |
return apply | |
def GetSpatialGatingWeights( | |
features: int, | |
block_size, | |
grid_size, | |
input_proj_factor: int = 2, | |
dropout_rate: float = 0.0, | |
use_bias: bool = True, | |
name: str = "spatial_gating", | |
): | |
"""Get gating weights for cross-gating MLP block.""" | |
def apply(x): | |
n, h, w, num_channels = ( | |
K.int_shape(x)[0], | |
K.int_shape(x)[1], | |
K.int_shape(x)[2], | |
K.int_shape(x)[3], | |
) | |
# input projection | |
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x) | |
x = layers.Dense( | |
num_channels * input_proj_factor, | |
use_bias=use_bias, | |
name=f"{name}_in_project", | |
)(x) | |
x = tf.nn.gelu(x, approximate=True) | |
u, v = tf.split(x, 2, axis=-1) | |
# Get grid MLP weights | |
gh, gw = grid_size | |
fh, fw = h // gh, w // gw | |
u = BlockImages()(u, patch_size=(fh, fw)) | |
dim_u = K.int_shape(u)[-3] | |
u = SwapAxes()(u, -1, -3) | |
u = layers.Dense(dim_u, use_bias=use_bias, name=f"{name}_Dense_0")(u) | |
u = SwapAxes()(u, -1, -3) | |
u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw)) | |
# Get Block MLP weights | |
fh, fw = block_size | |
gh, gw = h // fh, w // fw | |
v = BlockImages()(v, patch_size=(fh, fw)) | |
dim_v = K.int_shape(v)[-2] | |
v = SwapAxes()(v, -1, -2) | |
v = layers.Dense(dim_v, use_bias=use_bias, name=f"{name}_Dense_1")(v) | |
v = SwapAxes()(v, -1, -2) | |
v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw)) | |
x = tf.concat([u, v], axis=-1) | |
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x) | |
x = layers.Dropout(dropout_rate)(x) | |
return x | |
return apply | |
def CrossGatingBlock( | |
features: int, | |
block_size, | |
grid_size, | |
dropout_rate: float = 0.0, | |
input_proj_factor: int = 2, | |
upsample_y: bool = True, | |
use_bias: bool = True, | |
name: str = "cross_gating", | |
): | |
"""Cross-gating MLP block.""" | |
def apply(x, y): | |
# Upscale Y signal, y is the gating signal. | |
if upsample_y: | |
y = ConvT_up( | |
filters=features, use_bias=use_bias, name=f"{name}_ConvTranspose_0" | |
)(y) | |
x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_Conv_0")(x) | |
n, h, w, num_channels = ( | |
K.int_shape(x)[0], | |
K.int_shape(x)[1], | |
K.int_shape(x)[2], | |
K.int_shape(x)[3], | |
) | |
y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y) | |
shortcut_x = x | |
shortcut_y = y | |
# Get gating weights from X | |
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_x")(x) | |
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(x) | |
x = tf.nn.gelu(x, approximate=True) | |
gx = GetSpatialGatingWeights( | |
features=num_channels, | |
block_size=block_size, | |
grid_size=grid_size, | |
dropout_rate=dropout_rate, | |
use_bias=use_bias, | |
name=f"{name}_SplitHeadMultiAxisGating_x", | |
)(x) | |
# Get gating weights from Y | |
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_y")(y) | |
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(y) | |
y = tf.nn.gelu(y, approximate=True) | |
gy = GetSpatialGatingWeights( | |
features=num_channels, | |
block_size=block_size, | |
grid_size=grid_size, | |
dropout_rate=dropout_rate, | |
use_bias=use_bias, | |
name=f"{name}_SplitHeadMultiAxisGating_y", | |
)(y) | |
# Apply cross gating: X = X * GY, Y = Y * GX | |
y = y * gx | |
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(y) | |
y = layers.Dropout(dropout_rate)(y) | |
y = y + shortcut_y | |
x = x * gy # gating x using y | |
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(x) | |
x = layers.Dropout(dropout_rate)(x) | |
x = x + y + shortcut_x # get all aggregated signals | |
return x, y | |
return apply | |