File size: 6,176 Bytes
0b3043b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
from pathlib import Path
from typing import List

import chainlit as cl
import chainlit.data as cl_data
from langchain.callbacks.base import BaseCallbackHandler
from langchain.indexes import SQLRecordManager, index
from langchain.prompts import ChatPromptTemplate
from langchain.schema import Document
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
    PyPDFDirectoryLoader,
)
from langchain_community.vectorstores import Chroma
# from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEndpointEmbeddings

from feedback import CustomDataLayer
from rag_bot import RagBot

chunk_size = 1024
chunk_overlap = 50

embeddings_model = HuggingFaceEndpointEmbeddings(
    huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
    model="sentence-transformers/all-MiniLM-L12-v2",
)


# Feedback
cl_data._data_layer = CustomDataLayer()

PDF_STORAGE_PATH = "./data"


def process_pdfs(pdf_storage_path: str):
    pdf_directory = Path(pdf_storage_path)
    docs = []  # type: List[Document]
    # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)

    loader = PyPDFDirectoryLoader(pdf_directory)
    documents = loader.load()
    recursive_text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        is_separator_regex=False,
    )
    docs = recursive_text_splitter.split_documents(documents)

    doc_search = Chroma.from_documents(docs, embeddings_model)

    namespace = "chromadb/my_documents"
    record_manager = SQLRecordManager(
        namespace, db_url="sqlite:///record_manager_cache.sql"
    )
    record_manager.create_schema()

    index_result = index(
        docs,
        record_manager,
        doc_search,
        cleanup="full",
        source_id_key="source",
    )

    print(f"Indexing stats: {index_result}")

    return doc_search


doc_search = process_pdfs(PDF_STORAGE_PATH)
# model = ChatOpenAI(model_name="gpt-4", streaming=True)
model = ChatGroq(
    model='llama-3.1-70b-versatile',
    temperature=0,
    max_tokens=1024,
    timeout=None,
    max_retries=5,
    api_key=os.getenv("GROQ_API_KEY"),
    # other params...
)


@cl.on_chat_start
async def on_chat_start():

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system",
             """You are a helpful assistant that can answer questions about technical documents in any language. 
             Keep your answers only in the language of the question(s).

             Only use the factual information from the document(s) to answer the question(s). Keep your answers concise and to the point.

             If you do not have have sufficient information to answer a question, politely refuse to answer and say "I don't know".
             \n\nRelevant documents will be retrieved below."""
             "Context: {context}"
             ),
            ("human", "{question}"),
        ]
    )

    def format_docs(docs):
        return "\n\n".join([d.page_content for d in docs])

    retriever = doc_search.as_retriever(search_kwargs={"k": 5})

    runnable = (
            {"context": retriever | format_docs, "question": RunnablePassthrough()}
            | prompt
            | model
            | StrOutputParser()
    )

    cl.user_session.set("runnable", runnable)


@cl.on_message
async def on_message(message: cl.Message):
    runnable = cl.user_session.get("runnable")  # type: Runnable
    msg = cl.Message(content="")

    class PostMessageHandler(BaseCallbackHandler):
        """
        Callback handler for handling the retriever and LLM processes.
        Used to post the sources of the retrieved documents as a Chainlit element.
        """

        def __init__(self, msg: cl.Message):
            BaseCallbackHandler.__init__(self)
            self.msg = msg
            self.sources = []  # To store unique pairs

        def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
            for doc in documents:
                source = doc.metadata.get('source', 'Unknown Source')
                page = doc.metadata.get('page', 'N/A')
                page_content = doc.page_content
                # self.sources.add(source_page_pair)  # Add unique pairs to the set
                if not any(s["source"] == source and s["page"] == page for s in self.sources):
                    self.sources.append({
                        "source": source,
                        "page": page,
                        "content": page_content
                    })

        def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
            if len(self.sources):
                # Create a list of clickable elements for sources
                text_elements = []
                source_references = []
                for idx, src in enumerate(self.sources):
                    source_name = f"{src['source']} p.{src['page']}"
                    source_references.append(source_name)

                    # Add a previewable Chainlit element
                    text_elements.append(
                        cl.Text(
                            name=source_name,
                            content=src["content"],
                            display="side",
                        )
                    )
                # Generate the answer with clickable source names
                self.msg.content += f"\n\nSources: {", ".join(
                    source_references
                )}"

                # Append text elements to the message
                self.msg.elements.extend(text_elements)

    async for chunk in runnable.astream(
            message.content,
            config=RunnableConfig(callbacks=[
                cl.LangchainCallbackHandler(),
                PostMessageHandler(msg)
            ]),
    ):
        await msg.stream_token(chunk)

    await msg.send()