from einops import rearrange

import torch
import torch.nn as nn
import torch.nn.functional as F

# from timm.models.layers import trunc_normal_

from segmenter_model.utils import padding, unpadding


class Segmenter(nn.Module):
    def __init__(
            self,
            encoder,
            decoder,
            n_cls,
    ):
        super().__init__()
        self.n_cls = n_cls
        self.patch_size = encoder.patch_size
        self.encoder = encoder
        self.decoder = decoder

    @torch.jit.ignore
    def no_weight_decay(self):
        def append_prefix_no_weight_decay(prefix, module):
            return set(map(lambda x: prefix + x, module.no_weight_decay()))

        nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union(
            append_prefix_no_weight_decay("decoder.", self.decoder)
        )
        return nwd_params

    def forward(self, im, decoder_features=False, no_upsample=False, encoder_features=False, no_rearrange=False,
                cls_only=False, encoder_only=False):
        H_ori, W_ori = im.size(2), im.size(3)
        if not no_upsample:
            im = padding(im, self.patch_size)
        H, W = im.size(2), im.size(3)

        x = self.encoder(im, return_features=True)  # self.patch_size times smaller than im

        # remove CLS/DIST tokens for decoding
        num_extra_tokens = 1 + self.encoder.distilled

        if cls_only:
            return x[:, 0]
        x = x[:, num_extra_tokens:]

        if encoder_features:
            enc_fts = x.clone()
            if not no_rearrange:
                GS = H // self.patch_size
                enc_fts = rearrange(enc_fts, "b (h w) c -> b c h w", h=GS)
            if encoder_only:
                return enc_fts

        if decoder_features:
            output = self.decoder(x, (H, W), features_only=True, no_rearrange=no_rearrange)
            if no_rearrange:
                if encoder_features:
                    output = (enc_fts, output)
                return output
        else:
            output = self.decoder(x, (H, W))  # shape (BS, NCLS, H/self.patch_size, W/self.patch_size)

        if not no_upsample:
            output = F.interpolate(output, size=(H, W), mode="bilinear")  # upsample self.patch_size times
            output = unpadding(output, (H_ori, W_ori))

        if encoder_features:
            output = (enc_fts, output)
        return output

    def get_attention_map_enc(self, im, layer_id):
        return self.encoder.get_attention_map(im, layer_id)

    def get_attention_map_dec(self, im, layer_id):
        x = self.encoder(im, return_features=True)

        # remove CLS/DIST tokens for decoding
        num_extra_tokens = 1 + self.encoder.distilled
        x = x[:, num_extra_tokens:]

        return self.decoder.get_attention_map(x, layer_id)