import uuid
import streamlit as st
from openai import AzureOpenAI
import firebase_admin
from firebase_admin import credentials, firestore
from typing import Dict, Any
import time
import os
import tempfile
import json

from utils.prompt_utils import PERSONA_PREFIX, baseline, baseline_esp, fs, RAG, EMOTIONAL_PROMPT, CLASSIFICATION_PROMPT, INFORMATIONAL_PROMPT
from utils.RAG_utils import load_or_create_vectorstore

# PERSONA_PREFIX = ""
# baseline = ""
# baseline_esp = ""
# fs = ""
# RAG = ""
# EMOTIONAL_PROMPT = ""
# CLASSIFICATION_PROMPT = """
# Determine si esta afirmación busca empatía o (1) o busca información (0).
# Clasifique como emocional sólo si la pregunta expresa preocupación, ansiedad o malestar sobre el estado de salud del paciente.
# En caso contrario, clasificar como informativo.

# Ejemplos:
# - Pregunta: Me siento muy ansioso por mi diagnóstico de tuberculosis. 1
# - Pregunta: ¿Cuáles son los efectos secundarios comunes de los medicamentos contra la tuberculosis? 0
# - Pregunta: Estoy preocupada porque tengo mucho dolor. 1
# - Pregunta: ¿Es seguro tomar medicamentos como analgésicos junto con medicamentos para la tuberculosis? 0

# Aquí está la declaración para clasificar. Simplemente responda con el número "1" o "0":
# """

# INFORMATIONAL_PROMPT = ""

# Model configurations remain the same
MODEL_CONFIGS = {
    # "Model 0: Naive English Baseline Model": {
    #     "name": "Model 0: Naive English Baseline Model",
    #     "prompt": PERSONA_PREFIX + baseline,
    #     "uses_rag": False,
    #     "uses_classification": False
    # },
    # "Model 1: Naive Spanish Baseline Model": {
    #     "name": "Model 1: Baseline Model",
    #     "prompt": PERSONA_PREFIX + baseline_esp,
    #     "uses_rag": False,
    #     "uses_classification": False
    # },
    # "Model 1": {
    #     "name": "Model 1: Few_Shot model",
    #     "prompt": PERSONA_PREFIX + fs,
    #     "uses_rag": False,
    #     "uses_classification": False
    # },
    # "Model 3: RAG Model": {F
    #     "name": "Model 3: RAG Model",
    #     "prompt": PERSONA_PREFIX + RAG,
    #     "uses_rag": True,
    #     "uses_classification": False
    # },
    # "Model 2": {
    #     "name": "Model 2: RAG + Few_Shot Model",
    #     "prompt": PERSONA_PREFIX + RAG + fs,
    #     "uses_rag": True,
    #     "uses_classification": False
    # },
    "Model 3": {
        "name": "Model 3: 2-Stage Classification Model",
        "prompt": PERSONA_PREFIX + INFORMATIONAL_PROMPT,  # default
        "uses_rag": False,
        "uses_classification": False
    },
    # "Model 6: Multi-Agent": {
    #     "name": "Model 6: Multi-Agent",
    #     "prompt": PERSONA_PREFIX + INFORMATIONAL_PROMPT,  # default
    #     "uses_rag": True,
    #     "uses_classification": True,
    #     "uses_judges": True
    # }
}
PASSCODE = os.environ["MY_PASSCODE"]
creds_dict = {
    "type": os.environ.get("FIREBASE_TYPE", "service_account"),
    "project_id": os.environ.get("FIREBASE_PROJECT_ID"),
    "private_key_id": os.environ.get("FIREBASE_PRIVATE_KEY_ID"),
    "private_key": os.environ.get("FIREBASE_PRIVATE_KEY", "").replace("\\n", "\n"),
    "client_email": os.environ.get("FIREBASE_CLIENT_EMAIL"),
    "client_id": os.environ.get("FIREBASE_CLIENT_ID"),
    "auth_uri": os.environ.get("FIREBASE_AUTH_URI", "https://accounts.google.com/o/oauth2/auth"),
    "token_uri": os.environ.get("FIREBASE_TOKEN_URI", "https://oauth2.googleapis.com/token"),
    "auth_provider_x509_cert_url": os.environ.get("FIREBASE_AUTH_PROVIDER_X509_CERT_URL", 
                                                "https://www.googleapis.com/oauth2/v1/certs"),
    "client_x509_cert_url": os.environ.get("FIREBASE_CLIENT_X509_CERT_URL"),
    "universe_domain": "googleapis.com"

}

