File size: 12,287 Bytes
f0840f2
 
 
 
904690b
f0840f2
904690b
f0840f2
 
904690b
f0840f2
 
 
 
904690b
 
c234528
904690b
c234528
904690b
 
 
f0840f2
 
 
1ac0e39
 
 
f0840f2
 
 
1ac0e39
f0840f2
 
 
 
 
 
 
 
 
1ac0e39
 
f0840f2
 
 
 
 
 
 
 
1ac0e39
f0840f2
a1bc85b
f0840f2
1ac0e39
f0840f2
 
 
 
 
1ac0e39
a1bc85b
c2e09d4
 
 
 
 
 
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
1ac0e39
a1bc85b
f0840f2
 
 
 
 
 
 
 
a1bc85b
f0840f2
 
 
 
 
 
 
 
a1bc85b
c2e09d4
 
f0840f2
c2e09d4
 
1ac0e39
 
 
 
3c1cff1
 
 
c2e09d4
 
 
 
 
 
 
 
 
f0840f2
 
 
 
 
 
 
 
1ac0e39
a1bc85b
1ac0e39
 
 
f0840f2
 
a1bc85b
f0840f2
 
 
 
904690b
 
f0840f2
 
 
 
 
1ac0e39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0840f2
 
 
 
 
 
 
 
a1bc85b
f0840f2
 
 
3589b0d
f0840f2
 
1ac0e39
f0840f2
 
 
 
 
 
 
 
 
3589b0d
f0840f2
1ac0e39
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904690b
f0840f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904690b
f0840f2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# workflow.py

import time
from datetime import datetime
from typing import Dict, Any, Sequence

from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict, Annotated

from processor import EnhancedCognitiveProcessor
from config import ResearchConfig

import logging
logger = logging.getLogger(__name__)

# Define the state schema
class AgentState(TypedDict):
    messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages]
    context: Dict[str, Any]
    metadata: Dict[str, Any]

