mgbam's picture
Update app.py
80d22c8 verified
raw
history blame
18.2 kB
# ------------------------------
# Imports & Dependencies
# ------------------------------
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict, Annotated
from typing import Sequence, List, Dict, Any
import chromadb
import re
import os
import streamlit as st
import requests
import time
import hashlib
from langchain.tools.retriever import create_retriever_tool
from datetime import datetime
# ------------------------------
# Data Definitions
# ------------------------------
research_texts = [
"Research Report: Results of a New AI Model Improving Image Recognition Accuracy to 98%",
"Academic Paper Summary: Why Transformers Became the Mainstream Architecture in Natural Language Processing",
"Latest Trends in Machine Learning Methods Using Quantum Computing",
"Advancements in Neuromorphic Computing for Energy-Efficient AI Systems",
"Cross-Modal Learning: Integrating Visual and Textual Representations for Multimodal AI"
]
development_texts = [
"Project A: UI Design Completed, API Integration in Progress",
"Project B: Testing New Feature X, Bug Fixes Needed",
"Product Y: In the Performance Optimization Stage Before Release",
"Framework Z: Version 3.2 Released with Enhanced Distributed Training Support",
"DevOps Pipeline: Automated CI/CD Implementation for ML Model Deployment"
]
# ------------------------------
# Configuration Class
# ------------------------------
class AppConfig:
def __init__(self):
self.DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
self.CHROMA_PATH = "chroma_db"
self.MAX_RETRIES = 3
self.RETRY_DELAY = 1.5
self.DOCUMENT_CHUNK_SIZE = 300
self.DOCUMENT_OVERLAP = 50
self.SEARCH_K = 5
self.SEARCH_TYPE = "mmr"
self.validate_config()
def validate_config(self):
if not self.DEEPSEEK_API_KEY:
st.error("""
**Critical Configuration Missing**
πŸ”‘ DeepSeek API key not found in environment variables.
Please configure through Hugging Face Space secrets:
1. Go to Space Settings β†’ Repository secrets
2. Add secret: Name=DEEPSEEK_API_KEY, Value=your_api_key
3. Rebuild Space
""")
st.stop()
config = AppConfig()
# ------------------------------
# ChromaDB Manager
# ------------------------------
class ChromaManager:
def __init__(self, research_data: List[str], development_data: List[str]):
os.makedirs(config.CHROMA_PATH, exist_ok=True)
self.client = chromadb.PersistentClient(path=config.CHROMA_PATH)
self.embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
self.research_collection = self.create_collection(
research_data,
"research_collection",
{"category": "research", "version": "1.2"}
)
self.dev_collection = self.create_collection(
development_data,
"development_collection",
{"category": "development", "version": "1.1"}
)
def create_collection(self, documents: List[str], name: str, metadata: dict) -> Chroma:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.DOCUMENT_CHUNK_SIZE,
chunk_overlap=config.DOCUMENT_OVERLAP,
separators=["\n\n", "\n", "。", " "]
)
docs = text_splitter.create_documents(documents)
return Chroma.from_documents(
documents=docs,
embedding=self.embeddings,
client=self.client,
collection_name=name,
collection_metadata=metadata
)
# Initialize Chroma with data
chroma_manager = ChromaManager(research_texts, development_texts)
# ------------------------------
# Document Processing
# ------------------------------
class DocumentProcessor:
@staticmethod
def deduplicate_documents(docs: List[Any]) -> List[Any]:
seen = set()
unique_docs = []
for doc in docs:
content_hash = hashlib.md5(doc.page_content.encode()).hexdigest()
if content_hash not in seen:
unique_docs.append(doc)
seen.add(content_hash)
return unique_docs
@staticmethod
def extract_key_points(docs: List[Any]) -> str:
key_points = []
categories = {
"quantum": ["quantum", "qpu", "qubit"],
"vision": ["image", "recognition", "vision"],
"nlp": ["transformer", "language", "llm"]
}
for doc in docs:
content = doc.page_content.lower()
if any(kw in content for kw in categories["quantum"]):
key_points.append("- Quantum computing integration showing promising results")
if any(kw in content for kw in categories["vision"]):
key_points.append("- Computer vision models achieving state-of-the-art accuracy")
if any(kw in content for kw in categories["nlp"]):
key_points.append("- NLP architectures evolving with memory-augmented transformers")
return "\n".join(list(set(key_points)))
# ------------------------------
# Enhanced Agent Components
# ------------------------------
class EnhancedAgent:
def __init__(self):
self.session_stats = {
"processing_times": [],
"doc_counts": [],
"error_count": 0
}
def api_request_with_retry(self, endpoint: str, payload: Dict) -> Dict:
headers = {
"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}",
"Content-Type": "application/json"
}
for attempt in range(config.MAX_RETRIES):
try:
response = requests.post(
endpoint,
headers=headers,
json=payload,
timeout=30,
verify=False
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 429:
delay = config.RETRY_DELAY ** (attempt + 1)
time.sleep(delay)
continue
raise
raise Exception(f"API request failed after {config.MAX_RETRIES} attempts")
# ------------------------------
# Workflow Configuration
# ------------------------------
class AgentState(TypedDict):
messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages]
def agent(state: AgentState):
print("---CALL AGENT---")
messages = state["messages"]
user_message = messages[0].content if not isinstance(messages[0], tuple) else messages[0][1]
prompt = f"""Given this user question: "{user_message}"
If about research/academic topics, respond EXACTLY:
SEARCH_RESEARCH: <search terms>
If about development status, respond EXACTLY:
SEARCH_DEV: <search terms>
Otherwise, answer directly."""
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}",
"Content-Type": "application/json"
}
data = {
"model": "deepseek-chat",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"max_tokens": 1024
}
try:
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
headers=headers,
json=data,
verify=False,
timeout=30
)
response.raise_for_status()
response_text = response.json()['choices'][0]['message']['content']
if "SEARCH_RESEARCH:" in response_text:
query = response_text.split("SEARCH_RESEARCH:")[1].strip()
results = chroma_manager.research_collection.as_retriever().invoke(query)
return {"messages": [AIMessage(content=f'Action: research_db_tool\n{{"query": "{query}"}}\n\nResults: {str(results)}')]}
elif "SEARCH_DEV:" in response_text:
query = response_text.split("SEARCH_DEV:")[1].strip()
results = chroma_manager.dev_collection.as_retriever().invoke(query)
return {"messages": [AIMessage(content=f'Action: development_db_tool\n{{"query": "{query}"}}\n\nResults: {str(results)}')]}
return {"messages": [AIMessage(content=response_text)]}
except Exception as e:
error_msg = f"API Error: {str(e)}"
if "Insufficient Balance" in str(e):
error_msg += "\n\nPlease check your DeepSeek API account balance."
return {"messages": [AIMessage(content=error_msg)]}
def simple_grade_documents(state: AgentState):
messages = state["messages"]
last_message = messages[-1]
return "generate" if "Results: [Document" in last_message.content else "rewrite"
def generate(state: AgentState):
messages = state["messages"]
question = messages[0].content
last_message = messages[-1]
docs_content = []
if "Results: [" in last_message.content:
docs_str = last_message.content.split("Results: ")[1]
docs_content = eval(docs_str)
processed_info = DocumentProcessor.extract_key_points(
DocumentProcessor.deduplicate_documents(docs_content)
)
prompt = f"""Generate structured research summary:
Key Information:
{processed_info}
Include:
1. Section headings
2. Bullet points
3. Significance
4. Applications"""
try:
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": "deepseek-chat",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"max_tokens": 1024
},
timeout=30
)
response.raise_for_status()
return {"messages": [AIMessage(content=response.json()['choices'][0]['message']['content'])]}
except Exception as e:
return {"messages": [AIMessage(content=f"Generation Error: {str(e)}")]}
def rewrite(state: AgentState):
messages = state["messages"]
original_question = messages[0].content
try:
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": "deepseek-chat",
"messages": [{
"role": "user",
"content": f"Rewrite for clarity: {original_question}"
}],
"temperature": 0.7,
"max_tokens": 1024
},
timeout=30
)
response.raise_for_status()
return {"messages": [AIMessage(content=response.json()['choices'][0]['message']['content'])}
except Exception as e:
return {"messages": [AIMessage(content=f"Rewrite Error: {str(e)}")]}
tools_pattern = re.compile(r"Action: .*")
def custom_tools_condition(state: AgentState):
content = state["messages"][-1].content
return "tools" if tools_pattern.match(content) else END
# ------------------------------
# Workflow Graph Setup
# ------------------------------
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent)
workflow.add_node("retrieve", ToolNode([
create_retriever_tool(
chroma_manager.research_collection.as_retriever(),
"research_db_tool",
"Search research database"
),
create_retriever_tool(
chroma_manager.dev_collection.as_retriever(),
"development_db_tool",
"Search development database"
)
]))
workflow.add_node("rewrite", rewrite)
workflow.add_node("generate", generate)
workflow.set_entry_point("agent")
workflow.add_conditional_edges("agent", custom_tools_condition, {"tools": "retrieve", END: END})
workflow.add_conditional_edges("retrieve", simple_grade_documents, {"generate": "generate", "rewrite": "rewrite"})
workflow.add_edge("generate", END)
workflow.add_edge("rewrite", "agent")
app = workflow.compile()
# ------------------------------
# Streamlit UI
# ------------------------------
class UITheme:
primary_color = "#2E86C1"
secondary_color = "#28B463"
background_color = "#1A1A1A"
text_color = "#EAECEE"
@classmethod
def apply(cls):
st.markdown(f"""
<style>
.stApp {{ background-color: {cls.background_color}; color: {cls.text_color}; }}
.stTextArea textarea {{
background-color: #2D2D2D !important;
color: {cls.text_color} !important;
border: 1px solid {cls.primary_color};
}}
.stButton > button {{
background-color: {cls.primary_color};
color: white;
border: none;
padding: 12px 28px;
border-radius: 6px;
transition: all 0.3s ease;
font-weight: 500;
}}
.stButton > button:hover {{
background-color: {cls.secondary_color};
transform: translateY(-1px);
box-shadow: 0 4px 12px rgba(0,0,0,0.2);
}}
.data-box {{
background-color: #2D2D2D;
border-left: 4px solid {cls.primary_color};
padding: 18px;
margin: 14px 0;
border-radius: 8px;
box-shadow: 0 2px 8px rgba(0,0,0,0.15);
}}
</style>
""", unsafe_allow_html=True)
def main():
UITheme.apply()
st.set_page_config(
page_title="AI Research Assistant Pro",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'https://example.com/docs',
'Report a bug': 'https://example.com/issues',
'About': "v2.1 | Enhanced Research Assistant"
}
)
with st.sidebar:
st.header("πŸ“‚ Knowledge Bases")
with st.expander("Research Database", expanded=True):
for text in research_texts:
st.markdown(f'<div class="data-box research-box">{text}</div>', unsafe_allow_html=True)
with st.expander("Development Database"):
for text in development_texts:
st.markdown(f'<div class="data-box dev-box">{text}</div>', unsafe_allow_html=True)
st.title("πŸ”¬ AI Research Assistant Pro")
st.markdown("---")
query = st.text_area(
"Research Query Input",
height=120,
placeholder="Enter your research question...",
help="Be specific about domains for better results"
)
col1, col2 = st.columns([1, 2])
with col1:
if st.button("πŸš€ Analyze Documents", use_container_width=True):
if not query:
st.warning("⚠️ Please enter a research question")
return
with st.status("Processing Workflow...", expanded=True) as status:
try:
start_time = time.time()
events = process_question(query, app, {"configurable": {"thread_id": "1"}})
processed_data = []
for event in events:
if 'agent' in event:
content = event['agent']['messages'][0].content
if "Results:" in content:
docs = eval(content.split("Results: ")[1])
unique_docs = DocumentProcessor.deduplicate_documents(docs)
key_points = DocumentProcessor.extract_key_points(unique_docs)
processed_data.append(key_points)
with st.expander("πŸ“„ Retrieved Documents", expanded=False):
st.info(f"Found {len(unique_docs)} unique documents")
st.write(docs)
elif 'generate' in event:
final_answer = event['generate']['messages'][0].content
status.update(label="βœ… Analysis Complete", state="complete")
st.markdown("## πŸ“ Research Summary")
st.markdown(final_answer)
st.caption(f"⏱️ Processed in {time.time()-start_time:.2f}s | {len(processed_data)} clusters")
except Exception as e:
status.update(label="❌ Processing Failed", state="error")
st.error(f"**Error:** {str(e)}\n\nCheck API key and network connection")
with open("error_log.txt", "a") as f:
f.write(f"{datetime.now()} | {str(e)}\n")
with col2:
st.markdown("""
## πŸ“˜ Usage Guide
**1. Query Formulation**
- Specify domains (e.g., "quantum NLP")
- Include timeframes for recent advances
**2. Results Interpretation**
- Expand sections for source documents
- Key points show technical breakthroughs
- Summary includes commercial implications
**3. Advanced Features**
- Use keyboard shortcuts for efficiency
- Click documents for raw context
- Export via screenshot/PDF
""")
if __name__ == "__main__":
main()