File size: 10,602 Bytes
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# workflow.py

import time
from datetime import datetime
from typing import Dict

from langchain_core.messages import AIMessage, HumanMessage
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages

from processor import EnhancedCognitiveProcessor
from config import ResearchConfig

import logging
logger = logging.getLogger(__name__)

class ResearchWorkflow:
    """
    Defines a multi-step research workflow using a state graph.
    """
    def __init__(self) -> None:
        self.processor = EnhancedCognitiveProcessor()
        self.workflow = StateGraph()
        self._build_workflow()
        self.app = self.workflow.compile()

    def _build_workflow(self) -> None:
        self.workflow.add_node("ingest", self.ingest_query)
        self.workflow.add_node("retrieve", self.retrieve_documents)
        self.workflow.add_node("analyze", self.analyze_content)
        self.workflow.add_node("validate", self.validate_output)
        self.workflow.add_node("refine", self.refine_results)
        self.workflow.set_entry_point("ingest")
        self.workflow.add_edge("ingest", "retrieve")
        self.workflow.add_edge("retrieve", "analyze")
        self.workflow.add_conditional_edges(
            "analyze",
            self._quality_check,
            {"valid": "validate", "invalid": "refine"}
        )
        self.workflow.add_edge("validate", END)
        self.workflow.add_edge("refine", "retrieve")
        # Extended node for multi-modal enhancement
        self.workflow.add_node("enhance", self.enhance_analysis)
        self.workflow.add_edge("validate", "enhance")
        self.workflow.add_edge("enhance", END)

    def ingest_query(self, state: Dict) -> Dict:
        try:
            query = state["messages"][-1].content
            # Retrieve the domain from the state's context (defaulting to Biomedical Research)
            domain = state.get("context", {}).get("domain", "Biomedical Research")
            new_context = {"raw_query": query, "domain": domain, "refine_count": 0, "refinement_history": []}
            logger.info(f"Query ingested. Domain: {domain}")
            return {
                "messages": [AIMessage(content="Query ingested successfully")],
                "context": new_context,
                "metadata": {"timestamp": datetime.now().isoformat()}
            }
        except Exception as e:
            logger.exception("Error during query ingestion.")
            return self._error_state(f"Ingestion Error: {str(e)}")

    def retrieve_documents(self, state: Dict) -> Dict:
        try:
            query = state["context"]["raw_query"]
            # For demonstration, we use an empty document list.
            # Replace this with actual retrieval logic as needed.
            docs = []
            logger.info(f"Retrieved {len(docs)} documents for query.")
            return {
                "messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
                "context": {
                    "documents": docs,
                    "retrieval_time": time.time(),
                    "refine_count": state["context"].get("refine_count", 0),
                    "refinement_history": state["context"].get("refinement_history", []),
                    "domain": state["context"].get("domain", "Biomedical Research")
                }
            }
        except Exception as e:
            logger.exception("Error during document retrieval.")
            return self._error_state(f"Retrieval Error: {str(e)}")

    def analyze_content(self, state: Dict) -> Dict:
        try:
            domain = state["context"].get("domain", "Biomedical Research").strip().lower()
            fallback_analyses = ResearchConfig.DOMAIN_FALLBACKS
            if domain in fallback_analyses:
                logger.info(f"Using fallback analysis for domain: {state['context'].get('domain')}")
                return {
                    "messages": [AIMessage(content=fallback_analyses[domain].strip())],
                    "context": state["context"]
                }
            else:
                docs = state["context"].get("documents", [])
                docs_text = "\n\n".join([d.page_content for d in docs])
                domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain, "")
                full_prompt = f"{domain_prompt}\n\n" + ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
                response = self.processor.process_query(full_prompt)
                if "error" in response:
                    logger.error("Backend response error during analysis.")
                    return self._error_state(response["error"])
                logger.info("Content analysis completed.")
                return {
                    "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                    "context": state["context"]
                }
        except Exception as e:
            logger.exception("Error during content analysis.")
            return self._error_state(f"Analysis Error: {str(e)}")

    def validate_output(self, state: Dict) -> Dict:
        try:
            analysis = state["messages"][-1].content
            validation_prompt = (
                f"Validate the following research analysis:\n{analysis}\n\n"
                "Check for:\n"
                "1. Technical accuracy\n"
                "2. Citation support (are claims backed by evidence?)\n"
                "3. Logical consistency\n"
                "4. Methodological soundness\n\n"
                "Respond with 'VALID: [brief justification]' or 'INVALID: [brief justification]'."
            )
            response = self.processor.process_query(validation_prompt)
            logger.info("Output validation completed.")
            return {
                "messages": [AIMessage(content=analysis + f"\n\nValidation: {response.get('choices', [{}])[0].get('message', {}).get('content', '')}")]
            }
        except Exception as e:
            logger.exception("Error during output validation.")
            return self._error_state(f"Validation Error: {str(e)}")

    def refine_results(self, state: Dict) -> Dict:
        try:
            current_count = state["context"].get("refine_count", 0)
            state["context"]["refine_count"] = current_count + 1
            refinement_history = state["context"].setdefault("refinement_history", [])
            current_analysis = state["messages"][-1].content
            refinement_history.append(current_analysis)
            difficulty_level = max(0, 3 - state["context"]["refine_count"])
            logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}")

            if state["context"]["refine_count"] >= 3:
                meta_prompt = (
                    "You are given the following series of refinement outputs:\n" +
                    "\n---\n".join(refinement_history) +
                    "\n\nSynthesize the above into a final, concise, and high-quality technical analysis report. "
                    "Focus on the key findings and improvements made across the iterations. Do not introduce new ideas; just synthesize the improvements. Ensure the report is well-structured and easy to understand."
                )
                meta_response = self.processor.process_query(meta_prompt)
                logger.info("Meta-refinement completed.")
                return {
                    "messages": [AIMessage(content=meta_response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                    "context": state["context"]
                }
            else:
                refinement_prompt = (
                    f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
                    "First, critically evaluate the analysis and identify its weaknesses, such as inaccuracies, unsupported claims, or lack of clarity. Summarize these weaknesses in a short paragraph.\n\n"
                    "Then, improve the following aspects:\n"
                    "1. Technical precision\n"
                    "2. Empirical grounding\n"
                    "3. Theoretical coherence\n\n"
                    "Use a structured difficulty gradient approach (similar to LADDER) to produce a simpler yet more accurate variant, addressing the weaknesses identified."
                )
                response = self.processor.process_query(refinement_prompt)
                logger.info("Refinement completed.")
                return {
                    "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                    "context": state["context"]
                }
        except Exception as e:
            logger.exception("Error during refinement.")
            return self._error_state(f"Refinement Error: {str(e)}")

    def _quality_check(self, state: Dict) -> str:
        refine_count = state["context"].get("refine_count", 0)
        if refine_count >= 3:
            logger.warning("Refinement limit reached. Forcing valid outcome.")
            return "valid"
        content = state["messages"][-1].content
        quality = "valid" if "VALID" in content else "invalid"
        logger.info(f"Quality check returned: {quality}")
        return quality

    def _error_state(self, message: str) -> Dict:
        logger.error(message)
        return {
            "messages": [{"content": f"❌ {message}"}],
            "context": {"error": True},
            "metadata": {"status": "error"}
        }

    def enhance_analysis(self, state: Dict) -> Dict:
        try:
            analysis = state["messages"][-1].content
            enhanced = f"{analysis}\n\n## Multi-Modal Insights\n"
            if "images" in state["context"]:
                enhanced += "### Visual Evidence\n"
                for img in state["context"]["images"]:
                    enhanced += f"![Relevant visual]({img})\n"
            if "code" in state["context"]:
                enhanced += "### Code Artifacts\n```python\n"
                for code in state["context"]["code"]:
                    enhanced += f"{code}\n"
                enhanced += "```"
            return {
                "messages": [{"content": enhanced}],
                "context": state["context"]
            }
        except Exception as e:
            logger.exception("Error during multi-modal enhancement.")
            return self._error_state(f"Enhancement Error: {str(e)}")