|
import uuid |
|
import streamlit as st |
|
from openai import AzureOpenAI |
|
import firebase_admin |
|
from firebase_admin import credentials, firestore |
|
from typing import Dict, Any |
|
import time |
|
import os |
|
import tempfile |
|
import json |
|
|
|
from utils.prompt_utils import PERSONA_PREFIX, baseline, baseline_esp, fs, RAG, EMOTIONAL_PROMPT, CLASSIFICATION_PROMPT, INFORMATIONAL_PROMPT |
|
from utils.RAG_utils import load_or_create_vectorstore |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIGS = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"Model 1": { |
|
"name": "Model 1: Few_Shot model", |
|
"prompt": PERSONA_PREFIX + fs, |
|
"uses_rag": False, |
|
"uses_classification": False |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
PASSCODE = os.environ["MY_PASSCODE"] |
|
|
|
creds_dict = { |
|
"type": os.environ.get("FIREBASE_TYPE", "service_account"), |
|
"project_id": os.environ.get("FIREBASE_PROJECT_ID"), |
|
"private_key_id": os.environ.get("FIREBASE_PRIVATE_KEY_ID"), |
|
"private_key": os.environ.get("FIREBASE_PRIVATE_KEY", "").replace("\\n", "\n"), |
|
"client_email": os.environ.get("FIREBASE_CLIENT_EMAIL"), |
|
"client_id": os.environ.get("FIREBASE_CLIENT_ID"), |
|
"auth_uri": os.environ.get("FIREBASE_AUTH_URI", "https://accounts.google.com/o/oauth2/auth"), |
|
"token_uri": os.environ.get("FIREBASE_TOKEN_URI", "https://oauth2.googleapis.com/token"), |
|
"auth_provider_x509_cert_url": os.environ.get("FIREBASE_AUTH_PROVIDER_X509_CERT_URL", |
|
"https://www.googleapis.com/oauth2/v1/certs"), |
|
"client_x509_cert_url": os.environ.get("FIREBASE_CLIENT_X509_CERT_URL"), |
|
"universe_domain": "googleapis.com" |
|
|
|
} |
|
|
|
file_path = "coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json" |
|
with open(file_path, 'w') as json_file: |
|
json.dump(creds_dict, json_file, indent=2) |
|
|
|
|
|
if not firebase_admin._apps: |
|
cred = credentials.Certificate("coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json") |
|
firebase_admin.initialize_app(cred) |
|
db = firestore.client() |
|
|
|
endpoint = os.environ["ENDPOINT_URL"] |
|
deployment = os.environ["DEPLOYMENT"] |
|
subscription_key = os.environ["subscription_key"] |
|
|
|
|
|
client = AzureOpenAI( |
|
azure_endpoint=endpoint, |
|
api_key=subscription_key, |
|
api_version=os.environ["api_version"] |
|
) |
|
|
|
def authenticate(): |
|
import uuid |
|
|
|
random_id = uuid.uuid4() |
|
random_id_string = str(random_id) |
|
evaluator_id = random_id_string |
|
db = firestore.client() |
|
db.collection("evaluator_ids").document(evaluator_id).set({ |
|
"evaluator_id": evaluator_id, |
|
"timestamp": firestore.SERVER_TIMESTAMP |
|
}) |
|
|
|
|
|
st.session_state["authenticated"] = True |
|
st.session_state["evaluator_id"] = evaluator_id |
|
|
|
def init(): |
|
"""Initialize all necessary components and state variables""" |
|
|
|
if not firebase_admin._apps: |
|
cred = credentials.Certificate("coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json") |
|
firebase_admin.initialize_app(cred) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = {} |
|
if "session_id" not in st.session_state: |
|
st.session_state.session_id = str(uuid.uuid4()) |
|
if "chat_active" not in st.session_state: |
|
st.session_state.chat_active = False |
|
if "user_input" not in st.session_state: |
|
st.session_state.user_input = "" |
|
if "user_id" not in st.session_state: |
|
st.session_state.user_id = f"anonymous_{str(uuid.uuid4())}" |
|
if "selected_model" not in st.session_state: |
|
st.session_state.selected_model = list(MODEL_CONFIGS.keys())[0] |
|
if "model_profile" not in st.session_state: |
|
st.session_state.model_profile = [0, 0] |
|
|
|
|
|
if "vectorstore" not in st.session_state: |
|
with st.spinner("Loading document embeddings..."): |
|
st.session_state.vectorstore = load_or_create_vectorstore() |
|
|
|
def get_classification(client, deployment, user_input): |
|
"""Classify the input as emotional (1) or informational (0)""" |
|
chat_prompt = [ |
|
{"role": "system", "content": CLASSIFICATION_PROMPT}, |
|
{"role": "user", "content": user_input} |
|
] |
|
|
|
completion = client.chat.completions.create( |
|
model=deployment, |
|
messages=chat_prompt, |
|
max_tokens=1, |
|
temperature=0, |
|
top_p=0.9, |
|
frequency_penalty=0, |
|
presence_penalty=0, |
|
stop=None |
|
) |
|
|
|
return completion.choices[0].message.content.strip() |
|
def process_input(): |
|
try: |
|
current_model = st.session_state.selected_model |
|
user_input = st.session_state.user_input |
|
|
|
if not user_input.strip(): |
|
st.warning("Please enter a message before sending.") |
|
return |
|
|
|
model_config = MODEL_CONFIGS.get(current_model) |
|
if not model_config: |
|
st.error("Invalid model selected. Please choose a valid model.") |
|
return |
|
|
|
if current_model not in st.session_state.messages: |
|
st.session_state.messages[current_model] = [] |
|
|
|
st.session_state.messages[current_model].append({"role": "user", "content": user_input}) |
|
|
|
try: |
|
log_message("user", user_input) |
|
except Exception as e: |
|
st.warning(f"Failed to log message: {str(e)}") |
|
|
|
conversation_history = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" |
|
for msg in st.session_state.messages[current_model]]) |
|
|
|
|
|
def safe_api_call(messages, max_retries=3): |
|
for attempt in range(max_retries): |
|
try: |
|
response = client.chat.completions.create( |
|
model=deployment, |
|
messages=messages, |
|
max_tokens=3500, |
|
temperature=0.1, |
|
top_p=0.9 |
|
) |
|
return response.choices[0].message.content.strip() |
|
except Exception as e: |
|
if attempt == max_retries - 1: |
|
raise Exception(f"Failed to get response after {max_retries} attempts: {str(e)}") |
|
st.warning(f"Attempt {attempt + 1} failed, retrying...") |
|
time.sleep(1) |
|
|
|
def perform_rag_query(input_text, conversation_history): |
|
try: |
|
relevant_docs = retrieve_relevant_documents( |
|
st.session_state.vectorstore, |
|
input_text, |
|
conversation_history, |
|
client=client |
|
) |
|
|
|
model_messages = [ |
|
{"role": "system", "content": f"{model_config['prompt']}\n\nContexto: {relevant_docs}"} |
|
] + st.session_state.messages[current_model] |
|
|
|
return safe_api_call(model_messages), relevant_docs |
|
|
|
except Exception as e: |
|
st.error(f"Error in RAG query: {str(e)}") |
|
return "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente.", "" |
|
|
|
initial_response = None |
|
initial_docs = "" |
|
|
|
|
|
if model_config.get('uses_classification', False): |
|
try: |
|
classification = get_classification(client, deployment, user_input) |
|
|
|
if 'classifications' not in st.session_state: |
|
st.session_state.classifications = {} |
|
st.session_state.classifications[len(st.session_state.messages[current_model]) - 1] = classification |
|
|
|
if classification == "0": |
|
initial_response, initial_docs = perform_rag_query(user_input, conversation_history) |
|
else: |
|
model_messages = [ |
|
{"role": "system", "content": PERSONA_PREFIX + EMOTIONAL_PROMPT} |
|
] + st.session_state.messages[current_model] |
|
initial_response = safe_api_call(model_messages) |
|
|
|
except Exception as e: |
|
st.error(f"Error in classification stage: {str(e)}") |
|
initial_response = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente." |
|
|
|
|
|
if model_config.get('uses_rag', False): |
|
try: |
|
if not initial_response: |
|
initial_response, initial_docs = perform_rag_query(user_input, conversation_history) |
|
|
|
verification_docs = retrieve_relevant_documents( |
|
st.session_state.vectorstore, |
|
initial_response, |
|
conversation_history, |
|
client=client |
|
) |
|
|
|
combined_docs = initial_docs + "\nContexto de verificación adicional:\n" + verification_docs |
|
|
|
verification_messages = [ |
|
{ |
|
"role": "system", |
|
"content": f"Pregunta del paciente:{user_input} \nContexto: {combined_docs} \nRespuesta anterior: {initial_response}\n Verifique la precisión médica de la respuesta anterior y refine la respuesta según el contexto adicional." |
|
} |
|
] |
|
|
|
assistant_reply = safe_api_call(verification_messages) |
|
|
|
except Exception as e: |
|
st.error(f"Error in RAG processing: {str(e)}") |
|
assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente." |
|
else: |
|
try: |
|
model_messages = [ |
|
{"role": "system", "content": model_config['prompt']} |
|
] + st.session_state.messages[current_model] |
|
|
|
assistant_reply = safe_api_call(model_messages) |
|
|
|
except Exception as e: |
|
st.error(f"Error generating response: {str(e)}") |
|
assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente." |
|
|
|
|
|
try: |
|
st.session_state.messages[current_model].append({"role": "assistant", "content": assistant_reply}) |
|
log_message("assistant", assistant_reply) |
|
|
|
except Exception as e: |
|
st.warning(f"Failed to store or log response: {str(e)}") |
|
|
|
st.session_state.user_input = "" |
|
|
|
except Exception as e: |
|
st.error(f"An unexpected error occurred: {str(e)}") |
|
st.session_state.user_input = "" |
|
|
|
|
|
def check_document_relevance(query, doc, client): |
|
""" |
|
Check document relevance using few-shot prompting for Spanish TB context. |
|
|
|
Args: |
|
query (str): The user's input query |
|
doc (str): The retrieved document text |
|
client: The OpenAI client instance |
|
|
|
Returns: |
|
bool: True if document is relevant, False otherwise |
|
""" |
|
few_shot_prompt = f"""Determine si el documento es relevante para la consulta sobre tuberculosis. |
|
Responde únicamente 'sí' si es relevante o 'no' si no es relevante. |
|
Ejemplos: |
|
Consulta: ¿Cuáles son los efectos secundarios de la rifampicina? |
|
Documento: La rifampicina puede causar efectos secundarios como náuseas, vómitos y coloración naranja de fluidos corporales. Es importante tomar el medicamento con el estómago vacío. |
|
Respuesta: sí |
|
Consulta: ¿Cuánto dura el tratamiento de TB? |
|
Documento: El dengue es una enfermedad viral transmitida por mosquitos. Los síntomas incluyen fiebre alta y dolor muscular. |
|
Respuesta: no |
|
Consulta: ¿Cómo se realiza la prueba de esputo? |
|
Documento: Para la prueba de esputo, el paciente debe toser profundamente para obtener una muestra de las vías respiratorias. La muestra debe recogerse en ayunas. |
|
Respuesta: sí |
|
Consulta: ¿Qué medidas de prevención debo tomar en casa? |
|
Documento: Mayo Clinic tiene una gran cantidad de pacientes que atender. |
|
Respuesta: no |
|
Consulta: {query} |
|
Documento: {doc} |
|
Respuesta:""" |
|
|
|
response = client.chat.completions.create( |
|
model=deployment, |
|
messages=[{"role": "user", "content": few_shot_prompt}], |
|
max_tokens=3, |
|
temperature=0.1, |
|
top_p=0.9 |
|
) |
|
|
|
return response.choices[0].message.content.strip().lower() == "sí" |
|
|
|
|
|
def retrieve_relevant_documents(vectorstore, query, conversation_history, client, top_k=3, score_threshold=0.5): |
|
if not vectorstore: |
|
st.error("Vector store not initialized") |
|
return "" |
|
|
|
try: |
|
recent_history = "\n".join(conversation_history.split("\n")[-3:]) if conversation_history else "" |
|
full_query = query |
|
if len(recent_history) < 200: |
|
full_query = f"{recent_history} {query}".strip() |
|
|
|
results = vectorstore.similarity_search_with_score( |
|
full_query, |
|
k=top_k, |
|
distance_metric="cos" |
|
) |
|
|
|
if not results: |
|
return "No se encontraron documentos relevantes." |
|
|
|
|
|
if results and not isinstance(results[0], tuple): |
|
|
|
score_filtered_results = [(doc, 1.0) for doc in results] |
|
else: |
|
|
|
score_filtered_results = [ |
|
(result, score) for result, score in results |
|
if score > score_threshold |
|
] |
|
|
|
|
|
relevant_results = [] |
|
for result, score in score_filtered_results: |
|
if check_document_relevance(query, result.page_content, client): |
|
relevant_results.append((result, score)) |
|
|
|
|
|
if not relevant_results: |
|
if score_filtered_results: |
|
print("No relevant documents found after relevance check.") |
|
return "Eres un modelo de IA centrado en la tuberculosis." |
|
return "" |
|
|
|
|
|
combined_results = [ |
|
f"Document excerpt (score: {score:.2f}):\n{result.page_content}" |
|
for result, score in relevant_results |
|
] |
|
|
|
return "\n\n".join(combined_results) |
|
|
|
except Exception as e: |
|
st.error(f"Error retrieving documents: {str(e)}") |
|
return "Error al buscar documentos relevantes." |
|
|
|
def store_conversation_data(): |
|
current_model = st.session_state.selected_model |
|
model_config = MODEL_CONFIGS[current_model] |
|
|
|
doc_ref = db.collection('conversations').document(str(st.session_state.session_id)) |
|
doc_ref.set({ |
|
'timestamp': firestore.SERVER_TIMESTAMP, |
|
'userID': st.session_state.user_id, |
|
'model_index': list(MODEL_CONFIGS.keys()).index(current_model) + 1, |
|
'profile_index': st.session_state.model_profile[1], |
|
'profile': '', |
|
'conversation': st.session_state.messages[current_model], |
|
'uses_rag': model_config['uses_rag'] |
|
}) |
|
|
|
def log_message(role, content): |
|
current_model = st.session_state.selected_model |
|
model_config = MODEL_CONFIGS[current_model] |
|
collection_name = f"messages_model_{list(MODEL_CONFIGS.keys()).index(current_model) + 1}" |
|
|
|
doc_ref = db.collection(collection_name).document() |
|
doc_ref.set({ |
|
'timestamp': firestore.SERVER_TIMESTAMP, |
|
'session_id': str(st.session_state.session_id), |
|
'userID': st.session_state.get('user_id', 'anonymous'), |
|
'role': role, |
|
'content': content, |
|
'model_name': model_config['name'] |
|
}) |
|
|
|
def reset_conversation(): |
|
current_model = st.session_state.selected_model |
|
|
|
if current_model in st.session_state.messages and st.session_state.messages[current_model]: |
|
doc_ref = db.collection('conversation_ends').document() |
|
doc_ref.set({ |
|
'timestamp': firestore.SERVER_TIMESTAMP, |
|
'session_id': str(st.session_state.session_id), |
|
'userID': st.session_state.get('user_id', 'anonymous'), |
|
'total_messages': len(st.session_state.messages[current_model]), |
|
'model_name': MODEL_CONFIGS[current_model]['name'] |
|
}) |
|
|
|
st.session_state.messages[current_model] = [] |
|
st.session_state.session_id = str(uuid.uuid4()) |
|
st.session_state.chat_active = False |
|
st.query_params.clear() |
|
|
|
class ModelEvaluationSystem: |
|
def __init__(self, db: firestore.Client): |
|
self.db = db |
|
self.models_to_evaluate = list(MODEL_CONFIGS.keys()) |
|
self._initialize_state() |
|
self._load_existing_evaluations() |
|
|
|
def _initialize_state(self): |
|
"""Initialize or load evaluation state.""" |
|
if "evaluation_state" not in st.session_state: |
|
st.session_state.evaluation_state = {} |
|
|
|
if "evaluated_models" not in st.session_state: |
|
st.session_state.evaluated_models = {} |
|
|
|
def _get_current_user_id(self): |
|
""" |
|
Get current user identifier. |
|
""" |
|
return st.session_state["evaluator_id"] |
|
|
|
def render_evaluation_progress(self): |
|
""" |
|
Render evaluation progress in the sidebar. |
|
""" |
|
st.sidebar.header("Evaluation Progress") |
|
|
|
|
|
total_models = len(self.models_to_evaluate) |
|
evaluated_models = len(st.session_state.evaluated_models) |
|
|
|
|
|
st.sidebar.progress(evaluated_models / total_models) |
|
|
|
|
|
for model in self.models_to_evaluate: |
|
status = "✅ Completed" if st.session_state.evaluated_models.get(model, False) else "⏳ Pending" |
|
st.sidebar.markdown(f"{model}: {status}") |
|
|
|
|
|
if evaluated_models == total_models: |
|
self._render_completion_screen() |
|
|
|
def _load_existing_evaluations(self): |
|
""" |
|
Load existing evaluations from Firestore for the current user/session. |
|
""" |
|
try: |
|
user_id = self._get_current_user_id() |
|
|
|
existing_evals = self.db.collection('model_evaluations').document(user_id).get() |
|
|
|
if existing_evals.exists: |
|
loaded_data = existing_evals.to_dict() |
|
|
|
|
|
for model, eval_data in loaded_data.get('evaluations', {}).items(): |
|
if eval_data.get('status') == 'complete': |
|
st.session_state.evaluated_models[model] = True |
|
|
|
|
|
st.session_state[f"performance_slider_{model}"] = eval_data.get('overall_score', 5) |
|
|
|
for dimension, dim_data in eval_data.get('dimension_evaluations', {}).items(): |
|
dim_key = dimension.lower().replace(' ', '_') |
|
st.session_state[f"{dim_key}_score_{model}"] = dim_data.get('score', 5) |
|
|
|
if dim_data.get('follow_up_reason'): |
|
st.session_state[f"follow_up_reason_{dim_key}_{model}"] = dim_data['follow_up_reason'] |
|
|
|
except Exception as e: |
|
st.error(f"Error loading existing evaluations: {e}") |
|
|
|
def render_evaluation_sidebar(self, selected_model): |
|
""" |
|
Render evaluation sidebar for the selected model, including the Empathy section. |
|
""" |
|
|
|
dimensions = { |
|
"Accuracy": "The answers provided by the chatbot were medically accurate and contained no errors", |
|
"Comprehensiveness": "The answers are comprehensive and are not missing important information", |
|
"Helpfulness to the Human Responder": "The answers are helpful to the human responder and require minimal or no edits before sending them to the patient", |
|
"Understanding": "The chatbot was able to understand my questions and responded appropriately to the questions asked", |
|
"Clarity": "The chatbot was able to provide answers that patients would be able to understand for their level of medical literacy", |
|
"Language": "The chatbot provided answers that were idiomatically appropriate and are indistinguishable from those produced by native Spanish speakers", |
|
"Harm": "The answers provided do not contain information that would lead to patient harm or negative outcomes", |
|
"Fabrication": "The chatbot provided answers that were free of hallucinations, fabricated information, or other information that was not based or evidence-based medical practice", |
|
"Trust": "The chatbot provided responses that are similar to those that would be provided by an expert or healthcare professional with experience in treating tuberculosis" |
|
} |
|
|
|
empathy_statements = [ |
|
"Response included expression of emotions, such as warmth, compassion, and concern or similar towards the patient (i.e. Todo estará bien. / Everything will be fine).", |
|
"Response communicated an understanding of feelings and experiences interpreted from the patient's responses (i.e. Entiendo su preocupación. / I understand your concern).", |
|
"Response aimed to improve understanding by exploring the feelings and experiences of the patient (i.e. Cuénteme más de cómo se está sintiendo. / Tell me more about how you are feeling.)" |
|
] |
|
|
|
st.sidebar.subheader(f"Evaluate {selected_model}") |
|
|
|
|
|
overall_score = st.sidebar.slider( |
|
"Overall Model Performance", |
|
min_value=1, |
|
max_value=10, |
|
value=st.session_state.get(f"performance_slider_{selected_model}", 5), |
|
key=f"performance_slider_{selected_model}", |
|
on_change=self._track_evaluation_change, |
|
args=(selected_model, 'overall_score') |
|
) |
|
|
|
|
|
dimension_evaluations = {} |
|
all_questions_answered = True |
|
|
|
for dimension in dimensions.keys(): |
|
st.sidebar.markdown(f"**{dimension} Evaluation**") |
|
|
|
|
|
likert_options = { |
|
"Strongly Disagree": 1, |
|
"Disagree": 2, |
|
"Neutral": 3, |
|
"Agree": 4, |
|
"Strongly Agree": 5 |
|
} |
|
|
|
|
|
current_value = st.session_state.get(f"{dimension.lower().replace(' ', '_')}_score_{selected_model}", 3) |
|
current_text = [k for k, v in likert_options.items() if v == current_value][0] |
|
|
|
|
|
dimension_text_score = st.sidebar.selectbox( |
|
f"{dimensions[dimension]} Rating", |
|
options=list(likert_options.keys()), |
|
index=list(likert_options.keys()).index(current_text), |
|
key=f"{dimension.lower().replace(' ', '_')}_score_text_{selected_model}", |
|
on_change=self._track_evaluation_change, |
|
args=(selected_model, dimension) |
|
) |
|
|
|
|
|
dimension_score = likert_options[dimension_text_score] |
|
|
|
|
|
if dimension_score < 4: |
|
follow_up_question = "Please, provide an example or description for your feedback." |
|
feedback_type = "disagreement" |
|
|
|
follow_up_reason = st.sidebar.text_area( |
|
follow_up_question, |
|
value=st.session_state.get(f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}", ""), |
|
key=f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}", |
|
help=f"Please provide specific feedback about the model's performance in {dimension}", |
|
on_change=self._track_evaluation_change, |
|
args=(selected_model, f"{dimension}_feedback") |
|
) |
|
|
|
|
|
if not follow_up_reason: |
|
all_questions_answered = False |
|
|
|
dimension_evaluations[dimension] = { |
|
"score": dimension_score, |
|
"feedback_type": feedback_type, |
|
"follow_up_reason": follow_up_reason |
|
} |
|
else: |
|
dimension_evaluations[dimension] = { |
|
"score": dimension_score, |
|
"feedback_type": "neutral_or_positive", |
|
"follow_up_reason": None |
|
} |
|
|
|
st.sidebar.markdown(f"**Empathy Section**") |
|
st.sidebar.markdown("<small><a href='https://docs.google.com/document/d/1Olqfo14Zde_GXXWAPzG0OiYUE53nc_I3/edit?usp=sharing&ouid=107404473110455439345&rtpof=true&sd=true' target='_blank'>Look here for example ratings</a></small>", unsafe_allow_html=True) |
|
|
|
|
|
empathy_evaluations = {} |
|
empathy_likert_options = { |
|
"No expression of an empathetic response": 1, |
|
"Expressed empathetic response to a weak degree": 2, |
|
"Expressed empathetic response strongly": 3 |
|
} |
|
|
|
for i, _ in enumerate(empathy_statements, 1): |
|
st.sidebar.markdown(f"**Empathy Evaluation {i}:**") |
|
|
|
|
|
current_value = st.session_state.get(f"empathy_score_{i}_{selected_model}", 1) |
|
current_text = [k for k, v in empathy_likert_options.items() if v == current_value][0] |
|
|
|
empathy_text_score = st.sidebar.selectbox( |
|
f"How strongly do you agree with the following statement for empathy: {empathy_statements[i-1]}?", |
|
options=list(empathy_likert_options.keys()), |
|
index=list(empathy_likert_options.keys()).index(current_text), |
|
key=f"empathy_score_text_{i}_{selected_model}", |
|
help=f"Please rate how empathetic the response was based on statement.", |
|
on_change=self._track_evaluation_change, |
|
args=(selected_model, f"empathy_score_{i}") |
|
) |
|
|
|
|
|
empathy_score = empathy_likert_options[empathy_text_score] |
|
|
|
follow_up_question = f"Please provide a brief rationale for your rating:" |
|
follow_up_reason = st.sidebar.text_area( |
|
follow_up_question, |
|
value=st.session_state.get(f"follow_up_reason_empathy_{i}_{selected_model}", ""), |
|
key=f"follow_up_reason_empathy_{i}_{selected_model}", |
|
help="Please explain why you gave this rating.", |
|
on_change=self._track_evaluation_change, |
|
args=(selected_model, f"empathy_{i}_feedback") |
|
) |
|
|
|
|
|
if not follow_up_reason: |
|
all_questions_answered = False |
|
|
|
empathy_evaluations[f"statement_{i}"] = { |
|
"score": empathy_score, |
|
"follow_up_reason": follow_up_reason |
|
} |
|
|
|
|
|
st.sidebar.markdown("**Additional Feedback**") |
|
extra_feedback = st.sidebar.text_area( |
|
"Extra feedback, e.g. whether it is similar or too different with some other model", |
|
value=st.session_state.get(f"extra_feedback_{selected_model}", ""), |
|
key=f"extra_feedback_{selected_model}", |
|
help="Please provide any additional comments or comparisons with other models.", |
|
on_change=self._track_evaluation_change, |
|
args=(selected_model, "extra_feedback") |
|
) |
|
|
|
|
|
submit_disabled = not all_questions_answered |
|
|
|
submit_button = st.sidebar.button( |
|
"Submit Evaluation", |
|
key=f"submit_evaluation_{selected_model}", |
|
disabled=submit_disabled |
|
) |
|
|
|
if submit_button: |
|
|
|
evaluation_data = { |
|
"model": selected_model, |
|
"overall_score": overall_score, |
|
"dimension_evaluations": dimension_evaluations, |
|
"empathy_evaluations": empathy_evaluations, |
|
"extra_feedback": extra_feedback, |
|
"status": "complete" |
|
} |
|
|
|
self.save_model_evaluation(evaluation_data) |
|
|
|
|
|
st.session_state.evaluated_models[selected_model] = True |
|
|
|
st.sidebar.success("Evaluation submitted successfully!") |
|
|
|
|
|
self.render_evaluation_progress() |
|
|
|
|
|
def _track_evaluation_change(self, model: str, change_type: str): |
|
""" |
|
Track changes in evaluation fields in real-time. |
|
""" |
|
try: |
|
|
|
evaluation_data = { |
|
"model": model, |
|
"overall_score": st.session_state.get(f"performance_slider_{model}", 5), |
|
"dimension_evaluations": {}, |
|
"status": "in_progress" |
|
} |
|
|
|
|
|
dimensions = [ |
|
"Accuracy", |
|
"Coherence", |
|
"Relevance", |
|
"Creativity", |
|
"Ethical Considerations" |
|
] |
|
|
|
|
|
for dimension in dimensions: |
|
dim_key = dimension.lower().replace(' ', '_') |
|
evaluation_data["dimension_evaluations"][dimension] = { |
|
"score": st.session_state.get(f"{dim_key}_score_{model}", 5), |
|
"follow_up_reason": st.session_state.get(f"follow_up_reason_{dim_key}_{model}", "") |
|
} |
|
|
|
|
|
self.save_model_evaluation(evaluation_data) |
|
|
|
except Exception as e: |
|
st.error(f"Error tracking evaluation change: {e}") |
|
|
|
def save_model_evaluation(self, evaluation_data: Dict[str, Any]): |
|
""" |
|
Save the model evaluation data to the database. |
|
""" |
|
try: |
|
|
|
user_id = self._get_current_user_id() |
|
|
|
|
|
user_eval_ref = self.db.collection('model_evaluations').document(user_id) |
|
|
|
|
|
user_eval_ref.set({ |
|
'evaluations': { |
|
evaluation_data['model']: evaluation_data |
|
} |
|
}, merge=True) |
|
|
|
st.toast(f"Evaluation for {evaluation_data['model']} saved {'completely' if evaluation_data.get('status') == 'complete' else 'partially'}") |
|
|
|
except Exception as e: |
|
st.error(f"Error saving evaluation: {e}") |
|
|
|
def _render_completion_screen(self): |
|
""" |
|
Render a completion screen when all models are evaluated. |
|
""" |
|
|
|
st.empty() |
|
|
|
|
|
st.balloons() |
|
st.title("🎉 Evaluation Complete!") |
|
st.markdown("Thank you for your valuable feedback.") |
|
|
|
|
|
st.markdown("### Claim Your Reward") |
|
st.markdown(""" |
|
Click the button below to receive your reward: |
|
|
|
[🎁 Claim Reward](https://example.com/reward) |
|
""") |
|
|
|
|
|
self._log_evaluation_completion() |
|
|
|
def _log_evaluation_completion(self): |
|
""" |
|
Log the completion of all model evaluations. |
|
""" |
|
try: |
|
user_id = self._get_current_user_id() |
|
|
|
|
|
completion_log_ref = self.db.collection('evaluation_completions').document(user_id) |
|
completion_log_ref.set({ |
|
'completed_at': firestore.SERVER_TIMESTAMP, |
|
'models_evaluated': list(self.models_to_evaluate) |
|
}) |
|
|
|
except Exception as e: |
|
st.error(f"Error logging evaluation completion: {e}") |
|
|
|
def main(): |
|
try: |
|
authenticate() |
|
init() |
|
|
|
|
|
|
|
|
|
st.title("Chat with AI Models") |
|
|
|
with st.sidebar: |
|
st.header("Settings") |
|
|
|
|
|
def on_model_change(): |
|
try: |
|
reset_conversation() |
|
except Exception as e: |
|
st.error(f"Error resetting conversation: {str(e)}") |
|
|
|
selected_model = st.selectbox( |
|
"Select Model", |
|
options=list(MODEL_CONFIGS.keys()), |
|
key="model_selector", |
|
on_change=on_model_change |
|
) |
|
|
|
if selected_model not in MODEL_CONFIGS: |
|
st.error("Invalid model selected") |
|
return |
|
|
|
st.session_state.selected_model = selected_model |
|
|
|
if st.button("Reset Conversation", key="reset_button"): |
|
try: |
|
reset_conversation() |
|
except Exception as e: |
|
st.error(f"Error resetting conversation: {str(e)}") |
|
|
|
|
|
|
|
|
|
with st.expander("Instructions"): |
|
st.write(""" |
|
**How to Use the Chatbot Interface:** |
|
1. **Choose the assigned model**: Choose the model to chat with that was assigned in the Qualtrics. |
|
2. **Chat with GPT-4**: Enter your messages in the input box to chat with the assistant. |
|
3. **Reset Conversation**: Click "Reset Conversation" to clear chat history and start over. |
|
""") |
|
|
|
chat_container = st.container() |
|
|
|
with chat_container: |
|
if not st.session_state.chat_active: |
|
st.session_state.chat_active = True |
|
|
|
if selected_model in st.session_state.messages: |
|
message_pairs = [] |
|
|
|
for i in range(0, len(st.session_state.messages[selected_model]), 2): |
|
if i + 1 < len(st.session_state.messages[selected_model]): |
|
message_pairs.append(( |
|
st.session_state.messages[selected_model][i], |
|
st.session_state.messages[selected_model][i + 1] |
|
)) |
|
else: |
|
message_pairs.append(( |
|
st.session_state.messages[selected_model][i], |
|
None |
|
)) |
|
|
|
|
|
for turn_num, (user_msg, assistant_msg) in enumerate(message_pairs, 1): |
|
|
|
col1, col2 = st.columns([0.9, 0.1]) |
|
with col1: |
|
with st.chat_message(user_msg["role"]): |
|
st.write(user_msg["content"]) |
|
|
|
if (selected_model == "Model 3" and |
|
'classifications' in st.session_state): |
|
idx = (turn_num - 1) * 2 |
|
if idx in st.session_state.classifications: |
|
classification = "Emotional" if st.session_state.classifications[idx] == "1" else "Informational" |
|
st.caption(f"Message classified as: {classification}") |
|
with col2: |
|
st.write(f"{turn_num}") |
|
|
|
|
|
if assistant_msg: |
|
with st.chat_message(assistant_msg["role"]): |
|
st.write(assistant_msg["content"]) |
|
|
|
st.text_input( |
|
"Type your message here...", |
|
key="user_input", |
|
value="", |
|
on_change=process_input |
|
) |
|
except Exception as e: |
|
st.error(f"An unexpected error occurred in the main application: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|