Monsia commited on
Commit
c4331f2
·
0 Parent(s):

first commit

Browse files
Files changed (13) hide show
  1. .gitattributes +35 -0
  2. .gitignore +119 -0
  3. Dockerfile +14 -0
  4. README.md +25 -0
  5. app.py +121 -0
  6. chainlit.md +5 -0
  7. config.py +7 -0
  8. prompts.py +12 -0
  9. public/favicon.png +0 -0
  10. public/logo_dark.png +0 -0
  11. public/logo_light.png +0 -0
  12. requirements.txt +7 -0
  13. scrape_data.py +149 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by .ignore support plugin (hsz.mobi)
2
+ ### Python template
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ .hypothesis/
50
+ .pytest_cache/
51
+
52
+ # Translations
53
+ *.mo
54
+ *.pot
55
+
56
+ # Django stuff:
57
+ *.log
58
+ local_settings.py
59
+ db.sqlite3
60
+
61
+ # Flask stuff:
62
+ instance/
63
+ .webassets-cache
64
+
65
+ # Scrapy stuff:
66
+ .scrapy
67
+
68
+ # Sphinx documentation
69
+ docs/_build/
70
+
71
+ # PyBuilder
72
+ target/
73
+
74
+ # Jupyter Notebook
75
+ .ipynb_checkpoints
76
+
77
+ # pyenv
78
+ .python-version
79
+
80
+ # celery beat schedule file
81
+ celerybeat-schedule
82
+
83
+ # SageMath parsed files
84
+ *.sage.py
85
+
86
+ # Environments
87
+ .env
88
+ .venv
89
+ env/
90
+ venv/
91
+ ENV/
92
+ env.bak/
93
+ venv.bak/
94
+
95
+ # Spyder project settings
96
+ .spyderproject
97
+ .spyproject
98
+
99
+ # Rope project settings
100
+ .ropeproject
101
+
102
+ # mkdocs documentation
103
+ /site
104
+
105
+ # mypy
106
+ .mypy_cache/
107
+
108
+ .idea/*
109
+ .files/*
110
+
111
+ tmp
112
+ secret.*
113
+ volumes/
114
+ .chainlit
115
+ .DS_Store
116
+ __init__.py
117
+ data/
118
+ data*
119
+ record_manager_cache.sql
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+ RUN useradd -m -u 1000 user
3
+ USER user
4
+ ENV HOME=/home/user \
5
+ PATH=/home/user/.local/bin:$PATH
6
+ WORKDIR $HOME/app
7
+ COPY --chown=user . $HOME/app
8
+ COPY ./requirements.txt ~/app/requirements.txt
9
+ RUN pip install -r requirements.txt
10
+ COPY --chown=user . .
11
+ RUN --mount=type=secret,id=GOOGLE_API_KEY,mode=0444,required=true \
12
+ export GOOGLE_API_KEY=$(cat /run/secrets/GOOGLE_API_KEY) &&\
13
+ python scrape_data.py
14
+ CMD ["chainlit", "run", "app.py", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FinChat
3
+ emoji: 🤑
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # FinChat
12
+
13
+ FinChat est un chatbot conçu par [data354](https://data354.com/) pour répondre aux questions sur l'actualité économique et financière.
14
+
15
+
16
+ ## How to run ?
17
+
18
+ 1. Executer le script pour scraper et stoker les données:
19
+ ```shell
20
+ python scrape_data.py
21
+ ```
22
+ 2. Lancez la démo et commencer à interagir avec l'agent.
23
+
24
+ ```shell
25
+ chainlit run app.py
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chainlit as cl
2
+ from langchain.callbacks.base import BaseCallbackHandler
3
+ from langchain.chains.query_constructor.schema import AttributeInfo
4
+ from langchain.retrievers.self_query.base import SelfQueryRetriever
5
+ from langchain.schema import StrOutputParser
6
+ from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
7
+ from langchain.vectorstores.chroma import Chroma
8
+ from langchain_google_genai import (
9
+ GoogleGenerativeAI,
10
+ GoogleGenerativeAIEmbeddings,
11
+ HarmBlockThreshold,
12
+ HarmCategory,
13
+ )
14
+
15
+ import config
16
+ from prompts import prompt
17
+
18
+ metadata_field_info = [
19
+ AttributeInfo(
20
+ name="title",
21
+ description="Le titre de l'article",
22
+ type="string",
23
+ ),
24
+ AttributeInfo(
25
+ name="date",
26
+ description="Date de publication",
27
+ type="string",
28
+ ),
29
+ AttributeInfo(name="link", description="Source de l'article", type="string"),
30
+ ]
31
+ document_content_description = "Articles sur l'actualité."
32
+
33
+ model = GoogleGenerativeAI(
34
+ model=config.GOOGLE_CHAT_MODEL,
35
+ google_api_key=config.GOOGLE_API_KEY,
36
+ safety_settings={
37
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
38
+ },
39
+ ) # type: ignore
40
+
41
+ # Load vector database that was persisted earlier
42
+ embedding = embeddings_model = GoogleGenerativeAIEmbeddings(
43
+ model="models/embedding-001", google_api_key=config.GOOGLE_API_KEY
44
+ ) # type: ignore
45
+
46
+ vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
47
+
48
+ retriever = SelfQueryRetriever.from_llm(
49
+ model,
50
+ vectordb,
51
+ document_content_description,
52
+ metadata_field_info,
53
+ )
54
+
55
+
56
+ @cl.on_chat_start
57
+ async def on_chat_start():
58
+
59
+ def format_docs(docs):
60
+ return "\n\n".join(doc.page_content for doc in docs)
61
+
62
+ rag_chain = (
63
+ {
64
+ "context": vectordb.as_retriever() | format_docs,
65
+ "question": RunnablePassthrough(),
66
+ }
67
+ | prompt
68
+ | model
69
+ | StrOutputParser()
70
+ )
71
+
72
+ cl.user_session.set("rag_chain", rag_chain)
73
+
74
+ msg = cl.Message(
75
+ content=f"Vous pouvez poser vos questions sur les articles de SIKAFINANCE",
76
+ )
77
+ await msg.send()
78
+
79
+
80
+ @cl.on_message
81
+ async def on_message(message: cl.Message):
82
+ runnable = cl.user_session.get("rag_chain") # type: Runnable # type: ignore
83
+ msg = cl.Message(content="")
84
+
85
+ class PostMessageHandler(BaseCallbackHandler):
86
+ """
87
+ Callback handler for handling the retriever and LLM processes.
88
+ Used to post the sources of the retrieved documents as a Chainlit element.
89
+ """
90
+
91
+ def __init__(self, msg: cl.Message):
92
+ BaseCallbackHandler.__init__(self)
93
+ self.msg = msg
94
+ self.sources = []
95
+
96
+ def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
97
+ for d in documents:
98
+ source_doc = d.page_content + "\nSource: " + d.metadata["link"]
99
+ self.sources.append(source_doc)
100
+
101
+ def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
102
+ if len(self.sources):
103
+ # Display the reference docs with a Text widget
104
+ sources_element = [
105
+ cl.Text(name=f"source_{idx+1}", content=content)
106
+ for idx, content in enumerate(self.sources)
107
+ ]
108
+ source_names = [el.name for el in sources_element]
109
+ self.msg.elements += sources_element
110
+ self.msg.content += f"\nSources: {', '.join(source_names)}"
111
+
112
+ async with cl.Step(type="run", name="QA Assistant"):
113
+ async for chunk in runnable.astream(
114
+ message.content,
115
+ config=RunnableConfig(
116
+ callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
117
+ ),
118
+ ):
119
+ await msg.stream_token(chunk)
120
+
121
+ await msg.send()
chainlit.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Welcome to FinChat! 🚀🤖
2
+
3
+ FinChat est un chatbot conçu par [data354](https://data354.com/) pour répondre aux questions sur l'actualité économique et financière.
4
+
5
+ > Voilà, c'est fait ! Vous pouvez maintenant posez vos questions 💻😊.
config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
4
+ GOOGLE_CHAT_MODEL = "gemini-pro"
5
+ GOOGLE_EMBEDDING_MODEL = "models/embedding-001"
6
+ STORAGE_PATH = "data/chroma/"
7
+ HIISTORY_FILE = "./data/qa_history.txt"
prompts.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import ChatPromptTemplate
2
+
3
+ template = """
4
+ Répondez à la question en vous basant uniquement sur le contexte suivant:
5
+
6
+ {context}
7
+
8
+ Question : {question}
9
+
10
+ """
11
+
12
+ prompt = ChatPromptTemplate.from_template(template)
public/favicon.png ADDED
public/logo_dark.png ADDED
public/logo_light.png ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ langchain==0.1.14
2
+ langchain-google-genai==1.0.1
3
+ chainlit==1.0.500
4
+ chromadb==0.4.24
5
+ lark==1.1.9
6
+ bs4==0.0.2
7
+ selenium==4.19.0
scrape_data.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import date, timedelta
3
+
4
+ import bs4
5
+ from langchain.indexes import SQLRecordManager, index
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.vectorstores.chroma import Chroma
8
+ from langchain_community.document_loaders import WebBaseLoader
9
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
10
+ from selenium import webdriver
11
+ from selenium.webdriver.common.by import By
12
+ from selenium.webdriver.support import expected_conditions as EC
13
+ from selenium.webdriver.support.ui import WebDriverWait
14
+
15
+ import config
16
+
17
+ DATA_URL = "https://www.sikafinance.com/marches/actualites_bourse_brvm"
18
+
19
+ embeddings_model = GoogleGenerativeAIEmbeddings(
20
+ model=config.GOOGLE_EMBEDDING_MODEL
21
+ ) # type: ignore
22
+
23
+
24
+ options = webdriver.ChromeOptions()
25
+ options.add_argument("--headless")
26
+ options.add_argument("--no-sandbox")
27
+ options.add_argument("--disable-dev-shm-usage")
28
+ driver = webdriver.Chrome(options=options)
29
+
30
+
31
+ def scrap_articles(
32
+ url="https://www.sikafinance.com/marches/actualites_bourse_brvm", num_days_past=5
33
+ ):
34
+
35
+ today = date.today()
36
+
37
+ driver.get(url)
38
+
39
+ all_articles = []
40
+ for i in range(num_days_past + 1):
41
+ past_date = today - timedelta(days=i)
42
+ date_str = past_date.strftime("%Y-%m-%d")
43
+ WebDriverWait(driver, 10).until(
44
+ EC.presence_of_element_located((By.ID, "dateActu"))
45
+ )
46
+ text_box = driver.find_element(By.ID, "dateActu")
47
+ text_box.send_keys(date_str)
48
+
49
+ submit_btn = WebDriverWait(driver, 10).until(
50
+ EC.element_to_be_clickable((By.ID, "btn"))
51
+ )
52
+ submit_btn.click()
53
+
54
+ dates = driver.find_elements(By.CLASS_NAME, "sp1")
55
+ titles = driver.find_elements(By.XPATH, "//td/a")
56
+
57
+ articles = []
58
+ for i in range(len(titles)):
59
+ art = {
60
+ "title": titles[i].text.strip(),
61
+ "date": dates[i].text,
62
+ "link": titles[i].get_attribute("href"),
63
+ }
64
+ articles.append(art)
65
+
66
+ all_articles += articles
67
+ # driver.quit()
68
+
69
+ return all_articles
70
+
71
+
72
+ def set_metadata(documents, metadatas):
73
+ """
74
+ #Edit a metadata of lanchain Documents object
75
+ """
76
+ for doc in documents:
77
+ idx = documents.index(doc)
78
+ doc.metadata = metadatas[idx]
79
+ print("Metadata successfully changed")
80
+ print(documents[0].metadata)
81
+
82
+
83
+ def process_docs(
84
+ articles, persist_directory, embeddings_model, chunk_size=1000, chunk_overlap=100
85
+ ):
86
+ """
87
+ #Scrap all articles urls content and save on a vector DB
88
+ """
89
+ article_urls = [a["link"] for a in articles]
90
+
91
+ print("Starting to scrap ..")
92
+
93
+ loader = WebBaseLoader(
94
+ web_paths=article_urls,
95
+ bs_kwargs=dict(
96
+ parse_only=bs4.SoupStrainer(
97
+ class_=("inarticle txtbig", "dt_sign", "innerUp")
98
+ )
99
+ ),
100
+ )
101
+
102
+ print("After scraping Loading ..")
103
+ docs = loader.load()
104
+
105
+ # Update metadata: add title,
106
+ set_metadata(documents=docs, metadatas=articles)
107
+
108
+ print("Successfully loaded to document")
109
+
110
+ text_splitter = RecursiveCharacterTextSplitter(
111
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=["\n"]
112
+ )
113
+ splits = text_splitter.split_documents(docs)
114
+
115
+ # Create the storage path if it doesn't exist
116
+ if not os.path.exists(persist_directory):
117
+ os.makedirs(persist_directory)
118
+
119
+ doc_search = Chroma.from_documents(
120
+ documents=splits,
121
+ embedding=embeddings_model,
122
+ persist_directory=persist_directory,
123
+ )
124
+
125
+ # Indexing data
126
+ namespace = "chromadb/my_documents"
127
+ record_manager = SQLRecordManager(
128
+ namespace, db_url="sqlite:///record_manager_cache.sql"
129
+ )
130
+ record_manager.create_schema()
131
+
132
+ index_result = index(
133
+ docs,
134
+ record_manager,
135
+ doc_search,
136
+ cleanup="incremental",
137
+ source_id_key="link",
138
+ )
139
+
140
+ print(f"Indexing stats: {index_result}")
141
+
142
+ return doc_search
143
+
144
+
145
+ if __name__ == "__main__":
146
+
147
+ data = scrap_articles(DATA_URL, num_days_past=2)
148
+ vectordb = process_docs(data, config.STORAGE_PATH, embeddings_model)
149
+ ret = vectordb.as_retriever()