nickmuchi commited on
Commit
a4d579b
Β·
1 Parent(s): 5c66512

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -177
app.py CHANGED
@@ -1,209 +1,125 @@
1
- from langchain.prompts.prompt import PromptTemplate
2
- from langchain.llms import OpenAIChat
3
- from langchain.chains import ConversationalRetrievalChain, LLMChain
4
- from langchain.chains.question_answering import load_qa_chain
5
- from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings
6
- from langchain.callbacks import StdOutCallbackHandler
7
- from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
8
- from langchain.vectorstores import FAISS
9
- from langchain.memory import ConversationBufferMemory
10
  import os
11
- from typing import Optional, Tuple
12
- import gradio as gr
13
- import pickle
14
- from threading import Lock
15
-
16
- from langchain.prompts.chat import (
17
- ChatPromptTemplate,
18
- SystemMessagePromptTemplate,
19
- AIMessagePromptTemplate,
20
- HumanMessagePromptTemplate,
 
21
  )
22
- from langchain.schema import (
23
- AIMessage,
24
- HumanMessage,
25
- SystemMessage
 
 
 
 
 
 
 
 
26
  )
27
 
28
- from langchain.prompts import PromptTemplate
 
29
 
30
- prefix_messages = [{"role": "system", "content": "You are a helpful assistant that is very good at answering questions about investments using the information given."}]
31
 
32
  site_options = {'US': 'vanguard_embeddings_US',
33
  'AUS': 'vanguard_embeddings'}
34
 
35
  site_options_list = list(site_options.keys())
36
 
37
- memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer')
 
 
38
 
39
- def load_prompt():
40
-
41
- system_template="""Use only the following pieces of context that has been scraped from a website to answer the users question accurately.
42
- Do not use any information not provided in the website context.
43
- If you don't know the answer, just say 'There is no relevant answer in the Investor Website',
44
- don't try to make up an answer.
45
-
46
- ALWAYS return a "SOURCES" part in your answer.
47
- The "SOURCES" part should be a reference to the source of the document from which you got your answer.
48
-
49
- Remember, do not reference any information not given in the context.
50
- If the answer is not available in the given context just say 'There is no relevant answer in the website content'
51
-
52
- Follow the below format when answering:
53
-
54
- Question: {question}
55
- SOURCES: [xyz]
56
-
57
- Begin!
58
- ----------------
59
- {context}"""
60
-
61
- messages = [
62
- SystemMessagePromptTemplate.from_template(system_template),
63
- HumanMessagePromptTemplate.from_template("{question}")
64
- ]
65
- prompt = ChatPromptTemplate.from_messages(messages)
66
-
67
- return prompt
68
-
69
  def load_vectorstore(site):
70
  '''load embeddings and vectorstore'''
71
 
72
  emb = HuggingFaceEmbeddings(model_name="all-mpnet-base-v2")
 
 
73
 
74
- return FAISS.load_local(site_options[site], emb)
75
 
76
- #default embeddings and store
77
- vectorstore = load_vectorstore(site_options_list[0])
78
 
