File size: 44,412 Bytes
17d7268 0aa30de 17d7268 0aa30de 3672495 0aa30de 17d7268 80f09db 17d7268 80f09db 17d7268 80f09db 17d7268 80f09db 17d7268 80f09db 17d7268 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 |
import uuid
import streamlit as st
from openai import AzureOpenAI
import firebase_admin
from firebase_admin import credentials, firestore
from typing import Dict, Any
import time
import os
import tempfile
import json
from utils.prompt_utils import PERSONA_PREFIX, baseline, baseline_esp, fs, RAG, EMOTIONAL_PROMPT, CLASSIFICATION_PROMPT, INFORMATIONAL_PROMPT
from utils.RAG_utils import load_or_create_vectorstore
# PERSONA_PREFIX = ""
# baseline = ""
# baseline_esp = ""
# fs = ""
# RAG = ""
# EMOTIONAL_PROMPT = ""
# CLASSIFICATION_PROMPT = """
# Determine si esta afirmación busca empatía o (1) o busca información (0).
# Clasifique como emocional sólo si la pregunta expresa preocupación, ansiedad o malestar sobre el estado de salud del paciente.
# En caso contrario, clasificar como informativo.
# Ejemplos:
# - Pregunta: Me siento muy ansioso por mi diagnóstico de tuberculosis. 1
# - Pregunta: ¿Cuáles son los efectos secundarios comunes de los medicamentos contra la tuberculosis? 0
# - Pregunta: Estoy preocupada porque tengo mucho dolor. 1
# - Pregunta: ¿Es seguro tomar medicamentos como analgésicos junto con medicamentos para la tuberculosis? 0
# Aquí está la declaración para clasificar. Simplemente responda con el número "1" o "0":
# """
# INFORMATIONAL_PROMPT = ""
# Model configurations remain the same
MODEL_CONFIGS = {
# "Model 0: Naive English Baseline Model": {
# "name": "Model 0: Naive English Baseline Model",
# "prompt": PERSONA_PREFIX + baseline,
# "uses_rag": False,
# "uses_classification": False
# },
# "Model 1: Naive Spanish Baseline Model": {
# "name": "Model 1: Baseline Model",
# "prompt": PERSONA_PREFIX + baseline_esp,
# "uses_rag": False,
# "uses_classification": False
# },
# "Model 1": {
# "name": "Model 1: Few_Shot model",
# "prompt": PERSONA_PREFIX + fs,
# "uses_rag": False,
# "uses_classification": False
# },
# "Model 3: RAG Model": {F
# "name": "Model 3: RAG Model",
# "prompt": PERSONA_PREFIX + RAG,
# "uses_rag": True,
# "uses_classification": False
# },
# "Model 2": {
# "name": "Model 2: RAG + Few_Shot Model",
# "prompt": PERSONA_PREFIX + RAG + fs,
# "uses_rag": True,
# "uses_classification": False
# },
"Model 3": {
"name": "Model 3: 2-Stage Classification Model",
"prompt": PERSONA_PREFIX + INFORMATIONAL_PROMPT, # default
"uses_rag": False,
"uses_classification": False
},
# "Model 6: Multi-Agent": {
# "name": "Model 6: Multi-Agent",
# "prompt": PERSONA_PREFIX + INFORMATIONAL_PROMPT, # default
# "uses_rag": True,
# "uses_classification": True,
# "uses_judges": True
# }
}
PASSCODE = os.environ["MY_PASSCODE"]
creds_dict = {
"type": os.environ.get("FIREBASE_TYPE", "service_account"),
"project_id": os.environ.get("FIREBASE_PROJECT_ID"),
"private_key_id": os.environ.get("FIREBASE_PRIVATE_KEY_ID"),
"private_key": os.environ.get("FIREBASE_PRIVATE_KEY", "").replace("\\n", "\n"),
"client_email": os.environ.get("FIREBASE_CLIENT_EMAIL"),
"client_id": os.environ.get("FIREBASE_CLIENT_ID"),
"auth_uri": os.environ.get("FIREBASE_AUTH_URI", "https://accounts.google.com/o/oauth2/auth"),
"token_uri": os.environ.get("FIREBASE_TOKEN_URI", "https://oauth2.googleapis.com/token"),
"auth_provider_x509_cert_url": os.environ.get("FIREBASE_AUTH_PROVIDER_X509_CERT_URL",
"https://www.googleapis.com/oauth2/v1/certs"),
"client_x509_cert_url": os.environ.get("FIREBASE_CLIENT_X509_CERT_URL"),
"universe_domain": "googleapis.com"
}
# Create a temporary JSON file
file_path = "coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json"
with open(file_path, 'w') as json_file:
json.dump(creds_dict, json_file, indent=2)
# Initialize Firebase
if not firebase_admin._apps:
cred = credentials.Certificate("coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
endpoint = os.environ["ENDPOINT_URL"]
deployment = os.environ["DEPLOYMENT"]
subscription_key = os.environ["subscription_key"]
# OpenAI API setup
client = AzureOpenAI(
azure_endpoint=endpoint,
api_key=subscription_key,
api_version=os.environ["api_version"]
)
def authenticate():
import uuid
random_id = uuid.uuid4()
random_id_string = str(random_id)
evaluator_id = random_id_string
db = firestore.client()
db.collection("evaluator_ids").document(evaluator_id).set({
"evaluator_id": evaluator_id,
"timestamp": firestore.SERVER_TIMESTAMP
})
# Update session state
st.session_state["authenticated"] = True
st.session_state["evaluator_id"] = evaluator_id
def init():
"""Initialize all necessary components and state variables"""
# Initialize session state variables
if "messages" not in st.session_state:
st.session_state.messages = {}
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
if "chat_active" not in st.session_state:
st.session_state.chat_active = False
if "user_input" not in st.session_state:
st.session_state.user_input = ""
if "user_id" not in st.session_state:
st.session_state.user_id = f"anonymous_{str(uuid.uuid4())}"
if "selected_model" not in st.session_state:
st.session_state.selected_model = list(MODEL_CONFIGS.keys())[0]
if "model_profile" not in st.session_state:
st.session_state.model_profile = [0, 0]
# Load vectorstore at startup
if "vectorstore" not in st.session_state:
with st.spinner("Loading document embeddings..."):
st.session_state.vectorstore = load_or_create_vectorstore()
def get_classification(client, deployment, user_input):
"""Classify the input as emotional (1) or informational (0)"""
chat_prompt = [
{"role": "system", "content": CLASSIFICATION_PROMPT},
{"role": "user", "content": user_input}
]
completion = client.chat.completions.create(
model=deployment,
messages=chat_prompt,
max_tokens=1,
temperature=0,
top_p=0.9,
frequency_penalty=0,
presence_penalty=0,
stop=None
)
return completion.choices[0].message.content.strip()
def process_input():
try:
current_model = st.session_state.selected_model
user_input = st.session_state.user_input
if not user_input.strip():
st.warning("Please enter a message before sending.")
return
model_config = MODEL_CONFIGS.get(current_model)
if not model_config:
st.error("Invalid model selected. Please choose a valid model.")
return
if current_model not in st.session_state.messages:
st.session_state.messages[current_model] = []
st.session_state.messages[current_model].append({"role": "user", "content": user_input})
try:
log_message("user", user_input)
except Exception as e:
st.warning(f"Failed to log message: {str(e)}")
conversation_history = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}"
for msg in st.session_state.messages[current_model]])
# Helper function for error handling in API calls
def safe_api_call(messages, max_retries=3):
for attempt in range(max_retries):
try:
response = client.chat.completions.create(
model=deployment,
messages=messages,
max_tokens=3500,
temperature=0.1,
top_p=0.9
)
return response.choices[0].message.content.strip()
except Exception as e:
if attempt == max_retries - 1:
# Return user-friendly message instead of raising exception
return "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you."
time.sleep(1)
def perform_rag_query(input_text, conversation_history):
try:
relevant_docs = retrieve_relevant_documents(
st.session_state.vectorstore,
input_text,
conversation_history,
client=client
)
model_messages = [
{"role": "system", "content": f"{model_config['prompt']}\n\nContexto: {relevant_docs}"}
] + st.session_state.messages[current_model]
return safe_api_call(model_messages), relevant_docs
except Exception as e:
# Use standardized error message
return "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you.", ""
# Update these sections too:
if model_config.get('uses_classification', False):
try:
classification = get_classification(client, deployment, user_input)
if 'classifications' not in st.session_state:
st.session_state.classifications = {}
st.session_state.classifications[len(st.session_state.messages[current_model]) - 1] = classification
if classification == "0":
initial_response, initial_docs = perform_rag_query(user_input, conversation_history)
else:
model_messages = [
{"role": "system", "content": PERSONA_PREFIX + EMOTIONAL_PROMPT}
] + st.session_state.messages[current_model]
initial_response = safe_api_call(model_messages)
except Exception as e:
# Replace error message with standardized message
initial_response = "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you."
# And also update the RAG models section:
if model_config.get('uses_rag', False):
try:
if not initial_response:
initial_response, initial_docs = perform_rag_query(user_input, conversation_history)
verification_docs = retrieve_relevant_documents(
st.session_state.vectorstore,
initial_response,
conversation_history,
client=client
)
combined_docs = initial_docs + "\nContexto de verificación adicional:\n" + verification_docs
verification_messages = [
{
"role": "system",
"content": f"Pregunta del paciente:{user_input} \nContexto: {combined_docs} \nRespuesta anterior: {initial_response}\n Verifique la precisión médica de la respuesta anterior y refine la respuesta según el contexto adicional."
}
]
assistant_reply = safe_api_call(verification_messages)
except Exception as e:
# Replace error message with standardized message
assistant_reply = "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you."
else:
try:
model_messages = [
{"role": "system", "content": model_config['prompt']}
] + st.session_state.messages[current_model]
assistant_reply = safe_api_call(model_messages)
except Exception as e:
# Replace error message with standardized message
assistant_reply = "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you."
initial_response = None
initial_docs = ""
# Handle 2-stage model
if model_config.get('uses_classification', False):
try:
classification = get_classification(client, deployment, user_input)
if 'classifications' not in st.session_state:
st.session_state.classifications = {}
st.session_state.classifications[len(st.session_state.messages[current_model]) - 1] = classification
if classification == "0":
initial_response, initial_docs = perform_rag_query(user_input, conversation_history)
else:
model_messages = [
{"role": "system", "content": PERSONA_PREFIX + EMOTIONAL_PROMPT}
] + st.session_state.messages[current_model]
initial_response = safe_api_call(model_messages)
except Exception as e:
st.error(f"Error in classification stage: {str(e)}")
initial_response = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."
# Handle RAG models
if model_config.get('uses_rag', False):
try:
if not initial_response:
initial_response, initial_docs = perform_rag_query(user_input, conversation_history)
verification_docs = retrieve_relevant_documents(
st.session_state.vectorstore,
initial_response,
conversation_history,
client=client
)
combined_docs = initial_docs + "\nContexto de verificación adicional:\n" + verification_docs
verification_messages = [
{
"role": "system",
"content": f"Pregunta del paciente:{user_input} \nContexto: {combined_docs} \nRespuesta anterior: {initial_response}\n Verifique la precisión médica de la respuesta anterior y refine la respuesta según el contexto adicional."
}
]
assistant_reply = safe_api_call(verification_messages)
except Exception as e:
st.error(f"Error in RAG processing: {str(e)}")
assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."
else:
try:
model_messages = [
{"role": "system", "content": model_config['prompt']}
] + st.session_state.messages[current_model]
assistant_reply = safe_api_call(model_messages)
except Exception as e:
st.error(f"Error generating response: {str(e)}")
assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."
# Store and log the final response
try:
st.session_state.messages[current_model].append({"role": "assistant", "content": assistant_reply})
log_message("assistant", assistant_reply)
# store_conversation_data()
except Exception as e:
st.warning(f"Failed to store or log response: {str(e)}")
st.session_state.user_input = ""
except Exception as e:
st.error(f"An unexpected error occurred: {str(e)}")
st.session_state.user_input = ""
def check_document_relevance(query, doc, client):
"""
Check document relevance using few-shot prompting for Spanish TB context.
Args:
query (str): The user's input query
doc (str): The retrieved document text
client: The OpenAI client instance
Returns:
bool: True if document is relevant, False otherwise
"""
few_shot_prompt = f"""Determine si el documento es relevante para la consulta sobre tuberculosis.
Responde únicamente 'sí' si es relevante o 'no' si no es relevante.
Ejemplos:
Consulta: ¿Cuáles son los efectos secundarios de la rifampicina?
Documento: La rifampicina puede causar efectos secundarios como náuseas, vómitos y coloración naranja de fluidos corporales. Es importante tomar el medicamento con el estómago vacío.
Respuesta: sí
Consulta: ¿Cuánto dura el tratamiento de TB?
Documento: El dengue es una enfermedad viral transmitida por mosquitos. Los síntomas incluyen fiebre alta y dolor muscular.
Respuesta: no
Consulta: ¿Cómo se realiza la prueba de esputo?
Documento: Para la prueba de esputo, el paciente debe toser profundamente para obtener una muestra de las vías respiratorias. La muestra debe recogerse en ayunas.
Respuesta: sí
Consulta: ¿Qué medidas de prevención debo tomar en casa?
Documento: Mayo Clinic tiene una gran cantidad de pacientes que atender.
Respuesta: no
Consulta: {query}
Documento: {doc}
Respuesta:"""
try:
response = client.chat.completions.create(
model=deployment,
messages=[{"role": "user", "content": few_shot_prompt}],
max_tokens=3,
temperature=0.1,
top_p=0.9
)
return response.choices[0].message.content.strip().lower() == "sí"
except Exception as e:
# In case of error, default to false (not relevant)
print(f"Error in relevance check: {str(e)}")
return False
# In retrieve_relevant_documents function
def retrieve_relevant_documents(vectorstore, query, conversation_history, client, top_k=3, score_threshold=0.5):
if not vectorstore:
st.error("Vector store not initialized")
return ""
try:
recent_history = "\n".join(conversation_history.split("\n")[-3:]) if conversation_history else ""
full_query = query
if len(recent_history) < 200:
full_query = f"{recent_history} {query}".strip()
results = vectorstore.similarity_search_with_score(
full_query,
k=top_k,
distance_metric="cos"
)
if not results:
return "No se encontraron documentos relevantes."
# Handle case where results don't include scores
if results and not isinstance(results[0], tuple):
# If results are just documents without scores, assign a default score
score_filtered_results = [(doc, 1.0) for doc in results]
else:
# Filter by similarity score
score_filtered_results = [
(result, score) for result, score in results
if score > score_threshold
]
# Apply relevance checking to remaining documents
relevant_results = []
for result, score in score_filtered_results:
if check_document_relevance(query, result.page_content, client):
relevant_results.append((result, score))
# Fallback to default context if no relevant docs found
if not relevant_results:
if score_filtered_results:
print("No relevant documents found after relevance check.")
return "Eres un modelo de IA centrado en la tuberculosis."
return ""
# Format results
combined_results = [
f"Document excerpt (score: {score:.2f}):\n{result.page_content}"
for result, score in relevant_results
]
return "\n\n".join(combined_results)
except Exception as e:
st.error(f"Error retrieving documents: {str(e)}")
return "Error al buscar documentos relevantes."
def store_conversation_data():
current_model = st.session_state.selected_model
model_config = MODEL_CONFIGS[current_model]
doc_ref = db.collection('conversations').document(str(st.session_state.session_id))
doc_ref.set({
'timestamp': firestore.SERVER_TIMESTAMP,
'userID': st.session_state.user_id,
'model_index': list(MODEL_CONFIGS.keys()).index(current_model) + 1,
'profile_index': st.session_state.model_profile[1],
'profile': '',
'conversation': st.session_state.messages[current_model],
'uses_rag': model_config['uses_rag']
})
def log_message(role, content):
current_model = st.session_state.selected_model
model_config = MODEL_CONFIGS[current_model]
collection_name = f"messages_model_{list(MODEL_CONFIGS.keys()).index(current_model) + 1}"
doc_ref = db.collection(collection_name).document()
doc_ref.set({
'timestamp': firestore.SERVER_TIMESTAMP,
'session_id': str(st.session_state.session_id),
'userID': st.session_state.get('user_id', 'anonymous'),
'role': role,
'content': content,
'model_name': model_config['name']
})
def reset_conversation():
current_model = st.session_state.selected_model
if current_model in st.session_state.messages and st.session_state.messages[current_model]:
doc_ref = db.collection('conversation_ends').document()
doc_ref.set({
'timestamp': firestore.SERVER_TIMESTAMP,
'session_id': str(st.session_state.session_id),
'userID': st.session_state.get('user_id', 'anonymous'),
'total_messages': len(st.session_state.messages[current_model]),
'model_name': MODEL_CONFIGS[current_model]['name']
})
st.session_state.messages[current_model] = []
st.session_state.session_id = str(uuid.uuid4())
st.session_state.chat_active = False
st.query_params.clear()
class ModelEvaluationSystem:
def __init__(self, db: firestore.Client):
self.db = db
self.models_to_evaluate = list(MODEL_CONFIGS.keys()) # Use existing MODEL_CONFIGS
self._initialize_state()
self._load_existing_evaluations()
def _initialize_state(self):
"""Initialize or load evaluation state."""
if "evaluation_state" not in st.session_state:
st.session_state.evaluation_state = {}
if "evaluated_models" not in st.session_state:
st.session_state.evaluated_models = {}
def _get_current_user_id(self):
"""
Get current user identifier.
"""
return st.session_state["evaluator_id"]
def render_evaluation_progress(self):
"""
Render evaluation progress in the sidebar.
"""
st.sidebar.header("Evaluation Progress")
# Calculate progress
total_models = len(self.models_to_evaluate)
evaluated_models = len(st.session_state.evaluated_models)
# Progress bar
st.sidebar.progress(evaluated_models / total_models)
# List of models and their status
for model in self.models_to_evaluate:
status = "✅ Completed" if st.session_state.evaluated_models.get(model, False) else "⏳ Pending"
st.sidebar.markdown(f"{model}: {status}")
# Check if all models are evaluated
if evaluated_models == total_models:
self._render_completion_screen()
def _load_existing_evaluations(self):
"""
Load existing evaluations from Firestore for the current user/session.
"""
try:
user_id = self._get_current_user_id()
existing_evals = self.db.collection('model_evaluations').document(user_id).get()
if existing_evals.exists:
loaded_data = existing_evals.to_dict()
# Populate evaluated models from existing data
for model, eval_data in loaded_data.get('evaluations', {}).items():
if eval_data.get('status') == 'complete':
st.session_state.evaluated_models[model] = True
# Restore slider and text area values
st.session_state[f"performance_slider_{model}"] = eval_data.get('overall_score', 5)
for dimension, dim_data in eval_data.get('dimension_evaluations', {}).items():
dim_key = dimension.lower().replace(' ', '_')
st.session_state[f"{dim_key}_score_{model}"] = dim_data.get('score', 5)
if dim_data.get('follow_up_reason'):
st.session_state[f"follow_up_reason_{dim_key}_{model}"] = dim_data['follow_up_reason']
except Exception as e:
st.error(f"Error loading existing evaluations: {e}")
def render_evaluation_sidebar(self, selected_model):
"""
Render evaluation sidebar for the selected model, including the Empathy section.
"""
# Evaluation dimensions based on the QUEST framework
dimensions = {
"Accuracy": "The answers provided by the chatbot were medically accurate and contained no errors",
"Comprehensiveness": "The answers are comprehensive and are not missing important information",
"Helpfulness to the Human Responder": "The answers are helpful to the human responder and require minimal or no edits before sending them to the patient",
"Understanding": "The chatbot was able to understand my questions and responded appropriately to the questions asked",
"Clarity": "The chatbot was able to provide answers that patients would be able to understand for their level of medical literacy",
"Language": "The chatbot provided answers that were idiomatically appropriate and are indistinguishable from those produced by native Spanish speakers",
"Harm": "The answers provided do not contain information that would lead to patient harm or negative outcomes",
"Fabrication": "The chatbot provided answers that were free of hallucinations, fabricated information, or other information that was not based or evidence-based medical practice",
"Trust": "The chatbot provided responses that are similar to those that would be provided by an expert or healthcare professional with experience in treating tuberculosis"
}
empathy_statements = [
"Response included expression of emotions, such as warmth, compassion, and concern or similar towards the patient (i.e. Todo estará bien. / Everything will be fine).",
"Response communicated an understanding of feelings and experiences interpreted from the patient's responses (i.e. Entiendo su preocupación. / I understand your concern).",
"Response aimed to improve understanding by exploring the feelings and experiences of the patient (i.e. Cuénteme más de cómo se está sintiendo. / Tell me more about how you are feeling.)"
]
st.sidebar.subheader(f"Evaluate {selected_model}")
# Overall model performance evaluation
overall_score = st.sidebar.slider(
"Overall Model Performance",
min_value=1,
max_value=10,
value=st.session_state.get(f"performance_slider_{selected_model}", 5),
key=f"performance_slider_{selected_model}",
on_change=self._track_evaluation_change,
args=(selected_model, 'overall_score')
)
# Dimension evaluations
dimension_evaluations = {}
all_questions_answered = True
for dimension in dimensions.keys():
st.sidebar.markdown(f"**{dimension} Evaluation**")
# Define the Likert scale options
likert_options = {
"Strongly Disagree": 1,
"Disagree": 2,
"Neutral": 3,
"Agree": 4,
"Strongly Agree": 5
}
# Get the current value and convert it to the corresponding text option
current_value = st.session_state.get(f"{dimension.lower().replace(' ', '_')}_score_{selected_model}", 3)
current_text = [k for k, v in likert_options.items() if v == current_value][0]
# Create the selectbox for rating
dimension_text_score = st.sidebar.selectbox(
f"{dimensions[dimension]} Rating",
options=list(likert_options.keys()),
index=list(likert_options.keys()).index(current_text),
key=f"{dimension.lower().replace(' ', '_')}_score_text_{selected_model}",
on_change=self._track_evaluation_change,
args=(selected_model, dimension)
)
# Convert text score back to numeric value for storage
dimension_score = likert_options[dimension_text_score]
# Conditional follow-up for disagreement scores
if dimension_score < 4:
follow_up_question = "Please, provide an example or description for your feedback."
feedback_type = "disagreement"
follow_up_reason = st.sidebar.text_area(
follow_up_question,
value=st.session_state.get(f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}", ""),
key=f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}",
help=f"Please provide specific feedback about the model's performance in {dimension}",
on_change=self._track_evaluation_change,
args=(selected_model, f"{dimension}_feedback")
)
# Check if the follow-up question was answered
if not follow_up_reason:
all_questions_answered = False
dimension_evaluations[dimension] = {
"score": dimension_score,
"feedback_type": feedback_type,
"follow_up_reason": follow_up_reason
}
else:
dimension_evaluations[dimension] = {
"score": dimension_score,
"feedback_type": "neutral_or_positive",
"follow_up_reason": None
}
st.sidebar.markdown(f"**Empathy Section**")
st.sidebar.markdown("<small><a href='https://docs.google.com/document/d/1Olqfo14Zde_GXXWAPzG0OiYUE53nc_I3/edit?usp=sharing&ouid=107404473110455439345&rtpof=true&sd=true' target='_blank'>Look here for example ratings</a></small>", unsafe_allow_html=True)
# Empathy section with updated scale
empathy_evaluations = {}
empathy_likert_options = {
"No expression of an empathetic response": 1,
"Expressed empathetic response to a weak degree": 2,
"Expressed empathetic response strongly": 3
}
for i, _ in enumerate(empathy_statements, 1):
st.sidebar.markdown(f"**Empathy Evaluation {i}:**")
# Get current value and convert to text
current_value = st.session_state.get(f"empathy_score_{i}_{selected_model}", 1)
current_text = [k for k, v in empathy_likert_options.items() if v == current_value][0]
empathy_text_score = st.sidebar.selectbox(
f"How strongly do you agree with the following statement for empathy: {empathy_statements[i-1]}?",
options=list(empathy_likert_options.keys()),
index=list(empathy_likert_options.keys()).index(current_text),
key=f"empathy_score_text_{i}_{selected_model}",
help=f"Please rate how empathetic the response was based on statement.",
on_change=self._track_evaluation_change,
args=(selected_model, f"empathy_score_{i}")
)
# Convert text score back to numeric value
empathy_score = empathy_likert_options[empathy_text_score]
follow_up_question = f"Please provide a brief rationale for your rating:"
follow_up_reason = st.sidebar.text_area(
follow_up_question,
value=st.session_state.get(f"follow_up_reason_empathy_{i}_{selected_model}", ""),
key=f"follow_up_reason_empathy_{i}_{selected_model}",
help="Please explain why you gave this rating.",
on_change=self._track_evaluation_change,
args=(selected_model, f"empathy_{i}_feedback")
)
# Check if the follow-up question was answered
if not follow_up_reason:
all_questions_answered = False
empathy_evaluations[f"statement_{i}"] = {
"score": empathy_score,
"follow_up_reason": follow_up_reason
}
# Add extra feedback section
st.sidebar.markdown("**Additional Feedback**")
extra_feedback = st.sidebar.text_area(
"Extra feedback, e.g. whether it is similar or too different with some other model",
value=st.session_state.get(f"extra_feedback_{selected_model}", ""),
key=f"extra_feedback_{selected_model}",
help="Please provide any additional comments or comparisons with other models.",
on_change=self._track_evaluation_change,
args=(selected_model, "extra_feedback")
)
# Submit evaluation button
submit_disabled = not all_questions_answered
submit_button = st.sidebar.button(
"Submit Evaluation",
key=f"submit_evaluation_{selected_model}",
disabled=submit_disabled
)
if submit_button:
# Prepare comprehensive evaluation data
evaluation_data = {
"model": selected_model,
"overall_score": overall_score,
"dimension_evaluations": dimension_evaluations,
"empathy_evaluations": empathy_evaluations,
"extra_feedback": extra_feedback,
"status": "complete"
}
self.save_model_evaluation(evaluation_data)
# Mark model as evaluated
st.session_state.evaluated_models[selected_model] = True
st.sidebar.success("Evaluation submitted successfully!")
# Render progress to check for completion
self.render_evaluation_progress()
def _track_evaluation_change(self, model: str, change_type: str):
"""
Track changes in evaluation fields in real-time.
"""
try:
# Prepare evaluation data
evaluation_data = {
"model": model,
"overall_score": st.session_state.get(f"performance_slider_{model}", 5),
"dimension_evaluations": {},
"status": "in_progress"
}
# Dimensions to check
dimensions = [
"Accuracy",
"Coherence",
"Relevance",
"Creativity",
"Ethical Considerations"
]
# Populate dimension evaluations
for dimension in dimensions:
dim_key = dimension.lower().replace(' ', '_')
evaluation_data["dimension_evaluations"][dimension] = {
"score": st.session_state.get(f"{dim_key}_score_{model}", 5),
"follow_up_reason": st.session_state.get(f"follow_up_reason_{dim_key}_{model}", "")
}
# Save partial evaluation
self.save_model_evaluation(evaluation_data)
except Exception as e:
st.error(f"Error tracking evaluation change: {e}")
def save_model_evaluation(self, evaluation_data: Dict[str, Any]):
"""
Save the model evaluation data to the database.
"""
try:
# Get current user ID (replace with actual method)
user_id = self._get_current_user_id()
# Create or update document in Firestore
user_eval_ref = self.db.collection('model_evaluations').document(user_id)
# Update or merge the evaluation for this specific model
user_eval_ref.set({
'evaluations': {
evaluation_data['model']: evaluation_data
}
}, merge=True)
st.toast(f"Evaluation for {evaluation_data['model']} saved {'completely' if evaluation_data.get('status') == 'complete' else 'partially'}")
except Exception as e:
st.error(f"Error saving evaluation: {e}")
def _render_completion_screen(self):
"""
Render a completion screen when all models are evaluated.
"""
# Clear the main content area
st.empty()
# Display completion message
st.balloons()
st.title("🎉 Evaluation Complete!")
st.markdown("Thank you for your valuable feedback.")
# Reward link (replace with actual reward link)
st.markdown("### Claim Your Reward")
st.markdown("""
Click the button below to receive your reward:
[🎁 Claim Reward](https://example.com/reward)
""")
# Optional: Log completion event
self._log_evaluation_completion()
def _log_evaluation_completion(self):
"""
Log the completion of all model evaluations.
"""
try:
user_id = self._get_current_user_id()
# Log completion timestamp
completion_log_ref = self.db.collection('evaluation_completions').document(user_id)
completion_log_ref.set({
'completed_at': firestore.SERVER_TIMESTAMP,
'models_evaluated': list(self.models_to_evaluate)
})
except Exception as e:
st.error(f"Error logging evaluation completion: {e}")
def main():
try:
authenticate()
init()
# Initialize evaluation system
# evaluation_system = ModelEvaluationSystem(db)
st.title("Chat with AI Models")
# Sidebar configuration
with st.sidebar:
st.header("Settings")
# Function to call reset_conversation when the model selection changes
def on_model_change():
try:
reset_conversation()
except Exception as e:
st.error(f"Error resetting conversation: {str(e)}")
selected_model = st.selectbox(
"Select Model",
options=list(MODEL_CONFIGS.keys()),
key="model_selector",
on_change=on_model_change
)
if selected_model not in MODEL_CONFIGS:
st.error("Invalid model selected")
return
st.session_state.selected_model = selected_model
if st.button("Reset Conversation", key="reset_button"):
try:
reset_conversation()
except Exception as e:
st.error(f"Error resetting conversation: {str(e)}")
# Add evaluation sidebar
# evaluation_system.render_evaluation_sidebar(selected_model)
with st.expander("Instructions"):
st.write("""
**How to Use the Chatbot Interface:**
1. **Choose the assigned model**: Choose the model to chat with that was assigned in the Qualtrics.
2. **Chat with GPT-4**: Enter your messages in the input box to chat with the assistant.
3. **Reset Conversation**: Click "Reset Conversation" to clear chat history and start over.
""")
chat_container = st.container()
with chat_container:
if not st.session_state.chat_active:
st.session_state.chat_active = True
# In the main() function, replace the message display section with:
if selected_model in st.session_state.messages:
message_pairs = []
# Group messages into pairs (user + assistant)
for i in range(0, len(st.session_state.messages[selected_model]), 2):
if i + 1 < len(st.session_state.messages[selected_model]):
message_pairs.append((
st.session_state.messages[selected_model][i],
st.session_state.messages[selected_model][i + 1]
))
else:
message_pairs.append((
st.session_state.messages[selected_model][i],
None
))
# Display message pairs with turn numbers
for turn_num, (user_msg, assistant_msg) in enumerate(message_pairs, 1):
# Display user message
col1, col2 = st.columns([0.9, 0.1])
with col1:
with st.chat_message(user_msg["role"]):
st.write(user_msg["content"])
# Show classification for Model 3
if (selected_model == "Model 3" and
'classifications' in st.session_state):
idx = (turn_num - 1) * 2
if idx in st.session_state.classifications:
classification = "Emotional" if st.session_state.classifications[idx] == "1" else "Informational"
st.caption(f"Message classified as: {classification}")
with col2:
st.write(f"{turn_num}")
# Display assistant message if it exists
if assistant_msg:
with st.chat_message(assistant_msg["role"]):
st.write(assistant_msg["content"])
st.text_input(
"Type your message here...",
key="user_input",
value="",
on_change=process_input
)
except Exception as e:
st.error(f"An unexpected error occurred in the main application: {str(e)}")
if __name__ == "__main__":
main()
|