import numpy as np
import h5py
import os

import mercury as mr

import sys
sys.path.append('/plot_scripts/')
from map_packages_colors_all import *
from plot_scripts_all import *

package_str = ['qiskit' , 'cirq', 'qsimcirq', 'pennylane', 'pennylane_l', 'qibo', 'qibojit', 'yao', 'quest', 'qulacs', 'intel_qs_cpp', 'projectq', 'svsim',  'hybridq', 'hiq', 'qcgpu', 'qrack_sch', 'cuquantum_qiskit', 'cuquantum_qsimcirq', 'qpanda', 'qpp', 'myqlm', 'myqlm_cpp', 'braket']


def _build_data_mat(task, cc, pr_1, pr_2, _n_arr):

    dir = os.getcwd()

    data_mat = np.log(np.zeros((len(package_str), len(_n_arr))))

    for p_i, pack in enumerate(package_str):

        dat_pr1 = dir + '/data/{}/{}_{}_{}.h5'.format(task, pack, cc, pr_1)
        dat_pr2 = dir + '/data/{}/{}_{}_{}.h5'.format(task, pack, cc, pr_2)

        if os.path.isfile(dat_pr1) and os.path.isfile(dat_pr2):


            h5f_pr1 = h5py.File(dat_pr1, 'r')
            dat_pr1 = h5f_pr1[storage_dict[pack]][:]
            h5f_pr1.close()

            h5f_pr2 = h5py.File(dat_pr2, 'r')
            dat_pr2 = h5f_pr2[storage_dict[pack]][:]
            h5f_pr2.close()

            ratio_arr = []

            if len(dat_pr1) == len(dat_pr2):
                for i, elem in enumerate(dat_pr1):
                    ratio_arr.append(elem/float(dat_pr2[i]))
            elif len(dat_pr1) > len(dat_pr2):
                for i, elem in enumerate(dat_pr2):
                    ratio_arr.append(dat_pr1[i]/float(elem))
            elif len(dat_pr2) > len(dat_pr1):
                for i, elem in enumerate(dat_pr1):
                    ratio_arr.append(elem/float(dat_pr2[i]))

            if len(_n_arr) > len(ratio_arr):
                for r_i, rat in enumerate(ratio_arr):
                    data_mat[p_i, r_i] = rat
            elif len(_n_arr) < len(ratio_arr):
                for n_i in range(len(_n_arr)):
                    data_mat[p_i, n_i] = ratio_arr[n_i]
            else:
                for ri, rat_v in enumerate(ratio_arr):
                    data_mat[p_i, ri] = rat_v

    return data_mat

def abs_time_pack(task, cc, N_end, pr_1, pr_2):

    if task == "Heisenberg dynamics":
        task = "hdyn"
    elif task == "Random Quantum Circuit":
        task = "rqc"
    elif task == "Quantum Fourier Transform":
        task = "qft"

    if cc == "Singlethread":
        cc = 'singlethread'
    elif cc == "Multithread":
        cc = 'multithread'
    elif cc == "GPU":
        cc = 'gpu'

    if pr_1 == "Single":
        pr_1 = "sp"
    elif pr_1 == "Double":
        pr_1 = "dp"

    if pr_2 == "Single":
        pr_2 = "sp"
    elif pr_2 == "Double":
        pr_2 = "dp"

    fig, ax = plt.subplots()

    dir = os.getcwd()

    if task == 'hdyn' or task == 'qft':
        N_arr = np.arange(6, N_end, 2)
    elif task == 'rqc':
        N_arr = np.arange(12, N_end, 2)

    # if not os.path.isfile(dat_fst) and not os.path.isfile(dat_fmt) and not os.path.isfile(dat_fgpu):
    #     return mr.Md(f"Precision {pr} possibly not supported")

    data_mat = _build_data_mat(task, cc, pr_1, pr_2, N_arr)

    # params = {'figure.figsize': (10, 10)}
    # plt.rcParams.update(params)
    # plt.imshow(data_mat, cmap='OrRd')#, vmin=-16, vmax=0)

    plt.imshow(data_mat, cmap='gist_heat_r', vmin=-1., vmax=10)

    plt.yticks(range(len(pkg_str)), package_str)
    locs, labels = plt.yticks()

    # plt.setp(labels, rotation=90)
    plt.xticks(range(len(N_arr)), N_arr)
    # locs, labels = plt.xticks()

    ax.xaxis.set_major_locator(ticker.AutoLocator())
    ax.xaxis.set_minor_locator(ticker.AutoMinorLocator())

    plt.colorbar()
    plt.tight_layout()
    # plt.savefig(fn)
    plt.show()

