mgbam's picture
Update app.py
e1707aa verified
raw
history blame
17.5 kB
# ------------------------------
# Imports & Initial Configuration
# ------------------------------
import streamlit as st
# Set the page configuration immediately—this must be the first Streamlit command.
st.set_page_config(page_title="NeuroResearch AI", layout="wide", initial_sidebar_state="expanded")
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict, Annotated
from typing import Sequence, Dict, List, Optional, Any
import chromadb
import re
import os
import requests
import hashlib
import json
import time
from langchain.tools.retriever import create_retriever_tool
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
# ------------------------------
# State Schema Definition
# ------------------------------
class AgentState(TypedDict):
messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages]
context: Dict[str, Any]
metadata: Dict[str, Any]
# ------------------------------
# Configuration
# ------------------------------
class ResearchConfig:
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
CHROMA_PATH = "chroma_db"
CHUNK_SIZE = 512
CHUNK_OVERLAP = 64
MAX_CONCURRENT_REQUESTS = 5
EMBEDDING_DIMENSIONS = 1536
DOCUMENT_MAP = {
"Research Report: Results of a New AI Model Improving Image Recognition Accuracy to 98%":
"CV-Transformer Hybrid Architecture",
"Academic Paper Summary: Why Transformers Became the Mainstream Architecture in Natural Language Processing":
"Transformer Architecture Analysis",
"Latest Trends in Machine Learning Methods Using Quantum Computing":
"Quantum ML Frontiers"
}
ANALYSIS_TEMPLATE = """Analyze these technical documents with scientific rigor:
{context}
Respond with:
1. Key Technical Contributions (bullet points)
2. Novel Methodologies
3. Empirical Results (with metrics)
4. Potential Applications
5. Limitations & Future Directions
Format: Markdown with LaTeX mathematical notation where applicable
"""
# Validate API key configuration
if not ResearchConfig.DEEPSEEK_API_KEY:
st.error("""**Research Portal Configuration Required**
1. Obtain DeepSeek API key: [platform.deepseek.com](https://platform.deepseek.com/)
2. Configure secret: `DEEPSEEK_API_KEY` in Space settings
3. Rebuild deployment""")
st.stop()
# ------------------------------
# Quantum Document Processing
# ------------------------------
class QuantumDocumentManager:
def __init__(self):
self.client = chromadb.PersistentClient(path=ResearchConfig.CHROMA_PATH)
self.embeddings = OpenAIEmbeddings(
model="text-embedding-3-large",
dimensions=ResearchConfig.EMBEDDING_DIMENSIONS
)
def create_collection(self, documents: List[str], collection_name: str) -> Chroma:
splitter = RecursiveCharacterTextSplitter(
chunk_size=ResearchConfig.CHUNK_SIZE,
chunk_overlap=ResearchConfig.CHUNK_OVERLAP,
separators=["\n\n", "\n", "|||"]
)
docs = splitter.create_documents(documents)
# Debug: log the number of chunks created for the collection.
st.write(f"Created {len(docs)} chunks for collection '{collection_name}'")
return Chroma.from_documents(
documents=docs,
embedding=self.embeddings,
client=self.client,
collection_name=collection_name,
ids=[self._document_id(doc.page_content) for doc in docs]
)
def _document_id(self, content: str) -> str:
return f"{hashlib.sha256(content.encode()).hexdigest()[:16]}-{int(time.time())}"
# Initialize document collections
qdm = QuantumDocumentManager()
research_docs = qdm.create_collection([
"Research Report: Results of a New AI Model Improving Image Recognition Accuracy to 98%",
"Academic Paper Summary: Why Transformers Became the Mainstream Architecture in Natural Language Processing",
"Latest Trends in Machine Learning Methods Using Quantum Computing"
], "research")
development_docs = qdm.create_collection([
"Project A: UI Design Completed, API Integration in Progress",
"Project B: Testing New Feature X, Bug Fixes Needed",
"Product Y: In the Performance Optimization Stage Before Release"
], "development")
# ------------------------------
# Advanced Retrieval System
# ------------------------------
class ResearchRetriever:
def __init__(self):
self.retrievers = {
"research": research_docs.as_retriever(
search_type="mmr",
search_kwargs={
'k': 4,
'fetch_k': 20,
'lambda_mult': 0.85
}
),
"development": development_docs.as_retriever(
search_type="similarity",
search_kwargs={'k': 3}
)
}
def retrieve(self, query: str, domain: str) -> List[Any]:
try:
results = self.retrievers[domain].invoke(query)
st.write(f"[DEBUG] Retrieved {len(results)} documents for query: '{query}' in domain '{domain}'")
return results
except KeyError:
st.error(f"[ERROR] Retrieval domain '{domain}' not found.")
return []
retriever = ResearchRetriever()
# ------------------------------
# Cognitive Processing Unit
# ------------------------------
class CognitiveProcessor:
def __init__(self):
self.executor = ThreadPoolExecutor(max_workers=ResearchConfig.MAX_CONCURRENT_REQUESTS)
self.session_id = hashlib.sha256(datetime.now().isoformat().encode()).hexdigest()[:12]
def process_query(self, prompt: str) -> Dict:
futures = []
for _ in range(3): # Triple redundancy for robustness
futures.append(self.executor.submit(
self._execute_api_request,
prompt
))
results = []
for future in as_completed(futures):
try:
results.append(future.result())
except Exception as e:
st.error(f"Processing Error: {str(e)}")
return self._consensus_check(results)
def _execute_api_request(self, prompt: str) -> Dict:
headers = {
"Authorization": f"Bearer {ResearchConfig.DEEPSEEK_API_KEY}",
"Content-Type": "application/json",
"X-Research-Session": self.session_id
}
try:
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
headers=headers,
json={
"model": "deepseek-chat",
"messages": [{
"role": "user",
"content": f"Respond as Senior AI Researcher:\n{prompt}"
}],
"temperature": 0.7,
"max_tokens": 1500,
"top_p": 0.9
},
timeout=45
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
return {"error": str(e)}
def _consensus_check(self, results: List[Dict]) -> Dict:
valid = [r for r in results if "error" not in r]
if not valid:
return {"error": "All API requests failed"}
# Choose the result with the longest content for robustness.
return max(valid, key=lambda x: len(x.get('choices', [{}])[0].get('message', {}).get('content', '')))
# ------------------------------
# Research Workflow Engine
# ------------------------------
class ResearchWorkflow:
def __init__(self):
self.processor = CognitiveProcessor()
self.workflow = StateGraph(AgentState)
self._build_workflow()
def _build_workflow(self):
# Register nodes in the state graph
self.workflow.add_node("ingest", self.ingest_query)
self.workflow.add_node("retrieve", self.retrieve_documents)
self.workflow.add_node("analyze", self.analyze_content)
self.workflow.add_node("validate", self.validate_output)
self.workflow.add_node("refine", self.refine_results)
self.workflow.set_entry_point("ingest")
self.workflow.add_edge("ingest", "retrieve")
self.workflow.add_edge("retrieve", "analyze")
self.workflow.add_conditional_edges(
"analyze",
self._quality_check,
{"valid": "validate", "invalid": "refine"}
)
self.workflow.add_edge("validate", END)
self.workflow.add_edge("refine", "retrieve")
self.app = self.workflow.compile()
def ingest_query(self, state: AgentState) -> Dict:
try:
query = state["messages"][-1].content
st.write(f"[DEBUG] Ingesting query: {query}")
return {
"messages": [AIMessage(content="Query ingested successfully")],
"context": {"raw_query": query},
"metadata": {"timestamp": datetime.now().isoformat()}
}
except Exception as e:
return self._error_state(f"Ingestion Error: {str(e)}")
def retrieve_documents(self, state: AgentState) -> Dict:
try:
query = state["context"]["raw_query"]
docs = retriever.retrieve(query, "research")
st.write(f"[DEBUG] Retrieved {len(docs)} documents from retrieval node.")
return {
"messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
"context": {
"documents": docs,
"retrieval_time": time.time()
}
}
except Exception as e:
return self._error_state(f"Retrieval Error: {str(e)}")
def analyze_content(self, state: AgentState) -> Dict:
try:
# Ensure documents are present before proceeding.
if "documents" not in state["context"] or not state["context"]["documents"]:
return self._error_state("No documents retrieved; please check your query or retrieval process.")
# Concatenate all document content for analysis.
docs = "\n\n".join([d.page_content for d in state["context"]["documents"] if hasattr(d, "page_content")])
st.write(f"[DEBUG] Analyzing content from {len(state['context']['documents'])} documents.")
prompt = ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs)
response = self.processor.process_query(prompt)
if "error" in response:
return self._error_state(response["error"])
return {
"messages": [AIMessage(content=response['choices'][0]['message']['content'])],
"context": {"analysis": response}
}
except Exception as e:
return self._error_state(f"Analysis Error: {str(e)}")
def validate_output(self, state: AgentState) -> Dict:
analysis = state["messages"][-1].content
validation_prompt = f"""Validate research analysis:
{analysis}
Check for:
1. Technical accuracy
2. Citation support
3. Logical consistency
4. Methodological soundness
Respond with 'VALID' or 'INVALID'"""
response = self.processor.process_query(validation_prompt)
return {
"messages": [AIMessage(content=analysis + f"\n\nValidation: {response.get('choices', [{}])[0].get('message', {}).get('content', '')}")]
}
def refine_results(self, state: AgentState) -> Dict:
refinement_prompt = f"""Refine this analysis:
{state["messages"][-1].content}
Improve:
1. Technical precision
2. Empirical grounding
3. Theoretical coherence"""
response = self.processor.process_query(refinement_prompt)
return {
"messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
"context": state["context"]
}
def _quality_check(self, state: AgentState) -> str:
content = state["messages"][-1].content
return "valid" if "VALID" in content else "invalid"
def _error_state(self, message: str) -> Dict:
st.write(f"[ERROR] {message}")
return {
"messages": [AIMessage(content=f"❌ {message}")],
"context": {"error": True},
"metadata": {"status": "error"}
}
# ------------------------------
# Research Interface
# ------------------------------
class ResearchInterface:
def __init__(self):
self.workflow = ResearchWorkflow()
# Do not call st.set_page_config here because it has already been called at the top.
self._inject_styles()
self._build_sidebar()
self._build_main_interface()
def _inject_styles(self):
st.markdown("""
<style>
:root {
--primary: #2ecc71;
--secondary: #3498db;
--background: #0a0a0a;
--text: #ecf0f1;
}
.stApp {
background: var(--background);
color: var(--text);
font-family: 'Roboto', sans-serif;
}
.stTextArea textarea {
background: #1a1a1a !important;
color: var(--text) !important;
border: 2px solid var(--secondary);
border-radius: 8px;
padding: 1rem;
}
.stButton>button {
background: linear-gradient(135deg, var(--primary), var(--secondary));
border: none;
border-radius: 8px;
padding: 1rem 2rem;
transition: all 0.3s;
}
.stButton>button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(46, 204, 113, 0.3);
}
.stExpander {
background: #1a1a1a;
border: 1px solid #2a2a2a;
border-radius: 8px;
margin: 1rem 0;
}
</style>
""", unsafe_allow_html=True)
def _build_sidebar(self):
with st.sidebar:
st.title("🔍 Research Database")
st.subheader("Technical Papers")
for title, short in ResearchConfig.DOCUMENT_MAP.items():
with st.expander(short):
st.markdown(f"```\n{title}\n```")
st.subheader("Analysis Metrics")
st.metric("Vector Collections", 2)
st.metric("Embedding Dimensions", ResearchConfig.EMBEDDING_DIMENSIONS)
def _build_main_interface(self):
st.title("🧠 NeuroResearch AI")
query = st.text_area("Research Query:", height=200,
placeholder="Enter technical research question...")
if st.button("Execute Analysis", type="primary"):
self._execute_analysis(query)
def _execute_analysis(self, query: str):
try:
with st.spinner("Initializing Quantum Analysis..."):
results = self.workflow.app.stream(
{"messages": [HumanMessage(content=query)], "context": {}, "metadata": {}}
)
for event in results:
self._render_event(event)
st.success("✅ Analysis Completed Successfully")
except Exception as e:
st.error(f"""**Analysis Failed**
{str(e)}
Potential issues:
- Complex query structure
- Document correlation failure
- Temporal processing constraints""")
def _render_event(self, event: Dict):
if 'ingest' in event:
with st.container():
st.success("✅ Query Ingested")
elif 'retrieve' in event:
with st.container():
docs = event['retrieve']['context']['documents']
st.info(f"📚 Retrieved {len(docs)} documents")
with st.expander("View Retrieved Documents", expanded=False):
for i, doc in enumerate(docs, 1):
st.markdown(f"**Document {i}**")
st.code(doc.page_content, language='text')
elif 'analyze' in event:
with st.container():
content = event['analyze']['messages'][0].content
with st.expander("Technical Analysis Report", expanded=True):
st.markdown(content)
elif 'validate' in event:
with st.container():
content = event['validate']['messages'][0].content
if "VALID" in content:
st.success("✅ Validation Passed")
with st.expander("View Validated Analysis", expanded=True):
st.markdown(content.split("Validation:")[0])
else:
st.warning("⚠️ Validation Issues Detected")
with st.expander("View Validation Details", expanded=True):
st.markdown(content)
if __name__ == "__main__":
ResearchInterface()