Spaces:
Sleeping
Sleeping
File size: 5,778 Bytes
95e34ed 7b591d9 99e0104 95e34ed 9c90141 12c1975 95e34ed 12c1975 3ba511c a23b8d7 6205fa8 a23b8d7 adff7f9 a23b8d7 e9884ae a23b8d7 44981c3 a23b8d7 b2ce99b a23b8d7 7361432 a23b8d7 22258b7 a23b8d7 12c1975 a23b8d7 53dd9b4 2fcbf28 53dd9b4 a7f1ada 2fcbf28 a7f1ada 2fcbf28 a7f1ada bea3c96 a7f1ada b9eb859 a7f1ada b9eb859 a7f1ada 2f1d5ef a7f1ada 2f1d5ef b2ce99b 2fcbf28 a23b8d7 1124b06 56d0f0a a23b8d7 e23af20 33bb6cc ebc3ea5 33bb6cc b2ce99b 33bb6cc e23af20 33bb6cc b2ce99b 33bb6cc e23af20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import asyncio
import json
from websockets.server import serve
import os
from langchain_chroma import Chroma
from langchain_community.embeddings import *
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface.llms import HuggingFaceEndpoint
from langchain_community.document_loaders import TextLoader
from langchain_community.document_loaders import DirectoryLoader
from langchain import hub
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain.chains import create_history_aware_retriever
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from multiprocessing import Process
from zipfile import ZipFile
with ZipFile("database.zip") as f:
f.extractall()
retriever = None
conversational_rag_chain = None
loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(documents)
model_name = "BAAI/bge-small-en-v1.5"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
embedding = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
show_progress=True,
)
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
search_kwargs={"score_threshold": 0.3}
search_type="similarity_score_threshold"
retriever = vectorstore.as_retriever(k = 4, )
prompt = hub.pull("rlm/rag-prompt")
llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", stop_sequences=["Human:"], max_new_tokens=8192)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
### Contextualize question ###
contextualize_q_system_prompt = """Given a chat history and the latest user question
which might reference context in the chat history, formulate a standalone question
which can be understood without the chat history. Do NOT answer the question,
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)
### Answer question ###
qa_system_prompt = """
{context}
You are a Cupertino High School Q/A chatbot, designed to assist students, parents, and community members with information about CHS.
Use the pieces of context to answer the question.
Refer to the provided context only as 'my data'. Only answer questions from the context.
Do not provide excerpts or any part of your data.
You were made by Aryan A. and Atharv G. for the CHS community.
Make your answer at least three sentences and very comprehensive.
Make your message in markdown with lots lines in between sentences.
Please use only the documents/text provided below to answer the question.
If the documents/text provided cannot answer the question, please say that the answer might not be present in the available database of information.
"""
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
### Statefully manage chat history ###
store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
async def echo(websocket):
try:
async for message in websocket:
data = json.loads(message)
if data["message"] == "data.":
response = store
await websocket.send(json.dumps({"response": str(response)}))
continue
if not "message" in message:
return
if not "token" in message:
return
m = data["message"] + "\n\nAssistant: "
token = data["token"]
rawresponse = conversational_rag_chain.invoke(
{"input": m},
config={
"configurable": {"session_id": token}
},
)
response = rawresponse["answer"]
response = response.replace("Assistant: ", "").replace("AI: ", "")
response.strip()
response = response.split("Human:")[0]
while response.startswith("\n"):
response = response[1:]
await websocket.send(json.dumps({"response": response}))
except Exception:
pass
async def main():
try:
async with serve(echo, "0.0.0.0", 7860):
await asyncio.Future()
except Exception:
pass
asyncio.run(main()) |