import re

import torch
from PIL import Image

from lavis.models import load_model_and_preprocess
from lavis.processors import load_processor
from lavis.common.registry import registry
from torch.nn import functional as F
from lavis.models.base_model import all_gather_with_grad, concat_all_gather
import numpy as np
import pandas as pd
import time
from fuzzywuzzy import process
from multiprocessing import Pool, Queue, Process
import difflib
import Levenshtein
import os
# import obonet


def fuzzy_match(texts):
    text_dict = {}
    for context in texts:
        if context not in choices:
            # txt_dict[txt] = process.extractOne(txt, choices)[0]
            text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
    return text_dict


def txt_map(x, txt_dict):
    if type(x) == str:
        x = eval(x)
    x_ = []
    for i in x:
        if i in txt_dict:
            x_.append(txt_dict[i])
        else:
            x_.append(i)
    return x_


def levenshtein_sim(text, label):
    all_s = []
    for x in label:
        s = 0
        for y in text:
            temp = Levenshtein.ratio(x, y)
            if temp > s:
                s = temp
        all_s.append(s)
    all_s = [round(i, 3) for i in all_s]
    return all_s

def func(text, label):
    all_s = []
    for x in label:
        s = 0
        for y in text:
            temp = Levenshtein.ratio(x, y)
            if temp > s:
                s = temp
        all_s.append(s)
    all_s = [round(i, 3) for i in all_s]
    return all_s


def stage2_output(df_test):
    config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
              'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20230924220/checkpoint_5.pth',
              'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
              'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
              'max_protein_len': 600,
              'max_txt_len': 25}

    model_cls = registry.get_model_class(config['arch'])
    model = model_cls.from_config(config)
    model.to(device)
    model.eval()

    images = df_test['protein'].tolist()
    n = len(images)
    bsz = 12
    iter = n // bsz + 1

    for i in range(iter):
        image = images[i*bsz: min(n, (i+1)*bsz)]
        image = [('protein{}'.format(i), x) for i, x in enumerate(image)]

        with model.maybe_autocast():
            _, _, batch_tokens = model.visual_encoder(image)
            image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)["representations"][model.vis_layers].contiguous()

        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)

        query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = model.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        inputs_opt = model.opt_proj(query_output.last_hidden_state)
        atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)

        model.opt_tokenizer.padding_side = "right"

        text = ['' for i in range(len(image))]
        opt_tokens = model.opt_tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=model.max_txt_len,
        ).to(device)
        inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
        inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
        attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
        num_txt = 10
        return_num_txt = 5
        with model.maybe_autocast():
            outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=3,
                                               max_length=30,
                                               repetition_penalty=5., num_beams=num_txt, eos_token_id=50118,
                                               length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
        output_text = model.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
        output_text = [text.strip() for text in output_text]
        output_text_ = []
        for i in range(len(image)):
            output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
        with open('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), 'a+') as f:
            for i in range(len(image)):
                f.write(image[i][1] + "|" + output_text_[i] + '\n')


cat = 'mf'
fix = '_mf'
if cat == 'bp':
    fix = '_bp'
if cat == 'cc':
    fix = '_cc'

# model_pth = {'mf': 'uniprot_swissprot_mf_stage1_epo19.pth', 'bp': 'checkpoint17_GO_swissprot_reviewed_bp_stage1.pth', 'cc': ''}

# graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")

# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
# device = 'cpu'

### Levenshtein similarity
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test{}.csv'.format(fix), sep='|')[:10000]
test['function'] = test['function'].apply(lambda x: x.lower())


if os.path.exists('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix)):
    os.remove('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix))
print("stage 2 predict starting")
stage2_output(test)
print("stage 2 predict completed")



df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), sep='|', header=None, on_bad_lines='warn')
df_pred.columns = ['protein', 'function']
df_pred = df_pred.drop_duplicates()
df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';'))
df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))])

test.columns
test_g = test.groupby(['protein']).agg({'function': lambda x: list(x)}).reset_index()
test_g.columns = ['protein', 'label']

