refactor: code refactoring
Browse files- app.py +6 -50
- scrape_data.py +1 -1
- utils.py +45 -0
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import chainlit as cl
|
2 |
-
from langchain.
|
3 |
-
from langchain.chains.query_constructor.schema import AttributeInfo
|
4 |
-
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
5 |
from langchain.schema import StrOutputParser
|
6 |
from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
|
|
|
|
|
7 |
from langchain.vectorstores.chroma import Chroma
|
8 |
from langchain_google_genai import (
|
9 |
GoogleGenerativeAI,
|
@@ -11,12 +11,10 @@ from langchain_google_genai import (
|
|
11 |
HarmBlockThreshold,
|
12 |
HarmCategory,
|
13 |
)
|
14 |
-
|
15 |
-
from langchain.retrievers import ParentDocumentRetriever
|
16 |
-
from langchain.storage import InMemoryStore
|
17 |
import config
|
18 |
from prompts import prompt
|
19 |
-
import
|
20 |
|
21 |
model = GoogleGenerativeAI(
|
22 |
model=config.GOOGLE_CHAT_MODEL,
|
@@ -34,34 +32,20 @@ embedding = embeddings_model = GoogleGenerativeAIEmbeddings(
|
|
34 |
vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
|
35 |
|
36 |
## retriever
|
37 |
-
|
38 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"])
|
39 |
|
40 |
-
|
41 |
# The storage layer for the parent documents
|
42 |
store = InMemoryStore()
|
43 |
retriever = ParentDocumentRetriever(
|
44 |
vectorstore=vectordb,
|
45 |
docstore=store,
|
46 |
child_splitter=text_splitter,
|
47 |
-
)
|
48 |
-
|
49 |
|
50 |
|
51 |
@cl.on_chat_start
|
52 |
async def on_chat_start():
|
53 |
|
54 |
-
def format_docs(documents, max_context_size= 100000, separator= "\n\n"):
|
55 |
-
context = ""
|
56 |
-
encoder = tiktoken.get_encoding("cl100k_base")
|
57 |
-
i=0
|
58 |
-
for doc in documents:
|
59 |
-
i+=1
|
60 |
-
if len(encoder.encode(context)) < max_context_size:
|
61 |
-
source = doc.metadata['link']
|
62 |
-
context += f"Article{i}:\n"+doc.page_content + f"\nSource: {source}" + separator
|
63 |
-
return context
|
64 |
-
|
65 |
rag_chain = (
|
66 |
{
|
67 |
"context": retriever | format_docs,
|
@@ -71,7 +55,6 @@ async def on_chat_start():
|
|
71 |
| model
|
72 |
| StrOutputParser()
|
73 |
)
|
74 |
-
|
75 |
|
76 |
cl.user_session.set("rag_chain", rag_chain)
|
77 |
|
@@ -86,33 +69,6 @@ async def on_message(message: cl.Message):
|
|
86 |
runnable = cl.user_session.get("rag_chain") # type: Runnable # type: ignore
|
87 |
msg = cl.Message(content="")
|
88 |
|
89 |
-
class PostMessageHandler(BaseCallbackHandler):
|
90 |
-
"""
|
91 |
-
Callback handler for handling the retriever and LLM processes.
|
92 |
-
Used to post the sources of the retrieved documents as a Chainlit element.
|
93 |
-
"""
|
94 |
-
|
95 |
-
def __init__(self, msg: cl.Message):
|
96 |
-
BaseCallbackHandler.__init__(self)
|
97 |
-
self.msg = msg
|
98 |
-
self.sources = []
|
99 |
-
|
100 |
-
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
|
101 |
-
for d in documents:
|
102 |
-
source_doc = d.page_content + "\nSource: " + d.metadata["link"]
|
103 |
-
self.sources.append(source_doc)
|
104 |
-
|
105 |
-
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
|
106 |
-
if len(self.sources):
|
107 |
-
# Display the reference docs with a Text widget
|
108 |
-
sources_element = [
|
109 |
-
cl.Text(name=f"source_{idx+1}", content=content)
|
110 |
-
for idx, content in enumerate(self.sources)
|
111 |
-
]
|
112 |
-
source_names = [el.name for el in sources_element]
|
113 |
-
self.msg.elements += sources_element
|
114 |
-
self.msg.content += f"\nSources: {', '.join(source_names)}"
|
115 |
-
|
116 |
async with cl.Step(type="run", name="QA Assistant"):
|
117 |
async for chunk in runnable.astream(
|
118 |
message.content,
|
|
|
1 |
import chainlit as cl
|
2 |
+
from langchain.retrievers import ParentDocumentRetriever
|
|
|
|
|
3 |
from langchain.schema import StrOutputParser
|
4 |
from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
|
5 |
+
from langchain.storage import InMemoryStore
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
from langchain.vectorstores.chroma import Chroma
|
8 |
from langchain_google_genai import (
|
9 |
GoogleGenerativeAI,
|
|
|
11 |
HarmBlockThreshold,
|
12 |
HarmCategory,
|
13 |
)
|
14 |
+
|
|
|
|
|
15 |
import config
|
16 |
from prompts import prompt
|
17 |
+
from utils import PostMessageHandler, format_docs
|
18 |
|
19 |
model = GoogleGenerativeAI(
|
20 |
model=config.GOOGLE_CHAT_MODEL,
|
|
|
32 |
vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
|
33 |
|
34 |
## retriever
|
|
|
35 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"])
|
36 |
|
|
|
37 |
# The storage layer for the parent documents
|
38 |
store = InMemoryStore()
|
39 |
retriever = ParentDocumentRetriever(
|
40 |
vectorstore=vectordb,
|
41 |
docstore=store,
|
42 |
child_splitter=text_splitter,
|
43 |
+
)
|
|
|
44 |
|
45 |
|
46 |
@cl.on_chat_start
|
47 |
async def on_chat_start():
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
rag_chain = (
|
50 |
{
|
51 |
"context": retriever | format_docs,
|
|
|
55 |
| model
|
56 |
| StrOutputParser()
|
57 |
)
|
|
|
58 |
|
59 |
cl.user_session.set("rag_chain", rag_chain)
|
60 |
|
|
|
69 |
runnable = cl.user_session.get("rag_chain") # type: Runnable # type: ignore
|
70 |
msg = cl.Message(content="")
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
async with cl.Step(type="run", name="QA Assistant"):
|
73 |
async for chunk in runnable.astream(
|
74 |
message.content,
|
scrape_data.py
CHANGED
@@ -120,7 +120,7 @@ def process_docs(
|
|
120 |
documents=splits,
|
121 |
embedding=embeddings_model,
|
122 |
persist_directory=persist_directory,
|
123 |
-
)
|
124 |
|
125 |
return doc_search
|
126 |
|
|
|
120 |
documents=splits,
|
121 |
embedding=embeddings_model,
|
122 |
persist_directory=persist_directory,
|
123 |
+
)
|
124 |
|
125 |
return doc_search
|
126 |
|
utils.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chainlit as cl
|
2 |
+
import tiktoken
|
3 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
4 |
+
|
5 |
+
|
6 |
+
def format_docs(documents, max_context_size=100000, separator="\n\n"):
|
7 |
+
context = ""
|
8 |
+
encoder = tiktoken.get_encoding("cl100k_base")
|
9 |
+
i = 0
|
10 |
+
for doc in documents:
|
11 |
+
i += 1
|
12 |
+
if len(encoder.encode(context)) < max_context_size:
|
13 |
+
source = doc.metadata["link"]
|
14 |
+
context += (
|
15 |
+
f"Article{i}:\n" + doc.page_content + f"\nSource: {source}" + separator
|
16 |
+
)
|
17 |
+
return context
|
18 |
+
|
19 |
+
|
20 |
+
class PostMessageHandler(BaseCallbackHandler):
|
21 |
+
"""
|
22 |
+
Callback handler for handling the retriever and LLM processes.
|
23 |
+
Used to post the sources of the retrieved documents as a Chainlit element.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, msg: cl.Message):
|
27 |
+
BaseCallbackHandler.__init__(self)
|
28 |
+
self.msg = msg
|
29 |
+
self.sources = []
|
30 |
+
|
31 |
+
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
|
32 |
+
for d in documents:
|
33 |
+
source_doc = d.page_content + "\nSource: " + d.metadata["link"]
|
34 |
+
self.sources.append(source_doc)
|
35 |
+
|
36 |
+
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
|
37 |
+
if len(self.sources):
|
38 |
+
# Display the reference docs with a Text widget
|
39 |
+
sources_element = [
|
40 |
+
cl.Text(name=f"source_{idx+1}", content=content)
|
41 |
+
for idx, content in enumerate(self.sources)
|
42 |
+
]
|
43 |
+
source_names = [el.name for el in sources_element]
|
44 |
+
self.msg.elements += sources_element
|
45 |
+
self.msg.content += f"\nSources: {', '.join(source_names)}"
|