alexneakameni commited on
Commit
15aea1e
·
verified ·
1 Parent(s): 902b219

Medivocate : An AI-powered platform exploring African history, culture, and traditional medicine, fostering understanding and appreciation of the continent's rich heritage.

Browse files
.gitattributes CHANGED
@@ -1,35 +1,36 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/chroma_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
172
+ .vscode/
173
+
174
+ *.out
175
+ .python-version
176
+
177
+ .venv
178
+ *.sh
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 KameniAlexNea
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,65 @@
1
- ---
2
- title: Medivocate
3
- emoji: 📈
4
- colorFrom: purple
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.16.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Exploring African history, culture, and traditional medicine
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Medivocate
3
+ emoji: 🐢
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.12.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Medivocate is an AI-driven platform leveraging Retrieval-Aug
12
+ ---
13
+
14
+ # Medivocate
15
+
16
+ An AI-driven platform empowering users with trustworthy, personalized history guidance to combat misinformation and promote equitable history.
17
+
18
+ ## Follows us [here](https://github.com/KameniAlexNea/medivocate)
19
+
20
+ * [**Alex Kameni**](https://www.linkedin.com/in/elie-alex-kameni-ngangue/)
21
+ * [**Esdras Fandio**](https://www.linkedin.com/in/esdras-fandio/)
22
+ * [**Patric Zeufack**](https://www.linkedin.com/in/zeufack-patric-hermann-7a9256143/)
23
+
24
+ ## Project Overview
25
+
26
+ **Medivocate** is structured for modular development and ease of scalability, as seen in its directory layout:
27
+
28
+ ```
29
+ 📦 ./
30
+ ├── 📁 docs/
31
+ ├── 📁 src/
32
+ │ ├── 📁 ocr/
33
+ │ ├── 📁 preprocessing/
34
+ │ ├── 📁 chunking/
35
+ │ ├── 📁 vector_store/
36
+ │ ├── 📁 rag_pipeline/
37
+ │ ├── 📁 llm_integration/
38
+ │ └── 📁 prompt_engineering/
39
+ ├── 📁 tests/
40
+ │ ├── 📁 unit/
41
+ │ └── 📁 integration/
42
+ ├── 📁 examples/
43
+ ├── 📁 notebooks/
44
+ ├── 📁 config/
45
+ ├── 📄 README.md
46
+ ├── 📄 CONTRIBUTING.md
47
+ ├── 📄 requirements.txt
48
+ ├── 📄 .gitignore
49
+ └── 📄 LICENSE
50
+ ```
51
+
52
+ ### Key Features
53
+
54
+ 1. **Trustworthy Information Access** : Using RAG (Retrieval-Augmented Generation) pipelines to deliver fact-based responses.
55
+ 2. **Advanced Document Handling** : Leveraging OCR, preprocessing, and chunking for scalable document ingestion.
56
+ 3. **Integrated Tools** : Supports integration with vector databases (e.g., Chroma), LLMs, and advanced prompt engineering techniques.
57
+
58
+ ### Recommendations for Integration
59
+
60
+ * **Groq** : Utilize Groq APIs for free-tier LLM support, perfect for prototyping RAG applications.
61
+ * **LangChain + LangSmith** : Build and monitor intelligent agents with LangChain and enhance debugging and evaluation using LangSmith.
62
+ * **Hugging Face Datasets** : For one-liner dataset loading and preprocessing, supporting efficient ML training pipelines.
63
+ * **Search Index** : Include Chroma for robust semantic search capabilities in RAG.
64
+
65
+ This modular design and extensive integration make Medivocate a powerful tool for historical education and research.
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["IS_APP"] = "1"
4
+ from typing import List
5
+
6
+ import gradio as gr
7
+
8
+ from src.rag_pipeline.rag_system import RAGSystem
9
+ from load_data import download_and_prepare_data
10
+
11
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
12
+
13
+
14
+ class ChatInterface:
15
+ def __init__(self, rag_system: RAGSystem):
16
+ self.rag_system = rag_system
17
+ self.history_depth = int(os.getenv("MAX_MESSAGES") or 5) * 2
18
+
19
+ def respond(self, message: str, history: List[List[str]]):
20
+ result = ""
21
+ history = [(turn["role"], turn["content"]) for turn in history[-self.history_depth:]]
22
+ for text in self.rag_system.query(message, history):
23
+ result += text
24
+ yield result
25
+ return result
26
+
27
+ def create_interface(self) -> gr.ChatInterface:
28
+ description = (
29
+ "Medivocate is an application that offers clear and structured information "
30
+ "about African history and traditional medicine. The knowledge is exclusively "
31
+ "based on historical documentaries about the African continent.\n\n"
32
+ "🌟 **Code Repository**: [Medivocate GitHub](https://github.com/KameniAlexNea/medivocate)"
33
+ )
34
+ return gr.ChatInterface(
35
+ fn=self.respond,
36
+ type="messages",
37
+ title="Medivocate",
38
+ description=description,
39
+ )
40
+
41
+
42
+ # Usage example:
43
+ if __name__ == "__main__":
44
+ # Example usage
45
+ zip_filename = "chroma_db.zip"
46
+ extract_to = "chroma_db"
47
+ target_folder = "data/chroma_db"
48
+ gdrive_url = os.getenv("GDRIVE_URL")
49
+ download_and_prepare_data(gdrive_url, zip_filename, extract_to, target_folder)
50
+
51
+ top_k_docs = int(os.getenv("N_CONTEXT") or 4)
52
+ rag_system = RAGSystem(top_k_documents=top_k_docs)
53
+ rag_system.initialize_vector_store()
54
+
55
+ chat_interface = ChatInterface(rag_system)
56
+ demo = chat_interface.create_interface()
57
+ demo.launch(share=False)
data/chroma_db/ad04fd29-b3fe-456e-a525-757a6756c29e/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16fea540d03a37ecce67de43d3bce99a5ca61a0fcec19cbfe67928ed19064e72
3
+ size 16296000
data/chroma_db/ad04fd29-b3fe-456e-a525-757a6756c29e/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f49783cbcb025a93b8ac35a9c337d4cfdc94f741d1bd1c4b944127b212554a6
3
+ size 100
data/chroma_db/ad04fd29-b3fe-456e-a525-757a6756c29e/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d834576f8575c2eb62318da8a2e7fa1ec58b849d9bc02ae88f09c50c6cf60dad
3
+ size 755153
data/chroma_db/ad04fd29-b3fe-456e-a525-757a6756c29e/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24f6284c00d0f3b2567bb2d48cf91e813e827d7d3045017992afd4f904428f64
3
+ size 56000
data/chroma_db/ad04fd29-b3fe-456e-a525-757a6756c29e/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:610ed567f1bd6179e7df87fe10f2d8c0b0d29cab23628d900a99ce93e3688922
3
+ size 118696
data/chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:399ec17ca61e724b6f5fe0818842f32c046f5f2e2014b9dda310f967b68faeb1
3
+ size 199651328
load_data.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import zipfile
4
+ import shutil
5
+
6
+ import logging
7
+
8
+ def download_and_prepare_data(gdrive_url, zip_filename, extract_to, target_folder):
9
+ """
10
+ Download, extract, and organize data from a Google Drive link.
11
+
12
+ :param gdrive_url: Google Drive URL to download the zip file.
13
+ :param zip_filename: Name for the downloaded zip file.
14
+ :param extract_to: Directory to extract the zip file.
15
+ :param target_folder: Final directory to move extracted content.
16
+ """
17
+ try:
18
+ if os.path.exists(os.path.join(target_folder, "chroma.sqlite3")):
19
+ logging.info(f"Data already exists in {target_folder}")
20
+ return
21
+ # Step 1: Download the file using gdown
22
+ logging.info("Downloading file...")
23
+ subprocess.run(["gdown", gdrive_url, "-O", zip_filename], check=True)
24
+
25
+ # Step 2: Unzip the downloaded file
26
+ logging.info("Unzipping file...")
27
+ with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
28
+ zip_ref.extractall(extract_to)
29
+
30
+ # Step 3: Remove old data folder if it exists
31
+ if os.path.exists(target_folder):
32
+ logging.info(f"Removing existing folder: {target_folder}")
33
+ shutil.rmtree(target_folder)
34
+
35
+ # Step 4: Move the extracted folder to the target location
36
+ logging.info(f"Moving extracted data to {target_folder}")
37
+ extracted_folder = os.path.join(extract_to, os.path.basename(target_folder))
38
+ shutil.move(extracted_folder, target_folder)
39
+
40
+ # Step 5: Remove the downloaded zip file
41
+ logging.info(f"Cleaning up, removing zip file: {zip_filename}")
42
+ os.remove(zip_filename)
43
+
44
+ logging.info("Data preparation completed successfully!")
45
+ except Exception as e:
46
+ logging.info(f"An error occurred: {e}")
47
+
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain-ollama==0.2.2
2
+ langchain-groq==0.2.3
3
+ langchain-community==0.3.14
4
+ langchain-chroma==0.1.4
5
+ langchain-huggingface
6
+ langchain==0.3.14
7
+ ollama==0.4.5
8
+ chromadb==0.5.23
9
+ # OCR
10
+ tqdm==4.67.1
11
+ gradio==5.11.0
12
+ rank_bm25==0.2.2
13
+ groq==0.15.0
14
+ gdown==5.2.0
15
+ einops==0.8.1
src/rag_pipeline/__init__.py ADDED
File without changes
src/rag_pipeline/prompts.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts.chat import (
2
+ ChatPromptTemplate,
3
+ HumanMessagePromptTemplate,
4
+ MessagesPlaceholder,
5
+ SystemMessagePromptTemplate,
6
+ )
7
+
8
+ system_prompt = """
9
+ Vous êtes **Dikoka**, un assistant IA expert en histoire de l'Afrique et en médecine traditionnelle africaine, basé sur des recherches et documents historiques validés.
10
+
11
+ **Instructions :**
12
+ - **Répondez strictement en utilisant uniquement le contexte fourni.**
13
+ - **Résumez les points clés lorsque cela est demandé.**
14
+ - **Maintenez une grande rigueur dans l'exactitude et la neutralité ; évitez toute spéculation ou ajout d'informations externes.**
15
+
16
+ **Directives de réponse :**
17
+ 1. **Réponses fondées uniquement sur le contexte :** Appuyez-vous exclusivement sur le contexte fourni.
18
+ 2. **Informations insuffisantes :** Si les détails manquent, répondez :
19
+ > "Je n'ai pas suffisamment d'informations pour répondre à cette question en fonction du contexte fourni."
20
+ 3. **Demandes concernant la langue :** Si une question est posée dans une langue africaine ou demande une traduction, répondez :
21
+ > "Je ne peux fournir les informations que dans la langue du contexte original. Pourriez-vous reformuler votre question dans cette langue ?"
22
+ 4. **Sujets non pertinents :** Pour les questions qui ne concernent pas :
23
+ - L'histoire de l'Afrique
24
+ - La médecine traditionnelle africaine
25
+
26
+ répondez :
27
+ > "Je n'ai pas d'informations sur ce sujet en fonction du contexte fourni. Pourriez-vous poser une question relative à l'histoire de l'Afrique ou à la médecine traditionnelle africaine ?"
28
+ 5. **Résumés :** Fournissez des résumés concis et structurés (à l'aide de points ou de paragraphes) basés uniquement sur le contexte.
29
+ 6. **Mise en forme :** Organisez vos réponses avec des listes à puces, des listes numérotées, ainsi que des titres et sous-titres lorsque cela est approprié.
30
+
31
+ Contexte :
32
+ {context}
33
+ """
34
+
35
+ # Define the messages for the main chat prompt
36
+ chat_messages = [
37
+ MessagesPlaceholder(variable_name="chat_history"),
38
+ SystemMessagePromptTemplate.from_template(system_prompt),
39
+ HumanMessagePromptTemplate.from_template(
40
+ "Repondre dans la même langue que l'utilisateur:\n{input}"
41
+ ),
42
+ ]
43
+ CHAT_PROMPT = ChatPromptTemplate.from_messages(chat_messages)
44
+
45
+
46
+ contextualize_q_system_prompt = (
47
+ "Votre tâche consiste à formuler une question autonome, claire et compréhensible sans recourir à l'historique de conversation. Veuillez suivre ces instructions :\n"
48
+ "1. Analysez l'historique de conversation ainsi que la dernière question posée par l'utilisateur.\n"
49
+ "2. Reformulez la question en intégrant tout contexte nécessaire pour qu'elle soit compréhensible sans l'historique.\n"
50
+ "3. Si la question initiale est déjà autonome, renvoyez-la telle quelle.\n"
51
+ "4. Conservez l'intention et la langue d'origine de la question.\n"
52
+ "5. Fournissez uniquement la question autonome, sans explications ou texte additionnel.\n"
53
+ "NE répondez PAS à la question."
54
+ )
55
+
56
+ CONTEXTUEL_QUERY_PROMPT = ChatPromptTemplate.from_messages(
57
+ [
58
+ # SystemMessagePromptTemplate.from_template(contextualize_q_system_prompt),
59
+ MessagesPlaceholder("chat_history"),
60
+ HumanMessagePromptTemplate.from_template("{input}"),
61
+ ]
62
+ )
src/rag_pipeline/rag_system.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import List, Optional
4
+
5
+ from langchain.chains.combine_documents import create_stuff_documents_chain
6
+ from langchain.chains.conversational_retrieval.base import (
7
+ BaseConversationalRetrievalChain,
8
+ )
9
+ from langchain.chains.history_aware_retriever import (
10
+ create_history_aware_retriever,
11
+ )
12
+ from langchain.chains.retrieval import create_retrieval_chain
13
+
14
+ from ..utilities.llm_models import get_llm_model_chat
15
+ from ..vector_store.vector_store import VectorStoreManager
16
+ from .prompts import CHAT_PROMPT, CONTEXTUEL_QUERY_PROMPT
17
+
18
+
19
+ class RAGSystem:
20
+ def __init__(
21
+ self,
22
+ docs_dir: str = "data/chunks",
23
+ persist_directory_dir="data/chroma_db",
24
+ batch_size: int = 64,
25
+ top_k_documents=5,
26
+ ):
27
+ self.top_k_documents = top_k_documents
28
+ self.llm = self._get_llm()
29
+ self.chain: Optional[BaseConversationalRetrievalChain] = None
30
+ self.vector_store_management = VectorStoreManager(
31
+ persist_directory_dir, batch_size
32
+ )
33
+ self.docs_dir = docs_dir
34
+
35
+ def _get_llm(
36
+ self,
37
+ ):
38
+ return get_llm_model_chat(temperature=0.1, max_tokens=1000)
39
+
40
+ def load_documents(self) -> List:
41
+ """Load and split documents from the specified directory"""
42
+ return self.vector_store_management.load_and_process_documents(self.docs_dir)
43
+
44
+ def initialize_vector_store(self, documents: List = None):
45
+ """Initialize or load the vector store"""
46
+ self.vector_store_management.initialize_vector_store(documents)
47
+
48
+ def setup_rag_chain(self):
49
+ if self.chain is not None:
50
+ return
51
+ retriever = self.vector_store_management.create_retriever(
52
+ self.llm, self.top_k_documents, bm25_portion=0.03
53
+ )
54
+
55
+ # Contextualize question
56
+ self.history_aware_retriever = create_history_aware_retriever(
57
+ self.llm, retriever, CONTEXTUEL_QUERY_PROMPT
58
+ )
59
+ self.question_answer_chain = create_stuff_documents_chain(self.llm, CHAT_PROMPT)
60
+ self.chain = create_retrieval_chain(
61
+ self.history_aware_retriever, self.question_answer_chain
62
+ )
63
+ logging.info("RAG chain setup complete" + str(self.chain))
64
+
65
+ return self.chain
66
+
67
+ def query(self, question: str, history: list = []):
68
+ """Query the RAG system"""
69
+ if not self.vector_store_management.vs_initialized:
70
+ self.initialize_vector_store()
71
+
72
+ self.setup_rag_chain()
73
+
74
+ for token in self.chain.stream({"input": question, "chat_history": history}):
75
+ if "answer" in token:
76
+ yield token["answer"]
77
+
78
+
79
+ if __name__ == "__main__":
80
+ from glob import glob
81
+
82
+ from dotenv import load_dotenv
83
+
84
+ # loading variables from .env file
85
+ load_dotenv()
86
+
87
+ docs_dir = "data/docs"
88
+ persist_directory_dir = "data/chroma_db"
89
+ batch_size = 64
90
+
91
+ # Initialize RAG system
92
+ rag = RAGSystem(docs_dir, persist_directory_dir, batch_size)
93
+
94
+ if len(glob(os.path.join(persist_directory_dir, "*/*.bin"))):
95
+ rag.initialize_vector_store() # vector store initialized
96
+ else:
97
+ # Load and index documents
98
+ documents = rag.load_documents()
99
+ rag.initialize_vector_store(documents) # documents
100
+
101
+ queries = [
102
+ "Quand a eu lieu la traite négrière ?",
103
+ "Explique moi comment soigner la tiphoide puis le paludisme",
104
+ "Quels étaient les premiers peuples d'afrique centrale et quelles ont été leurs migrations?",
105
+ ]
106
+
107
+ print("Comparaison méthodes de query")
108
+
109
+ for query in queries:
110
+ print("Query: ", query, "\n\n")
111
+ print("1. Méthode simple:--------------------\n")
112
+ rag.query(question=query)
113
+
114
+ print("\n\n2. Méthode par décomposition:-----------------------\n\n")
115
+ rag.query_complex(question=query, verbose=True)
src/utilities/__init__.py ADDED
File without changes
src/utilities/embedding.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Any, List
4
+
5
+ import torch
6
+ from langchain_core.embeddings import Embeddings
7
+ from langchain_huggingface import (
8
+ HuggingFaceEmbeddings,
9
+ HuggingFaceEndpointEmbeddings,
10
+ )
11
+ from pydantic import BaseModel, Field
12
+
13
+
14
+ class CustomEmbedding(BaseModel, Embeddings):
15
+ """
16
+ Custom embedding class that supports both hosted and CPU embeddings.
17
+ """
18
+
19
+ hosted_embedding: HuggingFaceEndpointEmbeddings = Field(
20
+ default_factory=lambda: None
21
+ )
22
+ cpu_embedding: HuggingFaceEmbeddings = Field(default_factory=lambda: None)
23
+ matryoshka_dim: int = Field(default=256)
24
+
25
+ def get_instruction(self) -> str:
26
+ """
27
+ Generates the instruction for the embedding model based on environment variables.
28
+
29
+ Returns:
30
+ str: The instruction string.
31
+ """
32
+ if "nomic" in os.getenv("HF_MODEL"):
33
+ return (
34
+ "query"
35
+ if (os.getenv("IS_APP", "0") == "1")
36
+ else "search_document: "
37
+ )
38
+ return (
39
+ "Represent this sentence for searching relevant passages"
40
+ if (os.getenv("IS_APP", "0") == "1")
41
+ else ""
42
+ )
43
+
44
+ def get_hf_embedd(self) -> HuggingFaceEmbeddings:
45
+ """
46
+ Initializes the HuggingFaceEmbeddings with the appropriate settings.
47
+
48
+ Returns:
49
+ HuggingFaceEmbeddings: The initialized HuggingFaceEmbeddings object.
50
+ """
51
+ return HuggingFaceEmbeddings(
52
+ model_name=os.getenv("HF_MODEL"), # You can replace with any HF model
53
+ model_kwargs={
54
+ "device": "cpu" if not torch.cuda.is_available() else "cuda",
55
+ "trust_remote_code": True,
56
+ },
57
+ encode_kwargs={
58
+ "normalize_embeddings": True,
59
+ "prompt": self.get_instruction(),
60
+ },
61
+ )
62
+
63
+ def __init__(self, matryoshka_dim=256, **kwargs: Any):
64
+ """
65
+ Initializes the CustomEmbedding with the given parameters.
66
+
67
+ Args:
68
+ matryoshka_dim (int): Dimension of the embeddings.
69
+ **kwargs: Additional keyword arguments.
70
+ """
71
+ super().__init__(**kwargs)
72
+ query_instruction = self.get_instruction()
73
+ self.matryoshka_dim = matryoshka_dim
74
+ if torch.cuda.is_available():
75
+ logging.info("CUDA is available")
76
+ self.hosted_embedding = self.get_hf_embedd()
77
+ self.cpu_embedding = self.hosted_embedding
78
+ else:
79
+ logging.info("CUDA is not available")
80
+ self.hosted_embedding = HuggingFaceEndpointEmbeddings(
81
+ model=os.getenv("HF_MODEL"),
82
+ model_kwargs={
83
+ "encode_kwargs": {
84
+ "normalize_embeddings": True,
85
+ "prompt": query_instruction,
86
+ }
87
+ },
88
+ huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
89
+ )
90
+ self.cpu_embedding = self.get_hf_embedd()
91
+
92
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
93
+ """
94
+ Embeds a list of documents using the appropriate embedding model.
95
+
96
+ Args:
97
+ texts (List[str]): List of document texts to embed.
98
+
99
+ Returns:
100
+ List[List[float]]: List of embedded document vectors.
101
+ """
102
+ try:
103
+ embed = self.hosted_embedding.embed_documents(texts)
104
+ except Exception as e:
105
+ logging.warning(f"Issue with batch hosted embedding, moving to CPU: {e}")
106
+ embed = self.cpu_embedding.embed_documents(texts)
107
+ return (
108
+ [e[: self.matryoshka_dim] for e in embed] if self.matryoshka_dim else embed
109
+ )
110
+
111
+ def embed_query(self, text: str) -> List[float]:
112
+ """
113
+ Embeds a single query using the appropriate embedding model.
114
+
115
+ Args:
116
+ text (str): The query text to embed.
117
+
118
+ Returns:
119
+ List[float]: The embedded query vector.
120
+ """
121
+ try:
122
+ logging.info(text)
123
+ embed = self.hosted_embedding.embed_query(text)
124
+ except Exception as e:
125
+ logging.warning(f"Issue with hosted embedding, moving to CPU: {e}")
126
+ embed = self.cpu_embedding.embed_query(text)
127
+ logging.warning(text)
128
+ return embed[: self.matryoshka_dim] if self.matryoshka_dim else embed
src/utilities/llm_models.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import Enum
3
+
4
+ from langchain_groq import ChatGroq
5
+ from langchain_ollama import ChatOllama, OllamaEmbeddings
6
+
7
+ from .embedding import CustomEmbedding
8
+
9
+
10
+ class LLMModel(Enum):
11
+ OLLAMA = "ChatOllama"
12
+ GROQ = "ChatGroq"
13
+
14
+
15
+ def get_llm_model_chat(temperature=0.01, max_tokens: int = None):
16
+ if str(os.getenv("USE_OLLAMA_CHAT")) == "1":
17
+ return ChatOllama(
18
+ model=os.getenv("OLLAMA_MODEL"),
19
+ temperature=temperature,
20
+ num_predict=max_tokens,
21
+ )
22
+ return ChatGroq(
23
+ model=os.getenv("GROQ_MODEL_NAME"),
24
+ temperature=temperature,
25
+ max_tokens=max_tokens,
26
+ )
27
+
28
+
29
+ def get_llm_model_embedding():
30
+ if str(os.getenv("USE_HF_EMBEDDING")) == "1":
31
+ return CustomEmbedding()
32
+ return OllamaEmbeddings(
33
+ model=os.getenv("OLLAM_EMB"),
34
+ base_url=(
35
+ os.getenv("OLLAMA_HOST") if os.getenv("OLLAMA_HOST") is not None else None
36
+ ),
37
+ client_kwargs=(
38
+ {
39
+ "headers": {
40
+ "Authorization": "Bearer " + (os.getenv("OLLAMA_TOKEN") or "")
41
+ }
42
+ }
43
+ if os.getenv("OLLAMA_HOST") is not None
44
+ else None
45
+ ),
46
+ )
src/vector_store/__init__.py ADDED
File without changes
src/vector_store/bivector_store.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+
4
+ from langchain.retrievers import EnsembleRetriever, MultiQueryRetriever
5
+ from langchain_chroma import Chroma
6
+ from langchain_community.retrievers import BM25Retriever
7
+ from langchain_core.documents import Document
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer
10
+
11
+ from ..utilities.llm_models import get_llm_model_embedding
12
+ from .document_loader import DocumentLoader
13
+ from .vector_store import get_collection_name
14
+ from .prompts import DEFAULT_QUERY_PROMPT
15
+
16
+ class VectorStoreManager:
17
+ """
18
+ Manages vector store initialization, updates, and retrieval.
19
+ """
20
+
21
+ def __init__(self, persist_directory: str, batch_size: int = 64):
22
+ """
23
+ Initializes the VectorStoreManager with the given parameters.
24
+
25
+ Args:
26
+ persist_directory (str): Directory to persist the vector store.
27
+ batch_size (int): Number of documents to process in each batch.
28
+ """
29
+ self.persist_directory = persist_directory
30
+ self.batch_size = batch_size
31
+ self.embeddings = get_llm_model_embedding()
32
+ self.collection_name = get_collection_name()
33
+ self.vector_stores: dict[str, Union[Chroma, BM25Retriever]] = {
34
+ "chroma": None,
35
+ "bm25": None,
36
+ }
37
+ self.tokenizer = AutoTokenizer.from_pretrained(
38
+ os.getenv("HF_MODEL", "meta-llama/Llama-3.2-1B")
39
+ )
40
+ self.vs_initialized = False
41
+ self.vector_store = None
42
+
43
+ def _batch_process_documents(self, documents: List[Document]):
44
+ """
45
+ Processes documents in batches for vector store initialization.
46
+
47
+ Args:
48
+ documents (List[Document]): List of documents to process.
49
+ """
50
+ for i in tqdm(
51
+ range(0, len(documents), self.batch_size), desc="Processing documents"
52
+ ):
53
+ batch = documents[i : i + self.batch_size]
54
+
55
+ if not self.vs_initialized:
56
+ self.vector_stores["chroma"] = Chroma.from_documents(
57
+ collection_name=self.collection_name,
58
+ documents=batch,
59
+ embedding=self.embeddings,
60
+ persist_directory=self.persist_directory,
61
+ )
62
+ self.vs_initialized = True
63
+ else:
64
+ self.vector_stores["chroma"].add_documents(batch)
65
+
66
+ self.vector_stores["bm25"] = BM25Retriever.from_documents(
67
+ documents, tokenizer=self.tokenizer
68
+ )
69
+
70
+ def initialize_vector_store(self, documents: List[Document] = None):
71
+ """
72
+ Initializes or loads the vector store.
73
+
74
+ Args:
75
+ documents (List[Document], optional): List of documents to initialize the vector store. Defaults to None.
76
+ """
77
+ if documents:
78
+ self._batch_process_documents(documents)
79
+ else:
80
+ self.vector_stores["chroma"] = Chroma(
81
+ collection_name=self.collection_name,
82
+ persist_directory=self.persist_directory,
83
+ embedding_function=self.embeddings,
84
+ )
85
+ all_documents = self.vector_stores["chroma"].get(
86
+ include=["documents", "metadatas"]
87
+ )
88
+ documents = [
89
+ Document(page_content=content, id=doc_id, metadata=metadata)
90
+ for content, doc_id, metadata in zip(
91
+ all_documents["documents"],
92
+ all_documents["ids"],
93
+ all_documents["metadatas"],
94
+ )
95
+ ]
96
+ self.vector_stores["bm25"] = BM25Retriever.from_documents(documents)
97
+ self.vs_initialized = True
98
+
99
+ def create_retriever(
100
+ self, llm, n_documents: int, bm25_portion: float = 0.8
101
+ ) -> EnsembleRetriever:
102
+ """
103
+ Creates an ensemble retriever combining Chroma and BM25.
104
+
105
+ Args:
106
+ llm: Language model to use for retrieval.
107
+ n_documents (int): Number of documents to retrieve.
108
+ bm25_portion (float): Proportion of BM25 retriever in the ensemble.
109
+
110
+ Returns:
111
+ EnsembleRetriever: The created ensemble retriever.
112
+ """
113
+ self.vector_stores["bm25"].k = n_documents
114
+ self.vector_store = MultiQueryRetriever.from_llm(
115
+ retriever=EnsembleRetriever(
116
+ retrievers=[
117
+ self.vector_stores["bm25"],
118
+ self.vector_stores["chroma"].as_retriever(
119
+ search_kwargs={"k": n_documents}
120
+ ),
121
+ ],
122
+ weights=[bm25_portion, 1 - bm25_portion],
123
+ ),
124
+ llm=llm,
125
+ include_original=True,
126
+ prompt=DEFAULT_QUERY_PROMPT
127
+ )
128
+ return self.vector_store
129
+
130
+ def load_and_process_documents(self, doc_dir) -> List[Document]:
131
+ """
132
+ Loads and processes documents from the specified directory.
133
+
134
+ Returns:
135
+ List[Document]: List of loaded and processed documents.
136
+ """
137
+ loader = DocumentLoader(doc_dir)
138
+ return loader.load_documents()
src/vector_store/document_loader.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from glob import glob
5
+ from typing import List
6
+
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader
9
+ from langchain_core.documents import Document
10
+ from tqdm import tqdm
11
+
12
+
13
+ def sanitize_metadata(metadata: dict) -> dict:
14
+ sanitized = {}
15
+ for key, value in metadata.items():
16
+ if isinstance(value, list):
17
+ sanitized[key] = ", ".join(value)
18
+ elif isinstance(value, (str, int, float, bool)):
19
+ sanitized[key] = value
20
+ else:
21
+ raise ValueError(
22
+ f"Unsupported metadata type for key '{key}': {type(value)}"
23
+ )
24
+ return sanitized
25
+
26
+
27
+ class DocumentLoader:
28
+ """
29
+ Handles loading and splitting documents from directories.
30
+ """
31
+
32
+ def __init__(self, docs_dir: str):
33
+ self.docs_dir = docs_dir
34
+
35
+ def load_text_documents(self) -> List[Document]:
36
+ """Loads and splits text documents."""
37
+ loader = DirectoryLoader(self.docs_dir, glob="**/*.txt", loader_cls=TextLoader)
38
+ documents = loader.load()
39
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
40
+ return splitter.split_documents(documents)
41
+
42
+ def load_json_documents(self) -> List[Document]:
43
+ """Loads and processes JSON documents."""
44
+ files = glob(os.path.join(self.docs_dir, "*.json"))
45
+
46
+ def load_json_file(file_path):
47
+ with open(file_path, "r") as f:
48
+ data = json.load(f)["kwargs"]
49
+ return Document.model_validate(
50
+ {**data, "metadata": sanitize_metadata(data["metadata"])}
51
+ )
52
+
53
+ with ThreadPoolExecutor() as executor:
54
+ documents = list(
55
+ tqdm(
56
+ executor.map(load_json_file, files),
57
+ total=len(files),
58
+ desc="Loading JSON documents",
59
+ )
60
+ )
61
+
62
+ return documents
63
+
64
+ def load_documents(self) -> List[Document]:
65
+ """Determines and loads documents based on file type."""
66
+ if glob(os.path.join(self.docs_dir, "*.json")):
67
+ return self.load_json_documents()
68
+ return self.load_text_documents()
src/vector_store/prompts.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts.prompt import PromptTemplate
2
+
3
+ DEFAULT_QUERY_PROMPT = PromptTemplate(
4
+ input_variables=["question"],
5
+ template="""You are an AI language model assistant tasked with generating alternative versions of a given user question. Your goal is to create 3 different perspectives on the original question to help retrieve relevant documents from a vector database. This approach aims to overcome some limitations of distance-based similarity search.
6
+
7
+ When generating alternative questions, follow these guidelines:
8
+ 1. Maintain the core intent of the original question
9
+ 2. Use different phrasing, synonyms, or sentence structures
10
+ 3. Consider potential related aspects or implications of the question
11
+ 4. Avoid introducing new topics or drastically changing the subject matter
12
+
13
+ Here is the original question:
14
+
15
+ {question}
16
+
17
+ Generate 3 alternative versions of this question. Provide your output as a numbered list, with each alternative question on a new line. Do not include any additional explanation or commentary.
18
+
19
+ Remember, the purpose of these alternative questions is to broaden the search scope while staying relevant to the user's original intent. This will help in retrieving a diverse set of potentially relevant documents from the vector database.
20
+
21
+ Do not include any additional explanation or commentary, just give 3 alternative questions.
22
+ """,
23
+ )
src/vector_store/vector_store.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ from langchain.retrievers import MultiQueryRetriever
5
+ from langchain_chroma import Chroma
6
+ from langchain_core.documents import Document
7
+ from tqdm import tqdm
8
+
9
+ from ..utilities.llm_models import get_llm_model_embedding
10
+ from .document_loader import DocumentLoader
11
+ from .prompts import DEFAULT_QUERY_PROMPT
12
+
13
+ def get_collection_name() -> str:
14
+ """
15
+ Derives the collection name from an environment variable.
16
+
17
+ Returns:
18
+ str: Processed collection name.
19
+ """
20
+ return "medivocate-" + os.getenv("HF_MODEL", "default_model").split(":")[0].split("/")[-1]
21
+
22
+
23
+ class VectorStoreManager:
24
+ """
25
+ Manages vector store initialization, updates, and retrieval.
26
+ """
27
+
28
+ def __init__(self, persist_directory: str, batch_size: int = 64):
29
+ """
30
+ Initializes the VectorStoreManager with the given parameters.
31
+
32
+ Args:
33
+ persist_directory (str): Directory to persist the vector store.
34
+ batch_size (int): Number of documents to process in each batch.
35
+ """
36
+ self.persist_directory = persist_directory
37
+ self.batch_size = batch_size
38
+ self.embeddings = get_llm_model_embedding()
39
+ self.collection_name = get_collection_name()
40
+ self.vector_stores: dict[str, Chroma] = {"chroma": None}
41
+ self.vs_initialized = False
42
+
43
+ def _batch_process_documents(self, documents: List[Document]):
44
+ """
45
+ Processes documents in batches for vector store initialization.
46
+
47
+ Args:
48
+ documents (List[Document]): List of documents to process.
49
+ """
50
+ for i in tqdm(
51
+ range(0, len(documents), self.batch_size), desc="Processing documents"
52
+ ):
53
+ batch = documents[i : i + self.batch_size]
54
+ if not self.vs_initialized:
55
+ self.vector_stores["chroma"] = Chroma.from_documents(
56
+ collection_name=self.collection_name,
57
+ documents=batch,
58
+ embedding=self.embeddings,
59
+ persist_directory=self.persist_directory,
60
+ )
61
+ self.vs_initialized = True
62
+ else:
63
+ self.vector_stores["chroma"].add_documents(batch)
64
+
65
+ def initialize_vector_store(self, documents: List[Document] = None):
66
+ """
67
+ Initializes or loads the vector store.
68
+
69
+ Args:
70
+ documents (List[Document], optional): List of documents to initialize the vector store with.
71
+ """
72
+ if documents:
73
+ self._batch_process_documents(documents)
74
+ else:
75
+ self.vector_stores["chroma"] = Chroma(
76
+ collection_name=self.collection_name,
77
+ persist_directory=self.persist_directory,
78
+ embedding_function=self.embeddings,
79
+ )
80
+ self.vs_initialized = True
81
+
82
+ def create_retriever(
83
+ self, llm, n_documents: int, bm25_portion: float = 0.8
84
+ ) -> MultiQueryRetriever:
85
+ """
86
+ Creates a retriever using Chroma.
87
+
88
+ Args:
89
+ llm: Language model to use for the retriever.
90
+ n_documents (int): Number of documents to retrieve.
91
+ bm25_portion (float): Portion of BM25 to use in the retriever.
92
+
93
+ Returns:
94
+ MultiQueryRetriever: Configured retriever.
95
+ """
96
+ self.vector_store = MultiQueryRetriever.from_llm(
97
+ retriever=self.vector_stores["chroma"].as_retriever(
98
+ search_kwargs={"k": n_documents}
99
+ ),
100
+ llm=llm,
101
+ include_original=True,
102
+ prompt=DEFAULT_QUERY_PROMPT
103
+ )
104
+ return self.vector_store
105
+
106
+ def load_and_process_documents(self, doc_dir: str) -> List[Document]:
107
+ """
108
+ Loads and processes documents from the specified directory.
109
+
110
+ Returns:
111
+ List[Document]: List of processed documents.
112
+ """
113
+ loader = DocumentLoader(doc_dir)
114
+ return loader.load_documents()