mgbam commited on
Commit
a1bc85b
·
verified ·
1 Parent(s): e610129

Update workflow.py

Browse files
Files changed (1) hide show
  1. workflow.py +27 -32
workflow.py CHANGED
@@ -23,19 +23,18 @@ class AgentState(TypedDict):
23
 
24
  class ResearchWorkflow:
25
  """
26
- A multi-step research workflow that leverages a Retrieval-Augmented Generation (RAG) strategy.
27
- It dynamically retrieves external data and integrates it with the raw query to generate domain-specific analyses.
28
- Supported domains include:
29
  - Biomedical Research
30
  - Legal Research
31
  - Environmental and Energy Studies
32
  - Competitive Programming and Theoretical Computer Science
33
  - Social Sciences
 
34
  """
35
  def __init__(self) -> None:
36
  self.processor = EnhancedCognitiveProcessor()
37
- # Provide the state schema to the StateGraph constructor.
38
- self.workflow = StateGraph(AgentState)
39
  self._build_workflow()
40
  self.app = self.workflow.compile()
41
 
@@ -55,7 +54,7 @@ class ResearchWorkflow:
55
  )
56
  self.workflow.add_edge("validate", END)
57
  self.workflow.add_edge("refine", "retrieve")
58
- # Extended node for multi-modal enhancement.
59
  self.workflow.add_node("enhance", self.enhance_analysis)
60
  self.workflow.add_edge("validate", "enhance")
61
  self.workflow.add_edge("enhance", END)
@@ -63,8 +62,8 @@ class ResearchWorkflow:
63
  def ingest_query(self, state: Dict) -> Dict:
64
  try:
65
  query = state["messages"][-1].content
66
- # Get the domain from state; default to Biomedical Research if not provided.
67
- domain = state.get("context", {}).get("domain", "Biomedical Research")
68
  new_context = {
69
  "raw_query": query,
70
  "domain": domain,
@@ -84,9 +83,8 @@ class ResearchWorkflow:
84
  def retrieve_documents(self, state: Dict) -> Dict:
85
  try:
86
  query = state["context"]["raw_query"]
87
- # For demonstration, we use an empty list to simulate retrieval failure.
88
- # In a full implementation, integrate a retriever (e.g., via LangChain, LlamaIndex, or a vector DB).
89
- docs = []
90
  logger.info(f"Retrieved {len(docs)} documents for query.")
91
  return {
92
  "messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
@@ -95,7 +93,7 @@ class ResearchWorkflow:
95
  "retrieval_time": time.time(),
96
  "refine_count": state["context"].get("refine_count", 0),
97
  "refinement_history": state["context"].get("refinement_history", []),
98
- "domain": state["context"].get("domain", "Biomedical Research")
99
  }
100
  }
101
  except Exception as e:
@@ -104,16 +102,18 @@ class ResearchWorkflow:
104
 
105
  def analyze_content(self, state: Dict) -> Dict:
106
  try:
107
- domain = state["context"].get("domain", "Biomedical Research").strip().lower()
 
108
  docs = state["context"].get("documents", [])
109
- # If documents are present, use their content; otherwise, fall back to the raw query.
110
  if docs:
111
  docs_text = "\n\n".join([d.page_content for d in docs])
112
  else:
113
  docs_text = state["context"].get("raw_query", "")
114
- logger.info("No documents retrieved; switching to dynamic synthesis using RAG.")
115
- domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain, "")
116
- # Combine the domain prompt with either retrieved text or raw query.
 
117
  full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
118
  f"{domain_prompt}\n\n" + \
119
  ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
@@ -134,13 +134,13 @@ class ResearchWorkflow:
134
  try:
135
  analysis = state["messages"][-1].content
136
  validation_prompt = (
137
- f"Validate the following research analysis:\n{analysis}\n\n"
138
- "Check for:\n"
139
- "1. Technical accuracy\n"
140
- "2. Citation support (are claims backed by evidence?)\n"
141
  "3. Logical consistency\n"
142
  "4. Methodological soundness\n\n"
143
- "Respond with 'VALID: [brief justification]' or 'INVALID: [brief justification]'."
144
  )
145
  response = self.processor.process_query(validation_prompt)
146
  logger.info("Output validation completed.")
@@ -160,15 +160,15 @@ class ResearchWorkflow:
160
  current_analysis = state["messages"][-1].content
161
  refinement_history.append(current_analysis)
162
  difficulty_level = max(0, 3 - state["context"]["refine_count"])
163
- domain = state["context"].get("domain", "Biomedical Research")
164
  logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}")
165
  if state["context"]["refine_count"] >= 3:
166
  meta_prompt = (
167
  f"Domain: {domain}\n"
168
  "You are given the following series of refinement outputs:\n" +
169
  "\n---\n".join(refinement_history) +
170
- "\n\nSynthesize the above into a final, concise, and high-quality technical analysis report. "
171
- "Focus on the key findings and improvements made across the iterations. Do not introduce new ideas; just synthesize the improvements. Ensure the report is well-structured and easy to understand."
172
  )
173
  meta_response = self.processor.process_query(meta_prompt)
174
  logger.info("Meta-refinement completed.")
@@ -180,12 +180,8 @@ class ResearchWorkflow:
180
  refinement_prompt = (
181
  f"Domain: {domain}\n"
182
  f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
183
- "First, critically evaluate the analysis and identify its weaknesses, such as inaccuracies, unsupported claims, or lack of clarity. Summarize these weaknesses in a short paragraph.\n\n"
184
- "Then, improve the following aspects:\n"
185
- "1. Technical precision\n"
186
- "2. Empirical grounding\n"
187
- "3. Theoretical coherence\n\n"
188
- "Use a structured difficulty gradient approach to produce a simpler yet more accurate variant, addressing the identified weaknesses."
189
  )
