# workflow.py import time from datetime import datetime from typing import Dict, Any, Sequence from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langgraph.graph import END, StateGraph from langgraph.graph.message import add_messages from typing_extensions import TypedDict, Annotated from processor import EnhancedCognitiveProcessor from config import ResearchConfig import logging logger = logging.getLogger(__name__) # Define the state schema class AgentState(TypedDict): messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages] context: Dict[str, Any] metadata: Dict[str, Any] class ResearchWorkflow: """ A multi-step research workflow that leverages Retrieval-Augmented Generation (RAG). Supports domains including: - Biomedical Research - Legal Research - Environmental and Energy Studies - Competitive Programming and Theoretical Computer Science - Social Sciences This implementation normalizes the domain and uses domain-specific prompts and fallbacks. """ def __init__(self) -> None: self.processor = EnhancedCognitiveProcessor() self.workflow = StateGraph(AgentState) # Supply state schema self._build_workflow() self.app = self.workflow.compile() def _build_workflow(self) -> None: self.workflow.add_node("ingest", self.ingest_query) self.workflow.add_node("retrieve", self.retrieve_documents) self.workflow.add_node("analyze", self.analyze_content) self.workflow.add_node("validate", self.validate_output) self.workflow.add_node("refine", self.refine_results) self.workflow.set_entry_point("ingest") self.workflow.add_edge("ingest", "retrieve") self.workflow.add_edge("retrieve", "analyze") self.workflow.add_conditional_edges( "analyze", self._quality_check, {"valid": "validate", "invalid": "refine"} ) self.workflow.add_edge("validate", END) self.workflow.add_edge("refine", "retrieve") # Extended node for multi-modal enhancement self.workflow.add_node("enhance", self.enhance_analysis) self.workflow.add_edge("validate", "enhance") self.workflow.add_edge("enhance", END) def ingest_query(self, state: Dict) -> Dict: try: query = state["messages"][-1].content # Normalize domain string to lower-case; default to 'biomedical research' domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower() new_context = { "raw_query": query, "domain": domain, "refine_count": 0, "refinement_history": [] } logger.info(f"Query ingested. Domain: {domain}") return { "messages": [AIMessage(content="Query ingested successfully")], "context": new_context, "metadata": {"timestamp": datetime.now().isoformat()} } except Exception as e: logger.exception("Error during query ingestion.") return self._error_state(f"Ingestion Error: {str(e)}") def retrieve_documents(self, state: Dict) -> Dict: try: query = state["context"]["raw_query"] # Simulate retrieval; for now, an empty list indicates no external documents found. docs = [] logger.info(f"Retrieved {len(docs)} documents for query.") return { "messages": [AIMessage(content=f"Retrieved {len(docs)} documents")], "context": { "documents": docs, "retrieval_time": time.time(), "refine_count": state["context"].get("refine_count", 0), "refinement_history": state["context"].get("refinement_history", []), "domain": state["context"].get("domain", "biomedical research") } } except Exception as e: logger.exception("Error during document retrieval.") return self._error_state(f"Retrieval Error: {str(e)}") def analyze_content(self, state: Dict) -> Dict: try: # Normalize domain and use it for prompt generation domain = state["context"].get("domain", "biomedical research").strip().lower() docs = state["context"].get("documents", []) # Use retrieved documents if available; else, use raw query as fallback. if docs: docs_text = "\n\n".join([d.page_content for d in docs]) else: docs_text = state["context"].get("raw_query", "") logger.info("No documents retrieved; using dynamic synthesis (RAG mode).") # Get domain-specific prompt; ensure fallback prompts exist for all supported domains. domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain, "Consider relevant legal cases and statutory interpretations.") # Build the final prompt with domain tag for clarity. full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \ f"{domain_prompt}\n\n" + \ ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text) response = self.processor.process_query(full_prompt) if "error" in response: logger.error("Backend response error during analysis.") return self._error_state(response["error"]) logger.info("Content analysis completed using RAG approach.") return { "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))], "context": state["context"] } except Exception as e: logger.exception("Error during content analysis.") return self._error_state(f"Analysis Error: {str(e)}") def validate_output(self, state: Dict) -> Dict: try: analysis = state["messages"][-1].content validation_prompt = ( f"Validate the following analysis for correctness, clarity, and legal grounding:\n{analysis}\n\n" "Criteria:\n" "1. Technical and legal accuracy\n" "2. Evidence and citation support\n" "3. Logical consistency\n" "4. Methodological soundness\n\n" "Respond with 'VALID: [justification]' or 'INVALID: [justification]'." ) response = self.processor.process_query(validation_prompt) logger.info("Output validation completed.") return { "messages": [AIMessage(content=analysis + f"\n\nValidation: {response.get('choices', [{}])[0].get('message', {}).get('content', '')}")], "context": state["context"] } except Exception as e: logger.exception("Error during output validation.") return self._error_state(f"Validation Error: {str(e)}") def refine_results(self, state: Dict) -> Dict: try: current_count = state["context"].get("refine_count", 0) state["context"]["refine_count"] = current_count + 1 refinement_history = state["context"].setdefault("refinement_history", []) current_analysis = state["messages"][-1].content refinement_history.append(current_analysis) difficulty_level = max(0, 3 - state["context"]["refine_count"]) domain = state["context"].get("domain", "biomedical research") logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}") if state["context"]["refine_count"] >= 3: meta_prompt = ( f"Domain: {domain}\n" "You are given the following series of refinement outputs:\n" + "\n---\n".join(refinement_history) + "\n\nSynthesize these into a final, concise legal analysis report, highlighting key precedents and statutory interpretations. " "Focus on improving accuracy and relevance for legal research." ) meta_response = self.processor.process_query(meta_prompt) logger.info("Meta-refinement completed.") return { "messages": [AIMessage(content=meta_response.get('choices', [{}])[0].get('message', {}).get('content', ''))], "context": state["context"] } else: refinement_prompt = ( f"Domain: {domain}\n" f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n" "First, identify weaknesses such as lack of legal grounding or misinterpretation of cases. " "Then, improve the analysis with clear references to legal precedents and statutory language." ) response = self.processor.process_query(refinement_prompt) logger.info("Refinement completed.") return { "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))], "context": state["context"] } except Exception as e: logger.exception("Error during refinement.") return self._error_state(f"Refinement Error: {str(e)}") def _quality_check(self, state: Dict) -> str: refine_count = state["context"].get("refine_count", 0) if refine_count >= 3: logger.warning("Refinement limit reached. Forcing valid outcome.") return "valid" content = state["messages"][-1].content quality = "valid" if "VALID" in content else "invalid" logger.info(f"Quality check returned: {quality}") return quality def _error_state(self, message: str) -> Dict: logger.error(message) return { "messages": [AIMessage(content=f"❌ {message}")], "context": {"error": True}, "metadata": {"status": "error"} } def enhance_analysis(self, state: Dict) -> Dict: try: analysis = state["messages"][-1].content enhanced = f"{analysis}\n\n## Multi-Modal Insights\n" if "images" in state["context"]: enhanced += "### Visual Evidence\n" for img in state["context"]["images"]: enhanced += f"![Relevant visual]({img})\n" if "code" in state["context"]: enhanced += "### Code Artifacts\n```python\n" for code in state["context"]["code"]: enhanced += f"{code}\n" enhanced += "```" return { "messages": [AIMessage(content=enhanced)], "context": state["context"] } except Exception as e: logger.exception("Error during multi-modal enhancement.") return self._error_state(f"Enhancement Error: {str(e)}")