Commit
·
c4331f2
0
Parent(s):
first commit
Browse files- .gitattributes +35 -0
- .gitignore +119 -0
- Dockerfile +14 -0
- README.md +25 -0
- app.py +121 -0
- chainlit.md +5 -0
- config.py +7 -0
- prompts.py +12 -0
- public/favicon.png +0 -0
- public/logo_dark.png +0 -0
- public/logo_light.png +0 -0
- requirements.txt +7 -0
- scrape_data.py +149 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by .ignore support plugin (hsz.mobi)
|
2 |
+
### Python template
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
.hypothesis/
|
50 |
+
.pytest_cache/
|
51 |
+
|
52 |
+
# Translations
|
53 |
+
*.mo
|
54 |
+
*.pot
|
55 |
+
|
56 |
+
# Django stuff:
|
57 |
+
*.log
|
58 |
+
local_settings.py
|
59 |
+
db.sqlite3
|
60 |
+
|
61 |
+
# Flask stuff:
|
62 |
+
instance/
|
63 |
+
.webassets-cache
|
64 |
+
|
65 |
+
# Scrapy stuff:
|
66 |
+
.scrapy
|
67 |
+
|
68 |
+
# Sphinx documentation
|
69 |
+
docs/_build/
|
70 |
+
|
71 |
+
# PyBuilder
|
72 |
+
target/
|
73 |
+
|
74 |
+
# Jupyter Notebook
|
75 |
+
.ipynb_checkpoints
|
76 |
+
|
77 |
+
# pyenv
|
78 |
+
.python-version
|
79 |
+
|
80 |
+
# celery beat schedule file
|
81 |
+
celerybeat-schedule
|
82 |
+
|
83 |
+
# SageMath parsed files
|
84 |
+
*.sage.py
|
85 |
+
|
86 |
+
# Environments
|
87 |
+
.env
|
88 |
+
.venv
|
89 |
+
env/
|
90 |
+
venv/
|
91 |
+
ENV/
|
92 |
+
env.bak/
|
93 |
+
venv.bak/
|
94 |
+
|
95 |
+
# Spyder project settings
|
96 |
+
.spyderproject
|
97 |
+
.spyproject
|
98 |
+
|
99 |
+
# Rope project settings
|
100 |
+
.ropeproject
|
101 |
+
|
102 |
+
# mkdocs documentation
|
103 |
+
/site
|
104 |
+
|
105 |
+
# mypy
|
106 |
+
.mypy_cache/
|
107 |
+
|
108 |
+
.idea/*
|
109 |
+
.files/*
|
110 |
+
|
111 |
+
tmp
|
112 |
+
secret.*
|
113 |
+
volumes/
|
114 |
+
.chainlit
|
115 |
+
.DS_Store
|
116 |
+
__init__.py
|
117 |
+
data/
|
118 |
+
data*
|
119 |
+
record_manager_cache.sql
|
Dockerfile
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11
|
2 |
+
RUN useradd -m -u 1000 user
|
3 |
+
USER user
|
4 |
+
ENV HOME=/home/user \
|
5 |
+
PATH=/home/user/.local/bin:$PATH
|
6 |
+
WORKDIR $HOME/app
|
7 |
+
COPY --chown=user . $HOME/app
|
8 |
+
COPY ./requirements.txt ~/app/requirements.txt
|
9 |
+
RUN pip install -r requirements.txt
|
10 |
+
COPY --chown=user . .
|
11 |
+
RUN --mount=type=secret,id=GOOGLE_API_KEY,mode=0444,required=true \
|
12 |
+
export GOOGLE_API_KEY=$(cat /run/secrets/GOOGLE_API_KEY) &&\
|
13 |
+
python scrape_data.py
|
14 |
+
CMD ["chainlit", "run", "app.py", "--port", "7860"]
|
README.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: FinChat
|
3 |
+
emoji: 🤑
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: purple
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: apache-2.0
|
9 |
+
---
|
10 |
+
|
11 |
+
# FinChat
|
12 |
+
|
13 |
+
FinChat est un chatbot conçu par [data354](https://data354.com/) pour répondre aux questions sur l'actualité économique et financière.
|
14 |
+
|
15 |
+
|
16 |
+
## How to run ?
|
17 |
+
|
18 |
+
1. Executer le script pour scraper et stoker les données:
|
19 |
+
```shell
|
20 |
+
python scrape_data.py
|
21 |
+
```
|
22 |
+
2. Lancez la démo et commencer à interagir avec l'agent.
|
23 |
+
|
24 |
+
```shell
|
25 |
+
chainlit run app.py
|
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chainlit as cl
|
2 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
3 |
+
from langchain.chains.query_constructor.schema import AttributeInfo
|
4 |
+
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
5 |
+
from langchain.schema import StrOutputParser
|
6 |
+
from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
|
7 |
+
from langchain.vectorstores.chroma import Chroma
|
8 |
+
from langchain_google_genai import (
|
9 |
+
GoogleGenerativeAI,
|
10 |
+
GoogleGenerativeAIEmbeddings,
|
11 |
+
HarmBlockThreshold,
|
12 |
+
HarmCategory,
|
13 |
+
)
|
14 |
+
|
15 |
+
import config
|
16 |
+
from prompts import prompt
|
17 |
+
|
18 |
+
metadata_field_info = [
|
19 |
+
AttributeInfo(
|
20 |
+
name="title",
|
21 |
+
description="Le titre de l'article",
|
22 |
+
type="string",
|
23 |
+
),
|
24 |
+
AttributeInfo(
|
25 |
+
name="date",
|
26 |
+
description="Date de publication",
|
27 |
+
type="string",
|
28 |
+
),
|
29 |
+
AttributeInfo(name="link", description="Source de l'article", type="string"),
|
30 |
+
]
|
31 |
+
document_content_description = "Articles sur l'actualité."
|
32 |
+
|
33 |
+
model = GoogleGenerativeAI(
|
34 |
+
model=config.GOOGLE_CHAT_MODEL,
|
35 |
+
google_api_key=config.GOOGLE_API_KEY,
|
36 |
+
safety_settings={
|
37 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
38 |
+
},
|
39 |
+
) # type: ignore
|
40 |
+
|
41 |
+
# Load vector database that was persisted earlier
|
42 |
+
embedding = embeddings_model = GoogleGenerativeAIEmbeddings(
|
43 |
+
model="models/embedding-001", google_api_key=config.GOOGLE_API_KEY
|
44 |
+
) # type: ignore
|
45 |
+
|
46 |
+
vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
|
47 |
+
|
48 |
+
retriever = SelfQueryRetriever.from_llm(
|
49 |
+
model,
|
50 |
+
vectordb,
|
51 |
+
document_content_description,
|
52 |
+
metadata_field_info,
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
@cl.on_chat_start
|
57 |
+
async def on_chat_start():
|
58 |
+
|
59 |
+
def format_docs(docs):
|
60 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
61 |
+
|
62 |
+
rag_chain = (
|
63 |
+
{
|
64 |
+
"context": vectordb.as_retriever() | format_docs,
|
65 |
+
"question": RunnablePassthrough(),
|
66 |
+
}
|
67 |
+
| prompt
|
68 |
+
| model
|
69 |
+
| StrOutputParser()
|
70 |
+
)
|
71 |
+
|
72 |
+
cl.user_session.set("rag_chain", rag_chain)
|
73 |
+
|
74 |
+
msg = cl.Message(
|
75 |
+
content=f"Vous pouvez poser vos questions sur les articles de SIKAFINANCE",
|
76 |
+
)
|
77 |
+
await msg.send()
|
78 |
+
|
79 |
+
|
80 |
+
@cl.on_message
|
81 |
+
async def on_message(message: cl.Message):
|
82 |
+
runnable = cl.user_session.get("rag_chain") # type: Runnable # type: ignore
|
83 |
+
msg = cl.Message(content="")
|
84 |
+
|
85 |
+
class PostMessageHandler(BaseCallbackHandler):
|
86 |
+
"""
|
87 |
+
Callback handler for handling the retriever and LLM processes.
|
88 |
+
Used to post the sources of the retrieved documents as a Chainlit element.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(self, msg: cl.Message):
|
92 |
+
BaseCallbackHandler.__init__(self)
|
93 |
+
self.msg = msg
|
94 |
+
self.sources = []
|
95 |
+
|
96 |
+
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
|
97 |
+
for d in documents:
|
98 |
+
source_doc = d.page_content + "\nSource: " + d.metadata["link"]
|
99 |
+
self.sources.append(source_doc)
|
100 |
+
|
101 |
+
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
|
102 |
+
if len(self.sources):
|
103 |
+
# Display the reference docs with a Text widget
|
104 |
+
sources_element = [
|
105 |
+
cl.Text(name=f"source_{idx+1}", content=content)
|
106 |
+
for idx, content in enumerate(self.sources)
|
107 |
+
]
|
108 |
+
source_names = [el.name for el in sources_element]
|
109 |
+
self.msg.elements += sources_element
|
110 |
+
self.msg.content += f"\nSources: {', '.join(source_names)}"
|
111 |
+
|
112 |
+
async with cl.Step(type="run", name="QA Assistant"):
|
113 |
+
async for chunk in runnable.astream(
|
114 |
+
message.content,
|
115 |
+
config=RunnableConfig(
|
116 |
+
callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
|
117 |
+
),
|
118 |
+
):
|
119 |
+
await msg.stream_token(chunk)
|
120 |
+
|
121 |
+
await msg.send()
|
chainlit.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Welcome to FinChat! 🚀🤖
|
2 |
+
|
3 |
+
FinChat est un chatbot conçu par [data354](https://data354.com/) pour répondre aux questions sur l'actualité économique et financière.
|
4 |
+
|
5 |
+
> Voilà, c'est fait ! Vous pouvez maintenant posez vos questions 💻😊.
|
config.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
4 |
+
GOOGLE_CHAT_MODEL = "gemini-pro"
|
5 |
+
GOOGLE_EMBEDDING_MODEL = "models/embedding-001"
|
6 |
+
STORAGE_PATH = "data/chroma/"
|
7 |
+
HIISTORY_FILE = "./data/qa_history.txt"
|
prompts.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import ChatPromptTemplate
|
2 |
+
|
3 |
+
template = """
|
4 |
+
Répondez à la question en vous basant uniquement sur le contexte suivant:
|
5 |
+
|
6 |
+
{context}
|
7 |
+
|
8 |
+
Question : {question}
|
9 |
+
|
10 |
+
"""
|
11 |
+
|
12 |
+
prompt = ChatPromptTemplate.from_template(template)
|
public/favicon.png
ADDED
![]() |
public/logo_dark.png
ADDED
![]() |
public/logo_light.png
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain==0.1.14
|
2 |
+
langchain-google-genai==1.0.1
|
3 |
+
chainlit==1.0.500
|
4 |
+
chromadb==0.4.24
|
5 |
+
lark==1.1.9
|
6 |
+
bs4==0.0.2
|
7 |
+
selenium==4.19.0
|
scrape_data.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datetime import date, timedelta
|
3 |
+
|
4 |
+
import bs4
|
5 |
+
from langchain.indexes import SQLRecordManager, index
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
+
from langchain.vectorstores.chroma import Chroma
|
8 |
+
from langchain_community.document_loaders import WebBaseLoader
|
9 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
10 |
+
from selenium import webdriver
|
11 |
+
from selenium.webdriver.common.by import By
|
12 |
+
from selenium.webdriver.support import expected_conditions as EC
|
13 |
+
from selenium.webdriver.support.ui import WebDriverWait
|
14 |
+
|
15 |
+
import config
|
16 |
+
|
17 |
+
DATA_URL = "https://www.sikafinance.com/marches/actualites_bourse_brvm"
|
18 |
+
|
19 |
+
embeddings_model = GoogleGenerativeAIEmbeddings(
|
20 |
+
model=config.GOOGLE_EMBEDDING_MODEL
|
21 |
+
) # type: ignore
|
22 |
+
|
23 |
+
|
24 |
+
options = webdriver.ChromeOptions()
|
25 |
+
options.add_argument("--headless")
|
26 |
+
options.add_argument("--no-sandbox")
|
27 |
+
options.add_argument("--disable-dev-shm-usage")
|
28 |
+
driver = webdriver.Chrome(options=options)
|
29 |
+
|
30 |
+
|
31 |
+
def scrap_articles(
|
32 |
+
url="https://www.sikafinance.com/marches/actualites_bourse_brvm", num_days_past=5
|
33 |
+
):
|
34 |
+
|
35 |
+
today = date.today()
|
36 |
+
|
37 |
+
driver.get(url)
|
38 |
+
|
39 |
+
all_articles = []
|
40 |
+
for i in range(num_days_past + 1):
|
41 |
+
past_date = today - timedelta(days=i)
|
42 |
+
date_str = past_date.strftime("%Y-%m-%d")
|
43 |
+
WebDriverWait(driver, 10).until(
|
44 |
+
EC.presence_of_element_located((By.ID, "dateActu"))
|
45 |
+
)
|
46 |
+
text_box = driver.find_element(By.ID, "dateActu")
|
47 |
+
text_box.send_keys(date_str)
|
48 |
+
|
49 |
+
submit_btn = WebDriverWait(driver, 10).until(
|
50 |
+
EC.element_to_be_clickable((By.ID, "btn"))
|
51 |
+
)
|
52 |
+
submit_btn.click()
|
53 |
+
|
54 |
+
dates = driver.find_elements(By.CLASS_NAME, "sp1")
|
55 |
+
titles = driver.find_elements(By.XPATH, "//td/a")
|
56 |
+
|
57 |
+
articles = []
|
58 |
+
for i in range(len(titles)):
|
59 |
+
art = {
|
60 |
+
"title": titles[i].text.strip(),
|
61 |
+
"date": dates[i].text,
|
62 |
+
"link": titles[i].get_attribute("href"),
|
63 |
+
}
|
64 |
+
articles.append(art)
|
65 |
+
|
66 |
+
all_articles += articles
|
67 |
+
# driver.quit()
|
68 |
+
|
69 |
+
return all_articles
|
70 |
+
|
71 |
+
|
72 |
+
def set_metadata(documents, metadatas):
|
73 |
+
"""
|
74 |
+
#Edit a metadata of lanchain Documents object
|
75 |
+
"""
|
76 |
+
for doc in documents:
|
77 |
+
idx = documents.index(doc)
|
78 |
+
doc.metadata = metadatas[idx]
|
79 |
+
print("Metadata successfully changed")
|
80 |
+
print(documents[0].metadata)
|
81 |
+
|
82 |
+
|
83 |
+
def process_docs(
|
84 |
+
articles, persist_directory, embeddings_model, chunk_size=1000, chunk_overlap=100
|
85 |
+
):
|
86 |
+
"""
|
87 |
+
#Scrap all articles urls content and save on a vector DB
|
88 |
+
"""
|
89 |
+
article_urls = [a["link"] for a in articles]
|
90 |
+
|
91 |
+
print("Starting to scrap ..")
|
92 |
+
|
93 |
+
loader = WebBaseLoader(
|
94 |
+
web_paths=article_urls,
|
95 |
+
bs_kwargs=dict(
|
96 |
+
parse_only=bs4.SoupStrainer(
|
97 |
+
class_=("inarticle txtbig", "dt_sign", "innerUp")
|
98 |
+
)
|
99 |
+
),
|
100 |
+
)
|
101 |
+
|
102 |
+
print("After scraping Loading ..")
|
103 |
+
docs = loader.load()
|
104 |
+
|
105 |
+
# Update metadata: add title,
|
106 |
+
set_metadata(documents=docs, metadatas=articles)
|
107 |
+
|
108 |
+
print("Successfully loaded to document")
|
109 |
+
|
110 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
111 |
+
chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=["\n"]
|
112 |
+
)
|
113 |
+
splits = text_splitter.split_documents(docs)
|
114 |
+
|
115 |
+
# Create the storage path if it doesn't exist
|
116 |
+
if not os.path.exists(persist_directory):
|
117 |
+
os.makedirs(persist_directory)
|
118 |
+
|
119 |
+
doc_search = Chroma.from_documents(
|
120 |
+
documents=splits,
|
121 |
+
embedding=embeddings_model,
|
122 |
+
persist_directory=persist_directory,
|
123 |
+
)
|
124 |
+
|
125 |
+
# Indexing data
|
126 |
+
namespace = "chromadb/my_documents"
|
127 |
+
record_manager = SQLRecordManager(
|
128 |
+
namespace, db_url="sqlite:///record_manager_cache.sql"
|
129 |
+
)
|
130 |
+
record_manager.create_schema()
|
131 |
+
|
132 |
+
index_result = index(
|
133 |
+
docs,
|
134 |
+
record_manager,
|
135 |
+
doc_search,
|
136 |
+
cleanup="incremental",
|
137 |
+
source_id_key="link",
|
138 |
+
)
|
139 |
+
|
140 |
+
print(f"Indexing stats: {index_result}")
|
141 |
+
|
142 |
+
return doc_search
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
|
147 |
+
data = scrap_articles(DATA_URL, num_days_past=2)
|
148 |
+
vectordb = process_docs(data, config.STORAGE_PATH, embeddings_model)
|
149 |
+
ret = vectordb.as_retriever()
|