import os
import pandas as pd
import textwrap
import random
from datasets import load_dataset
from IPython.display import display

class LMSYSChat1MHandler:
    def __init__(self, hf_token, streaming=False, verbose=True):
        self.hf_token = hf_token
        self.streaming = streaming
        self.lmsys_dataset = load_dataset(
            'lmsys/lmsys-chat-1m',
            revision="main",
            token=self.hf_token,
            streaming=self.streaming
        )
        self.verbose = verbose
        if verbose:
            print(self.lmsys_dataset)
        self.df_sample = None
        self.df_prompts = None
        self.unwrapped_turns_df = None

        if not self.streaming and verbose:
            print('Data is cached at:\n')
            for file_info in self.lmsys_dataset['train'].cache_files:
                filename = file_info['filename']
                file_size = os.path.getsize(filename)
                i = int((len(filename) - 41) / 2)
                print(f"Filename: {filename[:i]}*{filename[-41:]}\nSize: {file_size} bytes")

    def extract_df_sample(self, n_samples=None, conversation_ids=None):
        """
        Extracts a sample of conversations or specific conversations based on their conversation IDs.

        Parameters:
        - n_samples (int): Number of random samples to extract. Ignored if `conversation_ids` is provided.
        - conversation_ids (list): List of conversation IDs to extract. If provided, this takes precedence over `n_samples`.

        Returns:
        - pd.DataFrame: A DataFrame containing the extracted conversations.
        """
        if conversation_ids:
            # Filter conversations based on the provided conversation IDs
            df_sample = self.lmsys_dataset['train'].to_pandas()
            df_sample = df_sample[df_sample['conversation_id'].isin(conversation_ids)]
            print(f"Retrieved {len(df_sample)} conversations based on specified IDs")
        else:
            # Randomly sample conversations if no IDs are provided
            if not self.streaming:    
                df_sample = self.lmsys_dataset['train'].to_pandas().sample(n_samples)
                print(f"Retrieved {len(df_sample)} random conversations from lmsys/lmsys-chat-1m")
            else:
                # Take a sample from the streamed dataset
                streamed_samples = []
                for i, row in enumerate(self.lmsys_dataset['train']):
                    streamed_samples.append(row)
                    if i + 1 == n_samples:  # Collect only the desired number of samples
                        break
                # Shuffle and convert the collected samples to a Pandas DataFrame
                random.shuffle(streamed_samples)
                df_sample = pd.DataFrame(streamed_samples)
                
        self.df_sample = df_sample
        if self.verbose and len(df_sample) > 4:
            display(df_sample.head(2))
            print('...')
            display(df_sample.tail(2))
        return df_sample
    
    def parquet_sampling(self, n_samples):
        base_url = "https://huggingface.co/datasets/lmsys/lmsys-chat-1m/resolve/main/data/"
        data_files = [
            "train-00000-of-00006-4feeb3f83346a0e9.parquet",
            "train-00001-of-00006-4030672591c2f478.parquet",
            "train-00002-of-00006-1779b7cec9462180.parquet",
            "train-00003-of-00006-2fa862bfed56af1f.parquet",
            "train-00004-of-00006-18f4bdd50c103e71.parquet",
            "train-00005-of-00006-fe1acc5d10a9f0e2.parquet"
        ]
        sample_file = random.choice(data_files)
        print(f"Sampling from {sample_file}")
        data_files = {"train": base_url + sample_file}
        parquet_sample = load_dataset("parquet", data_files=data_files, split="train")
        df_sample = parquet_sample.to_pandas().sample(n_samples)
        print(f"Retrieved {len(df_sample)} random conversations from lmsys/lmsys-chat-1m/{sample_file}")
        self.df_sample = df_sample
        if self.verbose and len(df_sample) > 4:
            display(df_sample.head(2))
            print('...')
            display(df_sample.tail(2))
        return df_sample

    def add_turns_to_conversations(self):
        """
        Adds 'turn' keys to each conversation in the 'conversation' column of the dataframe.
        """
        self.df_sample['conversation'] = self.df_sample['conversation'].apply(
            lambda conv: Conversation(conv).add_turns()
        )
        df_with_turns = self.df_sample
        return df_with_turns
    
    def unwrap_turns(self):
        """
        Creates a dataframe where each row corresponds to a pair of user-assistant messages in a conversation and turn.
        The 'prompt' column contains the user's message, and the 'response' column contains the assistant's message.
        Each row includes a 'turn_id' column, which numbers the turns uniquely per conversation.
        """
        paired_data = []
        for _, row in self.df_sample.iterrows():
            conversation_id = row['conversation_id']
            row_data = row.to_dict()
            row_data.pop('conversation')  # Remove the 'conversation' field as it's being unwrapped

            current_prompt = None
            turn_id = None

            for message in row['conversation']:
                if message['role'] == 'user':
                    current_prompt = message['content']
                    turn_id = f"{conversation_id}{message['turn']:03}"  # Create turn_id
                elif message['role'] == 'assistant' and current_prompt is not None:
                    # Create a new row with the user-assistant pair
                    paired_row = {
                        **row_data,
                        'turn_n': message['turn'],
                        'prompt': current_prompt,
                        'response': message['content'],
                    }
                    paired_data.append(paired_row)
                    current_prompt = None  # Reset after pairing

        unwrapped_turns_df = pd.DataFrame(paired_data)
        unwrapped_turns_df.rename(columns={"turn": "conversation_turns"}, inplace=True) # The naming in the original dataset is ambiguous
        self.unwrapped_turns_df = unwrapped_turns_df
        return unwrapped_turns_df

    def extract_prompts(self, filter_language=None, min_char_length=20, max_char_length=500, exclusions=None):
        """
        Extracts user prompts from the sample dataframe, optionally filtering by language and limiting the character length.

        Parameters:
        - filter_language (list of str or None): A list of specific languages to filter prompts by. If None, no language 
          filter is applied. Examples of valid values include ['English'], ['English', 'Portuguese'], or 
          ['Spanish', 'French', 'German'].
        - min_char_length (int): The minimum character length for user prompts to include. Defaults to 20.
        - max_char_length (int): The maximum character length for user prompts to include. Defaults to 500.
        - exclusions (str or None): Path to a text file containing phrases. Prompts containing any of these phrases
          will be excluded from the results. If None, no exclusions are applied.

        Returns:
        - pd.DataFrame: A DataFrame containing extracted prompts with columns 'prompt' and 'language'.
        """
        df_sample = self.df_sample
        if filter_language:
            extracted_data = df_sample[df_sample['language'].isin(filter_language)].apply(
                lambda row: [
                    {'content': entry['content'], 'language': row['language']}
                    for entry in row['conversation']
                    if entry['role'] == 'user' and min_char_length <= len(entry['content']) <= max_char_length
                ], axis=1
            ).explode().dropna()
        else:
            extracted_data = df_sample.apply(
                lambda row: [
                    {'content': entry['content'], 'language': row['language']}
                    for entry in row['conversation']
                    if entry['role'] == 'user' and min_char_length <= len(entry['content']) <= max_char_length
                ], axis=1
            ).explode().dropna()

        df_prompts = pd.DataFrame(extracted_data.tolist())
        df_prompts.rename(columns={'content': 'prompt'}, inplace=True)

        orig_length = len(df_prompts)
        if exclusions:
            # Excluding prompts with phrases that are repeated often in this dataset
            with open(exclusions, 'r') as f:
                exclusions = [line.strip() for line in f.readlines()]
            df_prompts = df_prompts[~df_prompts['prompt'].apply(lambda x: any(exclusion in x for exclusion in exclusions))]
            print(f"Excluded {orig_length - len(df_prompts)} prompts.")

        self.df_prompts = df_prompts
        if self.verbose and len(df_sample) > 4:
            display(df_prompts.head(2))
            print('...')
            display(df_prompts.tail(2))
        return df_prompts
    
    def extract_prompt_sample(self):
        prompt_sample = self.df_prompts.sample(1)['prompt'].values[0]
        if self.verbose:
            wrapped_message = textwrap.fill(prompt_sample, width=120)
            print(wrapped_message)
        return prompt_sample
    
    def search_conversations(self, search_term):
        """
        Searches the dataset for a given string and returns a DataFrame with matching records.

        Parameters:
        - search_term (str): The string to search for in the dataset.

        Returns:
        - pd.DataFrame: A DataFrame containing conversations where the search term is found.
        """
        if self.streaming:
            raise ValueError("Search is not supported in streaming mode.")
        df = self.lmsys_dataset['train'].to_pandas()
        # Filter rows where the search term appears in the 'conversation' column
        matching_records = df[df['conversation'].apply(
            lambda conv: any(search_term.lower() in message['content'].lower() for message in conv)
        )]
        if self.verbose:
            print(f"Found {len(matching_records)} matching conversations for search term: '{search_term}'")
        return matching_records
        
    def print_language_counts(self, df):
        language_counts = df['language'].value_counts()
        print("Language Record Counts:")
        print(language_counts.to_frame('Count').reset_index().rename(columns={'index': 'Language'}))


