Spaces:
Running
Running
""" | |
AI Research Assistant | |
""" | |
# ------------------------------ | |
# Core Imports & Configuration | |
# ------------------------------ | |
import os | |
import re | |
import time | |
import chromadb | |
import requests | |
import streamlit as st | |
from typing import Sequence, Tuple | |
from typing_extensions import TypedDict, Annotated | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain.tools.retriever import create_retriever_tool | |
from langgraph.graph import END, StateGraph | |
from langgraph.prebuilt import ToolNode | |
from langgraph.graph.message import add_messages | |
from chromadb.config import Settings | |
from langchain_openai import OpenAIEmbeddings | |
# ------------------------------ | |
# Type Definitions | |
# ------------------------------ | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[AIMessage | HumanMessage], add_messages] | |
# ------------------------------ | |
# Configuration & Constants | |
# ------------------------------ | |
class Config: | |
API_KEY = os.environ.get("DEEPSEEK_API_KEY") | |
CHROMA_PATH = "chroma_db" | |
TEXT_SPLITTER_CONFIG = { | |
"chunk_size": 512, | |
"chunk_overlap": 128, | |
"separators": ["\n\n", "\n", ". ", "! ", "? "] | |
} | |
# ------------------------------ | |
# Core System Components | |
# ------------------------------ | |
class ResearchAssistant: | |
def __init__(self): | |
self.embeddings = OpenAIEmbeddings(model="text-embedding-3-large") | |
self.vector_stores = self._init_vector_stores() | |
self.tools = self._create_tools() | |
self.workflow = self._build_workflow() | |
def _init_vector_stores(self) -> Tuple[Chroma, Chroma]: | |
"""Initialize vector stores with proper document processing""" | |
splitter = RecursiveCharacterTextSplitter(**Config.TEXT_SPLITTER_CONFIG) | |
research_docs = splitter.create_documents([ | |
"Research Report: New AI Model Achieves 98% Image Recognition Accuracy", | |
"Transformers: The New NLP Architecture Standard", | |
"Quantum Machine Learning: Emerging Trends and Applications" | |
]) | |
development_docs = splitter.create_documents([ | |
"Project A: UI Design Finalized, API Integration Phase", | |
"Project B: Feature Testing and Bug Fixes", | |
"Product Y: Performance Optimization Pre-Release" | |
]) | |
client = chromadb.PersistentClient( | |
path=Config.CHROMA_PATH, | |
settings=Settings(anonymized_telemetry=False) | |
return ( | |
Chroma.from_documents(research_docs, self.embeddings, | |
client=client, collection_name="research"), | |
Chroma.from_documents(development_docs, self.embeddings, | |
client=client, collection_name="development") | |
) | |
def _create_tools(self): | |
"""Create retrieval tools with optimized search parameters""" | |
research_retriever = self.vector_stores[0].as_retriever( | |
search_kwargs={"k": 3, "score_threshold": 0.7} | |
) | |
development_retriever = self.vector_stores[1].as_retriever( | |
search_kwargs={"k": 3, "score_threshold": 0.7} | |
) | |
return [ | |
create_retriever_tool( | |
research_retriever, | |
"research_db", | |
"Access technical research papers and reports" | |
), | |
create_retriever_tool( | |
development_retriever, | |
"development_db", | |
"Retrieve project development status updates" | |
) | |
] | |
def _build_workflow(self): | |
"""Construct and return the processing workflow""" | |
workflow = StateGraph(AgentState) | |
workflow.add_node("analyze", self.analyze_query) | |
workflow.add_node("retrieve", ToolNode(self.tools)) | |
workflow.add_node("synthesize", self.synthesize_response) | |
workflow.set_entry_point("analyze") | |
workflow.add_conditional_edges( | |
"analyze", | |
self._needs_retrieval, | |
{"retrieve": "retrieve", "direct": "synthesize"} | |
) | |
workflow.add_edge("retrieve", "synthesize") | |
workflow.add_edge("synthesize", END) | |
return workflow.compile() | |
def _needs_retrieval(self, state: AgentState) -> str: | |
"""Determine if document retrieval is needed""" | |
query = state["messages"][-1].content.lower() | |
return "retrieve" if any(kw in query for kw in { | |
"research", "study", "project", "develop", "trend" | |
}) else "direct" | |
def analyze_query(self, state: AgentState): | |
"""Analyze user query and determine next steps""" | |
try: | |
user_input = state["messages"][-1].content | |
headers = { | |
"Authorization": f"Bearer {Config.API_KEY}", | |
"Content-Type": "application/json" | |
} | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers=headers, | |
json={ | |
"model": "deepseek-chat", | |
"messages": [{ | |
"role": "user", | |
"content": f"""Analyze this query and format as: | |
CATEGORY: [RESEARCH|DEVELOPMENT|GENERAL] | |
KEY_TERMS: comma-separated list | |
{user_input}""" | |
}], | |
"temperature": 0.3 | |
}, | |
timeout=15 | |
) | |
response.raise_for_status() | |
analysis = response.json()["choices"][0]["message"]["content"] | |
return {"messages": [AIMessage(content=analysis)]} | |
except Exception as e: | |
return {"messages": [AIMessage( | |
content=f"Analysis Error: {str(e)}. Please rephrase your question." | |
)]} | |
def synthesize_response(self, state: AgentState): | |
"""Generate final response with citations""" | |
try: | |
context = "\n".join([msg.content for msg in state["messages"]]) | |
headers = { | |
"Authorization": f"Bearer {Config.API_KEY}", | |
"Content-Type": "application/json" | |
} | |
response = requests.post( | |
"https://api.deepseek.com/v1/chat/completions", | |
headers=headers, | |
json={ | |
"model": "deepseek-chat", | |
"messages": [{ | |
"role": "user", | |
"content": f"""Synthesize this information: | |
{context} | |
Include: | |
1. Key findings | |
2. Supporting evidence | |
3. Technical details | |
4. Potential applications""" | |
}], | |
"temperature": 0.5 | |
}, | |
timeout=20 | |
) | |
response.raise_for_status() | |
return {"messages": [AIMessage( | |
content=response.json()["choices"][0]["message"]["content"] | |
)]} | |
except Exception as e: | |
return {"messages": [AIMessage( | |
content=f"Synthesis Error: {str(e)}. Please try again later." | |
)]} | |
# ------------------------------ | |
# Professional UI Interface | |
# ------------------------------ | |
def main(): | |
st.set_page_config( | |
page_title="Research Assistant Pro", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Dark theme implementation | |
st.markdown(""" | |
<style> | |
.stApp { | |
background-color: #0f1114; | |
color: #ffffff; | |
} | |
.stTextInput input, .stTextArea textarea { | |
background-color: #1e1e24 !important; | |
color: #ffffff !important; | |
} | |
.stButton>button { | |
background: #2563eb; | |
transition: all 0.2s; | |
} | |
.stButton>button:hover { | |
background: #1d4ed8; | |
transform: scale(1.02); | |
} | |
.result-card { | |
background: #1a1a1f; | |
border-radius: 8px; | |
padding: 1.5rem; | |
margin: 1rem 0; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.title("π Research Assistant Pro") | |
st.write("Advanced AI-Powered Research Analysis") | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
with st.form("query_form"): | |
query = st.text_area("Research Query:", height=150, | |
placeholder="Enter your research question...") | |
submitted = st.form_submit_button("Analyze") | |
if submitted and query: | |
with st.spinner("Processing..."): | |
try: | |
assistant = ResearchAssistant() | |
result = assistant.workflow.invoke({"messages": [ | |
HumanMessage(content=query) | |
]}) | |
with st.expander("Analysis Details", expanded=True): | |
st.markdown(f""" | |
<div class="result-card"> | |
{result['messages'][-1].content} | |
</div> | |
""", unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"Processing Error: {str(e)}") | |
with col2: | |
st.subheader("Knowledge Base") | |
with st.expander("Research Documents"): | |
st.info(""" | |
- Advanced Image Recognition Systems | |
- Transformer Architecture Analysis | |
- Quantum ML Research | |
""") | |
with st.expander("Development Updates"): | |
st.info(""" | |
- Project A: API Integration Phase | |
- Project B: Feature Testing | |
- Product Y: Optimization Stage | |
""") | |
if __name__ == "__main__": | |
if not Config.API_KEY: | |
st.error(""" | |
π Configuration Required: | |
Set DEEPSEEK_API_KEY environment variable | |
""") | |
st.stop() | |
main() |