# -*- coding: utf-8 -*-
import re
from os.path import join, abspath, dirname
from collections import defaultdict
import epitran

epi = epitran.Epitran("deu-Latn-nar")


def mode_type(mode_in):
    """In the case of "sql", this will return an sqlite cursor."""
    if mode_in.lower() == "sql":
        import sqlite3
        conn = sqlite3.connect(join(abspath(dirname(__file__)), "./Resources/de.db"))
        return conn.cursor()


#TESTS
#NUMBERS ARE TOO HARD!



def preprocess(words):
    """Returns a string of words stripped of punctuation"""
    punct_str = '!"#$%&\'()*+,-./:;<=>/?@[\\]^_`{|}~«» '
    return ' '.join([w.strip(punct_str).lower() for w in words.split()])


def preserve_punc(words):
    """converts words to IPA and finds punctuation before and after the word."""
    words_preserved = []
    for w in words.split():
        punct_list = ["", preprocess(w), ""]
        before = re.search("^([^A-Za-z0-9]+)[A-Za-z]", w)
        after = re.search("[A-Za-z]([^A-Za-z0-9]+)$", w)
        if before:
            punct_list[0] = str(before.group(1))
        if after:
            punct_list[2] = str(after.group(1))
        words_preserved.append(punct_list)
    return words_preserved



def apply_punct(triple, as_str=False):
    """places surrounding punctuation back on center on a list of preserve_punc triples"""
    if type(triple[0]) == list:
        for i, t in enumerate(triple):
            triple[i] = str(''.join(triple[i]))
        if as_str:
            return ' '.join(triple)
        return triple
    if as_str:
        return str(''.join(t for t in triple))
    return [''.join(t for t in triple)]


def _punct_replace_word(original, transcription):
    """Get the IPA transcription of word with the original punctuation marks"""
    for i, trans_list in enumerate(transcription):
        for j, item in enumerate(trans_list):
            triple = [original[i][0]] + [item] + [original[i][2]]
            transcription[i][j] = apply_punct(triple, as_str=True)
    return transcription


def fetch_words(words_in, db_type="sql"):
    """fetches a list of words from the database"""
    asset = mode_type(db_type)
    f_result = []
    if db_type.lower() == "sql":
        for word in words_in:
            asset.execute("SELECT Words, phonemes FROM De_words WHERE Words IN (?)", (word,))
            result = asset.fetchall()
            flag = True
            try:
                f_result.append(result.pop())
                flag = False
            except IndexError:
                pass
            if result == [] and flag is True:
                result = epi.transliterate(word)
                f_result.append((word, result))
        f_result = list(filter(None,f_result))
        f_set = set(f_result)
        d = defaultdict(list)
        for k, v in f_set:
            d[k].append(v)
        return list(d.items())

def get_deu(tokens_in, db_type="sql"):
    """query the SQL database for the words and return the phonemes in the order of user_in"""
    result = fetch_words(tokens_in, db_type)
    ordered = []
    for word in tokens_in:
        this_word = [[i[1] for i in result if i[0] == word]][0]
        if this_word:
            ordered.append(this_word[0])
        else:
            ordered.append(["__IGNORE__" + word])
    return ordered


def deu_to_ipa(deu_list, mark=True):
    """converts the deu word lists into IPA transcriptions"""
    symbols = {}
    ipa_list = []  # the final list of IPA tokens to be returned
    for word_list in deu_list:
        ipa_word_list = []  # the word list for each word
        for word in word_list:
            if re.sub("\d*", "", word.replace("__IGNORE__", "")) == "":
                    pass  # do not delete token if it's all numbers
            else:
                    word = re.sub("[0-9]", "", word)
            ipa_form = ''
            if word.startswith("__IGNORE__"):
                ipa_form = word.replace("__IGNORE__", "")
                # mark words we couldn't transliterate with an asterisk:

                if mark:
                    if not re.sub("\d*", "", ipa_form) == "":
                        ipa_form += "*"
            else:
                for piece in word.split(" "):
                    marked = False
                    unmarked = piece
                    if piece[0] in ["ˈ", "ˌ"] or piece[0] is None:
                        marked = True
                        mark = piece
                        unmarked = piece[1:]

                    if unmarked in symbols:
                        if marked:
                            ipa_form += mark + symbols[unmarked]
                        else:
                            ipa_form += symbols[unmarked]

                    else:
                        ipa_form += piece
            swap_list = [["ˈər", "əˈr"], ["ˈie", "iˈe"]]
            for sym in swap_list:
                if not ipa_form.startswith(sym[0]):
                    ipa_form = ipa_form.replace(sym[0], sym[1])
            ipa_word_list.append(ipa_form)
        ipa_list.append(sorted(list(set(ipa_word_list))))
    return ipa_list


