Spaces:
Sleeping
Sleeping
print("eeeh") | |
import asyncio | |
import json | |
from websockets.server import serve | |
import os | |
from langchain_chroma import Chroma | |
from langchain_community.embeddings import * | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface.llms import HuggingFaceEndpoint | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.document_loaders import DirectoryLoader | |
from langchain import hub | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain.chains import create_history_aware_retriever | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
print() | |
print("-------") | |
print("started") | |
print("-------") | |
async def echo(websocket): | |
async for message in websocket: | |
data = json.loads(message) | |
if not "message" in message: | |
return | |
if not "token" in message: | |
return | |
m = data["message"] | |
token = data["token"] | |
docs = retriever.get_relevant_documents(m) | |
userData[token]["docs"] = str(docs) | |
response = conversational_rag_chain.invoke( | |
{"input": m}, | |
config={ | |
"configurable": {"session_id": token} | |
}, | |
)["answer"] | |
await websocket.send(json.dumps({"response": response})) | |
async def main(): | |
async with serve(echo, "0.0.0.0", 7860): | |
await asyncio.Future() | |
def g(): | |
if not os.path.isdir('database'): | |
os.system("unzip database.zip") | |
loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
splits = text_splitter.split_documents(documents) | |
print() | |
print("-------") | |
print("TextSplitter, DirectoryLoader") | |
print("-------") | |
persist_directory = 'db' | |
# embedding = HuggingFaceInferenceAPIEmbeddings(api_key=os.environ["HUGGINGFACE_API_KEY"], model=) | |
model_name = "BAAI/bge-large-en" | |
model_kwargs = {'device': 'cpu'} | |
encode_kwargs = {'normalize_embeddings': True} | |
embedding = HuggingFaceBgeEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs, | |
show_progress=True, | |
) | |
print() | |
print("-------") | |
print("Embeddings") | |
print("-------") | |
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding) | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
retriever = vectorstore.as_retriever() | |
prompt = hub.pull("rlm/rag-prompt") | |
llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1") | |
rag_chain = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
print() | |
print("-------") | |
print("Retriever, Prompt, LLM, Rag_Chain") | |
print("-------") | |
### Contextualize question ### | |
contextualize_q_system_prompt = """Given a chat history and the latest user question \ | |
which might reference context in the chat history, formulate a standalone question \ | |
which can be understood without the chat history. Do NOT answer the question, \ | |
just reformulate it if needed and otherwise return it as is.""" | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
history_aware_retriever = create_history_aware_retriever( | |
llm, retriever, contextualize_q_prompt | |
) | |
### Answer question ### | |
qa_system_prompt = """You are an assistant for question-answering tasks. \ | |
Use the following pieces of retrieved context to answer the question. \ | |
If you don't know the answer, just say that you don't know. \ | |
Use three sentences maximum and keep the answer concise.\ | |
{context}""" | |
qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", qa_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
### Statefully manage chat history ### | |
store = {} | |
def get_session_history(session_id: str) -> BaseChatMessageHistory: | |
if session_id not in store: | |
store[session_id] = ChatMessageHistory() | |
return store[session_id] | |
conversational_rag_chain = RunnableWithMessageHistory( | |
rag_chain, | |
get_session_history, | |
input_messages_key="input", | |
history_messages_key="chat_history", | |
output_messages_key="answer", | |
) | |
def f(): | |
asyncio.run(main()) | |
Process(f).start() | |
Process(g).start() | |
""" | |
websocket | |
streamlit app ~> backend | |
{"token": "random", "message": "what is something"} ~> backend | |
backend ~> {"response": "something is something"} | |
streamlit app ~> display response | |
""" |