mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 03:15:06 +08:00
162 lines
5.0 KiB
Python
162 lines
5.0 KiB
Python
from __future__ import annotations
|
|
|
|
import typing
|
|
from contextlib import AsyncExitStack
|
|
|
|
from mcp import ClientSession, StdioServerParameters
|
|
from mcp.client.stdio import stdio_client
|
|
from mcp.client.sse import sse_client
|
|
|
|
from .. import loader, entities as tools_entities
|
|
from ....core import app, entities as core_entities
|
|
|
|
|
|
class RuntimeMCPSession:
|
|
"""运行时 MCP 会话"""
|
|
|
|
ap: app.Application
|
|
|
|
server_name: str
|
|
|
|
server_config: dict
|
|
|
|
session: ClientSession
|
|
|
|
exit_stack: AsyncExitStack
|
|
|
|
functions: list[tools_entities.LLMFunction] = []
|
|
|
|
def __init__(self, server_name: str, server_config: dict, ap: app.Application):
|
|
self.server_name = server_name
|
|
self.server_config = server_config
|
|
self.ap = ap
|
|
|
|
self.session = None
|
|
|
|
self.exit_stack = AsyncExitStack()
|
|
self.functions = []
|
|
|
|
async def _init_stdio_python_server(self):
|
|
server_params = StdioServerParameters(
|
|
command=self.server_config["command"],
|
|
args=self.server_config["args"],
|
|
env=self.server_config["env"],
|
|
)
|
|
|
|
stdio_transport = await self.exit_stack.enter_async_context(
|
|
stdio_client(server_params)
|
|
)
|
|
|
|
stdio, write = stdio_transport
|
|
|
|
self.session = await self.exit_stack.enter_async_context(
|
|
ClientSession(stdio, write)
|
|
)
|
|
|
|
await self.session.initialize()
|
|
|
|
async def _init_sse_server(self):
|
|
sse_transport = await self.exit_stack.enter_async_context(
|
|
sse_client(
|
|
self.server_config["url"],
|
|
headers=self.server_config.get("headers", {}),
|
|
timeout=self.server_config.get("timeout", 10),
|
|
)
|
|
)
|
|
|
|
sseio, write = sse_transport
|
|
|
|
self.session = await self.exit_stack.enter_async_context(
|
|
ClientSession(sseio, write)
|
|
)
|
|
|
|
await self.session.initialize()
|
|
|
|
async def initialize(self):
|
|
self.ap.logger.debug(f"初始化 MCP 会话: {self.server_name} {self.server_config}")
|
|
|
|
if self.server_config["mode"] == "stdio":
|
|
await self._init_stdio_python_server()
|
|
elif self.server_config["mode"] == "sse":
|
|
await self._init_sse_server()
|
|
else:
|
|
raise ValueError(f"无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}")
|
|
|
|
tools = await self.session.list_tools()
|
|
|
|
self.ap.logger.debug(f"获取 MCP 工具: {tools}")
|
|
|
|
for tool in tools.tools:
|
|
|
|
async def func(query: core_entities.Query, **kwargs):
|
|
result = await self.session.call_tool(tool.name, kwargs)
|
|
if result.isError:
|
|
raise Exception(result.content[0].text)
|
|
return result.content[0].text
|
|
|
|
func.__name__ = tool.name
|
|
|
|
self.functions.append(tools_entities.LLMFunction(
|
|
name=tool.name,
|
|
human_desc=tool.description,
|
|
description=tool.description,
|
|
parameters=tool.inputSchema,
|
|
func=func,
|
|
))
|
|
|
|
async def shutdown(self):
|
|
"""关闭工具"""
|
|
await self.session._exit_stack.aclose()
|
|
|
|
@loader.loader_class("mcp")
|
|
class MCPLoader(loader.ToolLoader):
|
|
"""MCP 工具加载器。
|
|
|
|
在此加载器中管理所有与 MCP Server 的连接。
|
|
"""
|
|
|
|
sessions: dict[str, RuntimeMCPSession] = {}
|
|
|
|
_last_listed_functions: list[tools_entities.LLMFunction] = []
|
|
|
|
def __init__(self, ap: app.Application):
|
|
super().__init__(ap)
|
|
self.sessions = {}
|
|
self._last_listed_functions = []
|
|
|
|
async def initialize(self):
|
|
|
|
for server_config in self.ap.instance_config.data.get("mcp", {}).get("servers", []):
|
|
if not server_config["enable"]:
|
|
continue
|
|
session = RuntimeMCPSession(server_config["name"], server_config, self.ap)
|
|
await session.initialize()
|
|
# self.ap.event_loop.create_task(session.initialize())
|
|
self.sessions[server_config["name"]] = session
|
|
|
|
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
|
|
all_functions = []
|
|
|
|
for session in self.sessions.values():
|
|
all_functions.extend(session.functions)
|
|
|
|
self._last_listed_functions = all_functions
|
|
|
|
return all_functions
|
|
|
|
async def has_tool(self, name: str) -> bool:
|
|
return name in [f.name for f in self._last_listed_functions]
|
|
|
|
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
|
for server_name, session in self.sessions.items():
|
|
for function in session.functions:
|
|
if function.name == name:
|
|
return await function.func(query, **parameters)
|
|
|
|
raise ValueError(f"未找到工具: {name}")
|
|
|
|
async def shutdown(self):
|
|
"""关闭工具"""
|
|
for session in self.sessions.values():
|
|
await session.shutdown()
|