Spaces:
Runtime error
Runtime error
File size: 1,789 Bytes
6fd61b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import tensorflow as tf
from tensorflow.keras import layers
def spatial_attention_block(input_tensor):
average_pooling = tf.reduce_max(input_tensor, axis=-1)
average_pooling = tf.expand_dims(average_pooling, axis=-1)
max_pooling = tf.reduce_mean(input_tensor, axis=-1)
max_pooling = tf.expand_dims(max_pooling, axis=-1)
concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling])
feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated)
feature_map = tf.nn.sigmoid(feature_map)
return input_tensor * feature_map
def channel_attention_block(input_tensor):
channels = list(input_tensor.shape)[-1]
average_pooling = layers.GlobalAveragePooling2D()(input_tensor)
feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels))
feature_activations = layers.Conv2D(
filters=channels // 8, kernel_size=(1, 1), activation="relu"
)(feature_descriptor)
feature_activations = layers.Conv2D(
filters=channels, kernel_size=(1, 1), activation="sigmoid"
)(feature_activations)
return input_tensor * feature_activations
def dual_attention_unit_block(input_tensor):
channels = list(input_tensor.shape)[-1]
feature_map = layers.Conv2D(
channels, kernel_size=(3, 3), padding="same", activation="relu"
)(input_tensor)
feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
feature_map
)
channel_attention = channel_attention_block(feature_map)
spatial_attention = spatial_attention_block(feature_map)
concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])
concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
return layers.Add()([input_tensor, concatenation])
|