import asyncio import json from websockets.server import serve import os from langchain_chroma import Chroma from langchain_community.embeddings import * from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface.llms import HuggingFaceEndpoint from langchain_community.document_loaders import TextLoader from langchain_community.document_loaders import DirectoryLoader from langchain import hub from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from langchain.chains import create_history_aware_retriever from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.chat_history import BaseChatMessageHistory from langchain_community.chat_message_histories import ChatMessageHistory from multiprocessing import Process if not os.path.isdir('database'): os.system("unzip database.zip") retriever = None conversational_rag_chain = None loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader) documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) splits = text_splitter.split_documents(documents) print() print("-------") print("TextSplitter, DirectoryLoader") print("-------") persist_directory = 'db' model_name = "BAAI/bge-large-en" model_kwargs = {'device': 'cpu'} encode_kwargs = {'normalize_embeddings': True} embedding = HuggingFaceBgeEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, show_progress=True, ) async def echo(websocket): global retriever, conversational_rag_chain async for message in websocket: data = json.loads(message) if not "message" in message: return if not "token" in message: return m = data["message"] token = data["token"] docs = retriever.get_relevant_documents(m) response = conversational_rag_chain.invoke( {"input": m}, config={ "configurable": {"session_id": token} }, )["answer"] await websocket.send(json.dumps({"response": response})) async def main(): async with serve(echo, "0.0.0.0", 7860): await asyncio.Future() def f(): asyncio.run(main()) Process(target=f).start() vectorstore = Chroma.from_documents(documents=splits, embedding=embedding) def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) retriever = vectorstore.as_retriever() prompt = hub.pull("rlm/rag-prompt") llm = HuggingFaceEndpoint(repo_id="mistralai/Mistral-7B-Instruct-v0.3") rag_chain = ( {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) ### Contextualize question ### contextualize_q_system_prompt = """Given a chat history and the latest user question \ which might reference context in the chat history, formulate a standalone question \ which can be understood without the chat history. Do NOT answer the question, \ just reformulate it if needed and otherwise return it as is.""" contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) history_aware_retriever = create_history_aware_retriever( llm, retriever, contextualize_q_prompt ) ### Answer question ### qa_system_prompt = """You are an assistant for question-answering tasks. \ Use the following pieces of retrieved context to answer the question. \ If you don't know the answer, just say that you don't know. \ Use three sentences maximum and keep the answer concise.\ {context}""" qa_prompt = ChatPromptTemplate.from_messages( [ ("system", qa_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) ### Statefully manage chat history ### store = {} def get_session_history(session_id: str) -> BaseChatMessageHistory: if session_id not in store: store[session_id] = ChatMessageHistory() return store[session_id] conversational_rag_chain = RunnableWithMessageHistory( rag_chain, get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) """ websocket streamlit app ~> backend {"token": "random", "message": "what is something"} ~> backend backend ~> {"response": "something is something"} streamlit app ~> display response """