finchat / utils.py
Monsia's picture
perf: update prompt and clean the code
8bfa348
import re
import chainlit as cl
import tiktoken
from langchain.callbacks.base import BaseCallbackHandler
def format_docs(documents, max_context_size=100000, separator="\n\n"):
context = ""
encoder = tiktoken.get_encoding("cl100k_base")
i = 0
for doc in documents:
i += 1
if len(encoder.encode(context)) < max_context_size:
source = doc.metadata["link"]
title = doc.metadata["title"]
context += (
f"Article: {title}\n"
+ doc.page_content
+ f"\nSource: {source}"
+ separator
)
return context
class PostMessageHandler(BaseCallbackHandler):
"""
Callback handler for handling the retriever and LLM processes.
Used to post the sources of the retrieved documents as a Chainlit element.
"""
def __init__(self, msg: cl.Message):
BaseCallbackHandler.__init__(self)
self.msg = msg
self.sources = []
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
for d in documents:
source_doc = d.page_content + "\nSource: " + d.metadata["link"]
self.sources.append(source_doc)
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
if len(self.sources):
# Display the reference docs with a Text widget
sources_element = [
cl.Text(name=f"source_{idx+1}", content=content)
for idx, content in enumerate(self.sources)
]
source_names = [el.name for el in sources_element]
self.msg.elements += sources_element
self.msg.content += f"\nSources: {', '.join(source_names)}"
def clean_text(text):
text = re.sub("[Tt]weet", "", text) # type: ignore
text = re.sub(r"\ +", " ", text)
text = re.sub(r"\n+", "\n", text)
return text.strip()