Spaces:
Sleeping
Sleeping
# ------------------------------ | |
# Imports & Dependencies | |
# ------------------------------ | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langgraph.graph import END, StateGraph | |
from langgraph.prebuilt import ToolNode | |
from langgraph.graph.message import add_messages | |
from typing_extensions import TypedDict, Annotated | |
from typing import Sequence, List, Dict | |
import chromadb | |
import re | |
import os | |
import streamlit as st | |
import requests | |
import hashlib | |
from langchain.tools.retriever import create_retriever_tool | |
from langchain.schema import Document | |
# ------------------------------ | |
# Configuration | |
# ------------------------------ | |
# Get DeepSeek API key from environment variables | |
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY") | |
# Validate API key configuration | |
if not DEEPSEEK_API_KEY: | |
st.error(""" | |
**Critical Configuration Missing** | |
DeepSeek API key not found. Please ensure you have: | |
1. Created a Hugging Face Space secret named DEEPSEEK_API_KEY | |
2. Added your valid API key to the Space secrets | |
3. Restarted the Space after configuration | |
""") | |
st.stop() | |
# Create directory for Chroma persistence | |
os.makedirs("chroma_db", exist_ok=True) | |
# ------------------------------ | |
# ChromaDB Client Configuration | |
# ------------------------------ | |
# After (corrected) | |
chroma_client = chromadb.PersistentClient( | |
path="chroma_db", | |
settings=chromadb.config.Settings(anonymized_telemetry=False) | |
) | |
# ------------------------------ | |
# Document Processing Utilities | |
# ------------------------------ | |
def deduplicate_docs(docs: List[Document]) -> List[Document]: | |
"""Remove duplicate documents using content hashing""" | |
seen = set() | |
unique_docs = [] | |
for doc in docs: | |
content_hash = hashlib.sha256(doc.page_content.encode()).hexdigest() | |
if content_hash not in seen: | |
seen.add(content_hash) | |
unique_docs.append(doc) | |
return unique_docs | |
# ------------------------------ | |
# Data Preparation | |
# ------------------------------ | |
research_texts = [ | |
"Research Report: Results of a New AI Model Improving Image Recognition Accuracy to 98%", | |
"Academic Paper Summary: Why Transformers Became the Mainstream Architecture in Natural Language Processing", | |
"Latest Trends in Machine Learning Methods Using Quantum Computing" | |
] | |
development_texts = [ | |
"Project A: UI Design Completed, API Integration in Progress", | |
"Project B: Testing New Feature X, Bug Fixes Needed", | |
"Product Y: In the Performance Optimization Stage Before Release" | |
] | |
# Create documents with metadata | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=150, | |
chunk_overlap=20, | |
length_function=len, | |
add_start_index=True | |
) | |
research_docs = splitter.create_documents( | |
research_texts, | |
metadatas=[{"source": "research", "doc_id": f"res_{i}"} for i in range(len(research_texts))] | |
) | |
development_docs = splitter.create_documents( | |
development_texts, | |
metadatas=[{"source": "development", "doc_id": f"dev_{i}"} for i in range(len(development_texts))] | |
) | |
# ------------------------------ | |
# Vector Store Initialization | |
# ------------------------------ | |
embeddings = OpenAIEmbeddings( | |
model="text-embedding-3-large", | |
model_kwargs={"dimensions": 1024} | |
) | |
research_vectorstore = Chroma.from_documents( | |
documents=research_docs, | |
embedding=embeddings, | |
client=chroma_client, | |
collection_name="research_collection", | |
collection_metadata={"hnsw:space": "cosine"} | |
) | |
development_vectorstore = Chroma.from_documents( | |
documents=development_docs, | |
embedding=embeddings, | |
client=chroma_client, | |
collection_name="development_collection", | |
collection_metadata={"hnsw:space": "cosine"} | |
) | |
# ------------------------------ | |
# Retriever Tools Configuration | |
# ------------------------------ | |
research_retriever = research_vectorstore.as_retriever( | |
search_type="mmr", | |
search_kwargs={"k": 5, "fetch_k": 10} | |
) | |
development_retriever = development_vectorstore.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 5} | |
) | |
tools = [ | |
create_retriever_tool( | |
research_retriever, | |
"research_database", | |
"Searches through academic papers and research reports for technical AI advancements" | |
), | |
create_retriever_tool( | |
development_retriever, | |
"development_database", | |
"Accesses current project statuses and development timelines" | |
) | |
] | |
# ------------------------------ | |
# Agent State Definition | |
# ------------------------------ | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages] | |
# ------------------------------ | |
# Core Agent Function | |
# ------------------------------ | |
def agent(state: AgentState): | |
"""Main decision-making agent handling user queries""" | |
print("\n--- AGENT EXECUTION START ---") | |
messages = state["messages"] | |
try: | |
# Extract user message content | |
user_message = messages[-1].content if isinstance(messages[-1], HumanMessage) else "" | |
# Construct analysis prompt | |
prompt = f"""Analyze this user query and determine the appropriate action: | |
Query: {user_message} | |
Response Format: | |
- If research-related (technical details, academic concepts), respond: | |
SEARCH_RESEARCH: [keywords] | |
- If development-related (project status, timelines), respond: | |
SEARCH_DEV: [keywords] | |
- If general question, answer directly | |
- If unclear, request clarification | |
""" | |
# API request configuration | |
headers = { | |
"Accept": "application/json", | |
"Authorization": f"Bearer {DEEPSEEK_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": "deepseek-chat", | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.5, | |
"max_tokens": 256 | |
} | |
# Execute API call | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers=headers, | |
json=data, | |
timeout=30 | |
) | |
response.raise_for_status() | |
# Process response | |
response_text = response.json()['choices'][0]['message']['content'] | |
print(f"Agent Decision: {response_text}") | |
# Handle different response types | |
if "SEARCH_RESEARCH:" in response_text: | |
query = response_text.split("SEARCH_RESEARCH:")[1].strip() | |
results = research_retriever.invoke(query) | |
unique_results = deduplicate_docs(results) | |
return { | |
"messages": [ | |
AIMessage( | |
content=f'Action: research_database\nQuery: "{query}"\nResults: {len(unique_results)} relevant documents', | |
additional_kwargs={"documents": unique_results} | |
) | |
] | |
} | |
elif "SEARCH_DEV:" in response_text: | |
query = response_text.split("SEARCH_DEV:")[1].strip() | |
results = development_retriever.invoke(query) | |
unique_results = deduplicate_docs(results) | |
return { | |
"messages": [ | |
AIMessage( | |
content=f'Action: development_database\nQuery: "{query}"\nResults: {len(unique_results)} relevant documents', | |
additional_kwargs={"documents": unique_results} | |
) | |
] | |
} | |
else: | |
return {"messages": [AIMessage(content=response_text)]} | |
except requests.exceptions.HTTPError as e: | |
error_msg = f"API Error: {e.response.status_code} - {e.response.text}" | |
if "insufficient balance" in e.response.text.lower(): | |
error_msg += "\n\nPlease check your DeepSeek account balance." | |
return {"messages": [AIMessage(content=error_msg)]} | |
except Exception as e: | |
return {"messages": [AIMessage(content=f"Processing Error: {str(e)}")]} | |
# ------------------------------ | |
# Document Evaluation Functions | |
# ------------------------------ | |
def simple_grade_documents(state: AgentState): | |
"""Evaluate retrieved document relevance""" | |
messages = state["messages"] | |
last_message = messages[-1] | |
if last_message.additional_kwargs.get("documents"): | |
print("--- Relevant Documents Found ---") | |
return "generate" | |
else: | |
print("--- No Valid Documents Found ---") | |
return "rewrite" | |
def generate(state: AgentState): | |
"""Generate final answer from documents""" | |
print("\n--- GENERATING FINAL ANSWER ---") | |
messages = state["messages"] | |
try: | |
# Extract context | |
user_question = next(msg.content for msg in messages if isinstance(msg, HumanMessage)) | |
documents = messages[-1].additional_kwargs.get("documents", []) | |
# Format document sources | |
sources = list(set( | |
doc.metadata.get('source', 'unknown') | |
for doc in documents | |
)) | |
# Create analysis prompt | |
prompt = f"""Synthesize a technical answer using these documents: | |
Question: {user_question} | |
Documents: | |
{[doc.page_content for doc in documents]} | |
Requirements: | |
1. Highlight quantitative metrics | |
2. Cite document sources (research/development) | |
3. Note temporal context | |
4. List potential applications | |
5. Mention limitations/gaps | |
""" | |
# API request configuration | |
headers = { | |
"Accept": "application/json", | |
"Authorization": f"Bearer {DEEPSEEK_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": "deepseek-chat", | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.3, | |
"max_tokens": 1024 | |
} | |
# Execute API call | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers=headers, | |
json=data, | |
timeout=45 | |
) | |
response.raise_for_status() | |
# Format final answer | |
response_text = response.json()['choices'][0]['message']['content'] | |
formatted_answer = f"{response_text}\n\nSources: {', '.join(sources)}" | |
return {"messages": [AIMessage(content=formatted_answer)]} | |
except Exception as e: | |
return {"messages": [AIMessage(content=f"Generation Error: {str(e)}")]} | |
def rewrite(state: AgentState): | |
"""Rewrite unclear queries""" | |
print("\n--- REWRITING QUERY ---") | |
messages = state["messages"] | |
try: | |
original_query = next(msg.content for msg in messages if isinstance(msg, HumanMessage)) | |
headers = { | |
"Accept": "application/json", | |
"Authorization": f"Bearer {DEEPSEEK_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": "deepseek-chat", | |
"messages": [{ | |
"role": "user", | |
"content": f"Clarify this query while preserving technical intent: {original_query}" | |
}], | |
"temperature": 0.5, | |
"max_tokens": 256 | |
} | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers=headers, | |
json=data, | |
timeout=30 | |
) | |
response.raise_for_status() | |
rewritten = response.json()['choices'][0]['message']['content'] | |
return {"messages": [AIMessage(content=f"Revised Query: {rewritten}")]} | |
except Exception as e: | |
return {"messages": [AIMessage(content=f"Rewriting Error: {str(e)}")]} | |
# ------------------------------ | |
# Workflow Configuration | |
# ------------------------------ | |
workflow = StateGraph(AgentState) | |
# Node Registration | |
workflow.add_node("agent", agent) | |
workflow.add_node("retrieve", ToolNode(tools)) | |
workflow.add_node("generate", generate) | |
workflow.add_node("rewrite", rewrite) | |
# Workflow Structure | |
workflow.set_entry_point("agent") | |
workflow.add_conditional_edges( | |
"agent", | |
lambda state: "tools" if any( | |
tool.name in state["messages"][-1].content | |
for tool in tools | |
) else END, | |
{"tools": "retrieve", END: END} | |
) | |
workflow.add_conditional_edges( | |
"retrieve", | |
simple_grade_documents, | |
{"generate": "generate", "rewrite": "rewrite"} | |
) | |
workflow.add_edge("generate", END) | |
workflow.add_edge("rewrite", "agent") | |
app = workflow.compile() | |
# ------------------------------ | |
# Streamlit UI Implementation | |
# ------------------------------ | |
def main(): | |
"""Main application interface""" | |
st.set_page_config( | |
page_title="AI Research Assistant", | |
layout="centered", | |
initial_sidebar_state="expanded" | |
) | |
# Dark Theme Configuration | |
st.markdown(""" | |
<style> | |
.stApp { | |
background-color: #0E1117; | |
color: #FAFAFA; | |
} | |
.stTextArea textarea { | |
background-color: #262730 !important; | |
color: #FAFAFA !important; | |
border: 1px solid #3D4051; | |
} | |
.stButton>button { | |
background-color: #2E8B57; | |
color: white; | |
border-radius: 4px; | |
padding: 0.5rem 1rem; | |
transition: all 0.3s; | |
} | |
.stButton>button:hover { | |
background-color: #3CB371; | |
transform: scale(1.02); | |
} | |
.stAlert { | |
background-color: #1A1D23 !important; | |
border: 1px solid #3D4051; | |
} | |
.stExpander { | |
background-color: #1A1D23; | |
border: 1px solid #3D4051; | |
} | |
.data-source { | |
padding: 0.5rem; | |
margin: 0.5rem 0; | |
background-color: #1A1D23; | |
border-left: 3px solid #2E8B57; | |
border-radius: 4px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Sidebar Configuration | |
with st.sidebar: | |
st.header("Technical Databases") | |
with st.expander("Research Corpus", expanded=True): | |
st.markdown(""" | |
- AI Model Architectures | |
- Machine Learning Advances | |
- Quantum Computing Applications | |
- Algorithmic Breakthroughs | |
""") | |
with st.expander("Development Tracking", expanded=True): | |
st.markdown(""" | |
- Project Milestones | |
- System Architecture | |
- Deployment Status | |
- Performance Metrics | |
""") | |
# Main Interface | |
st.title("🧠 AI Research Assistant") | |
st.caption("Technical Analysis and Development Tracking System") | |
query = st.text_area( | |
"Enter Technical Query:", | |
height=150, | |
placeholder="Example: Compare transformer architectures for medical imaging analysis..." | |
) | |
if st.button("Execute Analysis", use_container_width=True): | |
if not query: | |
st.warning("Please input a technical query") | |
return | |
with st.status("Processing...", expanded=True) as status: | |
try: | |
events = [] | |
for event in app.stream({"messages": [HumanMessage(content=query)]}): | |
events.append(event) | |
if 'agent' in event: | |
status.update(label="Decision Making", state="running") | |
st.session_state.agent_step = event['agent'] | |
if 'retrieve' in event: | |
status.update(label="Document Retrieval", state="running") | |
st.session_state.retrieved = event['retrieve'] | |
if 'generate' in event: | |
status.update(label="Synthesizing Answer", state="running") | |
st.session_state.final_answer = event['generate'] | |
status.update(label="Analysis Complete", state="complete") | |
except Exception as e: | |
status.update(label="Processing Failed", state="error") | |
st.error(f""" | |
**System Error** | |
{str(e)} | |
Please verify: | |
- API key validity | |
- Network connectivity | |
- Query complexity | |
""") | |
if 'final_answer' in st.session_state: | |
answer = st.session_state.final_answer['messages'][0].content | |
with st.container(): | |
st.subheader("Technical Analysis") | |
st.markdown("---") | |
st.markdown(answer) | |
if "Sources:" in answer: | |
st.markdown(""" | |
<div class="data-source"> | |
ℹ️ Document sources are derived from the internal research database | |
</div> | |
""", unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() |