Update app.py
Browse files
app.py
CHANGED
@@ -1,209 +1,125 @@
|
|
1 |
-
from langchain.prompts.prompt import PromptTemplate
|
2 |
-
from langchain.llms import OpenAIChat
|
3 |
-
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
4 |
-
from langchain.chains.question_answering import load_qa_chain
|
5 |
-
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings
|
6 |
-
from langchain.callbacks import StdOutCallbackHandler
|
7 |
-
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
8 |
-
from langchain.vectorstores import FAISS
|
9 |
-
from langchain.memory import ConversationBufferMemory
|
10 |
import os
|
11 |
-
|
12 |
-
|
13 |
-
import
|
14 |
-
from
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
21 |
)
|
22 |
-
from langchain.
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
)
|
27 |
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
|
32 |
site_options = {'US': 'vanguard_embeddings_US',
|
33 |
'AUS': 'vanguard_embeddings'}
|
34 |
|
35 |
site_options_list = list(site_options.keys())
|
36 |
|
37 |
-
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
system_template="""Use only the following pieces of context that has been scraped from a website to answer the users question accurately.
|
42 |
-
Do not use any information not provided in the website context.
|
43 |
-
If you don't know the answer, just say 'There is no relevant answer in the Investor Website',
|
44 |
-
don't try to make up an answer.
|
45 |
-
|
46 |
-
ALWAYS return a "SOURCES" part in your answer.
|
47 |
-
The "SOURCES" part should be a reference to the source of the document from which you got your answer.
|
48 |
-
|
49 |
-
Remember, do not reference any information not given in the context.
|
50 |
-
If the answer is not available in the given context just say 'There is no relevant answer in the website content'
|
51 |
-
|
52 |
-
Follow the below format when answering:
|
53 |
-
|
54 |
-
Question: {question}
|
55 |
-
SOURCES: [xyz]
|
56 |
-
|
57 |
-
Begin!
|
58 |
-
----------------
|
59 |
-
{context}"""
|
60 |
-
|
61 |
-
messages = [
|
62 |
-
SystemMessagePromptTemplate.from_template(system_template),
|
63 |
-
HumanMessagePromptTemplate.from_template("{question}")
|
64 |
-
]
|
65 |
-
prompt = ChatPromptTemplate.from_messages(messages)
|
66 |
-
|
67 |
-
return prompt
|
68 |
-
|
69 |
def load_vectorstore(site):
|
70 |
'''load embeddings and vectorstore'''
|
71 |
|
72 |
emb = HuggingFaceEmbeddings(model_name="all-mpnet-base-v2")
|
|
|
|
|
73 |
|
74 |
-
return
|
75 |
|
76 |
-
#default embeddings and store
|
77 |
-
vectorstore = load_vectorstore(site_options_list[0])
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
vectorstore = load_vectorstore(site)
|
84 |
-
|
85 |
-
# vectorstore = load_vectorstore('vanguard-embeddings',sbert_emb)
|
86 |
-
|
87 |
-
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
|
88 |
-
You can assume the question about investing and the investment management industry.
|
89 |
-
Chat History:
|
90 |
-
{chat_history}
|
91 |
-
Follow Up Input: {question}
|
92 |
-
Standalone question:"""
|
93 |
-
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
94 |
-
|
95 |
-
template = """You are an AI assistant for answering questions about investing and the investment management industry.
|
96 |
-
You are given the following extracted parts of a long document and a question. Provide a conversational answer.
|
97 |
-
If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up an answer.
|
98 |
-
If the question is not about investing, politely inform them that you are tuned to only answer questions about investing and the investment management industry.
|
99 |
-
Question: {question}
|
100 |
-
=========
|
101 |
-
{context}
|
102 |
-
=========
|
103 |
-
Answer in Markdown:"""
|
104 |
-
QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
|
105 |
-
|
106 |
-
|
107 |
-
def get_chain(vectorstore):
|
108 |
-
llm = OpenAIChat(streaming=True,
|
109 |
-
callbacks=[StdOutCallbackHandler()],
|
110 |
-
verbose=True,
|
111 |
-
temperature=0,
|
112 |
-
model_name='gpt-4')
|
113 |
-
|
114 |
-
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
|
115 |
-
doc_chain = load_qa_chain(llm=llm,chain_type="stuff",prompt=load_prompt())
|
116 |
-
|
117 |
-
chain = ConversationalRetrievalChain(retriever=vectorstore.as_retriever(search_kwags={"k": 4}),
|
118 |
-
question_generator=question_generator,
|
119 |
-
combine_docs_chain=doc_chain,
|
120 |
-
memory=memory,
|
121 |
-
return_source_documents=True,
|
122 |
-
get_chat_history=lambda h :h)
|
123 |
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
history = history or []
|
143 |
-
# Set OpenAI key
|
144 |
-
# chain = get_chain(vectorstore)
|
145 |
-
# Run chain and append input.
|
146 |
-
output = chain({"question": inp})["answer"]
|
147 |
-
history.append((inp, output))
|
148 |
-
except Exception as e:
|
149 |
-
raise e
|
150 |
-
finally:
|
151 |
-
self.lock.release()
|
152 |
-
return history, history
|
153 |
-
|
154 |
-
block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
|
155 |
-
|
156 |
-
with block:
|
157 |
-
with gr.Row():
|
158 |
-
gr.Markdown("<h3><center>Chat-Your-Data (Investor Education)</center></h3>")
|
159 |
-
embed_but = gr.Button(value='Step 1: Click Me to Load the QA System')
|
160 |
-
with gr.Row():
|
161 |
-
websites = gr.Radio(choices=site_options_list,value=site_options_list[0],label='Select US or AUS website data',
|
162 |
-
interactive=True)
|
163 |
-
websites.change(on_value_change, websites)
|
164 |
-
|
165 |
-
vectorstore = load_vectorstore(websites.value)
|
166 |
-
|
167 |
-
chatbot = gr.Chatbot()
|
168 |
|
169 |
-
chat = ChatWrapper()
|
170 |
|
|
|
|
|
171 |
|
172 |
-
with gr.Row():
|
173 |
-
message = gr.Textbox(
|
174 |
-
label="What's your question?",
|
175 |
-
placeholder="Ask questions about Investing",
|
176 |
-
lines=1,
|
177 |
-
)
|
178 |
-
submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
|
179 |
-
|
180 |
-
gr.Examples(
|
181 |
-
examples=[
|
182 |
-
"What are the benefits of investing in ETFs?",
|
183 |
-
"What is the average cost of investing in a managed fund?",
|
184 |
-
"At what age can I start investing?",
|
185 |
-
"Do you offer investment accounts for kids?"
|
186 |
-
],
|
187 |
-
inputs=message,
|
188 |
-
)
|
189 |
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
-
gr.HTML(
|
193 |
-
"<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain π¦οΈπ</a></center>"
|
194 |
-
)
|
195 |
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
-
embed_but.click(
|
204 |
-
load_chain,
|
205 |
-
outputs=[agent_state],
|
206 |
-
)
|
207 |
|
208 |
gr.Markdown("")
|
209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
|
5 |
+
from langchain.vectorstores.faiss import FAISS
|
6 |
+
from huggingface_hub import snapshot_download
|
7 |
+
|
8 |
+
from langchain.callbacks import StreamlitCallbackHandler
|
9 |
+
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
|
10 |
+
from langchain.agents.agent_toolkits import create_retriever_tool
|
11 |
+
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
12 |
+
AgentTokenBufferMemory,
|
13 |
)
|
14 |
+
from langchain.chat_models import ChatOpenAI
|
15 |
+
from langchain.schema import SystemMessage, AIMessage, HumanMessage
|
16 |
+
from langchain.prompts import MessagesPlaceholder
|
17 |
+
from langsmith import Client
|
18 |
+
|
19 |
+
client = Client()
|
20 |
+
|
21 |
+
st.set_page_config(
|
22 |
+
page_title="Investor Education ChatChain",
|
23 |
+
page_icon="π",
|
24 |
+
layout="wide",
|
25 |
+
initial_sidebar_state="collapsed",
|
26 |
)
|
27 |
|
28 |
+
#Load API Key
|
29 |
+
api_key = os.environ["OPENAI_API_KEY"]
|
30 |
|
31 |
+
#### sidebar section 1 ####
|
32 |
|
33 |
site_options = {'US': 'vanguard_embeddings_US',
|
34 |
'AUS': 'vanguard_embeddings'}
|
35 |
|
36 |
site_options_list = list(site_options.keys())
|
37 |
|
38 |
+
site_radio = st.radio(
|
39 |
+
"Which Vanguard website location would you want to chat to?",
|
40 |
+
('US', 'AUS'))
|
41 |
|
42 |
+
@st.cache_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
def load_vectorstore(site):
|
44 |
'''load embeddings and vectorstore'''
|
45 |
|
46 |
emb = HuggingFaceEmbeddings(model_name="all-mpnet-base-v2")
|
47 |
+
|
48 |
+
vectorstore = FAISS.load_local(site_options[site], emb)
|
49 |
|
50 |
+
return vectorstore.as_retriever(search_kwargs={"k": 4})
|
51 |
|
|
|
|
|
52 |
|
53 |
+
tool = create_retriever_tool(
|
54 |
+
load_vectorstore(site_radio),
|
55 |
+
"search_vaguard_website",
|
56 |
+
"Searches and returns documents regarding the Vanguard website across US and UK locations. The websites provide investment related information to the user")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
tools = [tool]
|
59 |
+
llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4")
|
60 |
+
message = SystemMessage(
|
61 |
+
content=(
|
62 |
+
"You are a helpful chatbot who is tasked with answering questions about investments using informationn that has been scraped from a website to answer the users question accurately."
|
63 |
+
"Do not use any information not provided in the website context."
|
64 |
+
"Unless otherwise explicitly stated, it is probably fair to assume that questions are about the CFA program and materials. "
|
65 |
+
"If there is any ambiguity, you probably assume they are about that."
|
66 |
+
)
|
67 |
+
)
|
68 |
|
69 |
+
prompt = OpenAIFunctionsAgent.create_prompt(
|
70 |
+
system_message=message,
|
71 |
+
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
|
72 |
+
)
|
73 |
+
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
|
74 |
+
agent_executor = AgentExecutor(
|
75 |
+
agent=agent,
|
76 |
+
tools=tools,
|
77 |
+
verbose=True,
|
78 |
+
return_intermediate_steps=True,
|
79 |
+
)
|
80 |
+
memory = AgentTokenBufferMemory(llm=llm)
|
81 |
+
starter_message = "Ask me anything about information on the Vanguard US/UK websites!"
|
82 |
+
if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
|
83 |
+
st.session_state["messages"] = [AIMessage(content=starter_message)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
|
|
85 |
|
86 |
+
def send_feedback(run_id, score):
|
87 |
+
client.create_feedback(run_id, "user_score", score=score)
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
+
for msg in st.session_state.messages:
|
91 |
+
if isinstance(msg, AIMessage):
|
92 |
+
st.chat_message("assistant").write(msg.content)
|
93 |
+
elif isinstance(msg, HumanMessage):
|
94 |
+
st.chat_message("user").write(msg.content)
|
95 |
+
memory.chat_memory.add_message(msg)
|
96 |
|
|
|
|
|
|
|
97 |
|
98 |
+
if prompt := st.chat_input(placeholder=starter_message):
|
99 |
+
st.chat_message("user").write(prompt)
|
100 |
+
with st.chat_message("assistant"):
|
101 |
+
st_callback = StreamlitCallbackHandler(st.container())
|
102 |
+
response = agent_executor(
|
103 |
+
{"input": prompt, "history": st.session_state.messages},
|
104 |
+
callbacks=[st_callback],
|
105 |
+
include_run_info=True,
|
106 |
+
)
|
107 |
+
st.session_state.messages.append(AIMessage(content=response["output"]))
|
108 |
+
st.write(response["output"])
|
109 |
+
memory.save_context({"input": prompt}, response)
|
110 |
+
st.session_state["messages"] = memory.buffer
|
111 |
+
run_id = response["__run"].run_id
|
112 |
|
113 |
+
col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1])
|
114 |
+
with col_text:
|
115 |
+
st.text("Feedback:")
|
116 |
+
|
117 |
+
# with col1:
|
118 |
+
# st.button("π", on_click=send_feedback, args=(run_id, 1))
|
119 |
+
|
120 |
+
# with col2:
|
121 |
+
# st.button("π", on_click=send_feedback, args=(run_id, 0)
|
122 |
|
|
|
|
|
|
|
|
|
123 |
|
124 |
gr.Markdown("")
|
125 |
|