Spaces:
Runtime error
Runtime error
import os | |
from pathlib import Path | |
from typing import List | |
import chainlit as cl | |
import chainlit.data as cl_data | |
from langchain.callbacks.base import BaseCallbackHandler | |
from langchain.indexes import SQLRecordManager, index | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema import Document | |
from langchain.schema import StrOutputParser | |
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import ( | |
PyPDFDirectoryLoader, | |
) | |
from langchain_community.vectorstores import Chroma | |
# from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import HuggingFaceEndpointEmbeddings | |
from feedback import CustomDataLayer | |
from rag_bot import RagBot | |
chunk_size = 1024 | |
chunk_overlap = 50 | |
embeddings_model = HuggingFaceEndpointEmbeddings( | |
huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), | |
model="sentence-transformers/all-MiniLM-L12-v2", | |
) | |
# Feedback | |
cl_data._data_layer = CustomDataLayer() | |
PDF_STORAGE_PATH = "./data" | |
def process_pdfs(pdf_storage_path: str): | |
pdf_directory = Path(pdf_storage_path) | |
docs = [] # type: List[Document] | |
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
loader = PyPDFDirectoryLoader(pdf_directory) | |
documents = loader.load() | |
recursive_text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
length_function=len, | |
is_separator_regex=False, | |
) | |
docs = recursive_text_splitter.split_documents(documents) | |
if not docs: | |
raise ValueError("No documents found in the specified directory.") | |
doc_search = Chroma.from_documents(docs, embeddings_model) | |
namespace = "chromadb/my_documents" | |
record_manager = SQLRecordManager( | |
namespace, db_url="sqlite:///record_manager_cache.sql" | |
) | |
record_manager.create_schema() | |
index_result = index( | |
docs, | |
record_manager, | |
doc_search, | |
cleanup="full", | |
source_id_key="source", | |
) | |
print(f"Indexing stats: {index_result}") | |
return doc_search | |
doc_search = process_pdfs(PDF_STORAGE_PATH) | |
# model = ChatOpenAI(model_name="gpt-4", streaming=True) | |
model = ChatGroq( | |
model='llama-3.1-70b-versatile', | |
temperature=0, | |
max_tokens=1024, | |
timeout=None, | |
max_retries=5, | |
api_key=os.getenv("GROQ_API_KEY"), | |
# other params... | |
) | |
async def on_chat_start(): | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", | |
"""You are a helpful assistant that can answer questions about technical documents in any language. | |
Keep your answers only in the language of the question(s). | |
Only use the factual information from the document(s) to answer the question(s). Keep your answers concise and to the point. | |
If you do not have have sufficient information to answer a question, politely refuse to answer and say "I don't know". | |
\n\nRelevant documents will be retrieved below.""" | |
"Context: {context}" | |
), | |
("human", "{question}"), | |
] | |
) | |
def format_docs(docs): | |
return "\n\n".join([d.page_content for d in docs]) | |
retriever = doc_search.as_retriever(search_kwargs={"k": 5}) | |
runnable = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| model | |
| StrOutputParser() | |
) | |
cl.user_session.set("runnable", runnable) | |
async def on_message(message: cl.Message): | |
runnable = cl.user_session.get("runnable") # type: Runnable | |
msg = cl.Message(content="") | |
class PostMessageHandler(BaseCallbackHandler): | |
""" | |
Callback handler for handling the retriever and LLM processes. | |
Used to post the sources of the retrieved documents as a Chainlit element. | |
""" | |
def __init__(self, msg: cl.Message): | |
BaseCallbackHandler.__init__(self) | |
self.msg = msg | |
self.sources = [] # To store unique pairs | |
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): | |
for doc in documents: | |
source = doc.metadata.get('source', 'Unknown Source') | |
page = doc.metadata.get('page', 'N/A') | |
page_content = doc.page_content | |
# self.sources.add(source_page_pair) # Add unique pairs to the set | |
if not any(s["source"] == source and s["page"] == page for s in self.sources): | |
self.sources.append({ | |
"source": source, | |
"page": page, | |
"content": page_content | |
}) | |
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): | |
if len(self.sources): | |
# Create a list of clickable elements for sources | |
text_elements = [] | |
source_references = [] | |
for idx, src in enumerate(self.sources): | |
source_name = f"{src['source']} p.{src['page']}" | |
source_references.append(source_name) | |
# Add a previewable Chainlit element | |
text_elements.append( | |
cl.Text( | |
name=source_name, | |
content=src["content"], | |
display="side", | |
) | |
) | |
# Generate the answer with clickable source names | |
self.msg.content += f"\n\nSources: {", ".join( | |
source_references | |
)}" | |
# Append text elements to the message | |
self.msg.elements.extend(text_elements) | |
async for chunk in runnable.astream( | |
message.content, | |
config=RunnableConfig(callbacks=[ | |
cl.LangchainCallbackHandler(), | |
PostMessageHandler(msg) | |
]), | |
): | |
await msg.stream_token(chunk) | |
await msg.send() | |