|
import re |
|
|
|
import chainlit as cl |
|
import tiktoken |
|
from langchain.callbacks.base import BaseCallbackHandler |
|
|
|
|
|
def format_docs(documents, max_context_size=100000, separator="\n\n"): |
|
context = "" |
|
encoder = tiktoken.get_encoding("cl100k_base") |
|
i = 0 |
|
for doc in documents: |
|
i += 1 |
|
if len(encoder.encode(context)) < max_context_size: |
|
source = doc.metadata["link"] |
|
title = doc.metadata["title"] |
|
context += ( |
|
f"Article: {title}\n" |
|
+ doc.page_content |
|
+ f"\nSource: {source}" |
|
+ separator |
|
) |
|
return context |
|
|
|
|
|
class PostMessageHandler(BaseCallbackHandler): |
|
""" |
|
Callback handler for handling the retriever and LLM processes. |
|
Used to post the sources of the retrieved documents as a Chainlit element. |
|
""" |
|
|
|
def __init__(self, msg: cl.Message): |
|
BaseCallbackHandler.__init__(self) |
|
self.msg = msg |
|
self.sources = [] |
|
|
|
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): |
|
for d in documents: |
|
source_doc = d.page_content + "\nSource: " + d.metadata["link"] |
|
self.sources.append(source_doc) |
|
|
|
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): |
|
if len(self.sources): |
|
|
|
sources_element = [ |
|
cl.Text(name=f"source_{idx+1}", content=content) |
|
for idx, content in enumerate(self.sources) |
|
] |
|
source_names = [el.name for el in sources_element] |
|
self.msg.elements += sources_element |
|
self.msg.content += f"\nSources: {', '.join(source_names)}" |
|
|
|
def clean_text(text): |
|
text = re.sub("[Tt]weet", "", text) |
|
text = re.sub(r"\ +", " ", text) |
|
text = re.sub(r"\n+", "\n", text) |
|
return text.strip() |
|
|