import re
import numpy as np

import tiktoken
from langchain.text_splitter import TokenTextSplitter

def strtobool(val):
    val = val.lower()
    if val in ('yes', 'true', 't', '1'):
        return True
    elif val in ('no', 'false', 'f', '0'):
        return False
    else:
        raise ValueError(f"Invalid truth value {val}")


def split_camel_case(word):
    # This regular expression pattern matches the transition from a lowercase letter to an uppercase letter
    pattern = re.compile(r'(?<=[a-z])(?=[A-Z])')

    # Replace the matched pattern (the empty string between lowercase and uppercase letters) with a space
    split_word = pattern.sub(' ', word)

    return split_word


# Function to split tokens into chunks
def chunk_tokens(tokens, max_len):
    for i in range(0, len(tokens), max_len):
        yield tokens[i:i + max_len]


def update_nested_dict(d, u):
    for k, v in u.items():
        if isinstance(v, dict):
            d[k] = update_nested_dict(d.get(k, {}), v)
        else:
            d[k] = v
    return d


def cleanInputText(textInputLLM):

    # Sequentially applying all the replacements and cleaning operations on textInputLLM

    # Using regular expressions substitution
    textInputLLM = re.sub(r'\(\'\\n\\n', ' ', textInputLLM)
    textInputLLM = re.sub(r'\(\"\\n\\n', ' ', textInputLLM)
    textInputLLM = re.sub(r'\\n\\n\',\)', ' ', textInputLLM)
    textInputLLM = re.sub(r'\\n\\n\",\)', ' ', textInputLLM)

    # Applying replacements with while loops since we need repetition until conditions are met
    while re.search(r'##\n', textInputLLM):
        textInputLLM = re.sub(r"##\n", '. ', textInputLLM)
    while '###' in textInputLLM:
        textInputLLM = textInputLLM.replace("###", ' ')
    while '##' in textInputLLM:
        textInputLLM = textInputLLM.replace("##", ' ')
    while ' # ' in textInputLLM:
        textInputLLM = textInputLLM.replace(" # ", ' ')
    while '--' in textInputLLM:
        textInputLLM = textInputLLM.replace("--", '-')
    while re.search(r'\\\\-', textInputLLM):
        textInputLLM = re.sub(r"\\\\-", '.', textInputLLM)
    while re.search(r'\*\*\n', textInputLLM):
        textInputLLM = re.sub(r"\*\*\n", '. ', textInputLLM)
    while re.search(r'\*\*\*', textInputLLM):
        textInputLLM = re.sub(r"\*\*\*", ' ', textInputLLM)
    while re.search(r'\*\*', textInputLLM):
        textInputLLM = re.sub(r"\*\*", ' ', textInputLLM)
    while re.search(r' \* ', textInputLLM):
        textInputLLM = re.sub(r" \* ", ' ', textInputLLM)
    while re.search(r'is a program of the\n\nInternational Society for Infectious Diseases', textInputLLM):
        textInputLLM = re.sub(
            r'is a program of the\n\nInternational Society for Infectious Diseases',
            'is a program of the International Society for Infectious Diseases',
            textInputLLM,
            flags=re.M
        )

    # Optionally, if you want to include these commented out operations:
    # while re.search(r'\n\n', textInputLLM):
    #     textInputLLM = re.sub(r'\n\n', '. ', textInputLLM)
    # while re.search(r'\n', textInputLLM):
    #     textInputLLM = re.sub(r'\n', ' ', textInputLLM)

    while re.search(r' \*\.', textInputLLM):
        textInputLLM = re.sub(r' \*\.', ' .', textInputLLM)
    while '  ' in textInputLLM:
        textInputLLM = textInputLLM.replace("  ", ' ')
    while re.search(r'\.\.', textInputLLM):
        textInputLLM = re.sub(r'\.\.', '.', textInputLLM)
    while re.search(r'\. \.', textInputLLM):
        textInputLLM = re.sub(r'\. \.', '.', textInputLLM)

    # Final cleanup replacements
    textInputLLM = re.sub(r'\(\"\.', ' ', textInputLLM)
    textInputLLM = re.sub(r'\(\'\.', ' ', textInputLLM)
    textInputLLM = re.sub(r'\",\)', ' ', textInputLLM)
    textInputLLM = re.sub(r'\',\)', ' ', textInputLLM)

    # Strip leading/trailing whitespaces
    textInputLLM = textInputLLM.strip()

    return textInputLLM