def get_top(ipa_list):
    """Returns only the one result for a query. If multiple entries for words are found, only the first is used."""
    return ' '.join([word_list[-1] for word_list in ipa_list])


def get_all(ipa_list):
    """utilizes an algorithm to discover and return all possible combinations of IPA transcriptions"""
    final_size = 1
    for word_list in ipa_list:
        final_size *= len(word_list)
    list_all = ["" for s in range(final_size)]
    for i in range(len(ipa_list)):
        if i == 0:
            swtich_rate = final_size / len(ipa_list[i])
        else:
            swtich_rate /= len(ipa_list[i])
        k = 0
        for j in range(final_size):
            if (j+1) % int(swtich_rate) == 0:
                k += 1
            if k == len(ipa_list[i]):
                k = 0
            list_all[j] = list_all[j] + ipa_list[i][k] + " "
    return sorted([sent[:-1] for sent in list_all])


def ipa_list(words_in, keep_punct=True, db_type="sql"):
    """Returns a list of all the discovered IPA transcriptions for each word."""
    if type(words_in) == str:
        words = [preserve_punc(w.lower())[0] for w in words_in.split()]
    else:
        words = [preserve_punc(w.lower())[0] for w in words_in]
    deu = get_deu([w[1] for w in words], db_type=db_type)
    ipa = deu_to_ipa(deu)
    if keep_punct:
        ipa = _punct_replace_word(words, ipa)
    return ipa


def isin_deu(word, db_type="sql"):
    """checks if a word is in the deu dictionary. Doesn't strip punctuation.
    If given more than one word, returns True only if all words are present."""
    if type(word) == str:
        word = [preprocess(w) for w in word.split()]
    results = fetch_words(word, db_type)
    as_set = list(set(t[0] for t in results))
    return len(as_set) == len(set(word))

def replace_number(text):
    text = text.replace("1","eins ")
    text = text.replace("2","zwei ")
    text = text.replace("3","drei ")
    text = text.replace("4","vier ")
    text = text.replace("5","fünf ")
    text = text.replace("6","sechs ")
    text = text.replace("7","sieben ")
    text = text.replace("8","acht ")
    text = text.replace("9","neun ")
    text = text.replace("0","null ")
    return text



def convert(text, retrieve_all=False, keep_punct=True, mode="sql"):
    """takes either a string or list of German words and converts them to IPA"""
    text = replace_number(text)
    ipa = ipa_list(
                   words_in=text,
                   keep_punct=keep_punct,
                   db_type=mode)
    if retrieve_all:
        return get_all(ipa)
    return get_top(ipa)



_decimal_number_re = re.compile(r'\d+\,\d+')
_euros_pre = re.compile(r'€([0-9\,]*[0-9]+)')
_euros_re = re.compile(r'([0-9\,]*[0-9]+)€')
_ordinal_re = re.compile(r'(der |die |das )([0-9]+)\.')
_clock_re=re.compile(r'\d{1,2}\:\d{2}')
_number_re = re.compile(r'[0-9]+')

def base(text):
    text = text.replace("1", "eins ")
    text = text.replace("2", "zwei ")
    text = text.replace("3", "drei ")
    text = text.replace("4", "vier ")
    text = text.replace("5", "fünf ")
    text = text.replace("6", "sechs ")
    text = text.replace("7", "sieben ")
    text = text.replace("8", "acht ")
    text = text.replace("9", "neun ")
    text = text.replace("0", "null ")
    return text

