mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 03:15:06 +08:00
perf: make startup async
This commit is contained in:
@@ -38,10 +38,7 @@ class RuntimeMCPServer:
|
||||
}
|
||||
|
||||
self.session = RuntimeMCPSession(
|
||||
self.mcp_server_entity.name,
|
||||
mixed_config,
|
||||
self.mcp_server_entity.enable,
|
||||
self.ap
|
||||
self.mcp_server_entity.name, mixed_config, self.mcp_server_entity.enable, self.ap
|
||||
)
|
||||
await self.session.initialize()
|
||||
await self.session.start()
|
||||
@@ -59,12 +56,7 @@ class RuntimeMCPServer:
|
||||
**self.mcp_server_entity.extra_args,
|
||||
}
|
||||
|
||||
test_session = RuntimeMCPSession(
|
||||
self.mcp_server_entity.name,
|
||||
mixed_config,
|
||||
enable=True,
|
||||
ap=self.ap
|
||||
)
|
||||
test_session = RuntimeMCPSession(self.mcp_server_entity.name, mixed_config, enable=True, ap=self.ap)
|
||||
await test_session.start()
|
||||
|
||||
# 获取工具列表作为测试
|
||||
@@ -104,68 +96,12 @@ class RuntimeMCPServer:
|
||||
await self.session.shutdown()
|
||||
|
||||
|
||||
|
||||
class MCPService:
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
def _convert_server_entity_to_config(
|
||||
self, server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer]
|
||||
) -> dict:
|
||||
"""将数据库实体转换为 loader 需要的配置字典
|
||||
|
||||
Args:
|
||||
server_entity: 数据库查询返回的服务器实体或 Row 对象
|
||||
|
||||
Returns:
|
||||
包含服务器配置的字典
|
||||
"""
|
||||
if isinstance(server_entity, sqlalchemy.Row):
|
||||
server = persistence_mcp.MCPServer(**server_entity._mapping)
|
||||
else:
|
||||
server = server_entity
|
||||
|
||||
return {
|
||||
'name': server.name,
|
||||
'mode': server.mode,
|
||||
'enable': server.enable,
|
||||
'extra_args': server.extra_args,
|
||||
}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化 MCP Service,从数据库加载所有 MCP 服务器到运行时"""
|
||||
self.ap.logger.info('Initializing MCP Service and loading servers from database...')
|
||||
|
||||
if not self.ap.tool_mgr or not self.ap.tool_mgr.mcp_tool_loader:
|
||||
self.ap.logger.warning('MCP tool loader not available, skipping MCP servers initialization')
|
||||
return
|
||||
|
||||
try:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
|
||||
servers = result.all()
|
||||
|
||||
loaded_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for server in servers:
|
||||
try:
|
||||
# 将数据库实体转换为配置字典后传递给 loader
|
||||
server_config = self._convert_server_entity_to_config(server)
|
||||
await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config)
|
||||
loaded_count += 1
|
||||
self.ap.logger.debug(f'Loaded MCP server: {server_config["name"]}')
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
|
||||
server_name = getattr(server, 'name', 'unknown')
|
||||
self.ap.logger.error(f'Failed to load MCP server {server_name}: {e}\n{traceback.format_exc()}')
|
||||
|
||||
self.ap.logger.info(f'MCP Service initialization complete. Loaded: {loaded_count}, Failed: {failed_count}')
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to initialize MCP Service: {e}\n{traceback.format_exc()}')
|
||||
|
||||
async def get_mcp_servers(self) -> list[dict]:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
|
||||
|
||||
@@ -180,9 +116,10 @@ class MCPService:
|
||||
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_data['uuid'])
|
||||
)
|
||||
server_entity = result.first()
|
||||
if server_entity and self.ap.tool_mgr.mcp_tool_loader:
|
||||
server_config = self._convert_server_entity_to_config(server_entity)
|
||||
await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config)
|
||||
if server_entity:
|
||||
server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server_entity)
|
||||
if self.ap.tool_mgr.mcp_tool_loader:
|
||||
await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config)
|
||||
|
||||
return server_data['uuid']
|
||||
|
||||
@@ -205,14 +142,12 @@ class MCPService:
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
|
||||
|
||||
async def update_mcp_server(self, server_uuid: str, server_data: dict) -> None:
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
|
||||
)
|
||||
old_server = result.first()
|
||||
old_server_name = old_server.name if old_server else None
|
||||
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_mcp.MCPServer)
|
||||
.where(persistence_mcp.MCPServer.uuid == server_uuid)
|
||||
@@ -220,41 +155,36 @@ class MCPService:
|
||||
)
|
||||
|
||||
if self.ap.tool_mgr.mcp_tool_loader:
|
||||
|
||||
if old_server_name and old_server_name in self.ap.tool_mgr.mcp_tool_loader.sessions:
|
||||
await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(old_server_name)
|
||||
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
|
||||
)
|
||||
updated_server = result.first()
|
||||
if updated_server:
|
||||
# convert entity to config dict
|
||||
server_config = self._convert_server_entity_to_config(updated_server)
|
||||
server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, updated_server)
|
||||
await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config)
|
||||
|
||||
async def delete_mcp_server(self, server_uuid: str) -> None:
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
|
||||
)
|
||||
server = result.first()
|
||||
server_name = server.name if server else None
|
||||
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
|
||||
)
|
||||
|
||||
|
||||
if server_name and self.ap.tool_mgr.mcp_tool_loader:
|
||||
if server_name in self.ap.tool_mgr.mcp_tool_loader.sessions:
|
||||
await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(server_name)
|
||||
|
||||
async def test_mcp_server(self, server_uuid: str) -> str:
|
||||
"""测试 MCP 服务器连接并返回任务 ID"""
|
||||
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
|
||||
)
|
||||
@@ -262,7 +192,6 @@ class MCPService:
|
||||
if server is None:
|
||||
raise ValueError(f'Server not found: {server_uuid}')
|
||||
|
||||
|
||||
if isinstance(server, sqlalchemy.Row):
|
||||
server_entity = persistence_mcp.MCPServer(**server._mapping)
|
||||
else:
|
||||
@@ -270,5 +199,4 @@ class MCPService:
|
||||
|
||||
runtime_server = RuntimeMCPServer(ap=self.ap, mcp_server_entity=server_entity)
|
||||
|
||||
|
||||
return await runtime_server.test_connection()
|
||||
|
||||
@@ -129,7 +129,6 @@ class BuildAppStage(stage.BootingStage):
|
||||
|
||||
mcp_service_inst = mcp_service.MCPService(ap)
|
||||
ap.mcp_service = mcp_service_inst
|
||||
await mcp_service_inst.initialize()
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
|
||||
@@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
import typing
|
||||
from contextlib import AsyncExitStack
|
||||
import traceback
|
||||
import sqlalchemy
|
||||
import asyncio
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
@@ -11,6 +13,7 @@ from mcp.client.sse import sse_client
|
||||
from .. import loader
|
||||
from ....core import app
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
from ....entity.persistence import mcp as persistence_mcp
|
||||
|
||||
|
||||
class RuntimeMCPSession:
|
||||
@@ -77,8 +80,6 @@ class RuntimeMCPSession:
|
||||
if not self.enable:
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f'初始化 MCP 会话: {self.server_name} {self.server_config}')
|
||||
|
||||
if self.server_config['mode'] == 'stdio':
|
||||
await self._init_stdio_python_server()
|
||||
elif self.server_config['mode'] == 'sse':
|
||||
@@ -110,6 +111,9 @@ class RuntimeMCPSession:
|
||||
)
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[resource_tool.LLMTool]:
|
||||
return self.functions
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭会话并清理资源"""
|
||||
try:
|
||||
@@ -128,20 +132,59 @@ class MCPLoader(loader.ToolLoader):
|
||||
在此加载器中管理所有与 MCP Server 的连接。
|
||||
"""
|
||||
|
||||
sessions: dict[str, RuntimeMCPSession] = {}
|
||||
sessions: dict[str, RuntimeMCPSession]
|
||||
|
||||
_last_listed_functions: list[resource_tool.LLMTool] = []
|
||||
_last_listed_functions: list[resource_tool.LLMTool]
|
||||
|
||||
_startup_load_tasks: list[asyncio.Task]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
super().__init__(ap)
|
||||
self.sessions = {}
|
||||
self._last_listed_functions = []
|
||||
self._startup_load_tasks = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
await self.load_mcp_servers_from_db()
|
||||
|
||||
async def init_runtime_mcp_session(self, server_config: dict):
|
||||
"""从服务器配置创建运行时会话
|
||||
async def load_mcp_servers_from_db(self):
|
||||
self.ap.logger.info('Loading MCP servers from db...')
|
||||
|
||||
self.sessions = {}
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
|
||||
servers = result.all()
|
||||
|
||||
for server in servers:
|
||||
server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
|
||||
|
||||
async def load_mcp_server_task():
|
||||
self.ap.logger.debug(f'Loading MCP server {server_config}')
|
||||
try:
|
||||
session = await self.load_mcp_server(server_config)
|
||||
self.sessions[server_config['name']] = session
|
||||
except Exception as e:
|
||||
self.ap.logger.error(
|
||||
f'Failed to load MCP server from db: {server_config["name"]}({server_config["uuid"]}): {e}\n{traceback.format_exc()}'
|
||||
)
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f'Starting MCP server {server_config["name"]}({server_config["uuid"]})')
|
||||
try:
|
||||
await session.start()
|
||||
except Exception as e:
|
||||
self.ap.logger.error(
|
||||
f'Failed to start MCP server {server_config["name"]}({server_config["uuid"]}): {e}\n{traceback.format_exc()}'
|
||||
)
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f'Started MCP server {server_config["name"]}({server_config["uuid"]})')
|
||||
|
||||
task = asyncio.create_task(load_mcp_server_task())
|
||||
self._startup_load_tasks.append(task)
|
||||
|
||||
async def load_mcp_server(self, server_config: dict) -> RuntimeMCPSession:
|
||||
"""加载 MCP 服务器到运行时
|
||||
|
||||
Args:
|
||||
server_config: 服务器配置字典,必须包含:
|
||||
@@ -150,6 +193,7 @@ class MCPLoader(loader.ToolLoader):
|
||||
- enable: 是否启用
|
||||
- extra_args: 额外的配置参数 (可选)
|
||||
"""
|
||||
|
||||
name = server_config['name']
|
||||
mode = server_config['mode']
|
||||
enable = server_config['enable']
|
||||
@@ -167,25 +211,11 @@ class MCPLoader(loader.ToolLoader):
|
||||
|
||||
return session
|
||||
|
||||
async def load_mcp_server(self, server_config: dict):
|
||||
"""加载 MCP 服务器到运行时
|
||||
|
||||
Args:
|
||||
server_config: 服务器配置字典,必须包含:
|
||||
- name: 服务器名称
|
||||
- mode: 连接模式 (stdio/sse)
|
||||
- enable: 是否启用
|
||||
- extra_args: 额外的配置参数 (可选)
|
||||
"""
|
||||
session = await self.init_runtime_mcp_session(server_config)
|
||||
await session.start()
|
||||
self.sessions[server_config['name']] = session
|
||||
|
||||
async def get_tools(self) -> list[resource_tool.LLMTool]:
|
||||
all_functions = []
|
||||
|
||||
for session in self.sessions.values():
|
||||
all_functions.extend(session.functions)
|
||||
all_functions.extend(session.get_tools())
|
||||
|
||||
self._last_listed_functions = all_functions
|
||||
|
||||
@@ -194,7 +224,7 @@ class MCPLoader(loader.ToolLoader):
|
||||
async def has_tool(self, name: str) -> bool:
|
||||
"""检查工具是否存在"""
|
||||
for session in self.sessions.values():
|
||||
for function in session.functions:
|
||||
for function in session.get_tools():
|
||||
if function.name == name:
|
||||
return True
|
||||
return False
|
||||
@@ -202,7 +232,7 @@ class MCPLoader(loader.ToolLoader):
|
||||
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
|
||||
"""执行工具调用"""
|
||||
for session in self.sessions.values():
|
||||
for function in session.functions:
|
||||
for function in session.get_tools():
|
||||
if function.name == name:
|
||||
self.ap.logger.debug(f'Invoking MCP tool: {name} with parameters: {parameters}')
|
||||
try:
|
||||
@@ -254,7 +284,7 @@ class MCPLoader(loader.ToolLoader):
|
||||
def get_server_tool_count(self, server_name: str) -> int:
|
||||
"""获取指定服务器的工具数量"""
|
||||
session = self.get_session(server_name)
|
||||
return len(session.functions) if session else 0
|
||||
return len(session.get_tools()) if session else 0
|
||||
|
||||
def get_all_servers_info(self) -> dict[str, dict]:
|
||||
"""获取所有服务器的信息"""
|
||||
@@ -264,8 +294,8 @@ class MCPLoader(loader.ToolLoader):
|
||||
'name': server_name,
|
||||
'mode': session.server_config.get('mode'),
|
||||
'enable': session.enable,
|
||||
'tools_count': len(session.functions),
|
||||
'tool_names': [f.name for f in session.functions],
|
||||
'tools_count': len(session.get_tools()),
|
||||
'tool_names': [f.name for f in session.get_tools()],
|
||||
}
|
||||
return info
|
||||
|
||||
|
||||
Reference in New Issue
Block a user