Spaces:
Running
Running
| import os | |
| import json | |
| import requests | |
| import logging | |
| from typing import Dict, List, Optional, Any, Union | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class MCPClient: | |
| """ | |
| Client for interacting with MCP (Model Context Protocol) servers. | |
| Implements a subset of the MCP protocol sufficient for TTS and other basic tools. | |
| """ | |
| def __init__(self, server_url: str): | |
| """ | |
| Initialize an MCP client for a specific server URL | |
| Args: | |
| server_url: The URL of the MCP server to connect to | |
| """ | |
| self.server_url = server_url | |
| self.session_id = None | |
| logger.info(f"Initialized MCP Client for server: {server_url}") | |
| def connect(self) -> bool: | |
| """ | |
| Establish connection with the MCP server | |
| Returns: | |
| bool: True if connection was successful, False otherwise | |
| """ | |
| try: | |
| # For a real MCP implementation, this would use the MCP initialization protocol | |
| # This is a simplified version for demonstration purposes | |
| response = requests.post( | |
| f"{self.server_url}/connect", | |
| json={"client": "Serverless-TextGen-Hub", "version": "1.0.0"}, | |
| timeout=10 | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| self.session_id = result.get("session_id") | |
| logger.info(f"Connected to MCP server with session ID: {self.session_id}") | |
| return True | |
| else: | |
| logger.error(f"Failed to connect to MCP server: {response.status_code} - {response.text}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error connecting to MCP server: {e}") | |
| return False | |
| def list_tools(self) -> List[Dict]: | |
| """ | |
| List available tools from the MCP server | |
| Returns: | |
| List[Dict]: List of tool definitions from the server | |
| """ | |
| if not self.session_id: | |
| if not self.connect(): | |
| return [] | |
| try: | |
| # In a real MCP implementation, this would use the tools/list method | |
| response = requests.get( | |
| f"{self.server_url}/tools/list", | |
| headers={"X-MCP-Session": self.session_id}, | |
| timeout=10 | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| tools = result.get("tools", []) | |
| logger.info(f"Retrieved {len(tools)} tools from MCP server") | |
| return tools | |
| else: | |
| logger.error(f"Failed to list tools: {response.status_code} - {response.text}") | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Error listing tools: {e}") | |
| return [] | |
| def call_tool(self, tool_name: str, args: Dict) -> Dict: | |
| """ | |
| Call a tool on the MCP server | |
| Args: | |
| tool_name: Name of the tool to call | |
| args: Arguments to pass to the tool | |
| Returns: | |
| Dict: Result of the tool call | |
| """ | |
| if not self.session_id: | |
| if not self.connect(): | |
| return {"error": "Not connected to MCP server"} | |
| try: | |
| # In a real MCP implementation, this would use the tools/call method | |
| response = requests.post( | |
| f"{self.server_url}/tools/call", | |
| headers={"X-MCP-Session": self.session_id}, | |
| json={"name": tool_name, "arguments": args}, | |
| timeout=30 # Longer timeout for tool calls | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| logger.info(f"Successfully called tool {tool_name}") | |
| return result | |
| else: | |
| error_msg = f"Failed to call tool {tool_name}: {response.status_code} - {response.text}" | |
| logger.error(error_msg) | |
| return {"error": error_msg} | |
| except Exception as e: | |
| error_msg = f"Error calling tool {tool_name}: {e}" | |
| logger.error(error_msg) | |
| return {"error": error_msg} | |
| def close(self): | |
| """Clean up the client connection""" | |
| if self.session_id: | |
| try: | |
| # For a real MCP implementation, this would use the shutdown method | |
| requests.post( | |
| f"{self.server_url}/disconnect", | |
| headers={"X-MCP-Session": self.session_id}, | |
| timeout=5 | |
| ) | |
| logger.info(f"Disconnected from MCP server") | |
| except Exception as e: | |
| logger.error(f"Error disconnecting from MCP server: {e}") | |
| finally: | |
| self.session_id = None | |
| def get_mcp_servers() -> Dict[str, Dict[str, str]]: | |
| """ | |
| Load MCP server configuration from environment variable | |
| Returns: | |
| Dict[str, Dict[str, str]]: Map of server names to server configurations | |
| """ | |
| try: | |
| mcp_config = os.getenv("MCP_CONFIG") | |
| if mcp_config: | |
| servers = json.loads(mcp_config) | |
| logger.info(f"Loaded {len(servers)} MCP servers from configuration") | |
| return servers | |
| else: | |
| logger.warning("No MCP configuration found") | |
| return {} | |
| except Exception as e: | |
| logger.error(f"Error loading MCP configuration: {e}") | |
| return {} | |
| def text_to_speech(text: str, server_name: str = None) -> Optional[str]: | |
| """ | |
| Convert text to speech using an MCP TTS server | |
| Args: | |
| text: The text to convert to speech | |
| server_name: Name of the MCP server to use for TTS | |
| Returns: | |
| Optional[str]: Data URL containing the audio, or None if conversion failed | |
| """ | |
| servers = get_mcp_servers() | |
| if not server_name or server_name not in servers: | |
| logger.warning(f"TTS server {server_name} not configured") | |
| return None | |
| server_url = servers[server_name].get("url") | |
| if not server_url: | |
| logger.warning(f"No URL found for TTS server {server_name}") | |
| return None | |
| client = MCPClient(server_url) | |
| try: | |
| # List available tools to find the TTS tool | |
| tools = client.list_tools() | |
| # Find a TTS tool - look for common TTS tool names | |
| tts_tool = next( | |
| (t for t in tools if any( | |
| name in t["name"].lower() | |
| for name in ["text_to_audio", "tts", "text_to_speech", "speech"] | |
| )), | |
| None | |
| ) | |
| if not tts_tool: | |
| logger.warning(f"No TTS tool found on server {server_name}") | |
| return None | |
| # Call the TTS tool | |
| result = client.call_tool(tts_tool["name"], {"text": text, "speed": 1.0}) | |
| if "error" in result: | |
| logger.error(f"TTS error: {result['error']}") | |
| return None | |
| # Process the result - usually a base64 encoded WAV | |
| audio_data = result.get("audio") or result.get("content") or result.get("result") | |
| if isinstance(audio_data, str) and audio_data.startswith("data:audio"): | |
| # Already a data URL | |
| return audio_data | |
| elif isinstance(audio_data, str): | |
| # Assume it's base64 encoded | |
| return f"data:audio/wav;base64,{audio_data}" | |
| else: | |
| logger.error(f"Unexpected TTS result format: {type(audio_data)}") | |
| return None | |
| finally: | |
| client.close() |