Spaces:
Running
Running
Update workflow.py
Browse files- workflow.py +27 -32
workflow.py
CHANGED
@@ -23,19 +23,18 @@ class AgentState(TypedDict):
|
|
23 |
|
24 |
class ResearchWorkflow:
|
25 |
"""
|
26 |
-
A multi-step research workflow that leverages
|
27 |
-
|
28 |
-
Supported domains include:
|
29 |
- Biomedical Research
|
30 |
- Legal Research
|
31 |
- Environmental and Energy Studies
|
32 |
- Competitive Programming and Theoretical Computer Science
|
33 |
- Social Sciences
|
|
|
34 |
"""
|
35 |
def __init__(self) -> None:
|
36 |
self.processor = EnhancedCognitiveProcessor()
|
37 |
-
#
|
38 |
-
self.workflow = StateGraph(AgentState)
|
39 |
self._build_workflow()
|
40 |
self.app = self.workflow.compile()
|
41 |
|
@@ -55,7 +54,7 @@ class ResearchWorkflow:
|
|
55 |
)
|
56 |
self.workflow.add_edge("validate", END)
|
57 |
self.workflow.add_edge("refine", "retrieve")
|
58 |
-
# Extended node for multi-modal enhancement
|
59 |
self.workflow.add_node("enhance", self.enhance_analysis)
|
60 |
self.workflow.add_edge("validate", "enhance")
|
61 |
self.workflow.add_edge("enhance", END)
|
@@ -63,8 +62,8 @@ class ResearchWorkflow:
|
|
63 |
def ingest_query(self, state: Dict) -> Dict:
|
64 |
try:
|
65 |
query = state["messages"][-1].content
|
66 |
-
#
|
67 |
-
domain = state.get("context", {}).get("domain", "Biomedical Research")
|
68 |
new_context = {
|
69 |
"raw_query": query,
|
70 |
"domain": domain,
|
@@ -84,9 +83,8 @@ class ResearchWorkflow:
|
|
84 |
def retrieve_documents(self, state: Dict) -> Dict:
|
85 |
try:
|
86 |
query = state["context"]["raw_query"]
|
87 |
-
#
|
88 |
-
|
89 |
-
docs = []
|
90 |
logger.info(f"Retrieved {len(docs)} documents for query.")
|
91 |
return {
|
92 |
"messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
|
@@ -95,7 +93,7 @@ class ResearchWorkflow:
|
|
95 |
"retrieval_time": time.time(),
|
96 |
"refine_count": state["context"].get("refine_count", 0),
|
97 |
"refinement_history": state["context"].get("refinement_history", []),
|
98 |
-
"domain": state["context"].get("domain", "
|
99 |
}
|
100 |
}
|
101 |
except Exception as e:
|
@@ -104,16 +102,18 @@ class ResearchWorkflow:
|
|
104 |
|
105 |
def analyze_content(self, state: Dict) -> Dict:
|
106 |
try:
|
107 |
-
|
|
|
108 |
docs = state["context"].get("documents", [])
|
109 |
-
#
|
110 |
if docs:
|
111 |
docs_text = "\n\n".join([d.page_content for d in docs])
|
112 |
else:
|
113 |
docs_text = state["context"].get("raw_query", "")
|
114 |
-
logger.info("No documents retrieved;
|
115 |
-
|
116 |
-
|
|
|
117 |
full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
|
118 |
f"{domain_prompt}\n\n" + \
|
119 |
ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
|
@@ -134,13 +134,13 @@ class ResearchWorkflow:
|
|
134 |
try:
|
135 |
analysis = state["messages"][-1].content
|
136 |
validation_prompt = (
|
137 |
-
f"Validate the following
|
138 |
-
"
|
139 |
-
"1. Technical accuracy\n"
|
140 |
-
"2.
|
141 |
"3. Logical consistency\n"
|
142 |
"4. Methodological soundness\n\n"
|
143 |
-
"Respond with 'VALID: [
|
144 |
)
|
145 |
response = self.processor.process_query(validation_prompt)
|
146 |
logger.info("Output validation completed.")
|
@@ -160,15 +160,15 @@ class ResearchWorkflow:
|
|
160 |
current_analysis = state["messages"][-1].content
|
161 |
refinement_history.append(current_analysis)
|
162 |
difficulty_level = max(0, 3 - state["context"]["refine_count"])
|
163 |
-
domain = state["context"].get("domain", "
|
164 |
logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}")
|
165 |
if state["context"]["refine_count"] >= 3:
|
166 |
meta_prompt = (
|
167 |
f"Domain: {domain}\n"
|
168 |
"You are given the following series of refinement outputs:\n" +
|
169 |
"\n---\n".join(refinement_history) +
|
170 |
-
"\n\nSynthesize
|
171 |
-
"Focus on
|
172 |
)
|
173 |
meta_response = self.processor.process_query(meta_prompt)
|
174 |
logger.info("Meta-refinement completed.")
|
@@ -180,12 +180,8 @@ class ResearchWorkflow:
|
|
180 |
refinement_prompt = (
|
181 |
f"Domain: {domain}\n"
|
182 |
f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
|
183 |
-
"First,
|
184 |
-
"Then, improve the
|
185 |
-
"1. Technical precision\n"
|
186 |
-
"2. Empirical grounding\n"
|
187 |
-
"3. Theoretical coherence\n\n"
|
188 |
-
"Use a structured difficulty gradient approach to produce a simpler yet more accurate variant, addressing the identified weaknesses."
|
189 |
)
|
190 |
response = self.processor.process_query(refinement_prompt)
|
191 |
logger.info("Refinement completed.")
|
@@ -235,4 +231,3 @@ class ResearchWorkflow:
|
|
235 |
except Exception as e:
|
236 |
logger.exception("Error during multi-modal enhancement.")
|
237 |
return self._error_state(f"Enhancement Error: {str(e)}")
|
238 |
-
|
|
|
23 |
|
24 |
class ResearchWorkflow:
|
25 |
"""
|
26 |
+
A multi-step research workflow that leverages Retrieval-Augmented Generation (RAG).
|
27 |
+
Supports domains including:
|
|
|
28 |
- Biomedical Research
|
29 |
- Legal Research
|
30 |
- Environmental and Energy Studies
|
31 |
- Competitive Programming and Theoretical Computer Science
|
32 |
- Social Sciences
|
33 |
+
This implementation normalizes the domain and uses domain-specific prompts and fallbacks.
|
34 |
"""
|
35 |
def __init__(self) -> None:
|
36 |
self.processor = EnhancedCognitiveProcessor()
|
37 |
+
self.workflow = StateGraph(AgentState) # Supply state schema
|
|
|
38 |
self._build_workflow()
|
39 |
self.app = self.workflow.compile()
|
40 |
|
|
|
54 |
)
|
55 |
self.workflow.add_edge("validate", END)
|
56 |
self.workflow.add_edge("refine", "retrieve")
|
57 |
+
# Extended node for multi-modal enhancement
|
58 |
self.workflow.add_node("enhance", self.enhance_analysis)
|
59 |
self.workflow.add_edge("validate", "enhance")
|
60 |
self.workflow.add_edge("enhance", END)
|
|
|
62 |
def ingest_query(self, state: Dict) -> Dict:
|
63 |
try:
|
64 |
query = state["messages"][-1].content
|
65 |
+
# Normalize domain string to lower-case; default to 'biomedical research'
|
66 |
+
domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower()
|
67 |
new_context = {
|
68 |
"raw_query": query,
|
69 |
"domain": domain,
|
|
|
83 |
def retrieve_documents(self, state: Dict) -> Dict:
|
84 |
try:
|
85 |
query = state["context"]["raw_query"]
|
86 |
+
# Simulate retrieval; for now, an empty list indicates no external documents found.
|
87 |
+
docs = []
|
|
|
88 |
logger.info(f"Retrieved {len(docs)} documents for query.")
|
89 |
return {
|
90 |
"messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
|
|
|
93 |
"retrieval_time": time.time(),
|
94 |
"refine_count": state["context"].get("refine_count", 0),
|
95 |
"refinement_history": state["context"].get("refinement_history", []),
|
96 |
+
"domain": state["context"].get("domain", "biomedical research")
|
97 |
}
|
98 |
}
|
99 |
except Exception as e:
|
|
|
102 |
|
103 |
def analyze_content(self, state: Dict) -> Dict:
|
104 |
try:
|
105 |
+
# Normalize domain and use it for prompt generation
|
106 |
+
domain = state["context"].get("domain", "biomedical research").strip().lower()
|
107 |
docs = state["context"].get("documents", [])
|
108 |
+
# Use retrieved documents if available; else, use raw query as fallback.
|
109 |
if docs:
|
110 |
docs_text = "\n\n".join([d.page_content for d in docs])
|
111 |
else:
|
112 |
docs_text = state["context"].get("raw_query", "")
|
113 |
+
logger.info("No documents retrieved; using dynamic synthesis (RAG mode).")
|
114 |
+
# Get domain-specific prompt; ensure fallback prompts exist for all supported domains.
|
115 |
+
domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain, "Consider relevant legal cases and statutory interpretations.")
|
116 |
+
# Build the final prompt with domain tag for clarity.
|
117 |
full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
|
118 |
f"{domain_prompt}\n\n" + \
|
119 |
ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
|
|
|
134 |
try:
|
135 |
analysis = state["messages"][-1].content
|
136 |
validation_prompt = (
|
137 |
+
f"Validate the following analysis for correctness, clarity, and legal grounding:\n{analysis}\n\n"
|
138 |
+
"Criteria:\n"
|
139 |
+
"1. Technical and legal accuracy\n"
|
140 |
+
"2. Evidence and citation support\n"
|
141 |
"3. Logical consistency\n"
|
142 |
"4. Methodological soundness\n\n"
|
143 |
+
"Respond with 'VALID: [justification]' or 'INVALID: [justification]'."
|
144 |
)
|
145 |
response = self.processor.process_query(validation_prompt)
|
146 |
logger.info("Output validation completed.")
|
|
|
160 |
current_analysis = state["messages"][-1].content
|
161 |
refinement_history.append(current_analysis)
|
162 |
difficulty_level = max(0, 3 - state["context"]["refine_count"])
|
163 |
+
domain = state["context"].get("domain", "biomedical research")
|
164 |
logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}")
|
165 |
if state["context"]["refine_count"] >= 3:
|
166 |
meta_prompt = (
|
167 |
f"Domain: {domain}\n"
|
168 |
"You are given the following series of refinement outputs:\n" +
|
169 |
"\n---\n".join(refinement_history) +
|
170 |
+
"\n\nSynthesize these into a final, concise legal analysis report, highlighting key precedents and statutory interpretations. "
|
171 |
+
"Focus on improving accuracy and relevance for legal research."
|
172 |
)
|
173 |
meta_response = self.processor.process_query(meta_prompt)
|
174 |
logger.info("Meta-refinement completed.")
|
|
|
180 |
refinement_prompt = (
|
181 |
f"Domain: {domain}\n"
|
182 |
f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
|
183 |
+
"First, identify weaknesses such as lack of legal grounding or misinterpretation of cases. "
|
184 |
+
"Then, improve the analysis with clear references to legal precedents and statutory language."
|
|
|
|
|
|
|
|
|
185 |
)
|
186 |
response = self.processor.process_query(refinement_prompt)
|
187 |
logger.info("Refinement completed.")
|
|
|
231 |
except Exception as e:
|
232 |
logger.exception("Error during multi-modal enhancement.")
|
233 |
return self._error_state(f"Enhancement Error: {str(e)}")
|
|