from dataclasses import dataclass, field
from typing import Optional
import pandas as pd
import os

import torch
from transformers import VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, EarlyStoppingCallback
from peft import LoraConfig, get_peft_model

from data import AphaPenDataset
import evaluate
from sklearn.model_selection import train_test_split

from src.calibrator import EncoderDecoderCalibrator
from src.loss import MarginLoss, KLRegularization
from src.similarity import CERSimilarity
from datetime import datetime
import torch.nn.functional as F

os.environ["WANDB_PROJECT"] = "Alphapen-TrOCR"

# # Step 1: Load the dataset
train_df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv"
test_df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv"

#train_df = pd.read_csv(train_df_path)
#train_df.dropna(inplace=True)

train_df = pd.read_csv(test_df_path)[:4000]
train_df.dropna(inplace=True)

test_df = pd.read_csv(test_df_path)[4000:]
test_df.dropna(inplace=True)

# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"

model_name = "microsoft/trocr-large-handwritten"

processor = TrOCRProcessor.from_pretrained(model_name)
train_dataset = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor)
eval_dataset = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor)

# Step 2: Load the model
model = VisionEncoderDecoderModel.from_pretrained(model_name)

# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# for peft
model.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

# LoRa
lora_config = LoraConfig(
    r=1,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=[
        'query',
        'key',
        'value',
        'intermediate.dense',
        'output.dense',
        #'wte',
        #'wpe',
        #'c_attn',
        #'c_proj',
        #'q_attn',
        #'c_fc'
    ],
)
model = get_peft_model(model, lora_config)

tokenizer = processor.tokenizer
# sim = CERSimilarity(tokenizer)
# loss = MarginLoss(sim, beta=0.1, num_samples=60)
# reg = KLRegularization(model)
# calibrator = EncoderDecoderCalibrator(model, loss, reg, 15, 15)

# # Step 3: Define the training arguments
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    bf16=True,
    bf16_full_eval=True,
    output_dir="./",
    logging_steps=100,
    save_steps=20000,
    eval_steps=500,
    # report_to="wandb",
    optim="adamw_torch_fused",
    lr_scheduler_type="cosine",
    gradient_accumulation_steps=2,
    learning_rate=1.0e-4,
    max_steps=10000,
    run_name=f"trocr-LoRA-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}",
)

# Step 4: Define a metric

cer_metric = evaluate.load("cer")

def compute_cer(pred, target):
    return cer_metric.compute(predictions=[pred], references=[target])['cer']

def generate_candidates(model, pixel_values, num_candidates=10):
    return model.generate(
        pixel_values,
        num_return_sequences=num_candidates,
        num_beams=num_candidates,
        output_scores=True,
        return_dict_in_generate=True
    )

def rank_loss(positive_scores, negative_scores):
    return F.relu(1 - positive_scores + negative_scores).mean()

def margin_loss(positive_scores, negative_scores, margin=0.1):
    return F.relu(margin - positive_scores + negative_scores).mean()

def calibration_loss(model, pixel_values, ground_truth, processor, loss_type='margin'):
    candidates = generate_candidates(model, pixel_values)
    candidate_sequences = processor.batch_decode(candidates.sequences, skip_special_tokens=True)
    
    ground_truth = processor.decode(ground_truth, skip_special_tokens=True)
    
    similarities = [1 - compute_cer(cand, ground_truth) for cand in candidate_sequences]
    
    positive_pairs = []
    negative_pairs = []
    
    for i in range(len(similarities)):
        for j in range(i + 1, len(similarities)):
            if similarities[i] > similarities[j]:
                positive_pairs.append((i, j))
            else:
                negative_pairs.append((i, j))
    
    if not positive_pairs or not negative_pairs:
        return torch.tensor(0.0, device=pixel_values.device)
    
    positive_scores = candidates.sequences_scores[torch.tensor(positive_pairs)[:, 0]]
    negative_scores = candidates.sequences_scores[torch.tensor(negative_pairs)[:, 1]]
    
    if loss_type == 'rank':
        return rank_loss(positive_scores, negative_scores)
    elif loss_type == 'margin':
        return margin_loss(positive_scores, negative_scores)
    else:
        raise ValueError("Invalid loss type. Choose 'rank' or 'margin'.")

class CalibratedTrainer(Seq2SeqTrainer):
    def __init__(self, *args, **kwargs):
        self.processor = kwargs.pop('processor', None)
        self.calibration_weight = kwargs.pop('calibration_weight', 0.1)
        self.calibration_loss_type = kwargs.pop('calibration_loss_type', 'margin')
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        pixel_values = inputs['pixel_values']
        
        outputs = model.generate(**inputs, return_dict_in_generate=True, output_logits=True)
        
        logits = outputs.logits
        print(logits)

        # Original cross-entropy loss
        ce_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)

        # Calibration loss
        cal_loss = calibration_loss(model, pixel_values, labels, self.processor, self.calibration_loss_type)

        # Combine losses
        total_loss = ce_loss + self.calibration_weight * cal_loss

        return (total_loss, outputs) if return_outputs else total_loss



def compute_metrics(pred):
    # accuracy_metric = evaluate.load("precision")
    cer_metric = evaluate.load("cer")

    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    # accuracy = accuracy_metric.compute(predictions=pred_ids.tolist(), references=labels_ids.tolist())

    return {"cer": cer}

# # Step 5: Define the Trainer
# Step 5: Define the Trainer
trainer = CalibratedTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    # compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
    processor=processor,
    calibration_weight=0.1,
    calibration_loss_type='margin'  # or 'rank'
)


trainer.train()