Spaces:
Running
Running
# ------------------------------ | |
# Imports & Dependencies | |
# ------------------------------ | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langgraph.graph import END, StateGraph | |
from langgraph.prebuilt import ToolNode | |
from langgraph.graph.message import add_messages | |
from typing_extensions import TypedDict, Annotated | |
from typing import Sequence, List, Dict, Any | |
import chromadb | |
import re | |
import os | |
import streamlit as st | |
import requests | |
import time | |
import hashlib | |
from langchain.tools.retriever import create_retriever_tool | |
from datetime import datetime | |
# ------------------------------ | |
# Data Definitions | |
# ------------------------------ | |
research_texts = [ | |
"Research Report: Results of a New AI Model Improving Image Recognition Accuracy to 98%", | |
"Academic Paper Summary: Why Transformers Became the Mainstream Architecture in Natural Language Processing", | |
"Latest Trends in Machine Learning Methods Using Quantum Computing", | |
"Advancements in Neuromorphic Computing for Energy-Efficient AI Systems", | |
"Cross-Modal Learning: Integrating Visual and Textual Representations for Multimodal AI" | |
] | |
development_texts = [ | |
"Project A: UI Design Completed, API Integration in Progress", | |
"Project B: Testing New Feature X, Bug Fixes Needed", | |
"Product Y: In the Performance Optimization Stage Before Release", | |
"Framework Z: Version 3.2 Released with Enhanced Distributed Training Support", | |
"DevOps Pipeline: Automated CI/CD Implementation for ML Model Deployment" | |
] | |
# ------------------------------ | |
# Configuration Class | |
# ------------------------------ | |
class AppConfig: | |
def __init__(self): | |
self.DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY") | |
self.CHROMA_PATH = "chroma_db" | |
self.MAX_RETRIES = 3 | |
self.RETRY_DELAY = 1.5 | |
self.DOCUMENT_CHUNK_SIZE = 300 | |
self.DOCUMENT_OVERLAP = 50 | |
self.SEARCH_K = 5 | |
self.SEARCH_TYPE = "mmr" | |
self.validate_config() | |
def validate_config(self): | |
if not self.DEEPSEEK_API_KEY: | |
st.error(""" | |
**Critical Configuration Missing** | |
π DeepSeek API key not found in environment variables. | |
Please configure through Hugging Face Space secrets: | |
1. Go to Space Settings β Repository secrets | |
2. Add secret: Name=DEEPSEEK_API_KEY, Value=your_api_key | |
3. Rebuild Space | |
""") | |
st.stop() | |
config = AppConfig() | |
# ------------------------------ | |
# ChromaDB Manager | |
# ------------------------------ | |
class ChromaManager: | |
def __init__(self, research_data: List[str], development_data: List[str]): | |
os.makedirs(config.CHROMA_PATH, exist_ok=True) | |
self.client = chromadb.PersistentClient(path=config.CHROMA_PATH) | |
self.embeddings = OpenAIEmbeddings(model="text-embedding-3-large") | |
self.research_collection = self.create_collection( | |
research_data, | |
"research_collection", | |
{"category": "research", "version": "1.2"} | |
) | |
self.dev_collection = self.create_collection( | |
development_data, | |
"development_collection", | |
{"category": "development", "version": "1.1"} | |
) | |
def create_collection(self, documents: List[str], name: str, metadata: dict) -> Chroma: | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=config.DOCUMENT_CHUNK_SIZE, | |
chunk_overlap=config.DOCUMENT_OVERLAP, | |
separators=["\n\n", "\n", "γ", " "] | |
) | |
docs = text_splitter.create_documents(documents) | |
return Chroma.from_documents( | |
documents=docs, | |
embedding=self.embeddings, | |
client=self.client, | |
collection_name=name, | |
collection_metadata=metadata | |
) | |
# Initialize Chroma with data | |
chroma_manager = ChromaManager(research_texts, development_texts) | |
# ------------------------------ | |
# Document Processing | |
# ------------------------------ | |
class DocumentProcessor: | |
def deduplicate_documents(docs: List[Any]) -> List[Any]: | |
seen = set() | |
unique_docs = [] | |
for doc in docs: | |
content_hash = hashlib.md5(doc.page_content.encode()).hexdigest() | |
if content_hash not in seen: | |
unique_docs.append(doc) | |
seen.add(content_hash) | |
return unique_docs | |
def extract_key_points(docs: List[Any]) -> str: | |
key_points = [] | |
categories = { | |
"quantum": ["quantum", "qpu", "qubit"], | |
"vision": ["image", "recognition", "vision"], | |
"nlp": ["transformer", "language", "llm"] | |
} | |
for doc in docs: | |
content = doc.page_content.lower() | |
if any(kw in content for kw in categories["quantum"]): | |
key_points.append("- Quantum computing integration showing promising results") | |
if any(kw in content for kw in categories["vision"]): | |
key_points.append("- Computer vision models achieving state-of-the-art accuracy") | |
if any(kw in content for kw in categories["nlp"]): | |
key_points.append("- NLP architectures evolving with memory-augmented transformers") | |
return "\n".join(list(set(key_points))) | |
# ------------------------------ | |
# Enhanced Agent Components | |
# ------------------------------ | |
class EnhancedAgent: | |
def __init__(self): | |
self.session_stats = { | |
"processing_times": [], | |
"doc_counts": [], | |
"error_count": 0 | |
} | |
def api_request_with_retry(self, endpoint: str, payload: Dict) -> Dict: | |
headers = { | |
"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
for attempt in range(config.MAX_RETRIES): | |
try: | |
response = requests.post( | |
endpoint, | |
headers=headers, | |
json=payload, | |
timeout=30, | |
verify=False | |
) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.HTTPError as e: | |
if e.response.status_code == 429: | |
delay = config.RETRY_DELAY ** (attempt + 1) | |
time.sleep(delay) | |
continue | |
raise | |
raise Exception(f"API request failed after {config.MAX_RETRIES} attempts") | |
# ------------------------------ | |
# Workflow Configuration | |
# ------------------------------ | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages] | |
def agent(state: AgentState): | |
print("---CALL AGENT---") | |
messages = state["messages"] | |
user_message = messages[0].content if not isinstance(messages[0], tuple) else messages[0][1] | |
prompt = f"""Given this user question: "{user_message}" | |
If about research/academic topics, respond EXACTLY: | |
SEARCH_RESEARCH: <search terms> | |
If about development status, respond EXACTLY: | |
SEARCH_DEV: <search terms> | |
Otherwise, answer directly.""" | |
headers = { | |
"Accept": "application/json", | |
"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": "deepseek-chat", | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.7, | |
"max_tokens": 1024 | |
} | |
try: | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers=headers, | |
json=data, | |
verify=False, | |
timeout=30 | |
) | |
response.raise_for_status() | |
response_text = response.json()['choices'][0]['message']['content'] | |
if "SEARCH_RESEARCH:" in response_text: | |
query = response_text.split("SEARCH_RESEARCH:")[1].strip() | |
results = chroma_manager.research_collection.as_retriever().invoke(query) | |
return {"messages": [AIMessage(content=f'Action: research_db_tool\n{{"query": "{query}"}}\n\nResults: {str(results)}')]} | |
elif "SEARCH_DEV:" in response_text: | |
query = response_text.split("SEARCH_DEV:")[1].strip() | |
results = chroma_manager.dev_collection.as_retriever().invoke(query) | |
return {"messages": [AIMessage(content=f'Action: development_db_tool\n{{"query": "{query}"}}\n\nResults: {str(results)}')]} | |
return {"messages": [AIMessage(content=response_text)]} | |
except Exception as e: | |
error_msg = f"API Error: {str(e)}" | |
if "Insufficient Balance" in str(e): | |
error_msg += "\n\nPlease check your DeepSeek API account balance." | |
return {"messages": [AIMessage(content=error_msg)]} | |
def simple_grade_documents(state: AgentState): | |
messages = state["messages"] | |
last_message = messages[-1] | |
return "generate" if "Results: [Document" in last_message.content else "rewrite" | |
def generate(state: AgentState): | |
messages = state["messages"] | |
question = messages[0].content | |
last_message = messages[-1] | |
docs_content = [] | |
if "Results: [" in last_message.content: | |
docs_str = last_message.content.split("Results: ")[1] | |
docs_content = eval(docs_str) | |
processed_info = DocumentProcessor.extract_key_points( | |
DocumentProcessor.deduplicate_documents(docs_content) | |
) | |
prompt = f"""Generate structured research summary: | |
Key Information: | |
{processed_info} | |
Include: | |
1. Section headings | |
2. Bullet points | |
3. Significance | |
4. Applications""" | |
try: | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers={ | |
"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}", | |
"Content-Type": "application/json" | |
}, | |
json={ | |
"model": "deepseek-chat", | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.7, | |
"max_tokens": 1024 | |
}, | |
timeout=30 | |
) | |
response.raise_for_status() | |
return {"messages": [AIMessage(content=response.json()['choices'][0]['message']['content'])]} | |
except Exception as e: | |
return {"messages": [AIMessage(content=f"Generation Error: {str(e)}")]} | |
def rewrite(state: AgentState): | |
messages = state["messages"] | |
original_question = messages[0].content | |
try: | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers={ | |
"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}", | |
"Content-Type": "application/json" | |
}, | |
json={ | |
"model": "deepseek-chat", | |
"messages": [{ | |
"role": "user", | |
"content": f"Rewrite for clarity: {original_question}" | |
}], | |
"temperature": 0.7, | |
"max_tokens": 1024 | |
}, | |
timeout=30 | |
) | |
response.raise_for_status() | |
return {"messages": [AIMessage(content=response.json()['choices'][0]['message']['content'])} | |
except Exception as e: | |
return {"messages": [AIMessage(content=f"Rewrite Error: {str(e)}")]} | |
tools_pattern = re.compile(r"Action: .*") | |
def custom_tools_condition(state: AgentState): | |
content = state["messages"][-1].content | |
return "tools" if tools_pattern.match(content) else END | |
# ------------------------------ | |
# Workflow Graph Setup | |
# ------------------------------ | |
workflow = StateGraph(AgentState) | |
workflow.add_node("agent", agent) | |
workflow.add_node("retrieve", ToolNode([ | |
create_retriever_tool( | |
chroma_manager.research_collection.as_retriever(), | |
"research_db_tool", | |
"Search research database" | |
), | |
create_retriever_tool( | |
chroma_manager.dev_collection.as_retriever(), | |
"development_db_tool", | |
"Search development database" | |
) | |
])) | |
workflow.add_node("rewrite", rewrite) | |
workflow.add_node("generate", generate) | |
workflow.set_entry_point("agent") | |
workflow.add_conditional_edges("agent", custom_tools_condition, {"tools": "retrieve", END: END}) | |
workflow.add_conditional_edges("retrieve", simple_grade_documents, {"generate": "generate", "rewrite": "rewrite"}) | |
workflow.add_edge("generate", END) | |
workflow.add_edge("rewrite", "agent") | |
app = workflow.compile() | |
# ------------------------------ | |
# Streamlit UI | |
# ------------------------------ | |
class UITheme: | |
primary_color = "#2E86C1" | |
secondary_color = "#28B463" | |
background_color = "#1A1A1A" | |
text_color = "#EAECEE" | |
def apply(cls): | |
st.markdown(f""" | |
<style> | |
.stApp {{ background-color: {cls.background_color}; color: {cls.text_color}; }} | |
.stTextArea textarea {{ | |
background-color: #2D2D2D !important; | |
color: {cls.text_color} !important; | |
border: 1px solid {cls.primary_color}; | |
}} | |
.stButton > button {{ | |
background-color: {cls.primary_color}; | |
color: white; | |
border: none; | |
padding: 12px 28px; | |
border-radius: 6px; | |
transition: all 0.3s ease; | |
font-weight: 500; | |
}} | |
.stButton > button:hover {{ | |
background-color: {cls.secondary_color}; | |
transform: translateY(-1px); | |
box-shadow: 0 4px 12px rgba(0,0,0,0.2); | |
}} | |
.data-box {{ | |
background-color: #2D2D2D; | |
border-left: 4px solid {cls.primary_color}; | |
padding: 18px; | |
margin: 14px 0; | |
border-radius: 8px; | |
box-shadow: 0 2px 8px rgba(0,0,0,0.15); | |
}} | |
</style> | |
""", unsafe_allow_html=True) | |
def main(): | |
UITheme.apply() | |
st.set_page_config( | |
page_title="AI Research Assistant Pro", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
menu_items={ | |
'Get Help': 'https://example.com/docs', | |
'Report a bug': 'https://example.com/issues', | |
'About': "v2.1 | Enhanced Research Assistant" | |
} | |
) | |
with st.sidebar: | |
st.header("π Knowledge Bases") | |
with st.expander("Research Database", expanded=True): | |
for text in research_texts: | |
st.markdown(f'<div class="data-box research-box">{text}</div>', unsafe_allow_html=True) | |
with st.expander("Development Database"): | |
for text in development_texts: | |
st.markdown(f'<div class="data-box dev-box">{text}</div>', unsafe_allow_html=True) | |
st.title("π¬ AI Research Assistant Pro") | |
st.markdown("---") | |
query = st.text_area( | |
"Research Query Input", | |
height=120, | |
placeholder="Enter your research question...", | |
help="Be specific about domains for better results" | |
) | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
if st.button("π Analyze Documents", use_container_width=True): | |
if not query: | |
st.warning("β οΈ Please enter a research question") | |
return | |
with st.status("Processing Workflow...", expanded=True) as status: | |
try: | |
start_time = time.time() | |
events = process_question(query, app, {"configurable": {"thread_id": "1"}}) | |
processed_data = [] | |
for event in events: | |
if 'agent' in event: | |
content = event['agent']['messages'][0].content | |
if "Results:" in content: | |
docs = eval(content.split("Results: ")[1]) | |
unique_docs = DocumentProcessor.deduplicate_documents(docs) | |
key_points = DocumentProcessor.extract_key_points(unique_docs) | |
processed_data.append(key_points) | |
with st.expander("π Retrieved Documents", expanded=False): | |
st.info(f"Found {len(unique_docs)} unique documents") | |
st.write(docs) | |
elif 'generate' in event: | |
final_answer = event['generate']['messages'][0].content | |
status.update(label="β Analysis Complete", state="complete") | |
st.markdown("## π Research Summary") | |
st.markdown(final_answer) | |
st.caption(f"β±οΈ Processed in {time.time()-start_time:.2f}s | {len(processed_data)} clusters") | |
except Exception as e: | |
status.update(label="β Processing Failed", state="error") | |
st.error(f"**Error:** {str(e)}\n\nCheck API key and network connection") | |
with open("error_log.txt", "a") as f: | |
f.write(f"{datetime.now()} | {str(e)}\n") | |
with col2: | |
st.markdown(""" | |
## π Usage Guide | |
**1. Query Formulation** | |
- Specify domains (e.g., "quantum NLP") | |
- Include timeframes for recent advances | |
**2. Results Interpretation** | |
- Expand sections for source documents | |
- Key points show technical breakthroughs | |
- Summary includes commercial implications | |
**3. Advanced Features** | |
- Use keyboard shortcuts for efficiency | |
- Click documents for raw context | |
- Export via screenshot/PDF | |
""") | |
if __name__ == "__main__": | |
main() |