import os

#  external libraries
import torch
import torch.utils.checkpoint
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTextModel, CLIPTokenizer

# custom imports
from src.datasets.dresscode import DressCodeDataset
from src.datasets.vitonhd import VitonHDDataset
from src.mgd_pipelines.mgd_pipe import MGDPipe
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
from src.utils.arg_parser import eval_parse_args
from src.utils.image_from_pipe import generate_images_from_mgd_pipe
from src.utils.set_seeds import set_seed

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")

logger = get_logger(__name__, log_level="INFO")
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB_START_METHOD"] = "thread"


def main() -> None:
    args = eval_parse_args()
    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,
    )
    device = accelerator.device

    # Set the training seed
    if args.seed is not None:
        set_seed(args.seed)

    # Load scheduler, tokenizer, and models
    val_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
    val_scheduler.set_timesteps(50, device=device)

    tokenizer = CLIPTokenizer.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
    )
    text_encoder = CLIPTextModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
    )
    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)

    # Load unet
    unet = torch.hub.load(
        dataset=args.dataset,
        repo_or_dir="aimagelab/multimodal-garment-designer",
        source="github",
        model="mgd",
        pretrained=True,
    )

    # Freeze vae and text_encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    # Enable memory efficient attention if requested
    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")

    # Set the dataset category
    category = [args.category] if args.category else ["dresses", "upper_body", "lower_body"]

    # Load the appropriate dataset
    if args.dataset == "dresscode":
        test_dataset = DressCodeDataset(
            dataroot_path=args.dataset_path,
            phase="test",
            order=args.test_order,
            radius=5,
            sketch_threshold_range=(20, 20),
            tokenizer=tokenizer,
            category=category,
            size=(512, 384),
        )
    elif args.dataset == "vitonhd":
        test_dataset = VitonHDDataset(
            dataroot_path=args.dataset_path,
            phase="test",
            order=args.test_order,
            sketch_threshold_range=(20, 20),
            radius=5,
            tokenizer=tokenizer,
            size=(512, 384),
        )
    else:
        raise NotImplementedError(f"Dataset {args.dataset} is not supported.")

    # Prepare the dataloader
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers_test,
    )

    # Cast text_encoder and vae to half-precision for mixed precision training
    weight_dtype = torch.float32 if args.mixed_precision != "fp16" else torch.float16
    text_encoder.to(device, dtype=weight_dtype)
    vae.to(device, dtype=weight_dtype)

    # Ensure unet is in eval mode
    unet.eval()

    # Select the appropriate pipeline
    with torch.inference_mode():
        if args.disentagle:
            val_pipe = MGDPipeDisentangled(
                text_encoder=text_encoder,
                vae=vae,
                unet=unet.to(vae.dtype),
                tokenizer=tokenizer,
                scheduler=val_scheduler,
            ).to(device)
        else:
            val_pipe = MGDPipe(
                text_encoder=text_encoder,
                vae=vae,
                unet=unet.to(vae.dtype),
                tokenizer=tokenizer,
                scheduler=val_scheduler,
            ).to(device)

        # Debugging: Ensure val_pipe is callable
        assert callable(val_pipe), "The pipeline object (val_pipe) is not callable. Check MGDPipe implementation."

        # Enable attention slicing for memory efficiency
        val_pipe.enable_attention_slicing()

        # Prepare dataloader with accelerator
        test_dataloader = accelerator.prepare(test_dataloader)

        # Call the image generation function
        generate_images_from_mgd_pipe(
            test_order=args.test_order,
            pipe=val_pipe,
            test_dataloader=test_dataloader,
            save_name=args.save_name,
            dataset=args.dataset,
            output_dir=args.output_dir,
            guidance_scale=args.guidance_scale,
            guidance_scale_pose=args.guidance_scale_pose,
            guidance_scale_sketch=args.guidance_scale_sketch,
            sketch_cond_rate=args.sketch_cond_rate,
            start_cond_rate=args.start_cond_rate,
            no_pose=False,
            disentagle=args.disentagle,
            seed=args.seed,
        )


if __name__ == "__main__":
    main()