NeuroResearch_AI / workflow.py
mgbam's picture
Update workflow.py
1ac0e39 verified
# workflow.py
import time
from datetime import datetime
from typing import Dict, Any, Sequence
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict, Annotated
from processor import EnhancedCognitiveProcessor
from config import ResearchConfig
import logging
logger = logging.getLogger(__name__)
# Define the state schema
class AgentState(TypedDict):
messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages]
context: Dict[str, Any]
metadata: Dict[str, Any]
class ResearchWorkflow:
"""
A multi-step research workflow employing Retrieval-Augmented Generation (RAG) with an additional verification step.
This workflow supports multiple domains (e.g., Biomedical, Legal, Environmental, Competitive Programming, Social Sciences)
and integrates domain-specific prompts, iterative refinement, and a final verification to reduce hallucinations.
"""
def __init__(self) -> None:
self.processor = EnhancedCognitiveProcessor()
self.workflow = StateGraph(AgentState)
self._build_workflow()
self.app = self.workflow.compile()
def _build_workflow(self) -> None:
self.workflow.add_node("ingest", self.ingest_query)
self.workflow.add_node("retrieve", self.retrieve_documents)
self.workflow.add_node("analyze", self.analyze_content)
self.workflow.add_node("validate", self.validate_output)
self.workflow.add_node("refine", self.refine_results)
# New verify node to further cross-check the output
self.workflow.add_node("verify", self.verify_output)
self.workflow.set_entry_point("ingest")
self.workflow.add_edge("ingest", "retrieve")
self.workflow.add_edge("retrieve", "analyze")
self.workflow.add_conditional_edges(
"analyze",
self._quality_check,
{"valid": "validate", "invalid": "refine"}
)
self.workflow.add_edge("validate", "verify")
self.workflow.add_edge("refine", "retrieve")
# Extended node for multi-modal enhancement
self.workflow.add_node("enhance", self.enhance_analysis)
self.workflow.add_edge("verify", "enhance")
self.workflow.add_edge("enhance", END)
def ingest_query(self, state: Dict) -> Dict:
try:
query = state["messages"][-1].content
# Normalize the domain string; default to 'biomedical research'
domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower()
new_context = {
"raw_query": query,
"domain": domain,
"refine_count": 0,
"refinement_history": []
}
logger.info(f"Query ingested. Domain: {domain}")
return {
"messages": [AIMessage(content="Query ingested successfully")],
"context": new_context,
"metadata": {"timestamp": datetime.now().isoformat()}
}
except Exception as e:
logger.exception("Error during query ingestion.")
return self._error_state(f"Ingestion Error: {str(e)}")
def retrieve_documents(self, state: Dict) -> Dict:
try:
query = state["context"]["raw_query"]
# Placeholder retrieval: currently returns an empty list (simulate no documents)
docs = []
logger.info(f"Retrieved {len(docs)} documents for query.")
return {
"messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
"context": {
"documents": docs,
"retrieval_time": time.time(),
"refine_count": state["context"].get("refine_count", 0),
"refinement_history": state["context"].get("refinement_history", []),
"domain": state["context"].get("domain", "biomedical research")
}
}
except Exception as e:
logger.exception("Error during document retrieval.")
return self._error_state(f"Retrieval Error: {str(e)}")
def analyze_content(self, state: Dict) -> Dict:
try:
domain = state["context"].get("domain", "biomedical research").strip().lower()
docs = state["context"].get("documents", [])
if docs:
docs_text = "\n\n".join([d.page_content for d in docs])
else:
docs_text = state["context"].get("raw_query", "")
logger.info("No documents retrieved; switching to dynamic synthesis (RAG mode).")
# Use domain-specific prompt; for legal research, inject legal-specific guidance.
domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain,
"Provide an analysis based on the provided context.")
full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
f"{domain_prompt}\n\n" + \
ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
response = self.processor.process_query(full_prompt)
if "error" in response:
logger.error("Backend response error during analysis.")
return self._error_state(response["error"])
logger.info("Content analysis completed using RAG approach.")
return {
"messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
"context": state["context"]
}
except Exception as e:
logger.exception("Error during content analysis.")
return self._error_state(f"Analysis Error: {str(e)}")
def validate_output(self, state: Dict) -> Dict:
try:
analysis = state["messages"][-1].content
validation_prompt = (
f"Validate the following analysis for accuracy and domain-specific relevance:\n{analysis}\n\n"
"Criteria:\n"
"1. Factual and technical accuracy\n"
"2. For legal research: inclusion of relevant precedents and statutory interpretations; "
"for other domains: appropriate domain insights\n"
"3. Logical consistency\n"
"4. Methodological soundness\n\n"
"Respond with 'VALID: [justification]' or 'INVALID: [justification]'."
)
response = self.processor.process_query(validation_prompt)
logger.info("Output validation completed.")
return {
"messages": [AIMessage(content=analysis + f"\n\nValidation: {response.get('choices', [{}])[0].get('message', {}).get('content', '')}")],
"context": state["context"]
}
except Exception as e:
logger.exception("Error during output validation.")
return self._error_state(f"Validation Error: {str(e)}")
def verify_output(self, state: Dict) -> Dict:
try:
# New verify step: cross-check the analysis using an external fact-checking prompt.
analysis = state["messages"][-1].content
verification_prompt = (
f"Verify the following analysis by comparing it with established external legal databases and reference texts:\n{analysis}\n\n"
"Identify any discrepancies or hallucinations and provide a brief correction if necessary."
)
response = self.processor.process_query(verification_prompt)
logger.info("Output verification completed.")
# Here, you can merge the verification feedback with the analysis.
verified_analysis = analysis + "\n\nVerification Feedback: " + response.get('choices', [{}])[0].get('message', {}).get('content', '')
return {
"messages": [AIMessage(content=verified_analysis)],
"context": state["context"]
}
except Exception as e:
logger.exception("Error during output verification.")
return self._error_state(f"Verification Error: {str(e)}")
def refine_results(self, state: Dict) -> Dict:
try:
current_count = state["context"].get("refine_count", 0)
state["context"]["refine_count"] = current_count + 1
refinement_history = state["context"].setdefault("refinement_history", [])
current_analysis = state["messages"][-1].content
refinement_history.append(current_analysis)
difficulty_level = max(0, 3 - state["context"]["refine_count"])
domain = state["context"].get("domain", "biomedical research")
logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}")
if state["context"]["refine_count"] >= 3:
meta_prompt = (
f"Domain: {domain}\n"
"You are given the following series of refinement outputs:\n" +
"\n---\n".join(refinement_history) +
"\n\nSynthesize these into a final, concise analysis report with improved accuracy and verifiable details."
)
meta_response = self.processor.process_query(meta_prompt)
logger.info("Meta-refinement completed.")
return {
"messages": [AIMessage(content=meta_response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
"context": state["context"]
}
else:
refinement_prompt = (
f"Domain: {domain}\n"
f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
"Identify and correct any weaknesses or hallucinations in the analysis, providing verifiable details."
)
response = self.processor.process_query(refinement_prompt)
logger.info("Refinement completed.")
return {
"messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
"context": state["context"]
}
except Exception as e:
logger.exception("Error during refinement.")
return self._error_state(f"Refinement Error: {str(e)}")
def _quality_check(self, state: Dict) -> str:
refine_count = state["context"].get("refine_count", 0)
if refine_count >= 3:
logger.warning("Refinement limit reached. Forcing valid outcome.")
return "valid"
content = state["messages"][-1].content
quality = "valid" if "VALID" in content else "invalid"
logger.info(f"Quality check returned: {quality}")
return quality
def _error_state(self, message: str) -> Dict:
logger.error(message)
return {
"messages": [AIMessage(content=f"❌ {message}")],
"context": {"error": True},
"metadata": {"status": "error"}
}
def enhance_analysis(self, state: Dict) -> Dict:
try:
analysis = state["messages"][-1].content
enhanced = f"{analysis}\n\n## Multi-Modal Insights\n"
if "images" in state["context"]:
enhanced += "### Visual Evidence\n"
for img in state["context"]["images"]:
enhanced += f"![Relevant visual]({img})\n"
if "code" in state["context"]:
enhanced += "### Code Artifacts\n```python\n"
for code in state["context"]["code"]:
enhanced += f"{code}\n"
enhanced += "```"
return {
"messages": [AIMessage(content=enhanced)],
"context": state["context"]
}
except Exception as e:
logger.exception("Error during multi-modal enhancement.")
return self._error_state(f"Enhancement Error: {str(e)}")