Spaces:
Running
Running
# ------------------------------ | |
# Imports | |
# ------------------------------ | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langgraph.graph import END, StateGraph | |
from langgraph.prebuilt import ToolNode | |
from langgraph.graph.message import add_messages | |
from typing_extensions import TypedDict, Annotated | |
from typing import Sequence, List, Dict, Any | |
import chromadb | |
import re | |
import os | |
import streamlit as st | |
import requests | |
import time | |
import hashlib | |
from langchain.tools.retriever import create_retriever_tool | |
from datetime import datetime | |
# ------------------------------ | |
# Data | |
# ------------------------------ | |
research_texts = [ | |
"Research Report: Results of a New AI Model Improving Image Recognition Accuracy to 98%", | |
"Academic Paper Summary: Why Transformers Became the Mainstream Architecture in Natural Language Processing", | |
"Latest Trends in Machine Learning Methods Using Quantum Computing", | |
"Advancements in Neuromorphic Computing for Energy-Efficient AI Systems", | |
"Cross-Modal Learning: Integrating Visual and Textual Representations for Multimodal AI" | |
] | |
development_texts = [ | |
"Project A: UI Design Completed, API Integration in Progress", | |
"Project B: Testing New Feature X, Bug Fixes Needed", | |
"Product Y: In the Performance Optimization Stage Before Release", | |
"Framework Z: Version 3.2 Released with Enhanced Distributed Training Support", | |
"DevOps Pipeline: Automated CI/CD Implementation for ML Model Deployment" | |
] | |
# ------------------------------ | |
# Configuration | |
# ------------------------------ | |
class AppConfig: | |
def __init__(self): | |
self.DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY") | |
self.CHROMA_PATH = "chroma_db" | |
self.MAX_RETRIES = 3 | |
self.RETRY_DELAY = 1.5 | |
self.DOCUMENT_CHUNK_SIZE = 300 | |
self.DOCUMENT_OVERLAP = 50 | |
self.SEARCH_K = 5 | |
self.SEARCH_TYPE = "mmr" | |
def validate(self): | |
if not self.DEEPSEEK_API_KEY: | |
st.error(""" | |
**Configuration Error** | |
π Missing DeepSeek API key. | |
Configure through Hugging Face Space secrets: | |
1. Space Settings β Repository secrets | |
2. Add secret: DEEPSEEK_API_KEY=your_key | |
3. Rebuild Space | |
""") | |
st.stop() | |
# ------------------------------ | |
# Chroma Setup | |
# ------------------------------ | |
class ChromaManager: | |
def __init__(self, config: AppConfig): | |
os.makedirs(config.CHROMA_PATH, exist_ok=True) | |
self.client = chromadb.PersistentClient(path=config.CHROMA_PATH) | |
self.embeddings = OpenAIEmbeddings(model="text-embedding-3-large") | |
self.research_collection = self._create_collection( | |
research_texts, | |
"research_collection", | |
{"category": "research"} | |
) | |
self.dev_collection = self._create_collection( | |
development_texts, | |
"development_collection", | |
{"category": "development"} | |
) | |
def _create_collection(self, documents: List[str], name: str, metadata: dict) -> Chroma: | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=300, | |
chunk_overlap=50, | |
separators=["\n\n", "\n", "γ"] | |
) | |
docs = splitter.create_documents(documents) | |
return Chroma.from_documents( | |
documents=docs, | |
embedding=self.embeddings, | |
client=self.client, | |
collection_name=name, | |
collection_metadata=metadata | |
) | |
# ------------------------------ | |
# Document Processing | |
# ------------------------------ | |
class DocumentProcessor: | |
def deduplicate(docs: List[Any]) -> List[Any]: | |
seen = set() | |
return [doc for doc in docs | |
if not (hashlib.md5(doc.page_content.encode()).hexdigest() in seen | |
or seen.add(hashlib.md5(doc.page_content.encode()).hexdigest()))] | |
def extract_keypoints(docs: List[Any]) -> str: | |
categories = { | |
"quantum": ["quantum", "qubit"], | |
"vision": ["image", "recognition"], | |
"nlp": ["transformer", "language"] | |
} | |
return "\n".join(sorted({ | |
"- " + { | |
"quantum": "Quantum computing breakthroughs", | |
"vision": "Computer vision advancements", | |
"nlp": "NLP architecture innovations" | |
}[cat] | |
for doc in docs | |
for cat, kw in categories.items() | |
if any(k in doc.page_content.lower() for k in kw) | |
})) | |
# ------------------------------ | |
# Workflow State | |
# ------------------------------ | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages] | |
# ------------------------------ | |
# Workflow Setup | |
# ------------------------------ | |
class AgentWorkflow: | |
def __init__(self, chroma: ChromaManager): | |
self.chroma = chroma | |
self.workflow = StateGraph(AgentState) | |
# Define nodes | |
self.workflow.add_node("agent", self.agent) | |
self.workflow.add_node("retrieve", ToolNode([ | |
create_retriever_tool( | |
chroma.research_collection.as_retriever(), | |
"research_tool", | |
"Search research documents" | |
), | |
create_retriever_tool( | |
chroma.dev_collection.as_retriever(), | |
"dev_tool", | |
"Search development updates" | |
) | |
])) | |
self.workflow.add_node("generate", self.generate) | |
self.workflow.add_node("rewrite", self.rewrite) | |
# Define edges | |
self.workflow.set_entry_point("agent") | |
self.workflow.add_conditional_edges( | |
"agent", | |
self._tools_condition, | |
{"retrieve": "retrieve", "end": END} | |
) | |
self.workflow.add_conditional_edges( | |
"retrieve", | |
self._grade_documents, | |
{"generate": "generate", "rewrite": "rewrite"} | |
) | |
self.workflow.add_edge("generate", END) | |
self.workflow.add_edge("rewrite", "agent") | |
self.app = self.workflow.compile() | |
def agent(self, state: AgentState): | |
try: | |
messages = state["messages"] | |
query = messages[-1].content if isinstance(messages[-1], HumanMessage) else messages[-1]['content'] | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers={"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}"}, | |
json={ | |
"model": "deepseek-chat", | |
"messages": [{ | |
"role": "user", | |
"content": f"""Analyze this query: "{query}" | |
Respond EXACTLY as: | |
- SEARCH_RESEARCH: <terms> (for research topics) | |
- SEARCH_DEV: <terms> (for development updates) | |
- DIRECT: <answer> (otherwise)""" | |
}] | |
} | |
).json() | |
content = response['choices'][0]['message']['content'] | |
if "SEARCH_RESEARCH:" in content: | |
terms = content.split("SEARCH_RESEARCH:")[1].strip() | |
results = self.chroma.research_collection.similarity_search(terms) | |
return {"messages": [AIMessage(content=f"Research Results: {str(results)}")]} | |
elif "SEARCH_DEV:" in content: | |
terms = content.split("SEARCH_DEV:")[1].strip() | |
results = self.chroma.dev_collection.similarity_search(terms) | |
return {"messages": [AIMessage(content=f"Development Results: {str(results)}")]} | |
return {"messages": [AIMessage(content=content)]} | |
except Exception as e: | |
return {"messages": [AIMessage(content=f"Error: {str(e)}")]} | |
def generate(self, state: AgentState): | |
docs = eval(state["messages"][-1].content.split("Results: ")[1]) | |
processed = "\n".join([d.page_content[:200] for d in DocumentProcessor.deduplicate(docs)]) | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers={"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}"}, | |
json={ | |
"model": "deepseek-chat", | |
"messages": [{ | |
"role": "user", | |
"content": f"Summarize these findings:\n{processed}" | |
}] | |
} | |
).json() | |
return {"messages": [AIMessage(content=response['choices'][0]['message']['content'])]} | |
def rewrite(self, state: AgentState): | |
original = state["messages"][0].content | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers={"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}"}, | |
json={ | |
"model": "deepseek-chat", | |
"messages": [{ | |
"role": "user", | |
"content": f"Rephrase this query: {original}" | |
}] | |
} | |
).json() | |
return {"messages": [AIMessage(content=response['choices'][0]['message']['content'])]} | |
def _tools_condition(self, state: AgentState): | |
return "retrieve" if "Results:" in state["messages"][-1].content else "end" | |
def _grade_documents(self, state: AgentState): | |
return "generate" if len(eval(state["messages"][-1].content.split("Results: ")[1])) > 0 else "rewrite" | |
# ------------------------------ | |
# Streamlit App | |
# ------------------------------ | |
def apply_theme(): | |
st.markdown(""" | |
<style> | |
.stApp { background: #1a1a1a; color: white; } | |
.stTextArea textarea { background: #2d2d2d !important; color: white !important; } | |
.stButton>button { background: #2E86C1; transition: 0.3s; } | |
.stButton>button:hover { background: #1B4F72; transform: scale(1.02); } | |
.data-box { background: #2d2d2d; border-left: 4px solid #2E86C1; padding: 15px; margin: 10px 0; } | |
</style> | |
""", unsafe_allow_html=True) | |
def main(config: AppConfig, chroma: ChromaManager): | |
apply_theme() | |
with st.sidebar: | |
st.header("π Databases") | |
with st.expander("Research", expanded=True): | |
for text in research_texts: | |
st.markdown(f'<div class="data-box">{text}</div>', unsafe_allow_html=True) | |
with st.expander("Development"): | |
for text in development_texts: | |
st.markdown(f'<div class="data-box">{text}</div>', unsafe_allow_html=True) | |
st.title("π AI Research Assistant") | |
query = st.text_area("Enter your query:", height=100) | |
if st.button("Analyze"): | |
with st.spinner("Processing..."): | |
try: | |
workflow = AgentWorkflow(chroma) | |
results = workflow.app.invoke({"messages": [HumanMessage(content=query)]}) | |
with st.expander("Processing Details", expanded=True): | |
st.write("### Raw Results", results) | |
st.success("### Final Answer") | |
st.markdown(results['messages'][-1].content) | |
except Exception as e: | |
st.error(f"Processing failed: {str(e)}") | |
# ------------------------------ | |
# Initialization | |
# ------------------------------ | |
if __name__ == "__main__": | |
st.set_page_config( | |
page_title="AI Research Assistant", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
try: | |
config = AppConfig() | |
config.validate() | |
chroma = ChromaManager(config) | |
main(config, chroma) | |
except Exception as e: | |
st.error(f"Initialization failed: {str(e)}") |