# abs_time_pack("Heisenberg dynamics", "Singlethread", 36, "Double", "Single")
# abs_time_pack("Random Quantum Circuit", "Double", 36, "Singlethread", "Multithread")

def comp_time_pack(task_1, task_2, cc, N_end, pr_1, pr_2):

    if task_1 == "Heisenberg dynamics":
        task_1 = "hdyn"
    elif task_1 == "Random Quantum Circuit":
        task_1 = "rqc"
    elif task_1 == "Quantum Fourier Transform":
        task_1 = "qft"

    if task_2 == "Heisenberg dynamics":
        task_2 = "hdyn"
    elif task_2 == "Random Quantum Circuit":
        task_2 = "rqc"
    elif task_2 == "Quantum Fourier Transform":
        task_2 = "qft"

    if cc == "Singlethread":
        cc = 'singlethread'
    elif cc == "Multithread":
        cc = 'multithread'
    elif cc == "GPU":
        cc = 'gpu'

    if pr_1 == "Single":
        pr_1 = "sp"
    elif pr_1 == "Double":
        pr_1 = "dp"

    if pr_2 == "Single":
        pr_2 = "sp"
    elif pr_2 == "Double":
        pr_2 = "dp"


    fig, ax = plt.subplots()


    dir = os.getcwd()

    if task_1 == 'hdyn' or task_1 == 'qft':
        N_arr_1 = np.arange(6, N_end, 2)
    elif task_1 == 'rqc':
        N_arr_1 = np.arange(12, N_end, 2)

    if task_2 == 'hdyn' or task_2 == 'qft':
        N_arr_2 = np.arange(6, N_end, 2)
    elif task_2 == 'rqc':
        N_arr_2 = np.arange(12, N_end, 2)

    data_mat_1 = np.matrix(_build_data_mat(task_1, cc, pr_1, pr_2, N_arr_1))
    data_mat_2 = np.matrix(_build_data_mat(task_2, cc, pr_1, pr_2, N_arr_2))

    if N_arr_1[0] > N_arr_2[0]:
        data_mat_2 = data_mat_2[:,3:]

    elif N_arr_1[0] < N_arr_2[0]:
        data_mat_1 = data_mat_1[:,3:]

    # print(data_mat_1.shape)
    # print(data_mat_2.shape)

    # plt.imshow(data_mat_1, cmap='OrRd')#, vmin=-16, vmax=0)
    # plt.show()
    # plt.imshow(data_mat_2, cmap='OrRd')#, vmin=-16, vmax=0)
    # plt.show()

    comp_data_mat = np.zeros(data_mat_1.shape)

    for ri in range(comp_data_mat.shape[0]):
        for ci in range(comp_data_mat.shape[1]):
            comp_data_mat[ri, ci] = data_mat_1[ri, ci]/data_mat_2[ri, ci]

    # comp_data_mat = np.matrix(data_mat_1) - np.matrix(data_mat_2)

    # params = {'figure.figsize': (10, 10)}
    # plt.rcParams.update(params)

    # plt.imshow(comp_data_mat, cmap='Spectral')#, vmin=-16, vmax=0)
    plt.imshow(comp_data_mat, cmap='gist_heat_r', vmin=-0.5)

    plt.yticks(range(len(pkg_str)), package_str)
    locs, labels = plt.yticks()

    # plt.setp(labels, rotation=90)
    if N_arr_1[0] > N_arr_2[0]:
        plt.xticks(range(len(N_arr_1)), N_arr_1)
    elif N_arr_1[0] < N_arr_2[0]:
        plt.xticks(range(len(N_arr_2)), N_arr_2)
    else:
        plt.xticks(range(len(N_arr_1)), N_arr_1)
    # locs, labels = plt.xticks()

    ax.xaxis.set_major_locator(ticker.AutoLocator())
    ax.xaxis.set_minor_locator(ticker.AutoMinorLocator())

    plt.colorbar()
    plt.tight_layout()
    # plt.savefig(fn)
    plt.show()

# comp_time_pack("Heisenberg dynamics", "Random Quantum Circuit", "Double", 36, "Singlethread", "Multithread")