Upload 19 files
Browse files- app.py +432 -0
- requirements.txt +24 -0
- src/__init__.py +0 -0
- src/agents/__init__.py +0 -0
- src/agents/research_agent.py +476 -0
- src/agents/router.py +59 -0
- src/agents/state.py +30 -0
- src/agents/workflow.py +274 -0
- src/data_processing/__init__.py +0 -0
- src/data_processing/chunker.py +85 -0
- src/data_processing/loader.py +153 -0
- src/llm/__init__.py +0 -0
- src/llm/graders.py +128 -0
- src/llm/query_rewriter.py +55 -0
- src/tools/__init__.py +0 -0
- src/tools/deep_crawler.py +325 -0
- src/tools/web_search.py +173 -0
- src/vectorstore/__init__.py +0 -0
- src/vectorstore/pinecone_db.py +282 -0
app.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import asyncio
|
3 |
+
from src.vectorstore.pinecone_db import ingest_data, get_retriever, load_documents, process_chunks, save_to_parquet
|
4 |
+
from src.agents.research_agent import create_industry_research_workflow
|
5 |
+
from src.agents.workflow import run_adaptive_rag
|
6 |
+
from pinecone import Pinecone
|
7 |
+
from langchain_openai import ChatOpenAI
|
8 |
+
from langchain_ollama import ChatOllama
|
9 |
+
from langgraph.pregel import GraphRecursionError
|
10 |
+
import tempfile
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
# Page configuration
|
16 |
+
st.set_page_config(
|
17 |
+
page_title="Research & RAG Assistant",
|
18 |
+
page_icon="🤖",
|
19 |
+
layout="wide",
|
20 |
+
initial_sidebar_state="expanded"
|
21 |
+
)
|
22 |
+
|
23 |
+
# Custom CSS for better UI
|
24 |
+
st.markdown("""
|
25 |
+
<style>
|
26 |
+
.stTabs [data-baseweb="tab-list"] {
|
27 |
+
gap: 24px;
|
28 |
+
}
|
29 |
+
.stTabs [data-baseweb="tab"] {
|
30 |
+
padding: 8px 16px;
|
31 |
+
}
|
32 |
+
.config-section {
|
33 |
+
background-color: #f0f2f6;
|
34 |
+
border-radius: 10px;
|
35 |
+
padding: 20px;
|
36 |
+
margin: 10px 0;
|
37 |
+
}
|
38 |
+
.chat-container {
|
39 |
+
border: 1px solid #e0e0e0;
|
40 |
+
border-radius: 10px;
|
41 |
+
padding: 20px;
|
42 |
+
margin-top: 20px;
|
43 |
+
}
|
44 |
+
.stButton>button {
|
45 |
+
width: 100%;
|
46 |
+
}
|
47 |
+
</style>
|
48 |
+
""", unsafe_allow_html=True)
|
49 |
+
|
50 |
+
# Initialize session states
|
51 |
+
if "messages" not in st.session_state:
|
52 |
+
st.session_state.messages = []
|
53 |
+
if "documents_processed" not in st.session_state:
|
54 |
+
st.session_state.documents_processed = False
|
55 |
+
if "retriever" not in st.session_state:
|
56 |
+
st.session_state.retriever = None
|
57 |
+
if "pinecone_client" not in st.session_state:
|
58 |
+
st.session_state.pinecone_client = None
|
59 |
+
if "research_config_saved" not in st.session_state:
|
60 |
+
st.session_state.research_config_saved = False
|
61 |
+
if "rag_config_saved" not in st.session_state:
|
62 |
+
st.session_state.rag_config_saved = False
|
63 |
+
|
64 |
+
def save_research_config(api_keys):
|
65 |
+
"""Save research configuration."""
|
66 |
+
st.session_state.research_openai_key = api_keys['openai']
|
67 |
+
st.session_state.research_tavily_key = api_keys['tavily']
|
68 |
+
st.session_state.research_config_saved = True
|
69 |
+
|
70 |
+
|
71 |
+
def research_config_section():
|
72 |
+
"""Configuration section for Company Research tab."""
|
73 |
+
st.markdown("### ⚙️ Configuration")
|
74 |
+
|
75 |
+
with st.expander("API Configuration", expanded=not st.session_state.research_config_saved):
|
76 |
+
col1, col2 = st.columns(2)
|
77 |
+
with col1:
|
78 |
+
openai_key = st.text_input(
|
79 |
+
"OpenAI API Key",
|
80 |
+
type="password",
|
81 |
+
value=st.session_state.get('research_openai_key', ''),
|
82 |
+
key="research_openai_input"
|
83 |
+
)
|
84 |
+
with col2:
|
85 |
+
tavily_key = st.text_input(
|
86 |
+
"Tavily API Key",
|
87 |
+
type="password",
|
88 |
+
value=st.session_state.get('research_tavily_key', ''),
|
89 |
+
key="research_tavily_input"
|
90 |
+
)
|
91 |
+
|
92 |
+
if st.button("Save Research Configuration", key="save_research_config"):
|
93 |
+
if openai_key and tavily_key:
|
94 |
+
save_research_config({
|
95 |
+
'openai': openai_key,
|
96 |
+
'tavily': tavily_key
|
97 |
+
})
|
98 |
+
if not os.environ.get("TAVILY_API_KEY"):
|
99 |
+
os.environ["TAVILY_API_KEY"] = tavily_key
|
100 |
+
st.success("✅ Research configuration saved!")
|
101 |
+
else:
|
102 |
+
st.error("Please provide both API keys.")
|
103 |
+
|
104 |
+
|
105 |
+
async def run_industry_research(company: str, industry: str, llm):
|
106 |
+
"""Run the industry research workflow asynchronously."""
|
107 |
+
workflow = create_industry_research_workflow(llm)
|
108 |
+
|
109 |
+
output = await workflow.ainvoke(input={
|
110 |
+
"company": company,
|
111 |
+
"industry": industry
|
112 |
+
}, config={"recursion_limit": 5})
|
113 |
+
|
114 |
+
return output['final_report']
|
115 |
+
|
116 |
+
|
117 |
+
def research_input_section():
|
118 |
+
"""Input section for Company Research tab."""
|
119 |
+
st.markdown("### 🔍 Research Parameters")
|
120 |
+
|
121 |
+
col1, col2 = st.columns(2)
|
122 |
+
with col1:
|
123 |
+
company_name = st.text_input(
|
124 |
+
"Company Name",
|
125 |
+
placeholder="e.g., Tesla",
|
126 |
+
help="Enter the name of the company to research"
|
127 |
+
)
|
128 |
+
with col2:
|
129 |
+
industry_type = st.text_input(
|
130 |
+
"Industry Type",
|
131 |
+
placeholder="e.g., Automotive",
|
132 |
+
help="Enter the industry sector"
|
133 |
+
)
|
134 |
+
|
135 |
+
if st.button("Generate Research Report",
|
136 |
+
disabled=not st.session_state.research_config_saved,
|
137 |
+
type="primary"):
|
138 |
+
if company_name and industry_type:
|
139 |
+
with st.spinner("🔄 Generating comprehensive research report..."):
|
140 |
+
# try:
|
141 |
+
# Initialize LLM and run research
|
142 |
+
llm = ChatOpenAI(
|
143 |
+
model="gpt-3.5-turbo-0125",
|
144 |
+
temperature=0.1,
|
145 |
+
api_key=st.session_state.research_openai_key
|
146 |
+
)
|
147 |
+
|
148 |
+
report_path = asyncio.run(run_industry_research(
|
149 |
+
company=company_name,
|
150 |
+
industry=industry_type,
|
151 |
+
llm=llm
|
152 |
+
))
|
153 |
+
|
154 |
+
if os.path.exists(report_path):
|
155 |
+
with open(report_path, "rb") as file:
|
156 |
+
st.download_button(
|
157 |
+
"📥 Download Research Report",
|
158 |
+
data=file,
|
159 |
+
file_name=f"{company_name}_research_report.pdf",
|
160 |
+
mime="application/pdf"
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
st.error("Report generation failed.")
|
164 |
+
# except Exception as e:
|
165 |
+
# st.error(f"Error during report generation: {str(e)}")
|
166 |
+
else:
|
167 |
+
st.warning("Please fill in both company name and industry type.")
|
168 |
+
|
169 |
+
def initialize_pinecone(api_key):
|
170 |
+
"""Initialize Pinecone client with API key."""
|
171 |
+
try:
|
172 |
+
return Pinecone(api_key=api_key)
|
173 |
+
except Exception as e:
|
174 |
+
st.error(f"Error initializing Pinecone: {str(e)}")
|
175 |
+
return None
|
176 |
+
|
177 |
+
def initialize_llm(llm_option, openai_api_key=None):
|
178 |
+
"""Initialize LLM based on user selection."""
|
179 |
+
if llm_option == "OpenAI":
|
180 |
+
if not openai_api_key:
|
181 |
+
st.sidebar.warning("Please enter OpenAI API key.")
|
182 |
+
return None
|
183 |
+
return ChatOpenAI(api_key=openai_api_key, model="gpt-3.5-turbo")
|
184 |
+
|
185 |
+
def clear_pinecone_index(pc, index_name="vector-index"):
|
186 |
+
"""Clear the Pinecone index."""
|
187 |
+
try:
|
188 |
+
pc.delete_index(index_name)
|
189 |
+
st.session_state.documents_processed = False
|
190 |
+
st.session_state.retriever = None
|
191 |
+
st.success("Database cleared successfully!")
|
192 |
+
except Exception as e:
|
193 |
+
st.error(f"Error clearing database: {str(e)}")
|
194 |
+
|
195 |
+
def process_documents(uploaded_files, pc):
|
196 |
+
"""Process uploaded documents and store in Pinecone."""
|
197 |
+
if not uploaded_files:
|
198 |
+
st.warning("Please upload at least one document.")
|
199 |
+
return False
|
200 |
+
|
201 |
+
with st.spinner("Processing documents..."):
|
202 |
+
temp_dir = tempfile.mkdtemp()
|
203 |
+
file_paths = []
|
204 |
+
markdown_path = Path(temp_dir) / "combined.md"
|
205 |
+
parquet_path = Path(temp_dir) / "documents.parquet"
|
206 |
+
|
207 |
+
for uploaded_file in uploaded_files:
|
208 |
+
file_path = Path(temp_dir) / uploaded_file.name
|
209 |
+
with open(file_path, "wb") as f:
|
210 |
+
f.write(uploaded_file.getvalue())
|
211 |
+
file_paths.append(str(file_path))
|
212 |
+
|
213 |
+
try:
|
214 |
+
markdown_path = load_documents(file_paths, output_path=markdown_path)
|
215 |
+
chunks = process_chunks(markdown_path, chunk_size=256, threshold=0.6)
|
216 |
+
print(f"Processed chunks: {chunks}")
|
217 |
+
parquet_path = save_to_parquet(chunks, parquet_path)
|
218 |
+
|
219 |
+
ingest_data(
|
220 |
+
pc=pc,
|
221 |
+
parquet_path=parquet_path,
|
222 |
+
text_column="text",
|
223 |
+
pinecone_client=pc
|
224 |
+
)
|
225 |
+
|
226 |
+
st.session_state.retriever = get_retriever(pc)
|
227 |
+
st.session_state.documents_processed = True
|
228 |
+
|
229 |
+
return True
|
230 |
+
|
231 |
+
except Exception as e:
|
232 |
+
st.error(f"Error processing documents: {str(e)}")
|
233 |
+
return False
|
234 |
+
finally:
|
235 |
+
for file_path in file_paths:
|
236 |
+
try:
|
237 |
+
os.remove(file_path)
|
238 |
+
except:
|
239 |
+
pass
|
240 |
+
try:
|
241 |
+
os.rmdir(temp_dir)
|
242 |
+
except:
|
243 |
+
pass
|
244 |
+
|
245 |
+
def run_rag_with_streaming(retriever, question, llm, enable_web_search=False):
|
246 |
+
"""Run RAG workflow and yield streaming results."""
|
247 |
+
try:
|
248 |
+
response = run_adaptive_rag(
|
249 |
+
retriever=retriever,
|
250 |
+
question=question,
|
251 |
+
llm=llm,
|
252 |
+
top_k=5,
|
253 |
+
enable_websearch=enable_web_search
|
254 |
+
)
|
255 |
+
|
256 |
+
for word in response.split():
|
257 |
+
yield word + " "
|
258 |
+
time.sleep(0.03)
|
259 |
+
|
260 |
+
except GraphRecursionError:
|
261 |
+
response = "I apologize, but I cannot find a sufficient answer to your question in the provided documents. Please try rephrasing your question or ask something else about the content of the documents."
|
262 |
+
for word in response.split():
|
263 |
+
yield word + " "
|
264 |
+
time.sleep(0.03)
|
265 |
+
|
266 |
+
except Exception as e:
|
267 |
+
yield f"I encountered an error while processing your question: {str(e)}"
|
268 |
+
|
269 |
+
|
270 |
+
def document_upload_section():
|
271 |
+
"""Document upload section for RAG tab."""
|
272 |
+
st.markdown("### 📄 Document Management")
|
273 |
+
|
274 |
+
if not st.session_state.documents_processed:
|
275 |
+
uploaded_files = st.file_uploader(
|
276 |
+
"Upload your documents",
|
277 |
+
accept_multiple_files=True,
|
278 |
+
type=["pdf", "docx", "txt", "pptx", "md"],
|
279 |
+
help="Support multiple file uploads"
|
280 |
+
)
|
281 |
+
|
282 |
+
col1, col2 = st.columns([3, 1])
|
283 |
+
with col1:
|
284 |
+
if uploaded_files:
|
285 |
+
st.info(f"📁 {len(uploaded_files)} files selected")
|
286 |
+
with col2:
|
287 |
+
if st.button(
|
288 |
+
"Process Documents",
|
289 |
+
disabled=not (uploaded_files and st.session_state.rag_config_saved)
|
290 |
+
):
|
291 |
+
if process_documents(uploaded_files, st.session_state.pinecone_client):
|
292 |
+
st.success("✅ Documents processed successfully!")
|
293 |
+
else:
|
294 |
+
st.success("✅ Documents are loaded and ready for querying!")
|
295 |
+
if st.button("Upload New Documents"):
|
296 |
+
st.session_state.documents_processed = False
|
297 |
+
st.rerun()
|
298 |
+
|
299 |
+
# Update the save_rag_config function to remove web_search
|
300 |
+
def save_rag_config(config):
|
301 |
+
"""Save RAG configuration."""
|
302 |
+
st.session_state.rag_pinecone_key = config['pinecone']
|
303 |
+
st.session_state.rag_openai_key = config['openai']
|
304 |
+
st.session_state.rag_config_saved = True
|
305 |
+
|
306 |
+
# Update the rag_config_section to remove web_search checkbox
|
307 |
+
def rag_config_section():
|
308 |
+
"""Configuration section for RAG tab."""
|
309 |
+
st.markdown("### ⚙️ Configuration")
|
310 |
+
|
311 |
+
with st.expander("API Configuration", expanded=not st.session_state.rag_config_saved):
|
312 |
+
col1, col2 = st.columns(2)
|
313 |
+
with col1:
|
314 |
+
pinecone_key = st.text_input(
|
315 |
+
"Pinecone API Key",
|
316 |
+
type="password",
|
317 |
+
value=st.session_state.get('rag_pinecone_key', ''),
|
318 |
+
key="rag_pinecone_input"
|
319 |
+
)
|
320 |
+
with col2:
|
321 |
+
openai_key = st.text_input(
|
322 |
+
"OpenAI API Key",
|
323 |
+
type="password",
|
324 |
+
value=st.session_state.get('rag_openai_key', ''),
|
325 |
+
key="rag_openai_input"
|
326 |
+
)
|
327 |
+
|
328 |
+
if st.button("Save RAG Configuration", key="save_rag_config"):
|
329 |
+
if pinecone_key and openai_key:
|
330 |
+
save_rag_config({
|
331 |
+
'pinecone': pinecone_key,
|
332 |
+
'openai': openai_key
|
333 |
+
})
|
334 |
+
# Initialize Pinecone client
|
335 |
+
st.session_state.pinecone_client = initialize_pinecone(pinecone_key)
|
336 |
+
st.success("✅ RAG configuration saved!")
|
337 |
+
else:
|
338 |
+
st.error("Please provide both API keys.")
|
339 |
+
|
340 |
+
# Update the chat_interface function to include web search toggle
|
341 |
+
def chat_interface():
|
342 |
+
"""Enhanced chat interface with streaming responses and web search toggle."""
|
343 |
+
st.markdown("### 💬 Chat Interface")
|
344 |
+
|
345 |
+
# Add web search toggle in the chat interface
|
346 |
+
col1, col2 = st.columns([3, 1])
|
347 |
+
with col2:
|
348 |
+
web_search = st.checkbox(
|
349 |
+
"🌐 Enable Web Search",
|
350 |
+
value=st.session_state.get('use_web_search', False),
|
351 |
+
help="Toggle web search for additional context",
|
352 |
+
key="web_search_toggle"
|
353 |
+
)
|
354 |
+
st.session_state.use_web_search = web_search
|
355 |
+
|
356 |
+
# Chat container with messages
|
357 |
+
chat_container = st.container()
|
358 |
+
with chat_container:
|
359 |
+
for message in st.session_state.messages:
|
360 |
+
with st.chat_message(message["role"]):
|
361 |
+
st.markdown(message["content"])
|
362 |
+
|
363 |
+
# Chat input
|
364 |
+
if prompt := st.chat_input(
|
365 |
+
"Ask a question about your documents...",
|
366 |
+
disabled=not st.session_state.documents_processed,
|
367 |
+
key="chat_input"
|
368 |
+
):
|
369 |
+
# User message
|
370 |
+
with st.chat_message("user"):
|
371 |
+
if st.session_state.use_web_search:
|
372 |
+
st.markdown(f"{prompt} 🌐")
|
373 |
+
else:
|
374 |
+
st.markdown(prompt)
|
375 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
376 |
+
|
377 |
+
# Assistant response
|
378 |
+
with st.chat_message("assistant"):
|
379 |
+
response_container = st.empty()
|
380 |
+
full_response = ""
|
381 |
+
|
382 |
+
try:
|
383 |
+
with st.spinner("Thinking..."):
|
384 |
+
llm = ChatOpenAI(
|
385 |
+
api_key=st.session_state.rag_openai_key,
|
386 |
+
model="gpt-3.5-turbo"
|
387 |
+
)
|
388 |
+
|
389 |
+
for chunk in run_rag_with_streaming(
|
390 |
+
retriever=st.session_state.retriever,
|
391 |
+
question=prompt,
|
392 |
+
llm=llm,
|
393 |
+
enable_web_search=st.session_state.use_web_search
|
394 |
+
):
|
395 |
+
full_response += chunk
|
396 |
+
response_container.markdown(full_response + "▌")
|
397 |
+
|
398 |
+
response_container.markdown(full_response)
|
399 |
+
st.session_state.messages.append(
|
400 |
+
{"role": "assistant", "content": full_response}
|
401 |
+
)
|
402 |
+
|
403 |
+
except Exception as e:
|
404 |
+
st.error(f"Error: {str(e)}")
|
405 |
+
|
406 |
+
def main():
|
407 |
+
"""Main application layout."""
|
408 |
+
st.title("🤖 Research & RAG Assistant")
|
409 |
+
|
410 |
+
tab1, tab2 = st.tabs(["🔍 Company Research", "💬 Document Q&A"])
|
411 |
+
|
412 |
+
with tab1:
|
413 |
+
research_config_section()
|
414 |
+
if st.session_state.research_config_saved:
|
415 |
+
st.divider()
|
416 |
+
research_input_section()
|
417 |
+
else:
|
418 |
+
st.info("👆 Please configure your API keys above to get started.")
|
419 |
+
|
420 |
+
with tab2:
|
421 |
+
rag_config_section()
|
422 |
+
if st.session_state.rag_config_saved:
|
423 |
+
st.divider()
|
424 |
+
document_upload_section()
|
425 |
+
if st.session_state.documents_processed:
|
426 |
+
st.divider()
|
427 |
+
chat_interface()
|
428 |
+
else:
|
429 |
+
st.info("👆 Please configure your API keys above to get started.")
|
430 |
+
|
431 |
+
if __name__ == "__main__":
|
432 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain-community
|
2 |
+
tiktoken
|
3 |
+
langchain-openai
|
4 |
+
langchainhub
|
5 |
+
chromadb
|
6 |
+
langchain
|
7 |
+
langgraph
|
8 |
+
duckduckgo-search
|
9 |
+
langchain-groq
|
10 |
+
langchain-huggingface
|
11 |
+
sentence_transformers
|
12 |
+
tavily-python
|
13 |
+
langchain-ollama
|
14 |
+
ollama
|
15 |
+
crawl4ai
|
16 |
+
docling
|
17 |
+
easyocr
|
18 |
+
FlagEmbedding
|
19 |
+
chonkie[semantic]
|
20 |
+
pinecone
|
21 |
+
streamlit
|
22 |
+
markdown2
|
23 |
+
xhtml2pdf
|
24 |
+
PyPDF2
|
src/__init__.py
ADDED
File without changes
|
src/agents/__init__.py
ADDED
File without changes
|
src/agents/research_agent.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langgraph.graph import END, StateGraph, START
|
2 |
+
from langchain_core.prompts import PromptTemplate
|
3 |
+
from langchain_core.output_parsers import StrOutputParser
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
+
|
6 |
+
import re
|
7 |
+
import asyncio
|
8 |
+
from typing import TypedDict, List, Optional, Dict
|
9 |
+
from src.tools.deep_crawler import DeepWebCrawler, ResourceCollectionAgent
|
10 |
+
|
11 |
+
class ResearchGraphState(TypedDict):
|
12 |
+
company: str
|
13 |
+
industry: str
|
14 |
+
research_results: Optional[dict]
|
15 |
+
use_cases: Optional[str]
|
16 |
+
search_queries: Optional[Dict[str, List[str]]]
|
17 |
+
resources: Optional[List[dict]]
|
18 |
+
final_report: Optional[str]
|
19 |
+
|
20 |
+
|
21 |
+
def clean_text(text):
|
22 |
+
"""
|
23 |
+
Cleans the given text by:
|
24 |
+
1. Removing all hyperlinks.
|
25 |
+
2. Removing unnecessary parentheses and square brackets.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
text (str): The input text to be cleaned.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
str: The cleaned text with hyperlinks, parentheses, and square brackets removed.
|
32 |
+
"""
|
33 |
+
# Regular expression pattern for matching URLs
|
34 |
+
url_pattern = r'https?://\S+|www\.\S+'
|
35 |
+
# Remove hyperlinks
|
36 |
+
text_without_links = re.sub(url_pattern, '', text)
|
37 |
+
|
38 |
+
# Regular expression pattern for matching parentheses and square brackets
|
39 |
+
brackets_pattern = r'[\[\]\(\)]'
|
40 |
+
# Remove unnecessary brackets
|
41 |
+
cleaned_text = re.sub(brackets_pattern, '', text_without_links)
|
42 |
+
|
43 |
+
return cleaned_text.strip()
|
44 |
+
|
45 |
+
|
46 |
+
def create_industry_research_workflow(llm):
|
47 |
+
async def industry_research(state: ResearchGraphState):
|
48 |
+
"""Research industry and company using DeepWebCrawler."""
|
49 |
+
company = state['company']
|
50 |
+
industry = state['industry']
|
51 |
+
|
52 |
+
queries = [
|
53 |
+
f"{company} company profile services",
|
54 |
+
]
|
55 |
+
|
56 |
+
crawler = DeepWebCrawler(
|
57 |
+
max_search_results=3,
|
58 |
+
max_external_links=1,
|
59 |
+
word_count_threshold=100,
|
60 |
+
content_filter_type='bm25',
|
61 |
+
filter_threshold=0.48
|
62 |
+
)
|
63 |
+
|
64 |
+
all_results = []
|
65 |
+
for query in queries:
|
66 |
+
results = await crawler.search_and_crawl(query)
|
67 |
+
all_results.extend(results)
|
68 |
+
print(all_results)
|
69 |
+
combined_content = "\n\n".join([
|
70 |
+
f"Title: {clean_text(r['title'])} \n{clean_text(r['content'])}"
|
71 |
+
for r in all_results if r['success']
|
72 |
+
])
|
73 |
+
print("Combined Content: ", combined_content)
|
74 |
+
prompt = PromptTemplate.from_template(
|
75 |
+
"""Analyze this research about {company} in the {industry} industry:
|
76 |
+
{content}
|
77 |
+
|
78 |
+
Provide a comprehensive overview including:
|
79 |
+
1. Company Overview
|
80 |
+
2. Market Segments
|
81 |
+
3. Products and Services
|
82 |
+
4. Strategic Focus Areas
|
83 |
+
5. Industry Trends
|
84 |
+
6. Competitive Position
|
85 |
+
|
86 |
+
Format the analysis in clear sections with headers."""
|
87 |
+
)
|
88 |
+
|
89 |
+
chain = prompt | llm | StrOutputParser()
|
90 |
+
analysis = chain.invoke({
|
91 |
+
"company": company,
|
92 |
+
"industry": industry,
|
93 |
+
"content": combined_content
|
94 |
+
})
|
95 |
+
print("Analysis: ", analysis)
|
96 |
+
return {
|
97 |
+
"research_results": {
|
98 |
+
"analysis": analysis,
|
99 |
+
"raw_content": combined_content
|
100 |
+
}
|
101 |
+
}
|
102 |
+
|
103 |
+
def generate_use_cases_and_queries(state: ResearchGraphState):
|
104 |
+
"""Generate AI/ML use cases and extract relevant search queries."""
|
105 |
+
research_data = state['research_results']
|
106 |
+
company = state['company']
|
107 |
+
industry = state['industry']
|
108 |
+
|
109 |
+
# First generate use cases
|
110 |
+
use_case_prompt = PromptTemplate.from_template(
|
111 |
+
"""Based on this research:
|
112 |
+
|
113 |
+
Analysis: {analysis}
|
114 |
+
Raw Research: {raw_content}
|
115 |
+
|
116 |
+
Generate innovative use cases where {company} in the {industry} industry can leverage
|
117 |
+
Generative AI and Large Language Models for:
|
118 |
+
|
119 |
+
1. Internal Process Improvements
|
120 |
+
2. Customer Experience Enhancement
|
121 |
+
3. Product/Service Innovation
|
122 |
+
4. Data Analytics and Decision Making
|
123 |
+
|
124 |
+
For each use case, provide:
|
125 |
+
- Clear description
|
126 |
+
- Expected benefits
|
127 |
+
- Implementation considerations"""
|
128 |
+
)
|
129 |
+
|
130 |
+
chain = use_case_prompt | llm | StrOutputParser()
|
131 |
+
use_cases = chain.invoke({
|
132 |
+
"company": company,
|
133 |
+
"industry": industry,
|
134 |
+
"analysis": research_data['analysis'],
|
135 |
+
"raw_content": research_data['raw_content']
|
136 |
+
})
|
137 |
+
|
138 |
+
# Then extract relevant search queries
|
139 |
+
query_extraction_prompt = PromptTemplate.from_template(
|
140 |
+
"""Based on these AI/ML use cases for {company}:
|
141 |
+
|
142 |
+
{use_cases}
|
143 |
+
|
144 |
+
Extract Two specific search queries for finding relevant datasets and implementations.
|
145 |
+
|
146 |
+
Provide your response in this exact format:
|
147 |
+
DATASET QUERIES:
|
148 |
+
- query1
|
149 |
+
- query2
|
150 |
+
|
151 |
+
IMPLEMENTATION QUERIES:
|
152 |
+
- query1
|
153 |
+
- query2
|
154 |
+
|
155 |
+
Make queries specific and technical. Include ML model types, data types, and specific AI techniques."""
|
156 |
+
)
|
157 |
+
|
158 |
+
chain = query_extraction_prompt | llm | StrOutputParser()
|
159 |
+
queries_text = chain.invoke({
|
160 |
+
"company": company,
|
161 |
+
"use_cases": use_cases
|
162 |
+
})
|
163 |
+
|
164 |
+
# Parse the text response into structured format
|
165 |
+
def parse_queries(text):
|
166 |
+
dataset_queries = []
|
167 |
+
implementation_queries = []
|
168 |
+
current_section = None
|
169 |
+
|
170 |
+
for line in text.split('\n'):
|
171 |
+
line = line.strip()
|
172 |
+
if line == "DATASET QUERIES:":
|
173 |
+
current_section = "dataset"
|
174 |
+
elif line == "IMPLEMENTATION QUERIES:":
|
175 |
+
current_section = "implementation"
|
176 |
+
elif line.startswith("- "):
|
177 |
+
query = line[2:].strip()
|
178 |
+
if current_section == "dataset":
|
179 |
+
dataset_queries.append(query)
|
180 |
+
elif current_section == "implementation":
|
181 |
+
implementation_queries.append(query)
|
182 |
+
|
183 |
+
return {
|
184 |
+
"dataset_queries": dataset_queries or ["machine learning datasets business", "AI training data industry"],
|
185 |
+
"implementation_queries": implementation_queries or ["AI tools business automation", "machine learning implementation"]
|
186 |
+
}
|
187 |
+
|
188 |
+
search_queries = parse_queries(queries_text)
|
189 |
+
print("Search_queries: ", search_queries)
|
190 |
+
return {
|
191 |
+
"use_cases": use_cases,
|
192 |
+
"search_queries": search_queries
|
193 |
+
}
|
194 |
+
|
195 |
+
async def collect_targeted_resources(state: ResearchGraphState):
|
196 |
+
"""Find relevant datasets and resources using extracted queries."""
|
197 |
+
search_queries = state['search_queries']
|
198 |
+
resource_agent = ResourceCollectionAgent(max_results_per_query=5)
|
199 |
+
|
200 |
+
# Collect resources using targeted queries
|
201 |
+
all_resources = {
|
202 |
+
"datasets": [],
|
203 |
+
"implementations": []
|
204 |
+
}
|
205 |
+
|
206 |
+
# Search for datasets
|
207 |
+
for query in search_queries['dataset_queries']:
|
208 |
+
# Add platform-specific modifiers to queries
|
209 |
+
kaggle_query = f"site:kaggle.com/datasets {query}"
|
210 |
+
huggingface_query = f"site:huggingface.co/datasets {query}"
|
211 |
+
|
212 |
+
resources = await resource_agent.collect_resources()
|
213 |
+
|
214 |
+
# Process and categorize results
|
215 |
+
if resources.get("kaggle_datasets"):
|
216 |
+
all_resources["datasets"].extend([{
|
217 |
+
"title": item["title"],
|
218 |
+
"url": item["url"],
|
219 |
+
"description": item["snippet"],
|
220 |
+
"platform": "Kaggle",
|
221 |
+
"query": query
|
222 |
+
} for item in resources["kaggle_datasets"]])
|
223 |
+
|
224 |
+
if resources.get("huggingface_datasets"):
|
225 |
+
all_resources["datasets"].extend([{
|
226 |
+
"title": item["title"],
|
227 |
+
"url": item["url"],
|
228 |
+
"description": item["snippet"],
|
229 |
+
"platform": "HuggingFace",
|
230 |
+
"query": query
|
231 |
+
} for item in resources["huggingface_datasets"]])
|
232 |
+
|
233 |
+
# Search for implementations
|
234 |
+
for query in search_queries['implementation_queries']:
|
235 |
+
github_query = f"site:github.com {query}"
|
236 |
+
|
237 |
+
resources = await resource_agent.collect_resources()
|
238 |
+
|
239 |
+
if resources.get("github_repositories"):
|
240 |
+
all_resources["implementations"].extend([{
|
241 |
+
"title": item["title"],
|
242 |
+
"url": item["url"],
|
243 |
+
"description": item["snippet"],
|
244 |
+
"platform": "GitHub",
|
245 |
+
"query": query
|
246 |
+
} for item in resources["github_repositories"]])
|
247 |
+
print("Resources: ", all_resources)
|
248 |
+
return {"resources": all_resources}
|
249 |
+
|
250 |
+
def generate_pdf_report(state: ResearchGraphState):
|
251 |
+
"""Generate final PDF report with all collected information."""
|
252 |
+
research_data = state['research_results']
|
253 |
+
use_cases = state['use_cases']
|
254 |
+
resources = state['resources']
|
255 |
+
company = state['company']
|
256 |
+
industry = state['industry']
|
257 |
+
|
258 |
+
# Format resources for manual append later
|
259 |
+
datasets_section = "\n## Available Datasets\n"
|
260 |
+
if resources.get('datasets'):
|
261 |
+
for dataset in resources['datasets']:
|
262 |
+
datasets_section += f" - {dataset['platform']}: {dataset['url']}\n"
|
263 |
+
|
264 |
+
implementations_section = "\n## Implementation Resources\n"
|
265 |
+
if resources.get('implementations'):
|
266 |
+
for impl in resources['implementations']:
|
267 |
+
implementations_section += f" - {impl['platform']}: {impl['url']}\n"
|
268 |
+
|
269 |
+
|
270 |
+
prompt = PromptTemplate.from_template(
|
271 |
+
"""
|
272 |
+
# GenAI & ML Implementation Proposal for {company}
|
273 |
+
|
274 |
+
## Executive Summary
|
275 |
+
- **Current Position in the {industry} Industry**:
|
276 |
+
- **Key Opportunities for AI/ML Implementation**:
|
277 |
+
- **Expected Business Impact and ROI**:
|
278 |
+
- **Implementation Timeline Overview**:
|
279 |
+
|
280 |
+
## Industry and Company Analysis
|
281 |
+
{analysis}
|
282 |
+
|
283 |
+
## Strategic AI/ML Implementation Opportunities
|
284 |
+
|
285 |
+
Based on the analysis, here are the key opportunities for AI/ML implementation:
|
286 |
+
|
287 |
+
{use_cases}
|
288 |
+
|
289 |
+
Format the report in Markdown for clear sections, headings, and bullet points. Ensure professional formatting with structured subsections.
|
290 |
+
"""
|
291 |
+
)
|
292 |
+
|
293 |
+
chain = prompt | llm | StrOutputParser()
|
294 |
+
markdown_content = chain.invoke({
|
295 |
+
"company": company,
|
296 |
+
"industry": industry,
|
297 |
+
"analysis": research_data['analysis'],
|
298 |
+
"use_cases": use_cases,
|
299 |
+
})
|
300 |
+
|
301 |
+
if markdown_content.startswith("```markdown") and markdown_content.endswith("```"):
|
302 |
+
markdown_content = markdown_content[len("```markdown"):].rstrip("```").strip()
|
303 |
+
|
304 |
+
markdown_content += "\n\n" + datasets_section + "\n\n" + implementations_section
|
305 |
+
# Convert markdown to PDF
|
306 |
+
import tempfile
|
307 |
+
import os
|
308 |
+
import markdown2
|
309 |
+
from xhtml2pdf import pisa
|
310 |
+
|
311 |
+
# Create temporary directory and full path for PDF
|
312 |
+
temp_dir = tempfile.mkdtemp()
|
313 |
+
pdf_filename = f"{company.replace(' ', '_')}_research_report.pdf"
|
314 |
+
pdf_path = os.path.join(temp_dir, pdf_filename)
|
315 |
+
|
316 |
+
html_content = markdown2.markdown(markdown_content, extras=['tables', 'break-on-newline'])
|
317 |
+
# HTML template with enhanced styles (same as before)
|
318 |
+
html_template = f"""
|
319 |
+
<!DOCTYPE html>
|
320 |
+
<html>
|
321 |
+
<head>
|
322 |
+
<meta charset="UTF-8">
|
323 |
+
<style>
|
324 |
+
@page {{
|
325 |
+
size: A4;
|
326 |
+
margin: 2.5cm;
|
327 |
+
@frame footer {{
|
328 |
+
-pdf-frame-content: footerContent;
|
329 |
+
bottom: 1cm;
|
330 |
+
margin-left: 1cm;
|
331 |
+
margin-right: 1cm;
|
332 |
+
height: 1cm;
|
333 |
+
}}
|
334 |
+
}}
|
335 |
+
body {{
|
336 |
+
font-family: Helvetica, Arial, sans-serif;
|
337 |
+
font-size: 11pt;
|
338 |
+
line-height: 1.6;
|
339 |
+
color: #2c3e50;
|
340 |
+
}}
|
341 |
+
h1 {{
|
342 |
+
font-size: 24pt;
|
343 |
+
color: #1a237e;
|
344 |
+
text-align: center;
|
345 |
+
margin-bottom: 2cm;
|
346 |
+
font-weight: bold;
|
347 |
+
}}
|
348 |
+
h2 {{
|
349 |
+
font-size: 18pt;
|
350 |
+
color: #283593;
|
351 |
+
margin-top: 1.5cm;
|
352 |
+
border-bottom: 2px solid #3949ab;
|
353 |
+
padding-bottom: 0.3cm;
|
354 |
+
}}
|
355 |
+
h3 {{
|
356 |
+
font-size: 14pt;
|
357 |
+
color: #3949ab;
|
358 |
+
margin-top: 1cm;
|
359 |
+
}}
|
360 |
+
h4 {{
|
361 |
+
font-size: 12pt;
|
362 |
+
color: #5c6bc0;
|
363 |
+
margin-top: 0.8cm;
|
364 |
+
}}
|
365 |
+
p {{
|
366 |
+
text-align: justify;
|
367 |
+
margin-bottom: 0.5cm;
|
368 |
+
}}
|
369 |
+
ul {{
|
370 |
+
margin-left: 0;
|
371 |
+
padding-left: 1cm;
|
372 |
+
margin-bottom: 0.5cm;
|
373 |
+
}}
|
374 |
+
li {{
|
375 |
+
margin-bottom: 0.3cm;
|
376 |
+
}}
|
377 |
+
a {{
|
378 |
+
color: #3f51b5;
|
379 |
+
text-decoration: none;
|
380 |
+
}}
|
381 |
+
strong {{
|
382 |
+
color: #283593;
|
383 |
+
}}
|
384 |
+
.use-case {{
|
385 |
+
background-color: #f5f7fa;
|
386 |
+
padding: 1cm;
|
387 |
+
margin: 0.5cm 0;
|
388 |
+
border-left: 4px solid #3949ab;
|
389 |
+
}}
|
390 |
+
.benefit {{
|
391 |
+
margin-left: 1cm;
|
392 |
+
color: #34495e;
|
393 |
+
}}
|
394 |
+
</style>
|
395 |
+
</head>
|
396 |
+
<body>
|
397 |
+
{html_content}
|
398 |
+
<div id="footerContent" style="text-align: center; font-size: 8pt; color: #7f8c8d;">
|
399 |
+
Page <pdf:pagenumber> of <pdf:pagecount>
|
400 |
+
</div>
|
401 |
+
</body>
|
402 |
+
</html>
|
403 |
+
"""
|
404 |
+
|
405 |
+
# Convert HTML to PDF with proper error handling
|
406 |
+
try:
|
407 |
+
with open(pdf_path, "w+b") as pdf_file:
|
408 |
+
result = pisa.CreatePDF(
|
409 |
+
html_template,
|
410 |
+
dest=pdf_file
|
411 |
+
)
|
412 |
+
if result.err:
|
413 |
+
print(f"Error generating PDF: {result.err}")
|
414 |
+
return {"final_report": None}
|
415 |
+
|
416 |
+
# Verify the file exists
|
417 |
+
if os.path.exists(pdf_path):
|
418 |
+
print(f"PDF successfully generated at: {pdf_path}")
|
419 |
+
return {"final_report": pdf_path}
|
420 |
+
else:
|
421 |
+
print("PDF file was not created successfully")
|
422 |
+
return {"final_report": None}
|
423 |
+
|
424 |
+
except Exception as e:
|
425 |
+
print(f"Exception during PDF generation: {str(e)}")
|
426 |
+
return {"final_report": None}
|
427 |
+
|
428 |
+
# Create workflow
|
429 |
+
workflow = StateGraph(ResearchGraphState)
|
430 |
+
|
431 |
+
# Add nodes
|
432 |
+
workflow.add_node("industry_research", industry_research)
|
433 |
+
workflow.add_node("use_cases_gen", generate_use_cases_and_queries)
|
434 |
+
workflow.add_node("resources_gen", collect_targeted_resources)
|
435 |
+
workflow.add_node("report", generate_pdf_report)
|
436 |
+
|
437 |
+
# Define edges
|
438 |
+
workflow.add_edge(START, "industry_research")
|
439 |
+
workflow.add_edge("industry_research", "use_cases_gen")
|
440 |
+
workflow.add_edge("use_cases_gen", "resources_gen")
|
441 |
+
workflow.add_edge("resources_gen", "report")
|
442 |
+
workflow.add_edge("report", END)
|
443 |
+
|
444 |
+
return workflow.compile()
|
445 |
+
|
446 |
+
async def run_industry_research(company: str, industry: str, llm):
|
447 |
+
"""Run the industry research workflow asynchronously."""
|
448 |
+
workflow = create_industry_research_workflow(llm)
|
449 |
+
|
450 |
+
final_state = None
|
451 |
+
output = await workflow.ainvoke(input={
|
452 |
+
"company": company,
|
453 |
+
"industry": industry
|
454 |
+
}, config={"recursion_limit": 5})
|
455 |
+
|
456 |
+
return output['final_report']
|
457 |
+
|
458 |
+
# Example usage
|
459 |
+
if __name__ == "__main__":
|
460 |
+
async def main():
|
461 |
+
# Initialize LLM
|
462 |
+
llm = ChatOpenAI(
|
463 |
+
model="gpt-3.5-turbo-0125",
|
464 |
+
temperature=0.3,
|
465 |
+
timeout=None,
|
466 |
+
max_retries=2,)
|
467 |
+
|
468 |
+
# Run the workflow
|
469 |
+
report_path = await run_industry_research(
|
470 |
+
company="Adani Defence & Aerospace",
|
471 |
+
industry="Defense Engineering and Construction",
|
472 |
+
llm=llm
|
473 |
+
)
|
474 |
+
print(f"Report generated at: {report_path}")
|
475 |
+
|
476 |
+
asyncio.run(main())
|
src/agents/router.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import ChatPromptTemplate
|
2 |
+
from langchain_ollama import ChatOllama
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from typing import Literal
|
5 |
+
|
6 |
+
class RouteQuery(BaseModel):
|
7 |
+
"""Route a user query to the most relevant datasource."""
|
8 |
+
datasource: Literal["vectorstore", "web_search"] = Field(
|
9 |
+
description="Route question to web search or vectorstore retrieval"
|
10 |
+
)
|
11 |
+
|
12 |
+
def create_query_router():
|
13 |
+
"""
|
14 |
+
Create a query router to determine data source for a given question.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
Callable: Query router function
|
18 |
+
"""
|
19 |
+
# LLM with function call
|
20 |
+
llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
|
21 |
+
structured_llm_router = llm.with_structured_output(RouteQuery)
|
22 |
+
|
23 |
+
# Prompt
|
24 |
+
system = """You are an expert at routing a user question to a vectorstore or web search.
|
25 |
+
The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
|
26 |
+
Use the vectorstore for questions on these topics. Otherwise, use web-search."""
|
27 |
+
|
28 |
+
route_prompt = ChatPromptTemplate.from_messages([
|
29 |
+
("system", system),
|
30 |
+
("human", "{question}"),
|
31 |
+
])
|
32 |
+
|
33 |
+
return route_prompt | structured_llm_router
|
34 |
+
|
35 |
+
def route_query(question: str):
|
36 |
+
"""
|
37 |
+
Route a specific query to its appropriate data source.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
question (str): User's input question
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
str: Recommended data source
|
44 |
+
"""
|
45 |
+
router = create_query_router()
|
46 |
+
result = router.invoke({"question": question})
|
47 |
+
return result.datasource
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
# Example usage
|
51 |
+
test_questions = [
|
52 |
+
"Who will the Bears draft first in the NFL draft?",
|
53 |
+
"What are the types of agent memory?"
|
54 |
+
]
|
55 |
+
|
56 |
+
for q in test_questions:
|
57 |
+
source = route_query(q)
|
58 |
+
print(f"Question: {q}")
|
59 |
+
print(f"Routed to: {source}\n")
|
src/agents/state.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, TypedDict
|
2 |
+
from langchain_core.documents.base import Document
|
3 |
+
|
4 |
+
class GraphState(TypedDict):
|
5 |
+
"""
|
6 |
+
Represents the state of our adaptive RAG graph.
|
7 |
+
|
8 |
+
Attributes:
|
9 |
+
question (str): Original user question
|
10 |
+
generation (str, optional): LLM generated answer
|
11 |
+
documents (List[Document], optional): Retrieved or searched documents
|
12 |
+
"""
|
13 |
+
question: str
|
14 |
+
generation: str | None
|
15 |
+
documents: List[Document]
|
16 |
+
|
17 |
+
class ResearchGraphState(TypedDict):
|
18 |
+
"""
|
19 |
+
Represents the state of our adaptive RAG graph.
|
20 |
+
|
21 |
+
Attributes:
|
22 |
+
question (str): Original user question
|
23 |
+
generation (str, optional): LLM generated answer
|
24 |
+
documents (List[Document], optional): Retrieved or searched documents
|
25 |
+
"""
|
26 |
+
company: str
|
27 |
+
industry: str | None
|
28 |
+
research_results: str | None
|
29 |
+
use_cases: str | None
|
30 |
+
search_queries: str | None
|
src/agents/workflow.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langgraph.graph import END, StateGraph, START
|
2 |
+
from langchain_core.prompts import PromptTemplate
|
3 |
+
from src.agents.state import GraphState
|
4 |
+
# from agents.router import route_query
|
5 |
+
import asyncio
|
6 |
+
from src.vectorstore.pinecone_db import get_retriever
|
7 |
+
from src.tools.web_search import AdvancedWebCrawler
|
8 |
+
from src.llm.graders import (
|
9 |
+
grade_document_relevance,
|
10 |
+
check_hallucination,
|
11 |
+
grade_answer_quality
|
12 |
+
)
|
13 |
+
from langchain_core.output_parsers import StrOutputParser
|
14 |
+
from src.llm.query_rewriter import rewrite_query
|
15 |
+
from langchain_ollama import ChatOllama
|
16 |
+
|
17 |
+
def perform_web_search(question: str):
|
18 |
+
"""
|
19 |
+
Perform web search using the AdvancedWebCrawler.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
question (str): User's input question
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
List: Web search results
|
26 |
+
"""
|
27 |
+
# Initialize web crawler
|
28 |
+
crawler = AdvancedWebCrawler(
|
29 |
+
max_search_results=5,
|
30 |
+
word_count_threshold=50,
|
31 |
+
content_filter_type='f',
|
32 |
+
filter_threshold=0.48
|
33 |
+
)
|
34 |
+
results = asyncio.run(crawler.search_and_crawl(question))
|
35 |
+
|
36 |
+
return results
|
37 |
+
|
38 |
+
|
39 |
+
def create_adaptive_rag_workflow(retriever, llm, top_k=5, enable_websearch=False):
|
40 |
+
"""
|
41 |
+
Create the adaptive RAG workflow graph.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
retriever: Vector store retriever
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Compiled LangGraph workflow
|
48 |
+
"""
|
49 |
+
def retrieve(state: GraphState):
|
50 |
+
"""Retrieve documents from vectorstore."""
|
51 |
+
print("---RETRIEVE---")
|
52 |
+
question = state['question']
|
53 |
+
documents = retriever.invoke(question, top_k)
|
54 |
+
print(f"Retrieved {len(documents)} documents.")
|
55 |
+
print(documents)
|
56 |
+
return {"documents": documents, "question": question}
|
57 |
+
|
58 |
+
def route_to_datasource(state: GraphState):
|
59 |
+
"""Route question to web search or vectorstore."""
|
60 |
+
print("---ROUTE QUESTION---")
|
61 |
+
# question = state['question']
|
62 |
+
# source = route_query(question)
|
63 |
+
|
64 |
+
if enable_websearch:
|
65 |
+
print("---ROUTE TO WEB SEARCH---")
|
66 |
+
return "web_search"
|
67 |
+
else:
|
68 |
+
print("---ROUTE TO RAG---")
|
69 |
+
return "vectorstore"
|
70 |
+
|
71 |
+
def generate_answer(state: GraphState):
|
72 |
+
"""Generate answer using retrieved documents."""
|
73 |
+
print("---GENERATE---")
|
74 |
+
question = state['question']
|
75 |
+
documents = state['documents']
|
76 |
+
|
77 |
+
# Prepare context
|
78 |
+
context = "\n\n".join([doc["page_content"] for doc in documents])
|
79 |
+
prompt_template = PromptTemplate.from_template("""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
|
80 |
+
Question: {question}
|
81 |
+
Context: {context}
|
82 |
+
Answer:""")
|
83 |
+
# Generate answer
|
84 |
+
rag_chain = prompt_template | llm | StrOutputParser()
|
85 |
+
|
86 |
+
generation = rag_chain.invoke({"context": context, "question": question})
|
87 |
+
|
88 |
+
return {"generation": generation, "documents": documents, "question": question}
|
89 |
+
|
90 |
+
def grade_documents(state: GraphState):
|
91 |
+
"""Filter relevant documents."""
|
92 |
+
print("---GRADE DOCUMENTS---")
|
93 |
+
question = state['question']
|
94 |
+
documents = state['documents']
|
95 |
+
|
96 |
+
# Filter documents
|
97 |
+
filtered_docs = []
|
98 |
+
for doc in documents:
|
99 |
+
score = grade_document_relevance(question, doc["page_content"], llm)
|
100 |
+
if score == "yes":
|
101 |
+
filtered_docs.append(doc)
|
102 |
+
|
103 |
+
return {"documents": filtered_docs, "question": question}
|
104 |
+
|
105 |
+
def web_search(state: GraphState):
|
106 |
+
"""Perform web search."""
|
107 |
+
print("---WEB SEARCH---")
|
108 |
+
question = state['question']
|
109 |
+
|
110 |
+
# Perform web search
|
111 |
+
results = perform_web_search(question)
|
112 |
+
web_documents = [
|
113 |
+
{
|
114 |
+
"page_content": result['content'],
|
115 |
+
"metadata": {"source": result['url']}
|
116 |
+
} for result in results
|
117 |
+
]
|
118 |
+
|
119 |
+
return {"documents": web_documents, "question": question}
|
120 |
+
|
121 |
+
def check_generation_quality(state: GraphState):
|
122 |
+
"""Check the quality of generated answer."""
|
123 |
+
print("---ASSESS GENERATION---")
|
124 |
+
question = state['question']
|
125 |
+
documents = state['documents']
|
126 |
+
generation = state['generation']
|
127 |
+
|
128 |
+
|
129 |
+
print("---Generation is not hallucinated.---")
|
130 |
+
# Check answer quality
|
131 |
+
quality_score = grade_answer_quality(question, generation, llm)
|
132 |
+
if quality_score == "yes":
|
133 |
+
print("---Answer quality is good.---")
|
134 |
+
else:
|
135 |
+
print("---Answer quality is poor.---")
|
136 |
+
return "end" if quality_score == "yes" else "rewrite"
|
137 |
+
|
138 |
+
# Create workflow
|
139 |
+
workflow = StateGraph(GraphState)
|
140 |
+
|
141 |
+
# Add nodes
|
142 |
+
workflow.add_node("vectorstore", retrieve)
|
143 |
+
workflow.add_node("web_search", web_search)
|
144 |
+
workflow.add_node("grade_documents", grade_documents)
|
145 |
+
workflow.add_node("generate", generate_answer)
|
146 |
+
workflow.add_node("rewrite_query", lambda state: {
|
147 |
+
"question": rewrite_query(state['question'], llm),
|
148 |
+
"documents": [],
|
149 |
+
"generation": None
|
150 |
+
})
|
151 |
+
|
152 |
+
# Define edges
|
153 |
+
workflow.add_conditional_edges(
|
154 |
+
START,
|
155 |
+
route_to_datasource,
|
156 |
+
{
|
157 |
+
"web_search": "web_search",
|
158 |
+
"vectorstore": "vectorstore"
|
159 |
+
}
|
160 |
+
)
|
161 |
+
|
162 |
+
workflow.add_edge("web_search", "generate")
|
163 |
+
workflow.add_edge("vectorstore", "grade_documents")
|
164 |
+
|
165 |
+
workflow.add_conditional_edges(
|
166 |
+
"grade_documents",
|
167 |
+
lambda state: "generate" if state['documents'] else "rewrite_query"
|
168 |
+
)
|
169 |
+
|
170 |
+
workflow.add_edge("rewrite_query", "vectorstore")
|
171 |
+
|
172 |
+
workflow.add_conditional_edges(
|
173 |
+
"generate",
|
174 |
+
check_generation_quality,
|
175 |
+
{
|
176 |
+
"end": END,
|
177 |
+
"regenerate": "generate",
|
178 |
+
"rewrite": "rewrite_query"
|
179 |
+
}
|
180 |
+
)
|
181 |
+
|
182 |
+
# Compile the workflow
|
183 |
+
app = workflow.compile()
|
184 |
+
return app
|
185 |
+
|
186 |
+
def run_adaptive_rag(retriever, question: str, llm, top_k=5, enable_websearch=False):
|
187 |
+
"""
|
188 |
+
Run the adaptive RAG workflow for a given question.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
retriever: Vector store retriever
|
192 |
+
question (str): User's input question
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
str: Generated answer
|
196 |
+
"""
|
197 |
+
# Create workflow
|
198 |
+
workflow = create_adaptive_rag_workflow(retriever, llm, top_k, enable_websearch=enable_websearch)
|
199 |
+
|
200 |
+
# Run workflow
|
201 |
+
final_state = None
|
202 |
+
for output in workflow.stream({"question": question}, config={"recursion_limit": 5}):
|
203 |
+
for key, value in output.items():
|
204 |
+
print(f"Node '{key}':")
|
205 |
+
# Optionally print state details
|
206 |
+
# print(value)
|
207 |
+
final_state = value
|
208 |
+
|
209 |
+
return final_state.get('generation', 'No answer could be generated.')
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
# Example usage
|
213 |
+
from vectorstore.pinecone_db import PINECONE_API_KEY, ingest_data, get_retriever, load_documents, process_chunks, save_to_parquet
|
214 |
+
from pinecone import Pinecone
|
215 |
+
|
216 |
+
# Load and prepare documents
|
217 |
+
pc = Pinecone(api_key=PINECONE_API_KEY)
|
218 |
+
|
219 |
+
# Define input files
|
220 |
+
file_paths=[
|
221 |
+
# './data/2404.19756v1.pdf',
|
222 |
+
# './data/OD429347375590223100.pdf',
|
223 |
+
# './data/Project Report Format.docx',
|
224 |
+
'./data/UNIT 2 GENDER BASED VIOLENCE.pptx'
|
225 |
+
]
|
226 |
+
|
227 |
+
# Process pipeline
|
228 |
+
try:
|
229 |
+
# Step 1: Load and combine documents
|
230 |
+
print("Loading documents...")
|
231 |
+
markdown_path = load_documents(file_paths)
|
232 |
+
|
233 |
+
# Step 2: Process into chunks with embeddings
|
234 |
+
print("Processing chunks...")
|
235 |
+
chunks = process_chunks(markdown_path)
|
236 |
+
|
237 |
+
# Step 3: Save to Parquet
|
238 |
+
print("Saving to Parquet...")
|
239 |
+
parquet_path = save_to_parquet(chunks)
|
240 |
+
|
241 |
+
# Step 4: Ingest into Pinecone
|
242 |
+
print("Ingesting into Pinecone...")
|
243 |
+
ingest_data(pc,
|
244 |
+
parquet_path=parquet_path,
|
245 |
+
text_column="text",
|
246 |
+
pinecone_client=pc,
|
247 |
+
)
|
248 |
+
|
249 |
+
# Step 5: Test retrieval
|
250 |
+
print("\nTesting retrieval...")
|
251 |
+
retriever = get_retriever(
|
252 |
+
pinecone_client=pc,
|
253 |
+
index_name="vector-index",
|
254 |
+
namespace="rag"
|
255 |
+
)
|
256 |
+
|
257 |
+
except Exception as e:
|
258 |
+
print(f"Error in pipeline: {str(e)}")
|
259 |
+
|
260 |
+
llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
|
261 |
+
|
262 |
+
# Test questions
|
263 |
+
test_questions = [
|
264 |
+
# "What are the key components of AI agent memory?",
|
265 |
+
# "Explain prompt engineering techniques",
|
266 |
+
# "What are recent advancements in adversarial attacks on LLMs?"
|
267 |
+
"what are the trending papers that are published in NeurIPS 2024?"
|
268 |
+
]
|
269 |
+
|
270 |
+
# Run workflow for each test question
|
271 |
+
for question in test_questions:
|
272 |
+
print(f"\n--- Processing Question: {question} ---")
|
273 |
+
answer = run_adaptive_rag(retriever, question, llm)
|
274 |
+
print("\nFinal Answer:", answer)
|
src/data_processing/__init__.py
ADDED
File without changes
|
src/data_processing/chunker.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import numpy as np
|
3 |
+
from chonkie.embeddings import BaseEmbeddings
|
4 |
+
from FlagEmbedding import BGEM3FlagModel
|
5 |
+
from chonkie import SDPMChunker as SDPMChunker
|
6 |
+
|
7 |
+
class BGEM3Embeddings(BaseEmbeddings):
|
8 |
+
def __init__(self, model_name):
|
9 |
+
self.model = BGEM3FlagModel(model_name, use_fp16=True)
|
10 |
+
self.task = "separation"
|
11 |
+
|
12 |
+
@property
|
13 |
+
def dimension(self):
|
14 |
+
return 1024
|
15 |
+
|
16 |
+
def embed(self, text: str):
|
17 |
+
e = self.model.encode([text], return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs']
|
18 |
+
# print(e)
|
19 |
+
return e
|
20 |
+
|
21 |
+
def embed_batch(self, texts: List[str]):
|
22 |
+
embeddings = self.model.encode(texts, return_dense=True, return_sparse=False, return_colbert_vecs=False
|
23 |
+
)
|
24 |
+
# print(embeddings['dense_vecs'])
|
25 |
+
return embeddings['dense_vecs']
|
26 |
+
|
27 |
+
def count_tokens(self, text: str):
|
28 |
+
l = len(self.model.tokenizer.encode(text))
|
29 |
+
# print(l)
|
30 |
+
return l
|
31 |
+
|
32 |
+
def count_tokens_batch(self, texts: List[str]):
|
33 |
+
encodings = self.model.tokenizer(texts)
|
34 |
+
# print([len(enc) for enc in encodings["input_ids"]])
|
35 |
+
return [len(enc) for enc in encodings["input_ids"]]
|
36 |
+
|
37 |
+
def get_tokenizer_or_token_counter(self):
|
38 |
+
return self.model.tokenizer
|
39 |
+
|
40 |
+
def similarity(self, u: "np.ndarray", v: "np.ndarray"):
|
41 |
+
"""Compute cosine similarity between two embeddings."""
|
42 |
+
s = ([email protected])#.item()
|
43 |
+
# print(s)
|
44 |
+
return s
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def is_available(cls):
|
48 |
+
return True
|
49 |
+
|
50 |
+
def __repr__(self):
|
51 |
+
return "bgem3"
|
52 |
+
|
53 |
+
|
54 |
+
def main():
|
55 |
+
# Initialize the BGE M3 embeddings model
|
56 |
+
embedding_model = BGEM3Embeddings(
|
57 |
+
model_name="BAAI/bge-m3"
|
58 |
+
)
|
59 |
+
|
60 |
+
# Initialize the SDPM chunker
|
61 |
+
chunker = SDPMChunker(
|
62 |
+
embedding_model=embedding_model,
|
63 |
+
chunk_size=256,
|
64 |
+
threshold=0.7,
|
65 |
+
skip_window=2
|
66 |
+
)
|
67 |
+
|
68 |
+
with open('./output.md', 'r') as file:
|
69 |
+
text = file.read()
|
70 |
+
|
71 |
+
# Generate chunks
|
72 |
+
chunks = chunker.chunk(text)
|
73 |
+
|
74 |
+
# Print the chunks
|
75 |
+
for i, chunk in enumerate(chunks, 1):
|
76 |
+
print(f"\nChunk {i}:")
|
77 |
+
print(f"Text: {chunk.text}")
|
78 |
+
print(f"Token count: {chunk.token_count}")
|
79 |
+
print(f"Start index: {chunk.start_index}")
|
80 |
+
print(f"End index: {chunk.end_index}")
|
81 |
+
print(f"no of sentences: {len(chunk.sentences)}")
|
82 |
+
print("-" * 80)
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
main()
|
src/data_processing/loader.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import List, Union
|
3 |
+
import logging
|
4 |
+
from dataclasses import dataclass
|
5 |
+
|
6 |
+
from langchain_core.documents import Document as LCDocument
|
7 |
+
from langchain_core.document_loaders import BaseLoader
|
8 |
+
from docling.document_converter import DocumentConverter, PdfFormatOption
|
9 |
+
from docling.datamodel.base_models import InputFormat, ConversionStatus
|
10 |
+
from docling.datamodel.pipeline_options import (
|
11 |
+
PdfPipelineOptions,
|
12 |
+
EasyOcrOptions
|
13 |
+
)
|
14 |
+
|
15 |
+
logging.basicConfig(level=logging.INFO)
|
16 |
+
_log = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class ProcessingResult:
|
20 |
+
"""Store results of document processing"""
|
21 |
+
success_count: int = 0
|
22 |
+
failure_count: int = 0
|
23 |
+
partial_success_count: int = 0
|
24 |
+
failed_files: List[str] = None
|
25 |
+
|
26 |
+
def __post_init__(self):
|
27 |
+
if self.failed_files is None:
|
28 |
+
self.failed_files = []
|
29 |
+
|
30 |
+
class MultiFormatDocumentLoader(BaseLoader):
|
31 |
+
"""Loader for multiple document formats that converts to LangChain documents"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
file_paths: Union[str, List[str]],
|
36 |
+
enable_ocr: bool = True,
|
37 |
+
enable_tables: bool = True
|
38 |
+
):
|
39 |
+
self._file_paths = [file_paths] if isinstance(file_paths, str) else file_paths
|
40 |
+
self._enable_ocr = enable_ocr
|
41 |
+
self._enable_tables = enable_tables
|
42 |
+
self._converter = self._setup_converter()
|
43 |
+
|
44 |
+
def _setup_converter(self):
|
45 |
+
"""Set up the document converter with appropriate options"""
|
46 |
+
# Configure pipeline options
|
47 |
+
pipeline_options = PdfPipelineOptions(do_ocr=False, do_table_structure=False, ocr_options=EasyOcrOptions(
|
48 |
+
force_full_page_ocr=True
|
49 |
+
))
|
50 |
+
if self._enable_ocr:
|
51 |
+
pipeline_options.do_ocr = True
|
52 |
+
if self._enable_tables:
|
53 |
+
pipeline_options.do_table_structure = True
|
54 |
+
pipeline_options.table_structure_options.do_cell_matching = True
|
55 |
+
|
56 |
+
# Create converter with supported formats
|
57 |
+
return DocumentConverter(
|
58 |
+
allowed_formats=[
|
59 |
+
InputFormat.PDF,
|
60 |
+
InputFormat.IMAGE,
|
61 |
+
InputFormat.DOCX,
|
62 |
+
InputFormat.HTML,
|
63 |
+
InputFormat.PPTX,
|
64 |
+
InputFormat.ASCIIDOC,
|
65 |
+
InputFormat.MD,
|
66 |
+
],
|
67 |
+
format_options={
|
68 |
+
InputFormat.PDF: PdfFormatOption(
|
69 |
+
pipeline_options=pipeline_options,
|
70 |
+
)}
|
71 |
+
)
|
72 |
+
|
73 |
+
def lazy_load(self):
|
74 |
+
"""Convert documents and yield LangChain documents"""
|
75 |
+
results = ProcessingResult()
|
76 |
+
|
77 |
+
for file_path in self._file_paths:
|
78 |
+
try:
|
79 |
+
path = Path(file_path)
|
80 |
+
if not path.exists():
|
81 |
+
_log.warning(f"File not found: {file_path}")
|
82 |
+
results.failure_count += 1
|
83 |
+
results.failed_files.append(file_path)
|
84 |
+
continue
|
85 |
+
|
86 |
+
conversion_result = self._converter.convert(path)
|
87 |
+
|
88 |
+
if conversion_result.status == ConversionStatus.SUCCESS:
|
89 |
+
results.success_count += 1
|
90 |
+
text = conversion_result.document.export_to_markdown()
|
91 |
+
metadata = {
|
92 |
+
'source': str(path),
|
93 |
+
'file_type': path.suffix,
|
94 |
+
}
|
95 |
+
yield LCDocument(
|
96 |
+
page_content=text,
|
97 |
+
metadata=metadata
|
98 |
+
)
|
99 |
+
elif conversion_result.status == ConversionStatus.PARTIAL_SUCCESS:
|
100 |
+
results.partial_success_count += 1
|
101 |
+
_log.warning(f"Partial conversion for {file_path}")
|
102 |
+
text = conversion_result.document.export_to_markdown()
|
103 |
+
metadata = {
|
104 |
+
'source': str(path),
|
105 |
+
'file_type': path.suffix,
|
106 |
+
'conversion_status': 'partial'
|
107 |
+
}
|
108 |
+
yield LCDocument(
|
109 |
+
page_content=text,
|
110 |
+
metadata=metadata
|
111 |
+
)
|
112 |
+
else:
|
113 |
+
results.failure_count += 1
|
114 |
+
results.failed_files.append(file_path)
|
115 |
+
_log.error(f"Failed to convert {file_path}")
|
116 |
+
|
117 |
+
except Exception as e:
|
118 |
+
_log.error(f"Error processing {file_path}: {str(e)}")
|
119 |
+
results.failure_count += 1
|
120 |
+
results.failed_files.append(file_path)
|
121 |
+
|
122 |
+
# Log final results
|
123 |
+
total = results.success_count + results.partial_success_count + results.failure_count
|
124 |
+
_log.info(
|
125 |
+
f"Processed {total} documents:\n"
|
126 |
+
f"- Successfully converted: {results.success_count}\n"
|
127 |
+
f"- Partially converted: {results.partial_success_count}\n"
|
128 |
+
f"- Failed: {results.failure_count}"
|
129 |
+
)
|
130 |
+
if results.failed_files:
|
131 |
+
_log.info("Failed files:")
|
132 |
+
for file in results.failed_files:
|
133 |
+
_log.info(f"- {file}")
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == '__main__':
|
137 |
+
# Load documents from a list of file paths
|
138 |
+
loader = MultiFormatDocumentLoader(
|
139 |
+
file_paths=[
|
140 |
+
# './data/2404.19756v1.pdf',
|
141 |
+
# './data/OD429347375590223100.pdf',
|
142 |
+
'./data/Project Report Format.docx',
|
143 |
+
# './data/UNIT 2 GENDER BASED VIOLENCE.pptx'
|
144 |
+
],
|
145 |
+
enable_ocr=False,
|
146 |
+
enable_tables=True
|
147 |
+
)
|
148 |
+
for doc in loader.lazy_load():
|
149 |
+
print(doc.page_content)
|
150 |
+
print(doc.metadata)
|
151 |
+
# save document in .md file
|
152 |
+
with open('output.md', 'w') as f:
|
153 |
+
f.write(doc.page_content)
|
src/llm/__init__.py
ADDED
File without changes
|
src/llm/graders.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import ChatPromptTemplate
|
2 |
+
from langchain_ollama import ChatOllama
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
class DocumentRelevance(BaseModel):
|
7 |
+
"""Binary score for relevance check on retrieved documents."""
|
8 |
+
binary_score: str = Field(
|
9 |
+
description="Documents are relevant to the question, 'yes' or 'no'"
|
10 |
+
)
|
11 |
+
|
12 |
+
class HallucinationCheck(BaseModel):
|
13 |
+
"""Binary score for hallucination present in generation answer."""
|
14 |
+
binary_score: str = Field(
|
15 |
+
description="Answer is grounded in the facts, 'yes' or 'no'"
|
16 |
+
)
|
17 |
+
|
18 |
+
class AnswerQuality(BaseModel):
|
19 |
+
"""Binary score to assess answer addresses question."""
|
20 |
+
binary_score: str = Field(
|
21 |
+
description="Answer addresses the question, 'yes' or 'no'"
|
22 |
+
)
|
23 |
+
|
24 |
+
def create_llm_grader(grader_type: str, llm):
|
25 |
+
"""
|
26 |
+
Create an LLM grader based on the specified type.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
grader_type (str): Type of grader to create
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Callable: LLM grader function
|
33 |
+
"""
|
34 |
+
# Initialize LLM
|
35 |
+
|
36 |
+
# Select grader type and create structured output
|
37 |
+
if grader_type == "document_relevance":
|
38 |
+
structured_llm_grader = llm.with_structured_output(DocumentRelevance)
|
39 |
+
system = """You are a grader assessing relevance of a retrieved document to a user question.
|
40 |
+
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
|
41 |
+
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
|
42 |
+
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
|
43 |
+
|
44 |
+
prompt = ChatPromptTemplate.from_messages([
|
45 |
+
("system", system),
|
46 |
+
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
|
47 |
+
])
|
48 |
+
|
49 |
+
elif grader_type == "hallucination":
|
50 |
+
structured_llm_grader = llm.with_structured_output(HallucinationCheck)
|
51 |
+
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts.
|
52 |
+
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
|
53 |
+
|
54 |
+
prompt = ChatPromptTemplate.from_messages([
|
55 |
+
("system", system),
|
56 |
+
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
|
57 |
+
])
|
58 |
+
|
59 |
+
elif grader_type == "answer_quality":
|
60 |
+
structured_llm_grader = llm.with_structured_output(AnswerQuality)
|
61 |
+
system = """You are a grader assessing whether an answer addresses / resolves a question.
|
62 |
+
Give a binary score 'yes' or 'no'. 'Yes' means that the answer resolves the question."""
|
63 |
+
|
64 |
+
prompt = ChatPromptTemplate.from_messages([
|
65 |
+
("system", system),
|
66 |
+
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
|
67 |
+
])
|
68 |
+
|
69 |
+
else:
|
70 |
+
raise ValueError(f"Unknown grader type: {grader_type}")
|
71 |
+
|
72 |
+
return prompt | structured_llm_grader
|
73 |
+
|
74 |
+
def grade_document_relevance(question: str, document: str, llm):
|
75 |
+
"""
|
76 |
+
Grade the relevance of a document to a given question.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
question (str): User's question
|
80 |
+
document (str): Retrieved document content
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
str: Binary score ('yes' or 'no')
|
84 |
+
"""
|
85 |
+
grader = create_llm_grader("document_relevance", llm)
|
86 |
+
result = grader.invoke({"question": question, "document": document})
|
87 |
+
return result.binary_score
|
88 |
+
|
89 |
+
def check_hallucination(documents: List[str], generation: str, llm):
|
90 |
+
"""
|
91 |
+
Check if the generation is grounded in the provided documents.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
documents (List[str]): List of source documents
|
95 |
+
generation (str): LLM generated answer
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
str: Binary score ('yes' or 'no')
|
99 |
+
"""
|
100 |
+
grader = create_llm_grader("hallucination", llm)
|
101 |
+
result = grader.invoke({"documents": documents, "generation": generation})
|
102 |
+
return result.binary_score
|
103 |
+
|
104 |
+
def grade_answer_quality(question: str, generation: str, llm):
|
105 |
+
"""
|
106 |
+
Grade the quality of the answer in addressing the question.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
question (str): User's original question
|
110 |
+
generation (str): LLM generated answer
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
str: Binary score ('yes' or 'no')
|
114 |
+
"""
|
115 |
+
grader = create_llm_grader("answer_quality", llm)
|
116 |
+
result = grader.invoke({"question": question, "generation": generation})
|
117 |
+
return result.binary_score
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
# Example usage
|
121 |
+
test_question = "What are the types of agent memory?"
|
122 |
+
test_document = "Agent memory can be classified into different types such as episodic, semantic, and working memory."
|
123 |
+
test_generation = "Agent memory includes episodic memory for storing experiences, semantic memory for general knowledge, and working memory for immediate processing."
|
124 |
+
llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
|
125 |
+
|
126 |
+
print("Document Relevance:", grade_document_relevance(test_question, test_document, llm))
|
127 |
+
print("Hallucination Check:", check_hallucination([test_document], test_generation, llm))
|
128 |
+
print("Answer Quality:", grade_answer_quality(test_question, test_generation, llm))
|
src/llm/query_rewriter.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts.chat import ChatPromptTemplate
|
2 |
+
from langchain_ollama import ChatOllama
|
3 |
+
from langchain_core.output_parsers import StrOutputParser
|
4 |
+
|
5 |
+
def create_query_rewriter(llm):
|
6 |
+
"""
|
7 |
+
Create a query rewriter to optimize retrieval.
|
8 |
+
|
9 |
+
Returns:
|
10 |
+
Callable: Query rewriter function
|
11 |
+
"""
|
12 |
+
|
13 |
+
# Prompt for query rewriting
|
14 |
+
system = """You are a question re-writer that converts an input question to a better version that is optimized
|
15 |
+
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
|
16 |
+
|
17 |
+
re_write_prompt = ChatPromptTemplate.from_messages([
|
18 |
+
("system", system),
|
19 |
+
("human", "Here is the initial question: \n\n {question} \n Formulate an improved question."),
|
20 |
+
])
|
21 |
+
|
22 |
+
# Create query rewriter chain
|
23 |
+
return re_write_prompt | llm | StrOutputParser()
|
24 |
+
|
25 |
+
def rewrite_query(question: str, llm):
|
26 |
+
"""
|
27 |
+
Rewrite a given query to optimize retrieval.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
question (str): Original user question
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
str: Rewritten query
|
34 |
+
"""
|
35 |
+
query_rewriter = create_query_rewriter(llm)
|
36 |
+
try:
|
37 |
+
rewritten_query = query_rewriter.invoke({"question": question})
|
38 |
+
return rewritten_query
|
39 |
+
except Exception as e:
|
40 |
+
print(f"Query rewriting error: {e}")
|
41 |
+
return question
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
# Example usage
|
45 |
+
test_queries = [
|
46 |
+
"Tell me about AI agents",
|
47 |
+
"What do we know about memory in AI systems?",
|
48 |
+
"Bears draft strategy"
|
49 |
+
]
|
50 |
+
llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
|
51 |
+
|
52 |
+
for query in test_queries:
|
53 |
+
rewritten = rewrite_query(query, llm)
|
54 |
+
print(f"Original: {query}")
|
55 |
+
print(f"Rewritten: {rewritten}\n")
|
src/tools/__init__.py
ADDED
File without changes
|
src/tools/deep_crawler.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import asyncio
|
4 |
+
from typing import List, Dict, Optional, Set
|
5 |
+
from urllib.parse import urlparse
|
6 |
+
from langchain_community.tools import DuckDuckGoSearchResults, TavilySearchResults
|
7 |
+
from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
8 |
+
from crawl4ai import AsyncWebCrawler, CacheMode
|
9 |
+
from crawl4ai.content_filter_strategy import PruningContentFilter, BM25ContentFilter
|
10 |
+
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
class DeepWebCrawler:
|
16 |
+
def __init__(self,
|
17 |
+
max_search_results: int = 5,
|
18 |
+
max_external_links: int = 3,
|
19 |
+
word_count_threshold: int = 50,
|
20 |
+
content_filter_type: str = 'pruning',
|
21 |
+
filter_threshold: float = 0.48):
|
22 |
+
"""
|
23 |
+
Initialize the Deep Web Crawler with support for one-level deep crawling
|
24 |
+
|
25 |
+
Args:
|
26 |
+
max_search_results (int): Maximum number of search results to process
|
27 |
+
max_external_links (int): Maximum number of external links to crawl per page
|
28 |
+
word_count_threshold (int): Minimum word count for crawled content
|
29 |
+
content_filter_type (str): Type of content filter ('pruning' or 'bm25')
|
30 |
+
filter_threshold (float): Threshold for content filtering
|
31 |
+
"""
|
32 |
+
self.max_search_results = max_search_results
|
33 |
+
self.max_external_links = max_external_links
|
34 |
+
self.word_count_threshold = word_count_threshold
|
35 |
+
self.content_filter_type = content_filter_type
|
36 |
+
self.filter_threshold = filter_threshold
|
37 |
+
self.crawled_urls: Set[str] = set()
|
38 |
+
|
39 |
+
def _create_web_search_tool(self):
|
40 |
+
return TavilySearchResults(max_results=self.max_search_results)
|
41 |
+
|
42 |
+
def _create_content_filter(self, user_query: Optional[str] = None):
|
43 |
+
if self.content_filter_type == 'bm25' and user_query:
|
44 |
+
return BM25ContentFilter(
|
45 |
+
user_query=user_query,
|
46 |
+
bm25_threshold=self.filter_threshold
|
47 |
+
)
|
48 |
+
return PruningContentFilter(
|
49 |
+
threshold=self.filter_threshold,
|
50 |
+
threshold_type="fixed",
|
51 |
+
min_word_threshold=self.word_count_threshold
|
52 |
+
)
|
53 |
+
|
54 |
+
def _extract_links_from_search_results(self, results: List[Dict]) -> List[str]:
|
55 |
+
"""Safely extract URLs from search results"""
|
56 |
+
urls = []
|
57 |
+
for result in results:
|
58 |
+
if isinstance(result, dict) and 'url' in result:
|
59 |
+
urls.append(result['url'])
|
60 |
+
elif isinstance(result, str):
|
61 |
+
urls.append(result)
|
62 |
+
return urls
|
63 |
+
|
64 |
+
def _extract_url_from_link(self, link):
|
65 |
+
"""Extract URL string from link object which might be a dict or string"""
|
66 |
+
if isinstance(link, dict):
|
67 |
+
return link.get('url', '') # Assuming the URL is stored in a 'url' key
|
68 |
+
elif isinstance(link, str):
|
69 |
+
return link
|
70 |
+
return ''
|
71 |
+
|
72 |
+
def _process_crawl_result(self, result) -> Dict:
|
73 |
+
"""Process individual crawl result into structured format"""
|
74 |
+
return {
|
75 |
+
"url": result.url,
|
76 |
+
"success": result.success,
|
77 |
+
"title": result.metadata.get('title', 'N/A'),
|
78 |
+
"content": result.markdown_v2.raw_markdown if result.success else result.error_message,
|
79 |
+
"word_count": len(result.markdown_v2.raw_markdown.split()) if result.success else 0,
|
80 |
+
"links": {
|
81 |
+
"internal": result.links.get('internal', []),
|
82 |
+
"external": result.links.get('external', [])
|
83 |
+
},
|
84 |
+
"images": len(result.media.get('images', []))
|
85 |
+
}
|
86 |
+
|
87 |
+
async def crawl_urls(self, urls: List[str], user_query: Optional[str] = None, depth: int = 0):
|
88 |
+
"""
|
89 |
+
Crawl URLs with support for external link crawling
|
90 |
+
|
91 |
+
Args:
|
92 |
+
urls (List[str]): List of URLs to crawl
|
93 |
+
user_query (Optional[str]): Query for content filtering
|
94 |
+
depth (int): Current crawl depth (0 for initial, 1 for external links)
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
List of crawl results including external link content
|
98 |
+
"""
|
99 |
+
if not urls or depth > 1:
|
100 |
+
return []
|
101 |
+
|
102 |
+
# Filter out already crawled URLs
|
103 |
+
new_urls = [url for url in urls if url not in self.crawled_urls]
|
104 |
+
if not new_urls:
|
105 |
+
return []
|
106 |
+
|
107 |
+
async with AsyncWebCrawler(
|
108 |
+
browser_type="chromium",
|
109 |
+
headless=True,
|
110 |
+
verbose=True
|
111 |
+
) as crawler:
|
112 |
+
content_filter = self._create_content_filter(user_query)
|
113 |
+
|
114 |
+
results = await crawler.arun_many(
|
115 |
+
urls=new_urls,
|
116 |
+
word_count_threshold=self.word_count_threshold,
|
117 |
+
cache_mode=CacheMode.BYPASS,
|
118 |
+
markdown_generator=DefaultMarkdownGenerator(content_filter=content_filter),
|
119 |
+
exclude_external_links=True,
|
120 |
+
exclude_social_media_links=True,
|
121 |
+
remove_overlay_elements=True,
|
122 |
+
simulate_user=True,
|
123 |
+
magic=True
|
124 |
+
)
|
125 |
+
|
126 |
+
processed_results = []
|
127 |
+
external_urls = set()
|
128 |
+
|
129 |
+
# Process results and collect external URLs
|
130 |
+
for result in results:
|
131 |
+
self.crawled_urls.add(result.url)
|
132 |
+
processed_result = self._process_crawl_result(result)
|
133 |
+
processed_results.append(processed_result)
|
134 |
+
|
135 |
+
if depth == 0 and result.success:
|
136 |
+
# Collect unique external URLs for further crawling
|
137 |
+
external_links = result.links.get('external', [])[:self.max_external_links]
|
138 |
+
external_urls.update(
|
139 |
+
self._extract_url_from_link(link)
|
140 |
+
for link in external_links
|
141 |
+
if self._extract_url_from_link(link)
|
142 |
+
and self._extract_url_from_link(link) not in self.crawled_urls
|
143 |
+
)
|
144 |
+
|
145 |
+
# Crawl external links if at depth 0
|
146 |
+
if depth == 0 and external_urls and False:
|
147 |
+
external_results = await self.crawl_urls(
|
148 |
+
list(external_urls),
|
149 |
+
user_query=user_query,
|
150 |
+
depth=0
|
151 |
+
)
|
152 |
+
processed_results.extend(external_results)
|
153 |
+
|
154 |
+
return processed_results
|
155 |
+
|
156 |
+
async def search_and_crawl(self, query: str) -> List[Dict]:
|
157 |
+
"""
|
158 |
+
Perform web search and deep crawl of results
|
159 |
+
|
160 |
+
Args:
|
161 |
+
query (str): Search query
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
List of crawled content results including external links
|
165 |
+
"""
|
166 |
+
|
167 |
+
search_tool = self._create_web_search_tool()
|
168 |
+
search_results = search_tool.invoke(query)
|
169 |
+
|
170 |
+
# Handle different types of search results
|
171 |
+
if isinstance(search_results, str):
|
172 |
+
urls = [search_results]
|
173 |
+
elif isinstance(search_results, list):
|
174 |
+
urls = self._extract_links_from_search_results(search_results)
|
175 |
+
else:
|
176 |
+
print(f"Unexpected search results format: {type(search_results)}")
|
177 |
+
return []
|
178 |
+
|
179 |
+
if not urls:
|
180 |
+
print("No valid URLs found in search results")
|
181 |
+
return []
|
182 |
+
|
183 |
+
print(f"Initial search found {len(urls)} URLs for query: {query}")
|
184 |
+
print(urls)
|
185 |
+
crawl_results = await self.crawl_urls(urls, user_query=query)
|
186 |
+
|
187 |
+
return crawl_results
|
188 |
+
|
189 |
+
|
190 |
+
class ResourceCollectionAgent:
|
191 |
+
def __init__(self, max_results_per_query: int = 10):
|
192 |
+
"""
|
193 |
+
Initialize the Resource Collection Agent
|
194 |
+
|
195 |
+
Args:
|
196 |
+
max_results_per_query (int): Maximum number of results per search query
|
197 |
+
"""
|
198 |
+
self.max_results_per_query = max_results_per_query
|
199 |
+
self.search_tool = TavilySearchResults(max_results=max_results_per_query)
|
200 |
+
|
201 |
+
def _is_valid_domain(self, url: str, valid_domains: List[str]) -> bool:
|
202 |
+
"""Check if URL belongs to allowed domains"""
|
203 |
+
try:
|
204 |
+
domain = urlparse(url).netloc.lower()
|
205 |
+
return any(valid_domain in domain for valid_domain in valid_domains)
|
206 |
+
except:
|
207 |
+
return False
|
208 |
+
|
209 |
+
def _extract_search_result(self, result) -> Optional[Dict]:
|
210 |
+
"""Safely extract information from a search result"""
|
211 |
+
try:
|
212 |
+
if isinstance(result, dict):
|
213 |
+
return {
|
214 |
+
"title": result.get("title", "No title"),
|
215 |
+
"url": result.get("url", ""),
|
216 |
+
"snippet": result.get("snippet", "No description")
|
217 |
+
}
|
218 |
+
elif isinstance(result, str):
|
219 |
+
return {
|
220 |
+
"title": "Unknown",
|
221 |
+
"url": result,
|
222 |
+
"snippet": "No description available"
|
223 |
+
}
|
224 |
+
return None
|
225 |
+
except Exception as e:
|
226 |
+
print(f"Error processing search result: {str(e)}")
|
227 |
+
return None
|
228 |
+
|
229 |
+
async def collect_resources(self) -> Dict[str, List[Dict]]:
|
230 |
+
"""
|
231 |
+
Collect AI/ML resources from specific platforms
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
Dictionary with categorized resource links
|
235 |
+
"""
|
236 |
+
search_queries = {
|
237 |
+
"datasets": [
|
238 |
+
("kaggle", "site:kaggle.com/datasets machine learning"),
|
239 |
+
("huggingface", "site:huggingface.co/datasets artificial intelligence")
|
240 |
+
],
|
241 |
+
"repositories": [
|
242 |
+
("github", "site:github.com AI tools repository")
|
243 |
+
]
|
244 |
+
}
|
245 |
+
|
246 |
+
results = {
|
247 |
+
"kaggle_datasets": [],
|
248 |
+
"huggingface_datasets": [],
|
249 |
+
"github_repositories": []
|
250 |
+
}
|
251 |
+
|
252 |
+
for category, queries in search_queries.items():
|
253 |
+
for platform, query in queries:
|
254 |
+
try:
|
255 |
+
search_results = self.search_tool.invoke(query)
|
256 |
+
|
257 |
+
# Handle different result formats
|
258 |
+
if isinstance(search_results, str):
|
259 |
+
search_results = [search_results]
|
260 |
+
elif not isinstance(search_results, list):
|
261 |
+
print(f"Unexpected search results format for {platform}: {type(search_results)}")
|
262 |
+
continue
|
263 |
+
|
264 |
+
# Filter results based on domain
|
265 |
+
valid_domains = {
|
266 |
+
"kaggle": ["kaggle.com"],
|
267 |
+
"huggingface": ["huggingface.co"],
|
268 |
+
"github": ["github.com"]
|
269 |
+
}
|
270 |
+
|
271 |
+
for result in search_results:
|
272 |
+
processed_result = self._extract_search_result(result)
|
273 |
+
if processed_result and self._is_valid_domain(
|
274 |
+
processed_result["url"],
|
275 |
+
valid_domains[platform]
|
276 |
+
):
|
277 |
+
if platform == "kaggle":
|
278 |
+
results["kaggle_datasets"].append(processed_result)
|
279 |
+
elif platform == "huggingface":
|
280 |
+
results["huggingface_datasets"].append(processed_result)
|
281 |
+
elif platform == "github":
|
282 |
+
results["github_repositories"].append(processed_result)
|
283 |
+
|
284 |
+
except Exception as e:
|
285 |
+
print(f"Error collecting {platform} resources: {str(e)}")
|
286 |
+
continue
|
287 |
+
|
288 |
+
return results
|
289 |
+
|
290 |
+
def main():
|
291 |
+
async def run_examples():
|
292 |
+
# Test DeepWebCrawler
|
293 |
+
deep_crawler = DeepWebCrawler(
|
294 |
+
max_search_results=3,
|
295 |
+
max_external_links=2,
|
296 |
+
word_count_threshold=50
|
297 |
+
)
|
298 |
+
|
299 |
+
crawl_results = await deep_crawler.search_and_crawl(
|
300 |
+
"Adani Defence & Aerospace"
|
301 |
+
)
|
302 |
+
|
303 |
+
print("\nDeep Crawler Results:")
|
304 |
+
for result in crawl_results:
|
305 |
+
print(f"URL: {result['url']}")
|
306 |
+
print(f"Title: {result['title']}")
|
307 |
+
print(f"Word Count: {result['word_count']}")
|
308 |
+
print(f"External Links: {len(result['links']['external'])}\n")
|
309 |
+
|
310 |
+
# Test ResourceCollectionAgent
|
311 |
+
resource_agent = ResourceCollectionAgent(max_results_per_query=5)
|
312 |
+
resources = await resource_agent.collect_resources()
|
313 |
+
|
314 |
+
print("\nResource Collection Results:")
|
315 |
+
for category, items in resources.items():
|
316 |
+
print(f"\n{category.upper()}:")
|
317 |
+
for item in items:
|
318 |
+
print(f"Title: {item['title']}")
|
319 |
+
print(f"URL: {item['url']}")
|
320 |
+
print("---")
|
321 |
+
|
322 |
+
asyncio.run(run_examples())
|
323 |
+
|
324 |
+
if __name__ == "__main__":
|
325 |
+
main()
|
src/tools/web_search.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import asyncio
|
4 |
+
from typing import List, Dict, Optional
|
5 |
+
|
6 |
+
from langchain_community.tools import DuckDuckGoSearchResults
|
7 |
+
from crawl4ai import AsyncWebCrawler, CacheMode
|
8 |
+
from crawl4ai.content_filter_strategy import PruningContentFilter, BM25ContentFilter
|
9 |
+
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
class AdvancedWebCrawler:
|
15 |
+
def __init__(self,
|
16 |
+
max_search_results: int = 5,
|
17 |
+
word_count_threshold: int = 50,
|
18 |
+
content_filter_type: str = 'pruning',
|
19 |
+
filter_threshold: float = 0.48):
|
20 |
+
"""
|
21 |
+
Initialize the Advanced Web Crawler
|
22 |
+
|
23 |
+
Args:
|
24 |
+
max_search_results (int): Maximum number of search results to process
|
25 |
+
word_count_threshold (int): Minimum word count for crawled content
|
26 |
+
content_filter_type (str): Type of content filter ('pruning' or 'bm25')
|
27 |
+
filter_threshold (float): Threshold for content filtering
|
28 |
+
"""
|
29 |
+
self.max_search_results = max_search_results
|
30 |
+
self.word_count_threshold = word_count_threshold
|
31 |
+
self.content_filter_type = content_filter_type
|
32 |
+
self.filter_threshold = filter_threshold
|
33 |
+
|
34 |
+
def _create_web_search_tool(self):
|
35 |
+
"""
|
36 |
+
Create a web search tool using DuckDuckGo
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
DuckDuckGoSearchResults: Web search tool
|
40 |
+
"""
|
41 |
+
return DuckDuckGoSearchResults(max_results=self.max_search_results, output_format="list")
|
42 |
+
|
43 |
+
def _create_content_filter(self, user_query: Optional[str] = None):
|
44 |
+
"""
|
45 |
+
Create content filter based on specified type
|
46 |
+
|
47 |
+
Args:
|
48 |
+
user_query (Optional[str]): Query to use for BM25 filtering
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Content filter strategy
|
52 |
+
"""
|
53 |
+
if self.content_filter_type == 'bm25' and user_query:
|
54 |
+
return BM25ContentFilter(
|
55 |
+
user_query=user_query,
|
56 |
+
bm25_threshold=self.filter_threshold
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
return PruningContentFilter(
|
60 |
+
threshold=self.filter_threshold,
|
61 |
+
threshold_type="fixed",
|
62 |
+
min_word_threshold=self.word_count_threshold
|
63 |
+
)
|
64 |
+
|
65 |
+
async def crawl_urls(self, urls: List[str], user_query: Optional[str] = None):
|
66 |
+
"""
|
67 |
+
Crawl multiple URLs with content filtering
|
68 |
+
|
69 |
+
Args:
|
70 |
+
urls (List[str]): List of URLs to crawl
|
71 |
+
user_query (Optional[str]): Query used for BM25 content filtering
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
List of crawl results
|
75 |
+
"""
|
76 |
+
async with AsyncWebCrawler(
|
77 |
+
browser_type="chromium",
|
78 |
+
headless=True,
|
79 |
+
verbose=True
|
80 |
+
) as crawler:
|
81 |
+
# Create appropriate content filter
|
82 |
+
content_filter = self._create_content_filter(user_query)
|
83 |
+
|
84 |
+
# Run crawling for multiple URLs
|
85 |
+
results = await crawler.arun_many(
|
86 |
+
urls=urls,
|
87 |
+
word_count_threshold=self.word_count_threshold,
|
88 |
+
bypass_cache=True,
|
89 |
+
markdown_generator=DefaultMarkdownGenerator(
|
90 |
+
content_filter=content_filter
|
91 |
+
),
|
92 |
+
cache_mode=CacheMode.DISABLED,
|
93 |
+
exclude_external_links=True,
|
94 |
+
remove_overlay_elements=True,
|
95 |
+
simulate_user=True,
|
96 |
+
magic=True
|
97 |
+
)
|
98 |
+
|
99 |
+
# Process and return crawl results
|
100 |
+
processed_results = []
|
101 |
+
for result in results:
|
102 |
+
crawl_result = {
|
103 |
+
"url": result.url,
|
104 |
+
"success": result.success,
|
105 |
+
"title": result.metadata.get('title', 'N/A'),
|
106 |
+
"content": result.markdown_v2.raw_markdown if result.success else result.error_message,
|
107 |
+
"word_count": len(result.markdown_v2.raw_markdown.split()) if result.success else 0,
|
108 |
+
"links": {
|
109 |
+
"internal": len(result.links.get('internal', [])),
|
110 |
+
"external": len(result.links.get('external', []))
|
111 |
+
},
|
112 |
+
"images": len(result.media.get('images', []))
|
113 |
+
}
|
114 |
+
processed_results.append(crawl_result)
|
115 |
+
|
116 |
+
return processed_results
|
117 |
+
|
118 |
+
async def search_and_crawl(self, query: str) -> List[Dict]:
|
119 |
+
"""
|
120 |
+
Perform web search and crawl the results
|
121 |
+
|
122 |
+
Args:
|
123 |
+
query (str): Search query
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
List of crawled content results
|
127 |
+
"""
|
128 |
+
# Perform web search
|
129 |
+
search_tool = self._create_web_search_tool()
|
130 |
+
try:
|
131 |
+
search_results = search_tool.invoke({"query": query})
|
132 |
+
|
133 |
+
# Extract URLs from search results
|
134 |
+
urls = [result['link'] for result in search_results]
|
135 |
+
print(f"Found {len(urls)} URLs for query: {query}")
|
136 |
+
|
137 |
+
# Crawl URLs
|
138 |
+
crawl_results = await self.crawl_urls(urls, user_query=query)
|
139 |
+
|
140 |
+
return crawl_results
|
141 |
+
|
142 |
+
except Exception as e:
|
143 |
+
print(f"Web search and crawl error: {e}")
|
144 |
+
return []
|
145 |
+
|
146 |
+
def main():
|
147 |
+
# Example usage
|
148 |
+
crawler = AdvancedWebCrawler(
|
149 |
+
max_search_results=5,
|
150 |
+
word_count_threshold=50,
|
151 |
+
content_filter_type='f',
|
152 |
+
filter_threshold=0.48
|
153 |
+
)
|
154 |
+
|
155 |
+
test_queries = [
|
156 |
+
"Latest developments in AI agents",
|
157 |
+
"Today's weather forecast in Kolkata",
|
158 |
+
]
|
159 |
+
|
160 |
+
for query in test_queries:
|
161 |
+
# Run search and crawl asynchronously
|
162 |
+
results = asyncio.run(crawler.search_and_crawl(query))
|
163 |
+
|
164 |
+
print(f"\nResults for query: {query}")
|
165 |
+
for result in results:
|
166 |
+
print(f"URL: {result['url']}")
|
167 |
+
print(f"Success: {result['success']}")
|
168 |
+
print(f"Title: {result['title']}")
|
169 |
+
print(f"Word Count: {result['word_count']}")
|
170 |
+
print(f"Content Preview: {result['content'][:500]}...\n")
|
171 |
+
|
172 |
+
if __name__ == "__main__":
|
173 |
+
main()
|
src/vectorstore/__init__.py
ADDED
File without changes
|
src/vectorstore/pinecone_db.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.data_processing.loader import MultiFormatDocumentLoader
|
2 |
+
from src.data_processing.chunker import SDPMChunker, BGEM3Embeddings
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from typing import List, Dict, Any
|
6 |
+
from pinecone import Pinecone, ServerlessSpec
|
7 |
+
import time
|
8 |
+
from tqdm import tqdm
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
import os
|
11 |
+
|
12 |
+
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
# API Keys
|
16 |
+
PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
|
17 |
+
|
18 |
+
embedding_model = BGEM3Embeddings(model_name="BAAI/bge-m3")
|
19 |
+
|
20 |
+
|
21 |
+
def load_documents(file_paths: List[str], output_path='./data/output.md'):
|
22 |
+
"""
|
23 |
+
Load documents from multiple sources and combine them into a single markdown file
|
24 |
+
"""
|
25 |
+
loader = MultiFormatDocumentLoader(
|
26 |
+
file_paths=file_paths,
|
27 |
+
enable_ocr=False,
|
28 |
+
enable_tables=True
|
29 |
+
)
|
30 |
+
|
31 |
+
# Append all documents to the markdown file
|
32 |
+
with open(output_path, 'w') as f:
|
33 |
+
for doc in loader.lazy_load():
|
34 |
+
# Add metadata as YAML frontmatter
|
35 |
+
f.write('---\n')
|
36 |
+
for key, value in doc.metadata.items():
|
37 |
+
f.write(f'{key}: {value}\n')
|
38 |
+
f.write('---\n\n')
|
39 |
+
f.write(doc.page_content)
|
40 |
+
f.write('\n\n')
|
41 |
+
|
42 |
+
return output_path
|
43 |
+
|
44 |
+
def process_chunks(markdown_path: str, chunk_size: int = 256,
|
45 |
+
threshold: float = 0.7, skip_window: int = 2):
|
46 |
+
"""
|
47 |
+
Process the markdown file into chunks and prepare for vector storage
|
48 |
+
"""
|
49 |
+
chunker = SDPMChunker(
|
50 |
+
embedding_model=embedding_model,
|
51 |
+
chunk_size=chunk_size,
|
52 |
+
threshold=threshold,
|
53 |
+
skip_window=skip_window
|
54 |
+
)
|
55 |
+
|
56 |
+
# Read the markdown file
|
57 |
+
with open(markdown_path, 'r') as file:
|
58 |
+
text = file.read()
|
59 |
+
|
60 |
+
# Generate chunks
|
61 |
+
chunks = chunker.chunk(text)
|
62 |
+
|
63 |
+
# Prepare data for Parquet
|
64 |
+
processed_chunks = []
|
65 |
+
for chunk in chunks:
|
66 |
+
|
67 |
+
processed_chunks.append({
|
68 |
+
'text': chunk.text,
|
69 |
+
'token_count': chunk.token_count,
|
70 |
+
'start_index': chunk.start_index,
|
71 |
+
'end_index': chunk.end_index,
|
72 |
+
'num_sentences': len(chunk.sentences),
|
73 |
+
})
|
74 |
+
|
75 |
+
return processed_chunks
|
76 |
+
|
77 |
+
def save_to_parquet(chunks: List[Dict[str, Any]], output_path='./data/chunks.parquet'):
|
78 |
+
"""
|
79 |
+
Save processed chunks to a Parquet file
|
80 |
+
"""
|
81 |
+
df = pd.DataFrame(chunks)
|
82 |
+
print(f"Saving to Parquet: {output_path}")
|
83 |
+
df.to_parquet(output_path)
|
84 |
+
print(f"Saved to Parquet: {output_path}")
|
85 |
+
return output_path
|
86 |
+
|
87 |
+
|
88 |
+
class PineconeRetriever:
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
pinecone_client: Pinecone,
|
92 |
+
index_name: str,
|
93 |
+
namespace: str,
|
94 |
+
embedding_generator: BGEM3Embeddings
|
95 |
+
):
|
96 |
+
"""Initialize the retriever with Pinecone client and embedding generator.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
pinecone_client: Initialized Pinecone client
|
100 |
+
index_name: Name of the Pinecone index
|
101 |
+
namespace: Namespace in the index
|
102 |
+
embedding_generator: BGEM3Embeddings instance
|
103 |
+
"""
|
104 |
+
self.pinecone = pinecone_client
|
105 |
+
self.index = self.pinecone.Index(index_name)
|
106 |
+
self.namespace = namespace
|
107 |
+
self.embedding_generator = embedding_generator
|
108 |
+
|
109 |
+
def invoke(self, question: str, top_k: int = 5):
|
110 |
+
"""Retrieve similar documents for a question.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
question: Query string
|
114 |
+
top_k: Number of results to return
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
List of dictionaries containing retrieved documents
|
118 |
+
"""
|
119 |
+
# Generate embedding for the question
|
120 |
+
question_embedding = self.embedding_generator.embed(question)
|
121 |
+
question_embedding = question_embedding.tolist()
|
122 |
+
# Query Pinecone
|
123 |
+
results = self.index.query(
|
124 |
+
namespace=self.namespace,
|
125 |
+
vector=question_embedding,
|
126 |
+
top_k=top_k,
|
127 |
+
include_values=False,
|
128 |
+
include_metadata=True
|
129 |
+
)
|
130 |
+
|
131 |
+
# Format results
|
132 |
+
retrieved_docs = [
|
133 |
+
{"page_content": match.metadata["text"], "score": match.score}
|
134 |
+
for match in results.matches
|
135 |
+
]
|
136 |
+
|
137 |
+
return retrieved_docs
|
138 |
+
|
139 |
+
def ingest_data(
|
140 |
+
pc,
|
141 |
+
parquet_path: str,
|
142 |
+
text_column: str,
|
143 |
+
pinecone_client: Pinecone,
|
144 |
+
index_name= "vector-index",
|
145 |
+
namespace= "rag",
|
146 |
+
batch_size: int = 100
|
147 |
+
):
|
148 |
+
"""Ingest data from a Parquet file into Pinecone.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
parquet_path: Path to the Parquet file
|
152 |
+
text_column: Name of the column containing text data
|
153 |
+
pinecone_client: Initialized Pinecone client
|
154 |
+
index_name: Name of the Pinecone index
|
155 |
+
namespace: Namespace in the index
|
156 |
+
batch_size: Batch size for processing
|
157 |
+
"""
|
158 |
+
# Read Parquet file
|
159 |
+
print(f"Reading Parquet file: {parquet_path}")
|
160 |
+
df = pd.read_parquet(parquet_path)
|
161 |
+
print(f"Total records: {len(df)}")
|
162 |
+
# Create or get index
|
163 |
+
if not pinecone_client.has_index(index_name):
|
164 |
+
pinecone_client.create_index(
|
165 |
+
name=index_name,
|
166 |
+
dimension=1024, # BGE-M3 dimension
|
167 |
+
metric="cosine",
|
168 |
+
spec=ServerlessSpec(
|
169 |
+
cloud='aws',
|
170 |
+
region='us-east-1'
|
171 |
+
)
|
172 |
+
)
|
173 |
+
|
174 |
+
# Wait for index to be ready
|
175 |
+
while not pinecone_client.describe_index(index_name).status['ready']:
|
176 |
+
time.sleep(1)
|
177 |
+
|
178 |
+
index = pinecone_client.Index(index_name)
|
179 |
+
|
180 |
+
# Process in batches
|
181 |
+
for i in tqdm(range(0, len(df), batch_size)):
|
182 |
+
batch_df = df.iloc[i:i+batch_size]
|
183 |
+
|
184 |
+
# Generate embeddings for batch
|
185 |
+
texts = batch_df[text_column].tolist()
|
186 |
+
embeddings = embedding_model.embed_batch(texts)
|
187 |
+
print(f"embeddings for batch: {i}")
|
188 |
+
# Prepare records for upsert
|
189 |
+
records = []
|
190 |
+
for idx, (_, row) in enumerate(batch_df.iterrows()):
|
191 |
+
records.append({
|
192 |
+
"id": str(row.name), # Using DataFrame index as ID
|
193 |
+
"values": embeddings[idx],
|
194 |
+
"metadata": {"text": row[text_column]}
|
195 |
+
})
|
196 |
+
|
197 |
+
# Upsert to Pinecone
|
198 |
+
index.upsert(vectors=records, namespace=namespace)
|
199 |
+
|
200 |
+
# Small delay to handle rate limits
|
201 |
+
time.sleep(0.5)
|
202 |
+
|
203 |
+
def get_retriever(
|
204 |
+
pinecone_client: Pinecone,
|
205 |
+
index_name= "vector-index",
|
206 |
+
namespace= "rag"
|
207 |
+
):
|
208 |
+
"""Create and return a PineconeRetriever instance.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
pinecone_client: Initialized Pinecone client
|
212 |
+
index_name: Name of the Pinecone index
|
213 |
+
namespace: Namespace in the index
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
Configured PineconeRetriever instance
|
217 |
+
"""
|
218 |
+
return PineconeRetriever(
|
219 |
+
pinecone_client=pinecone_client,
|
220 |
+
index_name=index_name,
|
221 |
+
namespace=namespace,
|
222 |
+
embedding_generator=embedding_model
|
223 |
+
)
|
224 |
+
|
225 |
+
def main():
|
226 |
+
# Initialize Pinecone client
|
227 |
+
pc = Pinecone(api_key=PINECONE_API_KEY)
|
228 |
+
|
229 |
+
# Define input files
|
230 |
+
file_paths=[
|
231 |
+
# './data/2404.19756v1.pdf',
|
232 |
+
# './data/OD429347375590223100.pdf',
|
233 |
+
# './data/Project Report Format.docx',
|
234 |
+
'./data/UNIT 2 GENDER BASED VIOLENCE.pptx'
|
235 |
+
]
|
236 |
+
|
237 |
+
# Process pipeline
|
238 |
+
try:
|
239 |
+
# Step 1: Load and combine documents
|
240 |
+
# print("Loading documents...")
|
241 |
+
# markdown_path = load_documents(file_paths)
|
242 |
+
|
243 |
+
# # Step 2: Process into chunks with embeddings
|
244 |
+
# print("Processing chunks...")
|
245 |
+
# chunks = process_chunks(markdown_path)
|
246 |
+
|
247 |
+
# # Step 3: Save to Parquet
|
248 |
+
# print("Saving to Parquet...")
|
249 |
+
# parquet_path = save_to_parquet(chunks)
|
250 |
+
|
251 |
+
# # Step 4: Ingest into Pinecone
|
252 |
+
# print("Ingesting into Pinecone...")
|
253 |
+
# ingest_data(
|
254 |
+
# pc,
|
255 |
+
# parquet_path=parquet_path,
|
256 |
+
# text_column="text",
|
257 |
+
# pinecone_client=pc,
|
258 |
+
# )
|
259 |
+
|
260 |
+
# Step 5: Test retrieval
|
261 |
+
print("\nTesting retrieval...")
|
262 |
+
retriever = get_retriever(
|
263 |
+
pinecone_client=pc,
|
264 |
+
index_name="vector-index",
|
265 |
+
namespace="rag"
|
266 |
+
)
|
267 |
+
|
268 |
+
results = retriever.invoke(
|
269 |
+
question="describe the gender based violence",
|
270 |
+
top_k=5
|
271 |
+
)
|
272 |
+
|
273 |
+
for i, doc in enumerate(results, 1):
|
274 |
+
print(f"\nResult {i}:")
|
275 |
+
print(f"Content: {doc['page_content']}...")
|
276 |
+
print(f"Score: {doc['score']}")
|
277 |
+
|
278 |
+
except Exception as e:
|
279 |
+
print(f"Error in pipeline: {str(e)}")
|
280 |
+
|
281 |
+
if __name__ == "__main__":
|
282 |
+
main()
|