190
  response = self.processor.process_query(refinement_prompt)
191
  logger.info("Refinement completed.")
@@ -235,4 +231,3 @@ class ResearchWorkflow:
235
  except Exception as e:
236
  logger.exception("Error during multi-modal enhancement.")
237
  return self._error_state(f"Enhancement Error: {str(e)}")
238
-
 
23
 
24
  class ResearchWorkflow:
25
  """
26
+ A multi-step research workflow that leverages Retrieval-Augmented Generation (RAG).
27
+ Supports domains including:
 
28
  - Biomedical Research
29
  - Legal Research
30
  - Environmental and Energy Studies
31
  - Competitive Programming and Theoretical Computer Science
32
  - Social Sciences
33
+ This implementation normalizes the domain and uses domain-specific prompts and fallbacks.
34
  """
35
  def __init__(self) -> None:
36
  self.processor = EnhancedCognitiveProcessor()
37
+ self.workflow = StateGraph(AgentState) # Supply state schema
 
38
  self._build_workflow()
39
  self.app = self.workflow.compile()
40
 
 
54
  )
55
  self.workflow.add_edge("validate", END)
56
  self.workflow.add_edge("refine", "retrieve")
57
+ # Extended node for multi-modal enhancement
58
  self.workflow.add_node("enhance", self.enhance_analysis)
59
  self.workflow.add_edge("validate", "enhance")
60
  self.workflow.add_edge("enhance", END)
 
62
  def ingest_query(self, state: Dict) -> Dict:
63
  try:
64
  query = state["messages"][-1].content
65
+ # Normalize domain string to lower-case; default to 'biomedical research'
66
+ domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower()
67
  new_context = {
68
  "raw_query": query,
69
  "domain": domain,
 
83
  def retrieve_documents(self, state: Dict) -> Dict:
84
  try:
85
  query = state["context"]["raw_query"]
86
+ # Simulate retrieval; for now, an empty list indicates no external documents found.
87
+ docs = []
 
88
  logger.info(f"Retrieved {len(docs)} documents for query.")
89
  return {
90
  "messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
 
93
  "retrieval_time": time.time(),
94
  "refine_count": state["context"].get("refine_count", 0),
95
  "refinement_history": state["context"].get("refinement_history", []),
96
+ "domain": state["context"].get("domain", "biomedical research")
97
  }
98
  }
99
  except Exception as e:
 
102
 
103
  def analyze_content(self, state: Dict) -> Dict:
104
  try:
105
+ # Normalize domain and use it for prompt generation
106
+ domain = state["context"].get("domain", "biomedical research").strip().lower()
107
  docs = state["context"].get("documents", [])
108
+ # Use retrieved documents if available; else, use raw query as fallback.
109
  if docs:
110
  docs_text = "\n\n".join([d.page_content for d in docs])
111
  else:
112
  docs_text = state["context"].get("raw_query", "")
113
+ logger.info("No documents retrieved; using dynamic synthesis (RAG mode).")
114
+ # Get domain-specific prompt; ensure fallback prompts exist for all supported domains.
115
+ domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain, "Consider relevant legal cases and statutory interpretations.")
116
+ # Build the final prompt with domain tag for clarity.
117
  full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
118
  f"{domain_prompt}\n\n" + \
119
  ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
 
134
  try:
135
  analysis = state["messages"][-1].content
136
  validation_prompt = (
137
+ f"Validate the following analysis for correctness, clarity, and legal grounding:\n{analysis}\n\n"
138
+ "Criteria:\n"
139
+ "1. Technical and legal accuracy\n"
140
+ "2. Evidence and citation support\n"
141
  "3. Logical consistency\n"
142
  "4. Methodological soundness\n\n"
143
+ "Respond with 'VALID: [justification]' or 'INVALID: [justification]'."
144
  )
145
  response = self.processor.process_query(validation_prompt)
146
  logger.info("Output validation completed.")
 
160
  current_analysis = state["messages"][-1].content
161
  refinement_history.append(current_analysis)
162
  difficulty_level = max(0, 3 - state["context"]["refine_count"])
163
+ domain = state["context"].get("domain", "biomedical research")
164
  logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}")
165
  if state["context"]["refine_count"] >= 3:
166
  meta_prompt = (
167
  f"Domain: {domain}\n"
168
  "You are given the following series of refinement outputs:\n" +
169
  "\n---\n".join(refinement_history) +
170
+ "\n\nSynthesize these into a final, concise legal analysis report, highlighting key precedents and statutory interpretations. "
171
+ "Focus on improving accuracy and relevance for legal research."
172
  )
173
  meta_response = self.processor.process_query(meta_prompt)
174
  logger.info("Meta-refinement completed.")
 
180
  refinement_prompt = (
181
  f"Domain: {domain}\n"
182
  f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
183
+ "First, identify weaknesses such as lack of legal grounding or misinterpretation of cases. "
184
+ "Then, improve the analysis with clear references to legal precedents and statutory language."
 
 
 
 
185
  )
186
  response = self.processor.process_query(refinement_prompt)
187
  logger.info("Refinement completed.")
 
231
  except Exception as e:
232
  logger.exception("Error during multi-modal enhancement.")
233
  return self._error_state(f"Enhancement Error: {str(e)}")