from collections import Counter
from typing import Union
from dataclasses import make_dataclass, field
from transformers import T5Config
import ctypes
import os
import platform
import re
import torch

from datasketch import MinHash, MinHashLSH
from collections import defaultdict
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers import TrainingArguments, TrainerCallback

# from nltk import ngrams
from nltk.translate.bleu_score import sentence_bleu
import numpy as np
import ujson

from config import T5ModelConfig

# 结束标点符号
END_PUN = set(".。!!))》}】??\"”")

class MyTrainerCallback(TrainerCallback):
    log_cnt = 0
    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        '''
        在打印 n 次日志后清除cuda缓存,适合低显存设备,能防止OOM
        '''
        self.log_cnt += 1
        if self.log_cnt % 2 == 0:
            torch.cuda.empty_cache()
    
    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        '''
        在 on_epoch_end 时保存一次模型。
        TrainingArguments的 save_strategy 中 epoch 和 steps 不兼容。要实现每隔 save_steps 步保存一次检查点,考虑到磁盘空间大小,最多只保存最近N个检查点。
        '''
        # 设置should_save=True并返回即可
        control.should_save = True
        return control


# 保留中文和英文、下划线,不要标点符号
NON_CHAR = re.compile("[^[\u4E00-\u9FA5|A-Za-z_0-9]")

def _get_doc_mini_hash(doc: list[str] | str, num_perm: int) -> MinHash:
    '''
    获取一段文本的mini hash
    '''
    mini_hash = MinHash(num_perm=num_perm)
    for s in doc:
        mini_hash.update(s.encode('utf-8'))
    return mini_hash

class DropDatasetDuplicate:

    def __init__(self,  threshold: float=0.85, num_perm: int=256) -> None:
        '''
        获取一个数据集中所有重复(相似的超过threshold)的index,输入为:list[str],一个str元素为一段文本(doc)
        如输入: [a, b, c, d, c, d, e] 返回:{4, 5} (后面两个 c, d 的index)
        '''
        self.similar_index_cluster = defaultdict(set)
        self.data_lsh = MinHashLSH(threshold=threshold, num_perm=num_perm) 
        self.num_perm = num_perm

    def add_doc(self, index: object, doc: str,) -> set[int]:
        '''
        添加文档,
        index: 文档的索引
        doc: 文档本身
        '''

        # 只保留中文和英文、下划线,不要标点符号
        doc = ''.join(NON_CHAR.split(doc))
        # doc = [''.join(t) for t in list(ngrams(doc, 3))]

        doc_hash = _get_doc_mini_hash(doc, self.num_perm)
        close_duplicates = self.data_lsh.query(doc_hash)

        self.data_lsh.insert(index, doc_hash)

        # 所有相似的doc在similar_index_cluster中的key都是最早出现的idx
        # 如:data中索引inndex 2, 7, 8, 9, 10, 12 是相似的,则在similar_index_cluster中表现为 {2: {8, 9, 10, 12}}
        if len(close_duplicates) > 0:
            min_idx= min(close_duplicates)
            self.similar_index_cluster[min_idx].add(index)
    
    def get_duplicate_indexs(self):
        '''
        返回所有的重复文档索引
        '''
        similar_index_cluster = self.similar_index_cluster
        need_to_remove_idx = set()
        
        for key_idx in similar_index_cluster.keys():
            need_to_remove_idx |= similar_index_cluster[key_idx]

        return need_to_remove_idx


def get_T5_config(config: T5ModelConfig, vocab_size: int, decoder_start_token_id: int=0, eos_token_id: int=1) -> T5Config:
    '''
    用户配置转换为T5Config
    '''
    t5_config = T5Config()
    # t5_config.model_type = 'TextToTextModel'
    # 初始化
    t5_config.d_ff = config.d_ff
    t5_config.d_kv = config.d_kv
    t5_config.d_model = config.d_model
    t5_config.num_decoder_layers = config.num_decoder_layers
    t5_config.num_heads = config.num_heads
    t5_config.num_layers = config.num_layers
    t5_config.vocab_size = vocab_size
    t5_config.decoder_start_token_id = decoder_start_token_id
    t5_config.eos_token_id = eos_token_id

    return t5_config