# Create a temporary JSON file
file_path  = "coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json"
with open(file_path, 'w') as json_file:
    json.dump(creds_dict, json_file, indent=2)

# Initialize Firebase
if not firebase_admin._apps:
    cred = credentials.Certificate("coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json")
    firebase_admin.initialize_app(cred)
db = firestore.client()

endpoint = os.environ["ENDPOINT_URL"]
deployment = os.environ["DEPLOYMENT"]
subscription_key = os.environ["subscription_key"]

# OpenAI API setup
client = AzureOpenAI(
    azure_endpoint=endpoint,
    api_key=subscription_key,
    api_version=os.environ["api_version"]
)

def authenticate():
    import uuid
    
    random_id = uuid.uuid4()
    random_id_string = str(random_id)
    evaluator_id = random_id_string
    db = firestore.client()
    db.collection("evaluator_ids").document(evaluator_id).set({
        "evaluator_id": evaluator_id,
        "timestamp": firestore.SERVER_TIMESTAMP
    })

    # Update session state
    st.session_state["authenticated"] = True
    st.session_state["evaluator_id"] = evaluator_id


def init():
    """Initialize all necessary components and state variables"""
    # Initialize session state variables
    if "messages" not in st.session_state:
        st.session_state.messages = {}
    if "session_id" not in st.session_state:
        st.session_state.session_id = str(uuid.uuid4())
    if "chat_active" not in st.session_state:
        st.session_state.chat_active = False
    if "user_input" not in st.session_state:
        st.session_state.user_input = ""
    if "user_id" not in st.session_state:
        st.session_state.user_id = f"anonymous_{str(uuid.uuid4())}"
    if "selected_model" not in st.session_state:
        st.session_state.selected_model = list(MODEL_CONFIGS.keys())[0]
    if "model_profile" not in st.session_state:
        st.session_state.model_profile = [0, 0]
    
    # Load vectorstore at startup
    if "vectorstore" not in st.session_state:
        with st.spinner("Loading document embeddings..."):
            st.session_state.vectorstore = load_or_create_vectorstore()

def get_classification(client, deployment, user_input):
    """Classify the input as emotional (1) or informational (0)"""
    chat_prompt = [
        {"role": "system", "content": CLASSIFICATION_PROMPT},
        {"role": "user", "content": user_input}
    ]

    completion = client.chat.completions.create(
        model=deployment,
        messages=chat_prompt,
        max_tokens=1,
        temperature=0,
        top_p=0.9,
        frequency_penalty=0,
        presence_penalty=0,
        stop=None
    )
    
    return completion.choices[0].message.content.strip()
