Spaces:
Runtime error
Runtime error
File size: 6,853 Bytes
153628e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn as nn
__all__ = ["FASTConvLayer"]
class FASTConvLayer(nn.Module):
"""Convolutional layer used in the TextNet and FAST architectures"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = False,
) -> None:
super().__init__()
self.groups = groups
self.in_channels = in_channels
self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.hor_conv, self.hor_bn = None, None
self.ver_conv, self.ver_bn = None, None
padding = (int(((self.converted_ks[0] - 1) * dilation) / 2), int(((self.converted_ks[1] - 1) * dilation) / 2))
self.activation = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=self.converted_ks,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
self.bn = nn.BatchNorm2d(out_channels)
if self.converted_ks[1] != 1:
self.ver_conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=(self.converted_ks[0], 1),
padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0),
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
)
self.ver_bn = nn.BatchNorm2d(out_channels)
if self.converted_ks[0] != 1:
self.hor_conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=(1, self.converted_ks[1]),
padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)),
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
)
self.hor_bn = nn.BatchNorm2d(out_channels)
self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "fused_conv"):
return self.activation(self.fused_conv(x))
main_outputs = self.bn(self.conv(x))
vertical_outputs = self.ver_bn(self.ver_conv(x)) if self.ver_conv is not None and self.ver_bn is not None else 0
horizontal_outputs = (
self.hor_bn(self.hor_conv(x)) if self.hor_bn is not None and self.hor_conv is not None else 0
)
id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0
return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
# The following logic is used to reparametrize the layer
# Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
def _identity_to_conv(
self, identity: Union[nn.BatchNorm2d, None]
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
if identity is None or identity.running_var is None:
return 0, 0
if not hasattr(self, "id_tensor"):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 0, 0] = 1
id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
kernel = self.id_tensor
std = (identity.running_var + identity.eps).sqrt()
t = (identity.weight / std).reshape(-1, 1, 1, 1)
return kernel * t, identity.bias - identity.running_mean * identity.weight / std
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> Tuple[torch.Tensor, torch.Tensor]:
kernel = conv.weight
kernel = self._pad_to_mxn_tensor(kernel)
std = (bn.running_var + bn.eps).sqrt() # type: ignore
t = (bn.weight / std).reshape(-1, 1, 1, 1)
return kernel * t, bn.bias - bn.running_mean * bn.weight / std
def _get_equivalent_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
if self.ver_conv is not None:
kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
else:
kernel_mx1, bias_mx1 = 0, 0 # type: ignore[assignment]
if self.hor_conv is not None:
kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn) # type: ignore[arg-type]
else:
kernel_1xn, bias_1xn = 0, 0 # type: ignore[assignment]
kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
return kernel_mxn, bias_mxn
def _pad_to_mxn_tensor(self, kernel: torch.Tensor) -> torch.Tensor:
kernel_height, kernel_width = self.converted_ks
height, width = kernel.shape[2:]
pad_left_right = (kernel_width - width) // 2
pad_top_down = (kernel_height - height) // 2
return torch.nn.functional.pad(kernel, [pad_left_right, pad_left_right, pad_top_down, pad_top_down], value=0)
def reparameterize_layer(self):
if hasattr(self, "fused_conv"):
return
kernel, bias = self._get_equivalent_kernel_bias()
self.fused_conv = nn.Conv2d(
in_channels=self.conv.in_channels,
out_channels=self.conv.out_channels,
kernel_size=self.conv.kernel_size, # type: ignore[arg-type]
stride=self.conv.stride, # type: ignore[arg-type]
padding=self.conv.padding, # type: ignore[arg-type]
dilation=self.conv.dilation, # type: ignore[arg-type]
groups=self.conv.groups,
bias=True,
)
self.fused_conv.weight.data = kernel
self.fused_conv.bias.data = bias # type: ignore[union-attr]
for para in self.parameters():
para.detach_()
for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
if hasattr(self, attr):
self.__delattr__(attr)
if hasattr(self, "rbr_identity"):
self.__delattr__("rbr_identity")
|