Files
LangBot/pkg/provider/tools/loaders/mcp.py
Junyan Qin (Chin) 209f16af76 style: introduce ruff as linter and formatter (#1356)
* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
2025-04-29 17:24:07 +08:00

172 lines
5.1 KiB
Python

from __future__ import annotations
import typing
from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from .. import loader, entities as tools_entities
from ....core import app, entities as core_entities
class RuntimeMCPSession:
"""运行时 MCP 会话"""
ap: app.Application
server_name: str
server_config: dict
session: ClientSession
exit_stack: AsyncExitStack
functions: list[tools_entities.LLMFunction] = []
def __init__(self, server_name: str, server_config: dict, ap: app.Application):
self.server_name = server_name
self.server_config = server_config
self.ap = ap
self.session = None
self.exit_stack = AsyncExitStack()
self.functions = []
async def _init_stdio_python_server(self):
server_params = StdioServerParameters(
command=self.server_config['command'],
args=self.server_config['args'],
env=self.server_config['env'],
)
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
stdio, write = stdio_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(stdio, write)
)
await self.session.initialize()
async def _init_sse_server(self):
sse_transport = await self.exit_stack.enter_async_context(
sse_client(
self.server_config['url'],
headers=self.server_config.get('headers', {}),
timeout=self.server_config.get('timeout', 10),
)
)
sseio, write = sse_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(sseio, write)
)
await self.session.initialize()
async def initialize(self):
self.ap.logger.debug(
f'初始化 MCP 会话: {self.server_name} {self.server_config}'
)
if self.server_config['mode'] == 'stdio':
await self._init_stdio_python_server()
elif self.server_config['mode'] == 'sse':
await self._init_sse_server()
else:
raise ValueError(
f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}'
)
tools = await self.session.list_tools()
self.ap.logger.debug(f'获取 MCP 工具: {tools}')
for tool in tools.tools:
async def func(query: core_entities.Query, **kwargs):
result = await self.session.call_tool(tool.name, kwargs)
if result.isError:
raise Exception(result.content[0].text)
return result.content[0].text
func.__name__ = tool.name
self.functions.append(
tools_entities.LLMFunction(
name=tool.name,
human_desc=tool.description,
description=tool.description,
parameters=tool.inputSchema,
func=func,
)
)
async def shutdown(self):
"""关闭工具"""
await self.session._exit_stack.aclose()
@loader.loader_class('mcp')
class MCPLoader(loader.ToolLoader):
"""MCP 工具加载器。
在此加载器中管理所有与 MCP Server 的连接。
"""
sessions: dict[str, RuntimeMCPSession] = {}
_last_listed_functions: list[tools_entities.LLMFunction] = []
def __init__(self, ap: app.Application):
super().__init__(ap)
self.sessions = {}
self._last_listed_functions = []
async def initialize(self):
for server_config in self.ap.instance_config.data.get('mcp', {}).get(
'servers', []
):
if not server_config['enable']:
continue
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
await session.initialize()
# self.ap.event_loop.create_task(session.initialize())
self.sessions[server_config['name']] = session
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
all_functions = []
for session in self.sessions.values():
all_functions.extend(session.functions)
self._last_listed_functions = all_functions
return all_functions
async def has_tool(self, name: str) -> bool:
return name in [f.name for f in self._last_listed_functions]
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
for server_name, session in self.sessions.items():
for function in session.functions:
if function.name == name:
return await function.func(query, **parameters)
raise ValueError(f'未找到工具: {name}')
async def shutdown(self):
"""关闭工具"""
for session in self.sessions.values():
await session.shutdown()