def process_input():
    try:
        current_model = st.session_state.selected_model
        user_input = st.session_state.user_input
        
        if not user_input.strip():
            st.warning("Please enter a message before sending.")
            return
            
        model_config = MODEL_CONFIGS.get(current_model)
        if not model_config:
            st.error("Invalid model selected. Please choose a valid model.")
            return

        if current_model not in st.session_state.messages:
            st.session_state.messages[current_model] = []

        st.session_state.messages[current_model].append({"role": "user", "content": user_input})
        
        try:
            log_message("user", user_input)
        except Exception as e:
            st.warning(f"Failed to log message: {str(e)}")
            
        conversation_history = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" 
                                        for msg in st.session_state.messages[current_model]])

        # Helper function for error handling in API calls
        def safe_api_call(messages, max_retries=3):
            for attempt in range(max_retries):
                try:
                    response = client.chat.completions.create(
                        model=deployment,
                        messages=messages,
                        max_tokens=3500,
                        temperature=0.1,
                        top_p=0.9
                    )
                    return response.choices[0].message.content.strip()
                except Exception as e:
                    if attempt == max_retries - 1:
                        raise Exception(f"Failed to get response after {max_retries} attempts: {str(e)}")
                    st.warning(f"Attempt {attempt + 1} failed, retrying...")
                    time.sleep(1)

        def perform_rag_query(input_text, conversation_history):
            try:
                relevant_docs = retrieve_relevant_documents(
                    st.session_state.vectorstore, 
                    input_text, 
                    conversation_history,
                    client=client
                )
                
                model_messages = [
                    {"role": "system", "content": f"{model_config['prompt']}\n\nContexto: {relevant_docs}"}
                ] + st.session_state.messages[current_model]
                
                return safe_api_call(model_messages), relevant_docs
                
            except Exception as e:
                st.error(f"Error in RAG query: {str(e)}")
                return "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente.", ""

        initial_response = None
        initial_docs = ""

        # Handle 2-stage model
        if model_config.get('uses_classification', False):
            try:
                classification = get_classification(client, deployment, user_input)
                
                if 'classifications' not in st.session_state:
                    st.session_state.classifications = {}
                st.session_state.classifications[len(st.session_state.messages[current_model]) - 1] = classification

                if classification == "0":
                    initial_response, initial_docs = perform_rag_query(user_input, conversation_history)
                else:
                    model_messages = [
                        {"role": "system", "content": PERSONA_PREFIX + EMOTIONAL_PROMPT}
                    ] + st.session_state.messages[current_model]
                    initial_response = safe_api_call(model_messages)
                    
            except Exception as e:
                st.error(f"Error in classification stage: {str(e)}")
                initial_response = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."

        # Handle RAG models
        if model_config.get('uses_rag', False):
            try:
                if not initial_response:
                    initial_response, initial_docs = perform_rag_query(user_input, conversation_history)

                verification_docs = retrieve_relevant_documents(
                    st.session_state.vectorstore, 
                    initial_response,
                    conversation_history,
                    client=client
                )

                combined_docs = initial_docs + "\nContexto de verificación adicional:\n" + verification_docs

                verification_messages = [
                    {
                        "role": "system", 
                        "content": f"Pregunta del paciente:{user_input} \nContexto: {combined_docs} \nRespuesta anterior: {initial_response}\n Verifique la precisión médica de la respuesta anterior y refine la respuesta según el contexto adicional."
                    }
                ]

                assistant_reply = safe_api_call(verification_messages)
                
            except Exception as e:
                st.error(f"Error in RAG processing: {str(e)}")
                assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."
        else:
            try:
                model_messages = [
                    {"role": "system", "content": model_config['prompt']}
                ] + st.session_state.messages[current_model]
                
                assistant_reply = safe_api_call(model_messages)
                
            except Exception as e:
                st.error(f"Error generating response: {str(e)}")
                assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."

        # Store and log the final response
        try:
            st.session_state.messages[current_model].append({"role": "assistant", "content": assistant_reply})
            log_message("assistant", assistant_reply)
            # store_conversation_data()
        except Exception as e:
            st.warning(f"Failed to store or log response: {str(e)}")

        st.session_state.user_input = ""
        
    except Exception as e:
        st.error(f"An unexpected error occurred: {str(e)}")
        st.session_state.user_input = ""


def check_document_relevance(query, doc, client):
    """
    Check document relevance using few-shot prompting for Spanish TB context.
    
    Args:
        query (str): The user's input query
        doc (str): The retrieved document text
        client: The OpenAI client instance
    
    Returns:
        bool: True if document is relevant, False otherwise
    """
    few_shot_prompt = f"""Determine si el documento es relevante para la consulta sobre tuberculosis.
        Responde únicamente 'sí' si es relevante o 'no' si no es relevante.
        Ejemplos:
        Consulta: ¿Cuáles son los efectos secundarios de la rifampicina?
        Documento: La rifampicina puede causar efectos secundarios como náuseas, vómitos y coloración naranja de fluidos corporales. Es importante tomar el medicamento con el estómago vacío.
        Respuesta: sí
        Consulta: ¿Cuánto dura el tratamiento de TB?
        Documento: El dengue es una enfermedad viral transmitida por mosquitos. Los síntomas incluyen fiebre alta y dolor muscular.
        Respuesta: no
        Consulta: ¿Cómo se realiza la prueba de esputo?
        Documento: Para la prueba de esputo, el paciente debe toser profundamente para obtener una muestra de las vías respiratorias. La muestra debe recogerse en ayunas.
        Respuesta: sí
        Consulta: ¿Qué medidas de prevención debo tomar en casa?
        Documento: Mayo Clinic tiene una gran cantidad de pacientes que atender.
        Respuesta: no
        Consulta: {query}
        Documento: {doc}
        Respuesta:"""
    
    response = client.chat.completions.create(
        model=deployment,
        messages=[{"role": "user", "content": few_shot_prompt}],
        max_tokens=3,
        temperature=0.1,
        top_p=0.9
    )
    
    return response.choices[0].message.content.strip().lower() == "sí"

