RubenAMtz commited on
Commit
dd49b84
·
1 Parent(s): cdda8d7

added chat memory and fixed sys_message bugs

Browse files
Files changed (3) hide show
  1. app.py +32 -88
  2. requirements.txt +2 -1
  3. utils/chain.py +21 -3
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  # OpenAI Chat completion
4
  import os
5
- from openai import AsyncOpenAI # importing openai for API usage
6
  import chainlit as cl # importing chainlit for our app
7
  from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
8
  from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools
@@ -18,18 +17,14 @@ from utils.store import index_documents
18
  from utils.chain import create_chain
19
  from langchain.vectorstores import Pinecone
20
  from langchain.chat_models import ChatOpenAI
21
- from langchain.prompts import ChatPromptTemplate
22
- from langchain.prompts import PromptTemplate
23
- from operator import itemgetter
24
  from langchain.schema.runnable import RunnableSequence
25
  from langchain.schema import format_document
26
- from langchain.schema.output_parser import StrOutputParser
27
- from langchain.prompts.prompt import PromptTemplate
28
  from pprint import pprint
29
- from langchain_core.documents.base import Document
30
  from langchain_core.vectorstores import VectorStoreRetriever
31
  import langchain
32
  from langchain.cache import InMemoryCache
 
 
33
 
34
  load_dotenv()
35
  YOUR_API_KEY = os.environ["PINECONE_API_KEY"]
@@ -97,11 +92,16 @@ async def start_chat():
97
  # log data in WaB (on start)
98
  os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
99
 
 
 
 
 
100
  tools = {
101
  "arxiv_client": arxiv_client,
102
  "index": index,
103
  "embedder": embedder,
104
- "llm": llm
 
105
  }
106
  cl.user_session.set("tools", tools)
107
  cl.user_session.set("settings", settings)
@@ -111,18 +111,23 @@ async def start_chat():
111
  @cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
112
  async def main(message: cl.Message):
113
  settings = cl.user_session.get("settings")
114
- tools = cl.user_session.get("tools")
115
  first_run = cl.user_session.get("first_run")
 
 
116
 
 
 
 
117
  if not first_run:
118
 
119
  arxiv_client: arxiv.Client = tools['arxiv_client']
120
  index: pinecone.GRPCIndex = tools['index']
121
  embedder: CacheBackedEmbeddings = tools['embedder']
122
  llm: ChatOpenAI = tools['llm']
 
123
 
124
  # using query search for ArXiv documents (on message)
125
-
126
  search = arxiv.Search(
127
  query = message.content,
128
  max_results = 10,
@@ -130,18 +135,10 @@ async def main(message: cl.Message):
130
  )
131
  paper_urls = []
132
 
133
- sys_message = cl.Message(content="")
134
- await sys_message.send() # renders a loader
135
  for result in arxiv_client.results(search):
136
  paper_urls.append(result.pdf_url)
137
- sys_message.content = """
138
- I found some papers, let me study them real quick to help
139
- you learn, don't worry it'll be a few seconds 😉"""
140
- await sys_message.update()
141
- await sys_message.send()
142
 
143
- sys_message = cl.Message(content="")
144
- await sys_message.send() # renders a loader
145
  # load them and split them (on message)
146
  docs = []
147
  for paper_url in paper_urls:
@@ -159,9 +156,6 @@ async def main(message: cl.Message):
159
 
160
  # create an index using pinecone (on message)
161
  index_documents(docs, text_splitter, embedder, index)
162
- sys_message.content = "Done studying :)"
163
- await sys_message.update()
164
- await sys_message.send()
165
 
166
  text_field = "source_document"
167
  index = pinecone.Index(INDEX_NAME)
@@ -174,74 +168,24 @@ async def main(message: cl.Message):
174
 
175
  # create the chain (on message)
176
  retrieval_augmented_qa_chain: RunnableSequence = create_chain(retriever=retriever, llm=llm)
 
177
 
178
- # message.content = await cl.AskUserMessage(
179
- # content="Ask away"
180
- # ).send()
181
-
182
  # run
183
- msg = cl.Message(content="")
184
- for chunk in retrieval_augmented_qa_chain.stream({"question": f"{message.content}"}):
185
  pprint(chunk)
186
  if res:= chunk.get('response'):
187
- await msg.stream_token(res.content)
188
- await msg.send()
189
- cl.user_session.set("first_run", True)
190
- # first_run = True
191
 
