Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
#
|
2 |
# Imports & Initial Configuration
|
3 |
-
#
|
4 |
import streamlit as st
|
5 |
-
#
|
6 |
st.set_page_config(page_title="NeuroResearch AI", layout="wide", initial_sidebar_state="expanded")
|
7 |
|
8 |
from langchain_openai import OpenAIEmbeddings
|
@@ -15,27 +15,24 @@ from langgraph.graph.message import add_messages
|
|
15 |
from typing_extensions import TypedDict, Annotated
|
16 |
from typing import Sequence, Dict, List, Optional, Any
|
17 |
import chromadb
|
18 |
-
import re
|
19 |
import os
|
20 |
import requests
|
21 |
import hashlib
|
22 |
-
import json
|
23 |
import time
|
24 |
-
from langchain.tools.retriever import create_retriever_tool
|
25 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
26 |
from datetime import datetime
|
27 |
|
28 |
-
#
|
29 |
# State Schema Definition
|
30 |
-
#
|
31 |
class AgentState(TypedDict):
|
32 |
messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages]
|
33 |
context: Dict[str, Any]
|
34 |
metadata: Dict[str, Any]
|
35 |
|
36 |
-
#
|
37 |
# Configuration
|
38 |
-
#
|
39 |
class ResearchConfig:
|
40 |
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
|
41 |
CHROMA_PATH = "chroma_db"
|
@@ -72,9 +69,9 @@ if not ResearchConfig.DEEPSEEK_API_KEY:
|
|
72 |
3. Rebuild deployment""")
|
73 |
st.stop()
|
74 |
|
75 |
-
#
|
76 |
# Quantum Document Processing
|
77 |
-
#
|
78 |
class QuantumDocumentManager:
|
79 |
def __init__(self):
|
80 |
self.client = chromadb.PersistentClient(path=ResearchConfig.CHROMA_PATH)
|
@@ -90,8 +87,7 @@ class QuantumDocumentManager:
|
|
90 |
separators=["\n\n", "\n", "|||"]
|
91 |
)
|
92 |
docs = splitter.create_documents(documents)
|
93 |
-
#
|
94 |
-
st.write(f"Created {len(docs)} chunks for collection '{collection_name}'")
|
95 |
return Chroma.from_documents(
|
96 |
documents=docs,
|
97 |
embedding=self.embeddings,
|
@@ -101,6 +97,7 @@ class QuantumDocumentManager:
|
|
101 |
)
|
102 |
|
103 |
def _document_id(self, content: str) -> str:
|
|
|
104 |
return f"{hashlib.sha256(content.encode()).hexdigest()[:16]}-{int(time.time())}"
|
105 |
|
106 |
# Initialize document collections
|
@@ -117,9 +114,9 @@ development_docs = qdm.create_collection([
|
|
117 |
"Product Y: In the Performance Optimization Stage Before Release"
|
118 |
], "development")
|
119 |
|
120 |
-
#
|
121 |
# Advanced Retrieval System
|
122 |
-
#
|
123 |
class ResearchRetriever:
|
124 |
def __init__(self):
|
125 |
self.retrievers = {
|
@@ -138,9 +135,9 @@ class ResearchRetriever:
|
|
138 |
}
|
139 |
|
140 |
def retrieve(self, query: str, domain: str) -> List[Any]:
|
|
|
141 |
try:
|
142 |
results = self.retrievers[domain].invoke(query)
|
143 |
-
st.write(f"[DEBUG] Retrieved {len(results)} documents for query: '{query}' in domain '{domain}'")
|
144 |
return results
|
145 |
except KeyError:
|
146 |
st.error(f"[ERROR] Retrieval domain '{domain}' not found.")
|
@@ -148,21 +145,19 @@ class ResearchRetriever:
|
|
148 |
|
149 |
retriever = ResearchRetriever()
|
150 |
|
151 |
-
#
|
152 |
# Cognitive Processing Unit
|
153 |
-
#
|
154 |
class CognitiveProcessor:
|
155 |
def __init__(self):
|
156 |
self.executor = ThreadPoolExecutor(max_workers=ResearchConfig.MAX_CONCURRENT_REQUESTS)
|
157 |
self.session_id = hashlib.sha256(datetime.now().isoformat().encode()).hexdigest()[:12]
|
158 |
|
159 |
def process_query(self, prompt: str) -> Dict:
|
|
|
160 |
futures = []
|
161 |
-
for _ in range(3):
|
162 |
-
futures.append(self.executor.submit(
|
163 |
-
self._execute_api_request,
|
164 |
-
prompt
|
165 |
-
))
|
166 |
|
167 |
results = []
|
168 |
for future in as_completed(futures):
|
@@ -174,6 +169,7 @@ class CognitiveProcessor:
|
|
174 |
return self._consensus_check(results)
|
175 |
|
176 |
def _execute_api_request(self, prompt: str) -> Dict:
|
|
|
177 |
headers = {
|
178 |
"Authorization": f"Bearer {ResearchConfig.DEEPSEEK_API_KEY}",
|
179 |
"Content-Type": "application/json",
|
@@ -202,15 +198,15 @@ class CognitiveProcessor:
|
|
202 |
return {"error": str(e)}
|
203 |
|
204 |
def _consensus_check(self, results: List[Dict]) -> Dict:
|
|
|
205 |
valid = [r for r in results if "error" not in r]
|
206 |
if not valid:
|
207 |
return {"error": "All API requests failed"}
|
208 |
-
# Choose the result with the longest content for robustness.
|
209 |
return max(valid, key=lambda x: len(x.get('choices', [{}])[0].get('message', {}).get('content', '')))
|
210 |
|
211 |
-
#
|
212 |
# Research Workflow Engine
|
213 |
-
#
|
214 |
class ResearchWorkflow:
|
215 |
def __init__(self):
|
216 |
self.processor = CognitiveProcessor()
|
@@ -225,6 +221,7 @@ class ResearchWorkflow:
|
|
225 |
self.workflow.add_node("validate", self.validate_output)
|
226 |
self.workflow.add_node("refine", self.refine_results)
|
227 |
|
|
|
228 |
self.workflow.set_entry_point("ingest")
|
229 |
self.workflow.add_edge("ingest", "retrieve")
|
230 |
self.workflow.add_edge("retrieve", "analyze")
|
@@ -239,9 +236,9 @@ class ResearchWorkflow:
|
|
239 |
self.app = self.workflow.compile()
|
240 |
|
241 |
def ingest_query(self, state: AgentState) -> Dict:
|
|
|
242 |
try:
|
243 |
query = state["messages"][-1].content
|
244 |
-
st.write(f"[DEBUG] Ingesting query: {query}")
|
245 |
return {
|
246 |
"messages": [AIMessage(content="Query ingested successfully")],
|
247 |
"context": {"raw_query": query},
|
@@ -251,10 +248,10 @@ class ResearchWorkflow:
|
|
251 |
return self._error_state(f"Ingestion Error: {str(e)}")
|
252 |
|
253 |
def retrieve_documents(self, state: AgentState) -> Dict:
|
|
|
254 |
try:
|
255 |
query = state["context"]["raw_query"]
|
256 |
docs = retriever.retrieve(query, "research")
|
257 |
-
st.write(f"[DEBUG] Retrieved {len(docs)} documents from retrieval node.")
|
258 |
return {
|
259 |
"messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
|
260 |
"context": {
|
@@ -266,14 +263,15 @@ class ResearchWorkflow:
|
|
266 |
return self._error_state(f"Retrieval Error: {str(e)}")
|
267 |
|
268 |
def analyze_content(self, state: AgentState) -> Dict:
|
|
|
269 |
try:
|
270 |
-
# Ensure documents are present before proceeding.
|
271 |
if "documents" not in state["context"] or not state["context"]["documents"]:
|
272 |
return self._error_state("No documents retrieved; please check your query or retrieval process.")
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
277 |
prompt = ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs)
|
278 |
response = self.processor.process_query(prompt)
|
279 |
|
@@ -288,6 +286,7 @@ class ResearchWorkflow:
|
|
288 |
return self._error_state(f"Analysis Error: {str(e)}")
|
289 |
|
290 |
def validate_output(self, state: AgentState) -> Dict:
|
|
|
291 |
analysis = state["messages"][-1].content
|
292 |
validation_prompt = f"""Validate research analysis:
|
293 |
{analysis}
|
@@ -306,6 +305,7 @@ Respond with 'VALID' or 'INVALID'"""
|
|
306 |
}
|
307 |
|
308 |
def refine_results(self, state: AgentState) -> Dict:
|
|
|
309 |
refinement_prompt = f"""Refine this analysis:
|
310 |
{state["messages"][-1].content}
|
311 |
|
@@ -321,29 +321,32 @@ Improve:
|
|
321 |
}
|
322 |
|
323 |
def _quality_check(self, state: AgentState) -> str:
|
|
|
324 |
content = state["messages"][-1].content
|
325 |
return "valid" if "VALID" in content else "invalid"
|
326 |
|
327 |
def _error_state(self, message: str) -> Dict:
|
328 |
-
|
|
|
329 |
return {
|
330 |
"messages": [AIMessage(content=f"❌ {message}")],
|
331 |
"context": {"error": True},
|
332 |
"metadata": {"status": "error"}
|
333 |
}
|
334 |
|
335 |
-
#
|
336 |
# Research Interface
|
337 |
-
#
|
338 |
class ResearchInterface:
|
339 |
def __init__(self):
|
340 |
self.workflow = ResearchWorkflow()
|
341 |
-
#
|
342 |
self._inject_styles()
|
343 |
self._build_sidebar()
|
344 |
self._build_main_interface()
|
345 |
|
346 |
def _inject_styles(self):
|
|
|
347 |
st.markdown("""
|
348 |
<style>
|
349 |
:root {
|
@@ -390,6 +393,7 @@ class ResearchInterface:
|
|
390 |
""", unsafe_allow_html=True)
|
391 |
|
392 |
def _build_sidebar(self):
|
|
|
393 |
with st.sidebar:
|
394 |
st.title("🔍 Research Database")
|
395 |
st.subheader("Technical Papers")
|
@@ -402,6 +406,7 @@ class ResearchInterface:
|
|
402 |
st.metric("Embedding Dimensions", ResearchConfig.EMBEDDING_DIMENSIONS)
|
403 |
|
404 |
def _build_main_interface(self):
|
|
|
405 |
st.title("🧠 NeuroResearch AI")
|
406 |
query = st.text_area("Research Query:", height=200,
|
407 |
placeholder="Enter technical research question...")
|
@@ -410,6 +415,7 @@ class ResearchInterface:
|
|
410 |
self._execute_analysis(query)
|
411 |
|
412 |
def _execute_analysis(self, query: str):
|
|
|
413 |
try:
|
414 |
with st.spinner("Initializing Quantum Analysis..."):
|
415 |
results = self.workflow.app.stream(
|
@@ -427,6 +433,7 @@ Potential issues:
|
|
427 |
- Temporal processing constraints""")
|
428 |
|
429 |
def _render_event(self, event: Dict):
|
|
|
430 |
if 'ingest' in event:
|
431 |
with st.container():
|
432 |
st.success("✅ Query Ingested")
|
@@ -455,5 +462,8 @@ Potential issues:
|
|
455 |
with st.expander("View Validation Details", expanded=True):
|
456 |
st.markdown(content)
|
457 |
|
|
|
|
|
|
|
458 |
if __name__ == "__main__":
|
459 |
ResearchInterface()
|
|
|
1 |
+
# ---------------------------------------------
|
2 |
# Imports & Initial Configuration
|
3 |
+
# ---------------------------------------------
|
4 |
import streamlit as st
|
5 |
+
# IMPORTANT: Must be the first Streamlit command
|
6 |
st.set_page_config(page_title="NeuroResearch AI", layout="wide", initial_sidebar_state="expanded")
|
7 |
|
8 |
from langchain_openai import OpenAIEmbeddings
|
|
|
15 |
from typing_extensions import TypedDict, Annotated
|
16 |
from typing import Sequence, Dict, List, Optional, Any
|
17 |
import chromadb
|
|
|
18 |
import os
|
19 |
import requests
|
20 |
import hashlib
|
|
|
21 |
import time
|
|
|
22 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
23 |
from datetime import datetime
|
24 |
|
25 |
+
# ---------------------------------------------
|
26 |
# State Schema Definition
|
27 |
+
# ---------------------------------------------
|
28 |
class AgentState(TypedDict):
|
29 |
messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages]
|
30 |
context: Dict[str, Any]
|
31 |
metadata: Dict[str, Any]
|
32 |
|
33 |
+
# ---------------------------------------------
|
34 |
# Configuration
|
35 |
+
# ---------------------------------------------
|
36 |
class ResearchConfig:
|
37 |
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
|
38 |
CHROMA_PATH = "chroma_db"
|
|
|
69 |
3. Rebuild deployment""")
|
70 |
st.stop()
|
71 |
|
72 |
+
# ---------------------------------------------
|
73 |
# Quantum Document Processing
|
74 |
+
# ---------------------------------------------
|
75 |
class QuantumDocumentManager:
|
76 |
def __init__(self):
|
77 |
self.client = chromadb.PersistentClient(path=ResearchConfig.CHROMA_PATH)
|
|
|
87 |
separators=["\n\n", "\n", "|||"]
|
88 |
)
|
89 |
docs = splitter.create_documents(documents)
|
90 |
+
# Removed debug line that displayed chunk creation count
|
|
|
91 |
return Chroma.from_documents(
|
92 |
documents=docs,
|
93 |
embedding=self.embeddings,
|
|
|
97 |
)
|
98 |
|
99 |
def _document_id(self, content: str) -> str:
|
100 |
+
"""Create a unique ID for each document chunk."""
|
101 |
return f"{hashlib.sha256(content.encode()).hexdigest()[:16]}-{int(time.time())}"
|
102 |
|
103 |
# Initialize document collections
|
|
|
114 |
"Product Y: In the Performance Optimization Stage Before Release"
|
115 |
], "development")
|
116 |
|
117 |
+
# ---------------------------------------------
|
118 |
# Advanced Retrieval System
|
119 |
+
# ---------------------------------------------
|
120 |
class ResearchRetriever:
|
121 |
def __init__(self):
|
122 |
self.retrievers = {
|
|
|
135 |
}
|
136 |
|
137 |
def retrieve(self, query: str, domain: str) -> List[Any]:
|
138 |
+
"""Retrieve documents from the specified domain using the appropriate retriever."""
|
139 |
try:
|
140 |
results = self.retrievers[domain].invoke(query)
|
|
|
141 |
return results
|
142 |
except KeyError:
|
143 |
st.error(f"[ERROR] Retrieval domain '{domain}' not found.")
|
|
|
145 |
|
146 |
retriever = ResearchRetriever()
|
147 |
|
148 |
+
# ---------------------------------------------
|
149 |
# Cognitive Processing Unit
|
150 |
+
# ---------------------------------------------
|
151 |
class CognitiveProcessor:
|
152 |
def __init__(self):
|
153 |
self.executor = ThreadPoolExecutor(max_workers=ResearchConfig.MAX_CONCURRENT_REQUESTS)
|
154 |
self.session_id = hashlib.sha256(datetime.now().isoformat().encode()).hexdigest()[:12]
|
155 |
|
156 |
def process_query(self, prompt: str) -> Dict:
|
157 |
+
"""Send the prompt to the DeepSeek API using triple redundancy for robustness."""
|
158 |
futures = []
|
159 |
+
for _ in range(3):
|
160 |
+
futures.append(self.executor.submit(self._execute_api_request, prompt))
|
|
|
|
|
|
|
161 |
|
162 |
results = []
|
163 |
for future in as_completed(futures):
|
|
|
169 |
return self._consensus_check(results)
|
170 |
|
171 |
def _execute_api_request(self, prompt: str) -> Dict:
|
172 |
+
"""Make a single request to the DeepSeek API."""
|
173 |
headers = {
|
174 |
"Authorization": f"Bearer {ResearchConfig.DEEPSEEK_API_KEY}",
|
175 |
"Content-Type": "application/json",
|
|
|
198 |
return {"error": str(e)}
|
199 |
|
200 |
def _consensus_check(self, results: List[Dict]) -> Dict:
|
201 |
+
"""Pick the best result by comparing content length among successful responses."""
|
202 |
valid = [r for r in results if "error" not in r]
|
203 |
if not valid:
|
204 |
return {"error": "All API requests failed"}
|
|
|
205 |
return max(valid, key=lambda x: len(x.get('choices', [{}])[0].get('message', {}).get('content', '')))
|
206 |
|
207 |
+
# ---------------------------------------------
|
208 |
# Research Workflow Engine
|
209 |
+
# ---------------------------------------------
|
210 |
class ResearchWorkflow:
|
211 |
def __init__(self):
|
212 |
self.processor = CognitiveProcessor()
|
|
|
221 |
self.workflow.add_node("validate", self.validate_output)
|
222 |
self.workflow.add_node("refine", self.refine_results)
|
223 |
|
224 |
+
# Define workflow transitions
|
225 |
self.workflow.set_entry_point("ingest")
|
226 |
self.workflow.add_edge("ingest", "retrieve")
|
227 |
self.workflow.add_edge("retrieve", "analyze")
|
|
|
236 |
self.app = self.workflow.compile()
|
237 |
|
238 |
def ingest_query(self, state: AgentState) -> Dict:
|
239 |
+
"""Extract the user query and store it in the state."""
|
240 |
try:
|
241 |
query = state["messages"][-1].content
|
|
|
242 |
return {
|
243 |
"messages": [AIMessage(content="Query ingested successfully")],
|
244 |
"context": {"raw_query": query},
|
|
|
248 |
return self._error_state(f"Ingestion Error: {str(e)}")
|
249 |
|
250 |
def retrieve_documents(self, state: AgentState) -> Dict:
|
251 |
+
"""Retrieve relevant documents from the 'research' domain."""
|
252 |
try:
|
253 |
query = state["context"]["raw_query"]
|
254 |
docs = retriever.retrieve(query, "research")
|
|
|
255 |
return {
|
256 |
"messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
|
257 |
"context": {
|
|
|
263 |
return self._error_state(f"Retrieval Error: {str(e)}")
|
264 |
|
265 |
def analyze_content(self, state: AgentState) -> Dict:
|
266 |
+
"""Concatenate document contents and analyze them using the CognitiveProcessor."""
|
267 |
try:
|
|
|
268 |
if "documents" not in state["context"] or not state["context"]["documents"]:
|
269 |
return self._error_state("No documents retrieved; please check your query or retrieval process.")
|
270 |
|
271 |
+
docs = "\n\n".join([
|
272 |
+
d.page_content for d in state["context"]["documents"]
|
273 |
+
if hasattr(d, "page_content") and d.page_content
|
274 |
+
])
|
275 |
prompt = ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs)
|
276 |
response = self.processor.process_query(prompt)
|
277 |
|
|
|
286 |
return self._error_state(f"Analysis Error: {str(e)}")
|
287 |
|
288 |
def validate_output(self, state: AgentState) -> Dict:
|
289 |
+
"""Validate the technical correctness of the analysis output."""
|
290 |
analysis = state["messages"][-1].content
|
291 |
validation_prompt = f"""Validate research analysis:
|
292 |
{analysis}
|
|
|
305 |
}
|
306 |
|
307 |
def refine_results(self, state: AgentState) -> Dict:
|
308 |
+
"""Refine the analysis based on the validation feedback."""
|
309 |
refinement_prompt = f"""Refine this analysis:
|
310 |
{state["messages"][-1].content}
|
311 |
|
|
|
321 |
}
|
322 |
|
323 |
def _quality_check(self, state: AgentState) -> str:
|
324 |
+
"""Check if the validation step indicates a 'VALID' or 'INVALID' output."""
|
325 |
content = state["messages"][-1].content
|
326 |
return "valid" if "VALID" in content else "invalid"
|
327 |
|
328 |
def _error_state(self, message: str) -> Dict:
|
329 |
+
"""Return an error message and mark the state as erroneous."""
|
330 |
+
st.error(f"[ERROR] {message}")
|
331 |
return {
|
332 |
"messages": [AIMessage(content=f"❌ {message}")],
|
333 |
"context": {"error": True},
|
334 |
"metadata": {"status": "error"}
|
335 |
}
|
336 |
|
337 |
+
# ---------------------------------------------
|
338 |
# Research Interface
|
339 |
+
# ---------------------------------------------
|
340 |
class ResearchInterface:
|
341 |
def __init__(self):
|
342 |
self.workflow = ResearchWorkflow()
|
343 |
+
# We've already set the page config at the top.
|
344 |
self._inject_styles()
|
345 |
self._build_sidebar()
|
346 |
self._build_main_interface()
|
347 |
|
348 |
def _inject_styles(self):
|
349 |
+
"""Inject custom CSS for a sleek interface."""
|
350 |
st.markdown("""
|
351 |
<style>
|
352 |
:root {
|
|
|
393 |
""", unsafe_allow_html=True)
|
394 |
|
395 |
def _build_sidebar(self):
|
396 |
+
"""Construct the left sidebar with document info and metrics."""
|
397 |
with st.sidebar:
|
398 |
st.title("🔍 Research Database")
|
399 |
st.subheader("Technical Papers")
|
|
|
406 |
st.metric("Embedding Dimensions", ResearchConfig.EMBEDDING_DIMENSIONS)
|
407 |
|
408 |
def _build_main_interface(self):
|
409 |
+
"""Construct the main interface for query input and result display."""
|
410 |
st.title("🧠 NeuroResearch AI")
|
411 |
query = st.text_area("Research Query:", height=200,
|
412 |
placeholder="Enter technical research question...")
|
|
|
415 |
self._execute_analysis(query)
|
416 |
|
417 |
def _execute_analysis(self, query: str):
|
418 |
+
"""Execute the entire research workflow and render the results."""
|
419 |
try:
|
420 |
with st.spinner("Initializing Quantum Analysis..."):
|
421 |
results = self.workflow.app.stream(
|
|
|
433 |
- Temporal processing constraints""")
|
434 |
|
435 |
def _render_event(self, event: Dict):
|
436 |
+
"""Render each node's output in the UI as it streams through the workflow."""
|
437 |
if 'ingest' in event:
|
438 |
with st.container():
|
439 |
st.success("✅ Query Ingested")
|
|
|
462 |
with st.expander("View Validation Details", expanded=True):
|
463 |
st.markdown(content)
|
464 |
|
465 |
+
# ---------------------------------------------
|
466 |
+
# Main Execution
|
467 |
+
# ---------------------------------------------
|
468 |
if __name__ == "__main__":
|
469 |
ResearchInterface()
|