Spaces:
Sleeping
Sleeping
# ------------------------------ | |
# Imports & Dependencies | |
# ------------------------------ | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langgraph.graph import END, StateGraph | |
from typing_extensions import TypedDict, Annotated | |
from typing import Sequence, Dict, List, Optional, Any | |
import chromadb | |
import os | |
import streamlit as st | |
import requests | |
import hashlib | |
import json | |
import time | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from datetime import datetime | |
from pydantic import BaseModel, ValidationError | |
import traceback | |
# ------------------------------ | |
# Configuration & Constants | |
# ------------------------------ | |
class ResearchConfig: | |
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY") | |
CHROMA_PATH = "chroma_db" | |
CHUNK_SIZE = 512 | |
CHUNK_OVERLAP = 64 | |
MAX_CONCURRENT_REQUESTS = 5 | |
EMBEDDING_DIMENSIONS = 1536 | |
ANALYSIS_TEMPLATE = """**Technical Analysis Request** | |
{context} | |
Respond with: | |
1. Key Technical Innovations (markdown table) | |
2. Methodological Breakdown (bullet points) | |
3. Quantitative Results (LaTeX equations) | |
4. Critical Evaluation | |
5. Research Impact Assessment | |
Include proper academic citations where applicable.""" | |
# ------------------------------ | |
# Document Schema & Content | |
# ------------------------------ | |
DOCUMENT_CONTENT = { | |
"CV-Transformer Hybrid": { | |
"content": """## Hybrid Architecture for Computer Vision | |
**Authors**: DeepVision Research Team | |
**Abstract**: Novel combination of convolutional layers with transformer attention mechanisms. | |
### Key Innovations: | |
- Cross-attention feature fusion | |
- Adaptive spatial pooling | |
- Multi-scale gradient propagation | |
$$\\mathcal{L}_{total} = \\alpha\\mathcal{L}_{CE} + \\beta\\mathcal{L}_{SSIM}$$""", | |
"metadata": { | |
"year": 2024, | |
"domain": "computer_vision", | |
"citations": 142 | |
} | |
}, | |
"Quantum ML Advances": { | |
"content": """## Quantum Machine Learning Breakthroughs | |
**Authors**: Quantum AI Lab | |
### Achievements: | |
- Quantum-enhanced SGD (40% faster convergence) | |
- 5-qubit QNN achieving 98% accuracy | |
- Hybrid quantum-classical GANs | |
$$\\mathcal{H} = -\\sum_{i<j} J_{ij}\\sigma_i^z\\sigma_j^z - \\Gamma\\sum_i\\sigma_i^x$$""", | |
"metadata": { | |
"year": 2023, | |
"domain": "quantum_ml", | |
"citations": 89 | |
} | |
} | |
} | |
class DocumentSchema(BaseModel): | |
content: str | |
metadata: dict | |
doc_id: str | |
# ------------------------------ | |
# State Management | |
# ------------------------------ | |
class ResearchState(TypedDict): | |
messages: Annotated[List[BaseMessage], add_messages] | |
context: Annotated[Dict[str, Any], "research_context"] | |
metadata: Annotated[Dict[str, str], "system_metadata"] | |
# ------------------------------ | |
# Document Processing | |
# ------------------------------ | |
class DocumentManager: | |
def __init__(self): | |
self.client = chromadb.PersistentClient(path=ResearchConfig.CHROMA_PATH) | |
self.embeddings = OpenAIEmbeddings( | |
model="text-embedding-3-large", | |
dimensions=ResearchConfig.EMBEDDING_DIMENSIONS | |
) | |
def initialize_collections(self): | |
try: | |
self.research_col = self._create_collection("research") | |
self.dev_col = self._create_collection("development") | |
except Exception as e: | |
st.error(f"Collection initialization failed: {str(e)}") | |
traceback.print_exc() | |
def _create_collection(self, name: str) -> Chroma: | |
documents, metadatas, ids = [], [], [] | |
for title, data in DOCUMENT_CONTENT.items(): | |
try: | |
doc = DocumentSchema( | |
content=data["content"], | |
metadata=data["metadata"], | |
doc_id=hashlib.sha256(title.encode()).hexdigest()[:16] | |
) | |
documents.append(doc.content) | |
metadatas.append(doc.metadata) | |
ids.append(doc.doc_id) | |
except ValidationError as e: | |
st.error(f"Invalid document format: {title} - {str(e)}") | |
continue | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=ResearchConfig.CHUNK_SIZE, | |
chunk_overlap=ResearchConfig.CHUNK_OVERLAP, | |
separators=["\n## ", "\n### ", "\n\n", "\nβ’ "] | |
) | |
try: | |
docs = splitter.create_documents(documents, metadatas=metadatas) | |
return Chroma.from_documents( | |
docs, | |
self.embeddings, | |
client=self.client, | |
collection_name=name, | |
ids=ids | |
) | |
except Exception as e: | |
raise RuntimeError(f"Failed creating {name} collection: {str(e)}") | |
# ------------------------------ | |
# Retrieval System | |
# ------------------------------ | |
class ResearchRetriever: | |
def __init__(self): | |
self.dm = DocumentManager() | |
self.dm.initialize_collections() | |
def retrieve(self, query: str, domain: str) -> List[DocumentSchema]: | |
try: | |
collection = self.dm.research_col if domain == "research" else self.dm.dev_col | |
if not collection: | |
return [] | |
results = collection.as_retriever( | |
search_type="mmr", | |
search_kwargs={'k': 4, 'fetch_k': 20} | |
).invoke(query) | |
return [DocumentSchema( | |
content=doc.page_content, | |
metadata=doc.metadata, | |
doc_id=doc.metadata.get("doc_id", "") | |
) for doc in results if doc.page_content] | |
except Exception as e: | |
st.error(f"Retrieval failure: {str(e)}") | |
traceback.print_exc() | |
return [] | |
# ------------------------------ | |
# Analysis Processor | |
# ------------------------------ | |
class AnalysisEngine: | |
def __init__(self): | |
self.executor = ThreadPoolExecutor(max_workers=ResearchConfig.MAX_CONCURRENT_REQUESTS) | |
self.session_hash = hashlib.sha256(str(time.time()).encode()).hexdigest()[:12] | |
def analyze(self, prompt: str) -> Dict: | |
futures = [self.executor.submit(self._api_request, prompt) for _ in range(3)] | |
return self._validate_results([f.result() for f in as_completed(futures)]) | |
def _api_request(self, prompt: str) -> Dict: | |
headers = { | |
"Authorization": f"Bearer {ResearchConfig.DEEPSEEK_API_KEY}", | |
"X-Session-ID": self.session_hash, | |
"Content-Type": "application/json" | |
} | |
try: | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers=headers, | |
json={ | |
"model": "deepseek-chat", | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.7, | |
"max_tokens": 2000 | |
}, | |
timeout=30 | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
return {"error": str(e), "status_code": 500} | |
def _validate_results(self, results: List[Dict]) -> Dict: | |
valid = [r for r in results if "error" not in r] | |
if not valid: | |
return {"error": "All analysis attempts failed", "results": results} | |
# Corrected line with proper parenthesis closure | |
best = max(valid, key=lambda x: len(x.get('choices', [{}])[0].get('message', {}).get('content', ''))) | |
return best | |
# ------------------------------ | |
# Workflow Implementation | |
# ------------------------------ | |
class ResearchWorkflow: | |
def __init__(self): | |
self.retriever = ResearchRetriever() | |
self.engine = AnalysisEngine() | |
self.workflow = StateGraph(ResearchState) | |
self._build_graph() | |
def _build_graph(self): | |
self.workflow.add_node("ingest", self._ingest) | |
self.workflow.add_node("retrieve", self._retrieve) | |
self.workflow.add_node("analyze", self._analyze) | |
self.workflow.add_node("validate", self._validate) | |
self.workflow.add_node("refine", self._refine) | |
self.workflow.set_entry_point("ingest") | |
self.workflow.add_edge("ingest", "retrieve") | |
self.workflow.add_edge("retrieve", "analyze") | |
self.workflow.add_conditional_edges( | |
"analyze", | |
self._quality_gate, | |
{"valid": "validate", "invalid": "refine"} | |
) | |
self.workflow.add_edge("validate", END) | |
self.workflow.add_edge("refine", "retrieve") | |
def _ingest(self, state: ResearchState) -> ResearchState: | |
try: | |
query = next(msg.content for msg in reversed(state["messages"]) | |
if isinstance(msg, HumanMessage)) | |
return { | |
"messages": [AIMessage(content="Query ingested")], | |
"context": { | |
"query": query, | |
"documents": [], | |
"errors": [] | |
}, | |
"metadata": { | |
"session_id": hashlib.sha256(str(time.time()).encode()).hexdigest()[:8], | |
"timestamp": datetime.now().isoformat() | |
} | |
} | |
except Exception as e: | |
return self._handle_error(f"Ingest failed: {str(e)}", state) | |
def _retrieve(self, state: ResearchState) -> ResearchState: | |
try: | |
docs = self.retriever.retrieve(state["context"]["query"], "research") | |
return { | |
"messages": [AIMessage(content=f"Retrieved {len(docs)} documents")], | |
"context": { | |
**state["context"], | |
"documents": docs, | |
"retrieval_time": time.time() | |
}, | |
"metadata": state["metadata"] | |
} | |
except Exception as e: | |
return self._handle_error(f"Retrieval error: {str(e)}", state) | |
def _analyze(self, state: ResearchState) -> ResearchState: | |
docs = state["context"].get("documents", []) | |
if not docs: | |
return self._handle_error("No documents for analysis", state) | |
try: | |
context = "\n\n".join([d.content for d in docs]) | |
prompt = ResearchConfig.ANALYSIS_TEMPLATE.format(context=context) | |
result = self.engine.analyze(prompt) | |
if "error" in result: | |
raise RuntimeError(result["error"]) | |
content = result['choices'][0]['message']['content'] | |
if len(content) < 200 or not any(c.isalpha() for c in content): | |
raise ValueError("Insufficient analysis content") | |
return { | |
"messages": [AIMessage(content=content)], | |
"context": state["context"], | |
"metadata": state["metadata"] | |
} | |
except Exception as e: | |
return self._handle_error(f"Analysis failed: {str(e)}", state) | |
def _validate(self, state: ResearchState) -> ResearchState: | |
return state | |
def _refine(self, state: ResearchState) -> ResearchState: | |
return state | |
def _quality_gate(self, state: ResearchState) -> str: | |
content = state["messages"][-1].content if state["messages"] else "" | |
required = ["Innovations", "Results", "Evaluation"] | |
return "valid" if all(kw in content for kw in required) else "invalid" | |
def _handle_error(self, message: str, state: ResearchState) -> ResearchState: | |
return { | |
"messages": [AIMessage(content=f"π¨ Error: {message}")], | |
"context": { | |
**state["context"], | |
"errors": state["context"]["errors"] + [message] | |
}, | |
"metadata": state["metadata"] | |
} | |
# ------------------------------ | |
# User Interface | |
# ------------------------------ | |
class ResearchInterface: | |
def __init__(self): | |
self.workflow = ResearchWorkflow().workflow.compile() | |
self._setup_interface() | |
def _setup_interface(self): | |
st.set_page_config( | |
page_title="Research Assistant", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
self._apply_styles() | |
self._build_sidebar() | |
self._build_main() | |
def _apply_styles(self): | |
st.markdown(""" | |
<style> | |
.stApp { | |
background: #0a192f; | |
color: #64ffda; | |
} | |
.stTextArea textarea { | |
background: #172a45 !important; | |
color: #a8b2d1 !important; | |
} | |
.stButton>button { | |
background: #233554; | |
border: 1px solid #64ffda; | |
} | |
.error-box { | |
border: 1px solid #ff4444; | |
border-radius: 5px; | |
padding: 1rem; | |
margin: 1rem 0; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def _build_sidebar(self): | |
with st.sidebar: | |
st.title("π Document Database") | |
for title, data in DOCUMENT_CONTENT.items(): | |
with st.expander(title[:25]+"..."): | |
st.markdown(f"```\n{data['content'][:300]}...\n```") | |
def _build_main(self): | |
st.title("π§ Research Analysis System") | |
query = st.text_area("Enter your research query:", height=150) | |
if st.button("Start Analysis", type="primary"): | |
self._run_analysis(query) | |
def _run_analysis(self, query: str): | |
try: | |
with st.spinner("π Analyzing documents..."): | |
state = { | |
"messages": [HumanMessage(content=query)], | |
"context": { | |
"query": "", | |
"documents": [], | |
"errors": [] | |
}, | |
"metadata": {} | |
} | |
for event in self.workflow.stream(state): | |
self._display_progress(event) | |
final_state = self.workflow.invoke(state) | |
self._show_results(final_state) | |
except Exception as e: | |
st.error(f"""**Analysis Failed** | |
{str(e)} | |
Common solutions: | |
- Simplify your query | |
- Check document database status | |
- Verify API connectivity""") | |
def _display_progress(self, event): | |
current_state = next(iter(event.values())) | |
with st.container(): | |
st.markdown("---") | |
cols = st.columns([1,2,1]) | |
with cols[0]: | |
st.subheader("Processing Stage") | |
stage = list(event.keys())[0].title() | |
st.code(stage) | |
with cols[1]: | |
st.subheader("Documents") | |
docs = current_state["context"].get("documents", []) | |
st.metric("Retrieved", len(docs)) | |
with cols[2]: | |
st.subheader("Status") | |
if current_state["context"].get("errors"): | |
st.error("Errors detected") | |
else: | |
st.success("Normal operation") | |
def _show_results(self, state: ResearchState): | |
if state["context"].get("errors"): | |
st.error("Analysis completed with errors") | |
with st.expander("Error Details"): | |
for error in state["context"]["errors"]: | |
st.markdown(f"- {error}") | |
else: | |
st.success("Analysis completed successfully β ") | |
with st.expander("Full Report"): | |
st.markdown(state["messages"][-1].content) | |
if __name__ == "__main__": | |
ResearchInterface() |