import settings
import captum
import numpy as np
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from utils import get_args
from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter
import string
import time
import sys
from dataset import hierarchical_dataset, AlignCollate
import validators
from model import Model, STRScore
from PIL import Image
from lime.wrappers.scikit_image import SegmentationAlgorithm
from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge
import random
import os
from skimage.color import gray2rgb
import pickle
from train_shap_corr import getPredAndConf
import re
import copy
import statistics

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from captum.attr import (
    GradientShap,
    DeepLift,
    DeepLiftShap,
    IntegratedGradients,
    LayerConductance,
    NeuronConductance,
    NoiseTunnel,
    Saliency,
    InputXGradient,
    GuidedBackprop,
    Deconvolution,
    GuidedGradCam,
    FeatureAblation,
    ShapleyValueSampling,
    Lime,
    KernelShap
)

from captum.metrics import (
    infidelity,
    sensitivity_max
)

### Returns the mean for each segmentation having shape as the same as the input
### This function can only one attribution image at a time
def averageSegmentsOut(attr, segments):
    averagedInput = torch.clone(attr)
    sortedDict = {}
    for x in np.unique(segments):
        segmentMean = torch.mean(attr[segments == x][:])
        sortedDict[x] = float(segmentMean.detach().cpu().numpy())
        averagedInput[segments == x] = segmentMean
    return averagedInput, sortedDict

### Output and save segmentations only for one dataset only
def outputSegmOnly(opt):
    ### targetDataset - one dataset only, SVTP-645, CUTE80-288images
    targetDataset = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80']
    segmRootDir = "/home/uclpc1/Documents/STR/datasets/segmentations/224X224/{}/".format(targetDataset)

    if not os.path.exists(segmRootDir):
        os.makedirs(segmRootDir)

    opt.eval = True
    ### Only IIIT5k_3000
    if opt.fast_acc:
    # # To easily compute the total accuracy of our paper.
        eval_data_list = [targetDataset]
    else:
        # The evaluation datasets, dataset order is same with Table 1 in our paper.
        eval_data_list = [targetDataset]

    ### Taken from LIME
    segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4,
                                            max_dist=200, ratio=0.2,
                                            random_seed=random.randint(0, 1000))

    for eval_data in eval_data_list:
        eval_data_path = os.path.join(opt.eval_data, eval_data)
        AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt)
        eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt)
        evaluation_loader = torch.utils.data.DataLoader(
            eval_data, batch_size=1,
            shuffle=False,
            num_workers=int(opt.workers),
            collate_fn=AlignCollate_evaluation, pin_memory=True)
        for i, (image_tensors, labels) in enumerate(evaluation_loader):
            imgDataDict = {}
            img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only
            if img_numpy.shape[0] == 1:
                img_numpy = gray2rgb(img_numpy[0])
            # print("img_numpy shape: ", img_numpy.shape) # (224,224,3)
            segmOutput = segmentation_fn(img_numpy)
            imgDataDict['segdata'] = segmOutput
            imgDataDict['label'] = labels[0]
            outputPickleFile = segmRootDir + "{}.pkl".format(i)
            with open(outputPickleFile, 'wb') as f:
                pickle.dump(imgDataDict, f)

def acquireSelectivityHit(origImg, attributions, segmentations, model, converter, labels, scoring):
    # print("segmentations unique len: ", np.unique(segmentations))
    aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations)
    sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])]
    sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score
    # print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...}
    # print("aveSegmentations unique len: ", np.unique(aveSegmentations))
    # print("aveSegmentations device: ", aveSegmentations.device) # cuda:0
    # print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224)
    # print("aveSegmentations: ", aveSegmentations)

    n_correct = []
    confidenceList = [] # First index is one feature removed, second index two features removed, and so on...
    clonedImg = torch.clone(origImg)
    gt = str(labels)
    for totalSegToHide in range(0, len(sortedKeys)):
        ### Acquire LIME prediction result
        currentSegmentToHide = sortedKeys[totalSegToHide]
        clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0
        pred, confScore = getPredAndConf(opt, model, scoring, clonedImg, converter, np.array([gt]))
        # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
        if opt.sensitive and opt.data_filtering_off:
            pred = pred.lower()
            gt = gt.lower()
            alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
            out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
            pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
            gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)
        if pred == gt:
            n_correct.append(1)
        else:
            n_correct.append(0)
        confScore = confScore[0][0]*100
        confidenceList.append(confScore)
    return n_correct, confidenceList

