anindya-hf-2002 commited on
Commit
b6d19d9
·
verified ·
1 Parent(s): 53a8618

Upload 21 files

Browse files
.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ data/
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from src.vectorstore.pinecone_db import ingest_data, get_retriever, load_documents, process_chunks, save_to_parquet
3
+ from pinecone import Pinecone
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_ollama import ChatOllama
6
+ from src.agents.workflow import run_adaptive_rag
7
+ from langgraph.pregel import GraphRecursionError
8
+ import tempfile
9
+ import os
10
+ import time
11
+ from pathlib import Path
12
+
13
+ # Page config
14
+ st.set_page_config(page_title="RAG Chat Assistant", layout="wide")
15
+
16
+ # Initialize session states
17
+ if "messages" not in st.session_state:
18
+ st.session_state.messages = []
19
+ if "documents_processed" not in st.session_state:
20
+ st.session_state.documents_processed = False
21
+ if "retriever" not in st.session_state:
22
+ st.session_state.retriever = None
23
+ if "pinecone_client" not in st.session_state:
24
+ st.session_state.pinecone_client = None
25
+
26
+ def initialize_pinecone(api_key):
27
+ """Initialize Pinecone client with API key."""
28
+ try:
29
+ return Pinecone(api_key=api_key)
30
+ except Exception as e:
31
+ st.error(f"Error initializing Pinecone: {str(e)}")
32
+ return None
33
+
34
+ def initialize_llm(llm_option, openai_api_key=None):
35
+ """Initialize LLM based on user selection."""
36
+ if llm_option == "OpenAI":
37
+ if not openai_api_key:
38
+ st.sidebar.warning("Please enter OpenAI API key.")
39
+ return None
40
+ return ChatOpenAI(api_key=openai_api_key, model="gpt-3.5-turbo")
41
+ else:
42
+ return ChatOllama(model="llama3.2", temperature=0.3, num_predict=512, top_p=0.6)
43
+
44
+ def clear_pinecone_index(pc, index_name="vector-index"):
45
+ """Clear the Pinecone index."""
46
+ try:
47
+ pc.delete_index(index_name)
48
+ st.session_state.documents_processed = False
49
+ st.session_state.retriever = None
50
+ st.success("Database cleared successfully!")
51
+ except Exception as e:
52
+ st.error(f"Error clearing database: {str(e)}")
53
+
54
+ def process_documents(uploaded_files, pc):
55
+ """Process uploaded documents and store in Pinecone."""
56
+ if not uploaded_files:
57
+ st.warning("Please upload at least one document.")
58
+ return False
59
+
60
+ with st.spinner("Processing documents..."):
61
+ temp_dir = tempfile.mkdtemp()
62
+ file_paths = []
63
+ markdown_path = Path(temp_dir) / "combined.md"
64
+ parquet_path = Path(temp_dir) / "documents.parquet"
65
+
66
+ for uploaded_file in uploaded_files:
67
+ file_path = Path(temp_dir) / uploaded_file.name
68
+ with open(file_path, "wb") as f:
69
+ f.write(uploaded_file.getvalue())
70
+ file_paths.append(str(file_path))
71
+
72
+ try:
73
+ markdown_path = load_documents(file_paths, output_path=markdown_path)
74
+ chunks = process_chunks(markdown_path, chunk_size=256, threshold=0.6)
75
+ print(f"Processed chunks: {chunks}")
76
+ parquet_path = save_to_parquet(chunks, parquet_path)
77
+
78
+ ingest_data(
79
+ pc=pc,
80
+ parquet_path=parquet_path,
81
+ text_column="text",
82
+ pinecone_client=pc
83
+ )
84
+
85
+ st.session_state.retriever = get_retriever(pc)
86
+ st.session_state.documents_processed = True
87
+
88
+ return True
89
+
90
+ except Exception as e:
91
+ st.error(f"Error processing documents: {str(e)}")
92
+ return False
93
+ finally:
94
+ for file_path in file_paths:
95
+ try:
96
+ os.remove(file_path)
97
+ except:
98
+ pass
99
+ try:
100
+ os.rmdir(temp_dir)
101
+ except:
102
+ pass
103
+
104
+ def run_rag_with_streaming(retriever, question, llm, enable_web_search=False):
105
+ """Run RAG workflow and yield streaming results."""
106
+ try:
107
+ response = run_adaptive_rag(
108
+ retriever=retriever,
109
+ question=question,
110
+ llm=llm,
111
+ top_k=5,
112
+ enable_websearch=enable_web_search
113
+ )
114
+
115
+ for word in response.split():
116
+ yield word + " "
117
+ time.sleep(0.03)
118
+
119
+ except GraphRecursionError:
120
+ response = "I apologize, but I cannot find a sufficient answer to your question in the provided documents. Please try rephrasing your question or ask something else about the content of the documents."
121
+ for word in response.split():
122
+ yield word + " "
123
+ time.sleep(0.03)
124
+
125
+ except Exception as e:
126
+ yield f"I encountered an error while processing your question: {str(e)}"
127
+
128
+ def main():
129
+ st.title("🤖 RAG Chat Assistant")
130
+
131
+ # Sidebar configuration
132
+ st.sidebar.title("Configuration")
133
+
134
+ # API Keys in sidebar
135
+ pinecone_api_key = st.sidebar.text_input("Enter Pinecone API Key:", type="password")
136
+
137
+ # LLM Selection
138
+ llm_option = st.sidebar.selectbox("Select Language Model:", ["OpenAI", "Ollama"])
139
+ openai_api_key = None
140
+ if llm_option == "OpenAI":
141
+ openai_api_key = st.sidebar.text_input("Enter OpenAI API Key:", type="password")
142
+
143
+ # Web search tool in sidebar
144
+ st.sidebar.markdown("---")
145
+ st.sidebar.markdown("### Tools")
146
+ use_web_search = st.sidebar.checkbox("Web search")
147
+
148
+ # Initialize Pinecone
149
+ if pinecone_api_key:
150
+ if st.session_state.pinecone_client is None:
151
+ st.session_state.pinecone_client = initialize_pinecone(pinecone_api_key)
152
+ else:
153
+ st.sidebar.warning("Please enter Pinecone API key to continue.")
154
+ st.stop()
155
+
156
+ # Initialize LLM
157
+ llm = initialize_llm(llm_option, openai_api_key)
158
+ if llm is None:
159
+ st.stop()
160
+
161
+ # Clear DB Button
162
+ st.sidebar.markdown("---")
163
+ if st.sidebar.button("Clear Database"):
164
+ if st.session_state.pinecone_client:
165
+ clear_pinecone_index(st.session_state.pinecone_client)
166
+ st.session_state.messages = [] # Clear chat history
167
+
168
+ # Document upload section
169
+ if not st.session_state.documents_processed:
170
+ st.header("📄 Document Upload")
171
+ uploaded_files = st.file_uploader(
172
+ "Upload your documents",
173
+ accept_multiple_files=True,
174
+ type=["pdf", "docx", "txt", "pptx", "md"]
175
+ )
176
+
177
+ if st.button("Process Documents"):
178
+ if process_documents(uploaded_files, st.session_state.pinecone_client):
179
+ st.success("Documents processed successfully!")
180
+
181
+ # Chat interface
182
+ if st.session_state.documents_processed:
183
+ st.header("💬 Chat")
184
+
185
+ # Display chat history
186
+ for message in st.session_state.messages:
187
+ with st.chat_message(message["role"]):
188
+ st.markdown(message["content"])
189
+
190
+ # Chat input
191
+ if prompt := st.chat_input("Ask a question about your documents..."):
192
+ # Display user message
193
+ with st.chat_message("user"):
194
+ if use_web_search:
195
+ st.markdown(prompt.strip() + ''' :red-background[Web Search]''')
196
+ else:
197
+ st.markdown(prompt)
198
+ st.session_state.messages.append({"role": "user", "content": prompt})
199
+
200
+ # Generate and stream response
201
+ with st.chat_message("assistant"):
202
+ response_container = st.empty()
203
+ full_response = ""
204
+
205
+ # Show spinner while processing
206
+ with st.spinner("Thinking..."):
207
+ # Stream the response
208
+ for chunk in run_rag_with_streaming(
209
+ retriever=st.session_state.retriever,
210
+ question=prompt,
211
+ llm=llm,
212
+ enable_web_search=use_web_search
213
+ ):
214
+ full_response += chunk
215
+ response_container.markdown(full_response + "▌")
216
+
217
+ # Final update without cursor
218
+ response_container.markdown(full_response)
219
+
220
+ # Save to chat history
221
+ st.session_state.messages.append(
222
+ {"role": "assistant", "content": full_response}
223
+ )
224
+
225
+ if __name__ == "__main__":
226
+ main()
main.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pinecone import Pinecone
2
+ from langchain_openai import ChatOpenAI
3
+ from src.vectorstore.pinecone_db import ingest_data, get_retriever, load_documents, process_chunks, save_to_parquet
4
+ from src.agents.workflow import run_adaptive_rag
5
+ from langgraph.pregel import GraphRecursionError
6
+ import tempfile
7
+ import os
8
+ from pathlib import Path
9
+
10
+ def initialize_pinecone(api_key):
11
+ """Initialize Pinecone client with API key."""
12
+ try:
13
+ return Pinecone(api_key=api_key)
14
+ except Exception as e:
15
+ print(f"Error initializing Pinecone: {str(e)}")
16
+ return None
17
+
18
+ def initialize_llm(api_key):
19
+ """Initialize OpenAI LLM."""
20
+ try:
21
+ return ChatOpenAI(api_key=api_key, model="gpt-3.5-turbo")
22
+ except Exception as e:
23
+ print(f"Error initializing OpenAI: {str(e)}")
24
+ return None
25
+
26
+ def process_documents(file_paths, pc):
27
+ """Process documents and store in Pinecone."""
28
+ if not file_paths:
29
+ print("No documents provided.")
30
+ return None
31
+
32
+ print("Processing documents...")
33
+ temp_dir = tempfile.mkdtemp()
34
+ markdown_path = Path(temp_dir) / "combined.md"
35
+ parquet_path = Path(temp_dir) / "documents.parquet"
36
+
37
+ try:
38
+ markdown_path = load_documents(file_paths, output_path=markdown_path)
39
+ chunks = process_chunks(markdown_path, chunk_size=256, threshold=0.6)
40
+ parquet_path = save_to_parquet(chunks, parquet_path)
41
+
42
+ ingest_data(
43
+ pc=pc,
44
+ parquet_path=parquet_path,
45
+ text_column="text",
46
+ pinecone_client=pc
47
+ )
48
+
49
+ retriever = get_retriever(pc)
50
+ print("Documents processed successfully!")
51
+ return retriever
52
+
53
+ except Exception as e:
54
+ print(f"Error processing documents: {str(e)}")
55
+ return None
56
+ finally:
57
+ try:
58
+ os.remove(markdown_path)
59
+ os.remove(parquet_path)
60
+ os.rmdir(temp_dir)
61
+ except:
62
+ pass
63
+
64
+ def main():
65
+ # Get API keys
66
+ pinecone_api_key = input("Enter your Pinecone API key: ")
67
+ openai_api_key = input("Enter your OpenAI API key: ")
68
+
69
+ # Initialize clients
70
+ pc = initialize_pinecone(pinecone_api_key)
71
+ if not pc:
72
+ return
73
+
74
+ llm = initialize_llm(openai_api_key)
75
+ if not llm:
76
+ return
77
+
78
+ # Get document paths
79
+ print("\nEnter the paths to your documents (one per line).")
80
+ print("Press Enter twice when done:")
81
+
82
+ file_paths = []
83
+ while True:
84
+ path = input()
85
+ if not path:
86
+ break
87
+ if os.path.exists(path):
88
+ file_paths.append(path)
89
+ else:
90
+ print(f"Warning: File {path} does not exist")
91
+
92
+ # Process documents
93
+ retriever = process_documents(file_paths, pc)
94
+ if not retriever:
95
+ return
96
+
97
+ # Chat loop
98
+ print("\nChat with your documents! Type 'exit' to quit.")
99
+ while True:
100
+ question = input("\nYou: ")
101
+
102
+ if question.lower() == 'exit':
103
+ print("Goodbye!")
104
+ break
105
+
106
+ try:
107
+ response = run_adaptive_rag(
108
+ retriever=retriever,
109
+ question=question,
110
+ llm=llm,
111
+ top_k=5,
112
+ enable_websearch=False
113
+ )
114
+ print("\nAssistant:", response)
115
+
116
+ except GraphRecursionError:
117
+ print("\nAssistant: I cannot find a sufficient answer to your question in the provided documents. Please try rephrasing your question or ask something else about the content of the documents.")
118
+
119
+ except Exception as e:
120
+ print(f"\nError: {str(e)}")
121
+
122
+ if __name__ == "__main__":
123
+ main()
notebooks/adaptive_rag.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from src.vectorstore.pinecone_db import ingest_data, get_retriever, load_documents, process_chunks, save_to_parquet
3
+ from pinecone import Pinecone
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_ollama import ChatOllama
6
+ from src.agents.workflow import run_adaptive_rag
7
+ from langgraph.pregel import GraphRecursionError
8
+ import tempfile
9
+ import os
10
+ import time
11
+ from pathlib import Path
12
+
13
+ # Page config
14
+ st.set_page_config(page_title="RAG Chat Assistant", layout="wide")
15
+
16
+ # Initialize session states
17
+ if "messages" not in st.session_state:
18
+ st.session_state.messages = []
19
+ if "documents_processed" not in st.session_state:
20
+ st.session_state.documents_processed = False
21
+ if "retriever" not in st.session_state:
22
+ st.session_state.retriever = None
23
+ if "pinecone_client" not in st.session_state:
24
+ st.session_state.pinecone_client = None
25
+
26
+ def initialize_pinecone(api_key):
27
+ """Initialize Pinecone client with API key."""
28
+ try:
29
+ return Pinecone(api_key=api_key)
30
+ except Exception as e:
31
+ st.error(f"Error initializing Pinecone: {str(e)}")
32
+ return None
33
+
34
+ def initialize_llm(llm_option, openai_api_key=None):
35
+ """Initialize LLM based on user selection."""
36
+ if llm_option == "OpenAI":
37
+ if not openai_api_key:
38
+ st.sidebar.warning("Please enter OpenAI API key.")
39
+ return None
40
+ return ChatOpenAI(api_key=openai_api_key, model="gpt-3.5-turbo")
41
+ else:
42
+ return ChatOllama(model="llama3.2", temperature=0.3, num_predict=512, top_p=0.6)
43
+
44
+ def clear_pinecone_index(pc, index_name="vector-index"):
45
+ """Clear the Pinecone index."""
46
+ try:
47
+ pc.delete_index(index_name)
48
+ st.session_state.documents_processed = False
49
+ st.session_state.retriever = None
50
+ st.success("Database cleared successfully!")
51
+ except Exception as e:
52
+ st.error(f"Error clearing database: {str(e)}")
53
+
54
+ def process_documents(uploaded_files, pc):
55
+ """Process uploaded documents and store in Pinecone."""
56
+ if not uploaded_files:
57
+ st.warning("Please upload at least one document.")
58
+ return False
59
+
60
+ with st.spinner("Processing documents..."):
61
+ temp_dir = tempfile.mkdtemp()
62
+ file_paths = []
63
+ markdown_path = Path(temp_dir) / "combined.md"
64
+ parquet_path = Path(temp_dir) / "documents.parquet"
65
+
66
+ for uploaded_file in uploaded_files:
67
+ file_path = Path(temp_dir) / uploaded_file.name
68
+ with open(file_path, "wb") as f:
69
+ f.write(uploaded_file.getvalue())
70
+ file_paths.append(str(file_path))
71
+
72
+ try:
73
+ markdown_path = load_documents(file_paths, output_path=markdown_path)
74
+ chunks = process_chunks(markdown_path, chunk_size=256, threshold=0.6)
75
+ print(f"Processed chunks: {chunks}")
76
+ parquet_path = save_to_parquet(chunks, parquet_path)
77
+
78
+ ingest_data(
79
+ pc=pc,
80
+ parquet_path=parquet_path,
81
+ text_column="text",
82
+ pinecone_client=pc
83
+ )
84
+
85
+ st.session_state.retriever = get_retriever(pc)
86
+ st.session_state.documents_processed = True
87
+
88
+ return True
89
+
90
+ except Exception as e:
91
+ st.error(f"Error processing documents: {str(e)}")
92
+ return False
93
+ finally:
94
+ for file_path in file_paths:
95
+ try:
96
+ os.remove(file_path)
97
+ except:
98
+ pass
99
+ try:
100
+ os.rmdir(temp_dir)
101
+ except:
102
+ pass
103
+
104
+ def run_rag_with_streaming(retriever, question, llm, enable_web_search=False):
105
+ """Run RAG workflow and yield streaming results."""
106
+ try:
107
+ response = run_adaptive_rag(
108
+ retriever=retriever,
109
+ question=question,
110
+ llm=llm,
111
+ top_k=5,
112
+ enable_websearch=enable_web_search
113
+ )
114
+
115
+ for word in response.split():
116
+ yield word + " "
117
+ time.sleep(0.03)
118
+
119
+ except GraphRecursionError:
120
+ response = "I apologize, but I cannot find a sufficient answer to your question in the provided documents. Please try rephrasing your question or ask something else about the content of the documents."
121
+ for word in response.split():
122
+ yield word + " "
123
+ time.sleep(0.03)
124
+
125
+ except Exception as e:
126
+ yield f"I encountered an error while processing your question: {str(e)}"
127
+
128
+ def main():
129
+ st.title("🤖 RAG Chat Assistant")
130
+
131
+ # Sidebar configuration
132
+ st.sidebar.title("Configuration")
133
+
134
+ # API Keys in sidebar
135
+ pinecone_api_key = st.sidebar.text_input("Enter Pinecone API Key:", type="password")
136
+
137
+ # LLM Selection
138
+ llm_option = st.sidebar.selectbox("Select Language Model:", ["OpenAI", "Ollama"])
139
+ openai_api_key = None
140
+ if llm_option == "OpenAI":
141
+ openai_api_key = st.sidebar.text_input("Enter OpenAI API Key:", type="password")
142
+
143
+ # Web search tool in sidebar
144
+ st.sidebar.markdown("---")
145
+ st.sidebar.markdown("### Tools")
146
+ use_web_search = st.sidebar.checkbox("Web search")
147
+
148
+ # Initialize Pinecone
149
+ if pinecone_api_key:
150
+ if st.session_state.pinecone_client is None:
151
+ st.session_state.pinecone_client = initialize_pinecone(pinecone_api_key)
152
+ else:
153
+ st.sidebar.warning("Please enter Pinecone API key to continue.")
154
+ st.stop()
155
+
156
+ # Initialize LLM
157
+ llm = initialize_llm(llm_option, openai_api_key)
158
+ if llm is None:
159
+ st.stop()
160
+
161
+ # Clear DB Button
162
+ st.sidebar.markdown("---")
163
+ if st.sidebar.button("Clear Database"):
164
+ if st.session_state.pinecone_client:
165
+ clear_pinecone_index(st.session_state.pinecone_client)
166
+ st.session_state.messages = [] # Clear chat history
167
+
168
+ # Document upload section
169
+ if not st.session_state.documents_processed:
170
+ st.header("📄 Document Upload")
171
+ uploaded_files = st.file_uploader(
172
+ "Upload your documents",
173
+ accept_multiple_files=True,
174
+ type=["pdf", "docx", "txt", "pptx", "md"]
175
+ )
176
+
177
+ if st.button("Process Documents"):
178
+ if process_documents(uploaded_files, st.session_state.pinecone_client):
179
+ st.success("Documents processed successfully!")
180
+
181
+ # Chat interface
182
+ if st.session_state.documents_processed:
183
+ st.header("💬 Chat")
184
+
185
+ # Display chat history
186
+ for message in st.session_state.messages:
187
+ with st.chat_message(message["role"]):
188
+ st.markdown(message["content"])
189
+
190
+ # Chat input
191
+ if prompt := st.chat_input("Ask a question about your documents..."):
192
+ # Display user message
193
+ with st.chat_message("user"):
194
+ if use_web_search:
195
+ st.markdown(prompt.strip() + ''' :red-background[Web Search]''')
196
+ else:
197
+ st.markdown(prompt)
198
+ st.session_state.messages.append({"role": "user", "content": prompt})
199
+
200
+ # Generate and stream response
201
+ with st.chat_message("assistant"):
202
+ response_container = st.empty()
203
+ full_response = ""
204
+
205
+ # Show spinner while processing
206
+ with st.spinner("Thinking..."):
207
+ # Stream the response
208
+ for chunk in run_rag_with_streaming(
209
+ retriever=st.session_state.retriever,
210
+ question=prompt,
211
+ llm=llm,
212
+ enable_web_search=use_web_search
213
+ ):
214
+ full_response += chunk
215
+ response_container.markdown(full_response + "▌")
216
+
217
+ # Final update without cursor
218
+ response_container.markdown(full_response)
219
+
220
+ # Save to chat history
221
+ st.session_state.messages.append(
222
+ {"role": "assistant", "content": full_response}
223
+ )
224
+
225
+ if __name__ == "__main__":
226
+ main()
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain-community
2
+ tiktoken
3
+ langchain-openai
4
+ langchainhub
5
+ chromadb
6
+ langchain
7
+ langgraph
8
+ duckduckgo-search
9
+ langchain-groq
10
+ langchain-huggingface
11
+ sentence_transformers
12
+ tavily-python
13
+ langchain-ollama
14
+ ollama
15
+ crawl4ai
16
+ docling
17
+ easyocr
18
+ FlagEmbedding
19
+ chonkie[semantic]
20
+ pinecone
21
+ streamlit
src/__init__.py ADDED
File without changes
src/agents/__init__.py ADDED
File without changes
src/agents/router.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+ from langchain_ollama import ChatOllama
3
+ from pydantic import BaseModel, Field
4
+ from typing import Literal
5
+
6
+ class RouteQuery(BaseModel):
7
+ """Route a user query to the most relevant datasource."""
8
+ datasource: Literal["vectorstore", "web_search"] = Field(
9
+ description="Route question to web search or vectorstore retrieval"
10
+ )
11
+
12
+ def create_query_router():
13
+ """
14
+ Create a query router to determine data source for a given question.
15
+
16
+ Returns:
17
+ Callable: Query router function
18
+ """
19
+ # LLM with function call
20
+ llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
21
+ structured_llm_router = llm.with_structured_output(RouteQuery)
22
+
23
+ # Prompt
24
+ system = """You are an expert at routing a user question to a vectorstore or web search.
25
+ The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
26
+ Use the vectorstore for questions on these topics. Otherwise, use web-search."""
27
+
28
+ route_prompt = ChatPromptTemplate.from_messages([
29
+ ("system", system),
30
+ ("human", "{question}"),
31
+ ])
32
+
33
+ return route_prompt | structured_llm_router
34
+
35
+ def route_query(question: str):
36
+ """
37
+ Route a specific query to its appropriate data source.
38
+
39
+ Args:
40
+ question (str): User's input question
41
+
42
+ Returns:
43
+ str: Recommended data source
44
+ """
45
+ router = create_query_router()
46
+ result = router.invoke({"question": question})
47
+ return result.datasource
48
+
49
+ if __name__ == "__main__":
50
+ # Example usage
51
+ test_questions = [
52
+ "Who will the Bears draft first in the NFL draft?",
53
+ "What are the types of agent memory?"
54
+ ]
55
+
56
+ for q in test_questions:
57
+ source = route_query(q)
58
+ print(f"Question: {q}")
59
+ print(f"Routed to: {source}\n")
src/agents/state.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, TypedDict
2
+ from langchain_core.documents.base import Document
3
+
4
+ class GraphState(TypedDict):
5
+ """
6
+ Represents the state of our adaptive RAG graph.
7
+
8
+ Attributes:
9
+ question (str): Original user question
10
+ generation (str, optional): LLM generated answer
11
+ documents (List[Document], optional): Retrieved or searched documents
12
+ """
13
+ question: str
14
+ generation: str | None
15
+ documents: List[Document]
src/agents/workflow.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import END, StateGraph, START
2
+ from langchain_core.prompts import PromptTemplate
3
+ from agents.state import GraphState
4
+ # from agents.router import route_query
5
+ import asyncio
6
+ from vectorstore.pinecone_db import get_retriever
7
+ from tools.web_search import AdvancedWebCrawler
8
+ from llm.graders import (
9
+ grade_document_relevance,
10
+ check_hallucination,
11
+ grade_answer_quality
12
+ )
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from llm.query_rewriter import rewrite_query
15
+ from langchain_ollama import ChatOllama
16
+
17
+ def perform_web_search(question: str):
18
+ """
19
+ Perform web search using the AdvancedWebCrawler.
20
+
21
+ Args:
22
+ question (str): User's input question
23
+
24
+ Returns:
25
+ List: Web search results
26
+ """
27
+ # Initialize web crawler
28
+ crawler = AdvancedWebCrawler(
29
+ max_search_results=5,
30
+ word_count_threshold=50,
31
+ content_filter_type='f',
32
+ filter_threshold=0.48
33
+ )
34
+ results = asyncio.run(crawler.search_and_crawl(question))
35
+
36
+ return results
37
+
38
+
39
+ def create_adaptive_rag_workflow(retriever, llm, top_k=5, enable_websearch=False):
40
+ """
41
+ Create the adaptive RAG workflow graph.
42
+
43
+ Args:
44
+ retriever: Vector store retriever
45
+
46
+ Returns:
47
+ Compiled LangGraph workflow
48
+ """
49
+ def retrieve(state: GraphState):
50
+ """Retrieve documents from vectorstore."""
51
+ print("---RETRIEVE---")
52
+ question = state['question']
53
+ documents = retriever.invoke(question, top_k)
54
+ print(f"Retrieved {len(documents)} documents.")
55
+ print(documents)
56
+ return {"documents": documents, "question": question}
57
+
58
+ def route_to_datasource(state: GraphState):
59
+ """Route question to web search or vectorstore."""
60
+ print("---ROUTE QUESTION---")
61
+ # question = state['question']
62
+ # source = route_query(question)
63
+
64
+ if enable_websearch:
65
+ print("---ROUTE TO WEB SEARCH---")
66
+ return "web_search"
67
+ else:
68
+ print("---ROUTE TO RAG---")
69
+ return "vectorstore"
70
+
71
+ def generate_answer(state: GraphState):
72
+ """Generate answer using retrieved documents."""
73
+ print("---GENERATE---")
74
+ question = state['question']
75
+ documents = state['documents']
76
+
77
+ # Prepare context
78
+ context = "\n\n".join([doc["page_content"] for doc in documents])
79
+ prompt_template = PromptTemplate.from_template("""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
80
+ Question: {question}
81
+ Context: {context}
82
+ Answer:""")
83
+ # Generate answer
84
+ rag_chain = prompt_template | llm | StrOutputParser()
85
+
86
+ generation = rag_chain.invoke({"context": context, "question": question})
87
+
88
+ return {"generation": generation, "documents": documents, "question": question}
89
+
90
+ def grade_documents(state: GraphState):
91
+ """Filter relevant documents."""
92
+ print("---GRADE DOCUMENTS---")
93
+ question = state['question']
94
+ documents = state['documents']
95
+
96
+ # Filter documents
97
+ filtered_docs = []
98
+ for doc in documents:
99
+ score = grade_document_relevance(question, doc["page_content"], llm)
100
+ if score == "yes":
101
+ filtered_docs.append(doc)
102
+
103
+ return {"documents": filtered_docs, "question": question}
104
+
105
+ def web_search(state: GraphState):
106
+ """Perform web search."""
107
+ print("---WEB SEARCH---")
108
+ question = state['question']
109
+
110
+ # Perform web search
111
+ results = perform_web_search(question)
112
+ web_documents = [
113
+ {
114
+ "page_content": result['content'],
115
+ "metadata": {"source": result['url']}
116
+ } for result in results
117
+ ]
118
+
119
+ return {"documents": web_documents, "question": question}
120
+
121
+ def check_generation_quality(state: GraphState):
122
+ """Check the quality of generated answer."""
123
+ print("---ASSESS GENERATION---")
124
+ question = state['question']
125
+ documents = state['documents']
126
+ generation = state['generation']
127
+
128
+
129
+ print("---Generation is not hallucinated.---")
130
+ # Check answer quality
131
+ quality_score = grade_answer_quality(question, generation, llm)
132
+ if quality_score == "yes":
133
+ print("---Answer quality is good.---")
134
+ else:
135
+ print("---Answer quality is poor.---")
136
+ return "end" if quality_score == "yes" else "rewrite"
137
+
138
+ # Create workflow
139
+ workflow = StateGraph(GraphState)
140
+
141
+ # Add nodes
142
+ workflow.add_node("vectorstore", retrieve)
143
+ workflow.add_node("web_search", web_search)
144
+ workflow.add_node("grade_documents", grade_documents)
145
+ workflow.add_node("generate", generate_answer)
146
+ workflow.add_node("rewrite_query", lambda state: {
147
+ "question": rewrite_query(state['question'], llm),
148
+ "documents": [],
149
+ "generation": None
150
+ })
151
+
152
+ # Define edges
153
+ workflow.add_conditional_edges(
154
+ START,
155
+ route_to_datasource,
156
+ {
157
+ "web_search": "web_search",
158
+ "vectorstore": "vectorstore"
159
+ }
160
+ )
161
+
162
+ workflow.add_edge("web_search", "generate")
163
+ workflow.add_edge("vectorstore", "grade_documents")
164
+
165
+ workflow.add_conditional_edges(
166
+ "grade_documents",
167
+ lambda state: "generate" if state['documents'] else "rewrite_query"
168
+ )
169
+
170
+ workflow.add_edge("rewrite_query", "vectorstore")
171
+
172
+ workflow.add_conditional_edges(
173
+ "generate",
174
+ check_generation_quality,
175
+ {
176
+ "end": END,
177
+ "regenerate": "generate",
178
+ "rewrite": "rewrite_query"
179
+ }
180
+ )
181
+
182
+ # Compile the workflow
183
+ app = workflow.compile()
184
+ return app
185
+
186
+ def run_adaptive_rag(retriever, question: str, llm, top_k=5, enable_websearch=False):
187
+ """
188
+ Run the adaptive RAG workflow for a given question.
189
+
190
+ Args:
191
+ retriever: Vector store retriever
192
+ question (str): User's input question
193
+
194
+ Returns:
195
+ str: Generated answer
196
+ """
197
+ # Create workflow
198
+ workflow = create_adaptive_rag_workflow(retriever, llm, top_k, enable_websearch=enable_websearch)
199
+
200
+ # Run workflow
201
+ final_state = None
202
+ for output in workflow.stream({"question": question}, config={"recursion_limit": 5}):
203
+ for key, value in output.items():
204
+ print(f"Node '{key}':")
205
+ # Optionally print state details
206
+ # print(value)
207
+ final_state = value
208
+
209
+ return final_state.get('generation', 'No answer could be generated.')
210
+
211
+ if __name__ == "__main__":
212
+ # Example usage
213
+ from vectorstore.pinecone_db import PINECONE_API_KEY, ingest_data, get_retriever, load_documents, process_chunks, save_to_parquet
214
+ from pinecone import Pinecone
215
+
216
+ # Load and prepare documents
217
+ pc = Pinecone(api_key=PINECONE_API_KEY)
218
+
219
+ # Define input files
220
+ file_paths=[
221
+ # './data/2404.19756v1.pdf',
222
+ # './data/OD429347375590223100.pdf',
223
+ # './data/Project Report Format.docx',
224
+ './data/UNIT 2 GENDER BASED VIOLENCE.pptx'
225
+ ]
226
+
227
+ # Process pipeline
228
+ try:
229
+ # Step 1: Load and combine documents
230
+ print("Loading documents...")
231
+ markdown_path = load_documents(file_paths)
232
+
233
+ # Step 2: Process into chunks with embeddings
234
+ print("Processing chunks...")
235
+ chunks = process_chunks(markdown_path)
236
+
237
+ # Step 3: Save to Parquet
238
+ print("Saving to Parquet...")
239
+ parquet_path = save_to_parquet(chunks)
240
+
241
+ # Step 4: Ingest into Pinecone
242
+ print("Ingesting into Pinecone...")
243
+ ingest_data(pc,
244
+ parquet_path=parquet_path,
245
+ text_column="text",
246
+ pinecone_client=pc,
247
+ )
248
+
249
+ # Step 5: Test retrieval
250
+ print("\nTesting retrieval...")
251
+ retriever = get_retriever(
252
+ pinecone_client=pc,
253
+ index_name="vector-index",
254
+ namespace="rag"
255
+ )
256
+
257
+ except Exception as e:
258
+ print(f"Error in pipeline: {str(e)}")
259
+
260
+ llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
261
+
262
+ # Test questions
263
+ test_questions = [
264
+ # "What are the key components of AI agent memory?",
265
+ # "Explain prompt engineering techniques",
266
+ # "What are recent advancements in adversarial attacks on LLMs?"
267
+ "what are the trending papers that are published in NeurIPS 2024?"
268
+ ]
269
+
270
+ # Run workflow for each test question
271
+ for question in test_questions:
272
+ print(f"\n--- Processing Question: {question} ---")
273
+ answer = run_adaptive_rag(retriever, question, llm)
274
+ print("\nFinal Answer:", answer)
src/data_processing/__init__.py ADDED
File without changes
src/data_processing/chunker.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+ from chonkie.embeddings import BaseEmbeddings
4
+ from FlagEmbedding import BGEM3FlagModel
5
+ from chonkie import SDPMChunker as SDPMChunker
6
+
7
+ class BGEM3Embeddings(BaseEmbeddings):
8
+ def __init__(self, model_name):
9
+ self.model = BGEM3FlagModel(model_name, use_fp16=True)
10
+ self.task = "separation"
11
+
12
+ @property
13
+ def dimension(self):
14
+ return 1024
15
+
16
+ def embed(self, text: str):
17
+ e = self.model.encode([text], return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs']
18
+ # print(e)
19
+ return e
20
+
21
+ def embed_batch(self, texts: List[str]):
22
+ embeddings = self.model.encode(texts, return_dense=True, return_sparse=False, return_colbert_vecs=False
23
+ )
24
+ # print(embeddings['dense_vecs'])
25
+ return embeddings['dense_vecs']
26
+
27
+ def count_tokens(self, text: str):
28
+ l = len(self.model.tokenizer.encode(text))
29
+ # print(l)
30
+ return l
31
+
32
+ def count_tokens_batch(self, texts: List[str]):
33
+ encodings = self.model.tokenizer(texts)
34
+ # print([len(enc) for enc in encodings["input_ids"]])
35
+ return [len(enc) for enc in encodings["input_ids"]]
36
+
37
+ def get_tokenizer_or_token_counter(self):
38
+ return self.model.tokenizer
39
+
40
+ def similarity(self, u: "np.ndarray", v: "np.ndarray"):
41
+ """Compute cosine similarity between two embeddings."""
42
+ s = ([email protected])#.item()
43
+ # print(s)
44
+ return s
45
+
46
+ @classmethod
47
+ def is_available(cls):
48
+ return True
49
+
50
+ def __repr__(self):
51
+ return "bgem3"
52
+
53
+
54
+ def main():
55
+ # Initialize the BGE M3 embeddings model
56
+ embedding_model = BGEM3Embeddings(
57
+ model_name="BAAI/bge-m3"
58
+ )
59
+
60
+ # Initialize the SDPM chunker
61
+ chunker = SDPMChunker(
62
+ embedding_model=embedding_model,
63
+ chunk_size=256,
64
+ threshold=0.7,
65
+ skip_window=2
66
+ )
67
+
68
+ with open('./output.md', 'r') as file:
69
+ text = file.read()
70
+
71
+ # Generate chunks
72
+ chunks = chunker.chunk(text)
73
+
74
+ # Print the chunks
75
+ for i, chunk in enumerate(chunks, 1):
76
+ print(f"\nChunk {i}:")
77
+ print(f"Text: {chunk.text}")
78
+ print(f"Token count: {chunk.token_count}")
79
+ print(f"Start index: {chunk.start_index}")
80
+ print(f"End index: {chunk.end_index}")
81
+ print(f"no of sentences: {len(chunk.sentences)}")
82
+ print("-" * 80)
83
+
84
+ if __name__ == "__main__":
85
+ main()
src/data_processing/loader.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Union
3
+ import logging
4
+ from dataclasses import dataclass
5
+
6
+ from langchain_core.documents import Document as LCDocument
7
+ from langchain_core.document_loaders import BaseLoader
8
+ from docling.document_converter import DocumentConverter, PdfFormatOption
9
+ from docling.datamodel.base_models import InputFormat, ConversionStatus
10
+ from docling.datamodel.pipeline_options import (
11
+ PdfPipelineOptions,
12
+ EasyOcrOptions
13
+ )
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ _log = logging.getLogger(__name__)
17
+
18
+ @dataclass
19
+ class ProcessingResult:
20
+ """Store results of document processing"""
21
+ success_count: int = 0
22
+ failure_count: int = 0
23
+ partial_success_count: int = 0
24
+ failed_files: List[str] = None
25
+
26
+ def __post_init__(self):
27
+ if self.failed_files is None:
28
+ self.failed_files = []
29
+
30
+ class MultiFormatDocumentLoader(BaseLoader):
31
+ """Loader for multiple document formats that converts to LangChain documents"""
32
+
33
+ def __init__(
34
+ self,
35
+ file_paths: Union[str, List[str]],
36
+ enable_ocr: bool = True,
37
+ enable_tables: bool = True
38
+ ):
39
+ self._file_paths = [file_paths] if isinstance(file_paths, str) else file_paths
40
+ self._enable_ocr = enable_ocr
41
+ self._enable_tables = enable_tables
42
+ self._converter = self._setup_converter()
43
+
44
+ def _setup_converter(self):
45
+ """Set up the document converter with appropriate options"""
46
+ # Configure pipeline options
47
+ pipeline_options = PdfPipelineOptions(do_ocr=False, do_table_structure=False, ocr_options=EasyOcrOptions(
48
+ force_full_page_ocr=True
49
+ ))
50
+ if self._enable_ocr:
51
+ pipeline_options.do_ocr = True
52
+ if self._enable_tables:
53
+ pipeline_options.do_table_structure = True
54
+ pipeline_options.table_structure_options.do_cell_matching = True
55
+
56
+ # Create converter with supported formats
57
+ return DocumentConverter(
58
+ allowed_formats=[
59
+ InputFormat.PDF,
60
+ InputFormat.IMAGE,
61
+ InputFormat.DOCX,
62
+ InputFormat.HTML,
63
+ InputFormat.PPTX,
64
+ InputFormat.ASCIIDOC,
65
+ InputFormat.MD,
66
+ ],
67
+ format_options={
68
+ InputFormat.PDF: PdfFormatOption(
69
+ pipeline_options=pipeline_options,
70
+ )}
71
+ )
72
+
73
+ def lazy_load(self):
74
+ """Convert documents and yield LangChain documents"""
75
+ results = ProcessingResult()
76
+
77
+ for file_path in self._file_paths:
78
+ try:
79
+ path = Path(file_path)
80
+ if not path.exists():
81
+ _log.warning(f"File not found: {file_path}")
82
+ results.failure_count += 1
83
+ results.failed_files.append(file_path)
84
+ continue
85
+
86
+ conversion_result = self._converter.convert(path)
87
+
88
+ if conversion_result.status == ConversionStatus.SUCCESS:
89
+ results.success_count += 1
90
+ text = conversion_result.document.export_to_markdown()
91
+ metadata = {
92
+ 'source': str(path),
93
+ 'file_type': path.suffix,
94
+ }
95
+ yield LCDocument(
96
+ page_content=text,
97
+ metadata=metadata
98
+ )
99
+ elif conversion_result.status == ConversionStatus.PARTIAL_SUCCESS:
100
+ results.partial_success_count += 1
101
+ _log.warning(f"Partial conversion for {file_path}")
102
+ text = conversion_result.document.export_to_markdown()
103
+ metadata = {
104
+ 'source': str(path),
105
+ 'file_type': path.suffix,
106
+ 'conversion_status': 'partial'
107
+ }
108
+ yield LCDocument(
109
+ page_content=text,
110
+ metadata=metadata
111
+ )
112
+ else:
113
+ results.failure_count += 1
114
+ results.failed_files.append(file_path)
115
+ _log.error(f"Failed to convert {file_path}")
116
+
117
+ except Exception as e:
118
+ _log.error(f"Error processing {file_path}: {str(e)}")
119
+ results.failure_count += 1
120
+ results.failed_files.append(file_path)
121
+
122
+ # Log final results
123
+ total = results.success_count + results.partial_success_count + results.failure_count
124
+ _log.info(
125
+ f"Processed {total} documents:\n"
126
+ f"- Successfully converted: {results.success_count}\n"
127
+ f"- Partially converted: {results.partial_success_count}\n"
128
+ f"- Failed: {results.failure_count}"
129
+ )
130
+ if results.failed_files:
131
+ _log.info("Failed files:")
132
+ for file in results.failed_files:
133
+ _log.info(f"- {file}")
134
+
135
+
136
+ if __name__ == '__main__':
137
+ # Load documents from a list of file paths
138
+ loader = MultiFormatDocumentLoader(
139
+ file_paths=[
140
+ # './data/2404.19756v1.pdf',
141
+ # './data/OD429347375590223100.pdf',
142
+ './data/Project Report Format.docx',
143
+ # './data/UNIT 2 GENDER BASED VIOLENCE.pptx'
144
+ ],
145
+ enable_ocr=False,
146
+ enable_tables=True
147
+ )
148
+ for doc in loader.lazy_load():
149
+ print(doc.page_content)
150
+ print(doc.metadata)
151
+ # save document in .md file
152
+ with open('output.md', 'w') as f:
153
+ f.write(doc.page_content)
src/llm/__init__.py ADDED
File without changes
src/llm/graders.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+ from langchain_ollama import ChatOllama
3
+ from pydantic import BaseModel, Field
4
+ from typing import List
5
+
6
+ class DocumentRelevance(BaseModel):
7
+ """Binary score for relevance check on retrieved documents."""
8
+ binary_score: str = Field(
9
+ description="Documents are relevant to the question, 'yes' or 'no'"
10
+ )
11
+
12
+ class HallucinationCheck(BaseModel):
13
+ """Binary score for hallucination present in generation answer."""
14
+ binary_score: str = Field(
15
+ description="Answer is grounded in the facts, 'yes' or 'no'"
16
+ )
17
+
18
+ class AnswerQuality(BaseModel):
19
+ """Binary score to assess answer addresses question."""
20
+ binary_score: str = Field(
21
+ description="Answer addresses the question, 'yes' or 'no'"
22
+ )
23
+
24
+ def create_llm_grader(grader_type: str, llm):
25
+ """
26
+ Create an LLM grader based on the specified type.
27
+
28
+ Args:
29
+ grader_type (str): Type of grader to create
30
+
31
+ Returns:
32
+ Callable: LLM grader function
33
+ """
34
+ # Initialize LLM
35
+
36
+ # Select grader type and create structured output
37
+ if grader_type == "document_relevance":
38
+ structured_llm_grader = llm.with_structured_output(DocumentRelevance)
39
+ system = """You are a grader assessing relevance of a retrieved document to a user question.
40
+ If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
41
+ It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
42
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
43
+
44
+ prompt = ChatPromptTemplate.from_messages([
45
+ ("system", system),
46
+ ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
47
+ ])
48
+
49
+ elif grader_type == "hallucination":
50
+ structured_llm_grader = llm.with_structured_output(HallucinationCheck)
51
+ system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts.
52
+ Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
53
+
54
+ prompt = ChatPromptTemplate.from_messages([
55
+ ("system", system),
56
+ ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
57
+ ])
58
+
59
+ elif grader_type == "answer_quality":
60
+ structured_llm_grader = llm.with_structured_output(AnswerQuality)
61
+ system = """You are a grader assessing whether an answer addresses / resolves a question.
62
+ Give a binary score 'yes' or 'no'. 'Yes' means that the answer resolves the question."""
63
+
64
+ prompt = ChatPromptTemplate.from_messages([
65
+ ("system", system),
66
+ ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
67
+ ])
68
+
69
+ else:
70
+ raise ValueError(f"Unknown grader type: {grader_type}")
71
+
72
+ return prompt | structured_llm_grader
73
+
74
+ def grade_document_relevance(question: str, document: str, llm):
75
+ """
76
+ Grade the relevance of a document to a given question.
77
+
78
+ Args:
79
+ question (str): User's question
80
+ document (str): Retrieved document content
81
+
82
+ Returns:
83
+ str: Binary score ('yes' or 'no')
84
+ """
85
+ grader = create_llm_grader("document_relevance", llm)
86
+ result = grader.invoke({"question": question, "document": document})
87
+ return result.binary_score
88
+
89
+ def check_hallucination(documents: List[str], generation: str, llm):
90
+ """
91
+ Check if the generation is grounded in the provided documents.
92
+
93
+ Args:
94
+ documents (List[str]): List of source documents
95
+ generation (str): LLM generated answer
96
+
97
+ Returns:
98
+ str: Binary score ('yes' or 'no')
99
+ """
100
+ grader = create_llm_grader("hallucination", llm)
101
+ result = grader.invoke({"documents": documents, "generation": generation})
102
+ return result.binary_score
103
+
104
+ def grade_answer_quality(question: str, generation: str, llm):
105
+ """
106
+ Grade the quality of the answer in addressing the question.
107
+
108
+ Args:
109
+ question (str): User's original question
110
+ generation (str): LLM generated answer
111
+
112
+ Returns:
113
+ str: Binary score ('yes' or 'no')
114
+ """
115
+ grader = create_llm_grader("answer_quality", llm)
116
+ result = grader.invoke({"question": question, "generation": generation})
117
+ return result.binary_score
118
+
119
+ if __name__ == "__main__":
120
+ # Example usage
121
+ test_question = "What are the types of agent memory?"
122
+ test_document = "Agent memory can be classified into different types such as episodic, semantic, and working memory."
123
+ test_generation = "Agent memory includes episodic memory for storing experiences, semantic memory for general knowledge, and working memory for immediate processing."
124
+ llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
125
+
126
+ print("Document Relevance:", grade_document_relevance(test_question, test_document, llm))
127
+ print("Hallucination Check:", check_hallucination([test_document], test_generation, llm))
128
+ print("Answer Quality:", grade_answer_quality(test_question, test_generation, llm))
src/llm/query_rewriter.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts.chat import ChatPromptTemplate
2
+ from langchain_ollama import ChatOllama
3
+ from langchain_core.output_parsers import StrOutputParser
4
+
5
+ def create_query_rewriter(llm):
6
+ """
7
+ Create a query rewriter to optimize retrieval.
8
+
9
+ Returns:
10
+ Callable: Query rewriter function
11
+ """
12
+
13
+ # Prompt for query rewriting
14
+ system = """You are a question re-writer that converts an input question to a better version that is optimized
15
+ for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
16
+
17
+ re_write_prompt = ChatPromptTemplate.from_messages([
18
+ ("system", system),
19
+ ("human", "Here is the initial question: \n\n {question} \n Formulate an improved question."),
20
+ ])
21
+
22
+ # Create query rewriter chain
23
+ return re_write_prompt | llm | StrOutputParser()
24
+
25
+ def rewrite_query(question: str, llm):
26
+ """
27
+ Rewrite a given query to optimize retrieval.
28
+
29
+ Args:
30
+ question (str): Original user question
31
+
32
+ Returns:
33
+ str: Rewritten query
34
+ """
35
+ query_rewriter = create_query_rewriter(llm)
36
+ try:
37
+ rewritten_query = query_rewriter.invoke({"question": question})
38
+ return rewritten_query
39
+ except Exception as e:
40
+ print(f"Query rewriting error: {e}")
41
+ return question
42
+
43
+ if __name__ == "__main__":
44
+ # Example usage
45
+ test_queries = [
46
+ "Tell me about AI agents",
47
+ "What do we know about memory in AI systems?",
48
+ "Bears draft strategy"
49
+ ]
50
+ llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
51
+
52
+ for query in test_queries:
53
+ rewritten = rewrite_query(query, llm)
54
+ print(f"Original: {query}")
55
+ print(f"Rewritten: {rewritten}\n")
src/tools/__init__.py ADDED
File without changes
src/tools/web_search.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import asyncio
4
+ from typing import List, Dict, Optional
5
+
6
+ from langchain_community.tools import DuckDuckGoSearchResults
7
+ from crawl4ai import AsyncWebCrawler, CacheMode
8
+ from crawl4ai.content_filter_strategy import PruningContentFilter, BM25ContentFilter
9
+ from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
+
14
+ class AdvancedWebCrawler:
15
+ def __init__(self,
16
+ max_search_results: int = 5,
17
+ word_count_threshold: int = 50,
18
+ content_filter_type: str = 'pruning',
19
+ filter_threshold: float = 0.48):
20
+ """
21
+ Initialize the Advanced Web Crawler
22
+
23
+ Args:
24
+ max_search_results (int): Maximum number of search results to process
25
+ word_count_threshold (int): Minimum word count for crawled content
26
+ content_filter_type (str): Type of content filter ('pruning' or 'bm25')
27
+ filter_threshold (float): Threshold for content filtering
28
+ """
29
+ self.max_search_results = max_search_results
30
+ self.word_count_threshold = word_count_threshold
31
+ self.content_filter_type = content_filter_type
32
+ self.filter_threshold = filter_threshold
33
+
34
+ def _create_web_search_tool(self):
35
+ """
36
+ Create a web search tool using DuckDuckGo
37
+
38
+ Returns:
39
+ DuckDuckGoSearchResults: Web search tool
40
+ """
41
+ return DuckDuckGoSearchResults(max_results=self.max_search_results, output_format="list")
42
+
43
+ def _create_content_filter(self, user_query: Optional[str] = None):
44
+ """
45
+ Create content filter based on specified type
46
+
47
+ Args:
48
+ user_query (Optional[str]): Query to use for BM25 filtering
49
+
50
+ Returns:
51
+ Content filter strategy
52
+ """
53
+ if self.content_filter_type == 'bm25' and user_query:
54
+ return BM25ContentFilter(
55
+ user_query=user_query,
56
+ bm25_threshold=self.filter_threshold
57
+ )
58
+ else:
59
+ return PruningContentFilter(
60
+ threshold=self.filter_threshold,
61
+ threshold_type="fixed",
62
+ min_word_threshold=self.word_count_threshold
63
+ )
64
+
65
+ async def crawl_urls(self, urls: List[str], user_query: Optional[str] = None):
66
+ """
67
+ Crawl multiple URLs with content filtering
68
+
69
+ Args:
70
+ urls (List[str]): List of URLs to crawl
71
+ user_query (Optional[str]): Query used for BM25 content filtering
72
+
73
+ Returns:
74
+ List of crawl results
75
+ """
76
+ async with AsyncWebCrawler(
77
+ browser_type="chromium",
78
+ headless=True,
79
+ verbose=True
80
+ ) as crawler:
81
+ # Create appropriate content filter
82
+ content_filter = self._create_content_filter(user_query)
83
+
84
+ # Run crawling for multiple URLs
85
+ results = await crawler.arun_many(
86
+ urls=urls,
87
+ word_count_threshold=self.word_count_threshold,
88
+ bypass_cache=True,
89
+ markdown_generator=DefaultMarkdownGenerator(
90
+ content_filter=content_filter
91
+ ),
92
+ cache_mode=CacheMode.DISABLED,
93
+ exclude_external_links=True,
94
+ remove_overlay_elements=True,
95
+ simulate_user=True,
96
+ magic=True
97
+ )
98
+
99
+ # Process and return crawl results
100
+ processed_results = []
101
+ for result in results:
102
+ crawl_result = {
103
+ "url": result.url,
104
+ "success": result.success,
105
+ "title": result.metadata.get('title', 'N/A'),
106
+ "content": result.markdown_v2.raw_markdown if result.success else result.error_message,
107
+ "word_count": len(result.markdown_v2.raw_markdown.split()) if result.success else 0,
108
+ "links": {
109
+ "internal": len(result.links.get('internal', [])),
110
+ "external": len(result.links.get('external', []))
111
+ },
112
+ "images": len(result.media.get('images', []))
113
+ }
114
+ processed_results.append(crawl_result)
115
+
116
+ return processed_results
117
+
118
+ async def search_and_crawl(self, query: str) -> List[Dict]:
119
+ """
120
+ Perform web search and crawl the results
121
+
122
+ Args:
123
+ query (str): Search query
124
+
125
+ Returns:
126
+ List of crawled content results
127
+ """
128
+ # Perform web search
129
+ search_tool = self._create_web_search_tool()
130
+ try:
131
+ search_results = search_tool.invoke({"query": query})
132
+
133
+ # Extract URLs from search results
134
+ urls = [result['link'] for result in search_results]
135
+ print(f"Found {len(urls)} URLs for query: {query}")
136
+
137
+ # Crawl URLs
138
+ crawl_results = await self.crawl_urls(urls, user_query=query)
139
+
140
+ return crawl_results
141
+
142
+ except Exception as e:
143
+ print(f"Web search and crawl error: {e}")
144
+ return []
145
+
146
+ def main():
147
+ # Example usage
148
+ crawler = AdvancedWebCrawler(
149
+ max_search_results=5,
150
+ word_count_threshold=50,
151
+ content_filter_type='f',
152
+ filter_threshold=0.48
153
+ )
154
+
155
+ test_queries = [
156
+ "Latest developments in AI agents",
157
+ "Today's weather forecast in Kolkata",
158
+ ]
159
+
160
+ for query in test_queries:
161
+ # Run search and crawl asynchronously
162
+ results = asyncio.run(crawler.search_and_crawl(query))
163
+
164
+ print(f"\nResults for query: {query}")
165
+ for result in results:
166
+ print(f"URL: {result['url']}")
167
+ print(f"Success: {result['success']}")
168
+ print(f"Title: {result['title']}")
169
+ print(f"Word Count: {result['word_count']}")
170
+ print(f"Content Preview: {result['content'][:500]}...\n")
171
+
172
+ if __name__ == "__main__":
173
+ main()
src/vectorstore/__init__.py ADDED
File without changes
src/vectorstore/pinecone_db.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_processing.loader import MultiFormatDocumentLoader
2
+ from data_processing.chunker import SDPMChunker, BGEM3Embeddings
3
+
4
+ import pandas as pd
5
+ from typing import List, Dict, Any
6
+ from pinecone import Pinecone, ServerlessSpec
7
+ import time
8
+ from tqdm import tqdm
9
+ from dotenv import load_dotenv
10
+ import os
11
+
12
+
13
+ load_dotenv()
14
+
15
+ # API Keys
16
+ PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
17
+
18
+ embedding_model = BGEM3Embeddings(model_name="BAAI/bge-m3")
19
+
20
+
21
+ def load_documents(file_paths: List[str], output_path='./data/output.md'):
22
+ """
23
+ Load documents from multiple sources and combine them into a single markdown file
24
+ """
25
+ loader = MultiFormatDocumentLoader(
26
+ file_paths=file_paths,
27
+ enable_ocr=False,
28
+ enable_tables=True
29
+ )
30
+
31
+ # Append all documents to the markdown file
32
+ with open(output_path, 'w') as f:
33
+ for doc in loader.lazy_load():
34
+ # Add metadata as YAML frontmatter
35
+ f.write('---\n')
36
+ for key, value in doc.metadata.items():
37
+ f.write(f'{key}: {value}\n')
38
+ f.write('---\n\n')
39
+ f.write(doc.page_content)
40
+ f.write('\n\n')
41
+
42
+ return output_path
43
+
44
+ def process_chunks(markdown_path: str, chunk_size: int = 256,
45
+ threshold: float = 0.7, skip_window: int = 2):
46
+ """
47
+ Process the markdown file into chunks and prepare for vector storage
48
+ """
49
+ chunker = SDPMChunker(
50
+ embedding_model=embedding_model,
51
+ chunk_size=chunk_size,
52
+ threshold=threshold,
53
+ skip_window=skip_window
54
+ )
55
+
56
+ # Read the markdown file
57
+ with open(markdown_path, 'r') as file:
58
+ text = file.read()
59
+
60
+ # Generate chunks
61
+ chunks = chunker.chunk(text)
62
+
63
+ # Prepare data for Parquet
64
+ processed_chunks = []
65
+ for chunk in chunks:
66
+
67
+ processed_chunks.append({
68
+ 'text': chunk.text,
69
+ 'token_count': chunk.token_count,
70
+ 'start_index': chunk.start_index,
71
+ 'end_index': chunk.end_index,
72
+ 'num_sentences': len(chunk.sentences),
73
+ })
74
+
75
+ return processed_chunks
76
+
77
+ def save_to_parquet(chunks: List[Dict[str, Any]], output_path='./data/chunks.parquet'):
78
+ """
79
+ Save processed chunks to a Parquet file
80
+ """
81
+ df = pd.DataFrame(chunks)
82
+ print(f"Saving to Parquet: {output_path}")
83
+ df.to_parquet(output_path)
84
+ print(f"Saved to Parquet: {output_path}")
85
+ return output_path
86
+
87
+
88
+ class PineconeRetriever:
89
+ def __init__(
90
+ self,
91
+ pinecone_client: Pinecone,
92
+ index_name: str,
93
+ namespace: str,
94
+ embedding_generator: BGEM3Embeddings
95
+ ):
96
+ """Initialize the retriever with Pinecone client and embedding generator.
97
+
98
+ Args:
99
+ pinecone_client: Initialized Pinecone client
100
+ index_name: Name of the Pinecone index
101
+ namespace: Namespace in the index
102
+ embedding_generator: BGEM3Embeddings instance
103
+ """
104
+ self.pinecone = pinecone_client
105
+ self.index = self.pinecone.Index(index_name)
106
+ self.namespace = namespace
107
+ self.embedding_generator = embedding_generator
108
+
109
+ def invoke(self, question: str, top_k: int = 5):
110
+ """Retrieve similar documents for a question.
111
+
112
+ Args:
113
+ question: Query string
114
+ top_k: Number of results to return
115
+
116
+ Returns:
117
+ List of dictionaries containing retrieved documents
118
+ """
119
+ # Generate embedding for the question
120
+ question_embedding = self.embedding_generator.embed(question)
121
+ question_embedding = question_embedding.tolist()
122
+ # Query Pinecone
123
+ results = self.index.query(
124
+ namespace=self.namespace,
125
+ vector=question_embedding,
126
+ top_k=top_k,
127
+ include_values=False,
128
+ include_metadata=True
129
+ )
130
+
131
+ # Format results
132
+ retrieved_docs = [
133
+ {"page_content": match.metadata["text"], "score": match.score}
134
+ for match in results.matches
135
+ ]
136
+
137
+ return retrieved_docs
138
+
139
+ def ingest_data(
140
+ pc,
141
+ parquet_path: str,
142
+ text_column: str,
143
+ pinecone_client: Pinecone,
144
+ index_name= "vector-index",
145
+ namespace= "rag",
146
+ batch_size: int = 100
147
+ ):
148
+ """Ingest data from a Parquet file into Pinecone.
149
+
150
+ Args:
151
+ parquet_path: Path to the Parquet file
152
+ text_column: Name of the column containing text data
153
+ pinecone_client: Initialized Pinecone client
154
+ index_name: Name of the Pinecone index
155
+ namespace: Namespace in the index
156
+ batch_size: Batch size for processing
157
+ """
158
+ # Read Parquet file
159
+ print(f"Reading Parquet file: {parquet_path}")
160
+ df = pd.read_parquet(parquet_path)
161
+ print(f"Total records: {len(df)}")
162
+ # Create or get index
163
+ if not pinecone_client.has_index(index_name):
164
+ pinecone_client.create_index(
165
+ name=index_name,
166
+ dimension=1024, # BGE-M3 dimension
167
+ metric="cosine",
168
+ spec=ServerlessSpec(
169
+ cloud='aws',
170
+ region='us-east-1'
171
+ )
172
+ )
173
+
174
+ # Wait for index to be ready
175
+ while not pinecone_client.describe_index(index_name).status['ready']:
176
+ time.sleep(1)
177
+
178
+ index = pinecone_client.Index(index_name)
179
+
180
+ # Process in batches
181
+ for i in tqdm(range(0, len(df), batch_size)):
182
+ batch_df = df.iloc[i:i+batch_size]
183
+
184
+ # Generate embeddings for batch
185
+ texts = batch_df[text_column].tolist()
186
+ embeddings = embedding_model.embed_batch(texts)
187
+ print(f"embeddings for batch: {i}")
188
+ # Prepare records for upsert
189
+ records = []
190
+ for idx, (_, row) in enumerate(batch_df.iterrows()):
191
+ records.append({
192
+ "id": str(row.name), # Using DataFrame index as ID
193
+ "values": embeddings[idx],
194
+ "metadata": {"text": row[text_column]}
195
+ })
196
+
197
+ # Upsert to Pinecone
198
+ index.upsert(vectors=records, namespace=namespace)
199
+
200
+ # Small delay to handle rate limits
201
+ time.sleep(0.5)
202
+
203
+ def get_retriever(
204
+ pinecone_client: Pinecone,
205
+ index_name= "vector-index",
206
+ namespace= "rag"
207
+ ):
208
+ """Create and return a PineconeRetriever instance.
209
+
210
+ Args:
211
+ pinecone_client: Initialized Pinecone client
212
+ index_name: Name of the Pinecone index
213
+ namespace: Namespace in the index
214
+
215
+ Returns:
216
+ Configured PineconeRetriever instance
217
+ """
218
+ return PineconeRetriever(
219
+ pinecone_client=pinecone_client,
220
+ index_name=index_name,
221
+ namespace=namespace,
222
+ embedding_generator=embedding_model
223
+ )
224
+
225
+ def main():
226
+ # Initialize Pinecone client
227
+ pc = Pinecone(api_key=PINECONE_API_KEY)
228
+
229
+ # Define input files
230
+ file_paths=[
231
+ # './data/2404.19756v1.pdf',
232
+ # './data/OD429347375590223100.pdf',
233
+ # './data/Project Report Format.docx',
234
+ './data/UNIT 2 GENDER BASED VIOLENCE.pptx'
235
+ ]
236
+
237
+ # Process pipeline
238
+ try:
239
+ # Step 1: Load and combine documents
240
+ # print("Loading documents...")
241
+ # markdown_path = load_documents(file_paths)
242
+
243
+ # # Step 2: Process into chunks with embeddings
244
+ # print("Processing chunks...")
245
+ # chunks = process_chunks(markdown_path)
246
+
247
+ # # Step 3: Save to Parquet
248
+ # print("Saving to Parquet...")
249
+ # parquet_path = save_to_parquet(chunks)
250
+
251
+ # # Step 4: Ingest into Pinecone
252
+ # print("Ingesting into Pinecone...")
253
+ # ingest_data(
254
+ # pc,
255
+ # parquet_path=parquet_path,
256
+ # text_column="text",
257
+ # pinecone_client=pc,
258
+ # )
259
+
260
+ # Step 5: Test retrieval
261
+ print("\nTesting retrieval...")
262
+ retriever = get_retriever(
263
+ pinecone_client=pc,
264
+ index_name="vector-index",
265
+ namespace="rag"
266
+ )
267
+
268
+ results = retriever.invoke(
269
+ question="describe the gender based violence",
270
+ top_k=5
271
+ )
272
+
273
+ for i, doc in enumerate(results, 1):
274
+ print(f"\nResult {i}:")
275
+ print(f"Content: {doc['page_content']}...")
276
+ print(f"Score: {doc['score']}")
277
+
278
+ except Exception as e:
279
+ print(f"Error in pipeline: {str(e)}")
280
+
281
+ if __name__ == "__main__":
282
+ main()