cc-api / response_formatter.py
Severian's picture
Update response_formatter.py
e1e379a verified
raw
history blame
8.14 kB
from typing import Dict, Optional, Tuple, List, Any
import re
import xml.etree.ElementTree as ET
from datetime import datetime
import json
from logging import logger
class ToolType:
DUCKDUCKGO = "duckduckgo_search"
REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase"
PUBMED = "pubmed_search"
CENSUS = "get_census_data"
HEATMAP = "heatmap_code"
MERMAID = "mermaid_diagram"
WISQARS = "wisqars"
WONDER = "wonder"
NCHS = "nchs"
ONESTEP = "onestep"
DQS = "dqs_nhis_adult_summary_health_statistics"
class ResponseFormatter:
@staticmethod
def format_thought(
thought: str,
observation: str,
citations: List[Dict] = None,
metadata: Dict = None,
tool_outputs: List[Dict] = None
) -> Tuple[str, str]:
"""Format agent thought for both terminal and XML output"""
# Terminal format
terminal_output = {
"type": "agent_thought",
"content": thought,
"metadata": metadata or {}
}
if tool_outputs:
terminal_output["tool_outputs"] = tool_outputs
# XML format
root = ET.Element("agent_response")
thought_elem = ET.SubElement(root, "thought")
thought_elem.text = thought
if observation:
obs_elem = ET.SubElement(root, "observation")
obs_elem.text = observation
if tool_outputs:
tools_elem = ET.SubElement(root, "tool_outputs")
for tool_output in tool_outputs:
tool_elem = ET.SubElement(tools_elem, "tool_output")
tool_elem.attrib["type"] = tool_output.get("type", "")
tool_elem.text = tool_output.get("content", "")
if citations:
cites_elem = ET.SubElement(root, "citations")
for citation in citations:
cite_elem = ET.SubElement(cites_elem, "citation")
for key, value in citation.items():
cite_elem.attrib[key] = str(value)
xml_output = ET.tostring(root, encoding='unicode')
return json.dumps(terminal_output), xml_output
@staticmethod
def _create_tool_element(parent: ET.Element, tool_name: str, tool_data: Dict) -> ET.Element:
"""Create XML element for specific tool type with appropriate structure"""
tool_elem = ET.SubElement(parent, "tool")
tool_elem.set("name", tool_name)
# Handle different tool types
if tool_name == ToolType.CENSUS:
ResponseFormatter._format_census_data(tool_elem, tool_data)
elif tool_name == ToolType.MERMAID:
ResponseFormatter._format_mermaid_data(tool_elem, tool_data)
elif tool_name in [ToolType.WISQARS, ToolType.WONDER, ToolType.NCHS]:
ResponseFormatter._format_health_data(tool_elem, tool_data)
else:
# Generic tool output format
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = ResponseFormatter._clean_markdown(str(tool_data))
return tool_elem
@staticmethod
def _format_census_data(tool_elem: ET.Element, data: Dict) -> None:
"""Format census data with specific structure"""
try:
# Extract census tract data
tracts_elem = ET.SubElement(tool_elem, "census_tracts")
# Parse the llm_result to extract structured data
if "llm_result" in data:
result = json.loads(data["llm_result"])
for tract_data in result.get("tracts", []):
tract_elem = ET.SubElement(tracts_elem, "tract")
tract_elem.set("id", str(tract_data.get("tract", "")))
# Add tract details
for key, value in tract_data.items():
if key != "tract":
detail_elem = ET.SubElement(tract_elem, key.replace("_", ""))
detail_elem.text = str(value)
except:
# Fallback to simple format if parsing fails
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = ResponseFormatter._clean_markdown(str(data))
@staticmethod
def _format_mermaid_data(tool_elem: ET.Element, data: Dict) -> None:
"""Format mermaid diagram data with improved error handling"""
try:
diagram_elem = ET.SubElement(tool_elem, "diagram")
# Extract content from data
content = ""
if isinstance(data, dict):
content = data.get("content", data.get("mermaid_diagram", ""))
elif isinstance(data, str):
content = data
# Clean any remaining markdown/JSON formatting
content = re.sub(r'```mermaid\s*|\s*```', '', content)
content = re.sub(r'tool response:.*?{', '{', content)
content = re.sub(r'}\s*\.$', '}', content)
# Set cleaned content
diagram_elem.text = content.strip()
except Exception as e:
logger.error(f"Error formatting mermaid data: {e}")
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = "Error formatting diagram"
@staticmethod
def _format_health_data(tool_elem: ET.Element, data: Dict) -> None:
"""Format health-related data from WISQARS, WONDER, etc."""
try:
if isinstance(data, dict):
for key, value in data.items():
category_elem = ET.SubElement(tool_elem, key.replace("_", ""))
if isinstance(value, dict):
for sub_key, sub_value in value.items():
sub_elem = ET.SubElement(category_elem, sub_key.replace("_", ""))
sub_elem.text = str(sub_value)
else:
category_elem.text = str(value)
except:
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = ResponseFormatter._clean_markdown(str(data))
@staticmethod
def _extract_tool_outputs(observation: str) -> Dict[str, Any]:
"""Extract and clean tool outputs from observation"""
tool_outputs = {}
try:
if isinstance(observation, str):
data = json.loads(observation)
for key, value in data.items():
if isinstance(value, str) and "llm_result" in value:
try:
tool_result = json.loads(value)
tool_outputs[key] = tool_result
except:
tool_outputs[key] = value
except:
pass
return tool_outputs
@staticmethod
def format_message(message: str) -> Tuple[str, str]:
"""Format agent message for both terminal and XML output"""
# Terminal format
terminal_output = message.strip()
# XML format
root = ET.Element("agent_response")
msg_elem = ET.SubElement(root, "message")
msg_elem.text = message.strip()
xml_output = ET.tostring(root, encoding='unicode')
return terminal_output, xml_output
@staticmethod
def format_error(error: str) -> Tuple[str, str]:
"""Format error message for both terminal and XML output"""
# Terminal format
terminal_output = f"Error: {error}"
# XML format
root = ET.Element("agent_response")
error_elem = ET.SubElement(root, "error")
error_elem.text = error
xml_output = ET.tostring(root, encoding='unicode')
return terminal_output, xml_output
@staticmethod
def _clean_markdown(text: str) -> str:
"""Clean markdown formatting from text"""
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL)
text = re.sub(r'[*_`#]', '', text)
return re.sub(r'\n{3,}', '\n\n', text.strip())