79
- def on_value_change(site):
80
- '''When radio changes, change the website reference data'''
81
-
82
- global vectorstore
83
- vectorstore = load_vectorstore(site)
84
-
85
- # vectorstore = load_vectorstore('vanguard-embeddings',sbert_emb)
86
-
87
- _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
88
- You can assume the question about investing and the investment management industry.
89
- Chat History:
90
- {chat_history}
91
- Follow Up Input: {question}
92
- Standalone question:"""
93
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
94
-
95
- template = """You are an AI assistant for answering questions about investing and the investment management industry.
96
- You are given the following extracted parts of a long document and a question. Provide a conversational answer.
97
- If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up an answer.
98
- If the question is not about investing, politely inform them that you are tuned to only answer questions about investing and the investment management industry.
99
- Question: {question}
100
- =========
101
- {context}
102
- =========
103
- Answer in Markdown:"""
104
- QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
105
-
106
-
107
- def get_chain(vectorstore):
108
- llm = OpenAIChat(streaming=True,
109
- callbacks=[StdOutCallbackHandler()],
110
- verbose=True,
111
- temperature=0,
112
- model_name='gpt-4')
113
-
114
- question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
115
- doc_chain = load_qa_chain(llm=llm,chain_type="stuff",prompt=load_prompt())
116
-
117
- chain = ConversationalRetrievalChain(retriever=vectorstore.as_retriever(search_kwags={"k": 4}),
118
- question_generator=question_generator,
119
- combine_docs_chain=doc_chain,
120
- memory=memory,
121
- return_source_documents=True,
122
- get_chat_history=lambda h :h)
123
 
124
-
125
- return chain
 
 
 
 
 
 
 
 
126
 
127
- def load_chain():
128
- chain = get_chain(vectorstore)
129
- return chain
130
-
131
-
132
- class ChatWrapper:
133
-
134
- def __init__(self):
135
- self.lock = Lock()
136
- def __call__(
137
- self, inp: str, history: Optional[Tuple[str, str]], chain
138
- ):
139
- """Execute the chat functionality."""
140
- self.lock.acquire()
141
- try:
142
- history = history or []
143
- # Set OpenAI key
144
- # chain = get_chain(vectorstore)
145
- # Run chain and append input.
146
- output = chain({"question": inp})["answer"]
147
- history.append((inp, output))
148
- except Exception as e:
149
- raise e
150
- finally:
151
- self.lock.release()
152
- return history, history
153
-
154
- block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
155
-
156
- with block:
157
- with gr.Row():
158
- gr.Markdown("<h3><center>Chat-Your-Data (Investor Education)</center></h3>")
159
- embed_but = gr.Button(value='Step 1: Click Me to Load the QA System')
160
- with gr.Row():
161
- websites = gr.Radio(choices=site_options_list,value=site_options_list[0],label='Select US or AUS website data',
162
- interactive=True)
163
- websites.change(on_value_change, websites)
164
-
165
- vectorstore = load_vectorstore(websites.value)
166
-
167
- chatbot = gr.Chatbot()
168
 
169
- chat = ChatWrapper()
170
 
 
 
171
 
172
- with gr.Row():
173
- message = gr.Textbox(
174
- label="What's your question?",
175
- placeholder="Ask questions about Investing",
176
- lines=1,
177
- )
178
- submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
179
-
180
- gr.Examples(
181
- examples=[
182
- "What are the benefits of investing in ETFs?",
183
- "What is the average cost of investing in a managed fund?",
184
- "At what age can I start investing?",
185
- "Do you offer investment accounts for kids?"
186
- ],
187
- inputs=message,
188
- )
189
 
190
- gr.HTML("Demo application of a LangChain chain.")
 
 
 
 
 
191
 
192
- gr.HTML(
193
- "<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain πŸ¦œοΈπŸ”—</a></center>"
194
- )
195
 
196
- state = gr.State()
197
- agent_state = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
-
200
- submit.click(chat, inputs=[message, state, agent_state], outputs=[chatbot, state])
201
- message.submit(chat, inputs=[message, state, agent_state], outputs=[chatbot, state])
 
 
 
 
 
 
202
 
203
- embed_but.click(
204
- load_chain,
205
- outputs=[agent_state],
206
- )
207
 
208
  gr.Markdown("![](https://komarev.com/ghpvc/?username=nickmuchi87&style=flat-square)")
209
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import streamlit as st
3
+
4
+ from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
5
+ from langchain.vectorstores.faiss import FAISS
6
+ from huggingface_hub import snapshot_download
7
+
8
+ from langchain.callbacks import StreamlitCallbackHandler
9
+ from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
10
+ from langchain.agents.agent_toolkits import create_retriever_tool
11
+ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
12
+ AgentTokenBufferMemory,
13
  )
14
+ from langchain.chat_models import ChatOpenAI
15
+ from langchain.schema import SystemMessage, AIMessage, HumanMessage
16
+ from langchain.prompts import MessagesPlaceholder
17
+ from langsmith import Client
18
+
19
+ client = Client()
20
+
21
+ st.set_page_config(
22
+ page_title="Investor Education ChatChain",
23
+ page_icon="πŸ“–",
24
+ layout="wide",
25
+ initial_sidebar_state="collapsed",
26
  )
27
 
28
+ #Load API Key
29
+ api_key = os.environ["OPENAI_API_KEY"]
30
 
31
+ #### sidebar section 1 ####
32
 
33
  site_options = {'US': 'vanguard_embeddings_US',
34
  'AUS': 'vanguard_embeddings'}
35
 
36
  site_options_list = list(site_options.keys())
37
 
38
+ site_radio = st.radio(
39
+ "Which Vanguard website location would you want to chat to?",
40
+ ('US', 'AUS'))
41
 
42
+ @st.cache_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def load_vectorstore(site):
44
  '''load embeddings and vectorstore'''
45
 
46
  emb = HuggingFaceEmbeddings(model_name="all-mpnet-base-v2")
47
+
48
+ vectorstore = FAISS.load_local(site_options[site], emb)
49
 
50
+ return vectorstore.as_retriever(search_kwargs={"k": 4})
51
 
 
 
52
 
53
+ tool = create_retriever_tool(
54
+ load_vectorstore(site_radio),
55
+ "search_vaguard_website",
56
+ "Searches and returns documents regarding the Vanguard website across US and UK locations. The websites provide investment related information to the user")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ tools = [tool]
59
+ llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4")
60
+ message = SystemMessage(
61
+ content=(
62
+ "You are a helpful chatbot who is tasked with answering questions about investments using informationn that has been scraped from a website to answer the users question accurately."
63
+ "Do not use any information not provided in the website context."
64
+ "Unless otherwise explicitly stated, it is probably fair to assume that questions are about the CFA program and materials. "
65
+ "If there is any ambiguity, you probably assume they are about that."
66
+ )
67
+ )
68
 
69
+ prompt = OpenAIFunctionsAgent.create_prompt(
70
+ system_message=message,
71
+ extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
72
+ )
73
+ agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
74
+ agent_executor = AgentExecutor(
75
+ agent=agent,
76
+ tools=tools,
77
+ verbose=True,
78
+ return_intermediate_steps=True,
79
+ )
80
+ memory = AgentTokenBufferMemory(llm=llm)
81
+ starter_message = "Ask me anything about information on the Vanguard US/UK websites!"
82
+ if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
83
+ st.session_state["messages"] = [AIMessage(content=starter_message)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
 
85
 
86
+ def send_feedback(run_id, score):
87
+ client.create_feedback(run_id, "user_score", score=score)
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ for msg in st.session_state.messages:
91
+ if isinstance(msg, AIMessage):
92
+ st.chat_message("assistant").write(msg.content)
93
+ elif isinstance(msg, HumanMessage):
94
+ st.chat_message("user").write(msg.content)
95
+ memory.chat_memory.add_message(msg)
96
 
 
 
 
97
 
98
+ if prompt := st.chat_input(placeholder=starter_message):
99
+ st.chat_message("user").write(prompt)
100
+ with st.chat_message("assistant"):
101
+ st_callback = StreamlitCallbackHandler(st.container())
102
+ response = agent_executor(
103
+ {"input": prompt, "history": st.session_state.messages},
104
+ callbacks=[st_callback],
105
+ include_run_info=True,
106
+ )
107
+ st.session_state.messages.append(AIMessage(content=response["output"]))
108
+ st.write(response["output"])
109
+ memory.save_context({"input": prompt}, response)
110
+ st.session_state["messages"] = memory.buffer
111
+ run_id = response["__run"].run_id
112
 
113
+ col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1])
114
+ with col_text:
115
+ st.text("Feedback:")
116
+
117
+ # with col1:
118
+ # st.button("πŸ‘", on_click=send_feedback, args=(run_id, 1))
119
+
120
+ # with col2:
121
+ # st.button("πŸ‘Ž", on_click=send_feedback, args=(run_id, 0)
122
 
 
 
 
 
123
 
124
  gr.Markdown("![](https://komarev.com/ghpvc/?username=nickmuchi87&style=flat-square)")
125