mgbam's picture
Update app.py
81de628 verified
raw
history blame
17.3 kB
# ------------------------------
# Imports & Dependencies
# ------------------------------
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict, Annotated
from typing import Sequence, List, Dict
import chromadb
import re
import os
import streamlit as st
import requests
import hashlib
from langchain.tools.retriever import create_retriever_tool
from langchain.schema import Document
# ------------------------------
# Configuration
# ------------------------------
# Get DeepSeek API key from environment variables
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
# Validate API key configuration
if not DEEPSEEK_API_KEY:
st.error("""
**Critical Configuration Missing**
DeepSeek API key not found. Please ensure you have:
1. Created a Hugging Face Space secret named DEEPSEEK_API_KEY
2. Added your valid API key to the Space secrets
3. Restarted the Space after configuration
""")
st.stop()
# Create directory for Chroma persistence
os.makedirs("chroma_db", exist_ok=True)
# ------------------------------
# ChromaDB Client Configuration
# ------------------------------
# After (corrected)
chroma_client = chromadb.PersistentClient(
path="chroma_db",
settings=chromadb.config.Settings(anonymized_telemetry=False)
)
# ------------------------------
# Document Processing Utilities
# ------------------------------
def deduplicate_docs(docs: List[Document]) -> List[Document]:
"""Remove duplicate documents using content hashing"""
seen = set()
unique_docs = []
for doc in docs:
content_hash = hashlib.sha256(doc.page_content.encode()).hexdigest()
if content_hash not in seen:
seen.add(content_hash)
unique_docs.append(doc)
return unique_docs
# ------------------------------
# Data Preparation
# ------------------------------
research_texts = [
"Research Report: Results of a New AI Model Improving Image Recognition Accuracy to 98%",
"Academic Paper Summary: Why Transformers Became the Mainstream Architecture in Natural Language Processing",
"Latest Trends in Machine Learning Methods Using Quantum Computing"
]
development_texts = [
"Project A: UI Design Completed, API Integration in Progress",
"Project B: Testing New Feature X, Bug Fixes Needed",
"Product Y: In the Performance Optimization Stage Before Release"
]
# Create documents with metadata
splitter = RecursiveCharacterTextSplitter(
chunk_size=150,
chunk_overlap=20,
length_function=len,
add_start_index=True
)
research_docs = splitter.create_documents(
research_texts,
metadatas=[{"source": "research", "doc_id": f"res_{i}"} for i in range(len(research_texts))]
)
development_docs = splitter.create_documents(
development_texts,
metadatas=[{"source": "development", "doc_id": f"dev_{i}"} for i in range(len(development_texts))]
)
# ------------------------------
# Vector Store Initialization
# ------------------------------
embeddings = OpenAIEmbeddings(
model="text-embedding-3-large",
model_kwargs={"dimensions": 1024}
)
research_vectorstore = Chroma.from_documents(
documents=research_docs,
embedding=embeddings,
client=chroma_client,
collection_name="research_collection",
collection_metadata={"hnsw:space": "cosine"}
)
development_vectorstore = Chroma.from_documents(
documents=development_docs,
embedding=embeddings,
client=chroma_client,
collection_name="development_collection",
collection_metadata={"hnsw:space": "cosine"}
)
# ------------------------------
# Retriever Tools Configuration
# ------------------------------
research_retriever = research_vectorstore.as_retriever(
search_type="mmr",
search_kwargs={"k": 5, "fetch_k": 10}
)
development_retriever = development_vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 5}
)
tools = [
create_retriever_tool(
research_retriever,
"research_database",
"Searches through academic papers and research reports for technical AI advancements"
),
create_retriever_tool(
development_retriever,
"development_database",
"Accesses current project statuses and development timelines"
)
]
# ------------------------------
# Agent State Definition
# ------------------------------
class AgentState(TypedDict):
messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages]
# ------------------------------
# Core Agent Function
# ------------------------------
def agent(state: AgentState):
"""Main decision-making agent handling user queries"""
print("\n--- AGENT EXECUTION START ---")
messages = state["messages"]
try:
# Extract user message content
user_message = messages[-1].content if isinstance(messages[-1], HumanMessage) else ""
# Construct analysis prompt
prompt = f"""Analyze this user query and determine the appropriate action:
Query: {user_message}
Response Format:
- If research-related (technical details, academic concepts), respond:
SEARCH_RESEARCH: [keywords]
- If development-related (project status, timelines), respond:
SEARCH_DEV: [keywords]
- If general question, answer directly
- If unclear, request clarification
"""
# API request configuration
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {DEEPSEEK_API_KEY}",
"Content-Type": "application/json"
}
data = {
"model": "deepseek-chat",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
"max_tokens": 256
}
# Execute API call
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
headers=headers,
json=data,
timeout=30
)
response.raise_for_status()
# Process response
response_text = response.json()['choices'][0]['message']['content']
print(f"Agent Decision: {response_text}")
# Handle different response types
if "SEARCH_RESEARCH:" in response_text:
query = response_text.split("SEARCH_RESEARCH:")[1].strip()
results = research_retriever.invoke(query)
unique_results = deduplicate_docs(results)
return {
"messages": [
AIMessage(
content=f'Action: research_database\nQuery: "{query}"\nResults: {len(unique_results)} relevant documents',
additional_kwargs={"documents": unique_results}
)
]
}
elif "SEARCH_DEV:" in response_text:
query = response_text.split("SEARCH_DEV:")[1].strip()
results = development_retriever.invoke(query)
unique_results = deduplicate_docs(results)
return {
"messages": [
AIMessage(
content=f'Action: development_database\nQuery: "{query}"\nResults: {len(unique_results)} relevant documents',
additional_kwargs={"documents": unique_results}
)
]
}
else:
return {"messages": [AIMessage(content=response_text)]}
except requests.exceptions.HTTPError as e:
error_msg = f"API Error: {e.response.status_code} - {e.response.text}"
if "insufficient balance" in e.response.text.lower():
error_msg += "\n\nPlease check your DeepSeek account balance."
return {"messages": [AIMessage(content=error_msg)]}
except Exception as e:
return {"messages": [AIMessage(content=f"Processing Error: {str(e)}")]}
# ------------------------------
# Document Evaluation Functions
# ------------------------------
def simple_grade_documents(state: AgentState):
"""Evaluate retrieved document relevance"""
messages = state["messages"]
last_message = messages[-1]
if last_message.additional_kwargs.get("documents"):
print("--- Relevant Documents Found ---")
return "generate"
else:
print("--- No Valid Documents Found ---")
return "rewrite"
def generate(state: AgentState):
"""Generate final answer from documents"""
print("\n--- GENERATING FINAL ANSWER ---")
messages = state["messages"]
try:
# Extract context
user_question = next(msg.content for msg in messages if isinstance(msg, HumanMessage))
documents = messages[-1].additional_kwargs.get("documents", [])
# Format document sources
sources = list(set(
doc.metadata.get('source', 'unknown')
for doc in documents
))
# Create analysis prompt
prompt = f"""Synthesize a technical answer using these documents:
Question: {user_question}
Documents:
{[doc.page_content for doc in documents]}
Requirements:
1. Highlight quantitative metrics
2. Cite document sources (research/development)
3. Note temporal context
4. List potential applications
5. Mention limitations/gaps
"""
# API request configuration
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {DEEPSEEK_API_KEY}",
"Content-Type": "application/json"
}
data = {
"model": "deepseek-chat",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
"max_tokens": 1024
}
# Execute API call
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
headers=headers,
json=data,
timeout=45
)
response.raise_for_status()
# Format final answer
response_text = response.json()['choices'][0]['message']['content']
formatted_answer = f"{response_text}\n\nSources: {', '.join(sources)}"
return {"messages": [AIMessage(content=formatted_answer)]}
except Exception as e:
return {"messages": [AIMessage(content=f"Generation Error: {str(e)}")]}
def rewrite(state: AgentState):
"""Rewrite unclear queries"""
print("\n--- REWRITING QUERY ---")
messages = state["messages"]
try:
original_query = next(msg.content for msg in messages if isinstance(msg, HumanMessage))
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {DEEPSEEK_API_KEY}",
"Content-Type": "application/json"
}
data = {
"model": "deepseek-chat",
"messages": [{
"role": "user",
"content": f"Clarify this query while preserving technical intent: {original_query}"
}],
"temperature": 0.5,
"max_tokens": 256
}
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
headers=headers,
json=data,
timeout=30
)
response.raise_for_status()
rewritten = response.json()['choices'][0]['message']['content']
return {"messages": [AIMessage(content=f"Revised Query: {rewritten}")]}
except Exception as e:
return {"messages": [AIMessage(content=f"Rewriting Error: {str(e)}")]}
# ------------------------------
# Workflow Configuration
# ------------------------------
workflow = StateGraph(AgentState)
# Node Registration
workflow.add_node("agent", agent)
workflow.add_node("retrieve", ToolNode(tools))
workflow.add_node("generate", generate)
workflow.add_node("rewrite", rewrite)
# Workflow Structure
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
lambda state: "tools" if any(
tool.name in state["messages"][-1].content
for tool in tools
) else END,
{"tools": "retrieve", END: END}
)
workflow.add_conditional_edges(
"retrieve",
simple_grade_documents,
{"generate": "generate", "rewrite": "rewrite"}
)
workflow.add_edge("generate", END)
workflow.add_edge("rewrite", "agent")
app = workflow.compile()
# ------------------------------
# Streamlit UI Implementation
# ------------------------------
def main():
"""Main application interface"""
st.set_page_config(
page_title="AI Research Assistant",
layout="centered",
initial_sidebar_state="expanded"
)
# Dark Theme Configuration
st.markdown("""
<style>
.stApp {
background-color: #0E1117;
color: #FAFAFA;
}
.stTextArea textarea {
background-color: #262730 !important;
color: #FAFAFA !important;
border: 1px solid #3D4051;
}
.stButton>button {
background-color: #2E8B57;
color: white;
border-radius: 4px;
padding: 0.5rem 1rem;
transition: all 0.3s;
}
.stButton>button:hover {
background-color: #3CB371;
transform: scale(1.02);
}
.stAlert {
background-color: #1A1D23 !important;
border: 1px solid #3D4051;
}
.stExpander {
background-color: #1A1D23;
border: 1px solid #3D4051;
}
.data-source {
padding: 0.5rem;
margin: 0.5rem 0;
background-color: #1A1D23;
border-left: 3px solid #2E8B57;
border-radius: 4px;
}
</style>
""", unsafe_allow_html=True)
# Sidebar Configuration
with st.sidebar:
st.header("Technical Databases")
with st.expander("Research Corpus", expanded=True):
st.markdown("""
- AI Model Architectures
- Machine Learning Advances
- Quantum Computing Applications
- Algorithmic Breakthroughs
""")
with st.expander("Development Tracking", expanded=True):
st.markdown("""
- Project Milestones
- System Architecture
- Deployment Status
- Performance Metrics
""")
# Main Interface
st.title("🧠 AI Research Assistant")
st.caption("Technical Analysis and Development Tracking System")
query = st.text_area(
"Enter Technical Query:",
height=150,
placeholder="Example: Compare transformer architectures for medical imaging analysis..."
)
if st.button("Execute Analysis", use_container_width=True):
if not query:
st.warning("Please input a technical query")
return
with st.status("Processing...", expanded=True) as status:
try:
events = []
for event in app.stream({"messages": [HumanMessage(content=query)]}):
events.append(event)
if 'agent' in event:
status.update(label="Decision Making", state="running")
st.session_state.agent_step = event['agent']
if 'retrieve' in event:
status.update(label="Document Retrieval", state="running")
st.session_state.retrieved = event['retrieve']
if 'generate' in event:
status.update(label="Synthesizing Answer", state="running")
st.session_state.final_answer = event['generate']
status.update(label="Analysis Complete", state="complete")
except Exception as e:
status.update(label="Processing Failed", state="error")
st.error(f"""
**System Error**
{str(e)}
Please verify:
- API key validity
- Network connectivity
- Query complexity
""")
if 'final_answer' in st.session_state:
answer = st.session_state.final_answer['messages'][0].content
with st.container():
st.subheader("Technical Analysis")
st.markdown("---")
st.markdown(answer)
if "Sources:" in answer:
st.markdown("""
<div class="data-source">
ℹ️ Document sources are derived from the internal research database
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()