data = pd.merge(df_pred, test_g, on='protein', how='left')
data = data[data['label'].notnull()]

sim = []
for text, label in zip(data['function'].tolist(), data['label'].tolist()):
    sim.append(func(text, label))

data['sim'] = sim
data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
print("average similarity score: {}".format(round(data['avg_score'].mean(), 3)))
# data.to_csv('/home/nilin/LAVIS/predict_{}.csv'.format(cat), index=False, sep='|')


test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label'])
test['function'] = test['function'].apply(lambda x: x.lower())
test = test.drop_duplicates()
test_dict = dict(zip(test['function'], test['GO_label']))
val = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/val{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label'])
val['function'] = val['function'].apply(lambda x: x.lower())
val = val.drop_duplicates()
val_dict = dict(zip(val['function'], val['GO_label']))
train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/train{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label'])
train['function'] = train['function'].apply(lambda x: x.lower())
train = train.drop_duplicates()
train_dict = dict(zip(train['function'], train['GO_label']))


# go_des = pd.read_csv('/home/nilin/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
# # go_des = pd.read_csv('/home/nilin/LAVIS/data/go_descriptions.txt', sep='|', header=None)
# go_des.columns = ['GO', 'function']
# go_des = go_des[go_des['function'].notnull()]
# go_des['function'] = go_des['function'].apply(lambda x: x.lower())
# GO_dict = dict(zip(go_des['function'], go_des['GO']))
GO_dict = {}
GO_dict.update(train_dict)
GO_dict.update(val_dict)
GO_dict.update(test_dict)
choices = list(GO_dict.keys())



# data = pd.read_csv('/home/nilin/LAVIS/predict_{}.csv'.format(cat), sep='|')
data = data.sort_values(by='protein')
data = data.drop_duplicates('protein')
# data = data.sample(1000)

### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
t0 = time.time()
txt_dict = {}

all_txt = []
for txt in data['function']:
    if type(txt) == str:
        all_txt.extend(eval(txt))
    else:
        all_txt.extend(txt)
all_txt = list(set(all_txt))

n = len(all_txt)
thread = 20
size = int(n/thread)
inds = list(range(0, n, size))
inds.append(n)
all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]

with Pool(processes=thread) as pool:
    result = pool.map(fuzzy_match, all_txt_sep)
pool.close()
pool.join()
for d in result:
    txt_dict.update(d)

# for txt in all_txt[:10]:
#     fuzzy_match(txt)

data['function'] = data['function'].apply(lambda x: txt_map(x, txt_dict))
data['function'] = data['function'].apply(lambda x: list(set(x)))
print("fuzzy matching time: {}".format(time.time() - t0))




### Find the generated GO text that not included in the ground truth. Then generate pairs between them.
# pair_a, pair_b = [], []
# for preds, labels in zip(data['function'], data['label']):
#     if type(preds) == str:
#         preds = eval(preds)
#     if type(labels) == str:
#         labels = eval(labels)
#     l = len(labels)
#     for pred in preds:
#         if pred not in labels:
#             pair_a.extend([pred]*l)
#             pair_b.extend(labels[:])
# pair_a = [re.sub('_', ':', GO_dict[i]) for i in pair_a]
# pair_b = [re.sub('_', ':', GO_dict[i]) for i in pair_b]
# with open('/home/nilin/LAVIS/examples/GO_pair{}.txt'.format(fix), 'w+') as f:
#     for i, j in zip(pair_a, pair_b):
#         f.write(i+' '+j+'\n')


# load model
model_config = {'arch': 'blip2_protein', 'load_finetuned': False,
                'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage1/20230922185/checkpoint_15.pth',
                'finetuned': '', 'num_query_token': 32, 'prompt': '',
                'model_type': 'pretrain', 'load_pretrained': True, 'freeze_vit': False,
                'max_protein_len': 512, 'max_txt_len': 25}

model_cls = registry.get_model_class(model_config['arch'])
model = model_cls.from_config(model_config)
model = model.to(device)
model.eval()

# evaluate
t0 = time.time()
proteins = list(data['protein'])
txts = list(data['function'])
scores = []
for seq, txt in zip(proteins, txts):
    image = [('protein1', seq)]
    _, _, batch_tokens = model.visual_encoder(image)
    image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[30], return_contacts=True)["representations"][
        30].contiguous()

    image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)

    query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)

    query_output = model.Qformer.bert(
        query_embeds=query_tokens,
        encoder_hidden_states=image_embeds,
        encoder_attention_mask=image_atts,
        use_cache=True,
        return_dict=True,
    )

    image_feats = F.normalize(model.vision_proj(query_output.last_hidden_state), dim=-1)

    image_feats_all = concat_all_gather(image_feats)

    if type(txt) == str:
        txt = eval(txt)
    length = len(txt)
    with torch.no_grad():
        text_tokens = model.tokenizer(
            txt,
            padding="max_length",
            truncation=True,
            max_length=model.max_txt_len,
            return_tensors="pt",
        ).to(device)
        text_output = model.Qformer.bert(
            text_tokens.input_ids,
            attention_mask=text_tokens.attention_mask,
            return_dict=True,
        )

        text_feat = F.normalize(
            model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
        )

        text_feat_all = concat_all_gather(text_feat)
        sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze()
        sim_i2t, _ = sim_q2t.max(-1)
        # print('sim_i2t: {}'.format(sim_i2t))
    if length > 1:
        scores.append(list(sim_i2t.detach().cpu().numpy()))
    else:
        scores.append([sim_i2t.item()])
