import os
import re
import html
import string
import torch
import config
import unicodedata
from nltk.tokenize import word_tokenize

from dataset import XRayDataset
from model import EncoderDecoderNet
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split as sklearn_train_test_split


def load_dataset(raw_caption=False):
    return XRayDataset(
        root=config.DATASET_PATH,
        transform=config.basic_transforms,
        freq_threshold=config.VOCAB_THRESHOLD,
        raw_caption=raw_caption
    )


def get_model_instance(vocabulary):
    model = EncoderDecoderNet(
        features_size=config.FEATURES_SIZE,
        embed_size=config.EMBED_SIZE,
        hidden_size=config.HIDDEN_SIZE,
        vocabulary=vocabulary,
        encoder_checkpoint='./weights/chexnet.pth.tar'
    )
    model = model.to(config.DEVICE)

    return model

def train_test_split(dataset, test_size=0.25, random_state=44):
    train_idx, test_idx = sklearn_train_test_split(
        list(range(len(dataset))),
        test_size=test_size,
        random_state=random_state
    )

    return Subset(dataset, train_idx), Subset(dataset, test_idx)


def save_checkpoint(checkpoint):
    print('=> Saving checkpoint')

    torch.save(checkpoint, config.CHECKPOINT_FILE)


def load_checkpoint(model, optimizer=None):
    print('=> Loading checkpoint')

    checkpoint = torch.load(config.CHECKPOINT_FILE, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])

    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])

    return checkpoint['epoch']


def can_load_checkpoint():
    return os.path.exists(config.CHECKPOINT_FILE) and config.LOAD_MODEL


def remove_special_chars(text):
    re1 = re.compile(r'  +')
    x1 = text.lower().replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace(
        'nbsp;', ' ').replace('#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace(
        '<br />', "\n").replace('\\"', '"').replace('<unk>', 'u_n').replace(' @.@ ', '.').replace(
        ' @-@ ', '-').replace('\\', ' \\ ')

    return re1.sub(' ', html.unescape(x1))


def remove_non_ascii(text):
    return unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('utf-8', 'ignore')


def to_lowercase(text):
    return text.lower()


def remove_punctuation(text):
    translator = str.maketrans('', '', string.punctuation)
    return text.translate(translator)


def replace_numbers(text):
    return re.sub(r'\d+', '', text)


def text2words(text):
    return word_tokenize(text)


def normalize_text( text):
    text = remove_special_chars(text)
    text = remove_non_ascii(text)
    text = remove_punctuation(text)
    text = to_lowercase(text)
    text = replace_numbers(text)

    return text


def normalize_corpus(corpus):
    return [normalize_text(t) for t in corpus]