araeyn commited on
Commit
95e34ed
1 Parent(s): 244b084

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +141 -0
  2. database.zip +3 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ from websockets.server import serve
4
+ import os
5
+ from langchain_community.vectorstores import Chroma
6
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_huggingface.llms import HuggingFaceEndpoint
9
+ from langchain_community.document_loaders import TextLoader
10
+ from langchain_community.document_loaders import DirectoryLoader
11
+ from langchain import hub
12
+ from langchain_core.runnables import RunnablePassthrough
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from langchain.chains import create_history_aware_retriever
15
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
16
+ from langchain.chains import create_retrieval_chain
17
+ from langchain.chains.combine_documents import create_stuff_documents_chain
18
+ from langchain_core.runnables.history import RunnableWithMessageHistory
19
+ from langchain_core.chat_history import BaseChatMessageHistory
20
+ from langchain_community.chat_message_histories import ChatMessageHistory
21
+
22
+ if not os.path.isdir('database'):
23
+ os.system("unzip database.zip")
24
+
25
+ loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader)
26
+
27
+ documents = loader.load()
28
+
29
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
30
+ texts = text_splitter.split_documents(documents)
31
+
32
+ persist_directory = 'db'
33
+
34
+ embedding = HuggingFaceEmbeddings()
35
+
36
+ vectordb = Chroma.from_documents(documents=texts,
37
+ embedding=embedding,
38
+ persist_directory=persist_directory)
39
+
40
+ vectordb.persist()
41
+ vectordb = None
42
+
43
+ vectordb = Chroma(persist_directory=persist_directory,
44
+ embedding_function=embedding)
45
+
46
+ def format_docs(docs):
47
+ return "\n\n".join(doc.page_content for doc in docs)
48
+
49
+ retriever = vectordb.as_retriever()
50
+ prompt = hub.pull("rlm/rag-prompt")
51
+ llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")
52
+ rag_chain = (
53
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
54
+ | prompt
55
+ | llm
56
+ | StrOutputParser()
57
+ )
58
+
59
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
60
+ which might reference context in the chat history, formulate a standalone question \
61
+ which can be understood without the chat history. Do NOT answer the question, \
62
+ just reformulate it if needed and otherwise return it as is."""
63
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
64
+ [
65
+ ("system", contextualize_q_system_prompt),
66
+ MessagesPlaceholder("chat_history"),
67
+ ("human", "{input}"),
68
+ ]
69
+ )
70
+ history_aware_retriever = create_history_aware_retriever(
71
+ llm, retriever, contextualize_q_prompt
72
+ )
73
+
74
+ qa_system_prompt = """You are an assistant for question-answering tasks. \
75
+ Use the following pieces of retrieved context to answer the question. \
76
+ If you don't know the answer, just say that you don't know. \
77
+ Use three sentences maximum and keep the answer concise.\
78
+
79
+ {context}"""
80
+ qa_prompt = ChatPromptTemplate.from_messages(
81
+ [
82
+ ("system", qa_system_prompt),
83
+ MessagesPlaceholder("chat_history"),
84
+ ("human", "{input}"),
85
+ ]
86
+ )
87
+
88
+ store = {}
89
+
90
+ def get_session_history(session_id: str) -> BaseChatMessageHistory:
91
+ if session_id not in store:
92
+ store[session_id] = ChatMessageHistory()
93
+ return store[session_id]
94
+
95
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
96
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
97
+
98
+ conversational_rag_chain = RunnableWithMessageHistory(
99
+ rag_chain,
100
+ get_session_history,
101
+ input_messages_key="input",
102
+ history_messages_key="chat_history",
103
+ output_messages_key="answer",
104
+ )
105
+ print()
106
+ print("-------")
107
+ print("started")
108
+ print("-------")
109
+
110
+ response = conversational_rag_chain.invoke(
111
+ {"input": "who is the math teacher"},
112
+ config={
113
+ "configurable": {"session_id": "test"}
114
+ },
115
+ )["answer"]
116
+
117
+ async def echo(websocket):
118
+ async for message in websocket:
119
+ data = json.loads(message)
120
+ if not "message" in message:
121
+ return
122
+ if not "token" in message:
123
+ return
124
+ m = data["message"]
125
+ token = data["token"]
126
+ userData = json.load(open("userData.json", "w"))
127
+ docs = retriever.get_relevant_documents(m)
128
+ userData[token]["docs"] = str(docs)
129
+ response = conversational_rag_chain.invoke(
130
+ {"input": m},
131
+ config={
132
+ "configurable": {"session_id": token}
133
+ },
134
+ )["answer"]
135
+ await websocket.send(json.dumps({"response": response}))
136
+
137
+ async def main():
138
+ async with serve(echo, "0.0.0.0", 7860):
139
+ await asyncio.Future()
140
+
141
+ asyncio.run(main())
database.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:559d3ec60542ae76b8aebf3bffe3b8d8530b37d8fdab31411b0d6fc038d35ed9
3
+ size 528998
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ websockets
2
+ langchain
3
+ langchain-community
4
+ huggingface_hub
5
+ tiktoken
6
+ chromadb
7
+ langchain-huggingface
8
+ accelerate