import torch
from diffusers.pipelines import FluxPipeline
from typing import List, Union, Optional, Dict, Any, Callable
from .block import block_forward, single_block_forward
from .lora_controller import enable_lora
from diffusers.models.transformers.transformer_flux import (
    FluxTransformer2DModel,
    Transformer2DModelOutput,
    USE_PEFT_BACKEND,
    is_torch_version,
    scale_lora_layers,
    unscale_lora_layers,
    logger,
)
import numpy as np


def prepare_params(
    hidden_states: torch.Tensor,
    encoder_hidden_states: torch.Tensor = None,
    pooled_projections: torch.Tensor = None,
    timestep: torch.LongTensor = None,
    img_ids: torch.Tensor = None,
    txt_ids: torch.Tensor = None,
    guidance: torch.Tensor = None,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    controlnet_block_samples=None,
    controlnet_single_block_samples=None,
    return_dict: bool = True,
    **kwargs: dict,
):
    return (
        hidden_states,
        encoder_hidden_states,
        pooled_projections,
        timestep,
        img_ids,
        txt_ids,
        guidance,
        joint_attention_kwargs,
        controlnet_block_samples,
        controlnet_single_block_samples,
        return_dict,
    )


def tranformer_forward(
    transformer: FluxTransformer2DModel,
    condition_latents: torch.Tensor,
    condition_ids: torch.Tensor,
    condition_type_ids: torch.Tensor,
    model_config: Optional[Dict[str, Any]] = {},
    return_conditional_latents: bool = False,
    c_t=0,
    **params: dict,
):
    self = transformer
    use_condition = condition_latents is not None
    use_condition_in_single_blocks = model_config.get(
        "use_condition_in_single_blocks", True
    )
    # if return_conditional_latents is True, use_condition and use_condition_in_single_blocks must be True
    assert not return_conditional_latents or (
        use_condition and use_condition_in_single_blocks
    ), "`return_conditional_latents` is True, `use_condition` and `use_condition_in_single_blocks` must be True"

    (
        hidden_states,
        encoder_hidden_states,
        pooled_projections,
        timestep,
        img_ids,
        txt_ids,
        guidance,
        joint_attention_kwargs,
        controlnet_block_samples,
        controlnet_single_block_samples,
        return_dict,
    ) = prepare_params(**params)

    if joint_attention_kwargs is not None:
        joint_attention_kwargs = joint_attention_kwargs.copy()
        lora_scale = joint_attention_kwargs.pop("scale", 1.0)
    else:
        lora_scale = 1.0

    if USE_PEFT_BACKEND:
        # weight the lora layers by setting `lora_scale` for each PEFT layer
        scale_lora_layers(self, lora_scale)
    else:
        if (
            joint_attention_kwargs is not None
            and joint_attention_kwargs.get("scale", None) is not None
        ):
            logger.warning(
                "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
            )
    with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
        hidden_states = self.x_embedder(hidden_states)
    condition_latents = self.x_embedder(condition_latents) if use_condition else None

    timestep = timestep.to(hidden_states.dtype) * 1000
    if guidance is not None:
        guidance = guidance.to(hidden_states.dtype) * 1000
    else:
        guidance = None
    temb = (
        self.time_text_embed(timestep, pooled_projections)
        if guidance is None
        else self.time_text_embed(timestep, guidance, pooled_projections)
    )
    cond_temb = (
        self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
        if guidance is None
        else self.time_text_embed(
            torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
        )
    )
    if hasattr(self, "cond_type_embed") and condition_type_ids is not None:
        cond_type_proj = self.time_text_embed.time_proj(condition_type_ids[0])
        cond_type_emb = self.cond_type_embed(cond_type_proj.to(dtype=cond_temb.dtype))
        cond_temb = cond_temb + cond_type_emb
    encoder_hidden_states = self.context_embedder(encoder_hidden_states)

    if txt_ids.ndim == 3:
        logger.warning(
            "Passing `txt_ids` 3d torch.Tensor is deprecated."
            "Please remove the batch dimension and pass it as a 2d torch Tensor"
        )
        txt_ids = txt_ids[0]
    if img_ids.ndim == 3:
        logger.warning(
            "Passing `img_ids` 3d torch.Tensor is deprecated."
            "Please remove the batch dimension and pass it as a 2d torch Tensor"
        )
        img_ids = img_ids[0]

    ids = torch.cat((txt_ids, img_ids), dim=0)
    image_rotary_emb = self.pos_embed(ids)
    if use_condition:
        cond_ids = condition_ids
        cond_rotary_emb = self.pos_embed(cond_ids)

    # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)

    for index_block, block in enumerate(self.transformer_blocks):
        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module, return_dict=None):
                def custom_forward(*inputs):
                    if return_dict is not None:
                        return module(*inputs, return_dict=return_dict)
                    else:
                        return module(*inputs)

                return custom_forward

            ckpt_kwargs: Dict[str, Any] = (
                {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
            )
            encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(block),
                hidden_states,
                encoder_hidden_states,
                temb,
                image_rotary_emb,
                **ckpt_kwargs,
            )

        else:
            encoder_hidden_states, hidden_states, condition_latents = block_forward(
                block,
                model_config=model_config,
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                condition_latents=condition_latents if use_condition else None,
                temb=temb,
                cond_temb=cond_temb if use_condition else None,
                cond_rotary_emb=cond_rotary_emb if use_condition else None,
                image_rotary_emb=image_rotary_emb,
            )

        # controlnet residual
        if controlnet_block_samples is not None:
            interval_control = len(self.transformer_blocks) / len(
                controlnet_block_samples
            )
            interval_control = int(np.ceil(interval_control))
            hidden_states = (
                hidden_states
                + controlnet_block_samples[index_block // interval_control]
            )
    hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

    for index_block, block in enumerate(self.single_transformer_blocks):
        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module, return_dict=None):
                def custom_forward(*inputs):
                    if return_dict is not None:
                        return module(*inputs, return_dict=return_dict)
                    else:
                        return module(*inputs)

                return custom_forward

            ckpt_kwargs: Dict[str, Any] = (
                {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
            )
            hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(block),
                hidden_states,
                temb,
                image_rotary_emb,
                **ckpt_kwargs,
            )

        else:
            result = single_block_forward(
                block,
                model_config=model_config,
                hidden_states=hidden_states,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
                **(
                    {
                        "condition_latents": condition_latents,
                        "cond_temb": cond_temb,
                        "cond_rotary_emb": cond_rotary_emb,
                    }
                    if use_condition_in_single_blocks and use_condition
                    else {}
                ),
            )
            if use_condition_in_single_blocks and use_condition:
                hidden_states, condition_latents = result
            else:
                hidden_states = result

        # controlnet residual
        if controlnet_single_block_samples is not None:
            interval_control = len(self.single_transformer_blocks) / len(
                controlnet_single_block_samples
            )
            interval_control = int(np.ceil(interval_control))
            hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
                hidden_states[:, encoder_hidden_states.shape[1] :, ...]
                + controlnet_single_block_samples[index_block // interval_control]
            )

    hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

    hidden_states = self.norm_out(hidden_states, temb)
    output = self.proj_out(hidden_states)
    if return_conditional_latents:
        condition_latents = (
            self.norm_out(condition_latents, cond_temb) if use_condition else None
        )
        condition_output = self.proj_out(condition_latents) if use_condition else None

    if USE_PEFT_BACKEND:
        # remove `lora_scale` from each PEFT layer
        unscale_lora_layers(self, lora_scale)

    if not return_dict:
        return (
            (output,) if not return_conditional_latents else (output, condition_output)
        )

    return Transformer2DModelOutput(sample=output)