File size: 5,208 Bytes
31ea27b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import streamlit as st
import os
import re
from openai import AzureOpenAI
from streamlit_pills import pills

st.set_page_config(layout="wide")

# Initialize the Azure OpenAI client
client = AzureOpenAI(
  azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT"),
  api_key= os.getenv("AZURE_OPENAI_API_KEY"),
  api_version="2024-05-01-preview"
)

# Retrieve the assistant
assistant = client.beta.assistants.retrieve("asst_Tbjd3ckxAfOjj29TeP6KaZ0v")

# Streamlit app
st.title("MultiCare AI Summary Demo")

# Initialize session state
if "thread_id" not in st.session_state:
    thread = client.beta.threads.create()
    st.session_state.thread_id = thread.id

if "messages" not in st.session_state:
    st.session_state.messages = []

if "selected_pill" not in st.session_state:
    st.session_state.selected_pill = None


def clean_content(content):
    # Remove 【18:18†source】and 【5†source】style citations
    content = re.sub(r'【\d+(?::\d+)?†source】', '', content)
    # Prettify [1][2] style citations to [1, 2]
    content = re.sub(r'\[(\d+)\](?:\[(\d+)\])+', lambda m: f"[{', '.join(sorted(set(m.groups())))}]", content)
    return content.strip()

def process_message_content(message):
    message_content = message.content[0].text
    annotations = message_content.annotations
    citations = []
    citation_map = {}

    for index, annotation in enumerate(annotations):
        if file_citation := getattr(annotation, "file_citation", None):
            cited_file = client.files.retrieve(file_citation.file_id)
            citation = f"{cited_file.filename}"
            if citation not in citation_map:
                citation_map[citation] = len(citation_map) + 1
            citations.append(f"[{citation_map[citation]}] {citation}")
            message_content.value = message_content.value.replace(annotation.text, f"[{citation_map[citation]}]")
        elif file_path := getattr(annotation, "file_path", None):
            cited_file = client.files.retrieve(file_path.file_id)
            citation = f"{cited_file.filename}"
            if citation not in citation_map:
                citation_map[citation] = len(citation_map) + 1
            citations.append(f"[{citation_map[citation]}] {citation}")
            message_content.value = message_content.value.replace(annotation.text, f"[{citation_map[citation]}]")

    cleaned_content = clean_content(message_content.value)
    return cleaned_content, list(set(citations))

# Function to handle message submission
def submit_message(message):
    st.session_state.messages.append({"role": "user", "content": message})
    with st.chat_message("user", avatar="⚕️"):
        st.markdown(message)

    client.beta.threads.messages.create(
        thread_id=st.session_state.thread_id,
        role="user",
        content=message
    )

    with st.spinner("Assistant is thinking..."):
        run = client.beta.threads.runs.create_and_poll(
            thread_id=st.session_state.thread_id,
            assistant_id=assistant.id
        )

    if run.status == 'completed':
        messages = client.beta.threads.messages.list(
            thread_id=st.session_state.thread_id
        )
        assistant_message = messages.data[0]
        
        processed_content, citations = process_message_content(assistant_message)
        
        st.session_state.messages.append({
            "role": "assistant", 
            "content": processed_content,
            "citations": citations
        })
        
        with st.chat_message("assistant", avatar="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTETAUBM4r3hXmQ3eC31vWws8A9VbPvOxQNZQ&s"):
            st.markdown(processed_content)
            if citations:
                st.markdown("---")
                st.markdown("**Citations:**")
                for citation in citations:
                    st.markdown(citation)
    else:
        st.error(f"Error: {run.status}")

# Display chat messages
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])
        if "citations" in message and message["citations"]:
            st.markdown("---")
            st.markdown("**Citations:**")
            for citation in message["citations"]:
                st.markdown(citation)

suggested_messages = [
    "I'm a physician working at a primary care clinic, what Epic update changes will affect me?",
    "Summarize Epic update changes that will impact a Med Surg Registered Nurse.",
    "What Epic update changes need to be reviewed by ED registration staff?",
    "Which Epic update changes impact surgeons?",
    "Are there any Epic update changes that will affect X-ray technicians?",
    "Create a summary of all Epic update changes that are relevant for referral coordinators?"
]

selected_pill = pills(
    "Quick Questions by Role",
    suggested_messages,
    icons=['👨‍⚕️', '👩‍⚕️', '🏥', '🥽', '📷', '📞'],
    index=None,
    label_visibility="visible",
    key="suggested_messages"
)

if selected_pill:
    submit_message(selected_pill)

# Chat input
user_input = st.chat_input("What would you like to ask?")

if user_input:
    submit_message(user_input)