from typing import Optional
from einops import rearrange

import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusers.models.attention import Attention


class InflatedConv3d(nn.Conv2d):
	def forward(self, x):
		video_length = x.shape[2]
		
		x = rearrange(x, "b c f h w -> (b f) c h w")
		x = super().forward(x)
		x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
		
		return x


class FFInflatedConv3d(nn.Conv2d):
	def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
		super().__init__(
			in_channels=in_channels,
			out_channels=out_channels,
			kernel_size=kernel_size,
			**kwargs,
		)
		self.conv_temp = nn.Linear(3 * out_channels, out_channels)
		nn.init.zeros_(self.conv_temp.weight.data)  # initialized to be ones
		nn.init.zeros_(self.conv_temp.bias.data)
	
	def forward(self, x):
		video_length = x.shape[2]
		
		x = rearrange(x, "b c f h w -> (b f) c h w")
		x = super().forward(x)
		
		*_, h, w = x.shape
		x = rearrange(x, "(b f) c h w -> (b h w) f c", f=video_length)
		
		head_frame_index = [0, ] * video_length
		prev_frame_index = torch.clamp(
			torch.arange(video_length) - 1, min=0.0
		).long()
		curr_frame_index = torch.arange(video_length).long()
		conv_temp_nn_input = torch.cat([
			x[:, head_frame_index],
			x[:, prev_frame_index],
			x[:, curr_frame_index]
		], dim=2).contiguous()
		x = x + self.conv_temp(conv_temp_nn_input)
		
		x = rearrange(x, "(b h w) f c -> b c f h w", h=h, w=w)
		
		return x


class FFAttention(Attention):
	r"""
	A cross attention layer.

	Parameters:
		query_dim (`int`): The number of channels in the query.
		cross_attention_dim (`int`, *optional*):
			The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
		heads (`int`,  *optional*, defaults to 8): The number of heads to use for multi-head attention.
		dim_head (`int`,  *optional*, defaults to 64): The number of channels in each head.
		dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
		bias (`bool`, *optional*, defaults to False):
			Set to `True` for the query, key, and value linear layers to contain a bias parameter.
	"""
	
	def __init__(
			self,
			*args,
			scale_qk: bool = True,
			processor: Optional["FFAttnProcessor"] = None,
			**kwargs
	):
		super().__init__(*args, scale_qk=scale_qk, processor=processor, **kwargs)
		# set attention processor
		# We use the AttnProcessor by default when torch 2.x is used which uses
		# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
		# but only if it has the default `scale` argument.
		if processor is None:
			processor = FFAttnProcessor()
		self.set_processor(processor)
	
	def forward(self, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None,
	            **cross_attention_kwargs):
		# The `Attention` class can call different attention processors / attention functions
		# here we simply pass along all tensors to the selected processor class
		# For standard processors that are defined here, `**cross_attention_kwargs` is empty
		return self.processor(
			self,
			hidden_states,
			encoder_hidden_states=encoder_hidden_states,
			attention_mask=attention_mask,
			video_length=video_length,
			**cross_attention_kwargs,
		)


class FFAttnProcessor:
	def __init__(self):
		if not hasattr(F, "scaled_dot_product_attention"):
			raise ImportError(
				"FFAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
	
	def __call__(self, attn: Attention, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None):
		batch_size, sequence_length, _ = (
			hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
		)
		inner_dim = hidden_states.shape[-1]
		
		if attention_mask is not None:
			attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
			# scaled_dot_product_attention expects attention_mask shape to be
			# (batch, heads, source_length, target_length)
			attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
		
		query = attn.to_q(hidden_states)
		
		if encoder_hidden_states is None:
			encoder_hidden_states = hidden_states
		elif attn.norm_cross:
			encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
		
		key = attn.to_k(encoder_hidden_states)
		value = attn.to_v(encoder_hidden_states)
		
		# sparse causal attention
		former_frame_index = torch.arange(video_length) - 1
		former_frame_index[0] = 0
		
		key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
		key = key[:, [0] * video_length].contiguous()
		key = rearrange(key, "b f d c -> (b f) d c")
		
		value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
		value = value[:, [0] * video_length].contiguous()
		value = rearrange(value, "b f d c -> (b f) d c")
		
		head_dim = inner_dim // attn.heads
		query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
		key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
		value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
		
		# the output of sdp = (batch, num_heads, seq_len, head_dim)
		hidden_states = F.scaled_dot_product_attention(
			query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
		)
		
		hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
		hidden_states = hidden_states.to(query.dtype)
		
		# linear proj
		hidden_states = attn.to_out[0](hidden_states)
		# dropout
		hidden_states = attn.to_out[1](hidden_states)
		return hidden_states