File size: 4,863 Bytes
06597dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef80283
06597dd
 
 
 
 
ef80283
 
06597dd
 
 
 
 
 
 
 
 
 
 
 
 
 
ef80283
 
06597dd
ef80283
 
06597dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef80283
06597dd
 
 
 
ef80283
06597dd
 
 
ef80283
 
 
 
06597dd
 
 
 
 
ef80283
 
06597dd
 
 
ef80283
06597dd
ef80283
06597dd
 
 
ef80283
06597dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef80283
06597dd
 
ef80283
 
 
06597dd
ef80283
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import uuid
from dotenv import load_dotenv
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyMuPDFLoader
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain.storage import LocalFileStore
from langchain_qdrant import QdrantVectorStore
from langchain.embeddings import CacheBackedEmbeddings
from chainlit.types import AskFileResponse
from operator import itemgetter
from langchain_core.runnables.passthrough import RunnablePassthrough
import chainlit as cl
from langchain_core.runnables.config import RunnableConfig
from langchain_huggingface import HuggingFaceEndpoint
from langchain_huggingface.embeddings import HuggingFaceEndpointEmbeddings
from langchain_core.prompts import PromptTemplate

load_dotenv()

YOUR_LLM_ENDPOINT_URL = os.environ["YOUR_LLM_ENDPOINT_URL"]
YOUR_EMBED_MODEL_URL = os.environ["YOUR_EMBED_MODEL_URL"]

RAG_PROMPT_TEMPLATE = """\
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.<|eot_id|>

<|start_header_id|>user<|end_header_id|>
User Query:
{query}

Context:
{context}<|eot_id|>

<|start_header_id|>assistant<|end_header_id|>
"""

text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100)
hf_llm = HuggingFaceEndpoint(
    endpoint_url=f"{YOUR_LLM_ENDPOINT_URL}",
    max_new_tokens=300,
    top_k=10,
    top_p=0.95,
    typical_p=0.95,
    temperature=0.01,
    repetition_penalty=1.03,
    huggingfacehub_api_token=os.environ["HF_TOKEN"]
)
hf_embeddings = HuggingFaceEndpointEmbeddings(
    model=os.environ["YOUR_EMBED_MODEL_URL"],
    task="feature-extraction",
    huggingfacehub_api_token=os.environ["HF_TOKEN"],
)

rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)

def process_file(file: AskFileResponse):
    import tempfile

    with tempfile.NamedTemporaryFile(mode="w", delete=False) as tempfile:
        with open(tempfile.name, "wb") as f:
            f.write(file.content)

    Loader = PyMuPDFLoader

    loader = Loader(tempfile.name)
    documents = loader.load()
    docs = text_splitter.split_documents(documents)
    for i, doc in enumerate(docs):
        doc.metadata["source"] = f"source_{i}"
    return docs


@cl.on_chat_start
async def on_chat_start():
    files = None

    while files == None:
        files = await cl.AskFileMessage(
            content="Please upload a PDF file to begin!",
            accept=["application/pdf"],
            max_size_mb=20,
            timeout=180,
            max_files=1
        ).send()

    file = files[0]
    msg = cl.Message(
        content=f"Processing `{file.name}`...",
    )
    await msg.send()
    docs = process_file(file)

    # QDrant Client Set-up
    collection_name = f"pdf_to_parse_{uuid.uuid4()}"
    client = QdrantClient(":memory:")
    client.create_collection(
        collection_name=collection_name,
        vectors_config=VectorParams(size=768, distance=Distance.COSINE),
    )

    # Adding cache!
    # store = LocalFileStore("./cache/")
    # cached_embedder = CacheBackedEmbeddings.from_bytes_store(
    #     hf_embeddings, store, namespace=hf_embeddings.model
    # )

    # Typical QDrant Vector Store Set-up
    vectorstore = QdrantVectorStore(
        client=client,
        collection_name=collection_name,
        embedding=hf_embeddings)
    retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3})
    
    for i in range(0, len(docs), 32):
        if i == 0:
            retriever.add_documents(docs[i:i+32])
            continue
        retriever.add_documents(docs[i:i+32])


    retrieval_augmented_qa_chain = (
       {"context": itemgetter("query") | retriever, "query": itemgetter("query")}| rag_prompt | hf_llm
    )

    # Let the user know that the system is ready
    msg.content = f"Processing `{file.name}` done. You can now ask questions!"
    await msg.update()

    cl.user_session.set("chain", retrieval_augmented_qa_chain)
    

### Rename Chains ###
@cl.author_rename
def rename(orig_author: str):
    """ RENAME CODE HERE """
    rename_dict = {"ChatOpenAI": "the Generator...", "VectorStoreRetriever": "the Retriever..."}
    return rename_dict.get(orig_author, orig_author)

### On Message Section ###
@cl.on_message
async def main(message: cl.Message):
    runnable = cl.user_session.get("chain")

    msg = cl.Message(content="")

    async for chunk in runnable.astream(
        {"query": message.content},
        config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
    ):
        await msg.stream_token(chunk)

    await msg.send()

if __name__ == "__main__":
    from chainlit.cli import run_chainlit
    run_chainlit(__file__)