cc-api / response_formatter.py
Severian's picture
Update response_formatter.py
3c1c117 verified
from typing import Dict, Optional, Tuple, List, Any, Set, Union
import re
import xml.etree.ElementTree as ET
from datetime import datetime
import json
import logging
from enum import Enum
# Setup logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Create console handler if needed
if not logger.handlers:
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
class StreamingFormatter:
def __init__(self):
self.processed_events = set()
self.current_tool_outputs = []
self.current_citations = []
self.current_metadata = {}
self.current_message_id = None
self.current_message_buffer = ""
def reset(self):
"""Reset the formatter state"""
self.processed_events.clear()
self.current_tool_outputs.clear()
self.current_citations.clear()
self.current_metadata.clear()
self.current_message_id = None
self.current_message_buffer = ""
def append_to_buffer(self, text: str):
"""Append text to the current message buffer"""
self.current_message_buffer += text
def get_and_clear_buffer(self) -> str:
"""Get the current buffer content and clear it"""
content = self.current_message_buffer
self.current_message_buffer = ""
return content
class ToolType(Enum):
"""Enum for supported tool types"""
DUCKDUCKGO = "ddgo_search"
REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase"
PUBMED = "pubmed_search"
CENSUS = "get_census_data"
HEATMAP = "heatmap_code"
MERMAID = "mermaid_output"
WISQARS = "wisqars"
WONDER = "wonder"
NCHS = "nchs"
ONESTEP = "onestep"
DQS = "dqs_nhis_adult_summary_health_statistics"
@classmethod
def get_tool_type(cls, tool_name: str) -> Optional['ToolType']:
"""Get enum member from tool name string"""
try:
return cls[tool_name.upper()]
except KeyError:
return None
class ResponseFormatter:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ResponseFormatter, cls).__new__(cls)
cls._instance.streaming_state = StreamingFormatter()
cls._instance.logger = logger
return cls._instance
def format_thought(
self,
thought: str,
observation: str,
citations: List[Dict] = None,
metadata: Dict = None,
tool_outputs: List[Dict] = None,
event_id: str = None,
message_id: str = None
) -> Optional[Tuple[str, str]]:
"""Format agent thought for both terminal and XML output"""
# Skip if already processed in streaming mode
if event_id and event_id in self.streaming_state.processed_events:
return None
# Handle message state
if message_id != self.streaming_state.current_message_id:
self.streaming_state.reset()
self.streaming_state.current_message_id = message_id
# Skip empty thoughts
if not thought and not observation and not tool_outputs:
return None
# Terminal format
terminal_output = {
"type": "agent_thought",
"content": thought,
"metadata": metadata or {}
}
if tool_outputs:
# Deduplicate tool outputs
seen_outputs = set()
unique_outputs = []
for output in tool_outputs:
output_key = f"{output.get('type')}:{output.get('content')}"
if output_key not in seen_outputs:
seen_outputs.add(output_key)
unique_outputs.append(output)
terminal_output["tool_outputs"] = unique_outputs
# XML format
root = ET.Element("agent_response")
if thought:
thought_elem = ET.SubElement(root, "thought")
thought_elem.text = thought
if observation:
obs_elem = ET.SubElement(root, "observation")
obs_elem.text = observation
if tool_outputs:
tools_elem = ET.SubElement(root, "tool_outputs")
for tool_output in unique_outputs:
tool_elem = ET.SubElement(tools_elem, "tool_output")
tool_elem.attrib["type"] = tool_output.get("type", "")
tool_elem.text = tool_output.get("content", "")
if citations:
cites_elem = ET.SubElement(root, "citations")
for citation in citations:
cite_elem = ET.SubElement(cites_elem, "citation")
for key, value in citation.items():
cite_elem.attrib[key] = str(value)
xml_output = ET.tostring(root, encoding='unicode')
# Track processed event
if event_id:
self.streaming_state.processed_events.add(event_id)
return json.dumps(terminal_output), xml_output
def format_message(
self,
message: str,
event_id: str = None,
message_id: str = None
) -> Optional[Tuple[str, str]]:
"""Format agent message for both terminal and XML output"""
# Skip if already processed
if event_id and event_id in self.streaming_state.processed_events:
return None
# Handle message state
if message_id != self.streaming_state.current_message_id:
self.streaming_state.reset()
self.streaming_state.current_message_id = message_id
# Accumulate message content
self.streaming_state.append_to_buffer(message)
# Only output if we have meaningful content
if not self.streaming_state.current_message_buffer.strip():
return None
# Terminal format
terminal_output = self.streaming_state.current_message_buffer.strip()
# XML format
root = ET.Element("agent_response")
msg_elem = ET.SubElement(root, "message")
msg_elem.text = terminal_output
xml_output = ET.tostring(root, encoding='unicode')
# Track processed event
if event_id:
self.streaming_state.processed_events.add(event_id)
return terminal_output, xml_output
def format_error(
self,
error: str,
event_id: str = None,
message_id: str = None
) -> Optional[Tuple[str, str]]:
"""Format error message for both terminal and XML output"""
# Skip if already processed
if event_id and event_id in self.streaming_state.processed_events:
return None
# Handle message state
if message_id != self.streaming_state.current_message_id:
self.streaming_state.reset()
self.streaming_state.current_message_id = message_id
# Skip empty errors
if not error:
return None
# Terminal format
terminal_output = f"Error: {error}"
# XML format
root = ET.Element("agent_response")
error_elem = ET.SubElement(root, "error")
error_elem.text = error
xml_output = ET.tostring(root, encoding='unicode')
# Track processed event
if event_id:
self.streaming_state.processed_events.add(event_id)
return terminal_output, xml_output
def format_tool_output(
self,
tool_type: str,
content: Union[str, Dict],
metadata: Optional[Dict] = None
) -> Dict:
"""Format tool output into standardized structure"""
try:
# Get enum tool type
tool = ToolType.get_tool_type(tool_type)
if not tool:
self.logger.warning(f"Unknown tool type: {tool_type}")
return {
"type": tool_type,
"content": content,
"metadata": metadata or {}
}
# Format based on tool type
if tool == ToolType.MERMAID:
return {
"type": "mermaid",
"content": self._clean_mermaid_content(content),
"metadata": metadata or {}
}
elif tool == ToolType.HEATMAP:
return {
"type": "heatmap",
"content": self._format_heatmap_data(content),
"metadata": metadata or {}
}
else:
# Default formatting for other tools
return {
"type": tool.value,
"content": content,
"metadata": metadata or {}
}
except Exception as e:
self.logger.error(f"Error formatting tool output: {str(e)}")
return {
"type": "error",
"content": str(e),
"metadata": metadata or {}
}
def _clean_mermaid_content(self, content: Union[str, Dict]) -> str:
"""Clean and standardize mermaid diagram content"""
try:
if isinstance(content, dict):
content = content.get("mermaid_diagram", "")
# Remove markdown formatting
content = re.sub(r'```mermaid\s*|\s*```', '', content)
# Clean up whitespace
content = content.strip()
return content
except Exception as e:
self.logger.error(f"Error cleaning mermaid content: {str(e)}")
return str(content)
def _format_heatmap_data(self, content: Union[str, Dict]) -> Dict:
"""Format heatmap data into standardized structure"""
try:
if isinstance(content, str):
content = json.loads(content)
return {
"data": content.get("data", []),
"options": content.get("options", {}),
"metadata": content.get("metadata", {})
}
except Exception as e:
self.logger.error(f"Error formatting heatmap data: {str(e)}")
return {"error": str(e)}
@staticmethod
def _clean_markdown(text: str) -> str:
"""Clean markdown formatting from text"""
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL)
text = re.sub(r'[*_`#]', '', text)
return re.sub(r'\n{3,}', '\n\n', text.strip())