refactor(mcp): bridge controller and db operation with service layer

This commit is contained in:
Junyan Qin
2025-11-02 13:05:55 +08:00
parent e17b0cf5c5
commit 4c0917556f
9 changed files with 152 additions and 673 deletions

View File

@@ -1,143 +0,0 @@
from __future__ import annotations
import quart
import datetime
from .. import group
@group.group_class('market', '/api/v1/market')
class MarketRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/plugins', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""获取插件市场列表"""
# data = await quart.request.json
# page = data.get('page', 1)
# page_size = data.get('page_size', 10)
# query = data.get('query', '')
# sort_by = data.get('sort_by', 'stars')
# sort_order = data.get('sort_order', 'DESC')
# # 这里是获取插件列表的实现
# # 实际项目中这部分会连接到真实的插件市场API或数据库
# # 这里我们只是返回一些假数据作为示例
# # 模拟延迟
# import asyncio
# await asyncio.sleep(0.5)
# 返回结果
return self.success(data={'plugins': [], 'total': 0})
@self.route('/mcp', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""获取MCP服务器市场列表"""
data = await quart.request.json
page = data.get('page', 1)
page_size = data.get('page_size', 10)
query = data.get('query', '')
sort_by = data.get('sort_by', 'stars')
sort_order = data.get('sort_order', 'DESC')
# 这里是获取MCP服务器列表的实现
# 实际项目中这部分会连接到真实的MCP市场API或数据库
# 这里我们只是返回一些假数据作为示例
# 模拟延迟
import asyncio
await asyncio.sleep(0.5)
# 生成假数据
servers = []
# 只在有搜索关键词或排序时才返回数据
if query or sort_by:
now = datetime.datetime.now().isoformat()
yesterday = (datetime.datetime.now() - datetime.timedelta(days=1)).isoformat()
test_servers = [
{
'ID': 1,
'CreatedAt': yesterday,
'UpdatedAt': now,
'DeletedAt': None,
'name': 'Google Maps MCP',
'author': 'langbot-community',
'description': 'Google Maps integration for LangBot, providing geocoding and directions capabilities.',
'repository': 'langbot-community/google-maps-mcp',
'artifacts_path': '',
'stars': 124,
'downloads': 342,
'status': 'initialized',
'synced_at': now,
'pushed_at': now,
'version': '1.0.0',
},
{
'ID': 2,
'CreatedAt': yesterday,
'UpdatedAt': now,
'DeletedAt': None,
'name': 'Weather MCP',
'author': 'langbot-community',
'description': 'Weather integration for LangBot, providing current weather and forecasts.',
'repository': 'langbot-community/weather-mcp',
'artifacts_path': '',
'stars': 85,
'downloads': 215,
'status': 'initialized',
'synced_at': now,
'pushed_at': yesterday,
'version': '1.1.0',
},
{
'ID': 3,
'CreatedAt': yesterday,
'UpdatedAt': now,
'DeletedAt': None,
'name': 'Serper Search MCP',
'author': 'langbot-developers',
'description': 'Serper Search integration for LangBot, providing advanced web search capabilities.',
'repository': 'langbot-developers/serper-search-mcp',
'artifacts_path': '',
'stars': 67,
'downloads': 178,
'status': 'initialized',
'synced_at': now,
'pushed_at': yesterday,
'version': '0.9.0',
},
]
# 应用搜索过滤
if query:
query = query.lower()
servers = [
s
for s in test_servers
if query in s['name'].lower()
or query in s['description'].lower()
or query in s['author'].lower()
]
else:
servers = test_servers
# 应用排序
reverse = sort_order.upper() == 'DESC'
if sort_by == 'stars':
servers = sorted(servers, key=lambda s: s['stars'], reverse=reverse)
elif sort_by == 'created_at':
servers = sorted(servers, key=lambda s: s['CreatedAt'], reverse=reverse)
elif sort_by == 'pushed_at':
servers = sorted(servers, key=lambda s: s['pushed_at'], reverse=reverse)
# 应用分页
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
servers = servers[start_idx:end_idx]
# 返回结果
return self.success(data={'servers': servers, 'total': len(servers)})

View File

@@ -1,214 +0,0 @@
from __future__ import annotations
import time
import traceback
import uuid
import quart
import asyncio
import sqlalchemy
from pkg.entity.persistence.mcp import MCPServer
from .....core import taskmgr
from .. import group
from sqlalchemy import insert
@group.group_class('mcp', '/api/v1/mcp')
class MCPRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/servers', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""获取MCP服务器列表"""
if quart.request.method == 'GET':
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(MCPServer).order_by(MCPServer.created_at.desc())
)
raw_results = result.all()
servers = [self.ap.persistence_mgr.serialize_model(MCPServer, row) for row in raw_results]
servers_with_status = []
# 获取MCP工具加载器
mcp_loader = None
for loader in self.ap.tool_mgr.loaders:
if loader.__class__.__name__ == 'MCPLoader':
mcp_loader = loader
break
for server in servers:
# 设置状态
if server['enable']:
status = 'enabled'
else:
status = 'disabled'
# 构建 config 对象 (前端期望的格式)
extra_args = server.get('extra_args', {})
config = {
'name': server['name'],
'mode': server['mode'],
'enable': server['enable'],
}
# 根据模式添加相应的配置
if server['mode'] == 'sse':
config['url'] = extra_args.get('url', '')
config['headers'] = extra_args.get('headers', {})
config['timeout'] = extra_args.get('timeout', 60)
elif server['mode'] == 'stdio':
config['command'] = extra_args.get('command', '')
config['args'] = extra_args.get('args', [])
config['env'] = extra_args.get('env', {})
# 从运行中的会话获取工具数量
tools_count = 0
if mcp_loader and hasattr(mcp_loader, 'sessions') and server['name'] in mcp_loader.sessions:
session = mcp_loader.sessions[server['name']]
tools_count = len(session.functions)
server_info = {
'name': server['name'],
'mode': server['mode'],
'enable': server['enable'],
'status': status,
'tools': tools_count, # 从运行中的会话获取工具数量
'config': config,
}
servers_with_status.append(server_info)
return self.success(data={'servers': servers_with_status})
elif quart.request.method == 'POST':
data = await quart.request.json
data = data['source']
try:
# 检查服务器名称是否重复
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(MCPServer).where(MCPServer.name == data['name'])
)
if result.first() is not None:
return self.http_status(400, -1, 'Server name already exists')
# 创建新服务器配置
new_server = {
'uuid': str(uuid.uuid4()),
'name': data['name'],
'mode': 'sse',
'enable': data.get('enable', False),
'extra_args': {
'url':data.get('url',''),
'headers':data.get('headers',{}),
'timeout':data.get('timeout',60),
'ssereadtimeout':data.get('ssereadtimeout',300),
},
}
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(MCPServer).values(new_server)
)
return self.success()
except Exception:
print(traceback.format_exc())
@self.route('/servers/<server_name>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
"""获取、更新或删除MCP服务器配置"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(MCPServer).where(MCPServer.name == server_name)
)
server = result.first()
if server is None:
return self.http_status(404, -1, 'Server not found')
if quart.request.method == 'GET':
server_data = self.ap.persistence_mgr.serialize_model(MCPServer, server)
return self.success(data={'server': server_data})
elif quart.request.method == 'PUT':
data = await quart.request.json
update_data = {
'enable': data.get('enable', server.enable),
}
extra_args = server.extra_args or {}
if server.mode == 'sse':
extra_args.update({
'url': data.get('url', extra_args.get('url','')),
'headers': data.get('headers', extra_args.get('headers',{})),
'timeout': data.get('timeout', extra_args.get('timeout',60)),
'ssereadtimeout': data.get('ssereadtimeout', extra_args.get('ssereadtimeout',300)),
})
update_data['extra_args'] = extra_args
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(MCPServer).where(MCPServer.name == server_name).values(update_data)
)
return self.success()
elif quart.request.method == 'DELETE':
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(MCPServer).where(MCPServer.name == server_name)
)
return self.success()
@self.route('/servers/<server_name>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
"""测试MCP服务器连接"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(MCPServer).where(MCPServer.name == server_name)
)
server = result.first()
if server is None:
return self.http_status(404, -1, 'Server not found')
# 创建测试任务
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._test_mcp_server(server, ctx),
kind='mcp-operation',
name=f'mcp-test-{server_name}',
label=f'Testing MCP server {server_name}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
async def _test_mcp_server(self, server: MCPServer, ctx: taskmgr.TaskContext):
"""测试MCP服务器连接"""
try:
from .....provider.tools.loaders.mcp import RuntimeMCPSession
ctx.current_action = f'Testing connection to {server.name}'
# 创建临时会话进行测试
session = RuntimeMCPSession(server.name, {
'name': server.name,
'mode': server.mode,
'enable': server.enable,
'url': server.extra_args.get('url',''),
'headers': server.extra_args.get('headers',{}),
'timeout': server.extra_args.get('timeout',60),
},enable=True, ap=self.ap)
await session.start()
# 获取工具列表作为测试
tools_count = len(session.functions)
tool_name_list = []
for function in session.functions:
tool_name_list.append(function.name)
ctx.current_action = f'Successfully connected. Found {tools_count} tools.'
# 关闭测试会话
await session.shutdown()
return {'status': 'success', 'tools_count': tools_count,'tools_names_lists':tool_name_list}
except Exception as e:
print(traceback.format_exc())
ctx.current_action = f'Connection test failed: {str(e)}'
raise e