192
-
193
- # client = AsyncOpenAI()
194
-
195
- # print(message.content)
196
-
197
- # results_list = vector_db.search_by_text(query_text=message.content, k=3, return_as_text=True)
198
- # if results_list:
199
- # results_string = "\n\n".join(results_list)
200
- # else:
201
- # results_string = ""
202
-
203
- # prompt = Prompt(
204
- # provider=ChatOpenAI.id,
205
- # messages=[
206
- # PromptMessage(
207
- # role="system",
208
- # template=system_template,
209
- # formatted=system_template,
210
- # ),
211
- # PromptMessage(
212
- # role="user",
213
- # template=user_template,
214
- # formatted=user_template.format(input=message.content),
215
- # ),
216
- # PromptMessage(
217
- # role="assistant",
218
- # template=assistant_template,
219
- # formatted=assistant_template.format(context=results_string)
220
- # )
221
- # ],
222
- # inputs={
223
- # "input": message.content,
224
- # "context": results_string
225
- # },
226
- # settings=settings,
227
- # )
228
-
229
- # print([m.to_openai() for m in prompt.messages])
230
-
231
- # msg = cl.Message(content="")
232
-
233
- # # Call OpenAI
234
- # async for stream_resp in await client.chat.completions.create(
235
- # messages=[m.to_openai() for m in prompt.messages], stream=True, **settings
236
- # ):
237
- # token = stream_resp.choices[0].delta.content
238
- # if not token:
239
- # token = ""
240
- # await msg.stream_token(token)
241
-
242
- # # Update the prompt object with the completion
243
- # prompt.completion = msg.content
244
- # msg.prompt = prompt
245
-
246
- # # Send and close the message stream
247
- # await msg.send()
 
2
 
3
  # OpenAI Chat completion
4
  import os
 
5
  import chainlit as cl # importing chainlit for our app
6
  from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
7
  from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools
 
17
  from utils.chain import create_chain
18
  from langchain.vectorstores import Pinecone
19
  from langchain.chat_models import ChatOpenAI
 
 
 
20
  from langchain.schema.runnable import RunnableSequence
21
  from langchain.schema import format_document
 
 
22
  from pprint import pprint
 
23
  from langchain_core.vectorstores import VectorStoreRetriever
24
  import langchain
25
  from langchain.cache import InMemoryCache
26
+ from langchain_core.messages.human import HumanMessage
27
+ from langchain.memory import ConversationBufferMemory
28
 
29
  load_dotenv()
30
  YOUR_API_KEY = os.environ["PINECONE_API_KEY"]
 
92
  # log data in WaB (on start)
93
  os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
94
 
95
+ # setup memory
96
+
97
+ memory = ConversationBufferMemory(memory_key="chat_history")
98
+
99
  tools = {
100
  "arxiv_client": arxiv_client,
101
  "index": index,
102
  "embedder": embedder,
103
+ "llm": llm,
104
+ "memory": memory
105
  }
106
  cl.user_session.set("tools", tools)
107
  cl.user_session.set("settings", settings)
 
111
  @cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
112
  async def main(message: cl.Message):
113
  settings = cl.user_session.get("settings")
114
+ tools: dict = cl.user_session.get("tools")
115
  first_run = cl.user_session.get("first_run")
116
+ retrieval_augmented_qa_chain = cl.user_session.get("chain", None)
117
+ memory: ConversationBufferMemory = cl.user_session.get("memory")
118
 
119
+ sys_message = cl.Message(content="")
120
+ await sys_message.send() # renders a loader
121
+
122
  if not first_run:
123
 
124
  arxiv_client: arxiv.Client = tools['arxiv_client']
125
  index: pinecone.GRPCIndex = tools['index']
126
  embedder: CacheBackedEmbeddings = tools['embedder']
127
  llm: ChatOpenAI = tools['llm']
128
+ memory: ConversationBufferMemory = tools['memory']
129
 
130
  # using query search for ArXiv documents (on message)
 
131
  search = arxiv.Search(
132
  query = message.content,
133
  max_results = 10,
 
135
  )
136
  paper_urls = []
137
 
138
+
 
139
  for result in arxiv_client.results(search):
140
  paper_urls.append(result.pdf_url)
 
 
 
 
 
141
 
 
 
142
  # load them and split them (on message)
143
  docs = []
144
  for paper_url in paper_urls:
 
156
 
157
  # create an index using pinecone (on message)
158
  index_documents(docs, text_splitter, embedder, index)
 
 
 
159
 
160
  text_field = "source_document"
161
  index = pinecone.Index(INDEX_NAME)
 
168
 
169
  # create the chain (on message)
170
  retrieval_augmented_qa_chain: RunnableSequence = create_chain(retriever=retriever, llm=llm)
171
+ cl.user_session.set("chain", retrieval_augmented_qa_chain)
172
 
173
+ sys_message.content = """
174
+ I found some papers and studied them 😉 \n"""
175
+ await sys_message.update()
176
+
177
  # run
178
+ for chunk in retrieval_augmented_qa_chain.stream({"question": f"{message.content}", "chat_history": memory.buffer_as_messages}):
 
179
  pprint(chunk)
180
  if res:= chunk.get('response'):
