File size: 11,329 Bytes
f0840f2
 
 
 
904690b
f0840f2
904690b
f0840f2
 
904690b
f0840f2
 
 
 
904690b
 
c234528
904690b
c234528
904690b
 
 
f0840f2
 
 
a1bc85b
 
3c1cff1
 
 
 
 
a1bc85b
f0840f2
 
 
a1bc85b
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1bc85b
f0840f2
 
 
 
 
 
 
a1bc85b
 
c2e09d4
 
 
 
 
 
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
a1bc85b
 
f0840f2
 
 
 
 
 
 
 
a1bc85b
f0840f2
 
 
 
 
 
 
 
a1bc85b
 
c2e09d4
a1bc85b
c2e09d4
f0840f2
c2e09d4
 
a1bc85b
 
 
 
3c1cff1
 
 
c2e09d4
 
 
 
 
 
 
 
 
f0840f2
 
 
 
 
 
 
 
a1bc85b
 
 
 
f0840f2
 
a1bc85b
f0840f2
 
 
 
904690b
 
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
a1bc85b
f0840f2
 
 
3589b0d
f0840f2
 
a1bc85b
 
f0840f2
 
 
 
 
 
 
 
 
3589b0d
f0840f2
a1bc85b
 
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904690b
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904690b
f0840f2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# workflow.py

import time
from datetime import datetime
from typing import Dict, Any, Sequence

from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict, Annotated

from processor import EnhancedCognitiveProcessor
from config import ResearchConfig

import logging
logger = logging.getLogger(__name__)

# Define the state schema
class AgentState(TypedDict):
    messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages]
    context: Dict[str, Any]
    metadata: Dict[str, Any]

