from types import MethodType
from typing import Optional

from diffusers.models.attention_processor import Attention
import torch
import torch.nn.functional as F

from .feature import *
from .utils import *


def convolution_forward(  # From <class 'diffusers.models.resnet.ResnetBlock2D'>, forward (diffusers==0.28.0)
    self,
    input_tensor: torch.Tensor,
    temb: torch.Tensor,
    *args,
    **kwargs,
) -> torch.Tensor:
    do_structure_control = self.do_control and self.t in self.structure_schedule
    
    hidden_states = input_tensor

    hidden_states = self.norm1(hidden_states)
    hidden_states = self.nonlinearity(hidden_states)

    if self.upsample is not None:
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            input_tensor = input_tensor.contiguous()
            hidden_states = hidden_states.contiguous()
        input_tensor = self.upsample(input_tensor)
        hidden_states = self.upsample(hidden_states)
    elif self.downsample is not None:
        input_tensor = self.downsample(input_tensor)
        hidden_states = self.downsample(hidden_states)

    hidden_states = self.conv1(hidden_states)

    if self.time_emb_proj is not None:
        if not self.skip_time_act:
            temb = self.nonlinearity(temb)
        temb = self.time_emb_proj(temb)[:, :, None, None]

    if self.time_embedding_norm == "default":
        if temb is not None:
            hidden_states = hidden_states + temb
        hidden_states = self.norm2(hidden_states)
    elif self.time_embedding_norm == "scale_shift":
        if temb is None:
            raise ValueError(
                f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
            )
        time_scale, time_shift = torch.chunk(temb, 2, dim=1)
        hidden_states = self.norm2(hidden_states)
        hidden_states = hidden_states * (1 + time_scale) + time_shift
    else:
        hidden_states = self.norm2(hidden_states)

    hidden_states = self.nonlinearity(hidden_states)

    hidden_states = self.dropout(hidden_states)
    hidden_states = self.conv2(hidden_states)
    
    # Feature injection and AdaIN (hidden_states)
    if do_structure_control and "hidden_states" in self.structure_target:
        hidden_states = feature_injection(hidden_states, batch_order=self.batch_order)

    if self.conv_shortcut is not None:
        input_tensor = self.conv_shortcut(input_tensor)

    output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
    
    # Feature injection and AdaIN (output_tensor)
    if do_structure_control and "output_tensor" in self.structure_target:
        output_tensor = feature_injection(output_tensor, batch_order=self.batch_order)

    return output_tensor


class AttnProcessor2_0:  # From <class 'diffusers.models.attention_processor.AttnProcessor2_0'> (diffusers==0.28.0)

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:
        do_structure_control = attn.do_control and attn.t in attn.structure_schedule
        do_appearance_control = attn.do_control and attn.t in attn.appearance_schedule
        
        residual = hidden_states
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
        
        no_encoder_hidden_states = encoder_hidden_states is None
        if no_encoder_hidden_states:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
        
        if do_appearance_control:  # Assume we only have this for self attention
            hidden_states_normed = normalize(hidden_states, dim=-2)  # B H D C
            encoder_hidden_states_normed = normalize(encoder_hidden_states, dim=-2)
            
            query_normed = attn.to_q(hidden_states_normed)
            key_normed = attn.to_k(encoder_hidden_states_normed)
            
            inner_dim = key_normed.shape[-1]
            head_dim = inner_dim // attn.heads
            query_normed = query_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            key_normed = key_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            
            # Match query and key injection with structure injection (if injection is happening this layer)
            if do_structure_control:
                if "query" in attn.structure_target: 
                    query_normed = feature_injection(query_normed, batch_order=attn.batch_order)
                if "key" in attn.structure_target:
                    key_normed = feature_injection(key_normed, batch_order=attn.batch_order)
        
        # Appearance transfer (before)
        if do_appearance_control and "before" in attn.appearance_target:
            hidden_states = hidden_states.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order)
            hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
            
            if no_encoder_hidden_states:
                encoder_hidden_states = hidden_states
            elif attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
        
        query = attn.to_q(hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        
        # Feature injection (query, key, and/or value)
        if do_structure_control:
            if "query" in attn.structure_target: 
                query = feature_injection(query, batch_order=attn.batch_order)
            if "key" in attn.structure_target:
                key = feature_injection(key, batch_order=attn.batch_order)
            if "value" in attn.structure_target:
                value = feature_injection(value, batch_order=attn.batch_order)
        
        # Appearance transfer (value)
        if do_appearance_control and "value" in attn.appearance_target:
            value = appearance_transfer(value, query_normed, key_normed, batch_order=attn.batch_order)

        # The output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )
        
        # Appearance transfer (after)
        if do_appearance_control and "after" in attn.appearance_target:
            hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order)

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # Linear projection
        hidden_states = attn.to_out[0](hidden_states, *args)
        # Dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states
    

def register_control(
    model,
    timesteps,
    control_schedule,  # structure_conv, structure_attn, appearance_attn
    control_target = [["output_tensor"], ["query", "key"], ["before"]],
):  
    # Assume timesteps in reverse order (T -> 0)
    for block_type in ["encoder", "decoder", "middle"]:
        blocks = {
            "encoder": model.unet.down_blocks,
            "decoder": model.unet.up_blocks,
            "middle": [model.unet.mid_block],
        }[block_type]
        
        control_schedule_block = control_schedule[block_type]
        if block_type == "middle":
            control_schedule_block = [control_schedule_block]
        
        for layer in range(len(control_schedule_block)):
            # Convolution
            num_blocks = len(blocks[layer].resnets) if hasattr(blocks[layer], "resnets") else 0
            for block in range(num_blocks):
                convolution = blocks[layer].resnets[block]
                convolution.structure_target = control_target[0]
                convolution.structure_schedule = get_schedule(
                    timesteps, get_elem(control_schedule_block[layer][0], block)
                )
                convolution.forward = MethodType(convolution_forward, convolution)
        
            # Self-attention
            num_blocks = len(blocks[layer].attentions) if hasattr(blocks[layer], "attentions") else 0
            for block in range(num_blocks):
                for transformer_block in blocks[layer].attentions[block].transformer_blocks:
                    attention = transformer_block.attn1
                    attention.structure_target = control_target[1]
                    attention.structure_schedule = get_schedule(
                        timesteps, get_elem(control_schedule_block[layer][1], block)
                    )
                    attention.appearance_target = control_target[2]
                    attention.appearance_schedule = get_schedule(
                        timesteps, get_elem(control_schedule_block[layer][2], block)
                    )
                    attention.processor = AttnProcessor2_0()
                    
                    
def register_attr(model, t, do_control, batch_order):
    for layer_type in ["encoder", "decoder", "middle"]:
        blocks = {"encoder": model.unet.down_blocks, "decoder": model.unet.up_blocks,
                  "middle": [model.unet.mid_block]}[layer_type]
        for layer in blocks:
            # Convolution
            for module in layer.resnets:
                module.t = t
                module.do_control = do_control
                module.batch_order = batch_order
            # Self-attention
            if hasattr(layer, "attentions"):
                for block in layer.attentions:
                    for module in block.transformer_blocks:
                        module.attn1.t = t
                        module.attn1.do_control = do_control
                        module.attn1.batch_order = batch_order