print("model evaluate time: {}".format(time.time() - t0))
data['score'] = scores

# precision and recall top-k
topk = 2
threshould = 0.1
labels = []
pred_labels = []
for l in data['label']:
    if type(l) == str:
        l = eval(l)
    labels.extend(l)

labels = list(set(labels))
total = len(labels)
for topk in range(1,7):
    for threshould in range(1, 25, 1):
        threshould /= 100
        filter_txts = []
        recalls = []
        precisions = []
        f1 = []
        tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
        for txts, scores, label in zip(data['function'], data['score'], data['label']):
            if type(label) == str:
                label = eval(label)
            txts_ = np.array(txts)
            scores = np.array(scores)
            txts = txts_[scores > threshould]
            if len(txts) < 1:
                txts = txts_[np.argmax(scores)]
            scores = scores[scores > threshould]
            
            l = len(scores)
            ll = len(label)
            if l <= topk:
                filter_txts.append(list(txts))
            else:
                ind = np.argpartition(scores, -topk)[-topk:]
                txts = txts[ind]
                filter_txts.append(list(txts))
                l = topk
            for t in label:
                if t in txts:
                    tp_dict[t] += 1
                else:
                    fn_dict[t] += 1
            for p in txts:
                if p not in label:
                    if p in fp_dict:
                        fp_dict[p] += 1
                    else:
                        fp_dict[p] = 1
            pred_labels.extend(txts)
        p_total = len(set(pred_labels))
        re, pr = 0., 0.
        for x in labels:
            re += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
            pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x]+1e-8))
        r = re / total
        p = pr / total
        f1 = 2 * p * r / (p + r)
        print("Topk: {}, threshould: {}, macro_recall: {}, macro_precision: {}, micro_f1: {}".format(topk, threshould, r, p, f1))
        #     num_r = 0
        #     num_p = 0
        #     for x in label:
        #         if x in txts:
        #             num_r += 1
        #     for x in txts:
        #         if x in label:
        #             num_p += 1
        #     recall = num_r/ll
        #     precision = num_p/(l+0.0001)
        #     recalls.append(recall)
        #     precisions.append(precision)
        #     f1.append((2*recall*precision)/(recall+precision+0.0001))
        #
        # data['predict'] = filter_txts
        # data['precision'] = precisions
        # data['recall'] = recalls
        # data['f1'] = f1
        # print("Topk: {}, threshould: {}, macro_recall: {}, macro_precision: {}, micro_f1: {}".format(topk, threshould, round(data['recall'].mean(), 4), round(data['precision'].mean(), 4), round(data['f1'].mean(), 4)))
    





