Monsia commited on
Commit
3cb7480
·
unverified ·
2 Parent(s): a7aa9c3 a3b1498

Merge pull request #3 from data354/dev

Browse files
Files changed (6) hide show
  1. app.py +38 -73
  2. config.py +1 -1
  3. prompts.py +19 -3
  4. requirements.txt +2 -1
  5. scrape_data.py +23 -33
  6. utils.py +51 -0
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import chainlit as cl
2
- from langchain.callbacks.base import BaseCallbackHandler
3
- from langchain.chains.query_constructor.schema import AttributeInfo
4
- from langchain.retrievers.self_query.base import SelfQueryRetriever
5
- from langchain.schema import StrOutputParser
6
- from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
7
  from langchain.vectorstores.chroma import Chroma
8
  from langchain_google_genai import (
9
  GoogleGenerativeAI,
@@ -14,21 +14,7 @@ from langchain_google_genai import (
14
 
15
  import config
16
  from prompts import prompt
17
-
18
- metadata_field_info = [
19
- AttributeInfo(
20
- name="title",
21
- description="Le titre de l'article",
22
- type="string",
23
- ),
24
- AttributeInfo(
25
- name="date",
26
- description="Date de publication",
27
- type="string",
28
- ),
29
- AttributeInfo(name="link", description="Source de l'article", type="string"),
30
- ]
31
- document_content_description = "Articles sur l'actualité."
32
 
33
  model = GoogleGenerativeAI(
34
  model=config.GOOGLE_CHAT_MODEL,
@@ -38,38 +24,36 @@ model = GoogleGenerativeAI(
38
  },
39
  ) # type: ignore
40
 
41
- # Load vector database that was persisted earlier
42
- embedding = embeddings_model = GoogleGenerativeAIEmbeddings(
43
- model="models/embedding-001", google_api_key=config.GOOGLE_API_KEY
44
  ) # type: ignore
45
 
46
- vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
47
 
48
- retriever = SelfQueryRetriever.from_llm(
49
- model,
50
- vectordb,
51
- document_content_description,
52
- metadata_field_info,
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
 
55
 
56
  @cl.on_chat_start
57
  async def on_chat_start():
58
 
59
- def format_docs(docs):
60
- return "\n\n".join(doc.page_content for doc in docs)
61
-
62
- rag_chain = (
63
- {
64
- "context": vectordb.as_retriever() | format_docs,
65
- "question": RunnablePassthrough(),
66
- }
67
- | prompt
68
- | model
69
- | StrOutputParser()
70
- )
71
-
72
- cl.user_session.set("rag_chain", rag_chain)
73
 
74
  msg = cl.Message(
75
  content=f"Vous pouvez poser vos questions sur les articles de SIKAFINANCE",
@@ -79,39 +63,20 @@ async def on_chat_start():
79
 
80
  @cl.on_message
81
  async def on_message(message: cl.Message):
82
- runnable = cl.user_session.get("rag_chain") # type: Runnable # type: ignore
83
- msg = cl.Message(content="")
84
 
85
- class PostMessageHandler(BaseCallbackHandler):
86
- """
87
- Callback handler for handling the retriever and LLM processes.
88
- Used to post the sources of the retrieved documents as a Chainlit element.
89
- """
90
-
91
- def __init__(self, msg: cl.Message):
92
- BaseCallbackHandler.__init__(self)
93
- self.msg = msg
94
- self.sources = []
95
-
96
- def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
97
- for d in documents:
98
- source_doc = d.page_content + "\nSource: " + d.metadata["link"]
99
- self.sources.append(source_doc)
100
-
101
- def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
102
- if len(self.sources):
103
- # Display the reference docs with a Text widget
104
- sources_element = [
105
- cl.Text(name=f"source_{idx+1}", content=content)
106
- for idx, content in enumerate(self.sources)
107
- ]
108
- source_names = [el.name for el in sources_element]
109
- self.msg.elements += sources_element
110
- self.msg.content += f"\nSources: {', '.join(source_names)}"
111
 
112
  async with cl.Step(type="run", name="QA Assistant"):
113
- async for chunk in runnable.astream(
114
- message.content,
 
 
 
 
115
  config=RunnableConfig(
116
  callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
117
  ),
 
1
  import chainlit as cl
2
+ from langchain.retrievers import ParentDocumentRetriever
3
+ from langchain.schema.runnable import RunnableConfig
4
+ from langchain.storage import LocalFileStore
5
+ from langchain.storage._lc_store import create_kv_docstore
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.vectorstores.chroma import Chroma
8
  from langchain_google_genai import (
9
  GoogleGenerativeAI,
 
14
 
15
  import config
16
  from prompts import prompt
17
+ from utils import PostMessageHandler, format_docs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  model = GoogleGenerativeAI(
20
  model=config.GOOGLE_CHAT_MODEL,
 
24
  },
25
  ) # type: ignore
26
 
27
+ embeddings_model = GoogleGenerativeAIEmbeddings(
28
+ model=config.GOOGLE_EMBEDDING_MODEL
 
29
  ) # type: ignore
30
 
 
31
 
32
+ ## retriever
33
+ child_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"])
34
+
35
+ # The vectorstore to use to index the child chunks
36
+ vectorstore = Chroma(
37
+ persist_directory=config.STORAGE_PATH + "vectorstore",
38
+ collection_name="full_documents",
39
+ embedding_function=embeddings_model,
40
+ )
41
+
42
+ # The storage layer for the parent documents
43
+ fs = LocalFileStore(config.STORAGE_PATH + "docstore")
44
+ store = create_kv_docstore(fs)
45
+
46
+ retriever = ParentDocumentRetriever(
47
+ vectorstore=vectorstore,
48
+ docstore=store,
49
+ child_splitter=child_splitter,
50
  )
51
 
52
 
53
  @cl.on_chat_start
54
  async def on_chat_start():
55
 
56
+ cl.user_session.set("retriever", retriever)
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  msg = cl.Message(
59
  content=f"Vous pouvez poser vos questions sur les articles de SIKAFINANCE",
 
63
 
64
  @cl.on_message
65
  async def on_message(message: cl.Message):
 
 
66
 
67
+ # retriever = cl.user_session.get("retriever")
68
+
69
+ chain = prompt | model
70
+
71
+ msg = cl.Message(content="")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  async with cl.Step(type="run", name="QA Assistant"):
74
+
75
+ question = message.content
76
+ context = format_docs(retriever.get_relevant_documents(question))
77
+
78
+ async for chunk in chain.astream(
79
+ input={"context": context, "question": question},
80
  config=RunnableConfig(
81
  callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
82
  ),
config.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
4
  GOOGLE_CHAT_MODEL = "gemini-pro"
5
  GOOGLE_EMBEDDING_MODEL = "models/embedding-001"
6
- STORAGE_PATH = "data/chroma/"
7
  HIISTORY_FILE = "./data/qa_history.txt"
8
 
9
  NUM_DAYS_PAST = 30
 
3
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
4
  GOOGLE_CHAT_MODEL = "gemini-pro"
5
  GOOGLE_EMBEDDING_MODEL = "models/embedding-001"
6
+ STORAGE_PATH = "./data/"
7
  HIISTORY_FILE = "./data/qa_history.txt"
8
 
9
  NUM_DAYS_PAST = 30
prompts.py CHANGED
@@ -1,11 +1,27 @@
1
  from langchain.prompts import ChatPromptTemplate
2
 
3
  template = """
4
- Répondez à la question en vous basant uniquement sur le contexte suivant:
 
 
5
 
6
- {context}
7
 
8
- Question : {question}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  """
11
 
 
1
  from langchain.prompts import ChatPromptTemplate
2
 
3
  template = """
4
+ Vous êtes un assistant de recherche économique et financière, spécialement conçu pour répondre aux questions liées à l'économie et à la finance et pour aider à l'informations et la prise de décisions financières. Votre rôle consiste à analyser les articles et rapports d'actualité économique et financière qui vous sera fournis dans le contexte et à répondre de manière adequate aux questions spécifiques des utilisateurs. Lorsque vous répondez aux questions :
5
+ - Pour des questions d'ordre générales (ex: "Quelle est l'actualité du jour?") : Lisez attentivement tous les articles et résumez les points\évènements clés en mentionnant les dates de publications.
6
+ - Pour des questions spécifiques (ex: "Quelle est la tendance du marché boursier aujourd'hui?") : Recherchez les informations spécifiques à la question dans les articles.
7
 
8
+ -N'hésitez pas à utiliser vos connaissances et votre bon sens pour répondre aux questions.
9
 
10
+ - Basez vos réponses sur les articles d'actualité fournis. Citez directement les parties pertinentes de ces documents pour étayer vos réponses.
11
+ - Citez clairement les références, y compris les titres des articles, les dates de publication et tout autre détail pertinent, afin de vous assurer que les informations peuvent être facilement vérifiées et retracées jusqu'aux sources originales.
12
+
13
+ - Si la question sort du cadre des documents fournis ou si vous ne trouvez pas d'informations pertinentes, indiquez poliment que la réponse ne peut être déterminée sur la base des sources disponibles. Suggérez de consulter d'autres articles d'actualité financière ou des bases de données pour obtenir une réponse complète, le cas échéant.
14
+ - Insistez sur l'exactitude et la fiabilité de vos réponses, en comprenant la nature critique de votre aide dans les processus de prise de décision financière.
15
+ - Répondez aux utilisateurs dans la langue de leur question. Si la question est en français, votre réponse doit être en français. Si la question est en anglais, votre réponse doit être en anglais.
16
+ - Pour des question en relative à la date veuillez considerer qu'aujourd'hui est le Jeudi 11/04/2024. Par exemple pour repondre à une question sur l'actualité du jour, vous devez effectuer une comparaison entre les date de publications des articles et celle d'aujourdui pour filtrer sur les articles puis retourner les informations pertinantes.
17
+
18
+ <contexte>
19
+ ``{context}``
20
+ </contexte>
21
+
22
+ <question>
23
+ {question}
24
+ </question>
25
 
26
  """
27
 
requirements.txt CHANGED
@@ -4,4 +4,5 @@ chainlit==1.0.500
4
  chromadb==0.4.24
5
  lark==1.1.9
6
  bs4==0.0.2
7
- selenium==4.19.0
 
 
4
  chromadb==0.4.24
5
  lark==1.1.9
6
  bs4==0.0.2
7
+ selenium==4.19.0
8
+ tiktoken==0.1.1
scrape_data.py CHANGED
@@ -2,7 +2,9 @@ import os
2
  from datetime import date, timedelta
3
 
4
  import bs4
5
- from langchain.indexes import SQLRecordManager, index
 
 
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.vectorstores.chroma import Chroma
8
  from langchain_community.document_loaders import WebBaseLoader
@@ -81,7 +83,7 @@ def set_metadata(documents, metadatas):
81
 
82
 
83
  def process_docs(
84
- articles, persist_directory, embeddings_model, chunk_size=1000, chunk_overlap=100
85
  ):
86
  """
87
  #Scrap all articles urls content and save on a vector DB
@@ -105,45 +107,33 @@ def process_docs(
105
  # Update metadata: add title,
106
  set_metadata(documents=docs, metadatas=articles)
107
 
108
- print("Successfully loaded to document")
109
 
110
- text_splitter = RecursiveCharacterTextSplitter(
111
- chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=["\n"]
112
- )
113
- splits = text_splitter.split_documents(docs)
114
-
115
- # Create the storage path if it doesn't exist
116
- if not os.path.exists(persist_directory):
117
- os.makedirs(persist_directory)
118
 
119
- doc_search = Chroma.from_documents(
120
- documents=splits,
121
- embedding=embeddings_model,
122
- persist_directory=persist_directory,
 
123
  )
124
 
125
- # Indexing data
126
- namespace = "chromadb/my_documents"
127
- record_manager = SQLRecordManager(
128
- namespace, db_url="sqlite:///record_manager_cache.sql"
129
- )
130
- record_manager.create_schema()
131
-
132
- index_result = index(
133
- docs,
134
- record_manager,
135
- doc_search,
136
- cleanup="incremental",
137
- source_id_key="link",
138
- )
139
 
140
- print(f"Indexing stats: {index_result}")
 
 
 
 
141
 
142
- return doc_search
 
143
 
144
 
145
  if __name__ == "__main__":
146
 
147
  data = scrap_articles(DATA_URL, num_days_past=config.NUM_DAYS_PAST)
148
- vectordb = process_docs(data, config.STORAGE_PATH, embeddings_model)
149
- ret = vectordb.as_retriever()
 
2
  from datetime import date, timedelta
3
 
4
  import bs4
5
+ from langchain.retrievers import ParentDocumentRetriever
6
+ from langchain.storage import LocalFileStore
7
+ from langchain.storage._lc_store import create_kv_docstore
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain.vectorstores.chroma import Chroma
10
  from langchain_community.document_loaders import WebBaseLoader
 
83
 
84
 
85
  def process_docs(
86
+ articles, persist_directory, embeddings_model, chunk_size=500, chunk_overlap=0
87
  ):
88
  """
89
  #Scrap all articles urls content and save on a vector DB
 
107
  # Update metadata: add title,
108
  set_metadata(documents=docs, metadatas=articles)
109
 
110
+ # print("Successfully loaded to document")
111
 
112
+ # This text splitter is used to create the child documents
113
+ child_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=["\n"])
 
 
 
 
 
 
114
 
115
+ # The vectorstore to use to index the child chunks
116
+ vectorstore = Chroma(
117
+ persist_directory=persist_directory + "vectorstore",
118
+ collection_name="full_documents",
119
+ embedding_function=embeddings_model,
120
  )
121
 
122
+ # The storage layer for the parent documents
123
+ fs = LocalFileStore(persist_directory + "docstore")
124
+ store = create_kv_docstore(fs)
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ retriever = ParentDocumentRetriever(
127
+ vectorstore=vectorstore,
128
+ docstore=store,
129
+ child_splitter=child_splitter,
130
+ )
131
 
132
+ retriever.add_documents(docs, ids=None)
133
+ print(len(docs), " documents added")
134
 
135
 
136
  if __name__ == "__main__":
137
 
138
  data = scrap_articles(DATA_URL, num_days_past=config.NUM_DAYS_PAST)
139
+ process_docs(data, config.STORAGE_PATH, embeddings_model)
 
utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chainlit as cl
2
+ import tiktoken
3
+ from langchain.callbacks.base import BaseCallbackHandler
4
+
5
+
6
+ def format_docs(documents, max_context_size=100000, separator="\n\n"):
7
+ context = ""
8
+ encoder = tiktoken.get_encoding("cl100k_base")
9
+ i = 0
10
+ for doc in documents:
11
+ i += 1
12
+ if len(encoder.encode(context)) < max_context_size:
13
+ source = doc.metadata["link"]
14
+ title = doc.metadata["title"]
15
+ context += (
16
+ f"Article: {title}\n" + doc.page_content + f"\nSource: {source}" + separator
17
+ )
18
+ return context
19
+
20
+
21
+ class PostMessageHandler(BaseCallbackHandler):
22
+ """
23
+ Callback handler for handling the retriever and LLM processes.
24
+ Used to post the sources of the retrieved documents as a Chainlit element.
25
+ """
26
+
27
+ def __init__(self, msg: cl.Message):
28
+ BaseCallbackHandler.__init__(self)
29
+ self.msg = msg
30
+ self.sources = []
31
+
32
+ def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
33
+ for d in documents:
34
+ source_doc = d.page_content + "\nSource: " + d.metadata["link"]
35
+ self.sources.append(source_doc)
36
+
37
+ def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
38
+ if len(self.sources):
39
+ # Display the reference docs with a Text widget
40
+ sources_element = [
41
+ cl.Text(name=f"source_{idx+1}", content=content)
42
+ for idx, content in enumerate(self.sources)
43
+ ]
44
+ source_names = [el.name for el in sources_element]
45
+ self.msg.elements += sources_element
46
+ self.msg.content += f"\nSources: {', '.join(source_names)}"
47
+
48
+ def clean_text(text):
49
+ tx = text.replace("Tweet","")
50
+ tx = tx.replace("\n\n\n\n\n\n\n\n\n","")
51
+ return tx