Spaces:
Sleeping
Sleeping
File size: 4,863 Bytes
06597dd ef80283 06597dd ef80283 06597dd ef80283 06597dd ef80283 06597dd ef80283 06597dd ef80283 06597dd ef80283 06597dd ef80283 06597dd ef80283 06597dd ef80283 06597dd ef80283 06597dd ef80283 06597dd ef80283 06597dd ef80283 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import uuid
from dotenv import load_dotenv
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyMuPDFLoader
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain.storage import LocalFileStore
from langchain_qdrant import QdrantVectorStore
from langchain.embeddings import CacheBackedEmbeddings
from chainlit.types import AskFileResponse
from operator import itemgetter
from langchain_core.runnables.passthrough import RunnablePassthrough
import chainlit as cl
from langchain_core.runnables.config import RunnableConfig
from langchain_huggingface import HuggingFaceEndpoint
from langchain_huggingface.embeddings import HuggingFaceEndpointEmbeddings
from langchain_core.prompts import PromptTemplate
load_dotenv()
YOUR_LLM_ENDPOINT_URL = os.environ["YOUR_LLM_ENDPOINT_URL"]
YOUR_EMBED_MODEL_URL = os.environ["YOUR_EMBED_MODEL_URL"]
RAG_PROMPT_TEMPLATE = """\
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.<|eot_id|>
<|start_header_id|>user<|end_header_id|>
User Query:
{query}
Context:
{context}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100)
hf_llm = HuggingFaceEndpoint(
endpoint_url=f"{YOUR_LLM_ENDPOINT_URL}",
max_new_tokens=300,
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
huggingfacehub_api_token=os.environ["HF_TOKEN"]
)
hf_embeddings = HuggingFaceEndpointEmbeddings(
model=os.environ["YOUR_EMBED_MODEL_URL"],
task="feature-extraction",
huggingfacehub_api_token=os.environ["HF_TOKEN"],
)
rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
def process_file(file: AskFileResponse):
import tempfile
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tempfile:
with open(tempfile.name, "wb") as f:
f.write(file.content)
Loader = PyMuPDFLoader
loader = Loader(tempfile.name)
documents = loader.load()
docs = text_splitter.split_documents(documents)
for i, doc in enumerate(docs):
doc.metadata["source"] = f"source_{i}"
return docs
@cl.on_chat_start
async def on_chat_start():
files = None
while files == None:
files = await cl.AskFileMessage(
content="Please upload a PDF file to begin!",
accept=["application/pdf"],
max_size_mb=20,
timeout=180,
max_files=1
).send()
file = files[0]
msg = cl.Message(
content=f"Processing `{file.name}`...",
)
await msg.send()
docs = process_file(file)
# QDrant Client Set-up
collection_name = f"pdf_to_parse_{uuid.uuid4()}"
client = QdrantClient(":memory:")
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=768, distance=Distance.COSINE),
)
# Adding cache!
# store = LocalFileStore("./cache/")
# cached_embedder = CacheBackedEmbeddings.from_bytes_store(
# hf_embeddings, store, namespace=hf_embeddings.model
# )
# Typical QDrant Vector Store Set-up
vectorstore = QdrantVectorStore(
client=client,
collection_name=collection_name,
embedding=hf_embeddings)
retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3})
for i in range(0, len(docs), 32):
if i == 0:
retriever.add_documents(docs[i:i+32])
continue
retriever.add_documents(docs[i:i+32])
retrieval_augmented_qa_chain = (
{"context": itemgetter("query") | retriever, "query": itemgetter("query")}| rag_prompt | hf_llm
)
# Let the user know that the system is ready
msg.content = f"Processing `{file.name}` done. You can now ask questions!"
await msg.update()
cl.user_session.set("chain", retrieval_augmented_qa_chain)
### Rename Chains ###
@cl.author_rename
def rename(orig_author: str):
""" RENAME CODE HERE """
rename_dict = {"ChatOpenAI": "the Generator...", "VectorStoreRetriever": "the Retriever..."}
return rename_dict.get(orig_author, orig_author)
### On Message Section ###
@cl.on_message
async def main(message: cl.Message):
runnable = cl.user_session.get("chain")
msg = cl.Message(content="")
async for chunk in runnable.astream(
{"query": message.content},
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
await msg.stream_token(chunk)
await msg.send()
if __name__ == "__main__":
from chainlit.cli import run_chainlit
run_chainlit(__file__)
|