Spaces:
Sleeping
Sleeping
# workflow.py | |
import time | |
from datetime import datetime | |
from typing import Dict | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langgraph.graph import END, StateGraph | |
from langgraph.graph.message import add_messages | |
from processor import EnhancedCognitiveProcessor | |
from config import ResearchConfig | |
import logging | |
logger = logging.getLogger(__name__) | |
class ResearchWorkflow: | |
""" | |
Defines a multi-step research workflow using a state graph. | |
""" | |
def __init__(self) -> None: | |
self.processor = EnhancedCognitiveProcessor() | |
self.workflow = StateGraph() | |
self._build_workflow() | |
self.app = self.workflow.compile() | |
def _build_workflow(self) -> None: | |
self.workflow.add_node("ingest", self.ingest_query) | |
self.workflow.add_node("retrieve", self.retrieve_documents) | |
self.workflow.add_node("analyze", self.analyze_content) | |
self.workflow.add_node("validate", self.validate_output) | |
self.workflow.add_node("refine", self.refine_results) | |
self.workflow.set_entry_point("ingest") | |
self.workflow.add_edge("ingest", "retrieve") | |
self.workflow.add_edge("retrieve", "analyze") | |
self.workflow.add_conditional_edges( | |
"analyze", | |
self._quality_check, | |
{"valid": "validate", "invalid": "refine"} | |
) | |
self.workflow.add_edge("validate", END) | |
self.workflow.add_edge("refine", "retrieve") | |
# Extended node for multi-modal enhancement | |
self.workflow.add_node("enhance", self.enhance_analysis) | |
self.workflow.add_edge("validate", "enhance") | |
self.workflow.add_edge("enhance", END) | |
def ingest_query(self, state: Dict) -> Dict: | |
try: | |
query = state["messages"][-1].content | |
# Retrieve the domain from the state's context (defaulting to Biomedical Research) | |
domain = state.get("context", {}).get("domain", "Biomedical Research") | |
new_context = {"raw_query": query, "domain": domain, "refine_count": 0, "refinement_history": []} | |
logger.info(f"Query ingested. Domain: {domain}") | |
return { | |
"messages": [AIMessage(content="Query ingested successfully")], | |
"context": new_context, | |
"metadata": {"timestamp": datetime.now().isoformat()} | |
} | |
except Exception as e: | |
logger.exception("Error during query ingestion.") | |
return self._error_state(f"Ingestion Error: {str(e)}") | |
def retrieve_documents(self, state: Dict) -> Dict: | |
try: | |
query = state["context"]["raw_query"] | |
# For demonstration, we use an empty document list. | |
# Replace this with actual retrieval logic as needed. | |
docs = [] | |
logger.info(f"Retrieved {len(docs)} documents for query.") | |
return { | |
"messages": [AIMessage(content=f"Retrieved {len(docs)} documents")], | |
"context": { | |
"documents": docs, | |
"retrieval_time": time.time(), | |
"refine_count": state["context"].get("refine_count", 0), | |
"refinement_history": state["context"].get("refinement_history", []), | |
"domain": state["context"].get("domain", "Biomedical Research") | |
} | |
} | |
except Exception as e: | |
logger.exception("Error during document retrieval.") | |
return self._error_state(f"Retrieval Error: {str(e)}") | |
def analyze_content(self, state: Dict) -> Dict: | |
try: | |
domain = state["context"].get("domain", "Biomedical Research").strip().lower() | |
fallback_analyses = ResearchConfig.DOMAIN_FALLBACKS | |
if domain in fallback_analyses: | |
logger.info(f"Using fallback analysis for domain: {state['context'].get('domain')}") | |
return { | |
"messages": [AIMessage(content=fallback_analyses[domain].strip())], | |
"context": state["context"] | |
} | |
else: | |
docs = state["context"].get("documents", []) | |
docs_text = "\n\n".join([d.page_content for d in docs]) | |
domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain, "") | |
full_prompt = f"{domain_prompt}\n\n" + ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text) | |
response = self.processor.process_query(full_prompt) | |
if "error" in response: | |
logger.error("Backend response error during analysis.") | |
return self._error_state(response["error"]) | |
logger.info("Content analysis completed.") | |
return { | |
"messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))], | |
"context": state["context"] | |
} | |
except Exception as e: | |
logger.exception("Error during content analysis.") | |
return self._error_state(f"Analysis Error: {str(e)}") | |
def validate_output(self, state: Dict) -> Dict: | |
try: | |
analysis = state["messages"][-1].content | |
validation_prompt = ( | |
f"Validate the following research analysis:\n{analysis}\n\n" | |
"Check for:\n" | |
"1. Technical accuracy\n" | |
"2. Citation support (are claims backed by evidence?)\n" | |
"3. Logical consistency\n" | |
"4. Methodological soundness\n\n" | |
"Respond with 'VALID: [brief justification]' or 'INVALID: [brief justification]'." | |
) | |
response = self.processor.process_query(validation_prompt) | |
logger.info("Output validation completed.") | |
return { | |
"messages": [AIMessage(content=analysis + f"\n\nValidation: {response.get('choices', [{}])[0].get('message', {}).get('content', '')}")] | |
} | |
except Exception as e: | |
logger.exception("Error during output validation.") | |
return self._error_state(f"Validation Error: {str(e)}") | |
def refine_results(self, state: Dict) -> Dict: | |
try: | |
current_count = state["context"].get("refine_count", 0) | |
state["context"]["refine_count"] = current_count + 1 | |
refinement_history = state["context"].setdefault("refinement_history", []) | |
current_analysis = state["messages"][-1].content | |
refinement_history.append(current_analysis) | |
difficulty_level = max(0, 3 - state["context"]["refine_count"]) | |
logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}") | |
if state["context"]["refine_count"] >= 3: | |
meta_prompt = ( | |
"You are given the following series of refinement outputs:\n" + | |
"\n---\n".join(refinement_history) + | |
"\n\nSynthesize the above into a final, concise, and high-quality technical analysis report. " | |
"Focus on the key findings and improvements made across the iterations. Do not introduce new ideas; just synthesize the improvements. Ensure the report is well-structured and easy to understand." | |
) | |
meta_response = self.processor.process_query(meta_prompt) | |
logger.info("Meta-refinement completed.") | |
return { | |
"messages": [AIMessage(content=meta_response.get('choices', [{}])[0].get('message', {}).get('content', ''))], | |
"context": state["context"] | |
} | |
else: | |
refinement_prompt = ( | |
f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n" | |
"First, critically evaluate the analysis and identify its weaknesses, such as inaccuracies, unsupported claims, or lack of clarity. Summarize these weaknesses in a short paragraph.\n\n" | |
"Then, improve the following aspects:\n" | |
"1. Technical precision\n" | |
"2. Empirical grounding\n" | |
"3. Theoretical coherence\n\n" | |
"Use a structured difficulty gradient approach (similar to LADDER) to produce a simpler yet more accurate variant, addressing the weaknesses identified." | |
) | |
response = self.processor.process_query(refinement_prompt) | |
logger.info("Refinement completed.") | |
return { | |
"messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))], | |
"context": state["context"] | |
} | |
except Exception as e: | |
logger.exception("Error during refinement.") | |
return self._error_state(f"Refinement Error: {str(e)}") | |
def _quality_check(self, state: Dict) -> str: | |
refine_count = state["context"].get("refine_count", 0) | |
if refine_count >= 3: | |
logger.warning("Refinement limit reached. Forcing valid outcome.") | |
return "valid" | |
content = state["messages"][-1].content | |
quality = "valid" if "VALID" in content else "invalid" | |
logger.info(f"Quality check returned: {quality}") | |
return quality | |
def _error_state(self, message: str) -> Dict: | |
logger.error(message) | |
return { | |
"messages": [{"content": f"❌ {message}"}], | |
"context": {"error": True}, | |
"metadata": {"status": "error"} | |
} | |
def enhance_analysis(self, state: Dict) -> Dict: | |
try: | |
analysis = state["messages"][-1].content | |
enhanced = f"{analysis}\n\n## Multi-Modal Insights\n" | |
if "images" in state["context"]: | |
enhanced += "### Visual Evidence\n" | |
for img in state["context"]["images"]: | |
enhanced += f"\n" | |
if "code" in state["context"]: | |
enhanced += "### Code Artifacts\n```python\n" | |
for code in state["context"]["code"]: | |
enhanced += f"{code}\n" | |
enhanced += "```" | |
return { | |
"messages": [{"content": enhanced}], | |
"context": state["context"] | |
} | |
except Exception as e: | |
logger.exception("Error during multi-modal enhancement.") | |
return self._error_state(f"Enhancement Error: {str(e)}") | |