import torch
from models import VQVAE, build_vae_var
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, SiglipTextModel
from peft import LoraConfig, get_peft_model
from torchvision.transforms import ToPILImage
import random
import gradio as gr

class SimpleAdapter(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=1024, out_dim=1024):
        super(SimpleAdapter, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.norm0 = nn.LayerNorm(input_dim)
        self.activation1 = nn.GELU()
        self.layer2 = nn.Linear(hidden_dim, out_dim)
        self.norm2 = nn.LayerNorm(out_dim)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.001)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.norm0(x)
        x = self.layer1(x)
        x = self.activation1(x)
        x = self.layer2(x)
        x = self.norm2(x)
        return x

class InrenceTextVAR(nn.Module):
    def __init__(self, pl_checkpoint=None, start_class_id=578, hugging_face_token=None, siglip_model='google/siglip-base-patch16-224', device="cpu", MODEL_DEPTH=16):
        super(InrenceTextVAR, self).__init__()
        self.device = device
        self.class_id = start_class_id
        # Define layers
        patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
        self.vae, self.var = build_vae_var(
            V=4096, Cvae=32, ch=160, share_quant_resi=4,
            device=device, patch_nums=patch_nums,
            num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
        )
        self.text_processor = AutoTokenizer.from_pretrained(siglip_model, token=hugging_face_token)
        self.siglip_text_encoder = SiglipTextModel.from_pretrained(siglip_model, token=hugging_face_token).to(device)
        self.adapter = SimpleAdapter(
            input_dim=self.siglip_text_encoder.config.hidden_size,
            out_dim=self.var.C  # Ensure dimensional consistency
        ).to(device)
        self.apply_lora_to_var()
        if pl_checkpoint is not None:
            state_dict = torch.load(pl_checkpoint, map_location="cpu")['state_dict']
            var_state_dict = {k[len('var.'):]: v for k, v in state_dict.items() if k.startswith('var.')}
            vae_state_dict = {k[len('vae.'):]: v for k, v in state_dict.items() if k.startswith('vae.')}
            adapter_state_dict = {k[len('adapter.'):]: v for k, v in state_dict.items() if k.startswith('adapter.')}
            self.var.load_state_dict(var_state_dict)
            self.vae.load_state_dict(vae_state_dict)
            self.adapter.load_state_dict(adapter_state_dict)
        del self.vae.encoder

    def apply_lora_to_var(self):
        """
        Applies LoRA (Low-Rank Adaptation) to the VAR model.
        """
        def find_linear_module_names(model):
            linear_module_names = []
            for name, module in model.named_modules():
                if isinstance(module, nn.Linear):
                    linear_module_names.append(name)
            return linear_module_names

        linear_module_names = find_linear_module_names(self.var)

        lora_config = LoraConfig(
            r=8,
            lora_alpha=32,
            target_modules=linear_module_names,
            lora_dropout=0.05,
            bias="none",
        )

        self.var = get_peft_model(self.var, lora_config)

    @torch.no_grad()
    def generate_image(self, text, beta=1, seed=None, more_smooth=False, top_k=0, top_p=0.9):
        if seed is None:
            seed = random.randint(0, 2**32 - 1)
        inputs = self.text_processor([text], padding="max_length", return_tensors="pt").to(self.device)
        outputs = self.siglip_text_encoder(**inputs)
        pooled_output = outputs.pooler_output  # pooled (EOS token) states
        pooled_output = F.normalize(pooled_output, p=2, dim=-1)  # Normalize delta condition
        cond_delta = F.normalize(pooled_output, p=2, dim=-1).to(self.device)  # Use correct device
        cond_delta = self.adapter(cond_delta)
        cond_delta = F.normalize(cond_delta, p=2, dim=-1)  # Normalize delta condition
        generated_images = self.var.autoregressive_infer_cfg(
            B=1,
            label_B=self.class_id,
            delta_condition=cond_delta[:1],
            beta=beta,
            alpha=1,
            top_k=top_k,
            top_p=top_p,
            more_smooth=more_smooth,
            g_seed=seed
        )
        image = ToPILImage()(generated_images[0].cpu())
        return image


if __name__ == '__main__':

    # Initialize the model
    checkpoint = 'VARtext_v1.pth'  # Replace with your actual checkpoint path
    device = 'cpu' if not torch.cuda.is_available() else 'cuda'
    state_dict = torch.load(checkpoint, map_location="cpu")
    model = InrenceTextVAR(device=device)
    model.load_state_dict(state_dict)
    model.to(device)

    def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9):
        print(f"Generating image for text: {text}\n"
              f"beta: {beta}\n"
              f"seed: {seed}\n"
              f"more_smooth: {more_smooth}\n"
              f"top_k: {top_k}\n"
              f"top_p: {top_p}\n")
        image = model.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=int(top_k), top_p=top_p)
        return image

    with gr.Blocks() as demo:
        gr.Markdown("# Text to Image/Video Generator")
        with gr.Tab("Generate Image"):
            text_input = gr.Textbox(label="Input Text")
            beta_input = gr.Slider(label="Beta", minimum=0.0, maximum=2.5, step=0.05, value=1.0)
            seed_input = gr.Number(label="Seed", value=None)
            more_smooth_input = gr.Checkbox(label="More Smooth", value=False)
            top_k_input = gr.Number(label="Top K", value=0)
            top_p_input = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.9)
            generate_button = gr.Button("Generate Image")
            image_output = gr.Image(label="Generated Image")
            generate_button.click(
                generate_image_gradio,
                inputs=[text_input, beta_input, seed_input, more_smooth_input, top_k_input, top_p_input],
                outputs=image_output
            )

    demo.launch()