MedQA / agent.py
mgbam's picture
Update agent.py
267db18 verified
# /home/user/app/agent.py
import os
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage # SystemMessage can be implicitly created or explicit
# --- Import your defined tools FROM THE 'tools' PACKAGE ---
from tools import (
BioPortalLookupTool,
UMLSLookupTool,
QuantumTreatmentOptimizerTool,
# QuantumOptimizerInput, # Schemas are primarily used by tools themselves
# GeminiTool, # Not using in this OpenAI-centric agent
)
from config.settings import settings
from services.logger import app_logger
# --- Initialize LLM (OpenAI) ---
llm = None # Initialize to None for robust error handling if init fails
try:
if not settings.OPENAI_API_KEY:
app_logger.error("CRITICAL: OPENAI_API_KEY not found in settings. Agent cannot initialize.")
raise ValueError("OpenAI API Key not configured. Please set it in Hugging Face Space secrets as OPENAI_API_KEY.")
llm = ChatOpenAI(
model_name="gpt-4-turbo-preview", # More capable model for function calling
# model_name="gpt-3.5-turbo-0125", # More cost-effective alternative
temperature=0.1, # Lower for more deterministic tool use and function calls
openai_api_key=settings.OPENAI_API_KEY
)
app_logger.info(f"ChatOpenAI ({llm.model_name}) initialized successfully for agent.")
except Exception as e:
detailed_error_message = str(e)
user_facing_error = f"OpenAI LLM initialization failed: {detailed_error_message}. Check API key and model name."
if "api_key" in detailed_error_message.lower() or "authenticate" in detailed_error_message.lower():
user_facing_error = "OpenAI LLM initialization failed: API key issue. Ensure OPENAI_API_KEY is correctly set in Hugging Face Secrets and is valid."
app_logger.error(user_facing_error + f" Original: {detailed_error_message}", exc_info=False)
else:
app_logger.error(user_facing_error, exc_info=True)
raise ValueError(user_facing_error) # Propagate error to stop further agent setup
# --- Initialize Tools List ---
tools_list = [
UMLSLookupTool(),
BioPortalLookupTool(),
QuantumTreatmentOptimizerTool(),
]
app_logger.info(f"Agent tools initialized: {[tool.name for tool in tools_list]}")
# --- Agent Prompt (for OpenAI Functions Agent - Simplified System Prompt) ---
# The create_openai_functions_agent implicitly makes tool descriptions available to the LLM
# via the function-calling mechanism. Explicitly listing {tools} and {tool_names} in the
# system prompt string might be redundant or conflict with how this agent type works.
# We will still provide overall instructions and patient_context placeholder.
OPENAI_SYSTEM_PROMPT_TEXT_SIMPLIFIED = (
"You are 'Quantum Health Navigator', an AI assistant for healthcare professionals. "
"Your goal is to assist with medical information lookup, treatment optimization queries, and general medical Q&A. "
"You have access to a set of specialized tools. Use them when a user's query can be best answered by one of them, based on their descriptions.\n"
"Disclaimers: Always state that you are for informational support and not a substitute for clinical judgment. Do not provide direct medical advice for specific patient cases without using the 'quantum_treatment_optimizer' tool if relevant.\n"
"Patient Context for this session (if provided by the user earlier): {patient_context}\n" # This variable is passed from invoke
"When using the 'quantum_treatment_optimizer' tool, ensure you populate its 'patient_data' argument using the available {patient_context}.\n"
"For `bioportal_lookup`, if the user doesn't specify an ontology, you may ask or default to 'SNOMEDCT_US'.\n"
"Always be clear and concise. Cite tools if their output forms a key part of your answer."
)
# The ChatPromptTemplate defines the sequence of messages sent to the LLM.
# `create_openai_functions_agent` expects specific placeholders.
prompt = ChatPromptTemplate.from_messages([
("system", OPENAI_SYSTEM_PROMPT_TEXT_SIMPLIFIED), # System instructions, expects {patient_context}
MessagesPlaceholder(variable_name="chat_history"), # For past Human/AI messages
("human", "{input}"), # For the current user query
MessagesPlaceholder(variable_name="agent_scratchpad") # For agent's internal work (function calls/responses)
])
app_logger.info("Agent prompt template (simplified for OpenAI Functions) created.")
# --- Create Agent ---
if llm is None: # Defensive check, should have been caught by earlier raise
app_logger.critical("LLM object is None at agent creation (OpenAI). Application cannot proceed.")
raise SystemExit("Agent LLM failed to initialize. Application cannot start.")
try:
# `create_openai_functions_agent` will use the tools' Pydantic schemas to define
# the "functions" that the OpenAI model can call.
agent = create_openai_functions_agent(llm=llm, tools=tools_list, prompt=prompt)
app_logger.info("OpenAI Functions agent created successfully.")
except Exception as e:
app_logger.error(f"Failed to create OpenAI Functions agent: {e}", exc_info=True)
# This is where the "Input to ChatPromptTemplate is missing variables" error would occur
# if the prompt object was still expecting variables not provided by the agent constructor or invoke.
raise ValueError(f"OpenAI agent creation failed: {e}")
# --- Create Agent Executor ---
agent_executor = AgentExecutor(
agent=agent,
tools=tools_list,
verbose=True, # Essential for debugging tool usage and agent thoughts
handle_parsing_errors=True, # Tries to gracefully handle LLM output parsing issues
max_iterations=7, # Prevent runaway agent loops
# return_intermediate_steps=True, # Set to True to get detailed thought/action steps in the response
)
app_logger.info("AgentExecutor with OpenAI agent created successfully.")
# --- Getter Function for Streamlit App ---
_agent_executor_instance = agent_executor # Store the successfully initialized executor
def get_agent_executor():
"""
Returns the configured agent executor for OpenAI.
The executor is initialized when this module is first imported.
"""
global _agent_executor_instance
if _agent_executor_instance is None:
# This indicates a failure during the initial module load (LLM or agent creation).
app_logger.critical("CRITICAL: Agent executor is None when get_agent_executor is called (OpenAI). Initialization likely failed.")
raise RuntimeError("Agent executor (OpenAI) was not properly initialized. Check application startup logs for errors (e.g., API key issues, prompt errors).")
# Final check for API key, though LLM initialization should be the primary guard.
if not settings.OPENAI_API_KEY:
app_logger.error("OpenAI API Key is missing at get_agent_executor call. Agent will fail.")
raise ValueError("OpenAI API Key not configured.")
return _agent_executor_instance
# --- Example Usage (for local testing of this agent.py file) ---
if __name__ == "__main__":
if not settings.OPENAI_API_KEY:
print("🚨 Please set your OPENAI_API_KEY in .env file or as an environment variable to run the test.")
else:
print("\nπŸš€ Quantum Health Navigator (OpenAI Agent Test Console) πŸš€")
print("-----------------------------------------------------------")
print("Type 'exit' or 'quit' to stop.")
print("Example topics: medical definitions, treatment optimization (will use simulated patient context).")
print("-" * 59)
try:
test_executor = get_agent_executor() # Get the executor
except ValueError as e_init: # Catch errors from get_agent_executor or LLM/agent init
print(f"⚠️ Agent initialization failed during test startup: {e_init}")
print("Ensure your API key is correctly configured and prompt variables are set.")
exit() # Exit if agent can't be initialized
current_chat_history_for_test_run = [] # List of HumanMessage, AIMessage
# Simulated patient context for testing the {patient_context} variable
test_patient_context_summary_str = (
"Age: 70; Gender: Male; Chief Complaint: Shortness of breath on exertion; "
"Key Medical History: COPD, Atrial Fibrillation; "
"Current Medications: Tiotropium inhaler, Apixaban 5mg BID; Allergies: Penicillin."
)
print(f"ℹ️ Simulated Patient Context for this test run: {test_patient_context_summary_str}\n")
while True:
user_input_str = input("πŸ‘€ You: ").strip()
if user_input_str.lower() in ["exit", "quit"]:
print("πŸ‘‹ Exiting test console.")
break
if not user_input_str: # Skip empty input
continue
try:
app_logger.info(f"__main__ test (OpenAI): Invoking with input: '{user_input_str}'")
# These are the keys expected by the ChatPromptTemplate and agent:
# "input", "chat_history", and "patient_context" (because it's in our system prompt)
response_dict = test_executor.invoke({
"input": user_input_str,
"chat_history": current_chat_history_for_test_run,
"patient_context": test_patient_context_summary_str
})
ai_output_str = response_dict.get('output', "Agent did not produce an 'output' key.")
print(f"πŸ€– Agent: {ai_output_str}")
# Update history for the next turn
current_chat_history_for_test_run.append(HumanMessage(content=user_input_str))
current_chat_history_for_test_run.append(AIMessage(content=ai_output_str))
# Optional: Limit history length to prevent overly long contexts
if len(current_chat_history_for_test_run) > 10: # Keep last 5 pairs
current_chat_history_for_test_run = current_chat_history_for_test_run[-10:]
except Exception as e_invoke:
print(f"⚠️ Error during agent invocation: {type(e_invoke).__name__} - {e_invoke}")
app_logger.error(f"Error in __main__ OpenAI agent test invocation: {e_invoke}", exc_info=True)