Spaces:
Running
Running
Update workflow.py
Browse files- workflow.py +40 -26
workflow.py
CHANGED
@@ -23,18 +23,13 @@ class AgentState(TypedDict):
|
|
23 |
|
24 |
class ResearchWorkflow:
|
25 |
"""
|
26 |
-
A multi-step research workflow
|
27 |
-
|
28 |
-
|
29 |
-
- Legal Research
|
30 |
-
- Environmental and Energy Studies
|
31 |
-
- Competitive Programming and Theoretical Computer Science
|
32 |
-
- Social Sciences
|
33 |
-
This implementation normalizes the domain and uses domain-specific prompts and fallbacks.
|
34 |
"""
|
35 |
def __init__(self) -> None:
|
36 |
self.processor = EnhancedCognitiveProcessor()
|
37 |
-
self.workflow = StateGraph(AgentState)
|
38 |
self._build_workflow()
|
39 |
self.app = self.workflow.compile()
|
40 |
|
@@ -44,6 +39,8 @@ class ResearchWorkflow:
|
|
44 |
self.workflow.add_node("analyze", self.analyze_content)
|
45 |
self.workflow.add_node("validate", self.validate_output)
|
46 |
self.workflow.add_node("refine", self.refine_results)
|
|
|
|
|
47 |
self.workflow.set_entry_point("ingest")
|
48 |
self.workflow.add_edge("ingest", "retrieve")
|
49 |
self.workflow.add_edge("retrieve", "analyze")
|
@@ -52,17 +49,17 @@ class ResearchWorkflow:
|
|
52 |
self._quality_check,
|
53 |
{"valid": "validate", "invalid": "refine"}
|
54 |
)
|
55 |
-
self.workflow.add_edge("validate",
|
56 |
self.workflow.add_edge("refine", "retrieve")
|
57 |
# Extended node for multi-modal enhancement
|
58 |
self.workflow.add_node("enhance", self.enhance_analysis)
|
59 |
-
self.workflow.add_edge("
|
60 |
self.workflow.add_edge("enhance", END)
|
61 |
|
62 |
def ingest_query(self, state: Dict) -> Dict:
|
63 |
try:
|
64 |
query = state["messages"][-1].content
|
65 |
-
# Normalize domain string
|
66 |
domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower()
|
67 |
new_context = {
|
68 |
"raw_query": query,
|
@@ -83,7 +80,7 @@ class ResearchWorkflow:
|
|
83 |
def retrieve_documents(self, state: Dict) -> Dict:
|
84 |
try:
|
85 |
query = state["context"]["raw_query"]
|
86 |
-
#
|
87 |
docs = []
|
88 |
logger.info(f"Retrieved {len(docs)} documents for query.")
|
89 |
return {
|
@@ -102,18 +99,16 @@ class ResearchWorkflow:
|
|
102 |
|
103 |
def analyze_content(self, state: Dict) -> Dict:
|
104 |
try:
|
105 |
-
# Normalize domain and use it for prompt generation
|
106 |
domain = state["context"].get("domain", "biomedical research").strip().lower()
|
107 |
docs = state["context"].get("documents", [])
|
108 |
-
# Use retrieved documents if available; else, use raw query as fallback.
|
109 |
if docs:
|
110 |
docs_text = "\n\n".join([d.page_content for d in docs])
|
111 |
else:
|
112 |
docs_text = state["context"].get("raw_query", "")
|
113 |
-
logger.info("No documents retrieved;
|
114 |
-
#
|
115 |
-
domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain,
|
116 |
-
|
117 |
full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
|
118 |
f"{domain_prompt}\n\n" + \
|
119 |
ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
|
@@ -134,10 +129,11 @@ class ResearchWorkflow:
|
|
134 |
try:
|
135 |
analysis = state["messages"][-1].content
|
136 |
validation_prompt = (
|
137 |
-
f"Validate the following analysis for
|
138 |
"Criteria:\n"
|
139 |
-
"1.
|
140 |
-
"2.
|
|
|
141 |
"3. Logical consistency\n"
|
142 |
"4. Methodological soundness\n\n"
|
143 |
"Respond with 'VALID: [justification]' or 'INVALID: [justification]'."
|
@@ -152,6 +148,26 @@ class ResearchWorkflow:
|
|
152 |
logger.exception("Error during output validation.")
|
153 |
return self._error_state(f"Validation Error: {str(e)}")
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
def refine_results(self, state: Dict) -> Dict:
|
156 |
try:
|
157 |
current_count = state["context"].get("refine_count", 0)
|
@@ -167,8 +183,7 @@ class ResearchWorkflow:
|
|
167 |
f"Domain: {domain}\n"
|
168 |
"You are given the following series of refinement outputs:\n" +
|
169 |
"\n---\n".join(refinement_history) +
|
170 |
-
"\n\nSynthesize these into a final, concise
|
171 |
-
"Focus on improving accuracy and relevance for legal research."
|
172 |
)
|
173 |
meta_response = self.processor.process_query(meta_prompt)
|
174 |
logger.info("Meta-refinement completed.")
|
@@ -180,8 +195,7 @@ class ResearchWorkflow:
|
|
180 |
refinement_prompt = (
|
181 |
f"Domain: {domain}\n"
|
182 |
f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
|
183 |
-
"
|
184 |
-
"Then, improve the analysis with clear references to legal precedents and statutory language."
|
185 |
)
|
186 |
response = self.processor.process_query(refinement_prompt)
|
187 |
logger.info("Refinement completed.")
|
|
|
23 |
|
24 |
class ResearchWorkflow:
|
25 |
"""
|
26 |
+
A multi-step research workflow employing Retrieval-Augmented Generation (RAG) with an additional verification step.
|
27 |
+
This workflow supports multiple domains (e.g., Biomedical, Legal, Environmental, Competitive Programming, Social Sciences)
|
28 |
+
and integrates domain-specific prompts, iterative refinement, and a final verification to reduce hallucinations.
|
|
|
|
|
|
|
|
|
|
|
29 |
"""
|
30 |
def __init__(self) -> None:
|
31 |
self.processor = EnhancedCognitiveProcessor()
|
32 |
+
self.workflow = StateGraph(AgentState)
|
33 |
self._build_workflow()
|
34 |
self.app = self.workflow.compile()
|
35 |
|
|
|
39 |
self.workflow.add_node("analyze", self.analyze_content)
|
40 |
self.workflow.add_node("validate", self.validate_output)
|
41 |
self.workflow.add_node("refine", self.refine_results)
|
42 |
+
# New verify node to further cross-check the output
|
43 |
+
self.workflow.add_node("verify", self.verify_output)
|
44 |
self.workflow.set_entry_point("ingest")
|
45 |
self.workflow.add_edge("ingest", "retrieve")
|
46 |
self.workflow.add_edge("retrieve", "analyze")
|
|
|
49 |
self._quality_check,
|
50 |
{"valid": "validate", "invalid": "refine"}
|
51 |
)
|
52 |
+
self.workflow.add_edge("validate", "verify")
|
53 |
self.workflow.add_edge("refine", "retrieve")
|
54 |
# Extended node for multi-modal enhancement
|
55 |
self.workflow.add_node("enhance", self.enhance_analysis)
|
56 |
+
self.workflow.add_edge("verify", "enhance")
|
57 |
self.workflow.add_edge("enhance", END)
|
58 |
|
59 |
def ingest_query(self, state: Dict) -> Dict:
|
60 |
try:
|
61 |
query = state["messages"][-1].content
|
62 |
+
# Normalize the domain string; default to 'biomedical research'
|
63 |
domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower()
|
64 |
new_context = {
|
65 |
"raw_query": query,
|
|
|
80 |
def retrieve_documents(self, state: Dict) -> Dict:
|
81 |
try:
|
82 |
query = state["context"]["raw_query"]
|
83 |
+
# Placeholder retrieval: currently returns an empty list (simulate no documents)
|
84 |
docs = []
|
85 |
logger.info(f"Retrieved {len(docs)} documents for query.")
|
86 |
return {
|
|
|
99 |
|
100 |
def analyze_content(self, state: Dict) -> Dict:
|
101 |
try:
|
|
|
102 |
domain = state["context"].get("domain", "biomedical research").strip().lower()
|
103 |
docs = state["context"].get("documents", [])
|
|
|
104 |
if docs:
|
105 |
docs_text = "\n\n".join([d.page_content for d in docs])
|
106 |
else:
|
107 |
docs_text = state["context"].get("raw_query", "")
|
108 |
+
logger.info("No documents retrieved; switching to dynamic synthesis (RAG mode).")
|
109 |
+
# Use domain-specific prompt; for legal research, inject legal-specific guidance.
|
110 |
+
domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain,
|
111 |
+
"Provide an analysis based on the provided context.")
|
112 |
full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
|
113 |
f"{domain_prompt}\n\n" + \
|
114 |
ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
|
|
|
129 |
try:
|
130 |
analysis = state["messages"][-1].content
|
131 |
validation_prompt = (
|
132 |
+
f"Validate the following analysis for accuracy and domain-specific relevance:\n{analysis}\n\n"
|
133 |
"Criteria:\n"
|
134 |
+
"1. Factual and technical accuracy\n"
|
135 |
+
"2. For legal research: inclusion of relevant precedents and statutory interpretations; "
|
136 |
+
"for other domains: appropriate domain insights\n"
|
137 |
"3. Logical consistency\n"
|
138 |
"4. Methodological soundness\n\n"
|
139 |
"Respond with 'VALID: [justification]' or 'INVALID: [justification]'."
|
|
|
148 |
logger.exception("Error during output validation.")
|
149 |
return self._error_state(f"Validation Error: {str(e)}")
|
150 |
|
151 |
+
def verify_output(self, state: Dict) -> Dict:
|
152 |
+
try:
|
153 |
+
# New verify step: cross-check the analysis using an external fact-checking prompt.
|
154 |
+
analysis = state["messages"][-1].content
|
155 |
+
verification_prompt = (
|
156 |
+
f"Verify the following analysis by comparing it with established external legal databases and reference texts:\n{analysis}\n\n"
|
157 |
+
"Identify any discrepancies or hallucinations and provide a brief correction if necessary."
|
158 |
+
)
|
159 |
+
response = self.processor.process_query(verification_prompt)
|
160 |
+
logger.info("Output verification completed.")
|
161 |
+
# Here, you can merge the verification feedback with the analysis.
|
162 |
+
verified_analysis = analysis + "\n\nVerification Feedback: " + response.get('choices', [{}])[0].get('message', {}).get('content', '')
|
163 |
+
return {
|
164 |
+
"messages": [AIMessage(content=verified_analysis)],
|
165 |
+
"context": state["context"]
|
166 |
+
}
|
167 |
+
except Exception as e:
|
168 |
+
logger.exception("Error during output verification.")
|
169 |
+
return self._error_state(f"Verification Error: {str(e)}")
|
170 |
+
|
171 |
def refine_results(self, state: Dict) -> Dict:
|
172 |
try:
|
173 |
current_count = state["context"].get("refine_count", 0)
|
|
|
183 |
f"Domain: {domain}\n"
|
184 |
"You are given the following series of refinement outputs:\n" +
|
185 |
"\n---\n".join(refinement_history) +
|
186 |
+
"\n\nSynthesize these into a final, concise analysis report with improved accuracy and verifiable details."
|
|
|
187 |
)
|
188 |
meta_response = self.processor.process_query(meta_prompt)
|
189 |
logger.info("Meta-refinement completed.")
|
|
|
195 |
refinement_prompt = (
|
196 |
f"Domain: {domain}\n"
|
197 |
f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
|
198 |
+
"Identify and correct any weaknesses or hallucinations in the analysis, providing verifiable details."
|
|
|
199 |
)
|
200 |
response = self.processor.process_query(refinement_prompt)
|
201 |
logger.info("Refinement completed.")
|