import os, json
import math, random
from multiprocessing import Pool
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from transformers import CLIPTextModel
from transformers import PretrainedConfig


def pad_spec(spec, spec_length, pad_value=0, random_crop=True): # spec: [3, mel_dim, spec_len]
    assert spec_length % 8 == 0, "spec_length must be divisible by 8"
    if spec.shape[-1] < spec_length:
        # pad spec to spec_length
        spec = F.pad(spec, (0, spec_length - spec.shape[-1]), value=pad_value)
    else:
        # random crop
        if random_crop:
            start = random.randint(0, spec.shape[-1] - spec_length)
            spec = spec[:, :, start:start+spec_length]
        else:
            spec = spec[:, :, :spec_length]
    return spec


def load_spec(spec_path):
    if spec_path.endswith(".pt"):
        spec = torch.load(spec_path, map_location="cpu")
    elif spec_path.endswith(".npy"):
        spec = torch.from_numpy(np.load(spec_path))
    else:
        raise ValueError(f"Unknown spec file type {spec_path}")
    assert len(spec.shape) == 3, f"spec shape must be [3, mel_dim, spec_len], got {spec.shape}"
    if spec.size(0) == 1:
        spec = spec.repeat(3, 1, 1)
    return spec


def random_crop_spec(spec, target_spec_length, pad_value=0, frame_per_sec=100, time_step=5): # spec: [3, mel_dim, spec_len]
    assert target_spec_length % 8 == 0, "spec_length must be divisible by 8"

    spec_length = spec.shape[-1]
    full_s = math.ceil(spec_length / frame_per_sec / time_step) * time_step # get full seconds(ceil)
    start_s = random.randint(0, math.floor(spec_length / frame_per_sec / time_step)) * time_step # random get start seconds

    end_s = min(start_s + math.ceil(target_spec_length / frame_per_sec), full_s) # get end seconds

    spec = spec[:, :, start_s * frame_per_sec : end_s * frame_per_sec] # get spec in seconds(crop more than target_spec_length because ceiling)

    if spec.shape[-1] < target_spec_length:
        spec = F.pad(spec, (0, target_spec_length - spec.shape[-1]), value=pad_value) # pad to target_spec_length
    else:     
        spec = spec[:, :, :target_spec_length] # crop to target_spec_length

    return spec, int(start_s), int(end_s), int(full_s)



def load_condion_embed(text_embed_path):
    if text_embed_path.endswith(".pt"):
        text_embed_list = torch.load(text_embed_path, map_location="cpu")
    elif text_embed_path.endswith(".npy"):
        text_embed_list = torch.from_numpy(np.load(text_embed_path))
    else:
        raise ValueError(f"Unknown text embedding file type {text_embed_path}")
    if type(text_embed_list) == list:
        text_embed = random.choice(text_embed_list)
    if len(text_embed.shape) == 3: # [1, text_len, text_dim]
        text_embed = text_embed.squeeze(0) # random choice and return text_emb: [text_len, text_dim]
    return text_embed.detach().cpu()
    

def process_condition_embed(cond_emb, max_length): # [text_len, text_dim], Padding 0 and random drop by CFG
    if cond_emb.shape[0] < max_length:
        cond_emb = F.pad(cond_emb, (0, 0, 0, max_length - cond_emb.shape[0]), value=0)
    else:
        cond_emb = cond_emb[:max_length, :]
    return cond_emb
                

def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel
        return CLIPTextModel
    if "t5" in model_class.lower():
        from transformers import T5EncoderModel
        return T5EncoderModel
    if "clap" in model_class.lower():
        from transformers import ClapTextModelWithProjection
        return ClapTextModelWithProjection
    else:
        raise ValueError(f"{model_class} is not supported.")
    


def str2bool(string):
    str2val = {"True": True, "False": False, "true": True, "false": False, "none": False, "None": False}
    if string in str2val:
        return str2val[string]
    else:
        raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
    

def str2str(string):
    if string.lower() == "none" or string.lower() == "null" or string.lower() == "false" or string == "":
        return None
    else:
        return string    


def json_dump(data_json, json_save_path):
    with open(json_save_path, 'w') as f:
        json.dump(data_json, f, indent=4)
        f.close()


def json_load(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
        f.close()
    return data


def load_json_list(path):
    with open(path, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f.readlines()]
    

def save_json_list(data, path):
    with open(path, 'w', encoding='utf-8') as f:
        for d in data:
            f.write(json.dumps(d) + '\n')
    

def multiprocess_function(func, func_args, n_jobs=32):  
    with Pool(processes=n_jobs) as p:
            with tqdm(total=len(func_args)) as pbar:
                for i, _ in enumerate(p.imap_unordered(func, func_args)):
                    pbar.update()


def image_add_color(spec_img):
    cmap = plt.get_cmap('viridis')
    cmap_r = cmap.reversed()
    image = cmap(np.array(spec_img)[:,:,0])[:, :, :3]  # 省略透明度通道
    image = (image - image.min()) / (image.max() - image.min())
    image = Image.fromarray(np.uint8(image*255))
    return image


@staticmethod
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
    """
    Convert a PyTorch tensor to a NumPy image.
    """
    images = images.cpu().permute(0, 2, 3, 1).float().numpy()
    return images


def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]

    return pil_images

