mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 19:37:36 +08:00
refactor(mcp): bridge controller and db operation with service layer
This commit is contained in:
@@ -1,143 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import quart
|
||||
import datetime
|
||||
|
||||
from .. import group
|
||||
|
||||
|
||||
@group.group_class('market', '/api/v1/market')
|
||||
class MarketRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/plugins', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
"""获取插件市场列表"""
|
||||
# data = await quart.request.json
|
||||
# page = data.get('page', 1)
|
||||
# page_size = data.get('page_size', 10)
|
||||
# query = data.get('query', '')
|
||||
# sort_by = data.get('sort_by', 'stars')
|
||||
# sort_order = data.get('sort_order', 'DESC')
|
||||
|
||||
# # 这里是获取插件列表的实现
|
||||
# # 实际项目中这部分会连接到真实的插件市场API或数据库
|
||||
# # 这里我们只是返回一些假数据作为示例
|
||||
|
||||
# # 模拟延迟
|
||||
# import asyncio
|
||||
|
||||
# await asyncio.sleep(0.5)
|
||||
|
||||
# 返回结果
|
||||
return self.success(data={'plugins': [], 'total': 0})
|
||||
|
||||
@self.route('/mcp', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
"""获取MCP服务器市场列表"""
|
||||
data = await quart.request.json
|
||||
page = data.get('page', 1)
|
||||
page_size = data.get('page_size', 10)
|
||||
query = data.get('query', '')
|
||||
sort_by = data.get('sort_by', 'stars')
|
||||
sort_order = data.get('sort_order', 'DESC')
|
||||
|
||||
# 这里是获取MCP服务器列表的实现
|
||||
# 实际项目中这部分会连接到真实的MCP市场API或数据库
|
||||
# 这里我们只是返回一些假数据作为示例
|
||||
|
||||
# 模拟延迟
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# 生成假数据
|
||||
servers = []
|
||||
|
||||
# 只在有搜索关键词或排序时才返回数据
|
||||
if query or sort_by:
|
||||
now = datetime.datetime.now().isoformat()
|
||||
yesterday = (datetime.datetime.now() - datetime.timedelta(days=1)).isoformat()
|
||||
|
||||
test_servers = [
|
||||
{
|
||||
'ID': 1,
|
||||
'CreatedAt': yesterday,
|
||||
'UpdatedAt': now,
|
||||
'DeletedAt': None,
|
||||
'name': 'Google Maps MCP',
|
||||
'author': 'langbot-community',
|
||||
'description': 'Google Maps integration for LangBot, providing geocoding and directions capabilities.',
|
||||
'repository': 'langbot-community/google-maps-mcp',
|
||||
'artifacts_path': '',
|
||||
'stars': 124,
|
||||
'downloads': 342,
|
||||
'status': 'initialized',
|
||||
'synced_at': now,
|
||||
'pushed_at': now,
|
||||
'version': '1.0.0',
|
||||
},
|
||||
{
|
||||
'ID': 2,
|
||||
'CreatedAt': yesterday,
|
||||
'UpdatedAt': now,
|
||||
'DeletedAt': None,
|
||||
'name': 'Weather MCP',
|
||||
'author': 'langbot-community',
|
||||
'description': 'Weather integration for LangBot, providing current weather and forecasts.',
|
||||
'repository': 'langbot-community/weather-mcp',
|
||||
'artifacts_path': '',
|
||||
'stars': 85,
|
||||
'downloads': 215,
|
||||
'status': 'initialized',
|
||||
'synced_at': now,
|
||||
'pushed_at': yesterday,
|
||||
'version': '1.1.0',
|
||||
},
|
||||
{
|
||||
'ID': 3,
|
||||
'CreatedAt': yesterday,
|
||||
'UpdatedAt': now,
|
||||
'DeletedAt': None,
|
||||
'name': 'Serper Search MCP',
|
||||
'author': 'langbot-developers',
|
||||
'description': 'Serper Search integration for LangBot, providing advanced web search capabilities.',
|
||||
'repository': 'langbot-developers/serper-search-mcp',
|
||||
'artifacts_path': '',
|
||||
'stars': 67,
|
||||
'downloads': 178,
|
||||
'status': 'initialized',
|
||||
'synced_at': now,
|
||||
'pushed_at': yesterday,
|
||||
'version': '0.9.0',
|
||||
},
|
||||
]
|
||||
|
||||
# 应用搜索过滤
|
||||
if query:
|
||||
query = query.lower()
|
||||
servers = [
|
||||
s
|
||||
for s in test_servers
|
||||
if query in s['name'].lower()
|
||||
or query in s['description'].lower()
|
||||
or query in s['author'].lower()
|
||||
]
|
||||
else:
|
||||
servers = test_servers
|
||||
|
||||
# 应用排序
|
||||
reverse = sort_order.upper() == 'DESC'
|
||||
if sort_by == 'stars':
|
||||
servers = sorted(servers, key=lambda s: s['stars'], reverse=reverse)
|
||||
elif sort_by == 'created_at':
|
||||
servers = sorted(servers, key=lambda s: s['CreatedAt'], reverse=reverse)
|
||||
elif sort_by == 'pushed_at':
|
||||
servers = sorted(servers, key=lambda s: s['pushed_at'], reverse=reverse)
|
||||
|
||||
# 应用分页
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
servers = servers[start_idx:end_idx]
|
||||
|
||||
# 返回结果
|
||||
return self.success(data={'servers': servers, 'total': len(servers)})
|
||||
@@ -1,214 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
import quart
|
||||
import asyncio
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from pkg.entity.persistence.mcp import MCPServer
|
||||
|
||||
from .....core import taskmgr
|
||||
from .. import group
|
||||
|
||||
from sqlalchemy import insert
|
||||
|
||||
@group.group_class('mcp', '/api/v1/mcp')
|
||||
class MCPRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/servers', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
"""获取MCP服务器列表"""
|
||||
if quart.request.method == 'GET':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(MCPServer).order_by(MCPServer.created_at.desc())
|
||||
)
|
||||
raw_results = result.all()
|
||||
servers = [self.ap.persistence_mgr.serialize_model(MCPServer, row) for row in raw_results]
|
||||
|
||||
servers_with_status = []
|
||||
# 获取MCP工具加载器
|
||||
mcp_loader = None
|
||||
for loader in self.ap.tool_mgr.loaders:
|
||||
if loader.__class__.__name__ == 'MCPLoader':
|
||||
mcp_loader = loader
|
||||
break
|
||||
|
||||
for server in servers:
|
||||
# 设置状态
|
||||
if server['enable']:
|
||||
status = 'enabled'
|
||||
else:
|
||||
status = 'disabled'
|
||||
|
||||
# 构建 config 对象 (前端期望的格式)
|
||||
extra_args = server.get('extra_args', {})
|
||||
config = {
|
||||
'name': server['name'],
|
||||
'mode': server['mode'],
|
||||
'enable': server['enable'],
|
||||
}
|
||||
|
||||
# 根据模式添加相应的配置
|
||||
if server['mode'] == 'sse':
|
||||
config['url'] = extra_args.get('url', '')
|
||||
config['headers'] = extra_args.get('headers', {})
|
||||
config['timeout'] = extra_args.get('timeout', 60)
|
||||
elif server['mode'] == 'stdio':
|
||||
config['command'] = extra_args.get('command', '')
|
||||
config['args'] = extra_args.get('args', [])
|
||||
config['env'] = extra_args.get('env', {})
|
||||
|
||||
# 从运行中的会话获取工具数量
|
||||
tools_count = 0
|
||||
if mcp_loader and hasattr(mcp_loader, 'sessions') and server['name'] in mcp_loader.sessions:
|
||||
session = mcp_loader.sessions[server['name']]
|
||||
tools_count = len(session.functions)
|
||||
|
||||
server_info = {
|
||||
'name': server['name'],
|
||||
'mode': server['mode'],
|
||||
'enable': server['enable'],
|
||||
'status': status,
|
||||
'tools': tools_count, # 从运行中的会话获取工具数量
|
||||
'config': config,
|
||||
}
|
||||
servers_with_status.append(server_info)
|
||||
|
||||
return self.success(data={'servers': servers_with_status})
|
||||
|
||||
elif quart.request.method == 'POST':
|
||||
data = await quart.request.json
|
||||
data = data['source']
|
||||
try:
|
||||
# 检查服务器名称是否重复
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(MCPServer).where(MCPServer.name == data['name'])
|
||||
)
|
||||
if result.first() is not None:
|
||||
return self.http_status(400, -1, 'Server name already exists')
|
||||
|
||||
# 创建新服务器配置
|
||||
new_server = {
|
||||
'uuid': str(uuid.uuid4()),
|
||||
'name': data['name'],
|
||||
'mode': 'sse',
|
||||
'enable': data.get('enable', False),
|
||||
'extra_args': {
|
||||
'url':data.get('url',''),
|
||||
'headers':data.get('headers',{}),
|
||||
'timeout':data.get('timeout',60),
|
||||
'ssereadtimeout':data.get('ssereadtimeout',300),
|
||||
},
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(MCPServer).values(new_server)
|
||||
)
|
||||
|
||||
return self.success()
|
||||
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
|
||||
@self.route('/servers/<server_name>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(server_name: str) -> str:
|
||||
"""获取、更新或删除MCP服务器配置"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(MCPServer).where(MCPServer.name == server_name)
|
||||
)
|
||||
server = result.first()
|
||||
if server is None:
|
||||
return self.http_status(404, -1, 'Server not found')
|
||||
|
||||
if quart.request.method == 'GET':
|
||||
server_data = self.ap.persistence_mgr.serialize_model(MCPServer, server)
|
||||
return self.success(data={'server': server_data})
|
||||
|
||||
elif quart.request.method == 'PUT':
|
||||
data = await quart.request.json
|
||||
update_data = {
|
||||
'enable': data.get('enable', server.enable),
|
||||
}
|
||||
|
||||
extra_args = server.extra_args or {}
|
||||
if server.mode == 'sse':
|
||||
extra_args.update({
|
||||
'url': data.get('url', extra_args.get('url','')),
|
||||
'headers': data.get('headers', extra_args.get('headers',{})),
|
||||
'timeout': data.get('timeout', extra_args.get('timeout',60)),
|
||||
'ssereadtimeout': data.get('ssereadtimeout', extra_args.get('ssereadtimeout',300)),
|
||||
})
|
||||
update_data['extra_args'] = extra_args
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(MCPServer).where(MCPServer.name == server_name).values(update_data)
|
||||
)
|
||||
|
||||
return self.success()
|
||||
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(MCPServer).where(MCPServer.name == server_name)
|
||||
)
|
||||
return self.success()
|
||||
|
||||
@self.route('/servers/<server_name>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(server_name: str) -> str:
|
||||
"""测试MCP服务器连接"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(MCPServer).where(MCPServer.name == server_name)
|
||||
)
|
||||
server = result.first()
|
||||
if server is None:
|
||||
return self.http_status(404, -1, 'Server not found')
|
||||
|
||||
# 创建测试任务
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._test_mcp_server(server, ctx),
|
||||
kind='mcp-operation',
|
||||
name=f'mcp-test-{server_name}',
|
||||
label=f'Testing MCP server {server_name}',
|
||||
context=ctx,
|
||||
)
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
|
||||
async def _test_mcp_server(self, server: MCPServer, ctx: taskmgr.TaskContext):
|
||||
"""测试MCP服务器连接"""
|
||||
try:
|
||||
from .....provider.tools.loaders.mcp import RuntimeMCPSession
|
||||
|
||||
ctx.current_action = f'Testing connection to {server.name}'
|
||||
# 创建临时会话进行测试
|
||||
session = RuntimeMCPSession(server.name, {
|
||||
'name': server.name,
|
||||
'mode': server.mode,
|
||||
'enable': server.enable,
|
||||
'url': server.extra_args.get('url',''),
|
||||
'headers': server.extra_args.get('headers',{}),
|
||||
'timeout': server.extra_args.get('timeout',60),
|
||||
},enable=True, ap=self.ap)
|
||||
await session.start()
|
||||
|
||||
# 获取工具列表作为测试
|
||||
tools_count = len(session.functions)
|
||||
|
||||
tool_name_list = []
|
||||
for function in session.functions:
|
||||
tool_name_list.append(function.name)
|
||||
ctx.current_action = f'Successfully connected. Found {tools_count} tools.'
|
||||
|
||||
# 关闭测试会话
|
||||
await session.shutdown()
|
||||
|
||||
return {'status': 'success', 'tools_count': tools_count,'tools_names_lists':tool_name_list}
|
||||
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
ctx.current_action = f'Connection test failed: {str(e)}'
|
||||
raise e
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import quart
|
||||
import asyncio
|
||||
|
||||
from ......core import taskmgr
|
||||
|
||||
from ... import group
|
||||
|
||||
|
||||
@@ -14,342 +13,107 @@ class MCPRouterGroup(group.RouterGroup):
|
||||
async def _() -> str:
|
||||
"""获取MCP服务器列表"""
|
||||
if quart.request.method == 'GET':
|
||||
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
|
||||
return self.success(data={'servers': []})
|
||||
|
||||
servers = self.ap.provider_cfg.data.get('mcp', {}).get('servers', [])
|
||||
|
||||
# 获取每个服务器的状态和工具信息
|
||||
mcp_loader = None
|
||||
for loader_name, loader in self.ap.tool_mgr.loaders.items():
|
||||
if loader_name == 'mcp':
|
||||
mcp_loader = loader
|
||||
break
|
||||
servers = await self.ap.mcp_service.get_mcp_servers()
|
||||
|
||||
servers_with_status = []
|
||||
# 获取MCP工具加载器
|
||||
mcp_loader = self.ap.tool_mgr.mcp_tool_loader
|
||||
|
||||
for server in servers:
|
||||
# 从运行中的会话获取工具数量
|
||||
tools_count = 0
|
||||
if mcp_loader:
|
||||
session = mcp_loader.sessions.get(server['name'])
|
||||
if session:
|
||||
tools_count = len(session.functions)
|
||||
|
||||
server_info = {
|
||||
'name': server['name'],
|
||||
'mode': server['mode'],
|
||||
'enable': server['enable'],
|
||||
'config': server,
|
||||
'status': 'disconnected',
|
||||
'tools': [],
|
||||
'error': None,
|
||||
**server,
|
||||
'tools': tools_count,
|
||||
}
|
||||
|
||||
# 检查服务器连接状态
|
||||
if mcp_loader and server['name'] in mcp_loader.sessions:
|
||||
session = mcp_loader.sessions[server['name']]
|
||||
server_info['status'] = 'connected'
|
||||
server_info['tools'] = [
|
||||
{'name': func.name, 'description': func.description, 'parameters': func.parameters}
|
||||
for func in session.functions
|
||||
]
|
||||
elif server['enable']:
|
||||
server_info['status'] = 'error'
|
||||
server_info['error'] = 'Failed to connect'
|
||||
|
||||
servers_with_status.append(server_info)
|
||||
|
||||
return self.success(data={'servers': servers_with_status})
|
||||
|
||||
elif quart.request.method == 'POST':
|
||||
data = await quart.request.json
|
||||
data = data['source']
|
||||
|
||||
# 验证必填字段
|
||||
required_fields = ['name', 'mode']
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
return self.http_status(400, -1, f'Missing required field: {field}')
|
||||
uuid = await self.ap.mcp_service.create_mcp_server(data)
|
||||
|
||||
# 检查provider_cfg是否可用
|
||||
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
|
||||
return self.http_status(500, -1, 'Provider configuration not available')
|
||||
|
||||
# 获取当前配置
|
||||
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
|
||||
servers = mcp_config['servers']
|
||||
|
||||
# 检查服务器名称是否重复
|
||||
for server in servers:
|
||||
if server['name'] == data['name']:
|
||||
return self.http_status(400, -1, 'Server name already exists')
|
||||
|
||||
# 创建新服务器配置
|
||||
new_server = {
|
||||
'name': data['name'],
|
||||
'mode': data['mode'],
|
||||
'enable': data.get('enable', True),
|
||||
}
|
||||
|
||||
# 根据模式添加配置
|
||||
if data['mode'] == 'stdio':
|
||||
new_server.update(
|
||||
{'command': data.get('command', ''), 'args': data.get('args', []), 'env': data.get('env', {})}
|
||||
)
|
||||
elif data['mode'] == 'sse':
|
||||
new_server.update(
|
||||
{
|
||||
'url': data.get('url', ''),
|
||||
'headers': data.get('headers', {}),
|
||||
'timeout': data.get('timeout', 10),
|
||||
}
|
||||
)
|
||||
|
||||
# 添加到配置
|
||||
servers.append(new_server)
|
||||
self.ap.provider_cfg.data['mcp'] = mcp_config
|
||||
|
||||
# 保存配置
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
|
||||
# 如果启用,尝试重新加载MCP loader
|
||||
if new_server['enable']:
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._reload_mcp_loader(ctx),
|
||||
kind='mcp-operation',
|
||||
name=f'mcp-reload-{new_server["name"]}',
|
||||
label=f'Reloading MCP loader for {new_server["name"]}',
|
||||
context=ctx,
|
||||
)
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
else:
|
||||
return self.success()
|
||||
else:
|
||||
return self.http_status(405, -1, 'Method not allowed')
|
||||
return self.success(data={'uuid': uuid})
|
||||
|
||||
@self.route('/servers/<server_name>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(server_name: str) -> str:
|
||||
"""获取、更新或删除MCP服务器配置"""
|
||||
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
|
||||
return self.http_status(500, -1, 'Provider configuration not available')
|
||||
|
||||
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
|
||||
servers = mcp_config['servers']
|
||||
|
||||
# 查找服务器
|
||||
server_index = None
|
||||
for i, server in enumerate(servers):
|
||||
if server['name'] == server_name:
|
||||
server_index = i
|
||||
break
|
||||
|
||||
if server_index is None:
|
||||
server_data = await self.ap.mcp_service.get_mcp_server_by_name(server_name)
|
||||
if server_data is None:
|
||||
return self.http_status(404, -1, 'Server not found')
|
||||
|
||||
if quart.request.method == 'GET':
|
||||
return self.success(data={'server': servers[server_index]})
|
||||
return self.success(data={'server': server_data})
|
||||
|
||||
elif quart.request.method == 'PUT':
|
||||
data = await quart.request.json
|
||||
server = servers[server_index]
|
||||
|
||||
# 更新配置
|
||||
server.update(
|
||||
{
|
||||
'enable': data.get('enable', server.get('enable', True)),
|
||||
}
|
||||
)
|
||||
|
||||
# 根据模式更新特定配置
|
||||
if server['mode'] == 'stdio':
|
||||
server.update(
|
||||
{
|
||||
'command': data.get('command', server.get('command', '')),
|
||||
'args': data.get('args', server.get('args', [])),
|
||||
'env': data.get('env', server.get('env', {})),
|
||||
}
|
||||
)
|
||||
elif server['mode'] == 'sse':
|
||||
server.update(
|
||||
{
|
||||
'url': data.get('url', server.get('url', '')),
|
||||
'headers': data.get('headers', server.get('headers', {})),
|
||||
'timeout': data.get('timeout', server.get('timeout', 10)),
|
||||
}
|
||||
)
|
||||
|
||||
# 保存配置
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
|
||||
# 重新加载MCP loader
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._reload_mcp_loader(ctx),
|
||||
kind='mcp-operation',
|
||||
name=f'mcp-reload-{server_name}',
|
||||
label=f'Reloading MCP loader for {server_name}',
|
||||
context=ctx,
|
||||
)
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
await self.ap.mcp_service.update_mcp_server(server_data['uuid'], data)
|
||||
return self.success()
|
||||
|
||||
elif quart.request.method == 'DELETE':
|
||||
# 删除服务器
|
||||
servers.pop(server_index)
|
||||
self.ap.provider_cfg.data['mcp'] = mcp_config
|
||||
|
||||
# 保存配置
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
|
||||
# 重新加载MCP loader
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._reload_mcp_loader(ctx),
|
||||
kind='mcp-operation',
|
||||
name=f'mcp-remove-{server_name}',
|
||||
label=f'Removing MCP server {server_name}',
|
||||
context=ctx,
|
||||
)
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
|
||||
@self.route('/servers/<server_name>/toggle', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(server_name: str) -> str:
|
||||
"""切换MCP服务器启用状态"""
|
||||
data = await quart.request.json
|
||||
target_enabled = data.get('target_enabled')
|
||||
|
||||
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
|
||||
return self.http_status(500, -1, 'Provider configuration not available')
|
||||
|
||||
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
|
||||
servers = mcp_config['servers']
|
||||
|
||||
# 查找并更新服务器
|
||||
for server in servers:
|
||||
if server['name'] == server_name:
|
||||
server['enable'] = target_enabled
|
||||
break
|
||||
else:
|
||||
return self.http_status(404, -1, 'Server not found')
|
||||
|
||||
# 保存配置
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
|
||||
# 重新加载MCP loader
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._reload_mcp_loader(ctx),
|
||||
kind='mcp-operation',
|
||||
name=f'mcp-toggle-{server_name}',
|
||||
label=f'Toggling MCP server {server_name}',
|
||||
context=ctx,
|
||||
)
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
await self.ap.mcp_service.delete_mcp_server(server_data['uuid'])
|
||||
return self.success()
|
||||
|
||||
@self.route('/servers/<server_name>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(server_name: str) -> str:
|
||||
"""测试MCP服务器连接"""
|
||||
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
|
||||
return self.http_status(500, -1, 'Provider configuration not available')
|
||||
|
||||
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
|
||||
servers = mcp_config['servers']
|
||||
|
||||
# 查找服务器配置
|
||||
server_config = None
|
||||
for server in servers:
|
||||
if server['name'] == server_name:
|
||||
server_config = server
|
||||
break
|
||||
|
||||
if server_config is None:
|
||||
server_data = await self.ap.mcp_service.get_mcp_server_by_name(server_name)
|
||||
if server_data is None:
|
||||
return self.http_status(404, -1, 'Server not found')
|
||||
|
||||
# 创建测试任务
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._test_mcp_server(server_config, ctx),
|
||||
kind='mcp-operation',
|
||||
name=f'mcp-test-{server_name}',
|
||||
label=f'Testing MCP server {server_name}',
|
||||
context=ctx,
|
||||
)
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
|
||||
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
"""从GitHub安装MCP服务器"""
|
||||
data = await quart.request.json
|
||||
source = data.get('source')
|
||||
# TODO 这里移到service去
|
||||
# # 创建测试任务
|
||||
# ctx = taskmgr.TaskContext.new()
|
||||
# wrapper = self.ap.task_mgr.create_user_task(
|
||||
# self._test_mcp_server(server, ctx),
|
||||
# kind='mcp-operation',
|
||||
# name=f'mcp-test-{server_name}',
|
||||
# label=f'Testing MCP server {server_name}',
|
||||
# context=ctx,
|
||||
# )
|
||||
# return self.success(data={'task_id': wrapper.id})
|
||||
|
||||
if not source:
|
||||
return self.http_status(400, -1, 'Missing source parameter')
|
||||
# async def _test_mcp_server(self, server: persistence_mcp.MCPServer, ctx: taskmgr.TaskContext):
|
||||
# """测试MCP服务器连接"""
|
||||
# try:
|
||||
|
||||
# 创建安装任务
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._install_mcp_from_github(source, ctx),
|
||||
kind='mcp-operation',
|
||||
name='install-mcp-github',
|
||||
label=f'Installing MCP from GitHub: {source}',
|
||||
context=ctx,
|
||||
)
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
# ctx.current_action = f'Testing connection to {server.name}'
|
||||
# # 创建临时会话进行测试
|
||||
# session = RuntimeMCPSession(server.name, {
|
||||
# 'name': server.name,
|
||||
# 'mode': server.mode,
|
||||
# 'enable': server.enable,
|
||||
# 'url': server.extra_args.get('url',''),
|
||||
# 'headers': server.extra_args.get('headers',{}),
|
||||
# 'timeout': server.extra_args.get('timeout',60),
|
||||
# },enable=True, ap=self.ap)
|
||||
# await session.start()
|
||||
|
||||
async def _reload_mcp_loader(self, ctx: taskmgr.TaskContext):
|
||||
"""重新加载MCP loader"""
|
||||
try:
|
||||
ctx.current_action = 'Stopping existing MCP sessions'
|
||||
# 停止现有的MCP会话
|
||||
mcp_loader = None
|
||||
for loader_name, loader in self.ap.tool_mgr.loaders.items():
|
||||
if loader_name == 'mcp':
|
||||
mcp_loader = loader
|
||||
break
|
||||
# # 获取工具列表作为测试
|
||||
# tools_count = len(session.functions)
|
||||
|
||||
if mcp_loader:
|
||||
await mcp_loader.shutdown()
|
||||
# tool_name_list = []
|
||||
# for function in session.functions:
|
||||
# tool_name_list.append(function.name)
|
||||
# ctx.current_action = f'Successfully connected. Found {tools_count} tools.'
|
||||
|
||||
ctx.current_action = 'Reloading MCP configuration'
|
||||
# 重新加载MCP loader
|
||||
await self.ap.tool_mgr.reload_loader('mcp')
|
||||
# # 关闭测试会话
|
||||
# await session.shutdown()
|
||||
|
||||
ctx.current_action = 'MCP loader reloaded successfully'
|
||||
# return {'status': 'success', 'tools_count': tools_count,'tools_names_lists':tool_name_list}
|
||||
|
||||
except Exception as e:
|
||||
ctx.current_action = f'Failed to reload MCP loader: {str(e)}'
|
||||
raise e
|
||||
|
||||
async def _test_mcp_server(self, server_config: dict, ctx: taskmgr.TaskContext):
|
||||
"""测试MCP服务器连接"""
|
||||
try:
|
||||
from ......provider.tools.loaders.mcp import RuntimeMCPSession
|
||||
|
||||
ctx.current_action = f'Testing connection to {server_config["name"]}'
|
||||
|
||||
# 创建临时会话进行测试
|
||||
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
|
||||
await session.initialize()
|
||||
|
||||
# 获取工具列表作为测试
|
||||
tools_count = len(session.functions)
|
||||
ctx.current_action = f'Successfully connected. Found {tools_count} tools.'
|
||||
|
||||
# 关闭测试会话
|
||||
await session.shutdown()
|
||||
|
||||
return {'status': 'success', 'tools_count': tools_count}
|
||||
|
||||
except Exception as e:
|
||||
ctx.current_action = f'Connection test failed: {str(e)}'
|
||||
raise e
|
||||
|
||||
async def _install_mcp_from_github(self, source: str, ctx: taskmgr.TaskContext):
|
||||
"""从GitHub安装MCP服务器的实现"""
|
||||
try:
|
||||
ctx.current_action = f'Installing MCP server from {source}'
|
||||
|
||||
# 这里是安装逻辑的占位符
|
||||
# 实际实现将包括克隆仓库、解析配置、安装依赖等步骤
|
||||
|
||||
# 模拟安装过程
|
||||
|
||||
await asyncio.sleep(2) # 模拟安装过程
|
||||
|
||||
# 返回成功结果
|
||||
return {'status': 'success', 'message': f'Successfully installed MCP server from {source}'}
|
||||
|
||||
except Exception as e:
|
||||
ctx.current_action = f'Failed to install MCP server: {str(e)}'
|
||||
raise e
|
||||
# except Exception as e:
|
||||
# print(traceback.format_exc())
|
||||
# ctx.current_action = f'Connection test failed: {str(e)}'
|
||||
# raise e
|
||||
|
||||
@@ -15,12 +15,14 @@ from .groups import provider as groups_provider
|
||||
from .groups import platform as groups_platform
|
||||
from .groups import pipelines as groups_pipelines
|
||||
from .groups import knowledge as groups_knowledge
|
||||
from .groups import resources as groups_resources
|
||||
|
||||
importutil.import_modules_in_pkg(groups)
|
||||
importutil.import_modules_in_pkg(groups_provider)
|
||||
importutil.import_modules_in_pkg(groups_platform)
|
||||
importutil.import_modules_in_pkg(groups_pipelines)
|
||||
importutil.import_modules_in_pkg(groups_knowledge)
|
||||
importutil.import_modules_in_pkg(groups_resources)
|
||||
|
||||
|
||||
class HTTPController:
|
||||
|
||||
63
pkg/api/http/service/mcp.py
Normal file
63
pkg/api/http/service/mcp.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy
|
||||
import uuid
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import mcp as persistence_mcp
|
||||
|
||||
|
||||
class MCPService:
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_mcp_servers(self) -> list[dict]:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
|
||||
|
||||
servers = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server) for server in servers]
|
||||
|
||||
async def create_mcp_server(self, server_data: dict) -> str:
|
||||
server_data['uuid'] = str(uuid.uuid4())
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_mcp.MCPServer).values(server_data))
|
||||
server = await self.get_mcp_server(server_data['uuid'])
|
||||
|
||||
# TODO: load runtime mcp server session
|
||||
|
||||
return server['uuid']
|
||||
|
||||
async def get_mcp_server(self, server_uuid: str) -> dict | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
|
||||
)
|
||||
server = result.first()
|
||||
if server is None:
|
||||
return None
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
|
||||
|
||||
async def get_mcp_server_by_name(self, server_name: str) -> dict | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.name == server_name)
|
||||
)
|
||||
server = result.first()
|
||||
if server is None:
|
||||
return None
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
|
||||
|
||||
async def update_mcp_server(self, server_uuid: str, server_data: dict) -> None:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_mcp.MCPServer)
|
||||
.where(persistence_mcp.MCPServer.uuid == server_uuid)
|
||||
.values(server_data)
|
||||
)
|
||||
|
||||
# TODO: reload runtime mcp server session
|
||||
|
||||
async def delete_mcp_server(self, server_uuid: str) -> None:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
|
||||
)
|
||||
|
||||
# TODO: remove runtime mcp server session
|
||||
@@ -22,6 +22,7 @@ from ..api.http.service import model as model_service
|
||||
from ..api.http.service import pipeline as pipeline_service
|
||||
from ..api.http.service import bot as bot_service
|
||||
from ..api.http.service import knowledge as knowledge_service
|
||||
from ..api.http.service import mcp as mcp_service
|
||||
from ..discover import engine as discover_engine
|
||||
from ..storage import mgr as storagemgr
|
||||
from ..utils import logcache
|
||||
@@ -119,6 +120,8 @@ class Application:
|
||||
|
||||
knowledge_service: knowledge_service.KnowledgeService = None
|
||||
|
||||
mcp_service: mcp_service.MCPService = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from ...api.http.service import model as model_service
|
||||
from ...api.http.service import pipeline as pipeline_service
|
||||
from ...api.http.service import bot as bot_service
|
||||
from ...api.http.service import knowledge as knowledge_service
|
||||
from ...api.http.service import mcp as mcp_service
|
||||
from ...discover import engine as discover_engine
|
||||
from ...storage import mgr as storagemgr
|
||||
from ...utils import logcache
|
||||
@@ -126,5 +127,8 @@ class BuildAppStage(stage.BootingStage):
|
||||
knowledge_service_inst = knowledge_service.KnowledgeService(ap)
|
||||
ap.knowledge_service = knowledge_service_inst
|
||||
|
||||
mcp_service_inst = mcp_service.MCPService(ap)
|
||||
ap.mcp_service = mcp_service_inst
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from . import loader as tools_loader
|
||||
from ...utils import importutil
|
||||
from . import loaders
|
||||
from .loaders import mcp as mcp_loader, plugin as plugin_loader
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
|
||||
importutil.import_modules_in_pkg(loaders)
|
||||
@@ -16,25 +16,24 @@ class ToolManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
loaders: list[tools_loader.ToolLoader]
|
||||
plugin_tool_loader: plugin_loader.PluginToolLoader
|
||||
mcp_tool_loader: mcp_loader.MCPLoader
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.all_functions = []
|
||||
self.loaders = []
|
||||
|
||||
async def initialize(self):
|
||||
for loader_cls in tools_loader.preregistered_loaders:
|
||||
loader_inst = loader_cls(self.ap)
|
||||
await loader_inst.initialize()
|
||||
self.loaders.append(loader_inst)
|
||||
self.plugin_tool_loader = plugin_loader.PluginToolLoader(self.ap)
|
||||
await self.plugin_tool_loader.initialize()
|
||||
self.mcp_tool_loader = mcp_loader.MCPLoader(self.ap)
|
||||
await self.mcp_tool_loader.initialize()
|
||||
|
||||
async def get_all_tools(self) -> list[resource_tool.LLMTool]:
|
||||
"""获取所有函数"""
|
||||
all_functions: list[resource_tool.LLMTool] = []
|
||||
|
||||
for loader in self.loaders:
|
||||
all_functions.extend(await loader.get_tools())
|
||||
all_functions.extend(await self.plugin_tool_loader.get_tools())
|
||||
all_functions.extend(await self.mcp_tool_loader.get_tools())
|
||||
|
||||
return all_functions
|
||||
|
||||
@@ -93,13 +92,14 @@ class ToolManager:
|
||||
async def execute_func_call(self, name: str, parameters: dict) -> typing.Any:
|
||||
"""执行函数调用"""
|
||||
|
||||
for loader in self.loaders:
|
||||
if await loader.has_tool(name):
|
||||
return await loader.invoke_tool(name, parameters)
|
||||
if await self.plugin_tool_loader.has_tool(name):
|
||||
return await self.plugin_tool_loader.invoke_tool(name, parameters)
|
||||
elif await self.mcp_tool_loader.has_tool(name):
|
||||
return await self.mcp_tool_loader.invoke_tool(name, parameters)
|
||||
else:
|
||||
raise ValueError(f'未找到工具: {name}')
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭所有工具"""
|
||||
for loader in self.loaders:
|
||||
await loader.shutdown()
|
||||
await self.plugin_tool_loader.shutdown()
|
||||
await self.mcp_tool_loader.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user