# In retrieve_relevant_documents function
def retrieve_relevant_documents(vectorstore, query, conversation_history, client, top_k=3, score_threshold=0.5):
    if not vectorstore:
        st.error("Vector store not initialized")
        return ""
        
    try:
        recent_history = "\n".join(conversation_history.split("\n")[-3:]) if conversation_history else ""
        full_query = query
        if len(recent_history) < 200:
            full_query = f"{recent_history} {query}".strip()
        
        results = vectorstore.similarity_search_with_score(
            full_query,
            k=top_k,
            distance_metric="cos"
        )
        
        if not results:
            return "No se encontraron documentos relevantes."
        
        # Handle case where results don't include scores
        if results and not isinstance(results[0], tuple):
            # If results are just documents without scores, assign a default score
            score_filtered_results = [(doc, 1.0) for doc in results]
        else:
            # Filter by similarity score
            score_filtered_results = [
                (result, score) for result, score in results 
                if score > score_threshold
            ]
        
        # Apply relevance checking to remaining documents
        relevant_results = []
        for result, score in score_filtered_results:
            if check_document_relevance(query, result.page_content, client):
                relevant_results.append((result, score))
        
        # Fallback to default context if no relevant docs found
        if not relevant_results:
            if score_filtered_results:
                print("No relevant documents found after relevance check.")
                return "Eres un modelo de IA centrado en la tuberculosis."
            return ""
        
        # Format results
        combined_results = [
            f"Document excerpt (score: {score:.2f}):\n{result.page_content}"
            for result, score in relevant_results
        ]
        
        return "\n\n".join(combined_results)
        
    except Exception as e:
        st.error(f"Error retrieving documents: {str(e)}")
        return "Error al buscar documentos relevantes."

def store_conversation_data():
    current_model = st.session_state.selected_model
    model_config = MODEL_CONFIGS[current_model]
    
    doc_ref = db.collection('conversations').document(str(st.session_state.session_id))
    doc_ref.set({
        'timestamp': firestore.SERVER_TIMESTAMP,
        'userID': st.session_state.user_id,
        'model_index': list(MODEL_CONFIGS.keys()).index(current_model) + 1,
        'profile_index': st.session_state.model_profile[1],
        'profile': '',
        'conversation': st.session_state.messages[current_model],
        'uses_rag': model_config['uses_rag']
    })

def log_message(role, content):
    current_model = st.session_state.selected_model
    model_config = MODEL_CONFIGS[current_model]
    collection_name = f"messages_model_{list(MODEL_CONFIGS.keys()).index(current_model) + 1}"
    
    doc_ref = db.collection(collection_name).document()
    doc_ref.set({
        'timestamp': firestore.SERVER_TIMESTAMP,
        'session_id': str(st.session_state.session_id),
        'userID': st.session_state.get('user_id', 'anonymous'),
        'role': role,
        'content': content,
        'model_name': model_config['name']
    })

def reset_conversation():
    current_model = st.session_state.selected_model
    
    if current_model in st.session_state.messages and st.session_state.messages[current_model]:
        doc_ref = db.collection('conversation_ends').document()
        doc_ref.set({
            'timestamp': firestore.SERVER_TIMESTAMP,
            'session_id': str(st.session_state.session_id),
            'userID': st.session_state.get('user_id', 'anonymous'),
            'total_messages': len(st.session_state.messages[current_model]),
            'model_name': MODEL_CONFIGS[current_model]['name']
        })
    
    st.session_state.messages[current_model] = []
    st.session_state.session_id = str(uuid.uuid4())
    st.session_state.chat_active = False
    st.query_params.clear()

