import csv
import numpy as np

from rdkit import Chem
from rdkit.Chem import MolStandardize
from src import metrics
from src.delinker_utils import sascorer, calc_SC_RDKit
from tqdm import tqdm

from pdb import set_trace


def get_valid_as_in_delinker(data, progress=False):
    valid = []
    generator = tqdm(enumerate(data), total=len(data)) if progress else enumerate(data)
    for i, m in generator:
        pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=False)
        true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=False)
        frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=False)

        pred_mol_frags = Chem.GetMolFrags(pred_mol, asMols=True, sanitizeFrags=False)
        pred_mol_filtered = max(pred_mol_frags, default=pred_mol, key=lambda mol: mol.GetNumAtoms())

        try:
            Chem.SanitizeMol(pred_mol_filtered)
            Chem.SanitizeMol(true_mol)
            Chem.SanitizeMol(frag)
        except:
            continue

        if len(pred_mol_filtered.GetSubstructMatch(frag)) > 0:
            valid.append({
                'pred_mol': m['pred_mol'],
                'true_mol': m['true_mol'],
                'pred_mol_smi': Chem.MolToSmiles(pred_mol_filtered),
                'true_mol_smi': Chem.MolToSmiles(true_mol),
                'frag_smi': Chem.MolToSmiles(frag)
            })

    return valid


def extract_linker_smiles(molecule, fragments):
    match = molecule.GetSubstructMatch(fragments)
    elinker = Chem.EditableMol(molecule)
    for atom_id in sorted(match, reverse=True):
        elinker.RemoveAtom(atom_id)
    linker = elinker.GetMol()
    Chem.RemoveStereochemistry(linker)
    try:
        linker = MolStandardize.canonicalize_tautomer_smiles(Chem.MolToSmiles(linker))
    except:
        linker = Chem.MolToSmiles(linker)
    return linker


def compute_and_add_linker_smiles(data, progress=False):
    data_with_linkers = []
    generator = tqdm(data) if progress else data
    for m in generator:
        pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True)
        true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True)
        frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=True)

        pred_linker = extract_linker_smiles(pred_mol, frag)
        true_linker = extract_linker_smiles(true_mol, frag)
        data_with_linkers.append({
            **m,
            'pred_linker': pred_linker,
            'true_linker': true_linker,
        })

    return data_with_linkers


def compute_uniqueness(data, progress=False):
    mol_dictionary = {}
    generator = tqdm(data) if progress else data
    for m in generator:
        frag = m['frag_smi']
        pred_mol = m['pred_mol_smi']
        true_mol = m['true_mol_smi']

        key = f'{true_mol}.{frag}'
        mol_dictionary.setdefault(key, []).append(pred_mol)

    total_mol = 0
    unique_mol = 0
    for molecules in mol_dictionary.values():
        total_mol += len(molecules)
        unique_mol += len(set(molecules))

    return unique_mol / total_mol


def compute_novelty(data, progress=False):
    novel = 0
    true_linkers = set([m['true_linker'] for m in data])
    generator = tqdm(data) if progress else data
    for m in generator:
        pred_linker = m['pred_linker']
        if pred_linker in true_linkers:
            continue
        else:
            novel += 1

    return novel / len(data)


def compute_recovery_rate(data, progress=False):
    total = set()
    recovered = set()
    generator = tqdm(data) if progress else data
    for m in generator:
        pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True)
        Chem.RemoveStereochemistry(pred_mol)
        pred_mol = Chem.MolToSmiles(Chem.RemoveHs(pred_mol))

        true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True)
        Chem.RemoveStereochemistry(true_mol)
        true_mol = Chem.MolToSmiles(Chem.RemoveHs(true_mol))

        true_link = m['true_linker']
        total.add(f'{true_mol}.{true_link}')
        if pred_mol == true_mol:
            recovered.add(f'{true_mol}.{true_link}')

    return len(recovered) / len(total)


def calc_sa_score_mol(mol):
    if mol is None:
        return None
    return sascorer.calculateScore(mol)


def check_ring_filter(linker):
    check = True
    # Get linker rings
    ssr = Chem.GetSymmSSSR(linker)
    # Check rings
    for ring in ssr:
        for atom_idx in ring:
            for bond in linker.GetAtomWithIdx(atom_idx).GetBonds():
                if bond.GetBondType() == 2 and bond.GetBeginAtomIdx() in ring and bond.GetEndAtomIdx() in ring:
                    check = False
    return check


def check_pains(mol, pains_smarts):
    for pain in pains_smarts:
        if mol.HasSubstructMatch(pain):
            return False
    return True


