File size: 11,741 Bytes
f0840f2
 
 
 
904690b
f0840f2
904690b
f0840f2
 
904690b
f0840f2
 
 
 
904690b
 
c234528
904690b
c234528
904690b
 
 
f0840f2
 
 
3c1cff1
 
 
 
 
 
 
 
f0840f2
 
 
3c1cff1
904690b
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c1cff1
f0840f2
 
 
 
 
 
 
3c1cff1
f0840f2
c2e09d4
 
 
 
 
 
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
3c1cff1
 
 
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2e09d4
3c1cff1
c2e09d4
f0840f2
c2e09d4
 
3c1cff1
c2e09d4
3c1cff1
 
 
 
c2e09d4
 
 
 
 
 
 
 
 
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904690b
 
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
3589b0d
f0840f2
 
 
3589b0d
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
3589b0d
f0840f2
 
 
 
 
 
3c1cff1
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904690b
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904690b
f0840f2
 
 
 
 
e610129
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# workflow.py

import time
from datetime import datetime
from typing import Dict, Any, Sequence

from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict, Annotated

from processor import EnhancedCognitiveProcessor
from config import ResearchConfig

import logging
logger = logging.getLogger(__name__)

# Define the state schema
class AgentState(TypedDict):
    messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages]
    context: Dict[str, Any]
    metadata: Dict[str, Any]

