anindya-hf-2002 commited on
Commit
db17bc0
·
verified ·
1 Parent(s): d3b52c3

Upload 19 files

Browse files
app.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import asyncio
3
+ from src.vectorstore.pinecone_db import ingest_data, get_retriever, load_documents, process_chunks, save_to_parquet
4
+ from src.agents.research_agent import create_industry_research_workflow
5
+ from src.agents.workflow import run_adaptive_rag
6
+ from pinecone import Pinecone
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_ollama import ChatOllama
9
+ from langgraph.pregel import GraphRecursionError
10
+ import tempfile
11
+ import os
12
+ import time
13
+ from pathlib import Path
14
+
15
+ # Page configuration
16
+ st.set_page_config(
17
+ page_title="Research & RAG Assistant",
18
+ page_icon="🤖",
19
+ layout="wide",
20
+ initial_sidebar_state="expanded"
21
+ )
22
+
23
+ # Custom CSS for better UI
24
+ st.markdown("""
25
+ <style>
26
+ .stTabs [data-baseweb="tab-list"] {
27
+ gap: 24px;
28
+ }
29
+ .stTabs [data-baseweb="tab"] {
30
+ padding: 8px 16px;
31
+ }
32
+ .config-section {
33
+ background-color: #f0f2f6;
34
+ border-radius: 10px;
35
+ padding: 20px;
36
+ margin: 10px 0;
37
+ }
38
+ .chat-container {
39
+ border: 1px solid #e0e0e0;
40
+ border-radius: 10px;
41
+ padding: 20px;
42
+ margin-top: 20px;
43
+ }
44
+ .stButton>button {
45
+ width: 100%;
46
+ }
47
+ </style>
48
+ """, unsafe_allow_html=True)
49
+
50
+ # Initialize session states
51
+ if "messages" not in st.session_state:
52
+ st.session_state.messages = []
53
+ if "documents_processed" not in st.session_state:
54
+ st.session_state.documents_processed = False
55
+ if "retriever" not in st.session_state:
56
+ st.session_state.retriever = None
57
+ if "pinecone_client" not in st.session_state:
58
+ st.session_state.pinecone_client = None
59
+ if "research_config_saved" not in st.session_state:
60
+ st.session_state.research_config_saved = False
61
+ if "rag_config_saved" not in st.session_state:
62
+ st.session_state.rag_config_saved = False
63
+
64
+ def save_research_config(api_keys):
65
+ """Save research configuration."""
66
+ st.session_state.research_openai_key = api_keys['openai']
67
+ st.session_state.research_tavily_key = api_keys['tavily']
68
+ st.session_state.research_config_saved = True
69
+
70
+
71
+ def research_config_section():
72
+ """Configuration section for Company Research tab."""
73
+ st.markdown("### ⚙️ Configuration")
74
+
75
+ with st.expander("API Configuration", expanded=not st.session_state.research_config_saved):
76
+ col1, col2 = st.columns(2)
77
+ with col1:
78
+ openai_key = st.text_input(
79
+ "OpenAI API Key",
80
+ type="password",
81
+ value=st.session_state.get('research_openai_key', ''),
82
+ key="research_openai_input"
83
+ )
84
+ with col2:
85
+ tavily_key = st.text_input(
86
+ "Tavily API Key",
87
+ type="password",
88
+ value=st.session_state.get('research_tavily_key', ''),
89
+ key="research_tavily_input"
90
+ )
91
+
92
+ if st.button("Save Research Configuration", key="save_research_config"):
93
+ if openai_key and tavily_key:
94
+ save_research_config({
95
+ 'openai': openai_key,
96
+ 'tavily': tavily_key
97
+ })
98
+ if not os.environ.get("TAVILY_API_KEY"):
99
+ os.environ["TAVILY_API_KEY"] = tavily_key
100
+ st.success("✅ Research configuration saved!")
101
+ else:
102
+ st.error("Please provide both API keys.")
103
+
104
+
105
+ async def run_industry_research(company: str, industry: str, llm):
106
+ """Run the industry research workflow asynchronously."""
107
+ workflow = create_industry_research_workflow(llm)
108
+
109
+ output = await workflow.ainvoke(input={
110
+ "company": company,
111
+ "industry": industry
112
+ }, config={"recursion_limit": 5})
113
+
114
+ return output['final_report']
115
+
116
+
117
+ def research_input_section():
118
+ """Input section for Company Research tab."""
119
+ st.markdown("### 🔍 Research Parameters")
120
+
121
+ col1, col2 = st.columns(2)
122
+ with col1:
123
+ company_name = st.text_input(
124
+ "Company Name",
125
+ placeholder="e.g., Tesla",
126
+ help="Enter the name of the company to research"
127
+ )
128
+ with col2:
129
+ industry_type = st.text_input(
130
+ "Industry Type",
131
+ placeholder="e.g., Automotive",
132
+ help="Enter the industry sector"
133
+ )
134
+
135
+ if st.button("Generate Research Report",
136
+ disabled=not st.session_state.research_config_saved,
137
+ type="primary"):
138
+ if company_name and industry_type:
139
+ with st.spinner("🔄 Generating comprehensive research report..."):
140
+ # try:
141
+ # Initialize LLM and run research
142
+ llm = ChatOpenAI(
143
+ model="gpt-3.5-turbo-0125",
144
+ temperature=0.1,
145
+ api_key=st.session_state.research_openai_key
146
+ )
147
+
148
+ report_path = asyncio.run(run_industry_research(
149
+ company=company_name,
150
+ industry=industry_type,
151
+ llm=llm
152
+ ))
153
+
154
+ if os.path.exists(report_path):
155
+ with open(report_path, "rb") as file:
156
+ st.download_button(
157
+ "📥 Download Research Report",
158
+ data=file,
159
+ file_name=f"{company_name}_research_report.pdf",
160
+ mime="application/pdf"
161
+ )
162
+ else:
163
+ st.error("Report generation failed.")
164
+ # except Exception as e:
165
+ # st.error(f"Error during report generation: {str(e)}")
166
+ else:
167
+ st.warning("Please fill in both company name and industry type.")
168
+
169
+ def initialize_pinecone(api_key):
170
+ """Initialize Pinecone client with API key."""
171
+ try:
172
+ return Pinecone(api_key=api_key)
173
+ except Exception as e:
174
+ st.error(f"Error initializing Pinecone: {str(e)}")
175
+ return None
176
+
177
+ def initialize_llm(llm_option, openai_api_key=None):
178
+ """Initialize LLM based on user selection."""
179
+ if llm_option == "OpenAI":
180
+ if not openai_api_key:
181
+ st.sidebar.warning("Please enter OpenAI API key.")
182
+ return None
183
+ return ChatOpenAI(api_key=openai_api_key, model="gpt-3.5-turbo")
184
+
185
+ def clear_pinecone_index(pc, index_name="vector-index"):
186
+ """Clear the Pinecone index."""
187
+ try:
188
+ pc.delete_index(index_name)
189
+ st.session_state.documents_processed = False
190
+ st.session_state.retriever = None
191
+ st.success("Database cleared successfully!")
192
+ except Exception as e:
193
+ st.error(f"Error clearing database: {str(e)}")
194
+
195
+ def process_documents(uploaded_files, pc):
196
+ """Process uploaded documents and store in Pinecone."""
197
+ if not uploaded_files:
198
+ st.warning("Please upload at least one document.")
199
+ return False
200
+
201
+ with st.spinner("Processing documents..."):
202
+ temp_dir = tempfile.mkdtemp()
203
+ file_paths = []
204
+ markdown_path = Path(temp_dir) / "combined.md"
205
+ parquet_path = Path(temp_dir) / "documents.parquet"
206
+
207
+ for uploaded_file in uploaded_files:
208
+ file_path = Path(temp_dir) / uploaded_file.name
209
+ with open(file_path, "wb") as f:
210
+ f.write(uploaded_file.getvalue())
211
+ file_paths.append(str(file_path))
212
+
213
+ try:
214
+ markdown_path = load_documents(file_paths, output_path=markdown_path)
215
+ chunks = process_chunks(markdown_path, chunk_size=256, threshold=0.6)
216
+ print(f"Processed chunks: {chunks}")
217
+ parquet_path = save_to_parquet(chunks, parquet_path)
218
+
219
+ ingest_data(
220
+ pc=pc,
221
+ parquet_path=parquet_path,
222
+ text_column="text",
223
+ pinecone_client=pc
224
+ )
225
+
226
+ st.session_state.retriever = get_retriever(pc)
227
+ st.session_state.documents_processed = True
228
+
229
+ return True
230
+
231
+ except Exception as e:
232
+ st.error(f"Error processing documents: {str(e)}")
233
+ return False
234
+ finally:
235
+ for file_path in file_paths:
236
+ try:
237
+ os.remove(file_path)
238
+ except:
239
+ pass
240
+ try:
241
+ os.rmdir(temp_dir)
242
+ except:
243
+ pass
244
+
245
+ def run_rag_with_streaming(retriever, question, llm, enable_web_search=False):
246
+ """Run RAG workflow and yield streaming results."""
247
+ try:
248
+ response = run_adaptive_rag(
249
+ retriever=retriever,
250
+ question=question,
251
+ llm=llm,
252
+ top_k=5,
253
+ enable_websearch=enable_web_search
254
+ )
255
+
256
+ for word in response.split():
257
+ yield word + " "
258
+ time.sleep(0.03)
259
+
260
+ except GraphRecursionError:
261
+ response = "I apologize, but I cannot find a sufficient answer to your question in the provided documents. Please try rephrasing your question or ask something else about the content of the documents."
262
+ for word in response.split():
263
+ yield word + " "
264
+ time.sleep(0.03)
265
+
266
+ except Exception as e:
267
+ yield f"I encountered an error while processing your question: {str(e)}"
268
+
269
+
270
+ def document_upload_section():
271
+ """Document upload section for RAG tab."""
272
+ st.markdown("### 📄 Document Management")
273
+
274
+ if not st.session_state.documents_processed:
275
+ uploaded_files = st.file_uploader(
276
+ "Upload your documents",
277
+ accept_multiple_files=True,
278
+ type=["pdf", "docx", "txt", "pptx", "md"],
279
+ help="Support multiple file uploads"
280
+ )
281
+
282
+ col1, col2 = st.columns([3, 1])
283
+ with col1:
284
+ if uploaded_files:
285
+ st.info(f"📁 {len(uploaded_files)} files selected")
286
+ with col2:
287
+ if st.button(
288
+ "Process Documents",
289
+ disabled=not (uploaded_files and st.session_state.rag_config_saved)
290
+ ):
291
+ if process_documents(uploaded_files, st.session_state.pinecone_client):
292
+ st.success("✅ Documents processed successfully!")
293
+ else:
294
+ st.success("✅ Documents are loaded and ready for querying!")
295
+ if st.button("Upload New Documents"):
296
+ st.session_state.documents_processed = False
297
+ st.rerun()
298
+
299
+ # Update the save_rag_config function to remove web_search
300
+ def save_rag_config(config):
301
+ """Save RAG configuration."""
302
+ st.session_state.rag_pinecone_key = config['pinecone']
303
+ st.session_state.rag_openai_key = config['openai']
304
+ st.session_state.rag_config_saved = True
305
+
306
+ # Update the rag_config_section to remove web_search checkbox
307
+ def rag_config_section():
308
+ """Configuration section for RAG tab."""
309
+ st.markdown("### ⚙️ Configuration")
310
+
311
+ with st.expander("API Configuration", expanded=not st.session_state.rag_config_saved):
312
+ col1, col2 = st.columns(2)
313
+ with col1:
314
+ pinecone_key = st.text_input(
315
+ "Pinecone API Key",
316
+ type="password",
317
+ value=st.session_state.get('rag_pinecone_key', ''),
318
+ key="rag_pinecone_input"
319
+ )
320
+ with col2:
321
+ openai_key = st.text_input(
322
+ "OpenAI API Key",
323
+ type="password",
324
+ value=st.session_state.get('rag_openai_key', ''),
325
+ key="rag_openai_input"
326
+ )
327
+
328
+ if st.button("Save RAG Configuration", key="save_rag_config"):
329
+ if pinecone_key and openai_key:
330
+ save_rag_config({
331
+ 'pinecone': pinecone_key,
332
+ 'openai': openai_key
333
+ })
334
+ # Initialize Pinecone client
335
+ st.session_state.pinecone_client = initialize_pinecone(pinecone_key)
336
+ st.success("✅ RAG configuration saved!")
337
+ else:
338
+ st.error("Please provide both API keys.")
339
+
340
+ # Update the chat_interface function to include web search toggle
341
+ def chat_interface():
342
+ """Enhanced chat interface with streaming responses and web search toggle."""
343
+ st.markdown("### 💬 Chat Interface")
344
+
345
+ # Add web search toggle in the chat interface
346
+ col1, col2 = st.columns([3, 1])
347
+ with col2:
348
+ web_search = st.checkbox(
349
+ "🌐 Enable Web Search",
350
+ value=st.session_state.get('use_web_search', False),
351
+ help="Toggle web search for additional context",
352
+ key="web_search_toggle"
353
+ )
354
+ st.session_state.use_web_search = web_search
355
+
356
+ # Chat container with messages
357
+ chat_container = st.container()
358
+ with chat_container:
359
+ for message in st.session_state.messages:
360
+ with st.chat_message(message["role"]):
361
+ st.markdown(message["content"])
362
+
363
+ # Chat input
364
+ if prompt := st.chat_input(
365
+ "Ask a question about your documents...",
366
+ disabled=not st.session_state.documents_processed,
367
+ key="chat_input"
368
+ ):
369
+ # User message
370
+ with st.chat_message("user"):
371
+ if st.session_state.use_web_search:
372
+ st.markdown(f"{prompt} 🌐")
373
+ else:
374
+ st.markdown(prompt)
375
+ st.session_state.messages.append({"role": "user", "content": prompt})
376
+
377
+ # Assistant response
378
+ with st.chat_message("assistant"):
379
+ response_container = st.empty()
380
+ full_response = ""
381
+
382
+ try:
383
+ with st.spinner("Thinking..."):
384
+ llm = ChatOpenAI(
385
+ api_key=st.session_state.rag_openai_key,
386
+ model="gpt-3.5-turbo"
387
+ )
388
+
389
+ for chunk in run_rag_with_streaming(
390
+ retriever=st.session_state.retriever,
391
+ question=prompt,
392
+ llm=llm,
393
+ enable_web_search=st.session_state.use_web_search
394
+ ):
395
+ full_response += chunk
396
+ response_container.markdown(full_response + "▌")
397
+
398
+ response_container.markdown(full_response)
399
+ st.session_state.messages.append(
400
+ {"role": "assistant", "content": full_response}
401
+ )
402
+
403
+ except Exception as e:
404
+ st.error(f"Error: {str(e)}")
405
+
406
+ def main():
407
+ """Main application layout."""
408
+ st.title("🤖 Research & RAG Assistant")
409
+
410
+ tab1, tab2 = st.tabs(["🔍 Company Research", "💬 Document Q&A"])
411
+
412
+ with tab1:
413
+ research_config_section()
414
+ if st.session_state.research_config_saved:
415
+ st.divider()
416
+ research_input_section()
417
+ else:
418
+ st.info("👆 Please configure your API keys above to get started.")
419
+
420
+ with tab2:
421
+ rag_config_section()
422
+ if st.session_state.rag_config_saved:
423
+ st.divider()
424
+ document_upload_section()
425
+ if st.session_state.documents_processed:
426
+ st.divider()
427
+ chat_interface()
428
+ else:
429
+ st.info("👆 Please configure your API keys above to get started.")
430
+
431
+ if __name__ == "__main__":
432
+ main()
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain-community
2
+ tiktoken
3
+ langchain-openai
4
+ langchainhub
5
+ chromadb
6
+ langchain
7
+ langgraph
8
+ duckduckgo-search
9
+ langchain-groq
10
+ langchain-huggingface
11
+ sentence_transformers
12
+ tavily-python
13
+ langchain-ollama
14
+ ollama
15
+ crawl4ai
16
+ docling
17
+ easyocr
18
+ FlagEmbedding
19
+ chonkie[semantic]
20
+ pinecone
21
+ streamlit
22
+ markdown2
23
+ xhtml2pdf
24
+ PyPDF2
src/__init__.py ADDED
File without changes
src/agents/__init__.py ADDED
File without changes
src/agents/research_agent.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import END, StateGraph, START
2
+ from langchain_core.prompts import PromptTemplate
3
+ from langchain_core.output_parsers import StrOutputParser
4
+ from langchain_openai import ChatOpenAI
5
+
6
+ import re
7
+ import asyncio
8
+ from typing import TypedDict, List, Optional, Dict
9
+ from src.tools.deep_crawler import DeepWebCrawler, ResourceCollectionAgent
10
+
11
+ class ResearchGraphState(TypedDict):
12
+ company: str
13
+ industry: str
14
+ research_results: Optional[dict]
15
+ use_cases: Optional[str]
16
+ search_queries: Optional[Dict[str, List[str]]]
17
+ resources: Optional[List[dict]]
18
+ final_report: Optional[str]
19
+
20
+
21
+ def clean_text(text):
22
+ """
23
+ Cleans the given text by:
24
+ 1. Removing all hyperlinks.
25
+ 2. Removing unnecessary parentheses and square brackets.
26
+
27
+ Args:
28
+ text (str): The input text to be cleaned.
29
+
30
+ Returns:
31
+ str: The cleaned text with hyperlinks, parentheses, and square brackets removed.
32
+ """
33
+ # Regular expression pattern for matching URLs
34
+ url_pattern = r'https?://\S+|www\.\S+'
35
+ # Remove hyperlinks
36
+ text_without_links = re.sub(url_pattern, '', text)
37
+
38
+ # Regular expression pattern for matching parentheses and square brackets
39
+ brackets_pattern = r'[\[\]\(\)]'
40
+ # Remove unnecessary brackets
41
+ cleaned_text = re.sub(brackets_pattern, '', text_without_links)
42
+
43
+ return cleaned_text.strip()
44
+
45
+
46
+ def create_industry_research_workflow(llm):
47
+ async def industry_research(state: ResearchGraphState):
48
+ """Research industry and company using DeepWebCrawler."""
49
+ company = state['company']
50
+ industry = state['industry']
51
+
52
+ queries = [
53
+ f"{company} company profile services",
54
+ ]
55
+
56
+ crawler = DeepWebCrawler(
57
+ max_search_results=3,
58
+ max_external_links=1,
59
+ word_count_threshold=100,
60
+ content_filter_type='bm25',
61
+ filter_threshold=0.48
62
+ )
63
+
64
+ all_results = []
65
+ for query in queries:
66
+ results = await crawler.search_and_crawl(query)
67
+ all_results.extend(results)
68
+ print(all_results)
69
+ combined_content = "\n\n".join([
70
+ f"Title: {clean_text(r['title'])} \n{clean_text(r['content'])}"
71
+ for r in all_results if r['success']
72
+ ])
73
+ print("Combined Content: ", combined_content)
74
+ prompt = PromptTemplate.from_template(
75
+ """Analyze this research about {company} in the {industry} industry:
76
+ {content}
77
+
78
+ Provide a comprehensive overview including:
79
+ 1. Company Overview
80
+ 2. Market Segments
81
+ 3. Products and Services
82
+ 4. Strategic Focus Areas
83
+ 5. Industry Trends
84
+ 6. Competitive Position
85
+
86
+ Format the analysis in clear sections with headers."""
87
+ )
88
+
89
+ chain = prompt | llm | StrOutputParser()
90
+ analysis = chain.invoke({
91
+ "company": company,
92
+ "industry": industry,
93
+ "content": combined_content
94
+ })
95
+ print("Analysis: ", analysis)
96
+ return {
97
+ "research_results": {
98
+ "analysis": analysis,
99
+ "raw_content": combined_content
100
+ }
101
+ }
102
+
103
+ def generate_use_cases_and_queries(state: ResearchGraphState):
104
+ """Generate AI/ML use cases and extract relevant search queries."""
105
+ research_data = state['research_results']
106
+ company = state['company']
107
+ industry = state['industry']
108
+
109
+ # First generate use cases
110
+ use_case_prompt = PromptTemplate.from_template(
111
+ """Based on this research:
112
+
113
+ Analysis: {analysis}
114
+ Raw Research: {raw_content}
115
+
116
+ Generate innovative use cases where {company} in the {industry} industry can leverage
117
+ Generative AI and Large Language Models for:
118
+
119
+ 1. Internal Process Improvements
120
+ 2. Customer Experience Enhancement
121
+ 3. Product/Service Innovation
122
+ 4. Data Analytics and Decision Making
123
+
124
+ For each use case, provide:
125
+ - Clear description
126
+ - Expected benefits
127
+ - Implementation considerations"""
128
+ )
129
+
130
+ chain = use_case_prompt | llm | StrOutputParser()
131
+ use_cases = chain.invoke({
132
+ "company": company,
133
+ "industry": industry,
134
+ "analysis": research_data['analysis'],
135
+ "raw_content": research_data['raw_content']
136
+ })
137
+
138
+ # Then extract relevant search queries
139
+ query_extraction_prompt = PromptTemplate.from_template(
140
+ """Based on these AI/ML use cases for {company}:
141
+
142
+ {use_cases}
143
+
144
+ Extract Two specific search queries for finding relevant datasets and implementations.
145
+
146
+ Provide your response in this exact format:
147
+ DATASET QUERIES:
148
+ - query1
149
+ - query2
150
+
151
+ IMPLEMENTATION QUERIES:
152
+ - query1
153
+ - query2
154
+
155
+ Make queries specific and technical. Include ML model types, data types, and specific AI techniques."""
156
+ )
157
+
158
+ chain = query_extraction_prompt | llm | StrOutputParser()
159
+ queries_text = chain.invoke({
160
+ "company": company,
161
+ "use_cases": use_cases
162
+ })
163
+
164
+ # Parse the text response into structured format
165
+ def parse_queries(text):
166
+ dataset_queries = []
167
+ implementation_queries = []
168
+ current_section = None
169
+
170
+ for line in text.split('\n'):
171
+ line = line.strip()
172
+ if line == "DATASET QUERIES:":
173
+ current_section = "dataset"
174
+ elif line == "IMPLEMENTATION QUERIES:":
175
+ current_section = "implementation"
176
+ elif line.startswith("- "):
177
+ query = line[2:].strip()
178
+ if current_section == "dataset":
179
+ dataset_queries.append(query)
180
+ elif current_section == "implementation":
181
+ implementation_queries.append(query)
182
+
183
+ return {
184
+ "dataset_queries": dataset_queries or ["machine learning datasets business", "AI training data industry"],
185
+ "implementation_queries": implementation_queries or ["AI tools business automation", "machine learning implementation"]
186
+ }
187
+
188
+ search_queries = parse_queries(queries_text)
189
+ print("Search_queries: ", search_queries)
190
+ return {
191
+ "use_cases": use_cases,
192
+ "search_queries": search_queries
193
+ }
194
+
195
+ async def collect_targeted_resources(state: ResearchGraphState):
196
+ """Find relevant datasets and resources using extracted queries."""
197
+ search_queries = state['search_queries']
198
+ resource_agent = ResourceCollectionAgent(max_results_per_query=5)
199
+
200
+ # Collect resources using targeted queries
201
+ all_resources = {
202
+ "datasets": [],
203
+ "implementations": []
204
+ }
205
+
206
+ # Search for datasets
207
+ for query in search_queries['dataset_queries']:
208
+ # Add platform-specific modifiers to queries
209
+ kaggle_query = f"site:kaggle.com/datasets {query}"
210
+ huggingface_query = f"site:huggingface.co/datasets {query}"
211
+
212
+ resources = await resource_agent.collect_resources()
213
+
214
+ # Process and categorize results
215
+ if resources.get("kaggle_datasets"):
216
+ all_resources["datasets"].extend([{
217
+ "title": item["title"],
218
+ "url": item["url"],
219
+ "description": item["snippet"],
220
+ "platform": "Kaggle",
221
+ "query": query
222
+ } for item in resources["kaggle_datasets"]])
223
+
224
+ if resources.get("huggingface_datasets"):
225
+ all_resources["datasets"].extend([{
226
+ "title": item["title"],
227
+ "url": item["url"],
228
+ "description": item["snippet"],
229
+ "platform": "HuggingFace",
230
+ "query": query
231
+ } for item in resources["huggingface_datasets"]])
232
+
233
+ # Search for implementations
234
+ for query in search_queries['implementation_queries']:
235
+ github_query = f"site:github.com {query}"
236
+
237
+ resources = await resource_agent.collect_resources()
238
+
239
+ if resources.get("github_repositories"):
240
+ all_resources["implementations"].extend([{
241
+ "title": item["title"],
242
+ "url": item["url"],
243
+ "description": item["snippet"],
244
+ "platform": "GitHub",
245
+ "query": query
246
+ } for item in resources["github_repositories"]])
247
+ print("Resources: ", all_resources)
248
+ return {"resources": all_resources}
249
+
250
+ def generate_pdf_report(state: ResearchGraphState):
251
+ """Generate final PDF report with all collected information."""
252
+ research_data = state['research_results']
253
+ use_cases = state['use_cases']
254
+ resources = state['resources']
255
+ company = state['company']
256
+ industry = state['industry']
257
+
258
+ # Format resources for manual append later
259
+ datasets_section = "\n## Available Datasets\n"
260
+ if resources.get('datasets'):
261
+ for dataset in resources['datasets']:
262
+ datasets_section += f" - {dataset['platform']}: {dataset['url']}\n"
263
+
264
+ implementations_section = "\n## Implementation Resources\n"
265
+ if resources.get('implementations'):
266
+ for impl in resources['implementations']:
267
+ implementations_section += f" - {impl['platform']}: {impl['url']}\n"
268
+
269
+
270
+ prompt = PromptTemplate.from_template(
271
+ """
272
+ # GenAI & ML Implementation Proposal for {company}
273
+
274
+ ## Executive Summary
275
+ - **Current Position in the {industry} Industry**:
276
+ - **Key Opportunities for AI/ML Implementation**:
277
+ - **Expected Business Impact and ROI**:
278
+ - **Implementation Timeline Overview**:
279
+
280
+ ## Industry and Company Analysis
281
+ {analysis}
282
+
283
+ ## Strategic AI/ML Implementation Opportunities
284
+
285
+ Based on the analysis, here are the key opportunities for AI/ML implementation:
286
+
287
+ {use_cases}
288
+
289
+ Format the report in Markdown for clear sections, headings, and bullet points. Ensure professional formatting with structured subsections.
290
+ """
291
+ )
292
+
293
+ chain = prompt | llm | StrOutputParser()
294
+ markdown_content = chain.invoke({
295
+ "company": company,
296
+ "industry": industry,
297
+ "analysis": research_data['analysis'],
298
+ "use_cases": use_cases,
299
+ })
300
+
301
+ if markdown_content.startswith("```markdown") and markdown_content.endswith("```"):
302
+ markdown_content = markdown_content[len("```markdown"):].rstrip("```").strip()
303
+
304
+ markdown_content += "\n\n" + datasets_section + "\n\n" + implementations_section
305
+ # Convert markdown to PDF
306
+ import tempfile
307
+ import os
308
+ import markdown2
309
+ from xhtml2pdf import pisa
310
+
311
+ # Create temporary directory and full path for PDF
312
+ temp_dir = tempfile.mkdtemp()
313
+ pdf_filename = f"{company.replace(' ', '_')}_research_report.pdf"
314
+ pdf_path = os.path.join(temp_dir, pdf_filename)
315
+
316
+ html_content = markdown2.markdown(markdown_content, extras=['tables', 'break-on-newline'])
317
+ # HTML template with enhanced styles (same as before)
318
+ html_template = f"""
319
+ <!DOCTYPE html>
320
+ <html>
321
+ <head>
322
+ <meta charset="UTF-8">
323
+ <style>
324
+ @page {{
325
+ size: A4;
326
+ margin: 2.5cm;
327
+ @frame footer {{
328
+ -pdf-frame-content: footerContent;
329
+ bottom: 1cm;
330
+ margin-left: 1cm;
331
+ margin-right: 1cm;
332
+ height: 1cm;
333
+ }}
334
+ }}
335
+ body {{
336
+ font-family: Helvetica, Arial, sans-serif;
337
+ font-size: 11pt;
338
+ line-height: 1.6;
339
+ color: #2c3e50;
340
+ }}
341
+ h1 {{
342
+ font-size: 24pt;
343
+ color: #1a237e;
344
+ text-align: center;
345
+ margin-bottom: 2cm;
346
+ font-weight: bold;
347
+ }}
348
+ h2 {{
349
+ font-size: 18pt;
350
+ color: #283593;
351
+ margin-top: 1.5cm;
352
+ border-bottom: 2px solid #3949ab;
353
+ padding-bottom: 0.3cm;
354
+ }}
355
+ h3 {{
356
+ font-size: 14pt;
357
+ color: #3949ab;
358
+ margin-top: 1cm;
359
+ }}
360
+ h4 {{
361
+ font-size: 12pt;
362
+ color: #5c6bc0;
363
+ margin-top: 0.8cm;
364
+ }}
365
+ p {{
366
+ text-align: justify;
367
+ margin-bottom: 0.5cm;
368
+ }}
369
+ ul {{
370
+ margin-left: 0;
371
+ padding-left: 1cm;
372
+ margin-bottom: 0.5cm;
373
+ }}
374
+ li {{
375
+ margin-bottom: 0.3cm;
376
+ }}
377
+ a {{
378
+ color: #3f51b5;
379
+ text-decoration: none;
380
+ }}
381
+ strong {{
382
+ color: #283593;
383
+ }}
384
+ .use-case {{
385
+ background-color: #f5f7fa;
386
+ padding: 1cm;
387
+ margin: 0.5cm 0;
388
+ border-left: 4px solid #3949ab;
389
+ }}
390
+ .benefit {{
391
+ margin-left: 1cm;
392
+ color: #34495e;
393
+ }}
394
+ </style>
395
+ </head>
396
+ <body>
397
+ {html_content}
398
+ <div id="footerContent" style="text-align: center; font-size: 8pt; color: #7f8c8d;">
399
+ Page <pdf:pagenumber> of <pdf:pagecount>
400
+ </div>
401
+ </body>
402
+ </html>
403
+ """
404
+
405
+ # Convert HTML to PDF with proper error handling
406
+ try:
407
+ with open(pdf_path, "w+b") as pdf_file:
408
+ result = pisa.CreatePDF(
409
+ html_template,
410
+ dest=pdf_file
411
+ )
412
+ if result.err:
413
+ print(f"Error generating PDF: {result.err}")
414
+ return {"final_report": None}
415
+
416
+ # Verify the file exists
417
+ if os.path.exists(pdf_path):
418
+ print(f"PDF successfully generated at: {pdf_path}")
419
+ return {"final_report": pdf_path}
420
+ else:
421
+ print("PDF file was not created successfully")
422
+ return {"final_report": None}
423
+
424
+ except Exception as e:
425
+ print(f"Exception during PDF generation: {str(e)}")
426
+ return {"final_report": None}
427
+
428
+ # Create workflow
429
+ workflow = StateGraph(ResearchGraphState)
430
+
431
+ # Add nodes
432
+ workflow.add_node("industry_research", industry_research)
433
+ workflow.add_node("use_cases_gen", generate_use_cases_and_queries)
434
+ workflow.add_node("resources_gen", collect_targeted_resources)
435
+ workflow.add_node("report", generate_pdf_report)
436
+
437
+ # Define edges
438
+ workflow.add_edge(START, "industry_research")
439
+ workflow.add_edge("industry_research", "use_cases_gen")
440
+ workflow.add_edge("use_cases_gen", "resources_gen")
441
+ workflow.add_edge("resources_gen", "report")
442
+ workflow.add_edge("report", END)
443
+
444
+ return workflow.compile()
445
+
446
+ async def run_industry_research(company: str, industry: str, llm):
447
+ """Run the industry research workflow asynchronously."""
448
+ workflow = create_industry_research_workflow(llm)
449
+
450
+ final_state = None
451
+ output = await workflow.ainvoke(input={
452
+ "company": company,
453
+ "industry": industry
454
+ }, config={"recursion_limit": 5})
455
+
456
+ return output['final_report']
457
+
458
+ # Example usage
459
+ if __name__ == "__main__":
460
+ async def main():
461
+ # Initialize LLM
462
+ llm = ChatOpenAI(
463
+ model="gpt-3.5-turbo-0125",
464
+ temperature=0.3,
465
+ timeout=None,
466
+ max_retries=2,)
467
+
468
+ # Run the workflow
469
+ report_path = await run_industry_research(
470
+ company="Adani Defence & Aerospace",
471
+ industry="Defense Engineering and Construction",
472
+ llm=llm
473
+ )
474
+ print(f"Report generated at: {report_path}")
475
+
476
+ asyncio.run(main())
src/agents/router.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+ from langchain_ollama import ChatOllama
3
+ from pydantic import BaseModel, Field
4
+ from typing import Literal
5
+
6
+ class RouteQuery(BaseModel):
7
+ """Route a user query to the most relevant datasource."""
8
+ datasource: Literal["vectorstore", "web_search"] = Field(
9
+ description="Route question to web search or vectorstore retrieval"
10
+ )
11
+
12
+ def create_query_router():
13
+ """
14
+ Create a query router to determine data source for a given question.
15
+
16
+ Returns:
17
+ Callable: Query router function
18
+ """
19
+ # LLM with function call
20
+ llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
21
+ structured_llm_router = llm.with_structured_output(RouteQuery)
22
+
23
+ # Prompt
24
+ system = """You are an expert at routing a user question to a vectorstore or web search.
25
+ The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
26
+ Use the vectorstore for questions on these topics. Otherwise, use web-search."""
27
+
28
+ route_prompt = ChatPromptTemplate.from_messages([
29
+ ("system", system),
30
+ ("human", "{question}"),
31
+ ])
32
+
33
+ return route_prompt | structured_llm_router
34
+
35
+ def route_query(question: str):
36
+ """
37
+ Route a specific query to its appropriate data source.
38
+
39
+ Args:
40
+ question (str): User's input question
41
+
42
+ Returns:
43
+ str: Recommended data source
44
+ """
45
+ router = create_query_router()
46
+ result = router.invoke({"question": question})
47
+ return result.datasource
48
+
49
+ if __name__ == "__main__":
50
+ # Example usage
51
+ test_questions = [
52
+ "Who will the Bears draft first in the NFL draft?",
53
+ "What are the types of agent memory?"
54
+ ]
55
+
56
+ for q in test_questions:
57
+ source = route_query(q)
58
+ print(f"Question: {q}")
59
+ print(f"Routed to: {source}\n")
src/agents/state.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, TypedDict
2
+ from langchain_core.documents.base import Document
3
+
4
+ class GraphState(TypedDict):
5
+ """
6
+ Represents the state of our adaptive RAG graph.
7
+
8
+ Attributes:
9
+ question (str): Original user question
10
+ generation (str, optional): LLM generated answer
11
+ documents (List[Document], optional): Retrieved or searched documents
12
+ """
13
+ question: str
14
+ generation: str | None
15
+ documents: List[Document]
16
+
17
+ class ResearchGraphState(TypedDict):
18
+ """
19
+ Represents the state of our adaptive RAG graph.
20
+
21
+ Attributes:
22
+ question (str): Original user question
23
+ generation (str, optional): LLM generated answer
24
+ documents (List[Document], optional): Retrieved or searched documents
25
+ """
26
+ company: str
27
+ industry: str | None
28
+ research_results: str | None
29
+ use_cases: str | None
30
+ search_queries: str | None
src/agents/workflow.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import END, StateGraph, START
2
+ from langchain_core.prompts import PromptTemplate
3
+ from src.agents.state import GraphState
4
+ # from agents.router import route_query
5
+ import asyncio
6
+ from src.vectorstore.pinecone_db import get_retriever
7
+ from src.tools.web_search import AdvancedWebCrawler
8
+ from src.llm.graders import (
9
+ grade_document_relevance,
10
+ check_hallucination,
11
+ grade_answer_quality
12
+ )
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from src.llm.query_rewriter import rewrite_query
15
+ from langchain_ollama import ChatOllama
16
+
17
+ def perform_web_search(question: str):
18
+ """
19
+ Perform web search using the AdvancedWebCrawler.
20
+
21
+ Args:
22
+ question (str): User's input question
23
+
24
+ Returns:
25
+ List: Web search results
26
+ """
27
+ # Initialize web crawler
28
+ crawler = AdvancedWebCrawler(
29
+ max_search_results=5,
30
+ word_count_threshold=50,
31
+ content_filter_type='f',
32
+ filter_threshold=0.48
33
+ )
34
+ results = asyncio.run(crawler.search_and_crawl(question))
35
+
36
+ return results
37
+
38
+
39
+ def create_adaptive_rag_workflow(retriever, llm, top_k=5, enable_websearch=False):
40
+ """
41
+ Create the adaptive RAG workflow graph.
42
+
43
+ Args:
44
+ retriever: Vector store retriever
45
+
46
+ Returns:
47
+ Compiled LangGraph workflow
48
+ """
49
+ def retrieve(state: GraphState):
50
+ """Retrieve documents from vectorstore."""
51
+ print("---RETRIEVE---")
52
+ question = state['question']
53
+ documents = retriever.invoke(question, top_k)
54
+ print(f"Retrieved {len(documents)} documents.")
55
+ print(documents)
56
+ return {"documents": documents, "question": question}
57
+
58
+ def route_to_datasource(state: GraphState):
59
+ """Route question to web search or vectorstore."""
60
+ print("---ROUTE QUESTION---")
61
+ # question = state['question']
62
+ # source = route_query(question)
63
+
64
+ if enable_websearch:
65
+ print("---ROUTE TO WEB SEARCH---")
66
+ return "web_search"
67
+ else:
68
+ print("---ROUTE TO RAG---")
69
+ return "vectorstore"
70
+
71
+ def generate_answer(state: GraphState):
72
+ """Generate answer using retrieved documents."""
73
+ print("---GENERATE---")
74
+ question = state['question']
75
+ documents = state['documents']
76
+
77
+ # Prepare context
78
+ context = "\n\n".join([doc["page_content"] for doc in documents])
79
+ prompt_template = PromptTemplate.from_template("""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
80
+ Question: {question}
81
+ Context: {context}
82
+ Answer:""")
83
+ # Generate answer
84
+ rag_chain = prompt_template | llm | StrOutputParser()
85
+
86
+ generation = rag_chain.invoke({"context": context, "question": question})
87
+
88
+ return {"generation": generation, "documents": documents, "question": question}
89
+
90
+ def grade_documents(state: GraphState):
91
+ """Filter relevant documents."""
92
+ print("---GRADE DOCUMENTS---")
93
+ question = state['question']
94
+ documents = state['documents']
95
+
96
+ # Filter documents
97
+ filtered_docs = []
98
+ for doc in documents:
99
+ score = grade_document_relevance(question, doc["page_content"], llm)
100
+ if score == "yes":
101
+ filtered_docs.append(doc)
102
+
103
+ return {"documents": filtered_docs, "question": question}
104
+
105
+ def web_search(state: GraphState):
106
+ """Perform web search."""
107
+ print("---WEB SEARCH---")
108
+ question = state['question']
109
+
110
+ # Perform web search
111
+ results = perform_web_search(question)
112
+ web_documents = [
113
+ {
114
+ "page_content": result['content'],
115
+ "metadata": {"source": result['url']}
116
+ } for result in results
117
+ ]
118
+
119
+ return {"documents": web_documents, "question": question}
120
+
121
+ def check_generation_quality(state: GraphState):
122
+ """Check the quality of generated answer."""
123
+ print("---ASSESS GENERATION---")
124
+ question = state['question']
125
+ documents = state['documents']
126
+ generation = state['generation']
127
+
128
+
129
+ print("---Generation is not hallucinated.---")
130
+ # Check answer quality
131
+ quality_score = grade_answer_quality(question, generation, llm)
132
+ if quality_score == "yes":
133
+ print("---Answer quality is good.---")
134
+ else:
135
+ print("---Answer quality is poor.---")
136
+ return "end" if quality_score == "yes" else "rewrite"
137
+
138
+ # Create workflow
139
+ workflow = StateGraph(GraphState)
140
+
141
+ # Add nodes
142
+ workflow.add_node("vectorstore", retrieve)
143
+ workflow.add_node("web_search", web_search)
144
+ workflow.add_node("grade_documents", grade_documents)
145
+ workflow.add_node("generate", generate_answer)
146
+ workflow.add_node("rewrite_query", lambda state: {
147
+ "question": rewrite_query(state['question'], llm),
148
+ "documents": [],
149
+ "generation": None
150
+ })
151
+
152
+ # Define edges
153
+ workflow.add_conditional_edges(
154
+ START,
155
+ route_to_datasource,
156
+ {
157
+ "web_search": "web_search",
158
+ "vectorstore": "vectorstore"
159
+ }
160
+ )
161
+
162
+ workflow.add_edge("web_search", "generate")
163
+ workflow.add_edge("vectorstore", "grade_documents")
164
+
165
+ workflow.add_conditional_edges(
166
+ "grade_documents",
167
+ lambda state: "generate" if state['documents'] else "rewrite_query"
168
+ )
169
+
170
+ workflow.add_edge("rewrite_query", "vectorstore")
171
+
172
+ workflow.add_conditional_edges(
173
+ "generate",
174
+ check_generation_quality,
175
+ {
176
+ "end": END,
177
+ "regenerate": "generate",
178
+ "rewrite": "rewrite_query"
179
+ }
180
+ )
181
+
182
+ # Compile the workflow
183
+ app = workflow.compile()
184
+ return app
185
+
186
+ def run_adaptive_rag(retriever, question: str, llm, top_k=5, enable_websearch=False):
187
+ """
188
+ Run the adaptive RAG workflow for a given question.
189
+
190
+ Args:
191
+ retriever: Vector store retriever
192
+ question (str): User's input question
193
+
194
+ Returns:
195
+ str: Generated answer
196
+ """
197
+ # Create workflow
198
+ workflow = create_adaptive_rag_workflow(retriever, llm, top_k, enable_websearch=enable_websearch)
199
+
200
+ # Run workflow
201
+ final_state = None
202
+ for output in workflow.stream({"question": question}, config={"recursion_limit": 5}):
203
+ for key, value in output.items():
204
+ print(f"Node '{key}':")
205
+ # Optionally print state details
206
+ # print(value)
207
+ final_state = value
208
+
209
+ return final_state.get('generation', 'No answer could be generated.')
210
+
211
+ if __name__ == "__main__":
212
+ # Example usage
213
+ from vectorstore.pinecone_db import PINECONE_API_KEY, ingest_data, get_retriever, load_documents, process_chunks, save_to_parquet
214
+ from pinecone import Pinecone
215
+
216
+ # Load and prepare documents
217
+ pc = Pinecone(api_key=PINECONE_API_KEY)
218
+
219
+ # Define input files
220
+ file_paths=[
221
+ # './data/2404.19756v1.pdf',
222
+ # './data/OD429347375590223100.pdf',
223
+ # './data/Project Report Format.docx',
224
+ './data/UNIT 2 GENDER BASED VIOLENCE.pptx'
225
+ ]
226
+
227
+ # Process pipeline
228
+ try:
229
+ # Step 1: Load and combine documents
230
+ print("Loading documents...")
231
+ markdown_path = load_documents(file_paths)
232
+
233
+ # Step 2: Process into chunks with embeddings
234
+ print("Processing chunks...")
235
+ chunks = process_chunks(markdown_path)
236
+
237
+ # Step 3: Save to Parquet
238
+ print("Saving to Parquet...")
239
+ parquet_path = save_to_parquet(chunks)
240
+
241
+ # Step 4: Ingest into Pinecone
242
+ print("Ingesting into Pinecone...")
243
+ ingest_data(pc,
244
+ parquet_path=parquet_path,
245
+ text_column="text",
246
+ pinecone_client=pc,
247
+ )
248
+
249
+ # Step 5: Test retrieval
250
+ print("\nTesting retrieval...")
251
+ retriever = get_retriever(
252
+ pinecone_client=pc,
253
+ index_name="vector-index",
254
+ namespace="rag"
255
+ )
256
+
257
+ except Exception as e:
258
+ print(f"Error in pipeline: {str(e)}")
259
+
260
+ llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
261
+
262
+ # Test questions
263
+ test_questions = [
264
+ # "What are the key components of AI agent memory?",
265
+ # "Explain prompt engineering techniques",
266
+ # "What are recent advancements in adversarial attacks on LLMs?"
267
+ "what are the trending papers that are published in NeurIPS 2024?"
268
+ ]
269
+
270
+ # Run workflow for each test question
271
+ for question in test_questions:
272
+ print(f"\n--- Processing Question: {question} ---")
273
+ answer = run_adaptive_rag(retriever, question, llm)
274
+ print("\nFinal Answer:", answer)
src/data_processing/__init__.py ADDED
File without changes
src/data_processing/chunker.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+ from chonkie.embeddings import BaseEmbeddings
4
+ from FlagEmbedding import BGEM3FlagModel
5
+ from chonkie import SDPMChunker as SDPMChunker
6
+
7
+ class BGEM3Embeddings(BaseEmbeddings):
8
+ def __init__(self, model_name):
9
+ self.model = BGEM3FlagModel(model_name, use_fp16=True)
10
+ self.task = "separation"
11
+
12
+ @property
13
+ def dimension(self):
14
+ return 1024
15
+
16
+ def embed(self, text: str):
17
+ e = self.model.encode([text], return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs']
18
+ # print(e)
19
+ return e
20
+
21
+ def embed_batch(self, texts: List[str]):
22
+ embeddings = self.model.encode(texts, return_dense=True, return_sparse=False, return_colbert_vecs=False
23
+ )
24
+ # print(embeddings['dense_vecs'])
25
+ return embeddings['dense_vecs']
26
+
27
+ def count_tokens(self, text: str):
28
+ l = len(self.model.tokenizer.encode(text))
29
+ # print(l)
30
+ return l
31
+
32
+ def count_tokens_batch(self, texts: List[str]):
33
+ encodings = self.model.tokenizer(texts)
34
+ # print([len(enc) for enc in encodings["input_ids"]])
35
+ return [len(enc) for enc in encodings["input_ids"]]
36
+
37
+ def get_tokenizer_or_token_counter(self):
38
+ return self.model.tokenizer
39
+
40
+ def similarity(self, u: "np.ndarray", v: "np.ndarray"):
41
+ """Compute cosine similarity between two embeddings."""
42
+ s = ([email protected])#.item()
43
+ # print(s)
44
+ return s
45
+
46
+ @classmethod
47
+ def is_available(cls):
48
+ return True
49
+
50
+ def __repr__(self):
51
+ return "bgem3"
52
+
53
+
54
+ def main():
55
+ # Initialize the BGE M3 embeddings model
56
+ embedding_model = BGEM3Embeddings(
57
+ model_name="BAAI/bge-m3"
58
+ )
59
+
60
+ # Initialize the SDPM chunker
61
+ chunker = SDPMChunker(
62
+ embedding_model=embedding_model,
63
+ chunk_size=256,
64
+ threshold=0.7,
65
+ skip_window=2
66
+ )
67
+
68
+ with open('./output.md', 'r') as file:
69
+ text = file.read()
70
+
71
+ # Generate chunks
72
+ chunks = chunker.chunk(text)
73
+
74
+ # Print the chunks
75
+ for i, chunk in enumerate(chunks, 1):
76
+ print(f"\nChunk {i}:")
77
+ print(f"Text: {chunk.text}")
78
+ print(f"Token count: {chunk.token_count}")
79
+ print(f"Start index: {chunk.start_index}")
80
+ print(f"End index: {chunk.end_index}")
81
+ print(f"no of sentences: {len(chunk.sentences)}")
82
+ print("-" * 80)
83
+
84
+ if __name__ == "__main__":
85
+ main()
src/data_processing/loader.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Union
3
+ import logging
4
+ from dataclasses import dataclass
5
+
6
+ from langchain_core.documents import Document as LCDocument
7
+ from langchain_core.document_loaders import BaseLoader
8
+ from docling.document_converter import DocumentConverter, PdfFormatOption
9
+ from docling.datamodel.base_models import InputFormat, ConversionStatus
10
+ from docling.datamodel.pipeline_options import (
11
+ PdfPipelineOptions,
12
+ EasyOcrOptions
13
+ )
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ _log = logging.getLogger(__name__)
17
+
18
+ @dataclass
19
+ class ProcessingResult:
20
+ """Store results of document processing"""
21
+ success_count: int = 0
22
+ failure_count: int = 0
23
+ partial_success_count: int = 0
24
+ failed_files: List[str] = None
25
+
26
+ def __post_init__(self):
27
+ if self.failed_files is None:
28
+ self.failed_files = []
29
+
30
+ class MultiFormatDocumentLoader(BaseLoader):
31
+ """Loader for multiple document formats that converts to LangChain documents"""
32
+
33
+ def __init__(
34
+ self,
35
+ file_paths: Union[str, List[str]],
36
+ enable_ocr: bool = True,
37
+ enable_tables: bool = True
38
+ ):
39
+ self._file_paths = [file_paths] if isinstance(file_paths, str) else file_paths
40
+ self._enable_ocr = enable_ocr
41
+ self._enable_tables = enable_tables
42
+ self._converter = self._setup_converter()
43
+
44
+ def _setup_converter(self):
45
+ """Set up the document converter with appropriate options"""
46
+ # Configure pipeline options
47
+ pipeline_options = PdfPipelineOptions(do_ocr=False, do_table_structure=False, ocr_options=EasyOcrOptions(
48
+ force_full_page_ocr=True
49
+ ))
50
+ if self._enable_ocr:
51
+ pipeline_options.do_ocr = True
52
+ if self._enable_tables:
53
+ pipeline_options.do_table_structure = True
54
+ pipeline_options.table_structure_options.do_cell_matching = True
55
+
56
+ # Create converter with supported formats
57
+ return DocumentConverter(
58
+ allowed_formats=[
59
+ InputFormat.PDF,
60
+ InputFormat.IMAGE,
61
+ InputFormat.DOCX,
62
+ InputFormat.HTML,
63
+ InputFormat.PPTX,
64
+ InputFormat.ASCIIDOC,
65
+ InputFormat.MD,
66
+ ],
67
+ format_options={
68
+ InputFormat.PDF: PdfFormatOption(
69
+ pipeline_options=pipeline_options,
70
+ )}
71
+ )
72
+
73
+ def lazy_load(self):
74
+ """Convert documents and yield LangChain documents"""
75
+ results = ProcessingResult()
76
+
77
+ for file_path in self._file_paths:
78
+ try:
79
+ path = Path(file_path)
80
+ if not path.exists():
81
+ _log.warning(f"File not found: {file_path}")
82
+ results.failure_count += 1
83
+ results.failed_files.append(file_path)
84
+ continue
85
+
86
+ conversion_result = self._converter.convert(path)
87
+
88
+ if conversion_result.status == ConversionStatus.SUCCESS:
89
+ results.success_count += 1
90
+ text = conversion_result.document.export_to_markdown()
91
+ metadata = {
92
+ 'source': str(path),
93
+ 'file_type': path.suffix,
94
+ }
95
+ yield LCDocument(
96
+ page_content=text,
97
+ metadata=metadata
98
+ )
99
+ elif conversion_result.status == ConversionStatus.PARTIAL_SUCCESS:
100
+ results.partial_success_count += 1
101
+ _log.warning(f"Partial conversion for {file_path}")
102
+ text = conversion_result.document.export_to_markdown()
103
+ metadata = {
104
+ 'source': str(path),
105
+ 'file_type': path.suffix,
106
+ 'conversion_status': 'partial'
107
+ }
108
+ yield LCDocument(
109
+ page_content=text,
110
+ metadata=metadata
111
+ )
112
+ else:
113
+ results.failure_count += 1
114
+ results.failed_files.append(file_path)
115
+ _log.error(f"Failed to convert {file_path}")
116
+
117
+ except Exception as e:
118
+ _log.error(f"Error processing {file_path}: {str(e)}")
119
+ results.failure_count += 1
120
+ results.failed_files.append(file_path)
121
+
122
+ # Log final results
123
+ total = results.success_count + results.partial_success_count + results.failure_count
124
+ _log.info(
125
+ f"Processed {total} documents:\n"
126
+ f"- Successfully converted: {results.success_count}\n"
127
+ f"- Partially converted: {results.partial_success_count}\n"
128
+ f"- Failed: {results.failure_count}"
129
+ )
130
+ if results.failed_files:
131
+ _log.info("Failed files:")
132
+ for file in results.failed_files:
133
+ _log.info(f"- {file}")
134
+
135
+
136
+ if __name__ == '__main__':
137
+ # Load documents from a list of file paths
138
+ loader = MultiFormatDocumentLoader(
139
+ file_paths=[
140
+ # './data/2404.19756v1.pdf',
141
+ # './data/OD429347375590223100.pdf',
142
+ './data/Project Report Format.docx',
143
+ # './data/UNIT 2 GENDER BASED VIOLENCE.pptx'
144
+ ],
145
+ enable_ocr=False,
146
+ enable_tables=True
147
+ )
148
+ for doc in loader.lazy_load():
149
+ print(doc.page_content)
150
+ print(doc.metadata)
151
+ # save document in .md file
152
+ with open('output.md', 'w') as f:
153
+ f.write(doc.page_content)
src/llm/__init__.py ADDED
File without changes
src/llm/graders.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+ from langchain_ollama import ChatOllama
3
+ from pydantic import BaseModel, Field
4
+ from typing import List
5
+
6
+ class DocumentRelevance(BaseModel):
7
+ """Binary score for relevance check on retrieved documents."""
8
+ binary_score: str = Field(
9
+ description="Documents are relevant to the question, 'yes' or 'no'"
10
+ )
11
+
12
+ class HallucinationCheck(BaseModel):
13
+ """Binary score for hallucination present in generation answer."""
14
+ binary_score: str = Field(
15
+ description="Answer is grounded in the facts, 'yes' or 'no'"
16
+ )
17
+
18
+ class AnswerQuality(BaseModel):
19
+ """Binary score to assess answer addresses question."""
20
+ binary_score: str = Field(
21
+ description="Answer addresses the question, 'yes' or 'no'"
22
+ )
23
+
24
+ def create_llm_grader(grader_type: str, llm):
25
+ """
26
+ Create an LLM grader based on the specified type.
27
+
28
+ Args:
29
+ grader_type (str): Type of grader to create
30
+
31
+ Returns:
32
+ Callable: LLM grader function
33
+ """
34
+ # Initialize LLM
35
+
36
+ # Select grader type and create structured output
37
+ if grader_type == "document_relevance":
38
+ structured_llm_grader = llm.with_structured_output(DocumentRelevance)
39
+ system = """You are a grader assessing relevance of a retrieved document to a user question.
40
+ If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
41
+ It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
42
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
43
+
44
+ prompt = ChatPromptTemplate.from_messages([
45
+ ("system", system),
46
+ ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
47
+ ])
48
+
49
+ elif grader_type == "hallucination":
50
+ structured_llm_grader = llm.with_structured_output(HallucinationCheck)
51
+ system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts.
52
+ Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
53
+
54
+ prompt = ChatPromptTemplate.from_messages([
55
+ ("system", system),
56
+ ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
57
+ ])
58
+
59
+ elif grader_type == "answer_quality":
60
+ structured_llm_grader = llm.with_structured_output(AnswerQuality)
61
+ system = """You are a grader assessing whether an answer addresses / resolves a question.
62
+ Give a binary score 'yes' or 'no'. 'Yes' means that the answer resolves the question."""
63
+
64
+ prompt = ChatPromptTemplate.from_messages([
65
+ ("system", system),
66
+ ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
67
+ ])
68
+
69
+ else:
70
+ raise ValueError(f"Unknown grader type: {grader_type}")
71
+
72
+ return prompt | structured_llm_grader
73
+
74
+ def grade_document_relevance(question: str, document: str, llm):
75
+ """
76
+ Grade the relevance of a document to a given question.
77
+
78
+ Args:
79
+ question (str): User's question
80
+ document (str): Retrieved document content
81
+
82
+ Returns:
83
+ str: Binary score ('yes' or 'no')
84
+ """
85
+ grader = create_llm_grader("document_relevance", llm)
86
+ result = grader.invoke({"question": question, "document": document})
87
+ return result.binary_score
88
+
89
+ def check_hallucination(documents: List[str], generation: str, llm):
90
+ """
91
+ Check if the generation is grounded in the provided documents.
92
+
93
+ Args:
94
+ documents (List[str]): List of source documents
95
+ generation (str): LLM generated answer
96
+
97
+ Returns:
98
+ str: Binary score ('yes' or 'no')
99
+ """
100
+ grader = create_llm_grader("hallucination", llm)
101
+ result = grader.invoke({"documents": documents, "generation": generation})
102
+ return result.binary_score
103
+
104
+ def grade_answer_quality(question: str, generation: str, llm):
105
+ """
106
+ Grade the quality of the answer in addressing the question.
107
+
108
+ Args:
109
+ question (str): User's original question
110
+ generation (str): LLM generated answer
111
+
112
+ Returns:
113
+ str: Binary score ('yes' or 'no')
114
+ """
115
+ grader = create_llm_grader("answer_quality", llm)
116
+ result = grader.invoke({"question": question, "generation": generation})
117
+ return result.binary_score
118
+
119
+ if __name__ == "__main__":
120
+ # Example usage
121
+ test_question = "What are the types of agent memory?"
122
+ test_document = "Agent memory can be classified into different types such as episodic, semantic, and working memory."
123
+ test_generation = "Agent memory includes episodic memory for storing experiences, semantic memory for general knowledge, and working memory for immediate processing."
124
+ llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
125
+
126
+ print("Document Relevance:", grade_document_relevance(test_question, test_document, llm))
127
+ print("Hallucination Check:", check_hallucination([test_document], test_generation, llm))
128
+ print("Answer Quality:", grade_answer_quality(test_question, test_generation, llm))
src/llm/query_rewriter.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts.chat import ChatPromptTemplate
2
+ from langchain_ollama import ChatOllama
3
+ from langchain_core.output_parsers import StrOutputParser
4
+
5
+ def create_query_rewriter(llm):
6
+ """
7
+ Create a query rewriter to optimize retrieval.
8
+
9
+ Returns:
10
+ Callable: Query rewriter function
11
+ """
12
+
13
+ # Prompt for query rewriting
14
+ system = """You are a question re-writer that converts an input question to a better version that is optimized
15
+ for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
16
+
17
+ re_write_prompt = ChatPromptTemplate.from_messages([
18
+ ("system", system),
19
+ ("human", "Here is the initial question: \n\n {question} \n Formulate an improved question."),
20
+ ])
21
+
22
+ # Create query rewriter chain
23
+ return re_write_prompt | llm | StrOutputParser()
24
+
25
+ def rewrite_query(question: str, llm):
26
+ """
27
+ Rewrite a given query to optimize retrieval.
28
+
29
+ Args:
30
+ question (str): Original user question
31
+
32
+ Returns:
33
+ str: Rewritten query
34
+ """
35
+ query_rewriter = create_query_rewriter(llm)
36
+ try:
37
+ rewritten_query = query_rewriter.invoke({"question": question})
38
+ return rewritten_query
39
+ except Exception as e:
40
+ print(f"Query rewriting error: {e}")
41
+ return question
42
+
43
+ if __name__ == "__main__":
44
+ # Example usage
45
+ test_queries = [
46
+ "Tell me about AI agents",
47
+ "What do we know about memory in AI systems?",
48
+ "Bears draft strategy"
49
+ ]
50
+ llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
51
+
52
+ for query in test_queries:
53
+ rewritten = rewrite_query(query, llm)
54
+ print(f"Original: {query}")
55
+ print(f"Rewritten: {rewritten}\n")
src/tools/__init__.py ADDED
File without changes
src/tools/deep_crawler.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import asyncio
4
+ from typing import List, Dict, Optional, Set
5
+ from urllib.parse import urlparse
6
+ from langchain_community.tools import DuckDuckGoSearchResults, TavilySearchResults
7
+ from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
8
+ from crawl4ai import AsyncWebCrawler, CacheMode
9
+ from crawl4ai.content_filter_strategy import PruningContentFilter, BM25ContentFilter
10
+ from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv()
14
+
15
+ class DeepWebCrawler:
16
+ def __init__(self,
17
+ max_search_results: int = 5,
18
+ max_external_links: int = 3,
19
+ word_count_threshold: int = 50,
20
+ content_filter_type: str = 'pruning',
21
+ filter_threshold: float = 0.48):
22
+ """
23
+ Initialize the Deep Web Crawler with support for one-level deep crawling
24
+
25
+ Args:
26
+ max_search_results (int): Maximum number of search results to process
27
+ max_external_links (int): Maximum number of external links to crawl per page
28
+ word_count_threshold (int): Minimum word count for crawled content
29
+ content_filter_type (str): Type of content filter ('pruning' or 'bm25')
30
+ filter_threshold (float): Threshold for content filtering
31
+ """
32
+ self.max_search_results = max_search_results
33
+ self.max_external_links = max_external_links
34
+ self.word_count_threshold = word_count_threshold
35
+ self.content_filter_type = content_filter_type
36
+ self.filter_threshold = filter_threshold
37
+ self.crawled_urls: Set[str] = set()
38
+
39
+ def _create_web_search_tool(self):
40
+ return TavilySearchResults(max_results=self.max_search_results)
41
+
42
+ def _create_content_filter(self, user_query: Optional[str] = None):
43
+ if self.content_filter_type == 'bm25' and user_query:
44
+ return BM25ContentFilter(
45
+ user_query=user_query,
46
+ bm25_threshold=self.filter_threshold
47
+ )
48
+ return PruningContentFilter(
49
+ threshold=self.filter_threshold,
50
+ threshold_type="fixed",
51
+ min_word_threshold=self.word_count_threshold
52
+ )
53
+
54
+ def _extract_links_from_search_results(self, results: List[Dict]) -> List[str]:
55
+ """Safely extract URLs from search results"""
56
+ urls = []
57
+ for result in results:
58
+ if isinstance(result, dict) and 'url' in result:
59
+ urls.append(result['url'])
60
+ elif isinstance(result, str):
61
+ urls.append(result)
62
+ return urls
63
+
64
+ def _extract_url_from_link(self, link):
65
+ """Extract URL string from link object which might be a dict or string"""
66
+ if isinstance(link, dict):
67
+ return link.get('url', '') # Assuming the URL is stored in a 'url' key
68
+ elif isinstance(link, str):
69
+ return link
70
+ return ''
71
+
72
+ def _process_crawl_result(self, result) -> Dict:
73
+ """Process individual crawl result into structured format"""
74
+ return {
75
+ "url": result.url,
76
+ "success": result.success,
77
+ "title": result.metadata.get('title', 'N/A'),
78
+ "content": result.markdown_v2.raw_markdown if result.success else result.error_message,
79
+ "word_count": len(result.markdown_v2.raw_markdown.split()) if result.success else 0,
80
+ "links": {
81
+ "internal": result.links.get('internal', []),
82
+ "external": result.links.get('external', [])
83
+ },
84
+ "images": len(result.media.get('images', []))
85
+ }
86
+
87
+ async def crawl_urls(self, urls: List[str], user_query: Optional[str] = None, depth: int = 0):
88
+ """
89
+ Crawl URLs with support for external link crawling
90
+
91
+ Args:
92
+ urls (List[str]): List of URLs to crawl
93
+ user_query (Optional[str]): Query for content filtering
94
+ depth (int): Current crawl depth (0 for initial, 1 for external links)
95
+
96
+ Returns:
97
+ List of crawl results including external link content
98
+ """
99
+ if not urls or depth > 1:
100
+ return []
101
+
102
+ # Filter out already crawled URLs
103
+ new_urls = [url for url in urls if url not in self.crawled_urls]
104
+ if not new_urls:
105
+ return []
106
+
107
+ async with AsyncWebCrawler(
108
+ browser_type="chromium",
109
+ headless=True,
110
+ verbose=True
111
+ ) as crawler:
112
+ content_filter = self._create_content_filter(user_query)
113
+
114
+ results = await crawler.arun_many(
115
+ urls=new_urls,
116
+ word_count_threshold=self.word_count_threshold,
117
+ cache_mode=CacheMode.BYPASS,
118
+ markdown_generator=DefaultMarkdownGenerator(content_filter=content_filter),
119
+ exclude_external_links=True,
120
+ exclude_social_media_links=True,
121
+ remove_overlay_elements=True,
122
+ simulate_user=True,
123
+ magic=True
124
+ )
125
+
126
+ processed_results = []
127
+ external_urls = set()
128
+
129
+ # Process results and collect external URLs
130
+ for result in results:
131
+ self.crawled_urls.add(result.url)
132
+ processed_result = self._process_crawl_result(result)
133
+ processed_results.append(processed_result)
134
+
135
+ if depth == 0 and result.success:
136
+ # Collect unique external URLs for further crawling
137
+ external_links = result.links.get('external', [])[:self.max_external_links]
138
+ external_urls.update(
139
+ self._extract_url_from_link(link)
140
+ for link in external_links
141
+ if self._extract_url_from_link(link)
142
+ and self._extract_url_from_link(link) not in self.crawled_urls
143
+ )
144
+
145
+ # Crawl external links if at depth 0
146
+ if depth == 0 and external_urls and False:
147
+ external_results = await self.crawl_urls(
148
+ list(external_urls),
149
+ user_query=user_query,
150
+ depth=0
151
+ )
152
+ processed_results.extend(external_results)
153
+
154
+ return processed_results
155
+
156
+ async def search_and_crawl(self, query: str) -> List[Dict]:
157
+ """
158
+ Perform web search and deep crawl of results
159
+
160
+ Args:
161
+ query (str): Search query
162
+
163
+ Returns:
164
+ List of crawled content results including external links
165
+ """
166
+
167
+ search_tool = self._create_web_search_tool()
168
+ search_results = search_tool.invoke(query)
169
+
170
+ # Handle different types of search results
171
+ if isinstance(search_results, str):
172
+ urls = [search_results]
173
+ elif isinstance(search_results, list):
174
+ urls = self._extract_links_from_search_results(search_results)
175
+ else:
176
+ print(f"Unexpected search results format: {type(search_results)}")
177
+ return []
178
+
179
+ if not urls:
180
+ print("No valid URLs found in search results")
181
+ return []
182
+
183
+ print(f"Initial search found {len(urls)} URLs for query: {query}")
184
+ print(urls)
185
+ crawl_results = await self.crawl_urls(urls, user_query=query)
186
+
187
+ return crawl_results
188
+
189
+
190
+ class ResourceCollectionAgent:
191
+ def __init__(self, max_results_per_query: int = 10):
192
+ """
193
+ Initialize the Resource Collection Agent
194
+
195
+ Args:
196
+ max_results_per_query (int): Maximum number of results per search query
197
+ """
198
+ self.max_results_per_query = max_results_per_query
199
+ self.search_tool = TavilySearchResults(max_results=max_results_per_query)
200
+
201
+ def _is_valid_domain(self, url: str, valid_domains: List[str]) -> bool:
202
+ """Check if URL belongs to allowed domains"""
203
+ try:
204
+ domain = urlparse(url).netloc.lower()
205
+ return any(valid_domain in domain for valid_domain in valid_domains)
206
+ except:
207
+ return False
208
+
209
+ def _extract_search_result(self, result) -> Optional[Dict]:
210
+ """Safely extract information from a search result"""
211
+ try:
212
+ if isinstance(result, dict):
213
+ return {
214
+ "title": result.get("title", "No title"),
215
+ "url": result.get("url", ""),
216
+ "snippet": result.get("snippet", "No description")
217
+ }
218
+ elif isinstance(result, str):
219
+ return {
220
+ "title": "Unknown",
221
+ "url": result,
222
+ "snippet": "No description available"
223
+ }
224
+ return None
225
+ except Exception as e:
226
+ print(f"Error processing search result: {str(e)}")
227
+ return None
228
+
229
+ async def collect_resources(self) -> Dict[str, List[Dict]]:
230
+ """
231
+ Collect AI/ML resources from specific platforms
232
+
233
+ Returns:
234
+ Dictionary with categorized resource links
235
+ """
236
+ search_queries = {
237
+ "datasets": [
238
+ ("kaggle", "site:kaggle.com/datasets machine learning"),
239
+ ("huggingface", "site:huggingface.co/datasets artificial intelligence")
240
+ ],
241
+ "repositories": [
242
+ ("github", "site:github.com AI tools repository")
243
+ ]
244
+ }
245
+
246
+ results = {
247
+ "kaggle_datasets": [],
248
+ "huggingface_datasets": [],
249
+ "github_repositories": []
250
+ }
251
+
252
+ for category, queries in search_queries.items():
253
+ for platform, query in queries:
254
+ try:
255
+ search_results = self.search_tool.invoke(query)
256
+
257
+ # Handle different result formats
258
+ if isinstance(search_results, str):
259
+ search_results = [search_results]
260
+ elif not isinstance(search_results, list):
261
+ print(f"Unexpected search results format for {platform}: {type(search_results)}")
262
+ continue
263
+
264
+ # Filter results based on domain
265
+ valid_domains = {
266
+ "kaggle": ["kaggle.com"],
267
+ "huggingface": ["huggingface.co"],
268
+ "github": ["github.com"]
269
+ }
270
+
271
+ for result in search_results:
272
+ processed_result = self._extract_search_result(result)
273
+ if processed_result and self._is_valid_domain(
274
+ processed_result["url"],
275
+ valid_domains[platform]
276
+ ):
277
+ if platform == "kaggle":
278
+ results["kaggle_datasets"].append(processed_result)
279
+ elif platform == "huggingface":
280
+ results["huggingface_datasets"].append(processed_result)
281
+ elif platform == "github":
282
+ results["github_repositories"].append(processed_result)
283
+
284
+ except Exception as e:
285
+ print(f"Error collecting {platform} resources: {str(e)}")
286
+ continue
287
+
288
+ return results
289
+
290
+ def main():
291
+ async def run_examples():
292
+ # Test DeepWebCrawler
293
+ deep_crawler = DeepWebCrawler(
294
+ max_search_results=3,
295
+ max_external_links=2,
296
+ word_count_threshold=50
297
+ )
298
+
299
+ crawl_results = await deep_crawler.search_and_crawl(
300
+ "Adani Defence & Aerospace"
301
+ )
302
+
303
+ print("\nDeep Crawler Results:")
304
+ for result in crawl_results:
305
+ print(f"URL: {result['url']}")
306
+ print(f"Title: {result['title']}")
307
+ print(f"Word Count: {result['word_count']}")
308
+ print(f"External Links: {len(result['links']['external'])}\n")
309
+
310
+ # Test ResourceCollectionAgent
311
+ resource_agent = ResourceCollectionAgent(max_results_per_query=5)
312
+ resources = await resource_agent.collect_resources()
313
+
314
+ print("\nResource Collection Results:")
315
+ for category, items in resources.items():
316
+ print(f"\n{category.upper()}:")
317
+ for item in items:
318
+ print(f"Title: {item['title']}")
319
+ print(f"URL: {item['url']}")
320
+ print("---")
321
+
322
+ asyncio.run(run_examples())
323
+
324
+ if __name__ == "__main__":
325
+ main()
src/tools/web_search.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import asyncio
4
+ from typing import List, Dict, Optional
5
+
6
+ from langchain_community.tools import DuckDuckGoSearchResults
7
+ from crawl4ai import AsyncWebCrawler, CacheMode
8
+ from crawl4ai.content_filter_strategy import PruningContentFilter, BM25ContentFilter
9
+ from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
+
14
+ class AdvancedWebCrawler:
15
+ def __init__(self,
16
+ max_search_results: int = 5,
17
+ word_count_threshold: int = 50,
18
+ content_filter_type: str = 'pruning',
19
+ filter_threshold: float = 0.48):
20
+ """
21
+ Initialize the Advanced Web Crawler
22
+
23
+ Args:
24
+ max_search_results (int): Maximum number of search results to process
25
+ word_count_threshold (int): Minimum word count for crawled content
26
+ content_filter_type (str): Type of content filter ('pruning' or 'bm25')
27
+ filter_threshold (float): Threshold for content filtering
28
+ """
29
+ self.max_search_results = max_search_results
30
+ self.word_count_threshold = word_count_threshold
31
+ self.content_filter_type = content_filter_type
32
+ self.filter_threshold = filter_threshold
33
+
34
+ def _create_web_search_tool(self):
35
+ """
36
+ Create a web search tool using DuckDuckGo
37
+
38
+ Returns:
39
+ DuckDuckGoSearchResults: Web search tool
40
+ """
41
+ return DuckDuckGoSearchResults(max_results=self.max_search_results, output_format="list")
42
+
43
+ def _create_content_filter(self, user_query: Optional[str] = None):
44
+ """
45
+ Create content filter based on specified type
46
+
47
+ Args:
48
+ user_query (Optional[str]): Query to use for BM25 filtering
49
+
50
+ Returns:
51
+ Content filter strategy
52
+ """
53
+ if self.content_filter_type == 'bm25' and user_query:
54
+ return BM25ContentFilter(
55
+ user_query=user_query,
56
+ bm25_threshold=self.filter_threshold
57
+ )
58
+ else:
59
+ return PruningContentFilter(
60
+ threshold=self.filter_threshold,
61
+ threshold_type="fixed",
62
+ min_word_threshold=self.word_count_threshold
63
+ )
64
+
65
+ async def crawl_urls(self, urls: List[str], user_query: Optional[str] = None):
66
+ """
67
+ Crawl multiple URLs with content filtering
68
+
69
+ Args:
70
+ urls (List[str]): List of URLs to crawl
71
+ user_query (Optional[str]): Query used for BM25 content filtering
72
+
73
+ Returns:
74
+ List of crawl results
75
+ """
76
+ async with AsyncWebCrawler(
77
+ browser_type="chromium",
78
+ headless=True,
79
+ verbose=True
80
+ ) as crawler:
81
+ # Create appropriate content filter
82
+ content_filter = self._create_content_filter(user_query)
83
+
84
+ # Run crawling for multiple URLs
85
+ results = await crawler.arun_many(
86
+ urls=urls,
87
+ word_count_threshold=self.word_count_threshold,
88
+ bypass_cache=True,
89
+ markdown_generator=DefaultMarkdownGenerator(
90
+ content_filter=content_filter
91
+ ),
92
+ cache_mode=CacheMode.DISABLED,
93
+ exclude_external_links=True,
94
+ remove_overlay_elements=True,
95
+ simulate_user=True,
96
+ magic=True
97
+ )
98
+
99
+ # Process and return crawl results
100
+ processed_results = []
101
+ for result in results:
102
+ crawl_result = {
103
+ "url": result.url,
104
+ "success": result.success,
105
+ "title": result.metadata.get('title', 'N/A'),
106
+ "content": result.markdown_v2.raw_markdown if result.success else result.error_message,
107
+ "word_count": len(result.markdown_v2.raw_markdown.split()) if result.success else 0,
108
+ "links": {
109
+ "internal": len(result.links.get('internal', [])),
110
+ "external": len(result.links.get('external', []))
111
+ },
112
+ "images": len(result.media.get('images', []))
113
+ }
114
+ processed_results.append(crawl_result)
115
+
116
+ return processed_results
117
+
118
+ async def search_and_crawl(self, query: str) -> List[Dict]:
119
+ """
120
+ Perform web search and crawl the results
121
+
122
+ Args:
123
+ query (str): Search query
124
+
125
+ Returns:
126
+ List of crawled content results
127
+ """
128
+ # Perform web search
129
+ search_tool = self._create_web_search_tool()
130
+ try:
131
+ search_results = search_tool.invoke({"query": query})
132
+
133
+ # Extract URLs from search results
134
+ urls = [result['link'] for result in search_results]
135
+ print(f"Found {len(urls)} URLs for query: {query}")
136
+
137
+ # Crawl URLs
138
+ crawl_results = await self.crawl_urls(urls, user_query=query)
139
+
140
+ return crawl_results
141
+
142
+ except Exception as e:
143
+ print(f"Web search and crawl error: {e}")
144
+ return []
145
+
146
+ def main():
147
+ # Example usage
148
+ crawler = AdvancedWebCrawler(
149
+ max_search_results=5,
150
+ word_count_threshold=50,
151
+ content_filter_type='f',
152
+ filter_threshold=0.48
153
+ )
154
+
155
+ test_queries = [
156
+ "Latest developments in AI agents",
157
+ "Today's weather forecast in Kolkata",
158
+ ]
159
+
160
+ for query in test_queries:
161
+ # Run search and crawl asynchronously
162
+ results = asyncio.run(crawler.search_and_crawl(query))
163
+
164
+ print(f"\nResults for query: {query}")
165
+ for result in results:
166
+ print(f"URL: {result['url']}")
167
+ print(f"Success: {result['success']}")
168
+ print(f"Title: {result['title']}")
169
+ print(f"Word Count: {result['word_count']}")
170
+ print(f"Content Preview: {result['content'][:500]}...\n")
171
+
172
+ if __name__ == "__main__":
173
+ main()
src/vectorstore/__init__.py ADDED
File without changes
src/vectorstore/pinecone_db.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.data_processing.loader import MultiFormatDocumentLoader
2
+ from src.data_processing.chunker import SDPMChunker, BGEM3Embeddings
3
+
4
+ import pandas as pd
5
+ from typing import List, Dict, Any
6
+ from pinecone import Pinecone, ServerlessSpec
7
+ import time
8
+ from tqdm import tqdm
9
+ from dotenv import load_dotenv
10
+ import os
11
+
12
+
13
+ load_dotenv()
14
+
15
+ # API Keys
16
+ PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
17
+
18
+ embedding_model = BGEM3Embeddings(model_name="BAAI/bge-m3")
19
+
20
+
21
+ def load_documents(file_paths: List[str], output_path='./data/output.md'):
22
+ """
23
+ Load documents from multiple sources and combine them into a single markdown file
24
+ """
25
+ loader = MultiFormatDocumentLoader(
26
+ file_paths=file_paths,
27
+ enable_ocr=False,
28
+ enable_tables=True
29
+ )
30
+
31
+ # Append all documents to the markdown file
32
+ with open(output_path, 'w') as f:
33
+ for doc in loader.lazy_load():
34
+ # Add metadata as YAML frontmatter
35
+ f.write('---\n')
36
+ for key, value in doc.metadata.items():
37
+ f.write(f'{key}: {value}\n')
38
+ f.write('---\n\n')
39
+ f.write(doc.page_content)
40
+ f.write('\n\n')
41
+
42
+ return output_path
43
+
44
+ def process_chunks(markdown_path: str, chunk_size: int = 256,
45
+ threshold: float = 0.7, skip_window: int = 2):
46
+ """
47
+ Process the markdown file into chunks and prepare for vector storage
48
+ """
49
+ chunker = SDPMChunker(
50
+ embedding_model=embedding_model,
51
+ chunk_size=chunk_size,
52
+ threshold=threshold,
53
+ skip_window=skip_window
54
+ )
55
+
56
+ # Read the markdown file
57
+ with open(markdown_path, 'r') as file:
58
+ text = file.read()
59
+
60
+ # Generate chunks
61
+ chunks = chunker.chunk(text)
62
+
63
+ # Prepare data for Parquet
64
+ processed_chunks = []
65
+ for chunk in chunks:
66
+
67
+ processed_chunks.append({
68
+ 'text': chunk.text,
69
+ 'token_count': chunk.token_count,
70
+ 'start_index': chunk.start_index,
71
+ 'end_index': chunk.end_index,
72
+ 'num_sentences': len(chunk.sentences),
73
+ })
74
+
75
+ return processed_chunks
76
+
77
+ def save_to_parquet(chunks: List[Dict[str, Any]], output_path='./data/chunks.parquet'):
78
+ """
79
+ Save processed chunks to a Parquet file
80
+ """
81
+ df = pd.DataFrame(chunks)
82
+ print(f"Saving to Parquet: {output_path}")
83
+ df.to_parquet(output_path)
84
+ print(f"Saved to Parquet: {output_path}")
85
+ return output_path
86
+
87
+
88
+ class PineconeRetriever:
89
+ def __init__(
90
+ self,
91
+ pinecone_client: Pinecone,
92
+ index_name: str,
93
+ namespace: str,
94
+ embedding_generator: BGEM3Embeddings
95
+ ):
96
+ """Initialize the retriever with Pinecone client and embedding generator.
97
+
98
+ Args:
99
+ pinecone_client: Initialized Pinecone client
100
+ index_name: Name of the Pinecone index
101
+ namespace: Namespace in the index
102
+ embedding_generator: BGEM3Embeddings instance
103
+ """
104
+ self.pinecone = pinecone_client
105
+ self.index = self.pinecone.Index(index_name)
106
+ self.namespace = namespace
107
+ self.embedding_generator = embedding_generator
108
+
109
+ def invoke(self, question: str, top_k: int = 5):
110
+ """Retrieve similar documents for a question.
111
+
112
+ Args:
113
+ question: Query string
114
+ top_k: Number of results to return
115
+
116
+ Returns:
117
+ List of dictionaries containing retrieved documents
118
+ """
119
+ # Generate embedding for the question
120
+ question_embedding = self.embedding_generator.embed(question)
121
+ question_embedding = question_embedding.tolist()
122
+ # Query Pinecone
123
+ results = self.index.query(
124
+ namespace=self.namespace,
125
+ vector=question_embedding,
126
+ top_k=top_k,
127
+ include_values=False,
128
+ include_metadata=True
129
+ )
130
+
131
+ # Format results
132
+ retrieved_docs = [
133
+ {"page_content": match.metadata["text"], "score": match.score}
134
+ for match in results.matches
135
+ ]
136
+
137
+ return retrieved_docs
138
+
139
+ def ingest_data(
140
+ pc,
141
+ parquet_path: str,
142
+ text_column: str,
143
+ pinecone_client: Pinecone,
144
+ index_name= "vector-index",
145
+ namespace= "rag",
146
+ batch_size: int = 100
147
+ ):
148
+ """Ingest data from a Parquet file into Pinecone.
149
+
150
+ Args:
151
+ parquet_path: Path to the Parquet file
152
+ text_column: Name of the column containing text data
153
+ pinecone_client: Initialized Pinecone client
154
+ index_name: Name of the Pinecone index
155
+ namespace: Namespace in the index
156
+ batch_size: Batch size for processing
157
+ """
158
+ # Read Parquet file
159
+ print(f"Reading Parquet file: {parquet_path}")
160
+ df = pd.read_parquet(parquet_path)
161
+ print(f"Total records: {len(df)}")
162
+ # Create or get index
163
+ if not pinecone_client.has_index(index_name):
164
+ pinecone_client.create_index(
165
+ name=index_name,
166
+ dimension=1024, # BGE-M3 dimension
167
+ metric="cosine",
168
+ spec=ServerlessSpec(
169
+ cloud='aws',
170
+ region='us-east-1'
171
+ )
172
+ )
173
+
174
+ # Wait for index to be ready
175
+ while not pinecone_client.describe_index(index_name).status['ready']:
176
+ time.sleep(1)
177
+
178
+ index = pinecone_client.Index(index_name)
179
+
180
+ # Process in batches
181
+ for i in tqdm(range(0, len(df), batch_size)):
182
+ batch_df = df.iloc[i:i+batch_size]
183
+
184
+ # Generate embeddings for batch
185
+ texts = batch_df[text_column].tolist()
186
+ embeddings = embedding_model.embed_batch(texts)
187
+ print(f"embeddings for batch: {i}")
188
+ # Prepare records for upsert
189
+ records = []
190
+ for idx, (_, row) in enumerate(batch_df.iterrows()):
191
+ records.append({
192
+ "id": str(row.name), # Using DataFrame index as ID
193
+ "values": embeddings[idx],
194
+ "metadata": {"text": row[text_column]}
195
+ })
196
+
197
+ # Upsert to Pinecone
198
+ index.upsert(vectors=records, namespace=namespace)
199
+
200
+ # Small delay to handle rate limits
201
+ time.sleep(0.5)
202
+
203
+ def get_retriever(
204
+ pinecone_client: Pinecone,
205
+ index_name= "vector-index",
206
+ namespace= "rag"
207
+ ):
208
+ """Create and return a PineconeRetriever instance.
209
+
210
+ Args:
211
+ pinecone_client: Initialized Pinecone client
212
+ index_name: Name of the Pinecone index
213
+ namespace: Namespace in the index
214
+
215
+ Returns:
216
+ Configured PineconeRetriever instance
217
+ """
218
+ return PineconeRetriever(
219
+ pinecone_client=pinecone_client,
220
+ index_name=index_name,
221
+ namespace=namespace,
222
+ embedding_generator=embedding_model
223
+ )
224
+
225
+ def main():
226
+ # Initialize Pinecone client
227
+ pc = Pinecone(api_key=PINECONE_API_KEY)
228
+
229
+ # Define input files
230
+ file_paths=[
231
+ # './data/2404.19756v1.pdf',
232
+ # './data/OD429347375590223100.pdf',
233
+ # './data/Project Report Format.docx',
234
+ './data/UNIT 2 GENDER BASED VIOLENCE.pptx'
235
+ ]
236
+
237
+ # Process pipeline
238
+ try:
239
+ # Step 1: Load and combine documents
240
+ # print("Loading documents...")
241
+ # markdown_path = load_documents(file_paths)
242
+
243
+ # # Step 2: Process into chunks with embeddings
244
+ # print("Processing chunks...")
245
+ # chunks = process_chunks(markdown_path)
246
+
247
+ # # Step 3: Save to Parquet
248
+ # print("Saving to Parquet...")
249
+ # parquet_path = save_to_parquet(chunks)
250
+
251
+ # # Step 4: Ingest into Pinecone
252
+ # print("Ingesting into Pinecone...")
253
+ # ingest_data(
254
+ # pc,
255
+ # parquet_path=parquet_path,
256
+ # text_column="text",
257
+ # pinecone_client=pc,
258
+ # )
259
+
260
+ # Step 5: Test retrieval
261
+ print("\nTesting retrieval...")
262
+ retriever = get_retriever(
263
+ pinecone_client=pc,
264
+ index_name="vector-index",
265
+ namespace="rag"
266
+ )
267
+
268
+ results = retriever.invoke(
269
+ question="describe the gender based violence",
270
+ top_k=5
271
+ )
272
+
273
+ for i, doc in enumerate(results, 1):
274
+ print(f"\nResult {i}:")
275
+ print(f"Content: {doc['page_content']}...")
276
+ print(f"Score: {doc['score']}")
277
+
278
+ except Exception as e:
279
+ print(f"Error in pipeline: {str(e)}")
280
+
281
+ if __name__ == "__main__":
282
+ main()