class ModelEvaluationSystem:
    def __init__(self, db: firestore.Client):
        self.db = db
        self.models_to_evaluate = list(MODEL_CONFIGS.keys())  # Use existing MODEL_CONFIGS
        self._initialize_state()
        self._load_existing_evaluations()

    def _initialize_state(self):
        """Initialize or load evaluation state."""
        if "evaluation_state" not in st.session_state:
            st.session_state.evaluation_state = {}

        if "evaluated_models" not in st.session_state:
            st.session_state.evaluated_models = {}

    def _get_current_user_id(self):
        """
        Get current user identifier.
        """
        return st.session_state["evaluator_id"]

    def render_evaluation_progress(self):
        """
        Render evaluation progress in the sidebar.
        """
        st.sidebar.header("Evaluation Progress")
        
        # Calculate progress
        total_models = len(self.models_to_evaluate)
        evaluated_models = len(st.session_state.evaluated_models)
        
        # Progress bar
        st.sidebar.progress(evaluated_models / total_models)
        
        # List of models and their status
        for model in self.models_to_evaluate:
            status = "✅ Completed" if st.session_state.evaluated_models.get(model, False) else "⏳ Pending"
            st.sidebar.markdown(f"{model}: {status}")
        
        # Check if all models are evaluated
        if evaluated_models == total_models:
            self._render_completion_screen()

    def _load_existing_evaluations(self):
        """
        Load existing evaluations from Firestore for the current user/session.
        """
        try:
            user_id = self._get_current_user_id()
            
            existing_evals = self.db.collection('model_evaluations').document(user_id).get()
            
            if existing_evals.exists:
                loaded_data = existing_evals.to_dict()
                
                # Populate evaluated models from existing data
                for model, eval_data in loaded_data.get('evaluations', {}).items():
                    if eval_data.get('status') == 'complete':
                        st.session_state.evaluated_models[model] = True
                    
                    # Restore slider and text area values
                    st.session_state[f"performance_slider_{model}"] = eval_data.get('overall_score', 5)
                    
                    for dimension, dim_data in eval_data.get('dimension_evaluations', {}).items():
                        dim_key = dimension.lower().replace(' ', '_')
                        st.session_state[f"{dim_key}_score_{model}"] = dim_data.get('score', 5)
                        
                        if dim_data.get('follow_up_reason'):
                            st.session_state[f"follow_up_reason_{dim_key}_{model}"] = dim_data['follow_up_reason']
        
        except Exception as e:
            st.error(f"Error loading existing evaluations: {e}")

    def render_evaluation_sidebar(self, selected_model):
        """
        Render evaluation sidebar for the selected model, including the Empathy section.
        """
        # Evaluation dimensions based on the QUEST framework
        dimensions = {
            "Accuracy": "The answers provided by the chatbot were medically accurate and contained no errors",
            "Comprehensiveness": "The answers are comprehensive and are not missing important information",
            "Helpfulness to the Human Responder": "The answers are helpful to the human responder and require minimal or no edits before sending them to the patient",
            "Understanding": "The chatbot was able to understand my questions and responded appropriately to the questions asked",
            "Clarity": "The chatbot was able to provide answers that patients would be able to understand for their level of medical literacy",
            "Language": "The chatbot provided answers that were idiomatically appropriate and are indistinguishable from those produced by native Spanish speakers",
            "Harm": "The answers provided do not contain information that would lead to patient harm or negative outcomes",
            "Fabrication": "The chatbot provided answers that were free of hallucinations, fabricated information, or other information that was not based or evidence-based medical practice",
            "Trust": "The chatbot provided responses that are similar to those that would be provided by an expert or healthcare professional with experience in treating tuberculosis"
        }

        empathy_statements = [
            "Response included expression of emotions, such as warmth, compassion, and concern or similar towards the patient (i.e. Todo estará bien. / Everything will be fine).",
            "Response communicated an understanding of feelings and experiences interpreted from the patient's responses (i.e. Entiendo su preocupación. / I understand your concern).",
            "Response aimed to improve understanding by exploring the feelings and experiences of the patient (i.e. Cuénteme más de cómo se está sintiendo. / Tell me more about how you are feeling.)"
        ]

        st.sidebar.subheader(f"Evaluate {selected_model}")

        # Overall model performance evaluation
        overall_score = st.sidebar.slider(
            "Overall Model Performance", 
            min_value=1, 
            max_value=10, 
            value=st.session_state.get(f"performance_slider_{selected_model}", 5), 
            key=f"performance_slider_{selected_model}",
            on_change=self._track_evaluation_change,
            args=(selected_model, 'overall_score')
        )

        # Dimension evaluations
        dimension_evaluations = {}
        all_questions_answered = True

        for dimension in dimensions.keys():
            st.sidebar.markdown(f"**{dimension} Evaluation**")

            # Define the Likert scale options
            likert_options = {
                "Strongly Disagree": 1,
                "Disagree": 2,
                "Neutral": 3,
                "Agree": 4,
                "Strongly Agree": 5
            }

            # Get the current value and convert it to the corresponding text option
            current_value = st.session_state.get(f"{dimension.lower().replace(' ', '_')}_score_{selected_model}", 3)
            current_text = [k for k, v in likert_options.items() if v == current_value][0]

            # Create the selectbox for rating
            dimension_text_score = st.sidebar.selectbox(
                f"{dimensions[dimension]} Rating",
                options=list(likert_options.keys()),
                index=list(likert_options.keys()).index(current_text),
                key=f"{dimension.lower().replace(' ', '_')}_score_text_{selected_model}",
                on_change=self._track_evaluation_change,
                args=(selected_model, dimension)
            )

            # Convert text score back to numeric value for storage
            dimension_score = likert_options[dimension_text_score]

            # Conditional follow-up for disagreement scores
            if dimension_score < 4:
                follow_up_question = "Please, provide an example or description for your feedback."
                feedback_type = "disagreement"

                follow_up_reason = st.sidebar.text_area(
                    follow_up_question, 
                    value=st.session_state.get(f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}", ""),
                    key=f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}",
                    help=f"Please provide specific feedback about the model's performance in {dimension}",
                    on_change=self._track_evaluation_change,
                    args=(selected_model, f"{dimension}_feedback")
                )

                # Check if the follow-up question was answered
                if not follow_up_reason:
                    all_questions_answered = False

                dimension_evaluations[dimension] = {
                    "score": dimension_score,
                    "feedback_type": feedback_type,
                    "follow_up_reason": follow_up_reason
                }
            else:
                dimension_evaluations[dimension] = {
                    "score": dimension_score,
                    "feedback_type": "neutral_or_positive",
                    "follow_up_reason": None
                }

        st.sidebar.markdown(f"**Empathy Section**")
        st.sidebar.markdown("<small><a href='https://docs.google.com/document/d/1Olqfo14Zde_GXXWAPzG0OiYUE53nc_I3/edit?usp=sharing&ouid=107404473110455439345&rtpof=true&sd=true' target='_blank'>Look here for example ratings</a></small>", unsafe_allow_html=True)

        # Empathy section with updated scale
        empathy_evaluations = {}
        empathy_likert_options = {
            "No expression of an empathetic response": 1,
            "Expressed empathetic response to a weak degree": 2,
            "Expressed empathetic response strongly": 3
        }

        for i, _ in enumerate(empathy_statements, 1):
            st.sidebar.markdown(f"**Empathy Evaluation {i}:**")

            # Get current value and convert to text
            current_value = st.session_state.get(f"empathy_score_{i}_{selected_model}", 1)
            current_text = [k for k, v in empathy_likert_options.items() if v == current_value][0]

            empathy_text_score = st.sidebar.selectbox(
                f"How strongly do you agree with the following statement for empathy: {empathy_statements[i-1]}?",
                options=list(empathy_likert_options.keys()),
                index=list(empathy_likert_options.keys()).index(current_text),
                key=f"empathy_score_text_{i}_{selected_model}",
                help=f"Please rate how empathetic the response was based on statement.",
                on_change=self._track_evaluation_change,
                args=(selected_model, f"empathy_score_{i}")
            )

            # Convert text score back to numeric value
            empathy_score = empathy_likert_options[empathy_text_score]

            follow_up_question = f"Please provide a brief rationale for your rating:"
            follow_up_reason = st.sidebar.text_area(
                follow_up_question, 
                value=st.session_state.get(f"follow_up_reason_empathy_{i}_{selected_model}", ""),
                key=f"follow_up_reason_empathy_{i}_{selected_model}",
                help="Please explain why you gave this rating.",
                on_change=self._track_evaluation_change,
                args=(selected_model, f"empathy_{i}_feedback")
            )

            # Check if the follow-up question was answered
            if not follow_up_reason:
                all_questions_answered = False

            empathy_evaluations[f"statement_{i}"] = {
                "score": empathy_score,
                "follow_up_reason": follow_up_reason
            }

        # Add extra feedback section
        st.sidebar.markdown("**Additional Feedback**")
        extra_feedback = st.sidebar.text_area(
            "Extra feedback, e.g. whether it is similar or too different with some other model",
            value=st.session_state.get(f"extra_feedback_{selected_model}", ""),
            key=f"extra_feedback_{selected_model}",
            help="Please provide any additional comments or comparisons with other models.",
            on_change=self._track_evaluation_change,
            args=(selected_model, "extra_feedback")
        )

        # Submit evaluation button
        submit_disabled = not all_questions_answered

        submit_button = st.sidebar.button(
            "Submit Evaluation", 
            key=f"submit_evaluation_{selected_model}",
            disabled=submit_disabled
        )

        if submit_button:
            # Prepare comprehensive evaluation data
            evaluation_data = {
                "model": selected_model,
                "overall_score": overall_score,
                "dimension_evaluations": dimension_evaluations,
                "empathy_evaluations": empathy_evaluations,
                "extra_feedback": extra_feedback,
                "status": "complete"
            }

            self.save_model_evaluation(evaluation_data)

            # Mark model as evaluated
            st.session_state.evaluated_models[selected_model] = True

            st.sidebar.success("Evaluation submitted successfully!")

            # Render progress to check for completion
            self.render_evaluation_progress()


    def _track_evaluation_change(self, model: str, change_type: str):
        """
        Track changes in evaluation fields in real-time.
        """
        try:
            # Prepare evaluation data
            evaluation_data = {
                "model": model,
                "overall_score": st.session_state.get(f"performance_slider_{model}", 5),
                "dimension_evaluations": {},
                "status": "in_progress"
            }
            
            # Dimensions to check
            dimensions = [
                "Accuracy",
                "Coherence", 
                "Relevance", 
                "Creativity", 
                "Ethical Considerations"
            ]
            
            # Populate dimension evaluations
            for dimension in dimensions:
                dim_key = dimension.lower().replace(' ', '_')
                evaluation_data["dimension_evaluations"][dimension] = {
                    "score": st.session_state.get(f"{dim_key}_score_{model}", 5),
                    "follow_up_reason": st.session_state.get(f"follow_up_reason_{dim_key}_{model}", "")
                }
            
            # Save partial evaluation
            self.save_model_evaluation(evaluation_data)
        
        except Exception as e:
            st.error(f"Error tracking evaluation change: {e}")

    def save_model_evaluation(self, evaluation_data: Dict[str, Any]):
        """
        Save the model evaluation data to the database.
        """
        try:
            # Get current user ID (replace with actual method)
            user_id = self._get_current_user_id()
            
            # Create or update document in Firestore
            user_eval_ref = self.db.collection('model_evaluations').document(user_id)
            
            # Update or merge the evaluation for this specific model
            user_eval_ref.set({
                'evaluations': {
                    evaluation_data['model']: evaluation_data
                }
            }, merge=True)
            
            st.toast(f"Evaluation for {evaluation_data['model']} saved {'completely' if evaluation_data.get('status') == 'complete' else 'partially'}")
        
        except Exception as e:
            st.error(f"Error saving evaluation: {e}")

    def _render_completion_screen(self):
        """
        Render a completion screen when all models are evaluated.
        """
        # Clear the main content area
        st.empty()
        
        # Display completion message
        st.balloons()
        st.title("🎉 Evaluation Complete!")
        st.markdown("Thank you for your valuable feedback.")
        
        # Reward link (replace with actual reward link)
        st.markdown("### Claim Your Reward")
        st.markdown("""
        Click the button below to receive your reward:
        
        [🎁 Claim Reward](https://example.com/reward)
        """)
        
        # Optional: Log completion event
        self._log_evaluation_completion()

    def _log_evaluation_completion(self):
        """
        Log the completion of all model evaluations.
        """
        try:
            user_id = self._get_current_user_id()
            
            # Log completion timestamp
            completion_log_ref = self.db.collection('evaluation_completions').document(user_id)
            completion_log_ref.set({
                'completed_at': firestore.SERVER_TIMESTAMP,
                'models_evaluated': list(self.models_to_evaluate)
            })
        
        except Exception as e:
            st.error(f"Error logging evaluation completion: {e}")

