MedQA / quantum_treatment_optimizer_tool.py
mgbam's picture
Rename tools/quantum_tool.py to quantum_treatment_optimizer_tool.py
dbccd06 verified
raw
history blame
9.27 kB
# /home/user/app/tools/quantum_treatment_optimizer_tool.py
from langchain_core.tools import BaseTool # Updated import path
from typing import Type, List, Dict, Any, Optional # Optional for potentially missing fields in result
from pydantic import BaseModel, Field # For input schema validation
# Assuming your actual optimizer function is in this path
# If it's in a different location, adjust the import.
try:
from quantum.optimizer import optimize_treatment
except ImportError:
# Provide a mock function if the actual optimizer is not available
# This allows the rest of the app to run for UI/agent testing.
app_logger.warning("Actual 'quantum.optimizer.optimize_treatment' not found. Using mock function for QuantumTreatmentOptimizerTool.")
def optimize_treatment(patient_data: Dict[str, Any], current_treatments: List[str], conditions: List[str]) -> Dict[str, Any]:
# Mock implementation for demonstration and testing
mock_suggestions = [
f"Consider adjusting {current_treatments[0] if current_treatments else 'current treatment'} based on {conditions[0] if conditions else 'primary condition'}.",
"Explore adding a complementary therapy Y.",
"Monitor key biomarker Z closely."
]
return {
"simulated_optimization_id": "QO-Sim-12345",
"suggested_actions": mock_suggestions,
"primary_focus_condition": conditions[0] if conditions else "N/A",
"confidence_level_simulated": 0.75,
"summary_notes": "This simulated plan aims to address the primary condition while managing current treatments. Further clinical evaluation is essential."
}
from services.logger import app_logger # Your application logger
from services.metrics import log_tool_usage # Your metrics logger
class QuantumOptimizerInput(BaseModel):
"""Input schema for the QuantumTreatmentOptimizerTool."""
patient_data: Dict[str, Any] = Field(
description=(
"A dictionary containing relevant patient characteristics. "
"Examples: {'age': 55, 'gender': 'Male', 'relevant_labs': {'creatinine': 1.2, 'hbA1c': 7.5}, "
"'allergies': ['penicillin']}. This should be populated from the overall patient context."
)
)
current_treatments: List[str] = Field(
description="A list of current medications or therapies the patient is on (e.g., ['Aspirin 81mg', 'Metformin 500mg OD'])."
)
conditions: List[str] = Field(
description="A list of primary diagnosed conditions or symptoms to be addressed (e.g., ['Type 2 Diabetes', 'Hypertension', 'Chronic Back Pain'])."
)
# Optional: Add other specific parameters your optimizer might need
# optimization_goal: Optional[str] = Field(default=None, description="Specific goal for the optimization, e.g., 'minimize side effects', 'maximize efficacy for condition X'.")
class QuantumTreatmentOptimizerTool(BaseTool):
name: str = "quantum_treatment_optimizer"
description: str = (
"A specialized (simulated) tool that uses advanced algorithms to suggest optimized or alternative treatment plans "
"based on provided patient data, current treatments, and diagnosed conditions. "
"Use this when seeking novel therapeutic strategies, needing to optimize complex polypharmacy, "
"or exploring options for patients with multiple comorbidities. "
"You MUST provide detailed 'patient_data', 'current_treatments', and 'conditions'."
)
args_schema: Type[BaseModel] = QuantumOptimizerInput
# return_direct: bool = False # Usually False, so the agent can process the tool's output
def _format_results_for_llm(self, optimization_output: Dict[str, Any]) -> str:
"""
Formats the structured output from optimize_treatment into a natural language string
that the LLM can easily understand and use in its response to the user.
"""
if not optimization_output or not isinstance(optimization_output, dict):
return "The optimizer did not return a structured result."
summary_lines = ["Quantum Treatment Optimizer Suggestions:"]
if "suggested_actions" in optimization_output and optimization_output["suggested_actions"]:
summary_lines.append(" Key Suggested Actions:")
for action in optimization_output["suggested_actions"]:
summary_lines.append(f" - {action}")
if "primary_focus_condition" in optimization_output:
summary_lines.append(f" Primary Focus: Addressing {optimization_output['primary_focus_condition']}.")
if "confidence_level_simulated" in optimization_output:
summary_lines.append(f" Simulated Confidence Level: {optimization_output['confidence_level_simulated']:.0%}") # Format as percentage
if "summary_notes" in optimization_output:
summary_lines.append(f" Summary Notes: {optimization_output['summary_notes']}")
if "simulated_optimization_id" in optimization_output:
summary_lines.append(f" (Simulated Optimization ID: {optimization_output['simulated_optimization_id']})")
if len(summary_lines) == 1: # Only the initial title
return f"The optimizer processed the request but provided no specific actionable suggestions. Raw data: {str(optimization_output)}"
return "\n".join(summary_lines)
def _run(self, patient_data: Dict[str, Any], current_treatments: List[str], conditions: List[str], **kwargs: Any) -> str:
"""
Executes the quantum treatment optimization.
The arguments (patient_data, current_treatments, conditions) are automatically populated
by LangChain from the 'action_input' dictionary provided by the LLM,
based on the `args_schema` (QuantumOptimizerInput).
"""
# Any additional kwargs passed by the LLM in action_input that are not in the primary schema
# will be available in `kwargs` if your BaseTool is set up to accept them or if you handle them.
# For Pydantic validated args_schema, only defined fields are passed directly as named args.
app_logger.info(
f"Quantum Optimizer Tool called. Patient Data Keys: {list(patient_data.keys())}, "
f"Treatments: {current_treatments}, Conditions: {conditions}"
)
log_tool_usage(self.name, {"conditions_count": len(conditions), "treatments_count": len(current_treatments)})
# Basic validation (Pydantic handles schema, but you can add business logic checks)
if not patient_data or not conditions:
missing_info = []
if not patient_data: missing_info.append("'patient_data'")
if not conditions: missing_info.append("'conditions'")
return f"Error: Insufficient information provided for optimization. Missing: {', '.join(missing_info)}. Please provide comprehensive details."
try:
# Call your actual optimization logic
optimization_output: Dict[str, Any] = optimize_treatment(
patient_data=patient_data,
current_treatments=current_treatments,
conditions=conditions
)
app_logger.info(f"Quantum optimizer raw output: {str(optimization_output)[:500]}...") # Log snippet
# Format the potentially complex result into a string for the LLM
formatted_result = self._format_results_for_llm(optimization_output)
app_logger.info(f"Quantum optimizer formatted result for LLM: {formatted_result}")
return formatted_result
except ImportError as ie: # In case the mock was not used and import still fails
app_logger.error(f"ImportError in QuantumTreatmentOptimizerTool (quantum.optimizer likely missing): {ie}", exc_info=True)
return "Error: The core optimization module is currently unavailable."
except Exception as e:
app_logger.error(f"Unexpected error during quantum optimization process: {e}", exc_info=True)
return f"Error encountered during the optimization process: {str(e)}. Please ensure input data is correctly formatted."
async def _arun(self, patient_data: Dict[str, Any], current_treatments: List[str], conditions: List[str], **kwargs: Any) -> str:
"""
Asynchronous execution of the quantum treatment optimization.
For truly async behavior, `optimize_treatment` should be an async function,
or this method should run the sync `optimize_treatment` in a thread pool.
"""
app_logger.info(
f"Quantum Optimizer Tool (async) called. Patient Data Keys: {list(patient_data.keys())}, "
f"Treatments: {current_treatments}, Conditions: {conditions}"
)
# For now, for simplicity with Streamlit, we can call the synchronous version.
# If optimize_treatment is blocking, consider `asyncio.to_thread` for true async execution.
# import asyncio
# return await asyncio.to_thread(self._run, patient_data, current_treatments, conditions, **kwargs)
return self._run(patient_data, current_treatments, conditions, **kwargs)