181
+ await sys_message.stream_token(res.content)
182
+ await sys_message.send()
 
 
183
 
184
+ memory.chat_memory.add_user_message(message.content)
185
+ memory.chat_memory.add_ai_message(sys_message.content)
186
+
187
+ print(memory.buffer_as_str)
188
+
189
+
190
+ cl.user_session.set("memory", memory)
191
+ cl.user_session.set("first_run", True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -6,4 +6,5 @@ python-dotenv==1.0.0
6
  numpy==1.25.2
7
  langchain
8
  pinecone-client[grpc]
9
- pypdf
 
 
6
  numpy==1.25.2
7
  langchain
8
  pinecone-client[grpc]
9
+ pypdf
10
+ arxiv
utils/chain.py CHANGED
@@ -2,10 +2,12 @@ from operator import itemgetter
2
  from langchain_core.vectorstores import VectorStoreRetriever
3
  from langchain.schema.runnable import RunnableLambda, RunnableParallel, RunnableSequence
4
  from langchain.chat_models import ChatOpenAI
5
- from langchain.prompts import PromptTemplate
6
  from langchain_core.documents import Document
7
  from langchain_core.messages.ai import AIMessage
8
-
 
 
9
 
10
  template = """
11
  You are a helpful assistant, your job is to answer the user's question using the relevant context.
@@ -16,7 +18,21 @@ CONTEXT
16
 
17
  User question: {question}
18
  """
 
19
  prompt = PromptTemplate.from_template(template=template)
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  def to_doc(input: AIMessage) -> list[Document]:
@@ -46,7 +62,7 @@ def create_chain(**kwargs) -> RunnableSequence:
46
 
47
  docs_chain = (itemgetter("question") | retriever).with_config(config={"run_name": "docs"})
48
  self_knowledge_chain = (itemgetter("question") | llm | to_doc).with_config(config={"run_name": "self knowledge"})
49
- response_chain = (prompt | llm).with_config(config={"run_name": "response"})
50
  merge_docs_link = RunnableLambda(merge_docs).with_config(config={"run_name": "merge docs"})
51
  context_chain = (
52
  RunnableParallel(
@@ -61,11 +77,13 @@ def create_chain(**kwargs) -> RunnableSequence:
61
  retrieval_augmented_qa_chain = (
62
  RunnableParallel({
63
  "question": itemgetter("question"),
 
64
  "context": context_chain
65
  })
66
  | RunnableParallel({
67
  "response": response_chain,
68
  "context": itemgetter("context"),
 
69
  })
70
  )
71
  return retrieval_augmented_qa_chain
 
2
  from langchain_core.vectorstores import VectorStoreRetriever
3
  from langchain.schema.runnable import RunnableLambda, RunnableParallel, RunnableSequence
4
  from langchain.chat_models import ChatOpenAI
5
+ from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
6
  from langchain_core.documents import Document
7
  from langchain_core.messages.ai import AIMessage
8
+ from langchain_core.messages.human import HumanMessage
9
+ from langchain_core.messages.system import SystemMessage
10
+ from langchain_core.messages.function import FunctionMessage
11
 
12
  template = """
13
  You are a helpful assistant, your job is to answer the user's question using the relevant context.
 
18
 
19
  User question: {question}
20
  """
21
+
22
  prompt = PromptTemplate.from_template(template=template)
23
+ chat_prompt = ChatPromptTemplate.from_messages([
24
+ ("system", """
25
+ You are a helpful assistant, your job is to answer the user's question using the relevant context:
26
+ =========
27
+ CONTEXT:
28
+ {context}
29
+ =========
30
+ """),
31
+ MessagesPlaceholder(variable_name="chat_history"),
32
+ ("human", "{question}")
33
+ ])
34
+
35
+
36
 
37
 
38
  def to_doc(input: AIMessage) -> list[Document]:
 
62
 
63
  docs_chain = (itemgetter("question") | retriever).with_config(config={"run_name": "docs"})
64
  self_knowledge_chain = (itemgetter("question") | llm | to_doc).with_config(config={"run_name": "self knowledge"})
65
+ response_chain = (chat_prompt | llm).with_config(config={"run_name": "response"})
66
  merge_docs_link = RunnableLambda(merge_docs).with_config(config={"run_name": "merge docs"})
67
  context_chain = (
68
  RunnableParallel(
 
77
  retrieval_augmented_qa_chain = (
78
  RunnableParallel({
79
  "question": itemgetter("question"),
80
+ "chat_history": itemgetter("chat_history"),
81
  "context": context_chain
82
  })
83
  | RunnableParallel({
84
  "response": response_chain,
85
  "context": itemgetter("context"),
86
+ "chat_history": itemgetter("chat_history")
87
  })
88
  )
89
  return retrieval_augmented_qa_chain