|
import chainlit as cl |
|
from langchain.retrievers import ParentDocumentRetriever |
|
from langchain.schema.runnable import RunnableConfig |
|
from langchain.storage import LocalFileStore |
|
from langchain.storage._lc_store import create_kv_docstore |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores.chroma import Chroma |
|
from langchain_google_genai import ( |
|
GoogleGenerativeAI, |
|
GoogleGenerativeAIEmbeddings, |
|
HarmBlockThreshold, |
|
HarmCategory, |
|
) |
|
|
|
import config |
|
from prompts import prompt |
|
from utils import PostMessageHandler, format_docs |
|
|
|
model = GoogleGenerativeAI( |
|
model=config.GOOGLE_CHAT_MODEL, |
|
google_api_key=config.GOOGLE_API_KEY, |
|
safety_settings={ |
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, |
|
}, |
|
) |
|
|
|
embeddings_model = GoogleGenerativeAIEmbeddings( |
|
model=config.GOOGLE_EMBEDDING_MODEL |
|
) |
|
|
|
|
|
|
|
child_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"]) |
|
|
|
|
|
vectorstore = Chroma( |
|
persist_directory=config.STORAGE_PATH + "vectorstore", |
|
collection_name="full_documents", |
|
embedding_function=embeddings_model, |
|
) |
|
|
|
|
|
fs = LocalFileStore(config.STORAGE_PATH + "docstore") |
|
store = create_kv_docstore(fs) |
|
|
|
retriever = ParentDocumentRetriever( |
|
vectorstore=vectorstore, |
|
docstore=store, |
|
child_splitter=child_splitter, |
|
) |
|
|
|
|
|
@cl.on_chat_start |
|
async def on_chat_start(): |
|
cl.user_session.set("retriever", retriever) |
|
|
|
|
|
@cl.on_message |
|
async def on_message(message: cl.Message): |
|
chain = prompt | model |
|
msg = cl.Message(content="") |
|
|
|
async with cl.Step(type="run", name="QA Assistant"): |
|
question = message.content |
|
context = format_docs(retriever.get_relevant_documents(question)) |
|
async for chunk in chain.astream( |
|
input={"context": context, "question": question}, |
|
config=RunnableConfig( |
|
callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)] |
|
), |
|
): |
|
await msg.stream_token(chunk) |
|
|
|
await msg.send() |
|
|