from transformers import AutoModelForSeq2SeqLM
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator
from transformers import AutoTokenizer
from tqdm import tqdm
import pandas as pd
import numpy
import random
import nevergrad as ng
from peft.utils.save_and_load import set_peft_model_state_dict, get_peft_model_state_dict
from peft import PeftModel, PeftConfig
from functools import partial

random.seed(42)
numpy.random.seed(42)

def load_base_model_and_lora_modules(lora_module_list):
    # use gpu if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # load basic model
    default_peft_model_id = lora_module_list[0]
    # find the base model
    model_name_or_path = PeftConfig.from_pretrained(default_peft_model_id).base_model_name_or_path
    base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    # 0 is the default model
    peft_model = PeftModel.from_pretrained(base_model, default_peft_model_id)
    peft_model = peft_model.to(device)
    peft_model.eval()

    print("> Begin to load lora modules")
    cache = {}
    for peft_model_id in tqdm(lora_module_list):
        print("> Loading {} ...".format(peft_model_id))
        cur_peft_model = PeftModel.from_pretrained(base_model, peft_model_id)
        cache[peft_model_id] = get_peft_model_state_dict(cur_peft_model)

    return peft_model, tokenizer, cache


def preprocess_function(examples, tokenizer):
    inputs = examples["input"]
    targets = examples["output"]
    model_inputs = tokenizer(
        inputs,
        max_length=2048,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    labels = tokenizer(
        targets,
        max_length=256,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


def load_dataset_and_run(example_inputs, example_outputs, tokenizer):
    df = [
        {"input": example_inputs[i], "output": example_outputs[i]}
        for i in range(len(example_inputs))
    ]
    dataset = Dataset.from_pandas(pd.DataFrame(df))
    preprocess_func_with_tokenizer = partial(preprocess_function, tokenizer=tokenizer)
    processed_datasets = dataset.map(
        preprocess_func_with_tokenizer,
        batched=True,
        num_proc=1,
        desc="Running tokenizer on dataset",
    )
    return processed_datasets

    
def get_score(weights, model, cache, example_dataset):
    # the composed lora state dict
    final_state_dict = {}
    # module list is the list
    lora_module_list = list(cache.keys())
    # all keys are the same
    keys = cache[lora_module_list[0]].keys()
    for i, peft_model_id in enumerate(lora_module_list):
        lora_state_dict = cache[peft_model_id]
        if i == 0:
            for key in keys:
                final_state_dict[key] = weights[i] * lora_state_dict[key]
        else:
            for key in keys:
                final_state_dict[key] = (
                    final_state_dict[key] + weights[i] * lora_state_dict[key]
                )
    # reload the model with the new adapter config
    set_peft_model_state_dict(model, final_state_dict)
        
    def get_loss():
        # use gpu if available
        train_dataset = example_dataset
        train_dataloader = DataLoader(
            train_dataset,
            collate_fn=default_data_collator,
            batch_size=len(train_dataset),
            pin_memory=True,
        )
        train_loss = 0
        with torch.no_grad():
            device = "cuda" if torch.cuda.is_available() else "cpu"
            for _, batch in enumerate(train_dataloader):
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = model(**batch)
                loss = outputs.loss
                train_loss += loss.detach().float()
        loss = train_loss.float()
        return float(loss) / len(train_dataset["input"])
        
    # minimize the metric
    loss = get_loss()
    # L1 regularization term
    sum_of_squares = sum([abs(x) for x in weights]) / len(weights)
    metric_val = loss + 0.05 * sum_of_squares
    
    return metric_val

def get_final_weights(weights, lora_module_list, cache):
    final_state_dict = {}
    keys = cache[lora_module_list[0]].keys()
    for i, peft_model_id in enumerate(lora_module_list):
        lora_state_dict = cache[peft_model_id]
        if i == 0:
            for key in keys:
                final_state_dict[key] = weights[i] * lora_state_dict[key]
        else:
            for key in keys:
                final_state_dict[key] = (
                    final_state_dict[key] + weights[i] * lora_state_dict[key]
                )
    return final_state_dict
    


def lorahub_learning(lora_module_list, text_input, text_output, max_inference_step):
    number_of_loras = len(lora_module_list)
    if number_of_loras == 0:
        return None
    # load model
    model, tokenizer, cache = load_base_model_and_lora_modules(lora_module_list)
    # process dataset
    dataset = load_dataset_and_run(text_input.split("\n"), text_output.split("\n"), tokenizer)
    
    get_score_partial = partial(get_score, model=model, cache=cache, 
                                example_dataset=dataset)
    # set up the limit of the weights
    instrum = ng.p.Array(
        init=[0] * number_of_loras,
        upper=[1.5] * number_of_loras,
        lower=[-1.5] * number_of_loras,
    )
    optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=max_inference_step)
    print("> Begin to perform gradient-free optimization ...")
    recommendation = optimizer.minimize(get_score_partial, verbosity=1)
    final_lora = get_final_weights(recommendation.value, lora_module_list, cache)
    return recommendation, final_lora