# sim = []
# for text, label in zip(data['predict'].tolist(), data['label'].tolist()):
#     sim.append(levenshtein_sim(text, label))
#
# data['sim_filter'] = sim
# data['avg_score'] = data['sim_filter'].apply(lambda x: round(np.mean(x), 3))


# data['function'] = data['function'].apply(lambda x: eval(re.sub(';', ',', str(x))))
# data['label'] = data['label'].apply(lambda x: eval(re.sub(';', ',', str(x))))
# data['sim'] = data['sim'].apply(lambda x: eval(re.sub(';', ',', str(x))))
#
# data['function'] = data['function'].apply(lambda x: re.sub(',', ';', str(x)))
# data['label'] = data['label'].apply(lambda x: re.sub(',', ';', str(x)))
# data['sim'] = data['sim'].apply(lambda x: re.sub(',', ';', str(x)))
# data['predict'] = data['predict'].apply(lambda x: re.sub(',', ';', str(x)))
# data['sim_filter'] = data['sim_filter'].apply(lambda x: re.sub(',', ';', str(x)))

data.to_csv('/cluster/home/wenkai/LAVIS/output/predict_sim{}.csv'.format(fix), sep='|', index=False)
# data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_sim{}.csv'.format(fix), sep='|')








#
# # example
# image = ['MIELKHVTFGYNKKQMVLQDINITIPDGENVGILGESGCGKSTLASLVLGLFKPVKGEIYLSDNAVLTIFQHPLTSFNPDWTIETSLKEALYYYRGLTDNTAQDQLLLQHLSTFELNAQLLTKLPSEVSGGQLQRFNVMRSLLAQPRVLICDEITSNLDVIAEQNVINILKAQTITNLNHFIVISHDLSVLQRLVNRIIVLKDGMIVDDFAIEELFNVDRHPYTKELVQTFSY']
# image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
#
# _, _, batch_tokens = model.visual_encoder(image)
# image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[30], return_contacts=True)["representations"][30].contiguous()
#
# image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
#
# query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
#
# query_output = model.Qformer.bert(
#     query_embeds=query_tokens,
#     encoder_hidden_states=image_embeds,
#     encoder_attention_mask=image_atts,
#     use_cache=True,
#     return_dict=True,
# )
#
# image_feats = F.normalize(model.vision_proj(query_output.last_hidden_state), dim=-1)
#
# image_feats_all = concat_all_gather(image_feats)
#
# functions = ['transmembrane transporter activity', 'nickel cation transmembrane transporter activity', 'nickel cation binding', 'atp hydrolysis activity', 'atp hydrolysis', 'cadmium binding', 'abc-type nickel transmembrane transporter activity', 'abc-type nickel transporter activity', 'nickel transmembrane transporter activity', 'atp binding']
# for text in functions:
#     with torch.no_grad():
#         # text = 'flavin adenine dinucleotide binding'
#         text_tokens = model.tokenizer(
#             text,
#             padding="max_length",
#             truncation=True,
#             max_length=model.max_txt_len,
#             return_tensors="pt",
#         ).to(device)
#         text_output = model.Qformer.bert(
#             text_tokens.input_ids,
#             attention_mask=text_tokens.attention_mask,
#             return_dict=True,
#         )
#
#         text_feat = F.normalize(
#             model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
#         )
#
#         text_feat_all = concat_all_gather(text_feat)
#         sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze()
#         sim_i2t, _ = sim_q2t.max(-1)
#         print('sim_i2t: {}'.format(sim_i2t))
#
#         # # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]
#         # sim_t2q = torch.matmul(
#         #     text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
#         # ).squeeze()
#         #
#         # # text-image similarity: aggregate across all query tokens
#         # sim_t2i, _ = sim_t2q.max(-1)
#         # print('sim_t2i: {}'.format(sim_t2i))