def calc_2d_filters(toks, pains_smarts):
    pred_mol = Chem.MolFromSmiles(toks['pred_mol_smi'])
    frag = Chem.MolFromSmiles(toks['frag_smi'])
    linker = Chem.MolFromSmiles(toks['pred_linker'])

    result = [False, False, False]
    if len(pred_mol.GetSubstructMatch(frag)) > 0:
        sa_score = False
        ra_score = False
        pains_score = False

        try:
            sa_score = calc_sa_score_mol(pred_mol) < calc_sa_score_mol(frag)
        except Exception as e:
            print(f'Could not compute SA score: {e}')
        try:
            ra_score = check_ring_filter(linker)
        except Exception as e:
            print(f'Could not compute RA score: {e}')
        try:
            pains_score = check_pains(pred_mol, pains_smarts)
        except Exception as e:
            print(f'Could not compute PAINS score: {e}')

        result = [sa_score, ra_score, pains_score]

    return result


def calc_filters_2d_dataset(data):
    with open('models/wehi_pains.csv', 'r') as f:
        pains_smarts = [Chem.MolFromSmarts(line[0], mergeHs=True) for line in csv.reader(f)]

    pass_all = pass_SA = pass_RA = pass_PAINS = 0
    for m in data:
        filters_2d = calc_2d_filters(m, pains_smarts)
        pass_all += filters_2d[0] & filters_2d[1] & filters_2d[2]
        pass_SA += filters_2d[0]
        pass_RA += filters_2d[1]
        pass_PAINS += filters_2d[2]

    return pass_all / len(data), pass_SA / len(data), pass_RA / len(data), pass_PAINS / len(data)


def calc_sc_rdkit_full_mol(gen_mol, ref_mol):
    try:
        score = calc_SC_RDKit.calc_SC_RDKit_score(gen_mol, ref_mol)
        return score
    except:
        return -0.5


def sc_rdkit_score(data):
    scores = []
    for m in data:
        score = calc_sc_rdkit_full_mol(m['pred_mol'], m['true_mol'])
        scores.append(score)

    return np.mean(scores)


def get_delinker_metrics(pred_molecules, true_molecules, true_fragments):
    default_values = {
        'DeLinker/validity': 0,
        'DeLinker/uniqueness': 0,
        'DeLinker/novelty': 0,
        'DeLinker/recovery': 0,
        'DeLinker/2D_filters': 0,
        'DeLinker/2D_filters_SA': 0,
        'DeLinker/2D_filters_RA': 0,
        'DeLinker/2D_filters_PAINS': 0,
        'DeLinker/SC_RDKit': 0,
    }
    if len(pred_molecules) == 0:
        return default_values

    data = []
    for pred_mol, true_mol, true_frag in zip(pred_molecules, true_molecules, true_fragments):
        data.append({
            'pred_mol': pred_mol,
            'true_mol': true_mol,
            'pred_mol_smi': Chem.MolToSmiles(pred_mol),
            'true_mol_smi': Chem.MolToSmiles(true_mol),
            'frag_smi': Chem.MolToSmiles(true_frag)
        })

    # Validity according to DeLinker paper:
    # Passing rdkit.Chem.Sanitize and the biggest fragment contains both fragments
    valid_data = get_valid_as_in_delinker(data)
    validity_as_in_delinker = len(valid_data) / len(data)
    if len(valid_data) == 0:
        return default_values

    # Compute linkers and add to results
    valid_data = compute_and_add_linker_smiles(valid_data)

    # Compute uniqueness
    uniqueness = compute_uniqueness(valid_data)

    # Compute novelty
    novelty = compute_novelty(valid_data)

    # Compute recovered molecules
    recovery_rate = compute_recovery_rate(valid_data)

    # 2D filters
    pass_all, pass_SA, pass_RA, pass_PAINS = calc_filters_2d_dataset(valid_data)

    # 3D Filters
    sc_rdkit = sc_rdkit_score(valid_data)

    return {
        'DeLinker/validity': validity_as_in_delinker,
        'DeLinker/uniqueness': uniqueness,
        'DeLinker/novelty': novelty,
        'DeLinker/recovery': recovery_rate,
        'DeLinker/2D_filters': pass_all,
        'DeLinker/2D_filters_SA': pass_SA,
        'DeLinker/2D_filters_RA': pass_RA,
        'DeLinker/2D_filters_PAINS': pass_PAINS,
        'DeLinker/SC_RDKit': sc_rdkit,
    }