mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-26 03:44:58 +08:00
feat: add mcp db
This commit is contained in:
@@ -1,11 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
import quart
|
import quart
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
from pkg.entity.persistence.mcp import MCPServer
|
||||||
|
|
||||||
from .....core import taskmgr
|
from .....core import taskmgr
|
||||||
from .. import group
|
from .. import group
|
||||||
|
|
||||||
|
from sqlalchemy import insert
|
||||||
|
|
||||||
@group.group_class('mcp', '/api/v1/mcp')
|
@group.group_class('mcp', '/api/v1/mcp')
|
||||||
class MCPRouterGroup(group.RouterGroup):
|
class MCPRouterGroup(group.RouterGroup):
|
||||||
@@ -14,312 +21,137 @@ class MCPRouterGroup(group.RouterGroup):
|
|||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
"""获取MCP服务器列表"""
|
"""获取MCP服务器列表"""
|
||||||
if quart.request.method == 'GET':
|
if quart.request.method == 'GET':
|
||||||
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
return self.success(data={'servers': []})
|
sqlalchemy.select(MCPServer).order_by(MCPServer.created_at.desc())
|
||||||
|
)
|
||||||
servers = self.ap.provider_cfg.data.get('mcp', {}).get('servers', [])
|
servers = [self.ap.persistence_mgr.serialize_model(MCPServer, row) for row in result.scalars().all()]
|
||||||
|
|
||||||
# 获取每个服务器的状态和工具信息
|
|
||||||
mcp_loader = None
|
|
||||||
for loader_name, loader in self.ap.tool_mgr.loaders.items():
|
|
||||||
if loader_name == 'mcp':
|
|
||||||
mcp_loader = loader
|
|
||||||
break
|
|
||||||
|
|
||||||
servers_with_status = []
|
servers_with_status = []
|
||||||
for server in servers:
|
for server in servers:
|
||||||
|
if servers['enable']:
|
||||||
|
status = 'enabled'
|
||||||
|
else:
|
||||||
|
status = 'disabled'
|
||||||
|
|
||||||
|
# 这里先写成开关状态,先不写连接状态
|
||||||
server_info = {
|
server_info = {
|
||||||
'name': server['name'],
|
'name': server['name'],
|
||||||
'mode': server['mode'],
|
'mode': server['mode'],
|
||||||
'enable': server['enable'],
|
'enable': server['enable'],
|
||||||
'config': server,
|
'description': server.get('description',''),
|
||||||
'status': 'disconnected',
|
'extra_args': server.get('extra_args',{}),
|
||||||
'tools': [],
|
'status': status,
|
||||||
'error': None,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检查服务器连接状态
|
|
||||||
if mcp_loader and server['name'] in mcp_loader.sessions:
|
|
||||||
session = mcp_loader.sessions[server['name']]
|
|
||||||
server_info['status'] = 'connected'
|
|
||||||
server_info['tools'] = [
|
|
||||||
{'name': func.name, 'description': func.description, 'parameters': func.parameters}
|
|
||||||
for func in session.functions
|
|
||||||
]
|
|
||||||
elif server['enable']:
|
|
||||||
server_info['status'] = 'error'
|
|
||||||
server_info['error'] = 'Failed to connect'
|
|
||||||
|
|
||||||
servers_with_status.append(server_info)
|
servers_with_status.append(server_info)
|
||||||
|
|
||||||
return self.success(data={'servers': servers_with_status})
|
return self.success(data={'servers': servers_with_status})
|
||||||
|
|
||||||
elif quart.request.method == 'POST':
|
elif quart.request.method == 'POST':
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
|
|
||||||
# 验证必填字段
|
|
||||||
required_fields = ['name', 'mode']
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in data:
|
|
||||||
return self.http_status(400, -1, f'Missing required field: {field}')
|
|
||||||
|
|
||||||
# 检查provider_cfg是否可用
|
|
||||||
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
|
|
||||||
return self.http_status(500, -1, 'Provider configuration not available')
|
|
||||||
|
|
||||||
# 获取当前配置
|
|
||||||
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
|
|
||||||
servers = mcp_config['servers']
|
|
||||||
|
|
||||||
# 检查服务器名称是否重复
|
# 检查服务器名称是否重复
|
||||||
for server in servers:
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
if server['name'] == data['name']:
|
sqlalchemy.select(MCPServer).where(MCPServer.name == data['name'])
|
||||||
return self.http_status(400, -1, 'Server name already exists')
|
)
|
||||||
|
if result.first() is not None:
|
||||||
|
return self.http_status(400, -1, 'Server name already exists')
|
||||||
|
|
||||||
# 创建新服务器配置
|
# 创建新服务器配置
|
||||||
new_server = {
|
new_server = {
|
||||||
|
'uuid': str(uuid.uuid4()),
|
||||||
'name': data['name'],
|
'name': data['name'],
|
||||||
'mode': data['mode'],
|
'mode': data['mode'],
|
||||||
'enable': data.get('enable', True),
|
'enable': data.get('enable', False),
|
||||||
|
'description': data.get('description',''),
|
||||||
|
'extra_args': {
|
||||||
|
'url':data.get('url',''),
|
||||||
|
'headers':data.get('headers',{}),
|
||||||
|
'timeout':data.get('timeout',60),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# 根据模式添加配置
|
await self.ap.persistence_mgr.execute_async(
|
||||||
if data['mode'] == 'stdio':
|
sqlalchemy.insert(MCPServer).values(new_server)
|
||||||
new_server.update(
|
)
|
||||||
{'command': data.get('command', ''), 'args': data.get('args', []), 'env': data.get('env', {})}
|
|
||||||
)
|
|
||||||
elif data['mode'] == 'sse':
|
|
||||||
new_server.update(
|
|
||||||
{
|
|
||||||
'url': data.get('url', ''),
|
|
||||||
'headers': data.get('headers', {}),
|
|
||||||
'timeout': data.get('timeout', 10),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加到配置
|
|
||||||
servers.append(new_server)
|
|
||||||
self.ap.provider_cfg.data['mcp'] = mcp_config
|
|
||||||
|
|
||||||
# 保存配置
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
|
||||||
|
|
||||||
# 如果启用,尝试重新加载MCP loader
|
|
||||||
if new_server['enable']:
|
|
||||||
ctx = taskmgr.TaskContext.new()
|
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
|
||||||
self._reload_mcp_loader(ctx),
|
|
||||||
kind='mcp-operation',
|
|
||||||
name=f'mcp-reload-{new_server["name"]}',
|
|
||||||
label=f'Reloading MCP loader for {new_server["name"]}',
|
|
||||||
context=ctx,
|
|
||||||
)
|
|
||||||
return self.success(data={'task_id': wrapper.id})
|
|
||||||
else:
|
|
||||||
return self.success()
|
|
||||||
else:
|
|
||||||
return self.http_status(405, -1, 'Method not allowed')
|
|
||||||
|
|
||||||
@self.route('/servers/<server_name>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route('/servers/<server_name>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
async def _(server_name: str) -> str:
|
async def _(server_name: str) -> str:
|
||||||
"""获取、更新或删除MCP服务器配置"""
|
"""获取、更新或删除MCP服务器配置"""
|
||||||
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
return self.http_status(500, -1, 'Provider configuration not available')
|
sqlalchemy.select(MCPServer).where(MCPServer.name == server_name)
|
||||||
|
)
|
||||||
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
|
server = result.first()
|
||||||
servers = mcp_config['servers']
|
if server is None:
|
||||||
|
|
||||||
# 查找服务器
|
|
||||||
server_index = None
|
|
||||||
for i, server in enumerate(servers):
|
|
||||||
if server['name'] == server_name:
|
|
||||||
server_index = i
|
|
||||||
break
|
|
||||||
|
|
||||||
if server_index is None:
|
|
||||||
return self.http_status(404, -1, 'Server not found')
|
return self.http_status(404, -1, 'Server not found')
|
||||||
|
|
||||||
if quart.request.method == 'GET':
|
if quart.request.method == 'GET':
|
||||||
return self.success(data={'server': servers[server_index]})
|
server_data = self.ap.persistence_mgr.serialize_model(MCPServer, server)
|
||||||
|
return self.success(data={'server': server_data})
|
||||||
|
|
||||||
elif quart.request.method == 'PUT':
|
elif quart.request.method == 'PUT':
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
server = servers[server_index]
|
update_data = {
|
||||||
|
'enable': data.get('enable', server.enable),
|
||||||
|
'description': data.get('description', server.description),
|
||||||
|
}
|
||||||
|
|
||||||
# 更新配置
|
extra_args = server.extra_args or {}
|
||||||
server.update(
|
if server.mode == 'sse':
|
||||||
{
|
extra_args.update({
|
||||||
'enable': data.get('enable', server.get('enable', True)),
|
'url': data.get('url', extra_args.get('url','')),
|
||||||
}
|
'headers': data.get('headers', extra_args.get('headers',{})),
|
||||||
|
'timeout': data.get('timeout', extra_args.get('timeout',60)),
|
||||||
|
})
|
||||||
|
update_data['extra_args'] = extra_args
|
||||||
|
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.update(MCPServer).where(MCPServer.name == server_name).values(update_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 根据模式更新特定配置
|
return self.success()
|
||||||
if server['mode'] == 'stdio':
|
|
||||||
server.update(
|
|
||||||
{
|
|
||||||
'command': data.get('command', server.get('command', '')),
|
|
||||||
'args': data.get('args', server.get('args', [])),
|
|
||||||
'env': data.get('env', server.get('env', {})),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif server['mode'] == 'sse':
|
|
||||||
server.update(
|
|
||||||
{
|
|
||||||
'url': data.get('url', server.get('url', '')),
|
|
||||||
'headers': data.get('headers', server.get('headers', {})),
|
|
||||||
'timeout': data.get('timeout', server.get('timeout', 10)),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 保存配置
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
|
||||||
|
|
||||||
# 重新加载MCP loader
|
|
||||||
ctx = taskmgr.TaskContext.new()
|
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
|
||||||
self._reload_mcp_loader(ctx),
|
|
||||||
kind='mcp-operation',
|
|
||||||
name=f'mcp-reload-{server_name}',
|
|
||||||
label=f'Reloading MCP loader for {server_name}',
|
|
||||||
context=ctx,
|
|
||||||
)
|
|
||||||
return self.success(data={'task_id': wrapper.id})
|
|
||||||
|
|
||||||
elif quart.request.method == 'DELETE':
|
elif quart.request.method == 'DELETE':
|
||||||
# 删除服务器
|
await self.ap.persistence_mgr.execute_async(
|
||||||
servers.pop(server_index)
|
sqlalchemy.delete(MCPServer).where(MCPServer.name == server_name)
|
||||||
self.ap.provider_cfg.data['mcp'] = mcp_config
|
|
||||||
|
|
||||||
# 保存配置
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
|
||||||
|
|
||||||
# 重新加载MCP loader
|
|
||||||
ctx = taskmgr.TaskContext.new()
|
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
|
||||||
self._reload_mcp_loader(ctx),
|
|
||||||
kind='mcp-operation',
|
|
||||||
name=f'mcp-remove-{server_name}',
|
|
||||||
label=f'Removing MCP server {server_name}',
|
|
||||||
context=ctx,
|
|
||||||
)
|
)
|
||||||
return self.success(data={'task_id': wrapper.id})
|
return self.success()
|
||||||
|
|
||||||
@self.route('/servers/<server_name>/toggle', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
|
|
||||||
async def _(server_name: str) -> str:
|
|
||||||
"""切换MCP服务器启用状态"""
|
|
||||||
data = await quart.request.json
|
|
||||||
target_enabled = data.get('target_enabled')
|
|
||||||
|
|
||||||
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
|
|
||||||
return self.http_status(500, -1, 'Provider configuration not available')
|
|
||||||
|
|
||||||
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
|
|
||||||
servers = mcp_config['servers']
|
|
||||||
|
|
||||||
# 查找并更新服务器
|
|
||||||
for server in servers:
|
|
||||||
if server['name'] == server_name:
|
|
||||||
server['enable'] = target_enabled
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return self.http_status(404, -1, 'Server not found')
|
|
||||||
|
|
||||||
# 保存配置
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
|
||||||
|
|
||||||
# 重新加载MCP loader
|
|
||||||
ctx = taskmgr.TaskContext.new()
|
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
|
||||||
self._reload_mcp_loader(ctx),
|
|
||||||
kind='mcp-operation',
|
|
||||||
name=f'mcp-toggle-{server_name}',
|
|
||||||
label=f'Toggling MCP server {server_name}',
|
|
||||||
context=ctx,
|
|
||||||
)
|
|
||||||
return self.success(data={'task_id': wrapper.id})
|
|
||||||
|
|
||||||
@self.route('/servers/<server_name>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route('/servers/<server_name>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
async def _(server_name: str) -> str:
|
async def _(server_name: str) -> str:
|
||||||
"""测试MCP服务器连接"""
|
"""测试MCP服务器连接"""
|
||||||
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
return self.http_status(500, -1, 'Provider configuration not available')
|
sqlalchemy.select(MCPServer).where(MCPServer.name == server_name)
|
||||||
|
)
|
||||||
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
|
server = result.first()
|
||||||
servers = mcp_config['servers']
|
if server is None:
|
||||||
|
|
||||||
# 查找服务器配置
|
|
||||||
server_config = None
|
|
||||||
for server in servers:
|
|
||||||
if server['name'] == server_name:
|
|
||||||
server_config = server
|
|
||||||
break
|
|
||||||
|
|
||||||
if server_config is None:
|
|
||||||
return self.http_status(404, -1, 'Server not found')
|
return self.http_status(404, -1, 'Server not found')
|
||||||
|
|
||||||
# 创建测试任务
|
# 创建测试任务
|
||||||
ctx = taskmgr.TaskContext.new()
|
ctx = taskmgr.TaskContext.new()
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
wrapper = self.ap.task_mgr.create_user_task(
|
||||||
self._test_mcp_server(server_config, ctx),
|
self._test_mcp_server(server, ctx),
|
||||||
kind='mcp-operation',
|
kind='mcp-operation',
|
||||||
name=f'mcp-test-{server_name}',
|
name=f'mcp-test-{server_name}',
|
||||||
label=f'Testing MCP server {server_name}',
|
label=f'Testing MCP server {server_name}',
|
||||||
context=ctx,
|
context=ctx,
|
||||||
)
|
)
|
||||||
return self.success(data={'task_id': wrapper.id})
|
return self.success(data={'task_id': wrapper.id})
|
||||||
|
|
||||||
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
async def _test_mcp_server(self, server: MCPServer, ctx: taskmgr.TaskContext):
|
||||||
async def _() -> str:
|
|
||||||
"""从GitHub安装MCP服务器"""
|
|
||||||
data = await quart.request.json
|
|
||||||
source = data.get('source')
|
|
||||||
|
|
||||||
if not source:
|
|
||||||
return self.http_status(400, -1, 'Missing source parameter')
|
|
||||||
|
|
||||||
# 创建安装任务
|
|
||||||
ctx = taskmgr.TaskContext.new()
|
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
|
||||||
self._install_mcp_from_github(source, ctx),
|
|
||||||
kind='mcp-operation',
|
|
||||||
name='install-mcp-github',
|
|
||||||
label=f'Installing MCP from GitHub: {source}',
|
|
||||||
context=ctx,
|
|
||||||
)
|
|
||||||
return self.success(data={'task_id': wrapper.id})
|
|
||||||
|
|
||||||
async def _reload_mcp_loader(self, ctx: taskmgr.TaskContext):
|
|
||||||
"""重新加载MCP loader"""
|
|
||||||
try:
|
|
||||||
ctx.current_action = 'Stopping existing MCP sessions'
|
|
||||||
# 停止现有的MCP会话
|
|
||||||
mcp_loader = None
|
|
||||||
for loader_name, loader in self.ap.tool_mgr.loaders.items():
|
|
||||||
if loader_name == 'mcp':
|
|
||||||
mcp_loader = loader
|
|
||||||
break
|
|
||||||
|
|
||||||
if mcp_loader:
|
|
||||||
await mcp_loader.shutdown()
|
|
||||||
|
|
||||||
ctx.current_action = 'Reloading MCP configuration'
|
|
||||||
# 重新加载MCP loader
|
|
||||||
await self.ap.tool_mgr.reload_loader('mcp')
|
|
||||||
|
|
||||||
ctx.current_action = 'MCP loader reloaded successfully'
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
ctx.current_action = f'Failed to reload MCP loader: {str(e)}'
|
|
||||||
raise e
|
|
||||||
|
|
||||||
async def _test_mcp_server(self, server_config: dict, ctx: taskmgr.TaskContext):
|
|
||||||
"""测试MCP服务器连接"""
|
"""测试MCP服务器连接"""
|
||||||
try:
|
try:
|
||||||
from .....provider.tools.loaders.mcp import RuntimeMCPSession
|
from .....provider.tools.loaders.mcp import RuntimeMCPSession
|
||||||
|
|
||||||
ctx.current_action = f'Testing connection to {server_config["name"]}'
|
ctx.current_action = f'Testing connection to {server.name}'
|
||||||
|
|
||||||
# 创建临时会话进行测试
|
# 创建临时会话进行测试
|
||||||
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
|
session = RuntimeMCPSession(server.name, {
|
||||||
|
'name': server.name,
|
||||||
|
'mode': server.mode,
|
||||||
|
'enable': server.enable,
|
||||||
|
'description': server.description,
|
||||||
|
'extra_args': server.extra_args or {},
|
||||||
|
}, self.ap)
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
# 获取工具列表作为测试
|
# 获取工具列表作为测试
|
||||||
@@ -334,22 +166,5 @@ class MCPRouterGroup(group.RouterGroup):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
ctx.current_action = f'Connection test failed: {str(e)}'
|
ctx.current_action = f'Connection test failed: {str(e)}'
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _install_mcp_from_github(self, source: str, ctx: taskmgr.TaskContext):
|
|
||||||
"""从GitHub安装MCP服务器的实现"""
|
|
||||||
try:
|
|
||||||
ctx.current_action = f'Installing MCP server from {source}'
|
|
||||||
|
|
||||||
# 这里是安装逻辑的占位符
|
|
||||||
# 实际实现将包括克隆仓库、解析配置、安装依赖等步骤
|
|
||||||
|
|
||||||
# 模拟安装过程
|
|
||||||
|
|
||||||
await asyncio.sleep(2) # 模拟安装过程
|
|
||||||
|
|
||||||
# 返回成功结果
|
|
||||||
return {'status': 'success', 'message': f'Successfully installed MCP server from {source}'}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
ctx.current_action = f'Failed to install MCP server: {str(e)}'
|
|
||||||
raise e
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ class MCPServer(Base):
|
|||||||
|
|
||||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||||
|
description = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||||
enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
|
enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
|
||||||
mode = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) # stdio, sse
|
mode = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) # stdio, sse
|
||||||
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||||
|
|||||||
Reference in New Issue
Block a user