Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import asyncio
|
|
2 |
import json
|
3 |
from websockets.server import serve
|
4 |
import os
|
5 |
-
from
|
6 |
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
7 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
from langchain_huggingface.llms import HuggingFaceEndpoint
|
@@ -18,24 +18,22 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
18 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
19 |
from langchain_core.chat_history import BaseChatMessageHistory
|
20 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
|
|
21 |
|
22 |
if not os.path.isdir('database'):
|
23 |
os.system("unzip database.zip")
|
24 |
|
25 |
-
clean_up_tokenization_spaces = True
|
26 |
-
|
27 |
loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader)
|
28 |
|
29 |
documents = loader.load()
|
30 |
|
31 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
32 |
-
|
33 |
|
34 |
print()
|
35 |
print("-------")
|
36 |
print("TextSplitter, DirectoryLoader")
|
37 |
print("-------")
|
38 |
-
print("--")
|
39 |
|
40 |
persist_directory = 'db'
|
41 |
|
@@ -45,34 +43,14 @@ print()
|
|
45 |
print("-------")
|
46 |
print("Embeddings")
|
47 |
print("-------")
|
48 |
-
print("--")
|
49 |
-
|
50 |
-
vectordb = Chroma.from_documents(documents=texts,
|
51 |
-
embedding=embedding,
|
52 |
-
persist_directory=persist_directory)
|
53 |
-
|
54 |
-
vectordb.persist()
|
55 |
-
vectordb = None
|
56 |
-
|
57 |
-
print()
|
58 |
-
print("-------")
|
59 |
-
print("Chroma1")
|
60 |
-
print("-------")
|
61 |
-
print("--")
|
62 |
|
63 |
-
|
64 |
-
embedding_function=embedding)
|
65 |
-
|
66 |
-
print()
|
67 |
-
print("-------")
|
68 |
-
print("Chroma2")
|
69 |
-
print("-------")
|
70 |
-
print("--")
|
71 |
|
72 |
def format_docs(docs):
|
73 |
return "\n\n".join(doc.page_content for doc in docs)
|
74 |
|
75 |
-
retriever =
|
|
|
76 |
prompt = hub.pull("rlm/rag-prompt")
|
77 |
llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
78 |
rag_chain = (
|
@@ -86,8 +64,8 @@ print()
|
|
86 |
print("-------")
|
87 |
print("Retriever, Prompt, LLM, Rag_Chain")
|
88 |
print("-------")
|
89 |
-
print("--")
|
90 |
|
|
|
91 |
contextualize_q_system_prompt = """Given a chat history and the latest user question \
|
92 |
which might reference context in the chat history, formulate a standalone question \
|
93 |
which can be understood without the chat history. Do NOT answer the question, \
|
@@ -103,6 +81,8 @@ history_aware_retriever = create_history_aware_retriever(
|
|
103 |
llm, retriever, contextualize_q_prompt
|
104 |
)
|
105 |
|
|
|
|
|
106 |
qa_system_prompt = """You are an assistant for question-answering tasks. \
|
107 |
Use the following pieces of retrieved context to answer the question. \
|
108 |
If you don't know the answer, just say that you don't know. \
|
@@ -116,16 +96,20 @@ qa_prompt = ChatPromptTemplate.from_messages(
|
|
116 |
("human", "{input}"),
|
117 |
]
|
118 |
)
|
|
|
|
|
|
|
|
|
119 |
|
|
|
120 |
store = {}
|
121 |
|
|
|
122 |
def get_session_history(session_id: str) -> BaseChatMessageHistory:
|
123 |
if session_id not in store:
|
124 |
store[session_id] = ChatMessageHistory()
|
125 |
return store[session_id]
|
126 |
|
127 |
-
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
128 |
-
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
129 |
|
130 |
conversational_rag_chain = RunnableWithMessageHistory(
|
131 |
rag_chain,
|
@@ -140,11 +124,20 @@ print("started")
|
|
140 |
print("-------")
|
141 |
|
142 |
response = conversational_rag_chain.invoke(
|
143 |
-
{"input":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
config={
|
145 |
"configurable": {"session_id": "test"}
|
146 |
},
|
147 |
)["answer"]
|
|
|
148 |
|
149 |
async def echo(websocket):
|
150 |
async for message in websocket:
|
@@ -170,4 +163,4 @@ async def main():
|
|
170 |
async with serve(echo, "0.0.0.0", 7860):
|
171 |
await asyncio.Future()
|
172 |
|
173 |
-
asyncio.run(main())
|
|
|
2 |
import json
|
3 |
from websockets.server import serve
|
4 |
import os
|
5 |
+
from langchain_chroma import Chroma
|
6 |
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
7 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
from langchain_huggingface.llms import HuggingFaceEndpoint
|
|
|
18 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
19 |
from langchain_core.chat_history import BaseChatMessageHistory
|
20 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
21 |
+
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
22 |
|
23 |
if not os.path.isdir('database'):
|
24 |
os.system("unzip database.zip")
|
25 |
|
|
|
|
|
26 |
loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader)
|
27 |
|
28 |
documents = loader.load()
|
29 |
|
30 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
31 |
+
splits = text_splitter.split_documents(documents)
|
32 |
|
33 |
print()
|
34 |
print("-------")
|
35 |
print("TextSplitter, DirectoryLoader")
|
36 |
print("-------")
|
|
|
37 |
|
38 |
persist_directory = 'db'
|
39 |
|
|
|
43 |
print("-------")
|
44 |
print("Embeddings")
|
45 |
print("-------")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def format_docs(docs):
|
50 |
return "\n\n".join(doc.page_content for doc in docs)
|
51 |
|
52 |
+
retriever = vectorstore.as_retriever()
|
53 |
+
|
54 |
prompt = hub.pull("rlm/rag-prompt")
|
55 |
llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
56 |
rag_chain = (
|
|
|
64 |
print("-------")
|
65 |
print("Retriever, Prompt, LLM, Rag_Chain")
|
66 |
print("-------")
|
|
|
67 |
|
68 |
+
### Contextualize question ###
|
69 |
contextualize_q_system_prompt = """Given a chat history and the latest user question \
|
70 |
which might reference context in the chat history, formulate a standalone question \
|
71 |
which can be understood without the chat history. Do NOT answer the question, \
|
|
|
81 |
llm, retriever, contextualize_q_prompt
|
82 |
)
|
83 |
|
84 |
+
|
85 |
+
### Answer question ###
|
86 |
qa_system_prompt = """You are an assistant for question-answering tasks. \
|
87 |
Use the following pieces of retrieved context to answer the question. \
|
88 |
If you don't know the answer, just say that you don't know. \
|
|
|
96 |
("human", "{input}"),
|
97 |
]
|
98 |
)
|
99 |
+
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
100 |
+
|
101 |
+
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
102 |
+
|
103 |
|
104 |
+
### Statefully manage chat history ###
|
105 |
store = {}
|
106 |
|
107 |
+
|
108 |
def get_session_history(session_id: str) -> BaseChatMessageHistory:
|
109 |
if session_id not in store:
|
110 |
store[session_id] = ChatMessageHistory()
|
111 |
return store[session_id]
|
112 |
|
|
|
|
|
113 |
|
114 |
conversational_rag_chain = RunnableWithMessageHistory(
|
115 |
rag_chain,
|
|
|
124 |
print("-------")
|
125 |
|
126 |
response = conversational_rag_chain.invoke(
|
127 |
+
{"input": input()},
|
128 |
+
config={
|
129 |
+
"configurable": {"session_id": "test"}
|
130 |
+
},
|
131 |
+
)["answer"]
|
132 |
+
print(response)
|
133 |
+
|
134 |
+
response = conversational_rag_chain.invoke(
|
135 |
+
{"input": input()},
|
136 |
config={
|
137 |
"configurable": {"session_id": "test"}
|
138 |
},
|
139 |
)["answer"]
|
140 |
+
print(response)
|
141 |
|
142 |
async def echo(websocket):
|
143 |
async for message in websocket:
|
|
|
163 |
async with serve(echo, "0.0.0.0", 7860):
|
164 |
await asyncio.Future()
|
165 |
|
166 |
+
asyncio.run(main())
|