mgbam's picture
Update app.py
09a0b53 verified
raw
history blame
16 kB
# ------------------------------
# Imports & Dependencies
# ------------------------------
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langgraph.graph import END, StateGraph
from typing_extensions import TypedDict, Annotated
from typing import Sequence, Dict, List, Optional, Any
import chromadb
import os
import streamlit as st
import requests
import hashlib
import json
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pydantic import BaseModel, ValidationError
import traceback
# ------------------------------
# Configuration & Constants
# ------------------------------
class ResearchConfig:
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
CHROMA_PATH = "chroma_db"
CHUNK_SIZE = 512
CHUNK_OVERLAP = 64
MAX_CONCURRENT_REQUESTS = 5
EMBEDDING_DIMENSIONS = 1536
ANALYSIS_TEMPLATE = """**Technical Analysis Request**
{context}
Respond with:
1. Key Technical Innovations (markdown table)
2. Methodological Breakdown (bullet points)
3. Quantitative Results (LaTeX equations)
4. Critical Evaluation
5. Research Impact Assessment
Include proper academic citations where applicable."""
# ------------------------------
# Document Schema & Content
# ------------------------------
DOCUMENT_CONTENT = {
"CV-Transformer Hybrid": {
"content": """## Hybrid Architecture for Computer Vision
**Authors**: DeepVision Research Team
**Abstract**: Novel combination of convolutional layers with transformer attention mechanisms.
### Key Innovations:
- Cross-attention feature fusion
- Adaptive spatial pooling
- Multi-scale gradient propagation
$$\\mathcal{L}_{total} = \\alpha\\mathcal{L}_{CE} + \\beta\\mathcal{L}_{SSIM}$$""",
"metadata": {
"year": 2024,
"domain": "computer_vision",
"citations": 142
}
},
"Quantum ML Advances": {
"content": """## Quantum Machine Learning Breakthroughs
**Authors**: Quantum AI Lab
### Achievements:
- Quantum-enhanced SGD (40% faster convergence)
- 5-qubit QNN achieving 98% accuracy
- Hybrid quantum-classical GANs
$$\\mathcal{H} = -\\sum_{i<j} J_{ij}\\sigma_i^z\\sigma_j^z - \\Gamma\\sum_i\\sigma_i^x$$""",
"metadata": {
"year": 2023,
"domain": "quantum_ml",
"citations": 89
}
}
}
class DocumentSchema(BaseModel):
content: str
metadata: dict
doc_id: str
# ------------------------------
# State Management
# ------------------------------
class ResearchState(TypedDict):
messages: Annotated[List[BaseMessage], add_messages]
context: Annotated[Dict[str, Any], "research_context"]
metadata: Annotated[Dict[str, str], "system_metadata"]
# ------------------------------
# Document Processing
# ------------------------------
class DocumentManager:
def __init__(self):
self.client = chromadb.PersistentClient(path=ResearchConfig.CHROMA_PATH)
self.embeddings = OpenAIEmbeddings(
model="text-embedding-3-large",
dimensions=ResearchConfig.EMBEDDING_DIMENSIONS
)
def initialize_collections(self):
try:
self.research_col = self._create_collection("research")
self.dev_col = self._create_collection("development")
except Exception as e:
st.error(f"Collection initialization failed: {str(e)}")
traceback.print_exc()
def _create_collection(self, name: str) -> Chroma:
documents, metadatas, ids = [], [], []
for title, data in DOCUMENT_CONTENT.items():
try:
doc = DocumentSchema(
content=data["content"],
metadata=data["metadata"],
doc_id=hashlib.sha256(title.encode()).hexdigest()[:16]
)
documents.append(doc.content)
metadatas.append(doc.metadata)
ids.append(doc.doc_id)
except ValidationError as e:
st.error(f"Invalid document format: {title} - {str(e)}")
continue
splitter = RecursiveCharacterTextSplitter(
chunk_size=ResearchConfig.CHUNK_SIZE,
chunk_overlap=ResearchConfig.CHUNK_OVERLAP,
separators=["\n## ", "\n### ", "\n\n", "\nβ€’ "]
)
try:
docs = splitter.create_documents(documents, metadatas=metadatas)
return Chroma.from_documents(
docs,
self.embeddings,
client=self.client,
collection_name=name,
ids=ids
)
except Exception as e:
raise RuntimeError(f"Failed creating {name} collection: {str(e)}")
# ------------------------------
# Retrieval System
# ------------------------------
class ResearchRetriever:
def __init__(self):
self.dm = DocumentManager()
self.dm.initialize_collections()
def retrieve(self, query: str, domain: str) -> List[DocumentSchema]:
try:
collection = self.dm.research_col if domain == "research" else self.dm.dev_col
if not collection:
return []
results = collection.as_retriever(
search_type="mmr",
search_kwargs={'k': 4, 'fetch_k': 20}
).invoke(query)
return [DocumentSchema(
content=doc.page_content,
metadata=doc.metadata,
doc_id=doc.metadata.get("doc_id", "")
) for doc in results if doc.page_content]
except Exception as e:
st.error(f"Retrieval failure: {str(e)}")
traceback.print_exc()
return []
# ------------------------------
# Analysis Processor
# ------------------------------
class AnalysisEngine:
def __init__(self):
self.executor = ThreadPoolExecutor(max_workers=ResearchConfig.MAX_CONCURRENT_REQUESTS)
self.session_hash = hashlib.sha256(str(time.time()).encode()).hexdigest()[:12]
def analyze(self, prompt: str) -> Dict:
futures = [self.executor.submit(self._api_request, prompt) for _ in range(3)]
return self._validate_results([f.result() for f in as_completed(futures)])
def _api_request(self, prompt: str) -> Dict:
headers = {
"Authorization": f"Bearer {ResearchConfig.DEEPSEEK_API_KEY}",
"X-Session-ID": self.session_hash,
"Content-Type": "application/json"
}
try:
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
headers=headers,
json={
"model": "deepseek-chat",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"max_tokens": 2000
},
timeout=30
)
response.raise_for_status()
return response.json()
except Exception as e:
return {"error": str(e), "status_code": 500}
def _validate_results(self, results: List[Dict]) -> Dict:
valid = [r for r in results if "error" not in r]
if not valid:
return {"error": "All analysis attempts failed", "results": results}
# Corrected line with proper parenthesis closure
best = max(valid, key=lambda x: len(x.get('choices', [{}])[0].get('message', {}).get('content', '')))
return best
# ------------------------------
# Workflow Implementation
# ------------------------------
class ResearchWorkflow:
def __init__(self):
self.retriever = ResearchRetriever()
self.engine = AnalysisEngine()
self.workflow = StateGraph(ResearchState)
self._build_graph()
def _build_graph(self):
self.workflow.add_node("ingest", self._ingest)
self.workflow.add_node("retrieve", self._retrieve)
self.workflow.add_node("analyze", self._analyze)
self.workflow.add_node("validate", self._validate)
self.workflow.add_node("refine", self._refine)
self.workflow.set_entry_point("ingest")
self.workflow.add_edge("ingest", "retrieve")
self.workflow.add_edge("retrieve", "analyze")
self.workflow.add_conditional_edges(
"analyze",
self._quality_gate,
{"valid": "validate", "invalid": "refine"}
)
self.workflow.add_edge("validate", END)
self.workflow.add_edge("refine", "retrieve")
def _ingest(self, state: ResearchState) -> ResearchState:
try:
query = next(msg.content for msg in reversed(state["messages"])
if isinstance(msg, HumanMessage))
return {
"messages": [AIMessage(content="Query ingested")],
"context": {
"query": query,
"documents": [],
"errors": []
},
"metadata": {
"session_id": hashlib.sha256(str(time.time()).encode()).hexdigest()[:8],
"timestamp": datetime.now().isoformat()
}
}
except Exception as e:
return self._handle_error(f"Ingest failed: {str(e)}", state)
def _retrieve(self, state: ResearchState) -> ResearchState:
try:
docs = self.retriever.retrieve(state["context"]["query"], "research")
return {
"messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
"context": {
**state["context"],
"documents": docs,
"retrieval_time": time.time()
},
"metadata": state["metadata"]
}
except Exception as e:
return self._handle_error(f"Retrieval error: {str(e)}", state)
def _analyze(self, state: ResearchState) -> ResearchState:
docs = state["context"].get("documents", [])
if not docs:
return self._handle_error("No documents for analysis", state)
try:
context = "\n\n".join([d.content for d in docs])
prompt = ResearchConfig.ANALYSIS_TEMPLATE.format(context=context)
result = self.engine.analyze(prompt)
if "error" in result:
raise RuntimeError(result["error"])
content = result['choices'][0]['message']['content']
if len(content) < 200 or not any(c.isalpha() for c in content):
raise ValueError("Insufficient analysis content")
return {
"messages": [AIMessage(content=content)],
"context": state["context"],
"metadata": state["metadata"]
}
except Exception as e:
return self._handle_error(f"Analysis failed: {str(e)}", state)
def _validate(self, state: ResearchState) -> ResearchState:
return state
def _refine(self, state: ResearchState) -> ResearchState:
return state
def _quality_gate(self, state: ResearchState) -> str:
content = state["messages"][-1].content if state["messages"] else ""
required = ["Innovations", "Results", "Evaluation"]
return "valid" if all(kw in content for kw in required) else "invalid"
def _handle_error(self, message: str, state: ResearchState) -> ResearchState:
return {
"messages": [AIMessage(content=f"🚨 Error: {message}")],
"context": {
**state["context"],
"errors": state["context"]["errors"] + [message]
},
"metadata": state["metadata"]
}
# ------------------------------
# User Interface
# ------------------------------
class ResearchInterface:
def __init__(self):
self.workflow = ResearchWorkflow().workflow.compile()
self._setup_interface()
def _setup_interface(self):
st.set_page_config(
page_title="Research Assistant",
layout="wide",
initial_sidebar_state="expanded"
)
self._apply_styles()
self._build_sidebar()
self._build_main()
def _apply_styles(self):
st.markdown("""
<style>
.stApp {
background: #0a192f;
color: #64ffda;
}
.stTextArea textarea {
background: #172a45 !important;
color: #a8b2d1 !important;
}
.stButton>button {
background: #233554;
border: 1px solid #64ffda;
}
.error-box {
border: 1px solid #ff4444;
border-radius: 5px;
padding: 1rem;
margin: 1rem 0;
}
</style>
""", unsafe_allow_html=True)
def _build_sidebar(self):
with st.sidebar:
st.title("πŸ” Document Database")
for title, data in DOCUMENT_CONTENT.items():
with st.expander(title[:25]+"..."):
st.markdown(f"```\n{data['content'][:300]}...\n```")
def _build_main(self):
st.title("🧠 Research Analysis System")
query = st.text_area("Enter your research query:", height=150)
if st.button("Start Analysis", type="primary"):
self._run_analysis(query)
def _run_analysis(self, query: str):
try:
with st.spinner("πŸ” Analyzing documents..."):
state = {
"messages": [HumanMessage(content=query)],
"context": {
"query": "",
"documents": [],
"errors": []
},
"metadata": {}
}
for event in self.workflow.stream(state):
self._display_progress(event)
final_state = self.workflow.invoke(state)
self._show_results(final_state)
except Exception as e:
st.error(f"""**Analysis Failed**
{str(e)}
Common solutions:
- Simplify your query
- Check document database status
- Verify API connectivity""")
def _display_progress(self, event):
current_state = next(iter(event.values()))
with st.container():
st.markdown("---")
cols = st.columns([1,2,1])
with cols[0]:
st.subheader("Processing Stage")
stage = list(event.keys())[0].title()
st.code(stage)
with cols[1]:
st.subheader("Documents")
docs = current_state["context"].get("documents", [])
st.metric("Retrieved", len(docs))
with cols[2]:
st.subheader("Status")
if current_state["context"].get("errors"):
st.error("Errors detected")
else:
st.success("Normal operation")
def _show_results(self, state: ResearchState):
if state["context"].get("errors"):
st.error("Analysis completed with errors")
with st.expander("Error Details"):
for error in state["context"]["errors"]:
st.markdown(f"- {error}")
else:
st.success("Analysis completed successfully βœ…")
with st.expander("Full Report"):
st.markdown(state["messages"][-1].content)
if __name__ == "__main__":
ResearchInterface()