# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from annotator.uniformer.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
from annotator.uniformer.mmcv.ops.deform_conv import deform_conv2d
from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version


@CONV_LAYERS.register_module(name='SAC')
class SAConv2d(ConvAWS2d):
    """SAC (Switchable Atrous Convolution)

    This is an implementation of SAC in DetectoRS
    (https://arxiv.org/pdf/2006.02334.pdf).

    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int or tuple): Size of the convolving kernel
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of
            the input. Default: 0
        padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
            ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
        dilation (int or tuple, optional): Spacing between kernel elements.
            Default: 1
        groups (int, optional): Number of blocked connections from input
            channels to output channels. Default: 1
        bias (bool, optional): If ``True``, adds a learnable bias to the
            output. Default: ``True``
        use_deform: If ``True``, replace convolution with deformable
            convolution. Default: ``False``.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 use_deform=False):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias)
        self.use_deform = use_deform
        self.switch = nn.Conv2d(
            self.in_channels, 1, kernel_size=1, stride=stride, bias=True)
        self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size()))
        self.pre_context = nn.Conv2d(
            self.in_channels, self.in_channels, kernel_size=1, bias=True)
        self.post_context = nn.Conv2d(
            self.out_channels, self.out_channels, kernel_size=1, bias=True)
        if self.use_deform:
            self.offset_s = nn.Conv2d(
                self.in_channels,
                18,
                kernel_size=3,
                padding=1,
                stride=stride,
                bias=True)
            self.offset_l = nn.Conv2d(
                self.in_channels,
                18,
                kernel_size=3,
                padding=1,
                stride=stride,
                bias=True)
        self.init_weights()

    def init_weights(self):
        constant_init(self.switch, 0, bias=1)
        self.weight_diff.data.zero_()
        constant_init(self.pre_context, 0)
        constant_init(self.post_context, 0)
        if self.use_deform:
            constant_init(self.offset_s, 0)
            constant_init(self.offset_l, 0)

    def forward(self, x):
        # pre-context
        avg_x = F.adaptive_avg_pool2d(x, output_size=1)
        avg_x = self.pre_context(avg_x)
        avg_x = avg_x.expand_as(x)
        x = x + avg_x
        # switch
        avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
        avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
        switch = self.switch(avg_x)
        # sac
        weight = self._get_weight(self.weight)
        zero_bias = torch.zeros(
            self.out_channels, device=weight.device, dtype=weight.dtype)

        if self.use_deform:
            offset = self.offset_s(avg_x)
            out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
                                  self.dilation, self.groups, 1)
        else:
            if (TORCH_VERSION == 'parrots'
                    or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
                out_s = super().conv2d_forward(x, weight)
            elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
                # bias is a required argument of _conv_forward in torch 1.8.0
                out_s = super()._conv_forward(x, weight, zero_bias)
            else:
                out_s = super()._conv_forward(x, weight)
        ori_p = self.padding
        ori_d = self.dilation
        self.padding = tuple(3 * p for p in self.padding)
        self.dilation = tuple(3 * d for d in self.dilation)
        weight = weight + self.weight_diff
        if self.use_deform:
            offset = self.offset_l(avg_x)
            out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
                                  self.dilation, self.groups, 1)
        else:
            if (TORCH_VERSION == 'parrots'
                    or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
                out_l = super().conv2d_forward(x, weight)
            elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
                # bias is a required argument of _conv_forward in torch 1.8.0
                out_l = super()._conv_forward(x, weight, zero_bias)
            else:
                out_l = super()._conv_forward(x, weight)

        out = switch * out_s + (1 - switch) * out_l
        self.padding = ori_p
        self.dilation = ori_d
        # post-context
        avg_x = F.adaptive_avg_pool2d(out, output_size=1)
        avg_x = self.post_context(avg_x)
        avg_x = avg_x.expand_as(out)
        out = out + avg_x
        return out