View File

@@ -1,9 +1,8 @@
from __future__ import annotations
import quart
import asyncio
from ......core import taskmgr
from ... import group
@@ -14,342 +13,107 @@ class MCPRouterGroup(group.RouterGroup):
async def _() -> str:
"""获取MCP服务器列表"""
if quart.request.method == 'GET':
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.success(data={'servers': []})
servers = self.ap.provider_cfg.data.get('mcp', {}).get('servers', [])
# 获取每个服务器的状态和工具信息
mcp_loader = None
for loader_name, loader in self.ap.tool_mgr.loaders.items():
if loader_name == 'mcp':
mcp_loader = loader
break
servers = await self.ap.mcp_service.get_mcp_servers()
servers_with_status = []
# 获取MCP工具加载器
mcp_loader = self.ap.tool_mgr.mcp_tool_loader
for server in servers:
# 从运行中的会话获取工具数量
tools_count = 0
if mcp_loader:
session = mcp_loader.sessions.get(server['name'])
if session:
tools_count = len(session.functions)
server_info = {
'name': server['name'],
'mode': server['mode'],
'enable': server['enable'],
'config': server,
'status': 'disconnected',
'tools': [],
'error': None,
**server,
'tools': tools_count,
}
# 检查服务器连接状态
if mcp_loader and server['name'] in mcp_loader.sessions:
session = mcp_loader.sessions[server['name']]
server_info['status'] = 'connected'
server_info['tools'] = [
{'name': func.name, 'description': func.description, 'parameters': func.parameters}
for func in session.functions
]
elif server['enable']:
server_info['status'] = 'error'
server_info['error'] = 'Failed to connect'
servers_with_status.append(server_info)
return self.success(data={'servers': servers_with_status})
elif quart.request.method == 'POST':
data = await quart.request.json
data = data['source']
# 验证必填字段
required_fields = ['name', 'mode']
for field in required_fields:
if field not in data:
return self.http_status(400, -1, f'Missing required field: {field}')
uuid = await self.ap.mcp_service.create_mcp_server(data)
# 检查provider_cfg是否可用
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.http_status(500, -1, 'Provider configuration not available')
# 获取当前配置
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
servers = mcp_config['servers']
# 检查服务器名称是否重复
for server in servers:
if server['name'] == data['name']:
return self.http_status(400, -1, 'Server name already exists')
# 创建新服务器配置
new_server = {
'name': data['name'],
'mode': data['mode'],
'enable': data.get('enable', True),
}
# 根据模式添加配置
if data['mode'] == 'stdio':
new_server.update(
{'command': data.get('command', ''), 'args': data.get('args', []), 'env': data.get('env', {})}
)
elif data['mode'] == 'sse':
new_server.update(
{
'url': data.get('url', ''),
'headers': data.get('headers', {}),
'timeout': data.get('timeout', 10),
}
)
# 添加到配置
servers.append(new_server)
self.ap.provider_cfg.data['mcp'] = mcp_config
# 保存配置
await self.ap.provider_cfg.dump_config()
# 如果启用尝试重新加载MCP loader
if new_server['enable']:
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._reload_mcp_loader(ctx),
kind='mcp-operation',
name=f'mcp-reload-{new_server["name"]}',
label=f'Reloading MCP loader for {new_server["name"]}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
else:
return self.success()
else:
return self.http_status(405, -1, 'Method not allowed')
return self.success(data={'uuid': uuid})
@self.route('/servers/<server_name>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
"""获取、更新或删除MCP服务器配置"""
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.http_status(500, -1, 'Provider configuration not available')
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
servers = mcp_config['servers']
# 查找服务器
server_index = None
for i, server in enumerate(servers):
if server['name'] == server_name:
server_index = i
break
if server_index is None:
server_data = await self.ap.mcp_service.get_mcp_server_by_name(server_name)
if server_data is None:
return self.http_status(404, -1, 'Server not found')
if quart.request.method == 'GET':
return self.success(data={'server': servers[server_index]})
return self.success(data={'server': server_data})
elif quart.request.method == 'PUT':
data = await quart.request.json
server = servers[server_index]
# 更新配置
server.update(
{
'enable': data.get('enable', server.get('enable', True)),
}
)
# 根据模式更新特定配置
if server['mode'] == 'stdio':
server.update(
{
'command': data.get('command', server.get('command', '')),
'args': data.get('args', server.get('args', [])),
'env': data.get('env', server.get('env', {})),
}
)
elif server['mode'] == 'sse':
server.update(
{
'url': data.get('url', server.get('url', '')),
'headers': data.get('headers', server.get('headers', {})),
'timeout': data.get('timeout', server.get('timeout', 10)),
}
)
# 保存配置
await self.ap.provider_cfg.dump_config()
# 重新加载MCP loader
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._reload_mcp_loader(ctx),
kind='mcp-operation',
name=f'mcp-reload-{server_name}',
label=f'Reloading MCP loader for {server_name}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
await self.ap.mcp_service.update_mcp_server(server_data['uuid'], data)
return self.success()
elif quart.request.method == 'DELETE':
# 删除服务器
servers.pop(server_index)
self.ap.provider_cfg.data['mcp'] = mcp_config
# 保存配置
await self.ap.provider_cfg.dump_config()
# 重新加载MCP loader
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._reload_mcp_loader(ctx),
kind='mcp-operation',
name=f'mcp-remove-{server_name}',
label=f'Removing MCP server {server_name}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/servers/<server_name>/toggle', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
"""切换MCP服务器启用状态"""
data = await quart.request.json
target_enabled = data.get('target_enabled')
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.http_status(500, -1, 'Provider configuration not available')
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
servers = mcp_config['servers']
# 查找并更新服务器
for server in servers:
if server['name'] == server_name:
server['enable'] = target_enabled
break
else:
return self.http_status(404, -1, 'Server not found')
# 保存配置
await self.ap.provider_cfg.dump_config()
# 重新加载MCP loader
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._reload_mcp_loader(ctx),
kind='mcp-operation',
name=f'mcp-toggle-{server_name}',
label=f'Toggling MCP server {server_name}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
await self.ap.mcp_service.delete_mcp_server(server_data['uuid'])
return self.success()
@self.route('/servers/<server_name>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
"""测试MCP服务器连接"""
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.http_status(500, -1, 'Provider configuration not available')
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
servers = mcp_config['servers']
# 查找服务器配置
server_config = None
for server in servers:
if server['name'] == server_name:
server_config = server
break
if server_config is None:
server_data = await self.ap.mcp_service.get_mcp_server_by_name(server_name)
if server_data is None:
return self.http_status(404, -1, 'Server not found')
# 创建测试任务
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._test_mcp_server(server_config, ctx),
kind='mcp-operation',
name=f'mcp-test-{server_name}',
label=f'Testing MCP server {server_name}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""从GitHub安装MCP服务器"""
data = await quart.request.json
source = data.get('source')
# TODO 这里移到service去
# # 创建测试任务
# ctx = taskmgr.TaskContext.new()
# wrapper = self.ap.task_mgr.create_user_task(
# self._test_mcp_server(server, ctx),
# kind='mcp-operation',
# name=f'mcp-test-{server_name}',
# label=f'Testing MCP server {server_name}',
# context=ctx,
# )
# return self.success(data={'task_id': wrapper.id})
if not source:
return self.http_status(400, -1, 'Missing source parameter')
# async def _test_mcp_server(self, server: persistence_mcp.MCPServer, ctx: taskmgr.TaskContext):
# """测试MCP服务器连接"""
# try:
# 创建安装任务
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._install_mcp_from_github(source, ctx),
kind='mcp-operation',
name='install-mcp-github',
label=f'Installing MCP from GitHub: {source}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
# ctx.current_action = f'Testing connection to {server.name}'
# # 创建临时会话进行测试
# session = RuntimeMCPSession(server.name, {
# 'name': server.name,
# 'mode': server.mode,
# 'enable': server.enable,
# 'url': server.extra_args.get('url',''),
# 'headers': server.extra_args.get('headers',{}),
# 'timeout': server.extra_args.get('timeout',60),
# },enable=True, ap=self.ap)
# await session.start()
async def _reload_mcp_loader(self, ctx: taskmgr.TaskContext):
"""重新加载MCP loader"""
try:
ctx.current_action = 'Stopping existing MCP sessions'
# 停止现有的MCP会话
mcp_loader = None
for loader_name, loader in self.ap.tool_mgr.loaders.items():
if loader_name == 'mcp':
mcp_loader = loader
break
# # 获取工具列表作为测试
# tools_count = len(session.functions)
if mcp_loader:
await mcp_loader.shutdown()
# tool_name_list = []
# for function in session.functions:
# tool_name_list.append(function.name)
# ctx.current_action = f'Successfully connected. Found {tools_count} tools.'
ctx.current_action = 'Reloading MCP configuration'
# 重新加载MCP loader
await self.ap.tool_mgr.reload_loader('mcp')
# # 关闭测试会话
# await session.shutdown()
ctx.current_action = 'MCP loader reloaded successfully'
# return {'status': 'success', 'tools_count': tools_count,'tools_names_lists':tool_name_list}
except Exception as e:
ctx.current_action = f'Failed to reload MCP loader: {str(e)}'
raise e
async def _test_mcp_server(self, server_config: dict, ctx: taskmgr.TaskContext):
"""测试MCP服务器连接"""
try:
from ......provider.tools.loaders.mcp import RuntimeMCPSession
ctx.current_action = f'Testing connection to {server_config["name"]}'
# 创建临时会话进行测试
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
await session.initialize()
# 获取工具列表作为测试
tools_count = len(session.functions)
ctx.current_action = f'Successfully connected. Found {tools_count} tools.'
# 关闭测试会话
await session.shutdown()
return {'status': 'success', 'tools_count': tools_count}
except Exception as e:
ctx.current_action = f'Connection test failed: {str(e)}'
raise e
async def _install_mcp_from_github(self, source: str, ctx: taskmgr.TaskContext):
"""从GitHub安装MCP服务器的实现"""
try:
ctx.current_action = f'Installing MCP server from {source}'
# 这里是安装逻辑的占位符
# 实际实现将包括克隆仓库、解析配置、安装依赖等步骤
# 模拟安装过程
await asyncio.sleep(2) # 模拟安装过程
# 返回成功结果
return {'status': 'success', 'message': f'Successfully installed MCP server from {source}'}
except Exception as e:
ctx.current_action = f'Failed to install MCP server: {str(e)}'
raise e
# except Exception as e:
# print(traceback.format_exc())
# ctx.current_action = f'Connection test failed: {str(e)}'
# raise e

View File

@@ -15,12 +15,14 @@ from .groups import provider as groups_provider
from .groups import platform as groups_platform
from .groups import pipelines as groups_pipelines
from .groups import knowledge as groups_knowledge
from .groups import resources as groups_resources
importutil.import_modules_in_pkg(groups)
importutil.import_modules_in_pkg(groups_provider)
importutil.import_modules_in_pkg(groups_platform)
importutil.import_modules_in_pkg(groups_pipelines)
importutil.import_modules_in_pkg(groups_knowledge)
importutil.import_modules_in_pkg(groups_resources)
class HTTPController:

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
import sqlalchemy
import uuid
from ....core import app
from ....entity.persistence import mcp as persistence_mcp
class MCPService:
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_mcp_servers(self) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
servers = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server) for server in servers]
async def create_mcp_server(self, server_data: dict) -> str:
server_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_mcp.MCPServer).values(server_data))
server = await self.get_mcp_server(server_data['uuid'])
# TODO: load runtime mcp server session
return server['uuid']
async def get_mcp_server(self, server_uuid: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
server = result.first()
if server is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
async def get_mcp_server_by_name(self, server_name: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.name == server_name)
)
server = result.first()
if server is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
async def update_mcp_server(self, server_uuid: str, server_data: dict) -> None:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_mcp.MCPServer)
.where(persistence_mcp.MCPServer.uuid == server_uuid)
.values(server_data)
)
# TODO: reload runtime mcp server session
async def delete_mcp_server(self, server_uuid: str) -> None:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
# TODO: remove runtime mcp server session

