|
import os |
|
import re |
|
import streamlit as st |
|
from dotenv import load_dotenv |
|
from langchain.agents.openai_assistant import OpenAIAssistantRunnable |
|
|
|
|
|
load_dotenv() |
|
api_key = os.getenv("OPENAI_API_KEY") |
|
extractor_agent = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A") |
|
|
|
|
|
extractor_llm = OpenAIAssistantRunnable( |
|
assistant_id=extractor_agent, |
|
api_key=api_key, |
|
as_agent=True |
|
) |
|
|
|
def remove_citation(text: str) -> str: |
|
pattern = r"γ\d+β \w+γ" |
|
return re.sub(pattern, "π", text) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state["messages"] = [] |
|
if "thread_id" not in st.session_state: |
|
st.session_state["thread_id"] = None |
|
|
|
st.title("Solution Specifier A") |
|
|
|
def predict(user_input: str) -> str: |
|
""" |
|
This function calls our OpenAIAssistantRunnable to get a response. |
|
If we don't have a thread_id yet, we create a new thread on the first call. |
|
Otherwise, we continue the existing thread. |
|
""" |
|
if st.session_state["thread_id"] is None: |
|
response = extractor_llm.invoke({"content": user_input}) |
|
st.session_state["thread_id"] = response.thread_id |
|
else: |
|
response = extractor_llm.invoke( |
|
{"content": user_input, "thread_id": st.session_state["thread_id"]} |
|
) |
|
output = response.return_values["output"] |
|
return remove_citation(output) |
|
|
|
|
|
for msg in st.session_state["messages"]: |
|
if msg["role"] == "user": |
|
with st.chat_message("user"): |
|
st.write(msg["content"]) |
|
else: |
|
with st.chat_message("assistant"): |
|
st.write(msg["content"]) |
|
|
|
|
|
user_input = st.chat_input("Type your message here...") |
|
|
|
|
|
if user_input: |
|
|
|
st.session_state["messages"].append({"role": "user", "content": user_input}) |
|
|
|
|
|
with st.chat_message("user"): |
|
st.write(user_input) |
|
|
|
|
|
response_text = predict(user_input) |
|
|
|
|
|
st.session_state["messages"].append({"role": "assistant", "content": response_text}) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
st.write(response_text) |