def f1_p_r_compute(spo_list_pred: list, spo_list_true: list, repair: bool=False):
    '''
    spo_list: [ [(s,p,o)...], [(s,p,o)]], 每一行[(s,p,o)...]为一个句子中的spo
    计算spo的f1分数,精确率,召回率,
    '''
    assert len(spo_list_pred) == len(spo_list_true)

    def repair_song_album(spo_list: list, song: list, album: list):
        '''
        修复一条文本的'歌曲'和'专辑'的spo。对于歌曲x(subject)的关系歌手、作词、作曲,x必须同时存在于song和album中
        '''
        if len(song) == 0 and len(album) == 0:
            return spo_list

        ps = ['歌手', '作词', '作曲']
        new_spo_list = []
        for spo in spo_list:
            s, p = spo[0], spo[1]
            if p in ps and s in album and s not in song:
                continue
            new_spo_list.append(spo)
        
        return new_spo_list

    def repair_song_album_list(spo_list: list):
        '''
        '''
        new_spo_list = []
        for spos in spo_list:
            song, album = [], []
            for spo in spos:
                s, p, o = spo
                if p == '所属专辑':
                    song.append(s)
                    album.append(o)
            new_spo_list.append(repair_song_album(spos, song, album))
        
        return new_spo_list
    if repair:
        spo_list_pred = repair_song_album_list(spo_list_pred)
        spo_list_true = repair_song_album_list(spo_list_true)

    TP = 1e-10      # 正类判定为正类, A
    # TN = 1e-10    # 负类判定为负类
    TP_FP = 1e-10   # 检索到的, A + B
    TP_FN = 1e-10   # 真正想要的,A + C
    # FP = 1e-10    # 负类判定为正类
    # FN = 1e-10    # 正类判定为负类

    # p = a / (a + b)
    # r = a / (a + c)
    # f1 = 2pr / (p + r)

    for i in range(len(spo_list_true)):
        pred_set = set(spo_list_pred[i])
        true_set = set(spo_list_true[i])

        pred_true_set = pred_set & true_set     # 预测和真实取交集

        TP += len(pred_true_set)    # 检索到且是想要的, A
        TP_FP += len(pred_set)      # 检索到的,包括想要的和不想要的,A + B
        TP_FN += len(true_set)      # 真正想要的, 包括检索到和没检索到的,A + C

    p = TP / TP_FP
    r = TP / TP_FN
    f1 = (2 * p * r) / (p + r)
    
    return f1, p, r


def fixed_response(item: str) -> str:
    '''
    修复被截断的回答,从末尾往回找第一个结束标点
    '''
    if len(item) <= 1: return item
    if item[-1] in END_PUN: return item

    n = len(item)
    i = n - 1
    while i > 0 and item[i] not in END_PUN:
        i -= 1

    return ''.join(item[0: i + 1])


def fixed_space(sentence: str)->str:
    '''单个空格删除,连续两个空格保留一个
    '''
    n = len(sentence)
    new_sentence = []
    i = 0
    while i < n:
        word =  sentence[i]
        if word != ' ':
            new_sentence.append(word)
        elif i + 1 < n and sentence[i + 1] == ' ':
            new_sentence.append(word)
            i += 1 # 两个空格保留一个,指针往下走一步
        i += 1

    return ''.join(new_sentence)

def get_free_space_of_disk(folder: str='./') -> float:
    '''
    获取指定目录所在磁盘大小,返回单位: GB
    '''
    res_val = 0.0
    if platform.system() == 'Windows':
        free_bytes = ctypes.c_ulonglong(0)
        ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(folder), None, None, ctypes.pointer(free_bytes))
        res_val = free_bytes.value 
    else:
        st = os.statvfs(folder)
        res_val = st.f_bavail * st.f_frsize
    
    return res_val / (1024 ** 3)

def my_average(arry_list: list[float]) -> float:
    '''
    自定义均值计算,空数组返回0.0
    '''
    if len(arry_list) == 0: return 0.0 
    
    return np.average(arry_list)


def json_to_dataclass(json_file: str, class_name: str='Config') -> type:
    '''
    将json配置文件转换为dataclass
    >>> example:
    >>> data_class = json_to_dataclass('my_config.json', 'Config')
    >>> my_config = data_class()
    >>> assert my_config.name == 'Alice'
    >>> my_config.name = 'Bob' 
    '''
    json_dict = {}
    with open(json_file, 'r', encoding='utf-8') as f:
        json_dict = ujson.load(f)

    # 将dict转换为可迭代的属性名称、属性类型,默认值
    fields_list = []
    for k, v in json_dict.items():
        fields_list.append( (k, type(v), field(default=v)) )
    
    data_class = make_dataclass(cls_name=class_name, fields=fields_list)

    return data_class


