import os
import string
import shutil
from itertools import permutations, chain
from collections import defaultdict
from tqdm import tqdm
import sys

INDIC_LANGS = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]
# we will be testing the overlaps of training data with all these benchmarks
# benchmarks = ['wat2021-devtest', 'wat2020-devtest', 'wat-2018', 'wmt-news', 'ufal-ta', 'pmi']


def read_lines(path):
    # if path doesnt exist, return empty list
    if not os.path.exists(path):
        return []
    with open(path, "r") as f:
        lines = f.readlines()
    return lines


def create_txt(outFile, lines):
    add_newline = not "\n" in lines[0]
    outfile = open("{0}".format(outFile), "w")
    for line in lines:
        if add_newline:
            outfile.write(line + "\n")
        else:
            outfile.write(line)

    outfile.close()


def pair_dedup_files(src_file, tgt_file):
    src_lines = read_lines(src_file)
    tgt_lines = read_lines(tgt_file)
    len_before = len(src_lines)

    src_dedupped, tgt_dedupped = pair_dedup_lists(src_lines, tgt_lines)

    len_after = len(src_dedupped)
    num_duplicates = len_before - len_after

    print(f"Dropped duplicate pairs in {src_file} Num duplicates -> {num_duplicates}")
    create_txt(src_file, src_dedupped)
    create_txt(tgt_file, tgt_dedupped)


def pair_dedup_lists(src_list, tgt_list):
    src_tgt = list(set(zip(src_list, tgt_list)))
    src_deduped, tgt_deduped = zip(*src_tgt)
    return src_deduped, tgt_deduped


def strip_and_normalize(line):
    # lowercase line, remove spaces and strip punctuation

    # one of the fastest way to add an exclusion list and remove that
    # list of characters from a string
    # https://towardsdatascience.com/how-to-efficiently-remove-punctuations-from-a-string-899ad4a059fb
    exclist = string.punctuation + "\u0964"
    table_ = str.maketrans("", "", exclist)

    line = line.replace(" ", "").lower()
    # dont use this method, it is painfully slow
    # line = "".join([i for i in line if i not in string.punctuation])
    line = line.translate(table_)
    return line


def expand_tupled_list(list_of_tuples):
    # convert list of tuples into two lists
    # https://stackoverflow.com/questions/8081545/how-to-convert-list-of-tuples-to-multiple-lists
    # [(en, as), (as, bn), (bn, gu)] - > [en, as, bn], [as, bn, gu]
    list_a, list_b = map(list, zip(*list_of_tuples))
    return list_a, list_b


def get_src_tgt_lang_lists(many2many=False):
    if many2many is False:
        SRC_LANGS = ["en"]
        TGT_LANGS = INDIC_LANGS
    else:
        all_languages = INDIC_LANGS + ["en"]
        # lang_pairs = list(permutations(all_languages, 2))

        SRC_LANGS, TGT_LANGS = all_languages, all_languages

    return SRC_LANGS, TGT_LANGS


def normalize_and_gather_all_benchmarks(devtest_dir, many2many=False):

    # This is a dict of dict of lists
    # the first keys are for lang-pair, the second keys are for src/tgt
    # the values are the devtest lines.
    # so devtest_pairs_normalized[en-as][src] will store src(en lines)
    # so devtest_pairs_normalized[en-as][tgt] will store tgt(as lines)
    devtest_pairs_normalized = defaultdict(lambda: defaultdict(list))
    SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many)
    benchmarks = os.listdir(devtest_dir)
    for dataset in benchmarks:
        for src_lang in SRC_LANGS:
            for tgt_lang in TGT_LANGS:
                if src_lang == tgt_lang:
                    continue
                if dataset == "wat2021-devtest":
                    # wat2021 dev and test sets have differnet folder structure
                    src_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{src_lang}")
                    tgt_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{tgt_lang}")
                    src_test = read_lines(f"{devtest_dir}/{dataset}/test.{src_lang}")
                    tgt_test = read_lines(f"{devtest_dir}/{dataset}/test.{tgt_lang}")
                else:
                    src_dev = read_lines(
                        f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{src_lang}"
                    )
                    tgt_dev = read_lines(
                        f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{tgt_lang}"
                    )
                    src_test = read_lines(
                        f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{src_lang}"
                    )
                    tgt_test = read_lines(
                        f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{tgt_lang}"
                    )

                # if the tgt_pair data doesnt exist for a particular test set,
                # it will be an empty list
                if tgt_test == [] or tgt_dev == []:
                    # print(f'{dataset} does not have {src_lang}-{tgt_lang} data')
                    continue

                # combine both dev and test sets into one
                src_devtest = src_dev + src_test
                tgt_devtest = tgt_dev + tgt_test

                src_devtest = [strip_and_normalize(line) for line in src_devtest]
                tgt_devtest = [strip_and_normalize(line) for line in tgt_devtest]

                devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"].extend(
                    src_devtest
                )
                devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"].extend(
                    tgt_devtest
                )

    # dedup merged benchmark datasets
    for src_lang in SRC_LANGS:
        for tgt_lang in TGT_LANGS:
            if src_lang == tgt_lang:
                continue
            src_devtest, tgt_devtest = (
                devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"],
                devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"],
            )
            # if the devtest data doesnt exist for the src-tgt pair then continue
            if src_devtest == [] or tgt_devtest == []:
                continue
            src_devtest, tgt_devtest = pair_dedup_lists(src_devtest, tgt_devtest)
            (
                devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"],
                devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"],
            ) = (
                src_devtest,
                tgt_devtest,
            )

    return devtest_pairs_normalized