def encoding_getter(encoding_type: str):
    """
    Returns the appropriate encoding based on the given encoding type (either an encoding string or a model name).

    tiktoken supports three encodings used by OpenAI models:

    Encoding name	OpenAI models
    cl100k_base	gpt-4, gpt-3.5-turbo, text-embedding-ada-002
    p50k_base	Codex models, text-davinci-002, text-davinci-003
    r50k_base (or gpt2)	GPT-3 models like davinci

    https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb

    """
    if "k_base" in encoding_type:
        return tiktoken.get_encoding(encoding_type)
    else:
        try:
            my_enc = tiktoken.encoding_for_model(encoding_type)
            return my_enc
        except Exception as err:
            my_enc = tiktoken.get_encoding("cl100k_base")   #default for gpt-4, gpt-3.5-turbo
            return my_enc


def tokenizer(string: str, encoding_type: str) -> list:
    """
    Returns the tokens in a text string using the specified encoding.
    """
    encoding = encoding_getter(encoding_type)
    tokens = encoding.encode(string)
    return tokens


def token_counter(string: str, encoding_type: str) -> int:
    """
    Returns the number of tokens in a text string using the specified encoding.
    """
    num_tokens = len(tokenizer(string, encoding_type))
    return num_tokens


# Function to extract words from a given text
def extract_words(text, putInLower=False):
    # Use regex to find all words (sequences of alphanumeric characters)
    if putInLower:
        return [word.lower() for word in re.findall(r'\b\w+\b', text)]
    else:
        return [word for word in re.findall(r'\b\w+\b', text)]  #re.findall(r'\b\w+\b', text)

# Function to check if all words from 'compound_word' are in the 'word_list'
def all_words_in_list(compound_word, word_list, putInLower=False):
    words_to_check = extract_words(compound_word, putInLower=putInLower)
    if putInLower:
        return all(word.lower() in word_list for word in words_to_check)
    else:
        return all(word in word_list for word in words_to_check)


def row_to_dict_string(rrrow, columnsDict):
    formatted_items = []
    for col in rrrow.index:
        if col not in columnsDict:
            continue
        value = rrrow[col]
        # Check if the value is an instance of a number (int, float, etc.)
        if isinstance(value, (int, float)):
            formatted_items.append(f'"{col}": {value}')  # Use double quotes for keys
        else:
            formatted_items.append(
                f'"{col}": "{value}"')  # Use double quotes for keys and string values
    # Join items and enclose them in {}
    return '{' + ', '.join(formatted_items) + '}'

#
# def row_to_dict_string(rrrow):
#     formatted_items = []
#     for col in rrrow.index:
#         value = rrrow[col]
#         # Check if the value is an instance of a number (int, float, etc.)
#         if isinstance(value, (int, float)):
#             formatted_items.append(f"'{col}': {value}")
#         else:
#             formatted_items.append(f"'{col}': '{value}'")
#     # Join items and enclose them in {}
#     return '{' + ', '.join(formatted_items) + '}'


def rescale_exponential_to_linear(df, column, new_min=0.5, new_max=1.0):
    # Get the original exponential scores
    original_scores = df[column]

    # Normalize the scores to a 0-1 range
    min_score = original_scores.min()
    max_score = original_scores.max()
    normalized_scores = (original_scores - min_score) / (max_score - min_score)

    # Rescale the normalized scores to the interval [0.5, 1.0]
    linear_scores = new_min + (normalized_scores * (new_max - new_min))

    # Assign the linear scores back to the dataframe
    df[column] = linear_scores

    return df


def rescale_exponential_to_logarithmic(df, column, new_min=0.5, new_max=1.0):
    # Ensure all values are positive and greater than zero, because log(0) is undefined
    epsilon = 1e-10
    df[column] = df[column] + epsilon

    # Apply logarithmic transformation
    log_transformed_scores = np.log(df[column])

    # Normalize the log-transformed scores to a 0-1 range
    min_score = log_transformed_scores.min()
    max_score = log_transformed_scores.max()
    normalized_log_scores = (log_transformed_scores - min_score) / (max_score - min_score)

    # Rescale the normalized scores to the interval [0.5, 1.0]
    logarithmic_scores = new_min + (normalized_log_scores * (new_max - new_min))

    # Assign the logarithmically scaled scores back to the dataframe
    df[column] = logarithmic_scores

    return df