from logger_config import setup_logger from typing import Dict, Any, Optional, List, Union from dataclasses import dataclass, asdict from enum import Enum import json from dify_client_python.dify_client.models.stream import ( StreamEvent, StreamResponse, build_chat_stream_response ) import re logger = setup_logger() class EventType(Enum): AGENT_THOUGHT = "agent_thought" AGENT_MESSAGE = "agent_message" MESSAGE_END = "message_end" PING = "ping" @dataclass class ToolCall: tool_name: str tool_input: Dict[str, Any] tool_output: Optional[str] tool_labels: Dict[str, Dict[str, str]] @dataclass class Citation: dataset_id: str dataset_name: str document_id: str document_name: str segment_id: str score: float content: str @dataclass class ProcessedResponse: event_type: EventType task_id: str message_id: str conversation_id: str content: str tool_calls: List[ToolCall] citations: List[Citation] metadata: Dict[str, Any] created_at: int class EnumEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, Enum): return obj.value if hasattr(obj, 'dict'): return obj.dict() return super().default(obj) class SSEParser: def __init__(self): self.logger = setup_logger("sse_parser") def parse_sse_event(self, data: str) -> Optional[Dict]: """Parse SSE event data with improved mermaid handling""" self.logger.debug("Parsing SSE event") try: # Extract the data portion if "data:" in data: data = data.split("data:", 1)[1].strip() # Parse JSON data parsed_data = json.loads(data) # Enhanced mermaid diagram handling if "observation" in parsed_data: try: observation = parsed_data["observation"] if observation and isinstance(observation, str): if "mermaid_diagram" in observation: try: tool_data = json.loads(observation) if isinstance(tool_data, dict): mermaid_content = tool_data.get( "mermaid_diagram", "" ) if mermaid_content: # Clean and format mermaid content cleaned_content = self.clean_mermaid_content( mermaid_content ) parsed_data["observation"] = json.dumps({ "mermaid_diagram": cleaned_content }) except json.JSONDecodeError: self.logger.warning( "Failed to parse mermaid diagram content" ) except Exception as e: self.logger.error(f"Error processing observation: {str(e)}") return parsed_data except json.JSONDecodeError as e: self.logger.error(f"JSON decode error: {str(e)}") return None except Exception as e: self.logger.error(f"Parse error: {str(e)}") return None def clean_mermaid_content(self, content: str) -> str: """Clean and format mermaid diagram content""" try: # If content is JSON string, parse it if isinstance(content, str) and content.strip().startswith('{'): content_dict = json.loads(content) if "mermaid_diagram" in content_dict: content = content_dict["mermaid_diagram"] # Remove markdown code blocks content = re.sub(r'```mermaid\s*|\s*```', '', content) # Remove "tool response:" and any JSON wrapper content = re.sub(r'tool response:.*?{', '{', content) content = re.sub(r'}\s*\.$', '}', content) # If still JSON, extract mermaid content if content.strip().startswith('{'): try: content_dict = json.loads(content) if "mermaid_diagram" in content_dict: content = content_dict["mermaid_diagram"] except: pass # Final cleanup content = re.sub(r'\s+', ' ', content.strip()) return content except Exception as e: self.logger.error(f"Error cleaning mermaid content: {e}") return content