araeyn commited on
Commit
7b591d9
1 Parent(s): 647b731

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -33
app.py CHANGED
@@ -2,7 +2,7 @@ import asyncio
2
  import json
3
  from websockets.server import serve
4
  import os
5
- from langchain_community.vectorstores import Chroma
6
  from langchain_huggingface.embeddings import HuggingFaceEmbeddings
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain_huggingface.llms import HuggingFaceEndpoint
@@ -18,24 +18,22 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
18
  from langchain_core.runnables.history import RunnableWithMessageHistory
19
  from langchain_core.chat_history import BaseChatMessageHistory
20
  from langchain_community.chat_message_histories import ChatMessageHistory
 
21
 
22
  if not os.path.isdir('database'):
23
  os.system("unzip database.zip")
24
 
25
- clean_up_tokenization_spaces = True
26
-
27
  loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader)
28
 
29
  documents = loader.load()
30
 
31
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
32
- texts = text_splitter.split_documents(documents)
33
 
34
  print()
35
  print("-------")
36
  print("TextSplitter, DirectoryLoader")
37
  print("-------")
38
- print("--")
39
 
40
  persist_directory = 'db'
41
 
@@ -45,34 +43,14 @@ print()
45
  print("-------")
46
  print("Embeddings")
47
  print("-------")
48
- print("--")
49
-
50
- vectordb = Chroma.from_documents(documents=texts,
51
- embedding=embedding,
52
- persist_directory=persist_directory)
53
-
54
- vectordb.persist()
55
- vectordb = None
56
-
57
- print()
58
- print("-------")
59
- print("Chroma1")
60
- print("-------")
61
- print("--")
62
 
63
- vectordb = Chroma(persist_directory=persist_directory,
64
- embedding_function=embedding)
65
-
66
- print()
67
- print("-------")
68
- print("Chroma2")
69
- print("-------")
70
- print("--")
71
 
72
  def format_docs(docs):
73
  return "\n\n".join(doc.page_content for doc in docs)
74
 
75
- retriever = vectordb.as_retriever()
 
76
  prompt = hub.pull("rlm/rag-prompt")
77
  llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")
78
  rag_chain = (
@@ -86,8 +64,8 @@ print()
86
  print("-------")
87
  print("Retriever, Prompt, LLM, Rag_Chain")
88
  print("-------")
89
- print("--")
90
 
 
91
  contextualize_q_system_prompt = """Given a chat history and the latest user question \
92
  which might reference context in the chat history, formulate a standalone question \
93
  which can be understood without the chat history. Do NOT answer the question, \
@@ -103,6 +81,8 @@ history_aware_retriever = create_history_aware_retriever(
103
  llm, retriever, contextualize_q_prompt
104
  )
105
 
 
 
106
  qa_system_prompt = """You are an assistant for question-answering tasks. \
107
  Use the following pieces of retrieved context to answer the question. \
108
  If you don't know the answer, just say that you don't know. \
@@ -116,16 +96,20 @@ qa_prompt = ChatPromptTemplate.from_messages(
116
  ("human", "{input}"),
117
  ]
118
  )
 
 
 
 
119
 
 
120
  store = {}
121
 
 
122
  def get_session_history(session_id: str) -> BaseChatMessageHistory:
123
  if session_id not in store:
124
  store[session_id] = ChatMessageHistory()
125
  return store[session_id]
126
 
127
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
128
- rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
129
 
130
  conversational_rag_chain = RunnableWithMessageHistory(
131
  rag_chain,
@@ -140,11 +124,20 @@ print("started")
140
  print("-------")
141
 
142
  response = conversational_rag_chain.invoke(
143
- {"input": "who is the math teacher"},
 
 
 
 
 
 
 
 
144
  config={
145
  "configurable": {"session_id": "test"}
146
  },
147
  )["answer"]
 
148
 
149
  async def echo(websocket):
150
  async for message in websocket:
@@ -170,4 +163,4 @@ async def main():
170
  async with serve(echo, "0.0.0.0", 7860):
171
  await asyncio.Future()
172
 
173
- asyncio.run(main())
 
2
  import json
3
  from websockets.server import serve
4
  import os
5
+ from langchain_chroma import Chroma
6
  from langchain_huggingface.embeddings import HuggingFaceEmbeddings
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain_huggingface.llms import HuggingFaceEndpoint
 
18
  from langchain_core.runnables.history import RunnableWithMessageHistory
19
  from langchain_core.chat_history import BaseChatMessageHistory
20
  from langchain_community.chat_message_histories import ChatMessageHistory
21
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
22
 
23
  if not os.path.isdir('database'):
24
  os.system("unzip database.zip")
25
 
 
 
26
  loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader)
27
 
28
  documents = loader.load()
29
 
30
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
31
+ splits = text_splitter.split_documents(documents)
32
 
33
  print()
34
  print("-------")
35
  print("TextSplitter, DirectoryLoader")
36
  print("-------")
 
37
 
38
  persist_directory = 'db'
39
 
 
43
  print("-------")
44
  print("Embeddings")
45
  print("-------")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embedding)
 
 
 
 
 
 
 
48
 
49
  def format_docs(docs):
50
  return "\n\n".join(doc.page_content for doc in docs)
51
 
52
+ retriever = vectorstore.as_retriever()
53
+
54
  prompt = hub.pull("rlm/rag-prompt")
55
  llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")
56
  rag_chain = (
 
64
  print("-------")
65
  print("Retriever, Prompt, LLM, Rag_Chain")
66
  print("-------")
 
67
 
68
+ ### Contextualize question ###
69
  contextualize_q_system_prompt = """Given a chat history and the latest user question \
70
  which might reference context in the chat history, formulate a standalone question \
71
  which can be understood without the chat history. Do NOT answer the question, \
 
81
  llm, retriever, contextualize_q_prompt
82
  )
83
 
84
+
85
+ ### Answer question ###
86
  qa_system_prompt = """You are an assistant for question-answering tasks. \
87
  Use the following pieces of retrieved context to answer the question. \
88
  If you don't know the answer, just say that you don't know. \
 
96
  ("human", "{input}"),
97
  ]
98
  )
99
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
100
+
101
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
102
+
103
 
104
+ ### Statefully manage chat history ###
105
  store = {}
106
 
107
+
108
  def get_session_history(session_id: str) -> BaseChatMessageHistory:
109
  if session_id not in store:
110
  store[session_id] = ChatMessageHistory()
111
  return store[session_id]
112
 
 
 
113
 
114
  conversational_rag_chain = RunnableWithMessageHistory(
115
  rag_chain,
 
124
  print("-------")
125
 
126
  response = conversational_rag_chain.invoke(
127
+ {"input": input()},
128
+ config={
129
+ "configurable": {"session_id": "test"}
130
+ },
131
+ )["answer"]
132
+ print(response)
133
+
134
+ response = conversational_rag_chain.invoke(
135
+ {"input": input()},
136
  config={
137
  "configurable": {"session_id": "test"}
138
  },
139
  )["answer"]
140
+ print(response)
141
 
142
  async def echo(websocket):
143
  async for message in websocket:
 
163
  async with serve(echo, "0.0.0.0", 7860):
164
  await asyncio.Future()
165
 
166
+ asyncio.run(main())