|
import streamlit as st |
|
from openai import AzureOpenAI, AssistantEventHandler |
|
import os |
|
import re |
|
from typing_extensions import override |
|
|
|
|
|
|
|
api_key = os.environ.get("OPENAI_API_KEY") |
|
azure_endpoint = os.environ.get("OPENAI_AZURE_ENDPOINT") |
|
deployment_name = "gpt-4o" |
|
api_version = "2024-05-01-preview" |
|
|
|
|
|
client = AzureOpenAI( |
|
api_key=api_key, |
|
azure_endpoint=azure_endpoint, |
|
api_version=api_version |
|
) |
|
|
|
|
|
assistant_id = "asst_c54LY1ZIel6lOG4CHsJ5xhaw" |
|
assistant = client.beta.assistants.retrieve(assistant_id) |
|
|
|
|
|
|
|
st.set_page_config(page_title="ZimmerBot", page_icon=":hospital:", layout="wide") |
|
st.image("https://download.logo.wine/logo/Zimmer_Biomet/Zimmer_Biomet-Logo.wine.png", width=200) |
|
|
|
|
|
|
|
if "thread_id" not in st.session_state: |
|
thread = client.beta.threads.create() |
|
st.session_state.thread_id = thread.id |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if "placeholders" not in st.session_state: |
|
st.session_state.placeholders = [] |
|
|
|
def clean_content(content): |
|
|
|
content = re.sub(r'【\d+(?::\d+)?†source】', '', content) |
|
|
|
content = re.sub(r'\[(\d+)\](?:\[(\d+)\])+', lambda m: f"[{', '.join(sorted(set(m.groups())))}]", content) |
|
return content.strip() |
|
|
|
def display_message(role, content, citations=None): |
|
with st.chat_message(role): |
|
st.markdown(content) |
|
if citations: |
|
st.markdown("---") |
|
st.markdown("**Annotations:**") |
|
for citation in citations: |
|
st.markdown(f"- {citation}") |
|
|
|
def process_message_content(message): |
|
annotations = [] |
|
content_text = '' |
|
citations = [] |
|
citation_map = {} |
|
|
|
|
|
for message_content in message.content: |
|
|
|
text_block = message_content.text.value |
|
content_text += text_block |
|
|
|
|
|
content_annotations = getattr(message_content, 'annotations', []) |
|
annotations.extend(content_annotations) |
|
|
|
|
|
for annotation in annotations: |
|
annotation_text = annotation.text.value |
|
if hasattr(annotation, 'file_citation'): |
|
file_citation = annotation.file_citation |
|
cited_file = client.files.retrieve(file_citation.file_id) |
|
citation = f"{cited_file.filename}" |
|
if citation not in citation_map: |
|
citation_map[citation] = len(citation_map) + 1 |
|
content_text = content_text.replace(annotation_text, f"[{citation_map[citation]}]") |
|
citations.append(f"[{citation_map[citation]}] {citation}") |
|
elif hasattr(annotation, 'file_path'): |
|
file_path = annotation.file_path |
|
cited_file = client.files.retrieve(file_path.file_id) |
|
citation = f"{cited_file.filename}" |
|
if citation not in citation_map: |
|
citation_map[citation] = len(citation_map) + 1 |
|
content_text = content_text.replace(annotation_text, f"[{citation_map[citation]}]") |
|
citations.append(f"[{citation_map[citation]}] {citation}") |
|
|
|
cleaned_content = clean_content(content_text) |
|
return cleaned_content, citations |
|
|
|
def submit_message(message): |
|
|
|
st.session_state.messages.append({"role": "user", "content": message}) |
|
display_message("user", message) |
|
|
|
|
|
client.beta.threads.messages.create( |
|
thread_id=st.session_state.thread_id, |
|
role="user", |
|
content=message |
|
) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
|
|
|
|
class StreamEventHandler(AssistantEventHandler): |
|
def __init__(self): |
|
super().__init__() |
|
self.content = "" |
|
|
|
@override |
|
def on_text_delta(self, delta, snapshot): |
|
self.content += delta.value |
|
|
|
message_placeholder.markdown(self.content) |
|
|
|
|
|
event_handler = StreamEventHandler() |
|
with client.beta.threads.runs.stream( |
|
thread_id=st.session_state.thread_id, |
|
assistant_id=assistant.id, |
|
event_handler=event_handler, |
|
) as stream_run: |
|
stream_run.until_done() |
|
|
|
|
|
streamed_content = event_handler.content |
|
|
|
|
|
messages = client.beta.threads.messages.list( |
|
thread_id=st.session_state.thread_id |
|
) |
|
|
|
|
|
assistant_messages = [msg for msg in messages.data if msg.role == 'assistant'] |
|
if assistant_messages: |
|
assistant_message = assistant_messages[-1] |
|
processed_content, citations = process_message_content(assistant_message) |
|
|
|
|
|
final_content = streamed_content if streamed_content else processed_content |
|
|
|
|
|
st.session_state.messages.append({ |
|
"role": "assistant", |
|
"content": final_content, |
|
"citations": citations |
|
}) |
|
|
|
|
|
message_placeholder.markdown(final_content) |
|
if citations: |
|
with message_placeholder.container(): |
|
st.markdown("---") |
|
st.markdown("**Annotations:**") |
|
for citation in citations: |
|
st.markdown(f"- {citation}") |
|
else: |
|
st.error("No assistant response was found.") |
|
|
|
|
|
for message in st.session_state.messages: |
|
display_message(message["role"], message["content"], message.get("citations")) |
|
|
|
|
|
user_input = st.chat_input("What would you like to ask?") |
|
if user_input: |
|
submit_message(user_input) |