class ResearchWorkflow:
    """
    A multi-step research workflow that leverages a Retrieval-Augmented Generation (RAG) strategy.
    It dynamically retrieves external data and integrates it with the raw query to generate domain-specific analyses.
    Supported domains include:
      - Biomedical Research
      - Legal Research
      - Environmental and Energy Studies
      - Competitive Programming and Theoretical Computer Science
      - Social Sciences
    """
    def __init__(self) -> None:
        self.processor = EnhancedCognitiveProcessor()
        # Provide the state schema to the StateGraph constructor.
        self.workflow = StateGraph(AgentState)
        self._build_workflow()
        self.app = self.workflow.compile()

    def _build_workflow(self) -> None:
        self.workflow.add_node("ingest", self.ingest_query)
        self.workflow.add_node("retrieve", self.retrieve_documents)
        self.workflow.add_node("analyze", self.analyze_content)
        self.workflow.add_node("validate", self.validate_output)
        self.workflow.add_node("refine", self.refine_results)
        self.workflow.set_entry_point("ingest")
        self.workflow.add_edge("ingest", "retrieve")
        self.workflow.add_edge("retrieve", "analyze")
        self.workflow.add_conditional_edges(
            "analyze",
            self._quality_check,
            {"valid": "validate", "invalid": "refine"}
        )
        self.workflow.add_edge("validate", END)
        self.workflow.add_edge("refine", "retrieve")
        # Extended node for multi-modal enhancement.
        self.workflow.add_node("enhance", self.enhance_analysis)
        self.workflow.add_edge("validate", "enhance")
        self.workflow.add_edge("enhance", END)

    def ingest_query(self, state: Dict) -> Dict:
        try:
            query = state["messages"][-1].content
            # Get the domain from state; default to Biomedical Research if not provided.
            domain = state.get("context", {}).get("domain", "Biomedical Research")
            new_context = {
                "raw_query": query,
                "domain": domain,
                "refine_count": 0,
                "refinement_history": []
            }
            logger.info(f"Query ingested. Domain: {domain}")
            return {
                "messages": [AIMessage(content="Query ingested successfully")],
                "context": new_context,
                "metadata": {"timestamp": datetime.now().isoformat()}
            }
        except Exception as e:
            logger.exception("Error during query ingestion.")
            return self._error_state(f"Ingestion Error: {str(e)}")

    def retrieve_documents(self, state: Dict) -> Dict:
        try:
            query = state["context"]["raw_query"]
            # For demonstration, we use an empty list to simulate retrieval failure.
            # In a full implementation, integrate a retriever (e.g., via LangChain, LlamaIndex, or a vector DB).
            docs = []  
            logger.info(f"Retrieved {len(docs)} documents for query.")
            return {
                "messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
                "context": {
                    "documents": docs,
                    "retrieval_time": time.time(),
                    "refine_count": state["context"].get("refine_count", 0),
                    "refinement_history": state["context"].get("refinement_history", []),
                    "domain": state["context"].get("domain", "Biomedical Research")
                }
            }
        except Exception as e:
            logger.exception("Error during document retrieval.")
            return self._error_state(f"Retrieval Error: {str(e)}")

    def analyze_content(self, state: Dict) -> Dict:
        try:
            domain = state["context"].get("domain", "Biomedical Research").strip().lower()
            docs = state["context"].get("documents", [])
            # If documents are present, use their content; otherwise, fall back to the raw query.
            if docs:
                docs_text = "\n\n".join([d.page_content for d in docs])
            else:
                docs_text = state["context"].get("raw_query", "")
                logger.info("No documents retrieved; switching to dynamic synthesis using RAG.")
            domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain, "")
            # Combine the domain prompt with either retrieved text or raw query.
            full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
                          f"{domain_prompt}\n\n" + \
                          ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
            response = self.processor.process_query(full_prompt)
            if "error" in response:
                logger.error("Backend response error during analysis.")
                return self._error_state(response["error"])
            logger.info("Content analysis completed using RAG approach.")
            return {
                "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                "context": state["context"]
            }
        except Exception as e:
            logger.exception("Error during content analysis.")
            return self._error_state(f"Analysis Error: {str(e)}")

    def validate_output(self, state: Dict) -> Dict:
        try:
            analysis = state["messages"][-1].content
            validation_prompt = (
                f"Validate the following research analysis:\n{analysis}\n\n"
                "Check for:\n"
                "1. Technical accuracy\n"
                "2. Citation support (are claims backed by evidence?)\n"
                "3. Logical consistency\n"
                "4. Methodological soundness\n\n"
                "Respond with 'VALID: [brief justification]' or 'INVALID: [brief justification]'."
            )
            response = self.processor.process_query(validation_prompt)
            logger.info("Output validation completed.")
            return {
                "messages": [AIMessage(content=analysis + f"\n\nValidation: {response.get('choices', [{}])[0].get('message', {}).get('content', '')}")],
                "context": state["context"]
            }
        except Exception as e:
            logger.exception("Error during output validation.")
            return self._error_state(f"Validation Error: {str(e)}")

    def refine_results(self, state: Dict) -> Dict:
        try:
            current_count = state["context"].get("refine_count", 0)
            state["context"]["refine_count"] = current_count + 1
            refinement_history = state["context"].setdefault("refinement_history", [])
            current_analysis = state["messages"][-1].content
            refinement_history.append(current_analysis)
            difficulty_level = max(0, 3 - state["context"]["refine_count"])
            domain = state["context"].get("domain", "Biomedical Research")
            logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}")
            if state["context"]["refine_count"] >= 3:
                meta_prompt = (
                    f"Domain: {domain}\n"
                    "You are given the following series of refinement outputs:\n" +
                    "\n---\n".join(refinement_history) +
                    "\n\nSynthesize the above into a final, concise, and high-quality technical analysis report. "
                    "Focus on the key findings and improvements made across the iterations. Do not introduce new ideas; just synthesize the improvements. Ensure the report is well-structured and easy to understand."
                )
                meta_response = self.processor.process_query(meta_prompt)
                logger.info("Meta-refinement completed.")
                return {
                    "messages": [AIMessage(content=meta_response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                    "context": state["context"]
                }
            else:
                refinement_prompt = (
                    f"Domain: {domain}\n"
                    f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
                    "First, critically evaluate the analysis and identify its weaknesses, such as inaccuracies, unsupported claims, or lack of clarity. Summarize these weaknesses in a short paragraph.\n\n"
                    "Then, improve the following aspects:\n"
                    "1. Technical precision\n"
                    "2. Empirical grounding\n"
                    "3. Theoretical coherence\n\n"
                    "Use a structured difficulty gradient approach to produce a simpler yet more accurate variant, addressing the identified weaknesses."
                )
                response = self.processor.process_query(refinement_prompt)
                logger.info("Refinement completed.")
                return {
                    "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                    "context": state["context"]
                }
        except Exception as e:
            logger.exception("Error during refinement.")
            return self._error_state(f"Refinement Error: {str(e)}")

    def _quality_check(self, state: Dict) -> str:
        refine_count = state["context"].get("refine_count", 0)
        if refine_count >= 3:
            logger.warning("Refinement limit reached. Forcing valid outcome.")
            return "valid"
        content = state["messages"][-1].content
        quality = "valid" if "VALID" in content else "invalid"
        logger.info(f"Quality check returned: {quality}")
        return quality

    def _error_state(self, message: str) -> Dict:
        logger.error(message)
        return {
            "messages": [AIMessage(content=f"❌ {message}")],
            "context": {"error": True},
            "metadata": {"status": "error"}
        }

    def enhance_analysis(self, state: Dict) -> Dict:
        try:
            analysis = state["messages"][-1].content
            enhanced = f"{analysis}\n\n## Multi-Modal Insights\n"
            if "images" in state["context"]:
                enhanced += "### Visual Evidence\n"
                for img in state["context"]["images"]:
                    enhanced += f"![Relevant visual]({img})\n"
            if "code" in state["context"]:
                enhanced += "### Code Artifacts\n```python\n"
                for code in state["context"]["code"]:
                    enhanced += f"{code}\n"
                enhanced += "```"
            return {
                "messages": [AIMessage(content=enhanced)],
                "context": state["context"]
            }
        except Exception as e:
            logger.exception("Error during multi-modal enhancement.")
            return self._error_state(f"Enhancement Error: {str(e)}")