chore: stash

This commit is contained in:
Junyan Qin
2025-10-11 19:10:56 +08:00
parent a3552893aa
commit 0f39a31648
4 changed files with 447 additions and 12 deletions

View File

@@ -0,0 +1,355 @@
from __future__ import annotations
import quart
import asyncio
from ......core import taskmgr
from ... import group
@group.group_class('mcp', '/api/v1/mcp')
class MCPRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/servers', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""获取MCP服务器列表"""
if quart.request.method == 'GET':
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.success(data={'servers': []})
servers = self.ap.provider_cfg.data.get('mcp', {}).get('servers', [])
# 获取每个服务器的状态和工具信息
mcp_loader = None
for loader_name, loader in self.ap.tool_mgr.loaders.items():
if loader_name == 'mcp':
mcp_loader = loader
break
servers_with_status = []
for server in servers:
server_info = {
'name': server['name'],
'mode': server['mode'],
'enable': server['enable'],
'config': server,
'status': 'disconnected',
'tools': [],
'error': None,
}
# 检查服务器连接状态
if mcp_loader and server['name'] in mcp_loader.sessions:
session = mcp_loader.sessions[server['name']]
server_info['status'] = 'connected'
server_info['tools'] = [
{'name': func.name, 'description': func.description, 'parameters': func.parameters}
for func in session.functions
]
elif server['enable']:
server_info['status'] = 'error'
server_info['error'] = 'Failed to connect'
servers_with_status.append(server_info)
return self.success(data={'servers': servers_with_status})
elif quart.request.method == 'POST':
data = await quart.request.json
# 验证必填字段
required_fields = ['name', 'mode']
for field in required_fields:
if field not in data:
return self.http_status(400, -1, f'Missing required field: {field}')
# 检查provider_cfg是否可用
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.http_status(500, -1, 'Provider configuration not available')
# 获取当前配置
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
servers = mcp_config['servers']
# 检查服务器名称是否重复
for server in servers:
if server['name'] == data['name']:
return self.http_status(400, -1, 'Server name already exists')
# 创建新服务器配置
new_server = {
'name': data['name'],
'mode': data['mode'],
'enable': data.get('enable', True),
}
# 根据模式添加配置
if data['mode'] == 'stdio':
new_server.update(
{'command': data.get('command', ''), 'args': data.get('args', []), 'env': data.get('env', {})}
)
elif data['mode'] == 'sse':
new_server.update(
{
'url': data.get('url', ''),
'headers': data.get('headers', {}),
'timeout': data.get('timeout', 10),
}
)
# 添加到配置
servers.append(new_server)
self.ap.provider_cfg.data['mcp'] = mcp_config
# 保存配置
await self.ap.provider_cfg.dump_config()
# 如果启用尝试重新加载MCP loader
if new_server['enable']:
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._reload_mcp_loader(ctx),
kind='mcp-operation',
name=f'mcp-reload-{new_server["name"]}',
label=f'Reloading MCP loader for {new_server["name"]}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
else:
return self.success()
else:
return self.http_status(405, -1, 'Method not allowed')
@self.route('/servers/<server_name>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
"""获取、更新或删除MCP服务器配置"""
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.http_status(500, -1, 'Provider configuration not available')
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
servers = mcp_config['servers']
# 查找服务器
server_index = None
for i, server in enumerate(servers):
if server['name'] == server_name:
server_index = i
break
if server_index is None:
return self.http_status(404, -1, 'Server not found')
if quart.request.method == 'GET':
return self.success(data={'server': servers[server_index]})
elif quart.request.method == 'PUT':
data = await quart.request.json
server = servers[server_index]
# 更新配置
server.update(
{
'enable': data.get('enable', server.get('enable', True)),
}
)
# 根据模式更新特定配置
if server['mode'] == 'stdio':
server.update(
{
'command': data.get('command', server.get('command', '')),
'args': data.get('args', server.get('args', [])),
'env': data.get('env', server.get('env', {})),
}
)
elif server['mode'] == 'sse':
server.update(
{
'url': data.get('url', server.get('url', '')),
'headers': data.get('headers', server.get('headers', {})),
'timeout': data.get('timeout', server.get('timeout', 10)),
}
)
# 保存配置
await self.ap.provider_cfg.dump_config()
# 重新加载MCP loader
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._reload_mcp_loader(ctx),
kind='mcp-operation',
name=f'mcp-reload-{server_name}',
label=f'Reloading MCP loader for {server_name}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
elif quart.request.method == 'DELETE':
# 删除服务器
servers.pop(server_index)
self.ap.provider_cfg.data['mcp'] = mcp_config
# 保存配置
await self.ap.provider_cfg.dump_config()
# 重新加载MCP loader
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._reload_mcp_loader(ctx),
kind='mcp-operation',
name=f'mcp-remove-{server_name}',
label=f'Removing MCP server {server_name}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/servers/<server_name>/toggle', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
"""切换MCP服务器启用状态"""
data = await quart.request.json
target_enabled = data.get('target_enabled')
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.http_status(500, -1, 'Provider configuration not available')
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
servers = mcp_config['servers']
# 查找并更新服务器
for server in servers:
if server['name'] == server_name:
server['enable'] = target_enabled
break
else:
return self.http_status(404, -1, 'Server not found')
# 保存配置
await self.ap.provider_cfg.dump_config()
# 重新加载MCP loader
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._reload_mcp_loader(ctx),
kind='mcp-operation',
name=f'mcp-toggle-{server_name}',
label=f'Toggling MCP server {server_name}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/servers/<server_name>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
"""测试MCP服务器连接"""
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.http_status(500, -1, 'Provider configuration not available')
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
servers = mcp_config['servers']
# 查找服务器配置
server_config = None
for server in servers:
if server['name'] == server_name:
server_config = server
break
if server_config is None:
return self.http_status(404, -1, 'Server not found')
# 创建测试任务
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._test_mcp_server(server_config, ctx),
kind='mcp-operation',
name=f'mcp-test-{server_name}',
label=f'Testing MCP server {server_name}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""从GitHub安装MCP服务器"""
data = await quart.request.json
source = data.get('source')
if not source:
return self.http_status(400, -1, 'Missing source parameter')
# 创建安装任务
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._install_mcp_from_github(source, ctx),
kind='mcp-operation',
name='install-mcp-github',
label=f'Installing MCP from GitHub: {source}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
async def _reload_mcp_loader(self, ctx: taskmgr.TaskContext):
"""重新加载MCP loader"""
try:
ctx.current_action = 'Stopping existing MCP sessions'
# 停止现有的MCP会话
mcp_loader = None
for loader_name, loader in self.ap.tool_mgr.loaders.items():
if loader_name == 'mcp':
mcp_loader = loader
break
if mcp_loader:
await mcp_loader.shutdown()
ctx.current_action = 'Reloading MCP configuration'
# 重新加载MCP loader
await self.ap.tool_mgr.reload_loader('mcp')
ctx.current_action = 'MCP loader reloaded successfully'
except Exception as e:
ctx.current_action = f'Failed to reload MCP loader: {str(e)}'
raise e
async def _test_mcp_server(self, server_config: dict, ctx: taskmgr.TaskContext):
"""测试MCP服务器连接"""
try:
from ......provider.tools.loaders.mcp import RuntimeMCPSession
ctx.current_action = f'Testing connection to {server_config["name"]}'
# 创建临时会话进行测试
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
await session.initialize()
# 获取工具列表作为测试
tools_count = len(session.functions)
ctx.current_action = f'Successfully connected. Found {tools_count} tools.'
# 关闭测试会话
await session.shutdown()
return {'status': 'success', 'tools_count': tools_count}
except Exception as e:
ctx.current_action = f'Connection test failed: {str(e)}'
raise e
async def _install_mcp_from_github(self, source: str, ctx: taskmgr.TaskContext):
"""从GitHub安装MCP服务器的实现"""
try:
ctx.current_action = f'Installing MCP server from {source}'
# 这里是安装逻辑的占位符
# 实际实现将包括克隆仓库、解析配置、安装依赖等步骤
# 模拟安装过程
await asyncio.sleep(2) # 模拟安装过程
# 返回成功结果
return {'status': 'success', 'message': f'Successfully installed MCP server from {source}'}
except Exception as e:
ctx.current_action = f'Failed to install MCP server: {str(e)}'
raise e

View File

@@ -0,0 +1,20 @@
import sqlalchemy
from .base import Base
class MCPServer(Base):
__tablename__ = 'mcp_servers'
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
mode = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) # stdio, sse
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,
server_default=sqlalchemy.func.now(),
onupdate=sqlalchemy.func.now(),
)

