File size: 814 Bytes
3126b1e
 
 
 
 
 
4a0cabe
3126b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import tensorflow as tf

from .feature import ReduceSize


@tf.keras.utils.register_keras_serializable(package="gcvit")
class Stem(tf.keras.layers.Layer):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim

    def build(self, input_shape):
        self.pad = tf.keras.layers.ZeroPadding2D(1, name='pad')
        self.proj = tf.keras.layers.Conv2D(self.dim, kernel_size=3, strides=2, name='proj')
        self.conv_down = ReduceSize(keep_dim=True, name='conv_down')
        super().build(input_shape)

    def call(self, inputs, **kwargs):
        x = self.pad(inputs)
        x = self.proj(x)
        x = self.conv_down(x)
        return x

    def get_config(self):
        config = super().get_config()
        config.update({'dim': self.dim})
        return config