# Copyright (C) 2021-2024, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. from typing import Any, Tuple, Union import numpy as np import tensorflow as tf from tensorflow.keras import layers from doctr.utils.repr import NestedObject __all__ = ["FASTConvLayer"] class FASTConvLayer(layers.Layer, NestedObject): """Convolutional layer used in the TextNet and FAST architectures""" def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = False, ) -> None: super().__init__() self.groups = groups self.in_channels = in_channels self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size self.hor_conv, self.hor_bn = None, None self.ver_conv, self.ver_bn = None, None padding = ((self.converted_ks[0] - 1) * dilation // 2, (self.converted_ks[1] - 1) * dilation // 2) self.activation = layers.ReLU() self.conv_pad = layers.ZeroPadding2D(padding=padding) self.conv = layers.Conv2D( filters=out_channels, kernel_size=self.converted_ks, strides=stride, dilation_rate=dilation, groups=groups, use_bias=bias, ) self.bn = layers.BatchNormalization() if self.converted_ks[1] != 1: self.ver_pad = layers.ZeroPadding2D( padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0), ) self.ver_conv = layers.Conv2D( filters=out_channels, kernel_size=(self.converted_ks[0], 1), strides=stride, dilation_rate=dilation, groups=groups, use_bias=bias, ) self.ver_bn = layers.BatchNormalization() if self.converted_ks[0] != 1: self.hor_pad = layers.ZeroPadding2D( padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)), ) self.hor_conv = layers.Conv2D( filters=out_channels, kernel_size=(1, self.converted_ks[1]), strides=stride, dilation_rate=dilation, groups=groups, use_bias=bias, ) self.hor_bn = layers.BatchNormalization() self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: if hasattr(self, "fused_conv"): return self.activation(self.fused_conv(self.conv_pad(x, **kwargs), **kwargs)) main_outputs = self.bn(self.conv(self.conv_pad(x, **kwargs), **kwargs), **kwargs) vertical_outputs = ( self.ver_bn(self.ver_conv(self.ver_pad(x, **kwargs), **kwargs), **kwargs) if self.ver_conv is not None and self.ver_bn is not None else 0 ) horizontal_outputs = ( self.hor_bn(self.hor_conv(self.hor_pad(x, **kwargs), **kwargs), **kwargs) if self.hor_bn is not None and self.hor_conv is not None else 0 ) id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0 return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out) # The following logic is used to reparametrize the layer # Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py def _identity_to_conv( self, identity: layers.BatchNormalization ) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]: if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"): return 0, 0 if not hasattr(self, "id_tensor"): input_dim = self.in_channels // self.groups kernel_value = np.zeros((1, 1, input_dim, self.in_channels), dtype=np.float32) for i in range(self.in_channels): kernel_value[0, 0, i % input_dim, i] = 1 id_tensor = tf.constant(kernel_value, dtype=tf.float32) self.id_tensor = self._pad_to_mxn_tensor(id_tensor) kernel = self.id_tensor std = tf.sqrt(identity.moving_variance + identity.epsilon) t = tf.reshape(identity.gamma / std, (1, 1, 1, -1)) return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]: kernel = conv.kernel kernel = self._pad_to_mxn_tensor(kernel) std = tf.sqrt(bn.moving_variance + bn.epsilon) t = tf.reshape(bn.gamma / std, (1, 1, 1, -1)) return kernel * t, bn.beta - bn.moving_mean * bn.gamma / std def _get_equivalent_kernel_bias(self): kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn) if self.ver_conv is not None: kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) else: kernel_mx1, bias_mx1 = 0, 0 if self.hor_conv is not None: kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn) else: kernel_1xn, bias_1xn = 0, 0 kernel_id, bias_id = self._identity_to_conv(self.rbr_identity) kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id return kernel_mxn, bias_mxn def _pad_to_mxn_tensor(self, kernel: tf.Tensor) -> tf.Tensor: kernel_height, kernel_width = self.converted_ks height, width = kernel.shape[:2] pad_left_right = tf.maximum(0, (kernel_width - width) // 2) pad_top_down = tf.maximum(0, (kernel_height - height) // 2) return tf.pad(kernel, [[pad_top_down, pad_top_down], [pad_left_right, pad_left_right], [0, 0], [0, 0]]) def reparameterize_layer(self): kernel, bias = self._get_equivalent_kernel_bias() self.fused_conv = layers.Conv2D( filters=self.conv.filters, kernel_size=self.conv.kernel_size, strides=self.conv.strides, padding=self.conv.padding, dilation_rate=self.conv.dilation_rate, groups=self.conv.groups, use_bias=True, ) # build layer to initialize weights and biases self.fused_conv.build(input_shape=(None, None, None, kernel.shape[-2])) self.fused_conv.set_weights([kernel.numpy(), bias.numpy()]) for para in self.trainable_variables: para._trainable = False for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]: if hasattr(self, attr): delattr(self, attr) if hasattr(self, "rbr_identity"): delattr(self, "rbr_identity")