View File

@@ -59,7 +59,7 @@ class ModelManager:
try:
await self.load_llm_model(llm_model)
except provider_errors.RequesterNotFoundError as e:
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping model {llm_model.uuid}')
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping llm model {llm_model.uuid}')
except Exception as e:
self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}')
@@ -67,7 +67,14 @@ class ModelManager:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
embedding_models = result.all()
for embedding_model in embedding_models:
try:
await self.load_embedding_model(embedding_model)
except provider_errors.RequesterNotFoundError as e:
self.ap.logger.warning(
f'Requester {e.requester_name} not found, skipping embedding model {embedding_model.uuid}'
)
except Exception as e:
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
async def init_runtime_llm_model(
self,
@@ -107,6 +114,9 @@ class ModelManager:
elif isinstance(model_info, dict):
model_info = persistence_model.EmbeddingModel(**model_info)
if model_info.requester not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(model_info.requester)
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
await requester_inst.initialize()

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import typing
from contextlib import AsyncExitStack
import traceback
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@@ -9,7 +10,9 @@ from mcp.client.sse import sse_client
from .. import loader
from ....core import app
from ....entity.persistence import mcp as persistence_mcp
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import sqlalchemy
class RuntimeMCPSession:
@@ -27,11 +30,13 @@ class RuntimeMCPSession:
functions: list[resource_tool.LLMTool] = []
def __init__(self, server_name: str, server_config: dict, ap: app.Application):
enable: bool
def __init__(self, server_name: str, server_config: dict, enable: bool, ap: app.Application):
self.server_name = server_name
self.server_config = server_config
self.ap = ap
self.enable = enable
self.session = None
self.exit_stack = AsyncExitStack()
@@ -68,6 +73,12 @@ class RuntimeMCPSession:
await self.session.initialize()
async def initialize(self):
pass
async def start(self):
if not self.enable:
return
self.ap.logger.debug(f'初始化 MCP 会话: {self.server_name} {self.server_config}')
if self.server_config['mode'] == 'stdio':
@@ -123,13 +134,45 @@ class MCPLoader(loader.ToolLoader):
self._last_listed_functions = []
async def initialize(self):
for server_config in self.ap.instance_config.data.get('mcp', {}).get('servers', []):
if not server_config['enable']:
continue
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
await self.load_mcp_servers_from_db()
async def load_mcp_servers_from_db(self):
self.ap.logger.info('Loading MCP servers from db...')
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
servers = result.all()
for server in servers:
try:
await self.load_mcp_server(server)
except Exception as e:
self.ap.logger.error(f'Failed to load MCP server {server.name}: {e}\n{traceback.format_exc()}')
async def init_runtime_mcp_session(
self,
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
):
if isinstance(server_entity, sqlalchemy.Row):
server_entity = persistence_mcp.MCPServer(**server_entity._mapping)
elif isinstance(server_entity, dict):
server_entity = persistence_mcp.MCPServer(**server_entity)
mixed_config = {
'name': server_entity.name,
'mode': server_entity.mode,
'enable': server_entity.enable,
**server_entity.extra_args,
}
session = RuntimeMCPSession(server_entity.name, mixed_config, server_entity.enable, self.ap)
await session.initialize()
# self.ap.event_loop.create_task(session.initialize())
self.sessions[server_config['name']] = session
return session
async def load_mcp_server(
self,
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
):
session = await self.init_runtime_mcp_session(server_entity)
self.sessions[server_entity.name] = session
async def get_tools(self) -> list[resource_tool.LLMTool]:
all_functions = []
@@ -150,7 +193,14 @@ class MCPLoader(loader.ToolLoader):
if function.name == name:
return await function.func(**parameters)
raise ValueError(f'未找到工具: {name}')
raise ValueError(f'Tool not found: {name}')
async def remove_mcp_server(self, server_name: str):
if server_name not in self.sessions:
raise ValueError(f'MCP server {server_name} not found')
session = self.sessions.pop(server_name)
await session.shutdown()
async def shutdown(self):
"""关闭工具"""