### CODE FOR INPAITING ###
def normalize(images):
    """
    Normalize an image array to [-1,1].
    """
    if images.min() >= 0:
        return 2.0 * images - 1.0
    else:
        return images

def denormalize(images):
    """
    Denormalize an image array to [0,1].
    """
    if images.min() < 0:
        return (images / 2 + 0.5).clamp(0, 1)
    else:
        return images.clamp(0, 1)     
    

def prepare_mask_and_masked_image(image, mask):
    """
    Prepare a binary mask and the masked image.
    
    Parameters:
    - image (torch.Tensor): The input image tensor of shape [3, height, width] with values in the range [0, 1].
    - mask (torch.Tensor): The input mask tensor of shape [1, height, width].
    
    Returns:
    - tuple: A tuple containing the binary mask and the masked image.
    """
    # Noralize image to [0,1]
    if image.max() > 1:
        image = (image - image.min()) / (image.max() - image.min())
    # Normalize image from [0,1] to [-1,1]
    if image.min() >= 0:
        image = normalize(image)    
    # Apply the mask to the image
    masked_image = image * (mask < 0.5)
    
    return mask, masked_image


def torch_to_pil(image):
    """
    Convert a torch tensor to a PIL image.
    """
    if image.min() < 0:
        image = denormalize(image)

    return transforms.ToPILImage()(image.cpu().detach().squeeze())



# class TextEncoderAdapter(nn.Module):    
#     def __init__(self, hidden_size, cross_attention_dim=768):
#         super(TextEncoderAdapter, self).__init__()
#         self.hidden_size = hidden_size
#         self.cross_attention_dim = cross_attention_dim
#         self.proj = nn.Linear(self.hidden_size, self.cross_attention_dim)
#         self.norm = torch.nn.LayerNorm(self.cross_attention_dim)

#     def forward(self, x):
#         x = self.proj(x)
#         x = self.norm(x)
#         return x
    
#     def save_pretrained(self, save_directory, subfolder=""):
#         if subfolder:
#             save_directory = os.path.join(save_directory, subfolder)
#         os.makedirs(save_directory, exist_ok=True)
#         ckpt_path = os.path.join(save_directory, "adapter.pt")
#         config_path = os.path.join(save_directory, "config.json")
#         config = {"hidden_size": self.hidden_size, "cross_attention_dim": self.cross_attention_dim}
#         json_dump(config, config_path)
#         torch.save(self.state_dict(), ckpt_path)
#         print(f"Saving adapter model to {ckpt_path}")

#     @classmethod
#     def from_pretrained(cls, load_directory, subfolder=""):
#         if subfolder:
#             load_directory = os.path.join(load_directory, subfolder)
#         ckpt_path = os.path.join(load_directory, "adapter.pt")
#         config_path = os.path.join(load_directory, "config.json")
#         config = json_load(config_path)
#         instance = cls(**config)
#         instance.load_state_dict(torch.load(ckpt_path))
#         print(f"Loading adapter model from {ckpt_path}")
#         return instance          



class ConditionAdapter(nn.Module):
    def __init__(self, config):
        super(ConditionAdapter, self).__init__()
        self.config = config
        self.proj = nn.Linear(self.config["condition_dim"], self.config["cross_attention_dim"])
        self.norm = torch.nn.LayerNorm(self.config["cross_attention_dim"])
        print(f"INITIATED: ConditionAdapter: {self.config}")

    def forward(self, x):
        x = self.proj(x)
        x = self.norm(x)
        return x
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path):
        config_path = os.path.join(pretrained_model_name_or_path, "config.json")
        ckpt_path = os.path.join(pretrained_model_name_or_path, "condition_adapter.pt")
        config = json_load(config_path)
        instance = cls(config)
        instance.load_state_dict(torch.load(ckpt_path))
        print(f"LOADED: ConditionAdapter from {pretrained_model_name_or_path}")
        return instance

    def save_pretrained(self, pretrained_model_name_or_path):
        os.makedirs(pretrained_model_name_or_path, exist_ok=True)
        config_path = os.path.join(pretrained_model_name_or_path, "config.json")
        ckpt_path = os.path.join(pretrained_model_name_or_path, "condition_adapter.pt")        
        json_dump(self.config, config_path)
        torch.save(self.state_dict(), ckpt_path)
        print(f"SAVED: ConditionAdapter {self.config['condition_adapter_name']} to {pretrained_model_name_or_path}")


# class TextEncoderWrapper(CLIPTextModel):
#     def __init__(self, text_encoder, text_encoder_adapter):
#         super().__init__(text_encoder.config)
#         self.text_encoder = text_encoder
#         self.adapter = text_encoder_adapter

#     def forward(self, input_ids, **kwargs):
#         outputs = self.text_encoder(input_ids, **kwargs)
#         adapted_output = self.adapter(outputs[0])
#         return [adapted_output] # to compatible with last_hidden_state