import torch
import torch.nn.functional as F
from dataclasses import dataclass
from diffusers.utils import BaseOutput
from typing import Any, Dict, List, Optional, Tuple, Union
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, UpDecoderBlock2D, CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, UpBlock2D, CrossAttnUpBlock2D
from diffusers.models.resnet import ResnetBlock2D
from diffusers.models.attention import AttentionBlock
from diffusers.models.cross_attention import CrossAttention
from attribution import FullyConnectedLayer
import math


def customize_vae_decoder(vae, phi_dimension, modulation, finetune, weight_offset, lr_multiplier):
    d = 'd' in modulation
    e = 'e' in modulation
    q = 'q' in modulation
    k = 'k' in modulation
    v = 'v' in modulation

    def add_affine_conv(vaed):
        if not (d or e):
            return

        for layer in vaed.children():
            if type(layer) == ResnetBlock2D:
                if d:
                    layer.affine_d = FullyConnectedLayer(phi_dimension, layer.conv1.weight.shape[1], lr_multiplier=lr_multiplier, bias_init=1)
                if e:
                    layer.affine_e = FullyConnectedLayer(phi_dimension, layer.conv2.weight.shape[1], lr_multiplier=lr_multiplier, bias_init=1)
            else:
                add_affine_conv(layer)

    def add_affine_attn(vaed):
        if not (q or k or v):
            return

        for layer in vaed.children():
            if type(layer) == AttentionBlock:
                if q:
                    layer.affine_q = FullyConnectedLayer(phi_dimension, layer.query.weight.shape[1], lr_multiplier=lr_multiplier, bias_init=1)
                if k:
                    layer.affine_k = FullyConnectedLayer(phi_dimension, layer.key.weight.shape[1], lr_multiplier=lr_multiplier, bias_init=1)
                if v:
                    layer.affine_v = FullyConnectedLayer(phi_dimension, layer.value.weight.shape[1], lr_multiplier=lr_multiplier, bias_init=1)
            else:
                add_affine_attn(layer)

    def impose_grad_condition(vaed, finetune):
        if finetune == 'all':
            return

        for name, params in vaed.named_parameters():
            requires_grad = False
            if finetune == 'match':
                d_cond = d and (('resnets' in name and 'conv1' in name) or 'affine_d' in name)
                e_cond = e and (('resnets' in name and 'conv2' in name) or 'affine_e' in name)
                q_cond = q and (('attentions' in name and 'query' in name) or 'affine_q' in name)
                k_cond = k and (('attentions' in name and 'key' in name) or 'affine_k' in name)
                v_cond = v and (('attentions' in name and 'value' in name) or 'affine_v' in name)
                if q_cond or k_cond or v_cond or d_cond or e_cond:
                    requires_grad = True
                params.requires_grad = requires_grad

    def change_forward(vaed, layer_type, new_forward):
        for layer in vaed.children():
            if type(layer) == layer_type:
                bound_method = new_forward.__get__(layer, layer.__class__)
                setattr(layer, 'forward', bound_method)
            else:
                change_forward(layer, layer_type, new_forward)

    def new_forward_MB(self, hidden_states, encoded_fingerprint, temb=None):
        hidden_states = self.resnets[0]((hidden_states, encoded_fingerprint), temb)
        for attn, resnet in zip(self.attentions, self.resnets[1:]):
            if attn is not None:
                hidden_states = attn((hidden_states, encoded_fingerprint))
            hidden_states = resnet((hidden_states, encoded_fingerprint), temb)

        return hidden_states

    def new_forward_UDB(self, hidden_states, encoded_fingerprint):
        for resnet in self.resnets:
            hidden_states = resnet((hidden_states, encoded_fingerprint), temb=None)

        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                hidden_states = upsampler(hidden_states)

        return hidden_states

    def new_forward_RB(self, input_tensor, temb):
        input_tensor, encoded_fingerprint = input_tensor
        hidden_states = input_tensor

        hidden_states = self.norm1(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)

        if self.upsample is not None:
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
            input_tensor = self.upsample(input_tensor)
            hidden_states = self.upsample(hidden_states)
        elif self.downsample is not None:
            input_tensor = self.downsample(input_tensor)
            hidden_states = self.downsample(hidden_states)

        if d:
            phis = self.affine_d(encoded_fingerprint)
            batch_size = phis.shape[0]
            if not weight_offset:
                weight = phis.view(batch_size, 1, -1, 1, 1) * self.conv1.weight.unsqueeze(0)
            else:
                weight = self.conv1.weight
                weight_mod = phis.view(batch_size, 1, -1, 1, 1) * self.conv1.weight.unsqueeze(0)
                weight = weight.unsqueeze(0) + weight_mod
            hidden_states = F.conv2d(hidden_states.contiguous().view(1, -1, hidden_states.shape[-2], hidden_states.shape[-1]), weight.view(-1, weight.shape[-3], weight.shape[-2], weight.shape[-1]), padding=1, groups=batch_size).view(batch_size, weight.shape[1], hidden_states.shape[-2], hidden_states.shape[-1]) + self.conv1.bias.view(1, -1, 1, 1)
        else:
            hidden_states = self.conv1(hidden_states)

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]

        if temb is not None and self.time_embedding_norm == "default":
            hidden_states = hidden_states + temb

        hidden_states = self.norm2(hidden_states)

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

        hidden_states = self.nonlinearity(hidden_states)

        hidden_states = self.dropout(hidden_states)

        if e:
            phis = self.affine_e(encoded_fingerprint)
            batch_size = phis.shape[0]
            if not weight_offset:
                weight = phis.view(batch_size, 1, -1, 1, 1) * self.conv2.weight.unsqueeze(0)
            else:
                weight = self.conv2.weight
                weight_mod = phis.view(batch_size, 1, -1, 1, 1) * self.conv2.weight.unsqueeze(0)
                weight = weight.unsqueeze(0) + weight_mod
            hidden_states = F.conv2d(hidden_states.contiguous().view(1, -1, hidden_states.shape[-2], hidden_states.shape[-1]), weight.view(-1, weight.shape[-3], weight.shape[-2], weight.shape[-1]), padding=1, groups=batch_size).view(batch_size, weight.shape[1], hidden_states.shape[-2], hidden_states.shape[-1]) + self.conv2.bias.view(1, -1, 1, 1)
        else:
            hidden_states = self.conv2(hidden_states)

        if self.conv_shortcut is not None:
            input_tensor = self.conv_shortcut(input_tensor)

        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

        return output_tensor

    def new_forward_AB(self, hidden_states):
        hidden_states, encoded_fingerprint = hidden_states
        residual = hidden_states
        batch, channel, height, width = hidden_states.shape

        # norm
        hidden_states = self.group_norm(hidden_states)

        hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)

        # proj to q, k, v
        if q:
            phis_q = self.affine_q(encoded_fingerprint)
            if not weight_offset:
                query_proj = torch.bmm(hidden_states, phis_q.unsqueeze(-1) * self.query.weight.t().unsqueeze(0)) + self.query.bias
            else:
                qw = self.query.weight
                qw_mod = phis_q.unsqueeze(-1) * qw.t().unsqueeze(0)
                query_proj = torch.bmm(hidden_states, qw.t().unsqueeze(0) + qw_mod) + self.query.bias
        else:
            query_proj = self.query(hidden_states)

        if k:
            phis_k = self.affine_k(encoded_fingerprint)
            if not weight_offset:
                key_proj = torch.bmm(hidden_states, phis_k.unsqueeze(-1) * self.key.weight.t().unsqueeze(0)) + self.key.bias
            else:
                kw = self.key.weight
                kw_mod = phis_k.unsqueeze(-1) * kw.t().unsqueeze(0)
                key_proj = torch.bmm(hidden_states, kw.t().unsqueeze(0) + kw_mod) + self.key.bias
        else:
            key_proj = self.key(hidden_states)

        if v:
            phis_v = self.affine_v(encoded_fingerprint)
            if not weight_offset:
                value_proj = torch.bmm(hidden_states, phis_v.unsqueeze(-1) * self.value.weight.t().unsqueeze(0)) + self.value.bias
            else:
                vw = self.value.weight
                vw_mod = phis_v.unsqueeze(-1) * vw.t().unsqueeze(0)
                value_proj = torch.bmm(hidden_states, vw.t().unsqueeze(0) + vw_mod) + self.value.bias
        else:
            value_proj = self.value(hidden_states)

        scale = 1 / math.sqrt(self.channels / self.num_heads)

        query_proj = self.reshape_heads_to_batch_dim(query_proj)
        key_proj = self.reshape_heads_to_batch_dim(key_proj)
        value_proj = self.reshape_heads_to_batch_dim(value_proj)

        if self._use_memory_efficient_attention_xformers:
            # Memory efficient attention
            hidden_states = xformers.ops.memory_efficient_attention(
                query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
            )
            hidden_states = hidden_states.to(query_proj.dtype)
        else:
            attention_scores = torch.baddbmm(
                torch.empty(
                    query_proj.shape[0],
                    query_proj.shape[1],
                    key_proj.shape[1],
                    dtype=query_proj.dtype,
                    device=query_proj.device,
                ),
                query_proj,
                key_proj.transpose(-1, -2),
                beta=0,
                alpha=scale,
            )
            attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
            hidden_states = torch.bmm(attention_probs, value_proj)

        # reshape hidden_states
        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)

        # compute next hidden_states
        hidden_states = self.proj_attn(hidden_states)

        hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)

        # res connect and rescale
        hidden_states = (hidden_states + residual) / self.rescale_output_factor
        return hidden_states

    # Reference: https://github.com/huggingface/diffusers
    def new_forward_vaed(self, z, enconded_fingerprint):
        sample = z
        sample = self.conv_in(sample)

        # middle
        sample = self.mid_block(sample, enconded_fingerprint)

        # up
        for up_block in self.up_blocks:
            sample = up_block(sample, enconded_fingerprint)

        # post-process
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        return sample

    @dataclass
    class DecoderOutput(BaseOutput):
        """
        Output of decoding method.
        Args:
            sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Decoded output sample of the model. Output of the last layer of the model.
        """

        sample: torch.FloatTensor

    def new__decode(self, z: torch.FloatTensor, encoded_fingerprint: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
        z = self.post_quant_conv(z)
        dec = self.decoder(z, encoded_fingerprint)

        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)

    def new_decode(self, z: torch.FloatTensor, encoded_fingerprint: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
        if self.use_slicing and z.shape[0] > 1:
            decoded_slices = [self._decode(z_slice, encoded_fingerprint).sample for z_slice in z.split(1)]
            decoded = torch.cat(decoded_slices)
        else:
            decoded = self._decode(z, encoded_fingerprint).sample

        if not return_dict:
            return (decoded,)

        return DecoderOutput(sample=decoded)

    add_affine_conv(vae.decoder)
    add_affine_attn(vae.decoder)
    impose_grad_condition(vae.decoder, finetune)
    change_forward(vae.decoder, UNetMidBlock2D, new_forward_MB)
    change_forward(vae.decoder, UpDecoderBlock2D, new_forward_UDB)
    change_forward(vae.decoder, ResnetBlock2D, new_forward_RB)
    change_forward(vae.decoder, AttentionBlock, new_forward_AB)
    setattr(vae.decoder, 'forward', new_forward_vaed.__get__(vae.decoder, vae.decoder.__class__))
    setattr(vae, '_decode', new__decode.__get__(vae, vae.__class__))
    setattr(vae, 'decode', new_decode.__get__(vae, vae.__class__))

    return vae