import torch
import torch.nn as nn
from diffusers import AutoencoderOobleck
from diffusers import FluxTransformer2DModel
from tangoflux import TangoFluxInference
from tangoflux.model import DurationEmbedder, TangoFlux

def export_vae_encoder(vae, save_path, batch_size=1, audio_length=441000):
    """导出VAE编码器到ONNX格式
    
    Args:
        vae: AutoencoderOobleck实例
        save_path: 保存路径
        batch_size: batch大小
        audio_length: 音频长度(默认10秒,44100Hz采样率)
    """
    vae.eval()
    
    # 创建dummy input - 注意这里是双声道音频
    dummy_input = torch.randn(batch_size, 2, audio_length)
    
    # 创建一个包装类来处理forward调用
    class VAEEncoderWrapper(nn.Module):
        def __init__(self, vae):
            super().__init__()
            self.vae = vae
            
        def forward(self, audio):
            return self.vae.encode(audio).latent_dist.sample()
    
    wrapper = VAEEncoderWrapper(vae)
    
    # 导出encoder部分
    torch.onnx.export(
        wrapper,
        dummy_input,
        save_path,
        input_names=['audio'],
        output_names=['latent'],
        dynamic_axes={
            'audio': {0: 'batch_size', 2: 'audio_length'},
            'latent': {0: 'batch_size', 2: 'latent_length'}
        },
        opset_version=17
    )

def export_vae_decoder(vae, save_path, batch_size=1, latent_length=645):
    """导出VAE解码器到ONNX格式
    
    Args:
        vae: AutoencoderOobleck实例
        save_path: 保存路径
        batch_size: batch大小
        latent_length: 潜在向量长度
    """
    vae.eval()
    
    # 创建dummy input
    dummy_input = torch.randn(batch_size, 64, latent_length)
    
    # 创建一个包装类来处理forward调用
    class VAEDecoderWrapper(nn.Module):
        def __init__(self, vae):
            super().__init__()
            self.vae = vae
            
        def forward(self, latent):
            return self.vae.decode(latent).sample
    
    wrapper = VAEDecoderWrapper(vae)
    
    # 导出decoder部分
    torch.onnx.export(
        wrapper,
        dummy_input,
        save_path,
        input_names=['latent'],
        output_names=['audio'],
        dynamic_axes={
            'latent': {0: 'batch_size', 2: 'latent_length'},
            'audio': {0: 'batch_size', 2: 'audio_length'}
        },
        opset_version=17
    )

def export_duration_embedder(duration_embedder, save_path, batch_size=1):
    """导出Duration Embedder到ONNX格式
    
    Args:
        duration_embedder: DurationEmbedder实例
        save_path: 保存路径
        batch_size: batch大小
    """
    duration_embedder.eval()
    
    # 创建dummy input - 注意这里是标量值
    dummy_input = torch.tensor([[10.0]], dtype=torch.float32)  # 10秒
    
    # 导出
    torch.onnx.export(
        duration_embedder,
        dummy_input,
        save_path,
        input_names=['duration'],
        output_names=['embedding'],
        dynamic_axes={
            'duration': {0: 'batch_size'},
            'embedding': {0: 'batch_size'}
        },
        opset_version=17
    )

def export_flux_transformer(transformer, save_path, batch_size=1, seq_length=645):
    """导出FluxTransformer2D到ONNX格式
    
    Args:
        transformer: FluxTransformer2DModel实例
        save_path: 保存路径
        batch_size: batch大小
        seq_length: 序列长度
    """
    transformer.eval()
    
    # 创建dummy inputs - 注意所有输入的形状
    hidden_states = torch.randn(batch_size, seq_length, 64)  # [B, S, C]
    timestep = torch.tensor([0.5])  # [1]
    pooled_text = torch.randn(batch_size, 1024)  # [B, D]
    encoder_hidden_states = torch.randn(batch_size, 64, 1024)  # [B, L, D]
    txt_ids = torch.zeros(batch_size, 64, 3).to(torch.int64) # [B, L, 3]
    img_ids = torch.arange(seq_length).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, 3).to(torch.int64)  # [B, S, 3]
    
    # 创建一个包装类来处理forward调用
    class TransformerWrapper(nn.Module):
        def __init__(self, transformer):
            super().__init__()
            self.transformer = transformer
            
        def forward(self, hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids):
            return self.transformer(
                hidden_states=hidden_states,
                timestep=timestep,
                guidance=None,
                pooled_projections=pooled_text,
                encoder_hidden_states=encoder_hidden_states,
                txt_ids=txt_ids,
                img_ids=img_ids,
                return_dict=False
            )[0]
    
    wrapper = TransformerWrapper(transformer)
    
    # 导出
    torch.onnx.export(
        wrapper,
        (hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids),
        save_path,
        input_names=['hidden_states', 'timestep', 'pooled_text', 'encoder_hidden_states', 'txt_ids', 'img_ids'],
        output_names=['output'],
        dynamic_axes={
            'hidden_states': {0: 'batch_size', 1: 'sequence_length'},
            'pooled_text': {0: 'batch_size'},
            'encoder_hidden_states': {0: 'batch_size', 1: 'text_length'},
            'txt_ids': {0: 'batch_size', 1: 'text_length'},
            'img_ids': {0: 'batch_size', 1: 'sequence_length'}
        },
        opset_version=17
    )

def export_proj_layer(proj_layer, save_path, batch_size=1):
    """导出projection层到ONNX格式
    
    Args:
        proj_layer: 投影层(fc层)实例
        save_path: 保存路径
        batch_size: batch大小
    """
    proj_layer.eval()
    
    # 创建dummy input - 使用T5的hidden size
    dummy_input = torch.randn(batch_size, 1024)  # T5-large hidden size
    
    # 导出
    torch.onnx.export(
        proj_layer,
        dummy_input,
        save_path,
        input_names=['text_embedding'],
        output_names=['projected'],
        dynamic_axes={
            'text_embedding': {0: 'batch_size'},
            'projected': {0: 'batch_size'}
        },
        opset_version=17
    )

def export_all(model_path, output_dir):
    """导出所有组件到ONNX格式
    
    Args:
        model_path: TangoFlux模型路径
        output_dir: 输出目录
    """
    import os
    
    # 加载模型
    model = TangoFluxInference(name=model_path, device="cpu")
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 导出VAE
    export_vae_encoder(model.vae, f"{output_dir}/vae_encoder.onnx")
    export_vae_decoder(model.vae, f"{output_dir}/vae_decoder.onnx")
    
    # 导出Duration Embedder
    export_duration_embedder(model.model.duration_emebdder, f"{output_dir}/duration_embedder.onnx")
    
    # 导出Transformer
    export_flux_transformer(model.model.transformer, f"{output_dir}/transformer.onnx")
    
    # 导出Projection层
    export_proj_layer(model.model.fc, f"{output_dir}/proj.onnx")
    
    print(f"所有模型已导出到: {output_dir}")

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="导出TangoFlux模型到ONNX格式")
    parser.add_argument("--model_path", type=str, required=True, help="TangoFlux模型路径")
    parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
    
    args = parser.parse_args()
    export_all(args.model_path, args.output_dir)