diff --git a/pkg/api/http/service/mcp.py b/pkg/api/http/service/mcp.py index 4250d756..ea120718 100644 --- a/pkg/api/http/service/mcp.py +++ b/pkg/api/http/service/mcp.py @@ -38,10 +38,7 @@ class RuntimeMCPServer: } self.session = RuntimeMCPSession( - self.mcp_server_entity.name, - mixed_config, - self.mcp_server_entity.enable, - self.ap + self.mcp_server_entity.name, mixed_config, self.mcp_server_entity.enable, self.ap ) await self.session.initialize() await self.session.start() @@ -59,12 +56,7 @@ class RuntimeMCPServer: **self.mcp_server_entity.extra_args, } - test_session = RuntimeMCPSession( - self.mcp_server_entity.name, - mixed_config, - enable=True, - ap=self.ap - ) + test_session = RuntimeMCPSession(self.mcp_server_entity.name, mixed_config, enable=True, ap=self.ap) await test_session.start() # 获取工具列表作为测试 @@ -104,68 +96,12 @@ class RuntimeMCPServer: await self.session.shutdown() - class MCPService: ap: app.Application def __init__(self, ap: app.Application) -> None: self.ap = ap - def _convert_server_entity_to_config( - self, server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] - ) -> dict: - """将数据库实体转换为 loader 需要的配置字典 - - Args: - server_entity: 数据库查询返回的服务器实体或 Row 对象 - - Returns: - 包含服务器配置的字典 - """ - if isinstance(server_entity, sqlalchemy.Row): - server = persistence_mcp.MCPServer(**server_entity._mapping) - else: - server = server_entity - - return { - 'name': server.name, - 'mode': server.mode, - 'enable': server.enable, - 'extra_args': server.extra_args, - } - - async def initialize(self) -> None: - """初始化 MCP Service,从数据库加载所有 MCP 服务器到运行时""" - self.ap.logger.info('Initializing MCP Service and loading servers from database...') - - if not self.ap.tool_mgr or not self.ap.tool_mgr.mcp_tool_loader: - self.ap.logger.warning('MCP tool loader not available, skipping MCP servers initialization') - return - - try: - result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer)) - servers = result.all() - - loaded_count = 0 - failed_count = 0 - - for server in servers: - try: - # 将数据库实体转换为配置字典后传递给 loader - server_config = self._convert_server_entity_to_config(server) - await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config) - loaded_count += 1 - self.ap.logger.debug(f'Loaded MCP server: {server_config["name"]}') - except Exception as e: - failed_count += 1 - - server_name = getattr(server, 'name', 'unknown') - self.ap.logger.error(f'Failed to load MCP server {server_name}: {e}\n{traceback.format_exc()}') - - self.ap.logger.info(f'MCP Service initialization complete. Loaded: {loaded_count}, Failed: {failed_count}') - except Exception as e: - self.ap.logger.error(f'Failed to initialize MCP Service: {e}\n{traceback.format_exc()}') - async def get_mcp_servers(self) -> list[dict]: result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer)) @@ -180,9 +116,10 @@ class MCPService: sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_data['uuid']) ) server_entity = result.first() - if server_entity and self.ap.tool_mgr.mcp_tool_loader: - server_config = self._convert_server_entity_to_config(server_entity) - await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config) + if server_entity: + server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server_entity) + if self.ap.tool_mgr.mcp_tool_loader: + await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config) return server_data['uuid'] @@ -205,14 +142,12 @@ class MCPService: return self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server) async def update_mcp_server(self, server_uuid: str, server_data: dict) -> None: - result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) old_server = result.first() old_server_name = old_server.name if old_server else None - await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_mcp.MCPServer) .where(persistence_mcp.MCPServer.uuid == server_uuid) @@ -220,41 +155,36 @@ class MCPService: ) if self.ap.tool_mgr.mcp_tool_loader: - if old_server_name and old_server_name in self.ap.tool_mgr.mcp_tool_loader.sessions: await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(old_server_name) - result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) updated_server = result.first() if updated_server: # convert entity to config dict - server_config = self._convert_server_entity_to_config(updated_server) + server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, updated_server) await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config) async def delete_mcp_server(self, server_uuid: str) -> None: - result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) server = result.first() server_name = server.name if server else None - await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) - if server_name and self.ap.tool_mgr.mcp_tool_loader: if server_name in self.ap.tool_mgr.mcp_tool_loader.sessions: await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(server_name) async def test_mcp_server(self, server_uuid: str) -> str: """测试 MCP 服务器连接并返回任务 ID""" - + result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) @@ -262,7 +192,6 @@ class MCPService: if server is None: raise ValueError(f'Server not found: {server_uuid}') - if isinstance(server, sqlalchemy.Row): server_entity = persistence_mcp.MCPServer(**server._mapping) else: @@ -270,5 +199,4 @@ class MCPService: runtime_server = RuntimeMCPServer(ap=self.ap, mcp_server_entity=server_entity) - return await runtime_server.test_connection() diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index f2991315..8df32755 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -129,7 +129,6 @@ class BuildAppStage(stage.BootingStage): mcp_service_inst = mcp_service.MCPService(ap) ap.mcp_service = mcp_service_inst - await mcp_service_inst.initialize() ctrl = controller.Controller(ap) ap.ctrl = ctrl diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py index 0751ca42..fc9db294 100644 --- a/pkg/provider/tools/loaders/mcp.py +++ b/pkg/provider/tools/loaders/mcp.py @@ -3,6 +3,8 @@ from __future__ import annotations import typing from contextlib import AsyncExitStack import traceback +import sqlalchemy +import asyncio from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client @@ -11,6 +13,7 @@ from mcp.client.sse import sse_client from .. import loader from ....core import app import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +from ....entity.persistence import mcp as persistence_mcp class RuntimeMCPSession: @@ -77,8 +80,6 @@ class RuntimeMCPSession: if not self.enable: return - self.ap.logger.debug(f'初始化 MCP 会话: {self.server_name} {self.server_config}') - if self.server_config['mode'] == 'stdio': await self._init_stdio_python_server() elif self.server_config['mode'] == 'sse': @@ -110,6 +111,9 @@ class RuntimeMCPSession: ) ) + def get_tools(self) -> list[resource_tool.LLMTool]: + return self.functions + async def shutdown(self): """关闭会话并清理资源""" try: @@ -128,20 +132,59 @@ class MCPLoader(loader.ToolLoader): 在此加载器中管理所有与 MCP Server 的连接。 """ - sessions: dict[str, RuntimeMCPSession] = {} + sessions: dict[str, RuntimeMCPSession] - _last_listed_functions: list[resource_tool.LLMTool] = [] + _last_listed_functions: list[resource_tool.LLMTool] + + _startup_load_tasks: list[asyncio.Task] def __init__(self, ap: app.Application): super().__init__(ap) self.sessions = {} self._last_listed_functions = [] + self._startup_load_tasks = [] async def initialize(self): - pass + await self.load_mcp_servers_from_db() - async def init_runtime_mcp_session(self, server_config: dict): - """从服务器配置创建运行时会话 + async def load_mcp_servers_from_db(self): + self.ap.logger.info('Loading MCP servers from db...') + + self.sessions = {} + + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer)) + servers = result.all() + + for server in servers: + server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server) + + async def load_mcp_server_task(): + self.ap.logger.debug(f'Loading MCP server {server_config}') + try: + session = await self.load_mcp_server(server_config) + self.sessions[server_config['name']] = session + except Exception as e: + self.ap.logger.error( + f'Failed to load MCP server from db: {server_config["name"]}({server_config["uuid"]}): {e}\n{traceback.format_exc()}' + ) + return + + self.ap.logger.debug(f'Starting MCP server {server_config["name"]}({server_config["uuid"]})') + try: + await session.start() + except Exception as e: + self.ap.logger.error( + f'Failed to start MCP server {server_config["name"]}({server_config["uuid"]}): {e}\n{traceback.format_exc()}' + ) + return + + self.ap.logger.debug(f'Started MCP server {server_config["name"]}({server_config["uuid"]})') + + task = asyncio.create_task(load_mcp_server_task()) + self._startup_load_tasks.append(task) + + async def load_mcp_server(self, server_config: dict) -> RuntimeMCPSession: + """加载 MCP 服务器到运行时 Args: server_config: 服务器配置字典,必须包含: @@ -150,6 +193,7 @@ class MCPLoader(loader.ToolLoader): - enable: 是否启用 - extra_args: 额外的配置参数 (可选) """ + name = server_config['name'] mode = server_config['mode'] enable = server_config['enable'] @@ -167,25 +211,11 @@ class MCPLoader(loader.ToolLoader): return session - async def load_mcp_server(self, server_config: dict): - """加载 MCP 服务器到运行时 - - Args: - server_config: 服务器配置字典,必须包含: - - name: 服务器名称 - - mode: 连接模式 (stdio/sse) - - enable: 是否启用 - - extra_args: 额外的配置参数 (可选) - """ - session = await self.init_runtime_mcp_session(server_config) - await session.start() - self.sessions[server_config['name']] = session - async def get_tools(self) -> list[resource_tool.LLMTool]: all_functions = [] for session in self.sessions.values(): - all_functions.extend(session.functions) + all_functions.extend(session.get_tools()) self._last_listed_functions = all_functions @@ -194,7 +224,7 @@ class MCPLoader(loader.ToolLoader): async def has_tool(self, name: str) -> bool: """检查工具是否存在""" for session in self.sessions.values(): - for function in session.functions: + for function in session.get_tools(): if function.name == name: return True return False @@ -202,7 +232,7 @@ class MCPLoader(loader.ToolLoader): async def invoke_tool(self, name: str, parameters: dict) -> typing.Any: """执行工具调用""" for session in self.sessions.values(): - for function in session.functions: + for function in session.get_tools(): if function.name == name: self.ap.logger.debug(f'Invoking MCP tool: {name} with parameters: {parameters}') try: @@ -254,7 +284,7 @@ class MCPLoader(loader.ToolLoader): def get_server_tool_count(self, server_name: str) -> int: """获取指定服务器的工具数量""" session = self.get_session(server_name) - return len(session.functions) if session else 0 + return len(session.get_tools()) if session else 0 def get_all_servers_info(self) -> dict[str, dict]: """获取所有服务器的信息""" @@ -264,8 +294,8 @@ class MCPLoader(loader.ToolLoader): 'name': server_name, 'mode': session.server_config.get('mode'), 'enable': session.enable, - 'tools_count': len(session.functions), - 'tool_names': [f.name for f in session.functions], + 'tools_count': len(session.get_tools()), + 'tool_names': [f.name for f in session.get_tools()], } return info