mgbam commited on
Commit
1ac0e39
·
verified ·
1 Parent(s): a1bc85b

Update workflow.py

Browse files
Files changed (1) hide show
  1. workflow.py +40 -26
workflow.py CHANGED
@@ -23,18 +23,13 @@ class AgentState(TypedDict):
23
 
24
  class ResearchWorkflow:
25
  """
26
- A multi-step research workflow that leverages Retrieval-Augmented Generation (RAG).
27
- Supports domains including:
28
- - Biomedical Research
29
- - Legal Research
30
- - Environmental and Energy Studies
31
- - Competitive Programming and Theoretical Computer Science
32
- - Social Sciences
33
- This implementation normalizes the domain and uses domain-specific prompts and fallbacks.
34
  """
35
  def __init__(self) -> None:
36
  self.processor = EnhancedCognitiveProcessor()
37
- self.workflow = StateGraph(AgentState) # Supply state schema
38
  self._build_workflow()
39
  self.app = self.workflow.compile()
40
 
@@ -44,6 +39,8 @@ class ResearchWorkflow:
44
  self.workflow.add_node("analyze", self.analyze_content)
45
  self.workflow.add_node("validate", self.validate_output)
46
  self.workflow.add_node("refine", self.refine_results)
 
 
47
  self.workflow.set_entry_point("ingest")
48
  self.workflow.add_edge("ingest", "retrieve")
49
  self.workflow.add_edge("retrieve", "analyze")
@@ -52,17 +49,17 @@ class ResearchWorkflow:
52
  self._quality_check,
53
  {"valid": "validate", "invalid": "refine"}
54
  )
55
- self.workflow.add_edge("validate", END)
56
  self.workflow.add_edge("refine", "retrieve")
57
  # Extended node for multi-modal enhancement
58
  self.workflow.add_node("enhance", self.enhance_analysis)
59
- self.workflow.add_edge("validate", "enhance")
60
  self.workflow.add_edge("enhance", END)
61
 
62
  def ingest_query(self, state: Dict) -> Dict:
63
  try:
64
  query = state["messages"][-1].content
65
- # Normalize domain string to lower-case; default to 'biomedical research'
66
  domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower()
67
  new_context = {
68
  "raw_query": query,
@@ -83,7 +80,7 @@ class ResearchWorkflow:
83
  def retrieve_documents(self, state: Dict) -> Dict:
84
  try:
85
  query = state["context"]["raw_query"]
86
- # Simulate retrieval; for now, an empty list indicates no external documents found.
87
  docs = []
88
  logger.info(f"Retrieved {len(docs)} documents for query.")
89
  return {
@@ -102,18 +99,16 @@ class ResearchWorkflow:
102
 
103
  def analyze_content(self, state: Dict) -> Dict:
104
  try:
105
- # Normalize domain and use it for prompt generation
106
  domain = state["context"].get("domain", "biomedical research").strip().lower()
107
  docs = state["context"].get("documents", [])
108
- # Use retrieved documents if available; else, use raw query as fallback.
109
  if docs:
110
  docs_text = "\n\n".join([d.page_content for d in docs])
111
  else:
112
  docs_text = state["context"].get("raw_query", "")
113
- logger.info("No documents retrieved; using dynamic synthesis (RAG mode).")
114
- # Get domain-specific prompt; ensure fallback prompts exist for all supported domains.
115
- domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain, "Consider relevant legal cases and statutory interpretations.")
116
- # Build the final prompt with domain tag for clarity.
117
  full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
118
  f"{domain_prompt}\n\n" + \
119
  ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
@@ -134,10 +129,11 @@ class ResearchWorkflow:
134
  try:
135
  analysis = state["messages"][-1].content
136
  validation_prompt = (
137
- f"Validate the following analysis for correctness, clarity, and legal grounding:\n{analysis}\n\n"
138
  "Criteria:\n"
139
- "1. Technical and legal accuracy\n"
140
- "2. Evidence and citation support\n"
 
141
  "3. Logical consistency\n"
142
  "4. Methodological soundness\n\n"
143
  "Respond with 'VALID: [justification]' or 'INVALID: [justification]'."
@@ -152,6 +148,26 @@ class ResearchWorkflow:
152
  logger.exception("Error during output validation.")
153
  return self._error_state(f"Validation Error: {str(e)}")
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def refine_results(self, state: Dict) -> Dict:
156
  try:
157
  current_count = state["context"].get("refine_count", 0)
@@ -167,8 +183,7 @@ class ResearchWorkflow:
167
  f"Domain: {domain}\n"
168
  "You are given the following series of refinement outputs:\n" +
169
  "\n---\n".join(refinement_history) +
170
- "\n\nSynthesize these into a final, concise legal analysis report, highlighting key precedents and statutory interpretations. "
171
- "Focus on improving accuracy and relevance for legal research."
172
  )
173
  meta_response = self.processor.process_query(meta_prompt)
174
  logger.info("Meta-refinement completed.")
@@ -180,8 +195,7 @@ class ResearchWorkflow:
180
  refinement_prompt = (
181
  f"Domain: {domain}\n"
182
  f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
183
- "First, identify weaknesses such as lack of legal grounding or misinterpretation of cases. "
184
- "Then, improve the analysis with clear references to legal precedents and statutory language."
185
  )
186
  response = self.processor.process_query(refinement_prompt)
187
  logger.info("Refinement completed.")
 
23
 
24
  class ResearchWorkflow:
25
  """
26
+ A multi-step research workflow employing Retrieval-Augmented Generation (RAG) with an additional verification step.
27
+ This workflow supports multiple domains (e.g., Biomedical, Legal, Environmental, Competitive Programming, Social Sciences)
28
+ and integrates domain-specific prompts, iterative refinement, and a final verification to reduce hallucinations.
 
 
 
 
 
29
  """
30
  def __init__(self) -> None:
31
  self.processor = EnhancedCognitiveProcessor()
32
+ self.workflow = StateGraph(AgentState)
33
  self._build_workflow()
34
  self.app = self.workflow.compile()
35
 
 
39
  self.workflow.add_node("analyze", self.analyze_content)
40
  self.workflow.add_node("validate", self.validate_output)
41
  self.workflow.add_node("refine", self.refine_results)
42
+ # New verify node to further cross-check the output
43
+ self.workflow.add_node("verify", self.verify_output)
44
  self.workflow.set_entry_point("ingest")
45
  self.workflow.add_edge("ingest", "retrieve")
46
  self.workflow.add_edge("retrieve", "analyze")
 
49
  self._quality_check,
50
  {"valid": "validate", "invalid": "refine"}
51
  )