def remove_train_devtest_overlaps(train_dir, devtest_dir, many2many=False):

    devtest_pairs_normalized = normalize_and_gather_all_benchmarks(
        devtest_dir, many2many
    )

    SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many)

    if not many2many:
        all_src_sentences_normalized = []
        for key in devtest_pairs_normalized:
            all_src_sentences_normalized.extend(devtest_pairs_normalized[key]["src"])
        # remove all duplicates. Now this contains all the normalized
        # english sentences in all test benchmarks across all lang pair
        all_src_sentences_normalized = list(set(all_src_sentences_normalized))
    else:
        all_src_sentences_normalized = None

    src_overlaps = []
    tgt_overlaps = []
    for src_lang in SRC_LANGS:
        for tgt_lang in TGT_LANGS:
            if src_lang == tgt_lang:
                continue
            new_src_train = []
            new_tgt_train = []

            pair = f"{src_lang}-{tgt_lang}"
            src_train = read_lines(f"{train_dir}/{pair}/train.{src_lang}")
            tgt_train = read_lines(f"{train_dir}/{pair}/train.{tgt_lang}")

            len_before = len(src_train)
            if len_before == 0:
                continue

            src_train_normalized = [strip_and_normalize(line) for line in src_train]
            tgt_train_normalized = [strip_and_normalize(line) for line in tgt_train]

            if all_src_sentences_normalized:
                src_devtest_normalized = all_src_sentences_normalized
            else:
                src_devtest_normalized = devtest_pairs_normalized[pair]["src"]

            tgt_devtest_normalized = devtest_pairs_normalized[pair]["tgt"]

            # compute all src and tgt super strict overlaps for a lang pair
            overlaps = set(src_train_normalized) & set(src_devtest_normalized)
            src_overlaps.extend(list(overlaps))

            overlaps = set(tgt_train_normalized) & set(tgt_devtest_normalized)
            tgt_overlaps.extend(list(overlaps))
            # dictionaries offer o(1) lookup
            src_overlaps_dict = {}
            tgt_overlaps_dict = {}
            for line in src_overlaps:
                src_overlaps_dict[line] = 1
            for line in tgt_overlaps:
                tgt_overlaps_dict[line] = 1

            # loop to remove the ovelapped data
            idx = -1
            for src_line_norm, tgt_line_norm in tqdm(
                zip(src_train_normalized, tgt_train_normalized), total=len_before
            ):
                idx += 1
                if src_overlaps_dict.get(src_line_norm, None):
                    continue
                if tgt_overlaps_dict.get(tgt_line_norm, None):
                    continue
                new_src_train.append(src_train[idx])
                new_tgt_train.append(tgt_train[idx])

            len_after = len(new_src_train)
            print(
                f"Detected overlaps between train and devetest for {pair} is {len_before - len_after}"
            )
            print(f"saving new files at {train_dir}/{pair}/")
            create_txt(f"{train_dir}/{pair}/train.{src_lang}", new_src_train)
            create_txt(f"{train_dir}/{pair}/train.{tgt_lang}", new_tgt_train)


if __name__ == "__main__":
    train_data_dir = sys.argv[1]
    # benchmarks directory should contains all the test sets
    devtest_data_dir = sys.argv[2]
    if len(sys.argv) == 3:
        many2many = False
    elif len(sys.argv) == 4:
        many2many = sys.argv[4]
        if many2many.lower() == "true":
            many2many = True
        else:
            many2many = False
    remove_train_devtest_overlaps(train_data_dir, devtest_data_dir, many2many)