import streamlit as st
from groq import Groq
import sys
import io
import re
import traceback
from dotenv import load_dotenv
import os

load_dotenv()

# Set page configuration
st.set_page_config(
    page_title="AI Assistant with Code Execution",
    page_icon="🤖",
    layout="wide",
)

def main():
    st.title("AI Assistant with Code Execution")
    st.write("Interact with an AI assistant that can execute code when needed.")
    st.markdown("[Concept by cfahlgren1, check out his amazing work here](https://huggingface.co/spaces/cfahlgren1/qwen-2.5-code-interpreter)")
    st.markdown("Thanks Groq for the super fast model and thanks cfahlgren1 for the prompt and this idea!")

    # Sidebar settings
    st.sidebar.title("Settings")
    use_code_interpreter = st.sidebar.checkbox("Enable Code Interpreter", value=True)
    reset_button = st.sidebar.button("Reset Chat")

    # Initialize session state
    if 'messages' not in st.session_state or reset_button:
        st.session_state['messages'] = [
            {
                "role": "system",
                "content": (
                    "The user will ask you a tricky question, your job is to write Python code to answer the question. \n\n" +
                    "Really think step by step before writing any code to ensure you're answering the question correctly. \n\n" +
                    "Respond with a markdown code block starting with ```python and ``` at the end. Make sure the code can be executed without any changes"
                ),
            }
        ]
        if reset_button:
            st.session_state['user_input'] = ''
            st.rerun()

    # Display chat history (only user and assistant messages)
    for message in st.session_state['messages']:
        if message["role"] == "user":
            with st.chat_message("user"):
                st.markdown(message["content"])
        elif message["role"] == "assistant":
            with st.chat_message("assistant"):
                st.markdown(message["content"])
        # Internal messages (e.g., code execution results) are not displayed

    # User input
    if 'user_input' not in st.session_state:
        st.session_state['user_input'] = ''

    st.session_state['user_input'] = st.chat_input("Type your message")
    if st.session_state['user_input']:
        # Display user's message
        with st.chat_message("user"):
            st.markdown(st.session_state['user_input'])
        # Add user's message to conversation history
        st.session_state['messages'].append({"role": "user", "content": st.session_state['user_input']})

        # Generate assistant's response
        assistant_reply = get_assistant_response(st.session_state['messages'], use_code_interpreter)

        # Display assistant's final answer
        with st.chat_message("assistant"):
            st.markdown(assistant_reply)

        # Add assistant's final answer to conversation history
        st.session_state['messages'].append({"role": "assistant", "content": assistant_reply})

        # Clear user input
        st.session_state['user_input'] = ''

def get_assistant_response(conversation, use_code_interpreter):
    # Copy conversation to avoid modifying original
    messages = conversation.copy()

    # Initialize Groq client (Replace 'your_api_key' with your actual API key)
    client = Groq()

    # Get assistant's initial response (which may include code)
    completion = client.chat.completions.create(
        model="llama-3.2-3b-preview",
        messages=messages,
        temperature=1,
        max_tokens=1024,
        top_p=1,
        stream=False,
        stop=None,
    )
    assistant_reply = completion.choices[0].message.content
    # print(assistant_reply)
    # If code interpreter is enabled, check for code in assistant's reply
    if use_code_interpreter:
        code = extract_code(assistant_reply)
        if code:
            # Remove code from assistant's reply to avoid displaying it
            assistant_reply_no_code = remove_code_blocks(assistant_reply)

            # Execute the code
            execution_result = execute_code_safely(code)

            # Prepare internal message with execution result (not shown in UI chat)
            user_result_prompt = (
                f"The code you provided was executed and returned the result:\n{execution_result}\n"
                "Use this result to provide a final answer to the user's question. "
                "Do not mention the code or that you executed code."
            )

            # Add internal message to conversation history
            messages.append({"role": "user", "content": user_result_prompt})

            # Get assistant's final response using the execution result
            completion = client.chat.completions.create(
                model="llama-3.2-3b-preview",
                messages=messages,
                temperature=1,
                max_tokens=1024,
                top_p=1,
                stream=False,
                stop=None,
            )
            final_reply = completion.choices[0].message.content
            # Return assistant's final answer
            return final_reply.strip()
        else:
            # No code detected; return assistant's initial reply
            return assistant_reply.strip()
    else:
        # Code interpreter disabled; return assistant's initial reply
        return assistant_reply.strip()

def extract_code(text):
    # Extract code from code blocks
    code_blocks = re.findall(r"```(?:python)?\n(.*?)```", text, re.DOTALL)
    if code_blocks:
        return code_blocks[0]
    return None

def remove_code_blocks(text):
    # Remove code blocks from text
    return re.sub(r"```(?:python)?\n.*?```", '', text, flags=re.DOTALL).strip()

def execute_code_safely(code):
    """
    Executes the given code safely and returns the output.
    WARNING: Executing arbitrary code can be dangerous.
    This function uses restricted built-ins and a try-except block to handle errors.
    In a production environment, consider using a secure sandbox or a code execution service.
    """
    # Restrict built-ins
    safe_globals = {
        "__builtins__": {
            'abs': abs,
            'all': all,
            'any': any,
            'len': len,
            'max': max,
            'min': min,
            'sum': sum,
            'range': range,
            'print': print,
            'str': str,
            'int': int,
            'float': float,
            'bool': bool,
            'list': list,
            'dict': dict,
            'set': set,
            'tuple': tuple,
            'enumerate': enumerate,
            'zip': zip,
            'math': __import__('math'),
            'datetime': __import__('datetime'),
        }
    }
    safe_locals = {}

    # Capture output
    old_stdout = sys.stdout
    redirected_output = sys.stdout = io.StringIO()

    try:
        exec(code, safe_globals, safe_locals)
    except Exception as e:
        output = f"Error executing code:\n{traceback.format_exc()}"
    else:
        output = redirected_output.getvalue()
    finally:
        sys.stdout = old_stdout

    return output.strip()

if __name__ == "__main__":
    main()

# requirements.txt