class Conversation:
    def __init__(self, conversation_data):
        """
        Initializes the Conversation object with the conversation data.

        Parameters:
        - conversation_data (list): A list of dictionaries representing a conversation.
        """
        self.conversation_data = conversation_data

    def add_turns(self):
        """
        Adds a 'turn' key to each dictionary in the conversation,
        identifying the turn (pair of user and assistant messages).

        Returns:
        - list: The updated conversation with 'turn' keys added.
        """
        turn_counter = 0
        for message in self.conversation_data:
            if message['role'] == 'user':
                turn_counter += 1
            message['turn'] = turn_counter
        return self.conversation_data
    
    def pretty_print(self, user_prefix, assistant_prefix, width=80):
        """
        Prints the conversation with specified prefixes and wrapped text.

        Parameters:
        - user_prefix (str): Prefix to prepend to user messages.
        - assistant_prefix (str): Prefix to prepend to assistant messages.
        - width (int): Maximum characters per line for wrapping.
        """
        wrapper = textwrap.TextWrapper(width=width)
        
        for message in self.conversation_data:
            if message['role'] == 'user':
                prefix = user_prefix
            elif message['role'] == 'assistant':
                prefix = assistant_prefix
            else:
                continue  # Ignore roles other than 'user' and 'assistant'
            
            # Split on existing newlines, wrap each line, and join back with newlines
            wrapped_content = "\n".join(
                wrapper.fill(line) for line in message['content'].splitlines()
            )
            print(f"{prefix} {wrapped_content}\n")