Monsia commited on
Commit
0dfba83
·
1 Parent(s): 0c69aa1

refactor: code refactoring

Browse files
Files changed (3) hide show
  1. app.py +6 -50
  2. scrape_data.py +1 -1
  3. utils.py +45 -0
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import chainlit as cl
2
- from langchain.callbacks.base import BaseCallbackHandler
3
- from langchain.chains.query_constructor.schema import AttributeInfo
4
- from langchain.retrievers.self_query.base import SelfQueryRetriever
5
  from langchain.schema import StrOutputParser
6
  from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
 
 
7
  from langchain.vectorstores.chroma import Chroma
8
  from langchain_google_genai import (
9
  GoogleGenerativeAI,
@@ -11,12 +11,10 @@ from langchain_google_genai import (
11
  HarmBlockThreshold,
12
  HarmCategory,
13
  )
14
- from langchain.text_splitter import RecursiveCharacterTextSplitter
15
- from langchain.retrievers import ParentDocumentRetriever
16
- from langchain.storage import InMemoryStore
17
  import config
18
  from prompts import prompt
19
- import tiktoken
20
 
21
  model = GoogleGenerativeAI(
22
  model=config.GOOGLE_CHAT_MODEL,
@@ -34,34 +32,20 @@ embedding = embeddings_model = GoogleGenerativeAIEmbeddings(
34
  vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
35
 
36
  ## retriever
37
-
38
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"])
39
 
40
-
41
  # The storage layer for the parent documents
42
  store = InMemoryStore()
43
  retriever = ParentDocumentRetriever(
44
  vectorstore=vectordb,
45
  docstore=store,
46
  child_splitter=text_splitter,
47
- )
48
-
49
 
50
 
51
  @cl.on_chat_start
52
  async def on_chat_start():
53
 
54
- def format_docs(documents, max_context_size= 100000, separator= "\n\n"):
55
- context = ""
56
- encoder = tiktoken.get_encoding("cl100k_base")
57
- i=0
58
- for doc in documents:
59
- i+=1
60
- if len(encoder.encode(context)) < max_context_size:
61
- source = doc.metadata['link']
62
- context += f"Article{i}:\n"+doc.page_content + f"\nSource: {source}" + separator
63
- return context
64
-
65
  rag_chain = (
66
  {
67
  "context": retriever | format_docs,
@@ -71,7 +55,6 @@ async def on_chat_start():
71
  | model
72
  | StrOutputParser()
73
  )
74
-
75
 
76
  cl.user_session.set("rag_chain", rag_chain)
77
 
@@ -86,33 +69,6 @@ async def on_message(message: cl.Message):
86
  runnable = cl.user_session.get("rag_chain") # type: Runnable # type: ignore
87
  msg = cl.Message(content="")
88
 
89
- class PostMessageHandler(BaseCallbackHandler):
90
- """
91
- Callback handler for handling the retriever and LLM processes.
92
- Used to post the sources of the retrieved documents as a Chainlit element.
93
- """
94
-
95
- def __init__(self, msg: cl.Message):
96
- BaseCallbackHandler.__init__(self)
97
- self.msg = msg
98
- self.sources = []
99
-
100
- def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
101
- for d in documents:
102
- source_doc = d.page_content + "\nSource: " + d.metadata["link"]
103
- self.sources.append(source_doc)
104
-
105
- def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
106
- if len(self.sources):
107
- # Display the reference docs with a Text widget
108
- sources_element = [
109
- cl.Text(name=f"source_{idx+1}", content=content)
110
- for idx, content in enumerate(self.sources)
111
- ]
112
- source_names = [el.name for el in sources_element]
113
- self.msg.elements += sources_element
114
- self.msg.content += f"\nSources: {', '.join(source_names)}"
115
-
116
  async with cl.Step(type="run", name="QA Assistant"):
117
  async for chunk in runnable.astream(
118
  message.content,
 
1
  import chainlit as cl
2
+ from langchain.retrievers import ParentDocumentRetriever
 
 
3
  from langchain.schema import StrOutputParser
4
  from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
5
+ from langchain.storage import InMemoryStore
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.vectorstores.chroma import Chroma
8
  from langchain_google_genai import (
9
  GoogleGenerativeAI,
 
11
  HarmBlockThreshold,
12
  HarmCategory,
13
  )
14
+
 
 
15
  import config
16
  from prompts import prompt
17
+ from utils import PostMessageHandler, format_docs
18
 
19
  model = GoogleGenerativeAI(
20
  model=config.GOOGLE_CHAT_MODEL,
 
32
  vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
33
 
34
  ## retriever
 
35
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"])
36
 
 
37
  # The storage layer for the parent documents
38
  store = InMemoryStore()
39
  retriever = ParentDocumentRetriever(
40
  vectorstore=vectordb,
41
  docstore=store,
42
  child_splitter=text_splitter,
43
+ )
 
44
 
45
 
46
  @cl.on_chat_start
47
  async def on_chat_start():
48
 
 
 
 
 
 
 
 
 
 
 
 
49
  rag_chain = (
50
  {
51
  "context": retriever | format_docs,
 
55
  | model
56
  | StrOutputParser()
57
  )
 
58
 
59
  cl.user_session.set("rag_chain", rag_chain)
60
 
 
69
  runnable = cl.user_session.get("rag_chain") # type: Runnable # type: ignore
70
  msg = cl.Message(content="")
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  async with cl.Step(type="run", name="QA Assistant"):
73
  async for chunk in runnable.astream(
74
  message.content,
scrape_data.py CHANGED
@@ -120,7 +120,7 @@ def process_docs(
120
  documents=splits,
121
  embedding=embeddings_model,
122
  persist_directory=persist_directory,
123
- )
124
 
125
  return doc_search
126
 
 
120
  documents=splits,
121
  embedding=embeddings_model,
122
  persist_directory=persist_directory,
123
+ )
124
 
125
  return doc_search
126
 
utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chainlit as cl
2
+ import tiktoken
3
+ from langchain.callbacks.base import BaseCallbackHandler
4
+
5
+
6
+ def format_docs(documents, max_context_size=100000, separator="\n\n"):
7
+ context = ""
8
+ encoder = tiktoken.get_encoding("cl100k_base")
9
+ i = 0
10
+ for doc in documents:
11
+ i += 1
12
+ if len(encoder.encode(context)) < max_context_size:
13
+ source = doc.metadata["link"]
14
+ context += (
15
+ f"Article{i}:\n" + doc.page_content + f"\nSource: {source}" + separator
16
+ )
17
+ return context
18
+
19
+
20
+ class PostMessageHandler(BaseCallbackHandler):
21
+ """
22
+ Callback handler for handling the retriever and LLM processes.
23
+ Used to post the sources of the retrieved documents as a Chainlit element.
24
+ """
25
+
26
+ def __init__(self, msg: cl.Message):
27
+ BaseCallbackHandler.__init__(self)
28
+ self.msg = msg
29
+ self.sources = []
30
+
31
+ def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
32
+ for d in documents:
33
+ source_doc = d.page_content + "\nSource: " + d.metadata["link"]
34
+ self.sources.append(source_doc)
35
+
36
+ def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
37
+ if len(self.sources):
38
+ # Display the reference docs with a Text widget
39
+ sources_element = [
40
+ cl.Text(name=f"source_{idx+1}", content=content)
41
+ for idx, content in enumerate(self.sources)
42
+ ]
43
+ source_names = [el.name for el in sources_element]
44
+ self.msg.elements += sources_element
45
+ self.msg.content += f"\nSources: {', '.join(source_names)}"