class ResearchWorkflow:
    """
    A multi-step research workflow that leverages Retrieval-Augmented Generation (RAG).
    Supports domains including:
      - Biomedical Research
      - Legal Research
      - Environmental and Energy Studies
      - Competitive Programming and Theoretical Computer Science
      - Social Sciences
    This implementation normalizes the domain and uses domain-specific prompts and fallbacks.
    """
    def __init__(self) -> None:
        self.processor = EnhancedCognitiveProcessor()
        self.workflow = StateGraph(AgentState)  # Supply state schema
        self._build_workflow()
        self.app = self.workflow.compile()

    def _build_workflow(self) -> None:
        self.workflow.add_node("ingest", self.ingest_query)
        self.workflow.add_node("retrieve", self.retrieve_documents)
        self.workflow.add_node("analyze", self.analyze_content)
        self.workflow.add_node("validate", self.validate_output)
        self.workflow.add_node("refine", self.refine_results)
        self.workflow.set_entry_point("ingest")
        self.workflow.add_edge("ingest", "retrieve")
        self.workflow.add_edge("retrieve", "analyze")
        self.workflow.add_conditional_edges(
            "analyze",
            self._quality_check,
            {"valid": "validate", "invalid": "refine"}
        )
        self.workflow.add_edge("validate", END)
        self.workflow.add_edge("refine", "retrieve")
        # Extended node for multi-modal enhancement
        self.workflow.add_node("enhance", self.enhance_analysis)
        self.workflow.add_edge("validate", "enhance")
        self.workflow.add_edge("enhance", END)

    def ingest_query(self, state: Dict) -> Dict:
        try:
            query = state["messages"][-1].content
            # Normalize domain string to lower-case; default to 'biomedical research'
            domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower()
            new_context = {
                "raw_query": query,
                "domain": domain,
                "refine_count": 0,
                "refinement_history": []
            }
            logger.info(f"Query ingested. Domain: {domain}")
            return {
                "messages": [AIMessage(content="Query ingested successfully")],
                "context": new_context,
                "metadata": {"timestamp": datetime.now().isoformat()}
            }
        except Exception as e:
            logger.exception("Error during query ingestion.")
            return self._error_state(f"Ingestion Error: {str(e)}")

    def retrieve_documents(self, state: Dict) -> Dict:
        try:
            query = state["context"]["raw_query"]
            # Simulate retrieval; for now, an empty list indicates no external documents found.
            docs = []
            logger.info(f"Retrieved {len(docs)} documents for query.")
            return {
                "messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
                "context": {
                    "documents": docs,
                    "retrieval_time": time.time(),
                    "refine_count": state["context"].get("refine_count", 0),
                    "refinement_history": state["context"].get("refinement_history", []),
                    "domain": state["context"].get("domain", "biomedical research")
                }
            }
        except Exception as e:
            logger.exception("Error during document retrieval.")
            return self._error_state(f"Retrieval Error: {str(e)}")

    def analyze_content(self, state: Dict) -> Dict:
        try:
            # Normalize domain and use it for prompt generation
            domain = state["context"].get("domain", "biomedical research").strip().lower()
            docs = state["context"].get("documents", [])
            # Use retrieved documents if available; else, use raw query as fallback.
            if docs:
                docs_text = "\n\n".join([d.page_content for d in docs])
            else:
                docs_text = state["context"].get("raw_query", "")
                logger.info("No documents retrieved; using dynamic synthesis (RAG mode).")
            # Get domain-specific prompt; ensure fallback prompts exist for all supported domains.
            domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain, "Consider relevant legal cases and statutory interpretations.")
            # Build the final prompt with domain tag for clarity.
            full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
                          f"{domain_prompt}\n\n" + \
                          ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
            response = self.processor.process_query(full_prompt)
            if "error" in response:
                logger.error("Backend response error during analysis.")
                return self._error_state(response["error"])
            logger.info("Content analysis completed using RAG approach.")
            return {
                "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                "context": state["context"]
            }
        except Exception as e:
            logger.exception("Error during content analysis.")
            return self._error_state(f"Analysis Error: {str(e)}")

    def validate_output(self, state: Dict) -> Dict:
        try:
            analysis = state["messages"][-1].content
            validation_prompt = (
                f"Validate the following analysis for correctness, clarity, and legal grounding:\n{analysis}\n\n"
                "Criteria:\n"
                "1. Technical and legal accuracy\n"
                "2. Evidence and citation support\n"
                "3. Logical consistency\n"
                "4. Methodological soundness\n\n"
                "Respond with 'VALID: [justification]' or 'INVALID: [justification]'."
            )
            response = self.processor.process_query(validation_prompt)
            logger.info("Output validation completed.")
            return {
                "messages": [AIMessage(content=analysis + f"\n\nValidation: {response.get('choices', [{}])[0].get('message', {}).get('content', '')}")],
                "context": state["context"]
            }
        except Exception as e:
            logger.exception("Error during output validation.")
            return self._error_state(f"Validation Error: {str(e)}")

    def refine_results(self, state: Dict) -> Dict:
        try:
            current_count = state["context"].get("refine_count", 0)
            state["context"]["refine_count"] = current_count + 1
            refinement_history = state["context"].setdefault("refinement_history", [])
            current_analysis = state["messages"][-1].content
            refinement_history.append(current_analysis)
            difficulty_level = max(0, 3 - state["context"]["refine_count"])
            domain = state["context"].get("domain", "biomedical research")
            logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}")
            if state["context"]["refine_count"] >= 3:
                meta_prompt = (
                    f"Domain: {domain}\n"
                    "You are given the following series of refinement outputs:\n" +
                    "\n---\n".join(refinement_history) +
                    "\n\nSynthesize these into a final, concise legal analysis report, highlighting key precedents and statutory interpretations. "
                    "Focus on improving accuracy and relevance for legal research."
                )
                meta_response = self.processor.process_query(meta_prompt)
                logger.info("Meta-refinement completed.")
                return {
                    "messages": [AIMessage(content=meta_response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                    "context": state["context"]
                }
            else:
                refinement_prompt = (
                    f"Domain: {domain}\n"
                    f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
                    "First, identify weaknesses such as lack of legal grounding or misinterpretation of cases. "
                    "Then, improve the analysis with clear references to legal precedents and statutory language."
                )
                response = self.processor.process_query(refinement_prompt)
                logger.info("Refinement completed.")
                return {
                    "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                    "context": state["context"]
                }
        except Exception as e:
            logger.exception("Error during refinement.")
            return self._error_state(f"Refinement Error: {str(e)}")

    def _quality_check(self, state: Dict) -> str:
        refine_count = state["context"].get("refine_count", 0)
        if refine_count >= 3:
            logger.warning("Refinement limit reached. Forcing valid outcome.")
            return "valid"
        content = state["messages"][-1].content
        quality = "valid" if "VALID" in content else "invalid"
        logger.info(f"Quality check returned: {quality}")
        return quality

    def _error_state(self, message: str) -> Dict:
        logger.error(message)
        return {
            "messages": [AIMessage(content=f"❌ {message}")],
            "context": {"error": True},
            "metadata": {"status": "error"}
        }

    def enhance_analysis(self, state: Dict) -> Dict:
        try:
            analysis = state["messages"][-1].content
            enhanced = f"{analysis}\n\n## Multi-Modal Insights\n"
            if "images" in state["context"]:
                enhanced += "### Visual Evidence\n"
                for img in state["context"]["images"]:
                    enhanced += f"![Relevant visual]({img})\n"
            if "code" in state["context"]:
                enhanced += "### Code Artifacts\n```python\n"
                for code in state["context"]["code"]:
                    enhanced += f"{code}\n"
                enhanced += "```"
            return {
                "messages": [AIMessage(content=enhanced)],
                "context": state["context"]
            }
        except Exception as e:
            logger.exception("Error during multi-modal enhancement.")
            return self._error_state(f"Enhancement Error: {str(e)}")