File size: 4,747 Bytes
788074d
 
 
b8dee1d
9c32b8a
 
fc636ce
 
 
 
 
 
 
 
 
9c32b8a
 
 
 
 
 
 
fc636ce
 
 
 
 
9c32b8a
788074d
fc636ce
 
 
 
 
 
 
 
 
 
 
9c32b8a
 
03e1441
 
 
9c32b8a
 
 
 
 
 
 
 
 
 
 
 
03e1441
9c32b8a
 
788074d
9c32b8a
 
fc636ce
9c32b8a
 
fc636ce
9c32b8a
 
 
 
 
 
788074d
 
 
fc636ce
 
 
 
 
788074d
fc636ce
788074d
 
 
9c32b8a
788074d
 
 
 
 
fc636ce
 
788074d
 
 
fc636ce
 
788074d
9c32b8a
 
 
 
 
 
 
b8dee1d
 
 
 
9c32b8a
03e1441
 
9c32b8a
 
 
 
03e1441
 
788074d
9c32b8a
788074d
9c32b8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage
from typing import TypedDict, Annotated, List, Optional
import json

# Configuration
class MedicalConfig:
    SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules:
1. Analyze patient data using latest medical guidelines
2. Always check for drug interactions
3. Suggest tests only when necessary
4. Use structured actions:
   - order_lab_test: {test_name, reason}
   - prescribe_medication: {name, dosage, frequency}
5. Flag high-risk conditions immediately"""

    RED_FLAGS = {
        "symptoms": ["chest pain", "shortness of breath", "severe headache"],
        "vitals": {"temp": (38.5, "°C"), "hr": (120, "bpm"), "bp": ("180/120", "mmHg")}
    }

# State Management
class AgentState(TypedDict):
    messages: Annotated[List[dict], lambda l, r: l + r]
    patient_data: dict
    approvals: Annotated[List[dict], lambda l, r: l + r]

class MedicalAgent:
    def __init__(self):
        self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
        self.tools = {
            "medical_actions": {
                "order_lab_test": self.order_lab_test,
                "prescribe_medication": self.prescribe_medication
            },
            "research": TavilySearchResults(max_results=3)
        }

    def analyze_patient(self, patient_data: dict) -> Optional[dict]:
        try:
            # Bind tools to the model
            model_with_tools = self.model.bind_tools(list(self.tools["medical_actions"].keys()))
            response = model_with_tools.invoke([
                SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
                HumanMessage(content=f"Patient Data: {json.dumps(patient_data)}")
            ])
            return response
        except Exception as e:
            st.error(f"Error analyzing patient data: {str(e)}")
            return None

    def process_action(self, action: dict) -> str:
        try:
            if action['name'] in self.tools['medical_actions']:
                return self.tools['medical_actions'][action['name']](action['args'])
            return f"Unknown action: {action['name']}"
        except Exception as e:
            return f"Error processing action: {str(e)}"

    def order_lab_test(self, params: dict) -> str:
        return f"✅ Lab ordered: {params['test_name']} ({params['reason']})"

    def prescribe_medication(self, params: dict) -> str:
        return f"✅ Prescribed: {params['name']} {params['dosage']} {params['frequency']}"

    def check_red_flags(self, patient_data: dict) -> List[str]:
        flags = []
        for symptom in patient_data.get("symptoms", []):
            if symptom.lower() in MedicalConfig.RED_FLAGS["symptoms"]:
                flags.append(f"🚨 High-risk symptom: {symptom}")
        return flags

# Streamlit UI
def main():
    st.set_page_config(page_title="AI Clinic", layout="wide")
    
    # Initialize agent
    if 'agent' not in st.session_state:
        st.session_state.agent = MedicalAgent()
    
    # Patient intake
    with st.sidebar:
        st.header("Patient Intake")
        patient_data = {
            "symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain", "Shortness of Breath"]),
            "history": {
                "conditions": st.text_input("Medical History"),
                "medications": st.text_input("Current Medications")
            },
            "vitals": {
                "temp": st.number_input("Temp (°C)", 35.0, 42.0, 37.0),
                "bp": st.text_input("BP (mmHg)", "120/80")
            }
        }
    
    # Main interface
    st.title("AI-Powered Clinical Support System")
    
    # Check for red flags
    red_flags = st.session_state.agent.check_red_flags(patient_data)
    if red_flags:
        st.warning("High-Risk Conditions Detected")
        for flag in red_flags:
            st.error(flag)
    
    if st.button("Start Analysis"):
        with st.spinner("Analyzing patient data..."):
            response = st.session_state.agent.analyze_patient(patient_data)
            
            if response:
                st.subheader("Analysis Results")
                if hasattr(response, 'tool_calls') and response.tool_calls:
                    for action in response.tool_calls:
                        result = st.session_state.agent.process_action(action)
                        st.success(result)
                else:
                    st.info("No specific actions recommended. Here's the analysis:")
                    st.write(response.content)

# Run the app
if __name__ == "__main__":
    main()