### Once you have the selectivity_eval_results.pkl file,
def acquire_selectivity_auc(opt, pkl_filename=None):
    if pkl_filename is None:
        pkl_filename = "/home/goo/str/str_vit_dataexplain_lambda/metrics_sensitivity_eval_results_CUTE80.pkl" # VITSTR
    accKeys = []

    with open(pkl_filename, 'rb') as f:
        selectivity_data = pickle.load(f)

    for resDictIdx, resDict in enumerate(selectivity_data):
        keylistAcc = []
        keylistConf = []
        metricsKeys = resDict.keys()
        for keyStr in resDict.keys():
            if "_acc" in keyStr: keylistAcc.append(keyStr)
            if "_conf" in keyStr: keylistConf.append(keyStr)
        # Need to check if network correctly predicted the image
        for metrics_accStr in keylistAcc:
            if 1 not in resDict[metrics_accStr]: print("resDictIdx")

## gtClassNum - set to gtClassNum=0 for standard implemention, or specific class idx for local explanation
def acquireAttribution(opt, super_model, input, segmTensor, gtClassNum, lowestAccKey, device):
    channels = 1
    if opt.rgb:
        channels = 3

    ### Perform attribution
    if "intgrad_" in lowestAccKey:
        ig = IntegratedGradients(super_model)
        attributions = ig.attribute(input, target=gtClassNum)
    elif "gradshap_" in lowestAccKey:
        gs = GradientShap(super_model)
        baseline_dist = torch.zeros((1, channels, opt.imgH, opt.imgW))
        baseline_dist = baseline_dist.to(device)
        attributions = gs.attribute(input, baselines=baseline_dist, target=gtClassNum)
    elif "deeplift_" in lowestAccKey:
        dl = DeepLift(super_model)
        attributions = dl.attribute(input, target=gtClassNum)
    elif "saliency_" in lowestAccKey:
        saliency = Saliency(super_model)
        attributions = saliency.attribute(input, target=gtClassNum)
    elif "inpxgrad_" in lowestAccKey:
        input_x_gradient = InputXGradient(super_model)
        attributions = input_x_gradient.attribute(input, target=gtClassNum)
    elif "guidedbp_" in lowestAccKey:
        gbp = GuidedBackprop(super_model)
        attributions = gbp.attribute(input, target=gtClassNum)
    elif "deconv_" in lowestAccKey:
        deconv = Deconvolution(super_model)
        attributions = deconv.attribute(input, target=gtClassNum)
    elif "featablt_" in lowestAccKey:
        ablator = FeatureAblation(super_model)
        attributions = ablator.attribute(input, target=gtClassNum, feature_mask=segmTensor)
    elif "shapley_" in lowestAccKey:
        svs = ShapleyValueSampling(super_model)
        attributions = svs.attribute(input, target=gtClassNum, feature_mask=segmTensor)
    elif "lime_" in lowestAccKey:
        interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME
        lime = Lime(super_model, interpretable_model=interpretable_model)
        attributions = lime.attribute(input, target=gtClassNum, feature_mask=segmTensor)
    elif "kernelshap_" in lowestAccKey:
        ks = KernelShap(super_model)
        attributions = ks.attribute(input, target=gtClassNum, feature_mask=segmTensor)
    else:
        assert False
    return attributions

### In addition to acquire_average_auc(), this function also returns the best selectivity_acc attr-based method
### pklFile - you need to pass pkl file here
def acquire_bestacc_attr(opt, pickleFile):
    # pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl"
    # pickleFile = "/home/goo/str/str_vit_dataexplain_lambda/shapley_singlechar_ave_matrn_SVT.pkl"
    acquireSelectivity = True # If True, set to
    acquireInfidelity = False
    acquireSensitivity = False

    with open(pickleFile, 'rb') as f:
        data = pickle.load(f)
    metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens"
    selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle
    for imgData in data:
        if acquireSelectivity:
            for keyStr in imgData.keys():
                if ("_acc" in keyStr or "_conf" in keyStr) and not ("_local_" in keyStr or "_global_local_" in keyStr): # Accept only selectivity
                    if keyStr not in metricDict:
                        metricDict[keyStr] = []
                    dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0]
                    dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0
                    denom = [1] * len(dataList) # Denominator to normalize AUC
                    auc_norm = np.trapz(dataList) / np.trapz(denom)
                    metricDict[keyStr].append(auc_norm)
        elif acquireInfidelity:
            pass # TODO
        elif acquireSensitivity:
            pass # TODO

    lowestAccKey = ""
    lowestAcc = 10000000
    for metricKey in metricDict:
        if "_acc" in metricKey: # Used for selectivity accuracy only
            statisticVal = statistics.mean(metricDict[metricKey])
            if statisticVal < lowestAcc:
                lowestAcc = statisticVal
                lowestAccKey = metricKey
        # print("{}: {}".format(metricKey, statisticVal))

    assert lowestAccKey!=""
    return lowestAccKey

