Spaces:
Running
Running
import os | |
import json | |
import requests | |
import logging | |
from typing import Dict, List, Optional, Any, Union | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
class MCPClient: | |
""" | |
Client for interacting with MCP (Model Context Protocol) servers. | |
Implements a subset of the MCP protocol sufficient for TTS and other basic tools. | |
""" | |
def __init__(self, server_url: str): | |
""" | |
Initialize an MCP client for a specific server URL | |
Args: | |
server_url: The URL of the MCP server to connect to | |
""" | |
self.server_url = server_url | |
self.session_id = None | |
logger.info(f"Initialized MCP Client for server: {server_url}") | |
def connect(self) -> bool: | |
""" | |
Establish connection with the MCP server | |
Returns: | |
bool: True if connection was successful, False otherwise | |
""" | |
try: | |
# For a real MCP implementation, this would use the MCP initialization protocol | |
# This is a simplified version for demonstration purposes | |
response = requests.post( | |
f"{self.server_url}/connect", | |
json={"client": "Serverless-TextGen-Hub", "version": "1.0.0"}, | |
timeout=10 | |
) | |
if response.status_code == 200: | |
result = response.json() | |
self.session_id = result.get("session_id") | |
logger.info(f"Connected to MCP server with session ID: {self.session_id}") | |
return True | |
else: | |
logger.error(f"Failed to connect to MCP server: {response.status_code} - {response.text}") | |
return False | |
except Exception as e: | |
logger.error(f"Error connecting to MCP server: {e}") | |
return False | |
def list_tools(self) -> List[Dict]: | |
""" | |
List available tools from the MCP server | |
Returns: | |
List[Dict]: List of tool definitions from the server | |
""" | |
if not self.session_id: | |
if not self.connect(): | |
return [] | |
try: | |
# In a real MCP implementation, this would use the tools/list method | |
response = requests.get( | |
f"{self.server_url}/tools/list", | |
headers={"X-MCP-Session": self.session_id}, | |
timeout=10 | |
) | |
if response.status_code == 200: | |
result = response.json() | |
tools = result.get("tools", []) | |
logger.info(f"Retrieved {len(tools)} tools from MCP server") | |
return tools | |
else: | |
logger.error(f"Failed to list tools: {response.status_code} - {response.text}") | |
return [] | |
except Exception as e: | |
logger.error(f"Error listing tools: {e}") | |
return [] | |
def call_tool(self, tool_name: str, args: Dict) -> Dict: | |
""" | |
Call a tool on the MCP server | |
Args: | |
tool_name: Name of the tool to call | |
args: Arguments to pass to the tool | |
Returns: | |
Dict: Result of the tool call | |
""" | |
if not self.session_id: | |
if not self.connect(): | |
return {"error": "Not connected to MCP server"} | |
try: | |
# In a real MCP implementation, this would use the tools/call method | |
response = requests.post( | |
f"{self.server_url}/tools/call", | |
headers={"X-MCP-Session": self.session_id}, | |
json={"name": tool_name, "arguments": args}, | |
timeout=30 # Longer timeout for tool calls | |
) | |
if response.status_code == 200: | |
result = response.json() | |
logger.info(f"Successfully called tool {tool_name}") | |
return result | |
else: | |
error_msg = f"Failed to call tool {tool_name}: {response.status_code} - {response.text}" | |
logger.error(error_msg) | |
return {"error": error_msg} | |
except Exception as e: | |
error_msg = f"Error calling tool {tool_name}: {e}" | |
logger.error(error_msg) | |
return {"error": error_msg} | |
def close(self): | |
"""Clean up the client connection""" | |
if self.session_id: | |
try: | |
# For a real MCP implementation, this would use the shutdown method | |
requests.post( | |
f"{self.server_url}/disconnect", | |
headers={"X-MCP-Session": self.session_id}, | |
timeout=5 | |
) | |
logger.info(f"Disconnected from MCP server") | |
except Exception as e: | |
logger.error(f"Error disconnecting from MCP server: {e}") | |
finally: | |
self.session_id = None | |
def get_mcp_servers() -> Dict[str, Dict[str, str]]: | |
""" | |
Load MCP server configuration from environment variable | |
Returns: | |
Dict[str, Dict[str, str]]: Map of server names to server configurations | |
""" | |
try: | |
mcp_config = os.getenv("MCP_CONFIG") | |
if mcp_config: | |
servers = json.loads(mcp_config) | |
logger.info(f"Loaded {len(servers)} MCP servers from configuration") | |
return servers | |
else: | |
logger.warning("No MCP configuration found") | |
return {} | |
except Exception as e: | |
logger.error(f"Error loading MCP configuration: {e}") | |
return {} | |
def text_to_speech(text: str, server_name: str = None) -> Optional[str]: | |
""" | |
Convert text to speech using an MCP TTS server | |
Args: | |
text: The text to convert to speech | |
server_name: Name of the MCP server to use for TTS | |
Returns: | |
Optional[str]: Data URL containing the audio, or None if conversion failed | |
""" | |
servers = get_mcp_servers() | |
if not server_name or server_name not in servers: | |
logger.warning(f"TTS server {server_name} not configured") | |
return None | |
server_url = servers[server_name].get("url") | |
if not server_url: | |
logger.warning(f"No URL found for TTS server {server_name}") | |
return None | |
client = MCPClient(server_url) | |
try: | |
# List available tools to find the TTS tool | |
tools = client.list_tools() | |
# Find a TTS tool - look for common TTS tool names | |
tts_tool = next( | |
(t for t in tools if any( | |
name in t["name"].lower() | |
for name in ["text_to_audio", "tts", "text_to_speech", "speech"] | |
)), | |
None | |
) | |
if not tts_tool: | |
logger.warning(f"No TTS tool found on server {server_name}") | |
return None | |
# Call the TTS tool | |
result = client.call_tool(tts_tool["name"], {"text": text, "speed": 1.0}) | |
if "error" in result: | |
logger.error(f"TTS error: {result['error']}") | |
return None | |
# Process the result - usually a base64 encoded WAV | |
audio_data = result.get("audio") or result.get("content") or result.get("result") | |
if isinstance(audio_data, str) and audio_data.startswith("data:audio"): | |
# Already a data URL | |
return audio_data | |
elif isinstance(audio_data, str): | |
# Assume it's base64 encoded | |
return f"data:audio/wav;base64,{audio_data}" | |
else: | |
logger.error(f"Unexpected TTS result format: {type(audio_data)}") | |
return None | |
finally: | |
client.close() |