File size: 4,455 Bytes
788074d
 
 
b8dee1d
63b0a52
b564942
9c32b8a
fc636ce
 
 
 
 
 
63b0a52
 
 
 
9c32b8a
 
 
 
 
fc636ce
63b0a52
 
 
b564942
 
63b0a52
 
 
b564942
 
63b0a52
b564942
63b0a52
788074d
fc636ce
 
 
63b0a52
fc636ce
9c32b8a
 
63b0a52
9c32b8a
 
 
 
 
 
 
 
 
 
63b0a52
 
 
 
 
 
 
 
 
9c32b8a
 
788074d
b564942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from typing import Optional
import json

# Configuration
class MedicalConfig:
    SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules:
1. Analyze patient data using latest medical guidelines
2. Always check for drug interactions
3. Use structured actions when needed:
   - lab_order: Order lab tests
   - prescribe: Prescribe medication
4. Flag high-risk conditions immediately"""

    RED_FLAGS = {
        "symptoms": ["chest pain", "shortness of breath", "severe headache"],
        "vitals": {"temp": (38.5, "°C"), "hr": (120, "bpm"), "bp": ("180/120", "mmHg")}
    }

# Define tools with proper schemas
@tool
def order_lab_test(test_name: str, reason: str) -> str:
    """Orders a lab test with specified parameters."""
    return f"✅ Lab ordered: {test_name} ({reason})"

@tool
def prescribe_medication(name: str, dosage: str, frequency: str) -> str:
    """Prescribes medication with specific dosage instructions."""
    return f"✅ Prescribed: {name} {dosage} {frequency}"

# Initialize tools and model
tools = [order_lab_test, prescribe_medication, TavilySearchResults(max_results=3)]

class MedicalAgent:
    def __init__(self):
        self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
        self.model_with_tools = self.model.bind_tools(tools)

    def analyze_patient(self, patient_data: dict) -> Optional[dict]:
        try:
            response = self.model_with_tools.invoke([
                SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
                HumanMessage(content=f"Patient Data: {json.dumps(patient_data)}")
            ])
            return response
        except Exception as e:
            st.error(f"Error analyzing patient data: {str(e)}")
            return None

    def process_action(self, action: dict) -> str:
        try:
            tool_name = action['name']
            args = action['args']
            
            if tool_name == "order_lab_test":
                return order_lab_test.invoke(args)
            elif tool_name == "prescribe_medication":
                return prescribe_medication.invoke(args)
            else:
                return f"Unknown action: {tool_name}"
        except Exception as e:
            return f"Error processing action: {str(e)}"

# Streamlit UI
def main():
    st.set_page_config(page_title="AI Clinic", layout="wide")
    
    # Initialize agent
    if 'agent' not in st.session_state:
        st.session_state.agent = MedicalAgent()
    
    # Patient intake form
    with st.sidebar:
        st.header("Patient Intake")
        symptoms = st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain", "Shortness of Breath"])
        medical_history = st.text_input("Medical History")
        current_meds = st.text_input("Current Medications")
        temp = st.number_input("Temperature (°C)", 35.0, 42.0, 37.0)
        bp = st.text_input("Blood Pressure (mmHg)", "120/80")
        
        patient_data = {
            "symptoms": symptoms,
            "history": {
                "conditions": medical_history,
                "medications": current_meds
            },
            "vitals": {
                "temp": temp,
                "bp": bp
            }
        }
    
    # Main interface
    st.title("AI-Powered Clinical Support System")
    
    # Analysis button
    if st.button("Start Analysis", type="primary"):
        with st.spinner("Analyzing patient data..."):
            response = st.session_state.agent.analyze_patient(patient_data)
            
            if response:
                st.subheader("Clinical Recommendations")
                
                # Check for tool calls
                if hasattr(response, 'tool_calls') and response.tool_calls:
                    for action in response.tool_calls:
                        result = st.session_state.agent.process_action({
                            "name": action['name'],
                            "args": action['args']
                        })
                        st.success(result)
                else:
                    st.info("No specific actions recommended. Clinical Summary:")
                    st.write(response.content)

if __name__ == "__main__":
    main()