def tens_to_word(num):
    tens = num[0]
    ones = num[1]
    ones_word = base(ones)

    if num =="10":
        return "zehn"
    elif num=="11":
        return "elf"
    elif num=="12":
        return "zwölf"

    if tens == "1":
        if ones == "6":
            ones_word = ones_word[:-1]
        elif ones == "7":
            ones_word = ones_word[:-2]
        return ones_word + "zehn"
    else:
        tens_word = base(tens)
        if ones == "1":
            ones_word = ones_word[:-1]
        if tens == "2":
            tens_word = "zwan"
        elif tens == "6":
            tens_word = tens_word[:-1]
        elif tens == "7":
            tens_word = tens_word[:-2]
        if tens == "3":
            tens_word += "ßig"
        else:
            tens_word += "zig"
        if ones == "0":
            return tens_word
        else:
            return ones_word + " und " + tens_word

def huns_to_word(num):
    huns = num[0]
    tens = num[1]

    if huns == "1":
        huns_word= "hundert"
    else:
        huns_word = base(huns)+" hundert"

    remain = num_to_word(num[1:])
    if remain != "":
        remain = " " + remain
    return huns_word + remain

def thos_to_word(num):
    thos = num[0]
    if thos == "1":
        thos_word= "tausend"
    else:
        thos_word = base(thos)+" tausend"
    remain=num_to_word(num[1:])
    if remain!="":
        remain=" "+remain
    return thos_word+remain

def num_to_word(num):
    num=num.lstrip("0")
    if num=="":
        return("")
    digit=len(num)
    if digit==1:
        return base(num)
    elif digit==2:
        return tens_to_word(num)
    elif digit == 3:
        return huns_to_word(num)
    elif digit == 4:
        return thos_to_word(num)
    else:
        return base(num)

def number_to_words(m):
    m=m.group(0).lstrip("0")
    if m=="":
        return"null"
    return num_to_word(m)

def _expand_ordinal(m):

    pre=m.group(1)
    m = m.group(2).lstrip("0")

    if m=="":
        return"NULL"
    num=int(m)
    if num<=19 & num>=1:
        if num ==1:
            return "erste"
        elif num==3:
            return "dritte"
        elif num==7:
            return "siebte"
        elif num==8:
            return "achte"
        else:
            return pre + num_to_word(m) + "te"
    else:
        return pre + num_to_word(m) + "ste"

def _expand_decimal(m):
    match=m.group(0)
    parts = match.split(',')
    if int(parts[0])==0:
        return '%s komma %s' % ("null", base(parts[1]))
    return '%s komma %s' % (num_to_word(parts[0]),base(parts[1]))

def _expand_euros(m):
    match = m.group(1)
    parts = match.split(',')
    if len(parts) > 2:
        return match + ' euro'  # Unexpected format
    euros = int(parts[0]) if parts[0] else 0
    cents = int(parts[1])*10 if len(parts) > 1 and parts[1] else 0
    if euros and cents:
        return '%s euro %s' % (euros, cents)
    elif euros:
        return '%s euro' % (euros)
    elif cents:
        return '%s cent' % (cents)
    else:
        return 'null euro'

def _expand_clock(m):
    match = m.group(0)
    parts = match.split(':')
    if int(parts[0]) == 0:
        return '%s Uhr %s' % ("null",num_to_word(parts[1]))
    elif int(parts[0]) == 1:
        return '%s Uhr %s' % ("ein", num_to_word(parts[1]))
    return '%s Uhr %s' % (num_to_word(parts[0]),num_to_word(parts[1]))

def normalize_numbers(text):
    text = re.sub(_euros_pre, _expand_euros, text)
    text = re.sub(_euros_re, _expand_euros, text)
    text = re.sub(_clock_re, _expand_clock, text)
    text = re.sub(_decimal_number_re, _expand_decimal, text)
    text = re.sub(_ordinal_re, _expand_ordinal, text)
    text = re.sub(_number_re, number_to_words, text)
    text=text.replace("  "," ")
    return text

def collapse_whitespace(text):
    return re.sub(r'\s+', ' ', text)

def mark_dark_l(text):
    return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)