File size: 7,854 Bytes
c5a20a4
ea82e64
e45083a
109f11f
 
038f313
109f11f
 
 
75d7afe
109f11f
 
 
 
 
75bf974
109f11f
 
 
e45083a
109f11f
 
 
 
 
 
e45083a
109f11f
 
 
 
 
 
 
e45083a
109f11f
 
 
 
 
 
 
e45083a
109f11f
 
 
 
 
 
 
 
e45083a
109f11f
e45083a
 
109f11f
 
 
e45083a
109f11f
 
 
 
 
 
e45083a
 
109f11f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e45083a
109f11f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e45083a
 
109f11f
 
e45083a
109f11f
 
 
 
 
 
 
e45083a
109f11f
 
 
e45083a
109f11f
 
 
 
e45083a
 
109f11f
e45083a
 
109f11f
 
 
 
 
 
 
 
e45083a
109f11f
 
e45083a
1cee504
109f11f
 
 
e45083a
109f11f
 
 
e45083a
109f11f
 
 
 
e45083a
109f11f
 
 
e45083a
109f11f
 
 
 
75bf974
109f11f
8f939dc
109f11f
 
 
 
 
 
 
 
 
 
 
e45083a
 
109f11f
 
 
e45083a
109f11f
 
e45083a
109f11f
 
 
e45083a
109f11f
 
e45083a
109f11f
 
 
 
 
 
e45083a
109f11f
 
11de92c
109f11f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import os
import json
import requests
import logging
from typing import Dict, List, Optional, Any, Union

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class MCPClient:
    """
    Client for interacting with MCP (Model Context Protocol) servers.
    Implements a subset of the MCP protocol sufficient for TTS and other basic tools.
    """
    
    def __init__(self, server_url: str):
        """
        Initialize an MCP client for a specific server URL
        
        Args:
            server_url: The URL of the MCP server to connect to
        """
        self.server_url = server_url
        self.session_id = None
        logger.info(f"Initialized MCP Client for server: {server_url}")
    
    def connect(self) -> bool:
        """
        Establish connection with the MCP server
        
        Returns:
            bool: True if connection was successful, False otherwise
        """
        try:
            # For a real MCP implementation, this would use the MCP initialization protocol
            # This is a simplified version for demonstration purposes
            response = requests.post(
                f"{self.server_url}/connect",
                json={"client": "Serverless-TextGen-Hub", "version": "1.0.0"},
                timeout=10
            )
            
            if response.status_code == 200:
                result = response.json()
                self.session_id = result.get("session_id")
                logger.info(f"Connected to MCP server with session ID: {self.session_id}")
                return True
            else:
                logger.error(f"Failed to connect to MCP server: {response.status_code} - {response.text}")
                return False
        except Exception as e:
            logger.error(f"Error connecting to MCP server: {e}")
            return False
    
    def list_tools(self) -> List[Dict]:
        """
        List available tools from the MCP server
        
        Returns:
            List[Dict]: List of tool definitions from the server
        """
        if not self.session_id:
            if not self.connect():
                return []
        
        try:
            # In a real MCP implementation, this would use the tools/list method
            response = requests.get(
                f"{self.server_url}/tools/list",
                headers={"X-MCP-Session": self.session_id},
                timeout=10
            )
            
            if response.status_code == 200:
                result = response.json()
                tools = result.get("tools", [])
                logger.info(f"Retrieved {len(tools)} tools from MCP server")
                return tools
            else:
                logger.error(f"Failed to list tools: {response.status_code} - {response.text}")
                return []
        except Exception as e:
            logger.error(f"Error listing tools: {e}")
            return []
    
    def call_tool(self, tool_name: str, args: Dict) -> Dict:
        """
        Call a tool on the MCP server
        
        Args:
            tool_name: Name of the tool to call
            args: Arguments to pass to the tool
            
        Returns:
            Dict: Result of the tool call
        """
        if not self.session_id:
            if not self.connect():
                return {"error": "Not connected to MCP server"}
        
        try:
            # In a real MCP implementation, this would use the tools/call method
            response = requests.post(
                f"{self.server_url}/tools/call",
                headers={"X-MCP-Session": self.session_id},
                json={"name": tool_name, "arguments": args},
                timeout=30  # Longer timeout for tool calls
            )
            
            if response.status_code == 200:
                result = response.json()
                logger.info(f"Successfully called tool {tool_name}")
                return result
            else:
                error_msg = f"Failed to call tool {tool_name}: {response.status_code} - {response.text}"
                logger.error(error_msg)
                return {"error": error_msg}
        except Exception as e:
            error_msg = f"Error calling tool {tool_name}: {e}"
            logger.error(error_msg)
            return {"error": error_msg}
    
    def close(self):
        """Clean up the client connection"""
        if self.session_id:
            try:
                # For a real MCP implementation, this would use the shutdown method
                requests.post(
                    f"{self.server_url}/disconnect",
                    headers={"X-MCP-Session": self.session_id},
                    timeout=5
                )
                logger.info(f"Disconnected from MCP server")
            except Exception as e:
                logger.error(f"Error disconnecting from MCP server: {e}")
            finally:
                self.session_id = None


def get_mcp_servers() -> Dict[str, Dict[str, str]]:
    """
    Load MCP server configuration from environment variable
    
    Returns:
        Dict[str, Dict[str, str]]: Map of server names to server configurations
    """
    try:
        mcp_config = os.getenv("MCP_CONFIG")
        if mcp_config:
            servers = json.loads(mcp_config)
            logger.info(f"Loaded {len(servers)} MCP servers from configuration")
            return servers
        else:
            logger.warning("No MCP configuration found")
            return {}
    except Exception as e:
        logger.error(f"Error loading MCP configuration: {e}")
        return {}


def text_to_speech(text: str, server_name: str = None) -> Optional[str]:
    """
    Convert text to speech using an MCP TTS server
    
    Args:
        text: The text to convert to speech
        server_name: Name of the MCP server to use for TTS
        
    Returns:
        Optional[str]: Data URL containing the audio, or None if conversion failed
    """
    servers = get_mcp_servers()
    
    if not server_name or server_name not in servers:
        logger.warning(f"TTS server {server_name} not configured")
        return None
    
    server_url = servers[server_name].get("url")
    if not server_url:
        logger.warning(f"No URL found for TTS server {server_name}")
        return None
    
    client = MCPClient(server_url)
    
    try:
        # List available tools to find the TTS tool
        tools = client.list_tools()
        
        # Find a TTS tool - look for common TTS tool names
        tts_tool = next(
            (t for t in tools if any(
                name in t["name"].lower() 
                for name in ["text_to_audio", "tts", "text_to_speech", "speech"]
            )), 
            None
        )
        
        if not tts_tool:
            logger.warning(f"No TTS tool found on server {server_name}")
            return None
        
        # Call the TTS tool
        result = client.call_tool(tts_tool["name"], {"text": text, "speed": 1.0})
        
        if "error" in result:
            logger.error(f"TTS error: {result['error']}")
            return None
        
        # Process the result - usually a base64 encoded WAV
        audio_data = result.get("audio") or result.get("content") or result.get("result")
        
        if isinstance(audio_data, str) and audio_data.startswith("data:audio"):
            # Already a data URL
            return audio_data
        elif isinstance(audio_data, str):
            # Assume it's base64 encoded
            return f"data:audio/wav;base64,{audio_data}"
        else:
            logger.error(f"Unexpected TTS result format: {type(audio_data)}")
            return None
    
    finally:
        client.close()