Spaces:
Running
Running
theekshana
commited on
Commit
•
395275a
1
Parent(s):
ec5c64c
Upload 30 files
Browse files- .dockerignore +29 -0
- .env.example +22 -0
- .gitattributes +1 -0
- .gitignore +135 -0
- CHANGELOG.txt +2 -0
- Dockerfile +42 -0
- LICENSE +22 -0
- README.md +14 -11
- config.py +36 -0
- configs/__init__.py +5 -0
- configs/logger.py +40 -0
- controller.py +44 -0
- conversationBufferWindowMemory.py +134 -0
- data/__init__.py +5 -0
- data/splitted_texts.jsonl +0 -0
- ensemble_retriever.py +228 -0
- faissDb.py +68 -0
- faiss_embeddings_2024/index.faiss +3 -0
- faiss_embeddings_2024/index.pkl +3 -0
- llm.py +47 -0
- llmChain.py +96 -0
- multi_query_retriever.py +254 -0
- output_parser.py +39 -0
- prompts.py +123 -0
- qaPipeline.py +150 -0
- requirements.txt +0 -0
- retriever.py +137 -0
- schema.py +63 -0
- server.py +173 -0
- utils/__init__.py +0 -0
- utils/utils.py +40 -0
.dockerignore
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore node_modules
|
2 |
+
node_modules
|
3 |
+
|
4 |
+
# Ignore logs
|
5 |
+
logs
|
6 |
+
*.log
|
7 |
+
|
8 |
+
# Ignore temporary files
|
9 |
+
tmp
|
10 |
+
*.tmp
|
11 |
+
|
12 |
+
# Ignore build directories
|
13 |
+
dist
|
14 |
+
build
|
15 |
+
|
16 |
+
# Ignore environment variables
|
17 |
+
.env
|
18 |
+
|
19 |
+
# Ignore Docker files
|
20 |
+
Dockerfile
|
21 |
+
docker-compose.yml
|
22 |
+
|
23 |
+
# Ignore IDE specific files
|
24 |
+
.vscode
|
25 |
+
.idea
|
26 |
+
|
27 |
+
# Ignore OS generated files
|
28 |
+
.DS_Store
|
29 |
+
Thumbs.db
|
.env.example
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
################################################################################
|
2 |
+
### Chat App - Back-End - GENERAL SETTINGS
|
3 |
+
################################################################################
|
4 |
+
|
5 |
+
|
6 |
+
#api app
|
7 |
+
APP_HOST = 127.0.0.1
|
8 |
+
APP_PORT = 8000
|
9 |
+
|
10 |
+
################################################################################
|
11 |
+
### LLM MODELS
|
12 |
+
################################################################################
|
13 |
+
|
14 |
+
#API token keys
|
15 |
+
HUGGINGFACEHUB_API_TOKEN=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
|
16 |
+
ANYSCALE_ENDPOINT_TOKEN=esecret_n1svfld85uklyx5ebaasyiw2m9
|
17 |
+
OPENAI_API_KEY=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
|
18 |
+
|
19 |
+
|
20 |
+
################################################################################
|
21 |
+
### MEMORY
|
22 |
+
################################################################################
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
faiss_embeddings_2024/index.faiss filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
# testing files generated
|
132 |
+
*.txt.json
|
133 |
+
|
134 |
+
*.ipynb
|
135 |
+
env
|
CHANGELOG.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
2023-11-30 pipeline with only document retrieval
|
2 |
+
2024-08-23 azure app serice , open ai 'gpt4o mini'
|
Dockerfile
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Step 1: Use Python 3.11.9 as required
|
2 |
+
FROM python:3.11.9
|
3 |
+
|
4 |
+
# Step 2: Set up environment variables and timezone configuration
|
5 |
+
ENV TZ=Asia/Colombo
|
6 |
+
RUN apt-get update && apt-get install -y libaio1 wget unzip tzdata \
|
7 |
+
&& ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
8 |
+
|
9 |
+
# Step 4: Add a user for running the app (after installations)
|
10 |
+
RUN useradd -m -u 1000 user
|
11 |
+
|
12 |
+
# Step 5: Create the /app directory and set ownership to the new user
|
13 |
+
RUN mkdir -p /app && chown -R user:user /app
|
14 |
+
|
15 |
+
# Step 6: Switch to non-root user after the directory has the right permissions
|
16 |
+
USER user
|
17 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
18 |
+
|
19 |
+
# Step 7: Set up the working directory for the app
|
20 |
+
WORKDIR /app
|
21 |
+
|
22 |
+
# Step 8: Copy the requirements file and install dependencies
|
23 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
24 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
25 |
+
|
26 |
+
# Step 9: Install pipenv and handle pipenv environment
|
27 |
+
RUN pip install pipenv
|
28 |
+
COPY --chown=user . /app
|
29 |
+
RUN pipenv install
|
30 |
+
|
31 |
+
# Step 10: Expose the necessary port (7860 for Hugging Face Spaces)
|
32 |
+
EXPOSE 7860
|
33 |
+
|
34 |
+
# Step 11: Set environment variables for the app
|
35 |
+
ENV APP_HOST=0.0.0.0
|
36 |
+
ENV APP_PORT=7860
|
37 |
+
|
38 |
+
# Step 12: Create logs directory (if necessary)
|
39 |
+
RUN mkdir -p /app/logs
|
40 |
+
|
41 |
+
# Step 13: Run the app using Uvicorn, listening on port 7860
|
42 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
License
|
2 |
+
|
3 |
+
Copyright (2024-2025) AI Labs, IronOne Technologies, LLC
|
4 |
+
All Rights Reserved
|
5 |
+
|
6 |
+
This source code is protected under international copyright law. All rights
|
7 |
+
reserved and protected by the copyright holders.
|
8 |
+
This file is confidential and only available to authorized individuals with the
|
9 |
+
permission of the copyright holders.
|
10 |
+
|
11 |
+
Permission is hereby granted, to {User}. for testing and development purposes.
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in all
|
14 |
+
copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
-
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Boardpac Chat App Test Streamlit
|
3 |
+
emoji: 🏃
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.39.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
short_description: chatbot on central bank regulations
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
config.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AVALIABLE_MODELS=[
|
2 |
+
{
|
3 |
+
"id":"gpt-4o-mini",
|
4 |
+
"model_name":"openai/gpt-4o-mini",
|
5 |
+
"description":"gpt-4o-mini model from openai"
|
6 |
+
}
|
7 |
+
]
|
8 |
+
|
9 |
+
MODELS={
|
10 |
+
"DEFAULT":"openai",
|
11 |
+
"gpt-4o-mini":"openai",
|
12 |
+
|
13 |
+
}
|
14 |
+
|
15 |
+
DATASETS={
|
16 |
+
"DEFAULT":"faiss",
|
17 |
+
"a":"A",
|
18 |
+
"b":"B",
|
19 |
+
"c":"C"
|
20 |
+
|
21 |
+
}
|
22 |
+
|
23 |
+
MEMORY_WINDOW_K = 1
|
24 |
+
|
25 |
+
QA_MODEL_TYPE = "openai"
|
26 |
+
GENERAL_QA_MODEL_TYPE = "openai"
|
27 |
+
ROUTER_MODEL_TYPE = "openai"
|
28 |
+
Multi_Query_MODEL_TYPE = "openai"
|
29 |
+
|
30 |
+
|
31 |
+
ANSWER_TYPES = [
|
32 |
+
"relevant",
|
33 |
+
"greeting",
|
34 |
+
"other",
|
35 |
+
"not sure",
|
36 |
+
]
|
configs/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import os
|
2 |
+
# import sys
|
3 |
+
|
4 |
+
# if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
5 |
+
# sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
configs/logger.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
# from functools import wraps
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
stream_handler = logging.StreamHandler()
|
8 |
+
log_filename = "output.log"
|
9 |
+
file_handler = logging.FileHandler(filename=log_filename)
|
10 |
+
handlers = [stream_handler, file_handler]
|
11 |
+
|
12 |
+
|
13 |
+
class TimeFilter(logging.Filter):
|
14 |
+
def filter(self, record):
|
15 |
+
return "Running" in record.getMessage()
|
16 |
+
|
17 |
+
|
18 |
+
logger.addFilter(TimeFilter())
|
19 |
+
|
20 |
+
# Configure the logging module
|
21 |
+
logging.basicConfig(
|
22 |
+
level=logging.INFO,
|
23 |
+
format="%(name)s %(asctime)s - %(levelname)s - %(message)s",
|
24 |
+
handlers=handlers,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def time_logger(func):
|
29 |
+
"""Decorator function to log time taken by any function."""
|
30 |
+
|
31 |
+
# @wraps(func)
|
32 |
+
def wrapper(*args, **kwargs):
|
33 |
+
start_time = time.time() # Start time before function execution
|
34 |
+
result = func(*args, **kwargs) # Function execution
|
35 |
+
end_time = time.time() # End time after function execution
|
36 |
+
execution_time = end_time - start_time # Calculate execution time
|
37 |
+
logger.info(f"Running {func.__name__}: --- {execution_time} seconds ---")
|
38 |
+
return result
|
39 |
+
|
40 |
+
return wrapper
|
controller.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
import logging
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
from config import AVALIABLE_MODELS , MEMORY_WINDOW_K
|
20 |
+
|
21 |
+
# from qaPipeline import QAPipeline
|
22 |
+
# from qaPipeline_retriever_only import QAPipeline
|
23 |
+
# qaPipeline = QAPipeline()
|
24 |
+
|
25 |
+
from qaPipeline import run_agent
|
26 |
+
|
27 |
+
def get_QA_Answers(userQuery):
|
28 |
+
# model=userQuery.model
|
29 |
+
query=userQuery.content
|
30 |
+
|
31 |
+
# chat_history = userQuery.chat_history[-MEMORY_WINDOW_K:]
|
32 |
+
|
33 |
+
# logger.info(f"model: {model} \n query : {query} \n chat_history : {chat_history}")
|
34 |
+
logger.info(f"query : {query}")
|
35 |
+
# answer= run_agent(query=query, model=model, chat_history=chat_history)
|
36 |
+
answer= run_agent(query=query)
|
37 |
+
logger.info(f"Response: {answer}")
|
38 |
+
return answer
|
39 |
+
|
40 |
+
|
41 |
+
def get_avaliable_models():
|
42 |
+
logger.info("getting avaliable models")
|
43 |
+
return AVALIABLE_MODELS
|
44 |
+
|
conversationBufferWindowMemory.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/11/2020
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
from abc import ABC
|
18 |
+
from typing import Any, Dict, Optional, Tuple
|
19 |
+
# import json
|
20 |
+
|
21 |
+
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
22 |
+
from langchain.memory.utils import get_prompt_input_key
|
23 |
+
from langchain.pydantic_v1 import Field
|
24 |
+
from langchain.schema import BaseChatMessageHistory, BaseMemory
|
25 |
+
|
26 |
+
from typing import List, Union
|
27 |
+
|
28 |
+
# from langchain.memory.chat_memory import BaseChatMemory
|
29 |
+
from langchain.schema.messages import BaseMessage, get_buffer_string
|
30 |
+
|
31 |
+
|
32 |
+
class BaseChatMemory(BaseMemory, ABC):
|
33 |
+
"""Abstract base class for chat memory."""
|
34 |
+
|
35 |
+
chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
|
36 |
+
output_key: Optional[str] = None
|
37 |
+
input_key: Optional[str] = None
|
38 |
+
return_messages: bool = False
|
39 |
+
|
40 |
+
def _get_input_output(
|
41 |
+
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
42 |
+
) -> Tuple[str, str]:
|
43 |
+
|
44 |
+
|
45 |
+
if self.input_key is None:
|
46 |
+
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
47 |
+
else:
|
48 |
+
prompt_input_key = self.input_key
|
49 |
+
|
50 |
+
if self.output_key is None:
|
51 |
+
"""
|
52 |
+
output for agent with LLM chain tool = {answer}
|
53 |
+
output for agent with ConversationalRetrievalChain tool = {'question', 'chat_history', 'answer','source_documents'}
|
54 |
+
"""
|
55 |
+
|
56 |
+
LLM_key = 'output'
|
57 |
+
Retrieval_key = 'answer'
|
58 |
+
if isinstance(outputs[LLM_key], dict):
|
59 |
+
Retrieval_dict = outputs[LLM_key]
|
60 |
+
if Retrieval_key in Retrieval_dict.keys():
|
61 |
+
#output keys are 'answer' , 'source_documents'
|
62 |
+
output = Retrieval_dict[Retrieval_key]
|
63 |
+
else:
|
64 |
+
raise ValueError(f"output key: {LLM_key} not a valid dictionary")
|
65 |
+
|
66 |
+
else:
|
67 |
+
#otherwise output key will be 'output'
|
68 |
+
output_key = list(outputs.keys())[0]
|
69 |
+
output = outputs[output_key]
|
70 |
+
|
71 |
+
# if len(outputs) != 1:
|
72 |
+
# raise ValueError(f"One output key expected, got {outputs.keys()}")
|
73 |
+
|
74 |
+
|
75 |
+
else:
|
76 |
+
output_key = self.output_key
|
77 |
+
output = outputs[output_key]
|
78 |
+
|
79 |
+
return inputs[prompt_input_key], output
|
80 |
+
|
81 |
+
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
82 |
+
"""Save context from this conversation to buffer."""
|
83 |
+
input_str, output_str = self._get_input_output(inputs, outputs)
|
84 |
+
self.chat_memory.add_user_message(input_str)
|
85 |
+
self.chat_memory.add_ai_message(output_str)
|
86 |
+
|
87 |
+
def clear(self) -> None:
|
88 |
+
"""Clear memory contents."""
|
89 |
+
self.chat_memory.clear()
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
class ConversationBufferWindowMemory(BaseChatMemory):
|
96 |
+
"""Buffer for storing conversation memory inside a limited size window."""
|
97 |
+
|
98 |
+
human_prefix: str = "Human"
|
99 |
+
ai_prefix: str = "AI"
|
100 |
+
memory_key: str = "history" #: :meta private:
|
101 |
+
k: int = 5
|
102 |
+
"""Number of messages to store in buffer."""
|
103 |
+
|
104 |
+
@property
|
105 |
+
def buffer(self) -> Union[str, List[BaseMessage]]:
|
106 |
+
"""String buffer of memory."""
|
107 |
+
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
108 |
+
|
109 |
+
@property
|
110 |
+
def buffer_as_str(self) -> str:
|
111 |
+
"""Exposes the buffer as a string in case return_messages is True."""
|
112 |
+
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
113 |
+
return get_buffer_string(
|
114 |
+
messages,
|
115 |
+
human_prefix=self.human_prefix,
|
116 |
+
ai_prefix=self.ai_prefix,
|
117 |
+
)
|
118 |
+
|
119 |
+
@property
|
120 |
+
def buffer_as_messages(self) -> List[BaseMessage]:
|
121 |
+
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
122 |
+
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
123 |
+
|
124 |
+
@property
|
125 |
+
def memory_variables(self) -> List[str]:
|
126 |
+
"""Will always return list of memory variables.
|
127 |
+
|
128 |
+
:meta private:
|
129 |
+
"""
|
130 |
+
return [self.memory_key]
|
131 |
+
|
132 |
+
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
133 |
+
"""Return history buffer."""
|
134 |
+
return {self.memory_key: self.buffer}
|
data/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import os
|
2 |
+
# import sys
|
3 |
+
|
4 |
+
# if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
5 |
+
# sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
data/splitted_texts.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ensemble_retriever.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
"""
|
18 |
+
Ensemble retriever that ensemble the results of
|
19 |
+
multiple retrievers by using weighted Reciprocal Rank Fusion
|
20 |
+
"""
|
21 |
+
|
22 |
+
import os
|
23 |
+
import sys
|
24 |
+
|
25 |
+
from pathlib import Path
|
26 |
+
Path(__file__).resolve().parent.parent
|
27 |
+
|
28 |
+
if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
29 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
30 |
+
|
31 |
+
|
32 |
+
import logging
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
from typing import Any, Dict, List
|
35 |
+
|
36 |
+
from langchain.callbacks.manager import (
|
37 |
+
AsyncCallbackManagerForRetrieverRun,
|
38 |
+
CallbackManagerForRetrieverRun,
|
39 |
+
)
|
40 |
+
from langchain.pydantic_v1 import root_validator
|
41 |
+
from langchain.schema import BaseRetriever, Document
|
42 |
+
|
43 |
+
import numpy as np
|
44 |
+
import pandas as pd
|
45 |
+
|
46 |
+
|
47 |
+
class EnsembleRetriever(BaseRetriever):
|
48 |
+
"""Retriever that ensembles the multiple retrievers.
|
49 |
+
|
50 |
+
It uses a rank fusion.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
retrievers: A list of retrievers to ensemble.
|
54 |
+
weights: A list of weights corresponding to the retrievers. Defaults to equal
|
55 |
+
weighting for all retrievers.
|
56 |
+
c: A constant added to the rank, controlling the balance between the importance
|
57 |
+
of high-ranked items and the consideration given to lower-ranked items.
|
58 |
+
Default is 60.
|
59 |
+
"""
|
60 |
+
|
61 |
+
retrievers: List[BaseRetriever]
|
62 |
+
weights: List[float]
|
63 |
+
c: int = 60
|
64 |
+
date_key: str = "year"
|
65 |
+
top_k: int = 4
|
66 |
+
|
67 |
+
@root_validator(pre=True,allow_reuse=True)
|
68 |
+
def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
69 |
+
if not values.get("weights"):
|
70 |
+
n_retrievers = len(values["retrievers"])
|
71 |
+
values["weights"] = [1 / n_retrievers] * n_retrievers
|
72 |
+
return values
|
73 |
+
|
74 |
+
def _get_relevant_documents(
|
75 |
+
self,
|
76 |
+
query: str,
|
77 |
+
*,
|
78 |
+
run_manager: CallbackManagerForRetrieverRun,
|
79 |
+
) -> List[Document]:
|
80 |
+
"""
|
81 |
+
Get the relevant documents for a given query.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
query: The query to search for.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
A list of reranked documents.
|
88 |
+
"""
|
89 |
+
|
90 |
+
# Get fused result of the retrievers.
|
91 |
+
fused_documents = self.rank_fusion(query, run_manager)
|
92 |
+
|
93 |
+
# check for key exists
|
94 |
+
if fused_documents[0].metadata[self.date_key] != None:
|
95 |
+
doc_dates = pd.to_datetime(
|
96 |
+
[doc.metadata[self.date_key] for doc in fused_documents]
|
97 |
+
)
|
98 |
+
sorted_node_idxs = np.flip(doc_dates.argsort())
|
99 |
+
fused_documents = [fused_documents[idx] for idx in sorted_node_idxs]
|
100 |
+
logger.info('Ensemble Retriever Documents sorted by year')
|
101 |
+
|
102 |
+
# return fused_documents[:self.top_k]
|
103 |
+
return fused_documents
|
104 |
+
|
105 |
+
async def _aget_relevant_documents(
|
106 |
+
self,
|
107 |
+
query: str,
|
108 |
+
*,
|
109 |
+
run_manager: AsyncCallbackManagerForRetrieverRun,
|
110 |
+
) -> List[Document]:
|
111 |
+
"""
|
112 |
+
Asynchronously get the relevant documents for a given query.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
query: The query to search for.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
A list of reranked documents.
|
119 |
+
"""
|
120 |
+
|
121 |
+
# Get fused result of the retrievers.
|
122 |
+
fused_documents = await self.arank_fusion(query, run_manager)
|
123 |
+
|
124 |
+
return fused_documents
|
125 |
+
|
126 |
+
def rank_fusion(
|
127 |
+
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
128 |
+
) -> List[Document]:
|
129 |
+
"""
|
130 |
+
Retrieve the results of the retrievers and use rank_fusion_func to get
|
131 |
+
the final result.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
query: The query to search for.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
A list of reranked documents.
|
138 |
+
"""
|
139 |
+
|
140 |
+
# Get the results of all retrievers.
|
141 |
+
retriever_docs = [
|
142 |
+
retriever.get_relevant_documents(
|
143 |
+
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
144 |
+
)
|
145 |
+
for i, retriever in enumerate(self.retrievers)
|
146 |
+
]
|
147 |
+
|
148 |
+
# apply rank fusion
|
149 |
+
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
150 |
+
|
151 |
+
return fused_documents
|
152 |
+
|
153 |
+
async def arank_fusion(
|
154 |
+
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
|
155 |
+
) -> List[Document]:
|
156 |
+
"""
|
157 |
+
Asynchronously retrieve the results of the retrievers
|
158 |
+
and use rank_fusion_func to get the final result.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
query: The query to search for.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
A list of reranked documents.
|
165 |
+
"""
|
166 |
+
|
167 |
+
# Get the results of all retrievers.
|
168 |
+
retriever_docs = [
|
169 |
+
await retriever.aget_relevant_documents(
|
170 |
+
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
171 |
+
)
|
172 |
+
for i, retriever in enumerate(self.retrievers)
|
173 |
+
]
|
174 |
+
|
175 |
+
# apply rank fusion
|
176 |
+
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
177 |
+
|
178 |
+
return fused_documents
|
179 |
+
|
180 |
+
def weighted_reciprocal_rank(
|
181 |
+
self, doc_lists: List[List[Document]]
|
182 |
+
) -> List[Document]:
|
183 |
+
"""
|
184 |
+
Perform weighted Reciprocal Rank Fusion on multiple rank lists.
|
185 |
+
You can find more details about RRF here:
|
186 |
+
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
|
187 |
+
|
188 |
+
Args:
|
189 |
+
doc_lists: A list of rank lists, where each rank list contains unique items.
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
list: The final aggregated list of items sorted by their weighted RRF
|
193 |
+
scores in descending order.
|
194 |
+
"""
|
195 |
+
if len(doc_lists) != len(self.weights):
|
196 |
+
raise ValueError(
|
197 |
+
"Number of rank lists must be equal to the number of weights."
|
198 |
+
)
|
199 |
+
|
200 |
+
# Create a union of all unique documents in the input doc_lists
|
201 |
+
all_documents = set()
|
202 |
+
for doc_list in doc_lists:
|
203 |
+
for doc in doc_list:
|
204 |
+
all_documents.add(doc.page_content)
|
205 |
+
|
206 |
+
# Initialize the RRF score dictionary for each document
|
207 |
+
rrf_score_dic = {doc: 0.0 for doc in all_documents}
|
208 |
+
|
209 |
+
# Calculate RRF scores for each document
|
210 |
+
for doc_list, weight in zip(doc_lists, self.weights):
|
211 |
+
for rank, doc in enumerate(doc_list, start=1):
|
212 |
+
rrf_score = weight * (1 / (rank + self.c))
|
213 |
+
rrf_score_dic[doc.page_content] += rrf_score
|
214 |
+
|
215 |
+
# Sort documents by their RRF scores in descending order
|
216 |
+
sorted_documents = sorted(
|
217 |
+
rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True
|
218 |
+
)
|
219 |
+
|
220 |
+
# Map the sorted page_content back to the original document objects
|
221 |
+
page_content_to_doc_map = {
|
222 |
+
doc.page_content: doc for doc_list in doc_lists for doc in doc_list
|
223 |
+
}
|
224 |
+
sorted_docs = [
|
225 |
+
page_content_to_doc_map[page_content] for page_content in sorted_documents
|
226 |
+
]
|
227 |
+
|
228 |
+
return sorted_docs
|
faissDb.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 14/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
import logging
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
import os
|
20 |
+
from dotenv import load_dotenv
|
21 |
+
|
22 |
+
# from langchain_text_splitters import RecursiveCharacterTextSplitter
|
23 |
+
# from langchain.docstore.document import Document
|
24 |
+
# from langchain_community.document_loaders import PyPDFLoader
|
25 |
+
# from langchain.document_loaders import TextLoader
|
26 |
+
# from langchain_community.document_loaders import DirectoryLoader
|
27 |
+
from langchain_community.vectorstores import FAISS
|
28 |
+
|
29 |
+
|
30 |
+
chunk_size=2000
|
31 |
+
chunk_overlap=100
|
32 |
+
|
33 |
+
embeddings_model_name = "BAAI/bge-large-en-v1.5"
|
34 |
+
persist_directory = "faiss_embeddings_2024"
|
35 |
+
|
36 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
37 |
+
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
|
38 |
+
print(f"> Local Embeddings loading")
|
39 |
+
|
40 |
+
load_dotenv()
|
41 |
+
|
42 |
+
# from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
43 |
+
# inference_api_key = os.environ.get('HUGGINGFACEHUB_API_TOKEN')
|
44 |
+
# embeddings = HuggingFaceInferenceAPIEmbeddings(
|
45 |
+
# api_key=inference_api_key, model_name=embeddings_model_name
|
46 |
+
# )
|
47 |
+
# print(f"> HuggingFace InferenceAPI Embeddings loading")
|
48 |
+
# print(f"> --- ---- ---- HuggingFace api key: {inference_api_key}")
|
49 |
+
|
50 |
+
|
51 |
+
# def create_faiss():
|
52 |
+
# # documents = DirectoryLoader(persist_directory, loader_cls=PyMuPDFLoader).load()
|
53 |
+
# documents = DirectoryLoader("CBSL", loader_cls=PyPDFLoader).load()
|
54 |
+
|
55 |
+
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
56 |
+
# texts = text_splitter.split_documents(documents)
|
57 |
+
# vectorstore = FAISS.from_documents(texts, embeddings)
|
58 |
+
# vectorstore.save_local("faiss_index")
|
59 |
+
|
60 |
+
|
61 |
+
def load_FAISS_store():
|
62 |
+
try:
|
63 |
+
print(f"> {persist_directory} loading")
|
64 |
+
logger.info(f"{persist_directory} loaded")
|
65 |
+
return FAISS.load_local(persist_directory, embeddings, allow_dangerous_deserialization=True)
|
66 |
+
except Exception as e:
|
67 |
+
logger.exception(e)
|
68 |
+
raise e
|
faiss_embeddings_2024/index.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a3087f02172887cbf8bfb0fb3b371843548619c2a2873fdf4629339e2031a2c1
|
3 |
+
size 10895405
|
faiss_embeddings_2024/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3da917c3758e2bfe0aedbd050199f4c80ec372d5b0349b49126b790fb1757db9
|
3 |
+
size 3935715
|
llm.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
# import time
|
19 |
+
import logging
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
from dotenv import load_dotenv
|
22 |
+
|
23 |
+
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
24 |
+
from langchain_openai import ChatOpenAI
|
25 |
+
|
26 |
+
load_dotenv()
|
27 |
+
|
28 |
+
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
29 |
+
# openai_api_key = "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"
|
30 |
+
print(f"--- ---- ---- openai_api_key: {openai_api_key}")
|
31 |
+
|
32 |
+
verbose = os.environ.get('VERBOSE')
|
33 |
+
|
34 |
+
def get_model(model_type):
|
35 |
+
|
36 |
+
match model_type:
|
37 |
+
case "openai":
|
38 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, openai_api_key=openai_api_key)
|
39 |
+
case _default:
|
40 |
+
# raise exception if model_type is not supported
|
41 |
+
msg=f"Model type '{model_type}' is not supported. Please choose a valid one"
|
42 |
+
logger.error(msg)
|
43 |
+
return Exception(msg)
|
44 |
+
|
45 |
+
|
46 |
+
logger.info(f"model_type: {model_type} loaded:")
|
47 |
+
return llm
|
llmChain.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import logging
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
from dotenv import load_dotenv
|
21 |
+
|
22 |
+
load_dotenv()
|
23 |
+
|
24 |
+
verbose = os.environ.get('VERBOSE')
|
25 |
+
|
26 |
+
from llm import get_model
|
27 |
+
from langchain.chains import ConversationalRetrievalChain
|
28 |
+
# from conversationBufferWindowMemory import ConversationBufferWindowMemory
|
29 |
+
|
30 |
+
# from langchain.prompts import PromptTemplate
|
31 |
+
from langchain.chains import LLMChain
|
32 |
+
|
33 |
+
from prompts import retrieval_qa_chain_prompt, document_combine_prompt, general_qa_chain_prompt, router_prompt
|
34 |
+
|
35 |
+
def get_qa_chain(model_type,retriever):
|
36 |
+
logger.info("creating qa_chain")
|
37 |
+
|
38 |
+
try:
|
39 |
+
qa_llm = get_model(model_type)
|
40 |
+
|
41 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
42 |
+
llm=qa_llm,
|
43 |
+
chain_type="stuff",
|
44 |
+
retriever = retriever,
|
45 |
+
# retriever = self.retriever(search_kwargs={"k": target_source_chunks}
|
46 |
+
return_source_documents= True,
|
47 |
+
get_chat_history=lambda h : h,
|
48 |
+
combine_docs_chain_kwargs={
|
49 |
+
"prompt": retrieval_qa_chain_prompt,
|
50 |
+
"document_prompt": document_combine_prompt,
|
51 |
+
},
|
52 |
+
verbose=True,
|
53 |
+
# memory=memory,
|
54 |
+
)
|
55 |
+
|
56 |
+
logger.info("qa_chain created")
|
57 |
+
return qa_chain
|
58 |
+
|
59 |
+
except Exception as e:
|
60 |
+
msg=f"Error : {e}"
|
61 |
+
logger.exception(msg)
|
62 |
+
raise e
|
63 |
+
|
64 |
+
|
65 |
+
def get_general_qa_chain(model_type):
|
66 |
+
logger.info("creating general_qa_chain")
|
67 |
+
|
68 |
+
try:
|
69 |
+
general_qa_llm = get_model(model_type)
|
70 |
+
general_qa_chain = LLMChain(llm=general_qa_llm, prompt=general_qa_chain_prompt)
|
71 |
+
|
72 |
+
logger.info("general_qa_chain created")
|
73 |
+
return general_qa_chain
|
74 |
+
|
75 |
+
except Exception as e:
|
76 |
+
msg=f"Error : {e}"
|
77 |
+
logger.exception(msg)
|
78 |
+
raise e
|
79 |
+
|
80 |
+
|
81 |
+
def get_router_chain(model_type):
|
82 |
+
logger.info("creating router_chain")
|
83 |
+
|
84 |
+
try:
|
85 |
+
router_llm = get_model(model_type)
|
86 |
+
router_chain = LLMChain(llm=router_llm, prompt=router_prompt)
|
87 |
+
|
88 |
+
logger.info("router_chain created")
|
89 |
+
return router_chain
|
90 |
+
|
91 |
+
except Exception as e:
|
92 |
+
msg=f"Error : {e}"
|
93 |
+
logger.exception(msg)
|
94 |
+
raise e
|
95 |
+
|
96 |
+
|
multi_query_retriever.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
/*************************************************************************
|
4 |
+
*
|
5 |
+
* CONFIDENTIAL
|
6 |
+
* __________________
|
7 |
+
*
|
8 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
9 |
+
* All Rights Reserved
|
10 |
+
*
|
11 |
+
* Author : Theekshana Samaradiwakara
|
12 |
+
* Description :Python Backend API to chat with private data
|
13 |
+
* CreatedDate : 14/11/2023
|
14 |
+
* LastModifiedDate : 21/03/2024
|
15 |
+
*************************************************************************/
|
16 |
+
"""
|
17 |
+
|
18 |
+
import asyncio
|
19 |
+
import logging
|
20 |
+
from typing import List, Optional, Sequence
|
21 |
+
|
22 |
+
from langchain_core.callbacks import (
|
23 |
+
AsyncCallbackManagerForRetrieverRun,
|
24 |
+
CallbackManagerForRetrieverRun,
|
25 |
+
)
|
26 |
+
from langchain_core.documents import Document
|
27 |
+
from langchain_core.language_models import BaseLanguageModel
|
28 |
+
from langchain_core.output_parsers import BaseOutputParser
|
29 |
+
from langchain_core.prompts.prompt import PromptTemplate
|
30 |
+
from langchain_core.retrievers import BaseRetriever
|
31 |
+
|
32 |
+
from langchain.chains.llm import LLMChain
|
33 |
+
|
34 |
+
import numpy as np
|
35 |
+
import pandas as pd
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
from prompts import MULTY_QUERY_PROMPT
|
40 |
+
|
41 |
+
class LineListOutputParser(BaseOutputParser[List[str]]):
|
42 |
+
"""Output parser for a list of lines."""
|
43 |
+
|
44 |
+
def parse(self, text: str) -> List[str]:
|
45 |
+
lines = text.strip().split("\n")
|
46 |
+
return lines
|
47 |
+
|
48 |
+
|
49 |
+
# Default prompt
|
50 |
+
# DEFAULT_QUERY_PROMPT = PromptTemplate(
|
51 |
+
# input_variables=["question"],
|
52 |
+
# template="""You are an AI language model assistant. Your task is
|
53 |
+
# to generate 3 different versions of the given user
|
54 |
+
# question to retrieve relevant documents from a vector database.
|
55 |
+
# By generating multiple perspectives on the user question,
|
56 |
+
# your goal is to help the user overcome some of the limitations
|
57 |
+
# of distance-based similarity search. Provide these alternative
|
58 |
+
# questions separated by newlines. Original question: {question}""",
|
59 |
+
# )
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def _unique_documents(documents: Sequence[Document]) -> List[Document]:
|
65 |
+
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
|
66 |
+
|
67 |
+
|
68 |
+
class MultiQueryRetriever(BaseRetriever):
|
69 |
+
"""Given a query, use an LLM to write a set of queries.
|
70 |
+
|
71 |
+
Retrieve docs for each query. Return the unique union of all retrieved docs.
|
72 |
+
"""
|
73 |
+
|
74 |
+
retriever: BaseRetriever
|
75 |
+
llm_chain: LLMChain
|
76 |
+
verbose: bool = True
|
77 |
+
parser_key: str = "lines"
|
78 |
+
"""DEPRECATED. parser_key is no longer used and should not be specified."""
|
79 |
+
include_original: bool = False
|
80 |
+
"""Whether to include the original query in the list of generated queries."""
|
81 |
+
date_key: str = "year"
|
82 |
+
top_k: int = 4
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def from_llm(
|
86 |
+
cls,
|
87 |
+
retriever: BaseRetriever,
|
88 |
+
llm: BaseLanguageModel,
|
89 |
+
prompt: PromptTemplate = MULTY_QUERY_PROMPT,
|
90 |
+
parser_key: Optional[str] = None,
|
91 |
+
include_original: bool = False,
|
92 |
+
) -> "MultiQueryRetriever":
|
93 |
+
"""Initialize from llm using default template.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
retriever: retriever to query documents from
|
97 |
+
llm: llm for query generation using DEFAULT_QUERY_PROMPT
|
98 |
+
include_original: Whether to include the original query in the list of
|
99 |
+
generated queries.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
MultiQueryRetriever
|
103 |
+
"""
|
104 |
+
output_parser = LineListOutputParser()
|
105 |
+
llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser)
|
106 |
+
return cls(
|
107 |
+
retriever=retriever,
|
108 |
+
llm_chain=llm_chain,
|
109 |
+
include_original=include_original,
|
110 |
+
)
|
111 |
+
|
112 |
+
async def _aget_relevant_documents(
|
113 |
+
self,
|
114 |
+
query: str,
|
115 |
+
*,
|
116 |
+
run_manager: AsyncCallbackManagerForRetrieverRun,
|
117 |
+
) -> List[Document]:
|
118 |
+
"""Get relevant documents given a user query.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
question: user query
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Unique union of relevant documents from all generated queries
|
125 |
+
"""
|
126 |
+
queries = await self.agenerate_queries(query, run_manager)
|
127 |
+
if self.include_original:
|
128 |
+
queries.append(query)
|
129 |
+
documents = await self.aretrieve_documents(queries, run_manager)
|
130 |
+
return self.unique_union(documents)
|
131 |
+
|
132 |
+
|
133 |
+
async def agenerate_queries(
|
134 |
+
self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun
|
135 |
+
) -> List[str]:
|
136 |
+
"""Generate queries based upon user input.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
question: user query
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
List of LLM generated queries that are similar to the user input
|
143 |
+
"""
|
144 |
+
response = await self.llm_chain.ainvoke(
|
145 |
+
inputs={"question": question}, callbacks=run_manager.get_child()
|
146 |
+
)
|
147 |
+
lines = response["text"]
|
148 |
+
if self.verbose:
|
149 |
+
logger.info(f"Generated queries: {lines}")
|
150 |
+
return lines
|
151 |
+
|
152 |
+
async def aretrieve_documents(
|
153 |
+
self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun
|
154 |
+
) -> List[Document]:
|
155 |
+
"""Run all LLM generated queries.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
queries: query list
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
List of retrieved Documents
|
162 |
+
"""
|
163 |
+
document_lists = await asyncio.gather(
|
164 |
+
*(
|
165 |
+
self.retriever.aget_relevant_documents(
|
166 |
+
query, callbacks=run_manager.get_child()
|
167 |
+
)
|
168 |
+
for query in queries
|
169 |
+
)
|
170 |
+
)
|
171 |
+
return [doc for docs in document_lists for doc in docs]
|
172 |
+
|
173 |
+
def _get_relevant_documents(
|
174 |
+
self,
|
175 |
+
query: str,
|
176 |
+
*,
|
177 |
+
run_manager: CallbackManagerForRetrieverRun,
|
178 |
+
) -> List[Document]:
|
179 |
+
"""Get relevant documents given a user query.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
question: user query
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
Unique union of relevant documents from all generated queries
|
186 |
+
"""
|
187 |
+
queries = self.generate_queries(query, run_manager)
|
188 |
+
if self.include_original:
|
189 |
+
queries.append(query)
|
190 |
+
documents = self.retrieve_documents(queries, run_manager)
|
191 |
+
fused_documents= self.unique_union(documents)
|
192 |
+
# check for key exists
|
193 |
+
if fused_documents[0].metadata[self.date_key] != None:
|
194 |
+
doc_dates = pd.to_datetime(
|
195 |
+
[doc.metadata[self.date_key] for doc in fused_documents]
|
196 |
+
)
|
197 |
+
sorted_node_idxs = np.flip(doc_dates.argsort())
|
198 |
+
fused_documents = [fused_documents[idx] for idx in sorted_node_idxs]
|
199 |
+
logger.info('Documents sorted by year')
|
200 |
+
|
201 |
+
return fused_documents[:self.top_k]
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
def generate_queries(
|
206 |
+
self, question: str, run_manager: CallbackManagerForRetrieverRun
|
207 |
+
) -> List[str]:
|
208 |
+
"""Generate queries based upon user input.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
question: user query
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
List of LLM generated queries that are similar to the user input
|
215 |
+
"""
|
216 |
+
response = self.llm_chain.invoke(
|
217 |
+
{"question": question}, callbacks=run_manager.get_child()
|
218 |
+
)
|
219 |
+
lines = response["text"]
|
220 |
+
if self.verbose:
|
221 |
+
logger.info(f"Generated queries: {lines}")
|
222 |
+
return lines
|
223 |
+
|
224 |
+
def retrieve_documents(
|
225 |
+
self, queries: List[str], run_manager: CallbackManagerForRetrieverRun
|
226 |
+
) -> List[Document]:
|
227 |
+
"""Run all LLM generated queries.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
queries: query list
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
List of retrieved Documents
|
234 |
+
"""
|
235 |
+
documents = []
|
236 |
+
for query in queries:
|
237 |
+
logger.info(f"MQ Retriever question: {query}")
|
238 |
+
docs = self.retriever.get_relevant_documents(
|
239 |
+
query, callbacks=run_manager.get_child()
|
240 |
+
)
|
241 |
+
documents.extend(docs)
|
242 |
+
return documents
|
243 |
+
|
244 |
+
def unique_union(self, documents: List[Document]) -> List[Document]:
|
245 |
+
"""Get unique Documents.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
documents: List of retrieved Documents
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
List of unique retrieved Documents
|
252 |
+
"""
|
253 |
+
return _unique_documents(documents)
|
254 |
+
|
output_parser.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
def qa_chain_output_parser(result):
|
18 |
+
return {
|
19 |
+
"question": result["question"],
|
20 |
+
"answer": result["answer"],
|
21 |
+
"source_documents": result["source_documents"]
|
22 |
+
}
|
23 |
+
|
24 |
+
def general_qa_chain_output_parser(result):
|
25 |
+
return {
|
26 |
+
"question": result["question"],
|
27 |
+
"answer": result["text"],
|
28 |
+
"source_documents": []
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
def out_of_domain_chain_parser(query):
|
33 |
+
return {
|
34 |
+
"question": query,
|
35 |
+
"answer":"sorry this question is out of my domain.",
|
36 |
+
"source_documents":[]
|
37 |
+
}
|
38 |
+
|
39 |
+
|
prompts.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2024-2025) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 19/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
from langchain.prompts import PromptTemplate
|
18 |
+
|
19 |
+
# multi query prompt
|
20 |
+
MULTY_QUERY_PROMPT = PromptTemplate(
|
21 |
+
input_variables=["question"],
|
22 |
+
template="""You are an AI language model assistant. Your task is to generate three
|
23 |
+
different versions of the given user question to retrieve relevant documents from a vector
|
24 |
+
database. By generating multiple perspectives on the user question, your goal is to help
|
25 |
+
the user overcome some of the limitations of the distance-based similarity search.
|
26 |
+
Provide these alternative questions separated by newlines.
|
27 |
+
|
28 |
+
Dont add anything extra before or after to the 3 questions. Just give 3 lines with 3 questions.
|
29 |
+
Just provide 3 lines having 3 questions only.
|
30 |
+
Answer should be in following format.
|
31 |
+
|
32 |
+
1. alternative question 1
|
33 |
+
2. alternative question 2
|
34 |
+
3. alternative question 3
|
35 |
+
|
36 |
+
Original question: {question}""",
|
37 |
+
)
|
38 |
+
|
39 |
+
#retrieval prompt
|
40 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
41 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
42 |
+
|
43 |
+
retrieval_qa_template = (
|
44 |
+
"""<<SYS>>
|
45 |
+
You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
|
46 |
+
|
47 |
+
please answer the question based on the chat history provided below. Answer should be short and simple as possible and on to the point.
|
48 |
+
<chat history>: {chat_history}
|
49 |
+
|
50 |
+
If the question is related to welcomes and greetings answer accordingly.
|
51 |
+
|
52 |
+
Else If the question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations.
|
53 |
+
please answer the question based only on the information provided in following central bank documents published in various years.
|
54 |
+
The published year is mentioned as the metadata 'year' of each source document.
|
55 |
+
Please notice that content of a one document of a past year can updated by a new document from a recent year.
|
56 |
+
Always try to answer with latest information and mention the year which information extracted.
|
57 |
+
If you dont know the answer say you dont know, dont try to makeup answers. Dont add any extra details that is not mentioned in the context.
|
58 |
+
|
59 |
+
<</SYS>>
|
60 |
+
|
61 |
+
[INST]
|
62 |
+
<DOCUMENTS>
|
63 |
+
{context}
|
64 |
+
</DOCUMENTS>
|
65 |
+
|
66 |
+
Question : {question}[/INST]"""
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
retrieval_qa_chain_prompt = PromptTemplate(
|
71 |
+
input_variables=["question", "context", "chat_history"],
|
72 |
+
template=retrieval_qa_template
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
#document combine prompt
|
78 |
+
document_combine_prompt = PromptTemplate(
|
79 |
+
input_variables=["source","year", "page","page_content"],
|
80 |
+
template=
|
81 |
+
"""<doc> source: {source}, year: {year}, page: {page}, page content: {page_content} </doc>"""
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
router_template_Mixtral_V0= """
|
86 |
+
You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
|
87 |
+
|
88 |
+
If a user asks a question you have to classify it to following 3 types Relevant, Greeting, Other.
|
89 |
+
|
90 |
+
"Relevant”: If the question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations.
|
91 |
+
"Greeting”: If the question is a greeting like good morning, hi my name is., thank you or General Question ask about the AI assistance of a company boardpac.
|
92 |
+
"Other”: If the question is not related to research papers.
|
93 |
+
|
94 |
+
Give the correct name of question type. If you are not sure return "Not Sure" instead.
|
95 |
+
|
96 |
+
Question : {question}
|
97 |
+
"""
|
98 |
+
router_prompt=PromptTemplate.from_template(router_template_Mixtral_V0)
|
99 |
+
|
100 |
+
|
101 |
+
general_qa_template_Mixtral_V0= """
|
102 |
+
You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
|
103 |
+
you can answer Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations related question .
|
104 |
+
|
105 |
+
Is the provided question below a greeting? First, evaluate whether the input resembles a typical greeting or not.
|
106 |
+
|
107 |
+
Greetings are used to say 'hello' and 'how are you?' and to say 'goodbye' and 'nice speaking with you.' and 'hi, I'm (user's name).'
|
108 |
+
Greetings are words used when we want to introduce ourselves to others and when we want to find out how someone is feeling.
|
109 |
+
|
110 |
+
You can only reply to the user's greetings.
|
111 |
+
If the question is a greeting, reply accordingly as the AI assistant of company boardpac.
|
112 |
+
If the question is not related to greetings and research papers, say that it is out of your domain.
|
113 |
+
If the question is not clear enough, ask for more details and don't try to make up answers.
|
114 |
+
|
115 |
+
Answer should be polite, short, and simple.
|
116 |
+
|
117 |
+
Additionally, it's important to note that this AI assistant has access to an internal collection of research papers, and answers can be provided using the information available in those CBSL Dataset.
|
118 |
+
|
119 |
+
Question: {question}
|
120 |
+
"""
|
121 |
+
|
122 |
+
general_qa_chain_prompt = PromptTemplate.from_template(general_qa_template_Mixtral_V0)
|
123 |
+
|
qaPipeline.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import time
|
19 |
+
import logging
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
from dotenv import load_dotenv
|
22 |
+
from fastapi import HTTPException
|
23 |
+
from llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
24 |
+
from output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
25 |
+
|
26 |
+
from config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
|
27 |
+
from retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
|
28 |
+
load_dotenv()
|
29 |
+
|
30 |
+
verbose = os.environ.get('VERBOSE')
|
31 |
+
|
32 |
+
qa_model_type=QA_MODEL_TYPE
|
33 |
+
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
34 |
+
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
|
35 |
+
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
|
36 |
+
# model_type="tiiuae/falcon-7b-instruct"
|
37 |
+
|
38 |
+
# retriever=load_faiss_retriever()
|
39 |
+
retriever=load_ensemble_retriever()
|
40 |
+
# retriever=load_multi_query_retriever(multi_query_model_type)
|
41 |
+
logger.info("retriever loaded:")
|
42 |
+
|
43 |
+
qa_chain= get_qa_chain(qa_model_type,retriever)
|
44 |
+
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
45 |
+
router_chain= get_router_chain(router_model_type)
|
46 |
+
|
47 |
+
def chain_selector(chain_type, query):
|
48 |
+
chain_type = chain_type.lower().strip()
|
49 |
+
logger.info(f"chain_selector : chain_type: {chain_type} Question: {query}")
|
50 |
+
if "greeting" in chain_type:
|
51 |
+
return run_general_qa_chain(query)
|
52 |
+
elif "other" in chain_type:
|
53 |
+
return run_out_of_domain_chain(query)
|
54 |
+
elif ("relevant" in chain_type) or ("not sure" in chain_type) :
|
55 |
+
return run_qa_chain(query)
|
56 |
+
else:
|
57 |
+
raise ValueError(
|
58 |
+
f"Received invalid type '{chain_type}'"
|
59 |
+
)
|
60 |
+
|
61 |
+
def run_agent(query):
|
62 |
+
try:
|
63 |
+
logger.info(f"run_agent : Question: {query}")
|
64 |
+
print(f"---------------- run_agent : Question: {query} ----------------")
|
65 |
+
# Get the answer from the chain
|
66 |
+
start = time.time()
|
67 |
+
chain_type = run_router_chain(query)
|
68 |
+
res = chain_selector(chain_type,query)
|
69 |
+
end = time.time()
|
70 |
+
|
71 |
+
# log the result
|
72 |
+
logger.error(f"---------------- Answer (took {round(end - start, 2)} s.) \n: {res}")
|
73 |
+
print(f" \n ---------------- Answer (took {round(end - start, 2)} s.): -------------- \n")
|
74 |
+
|
75 |
+
return res
|
76 |
+
|
77 |
+
except HTTPException as e:
|
78 |
+
print('HTTPException eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee')
|
79 |
+
print(e)
|
80 |
+
logger.exception(e)
|
81 |
+
raise e
|
82 |
+
|
83 |
+
except Exception as e:
|
84 |
+
print('Exception eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee')
|
85 |
+
print(e)
|
86 |
+
logger.exception(e)
|
87 |
+
raise e
|
88 |
+
|
89 |
+
|
90 |
+
def run_router_chain(query):
|
91 |
+
try:
|
92 |
+
logger.info(f"run_router_chain : Question: {query}")
|
93 |
+
# Get the answer from the chain
|
94 |
+
start = time.time()
|
95 |
+
chain_type = router_chain.invoke(query)['text']
|
96 |
+
end = time.time()
|
97 |
+
|
98 |
+
# log the result
|
99 |
+
logger.info(f"Answer (took {round(end - start, 2)} s.) chain_type: {chain_type}")
|
100 |
+
|
101 |
+
return chain_type
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
logger.exception(e)
|
105 |
+
raise e
|
106 |
+
|
107 |
+
|
108 |
+
def run_qa_chain(query):
|
109 |
+
try:
|
110 |
+
logger.info(f"run_qa_chain : Question: {query}")
|
111 |
+
# Get the answer from the chain
|
112 |
+
start = time.time()
|
113 |
+
# res = qa_chain(query)
|
114 |
+
res = qa_chain.invoke({"question": query, "chat_history":""})
|
115 |
+
# res = response
|
116 |
+
# answer, docs = res['result'],res['source_documents']
|
117 |
+
end = time.time()
|
118 |
+
|
119 |
+
# log the result
|
120 |
+
logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
|
121 |
+
|
122 |
+
return qa_chain_output_parser(res)
|
123 |
+
|
124 |
+
except Exception as e:
|
125 |
+
logger.exception(e)
|
126 |
+
raise e
|
127 |
+
|
128 |
+
|
129 |
+
def run_general_qa_chain(query):
|
130 |
+
try:
|
131 |
+
logger.info(f"run_general_qa_chain : Question: {query}")
|
132 |
+
|
133 |
+
# Get the answer from the chain
|
134 |
+
start = time.time()
|
135 |
+
res = general_qa_chain.invoke(query)
|
136 |
+
end = time.time()
|
137 |
+
|
138 |
+
# log the result
|
139 |
+
|
140 |
+
logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
|
141 |
+
|
142 |
+
return general_qa_chain_output_parser(res)
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
logger.exception(e)
|
146 |
+
raise e
|
147 |
+
|
148 |
+
|
149 |
+
def run_out_of_domain_chain(query):
|
150 |
+
return out_of_domain_chain_parser(query)
|
requirements.txt
ADDED
Binary file (2.24 kB). View file
|
|
retriever.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
/*************************************************************************
|
4 |
+
*
|
5 |
+
* CONFIDENTIAL
|
6 |
+
* __________________
|
7 |
+
*
|
8 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
9 |
+
* All Rights Reserved
|
10 |
+
*
|
11 |
+
* Author : Theekshana Samaradiwakara
|
12 |
+
* Description :Python Backend API to chat with private data
|
13 |
+
* CreatedDate : 19/03/2023
|
14 |
+
* LastModifiedDate : 19/03/2024
|
15 |
+
*************************************************************************/
|
16 |
+
"""
|
17 |
+
|
18 |
+
"""
|
19 |
+
Ensemble retriever that ensemble the results of
|
20 |
+
multiple retrievers by using weighted Reciprocal Rank Fusion
|
21 |
+
"""
|
22 |
+
|
23 |
+
import logging
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
from faissDb import load_FAISS_store
|
27 |
+
|
28 |
+
from langchain_community.retrievers import BM25Retriever
|
29 |
+
from langchain_community.document_loaders import PyPDFLoader
|
30 |
+
from langchain_community.document_loaders import DirectoryLoader
|
31 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
32 |
+
|
33 |
+
from langchain.schema import Document
|
34 |
+
from typing import Iterable
|
35 |
+
import json
|
36 |
+
|
37 |
+
def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None:
|
38 |
+
with open(file_path, 'w') as jsonl_file:
|
39 |
+
for doc in array:
|
40 |
+
jsonl_file.write(doc.json() + '\n')
|
41 |
+
|
42 |
+
def load_docs_from_jsonl(file_path)->Iterable[Document]:
|
43 |
+
array = []
|
44 |
+
with open(file_path, 'r') as jsonl_file:
|
45 |
+
for line in jsonl_file:
|
46 |
+
data = json.loads(line)
|
47 |
+
obj = Document(**data)
|
48 |
+
array.append(obj)
|
49 |
+
return array
|
50 |
+
|
51 |
+
def split_documents():
|
52 |
+
chunk_size=2000
|
53 |
+
chunk_overlap=100
|
54 |
+
|
55 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
56 |
+
|
57 |
+
years = [2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024]
|
58 |
+
docs_list=[]
|
59 |
+
splits_list=[]
|
60 |
+
|
61 |
+
|
62 |
+
for year in years:
|
63 |
+
data_path= f"data/CBSL/{year}"
|
64 |
+
logger.info(f"Loading year : {data_path}")
|
65 |
+
|
66 |
+
documents = DirectoryLoader(data_path, loader_cls=PyPDFLoader).load()
|
67 |
+
|
68 |
+
for doc in documents:
|
69 |
+
doc.metadata['year']=year
|
70 |
+
logger.info(f"{doc.metadata['year']} : {doc.metadata['source']}" )
|
71 |
+
docs_list.append(doc)
|
72 |
+
|
73 |
+
texts = text_splitter.split_documents(documents)
|
74 |
+
for text in texts:
|
75 |
+
splits_list.append(text)
|
76 |
+
|
77 |
+
splitted_texts_file='data/splitted_texts.jsonl'
|
78 |
+
save_docs_to_jsonl(splits_list,splitted_texts_file)
|
79 |
+
|
80 |
+
from ensemble_retriever import EnsembleRetriever
|
81 |
+
from multi_query_retriever import MultiQueryRetriever
|
82 |
+
|
83 |
+
def load_faiss_retriever():
|
84 |
+
try:
|
85 |
+
vectorstore=load_FAISS_store()
|
86 |
+
retriever = vectorstore.as_retriever(
|
87 |
+
# search_type="mmr",
|
88 |
+
search_kwargs={'k': 5, 'fetch_k': 10}
|
89 |
+
)
|
90 |
+
logger.info("FAISS Retriever loaded:")
|
91 |
+
return retriever
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
logger.exception(e)
|
95 |
+
raise e
|
96 |
+
|
97 |
+
def load_ensemble_retriever():
|
98 |
+
try:
|
99 |
+
# splitted_texts_file=os.path.dirname(os.path.abspath(__file__).join('/data/splitted_texts.jsonl'))
|
100 |
+
splitted_texts_file='./data/splitted_texts.jsonl'
|
101 |
+
sementic_k = 4
|
102 |
+
bm25_k = 2
|
103 |
+
splits_list = load_docs_from_jsonl(splitted_texts_file)
|
104 |
+
|
105 |
+
bm25_retriever = BM25Retriever.from_documents(splits_list)
|
106 |
+
bm25_retriever.k = bm25_k
|
107 |
+
|
108 |
+
faiss_vectorstore = load_FAISS_store()
|
109 |
+
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={'k': sementic_k,})
|
110 |
+
|
111 |
+
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5])
|
112 |
+
ensemble_retriever.top_k=4
|
113 |
+
|
114 |
+
logger.info("EnsembleRetriever loaded:")
|
115 |
+
return ensemble_retriever
|
116 |
+
|
117 |
+
except Exception as e:
|
118 |
+
logger.exception(e)
|
119 |
+
raise e
|
120 |
+
|
121 |
+
from llm import get_model
|
122 |
+
|
123 |
+
def load_multi_query_retriever(multi_query_model_type):
|
124 |
+
#multi query
|
125 |
+
try:
|
126 |
+
llm = get_model(multi_query_model_type)
|
127 |
+
ensembleRetriever = load_ensemble_retriever()
|
128 |
+
retriever = MultiQueryRetriever.from_llm(
|
129 |
+
retriever=ensembleRetriever,
|
130 |
+
llm=llm
|
131 |
+
)
|
132 |
+
logger.info("MultiQueryRetriever loaded:")
|
133 |
+
return retriever
|
134 |
+
|
135 |
+
except Exception as e:
|
136 |
+
logger.exception(e)
|
137 |
+
raise e
|
schema.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2025) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 15/10/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
from typing import Optional, List, Any, Dict
|
18 |
+
from pydantic import BaseModel
|
19 |
+
|
20 |
+
|
21 |
+
class LoginRequest(BaseModel):
|
22 |
+
username: str
|
23 |
+
password: str
|
24 |
+
|
25 |
+
|
26 |
+
# model for front end session
|
27 |
+
# Output model (UserModel response)
|
28 |
+
class UserModel(BaseModel):
|
29 |
+
userId: int
|
30 |
+
firstName: str
|
31 |
+
lastName: str
|
32 |
+
userName: str
|
33 |
+
token: str
|
34 |
+
|
35 |
+
|
36 |
+
class UserQuery(BaseModel):
|
37 |
+
content: str
|
38 |
+
userId: int
|
39 |
+
aiModel: int = "default"
|
40 |
+
|
41 |
+
|
42 |
+
class Document(BaseModel):
|
43 |
+
name: Optional[str]
|
44 |
+
page_content: str
|
45 |
+
metadata: Dict[str, Any]
|
46 |
+
|
47 |
+
|
48 |
+
class ResponseModel(BaseModel):
|
49 |
+
question: str
|
50 |
+
answer: str
|
51 |
+
source_documents: List[Document] = None
|
52 |
+
|
53 |
+
|
54 |
+
# class Feedback(BaseModel):
|
55 |
+
# """
|
56 |
+
# Schema for collecting feedback from the user.
|
57 |
+
# It includes the question, bot response, and user feedback.
|
58 |
+
# """
|
59 |
+
|
60 |
+
# question: str
|
61 |
+
# botResponse: str
|
62 |
+
# userFeedback: str
|
63 |
+
# feedback: str
|
server.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2025) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 15/10/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import time
|
19 |
+
import sys
|
20 |
+
import logging
|
21 |
+
import datetime
|
22 |
+
import uvicorn
|
23 |
+
from dotenv import load_dotenv
|
24 |
+
|
25 |
+
from fastapi import FastAPI, APIRouter, HTTPException, status
|
26 |
+
from fastapi import HTTPException, status
|
27 |
+
from fastapi.middleware.cors import CORSMiddleware
|
28 |
+
|
29 |
+
from schema import UserQuery, ResponseModel, Document, LoginRequest, UserModel
|
30 |
+
from controller import get_QA_Answers, get_avaliable_models
|
31 |
+
|
32 |
+
|
33 |
+
def filer():
|
34 |
+
return "logs/log"
|
35 |
+
# today = datetime.datetime.today()
|
36 |
+
# log_filename = f"logs/{today.year}-{today.month:02d}-{today.day:02d}.log"
|
37 |
+
# return log_filename
|
38 |
+
|
39 |
+
|
40 |
+
file_handler = logging.FileHandler(filer())
|
41 |
+
# file_handler = logging.handlers.TimedRotatingFileHandler(filer(),when="D")
|
42 |
+
file_handler.setLevel(logging.INFO)
|
43 |
+
|
44 |
+
logging.basicConfig(
|
45 |
+
level=logging.DEBUG,
|
46 |
+
format="%(asctime)s %(levelname)s (%(name)s) : %(message)s",
|
47 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
48 |
+
handlers=[file_handler],
|
49 |
+
force=True,
|
50 |
+
)
|
51 |
+
|
52 |
+
logger = logging.getLogger(__name__)
|
53 |
+
|
54 |
+
load_dotenv()
|
55 |
+
host = os.environ.get("APP_HOST")
|
56 |
+
port = int(os.environ.get("APP_PORT"))
|
57 |
+
|
58 |
+
|
59 |
+
class ChatAPI:
|
60 |
+
|
61 |
+
def __init__(self):
|
62 |
+
self.router = APIRouter()
|
63 |
+
self.router.add_api_route("/api/v1/health", self.hello, methods=["GET"])
|
64 |
+
self.router.add_api_route("/api/v1/models", self.avaliable_models, methods=["GET"])
|
65 |
+
self.router.add_api_route(
|
66 |
+
"/api/v1/login", self.login, methods=["POST"], response_model=UserModel
|
67 |
+
)
|
68 |
+
self.router.add_api_route("/api/v1/chat", self.chat, methods=["POST"])
|
69 |
+
|
70 |
+
async def hello(self):
|
71 |
+
return "Hello there!"
|
72 |
+
|
73 |
+
async def avaliable_models(self):
|
74 |
+
logger.info("getting avaliable models")
|
75 |
+
models = get_avaliable_models()
|
76 |
+
|
77 |
+
if not models:
|
78 |
+
logger.exception("models not found")
|
79 |
+
raise HTTPException(
|
80 |
+
status_code=status.HTTP_404_NOT_FOUND, detail="models not found"
|
81 |
+
)
|
82 |
+
|
83 |
+
return models
|
84 |
+
|
85 |
+
async def login(self, login_request: LoginRequest):
|
86 |
+
logger.info(f"username password: {login_request} ")
|
87 |
+
# Dummy user data for demonstration (normally, you'd use a database)
|
88 |
+
dummy_users_db = {
|
89 |
+
"john_doe": {
|
90 |
+
"userId": 1,
|
91 |
+
"firstName": "John",
|
92 |
+
"lastName": "Doe",
|
93 |
+
"userName": "john_doe",
|
94 |
+
"password": "password", # Normally, passwords would be hashed and stored securely
|
95 |
+
"token": "dummy_token_123", # In a real scenario, this would be a JWT or another kind of token
|
96 |
+
}
|
97 |
+
}
|
98 |
+
# Fetch user by username
|
99 |
+
# user = dummy_users_db.get(login_request.username)
|
100 |
+
user = dummy_users_db.get("john_doe")
|
101 |
+
# Validate user credentials
|
102 |
+
if not user or user["password"] != login_request.password:
|
103 |
+
raise HTTPException(status_code=401, detail="Invalid username or password")
|
104 |
+
|
105 |
+
# Return the user model without the password
|
106 |
+
return UserModel(
|
107 |
+
userId=user["userId"],
|
108 |
+
firstName=user["firstName"],
|
109 |
+
lastName=user["lastName"],
|
110 |
+
userName=user["userName"],
|
111 |
+
token=user["token"],
|
112 |
+
)
|
113 |
+
|
114 |
+
async def chat(
|
115 |
+
self, userQuery: UserQuery
|
116 |
+
): #:UserQuery):# -> ResponseModel: #chat: QueryModel): # -> ResponseModel:
|
117 |
+
"""Makes query to doc store via Langchain pipeline.
|
118 |
+
|
119 |
+
:param chat.: question, model, dataset location, history of the chat.
|
120 |
+
:type chat: QueryModel
|
121 |
+
"""
|
122 |
+
logger.info(f"userQuery: {userQuery} ")
|
123 |
+
|
124 |
+
try:
|
125 |
+
start = time.time()
|
126 |
+
res = get_QA_Answers(userQuery)
|
127 |
+
logger.info(
|
128 |
+
f"-------------------------- answer: {res} -------------------------- "
|
129 |
+
)
|
130 |
+
# return res
|
131 |
+
end = time.time()
|
132 |
+
logger.info(
|
133 |
+
f"-------------------------- Server process (took {round(end - start, 2)} s.) \n: {res}"
|
134 |
+
)
|
135 |
+
print(
|
136 |
+
f" \n -------------------------- Server process (took {round(end - start, 2)} s.) ------------------------- \n"
|
137 |
+
)
|
138 |
+
return res
|
139 |
+
|
140 |
+
except HTTPException as e:
|
141 |
+
logger.exception(e)
|
142 |
+
raise e
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
logger.exception(e)
|
146 |
+
raise HTTPException(status_code=400, detail=f"Error : {e}")
|
147 |
+
|
148 |
+
|
149 |
+
# initialize API
|
150 |
+
app = FastAPI(title="Boardpac chatbot API")
|
151 |
+
api = ChatAPI()
|
152 |
+
app.include_router(api.router)
|
153 |
+
|
154 |
+
# origins = ['http://localhost:8000','http://192.168.10.100:8000']
|
155 |
+
|
156 |
+
app.add_middleware(
|
157 |
+
CORSMiddleware,
|
158 |
+
allow_origins=["*"], # origins,
|
159 |
+
allow_credentials=True,
|
160 |
+
allow_methods=["*"],
|
161 |
+
allow_headers=["*"],
|
162 |
+
)
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
|
166 |
+
host = "0.0.0.0"
|
167 |
+
port = 8000
|
168 |
+
|
169 |
+
# config = uvicorn.Config("server:app",host=host, port=port, log_config= logging.basicConfig())
|
170 |
+
config = uvicorn.Config("server:app", host=host, port=port)
|
171 |
+
server = uvicorn.Server(config)
|
172 |
+
server.run()
|
173 |
+
# uvicorn.run(app)
|
utils/__init__.py
ADDED
File without changes
|
utils/utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Python Backend API to chat with private data
|
3 |
+
15/11/2023
|
4 |
+
Theekshana Samaradiwakara
|
5 |
+
"""
|
6 |
+
"""
|
7 |
+
/*************************************************************************
|
8 |
+
*
|
9 |
+
* CONFIDENTIAL
|
10 |
+
* __________________
|
11 |
+
*
|
12 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
13 |
+
* All Rights Reserved
|
14 |
+
*
|
15 |
+
* Author : Theekshana Samaradiwakara
|
16 |
+
* Description :Python Backend API to chat with private data
|
17 |
+
* CreatedDate : 15/11/2023
|
18 |
+
* LastModifiedDate : 10/12/2020
|
19 |
+
*************************************************************************/
|
20 |
+
"""
|
21 |
+
|
22 |
+
# from passlib.context import CryptContext
|
23 |
+
# pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
24 |
+
|
25 |
+
|
26 |
+
# def hash(password: str):
|
27 |
+
# return pwd_context.hash(password)
|
28 |
+
|
29 |
+
|
30 |
+
# def verify(plain_password, hashed_password):
|
31 |
+
# return pwd_context.verify(plain_password, hashed_password)
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
import re
|
36 |
+
def is_valid_open_ai_api_key(secretKey):
|
37 |
+
if re.search("^sk-[a-zA-Z0-9]{32,}$", secretKey ):
|
38 |
+
return True
|
39 |
+
else: return False
|
40 |
+
|