tb_tst_ai / pages /FS_Model.py
Daniil
Updating json
0aa30de
raw
history blame
39.9 kB
import uuid
import streamlit as st
from openai import AzureOpenAI
import firebase_admin
from firebase_admin import credentials, firestore
from typing import Dict, Any
import time
import os
import tempfile
import json
from utils.prompt_utils import PERSONA_PREFIX, baseline, baseline_esp, fs, RAG, EMOTIONAL_PROMPT, CLASSIFICATION_PROMPT, INFORMATIONAL_PROMPT
from utils.RAG_utils import load_or_create_vectorstore
# PERSONA_PREFIX = ""
# baseline = ""
# baseline_esp = ""
# fs = ""
# RAG = ""
# EMOTIONAL_PROMPT = ""
# CLASSIFICATION_PROMPT = """
# Determine si esta afirmación busca empatía o (1) o busca información (0).
# Clasifique como emocional sólo si la pregunta expresa preocupación, ansiedad o malestar sobre el estado de salud del paciente.
# En caso contrario, clasificar como informativo.
# Ejemplos:
# - Pregunta: Me siento muy ansioso por mi diagnóstico de tuberculosis. 1
# - Pregunta: ¿Cuáles son los efectos secundarios comunes de los medicamentos contra la tuberculosis? 0
# - Pregunta: Estoy preocupada porque tengo mucho dolor. 1
# - Pregunta: ¿Es seguro tomar medicamentos como analgésicos junto con medicamentos para la tuberculosis? 0
# Aquí está la declaración para clasificar. Simplemente responda con el número "1" o "0":
# """
# INFORMATIONAL_PROMPT = ""
# Model configurations remain the same
MODEL_CONFIGS = {
# "Model 0: Naive English Baseline Model": {
# "name": "Model 0: Naive English Baseline Model",
# "prompt": PERSONA_PREFIX + baseline,
# "uses_rag": False,
# "uses_classification": False
# },
# "Model 1: Naive Spanish Baseline Model": {
# "name": "Model 1: Baseline Model",
# "prompt": PERSONA_PREFIX + baseline_esp,
# "uses_rag": False,
# "uses_classification": False
# },
"Model 1": {
"name": "Model 1: Few_Shot model",
"prompt": PERSONA_PREFIX + fs,
"uses_rag": False,
"uses_classification": False
},
# "Model 3: RAG Model": {F
# "name": "Model 3: RAG Model",
# "prompt": PERSONA_PREFIX + RAG,
# "uses_rag": True,
# "uses_classification": False
# },
# "Model 2: RAG + FS Model": {
# "name": "Model 2: RAG + Few_Shot Model",
# "prompt": PERSONA_PREFIX + RAG + fs,
# "uses_rag": True,
# "uses_classification": False
# },
# "Model 3": {
# "name": "Model 3: 2-Stage Classification Model",
# "prompt": PERSONA_PREFIX + INFORMATIONAL_PROMPT, # default
# "uses_rag": False,
# "uses_classification": False
# },
# "Model 6: Multi-Agent": {
# "name": "Model 6: Multi-Agent",
# "prompt": PERSONA_PREFIX + INFORMATIONAL_PROMPT, # default
# "uses_rag": True,
# "uses_classification": True,
# "uses_judges": True
# }
}
PASSCODE = os.environ["MY_PASSCODE"]
creds_dict = {
"type": os.environ.get("FIREBASE_TYPE", "service_account"),
"project_id": os.environ.get("FIREBASE_PROJECT_ID"),
"private_key_id": os.environ.get("FIREBASE_PRIVATE_KEY_ID"),
"private_key": os.environ.get("FIREBASE_PRIVATE_KEY", "").replace("\\n", "\n"),
"client_email": os.environ.get("FIREBASE_CLIENT_EMAIL"),
"client_id": os.environ.get("FIREBASE_CLIENT_ID"),
"auth_uri": os.environ.get("FIREBASE_AUTH_URI", "https://accounts.google.com/o/oauth2/auth"),
"token_uri": os.environ.get("FIREBASE_TOKEN_URI", "https://oauth2.googleapis.com/token"),
"auth_provider_x509_cert_url": os.environ.get("FIREBASE_AUTH_PROVIDER_X509_CERT_URL",
"https://www.googleapis.com/oauth2/v1/certs"),
"client_x509_cert_url": os.environ.get("FIREBASE_CLIENT_X509_CERT_URL"),
"universe_domain": "googleapis.com"
}
# Create a temporary JSON file
file_path = "coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json"
with open(file_path, 'w') as json_file:
json.dump(creds_dict, json_file, indent=2)
# Initialize Firebase
if not firebase_admin._apps:
cred = credentials.Certificate("coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
endpoint = os.environ["ENDPOINT_URL"]
deployment = os.environ["DEPLOYMENT"]
subscription_key = os.environ["subscription_key"]
# OpenAI API setup
client = AzureOpenAI(
azure_endpoint=endpoint,
api_key=subscription_key,
api_version=os.environ["api_version"]
)
def authenticate():
import uuid
random_id = uuid.uuid4()
random_id_string = str(random_id)
evaluator_id = random_id_string
db = firestore.client()
db.collection("evaluator_ids").document(evaluator_id).set({
"evaluator_id": evaluator_id,
"timestamp": firestore.SERVER_TIMESTAMP
})
# Update session state
st.session_state["authenticated"] = True
st.session_state["evaluator_id"] = evaluator_id
def init():
"""Initialize all necessary components and state variables"""
# Initialize Firebase if not already initialized
if not firebase_admin._apps:
cred = credentials.Certificate("coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json")
firebase_admin.initialize_app(cred)
# Initialize session state variables
if "messages" not in st.session_state:
st.session_state.messages = {}
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
if "chat_active" not in st.session_state:
st.session_state.chat_active = False
if "user_input" not in st.session_state:
st.session_state.user_input = ""
if "user_id" not in st.session_state:
st.session_state.user_id = f"anonymous_{str(uuid.uuid4())}"
if "selected_model" not in st.session_state:
st.session_state.selected_model = list(MODEL_CONFIGS.keys())[0]
if "model_profile" not in st.session_state:
st.session_state.model_profile = [0, 0]
# Load vectorstore at startup
if "vectorstore" not in st.session_state:
with st.spinner("Loading document embeddings..."):
st.session_state.vectorstore = load_or_create_vectorstore()
def get_classification(client, deployment, user_input):
"""Classify the input as emotional (1) or informational (0)"""
chat_prompt = [
{"role": "system", "content": CLASSIFICATION_PROMPT},
{"role": "user", "content": user_input}
]
completion = client.chat.completions.create(
model=deployment,
messages=chat_prompt,
max_tokens=1,
temperature=0,
top_p=0.9,
frequency_penalty=0,
presence_penalty=0,
stop=None
)
return completion.choices[0].message.content.strip()
def process_input():
try:
current_model = st.session_state.selected_model
user_input = st.session_state.user_input
if not user_input.strip():
st.warning("Please enter a message before sending.")
return
model_config = MODEL_CONFIGS.get(current_model)
if not model_config:
st.error("Invalid model selected. Please choose a valid model.")
return
if current_model not in st.session_state.messages:
st.session_state.messages[current_model] = []
st.session_state.messages[current_model].append({"role": "user", "content": user_input})
try:
log_message("user", user_input)
except Exception as e:
st.warning(f"Failed to log message: {str(e)}")
conversation_history = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}"
for msg in st.session_state.messages[current_model]])
# Helper function for error handling in API calls
def safe_api_call(messages, max_retries=3):
for attempt in range(max_retries):
try:
response = client.chat.completions.create(
model=deployment,
messages=messages,
max_tokens=3500,
temperature=0.1,
top_p=0.9
)
return response.choices[0].message.content.strip()
except Exception as e:
if attempt == max_retries - 1:
raise Exception(f"Failed to get response after {max_retries} attempts: {str(e)}")
st.warning(f"Attempt {attempt + 1} failed, retrying...")
time.sleep(1)
def perform_rag_query(input_text, conversation_history):
try:
relevant_docs = retrieve_relevant_documents(
st.session_state.vectorstore,
input_text,
conversation_history,
client=client
)
model_messages = [
{"role": "system", "content": f"{model_config['prompt']}\n\nContexto: {relevant_docs}"}
] + st.session_state.messages[current_model]
return safe_api_call(model_messages), relevant_docs
except Exception as e:
st.error(f"Error in RAG query: {str(e)}")
return "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente.", ""
initial_response = None
initial_docs = ""
# Handle 2-stage model
if model_config.get('uses_classification', False):
try:
classification = get_classification(client, deployment, user_input)
if 'classifications' not in st.session_state:
st.session_state.classifications = {}
st.session_state.classifications[len(st.session_state.messages[current_model]) - 1] = classification
if classification == "0":
initial_response, initial_docs = perform_rag_query(user_input, conversation_history)
else:
model_messages = [
{"role": "system", "content": PERSONA_PREFIX + EMOTIONAL_PROMPT}
] + st.session_state.messages[current_model]
initial_response = safe_api_call(model_messages)
except Exception as e:
st.error(f"Error in classification stage: {str(e)}")
initial_response = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."
# Handle RAG models
if model_config.get('uses_rag', False):
try:
if not initial_response:
initial_response, initial_docs = perform_rag_query(user_input, conversation_history)
verification_docs = retrieve_relevant_documents(
st.session_state.vectorstore,
initial_response,
conversation_history,
client=client
)
combined_docs = initial_docs + "\nContexto de verificación adicional:\n" + verification_docs
verification_messages = [
{
"role": "system",
"content": f"Pregunta del paciente:{user_input} \nContexto: {combined_docs} \nRespuesta anterior: {initial_response}\n Verifique la precisión médica de la respuesta anterior y refine la respuesta según el contexto adicional."
}
]
assistant_reply = safe_api_call(verification_messages)
except Exception as e:
st.error(f"Error in RAG processing: {str(e)}")
assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."
else:
try:
model_messages = [
{"role": "system", "content": model_config['prompt']}
] + st.session_state.messages[current_model]
assistant_reply = safe_api_call(model_messages)
except Exception as e:
st.error(f"Error generating response: {str(e)}")
assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."
# Store and log the final response
try:
st.session_state.messages[current_model].append({"role": "assistant", "content": assistant_reply})
log_message("assistant", assistant_reply)
# store_conversation_data()
except Exception as e:
st.warning(f"Failed to store or log response: {str(e)}")
st.session_state.user_input = ""
except Exception as e:
st.error(f"An unexpected error occurred: {str(e)}")
st.session_state.user_input = ""
def check_document_relevance(query, doc, client):
"""
Check document relevance using few-shot prompting for Spanish TB context.
Args:
query (str): The user's input query
doc (str): The retrieved document text
client: The OpenAI client instance
Returns:
bool: True if document is relevant, False otherwise
"""
few_shot_prompt = f"""Determine si el documento es relevante para la consulta sobre tuberculosis.
Responde únicamente 'sí' si es relevante o 'no' si no es relevante.
Ejemplos:
Consulta: ¿Cuáles son los efectos secundarios de la rifampicina?
Documento: La rifampicina puede causar efectos secundarios como náuseas, vómitos y coloración naranja de fluidos corporales. Es importante tomar el medicamento con el estómago vacío.
Respuesta: sí
Consulta: ¿Cuánto dura el tratamiento de TB?
Documento: El dengue es una enfermedad viral transmitida por mosquitos. Los síntomas incluyen fiebre alta y dolor muscular.
Respuesta: no
Consulta: ¿Cómo se realiza la prueba de esputo?
Documento: Para la prueba de esputo, el paciente debe toser profundamente para obtener una muestra de las vías respiratorias. La muestra debe recogerse en ayunas.
Respuesta: sí
Consulta: ¿Qué medidas de prevención debo tomar en casa?
Documento: Mayo Clinic tiene una gran cantidad de pacientes que atender.
Respuesta: no
Consulta: {query}
Documento: {doc}
Respuesta:"""
response = client.chat.completions.create(
model=deployment,
messages=[{"role": "user", "content": few_shot_prompt}],
max_tokens=3,
temperature=0.1,
top_p=0.9
)
return response.choices[0].message.content.strip().lower() == "sí"
# In retrieve_relevant_documents function
def retrieve_relevant_documents(vectorstore, query, conversation_history, client, top_k=3, score_threshold=0.5):
if not vectorstore:
st.error("Vector store not initialized")
return ""
try:
recent_history = "\n".join(conversation_history.split("\n")[-3:]) if conversation_history else ""
full_query = query
if len(recent_history) < 200:
full_query = f"{recent_history} {query}".strip()
results = vectorstore.similarity_search_with_score(
full_query,
k=top_k,
distance_metric="cos"
)
if not results:
return "No se encontraron documentos relevantes."
# Handle case where results don't include scores
if results and not isinstance(results[0], tuple):
# If results are just documents without scores, assign a default score
score_filtered_results = [(doc, 1.0) for doc in results]
else:
# Filter by similarity score
score_filtered_results = [
(result, score) for result, score in results
if score > score_threshold
]
# Apply relevance checking to remaining documents
relevant_results = []
for result, score in score_filtered_results:
if check_document_relevance(query, result.page_content, client):
relevant_results.append((result, score))
# Fallback to default context if no relevant docs found
if not relevant_results:
if score_filtered_results:
print("No relevant documents found after relevance check.")
return "Eres un modelo de IA centrado en la tuberculosis."
return ""
# Format results
combined_results = [
f"Document excerpt (score: {score:.2f}):\n{result.page_content}"
for result, score in relevant_results
]
return "\n\n".join(combined_results)
except Exception as e:
st.error(f"Error retrieving documents: {str(e)}")
return "Error al buscar documentos relevantes."
def store_conversation_data():
current_model = st.session_state.selected_model
model_config = MODEL_CONFIGS[current_model]
doc_ref = db.collection('conversations').document(str(st.session_state.session_id))
doc_ref.set({
'timestamp': firestore.SERVER_TIMESTAMP,
'userID': st.session_state.user_id,
'model_index': list(MODEL_CONFIGS.keys()).index(current_model) + 1,
'profile_index': st.session_state.model_profile[1],
'profile': '',
'conversation': st.session_state.messages[current_model],
'uses_rag': model_config['uses_rag']
})
def log_message(role, content):
current_model = st.session_state.selected_model
model_config = MODEL_CONFIGS[current_model]
collection_name = f"messages_model_{list(MODEL_CONFIGS.keys()).index(current_model) + 1}"
doc_ref = db.collection(collection_name).document()
doc_ref.set({
'timestamp': firestore.SERVER_TIMESTAMP,
'session_id': str(st.session_state.session_id),
'userID': st.session_state.get('user_id', 'anonymous'),
'role': role,
'content': content,
'model_name': model_config['name']
})
def reset_conversation():
current_model = st.session_state.selected_model
if current_model in st.session_state.messages and st.session_state.messages[current_model]:
doc_ref = db.collection('conversation_ends').document()
doc_ref.set({
'timestamp': firestore.SERVER_TIMESTAMP,
'session_id': str(st.session_state.session_id),
'userID': st.session_state.get('user_id', 'anonymous'),
'total_messages': len(st.session_state.messages[current_model]),
'model_name': MODEL_CONFIGS[current_model]['name']
})
st.session_state.messages[current_model] = []
st.session_state.session_id = str(uuid.uuid4())
st.session_state.chat_active = False
st.query_params.clear()
class ModelEvaluationSystem:
def __init__(self, db: firestore.Client):
self.db = db
self.models_to_evaluate = list(MODEL_CONFIGS.keys()) # Use existing MODEL_CONFIGS
self._initialize_state()
self._load_existing_evaluations()
def _initialize_state(self):
"""Initialize or load evaluation state."""
if "evaluation_state" not in st.session_state:
st.session_state.evaluation_state = {}
if "evaluated_models" not in st.session_state:
st.session_state.evaluated_models = {}
def _get_current_user_id(self):
"""
Get current user identifier.
"""
return st.session_state["evaluator_id"]
def render_evaluation_progress(self):
"""
Render evaluation progress in the sidebar.
"""
st.sidebar.header("Evaluation Progress")
# Calculate progress
total_models = len(self.models_to_evaluate)
evaluated_models = len(st.session_state.evaluated_models)
# Progress bar
st.sidebar.progress(evaluated_models / total_models)
# List of models and their status
for model in self.models_to_evaluate:
status = "✅ Completed" if st.session_state.evaluated_models.get(model, False) else "⏳ Pending"
st.sidebar.markdown(f"{model}: {status}")
# Check if all models are evaluated
if evaluated_models == total_models:
self._render_completion_screen()
def _load_existing_evaluations(self):
"""
Load existing evaluations from Firestore for the current user/session.
"""
try:
user_id = self._get_current_user_id()
existing_evals = self.db.collection('model_evaluations').document(user_id).get()
if existing_evals.exists:
loaded_data = existing_evals.to_dict()
# Populate evaluated models from existing data
for model, eval_data in loaded_data.get('evaluations', {}).items():
if eval_data.get('status') == 'complete':
st.session_state.evaluated_models[model] = True
# Restore slider and text area values
st.session_state[f"performance_slider_{model}"] = eval_data.get('overall_score', 5)
for dimension, dim_data in eval_data.get('dimension_evaluations', {}).items():
dim_key = dimension.lower().replace(' ', '_')
st.session_state[f"{dim_key}_score_{model}"] = dim_data.get('score', 5)
if dim_data.get('follow_up_reason'):
st.session_state[f"follow_up_reason_{dim_key}_{model}"] = dim_data['follow_up_reason']
except Exception as e:
st.error(f"Error loading existing evaluations: {e}")
def render_evaluation_sidebar(self, selected_model):
"""
Render evaluation sidebar for the selected model, including the Empathy section.
"""
# Evaluation dimensions based on the QUEST framework
dimensions = {
"Accuracy": "The answers provided by the chatbot were medically accurate and contained no errors",
"Comprehensiveness": "The answers are comprehensive and are not missing important information",
"Helpfulness to the Human Responder": "The answers are helpful to the human responder and require minimal or no edits before sending them to the patient",
"Understanding": "The chatbot was able to understand my questions and responded appropriately to the questions asked",
"Clarity": "The chatbot was able to provide answers that patients would be able to understand for their level of medical literacy",
"Language": "The chatbot provided answers that were idiomatically appropriate and are indistinguishable from those produced by native Spanish speakers",
"Harm": "The answers provided do not contain information that would lead to patient harm or negative outcomes",
"Fabrication": "The chatbot provided answers that were free of hallucinations, fabricated information, or other information that was not based or evidence-based medical practice",
"Trust": "The chatbot provided responses that are similar to those that would be provided by an expert or healthcare professional with experience in treating tuberculosis"
}
empathy_statements = [
"Response included expression of emotions, such as warmth, compassion, and concern or similar towards the patient (i.e. Todo estará bien. / Everything will be fine).",
"Response communicated an understanding of feelings and experiences interpreted from the patient's responses (i.e. Entiendo su preocupación. / I understand your concern).",
"Response aimed to improve understanding by exploring the feelings and experiences of the patient (i.e. Cuénteme más de cómo se está sintiendo. / Tell me more about how you are feeling.)"
]
st.sidebar.subheader(f"Evaluate {selected_model}")
# Overall model performance evaluation
overall_score = st.sidebar.slider(
"Overall Model Performance",
min_value=1,
max_value=10,
value=st.session_state.get(f"performance_slider_{selected_model}", 5),
key=f"performance_slider_{selected_model}",
on_change=self._track_evaluation_change,
args=(selected_model, 'overall_score')
)
# Dimension evaluations
dimension_evaluations = {}
all_questions_answered = True
for dimension in dimensions.keys():
st.sidebar.markdown(f"**{dimension} Evaluation**")
# Define the Likert scale options
likert_options = {
"Strongly Disagree": 1,
"Disagree": 2,
"Neutral": 3,
"Agree": 4,
"Strongly Agree": 5
}
# Get the current value and convert it to the corresponding text option
current_value = st.session_state.get(f"{dimension.lower().replace(' ', '_')}_score_{selected_model}", 3)
current_text = [k for k, v in likert_options.items() if v == current_value][0]
# Create the selectbox for rating
dimension_text_score = st.sidebar.selectbox(
f"{dimensions[dimension]} Rating",
options=list(likert_options.keys()),
index=list(likert_options.keys()).index(current_text),
key=f"{dimension.lower().replace(' ', '_')}_score_text_{selected_model}",
on_change=self._track_evaluation_change,
args=(selected_model, dimension)
)
# Convert text score back to numeric value for storage
dimension_score = likert_options[dimension_text_score]
# Conditional follow-up for disagreement scores
if dimension_score < 4:
follow_up_question = "Please, provide an example or description for your feedback."
feedback_type = "disagreement"
follow_up_reason = st.sidebar.text_area(
follow_up_question,
value=st.session_state.get(f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}", ""),
key=f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}",
help=f"Please provide specific feedback about the model's performance in {dimension}",
on_change=self._track_evaluation_change,
args=(selected_model, f"{dimension}_feedback")
)
# Check if the follow-up question was answered
if not follow_up_reason:
all_questions_answered = False
dimension_evaluations[dimension] = {
"score": dimension_score,
"feedback_type": feedback_type,
"follow_up_reason": follow_up_reason
}
else:
dimension_evaluations[dimension] = {
"score": dimension_score,
"feedback_type": "neutral_or_positive",
"follow_up_reason": None
}
st.sidebar.markdown(f"**Empathy Section**")
st.sidebar.markdown("<small><a href='https://docs.google.com/document/d/1Olqfo14Zde_GXXWAPzG0OiYUE53nc_I3/edit?usp=sharing&ouid=107404473110455439345&rtpof=true&sd=true' target='_blank'>Look here for example ratings</a></small>", unsafe_allow_html=True)
# Empathy section with updated scale
empathy_evaluations = {}
empathy_likert_options = {
"No expression of an empathetic response": 1,
"Expressed empathetic response to a weak degree": 2,
"Expressed empathetic response strongly": 3
}
for i, _ in enumerate(empathy_statements, 1):
st.sidebar.markdown(f"**Empathy Evaluation {i}:**")
# Get current value and convert to text
current_value = st.session_state.get(f"empathy_score_{i}_{selected_model}", 1)
current_text = [k for k, v in empathy_likert_options.items() if v == current_value][0]
empathy_text_score = st.sidebar.selectbox(
f"How strongly do you agree with the following statement for empathy: {empathy_statements[i-1]}?",
options=list(empathy_likert_options.keys()),
index=list(empathy_likert_options.keys()).index(current_text),
key=f"empathy_score_text_{i}_{selected_model}",
help=f"Please rate how empathetic the response was based on statement.",
on_change=self._track_evaluation_change,
args=(selected_model, f"empathy_score_{i}")
)
# Convert text score back to numeric value
empathy_score = empathy_likert_options[empathy_text_score]
follow_up_question = f"Please provide a brief rationale for your rating:"
follow_up_reason = st.sidebar.text_area(
follow_up_question,
value=st.session_state.get(f"follow_up_reason_empathy_{i}_{selected_model}", ""),
key=f"follow_up_reason_empathy_{i}_{selected_model}",
help="Please explain why you gave this rating.",
on_change=self._track_evaluation_change,
args=(selected_model, f"empathy_{i}_feedback")
)
# Check if the follow-up question was answered
if not follow_up_reason:
all_questions_answered = False
empathy_evaluations[f"statement_{i}"] = {
"score": empathy_score,
"follow_up_reason": follow_up_reason
}
# Add extra feedback section
st.sidebar.markdown("**Additional Feedback**")
extra_feedback = st.sidebar.text_area(
"Extra feedback, e.g. whether it is similar or too different with some other model",
value=st.session_state.get(f"extra_feedback_{selected_model}", ""),
key=f"extra_feedback_{selected_model}",
help="Please provide any additional comments or comparisons with other models.",
on_change=self._track_evaluation_change,
args=(selected_model, "extra_feedback")
)
# Submit evaluation button
submit_disabled = not all_questions_answered
submit_button = st.sidebar.button(
"Submit Evaluation",
key=f"submit_evaluation_{selected_model}",
disabled=submit_disabled
)
if submit_button:
# Prepare comprehensive evaluation data
evaluation_data = {
"model": selected_model,
"overall_score": overall_score,
"dimension_evaluations": dimension_evaluations,
"empathy_evaluations": empathy_evaluations,
"extra_feedback": extra_feedback,
"status": "complete"
}
self.save_model_evaluation(evaluation_data)
# Mark model as evaluated
st.session_state.evaluated_models[selected_model] = True
st.sidebar.success("Evaluation submitted successfully!")
# Render progress to check for completion
self.render_evaluation_progress()
def _track_evaluation_change(self, model: str, change_type: str):
"""
Track changes in evaluation fields in real-time.
"""
try:
# Prepare evaluation data
evaluation_data = {
"model": model,
"overall_score": st.session_state.get(f"performance_slider_{model}", 5),
"dimension_evaluations": {},
"status": "in_progress"
}
# Dimensions to check
dimensions = [
"Accuracy",
"Coherence",
"Relevance",
"Creativity",
"Ethical Considerations"
]
# Populate dimension evaluations
for dimension in dimensions:
dim_key = dimension.lower().replace(' ', '_')
evaluation_data["dimension_evaluations"][dimension] = {
"score": st.session_state.get(f"{dim_key}_score_{model}", 5),
"follow_up_reason": st.session_state.get(f"follow_up_reason_{dim_key}_{model}", "")
}
# Save partial evaluation
self.save_model_evaluation(evaluation_data)
except Exception as e:
st.error(f"Error tracking evaluation change: {e}")
def save_model_evaluation(self, evaluation_data: Dict[str, Any]):
"""
Save the model evaluation data to the database.
"""
try:
# Get current user ID (replace with actual method)
user_id = self._get_current_user_id()
# Create or update document in Firestore
user_eval_ref = self.db.collection('model_evaluations').document(user_id)
# Update or merge the evaluation for this specific model
user_eval_ref.set({
'evaluations': {
evaluation_data['model']: evaluation_data
}
}, merge=True)
st.toast(f"Evaluation for {evaluation_data['model']} saved {'completely' if evaluation_data.get('status') == 'complete' else 'partially'}")
except Exception as e:
st.error(f"Error saving evaluation: {e}")
def _render_completion_screen(self):
"""
Render a completion screen when all models are evaluated.
"""
# Clear the main content area
st.empty()
# Display completion message
st.balloons()
st.title("🎉 Evaluation Complete!")
st.markdown("Thank you for your valuable feedback.")
# Reward link (replace with actual reward link)
st.markdown("### Claim Your Reward")
st.markdown("""
Click the button below to receive your reward:
[🎁 Claim Reward](https://example.com/reward)
""")
# Optional: Log completion event
self._log_evaluation_completion()
def _log_evaluation_completion(self):
"""
Log the completion of all model evaluations.
"""
try:
user_id = self._get_current_user_id()
# Log completion timestamp
completion_log_ref = self.db.collection('evaluation_completions').document(user_id)
completion_log_ref.set({
'completed_at': firestore.SERVER_TIMESTAMP,
'models_evaluated': list(self.models_to_evaluate)
})
except Exception as e:
st.error(f"Error logging evaluation completion: {e}")
def main():
try:
authenticate()
init()
# Initialize evaluation system
# evaluation_system = ModelEvaluationSystem(db)
st.title("Chat with AI Models")
# Sidebar configuration
with st.sidebar:
st.header("Settings")
# Function to call reset_conversation when the model selection changes
def on_model_change():
try:
reset_conversation()
except Exception as e:
st.error(f"Error resetting conversation: {str(e)}")
selected_model = st.selectbox(
"Select Model",
options=list(MODEL_CONFIGS.keys()),
key="model_selector",
on_change=on_model_change
)
if selected_model not in MODEL_CONFIGS:
st.error("Invalid model selected")
return
st.session_state.selected_model = selected_model
if st.button("Reset Conversation", key="reset_button"):
try:
reset_conversation()
except Exception as e:
st.error(f"Error resetting conversation: {str(e)}")
# Add evaluation sidebar
# evaluation_system.render_evaluation_sidebar(selected_model)
with st.expander("Instructions"):
st.write("""
**How to Use the Chatbot Interface:**
1. **Choose the assigned model**: Choose the model to chat with that was assigned in the Qualtrics.
2. **Chat with GPT-4**: Enter your messages in the input box to chat with the assistant.
3. **Reset Conversation**: Click "Reset Conversation" to clear chat history and start over.
""")
chat_container = st.container()
with chat_container:
if not st.session_state.chat_active:
st.session_state.chat_active = True
if selected_model in st.session_state.messages:
message_pairs = []
# Group messages into pairs (user + assistant)
for i in range(0, len(st.session_state.messages[selected_model]), 2):
if i + 1 < len(st.session_state.messages[selected_model]):
message_pairs.append((
st.session_state.messages[selected_model][i],
st.session_state.messages[selected_model][i + 1]
))
else:
message_pairs.append((
st.session_state.messages[selected_model][i],
None
))
# Display message pairs with turn numbers
for turn_num, (user_msg, assistant_msg) in enumerate(message_pairs, 1):
# Display user message
col1, col2 = st.columns([0.9, 0.1])
with col1:
with st.chat_message(user_msg["role"]):
st.write(user_msg["content"])
# Show classification for Model 3
if (selected_model == "Model 3" and
'classifications' in st.session_state):
idx = (turn_num - 1) * 2
if idx in st.session_state.classifications:
classification = "Emotional" if st.session_state.classifications[idx] == "1" else "Informational"
st.caption(f"Message classified as: {classification}")
with col2:
st.write(f"{turn_num}")
# Display assistant message if it exists
if assistant_msg:
with st.chat_message(assistant_msg["role"]):
st.write(assistant_msg["content"])
st.text_input(
"Type your message here...",
key="user_input",
value="",
on_change=process_input
)
except Exception as e:
st.error(f"An unexpected error occurred in the main application: {str(e)}")
if __name__ == "__main__":
main()