finchat / app.py
Monsia's picture
perf: update prompt and clean the code
8bfa348
import chainlit as cl
from langchain.retrievers import ParentDocumentRetriever
from langchain.schema.runnable import RunnableConfig
from langchain.storage import LocalFileStore
from langchain.storage._lc_store import create_kv_docstore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain_google_genai import (
GoogleGenerativeAI,
GoogleGenerativeAIEmbeddings,
HarmBlockThreshold,
HarmCategory,
)
import config
from prompts import prompt
from utils import PostMessageHandler, format_docs
model = GoogleGenerativeAI(
model=config.GOOGLE_CHAT_MODEL,
google_api_key=config.GOOGLE_API_KEY,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
},
) # type: ignore
embeddings_model = GoogleGenerativeAIEmbeddings(
model=config.GOOGLE_EMBEDDING_MODEL
) # type: ignore
## retriever
child_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"])
# The vectorstore to use to index the child chunks
vectorstore = Chroma(
persist_directory=config.STORAGE_PATH + "vectorstore",
collection_name="full_documents",
embedding_function=embeddings_model,
)
# The storage layer for the parent documents
fs = LocalFileStore(config.STORAGE_PATH + "docstore")
store = create_kv_docstore(fs)
retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
)
@cl.on_chat_start
async def on_chat_start():
cl.user_session.set("retriever", retriever)
@cl.on_message
async def on_message(message: cl.Message):
chain = prompt | model
msg = cl.Message(content="")
async with cl.Step(type="run", name="QA Assistant"):
question = message.content
context = format_docs(retriever.get_relevant_documents(question))
async for chunk in chain.astream(
input={"context": context, "question": question},
config=RunnableConfig(
callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
),
):
await msg.stream_token(chunk)
await msg.send()