Spaces:
Sleeping
Sleeping
# ------------------------------ | |
# Imports & Dependencies | |
# ------------------------------ | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_core.documents import Document | |
from langgraph.graph import END, StateGraph | |
from typing_extensions import TypedDict, Annotated | |
from typing import Sequence, Dict, List, Optional, Any | |
import chromadb | |
from chromadb.config import Settings | |
import numpy as np | |
import os | |
import streamlit as st | |
import requests | |
import hashlib | |
import re | |
import time | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from datetime import datetime | |
from sklearn.metrics.pairwise import cosine_similarity | |
# ------------------------------ | |
# State Schema Definition | |
# ------------------------------ | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[AIMessage | HumanMessage], add_messages] | |
context: Dict[str, Any] | |
metadata: Dict[str, Any] | |
# ------------------------------ | |
# Configuration | |
# ------------------------------ | |
class ResearchConfig: | |
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY") | |
CHROMA_PATH = "chroma_db" | |
CHUNK_SIZE = 512 | |
CHUNK_OVERLAP = 64 | |
MAX_CONCURRENT_REQUESTS = 5 | |
EMBEDDING_DIMENSIONS = 1536 | |
RESEARCH_EMBEDDING = np.random.randn(1536) | |
TENANT = "research_tenant" | |
DATABASE = "ai_papers_db" | |
DOCUMENT_MAP = { | |
"CV-Transformer Hybrid Architecture": { | |
"title": "Hybrid CV-Transformer Model (98% Accuracy)", | |
"content": """ | |
Combines CNN feature extraction with transformer attention mechanisms. | |
Key equation: $f(x) = \text{Softmax}(\frac{QK^T}{\sqrt{d_k}})V$ | |
ImageNet-1k: 98.2% Top-1 Accuracy, 42ms/inference | |
""" | |
}, | |
"Transformer Architecture Analysis": { | |
"title": "Transformer Architectures in NLP", | |
"content": """ | |
Self-attention mechanisms enable parallel processing of sequences. | |
$\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$ | |
GLUE Score: 92.4%, Training Efficiency: 1.8x vs RNNs | |
""" | |
} | |
} | |
ANALYSIS_TEMPLATE = """Analyze these technical documents: | |
{context} | |
Respond in MARKDOWN with: | |
1. **Key Innovations** (mathematical formulations) | |
2. **Methodologies** (algorithms & architectures) | |
3. **Empirical Results** (comparative metrics) | |
4. **Applications** (industry use cases) | |
5. **Limitations** (theoretical boundaries) | |
Include LaTeX equations where applicable.""" | |
if not ResearchConfig.DEEPSEEK_API_KEY: | |
st.error("""**Configuration Required** | |
1. Get DeepSeek API key: [platform.deepseek.com](https://platform.deepseek.com/) | |
2. Set secret: `DEEPSEEK_API_KEY` | |
3. Rebuild deployment""") | |
st.stop() | |
# ------------------------------ | |
# ChromaDB Document Manager (Fixed) | |
# ------------------------------ | |
class QuantumDocumentManager: | |
def __init__(self): | |
self.client_settings = Settings( | |
chroma_db_impl="duckdb+parquet", | |
persist_directory=ResearchConfig.CHROMA_PATH, | |
anonymized_telemetry=False | |
) | |
self.client = chromadb.Client(self.client_settings) | |
self._initialize_tenant_db() | |
self.embeddings = OpenAIEmbeddings( | |
model="text-embedding-3-large", | |
dimensions=ResearchConfig.EMBEDDING_DIMENSIONS | |
) | |
def _initialize_tenant_db(self): | |
try: | |
self.client.create_tenant(ResearchConfig.TENANT) | |
except chromadb.db.base.UniqueConstraintError: | |
pass # Tenant exists | |
try: | |
self.client.create_database( | |
ResearchConfig.DATABASE, | |
tenant=ResearchConfig.TENANT | |
) | |
except chromadb.db.base.UniqueConstraintError: | |
pass # Database exists | |
def create_collection(self, document_map: Dict[str, Dict[str, str]], collection_name: str) -> Chroma: | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=ResearchConfig.CHUNK_SIZE, | |
chunk_overlap=ResearchConfig.CHUNK_OVERLAP, | |
separators=["\n\n", "\n", "|||"] | |
) | |
docs = [] | |
for key, data in document_map.items(): | |
chunks = splitter.split_text(data["content"]) | |
for chunk in chunks: | |
docs.append(Document( | |
page_content=chunk, | |
metadata={ | |
"title": data["title"], | |
"source": collection_name, | |
"hash": hashlib.sha256(chunk.encode()).hexdigest()[:16] | |
} | |
)) | |
return Chroma.from_documents( | |
documents=docs, | |
embedding=self.embeddings, | |
collection_name=collection_name, | |
client=self.client, | |
tenant=ResearchConfig.TENANT, | |
database=ResearchConfig.DATABASE, | |
collection_metadata={"hnsw:space": "cosine"}, | |
ids=[self._document_id(doc.page_content) for doc in docs] | |
) | |
def _document_id(self, content: str) -> str: | |
return f"{hashlib.sha256(content.encode()).hexdigest()[:16]}-{int(time.time())}" | |
# Initialize document system | |
qdm = QuantumDocumentManager() | |
research_docs = qdm.create_collection(ResearchConfig.DOCUMENT_MAP, "research_papers") | |
# ------------------------------ | |
# Retrieval System | |
# ------------------------------ | |
class ResearchRetriever: | |
def __init__(self): | |
self.retriever = research_docs.as_retriever( | |
search_type="mmr", | |
search_kwargs={ | |
'k': 4, | |
'fetch_k': 20, | |
'lambda_mult': 0.85 | |
} | |
) | |
def retrieve(self, query: str) -> List[Document]: | |
try: | |
docs = self.retriever.invoke(query) | |
if len(docs) < 1: | |
raise ValueError("No relevant documents found") | |
return docs | |
except Exception as e: | |
st.error(f"Retrieval Error: {str(e)}") | |
return [] | |
# ------------------------------ | |
# Analysis Processor | |
# ------------------------------ | |
class CognitiveProcessor: | |
def __init__(self): | |
self.executor = ThreadPoolExecutor(max_workers=ResearchConfig.MAX_CONCURRENT_REQUESTS) | |
def process_query(self, prompt: str) -> Dict: | |
futures = [self.executor.submit(self._api_request, prompt) for _ in range(3)] | |
return self._best_result([f.result() for f in as_completed(futures)]) | |
def _api_request(self, prompt: str) -> Dict: | |
headers = { | |
"Authorization": f"Bearer {ResearchConfig.DEEPSEEK_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
try: | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers=headers, | |
json={ | |
"model": "deepseek-chat", | |
"messages": [{ | |
"role": "user", | |
"content": f"Respond as Senior AI Researcher:\n{prompt}" | |
}], | |
"temperature": 0.7, | |
"max_tokens": 1500, | |
"top_p": 0.9 | |
}, | |
timeout=45 | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
return {"error": str(e)} | |
def _best_result(self, results: List[Dict]) -> Dict: | |
valid = [r for r in results if "error" not in r] | |
if not valid: | |
return {"error": "All API requests failed"} | |
# Select response with most technical content | |
contents = [r.get('choices', [{}])[0].get('message', {}).get('content', '') for r in valid] | |
tech_scores = [len(re.findall(r"\$.*?\$", c)) for c in contents] | |
return valid[np.argmax(tech_scores)] | |
# ------------------------------ | |
# Workflow Engine | |
# ------------------------------ | |
class ResearchWorkflow: | |
def __init__(self): | |
self.retriever = ResearchRetriever() | |
self.processor = CognitiveProcessor() | |
self.workflow = StateGraph(AgentState) | |
self._build_workflow() | |
def _build_workflow(self): | |
self.workflow.add_node("ingest", self.ingest) | |
self.workflow.add_node("retrieve", self.retrieve) | |
self.workflow.add_node("analyze", self.analyze) | |
self.workflow.add_node("validate", self.validate) | |
self.workflow.add_node("refine", self.refine) | |
self.workflow.set_entry_point("ingest") | |
self.workflow.add_edge("ingest", "retrieve") | |
self.workflow.add_edge("retrieve", "analyze") | |
self.workflow.add_conditional_edges( | |
"analyze", | |
self._quality_check, | |
{"valid": "validate", "invalid": "refine"} | |
) | |
self.workflow.add_edge("validate", END) | |
self.workflow.add_edge("refine", "retrieve") | |
self.app = self.workflow.compile() | |
def ingest(self, state: AgentState) -> Dict: | |
try: | |
query = state["messages"][-1].content | |
return { | |
"messages": [AIMessage(content="Query ingested")], | |
"context": {"query": query}, | |
"metadata": {"timestamp": datetime.now().isoformat()} | |
} | |
except Exception as e: | |
return self._error_state(f"Ingestion Error: {str(e)}") | |
def retrieve(self, state: AgentState) -> Dict: | |
try: | |
docs = self.retriever.retrieve(state["context"]["query"]) | |
return { | |
"messages": [AIMessage(content=f"Found {len(docs)} relevant papers")], | |
"context": {"docs": docs} | |
} | |
except Exception as e: | |
return self._error_state(f"Retrieval Error: {str(e)}") | |
def analyze(self, state: AgentState) -> Dict: | |
try: | |
context = "\n\n".join([ | |
f"### {doc.metadata['title']}\n{doc.page_content}" | |
for doc in state["context"]["docs"] | |
]) | |
prompt = ResearchConfig.ANALYSIS_TEMPLATE.format(context=context) | |
response = self.processor.process_query(prompt) | |
if "error" in response: | |
raise RuntimeError(response["error"]) | |
content = response['choices'][0]['message']['content'] | |
self._validate_analysis(content) | |
return {"messages": [AIMessage(content=content)]} | |
except Exception as e: | |
return self._error_state(f"Analysis Error: {str(e)}") | |
def validate(self, state: AgentState) -> Dict: | |
validation_prompt = f"""Validate this technical analysis: | |
{state["messages"][-1].content} | |
Check for: | |
1. Mathematical accuracy | |
2. Technical depth | |
3. Logical consistency | |
Respond with 'VALID' or 'INVALID'""" | |
response = self.processor.process_query(validation_prompt) | |
valid = "VALID" in response.get('choices', [{}])[0].get('message', {}).get('content', '') | |
return { | |
"messages": [AIMessage(content=f"{state['messages'][-1].content}\n\nValidation: {'✅ Valid' if valid else '❌ Invalid'}")], | |
"context": {"valid": valid} | |
} | |
def refine(self, state: AgentState) -> Dict: | |
refinement_prompt = f"""Improve this analysis: | |
{state["messages"][-1].content} | |
Focus on: | |
1. Mathematical precision | |
2. Technical terminology | |
3. Empirical references""" | |
response = self.processor.process_query(refinement_prompt) | |
return {"messages": [AIMessage(content=response['choices'][0]['message']['content'])]} | |
def _quality_check(self, state: AgentState) -> str: | |
return "valid" if state.get("context", {}).get("valid", False) else "invalid" | |
def _validate_analysis(self, content: str): | |
required_sections = [ | |
"Key Innovations", | |
"Methodologies", | |
"Empirical Results", | |
"Applications", | |
"Limitations" | |
] | |
missing = [s for s in required_sections if f"## {s}" not in content] | |
if missing: | |
raise ValueError(f"Missing sections: {', '.join(missing)}") | |
if not re.search(r"\$.*?\$", content): | |
raise ValueError("Analysis lacks mathematical notation") | |
def _error_state(self, message: str) -> Dict: | |
return { | |
"messages": [AIMessage(content=f"❌ {message}")], | |
"context": {"error": True}, | |
"metadata": {"status": "error"} | |
} | |
# ------------------------------ | |
# Streamlit Interface | |
# ------------------------------ | |
class ResearchInterface: | |
def __init__(self): | |
self.workflow = ResearchWorkflow() | |
self._initialize() | |
def _initialize(self): | |
st.set_page_config( | |
page_title="AI Research Assistant", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
self._inject_styles() | |
self._build_sidebar() | |
self._build_main() | |
def _inject_styles(self): | |
st.markdown(""" | |
<style> | |
:root { | |
--primary: #2ecc71; | |
--secondary: #3498db; | |
--background: #0a0a0a; | |
} | |
.stApp { | |
background: var(--background); | |
color: white; | |
} | |
.stTextArea textarea { | |
background: #1a1a1a !important; | |
border: 2px solid var(--secondary) !important; | |
} | |
code { | |
color: var(--primary); | |
background: #002200; | |
padding: 2px 4px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def _build_sidebar(self): | |
with st.sidebar: | |
st.title("🔬 Research Corpus") | |
for key, data in ResearchConfig.DOCUMENT_MAP.items(): | |
with st.expander(data["title"]): | |
st.markdown(f"```latex\n{data['content']}\n```") | |
st.metric("Vector DB Size", len(research_docs.get()['ids'])) | |
def _build_main(self): | |
st.title("🧠 AI Research Analyst") | |
query = st.text_area("Research Query:", height=150, | |
placeholder="Enter technical question...") | |
if st.button("Analyze", type="primary"): | |
self._execute_analysis(query) | |
def _execute_analysis(self, query: str): | |
try: | |
with st.spinner("Analyzing research corpus..."): | |
result = self.workflow.app.invoke( | |
{"messages": [HumanMessage(content=query)]} | |
) | |
if result.get("context", {}).get("error"): | |
self._show_error(result["context"]["error"]) | |
else: | |
self._display_result(result) | |
except Exception as e: | |
self._show_error(str(e)) | |
def _display_result(self, result): | |
with st.expander("Technical Report", expanded=True): | |
st.markdown(result["messages"][-1].content) | |
with st.expander("Source Excerpts", expanded=False): | |
for doc in result["context"].get("docs", []): | |
st.markdown(f"**{doc.metadata['title']}**") | |
st.code(doc.page_content, language='latex') | |
def _show_error(self, message): | |
st.error(f""" | |
⚠️ Analysis Failed | |
{message} | |
Mitigation Steps: | |
1. Simplify query complexity | |
2. Check document connections | |
3. Verify technical terms | |
""") | |
if __name__ == "__main__": | |
ResearchInterface() |