def saveAttrData(filename, attribution, segmData, origImg):
    pklData = {}
    pklData['attribution'] = torch.clone(attribution).detach().cpu().numpy()
    pklData['segmData'] = segmData
    pklData['origImg'] = origImg
    with open(filename, 'wb') as f:
        pickle.dump(pklData, f)

### New code (8/3/2022) to acquire average selectivity, infidelity, etc. after running captum test
def acquire_average_auc(opt):
    # pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl"
    pickleFile = "/home/goo/str/str_vit_dataexplain_lambda/shapley_singlechar_ave_vitstr_IC03_860.pkl"
    acquireSelectivity = True # If True, set to
    acquireInfidelity = False
    acquireSensitivity = False

    with open(pickleFile, 'rb') as f:
        data = pickle.load(f)
    metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens"
    selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle
    for imgData in data:
        if acquireSelectivity:
            for keyStr in imgData.keys():
                if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity
                    if keyStr not in metricDict:
                        metricDict[keyStr] = []
                    dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0]
                    dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0
                    denom = [1] * len(dataList) # Denominator to normalize AUC
                    auc_norm = np.trapz(dataList) / np.trapz(denom)
                    metricDict[keyStr].append(auc_norm)
        elif acquireInfidelity:
            pass # TODO
        elif acquireSensitivity:
            pass # TODO

    for metricKey in metricDict:
        print("{}: {}".format(metricKey, statistics.mean(metricDict[metricKey])))

### Use this acquire list
def acquireListOfAveAUC(opt):
    acquireSelectivity = True
    acquireInfidelity = False
    acquireSensitivity = False
    totalChars = 10
    collectedMetricDict = {}
    for charNum in range(0, totalChars):
        pickleFile = f"/home/goo/str/str_vit_dataexplain_lambda/singlechar{charNum}_results_{totalChars}chardataset.pkl"
        with open(pickleFile, 'rb') as f:
            data = pickle.load(f)
        metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens"
        selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle
        for imgData in data:
            if acquireSelectivity:
                for keyStr in imgData.keys():
                    if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity
                        if keyStr not in metricDict:
                            metricDict[keyStr] = []
                        dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0]
                        dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0
                        denom = [1] * len(dataList) # Denominator to normalize AUC
                        auc_norm = np.trapz(dataList) / np.trapz(denom)
                        metricDict[keyStr].append(auc_norm)
        for metricKey in metricDict:
            selec_auc_normalize = statistics.mean(metricDict[metricKey])
            if metricKey not in collectedMetricDict:
                collectedMetricDict[metricKey] = []
            collectedMetricDict[metricKey].append(selec_auc_normalize)
    for collectedMetricDictKey in collectedMetricDict:
        print("{}: {}".format(collectedMetricDictKey, collectedMetricDict[collectedMetricDictKey]))
    for charNum in range(0, totalChars):
        selectivityAcrossCharsLs = []
        for collectedMetricDictKey in collectedMetricDict:
            if "_acc" in collectedMetricDictKey:
                selectivityAcrossCharsLs.append(collectedMetricDict[collectedMetricDictKey][charNum])
        print("accuracy -- {}: {}".format(charNum, statistics.mean(selectivityAcrossCharsLs)))
    for charNum in range(0, totalChars):
        selectivityAcrossCharsLs = []
        for collectedMetricDictKey in collectedMetricDict:
            if "_conf" in collectedMetricDictKey:
                selectivityAcrossCharsLs.append(collectedMetricDict[collectedMetricDictKey][charNum])
        print("confidence -- {}: {}".format(charNum, statistics.mean(selectivityAcrossCharsLs)))

if __name__ == '__main__':
    # deleteInf()
    opt = get_args(is_train=False)

    """ vocab / character number configuration """
    if opt.sensitive:
        opt.character = string.printable[:-6]  # same with ASTER setting (use 94 char).

    cudnn.benchmark = True
    cudnn.deterministic = True
    opt.num_gpu = torch.cuda.device_count()

    main(opt)