import uuid |
import streamlit as st |
from openai import AzureOpenAI |
import firebase_admin |
from firebase_admin import credentials, firestore |
from typing import Dict, Any |
import time |
import os |
import tempfile |
import json |
from utils.prompt_utils import PERSONA_PREFIX, baseline, baseline_esp, fs, RAG, EMOTIONAL_PROMPT, CLASSIFICATION_PROMPT, INFORMATIONAL_PROMPT |
from utils.RAG_utils import load_or_create_vectorstore |
"Model 2": { |
"name": "Model 2: RAG + Few_Shot Model", |
"prompt": PERSONA_PREFIX + RAG + fs, |
"uses_rag": True, |
"uses_classification": False |
}, |
} |
PASSCODE = os.environ["MY_PASSCODE"] |
creds_dict = { |
"type": os.environ.get("FIREBASE_TYPE", "service_account"), |
"project_id": os.environ.get("FIREBASE_PROJECT_ID"), |
"private_key_id": os.environ.get("FIREBASE_PRIVATE_KEY_ID"), |
"private_key": os.environ.get("FIREBASE_PRIVATE_KEY", "").replace("\\n", "\n"), |
"client_email": os.environ.get("FIREBASE_CLIENT_EMAIL"), |
"client_id": os.environ.get("FIREBASE_CLIENT_ID"), |
"auth_uri": os.environ.get("FIREBASE_AUTH_URI", "https://accounts.google.com/o/oauth2/auth"), |
"token_uri": os.environ.get("FIREBASE_TOKEN_URI", "https://oauth2.googleapis.com/token"), |
"auth_provider_x509_cert_url": os.environ.get("FIREBASE_AUTH_PROVIDER_X509_CERT_URL", |
"https://www.googleapis.com/oauth2/v1/certs"), |
"client_x509_cert_url": os.environ.get("FIREBASE_CLIENT_X509_CERT_URL"), |
"universe_domain": "googleapis.com" |
} |
file_path = "coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json" |
with open(file_path, 'w') as json_file: |
json.dump(creds_dict, json_file, indent=2) |
if not firebase_admin._apps: |
cred = credentials.Certificate("coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json") |
firebase_admin.initialize_app(cred) |
db = firestore.client() |
endpoint = os.environ["ENDPOINT_URL"] |
deployment = os.environ["DEPLOYMENT"] |
subscription_key = os.environ["subscription_key"] |
client = AzureOpenAI( |
azure_endpoint=endpoint, |
api_key=subscription_key, |
api_version=os.environ["api_version"] |
) |
def authenticate(): |
import uuid |
random_id = uuid.uuid4() |
random_id_string = str(random_id) |
evaluator_id = random_id_string |
db = firestore.client() |
db.collection("evaluator_ids").document(evaluator_id).set({ |
"evaluator_id": evaluator_id, |
"timestamp": firestore.SERVER_TIMESTAMP |
}) |
st.session_state["authenticated"] = True |
st.session_state["evaluator_id"] = evaluator_id |
def init(): |
"""Initialize all necessary components and state variables""" |
if not firebase_admin._apps: |
cred = credentials.Certificate("coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json") |
firebase_admin.initialize_app(cred) |
if "messages" not in st.session_state: |
st.session_state.messages = {} |
if "session_id" not in st.session_state: |
st.session_state.session_id = str(uuid.uuid4()) |
if "chat_active" not in st.session_state: |
st.session_state.chat_active = False |
if "user_input" not in st.session_state: |
st.session_state.user_input = "" |
if "user_id" not in st.session_state: |
st.session_state.user_id = f"anonymous_{str(uuid.uuid4())}" |
if "selected_model" not in st.session_state: |
st.session_state.selected_model = list(MODEL_CONFIGS.keys())[0] |
if "model_profile" not in st.session_state: |
st.session_state.model_profile = [0, 0] |
if "vectorstore" not in st.session_state: |
with st.spinner("Loading document embeddings..."): |
st.session_state.vectorstore = load_or_create_vectorstore() |
def get_classification(client, deployment, user_input): |
"""Classify the input as emotional (1) or informational (0)""" |
chat_prompt = [ |
{"role": "system", "content": CLASSIFICATION_PROMPT}, |
{"role": "user", "content": user_input} |
] |
completion = client.chat.completions.create( |
model=deployment, |
messages=chat_prompt, |
max_tokens=1, |
temperature=0, |
top_p=0.9, |
frequency_penalty=0, |
presence_penalty=0, |
stop=None |
) |
return completion.choices[0].message.content.strip() |
def process_input(): |
try: |
current_model = st.session_state.selected_model |
user_input = st.session_state.user_input |
if not user_input.strip(): |
st.warning("Please enter a message before sending.") |
return |
model_config = MODEL_CONFIGS.get(current_model) |
if not model_config: |
st.error("Invalid model selected. Please choose a valid model.") |
return |
if current_model not in st.session_state.messages: |
st.session_state.messages[current_model] = [] |
st.session_state.messages[current_model].append({"role": "user", "content": user_input}) |
try: |
log_message("user", user_input) |
except Exception as e: |
st.warning(f"Failed to log message: {str(e)}") |
conversation_history = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" |
for msg in st.session_state.messages[current_model]]) |
def safe_api_call(messages, max_retries=3): |
for attempt in range(max_retries): |
try: |
response = client.chat.completions.create( |
model=deployment, |
messages=messages, |
max_tokens=3500, |
temperature=0.1, |
top_p=0.9 |
) |
return response.choices[0].message.content.strip() |
except Exception as e: |
if attempt == max_retries - 1: |
return "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you." |
time.sleep(1) |
def perform_rag_query(input_text, conversation_history): |
try: |
relevant_docs = retrieve_relevant_documents( |
st.session_state.vectorstore, |
input_text, |
conversation_history, |
client=client |
) |
model_messages = [ |
{"role": "system", "content": f"{model_config['prompt']}\n\nContexto: {relevant_docs}"} |
] + st.session_state.messages[current_model] |
return safe_api_call(model_messages), relevant_docs |
except Exception as e: |
return "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you.", "" |
if model_config.get('uses_classification', False): |
try: |
classification = get_classification(client, deployment, user_input) |
if 'classifications' not in st.session_state: |
st.session_state.classifications = {} |
st.session_state.classifications[len(st.session_state.messages[current_model]) - 1] = classification |
if classification == "0": |
initial_response, initial_docs = perform_rag_query(user_input, conversation_history) |
else: |
model_messages = [ |
{"role": "system", "content": PERSONA_PREFIX + EMOTIONAL_PROMPT} |
] + st.session_state.messages[current_model] |
initial_response = safe_api_call(model_messages) |
except Exception as e: |
initial_response = "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you." |
if model_config.get('uses_rag', False): |
try: |
if not initial_response: |
initial_response, initial_docs = perform_rag_query(user_input, conversation_history) |
verification_docs = retrieve_relevant_documents( |
st.session_state.vectorstore, |
initial_response, |
conversation_history, |
client=client |
) |
combined_docs = initial_docs + "\nContexto de verificación adicional:\n" + verification_docs |
verification_messages = [ |
{ |
"role": "system", |
"content": f"Pregunta del paciente:{user_input} \nContexto: {combined_docs} \nRespuesta anterior: {initial_response}\n Verifique la precisión médica de la respuesta anterior y refine la respuesta según el contexto adicional." |
} |
] |
assistant_reply = safe_api_call(verification_messages) |
except Exception as e: |
assistant_reply = "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you." |
else: |
try: |
model_messages = [ |
{"role": "system", "content": model_config['prompt']} |
] + st.session_state.messages[current_model] |
assistant_reply = safe_api_call(model_messages) |
except Exception as e: |
assistant_reply = "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you." |
initial_response = None |
initial_docs = "" |
if model_config.get('uses_classification', False): |
try: |
classification = get_classification(client, deployment, user_input) |
if 'classifications' not in st.session_state: |
st.session_state.classifications = {} |
st.session_state.classifications[len(st.session_state.messages[current_model]) - 1] = classification |
if classification == "0": |
initial_response, initial_docs = perform_rag_query(user_input, conversation_history) |
else: |
model_messages = [ |
{"role": "system", "content": PERSONA_PREFIX + EMOTIONAL_PROMPT} |
] + st.session_state.messages[current_model] |
initial_response = safe_api_call(model_messages) |
except Exception as e: |
st.error(f"Error in classification stage: {str(e)}") |
initial_response = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente." |
if model_config.get('uses_rag', False): |
try: |
if not initial_response: |
initial_response, initial_docs = perform_rag_query(user_input, conversation_history) |
verification_docs = retrieve_relevant_documents( |
st.session_state.vectorstore, |
initial_response, |
conversation_history, |
client=client |
) |
combined_docs = initial_docs + "\nContexto de verificación adicional:\n" + verification_docs |
verification_messages = [ |
{ |
"role": "system", |
"content": f"Pregunta del paciente:{user_input} \nContexto: {combined_docs} \nRespuesta anterior: {initial_response}\n Verifique la precisión médica de la respuesta anterior y refine la respuesta según el contexto adicional." |
} |
] |
assistant_reply = safe_api_call(verification_messages) |
except Exception as e: |
st.error(f"Error in RAG processing: {str(e)}") |
assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente." |
else: |
try: |
model_messages = [ |
{"role": "system", "content": model_config['prompt']} |
] + st.session_state.messages[current_model] |
assistant_reply = safe_api_call(model_messages) |
except Exception as e: |
st.error(f"Error generating response: {str(e)}") |
assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente." |
try: |
st.session_state.messages[current_model].append({"role": "assistant", "content": assistant_reply}) |
log_message("assistant", assistant_reply) |
except Exception as e: |
st.warning(f"Failed to store or log response: {str(e)}") |
st.session_state.user_input = "" |
except Exception as e: |
st.error(f"An unexpected error occurred: {str(e)}") |
st.session_state.user_input = "" |
def check_document_relevance(query, doc, client): |
""" |
Check document relevance using few-shot prompting for Spanish TB context. |
Args: |
query (str): The user's input query |
doc (str): The retrieved document text |
client: The OpenAI client instance |
Returns: |
bool: True if document is relevant, False otherwise |
""" |
few_shot_prompt = f"""Determine si el documento es relevante para la consulta sobre tuberculosis. |
Responde únicamente 'sí' si es relevante o 'no' si no es relevante. |
Ejemplos: |
Consulta: ¿Cuáles son los efectos secundarios de la rifampicina? |
Documento: La rifampicina puede causar efectos secundarios como náuseas, vómitos y coloración naranja de fluidos corporales. Es importante tomar el medicamento con el estómago vacío. |
Respuesta: sí |
Consulta: ¿Cuánto dura el tratamiento de TB? |
Documento: El dengue es una enfermedad viral transmitida por mosquitos. Los síntomas incluyen fiebre alta y dolor muscular. |
Respuesta: no |
Consulta: ¿Cómo se realiza la prueba de esputo? |
Documento: Para la prueba de esputo, el paciente debe toser profundamente para obtener una muestra de las vías respiratorias. La muestra debe recogerse en ayunas. |
Respuesta: sí |
Consulta: ¿Qué medidas de prevención debo tomar en casa? |
Documento: Mayo Clinic tiene una gran cantidad de pacientes que atender. |
Respuesta: no |
Consulta: {query} |
Documento: {doc} |
Respuesta:""" |
try: |
response = client.chat.completions.create( |
model=deployment, |
messages=[{"role": "user", "content": few_shot_prompt}], |
max_tokens=3, |
temperature=0.1, |
top_p=0.9 |
) |
return response.choices[0].message.content.strip().lower() == "sí" |
except Exception as e: |
print(f"Error in relevance check: {str(e)}") |
return False |
def retrieve_relevant_documents(vectorstore, query, conversation_history, client, top_k=3, score_threshold=0.5): |
if not vectorstore: |
st.error("Vector store not initialized") |
return "" |
try: |
recent_history = "\n".join(conversation_history.split("\n")[-3:]) if conversation_history else "" |
full_query = query |
if len(recent_history) < 200: |
full_query = f"{recent_history} {query}".strip() |
results = vectorstore.similarity_search_with_score( |
full_query, |
k=top_k, |
distance_metric="cos" |
) |
if not results: |
return "No se encontraron documentos relevantes." |
if results and not isinstance(results[0], tuple): |
score_filtered_results = [(doc, 1.0) for doc in results] |
else: |
score_filtered_results = [ |
(result, score) for result, score in results |
if score > score_threshold |
] |
relevant_results = [] |
for result, score in score_filtered_results: |
if check_document_relevance(query, result.page_content, client): |
relevant_results.append((result, score)) |
if not relevant_results: |
if score_filtered_results: |
print("No relevant documents found after relevance check.") |
return "Eres un modelo de IA centrado en la tuberculosis." |
return "" |
combined_results = [ |
f"Document excerpt (score: {score:.2f}):\n{result.page_content}" |
for result, score in relevant_results |
] |
return "\n\n".join(combined_results) |
except Exception as e: |
st.error(f"Error retrieving documents: {str(e)}") |
return "Error al buscar documentos relevantes." |
def store_conversation_data(): |
current_model = st.session_state.selected_model |
model_config = MODEL_CONFIGS[current_model] |
doc_ref = db.collection('conversations').document(str(st.session_state.session_id)) |
doc_ref.set({ |
'timestamp': firestore.SERVER_TIMESTAMP, |
'userID': st.session_state.user_id, |
'model_index': list(MODEL_CONFIGS.keys()).index(current_model) + 1, |
'profile_index': st.session_state.model_profile[1], |
'profile': '', |
'conversation': st.session_state.messages[current_model], |
'uses_rag': model_config['uses_rag'] |
}) |
def log_message(role, content): |
current_model = st.session_state.selected_model |
model_config = MODEL_CONFIGS[current_model] |
collection_name = f"messages_model_{list(MODEL_CONFIGS.keys()).index(current_model) + 1}" |
doc_ref = db.collection(collection_name).document() |
doc_ref.set({ |
'timestamp': firestore.SERVER_TIMESTAMP, |
'session_id': str(st.session_state.session_id), |
'userID': st.session_state.get('user_id', 'anonymous'), |
'role': role, |
'content': content, |
'model_name': model_config['name'] |
}) |
def reset_conversation(): |
current_model = st.session_state.selected_model |
if current_model in st.session_state.messages and st.session_state.messages[current_model]: |
doc_ref = db.collection('conversation_ends').document() |
doc_ref.set({ |
'timestamp': firestore.SERVER_TIMESTAMP, |
'session_id': str(st.session_state.session_id), |
'userID': st.session_state.get('user_id', 'anonymous'), |
'total_messages': len(st.session_state.messages[current_model]), |
'model_name': MODEL_CONFIGS[current_model]['name'] |
}) |
st.session_state.messages[current_model] = [] |
st.session_state.session_id = str(uuid.uuid4()) |
st.session_state.chat_active = False |
st.query_params.clear() |
class ModelEvaluationSystem: |
def __init__(self, db: firestore.Client): |
self.db = db |
self.models_to_evaluate = list(MODEL_CONFIGS.keys()) |
self._initialize_state() |
self._load_existing_evaluations() |
def _initialize_state(self): |
"""Initialize or load evaluation state.""" |
if "evaluation_state" not in st.session_state: |
st.session_state.evaluation_state = {} |
if "evaluated_models" not in st.session_state: |
st.session_state.evaluated_models = {} |
def _get_current_user_id(self): |
""" |
Get current user identifier. |
""" |
return st.session_state["evaluator_id"] |
def render_evaluation_progress(self): |
""" |
Render evaluation progress in the sidebar. |
""" |
st.sidebar.header("Evaluation Progress") |
total_models = len(self.models_to_evaluate) |
evaluated_models = len(st.session_state.evaluated_models) |
st.sidebar.progress(evaluated_models / total_models) |
for model in self.models_to_evaluate: |
status = "✅ Completed" if st.session_state.evaluated_models.get(model, False) else "⏳ Pending" |
st.sidebar.markdown(f"{model}: {status}") |
if evaluated_models == total_models: |
self._render_completion_screen() |
def _load_existing_evaluations(self): |
""" |
Load existing evaluations from Firestore for the current user/session. |
""" |
try: |
user_id = self._get_current_user_id() |
existing_evals = self.db.collection('model_evaluations').document(user_id).get() |
if existing_evals.exists: |
loaded_data = existing_evals.to_dict() |
for model, eval_data in loaded_data.get('evaluations', {}).items(): |
if eval_data.get('status') == 'complete': |
st.session_state.evaluated_models[model] = True |
st.session_state[f"performance_slider_{model}"] = eval_data.get('overall_score', 5) |
for dimension, dim_data in eval_data.get('dimension_evaluations', {}).items(): |
dim_key = dimension.lower().replace(' ', '_') |
st.session_state[f"{dim_key}_score_{model}"] = dim_data.get('score', 5) |
if dim_data.get('follow_up_reason'): |
st.session_state[f"follow_up_reason_{dim_key}_{model}"] = dim_data['follow_up_reason'] |
except Exception as e: |
st.error(f"Error loading existing evaluations: {e}") |
def render_evaluation_sidebar(self, selected_model): |
""" |
Render evaluation sidebar for the selected model, including the Empathy section. |
""" |
dimensions = { |
"Accuracy": "The answers provided by the chatbot were medically accurate and contained no errors", |
"Comprehensiveness": "The answers are comprehensive and are not missing important information", |
"Helpfulness to the Human Responder": "The answers are helpful to the human responder and require minimal or no edits before sending them to the patient", |
"Understanding": "The chatbot was able to understand my questions and responded appropriately to the questions asked", |
"Clarity": "The chatbot was able to provide answers that patients would be able to understand for their level of medical literacy", |
"Language": "The chatbot provided answers that were idiomatically appropriate and are indistinguishable from those produced by native Spanish speakers", |
"Harm": "The answers provided do not contain information that would lead to patient harm or negative outcomes", |
"Fabrication": "The chatbot provided answers that were free of hallucinations, fabricated information, or other information that was not based or evidence-based medical practice", |
"Trust": "The chatbot provided responses that are similar to those that would be provided by an expert or healthcare professional with experience in treating tuberculosis" |
} |
empathy_statements = [ |
"Response included expression of emotions, such as warmth, compassion, and concern or similar towards the patient (i.e. Todo estará bien. / Everything will be fine).", |
"Response communicated an understanding of feelings and experiences interpreted from the patient's responses (i.e. Entiendo su preocupación. / I understand your concern).", |
"Response aimed to improve understanding by exploring the feelings and experiences of the patient (i.e. Cuénteme más de cómo se está sintiendo. / Tell me more about how you are feeling.)" |
] |
st.sidebar.subheader(f"Evaluate {selected_model}") |
overall_score = st.sidebar.slider( |
"Overall Model Performance", |
min_value=1, |
max_value=10, |
value=st.session_state.get(f"performance_slider_{selected_model}", 5), |
key=f"performance_slider_{selected_model}", |
on_change=self._track_evaluation_change, |
args=(selected_model, 'overall_score') |
) |
dimension_evaluations = {} |
all_questions_answered = True |
for dimension in dimensions.keys(): |
st.sidebar.markdown(f"**{dimension} Evaluation**") |
likert_options = { |
"Strongly Disagree": 1, |
"Disagree": 2, |
"Neutral": 3, |
"Agree": 4, |
"Strongly Agree": 5 |
} |
current_value = st.session_state.get(f"{dimension.lower().replace(' ', '_')}_score_{selected_model}", 3) |
current_text = [k for k, v in likert_options.items() if v == current_value][0] |
dimension_text_score = st.sidebar.selectbox( |
f"{dimensions[dimension]} Rating", |
options=list(likert_options.keys()), |
index=list(likert_options.keys()).index(current_text), |
key=f"{dimension.lower().replace(' ', '_')}_score_text_{selected_model}", |
on_change=self._track_evaluation_change, |
args=(selected_model, dimension) |
) |
dimension_score = likert_options[dimension_text_score] |
if dimension_score < 4: |
follow_up_question = "Please, provide an example or description for your feedback." |
feedback_type = "disagreement" |
follow_up_reason = st.sidebar.text_area( |
follow_up_question, |
value=st.session_state.get(f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}", ""), |
key=f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}", |
help=f"Please provide specific feedback about the model's performance in {dimension}", |
on_change=self._track_evaluation_change, |
args=(selected_model, f"{dimension}_feedback") |
) |
if not follow_up_reason: |
all_questions_answered = False |
dimension_evaluations[dimension] = { |
"score": dimension_score, |
"feedback_type": feedback_type, |
"follow_up_reason": follow_up_reason |
} |
else: |
dimension_evaluations[dimension] = { |
"score": dimension_score, |
"feedback_type": "neutral_or_positive", |
"follow_up_reason": None |
} |
st.sidebar.markdown(f"**Empathy Section**") |
st.sidebar.markdown("<small><a href='https://docs.google.com/document/d/1Olqfo14Zde_GXXWAPzG0OiYUE53nc_I3/edit?usp=sharing&ouid=107404473110455439345&rtpof=true&sd=true' target='_blank'>Look here for example ratings</a></small>", unsafe_allow_html=True) |
empathy_evaluations = {} |
empathy_likert_options = { |
"No expression of an empathetic response": 1, |
"Expressed empathetic response to a weak degree": 2, |
"Expressed empathetic response strongly": 3 |
} |
for i, _ in enumerate(empathy_statements, 1): |
st.sidebar.markdown(f"**Empathy Evaluation {i}:**") |
current_value = st.session_state.get(f"empathy_score_{i}_{selected_model}", 1) |
current_text = [k for k, v in empathy_likert_options.items() if v == current_value][0] |
empathy_text_score = st.sidebar.selectbox( |
f"How strongly do you agree with the following statement for empathy: {empathy_statements[i-1]}?", |
options=list(empathy_likert_options.keys()), |
index=list(empathy_likert_options.keys()).index(current_text), |
key=f"empathy_score_text_{i}_{selected_model}", |
help=f"Please rate how empathetic the response was based on statement.", |
on_change=self._track_evaluation_change, |
args=(selected_model, f"empathy_score_{i}") |
) |
empathy_score = empathy_likert_options[empathy_text_score] |
follow_up_question = f"Please provide a brief rationale for your rating:" |
follow_up_reason = st.sidebar.text_area( |
follow_up_question, |
value=st.session_state.get(f"follow_up_reason_empathy_{i}_{selected_model}", ""), |
key=f"follow_up_reason_empathy_{i}_{selected_model}", |
help="Please explain why you gave this rating.", |
on_change=self._track_evaluation_change, |
args=(selected_model, f"empathy_{i}_feedback") |
) |
if not follow_up_reason: |
all_questions_answered = False |
empathy_evaluations[f"statement_{i}"] = { |
"score": empathy_score, |
"follow_up_reason": follow_up_reason |
} |
st.sidebar.markdown("**Additional Feedback**") |
extra_feedback = st.sidebar.text_area( |
"Extra feedback, e.g. whether it is similar or too different with some other model", |
value=st.session_state.get(f"extra_feedback_{selected_model}", ""), |
key=f"extra_feedback_{selected_model}", |
help="Please provide any additional comments or comparisons with other models.", |
on_change=self._track_evaluation_change, |
args=(selected_model, "extra_feedback") |
) |
submit_disabled = not all_questions_answered |
submit_button = st.sidebar.button( |
"Submit Evaluation", |
key=f"submit_evaluation_{selected_model}", |
disabled=submit_disabled |
) |
if submit_button: |
evaluation_data = { |
"model": selected_model, |
"overall_score": overall_score, |
"dimension_evaluations": dimension_evaluations, |
"empathy_evaluations": empathy_evaluations, |
"extra_feedback": extra_feedback, |
"status": "complete" |
} |
self.save_model_evaluation(evaluation_data) |
st.session_state.evaluated_models[selected_model] = True |
st.sidebar.success("Evaluation submitted successfully!") |
self.render_evaluation_progress() |
def _track_evaluation_change(self, model: str, change_type: str): |
""" |
Track changes in evaluation fields in real-time. |
""" |
try: |
evaluation_data = { |
"model": model, |
"overall_score": st.session_state.get(f"performance_slider_{model}", 5), |
"dimension_evaluations": {}, |
"status": "in_progress" |
} |
dimensions = [ |
"Accuracy", |
"Coherence", |
"Relevance", |
"Creativity", |
"Ethical Considerations" |
] |
for dimension in dimensions: |
dim_key = dimension.lower().replace(' ', '_') |
evaluation_data["dimension_evaluations"][dimension] = { |
"score": st.session_state.get(f"{dim_key}_score_{model}", 5), |
"follow_up_reason": st.session_state.get(f"follow_up_reason_{dim_key}_{model}", "") |
} |
self.save_model_evaluation(evaluation_data) |
except Exception as e: |
st.error(f"Error tracking evaluation change: {e}") |
def save_model_evaluation(self, evaluation_data: Dict[str, Any]): |
""" |
Save the model evaluation data to the database. |
""" |
try: |
user_id = self._get_current_user_id() |
user_eval_ref = self.db.collection('model_evaluations').document(user_id) |
user_eval_ref.set({ |
'evaluations': { |
evaluation_data['model']: evaluation_data |
} |
}, merge=True) |
st.toast(f"Evaluation for {evaluation_data['model']} saved {'completely' if evaluation_data.get('status') == 'complete' else 'partially'}") |
except Exception as e: |
st.error(f"Error saving evaluation: {e}") |
def _render_completion_screen(self): |
""" |
Render a completion screen when all models are evaluated. |
""" |
st.empty() |
st.balloons() |
st.title("🎉 Evaluation Complete!") |
st.markdown("Thank you for your valuable feedback.") |
st.markdown("### Claim Your Reward") |
st.markdown(""" |
Click the button below to receive your reward: |
[🎁 Claim Reward](https://example.com/reward) |
""") |
self._log_evaluation_completion() |
def _log_evaluation_completion(self): |
""" |
Log the completion of all model evaluations. |
""" |
try: |
user_id = self._get_current_user_id() |
completion_log_ref = self.db.collection('evaluation_completions').document(user_id) |
completion_log_ref.set({ |
'completed_at': firestore.SERVER_TIMESTAMP, |
'models_evaluated': list(self.models_to_evaluate) |
}) |
except Exception as e: |
st.error(f"Error logging evaluation completion: {e}") |
def main(): |
try: |
authenticate() |
init() |
st.title("Chat with AI Models") |
with st.sidebar: |
st.header("Settings") |
def on_model_change(): |
try: |
reset_conversation() |
except Exception as e: |
st.error(f"Error resetting conversation: {str(e)}") |
selected_model = st.selectbox( |
"Select Model", |
options=list(MODEL_CONFIGS.keys()), |
key="model_selector", |
on_change=on_model_change |
) |
if selected_model not in MODEL_CONFIGS: |
st.error("Invalid model selected") |
return |
st.session_state.selected_model = selected_model |
if st.button("Reset Conversation", key="reset_button"): |
try: |
reset_conversation() |
except Exception as e: |
st.error(f"Error resetting conversation: {str(e)}") |
with st.expander("Instructions"): |
st.write(""" |
**How to Use the Chatbot Interface:** |
1. **Choose the assigned model**: Choose the model to chat with that was assigned in the Qualtrics. |
2. **Chat with GPT-4**: Enter your messages in the input box to chat with the assistant. |
3. **Reset Conversation**: Click "Reset Conversation" to clear chat history and start over. |
""") |
chat_container = st.container() |
with chat_container: |
if not st.session_state.chat_active: |
st.session_state.chat_active = True |
if selected_model in st.session_state.messages: |
message_pairs = [] |
for i in range(0, len(st.session_state.messages[selected_model]), 2): |
if i + 1 < len(st.session_state.messages[selected_model]): |
message_pairs.append(( |
st.session_state.messages[selected_model][i], |
st.session_state.messages[selected_model][i + 1] |
)) |
else: |
message_pairs.append(( |
st.session_state.messages[selected_model][i], |
None |
)) |
for turn_num, (user_msg, assistant_msg) in enumerate(message_pairs, 1): |
col1, col2 = st.columns([0.9, 0.1]) |
with col1: |
with st.chat_message(user_msg["role"]): |
st.write(user_msg["content"]) |
if (selected_model == "Model 3" and |
'classifications' in st.session_state): |
idx = (turn_num - 1) * 2 |
if idx in st.session_state.classifications: |
classification = "Emotional" if st.session_state.classifications[idx] == "1" else "Informational" |
st.caption(f"Message classified as: {classification}") |
with col2: |
st.write(f"{turn_num}") |
if assistant_msg: |
with st.chat_message(assistant_msg["role"]): |
st.write(assistant_msg["content"]) |
st.text_input( |
"Type your message here...", |
key="user_input", |
value="", |
on_change=process_input |
) |
except Exception as e: |
st.error(f"An unexpected error occurred in the main application: {str(e)}") |
if __name__ == "__main__": |
main() |