class ResearchWorkflow:
    """
    A multi-step research workflow employing Retrieval-Augmented Generation (RAG) with an additional verification step.
    This workflow supports multiple domains (e.g., Biomedical, Legal, Environmental, Competitive Programming, Social Sciences)
    and integrates domain-specific prompts, iterative refinement, and a final verification to reduce hallucinations.
    """
    def __init__(self) -> None:
        self.processor = EnhancedCognitiveProcessor()
        self.workflow = StateGraph(AgentState)
        self._build_workflow()
        self.app = self.workflow.compile()

    def _build_workflow(self) -> None:
        self.workflow.add_node("ingest", self.ingest_query)
        self.workflow.add_node("retrieve", self.retrieve_documents)
        self.workflow.add_node("analyze", self.analyze_content)
        self.workflow.add_node("validate", self.validate_output)
        self.workflow.add_node("refine", self.refine_results)
        # New verify node to further cross-check the output
        self.workflow.add_node("verify", self.verify_output)
        self.workflow.set_entry_point("ingest")
        self.workflow.add_edge("ingest", "retrieve")
        self.workflow.add_edge("retrieve", "analyze")
        self.workflow.add_conditional_edges(
            "analyze",
            self._quality_check,
            {"valid": "validate", "invalid": "refine"}
        )
        self.workflow.add_edge("validate", "verify")
        self.workflow.add_edge("refine", "retrieve")
        # Extended node for multi-modal enhancement
        self.workflow.add_node("enhance", self.enhance_analysis)
        self.workflow.add_edge("verify", "enhance")
        self.workflow.add_edge("enhance", END)

    def ingest_query(self, state: Dict) -> Dict:
        try:
            query = state["messages"][-1].content
            # Normalize the domain string; default to 'biomedical research'
            domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower()
            new_context = {
                "raw_query": query,
                "domain": domain,
                "refine_count": 0,
                "refinement_history": []
            }
            logger.info(f"Query ingested. Domain: {domain}")
            return {
                "messages": [AIMessage(content="Query ingested successfully")],
                "context": new_context,
                "metadata": {"timestamp": datetime.now().isoformat()}
            }
        except Exception as e:
            logger.exception("Error during query ingestion.")
            return self._error_state(f"Ingestion Error: {str(e)}")

    def retrieve_documents(self, state: Dict) -> Dict:
        try:
            query = state["context"]["raw_query"]
            # Placeholder retrieval: currently returns an empty list (simulate no documents)
            docs = []
            logger.info(f"Retrieved {len(docs)} documents for query.")
            return {
                "messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
                "context": {
                    "documents": docs,
                    "retrieval_time": time.time(),
                    "refine_count": state["context"].get("refine_count", 0),
                    "refinement_history": state["context"].get("refinement_history", []),
                    "domain": state["context"].get("domain", "biomedical research")
                }
            }
        except Exception as e:
            logger.exception("Error during document retrieval.")
            return self._error_state(f"Retrieval Error: {str(e)}")

    def analyze_content(self, state: Dict) -> Dict:
        try:
            domain = state["context"].get("domain", "biomedical research").strip().lower()
            docs = state["context"].get("documents", [])
            if docs:
                docs_text = "\n\n".join([d.page_content for d in docs])
            else:
                docs_text = state["context"].get("raw_query", "")
                logger.info("No documents retrieved; switching to dynamic synthesis (RAG mode).")
            # Use domain-specific prompt; for legal research, inject legal-specific guidance.
            domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain, 
                            "Provide an analysis based on the provided context.")
            full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
                          f"{domain_prompt}\n\n" + \
                          ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
            response = self.processor.process_query(full_prompt)
            if "error" in response:
                logger.error("Backend response error during analysis.")
                return self._error_state(response["error"])
            logger.info("Content analysis completed using RAG approach.")
            return {
                "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                "context": state["context"]
            }
        except Exception as e:
            logger.exception("Error during content analysis.")
            return self._error_state(f"Analysis Error: {str(e)}")

    def validate_output(self, state: Dict) -> Dict:
        try:
            analysis = state["messages"][-1].content
            validation_prompt = (
                f"Validate the following analysis for accuracy and domain-specific relevance:\n{analysis}\n\n"
                "Criteria:\n"
                "1. Factual and technical accuracy\n"
                "2. For legal research: inclusion of relevant precedents and statutory interpretations; "
                "for other domains: appropriate domain insights\n"
                "3. Logical consistency\n"
                "4. Methodological soundness\n\n"
                "Respond with 'VALID: [justification]' or 'INVALID: [justification]'."
            )
            response = self.processor.process_query(validation_prompt)
            logger.info("Output validation completed.")
            return {
                "messages": [AIMessage(content=analysis + f"\n\nValidation: {response.get('choices', [{}])[0].get('message', {}).get('content', '')}")],
                "context": state["context"]
            }
        except Exception as e:
            logger.exception("Error during output validation.")
            return self._error_state(f"Validation Error: {str(e)}")

    def verify_output(self, state: Dict) -> Dict:
        try:
            # New verify step: cross-check the analysis using an external fact-checking prompt.
            analysis = state["messages"][-1].content
            verification_prompt = (
                f"Verify the following analysis by comparing it with established external legal databases and reference texts:\n{analysis}\n\n"
                "Identify any discrepancies or hallucinations and provide a brief correction if necessary."
            )
            response = self.processor.process_query(verification_prompt)
            logger.info("Output verification completed.")
            # Here, you can merge the verification feedback with the analysis.
            verified_analysis = analysis + "\n\nVerification Feedback: " + response.get('choices', [{}])[0].get('message', {}).get('content', '')
            return {
                "messages": [AIMessage(content=verified_analysis)],
                "context": state["context"]
            }
        except Exception as e:
            logger.exception("Error during output verification.")
            return self._error_state(f"Verification Error: {str(e)}")

    def refine_results(self, state: Dict) -> Dict:
        try:
            current_count = state["context"].get("refine_count", 0)
            state["context"]["refine_count"] = current_count + 1
            refinement_history = state["context"].setdefault("refinement_history", [])
            current_analysis = state["messages"][-1].content
            refinement_history.append(current_analysis)
            difficulty_level = max(0, 3 - state["context"]["refine_count"])
            domain = state["context"].get("domain", "biomedical research")
            logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}")
            if state["context"]["refine_count"] >= 3:
                meta_prompt = (
                    f"Domain: {domain}\n"
                    "You are given the following series of refinement outputs:\n" +
                    "\n---\n".join(refinement_history) +
                    "\n\nSynthesize these into a final, concise analysis report with improved accuracy and verifiable details."
                )
                meta_response = self.processor.process_query(meta_prompt)
                logger.info("Meta-refinement completed.")
                return {
                    "messages": [AIMessage(content=meta_response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                    "context": state["context"]
                }
            else:
                refinement_prompt = (
                    f"Domain: {domain}\n"
                    f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
                    "Identify and correct any weaknesses or hallucinations in the analysis, providing verifiable details."
                )
                response = self.processor.process_query(refinement_prompt)
                logger.info("Refinement completed.")
                return {
                    "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))],
                    "context": state["context"]
                }
        except Exception as e:
            logger.exception("Error during refinement.")
            return self._error_state(f"Refinement Error: {str(e)}")

    def _quality_check(self, state: Dict) -> str:
        refine_count = state["context"].get("refine_count", 0)
        if refine_count >= 3:
            logger.warning("Refinement limit reached. Forcing valid outcome.")
            return "valid"
        content = state["messages"][-1].content
        quality = "valid" if "VALID" in content else "invalid"
        logger.info(f"Quality check returned: {quality}")
        return quality

    def _error_state(self, message: str) -> Dict:
        logger.error(message)
        return {
            "messages": [AIMessage(content=f"❌ {message}")],
            "context": {"error": True},
            "metadata": {"status": "error"}
        }

    def enhance_analysis(self, state: Dict) -> Dict:
        try:
            analysis = state["messages"][-1].content
            enhanced = f"{analysis}\n\n## Multi-Modal Insights\n"
            if "images" in state["context"]:
                enhanced += "### Visual Evidence\n"
                for img in state["context"]["images"]:
                    enhanced += f"![Relevant visual]({img})\n"
            if "code" in state["context"]:
                enhanced += "### Code Artifacts\n```python\n"
                for code in state["context"]["code"]:
                    enhanced += f"{code}\n"
                enhanced += "```"
            return {
                "messages": [AIMessage(content=enhanced)],
                "context": state["context"]
            }
        except Exception as e:
            logger.exception("Error during multi-modal enhancement.")
            return self._error_state(f"Enhancement Error: {str(e)}")