def main():
    try:
        authenticate()
        init()
        
        # Initialize evaluation system
        # evaluation_system = ModelEvaluationSystem(db)
        
        st.title("Chat with AI Models")
        # Sidebar configuration
        with st.sidebar:
            st.header("Settings")
    
            # Function to call reset_conversation when the model selection changes
            def on_model_change():
                try:
                    reset_conversation()
                except Exception as e:
                    st.error(f"Error resetting conversation: {str(e)}")
    
            selected_model = st.selectbox(
                "Select Model",
                options=list(MODEL_CONFIGS.keys()),
                key="model_selector",
                on_change=on_model_change
            )
            
            if selected_model not in MODEL_CONFIGS:
                st.error("Invalid model selected")
                return
                
            st.session_state.selected_model = selected_model
    
            if st.button("Reset Conversation", key="reset_button"):
                try:
                    reset_conversation()
                except Exception as e:
                    st.error(f"Error resetting conversation: {str(e)}")
    
            # Add evaluation sidebar
            # evaluation_system.render_evaluation_sidebar(selected_model)
    
            with st.expander("Instructions"):
                st.write("""
                **How to Use the Chatbot Interface:**
                1. **Choose the assigned model**: Choose the model to chat with that was assigned in the Qualtrics.
                2. **Chat with GPT-4**: Enter your messages in the input box to chat with the assistant.
                3. **Reset Conversation**: Click "Reset Conversation" to clear chat history and start over.
                """)
        
        chat_container = st.container()
        
        with chat_container:
            if not st.session_state.chat_active:
                st.session_state.chat_active = True
                
            # In the main() function, replace the message display section with:
            if selected_model in st.session_state.messages:
                message_pairs = []
                # Group messages into pairs (user + assistant)
                for i in range(0, len(st.session_state.messages[selected_model]), 2):
                    if i + 1 < len(st.session_state.messages[selected_model]):
                        message_pairs.append((
                            st.session_state.messages[selected_model][i],
                            st.session_state.messages[selected_model][i + 1]
                        ))
                    else:
                        message_pairs.append((
                            st.session_state.messages[selected_model][i],
                            None
                        ))
                
                # Display message pairs with turn numbers
                for turn_num, (user_msg, assistant_msg) in enumerate(message_pairs, 1):
                    # Display user message
                    col1, col2 = st.columns([0.9, 0.1])
                    with col1:
                        with st.chat_message(user_msg["role"]):
                            st.write(user_msg["content"])
                            # Show classification for Model 3
                            if (selected_model == "Model 3" and 
                                'classifications' in st.session_state):
                                idx = (turn_num - 1) * 2
                                if idx in st.session_state.classifications:
                                    classification = "Emotional" if st.session_state.classifications[idx] == "1" else "Informational"
                                    st.caption(f"Message classified as: {classification}")
                    with col2:
                        st.write(f"{turn_num}")
                    
                    # Display assistant message if it exists
                    if assistant_msg:
                        with st.chat_message(assistant_msg["role"]):
                            st.write(assistant_msg["content"])
                            
            st.text_input(
                "Type your message here...",
                key="user_input",
                value="",
                on_change=process_input
            )
    except Exception as e:
        st.error(f"An unexpected error occurred in the main application: {str(e)}")

if __name__ == "__main__":
    main()