File size: 10,518 Bytes
a1dbeb2
 
 
 
 
f1cd1c6
86911ce
 
 
a1dbeb2
 
86911ce
a1dbeb2
86911ce
 
a1dbeb2
86911ce
a1dbeb2
 
 
86911ce
a1dbeb2
86911ce
a1dbeb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86911ce
 
a1dbeb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86911ce
 
a1dbeb2
 
 
 
 
 
 
e7e593a
a1dbeb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7e593a
a1dbeb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""
SynapseAI Clinical Decision Support System
Expert-Level Implementation with Safety-Centric Architecture
"""

import os
import re
import json
import logging
from typing import (Any, Dict, List, Optional, TypedDict, 
                    Callable, Sequence, Tuple, Union)
from functools import lru_cache
from enum import Enum

import requests
from pydantic import BaseModel, Field, ValidationError
from langchain_groq import ChatGroq
from langchain_core.messages import (HumanMessage, SystemMessage, 
                                   AIMessage, ToolMessage)
from langchain_core.tools import BaseTool
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolExecutor

# ── Type Definitions ──────────────────────────────────────────────────
class ClinicalPriority(str, Enum):
    STAT = "STAT"
    URGENT = "Urgent"
    ROUTINE = "Routine"

class ClinicalState(TypedDict):
    messages: List[Union[HumanMessage, SystemMessage, AIMessage, ToolMessage]]
    patient_data: Dict[str, Any]
    safety_warnings: List[Dict[str, str]]
    workflow_metadata: Dict[str, Union[int, float, bool]]
    execution_log: List[Dict[str, str]]

# ── Configuration ─────────────────────────────────────────────────────
class ClinicalConfig:
    MAX_ITERATIONS = 6  # Evidence-based conversation turn limit
    RECURSION_BUFFER = 2  # Safety margin for LangGraph execution
    DRUG_CHECK_REQUIRED = True  # Hard enforcement for interaction checks
    
    SAFETY_PARAMETERS = {
        'max_bp_systolic': 180,
        'min_bp_systolic': 90,
        'max_hr': 120,
        'min_spo2': 92
    }

# ── Core Clinical Tools ──────────────────────────────────────────────
class ClinicalToolkit:
    @staticmethod
    def get_essential_tools() -> List[BaseTool]:
        """Return validated clinical tools with safety wrappers"""
        return [
            ClinicalToolkit.order_lab_test,
            ClinicalToolkit.prescribe_medication,
            ClinicalToolkit.check_drug_interactions,
            ClinicalToolkit.flag_clinical_risk
        ]

    class LabOrderInput(BaseModel):
        test_name: str = Field(..., pattern=r"^[A-Za-z0-9\s-]+$")
        rationale: str = Field(..., min_length=20)
        priority: ClinicalPriority = ClinicalPriority.ROUTINE

    @tool("order_lab_test", args_schema=LabOrderInput)
    def order_lab_test(test_name: str, rationale: str, 
                      priority: ClinicalPriority) -> Dict[str, Any]:
        """Standardized lab ordering with clinical validation"""
        # Implementation details...
        return {"status": "ordered", "details": {...}}

    class PrescriptionSafetyCheck(BaseModel):
        medication: str
        rxcui: Optional[str]
        contraindications: List[str]
        # Additional safety fields...

    @classmethod
    def validate_prescription(cls, rx_data: Dict) -> PrescriptionSafetyCheck:
        """Pharmaceutical safety validation pipeline"""
        # Comprehensive validation logic...
        return PrescriptionSafetyCheck(...)

# ── State Management Engine ──────────────────────────────────────────
class ClinicalStateManager:
    @staticmethod
    def initialize_state(patient_data: Dict) -> ClinicalState:
        """Create validated initial state with clinical context"""
        return {
            "messages": [
                SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT),
                HumanMessage(content="Initiate clinical consultation")
            ],
            "patient_data": ClinicalValidator.sanitize_patient_data(patient_data),
            "safety_warnings": [],
            "workflow_metadata": {
                "iterations": 0,
                "active_alerts": 0,
                "safety_override": False
            },
            "execution_log": []
        }

    @staticmethod
    def propagate_state(previous: ClinicalState, 
                       updates: Dict) -> ClinicalState:
        """State transition with clinical context preservation"""
        preserved_fields = {
            'patient_data': previous['patient_data'],
            'workflow_metadata': {
                **previous['workflow_metadata'],
                **updates.get('workflow_metadata', {})
            }
        }
        return ClinicalValidator.validate_state({
            **preserved_fields,
            **updates
        })

# ── Clinical Workflow Nodes ─────────────────────────────────────────
class ClinicalWorkflowNodes:
    @staticmethod
    def agent_node(state: ClinicalState) -> ClinicalState:
        """FDA-compliant clinical reasoning engine"""
        ClinicalValidator.check_iteration_limit(state)
        
        try:
            response = ClinicalLLM.invoke(state)
            new_state = ClinicalStateManager.propagate_state(state, {
                "messages": [response],
                "workflow_metadata": {
                    "iterations": state["workflow_metadata"]["iterations"] + 1
                }
            })
            
            if ClinicalTerminationCriteria.should_terminate(new_state):
                return ClinicalWorkflowNodes.apply_termination_protocol(new_state)
                
            return new_state
        except CriticalClinicalError as e:
            return ClinicalErrorHandler.handle_critical_error(state, e)

    @staticmethod
    def tool_node(state: ClinicalState) -> ClinicalState:
        """HIPAA-compliant tool execution with safety audit"""
        ClinicalSafetyEngine.pre_execution_checks(state)
        
        tool_results = []
        for tool_call in state["messages"][-1].tool_calls:
            result = ClinicalToolExecutor.execute_with_audit(tool_call)
            tool_results.append(result)
            
            if result['category'] == "DRUG_ORDER":
                ClinicalSafetyEngine.post_drug_order_checks(result)
        
        return ClinicalStateManager.propagate_state(state, {
            "messages": [ToolMessage(...)],
            "safety_warnings": ClinicalSafetyEngine.aggregate_warnings(tool_results)
        })

# ── Safety Subsystems ───────────────────────────────────────────────
class ClinicalSafetyEngine:
    @staticmethod
    def enforce_prescription_rules(tool_calls: List) -> None:
        """Hard requirements for medication orders"""
        drug_orders = [tc for tc in tool_calls if tc.name == "prescribe_medication"]
        interaction_checks = [tc for tc in tool_calls 
                            if tc.name == "check_drug_interactions"]
        
        if ClinicalConfig.DRUG_CHECK_REQUIRED:
            for rx in drug_orders:
                if not any(ic.args['medication'] == rx.args['medication'] 
                          for ic in interaction_checks):
                    raise CriticalSafetyViolation(
                        f"Missing interaction check for {rx.args['medication']}"
                    )

class ClinicalTerminationCriteria:
    @staticmethod
    def should_terminate(state: ClinicalState) -> bool:
        """Multi-factor clinical conversation termination"""
        metadata = state["workflow_metadata"]
        return any([
            metadata["iterations"] >= ClinicalConfig.MAX_ITERATIONS,
            metadata["active_alerts"] > 3,
            "terminate consultation" in state["messages"][-1].content.lower()
        ])

# ── Execution Framework ─────────────────────────────────────────────
class ClinicalWorkflow:
    def __init__(self):
        self.workflow = self._build_workflow()
        self.toolkit = ClinicalToolkit.get_essential_tools()
        self.llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1)
        
    def _build_workflow(self) -> StateGraph:
        """Construct ISO 13485-compliant clinical workflow"""
        workflow = StateGraph(ClinicalState)
        
        workflow.add_node("clinical_reasoning", ClinicalWorkflowNodes.agent_node)
        workflow.add_node("tool_execution", ClinicalWorkflowNodes.tool_node)
        workflow.add_node("safety_review", ClinicalSafetyProtocols.review_node)
        
        workflow.set_entry_point("clinical_reasoning")
        
        workflow.add_conditional_edges(
            "clinical_reasoning",
            ClinicalDecisionRouter.route_agent_output,
            {
                "require_tools": "tool_execution",
                "need_safety_review": "safety_review",
                "final_output": END
            }
        )
        
        workflow.add_edge("tool_execution", "clinical_reasoning")
        workflow.add_edge("safety_review", "clinical_reasoning")
        
        return workflow.compile()

    def execute_consultation(self, patient_data: Dict) -> ClinicalState:
        """Execute full clinical workflow with safety audits"""
        initial_state = ClinicalStateManager.initialize_state(patient_data)
        return self.workflow.invoke(
            initial_state,
            config={"recursion_limit": ClinicalConfig.MAX_ITERATIONS + ClinicalConfig.RECURSION_BUFFER}
        )

# ── Usage Example ───────────────────────────────────────────────────
if __name__ == "__main__":
    # Initialize clinical environment
    ClinicalValidator.validate_environment()
    
    # Sample patient scenario
    complex_case = {
        "demographics": {"age": 68, "sex": "F", "weight_kg": 82},
        "presenting_complaint": "Chest pain radiating to left arm",
        "medical_history": ["HTN", "Type 2 DM", "HLD"],
        "current_meds": ["Atenolol 50mg daily", "Simvastatin 40mg HS"]
    }
    
    # Execute clinical workflow
    workflow = ClinicalWorkflow()
    result = workflow.execute_consultation(complex_case)
    
    # Generate clinical summary
    final_report = ClinicalDocumentation.generate_report(result)
    print(json.dumps(final_report, indent=2))