def get_path_of_suffix_files(root: str, suffix: str, with_create_time: bool=False) -> list:
    '''
        获取指定目录下下指定后缀的所有文件的绝对路径
    '''
    suffix_files = []
    for root, _, files in os.walk(root):
        for file in files:
            if file.endswith(suffix):
                full_path = '{}/{}'.format(root, file)
                if with_create_time:
                    suffix_files.append( (full_path, os.path.getctime(full_path)) )
                else:
                    suffix_files.append(full_path)
                            
    return suffix_files

def get_bleu4_score(reference: Union[str, list[str]], outputs: Union[str, list[str]], n_gram: int=4) -> float:
    '''
    获取bleu4分数
    '''
    
    weights = np.ones(n_gram) * (1.0 / n_gram)

    outputs_len, reference_len = len(outputs), len(reference)

    if not type(reference) is list:
        reference = list(reference)
    if not type(outputs) is list:
        outputs = list(outputs)

    outputs_counter = extract_Ngram(outputs, n_gram=n_gram)
    reference_counter = extract_Ngram(reference, n_gram=n_gram)

    ngram_counter_clip = outputs_counter & reference_counter

    clip_counter = np.zeros(n_gram)
    output_ngram_counter = np.zeros(n_gram)

    for (key, ngram), cnt in ngram_counter_clip.items():
        clip_counter[ngram - 1] += cnt 
    
    for (key, ngram), cnt in outputs_counter.items():
        output_ngram_counter[ngram - 1] += cnt
    
    # print(clip_counter, output_ngram_counter)
    if np.min(clip_counter) == 0.0:
        return np.array(0.0)

    precision_scores = clip_counter / output_ngram_counter
   
    # bleu
    log_precision_scores = weights * np.log(precision_scores)
    
    # 几何平均形式求平均值然后加权
    geometric_mean = np.exp(np.sum(log_precision_scores))
    brevity_penalty = np.exp(1 - (reference_len / outputs_len))

    # brevity_penalty = 1.0,   bleu = sentence_bleu([reference], outputs)
    # brevity_penalty = 1.0

    bleu = brevity_penalty * geometric_mean

    return bleu


def extract_Ngram(words_list: list[str], n_gram: int) -> tuple:
    '''
    获取一个句子的n_grama
    return:
        ngram_counter: key = ('w1  w2 ... wn', n_gram), value: count of key
    '''
    n = len(words_list)
    ngram_counter = Counter()

    for i in range(1, n_gram + 1):
        for j in range(n - i + 1):
            key = ' '.join(words_list[j: j + i])
            ngram_counter[(key, i)] += 1

    return ngram_counter


def save_model_config(config_dict: dict, file: str) -> None:
    '''
    将模型配置写入到json文件, 输入模型保存的目录及文件名
    '''
    # file = file.replace('\\', '/')
    # file = '{}/model_config.json'.format('/'.join(file.split('/')[0: -1]))
    
    with open(file, 'w', encoding='utf-8') as f:
        ujson.dump(config_dict, f, indent=4, ensure_ascii=False)

if __name__ == '__main__':
    ref = '抱歉,我不知道ABB代表什么意思'
    out = '我不明白ABB是什么意思'
    b1 = sentence_bleu([list(out)], list(ref),  weights=(0.25, 0.25, 0.25, 0.25))
    print(b1)
    b2 = get_bleu4_score(out, ref)
    print(b2)

    
    candidate_corpus = ['i', 'have', 'a', 'pen', 'on', 'my', 'desk', 'a', 'b', 'c', 'd','f','f']
    reference_corpus = ['there', 'is', 'a', 'pen', 'on', 'my', 'desk', 'a', 'b', 'd', 'd', 'fd']
    
    print('----')
    print(sentence_bleu([reference_corpus], candidate_corpus,  weights=(0.25, 0.25, 0.25, 0.25)))
    print(get_bleu4_score(reference_corpus, candidate_corpus))