|
import os |
|
import streamlit as st |
|
|
|
from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings |
|
from langchain.vectorstores.faiss import FAISS |
|
from huggingface_hub import snapshot_download |
|
|
|
from langchain.callbacks import StreamlitCallbackHandler |
|
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor |
|
from langchain.agents.agent_toolkits import create_retriever_tool |
|
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( |
|
AgentTokenBufferMemory, |
|
) |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.schema import SystemMessage, AIMessage, HumanMessage |
|
from langchain.prompts import MessagesPlaceholder |
|
from langsmith import Client |
|
|
|
client = Client() |
|
|
|
st.set_page_config( |
|
page_title="Investor Education ChatChain", |
|
page_icon="π", |
|
layout="wide", |
|
initial_sidebar_state="collapsed", |
|
) |
|
|
|
|
|
api_key = os.environ["OPENAI_API_KEY"] |
|
|
|
|
|
|
|
site_options = {'US': 'vanguard_embeddings_US', |
|
'AUS': 'vanguard-embeddings'} |
|
|
|
site_options_list = list(site_options.keys()) |
|
|
|
site_radio = st.radio( |
|
"Which Vanguard website location would you want to chat to?", |
|
('US', 'AUS')) |
|
|
|
@st.cache_data |
|
def load_vectorstore(site): |
|
'''load embeddings and vectorstore''' |
|
|
|
emb = HuggingFaceEmbeddings(model_name="all-mpnet-base-v2") |
|
|
|
vectorstore = FAISS.load_local(site_options[site], emb,allow_dangerous_deserialization=True) |
|
|
|
return vectorstore.as_retriever(search_kwargs={"k": 4}) |
|
|
|
|
|
tool = create_retriever_tool( |
|
load_vectorstore(site_radio), |
|
"search_vaguard_website", |
|
"Searches and returns documents regarding the Vanguard website across US and AUS locations. The websites provide investment related information to the user") |
|
|
|
tools = [tool] |
|
llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4o") |
|
message = SystemMessage( |
|
content=( |
|
"You are a helpful chatbot who is tasked with answering questions about investments using informationn that has been scraped from a website to answer the users question accurately." |
|
"Do not use any information not provided in the website context." |
|
"Unless otherwise explicitly stated, it is probably fair to assume that questions are about the CFA program and materials. " |
|
"If there is any ambiguity, you probably assume they are about that." |
|
) |
|
) |
|
|
|
prompt = OpenAIFunctionsAgent.create_prompt( |
|
system_message=message, |
|
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], |
|
) |
|
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) |
|
agent_executor = AgentExecutor( |
|
agent=agent, |
|
tools=tools, |
|
verbose=True, |
|
return_intermediate_steps=True, |
|
) |
|
memory = AgentTokenBufferMemory(llm=llm) |
|
starter_message = "Ask me anything about information on the Vanguard US/AUS websites!" |
|
if "messages" not in st.session_state or st.sidebar.button("Clear message history"): |
|
st.session_state["messages"] = [AIMessage(content=starter_message)] |
|
|
|
|
|
def send_feedback(run_id, score): |
|
client.create_feedback(run_id, "user_score", score=score) |
|
|
|
|
|
for msg in st.session_state.messages: |
|
if isinstance(msg, AIMessage): |
|
st.chat_message("assistant").write(msg.content) |
|
elif isinstance(msg, HumanMessage): |
|
st.chat_message("user").write(msg.content) |
|
memory.chat_memory.add_message(msg) |
|
|
|
|
|
if prompt := st.chat_input(placeholder=starter_message): |
|
st.chat_message("user").write(prompt) |
|
with st.chat_message("assistant"): |
|
st_callback = StreamlitCallbackHandler(st.container()) |
|
response = agent_executor( |
|
{"input": prompt, "history": st.session_state.messages}, |
|
callbacks=[st_callback], |
|
include_run_info=True, |
|
) |
|
st.session_state.messages.append(AIMessage(content=response["output"])) |
|
st.write(response["output"]) |
|
memory.save_context({"input": prompt}, response) |
|
st.session_state["messages"] = memory.buffer |
|
run_id = response["__run"].run_id |
|
|
|
col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1]) |
|
with col_text: |
|
st.text("Feedback:") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|