mgbam's picture
Update app.py
9f9113f verified
raw
history blame
16 kB
# ------------------------------
# Imports & Dependencies
# ------------------------------
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langgraph.graph import END, StateGraph
from typing_extensions import TypedDict, Annotated
from typing import Sequence, Dict, List, Optional, Any
from langgraph.graph.message import add_messages
import chromadb
import os
import streamlit as st
import requests
import hashlib
import json
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pydantic import BaseModel, ValidationError
import traceback
# ------------------------------
# Configuration & Constants
# ------------------------------
class ResearchConfig:
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
CHROMA_PATH = "chroma_db"
CHUNK_SIZE = 512
CHUNK_OVERLAP = 64
MAX_CONCURRENT_REQUESTS = 5
EMBEDDING_DIMENSIONS = 1536
ANALYSIS_TEMPLATE = """**Technical Analysis Request**
{context}
Respond with:
1. Key Technical Innovations (markdown table)
2. Methodological Breakdown (bullet points)
3. Quantitative Results (LaTeX equations)
4. Critical Evaluation
5. Research Impact Assessment
Include proper academic citations where applicable."""
# ------------------------------
# Document Schema & Content
# ------------------------------
DOCUMENT_CONTENT = {
"CV-Transformer Hybrid": {
"content": """## Hybrid Architecture for Computer Vision
**Authors**: DeepVision Research Team
**Abstract**: Novel combination of convolutional layers with transformer attention mechanisms.
### Key Innovations:
- Cross-attention feature fusion
- Adaptive spatial pooling
- Multi-scale gradient propagation
$$\\mathcal{L}_{total} = \\alpha\\mathcal{L}_{CE} + \\beta\\mathcal{L}_{SSIM}$$""",
"metadata": {
"year": 2024,
"domain": "computer_vision",
"citations": 142
}
},
"Quantum ML Advances": {
"content": """## Quantum Machine Learning Breakthroughs
**Authors**: Quantum AI Lab
### Achievements:
- Quantum-enhanced SGD (40% faster convergence)
- 5-qubit QNN achieving 98% accuracy
- Hybrid quantum-classical GANs
$$\\mathcal{H} = -\\sum_{i<j} J_{ij}\\sigma_i^z\\sigma_j^z - \\Gamma\\sum_i\\sigma_i^x$$""",
"metadata": {
"year": 2023,
"domain": "quantum_ml",
"citations": 89
}
}
}
class DocumentSchema(BaseModel):
content: str
metadata: dict
doc_id: str
# ------------------------------
# State Management
# ------------------------------
class ResearchState(TypedDict):
messages: Annotated[List[BaseMessage], add_messages]
context: Annotated[Dict[str, Any], "research_context"]
metadata: Annotated[Dict[str, str], "system_metadata"]
# ------------------------------
# Document Processing
# ------------------------------
class DocumentManager:
def __init__(self):
self.client = chromadb.PersistentClient(path=ResearchConfig.CHROMA_PATH)
self.embeddings = OpenAIEmbeddings(
model="text-embedding-3-large",
dimensions=ResearchConfig.EMBEDDING_DIMENSIONS
)
def initialize_collections(self):
try:
self.research_col = self._create_collection("research")
self.dev_col = self._create_collection("development")
except Exception as e:
st.error(f"Collection initialization failed: {str(e)}")
traceback.print_exc()
def _create_collection(self, name: str) -> Chroma:
documents, metadatas, ids = [], [], []
for title, data in DOCUMENT_CONTENT.items():
try:
doc = DocumentSchema(
content=data["content"],
metadata=data["metadata"],
doc_id=hashlib.sha256(title.encode()).hexdigest()[:16]
)
documents.append(doc.content)
metadatas.append(doc.metadata)
ids.append(doc.doc_id)
except ValidationError as e:
st.error(f"Invalid document format: {title} - {str(e)}")
continue
splitter = RecursiveCharacterTextSplitter(
chunk_size=ResearchConfig.CHUNK_SIZE,
chunk_overlap=ResearchConfig.CHUNK_OVERLAP,
separators=["\n## ", "\n### ", "\n\n", "\nβ€’ "]
)
try:
docs = splitter.create_documents(documents, metadatas=metadatas)
return Chroma.from_documents(
docs,
self.embeddings,
client=self.client,
collection_name=name,
ids=ids
)
except Exception as e:
raise RuntimeError(f"Failed creating {name} collection: {str(e)}")
# ------------------------------
# Retrieval System
# ------------------------------
class ResearchRetriever:
def __init__(self):
self.dm = DocumentManager()
self.dm.initialize_collections()
def retrieve(self, query: str, domain: str) -> List[DocumentSchema]:
try:
collection = self.dm.research_col if domain == "research" else self.dm.dev_col
if not collection:
return []
results = collection.as_retriever(
search_type="mmr",
search_kwargs={'k': 4, 'fetch_k': 20}
).invoke(query)
return [DocumentSchema(
content=doc.page_content,
metadata=doc.metadata,
doc_id=doc.metadata.get("doc_id", "")
) for doc in results if doc.page_content]
except Exception as e:
st.error(f"Retrieval failure: {str(e)}")
traceback.print_exc()
return []
# ------------------------------
# Analysis Processor
# ------------------------------
class AnalysisEngine:
def __init__(self):
self.executor = ThreadPoolExecutor(max_workers=ResearchConfig.MAX_CONCURRENT_REQUESTS)
self.session_hash = hashlib.sha256(str(time.time()).encode()).hexdigest()[:12]
def analyze(self, prompt: str) -> Dict:
futures = [self.executor.submit(self._api_request, prompt) for _ in range(3)]
return self._validate_results([f.result() for f in as_completed(futures)])
def _api_request(self, prompt: str) -> Dict:
headers = {
"Authorization": f"Bearer {ResearchConfig.DEEPSEEK_API_KEY}",
"X-Session-ID": self.session_hash,
"Content-Type": "application/json"
}
try:
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
headers=headers,
json={
"model": "deepseek-chat",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"max_tokens": 2000
},
timeout=30
)
response.raise_for_status()
return response.json()
except Exception as e:
return {"error": str(e), "status_code": 500}
def _validate_results(self, results: List[Dict]) -> Dict:
valid = [r for r in results if "error" not in r]
if not valid:
return {"error": "All analysis attempts failed", "results": results}
# Corrected line with proper parenthesis closure
best = max(valid, key=lambda x: len(x.get('choices', [{}])[0].get('message', {}).get('content', '')))
return best
# ------------------------------
# Workflow Implementation
# ------------------------------
class ResearchWorkflow:
def __init__(self):
self.retriever = ResearchRetriever()
self.engine = AnalysisEngine()
self.workflow = StateGraph(ResearchState)
self._build_graph()
def _build_graph(self):
self.workflow.add_node("ingest", self._ingest)
self.workflow.add_node("retrieve", self._retrieve)
self.workflow.add_node("analyze", self._analyze)
self.workflow.add_node("validate", self._validate)
self.workflow.add_node("refine", self._refine)
self.workflow.set_entry_point("ingest")
self.workflow.add_edge("ingest", "retrieve")
self.workflow.add_edge("retrieve", "analyze")
self.workflow.add_conditional_edges(
"analyze",
self._quality_gate,
{"valid": "validate", "invalid": "refine"}
)
self.workflow.add_edge("validate", END)
self.workflow.add_edge("refine", "retrieve")
def _ingest(self, state: ResearchState) -> ResearchState:
try:
query = next(msg.content for msg in reversed(state["messages"])
if isinstance(msg, HumanMessage))
return {
"messages": [AIMessage(content="Query ingested")],
"context": {
"query": query,
"documents": [],
"errors": []
},
"metadata": {
"session_id": hashlib.sha256(str(time.time()).encode()).hexdigest()[:8],
"timestamp": datetime.now().isoformat()
}
}
except Exception as e:
return self._handle_error(f"Ingest failed: {str(e)}", state)
def _retrieve(self, state: ResearchState) -> ResearchState:
try:
docs = self.retriever.retrieve(state["context"]["query"], "research")
return {
"messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
"context": {
**state["context"],
"documents": docs,
"retrieval_time": time.time()
},
"metadata": state["metadata"]
}
except Exception as e:
return self._handle_error(f"Retrieval error: {str(e)}", state)
def _analyze(self, state: ResearchState) -> ResearchState:
docs = state["context"].get("documents", [])
if not docs:
return self._handle_error("No documents for analysis", state)
try:
context = "\n\n".join([d.content for d in docs])
prompt = ResearchConfig.ANALYSIS_TEMPLATE.format(context=context)
result = self.engine.analyze(prompt)
if "error" in result:
raise RuntimeError(result["error"])
content = result['choices'][0]['message']['content']
if len(content) < 200 or not any(c.isalpha() for c in content):
raise ValueError("Insufficient analysis content")
return {
"messages": [AIMessage(content=content)],
"context": state["context"],
"metadata": state["metadata"]
}
except Exception as e:
return self._handle_error(f"Analysis failed: {str(e)}", state)
def _validate(self, state: ResearchState) -> ResearchState:
return state
def _refine(self, state: ResearchState) -> ResearchState:
return state
def _quality_gate(self, state: ResearchState) -> str:
content = state["messages"][-1].content if state["messages"] else ""
required = ["Innovations", "Results", "Evaluation"]
return "valid" if all(kw in content for kw in required) else "invalid"
def _handle_error(self, message: str, state: ResearchState) -> ResearchState:
return {
"messages": [AIMessage(content=f"🚨 Error: {message}")],
"context": {
**state["context"],
"errors": state["context"]["errors"] + [message]
},
"metadata": state["metadata"]
}
# ------------------------------
# User Interface
# ------------------------------
class ResearchInterface:
def __init__(self):
self.workflow = ResearchWorkflow().workflow.compile()
self._setup_interface()
def _setup_interface(self):
st.set_page_config(
page_title="Research Assistant",
layout="wide",
initial_sidebar_state="expanded"
)
self._apply_styles()
self._build_sidebar()
self._build_main()
def _apply_styles(self):
st.markdown("""
<style>
.stApp {
background: #0a192f;
color: #64ffda;
}
.stTextArea textarea {
background: #172a45 !important;
color: #a8b2d1 !important;
}
.stButton>button {
background: #233554;
border: 1px solid #64ffda;
}
.error-box {
border: 1px solid #ff4444;
border-radius: 5px;
padding: 1rem;
margin: 1rem 0;
}
</style>
""", unsafe_allow_html=True)
def _build_sidebar(self):
with st.sidebar:
st.title("πŸ” Document Database")
for title, data in DOCUMENT_CONTENT.items():
with st.expander(title[:25]+"..."):
st.markdown(f"```\n{data['content'][:300]}...\n```")
def _build_main(self):
st.title("🧠 Research Analysis System")
query = st.text_area("Enter your research query:", height=150)
if st.button("Start Analysis", type="primary"):
self._run_analysis(query)
def _run_analysis(self, query: str):
try:
with st.spinner("πŸ” Analyzing documents..."):
state = {
"messages": [HumanMessage(content=query)],
"context": {
"query": "",
"documents": [],
"errors": []
},
"metadata": {}
}
for event in self.workflow.stream(state):
self._display_progress(event)
final_state = self.workflow.invoke(state)
self._show_results(final_state)
except Exception as e:
st.error(f"""**Analysis Failed**
{str(e)}
Common solutions:
- Simplify your query
- Check document database status
- Verify API connectivity""")
def _display_progress(self, event):
current_state = next(iter(event.values()))
with st.container():
st.markdown("---")
cols = st.columns([1,2,1])
with cols[0]:
st.subheader("Processing Stage")
stage = list(event.keys())[0].title()
st.code(stage)
with cols[1]:
st.subheader("Documents")
docs = current_state["context"].get("documents", [])
st.metric("Retrieved", len(docs))
with cols[2]:
st.subheader("Status")
if current_state["context"].get("errors"):
st.error("Errors detected")
else:
st.success("Normal operation")
def _show_results(self, state: ResearchState):
if state["context"].get("errors"):
st.error("Analysis completed with errors")
with st.expander("Error Details"):
for error in state["context"]["errors"]:
st.markdown(f"- {error}")
else:
st.success("Analysis completed successfully βœ…")
with st.expander("Full Report"):
st.markdown(state["messages"][-1].content)
if __name__ == "__main__":
ResearchInterface()