import chainlit as cl from langchain.retrievers import ParentDocumentRetriever from langchain.schema.runnable import RunnableConfig from langchain.storage import LocalFileStore from langchain.storage._lc_store import create_kv_docstore from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores.chroma import Chroma from langchain_google_genai import ( GoogleGenerativeAI, GoogleGenerativeAIEmbeddings, HarmBlockThreshold, HarmCategory, ) import config from prompts import prompt from utils import PostMessageHandler, format_docs model = GoogleGenerativeAI( model=config.GOOGLE_CHAT_MODEL, google_api_key=config.GOOGLE_API_KEY, safety_settings={ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, }, ) # type: ignore embeddings_model = GoogleGenerativeAIEmbeddings( model=config.GOOGLE_EMBEDDING_MODEL ) # type: ignore ## retriever child_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"]) # The vectorstore to use to index the child chunks vectorstore = Chroma( persist_directory=config.STORAGE_PATH + "vectorstore", collection_name="full_documents", embedding_function=embeddings_model, ) # The storage layer for the parent documents fs = LocalFileStore(config.STORAGE_PATH + "docstore") store = create_kv_docstore(fs) retriever = ParentDocumentRetriever( vectorstore=vectorstore, docstore=store, child_splitter=child_splitter, ) @cl.on_chat_start async def on_chat_start(): cl.user_session.set("retriever", retriever) @cl.on_message async def on_message(message: cl.Message): chain = prompt | model msg = cl.Message(content="") async with cl.Step(type="run", name="QA Assistant"): question = message.content context = format_docs(retriever.get_relevant_documents(question)) async for chunk in chain.astream( input={"context": context, "question": question}, config=RunnableConfig( callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)] ), ): await msg.stream_token(chunk) await msg.send()