52
+ self.workflow.add_edge("validate", "verify")
53
  self.workflow.add_edge("refine", "retrieve")
54
  # Extended node for multi-modal enhancement
55
  self.workflow.add_node("enhance", self.enhance_analysis)
56
+ self.workflow.add_edge("verify", "enhance")
57
  self.workflow.add_edge("enhance", END)
58
 
59
  def ingest_query(self, state: Dict) -> Dict:
60
  try:
61
  query = state["messages"][-1].content
62
+ # Normalize the domain string; default to 'biomedical research'
63
  domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower()
64
  new_context = {
65
  "raw_query": query,
 
80
  def retrieve_documents(self, state: Dict) -> Dict:
81
  try:
82
  query = state["context"]["raw_query"]
83
+ # Placeholder retrieval: currently returns an empty list (simulate no documents)
84
  docs = []
85
  logger.info(f"Retrieved {len(docs)} documents for query.")
86
  return {
 
99
 
100
  def analyze_content(self, state: Dict) -> Dict:
101
  try:
 
102
  domain = state["context"].get("domain", "biomedical research").strip().lower()
103
  docs = state["context"].get("documents", [])
 
104
  if docs:
105
  docs_text = "\n\n".join([d.page_content for d in docs])
106
  else:
107
  docs_text = state["context"].get("raw_query", "")
108
+ logger.info("No documents retrieved; switching to dynamic synthesis (RAG mode).")
109
+ # Use domain-specific prompt; for legal research, inject legal-specific guidance.
110
+ domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain,
111
+ "Provide an analysis based on the provided context.")
112
  full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
113
  f"{domain_prompt}\n\n" + \
114
  ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
 
129
  try:
130
  analysis = state["messages"][-1].content
131
  validation_prompt = (
132
+ f"Validate the following analysis for accuracy and domain-specific relevance:\n{analysis}\n\n"
133
  "Criteria:\n"
134
+ "1. Factual and technical accuracy\n"
135
+ "2. For legal research: inclusion of relevant precedents and statutory interpretations; "
136
+ "for other domains: appropriate domain insights\n"
137
  "3. Logical consistency\n"
138
  "4. Methodological soundness\n\n"
139
  "Respond with 'VALID: [justification]' or 'INVALID: [justification]'."
 
148
  logger.exception("Error during output validation.")
149
  return self._error_state(f"Validation Error: {str(e)}")
150
 
151
+ def verify_output(self, state: Dict) -> Dict:
152
+ try:
153
+ # New verify step: cross-check the analysis using an external fact-checking prompt.
154
+ analysis = state["messages"][-1].content
155
+ verification_prompt = (
156
+ f"Verify the following analysis by comparing it with established external legal databases and reference texts:\n{analysis}\n\n"
157
+ "Identify any discrepancies or hallucinations and provide a brief correction if necessary."
158
+ )
159
+ response = self.processor.process_query(verification_prompt)
160
+ logger.info("Output verification completed.")
161
+ # Here, you can merge the verification feedback with the analysis.
162
+ verified_analysis = analysis + "\n\nVerification Feedback: " + response.get('choices', [{}])[0].get('message', {}).get('content', '')
163
+ return {
164
+ "messages": [AIMessage(content=verified_analysis)],
165
+ "context": state["context"]
166
+ }
167
+ except Exception as e:
168
+ logger.exception("Error during output verification.")
169
+ return self._error_state(f"Verification Error: {str(e)}")
170
+
171
  def refine_results(self, state: Dict) -> Dict:
172
  try:
173
  current_count = state["context"].get("refine_count", 0)
 
183
  f"Domain: {domain}\n"
184
  "You are given the following series of refinement outputs:\n" +
185
  "\n---\n".join(refinement_history) +
186
+ "\n\nSynthesize these into a final, concise analysis report with improved accuracy and verifiable details."
 
187
  )
188
  meta_response = self.processor.process_query(meta_prompt)
189
  logger.info("Meta-refinement completed.")
 
195
  refinement_prompt = (
196
  f"Domain: {domain}\n"
197
  f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
198
+ "Identify and correct any weaknesses or hallucinations in the analysis, providing verifiable details."
 
199
  )
200
  response = self.processor.process_query(refinement_prompt)
201
  logger.info("Refinement completed.")