""" SynapseAI Clinical Decision Support System Expert-Level Implementation with Safety-Centric Architecture """ import os import re import json import logging from typing import (Any, Dict, List, Optional, TypedDict, Callable, Sequence, Tuple, Union) from functools import lru_cache from enum import Enum import requests from pydantic import BaseModel, Field, ValidationError from langchain_groq import ChatGroq from langchain_core.messages import (HumanMessage, SystemMessage, AIMessage, ToolMessage) from langchain_core.tools import BaseTool from langgraph.graph import StateGraph, END from langgraph.prebuilt import ToolExecutor # ── Type Definitions ────────────────────────────────────────────────── class ClinicalPriority(str, Enum): STAT = "STAT" URGENT = "Urgent" ROUTINE = "Routine" class ClinicalState(TypedDict): messages: List[Union[HumanMessage, SystemMessage, AIMessage, ToolMessage]] patient_data: Dict[str, Any] safety_warnings: List[Dict[str, str]] workflow_metadata: Dict[str, Union[int, float, bool]] execution_log: List[Dict[str, str]] # ── Configuration ───────────────────────────────────────────────────── class ClinicalConfig: MAX_ITERATIONS = 6 # Evidence-based conversation turn limit RECURSION_BUFFER = 2 # Safety margin for LangGraph execution DRUG_CHECK_REQUIRED = True # Hard enforcement for interaction checks SAFETY_PARAMETERS = { 'max_bp_systolic': 180, 'min_bp_systolic': 90, 'max_hr': 120, 'min_spo2': 92 } # ── Core Clinical Tools ────────────────────────────────────────────── class ClinicalToolkit: @staticmethod def get_essential_tools() -> List[BaseTool]: """Return validated clinical tools with safety wrappers""" return [ ClinicalToolkit.order_lab_test, ClinicalToolkit.prescribe_medication, ClinicalToolkit.check_drug_interactions, ClinicalToolkit.flag_clinical_risk ] class LabOrderInput(BaseModel): test_name: str = Field(..., pattern=r"^[A-Za-z0-9\s-]+$") rationale: str = Field(..., min_length=20) priority: ClinicalPriority = ClinicalPriority.ROUTINE @tool("order_lab_test", args_schema=LabOrderInput) def order_lab_test(test_name: str, rationale: str, priority: ClinicalPriority) -> Dict[str, Any]: """Standardized lab ordering with clinical validation""" # Implementation details... return {"status": "ordered", "details": {...}} class PrescriptionSafetyCheck(BaseModel): medication: str rxcui: Optional[str] contraindications: List[str] # Additional safety fields... @classmethod def validate_prescription(cls, rx_data: Dict) -> PrescriptionSafetyCheck: """Pharmaceutical safety validation pipeline""" # Comprehensive validation logic... return PrescriptionSafetyCheck(...) # ── State Management Engine ────────────────────────────────────────── class ClinicalStateManager: @staticmethod def initialize_state(patient_data: Dict) -> ClinicalState: """Create validated initial state with clinical context""" return { "messages": [ SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT), HumanMessage(content="Initiate clinical consultation") ], "patient_data": ClinicalValidator.sanitize_patient_data(patient_data), "safety_warnings": [], "workflow_metadata": { "iterations": 0, "active_alerts": 0, "safety_override": False }, "execution_log": [] } @staticmethod def propagate_state(previous: ClinicalState, updates: Dict) -> ClinicalState: """State transition with clinical context preservation""" preserved_fields = { 'patient_data': previous['patient_data'], 'workflow_metadata': { **previous['workflow_metadata'], **updates.get('workflow_metadata', {}) } } return ClinicalValidator.validate_state({ **preserved_fields, **updates }) # ── Clinical Workflow Nodes ───────────────────────────────────────── class ClinicalWorkflowNodes: @staticmethod def agent_node(state: ClinicalState) -> ClinicalState: """FDA-compliant clinical reasoning engine""" ClinicalValidator.check_iteration_limit(state) try: response = ClinicalLLM.invoke(state) new_state = ClinicalStateManager.propagate_state(state, { "messages": [response], "workflow_metadata": { "iterations": state["workflow_metadata"]["iterations"] + 1 } }) if ClinicalTerminationCriteria.should_terminate(new_state): return ClinicalWorkflowNodes.apply_termination_protocol(new_state) return new_state except CriticalClinicalError as e: return ClinicalErrorHandler.handle_critical_error(state, e) @staticmethod def tool_node(state: ClinicalState) -> ClinicalState: """HIPAA-compliant tool execution with safety audit""" ClinicalSafetyEngine.pre_execution_checks(state) tool_results = [] for tool_call in state["messages"][-1].tool_calls: result = ClinicalToolExecutor.execute_with_audit(tool_call) tool_results.append(result) if result['category'] == "DRUG_ORDER": ClinicalSafetyEngine.post_drug_order_checks(result) return ClinicalStateManager.propagate_state(state, { "messages": [ToolMessage(...)], "safety_warnings": ClinicalSafetyEngine.aggregate_warnings(tool_results) }) # ── Safety Subsystems ─────────────────────────────────────────────── class ClinicalSafetyEngine: @staticmethod def enforce_prescription_rules(tool_calls: List) -> None: """Hard requirements for medication orders""" drug_orders = [tc for tc in tool_calls if tc.name == "prescribe_medication"] interaction_checks = [tc for tc in tool_calls if tc.name == "check_drug_interactions"] if ClinicalConfig.DRUG_CHECK_REQUIRED: for rx in drug_orders: if not any(ic.args['medication'] == rx.args['medication'] for ic in interaction_checks): raise CriticalSafetyViolation( f"Missing interaction check for {rx.args['medication']}" ) class ClinicalTerminationCriteria: @staticmethod def should_terminate(state: ClinicalState) -> bool: """Multi-factor clinical conversation termination""" metadata = state["workflow_metadata"] return any([ metadata["iterations"] >= ClinicalConfig.MAX_ITERATIONS, metadata["active_alerts"] > 3, "terminate consultation" in state["messages"][-1].content.lower() ]) # ── Execution Framework ───────────────────────────────────────────── class ClinicalWorkflow: def __init__(self): self.workflow = self._build_workflow() self.toolkit = ClinicalToolkit.get_essential_tools() self.llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1) def _build_workflow(self) -> StateGraph: """Construct ISO 13485-compliant clinical workflow""" workflow = StateGraph(ClinicalState) workflow.add_node("clinical_reasoning", ClinicalWorkflowNodes.agent_node) workflow.add_node("tool_execution", ClinicalWorkflowNodes.tool_node) workflow.add_node("safety_review", ClinicalSafetyProtocols.review_node) workflow.set_entry_point("clinical_reasoning") workflow.add_conditional_edges( "clinical_reasoning", ClinicalDecisionRouter.route_agent_output, { "require_tools": "tool_execution", "need_safety_review": "safety_review", "final_output": END } ) workflow.add_edge("tool_execution", "clinical_reasoning") workflow.add_edge("safety_review", "clinical_reasoning") return workflow.compile() def execute_consultation(self, patient_data: Dict) -> ClinicalState: """Execute full clinical workflow with safety audits""" initial_state = ClinicalStateManager.initialize_state(patient_data) return self.workflow.invoke( initial_state, config={"recursion_limit": ClinicalConfig.MAX_ITERATIONS + ClinicalConfig.RECURSION_BUFFER} ) # ── Usage Example ─────────────────────────────────────────────────── if __name__ == "__main__": # Initialize clinical environment ClinicalValidator.validate_environment() # Sample patient scenario complex_case = { "demographics": {"age": 68, "sex": "F", "weight_kg": 82}, "presenting_complaint": "Chest pain radiating to left arm", "medical_history": ["HTN", "Type 2 DM", "HLD"], "current_meds": ["Atenolol 50mg daily", "Simvastatin 40mg HS"] } # Execute clinical workflow workflow = ClinicalWorkflow() result = workflow.execute_consultation(complex_case) # Generate clinical summary final_report = ClinicalDocumentation.generate_report(result) print(json.dumps(final_report, indent=2))