View File

@@ -22,6 +22,7 @@ from ..api.http.service import model as model_service
from ..api.http.service import pipeline as pipeline_service
from ..api.http.service import bot as bot_service
from ..api.http.service import knowledge as knowledge_service
from ..api.http.service import mcp as mcp_service
from ..discover import engine as discover_engine
from ..storage import mgr as storagemgr
from ..utils import logcache
@@ -119,6 +120,8 @@ class Application:
knowledge_service: knowledge_service.KnowledgeService = None
mcp_service: mcp_service.MCPService = None
def __init__(self):
pass

View File

@@ -19,6 +19,7 @@ from ...api.http.service import model as model_service
from ...api.http.service import pipeline as pipeline_service
from ...api.http.service import bot as bot_service
from ...api.http.service import knowledge as knowledge_service
from ...api.http.service import mcp as mcp_service
from ...discover import engine as discover_engine
from ...storage import mgr as storagemgr
from ...utils import logcache
@@ -126,5 +127,8 @@ class BuildAppStage(stage.BootingStage):
knowledge_service_inst = knowledge_service.KnowledgeService(ap)
ap.knowledge_service = knowledge_service_inst
mcp_service_inst = mcp_service.MCPService(ap)
ap.mcp_service = mcp_service_inst
ctrl = controller.Controller(ap)
ap.ctrl = ctrl

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
import typing
from ...core import app
from . import loader as tools_loader
from ...utils import importutil
from . import loaders
from .loaders import mcp as mcp_loader, plugin as plugin_loader
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
importutil.import_modules_in_pkg(loaders)
@@ -16,25 +16,24 @@ class ToolManager:
ap: app.Application
loaders: list[tools_loader.ToolLoader]
plugin_tool_loader: plugin_loader.PluginToolLoader
mcp_tool_loader: mcp_loader.MCPLoader
def __init__(self, ap: app.Application):
self.ap = ap
self.all_functions = []
self.loaders = []
async def initialize(self):
for loader_cls in tools_loader.preregistered_loaders:
loader_inst = loader_cls(self.ap)
await loader_inst.initialize()
self.loaders.append(loader_inst)
self.plugin_tool_loader = plugin_loader.PluginToolLoader(self.ap)
await self.plugin_tool_loader.initialize()
self.mcp_tool_loader = mcp_loader.MCPLoader(self.ap)
await self.mcp_tool_loader.initialize()
async def get_all_tools(self) -> list[resource_tool.LLMTool]:
"""获取所有函数"""
all_functions: list[resource_tool.LLMTool] = []
for loader in self.loaders:
all_functions.extend(await loader.get_tools())
all_functions.extend(await self.plugin_tool_loader.get_tools())
all_functions.extend(await self.mcp_tool_loader.get_tools())
return all_functions
@@ -93,13 +92,14 @@ class ToolManager:
async def execute_func_call(self, name: str, parameters: dict) -> typing.Any:
"""执行函数调用"""
for loader in self.loaders:
if await loader.has_tool(name):
return await loader.invoke_tool(name, parameters)
if await self.plugin_tool_loader.has_tool(name):
return await self.plugin_tool_loader.invoke_tool(name, parameters)
elif await self.mcp_tool_loader.has_tool(name):
return await self.mcp_tool_loader.invoke_tool(name, parameters)
else:
raise ValueError(f'未找到工具: {name}')
async def shutdown(self):
"""关闭所有工具"""
for loader in self.loaders:
await loader.shutdown()
await self.plugin_tool_loader.shutdown()
await self.mcp_tool_loader.shutdown()