Spaces:
Running
Running
from typing import Dict, Optional, Tuple, List, Any, Set, Union | |
import re | |
import xml.etree.ElementTree as ET | |
from datetime import datetime | |
import json | |
import logging | |
from enum import Enum | |
# Setup logger | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
# Create console handler if needed | |
if not logger.handlers: | |
ch = logging.StreamHandler() | |
ch.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
ch.setFormatter(formatter) | |
logger.addHandler(ch) | |
class StreamingFormatter: | |
def __init__(self): | |
self.processed_events = set() | |
self.current_tool_outputs = [] | |
self.current_citations = [] | |
self.current_metadata = {} | |
self.current_message_id = None | |
self.current_message_buffer = "" | |
def reset(self): | |
"""Reset the formatter state""" | |
self.processed_events.clear() | |
self.current_tool_outputs.clear() | |
self.current_citations.clear() | |
self.current_metadata.clear() | |
self.current_message_id = None | |
self.current_message_buffer = "" | |
def append_to_buffer(self, text: str): | |
"""Append text to the current message buffer""" | |
self.current_message_buffer += text | |
def get_and_clear_buffer(self) -> str: | |
"""Get the current buffer content and clear it""" | |
content = self.current_message_buffer | |
self.current_message_buffer = "" | |
return content | |
class ToolType(Enum): | |
"""Enum for supported tool types""" | |
DUCKDUCKGO = "ddgo_search" | |
REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase" | |
PUBMED = "pubmed_search" | |
CENSUS = "get_census_data" | |
HEATMAP = "heatmap_code" | |
MERMAID = "mermaid_output" | |
WISQARS = "wisqars" | |
WONDER = "wonder" | |
NCHS = "nchs" | |
ONESTEP = "onestep" | |
DQS = "dqs_nhis_adult_summary_health_statistics" | |
def get_tool_type(cls, tool_name: str) -> Optional['ToolType']: | |
"""Get enum member from tool name string""" | |
try: | |
return cls[tool_name.upper()] | |
except KeyError: | |
return None | |
class ResponseFormatter: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(ResponseFormatter, cls).__new__(cls) | |
cls._instance.streaming_state = StreamingFormatter() | |
cls._instance.logger = logger | |
return cls._instance | |
def format_thought( | |
self, | |
thought: str, | |
observation: str, | |
citations: List[Dict] = None, | |
metadata: Dict = None, | |
tool_outputs: List[Dict] = None, | |
event_id: str = None, | |
message_id: str = None | |
) -> Optional[Tuple[str, str]]: | |
"""Format agent thought for both terminal and XML output""" | |
# Skip if already processed in streaming mode | |
if event_id and event_id in self.streaming_state.processed_events: | |
return None | |
# Handle message state | |
if message_id != self.streaming_state.current_message_id: | |
self.streaming_state.reset() | |
self.streaming_state.current_message_id = message_id | |
# Skip empty thoughts | |
if not thought and not observation and not tool_outputs: | |
return None | |
# Terminal format | |
terminal_output = { | |
"type": "agent_thought", | |
"content": thought, | |
"metadata": metadata or {} | |
} | |
if tool_outputs: | |
# Deduplicate tool outputs | |
seen_outputs = set() | |
unique_outputs = [] | |
for output in tool_outputs: | |
output_key = f"{output.get('type')}:{output.get('content')}" | |
if output_key not in seen_outputs: | |
seen_outputs.add(output_key) | |
unique_outputs.append(output) | |
terminal_output["tool_outputs"] = unique_outputs | |
# XML format | |
root = ET.Element("agent_response") | |
if thought: | |
thought_elem = ET.SubElement(root, "thought") | |
thought_elem.text = thought | |
if observation: | |
obs_elem = ET.SubElement(root, "observation") | |
obs_elem.text = observation | |
if tool_outputs: | |
tools_elem = ET.SubElement(root, "tool_outputs") | |
for tool_output in unique_outputs: | |
tool_elem = ET.SubElement(tools_elem, "tool_output") | |
tool_elem.attrib["type"] = tool_output.get("type", "") | |
tool_elem.text = tool_output.get("content", "") | |
if citations: | |
cites_elem = ET.SubElement(root, "citations") | |
for citation in citations: | |
cite_elem = ET.SubElement(cites_elem, "citation") | |
for key, value in citation.items(): | |
cite_elem.attrib[key] = str(value) | |
xml_output = ET.tostring(root, encoding='unicode') | |
# Track processed event | |
if event_id: | |
self.streaming_state.processed_events.add(event_id) | |
return json.dumps(terminal_output), xml_output | |
def format_message( | |
self, | |
message: str, | |
event_id: str = None, | |
message_id: str = None | |
) -> Optional[Tuple[str, str]]: | |
"""Format agent message for both terminal and XML output""" | |
# Skip if already processed | |
if event_id and event_id in self.streaming_state.processed_events: | |
return None | |
# Handle message state | |
if message_id != self.streaming_state.current_message_id: | |
self.streaming_state.reset() | |
self.streaming_state.current_message_id = message_id | |
# Accumulate message content | |
self.streaming_state.append_to_buffer(message) | |
# Only output if we have meaningful content | |
if not self.streaming_state.current_message_buffer.strip(): | |
return None | |
# Terminal format | |
terminal_output = self.streaming_state.current_message_buffer.strip() | |
# XML format | |
root = ET.Element("agent_response") | |
msg_elem = ET.SubElement(root, "message") | |
msg_elem.text = terminal_output | |
xml_output = ET.tostring(root, encoding='unicode') | |
# Track processed event | |
if event_id: | |
self.streaming_state.processed_events.add(event_id) | |
return terminal_output, xml_output | |
def format_error( | |
self, | |
error: str, | |
event_id: str = None, | |
message_id: str = None | |
) -> Optional[Tuple[str, str]]: | |
"""Format error message for both terminal and XML output""" | |
# Skip if already processed | |
if event_id and event_id in self.streaming_state.processed_events: | |
return None | |
# Handle message state | |
if message_id != self.streaming_state.current_message_id: | |
self.streaming_state.reset() | |
self.streaming_state.current_message_id = message_id | |
# Skip empty errors | |
if not error: | |
return None | |
# Terminal format | |
terminal_output = f"Error: {error}" | |
# XML format | |
root = ET.Element("agent_response") | |
error_elem = ET.SubElement(root, "error") | |
error_elem.text = error | |
xml_output = ET.tostring(root, encoding='unicode') | |
# Track processed event | |
if event_id: | |
self.streaming_state.processed_events.add(event_id) | |
return terminal_output, xml_output | |
def format_tool_output( | |
self, | |
tool_type: str, | |
content: Union[str, Dict], | |
metadata: Optional[Dict] = None | |
) -> Dict: | |
"""Format tool output into standardized structure""" | |
try: | |
# Get enum tool type | |
tool = ToolType.get_tool_type(tool_type) | |
if not tool: | |
self.logger.warning(f"Unknown tool type: {tool_type}") | |
return { | |
"type": tool_type, | |
"content": content, | |
"metadata": metadata or {} | |
} | |
# Format based on tool type | |
if tool == ToolType.MERMAID: | |
return { | |
"type": "mermaid", | |
"content": self._clean_mermaid_content(content), | |
"metadata": metadata or {} | |
} | |
elif tool == ToolType.HEATMAP: | |
return { | |
"type": "heatmap", | |
"content": self._format_heatmap_data(content), | |
"metadata": metadata or {} | |
} | |
else: | |
# Default formatting for other tools | |
return { | |
"type": tool.value, | |
"content": content, | |
"metadata": metadata or {} | |
} | |
except Exception as e: | |
self.logger.error(f"Error formatting tool output: {str(e)}") | |
return { | |
"type": "error", | |
"content": str(e), | |
"metadata": metadata or {} | |
} | |
def _clean_mermaid_content(self, content: Union[str, Dict]) -> str: | |
"""Clean and standardize mermaid diagram content""" | |
try: | |
if isinstance(content, dict): | |
content = content.get("mermaid_diagram", "") | |
# Remove markdown formatting | |
content = re.sub(r'```mermaid\s*|\s*```', '', content) | |
# Clean up whitespace | |
content = content.strip() | |
return content | |
except Exception as e: | |
self.logger.error(f"Error cleaning mermaid content: {str(e)}") | |
return str(content) | |
def _format_heatmap_data(self, content: Union[str, Dict]) -> Dict: | |
"""Format heatmap data into standardized structure""" | |
try: | |
if isinstance(content, str): | |
content = json.loads(content) | |
return { | |
"data": content.get("data", []), | |
"options": content.get("options", {}), | |
"metadata": content.get("metadata", {}) | |
} | |
except Exception as e: | |
self.logger.error(f"Error formatting heatmap data: {str(e)}") | |
return {"error": str(e)} | |
def _clean_markdown(text: str) -> str: | |
"""Clean markdown formatting from text""" | |
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) | |
text = re.sub(r'[*_`#]', '', text) | |
return re.sub(r'\n{3,}', '\n\n', text.strip()) |