feat: pass session and query_id to tool call

This commit is contained in:
Junyan Qin
2025-11-20 21:17:47 +08:00
parent 3182214663
commit 2bf593fa6b
8 changed files with 34 additions and 22 deletions

View File

@@ -55,17 +55,6 @@ class SystemRouterGroup(group.RouterGroup):
return self.success(data=exec(py_code, {'ap': ap}))
@self.route('/debug/tools/call', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
data = await quart.request.json
return self.success(
data=await self.ap.tool_mgr.execute_func_call(data['tool_name'], data['tool_parameters'])
)
@self.route(
'/debug/plugin/action',
methods=['POST'],

View File

@@ -8,6 +8,7 @@ import os
import sys
import httpx
from async_lru import alru_cache
from langbot_plugin.api.entities.builtin.pipeline.query import provider_session
from ..core import app
from . import handler
@@ -321,13 +322,20 @@ class PluginRuntimeConnector:
return tools
async def call_tool(
self, tool_name: str, parameters: dict[str, Any], bound_plugins: list[str] | None = None
self,
tool_name: str,
parameters: dict[str, Any],
session: provider_session.Session,
query_id: int,
bound_plugins: list[str] | None = None,
) -> dict[str, Any]:
if not self.is_enable_plugin:
return {'error': 'Tool not found: plugin system is disabled'}
# Pass include_plugins to runtime for validation
return await self.handler.call_tool(tool_name, parameters, include_plugins=bound_plugins)
return await self.handler.call_tool(
tool_name, parameters, session.model_dump(serialize_as_any=True), query_id, include_plugins=bound_plugins
)
async def list_commands(self, bound_plugins: list[str] | None = None) -> list[ComponentManifest]:
if not self.is_enable_plugin:

View File

@@ -620,7 +620,12 @@ class RuntimeConnectionHandler(handler.Handler):
)
async def call_tool(
self, tool_name: str, parameters: dict[str, Any], include_plugins: list[str] | None = None
self,
tool_name: str,
parameters: dict[str, Any],
session: dict[str, Any],
query_id: int,
include_plugins: list[str] | None = None,
) -> dict[str, Any]:
"""Call tool"""
result = await self.call_action(
@@ -628,6 +633,8 @@ class RuntimeConnectionHandler(handler.Handler):
{
'tool_name': tool_name,
'tool_parameters': parameters,
'session': session,
'query_id': query_id,
'include_plugins': include_plugins,
},
timeout=60,

View File

@@ -196,7 +196,7 @@ class LocalAgentRunner(runner.RequestRunner):
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters)
func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters, query=query)
if is_stream:
msg = provider_message.MessageChunk(
role='tool',

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
import abc
import typing
from langbot_plugin.api.entities.events import pipeline_query
from ...core import app
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
@@ -45,7 +47,7 @@ class ToolLoader(abc.ABC):
pass
@abc.abstractmethod
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
async def invoke_tool(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
"""执行工具调用"""
pass

View File

@@ -4,6 +4,7 @@ import enum
import typing
from contextlib import AsyncExitStack
import traceback
from langbot_plugin.api.entities.events import pipeline_query
import sqlalchemy
import asyncio
@@ -329,7 +330,7 @@ class MCPLoader(loader.ToolLoader):
return True
return False
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
async def invoke_tool(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
"""执行工具调用"""
for session in self.sessions.values():
for function in session.get_tools():

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
import typing
import traceback
from langbot_plugin.api.entities.events import pipeline_query
from .. import loader
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
@@ -43,9 +45,11 @@ class PluginToolLoader(loader.ToolLoader):
return tool
return None
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
async def invoke_tool(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
try:
return await self.ap.plugin_connector.call_tool(name, parameters)
return await self.ap.plugin_connector.call_tool(
name, parameters, session=query.session, query_id=query.query_id
)
except Exception as e:
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
traceback.print_exc()

View File

@@ -7,6 +7,7 @@ from langbot.pkg.utils import importutil
from langbot.pkg.provider.tools import loaders
from langbot.pkg.provider.tools.loaders import mcp as mcp_loader, plugin as plugin_loader
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from langbot_plugin.api.entities.events import pipeline_query
importutil.import_modules_in_pkg(loaders)
@@ -91,13 +92,13 @@ class ToolManager:
return tools
async def execute_func_call(self, name: str, parameters: dict) -> typing.Any:
async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
"""执行函数调用"""
if await self.plugin_tool_loader.has_tool(name):
return await self.plugin_tool_loader.invoke_tool(name, parameters)
return await self.plugin_tool_loader.invoke_tool(name, parameters, query)
elif await self.mcp_tool_loader.has_tool(name):
return await self.mcp_tool_loader.invoke_tool(name, parameters)
return await self.mcp_tool_loader.invoke_tool(name, parameters, query)
else:
raise ValueError(f'未找到工具: {name}')