cloud-sean's picture
Update app.py
965c324 verified
import streamlit as st
from openai import AzureOpenAI, AssistantEventHandler
import os
import re
from typing_extensions import override
# Set up your Azure OpenAI API credentials
# Configuration
api_key = os.environ.get("OPENAI_API_KEY")
azure_endpoint = os.environ.get("OPENAI_AZURE_ENDPOINT")
deployment_name = "gpt-4o" # Replace with your deployment name
api_version = "2024-05-01-preview" # Ensure this matches your Azure OpenAI resource
# Initialize the Azure OpenAI client
client = AzureOpenAI(
api_key=api_key,
azure_endpoint=azure_endpoint,
api_version=api_version
)
# Retrieve the assistant
assistant_id = "asst_c54LY1ZIel6lOG4CHsJ5xhaw" # Replace with your assistant ID
assistant = client.beta.assistants.retrieve(assistant_id)
# Streamlit app setup
st.set_page_config(page_title="ZimmerBot", page_icon=":hospital:", layout="wide")
st.image("https://download.logo.wine/logo/Zimmer_Biomet/Zimmer_Biomet-Logo.wine.png", width=200)
# Initialize session state
if "thread_id" not in st.session_state:
thread = client.beta.threads.create()
st.session_state.thread_id = thread.id
if "messages" not in st.session_state:
st.session_state.messages = []
if "placeholders" not in st.session_state:
st.session_state.placeholders = []
def clean_content(content):
# Remove and style citations
content = re.sub(r'【\d+(?::\d+)?†source】', '', content)
# Prettify [1][2] style citations to [1, 2]
content = re.sub(r'\[(\d+)\](?:\[(\d+)\])+', lambda m: f"[{', '.join(sorted(set(m.groups())))}]", content)
return content.strip()
def display_message(role, content, citations=None):
with st.chat_message(role):
st.markdown(content)
if citations:
st.markdown("---")
st.markdown("**Annotations:**")
for citation in citations:
st.markdown(f"- {citation}")
def process_message_content(message):
annotations = []
content_text = ''
citations = []
citation_map = {}
# Loop over all content blocks in message.content
for message_content in message.content:
# Extract text
text_block = message_content.text.value
content_text += text_block
# Extract annotations
content_annotations = getattr(message_content, 'annotations', [])
annotations.extend(content_annotations)
# Now process annotations
for annotation in annotations:
annotation_text = annotation.text.value # Extract the string value
if hasattr(annotation, 'file_citation'):
file_citation = annotation.file_citation
cited_file = client.files.retrieve(file_citation.file_id)
citation = f"{cited_file.filename}"
if citation not in citation_map:
citation_map[citation] = len(citation_map) + 1
content_text = content_text.replace(annotation_text, f"[{citation_map[citation]}]")
citations.append(f"[{citation_map[citation]}] {citation}")
elif hasattr(annotation, 'file_path'):
file_path = annotation.file_path
cited_file = client.files.retrieve(file_path.file_id)
citation = f"{cited_file.filename}"
if citation not in citation_map:
citation_map[citation] = len(citation_map) + 1
content_text = content_text.replace(annotation_text, f"[{citation_map[citation]}]")
citations.append(f"[{citation_map[citation]}] {citation}")
cleaned_content = clean_content(content_text)
return cleaned_content, citations
def submit_message(message):
# Add user message to session state and display it
st.session_state.messages.append({"role": "user", "content": message})
display_message("user", message)
# Send message to the assistant
client.beta.threads.messages.create(
thread_id=st.session_state.thread_id,
role="user",
content=message
)
# Create a placeholder for the assistant's response
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
# Define EventHandler to handle the streaming events
class StreamEventHandler(AssistantEventHandler):
def __init__(self):
super().__init__()
self.content = ""
@override
def on_text_delta(self, delta, snapshot):
self.content += delta.value
# Update the message placeholder
message_placeholder.markdown(self.content)
# Create and start the run with streaming enabled
event_handler = StreamEventHandler()
with client.beta.threads.runs.stream(
thread_id=st.session_state.thread_id,
assistant_id=assistant.id,
event_handler=event_handler,
) as stream_run:
stream_run.until_done()
# Store the streamed content
streamed_content = event_handler.content
# After streaming is done, process annotations
messages = client.beta.threads.messages.list(
thread_id=st.session_state.thread_id
)
# Find the last assistant message
assistant_messages = [msg for msg in messages.data if msg.role == 'assistant']
if assistant_messages:
assistant_message = assistant_messages[-1]
processed_content, citations = process_message_content(assistant_message)
# Combine streamed content with processed content
final_content = streamed_content if streamed_content else processed_content
# Add assistant response to session state
st.session_state.messages.append({
"role": "assistant",
"content": final_content,
"citations": citations
})
# Update the message placeholder with the final content
message_placeholder.markdown(final_content)
if citations:
with message_placeholder.container():
st.markdown("---")
st.markdown("**Annotations:**")
for citation in citations:
st.markdown(f"- {citation}")
else:
st.error("No assistant response was found.")
# Display chat history
for message in st.session_state.messages:
display_message(message["role"], message["content"], message.get("citations"))
# User input
user_input = st.chat_input("What would you like to ask?")
if user_input:
submit_message(user_input)