schoolQuest / app.py
araeyn's picture
Update app.py
9c90141 verified
raw
history blame
5.64 kB
print("eeeh")
import asyncio
import json
from websockets.server import serve
import os
from langchain_chroma import Chroma
from langchain_community.embeddings import *
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface.llms import HuggingFaceEndpoint
from langchain_community.document_loaders import TextLoader
from langchain_community.document_loaders import DirectoryLoader
from langchain import hub
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain.chains import create_history_aware_retriever
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from multiprocessing import Process
print()
print("-------")
print("started")
print("-------")
async def echo(websocket):
async for message in websocket:
data = json.loads(message)
if not "message" in message:
return
if not "token" in message:
return
m = data["message"]
token = data["token"]
docs = retriever.get_relevant_documents(m)
userData[token]["docs"] = str(docs)
response = conversational_rag_chain.invoke(
{"input": m},
config={
"configurable": {"session_id": token}
},
)["answer"]
await websocket.send(json.dumps({"response": response}))
async def main():
async with serve(echo, "0.0.0.0", 7860):
await asyncio.Future()
def g():
if not os.path.isdir('database'):
os.system("unzip database.zip")
loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(documents)
print()
print("-------")
print("TextSplitter, DirectoryLoader")
print("-------")
persist_directory = 'db'
# embedding = HuggingFaceInferenceAPIEmbeddings(api_key=os.environ["HUGGINGFACE_API_KEY"], model=)
model_name = "BAAI/bge-large-en"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
embedding = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
show_progress=True,
)
print()
print("-------")
print("Embeddings")
print("-------")
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
retriever = vectorstore.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
print()
print("-------")
print("Retriever, Prompt, LLM, Rag_Chain")
print("-------")
### Contextualize question ###
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)
### Answer question ###
qa_system_prompt = """You are an assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \
Use three sentences maximum and keep the answer concise.\
{context}"""
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
### Statefully manage chat history ###
store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
def f():
asyncio.run(main())
Process(f).start()
Process(g).start()
"""
websocket
streamlit app ~> backend
{"token": "random", "message": "what is something"} ~> backend
backend ~> {"response": "something is something"}
streamlit app ~> display response
"""