diff --git a/components.yaml b/components.yaml index b91b8813..5d8e75d2 100644 --- a/components.yaml +++ b/components.yaml @@ -9,7 +9,6 @@ spec: components: ComponentTemplate: fromFiles: - - pkg/platform/adapter.yaml - pkg/provider/modelmgr/requester.yaml MessagePlatformAdapter: fromDirs: diff --git a/docker-compose.yaml b/docker-compose.yaml index 6f75e85d..e1231d66 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,3 +1,4 @@ +# This file is deprecated, and will be replaced by docker/docker-compose.yaml in next version. version: "3" services: @@ -13,4 +14,4 @@ services: ports: - 5300:5300 # 供 WebUI 使用 - 2280-2290:2280-2290 # 供消息平台适配器方向连接 - # 根据具体环境配置网络 + # 根据具体环境配置网络 \ No newline at end of file diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml new file mode 100644 index 00000000..107a9e26 --- /dev/null +++ b/docker/docker-compose.yaml @@ -0,0 +1,36 @@ +version: "3" + +services: + + langbot_plugin_runtime: + image: rockchin/langbot:latest + container_name: langbot_plugin_runtime + volumes: + - ./data/plugins:/app/data/plugins + ports: + - 5401:5401 + restart: on-failure + environment: + - TZ=Asia/Shanghai + command: ["uv", "run", "-m", "langbot_plugin.cli.__init__", "rt"] + networks: + - langbot_network + + langbot: + image: rockchin/langbot:latest + container_name: langbot + volumes: + - ./data:/app/data + - ./plugins:/app/plugins + restart: on-failure + environment: + - TZ=Asia/Shanghai + ports: + - 5300:5300 # For web ui + - 2280-2290:2280-2290 # For platform webhook + networks: + - langbot_network + +networks: + langbot_network: + driver: bridge diff --git a/libs/qq_official_api/api.py b/libs/qq_official_api/api.py index cb5f658a..c5728437 100644 --- a/libs/qq_official_api/api.py +++ b/libs/qq_official_api/api.py @@ -3,7 +3,7 @@ from quart import request import httpx from quart import Quart from typing import Callable, Dict, Any -from pkg.platform.types import events as platform_events +import langbot_plugin.api.entities.builtin.platform.events as platform_events from .qqofficialevent import QQOfficialEvent import json import traceback diff --git a/libs/slack_api/api.py b/libs/slack_api/api.py index 746d15da..241a42cf 100644 --- a/libs/slack_api/api.py +++ b/libs/slack_api/api.py @@ -4,7 +4,7 @@ from quart import Quart, jsonify, request from slack_sdk.web.async_client import AsyncWebClient from .slackevent import SlackEvent from typing import Callable -from pkg.platform.types import events as platform_events +import langbot_plugin.api.entities.builtin.platform.events as platform_events class SlackClient: diff --git a/libs/wecom_api/api.py b/libs/wecom_api/api.py index c1328b0d..352a550c 100644 --- a/libs/wecom_api/api.py +++ b/libs/wecom_api/api.py @@ -8,7 +8,7 @@ from quart import Quart import xml.etree.ElementTree as ET from typing import Callable, Dict, Any from .wecomevent import WecomEvent -from pkg.platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message import aiofiles diff --git a/libs/wecom_customer_service_api/api.py b/libs/wecom_customer_service_api/api.py index 32fab7f7..f912326e 100644 --- a/libs/wecom_customer_service_api/api.py +++ b/libs/wecom_customer_service_api/api.py @@ -8,7 +8,7 @@ from quart import Quart import xml.etree.ElementTree as ET from typing import Callable from .wecomcsevent import WecomCSEvent -from pkg.platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message import aiofiles diff --git a/main.py b/main.py index 1909e343..bf6cd39a 100644 --- a/main.py +++ b/main.py @@ -19,8 +19,14 @@ asciiart = r""" async def main_entry(loop: asyncio.AbstractEventLoop): parser = argparse.ArgumentParser(description='LangBot') parser.add_argument('--skip-plugin-deps-check', action='store_true', help='跳过插件依赖项检查', default=False) + parser.add_argument('--standalone-runtime', action='store_true', help='使用独立插件运行时', default=False) args = parser.parse_args() + if args.standalone_runtime: + from pkg.utils import platform + + platform.standalone_runtime = True + print(asciiart) import sys @@ -47,13 +53,13 @@ async def main_entry(loop: asyncio.AbstractEventLoop): if not args.skip_plugin_deps_check: await deps.precheck_plugin_deps() - # 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1 - import pydantic.version + # # 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1 + # import pydantic.version - if pydantic.version.VERSION < '2.0': - import pydantic + # if pydantic.version.VERSION < '2.0': + # import pydantic - sys.modules['pydantic.v1'] = pydantic + # sys.modules['pydantic.v1'] = pydantic # 检查配置文件 diff --git a/pkg/api/http/controller/groups/pipelines/webchat.py b/pkg/api/http/controller/groups/pipelines/webchat.py index 7eea471a..13f955d8 100644 --- a/pkg/api/http/controller/groups/pipelines/webchat.py +++ b/pkg/api/http/controller/groups/pipelines/webchat.py @@ -44,9 +44,9 @@ class WebChatDebugRouterGroup(group.RouterGroup): 'Content-Type': 'text/event-stream', 'Transfer-Encoding': 'chunked', 'Cache-Control': 'no-cache', - 'Connection': 'keep-alive' + 'Connection': 'keep-alive', } - return quart.Response(stream_generator(generator), mimetype='text/event-stream',headers=headers) + return quart.Response(stream_generator(generator), mimetype='text/event-stream', headers=headers) else: # non-stream result = None diff --git a/pkg/api/http/controller/groups/plugins.py b/pkg/api/http/controller/groups/plugins.py index b7e0a5e9..28a3c44c 100644 --- a/pkg/api/http/controller/groups/plugins.py +++ b/pkg/api/http/controller/groups/plugins.py @@ -1,10 +1,11 @@ from __future__ import annotations - +import base64 import quart from .....core import taskmgr from .. import group +from langbot_plugin.runtime.plugin.mgr import PluginInstallSource @group.group_class('plugins', '/api/v1/plugins') @@ -12,35 +13,22 @@ class PluginsRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: - plugins = self.ap.plugin_mgr.plugins() + plugins = await self.ap.plugin_connector.list_plugins() - plugins_data = [plugin.model_dump() for plugin in plugins] - - return self.success(data={'plugins': plugins_data}) + return self.success(data={'plugins': plugins}) @self.route( - '///toggle', - methods=['PUT'], - auth_type=group.AuthType.USER_TOKEN, - ) - async def _(author: str, plugin_name: str) -> str: - data = await quart.request.json - target_enabled = data.get('target_enabled') - await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled) - return self.success() - - @self.route( - '///update', + '///upgrade', methods=['POST'], auth_type=group.AuthType.USER_TOKEN, ) async def _(author: str, plugin_name: str) -> str: ctx = taskmgr.TaskContext.new() wrapper = self.ap.task_mgr.create_user_task( - self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx), + self.ap.plugin_connector.upgrade_plugin(author, plugin_name, task_context=ctx), kind='plugin-operation', - name=f'plugin-update-{plugin_name}', - label=f'Updating plugin {plugin_name}', + name=f'plugin-upgrade-{plugin_name}', + label=f'Upgrading plugin {plugin_name}', context=ctx, ) return self.success(data={'task_id': wrapper.id}) @@ -52,14 +40,14 @@ class PluginsRouterGroup(group.RouterGroup): ) async def _(author: str, plugin_name: str) -> str: if quart.request.method == 'GET': - plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name) + plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name) if plugin is None: return self.http_status(404, -1, 'plugin not found') - return self.success(data={'plugin': plugin.model_dump()}) + return self.success(data={'plugin': plugin}) elif quart.request.method == 'DELETE': ctx = taskmgr.TaskContext.new() wrapper = self.ap.task_mgr.create_user_task( - self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx), + self.ap.plugin_connector.delete_plugin(author, plugin_name, task_context=ctx), kind='plugin-operation', name=f'plugin-remove-{plugin_name}', label=f'Removing plugin {plugin_name}', @@ -74,23 +62,32 @@ class PluginsRouterGroup(group.RouterGroup): auth_type=group.AuthType.USER_TOKEN, ) async def _(author: str, plugin_name: str) -> quart.Response: - plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name) + plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name) if plugin is None: return self.http_status(404, -1, 'plugin not found') + if quart.request.method == 'GET': - return self.success(data={'config': plugin.plugin_config}) + return self.success(data={'config': plugin['plugin_config']}) elif quart.request.method == 'PUT': data = await quart.request.json - await self.ap.plugin_mgr.set_plugin_config(plugin, data) + await self.ap.plugin_connector.set_plugin_config(author, plugin_name, data) return self.success(data={}) - @self.route('/reorder', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN) - async def _() -> str: - data = await quart.request.json - await self.ap.plugin_mgr.reorder_plugins(data.get('plugins')) - return self.success() + @self.route( + '///icon', + methods=['GET'], + auth_type=group.AuthType.NONE, + ) + async def _(author: str, plugin_name: str) -> quart.Response: + icon_data = await self.ap.plugin_connector.get_plugin_icon(author, plugin_name) + icon_base64 = icon_data['plugin_icon_base64'] + mime_type = icon_data['mime_type'] + + icon_data = base64.b64decode(icon_base64) + + return quart.Response(icon_data, mimetype=mime_type) @self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: @@ -102,7 +99,47 @@ class PluginsRouterGroup(group.RouterGroup): self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx), kind='plugin-operation', name='plugin-install-github', - label=f'Installing plugin ...{short_source_str}', + label=f'Installing plugin from github ...{short_source_str}', + context=ctx, + ) + + return self.success(data={'task_id': wrapper.id}) + + @self.route('/install/marketplace', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + async def _() -> str: + data = await quart.request.json + + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self.ap.plugin_connector.install_plugin(PluginInstallSource.MARKETPLACE, data, task_context=ctx), + kind='plugin-operation', + name='plugin-install-marketplace', + label=f'Installing plugin from marketplace ...{data}', + context=ctx, + ) + + return self.success(data={'task_id': wrapper.id}) + + @self.route('/install/local', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + async def _() -> str: + file = (await quart.request.files).get('file') + if file is None: + return self.http_status(400, -1, 'file is required') + + file_bytes = file.read() + + file_base64 = base64.b64encode(file_bytes).decode('utf-8') + + data = { + 'plugin_file': file_base64, + } + + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self.ap.plugin_connector.install_plugin(PluginInstallSource.LOCAL, data, task_context=ctx), + kind='plugin-operation', + name='plugin-install-local', + label=f'Installing plugin from local ...{file.filename}', context=ctx, ) diff --git a/pkg/api/http/controller/groups/system.py b/pkg/api/http/controller/groups/system.py index c4cab602..ee107401 100644 --- a/pkg/api/http/controller/groups/system.py +++ b/pkg/api/http/controller/groups/system.py @@ -14,6 +14,12 @@ class SystemRouterGroup(group.RouterGroup): 'version': constants.semantic_version, 'debug': constants.debug_mode, 'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()), + 'enable_marketplace': self.ap.instance_config.data['plugin'].get('enable_marketplace', True), + 'cloud_service_url': ( + self.ap.instance_config.data['plugin']['cloud_service_url'] + if 'cloud_service_url' in self.ap.instance_config.data['plugin'] + else 'https://space.langbot.app' + ), } ) @@ -35,16 +41,7 @@ class SystemRouterGroup(group.RouterGroup): return self.success(data=task.to_dict()) - @self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) - async def _() -> str: - json_data = await quart.request.json - - scope = json_data.get('scope') - - await self.ap.reload(scope=scope) - return self.success() - - @self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + @self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: if not constants.debug_mode: return self.http_status(403, 403, 'Forbidden') @@ -54,3 +51,39 @@ class SystemRouterGroup(group.RouterGroup): ap = self.ap return self.success(data=exec(py_code, {'ap': ap})) + + @self.route('/debug/tools/call', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + async def _() -> str: + if not constants.debug_mode: + return self.http_status(403, 403, 'Forbidden') + + data = await quart.request.json + + return self.success( + data=await self.ap.tool_mgr.execute_func_call(data['tool_name'], data['tool_parameters']) + ) + + @self.route( + '/debug/plugin/action', + methods=['POST'], + auth_type=group.AuthType.USER_TOKEN, + ) + async def _() -> str: + if not constants.debug_mode: + return self.http_status(403, 403, 'Forbidden') + + data = await quart.request.json + + class AnoymousAction: + value = 'anonymous_action' + + def __init__(self, value: str): + self.value = value + + resp = await self.ap.plugin_connector.handler.call_action( + AnoymousAction(data['action']), + data['data'], + timeout=data.get('timeout', 10), + ) + + return self.success(data=resp) diff --git a/pkg/api/http/controller/groups/user.py b/pkg/api/http/controller/groups/user.py index b84b2292..f1525a6b 100644 --- a/pkg/api/http/controller/groups/user.py +++ b/pkg/api/http/controller/groups/user.py @@ -71,15 +71,15 @@ class UserRouterGroup(group.RouterGroup): @self.route('/change-password', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) async def _(user_email: str) -> str: json_data = await quart.request.json - + current_password = json_data['current_password'] new_password = json_data['new_password'] - + try: await self.ap.user_service.change_password(user_email, current_password, new_password) except argon2.exceptions.VerifyMismatchError: return self.http_status(400, -1, 'Current password is incorrect') except ValueError as e: return self.http_status(400, -1, str(e)) - + return self.success(data={'user': user_email}) diff --git a/pkg/api/http/service/bot.py b/pkg/api/http/service/bot.py index adf19d03..3ced0e51 100644 --- a/pkg/api/http/service/bot.py +++ b/pkg/api/http/service/bot.py @@ -17,16 +17,20 @@ class BotService: def __init__(self, ap: app.Application) -> None: self.ap = ap - async def get_bots(self) -> list[dict]: - """Get all bots""" + async def get_bots(self, include_secret: bool = True) -> list[dict]: + """获取所有机器人""" result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot)) bots = result.all() - return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots] + masked_columns = [] + if not include_secret: + masked_columns = ['adapter_config'] - async def get_bot(self, bot_uuid: str) -> dict | None: - """Get bot""" + return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns) for bot in bots] + + async def get_bot(self, bot_uuid: str, include_secret: bool = True) -> dict | None: + """获取机器人""" result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) ) @@ -36,7 +40,27 @@ class BotService: if bot is None: return None - return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) + masked_columns = [] + if not include_secret: + masked_columns = ['adapter_config'] + + return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns) + + async def get_runtime_bot_info(self, bot_uuid: str, include_secret: bool = True) -> dict: + """获取机器人运行时信息""" + persistence_bot = await self.get_bot(bot_uuid, include_secret) + if persistence_bot is None: + raise Exception('Bot not found') + + adapter_runtime_values = {} + + runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid) + if runtime_bot is not None: + adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id + + persistence_bot['adapter_runtime_values'] = adapter_runtime_values + + return persistence_bot async def create_bot(self, bot_data: dict) -> str: """Create bot""" diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index d3f3d5d8..036c1b9c 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -7,7 +7,7 @@ from ....core import app from ....entity.persistence import model as persistence_model from ....entity.persistence import pipeline as persistence_pipeline from ....provider.modelmgr import requester as model_requester -from ....provider import entities as llm_entities +from langbot_plugin.api.entities.builtin.provider import message as provider_message class LLMModelsService: @@ -16,11 +16,19 @@ class LLMModelsService: def __init__(self, ap: app.Application) -> None: self.ap = ap - async def get_llm_models(self) -> list[dict]: + async def get_llm_models(self, include_secret: bool = True) -> list[dict]: result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) models = result.all() - return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models] + + masked_columns = [] + if not include_secret: + masked_columns = ['api_keys'] + + return [ + self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model, masked_columns) + for model in models + ] async def create_llm_model(self, model_data: dict) -> str: model_data['uuid'] = str(uuid.uuid4()) @@ -99,7 +107,7 @@ class LLMModelsService: await runtime_llm_model.requester.invoke_llm( query=None, model=runtime_llm_model, - messages=[llm_entities.Message(role='user', content='Hello, world!')], + messages=[provider_message.Message(role='user', content='Hello, world!')], funcs=[], extra_args=model_data.get('extra_args', {}), ) diff --git a/pkg/api/http/service/user.py b/pkg/api/http/service/user.py index 7a1f0323..b2403d15 100644 --- a/pkg/api/http/service/user.py +++ b/pkg/api/http/service/user.py @@ -85,15 +85,15 @@ class UserService: async def change_password(self, user_email: str, current_password: str, new_password: str) -> None: ph = argon2.PasswordHasher() - + user_obj = await self.get_user_by_email(user_email) if user_obj is None: raise ValueError('User not found') - + ph.verify(user_obj.password, current_password) - + hashed_password = ph.hash(new_password) - + await self.ap.persistence_mgr.execute_async( sqlalchemy.update(user.User).where(user.User.user == user_email).values(password=hashed_password) ) diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 1bd03fcf..17da8161 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -2,9 +2,12 @@ from __future__ import annotations import typing -from ..core import app, entities as core_entities -from . import entities, operator, errors +from ..core import app +from . import operator from ..utils import importutil +import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors # 引入所有算子以便注册 from . import operators @@ -13,13 +16,11 @@ importutil.import_modules_in_pkg(operators) class CommandManager: - """命令管理器""" - ap: app.Application cmd_list: list[operator.CommandOperator] """ - 运行时命令列表,扁平存储,各个对象包含对应的子节点引用 + Runtime command list, flat storage, each object contains a reference to the corresponding child node """ def __init__(self, ap: app.Application): @@ -55,43 +56,28 @@ class CommandManager: async def _execute( self, - context: entities.ExecuteContext, + context: command_context.ExecuteContext, operator_list: list[operator.CommandOperator], operator: operator.CommandOperator = None, - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: """执行命令""" - found = False - if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名 - for oper in operator_list: - if (context.crt_params[0] == oper.name or context.crt_params[0] in oper.alias) and ( - oper.parent_class is None or oper.parent_class == operator.__class__ - ): - found = True + command_list = await self.ap.plugin_connector.list_commands() - context.crt_command = context.crt_params[0] - context.crt_params = context.crt_params[1:] - - async for ret in self._execute(context, oper.children, oper): - yield ret - break - - if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错 - if operator is None: - yield entities.CommandReturn(error=errors.CommandNotFoundError(context.crt_params[0])) - else: - if operator.lowest_privilege > context.privilege: - yield entities.CommandReturn(error=errors.CommandPrivilegeError(operator.name)) - else: - async for ret in operator.execute(context): - yield ret + for command in command_list: + if command.metadata.name == context.command: + async for ret in self.ap.plugin_connector.execute_command(context): + yield ret + break + else: + yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(context.command)) async def execute( self, command_text: str, - query: core_entities.Query, - session: core_entities.Session, - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + query: pipeline_query.Query, + session: provider_session.Session, + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: """执行命令""" privilege = 1 @@ -99,8 +85,8 @@ class CommandManager: if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']: privilege = 2 - ctx = entities.ExecuteContext( - query=query, + ctx = command_context.ExecuteContext( + query_id=query.query_id, session=session, command_text=command_text, command='', @@ -110,5 +96,9 @@ class CommandManager: privilege=privilege, ) + ctx.command = ctx.params[0] + + ctx.shift() + async for ret in self._execute(ctx, self.cmd_list): yield ret diff --git a/pkg/command/entities.py b/pkg/command/entities.py deleted file mode 100644 index cccd588e..00000000 --- a/pkg/command/entities.py +++ /dev/null @@ -1,74 +0,0 @@ -from __future__ import annotations - -import typing - -import pydantic.v1 as pydantic - -from ..core import entities as core_entities -from . import errors -from ..platform.types import message as platform_message - - -class CommandReturn(pydantic.BaseModel): - """命令返回值""" - - text: typing.Optional[str] = None - """文本 - """ - - image: typing.Optional[platform_message.Image] = None - """弃用""" - - image_url: typing.Optional[str] = None - """图片链接 - """ - - error: typing.Optional[errors.CommandError] = None - """错误 - """ - - class Config: - arbitrary_types_allowed = True - - -class ExecuteContext(pydantic.BaseModel): - """单次命令执行上下文""" - - query: core_entities.Query - """本次消息的请求对象""" - - session: core_entities.Session - """本次消息所属的会话对象""" - - command_text: str - """命令完整文本""" - - command: str - """命令名称""" - - crt_command: str - """当前命令 - - 多级命令中crt_command为当前命令,command为根命令。 - 例如:!plugin on Webwlkr - 处理到plugin时,command为plugin,crt_command为plugin - 处理到on时,command为plugin,crt_command为on - """ - - params: list[str] - """命令参数 - - 整个命令以空格分割后的参数列表 - """ - - crt_params: list[str] - """当前命令参数 - - 多级命令中crt_params为当前命令参数,params为根命令参数。 - 例如:!plugin on Webwlkr - 处理到plugin时,params为['on', 'Webwlkr'],crt_params为['on', 'Webwlkr'] - 处理到on时,params为['on', 'Webwlkr'],crt_params为['Webwlkr'] - """ - - privilege: int - """发起人权限""" diff --git a/pkg/command/errors.py b/pkg/command/errors.py deleted file mode 100644 index df05b3d1..00000000 --- a/pkg/command/errors.py +++ /dev/null @@ -1,26 +0,0 @@ -class CommandError(Exception): - def __init__(self, message: str = None): - self.message = message - - def __str__(self): - return self.message - - -class CommandNotFoundError(CommandError): - def __init__(self, message: str = None): - super().__init__('未知命令: ' + message) - - -class CommandPrivilegeError(CommandError): - def __init__(self, message: str = None): - super().__init__('权限不足: ' + message) - - -class ParamNotEnoughError(CommandError): - def __init__(self, message: str = None): - super().__init__('参数不足: ' + message) - - -class CommandOperationError(CommandError): - def __init__(self, message: str = None): - super().__init__('操作失败: ' + message) diff --git a/pkg/command/operator.py b/pkg/command/operator.py index 9ee3de37..0157cf28 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -4,7 +4,7 @@ import typing import abc from ..core import app -from . import entities +from langbot_plugin.api.entities.builtin.command import context as command_context preregistered_operators: list[typing.Type[CommandOperator]] = [] @@ -95,16 +95,18 @@ class CommandOperator(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: """实现此方法以执行命令 支持多次yield以返回多个结果。 例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。 Args: - context (entities.ExecuteContext): 命令执行上下文 + context (command_context.ExecuteContext): 命令执行上下文 Yields: - entities.CommandReturn: 命令返回封装 + command_context.CommandReturn: 命令返回封装 """ pass diff --git a/pkg/command/operators/cmd.py b/pkg/command/operators/cmd.py index f5a69a7b..cb0c3554 100644 --- a/pkg/command/operators/cmd.py +++ b/pkg/command/operators/cmd.py @@ -2,14 +2,17 @@ from __future__ import annotations import typing -from .. import operator, entities, errors +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors @operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>') class CmdOperator(operator.CommandOperator): """命令列表""" - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: """执行""" if len(context.crt_params) == 0: reply_str = '当前所有命令: \n\n' @@ -20,7 +23,7 @@ class CmdOperator(operator.CommandOperator): reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助' - yield entities.CommandReturn(text=reply_str.strip()) + yield command_context.CommandReturn(text=reply_str.strip()) else: cmd_name = context.crt_params[0] @@ -33,9 +36,9 @@ class CmdOperator(operator.CommandOperator): break if cmd is None: - yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name)) + yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(cmd_name)) else: reply_str = f'{cmd.name}: {cmd.help}\n\n' reply_str += f'使用方法: \n{cmd.usage}' - yield entities.CommandReturn(text=reply_str.strip()) + yield command_context.CommandReturn(text=reply_str.strip()) diff --git a/pkg/command/operators/delc.py b/pkg/command/operators/delc.py index 7e72ff3c..06db3d1e 100644 --- a/pkg/command/operators/delc.py +++ b/pkg/command/operators/delc.py @@ -2,23 +2,26 @@ from __future__ import annotations import typing -from .. import operator, entities, errors +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors @operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all') class DelOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: if context.session.conversations: delete_index = 0 if len(context.crt_params) > 0: try: delete_index = int(context.crt_params[0]) except Exception: - yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数')) + yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引必须是整数')) return if delete_index < 0 or delete_index >= len(context.session.conversations): - yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围')) + yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引超出范围')) return # 倒序 @@ -29,15 +32,17 @@ class DelOperator(operator.CommandOperator): del context.session.conversations[to_delete_index] - yield entities.CommandReturn(text=f'已删除对话: {delete_index}') + yield command_context.CommandReturn(text=f'已删除对话: {delete_index}') else: - yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话')) @operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator) class DelAllOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: context.session.conversations = [] context.session.using_conversation = None - yield entities.CommandReturn(text='已删除所有对话') + yield command_context.CommandReturn(text='已删除所有对话') diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py index 648cc5e2..e7828a51 100644 --- a/pkg/command/operators/func.py +++ b/pkg/command/operators/func.py @@ -1,19 +1,20 @@ from __future__ import annotations from typing import AsyncGenerator -from .. import operator, entities +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context @operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func') class FuncOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> AsyncGenerator[command_context.CommandReturn, None]: reply_str = '当前已启用的内容函数: \n\n' index = 1 - all_functions = await self.ap.tool_mgr.get_all_functions( - plugin_enabled=True, - ) + all_functions = await self.ap.tool_mgr.get_all_tools() for func in all_functions: reply_str += '{}. {}:\n{}\n\n'.format( @@ -23,4 +24,4 @@ class FuncOperator(operator.CommandOperator): ) index += 1 - yield entities.CommandReturn(text=reply_str) + yield command_context.CommandReturn(text=reply_str) diff --git a/pkg/command/operators/help.py b/pkg/command/operators/help.py index 91ad66dc..609f05ad 100644 --- a/pkg/command/operators/help.py +++ b/pkg/command/operators/help.py @@ -2,14 +2,17 @@ from __future__ import annotations import typing -from .. import operator, entities +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context @operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>') class HelpOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app' help += '\n发送命令 !cmd 可查看命令列表' - yield entities.CommandReturn(text=help) + yield command_context.CommandReturn(text=help) diff --git a/pkg/command/operators/last.py b/pkg/command/operators/last.py index 25b1fc6a..3f92e2e2 100644 --- a/pkg/command/operators/last.py +++ b/pkg/command/operators/last.py @@ -3,26 +3,31 @@ from __future__ import annotations import typing -from .. import operator, entities, errors +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors @operator.operator_class(name='last', help='切换到前一个对话', usage='!last') class LastOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: if context.session.conversations: # 找到当前会话的上一个会话 for index in range(len(context.session.conversations) - 1, -1, -1): if context.session.conversations[index] == context.session.using_conversation: if index == 0: - yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了')) + yield command_context.CommandReturn( + error=command_errors.CommandOperationError('已经是第一个对话了') + ) return else: context.session.using_conversation = context.session.conversations[index - 1] time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S') - yield entities.CommandReturn( + yield command_context.CommandReturn( text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}' ) return else: - yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话')) diff --git a/pkg/command/operators/list.py b/pkg/command/operators/list.py index 70ff3945..ca1bf8e9 100644 --- a/pkg/command/operators/list.py +++ b/pkg/command/operators/list.py @@ -2,19 +2,22 @@ from __future__ import annotations import typing -from .. import operator, entities, errors +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors @operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>') class ListOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: page = 0 if len(context.crt_params) > 0: try: page = int(context.crt_params[0] - 1) except Exception: - yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数')) + yield command_context.CommandReturn(error=command_errors.CommandOperationError('页码应为整数')) return record_per_page = 10 @@ -45,4 +48,4 @@ class ListOperator(operator.CommandOperator): else: content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}' - yield entities.CommandReturn(text=f'第 {page + 1} 页 (时间倒序):\n{content}') + yield command_context.CommandReturn(text=f'第 {page + 1} 页 (时间倒序):\n{content}') diff --git a/pkg/command/operators/next.py b/pkg/command/operators/next.py index 938c8331..87cc565c 100644 --- a/pkg/command/operators/next.py +++ b/pkg/command/operators/next.py @@ -2,26 +2,31 @@ from __future__ import annotations import typing -from .. import operator, entities, errors +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors @operator.operator_class(name='next', help='切换到后一个对话', usage='!next') class NextOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: if context.session.conversations: # 找到当前会话的下一个会话 for index in range(len(context.session.conversations)): if context.session.conversations[index] == context.session.using_conversation: if index == len(context.session.conversations) - 1: - yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了')) + yield command_context.CommandReturn( + error=command_errors.CommandOperationError('已经是最后一个对话了') + ) return else: context.session.using_conversation = context.session.conversations[index + 1] time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S') - yield entities.CommandReturn( + yield command_context.CommandReturn( text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}' ) return else: - yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话')) diff --git a/pkg/command/operators/plugin.py b/pkg/command/operators/plugin.py index 40ec0e3a..1c135bd4 100644 --- a/pkg/command/operators/plugin.py +++ b/pkg/command/operators/plugin.py @@ -2,7 +2,8 @@ from __future__ import annotations import typing import traceback -from .. import operator, entities, errors +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors @operator.operator_class( @@ -11,7 +12,9 @@ from .. import operator, entities, errors usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>', ) class PluginOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: plugin_list = self.ap.plugin_mgr.plugins() reply_str = '所有插件({}):\n'.format(len(plugin_list)) idx = 0 @@ -27,32 +30,36 @@ class PluginOperator(operator.CommandOperator): idx += 1 - yield entities.CommandReturn(text=reply_str) + yield command_context.CommandReturn(text=reply_str) @operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator) class PluginGetOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址')) + yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件仓库地址')) else: repo = context.crt_params[0] - yield entities.CommandReturn(text='正在安装插件...') + yield command_context.CommandReturn(text='正在安装插件...') try: await self.ap.plugin_mgr.install_plugin(repo) - yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件') + yield command_context.CommandReturn(text='插件安装成功,请重启程序以加载插件') except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError('插件安装失败: ' + str(e))) + yield command_context.CommandReturn(error=command_errors.CommandError('插件安装失败: ' + str(e))) @operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator) class PluginUpdateOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称')) else: plugin_name = context.crt_params[0] @@ -60,24 +67,26 @@ class PluginUpdateOperator(operator.CommandOperator): plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) if plugin_container is not None: - yield entities.CommandReturn(text='正在更新插件...') + yield command_context.CommandReturn(text='正在更新插件...') await self.ap.plugin_mgr.update_plugin(plugin_name) - yield entities.CommandReturn(text='插件更新成功,请重启程序以加载插件') + yield command_context.CommandReturn(text='插件更新成功,请重启程序以加载插件') else: - yield entities.CommandReturn(error=errors.CommandError('插件更新失败: 未找到插件')) + yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: 未找到插件')) except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) + yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e))) @operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator) class PluginUpdateAllOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: try: plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()] if plugins: - yield entities.CommandReturn(text='正在更新插件...') + yield command_context.CommandReturn(text='正在更新插件...') updated = [] try: for plugin_name in plugins: @@ -85,20 +94,22 @@ class PluginUpdateAllOperator(operator.CommandOperator): updated.append(plugin_name) except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) - yield entities.CommandReturn(text='已更新插件: {}'.format(', '.join(updated))) + yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e))) + yield command_context.CommandReturn(text='已更新插件: {}'.format(', '.join(updated))) else: - yield entities.CommandReturn(text='没有可更新的插件') + yield command_context.CommandReturn(text='没有可更新的插件') except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) + yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e))) @operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator) class PluginDelOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称')) else: plugin_name = context.crt_params[0] @@ -106,51 +117,55 @@ class PluginDelOperator(operator.CommandOperator): plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) if plugin_container is not None: - yield entities.CommandReturn(text='正在删除插件...') + yield command_context.CommandReturn(text='正在删除插件...') await self.ap.plugin_mgr.uninstall_plugin(plugin_name) - yield entities.CommandReturn(text='插件删除成功,请重启程序以加载插件') + yield command_context.CommandReturn(text='插件删除成功,请重启程序以加载插件') else: - yield entities.CommandReturn(error=errors.CommandError('插件删除失败: 未找到插件')) + yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: 未找到插件')) except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError('插件删除失败: ' + str(e))) + yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: ' + str(e))) @operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator) class PluginEnableOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称')) else: plugin_name = context.crt_params[0] try: if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True): - yield entities.CommandReturn(text='已启用插件: {}'.format(plugin_name)) + yield command_context.CommandReturn(text='已启用插件: {}'.format(plugin_name)) else: - yield entities.CommandReturn( - error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name)) + yield command_context.CommandReturn( + error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name)) ) except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e))) + yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e))) @operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator) class PluginDisableOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称')) else: plugin_name = context.crt_params[0] try: if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False): - yield entities.CommandReturn(text='已禁用插件: {}'.format(plugin_name)) + yield command_context.CommandReturn(text='已禁用插件: {}'.format(plugin_name)) else: - yield entities.CommandReturn( - error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name)) + yield command_context.CommandReturn( + error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name)) ) except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e))) + yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e))) diff --git a/pkg/command/operators/prompt.py b/pkg/command/operators/prompt.py index fdcba2bd..b43be2cf 100644 --- a/pkg/command/operators/prompt.py +++ b/pkg/command/operators/prompt.py @@ -2,19 +2,22 @@ from __future__ import annotations import typing -from .. import operator, entities, errors +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors @operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt') class PromptOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: """执行""" if context.session.using_conversation is None: - yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话')) else: reply_str = '当前对话所有内容:\n\n' for msg in context.session.using_conversation.messages: reply_str += f'{msg.role}: {msg.content}\n' - yield entities.CommandReturn(text=reply_str) + yield command_context.CommandReturn(text=reply_str) diff --git a/pkg/command/operators/resend.py b/pkg/command/operators/resend.py index 39789fef..14bfee99 100644 --- a/pkg/command/operators/resend.py +++ b/pkg/command/operators/resend.py @@ -2,15 +2,18 @@ from __future__ import annotations import typing -from .. import operator, entities, errors +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors @operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend') class ResendOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: # 回滚到最后一条用户message前 if context.session.using_conversation is None: - yield entities.CommandReturn(error=errors.CommandError('当前没有对话')) + yield command_context.CommandReturn(error=command_errors.CommandError('当前没有对话')) else: conv_msg = context.session.using_conversation.messages @@ -23,4 +26,4 @@ class ResendOperator(operator.CommandOperator): conv_msg.pop() # 不重发了,提示用户已删除就行了 - yield entities.CommandReturn(text='已删除最后一次请求记录') + yield command_context.CommandReturn(text='已删除最后一次请求记录') diff --git a/pkg/command/operators/reset.py b/pkg/command/operators/reset.py index 008143a1..0c85fb32 100644 --- a/pkg/command/operators/reset.py +++ b/pkg/command/operators/reset.py @@ -2,13 +2,16 @@ from __future__ import annotations import typing -from .. import operator, entities +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context @operator.operator_class(name='reset', help='重置当前会话', usage='!reset') class ResetOperator(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: """执行""" context.session.using_conversation = None - yield entities.CommandReturn(text='已重置当前会话') + yield command_context.CommandReturn(text='已重置当前会话') diff --git a/pkg/command/operators/update.py b/pkg/command/operators/update.py deleted file mode 100644 index 29b8f560..00000000 --- a/pkg/command/operators/update.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -import typing - -from .. import operator, entities - - -@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2) -class UpdateCommand(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: - yield entities.CommandReturn(text='不再支持通过命令更新,请查看 LangBot 文档。') diff --git a/pkg/command/operators/version.py b/pkg/command/operators/version.py index 200875aa..5b3c3358 100644 --- a/pkg/command/operators/version.py +++ b/pkg/command/operators/version.py @@ -2,12 +2,15 @@ from __future__ import annotations import typing -from .. import operator, entities +from .. import operator +from langbot_plugin.api.entities.builtin.command import context as command_context @operator.operator_class(name='version', help='显示版本信息', usage='!version') class VersionCommand(operator.CommandOperator): - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute( + self, context: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}' try: @@ -16,4 +19,4 @@ class VersionCommand(operator.CommandOperator): except Exception: pass - yield entities.CommandReturn(text=reply_str.strip()) + yield command_context.CommandReturn(text=reply_str.strip()) diff --git a/pkg/core/app.py b/pkg/core/app.py index 21816cfc..27b780f6 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging import asyncio import traceback -import sys import os from ..platform import botmgr as im_mgr @@ -12,7 +11,7 @@ from ..provider.modelmgr import modelmgr as llm_model_mgr from ..provider.tools import toolmgr as llm_tool_mgr from ..config import manager as config_mgr from ..command import cmdmgr -from ..plugin import manager as plugin_mgr +from ..plugin import connector as plugin_connector from ..pipeline import pool from ..pipeline import controller, pipelinemgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr @@ -80,7 +79,7 @@ class Application: # ========================= - plugin_mgr: plugin_mgr.PluginManager = None + plugin_connector: plugin_connector.PluginRuntimeConnector = None query_pool: pool.QueryPool = None @@ -128,7 +127,7 @@ class Application: async def run(self): try: - await self.plugin_mgr.initialize_plugins() + await self.plugin_connector.initialize_plugins() # 后续可能会允许动态重启其他任务 # 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程 @@ -169,6 +168,9 @@ class Application: self.logger.error(f'Application runtime fatal exception: {e}') self.logger.debug(f'Traceback: {traceback.format_exc()}') + def dispose(self): + self.plugin_connector.dispose() + async def print_web_access_info(self): """Print access webui tips""" @@ -195,59 +197,3 @@ class Application: """.strip() for line in tips.split('\n'): self.logger.info(line) - - async def reload( - self, - scope: core_entities.LifecycleControlScope, - ): - match scope: - case core_entities.LifecycleControlScope.PLATFORM.value: - self.logger.info('Hot reload scope=' + scope) - await self.platform_mgr.shutdown() - - self.platform_mgr = im_mgr.PlatformManager(self) - - await self.platform_mgr.initialize() - - self.task_mgr.create_task( - self.platform_mgr.run(), - name='platform-manager', - scopes=[ - core_entities.LifecycleControlScope.APPLICATION, - core_entities.LifecycleControlScope.PLATFORM, - ], - ) - case core_entities.LifecycleControlScope.PLUGIN.value: - self.logger.info('Hot reload scope=' + scope) - await self.plugin_mgr.destroy_plugins() - - # 删除 sys.module 中所有的 plugins/* 下的模块 - for mod in list(sys.modules.keys()): - if mod.startswith('plugins.'): - del sys.modules[mod] - - self.plugin_mgr = plugin_mgr.PluginManager(self) - await self.plugin_mgr.initialize() - - await self.plugin_mgr.initialize_plugins() - - await self.plugin_mgr.load_plugins() - await self.plugin_mgr.initialize_plugins() - case core_entities.LifecycleControlScope.PROVIDER.value: - self.logger.info('Hot reload scope=' + scope) - - await self.tool_mgr.shutdown() - - llm_model_mgr_inst = llm_model_mgr.ModelManager(self) - await llm_model_mgr_inst.initialize() - self.model_mgr = llm_model_mgr_inst - - llm_session_mgr_inst = llm_session_mgr.SessionManager(self) - await llm_session_mgr_inst.initialize() - self.sess_mgr = llm_session_mgr_inst - - llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self) - await llm_tool_mgr_inst.initialize() - self.tool_mgr = llm_tool_mgr_inst - case _: - pass diff --git a/pkg/core/boot.py b/pkg/core/boot.py index b8243d4a..11a2d5e2 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -51,8 +51,8 @@ async def main(loop: asyncio.AbstractEventLoop): import signal def signal_handler(sig, frame): + app_inst.dispose() print('[Signal] Program exit.') - # ap.shutdown() os._exit(0) signal.signal(signal.SIGINT, signal_handler) diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 1efee3fc..4383f07f 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -1,18 +1,6 @@ from __future__ import annotations import enum -import typing -import datetime -import asyncio - -import pydantic.v1 as pydantic - -from ..provider import entities as llm_entities -from ..provider.modelmgr import requester -from ..provider.tools import entities as tools_entities -from ..platform import adapter as msadapter -from ..platform.types import message as platform_message -from ..platform.types import events as platform_events class LifecycleControlScope(enum.Enum): @@ -20,159 +8,3 @@ class LifecycleControlScope(enum.Enum): PLATFORM = 'platform' PLUGIN = 'plugin' PROVIDER = 'provider' - - -class LauncherTypes(enum.Enum): - """一个请求的发起者类型""" - - PERSON = 'person' - """私聊""" - - GROUP = 'group' - """群聊""" - - -class Query(pydantic.BaseModel): - """一次请求的信息封装""" - - query_id: int - """请求ID,添加进请求池时生成""" - - launcher_type: LauncherTypes - """会话类型,platform处理阶段设置""" - - launcher_id: typing.Union[int, str] - """会话ID,platform处理阶段设置""" - - sender_id: typing.Union[int, str] - """发送者ID,platform处理阶段设置""" - - message_event: platform_events.MessageEvent - """事件,platform收到的原始事件""" - - message_chain: platform_message.MessageChain - """消息链,platform收到的原始消息链""" - - bot_uuid: typing.Optional[str] = None - """机器人UUID。""" - - pipeline_uuid: typing.Optional[str] = None - """流水线UUID。""" - - pipeline_config: typing.Optional[dict[str, typing.Any]] = None - """流水线配置,由 Pipeline 在运行开始时设置。""" - - adapter: msadapter.MessagePlatformAdapter - """消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器""" - - session: typing.Optional[Session] = None - """会话对象,由前置处理器阶段设置""" - - messages: typing.Optional[list[llm_entities.Message]] = [] - """历史消息列表,由前置处理器阶段设置""" - - prompt: typing.Optional[llm_entities.Prompt] = None - """情景预设内容,由前置处理器阶段设置""" - - user_message: typing.Optional[llm_entities.Message] = None - """此次请求的用户消息对象,由前置处理器阶段设置""" - - variables: typing.Optional[dict[str, typing.Any]] = None - """变量,由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。""" - - use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None - """使用的对话模型,由前置处理器阶段设置""" - - use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None - """使用的函数,由前置处理器阶段设置""" - - resp_messages: ( - typing.Optional[list[llm_entities.Message]] - | typing.Optional[list[platform_message.MessageChain]] - | typing.Optional[list[llm_entities.MessageChunk]] - ) = [] - """由Process阶段生成的回复消息对象列表""" - - resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None - """回复消息链,从resp_messages包装而得""" - - # ======= 内部保留 ======= - current_stage: typing.Optional['pkg.pipeline.pipelinemgr.StageInstContainer'] = None - """当前所处阶段""" - - class Config: - arbitrary_types_allowed = True - - # ========== 插件可调用的 API(请求 API) ========== - - def set_variable(self, key: str, value: typing.Any): - """设置变量""" - if self.variables is None: - self.variables = {} - self.variables[key] = value - - def get_variable(self, key: str) -> typing.Any: - """获取变量""" - if self.variables is None: - return None - return self.variables.get(key) - - def get_variables(self) -> dict[str, typing.Any]: - """获取所有变量""" - if self.variables is None: - return {} - return self.variables - - -class Conversation(pydantic.BaseModel): - """对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation,但只有一个当前使用的 Conversation""" - - prompt: llm_entities.Prompt - - messages: list[llm_entities.Message] - - create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - - update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - - use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None - - use_funcs: typing.Optional[list[tools_entities.LLMFunction]] - - pipeline_uuid: str - """流水线UUID。""" - - bot_uuid: str - """机器人UUID。""" - - uuid: typing.Optional[str] = None - """该对话的 uuid,在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。""" - - class Config: - arbitrary_types_allowed = True - - -class Session(pydantic.BaseModel): - """会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}""" - - launcher_type: LauncherTypes - - launcher_id: typing.Union[int, str] - - sender_id: typing.Optional[typing.Union[int, str]] = 0 - - use_prompt_name: typing.Optional[str] = 'default' - - using_conversation: typing.Optional[Conversation] = None - - conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list) - - create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - - update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - - semaphore: typing.Optional[asyncio.Semaphore] = None - """当前会话的信号量,用于限制并发""" - - class Config: - arbitrary_types_allowed = True diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 0f28f0c8..54a64ae8 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -1,10 +1,11 @@ from __future__ import annotations +import asyncio from .. import stage, app from ...utils import version, proxy, announce from ...pipeline import pool, controller, pipelinemgr -from ...plugin import manager as plugin_mgr +from ...plugin import connector as plugin_connector from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr @@ -62,10 +63,13 @@ class BuildAppStage(stage.BootingStage): ap.persistence_mgr = persistence_mgr_inst await persistence_mgr_inst.initialize() - plugin_mgr_inst = plugin_mgr.PluginManager(ap) - await plugin_mgr_inst.initialize() - ap.plugin_mgr = plugin_mgr_inst - await plugin_mgr_inst.load_plugins() + async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None: + await asyncio.sleep(3) + await plugin_connector_inst.initialize() + + plugin_connector_inst = plugin_connector.PluginRuntimeConnector(ap, runtime_disconnect_callback) + await plugin_connector_inst.initialize() + ap.plugin_connector = plugin_connector_inst cmd_mgr_inst = cmdmgr.CommandManager(ap) await cmd_mgr_inst.initialize() diff --git a/pkg/entity/persistence/bstorage.py b/pkg/entity/persistence/bstorage.py new file mode 100644 index 00000000..674dee29 --- /dev/null +++ b/pkg/entity/persistence/bstorage.py @@ -0,0 +1,22 @@ +import sqlalchemy + +from .base import Base + + +class BinaryStorage(Base): + """Current for plugin use only""" + + __tablename__ = 'binary_storages' + + unique_key = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True) + key = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + owner_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + owner = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + value = sqlalchemy.Column(sqlalchemy.LargeBinary, nullable=False) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/pkg/entity/persistence/plugin.py b/pkg/entity/persistence/plugin.py index e777441f..61629586 100644 --- a/pkg/entity/persistence/plugin.py +++ b/pkg/entity/persistence/plugin.py @@ -13,6 +13,8 @@ class PluginSetting(Base): enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True) priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0) config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict) + install_source = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, default='github') + install_info = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict) created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) updated_at = sqlalchemy.Column( sqlalchemy.DateTime, diff --git a/pkg/persistence/mgr.py b/pkg/persistence/mgr.py index 3aa21ad2..9b926733 100644 --- a/pkg/persistence/mgr.py +++ b/pkg/persistence/mgr.py @@ -44,6 +44,38 @@ class PersistenceManager: await self.create_tables() + # run migrations + database_version = await self.execute_async( + sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version') + ) + + database_version = int(database_version.fetchone()[1]) + required_database_version = constants.required_database_version + + if database_version < required_database_version: + migrations = migration.preregistered_db_migrations + migrations.sort(key=lambda x: x.number) + + last_migration_number = database_version + + for migration_cls in migrations: + migration_instance = migration_cls(self.ap) + + if ( + migration_instance.number > database_version + and migration_instance.number <= required_database_version + ): + await migration_instance.upgrade() + await self.execute_async( + sqlalchemy.update(metadata.Metadata) + .where(metadata.Metadata.key == 'database_version') + .values({'value': str(migration_instance.number)}) + ) + last_migration_number = migration_instance.number + self.ap.logger.info(f'Migration {migration_instance.number} completed.') + + self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.') + async def create_tables(self): # create tables async with self.get_db_engine().connect() as conn: @@ -87,38 +119,6 @@ class PersistenceManager: # ================================= - # run migrations - database_version = await self.execute_async( - sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version') - ) - - database_version = int(database_version.fetchone()[1]) - required_database_version = constants.required_database_version - - if database_version < required_database_version: - migrations = migration.preregistered_db_migrations - migrations.sort(key=lambda x: x.number) - - last_migration_number = database_version - - for migration_cls in migrations: - migration_instance = migration_cls(self.ap) - - if ( - migration_instance.number > database_version - and migration_instance.number <= required_database_version - ): - await migration_instance.upgrade() - await self.execute_async( - sqlalchemy.update(metadata.Metadata) - .where(metadata.Metadata.key == 'database_version') - .values({'value': str(migration_instance.number)}) - ) - last_migration_number = migration_instance.number - self.ap.logger.info(f'Migration {migration_instance.number} completed.') - - self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.') - async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult: async with self.get_db_engine().connect() as conn: result = await conn.execute(*args, **kwargs) @@ -128,10 +128,13 @@ class PersistenceManager: def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine: return self.db.get_engine() - def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict: + def serialize_model( + self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base, masked_columns: list[str] = [] + ) -> dict: return { column.name: getattr(data, column.name) if not isinstance(getattr(data, column.name), (datetime.datetime)) else getattr(data, column.name).isoformat() for column in model.__table__.columns + if column.name not in masked_columns } diff --git a/pkg/persistence/migrations/dbm004_plugin_config.py b/pkg/persistence/migrations/dbm004_plugin_config.py new file mode 100644 index 00000000..fc7a175a --- /dev/null +++ b/pkg/persistence/migrations/dbm004_plugin_config.py @@ -0,0 +1,20 @@ +from .. import migration + + +@migration.migration_class(4) +class DBMigratePluginConfig(migration.DBMigration): + """插件配置""" + + async def upgrade(self): + """升级""" + + if 'plugin' not in self.ap.instance_config.data: + self.ap.instance_config.data['plugin'] = { + 'runtime_ws_url': 'ws://localhost:5400/control/ws', + } + + await self.ap.instance_config.dump_config() + + async def downgrade(self): + """降级""" + pass diff --git a/pkg/persistence/migrations/dbm006_plugin_install_source.py b/pkg/persistence/migrations/dbm006_plugin_install_source.py new file mode 100644 index 00000000..37f74929 --- /dev/null +++ b/pkg/persistence/migrations/dbm006_plugin_install_source.py @@ -0,0 +1,32 @@ +import sqlalchemy +from .. import migration + + +@migration.migration_class(6) +class DBMigratePluginInstallSource(migration.DBMigration): + """插件安装来源""" + + async def upgrade(self): + """升级""" + # 查询表结构获取所有列名(异步执行 SQL) + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text('PRAGMA table_info(plugin_settings);')) + # fetchall() 是同步方法,无需 await + columns = [row[1] for row in result.fetchall()] + + # 检查并添加 install_source 列 + if 'install_source' not in columns: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text( + "ALTER TABLE plugin_settings ADD COLUMN install_source VARCHAR(255) NOT NULL DEFAULT 'github'" + ) + ) + + # 检查并添加 install_info 列 + if 'install_info' not in columns: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text("ALTER TABLE plugin_settings ADD COLUMN install_info JSON NOT NULL DEFAULT '{}'") + ) + + async def downgrade(self): + """降级""" + pass diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index ed6ddf74..a1cad2b1 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -1,7 +1,7 @@ from __future__ import annotations from .. import stage, entities -from ...core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @stage.stage_class('BanSessionCheckStage') @@ -14,7 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage): async def initialize(self, pipeline_config: dict): pass - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: found = False mode = query.pipeline_config['trigger']['access-control']['mode'] diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index e035c1d0..26b00411 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -3,12 +3,11 @@ from __future__ import annotations from ...core import app from .. import stage, entities -from ...core import entities as core_entities from . import filter as filter_model, entities as filter_entities -from ...provider import entities as llm_entities -from ...platform.types import message as platform_message +from langbot_plugin.api.entities.builtin.provider import message as provider_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message from ...utils import importutil - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query from . import filters importutil.import_modules_in_pkg(filters) @@ -58,7 +57,7 @@ class ContentFilterStage(stage.PipelineStage): async def _pre_process( self, message: str, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.StageProcessResult: """请求llm前处理消息 只要有一个不通过就不放行,只放行 PASS 的消息 @@ -86,14 +85,14 @@ class ContentFilterStage(stage.PipelineStage): elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 message = result.replacement - query.message_chain = platform_message.MessageChain(platform_message.Plain(message)) + query.message_chain = platform_message.MessageChain([platform_message.Plain(text=message)]) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) async def _post_process( self, message: str, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.StageProcessResult: """请求llm后处理响应 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter @@ -123,7 +122,7 @@ class ContentFilterStage(stage.PipelineStage): return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: """处理""" if stage_inst_name == 'PreContentFilterStage': contain_non_text = False @@ -142,7 +141,7 @@ class ContentFilterStage(stage.PipelineStage): return await self._pre_process(str(query.message_chain).strip(), query) elif stage_inst_name == 'PostContentFilterStage': # 仅处理 query.resp_messages[-1].content 是 str 的情况 - if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance( + if isinstance(query.resp_messages[-1], provider_message.Message) and isinstance( query.resp_messages[-1].content, str ): return await self._post_process(query.resp_messages[-1].content, query) diff --git a/pkg/pipeline/cntfilter/entities.py b/pkg/pipeline/cntfilter/entities.py index 5e804c0d..607eba9a 100644 --- a/pkg/pipeline/cntfilter/entities.py +++ b/pkg/pipeline/cntfilter/entities.py @@ -1,6 +1,6 @@ import enum -import pydantic.v1 as pydantic +import pydantic class ResultLevel(enum.Enum): diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 36d8a7f4..58804fd8 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -3,9 +3,9 @@ from __future__ import annotations import abc import typing -from ...core import app, entities as core_entities +from ...core import app from . import entities - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_filters: list[typing.Type[ContentFilter]] = [] @@ -60,8 +60,8 @@ class ContentFilter(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult: - """Process message + async def process(self, query: pipeline_query.Query, message: str = None, image_url=None) -> entities.FilterResult: + """处理消息 It is divided into two stages, depending on the value of enable_stages. For content filters, you do not need to consider the stage of the message, you only need to check the message content. diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py index 9637aec2..4213e662 100644 --- a/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -4,8 +4,7 @@ import aiohttp from .. import entities from .. import filter as filter_model -from ....core import entities as core_entities - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}' BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token' @@ -27,7 +26,7 @@ class BaiduCloudExamine(filter_model.ContentFilter): ) as resp: return (await resp.json())['access_token'] - async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: + async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult: async with aiohttp.ClientSession() as session: async with session.post( BAIDU_EXAMINE_URL.format(await self._get_token()), diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index b03e79a9..05b25013 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -3,7 +3,7 @@ import re from .. import filter as filter_model from .. import entities -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @filter_model.filter_class('ban-word-filter') @@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter): async def initialize(self): pass - async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: + async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult: found = False for word in self.ap.sensitive_meta.data['words']: diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index b80d90eb..731ab392 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -3,7 +3,7 @@ import re from .. import entities from .. import filter as filter_model -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @filter_model.filter_class('content-ignore') @@ -16,7 +16,7 @@ class ContentIgnore(filter_model.ContentFilter): entities.EnableStage.PRE, ] - async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: + async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult: if 'prefix' in query.pipeline_config['trigger']['ignore-rules']: for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']: if message.startswith(rule): diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index 6679bd88..b1dde4a6 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -3,7 +3,10 @@ from __future__ import annotations import asyncio import traceback -from ..core import app, entities +from ..core import app +from ..core import entities as core_entities + +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class Controller: @@ -22,19 +25,19 @@ class Controller: """事件处理循环""" try: while True: - selected_query: entities.Query = None + selected_query: pipeline_query.Query = None # 取请求 async with self.ap.query_pool: - queries: list[entities.Query] = self.ap.query_pool.queries + queries: list[pipeline_query.Query] = self.ap.query_pool.queries for query in queries: session = await self.ap.sess_mgr.get_session(query) self.ap.logger.debug(f'Checking query {query} session {session}') - if not session.semaphore.locked(): + if not session._semaphore.locked(): selected_query = query - await session.semaphore.acquire() + await session._semaphore.acquire() break @@ -46,7 +49,7 @@ class Controller: if selected_query: - async def _process_query(selected_query: entities.Query): + async def _process_query(selected_query: pipeline_query.Query): async with self.semaphore: # 总并发上限 # find pipeline # Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one. @@ -59,7 +62,7 @@ class Controller: await pipeline.run(selected_query) async with self.ap.query_pool: - (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() + (await self.ap.sess_mgr.get_session(selected_query))._semaphore.release() # 通知其他协程,有新的请求可以处理了 self.ap.query_pool.condition.notify_all() @@ -68,8 +71,8 @@ class Controller: kind='query', name=f'query-{selected_query.query_id}', scopes=[ - entities.LifecycleControlScope.APPLICATION, - entities.LifecycleControlScope.PLATFORM, + core_entities.LifecycleControlScope.APPLICATION, + core_entities.LifecycleControlScope.PLATFORM, ], ) diff --git a/pkg/pipeline/entities.py b/pkg/pipeline/entities.py index dd6434c0..5426685e 100644 --- a/pkg/pipeline/entities.py +++ b/pkg/pipeline/entities.py @@ -3,10 +3,10 @@ from __future__ import annotations import enum import typing -import pydantic.v1 as pydantic -from ..platform.types import message as platform_message +import pydantic -from ..core import entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.platform.message as platform_message class ResultType(enum.Enum): @@ -20,7 +20,7 @@ class ResultType(enum.Enum): class StageProcessResult(pydantic.BaseModel): result_type: ResultType - new_query: entities.Query + new_query: pipeline_query.Query user_notice: typing.Optional[ typing.Union[ diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 03457212..4b461bd6 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -5,10 +5,9 @@ import traceback from . import strategy from .. import stage, entities -from ...core import entities as core_entities -from ...platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message from ...utils import importutil - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query from . import strategies importutil.import_modules_in_pkg(strategies) @@ -67,8 +66,8 @@ class LongTextProcessStage(stage.PipelineStage): await self.strategy_impl.initialize() - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: - # Check if it contains non-Plain components + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: + # 检查是否包含非 Plain 组件 contains_non_plain = False for msg in query.resp_message_chain[-1]: diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py index cb772339..201622ce 100644 --- a/pkg/pipeline/longtext/strategies/forward.py +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -3,9 +3,9 @@ from __future__ import annotations from .. import strategy as strategy_model -from ....core import entities as core_entities -from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.platform.message as platform_message ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay Forward = platform_message.Forward @@ -13,7 +13,7 @@ Forward = platform_message.Forward @strategy_model.strategy_class('forward') class ForwardComponentStrategy(strategy_model.LongTextStrategy): - async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: + async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]: display = ForwardMessageDiaplay( title='Group chat history', brief='[Chat history]', diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py index f96f7265..110f1f81 100644 --- a/pkg/pipeline/longtext/strategies/image.py +++ b/pkg/pipeline/longtext/strategies/image.py @@ -8,10 +8,10 @@ import re from PIL import Image, ImageDraw, ImageFont import functools -from ....platform.types import message as platform_message from .. import strategy as strategy_model -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.platform.message as platform_message @strategy_model.strategy_class('image') @@ -27,7 +27,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): encoding='utf-8', ) - async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: + async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]: img_path = self.text_to_image( text_str=message, save_as='temp/{}.png'.format(int(time.time())), @@ -131,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): text_str: str, save_as='temp.png', width=800, - query: core_entities.Query = None, + query: pipeline_query.Query = None, ): text_str = text_str.replace('\t', ' ') diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py index 5b521067..cb8ce7e1 100644 --- a/pkg/pipeline/longtext/strategy.py +++ b/pkg/pipeline/longtext/strategy.py @@ -4,8 +4,9 @@ import typing from ...core import app -from ...core import entities as core_entities -from ...platform.types import message as platform_message + +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] @@ -49,8 +50,8 @@ class LongTextStrategy(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: - """Process long text + async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]: + """处理长文本 If the text length exceeds the threshold, this method will be called. diff --git a/pkg/pipeline/msgtrun/msgtrun.py b/pkg/pipeline/msgtrun/msgtrun.py index 1c5ee17d..00a9bfbf 100644 --- a/pkg/pipeline/msgtrun/msgtrun.py +++ b/pkg/pipeline/msgtrun/msgtrun.py @@ -1,10 +1,9 @@ from __future__ import annotations from .. import stage, entities -from ...core import entities as core_entities from . import truncator from ...utils import importutil - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query from . import truncators importutil.import_modules_in_pkg(truncators) @@ -29,8 +28,8 @@ class ConversationMessageTruncator(stage.PipelineStage): else: raise ValueError(f'Unknown truncator: {use_method}') - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: - """Process""" + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: + """处理""" query = await self.trun.truncate(query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/pipeline/msgtrun/truncator.py b/pkg/pipeline/msgtrun/truncator.py index 9e8b8a6c..180982d3 100644 --- a/pkg/pipeline/msgtrun/truncator.py +++ b/pkg/pipeline/msgtrun/truncator.py @@ -3,8 +3,8 @@ from __future__ import annotations import typing import abc -from ...core import entities as core_entities, app - +from ...core import app +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_truncators: list[typing.Type[Truncator]] = [] @@ -47,7 +47,7 @@ class Truncator(abc.ABC): pass @abc.abstractmethod - async def truncate(self, query: core_entities.Query) -> core_entities.Query: + async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query: """截断 一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。 diff --git a/pkg/pipeline/msgtrun/truncators/round.py b/pkg/pipeline/msgtrun/truncators/round.py index 2acb1d8c..400706b6 100644 --- a/pkg/pipeline/msgtrun/truncators/round.py +++ b/pkg/pipeline/msgtrun/truncators/round.py @@ -1,15 +1,15 @@ from __future__ import annotations from .. import truncator -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @truncator.truncator_class('round') class RoundTruncator(truncator.Truncator): """Truncate the conversation message chain to adapt to the LLM message length limit.""" - async def truncate(self, query: core_entities.Query) -> core_entities.Query: - """Truncate""" + async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query: + """截断""" max_round = query.pipeline_config['ai']['local-agent']['max-round'] temp_messages = [] diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index abf80e16..3e126314 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -5,14 +5,18 @@ import traceback import sqlalchemy -from ..core import app, entities +from ..core import app from . import entities as pipeline_entities from ..entity.persistence import pipeline as persistence_pipeline from . import stage -from ..platform.types import message as platform_message, events as platform_events -from ..plugin import events +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.events as events from ..utils import importutil +import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + from . import ( resprule, bansess, @@ -75,17 +79,17 @@ class RuntimePipeline: self.pipeline_entity = pipeline_entity self.stage_containers = stage_containers - async def run(self, query: entities.Query): + async def run(self, query: pipeline_query.Query): query.pipeline_config = self.pipeline_entity.config await self.process_query(query) - async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): + async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult): """检查输出""" if result.user_notice: # 处理str类型 if isinstance(result.user_notice, str): - result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice)) + result.user_notice = platform_message.MessageChain([platform_message.Plain(text=result.user_notice)]) elif isinstance(result.user_notice, list): result.user_notice = platform_message.MessageChain(*result.user_notice) @@ -99,7 +103,7 @@ class RuntimePipeline: bot_message=query.resp_messages[-1], message=result.user_notice, quote_origin=query.pipeline_config['output']['misc']['quote-origin'], - is_final=[msg.is_final for msg in query.resp_messages][0] + is_final=[msg.is_final for msg in query.resp_messages][0], ) else: await query.adapter.reply_message( @@ -117,7 +121,7 @@ class RuntimePipeline: async def _execute_from_stage( self, stage_index: int, - query: entities.Query, + query: pipeline_query.Query, ): """从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。 @@ -144,7 +148,7 @@ class RuntimePipeline: while i < len(self.stage_containers): stage_container = self.stage_containers[i] - query.current_stage = stage_container # 标记到 Query 对象里 + query.current_stage_name = stage_container.inst_name # 标记到 Query 对象里 result = stage_container.inst.process(query, stage_container.inst_name) @@ -181,26 +185,26 @@ class RuntimePipeline: i += 1 - async def process_query(self, query: entities.Query): + async def process_query(self, query: pipeline_query.Query): """处理请求""" try: # ======== 触发 MessageReceived 事件 ======== event_type = ( events.PersonMessageReceived - if query.launcher_type == entities.LauncherTypes.PERSON + if query.launcher_type == provider_session.LauncherTypes.PERSON else events.GroupMessageReceived ) - event_ctx = await self.ap.plugin_mgr.emit_event( - event=event_type( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - message_chain=query.message_chain, - query=query, - ) + event_obj = event_type( + query=query, + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + message_chain=query.message_chain, ) + event_ctx = await self.ap.plugin_connector.emit_event(event_obj) + if event_ctx.is_prevented_default(): return @@ -208,11 +212,12 @@ class RuntimePipeline: await self._execute_from_stage(0, query) except Exception as e: - inst_name = query.current_stage.inst_name if query.current_stage else 'unknown' + inst_name = query.current_stage_name if query.current_stage_name else 'unknown' self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}') self.ap.logger.error(f'Traceback: {traceback.format_exc()}') finally: self.ap.logger.debug(f'Query {query.query_id} processed') + del self.ap.query_pool.cached_queries[query.query_id] class PipelineManager: diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index 6975e53c..eb7df66b 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -3,10 +3,11 @@ from __future__ import annotations import asyncio import typing -from ..core import entities -from ..platform import adapter as msadapter -from ..platform.types import message as platform_message -from ..platform.types import events as platform_events +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter class QueryPool: @@ -16,7 +17,10 @@ class QueryPool: pool_lock: asyncio.Lock - queries: list[entities.Query] + queries: list[pipeline_query.Query] + + cached_queries: dict[int, pipeline_query.Query] + """Cached queries, used for plugin backward api call, will be removed after the query completely processed""" condition: asyncio.Condition @@ -24,34 +28,38 @@ class QueryPool: self.query_id_counter = 0 self.pool_lock = asyncio.Lock() self.queries = [] + self.cached_queries = {} self.condition = asyncio.Condition(self.pool_lock) async def add_query( self, bot_uuid: str, - launcher_type: entities.LauncherTypes, + launcher_type: provider_session.LauncherTypes, launcher_id: typing.Union[int, str], sender_id: typing.Union[int, str], message_event: platform_events.MessageEvent, message_chain: platform_message.MessageChain, - adapter: msadapter.MessagePlatformAdapter, + adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter, pipeline_uuid: typing.Optional[str] = None, - ) -> entities.Query: + ) -> pipeline_query.Query: async with self.condition: - query = entities.Query( + query_id = self.query_id_counter + query = pipeline_query.Query( bot_uuid=bot_uuid, - query_id=self.query_id_counter, + query_id=query_id, launcher_type=launcher_type, launcher_id=launcher_id, sender_id=sender_id, message_event=message_event, message_chain=message_chain, + variables={}, resp_messages=[], resp_message_chain=[], adapter=adapter, pipeline_uuid=pipeline_uuid, ) self.queries.append(query) + self.cached_queries[query_id] = query self.query_id_counter += 1 self.condition.notify_all() diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 1aada6b3..bd150998 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -3,10 +3,10 @@ from __future__ import annotations import datetime from .. import stage, entities -from ...core import entities as core_entities -from ...provider import entities as llm_entities -from ...plugin import events -from ...platform.types import message as platform_message +from langbot_plugin.api.entities.builtin.provider import message as provider_message +import langbot_plugin.api.entities.events as events +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @stage.stage_class('PreProcessor') @@ -26,7 +26,7 @@ class PreProcessor(stage.PipelineStage): async def process( self, - query: core_entities.Query, + query: pipeline_query.Query, stage_inst_name: str, ) -> entities.StageProcessResult: """Process""" @@ -49,80 +49,73 @@ class PreProcessor(stage.PipelineStage): query.bot_uuid, ) - conversation.use_llm_model = llm_model - - # Set query + # 设置query query.session = session query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() - query.use_llm_model = llm_model + query.use_llm_model_uuid = llm_model.model_entity.uuid if selected_runner == 'local-agent': - query.use_funcs = ( - conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('func_call') else None - ) + query.use_funcs = [] - query.variables = { + if llm_model.model_entity.abilities.__contains__('func_call'): + query.use_funcs = await self.ap.tool_mgr.get_all_tools() + + variables = { 'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}', 'conversation_id': conversation.uuid, 'msg_create_time': ( int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()) ), } + query.variables.update(variables) # Check if this model supports vision, if not, remove all images # TODO this checking should be performed in runner, and in this stage, the image should be reserved - if selected_runner == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'): + if selected_runner == 'local-agent' and not llm_model.model_entity.abilities.__contains__('vision'): for msg in query.messages: if isinstance(msg.content, list): for me in msg.content: if me.type == 'image_url': msg.content.remove(me) - content_list: list[llm_entities.ContentElement] = [] + content_list: list[provider_message.ContentElement] = [] plain_text = '' qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message') - # tidy the content_list - # combine all text content into one, and put it in the first position for me in query.message_chain: if isinstance(me, platform_message.Plain): + content_list.append(provider_message.ContentElement.from_text(me.text)) plain_text += me.text elif isinstance(me, platform_message.Image): - if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__( - 'vision' - ): + if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'): if me.base64 is not None: - content_list.append(llm_entities.ContentElement.from_image_base64(me.base64)) + content_list.append(provider_message.ContentElement.from_image_base64(me.base64)) elif isinstance(me, platform_message.Quote) and qoute_msg: for msg in me.origin: if isinstance(msg, platform_message.Plain): - content_list.append(llm_entities.ContentElement.from_text(msg.text)) + content_list.append(provider_message.ContentElement.from_text(msg.text)) elif isinstance(msg, platform_message.Image): - if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__( - 'vision' - ): + if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'): if msg.base64 is not None: - content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64)) - - content_list.insert(0, llm_entities.ContentElement.from_text(plain_text)) + content_list.append(provider_message.ContentElement.from_image_base64(msg.base64)) query.variables['user_message_text'] = plain_text - query.user_message = llm_entities.Message(role='user', content=content_list) - # =========== Trigger event PromptPreProcessing + query.user_message = provider_message.Message(role='user', content=content_list) + # =========== 触发事件 PromptPreProcessing - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.PromptPreProcessing( - session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}', - default_prompt=query.prompt.messages, - prompt=query.messages, - query=query, - ) + event = events.PromptPreProcessing( + session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}', + default_prompt=query.prompt.messages, + prompt=query.messages, + query=query, ) + event_ctx = await self.ap.plugin_connector.emit_event(event) + query.prompt.messages = event_ctx.event.default_prompt query.messages = event_ctx.event.prompt diff --git a/pkg/pipeline/process/handler.py b/pkg/pipeline/process/handler.py index 837b72e2..b70a8e04 100644 --- a/pkg/pipeline/process/handler.py +++ b/pkg/pipeline/process/handler.py @@ -3,8 +3,8 @@ from __future__ import annotations import abc from ...core import app -from ...core import entities as core_entities from .. import entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class MessageHandler(metaclass=abc.ABCMeta): @@ -19,7 +19,7 @@ class MessageHandler(metaclass=abc.ABCMeta): @abc.abstractmethod async def handle( self, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.StageProcessResult: raise NotImplementedError diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index fee54427..b6da5fa6 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -7,13 +7,15 @@ import traceback from .. import handler from ... import entities -from ....core import entities as core_entities from ....provider import runner as runner_module -from ....plugin import events -from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.events as events from ....utils import importutil from ....provider import runners +import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + importutil.import_modules_in_pkg(runners) @@ -21,7 +23,7 @@ importutil.import_modules_in_pkg(runners) class ChatMessageHandler(handler.MessageHandler): async def handle( self, - query: core_entities.Query, + query: pipeline_query.Query, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """处理""" # 调API @@ -30,19 +32,20 @@ class ChatMessageHandler(handler.MessageHandler): # 触发插件事件 event_class = ( events.PersonNormalMessageReceived - if query.launcher_type == core_entities.LauncherTypes.PERSON + if query.launcher_type == provider_session.LauncherTypes.PERSON else events.GroupNormalMessageReceived ) - event_ctx = await self.ap.plugin_mgr.emit_event( - event=event_class( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - text_message=str(query.message_chain), - query=query, - ) + event = event_class( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + text_message=str(query.message_chain), + query=query, ) + + event_ctx = await self.ap.plugin_connector.emit_event(event) + is_create_card = False # 判断下是否需要创建流式卡片 if event_ctx.is_prevented_default(): if event_ctx.event.reply is not None: @@ -120,4 +123,4 @@ class ChatMessageHandler(handler.MessageHandler): ) finally: # TODO statistics - pass \ No newline at end of file + pass diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index 7348d6b8..92cebe02 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -4,16 +4,17 @@ import typing from .. import handler from ... import entities -from ....core import entities as core_entities -from ....provider import entities as llm_entities -from ....plugin import events -from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.provider.message as provider_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.events as events class CommandHandler(handler.MessageHandler): async def handle( self, - query: core_entities.Query, + query: pipeline_query.Query, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """Process""" @@ -28,23 +29,23 @@ class CommandHandler(handler.MessageHandler): event_class = ( events.PersonCommandSent - if query.launcher_type == core_entities.LauncherTypes.PERSON + if query.launcher_type == provider_session.LauncherTypes.PERSON else events.GroupCommandSent ) - event_ctx = await self.ap.plugin_mgr.emit_event( - event=event_class( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - command=spt[0], - params=spt[1:] if len(spt) > 1 else [], - text_message=str(query.message_chain), - is_admin=(privilege == 2), - query=query, - ) + event = event_class( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + command=spt[0], + params=spt[1:] if len(spt) > 1 else [], + text_message=str(query.message_chain), + is_admin=(privilege == 2), + query=query, ) + event_ctx = await self.ap.plugin_connector.emit_event(event) + if event_ctx.is_prevented_default(): if event_ctx.event.reply is not None: mc = platform_message.MessageChain(event_ctx.event.reply) @@ -64,7 +65,7 @@ class CommandHandler(handler.MessageHandler): async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session): if ret.error is not None: query.resp_messages.append( - llm_entities.Message( + provider_message.Message( role='command', content=str(ret.error), ) @@ -73,17 +74,20 @@ class CommandHandler(handler.MessageHandler): self.ap.logger.info(f'Command({query.query_id}) error: {self.cut_str(str(ret.error))}') yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) - elif ret.text is not None or ret.image_url is not None: - content: list[llm_entities.ContentElement] = [] + elif ret.text is not None or ret.image_url is not None or ret.image_base64 is not None: + content: list[provider_message.ContentElement] = [] if ret.text is not None: - content.append(llm_entities.ContentElement.from_text(ret.text)) + content.append(provider_message.ContentElement.from_text(ret.text)) if ret.image_url is not None: - content.append(llm_entities.ContentElement.from_image_url(ret.image_url)) + content.append(provider_message.ContentElement.from_image_url(ret.image_url)) + + if ret.image_base64 is not None: + content.append(provider_message.ContentElement.from_image_base64(ret.image_base64)) query.resp_messages.append( - llm_entities.Message( + provider_message.Message( role='command', content=content, ) diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py index a08b8c08..27632f28 100644 --- a/pkg/pipeline/process/process.py +++ b/pkg/pipeline/process/process.py @@ -1,10 +1,10 @@ from __future__ import annotations -from ...core import entities as core_entities from . import handler from .handlers import chat, command from .. import entities from .. import stage +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @stage.stage_class('MessageProcessor') @@ -30,7 +30,7 @@ class Processor(stage.PipelineStage): async def process( self, - query: core_entities.Query, + query: pipeline_query.Query, stage_inst_name: str, ) -> entities.StageProcessResult: """Process""" diff --git a/pkg/pipeline/ratelimit/algo.py b/pkg/pipeline/ratelimit/algo.py index 3bcc347a..efbc326b 100644 --- a/pkg/pipeline/ratelimit/algo.py +++ b/pkg/pipeline/ratelimit/algo.py @@ -2,7 +2,8 @@ from __future__ import annotations import abc import typing -from ...core import app, entities as core_entities +from ...core import app +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_algos: list[typing.Type[ReteLimitAlgo]] = [] @@ -33,7 +34,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta): @abc.abstractmethod async def require_access( self, - query: core_entities.Query, + query: pipeline_query.Query, launcher_type: str, launcher_id: typing.Union[int, str], ) -> bool: @@ -53,7 +54,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta): @abc.abstractmethod async def release_access( self, - query: core_entities.Query, + query: pipeline_query.Query, launcher_type: str, launcher_id: typing.Union[int, str], ): diff --git a/pkg/pipeline/ratelimit/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py index cc816f73..6a2a8e97 100644 --- a/pkg/pipeline/ratelimit/algos/fixedwin.py +++ b/pkg/pipeline/ratelimit/algos/fixedwin.py @@ -3,7 +3,7 @@ import asyncio import time import typing from .. import algo -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query # 固定窗口算法 @@ -32,7 +32,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): async def require_access( self, - query: core_entities.Query, + query: pipeline_query.Query, launcher_type: str, launcher_id: typing.Union[int, str], ) -> bool: @@ -91,7 +91,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): async def release_access( self, - query: core_entities.Query, + query: pipeline_query.Query, launcher_type: str, launcher_id: typing.Union[int, str], ): diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index 23de4ec6..cab62b8d 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -4,9 +4,10 @@ import typing from .. import entities, stage from . import algo -from ...core import entities as core_entities from ...utils import importutil +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + from . import algos importutil.import_modules_in_pkg(algos) @@ -39,7 +40,7 @@ class RateLimit(stage.PipelineStage): async def process( self, - query: core_entities.Query, + query: pipeline_query.Query, stage_inst_name: str, ) -> typing.Union[ entities.StageProcessResult, diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index ece4e392..331a36aa 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -4,22 +4,19 @@ import random import asyncio -from ...platform.types import events as platform_events -from ...platform.types import message as platform_message - -from ...provider import entities as llm_entities - - +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.provider.message as provider_message from .. import stage, entities -from ...core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @stage.stage_class('SendResponseBackStage') class SendResponseBackStage(stage.PipelineStage): """发送响应消息""" - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: """处理""" random_range = ( @@ -40,7 +37,7 @@ class SendResponseBackStage(stage.PipelineStage): quote_origin = query.pipeline_config['output']['misc']['quote-origin'] - has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages) + has_chunks = any(isinstance(msg, provider_message.MessageChunk) for msg in query.resp_messages) # TODO 命令与流式的兼容性问题 if await query.adapter.is_stream_output_supported() and has_chunks: is_final = [msg.is_final for msg in query.resp_messages][0] diff --git a/pkg/pipeline/resprule/entities.py b/pkg/pipeline/resprule/entities.py index a0ba7807..71973c8a 100644 --- a/pkg/pipeline/resprule/entities.py +++ b/pkg/pipeline/resprule/entities.py @@ -1,6 +1,6 @@ -import pydantic.v1 as pydantic +import pydantic -from ...platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message class RuleJudgeResult(pydantic.BaseModel): diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index 0193f2ce..1a3560ff 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -4,9 +4,10 @@ from __future__ import annotations from . import rule from .. import stage, entities -from ...core import entities as core_entities from ...utils import importutil +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + from . import rules importutil.import_modules_in_pkg(rules) @@ -32,7 +33,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): await rule_inst.initialize() self.rule_matchers.append(rule_inst) - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: if query.launcher_type.value != 'group': # 只处理群消息 return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/pipeline/resprule/rule.py b/pkg/pipeline/resprule/rule.py index 3fdb0386..34e89a72 100644 --- a/pkg/pipeline/resprule/rule.py +++ b/pkg/pipeline/resprule/rule.py @@ -2,10 +2,11 @@ from __future__ import annotations import abc import typing -from ...core import app, entities as core_entities +from ...core import app from . import entities -from ...platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregisetered_rules: list[typing.Type[GroupRespondRule]] = [] @@ -39,7 +40,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta): message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.RuleJudgeResult: """判断消息是否匹配规则""" raise NotImplementedError diff --git a/pkg/pipeline/resprule/rules/atbot.py b/pkg/pipeline/resprule/rules/atbot.py index 340b92c7..b35fb5e4 100644 --- a/pkg/pipeline/resprule/rules/atbot.py +++ b/pkg/pipeline/resprule/rules/atbot.py @@ -3,8 +3,8 @@ from __future__ import annotations from .. import rule as rule_model from .. import entities -from ....core import entities as core_entities -from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @rule_model.rule_class('at-bot') @@ -14,19 +14,28 @@ class AtBotRule(rule_model.GroupRespondRule): message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.RuleJudgeResult: - if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']: - message_chain.remove(platform_message.At(query.adapter.bot_account_id)) + def remove_at(message_chain: platform_message.MessageChain): + for component in message_chain.root: + if isinstance(component, platform_message.At) and component.target == query.adapter.bot_account_id: + message_chain.remove(component) + break - if message_chain.has( - platform_message.At(query.adapter.bot_account_id) - ): # 回复消息时会at两次,检查并删除重复的 - message_chain.remove(platform_message.At(query.adapter.bot_account_id)) + remove_at(message_chain) + remove_at(message_chain) # 回复消息时会at两次,检查并删除重复的 - return entities.RuleJudgeResult( - matching=True, - replacement=message_chain, - ) + # if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']: + # message_chain.remove(platform_message.At(query.adapter.bot_account_id)) + + # if message_chain.has( + # platform_message.At(query.adapter.bot_account_id) + # ): # 回复消息时会at两次,检查并删除重复的 + # message_chain.remove(platform_message.At(query.adapter.bot_account_id)) + + # return entities.RuleJudgeResult( + # matching=True, + # replacement=message_chain, + # ) return entities.RuleJudgeResult(matching=False, replacement=message_chain) diff --git a/pkg/pipeline/resprule/rules/prefix.py b/pkg/pipeline/resprule/rules/prefix.py index c712d3e8..72f0de77 100644 --- a/pkg/pipeline/resprule/rules/prefix.py +++ b/pkg/pipeline/resprule/rules/prefix.py @@ -1,7 +1,7 @@ from .. import rule as rule_model from .. import entities -from ....core import entities as core_entities -from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @rule_model.rule_class('prefix') @@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule): message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.RuleJudgeResult: prefixes = rule_dict['prefix'] diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py index d2f782ab..2bfe8b71 100644 --- a/pkg/pipeline/resprule/rules/random.py +++ b/pkg/pipeline/resprule/rules/random.py @@ -3,8 +3,8 @@ import random from .. import rule as rule_model from .. import entities -from ....core import entities as core_entities -from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @rule_model.rule_class('random') @@ -14,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule): message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.RuleJudgeResult: random_rate = rule_dict['random'] diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py index daac0869..41e1df8e 100644 --- a/pkg/pipeline/resprule/rules/regexp.py +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -3,8 +3,8 @@ import re from .. import rule as rule_model from .. import entities -from ....core import entities as core_entities -from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @rule_model.rule_class('regexp') @@ -14,7 +14,7 @@ class RegExpRule(rule_model.GroupRespondRule): message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.RuleJudgeResult: regexps = rule_dict['regexp'] diff --git a/pkg/pipeline/stage.py b/pkg/pipeline/stage.py index 18a94b73..0ff1af7e 100644 --- a/pkg/pipeline/stage.py +++ b/pkg/pipeline/stage.py @@ -3,8 +3,9 @@ from __future__ import annotations import abc import typing -from ..core import app, entities as core_entities +from ..core import app from . import entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_stages: dict[str, type[PipelineStage]] = {} @@ -33,7 +34,7 @@ class PipelineStage(metaclass=abc.ABCMeta): @abc.abstractmethod async def process( self, - query: core_entities.Query, + query: pipeline_query.Query, stage_inst_name: str, ) -> typing.Union[ entities.StageProcessResult, diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index 3299a226..439595e9 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -2,12 +2,12 @@ from __future__ import annotations import typing - -from ...core import entities as core_entities from .. import entities from .. import stage -from ...plugin import events -from ...platform.types import message as platform_message + +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.events as events @stage.stage_class('ResponseWrapper') @@ -25,7 +25,7 @@ class ResponseWrapper(stage.PipelineStage): async def process( self, - query: core_entities.Query, + query: pipeline_query.Query, stage_inst_name: str, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """处理""" @@ -58,21 +58,22 @@ class ResponseWrapper(stage.PipelineStage): reply_text = str(result.get_content_platform_message_chain()) # ============= 触发插件事件 =============== - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.NormalMessageResponded( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - session=session, - prefix='', - response_text=reply_text, - finish_reason='stop', - funcs_called=[fc.function.name for fc in result.tool_calls] - if result.tool_calls is not None - else [], - query=query, - ) + event = events.NormalMessageResponded( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + session=session, + prefix='', + response_text=reply_text, + finish_reason='stop', + funcs_called=[fc.function.name for fc in result.tool_calls] + if result.tool_calls is not None + else [], + query=query, ) + + event_ctx = await self.ap.plugin_connector.emit_event(event) + if event_ctx.is_prevented_default(): yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, @@ -96,26 +97,26 @@ class ResponseWrapper(stage.PipelineStage): reply_text = f'调用函数 {".".join(function_names)}...' query.resp_message_chain.append( - platform_message.MessageChain([platform_message.Plain(reply_text)]) + platform_message.MessageChain([platform_message.Plain(text=reply_text)]) ) if query.pipeline_config['output']['misc']['track-function-calls']: - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.NormalMessageResponded( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - session=session, - prefix='', - response_text=reply_text, - finish_reason='stop', - funcs_called=[fc.function.name for fc in result.tool_calls] - if result.tool_calls is not None - else [], - query=query, - ) + event = events.NormalMessageResponded( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + session=session, + prefix='', + response_text=reply_text, + finish_reason='stop', + funcs_called=[fc.function.name for fc in result.tool_calls] + if result.tool_calls is not None + else [], + query=query, ) + event_ctx = await self.ap.plugin_connector.emit_event(event) + if event_ctx.is_prevented_default(): yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, @@ -124,12 +125,12 @@ class ResponseWrapper(stage.PipelineStage): else: if event_ctx.event.reply is not None: query.resp_message_chain.append( - platform_message.MessageChain(event_ctx.event.reply) + platform_message.MessageChain(text=event_ctx.event.reply) ) else: query.resp_message_chain.append( - platform_message.MessageChain([platform_message.Plain(reply_text)]) + platform_message.MessageChain([platform_message.Plain(text=reply_text)]) ) yield entities.StageProcessResult( diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py deleted file mode 100644 index e064ef80..00000000 --- a/pkg/platform/adapter.py +++ /dev/null @@ -1,190 +0,0 @@ -from __future__ import annotations - -# MessageSource的适配器 -import typing -import abc - - -from ..core import app -from .types import message as platform_message -from .types import events as platform_events -from .logger import EventLogger - - -class MessagePlatformAdapter(metaclass=abc.ABCMeta): - """消息平台适配器基类""" - - name: str - - bot_account_id: int - """机器人账号ID,需要在初始化时设置""" - - config: dict - - ap: app.Application - - logger: EventLogger - - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): - """初始化适配器 - - Args: - config (dict): 对应的配置 - ap (app.Application): 应用上下文 - """ - self.config = config - self.ap = ap - self.logger = logger - - async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): - """主动发送消息 - - Args: - target_type (str): 目标类型,`person`或`group` - target_id (str): 目标ID - message (platform.types.MessageChain): 消息链 - """ - raise NotImplementedError - - async def reply_message( - self, - message_source: platform_events.MessageEvent, - message: platform_message.MessageChain, - quote_origin: bool = False, - ): - """回复消息 - - Args: - message_source (platform.types.MessageEvent): 消息源事件 - message (platform.types.MessageChain): 消息链 - quote_origin (bool, optional): 是否引用原消息. Defaults to False. - """ - raise NotImplementedError - - async def reply_message_chunk( - self, - message_source: platform_events.MessageEvent, - bot_message: dict, - message: platform_message.MessageChain, - quote_origin: bool = False, - is_final: bool = False, - ): - """回复消息(流式输出) - Args: - message_source (platform.types.MessageEvent): 消息源事件 - message_id (int): 消息ID - message (platform.types.MessageChain): 消息链 - quote_origin (bool, optional): 是否引用原消息. Defaults to False. - is_final (bool, optional): 流式是否结束. Defaults to False. - """ - raise NotImplementedError - - async def create_message_card(self, message_id: typing.Type[str, int], event: platform_events.MessageEvent) -> bool: - """创建卡片消息 - Args: - message_id (str): 消息ID - event (platform_events.MessageEvent): 消息源事件 - """ - return False - - async def is_muted(self, group_id: int) -> bool: - """获取账号是否在指定群被禁言""" - raise NotImplementedError - - def register_listener( - self, - event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], - ): - """注册事件监听器 - - Args: - event_type (typing.Type[platform.types.Event]): 事件类型 - callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件 - """ - raise NotImplementedError - - def unregister_listener( - self, - event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], - ): - """注销事件监听器 - - Args: - event_type (typing.Type[platform.types.Event]): 事件类型 - callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件 - """ - raise NotImplementedError - - async def run_async(self): - """异步运行""" - raise NotImplementedError - - async def is_stream_output_supported(self) -> bool: - """是否支持流式输出""" - return False - - async def kill(self) -> bool: - """关闭适配器 - - Returns: - bool: 是否成功关闭,热重载时若此函数返回False则不会重载MessageSource底层 - """ - raise NotImplementedError - - -class MessageConverter: - """消息链转换器基类""" - - @staticmethod - def yiri2target(message_chain: platform_message.MessageChain): - """将源平台消息链转换为目标平台消息链 - - Args: - message_chain (platform.types.MessageChain): 源平台消息链 - - Returns: - typing.Any: 目标平台消息链 - """ - raise NotImplementedError - - @staticmethod - def target2yiri(message_chain: typing.Any) -> platform_message.MessageChain: - """将目标平台消息链转换为源平台消息链 - - Args: - message_chain (typing.Any): 目标平台消息链 - - Returns: - platform.types.MessageChain: 源平台消息链 - """ - raise NotImplementedError - - -class EventConverter: - """事件转换器基类""" - - @staticmethod - def yiri2target(event: typing.Type[platform_events.Event]): - """将源平台事件转换为目标平台事件 - - Args: - event (typing.Type[platform.types.Event]): 源平台事件 - - Returns: - typing.Any: 目标平台事件 - """ - raise NotImplementedError - - @staticmethod - def target2yiri(event: typing.Any) -> platform_events.Event: - """将目标平台事件的调用参数转换为源平台的事件参数对象 - - Args: - event (typing.Any): 目标平台事件 - - Returns: - typing.Type[platform.types.Event]: 源平台事件 - """ - raise NotImplementedError diff --git a/pkg/platform/adapter.yaml b/pkg/platform/adapter.yaml deleted file mode 100644 index d32b412d..00000000 --- a/pkg/platform/adapter.yaml +++ /dev/null @@ -1,14 +0,0 @@ -apiVersion: v1 -kind: ComponentTemplate -metadata: - name: MessagePlatformAdapter - label: - en_US: Message Platform Adapter - zh_Hans: 消息平台适配器模板类 -spec: - type: - - python -execution: - python: - path: ./adapter.py - attr: MessagePlatformAdapter diff --git a/pkg/platform/botmgr.py b/pkg/platform/botmgr.py index 1da5eec8..ee9bd040 100644 --- a/pkg/platform/botmgr.py +++ b/pkg/platform/botmgr.py @@ -1,15 +1,10 @@ from __future__ import annotations -import sys import asyncio import traceback import sqlalchemy -# FriendMessage, Image, MessageChain, Plain -from . import adapter as msadapter - from ..core import app, entities as core_entities, taskmgr -from .types import events as platform_events, message as platform_message from ..discover import engine @@ -19,10 +14,10 @@ from ..entity.errors import platform as platform_errors from .logger import EventLogger -# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题 -from . import types as mirai - -sys.modules['mirai'] = mirai +import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter class RuntimeBot: @@ -34,7 +29,7 @@ class RuntimeBot: enable: bool - adapter: msadapter.MessagePlatformAdapter + adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter task_wrapper: taskmgr.TaskWrapper @@ -46,7 +41,7 @@ class RuntimeBot: self, ap: app.Application, bot_entity: persistence_bot.Bot, - adapter: msadapter.MessagePlatformAdapter, + adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter, logger: EventLogger, ): self.ap = ap @@ -59,7 +54,7 @@ class RuntimeBot: async def initialize(self): async def on_friend_message( event: platform_events.FriendMessage, - adapter: msadapter.MessagePlatformAdapter, + adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter, ): image_components = [ component for component in event.message_chain if isinstance(component, platform_message.Image) @@ -73,7 +68,7 @@ class RuntimeBot: await self.ap.query_pool.add_query( bot_uuid=self.bot_entity.uuid, - launcher_type=core_entities.LauncherTypes.PERSON, + launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=event.sender.id, sender_id=event.sender.id, message_event=event, @@ -84,7 +79,7 @@ class RuntimeBot: async def on_group_message( event: platform_events.GroupMessage, - adapter: msadapter.MessagePlatformAdapter, + adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter, ): image_components = [ component for component in event.message_chain if isinstance(component, platform_message.Image) @@ -98,7 +93,7 @@ class RuntimeBot: await self.ap.query_pool.add_query( bot_uuid=self.bot_entity.uuid, - launcher_type=core_entities.LauncherTypes.GROUP, + launcher_type=provider_session.LauncherTypes.GROUP, launcher_id=event.group.id, sender_id=event.sender.id, message_event=event, @@ -153,7 +148,7 @@ class PlatformManager: adapter_components: list[engine.Component] - adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] + adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]] def __init__(self, ap: app.Application = None): self.ap = ap @@ -163,7 +158,7 @@ class PlatformManager: async def initialize(self): self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter') - adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {} + adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]] = {} for component in self.adapter_components: adapter_dict[component.metadata.name] = component.get_python_component_class() self.adapter_dict = adapter_dict @@ -174,8 +169,9 @@ class PlatformManager: webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap) webchat_adapter_inst = webchat_adapter_class( {}, - self.ap, webchat_logger, + ap=self.ap, + is_stream=False, ) self.webchat_proxy_bot = RuntimeBot( @@ -195,7 +191,7 @@ class PlatformManager: await self.load_bots_from_db() - def get_running_adapters(self) -> list[msadapter.MessagePlatformAdapter]: + def get_running_adapters(self) -> list[abstract_platform_adapter.AbstractMessagePlatformAdapter]: return [bot.adapter for bot in self.bots if bot.enable] async def load_bots_from_db(self): @@ -233,7 +229,6 @@ class PlatformManager: adapter_inst = self.adapter_dict[bot_entity.adapter]( bot_entity.adapter_config, - self.ap, logger, ) @@ -276,43 +271,6 @@ class PlatformManager: return component return None - async def write_back_config( - self, - adapter_name: str, - adapter_inst: msadapter.MessagePlatformAdapter, - config: dict, - ): - # index = -2 - - # for i, adapter in enumerate(self.adapters): - # if adapter == adapter_inst: - # index = i - # break - - # if index == -2: - # raise Exception('平台适配器未找到') - - # # 只修改启用的适配器 - # real_index = -1 - - # for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']): - # if adapter['enable']: - # index -= 1 - # if index == -1: - # real_index = i - # break - - # new_cfg = { - # 'adapter': adapter_name, - # 'enable': True, - # **config - # } - # self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg - # await self.ap.platform_cfg.dump_config() - - # TODO implement this - pass - async def run(self): # This method will only be called when the application launching await self.webchat_proxy_bot.run() diff --git a/pkg/platform/logger.py b/pkg/platform/logger.py index a2ea2e25..05fce394 100644 --- a/pkg/platform/logger.py +++ b/pkg/platform/logger.py @@ -9,7 +9,8 @@ import traceback import uuid from ..core import app -from .types import message as platform_message +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_event_logger class EventLogLevel(enum.Enum): @@ -55,7 +56,7 @@ MAX_LOG_COUNT = 200 DELETE_COUNT_PER_TIME = 50 -class EventLogger: +class EventLogger(abstract_platform_event_logger.AbstractEventLogger): """used for logging bot events""" ap: app.Application diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index 70f0ac9d..ba673796 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -5,17 +5,17 @@ import traceback import datetime import aiocqhttp +import pydantic -from .. import adapter -from ...core import app -from ..types import message as platform_message -from ..types import events as platform_events -from ..types import entities as platform_entities +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities from ...utils import image -from ..logger import EventLogger +import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger -class AiocqhttpMessageConverter(adapter.MessageConverter): +class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConverter): @staticmethod async def yiri2target( message_chain: platform_message.MessageChain, @@ -266,20 +266,21 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): await process_message_data(msg_data, reply_list) reply_msg = platform_message.Quote( - message_id=msg.data['id'], sender_id=msg_datas['sender']['user_id'], origin=reply_list + message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list ) yiri_msg_list.append(reply_msg) - # 这里下载所有文件会导致下载文件过多,暂时不下载 - # elif msg.type == 'file': - # # file_name = msg.data['file'] - # file_id = msg.data['file_id'] - # file_data = await bot.get_file(file_id=file_id) - # file_name = file_data.get('file_name') - # file_path = file_data.get('file') - # file_url = file_data.get('file_url') - # file_size = file_data.get('file_size') - # yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size)) + elif msg.type == 'file': + pass + # file_name = msg.data['file'] + # file_id = msg.data['file_id'] + # file_data = await bot.get_file(file_id=file_id) + # file_name = file_data.get('file_name') + # file_path = file_data.get('file') + # _ = file_path + # file_url = file_data.get('file_url') + # file_size = file_data.get('file_size') + # yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size)) elif msg.type == 'face': face_id = msg.data['id'] face_name = msg.data['raw']['faceText'] @@ -298,7 +299,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): return chain -class AiocqhttpEventConverter(adapter.EventConverter): +class AiocqhttpEventConverter(abstract_platform_adapter.AbstractEventConverter): @staticmethod async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int): return event.source_platform_object @@ -348,23 +349,19 @@ class AiocqhttpEventConverter(adapter.EventConverter): ) -class AiocqhttpAdapter(adapter.MessagePlatformAdapter): - bot: aiocqhttp.CQHttp - - bot_account_id: int +class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): + bot: aiocqhttp.CQHttp = pydantic.Field(exclude=True, default_factory=aiocqhttp.CQHttp) message_converter: AiocqhttpMessageConverter = AiocqhttpMessageConverter() event_converter: AiocqhttpEventConverter = AiocqhttpEventConverter() - config: dict - - ap: app.Application - on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = [] - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): - self.config = config - self.logger = logger + def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger): + super().__init__( + config=config, + logger=logger, + ) async def shutdown_trigger_placeholder(): while True: @@ -372,7 +369,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): self.config['shutdown_trigger'] = shutdown_trigger_placeholder - self.ap = ap self.on_websocket_connection_event_cache = [] if 'access-token' in config: @@ -408,7 +404,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): async def on_message(event: aiocqhttp.Event): self.bot_account_id = event.self_id @@ -439,7 +437,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/dingtalk.py b/pkg/platform/sources/dingtalk.py index 77b467d5..eb15775f 100644 --- a/pkg/platform/sources/dingtalk.py +++ b/pkg/platform/sources/dingtalk.py @@ -1,19 +1,16 @@ -from re import S import traceback import typing from libs.dingtalk_api.dingtalkevent import DingTalkEvent -from pkg.platform.types import message as platform_message -from pkg.platform.adapter import MessagePlatformAdapter -from .. import adapter -from ...core import app -from ..types import events as platform_events -from ..types import entities as platform_entities +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities from libs.dingtalk_api.api import DingTalkClient import datetime from ..logger import EventLogger -class DingTalkMessageConverter(adapter.MessageConverter): +class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): @staticmethod async def yiri2target(message_chain: platform_message.MessageChain): content = '' @@ -52,7 +49,7 @@ class DingTalkMessageConverter(adapter.MessageConverter): return chain -class DingTalkEventConverter(adapter.EventConverter): +class DingTalkEventConverter(abstract_platform_adapter.AbstractEventConverter): @staticmethod async def yiri2target(event: platform_events.MessageEvent): return event.source_platform_object @@ -96,22 +93,18 @@ class DingTalkEventConverter(adapter.EventConverter): ) -class DingTalkAdapter(adapter.MessagePlatformAdapter): +class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): bot: DingTalkClient - ap: app.Application bot_account_id: str message_converter: DingTalkMessageConverter = DingTalkMessageConverter() event_converter: DingTalkEventConverter = DingTalkEventConverter() config: dict - card_instance_id_dict: dict # 回复卡片消息字典,key为消息id,value为回复卡片实例id,用于在流式消息时判断是否发送到指定卡片 - seq: int # 消息顺序,直接以seq作为标识 + card_instance_id_dict: ( + dict # 回复卡片消息字典,key为消息id,value为回复卡片实例id,用于在流式消息时判断是否发送到指定卡片 + ) + + def __init__(self, config: dict, logger: EventLogger): - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): - self.config = config - self.ap = ap - self.logger = logger - self.card_instance_id_dict = {} - # self.seq = 1 required_keys = [ 'client_id', 'client_secret', @@ -121,16 +114,23 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): missing_keys = [key for key in required_keys if key not in config] if missing_keys: raise Exception('钉钉缺少相关配置项,请查看文档或联系管理员') + bot = DingTalkClient( + client_id=config['client_id'], + client_secret=config['client_secret'], + robot_name=config['robot_name'], + robot_code=config['robot_code'], + markdown_card=config['markdown_card'], + logger=logger, + ) + bot_account_id = config['robot_name'] + super().__init__( + config=config, + logger=logger, + card_instance_id_dict={}, + bot_account_id=bot_account_id, + bot=bot, + listeners={}, - self.bot_account_id = self.config['robot_name'] - - self.bot = DingTalkClient( - client_id=config['client_id'], - client_secret=config['client_secret'], - robot_name=config['robot_name'], - robot_code=config['robot_code'], - markdown_card=config['markdown_card'], - logger=self.logger, ) async def reply_message( @@ -165,12 +165,11 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): msg_seq = bot_message.msg_sequence if (msg_seq - 1) % 8 == 0 or is_final: - content, at = await DingTalkMessageConverter.yiri2target(message) card_instance, card_instance_id = self.card_instance_id_dict[message_id] if not content and bot_message.content: - content = bot_message.content # 兼容直接传入content的情况 + content = bot_message.content # 兼容直接传入content的情况 # print(card_instance_id) if content: await self.bot.send_card_message(card_instance, card_instance_id, content, is_final) @@ -202,7 +201,9 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): async def on_message(event: DingTalkEvent): try: @@ -224,9 +225,14 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): async def kill(self) -> bool: return False + async def is_muted(self) -> bool: + return False + async def unregister_listener( self, event_type: type, - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/discord.py b/pkg/platform/sources/discord.py index 9e26f239..933961de 100644 --- a/pkg/platform/sources/discord.py +++ b/pkg/platform/sources/discord.py @@ -12,13 +12,14 @@ import asyncio from enum import Enum import aiohttp +import pydantic -from .. import adapter -from ...core import app +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger from ..logger import EventLogger -from ..types import message as platform_message -from ..types import events as platform_events -from ..types import entities as platform_entities # 语音功能相关异常定义 @@ -582,7 +583,7 @@ class VoiceConnectionManager: await self.stop_monitoring() -class DiscordMessageConverter(adapter.MessageConverter): +class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter): @staticmethod async def yiri2target( message_chain: platform_message.MessageChain, @@ -736,7 +737,7 @@ class DiscordMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(element_list) -class DiscordEventConverter(adapter.EventConverter): +class DiscordEventConverter(abstract_platform_adapter.AbstractEventConverter): @staticmethod async def yiri2target(event: platform_events.Event) -> discord.Message: pass @@ -778,32 +779,26 @@ class DiscordEventConverter(adapter.EventConverter): ) -class DiscordAdapter(adapter.MessagePlatformAdapter): - bot: discord.Client - - bot_account_id: str # 用于在流水线中识别at是否是本bot,直接以bot_name作为标识 - - config: dict - - ap: app.Application +class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): + bot: discord.Client = pydantic.Field(exclude=True) message_converter: DiscordMessageConverter = DiscordMessageConverter() event_converter: DiscordEventConverter = DiscordEventConverter() listeners: typing.Dict[ typing.Type[platform_events.Event], - typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None], ] = {} - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): - self.config = config - self.ap = ap - self.logger = logger + voice_manager: VoiceConnectionManager | None = pydantic.Field(exclude=True, default=None) - self.bot_account_id = self.config['client_id'] + def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs): + bot_account_id = config['client_id'] + + listeners = {} # 初始化语音连接管理器 - self.voice_manager: VoiceConnectionManager = None + # self.voice_manager: VoiceConnectionManager = None adapter_self = self @@ -823,7 +818,17 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): if os.getenv('http_proxy'): args['proxy'] = os.getenv('http_proxy') - self.bot = MyClient(intents=intents, **args) + bot = MyClient(intents=intents, **args) + + super().__init__( + config=config, + logger=logger, + bot_account_id=bot_account_id, + listeners=listeners, + bot=bot, + voice_manager=None, + **kwargs, + ) # Voice functionality methods async def join_voice_channel(self, guild_id: int, channel_id: int, user_id: int = None) -> discord.VoiceClient: @@ -1029,7 +1034,14 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): if quote_origin: args['reference'] = message_source.source_platform_object - if message.has(platform_message.At): + has_at = False + + for component in message.root: + if isinstance(component, platform_message.At): + has_at = True + break + + if has_at: args['mention_author'] = True await message_source.source_platform_object.channel.send(**args) @@ -1040,14 +1052,18 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): self.listeners[event_type] = callback def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): self.listeners.pop(event_type) diff --git a/pkg/platform/sources/lark.py b/pkg/platform/sources/lark.py index 975730b5..23257e6f 100644 --- a/pkg/platform/sources/lark.py +++ b/pkg/platform/sources/lark.py @@ -17,14 +17,14 @@ import aiohttp import lark_oapi.ws.exception import quart from lark_oapi.api.im.v1 import * +import pydantic from lark_oapi.api.cardkit.v1 import * -from .. import adapter -from ...core import app -from ..types import message as platform_message -from ..types import events as platform_events -from ..types import entities as platform_entities -from ..logger import EventLogger +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger class AESCipher(object): @@ -53,7 +53,7 @@ class AESCipher(object): return self.decrypt(enc).decode('utf8') -class LarkMessageConverter(adapter.MessageConverter): +class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): @staticmethod async def yiri2target( message_chain: platform_message.MessageChain, api_client: lark_oapi.Client @@ -277,7 +277,7 @@ class LarkMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(lb_msg_list) -class LarkEventConverter(adapter.EventConverter): +class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter): @staticmethod async def yiri2target( event: platform_events.MessageEvent, @@ -325,49 +325,37 @@ CARD_ID_CACHE_SIZE = 500 CARD_ID_CACHE_MAX_LIFETIME = 20 * 60 # 20分钟 -class LarkAdapter(adapter.MessagePlatformAdapter): - bot: lark_oapi.ws.Client - api_client: lark_oapi.Client +class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): + bot: lark_oapi.ws.Client = pydantic.Field(exclude=True) + api_client: lark_oapi.Client = pydantic.Field(exclude=True) bot_account_id: str # 用于在流水线中识别at是否是本bot,直接以bot_name作为标识 - lark_tenant_key: str # 飞书企业key + lark_tenant_key: str = pydantic.Field(exclude=True, default='') # 飞书企业key message_converter: LarkMessageConverter = LarkMessageConverter() event_converter: LarkEventConverter = LarkEventConverter() listeners: typing.Dict[ typing.Type[platform_events.Event], - typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None], ] - config: dict - quart_app: quart.Quart - ap: app.Application - + quart_app: quart.Quart = pydantic.Field(exclude=True) card_id_dict: dict[str, str] # 消息id到卡片id的映射,便于创建卡片后的发送消息到指定卡片 seq: int # 用于在发送卡片消息中识别消息顺序,直接以seq作为标识 - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): - self.config = config - self.ap = ap - self.logger = logger - self.quart_app = quart.Quart(__name__) - self.listeners = {} - self.card_id_dict = {} - self.seq = 1 + def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs): + quart_app = quart.Quart(__name__) - - @self.quart_app.route('/lark/callback', methods=['POST']) + @quart_app.route('/lark/callback', methods=['POST']) async def lark_callback(): try: data = await quart.request.json - self.ap.logger.debug(f'Lark callback event: {data}') - if 'encrypt' in data: - cipher = AESCipher(self.config['encrypt-key']) + cipher = AESCipher(config['encrypt-key']) data = cipher.decrypt_string(data['encrypt']) data = json.loads(data) @@ -414,10 +402,24 @@ class LarkAdapter(adapter.MessagePlatformAdapter): lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build() ) - self.bot_account_id = config['bot_name'] + bot_account_id = config['bot_name'] - self.bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler) - self.api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build() + bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler) + api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build() + + super().__init__( + config=config, + logger=logger, + lark_tenant_key=config.get('lark_tenant_key', ''), + card_id_dict={}, + seq=1, + listeners={}, + quart_app=quart_app, + bot=bot, + api_client=api_client, + bot_account_id=bot_account_id, + **kwargs, + ) async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): pass @@ -430,151 +432,177 @@ class LarkAdapter(adapter.MessagePlatformAdapter): async def create_card_id(self, message_id): try: - self.ap.logger.debug('飞书支持stream输出,创建卡片......') + # self.logger.debug('飞书支持stream输出,创建卡片......') - card_data = {"schema": "2.0", "config": {"update_multi": True, "streaming_mode": True, - "streaming_config": {"print_step": {"default": 1}, - "print_frequency_ms": {"default": 70}, - "print_strategy": "fast"}}, - "body": {"direction": "vertical", "padding": "12px 12px 12px 12px", "elements": [{"tag": "div", - "text": { - "tag": "plain_text", - "content": "LangBot", - "text_size": "normal", - "text_align": "left", - "text_color": "default"}, - "icon": { - "tag": "custom_icon", - "img_key": "img_v3_02p3_05c65d5d-9bad-440a-a2fb-c89571bfd5bg"}}, - { - "tag": "markdown", - "content": "", - "text_align": "left", - "text_size": "normal", - "margin": "0px 0px 0px 0px", - "element_id": "streaming_txt"}, - { - "tag": "markdown", - "content": "", - "text_align": "left", - "text_size": "normal", - "margin": "0px 0px 0px 0px"}, - { - "tag": "column_set", - "horizontal_spacing": "8px", - "horizontal_align": "left", - "columns": [ - { - "tag": "column", - "width": "weighted", - "elements": [ - { - "tag": "markdown", - "content": "", - "text_align": "left", - "text_size": "normal", - "margin": "0px 0px 0px 0px"}, - { - "tag": "markdown", - "content": "", - "text_align": "left", - "text_size": "normal", - "margin": "0px 0px 0px 0px"}, - { - "tag": "markdown", - "content": "", - "text_align": "left", - "text_size": "normal", - "margin": "0px 0px 0px 0px"}], - "padding": "0px 0px 0px 0px", - "direction": "vertical", - "horizontal_spacing": "8px", - "vertical_spacing": "2px", - "horizontal_align": "left", - "vertical_align": "top", - "margin": "0px 0px 0px 0px", - "weight": 1}], - "margin": "0px 0px 0px 0px"}, - {"tag": "hr", - "margin": "0px 0px 0px 0px"}, - { - "tag": "column_set", - "horizontal_spacing": "12px", - "horizontal_align": "right", - "columns": [ - { - "tag": "column", - "width": "weighted", - "elements": [ - { - "tag": "markdown", - "content": "以上内容由 AI 生成,仅供参考。更多详细、准确信息可点击引用链接查看", - "text_align": "left", - "text_size": "notation", - "margin": "4px 0px 0px 0px", - "icon": { - "tag": "standard_icon", - "token": "robot_outlined", - "color": "grey"}}], - "padding": "0px 0px 0px 0px", - "direction": "vertical", - "horizontal_spacing": "8px", - "vertical_spacing": "8px", - "horizontal_align": "left", - "vertical_align": "top", - "margin": "0px 0px 0px 0px", - "weight": 1}, - { - "tag": "column", - "width": "20px", - "elements": [ - { - "tag": "button", - "text": { - "tag": "plain_text", - "content": ""}, - "type": "text", - "width": "fill", - "size": "medium", - "icon": { - "tag": "standard_icon", - "token": "thumbsup_outlined"}, - "hover_tips": { - "tag": "plain_text", - "content": "有帮助"}, - "margin": "0px 0px 0px 0px"}], - "padding": "0px 0px 0px 0px", - "direction": "vertical", - "horizontal_spacing": "8px", - "vertical_spacing": "8px", - "horizontal_align": "left", - "vertical_align": "top", - "margin": "0px 0px 0px 0px"}, - { - "tag": "column", - "width": "30px", - "elements": [ - { - "tag": "button", - "text": { - "tag": "plain_text", - "content": ""}, - "type": "text", - "width": "default", - "size": "medium", - "icon": { - "tag": "standard_icon", - "token": "thumbdown_outlined"}, - "hover_tips": { - "tag": "plain_text", - "content": "无帮助"}, - "margin": "0px 0px 0px 0px"}], - "padding": "0px 0px 0px 0px", - "vertical_spacing": "8px", - "horizontal_align": "left", - "vertical_align": "top", - "margin": "0px 0px 0px 0px"}], - "margin": "0px 0px 4px 0px"}]}} + card_data = { + 'schema': '2.0', + 'config': { + 'update_multi': True, + 'streaming_mode': True, + 'streaming_config': { + 'print_step': {'default': 1}, + 'print_frequency_ms': {'default': 70}, + 'print_strategy': 'fast', + }, + }, + 'body': { + 'direction': 'vertical', + 'padding': '12px 12px 12px 12px', + 'elements': [ + { + 'tag': 'div', + 'text': { + 'tag': 'plain_text', + 'content': 'LangBot', + 'text_size': 'normal', + 'text_align': 'left', + 'text_color': 'default', + }, + 'icon': { + 'tag': 'custom_icon', + 'img_key': 'img_v3_02p3_05c65d5d-9bad-440a-a2fb-c89571bfd5bg', + }, + }, + { + 'tag': 'markdown', + 'content': '', + 'text_align': 'left', + 'text_size': 'normal', + 'margin': '0px 0px 0px 0px', + 'element_id': 'streaming_txt', + }, + { + 'tag': 'markdown', + 'content': '', + 'text_align': 'left', + 'text_size': 'normal', + 'margin': '0px 0px 0px 0px', + }, + { + 'tag': 'column_set', + 'horizontal_spacing': '8px', + 'horizontal_align': 'left', + 'columns': [ + { + 'tag': 'column', + 'width': 'weighted', + 'elements': [ + { + 'tag': 'markdown', + 'content': '', + 'text_align': 'left', + 'text_size': 'normal', + 'margin': '0px 0px 0px 0px', + }, + { + 'tag': 'markdown', + 'content': '', + 'text_align': 'left', + 'text_size': 'normal', + 'margin': '0px 0px 0px 0px', + }, + { + 'tag': 'markdown', + 'content': '', + 'text_align': 'left', + 'text_size': 'normal', + 'margin': '0px 0px 0px 0px', + }, + ], + 'padding': '0px 0px 0px 0px', + 'direction': 'vertical', + 'horizontal_spacing': '8px', + 'vertical_spacing': '2px', + 'horizontal_align': 'left', + 'vertical_align': 'top', + 'margin': '0px 0px 0px 0px', + 'weight': 1, + } + ], + 'margin': '0px 0px 0px 0px', + }, + {'tag': 'hr', 'margin': '0px 0px 0px 0px'}, + { + 'tag': 'column_set', + 'horizontal_spacing': '12px', + 'horizontal_align': 'right', + 'columns': [ + { + 'tag': 'column', + 'width': 'weighted', + 'elements': [ + { + 'tag': 'markdown', + 'content': '以上内容由 AI 生成,仅供参考。更多详细、准确信息可点击引用链接查看', + 'text_align': 'left', + 'text_size': 'notation', + 'margin': '4px 0px 0px 0px', + 'icon': { + 'tag': 'standard_icon', + 'token': 'robot_outlined', + 'color': 'grey', + }, + } + ], + 'padding': '0px 0px 0px 0px', + 'direction': 'vertical', + 'horizontal_spacing': '8px', + 'vertical_spacing': '8px', + 'horizontal_align': 'left', + 'vertical_align': 'top', + 'margin': '0px 0px 0px 0px', + 'weight': 1, + }, + { + 'tag': 'column', + 'width': '20px', + 'elements': [ + { + 'tag': 'button', + 'text': {'tag': 'plain_text', 'content': ''}, + 'type': 'text', + 'width': 'fill', + 'size': 'medium', + 'icon': {'tag': 'standard_icon', 'token': 'thumbsup_outlined'}, + 'hover_tips': {'tag': 'plain_text', 'content': '有帮助'}, + 'margin': '0px 0px 0px 0px', + } + ], + 'padding': '0px 0px 0px 0px', + 'direction': 'vertical', + 'horizontal_spacing': '8px', + 'vertical_spacing': '8px', + 'horizontal_align': 'left', + 'vertical_align': 'top', + 'margin': '0px 0px 0px 0px', + }, + { + 'tag': 'column', + 'width': '30px', + 'elements': [ + { + 'tag': 'button', + 'text': {'tag': 'plain_text', 'content': ''}, + 'type': 'text', + 'width': 'default', + 'size': 'medium', + 'icon': {'tag': 'standard_icon', 'token': 'thumbdown_outlined'}, + 'hover_tips': {'tag': 'plain_text', 'content': '无帮助'}, + 'margin': '0px 0px 0px 0px', + } + ], + 'padding': '0px 0px 0px 0px', + 'vertical_spacing': '8px', + 'horizontal_align': 'left', + 'vertical_align': 'top', + 'margin': '0px 0px 0px 0px', + }, + ], + 'margin': '0px 0px 4px 0px', + }, + ], + }, + } # delay / fast 创建卡片模板,delay 延迟打印,fast 实时打印,可以自定义更好看的消息模板 request: CreateCardRequest = ( @@ -592,15 +620,13 @@ class LarkAdapter(adapter.MessagePlatformAdapter): f'client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' ) - self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}') self.card_id_dict[message_id] = response.data.card_id card_id = response.data.card_id return card_id except Exception as e: - self.ap.logger.error(f'飞书卡片创建失败,错误信息: {e}') - + raise e async def create_message_card(self, message_id, event) -> str: """ 创建卡片消息。 @@ -612,7 +638,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter): content = { 'type': 'card', 'data': {'card_id': card_id, 'template_variable': {'content': 'Thinking...'}}, - } # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档 + } # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档 request: ReplyMessageRequest = ( ReplyMessageRequest.builder() .message_id(event.message_chain.message_id) @@ -685,10 +711,8 @@ class LarkAdapter(adapter.MessagePlatformAdapter): message_id = bot_message.resp_message_id msg_seq = bot_message.msg_sequence if msg_seq % 8 == 0 or is_final: - lark_message = await self.message_converter.yiri2target(message, self.api_client) - text_message = '' for ele in lark_message[0]: if ele['tag'] == 'text': @@ -734,14 +758,18 @@ class LarkAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): self.listeners[event_type] = callback def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): self.listeners.pop(event_type) @@ -778,4 +806,4 @@ class LarkAdapter(adapter.MessagePlatformAdapter): # 所以要设置_auto_reconnect=False,让其不重连。 self.bot._auto_reconnect = False await self.bot._disconnect() - return False \ No newline at end of file + return False diff --git a/pkg/platform/sources/gewechat.png b/pkg/platform/sources/legacy/gewechat.png similarity index 100% rename from pkg/platform/sources/gewechat.png rename to pkg/platform/sources/legacy/gewechat.png diff --git a/pkg/platform/sources/gewechat.py b/pkg/platform/sources/legacy/gewechat.py similarity index 95% rename from pkg/platform/sources/gewechat.py rename to pkg/platform/sources/legacy/gewechat.py index 01d9f946..cd5dcf22 100644 --- a/pkg/platform/sources/gewechat.py +++ b/pkg/platform/sources/legacy/gewechat.py @@ -11,19 +11,19 @@ import threading import quart import aiohttp -from .. import adapter -from ...core import app -from ..types import message as platform_message -from ..types import events as platform_events -from ..types import entities as platform_entities -from ...utils import image +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +from ....core import app +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +from ....utils import image import xml.etree.ElementTree as ET from typing import Optional, Tuple from functools import partial -from ..logger import EventLogger +from ...logger import EventLogger -class GewechatMessageConverter(adapter.MessageConverter): +class GewechatMessageConverter(abstract_platform_adapter.AbstractMessageConverter): def __init__(self, config: dict): self.config = config @@ -398,7 +398,7 @@ class GewechatMessageConverter(adapter.MessageConverter): return from_user_name.endswith('@chatroom') -class GewechatEventConverter(adapter.EventConverter): +class GewechatEventConverter(abstract_platform_adapter.AbstractEventConverter): def __init__(self, config: dict): self.config = config self.message_converter = GewechatMessageConverter(config) @@ -458,7 +458,7 @@ class GewechatEventConverter(adapter.EventConverter): ) -class GeWeChatAdapter(adapter.MessagePlatformAdapter): +class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): name: str = 'gewechat' # 定义适配器名称 bot: gewechat_client.GewechatClient @@ -475,7 +475,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): listeners: typing.Dict[ typing.Type[platform_events.Event], - typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None], ] = {} def __init__(self, config: dict, ap: app.Application, logger: EventLogger): @@ -491,7 +491,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): async def gewechat_callback(): data = await quart.request.json # print(json.dumps(data, indent=4, ensure_ascii=False)) - self.ap.logger.debug(f'Gewechat callback event: {data}') + await self.logger.debug(f'Gewechat callback event: {data}') if 'data' in data: data['Data'] = data['data'] @@ -601,7 +601,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): if handler := handler_map.get(msg['type']): handler(msg) else: - self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') + await self.logger.warning(f'未处理的消息类型: {msg["type"]}') continue async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): @@ -625,14 +625,18 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): self.listeners[event_type] = callback def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): pass @@ -656,9 +660,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): self.config['app_id'] = app_id - self.ap.logger.info(f'Gewechat 登录成功,app_id: {app_id}') - - self.ap.platform_mgr.write_back_config('gewechat', self, self.config) + print(f'Gewechat 登录成功,app_id: {app_id}') # 获取 nickname profile = self.bot.get_profile(self.config['app_id']) diff --git a/pkg/platform/sources/gewechat.yaml b/pkg/platform/sources/legacy/gewechat.yaml similarity index 100% rename from pkg/platform/sources/gewechat.yaml rename to pkg/platform/sources/legacy/gewechat.yaml diff --git a/pkg/platform/sources/nakuru.png b/pkg/platform/sources/legacy/nakuru.png similarity index 100% rename from pkg/platform/sources/nakuru.png rename to pkg/platform/sources/legacy/nakuru.png diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/legacy/nakuru.py similarity index 92% rename from pkg/platform/sources/nakuru.py rename to pkg/platform/sources/legacy/nakuru.py index 16ad54db..52609792 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/legacy/nakuru.py @@ -9,15 +9,15 @@ import traceback import nakuru import nakuru.entities.components as nkc -from .. import adapter as adapter_model -from ...pipeline.longtext.strategies import forward -from ...platform.types import message as platform_message -from ...platform.types import entities as platform_entities -from ...platform.types import events as platform_events -from ..logger import EventLogger +from ....pipeline.longtext.strategies import forward +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +from ...logger import EventLogger -class NakuruProjectMessageConverter(adapter_model.MessageConverter): +class NakuruProjectMessageConverter(abstract_platform_adapter.AbstractMessageConverter): """消息转换器""" @staticmethod @@ -109,7 +109,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): return chain -class NakuruProjectEventConverter(adapter_model.EventConverter): +class NakuruProjectEventConverter(abstract_platform_adapter.AbstractEventConverter): """事件转换器""" @staticmethod @@ -164,7 +164,7 @@ class NakuruProjectEventConverter(adapter_model.EventConverter): raise Exception('未支持转换的事件类型: ' + str(event)) -class NakuruAdapter(adapter_model.MessagePlatformAdapter): +class NakuruAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): """nakuru-project适配器""" bot: nakuru.CQHTTP @@ -256,13 +256,15 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): try: source_cls = NakuruProjectEventConverter.yiri2target(event_type) # 包装函数 - async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): + async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): # type: ignore await callback(self.event_converter.target2yiri(source), self) # 将包装函数和原函数的对应关系存入列表 @@ -283,7 +285,9 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): nakuru_event_name = self.event_converter.yiri2target(event_type).__name__ @@ -322,7 +326,6 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): except Exception: raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确') await self.bot._run() - self.ap.logger.info('运行 Nakuru 适配器') while True: await asyncio.sleep(1) diff --git a/pkg/platform/sources/nakuru.yaml b/pkg/platform/sources/legacy/nakuru.yaml similarity index 100% rename from pkg/platform/sources/nakuru.yaml rename to pkg/platform/sources/legacy/nakuru.yaml diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/legacy/qqbotpy.py similarity index 94% rename from pkg/platform/sources/qqbotpy.py rename to pkg/platform/sources/legacy/qqbotpy.py index d4a4d526..90e4c2d7 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/legacy/qqbotpy.py @@ -10,14 +10,14 @@ import botpy import botpy.message as botpy_message import botpy.types.message as botpy_message_type -from .. import adapter as adapter_model -from ...pipeline.longtext.strategies import forward -from ...core import app -from ...config import manager as cfg_mgr -from ...platform.types import entities as platform_entities -from ...platform.types import events as platform_events -from ...platform.types import message as platform_message -from ..logger import EventLogger +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +from ....pipeline.longtext.strategies import forward +from ....core import app +from ....config import manager as cfg_mgr +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.message as platform_message +from ...logger import EventLogger class OfficialGroupMessage(platform_events.GroupMessage): @@ -133,7 +133,7 @@ class OpenIDMapping(typing.Generic[K, V]): return value -class OfficialMessageConverter(adapter_model.MessageConverter): +class OfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter): """QQ 官方消息转换器""" @staticmethod @@ -237,7 +237,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter): return chain -class OfficialEventConverter(adapter_model.EventConverter): +class OfficialEventConverter(abstract_platform_adapter.AbstractEventConverter): """事件转换器""" def __init__(self): @@ -333,7 +333,7 @@ class OfficialEventConverter(adapter_model.EventConverter): ) -class OfficialAdapter(adapter_model.MessagePlatformAdapter): +class OfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): """QQ 官方消息适配器""" bot: botpy.Client = None @@ -484,7 +484,9 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): try: @@ -507,7 +509,9 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): delattr(self.bot, event_handler_mapping[event_type]) @@ -519,7 +523,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): self.cfg['ret_coro'] = True - self.ap.logger.info('运行 QQ 官方适配器') + await self.logger.info('运行 QQ 官方适配器') await (await self.bot.start(**self.cfg)) async def kill(self) -> bool: diff --git a/pkg/platform/sources/qqbotpy.svg b/pkg/platform/sources/legacy/qqbotpy.svg similarity index 100% rename from pkg/platform/sources/qqbotpy.svg rename to pkg/platform/sources/legacy/qqbotpy.svg diff --git a/pkg/platform/sources/qqbotpy.yaml b/pkg/platform/sources/legacy/qqbotpy.yaml similarity index 100% rename from pkg/platform/sources/qqbotpy.yaml rename to pkg/platform/sources/legacy/qqbotpy.yaml diff --git a/pkg/platform/sources/officialaccount.py b/pkg/platform/sources/officialaccount.py index 3fc1e393..01a2c868 100644 --- a/pkg/platform/sources/officialaccount.py +++ b/pkg/platform/sources/officialaccount.py @@ -4,19 +4,18 @@ import asyncio import traceback import datetime -from pkg.platform.adapter import MessagePlatformAdapter -from pkg.platform.types import events as platform_events, message as platform_message +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter from libs.official_account_api.oaevent import OAEvent from libs.official_account_api.api import OAClient from libs.official_account_api.api import OAClientForLongerResponse -from .. import adapter -from ...core import app -from ..types import entities as platform_entities -from ...command.errors import ParamNotEnoughError +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +from langbot_plugin.api.entities.builtin.command import errors as command_errors from ..logger import EventLogger -class OAMessageConverter(adapter.MessageConverter): +class OAMessageConverter(abstract_platform_adapter.AbstractMessageConverter): @staticmethod async def yiri2target(message_chain: platform_message.MessageChain): for msg in message_chain: @@ -34,7 +33,7 @@ class OAMessageConverter(adapter.MessageConverter): return chain -class OAEventConverter(adapter.EventConverter): +class OAEventConverter(abstract_platform_adapter.AbstractEventConverter): @staticmethod async def target2yiri(event: OAEvent): if event.type == 'text': @@ -56,17 +55,15 @@ class OAEventConverter(adapter.EventConverter): return None -class OfficialAccountAdapter(adapter.MessagePlatformAdapter): +class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): bot: OAClient | OAClientForLongerResponse - ap: app.Application bot_account_id: str message_converter: OAMessageConverter = OAMessageConverter() event_converter: OAEventConverter = OAEventConverter() config: dict - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger required_keys = [ @@ -78,7 +75,7 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员') + raise command_errors.ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员') if self.config['Mode'] == 'drop': self.bot = OAClient( @@ -119,7 +116,9 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: type, - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): async def on_message(event: OAEvent): self.bot_account_id = event.receiver_id @@ -150,6 +149,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): async def unregister_listener( self, event_type: type, - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/qqofficial.py b/pkg/platform/sources/qqofficial.py index 63ab531f..28a09d8c 100644 --- a/pkg/platform/sources/qqofficial.py +++ b/pkg/platform/sources/qqofficial.py @@ -5,19 +5,18 @@ import traceback import datetime -from pkg.platform.adapter import MessagePlatformAdapter -from pkg.platform.types import events as platform_events, message as platform_message -from .. import adapter -from ...core import app -from ..types import entities as platform_entities -from ...command.errors import ParamNotEnoughError +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +from langbot_plugin.api.entities.builtin.command import errors as command_errors from libs.qq_official_api.api import QQOfficialClient from libs.qq_official_api.qqofficialevent import QQOfficialEvent from ...utils import image from ..logger import EventLogger -class QQOfficialMessageConverter(adapter.MessageConverter): +class QQOfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter): @staticmethod async def yiri2target(message_chain: platform_message.MessageChain): content_list = [] @@ -46,7 +45,7 @@ class QQOfficialMessageConverter(adapter.MessageConverter): return chain -class QQOfficialEventConverter(adapter.EventConverter): +class QQOfficialEventConverter(abstract_platform_adapter.AbstractEventConverter): @staticmethod async def yiri2target(event: platform_events.MessageEvent) -> QQOfficialEvent: return event.source_platform_object @@ -132,17 +131,15 @@ class QQOfficialEventConverter(adapter.EventConverter): ) -class QQOfficialAdapter(adapter.MessagePlatformAdapter): +class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): bot: QQOfficialClient - ap: app.Application config: dict bot_account_id: str message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter() event_converter: QQOfficialEventConverter = QQOfficialEventConverter() - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger required_keys = [ @@ -151,7 +148,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员') + raise command_errors.ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员') self.bot = QQOfficialClient( app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger @@ -215,7 +212,9 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): async def on_message(event: QQOfficialEvent): self.bot_account_id = 'justbot' @@ -248,6 +247,8 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): def unregister_listener( self, event_type: type, - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/slack.py b/pkg/platform/sources/slack.py index 1bd5aa2d..e08cc8c0 100644 --- a/pkg/platform/sources/slack.py +++ b/pkg/platform/sources/slack.py @@ -6,18 +6,17 @@ import traceback import datetime from libs.slack_api.api import SlackClient -from pkg.platform.adapter import MessagePlatformAdapter -from pkg.platform.types import events as platform_events, message as platform_message +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter from libs.slack_api.slackevent import SlackEvent -from pkg.core import app -from .. import adapter -from ..types import entities as platform_entities -from ...command.errors import ParamNotEnoughError +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +from langbot_plugin.api.entities.builtin.command import errors as command_errors from ...utils import image from ..logger import EventLogger -class SlackMessageConverter(adapter.MessageConverter): +class SlackMessageConverter(abstract_platform_adapter.AbstractMessageConverter): @staticmethod async def yiri2target(message_chain: platform_message.MessageChain): content_list = [] @@ -44,7 +43,7 @@ class SlackMessageConverter(adapter.MessageConverter): return chain -class SlackEventConverter(adapter.EventConverter): +class SlackEventConverter(abstract_platform_adapter.AbstractEventConverter): @staticmethod async def yiri2target(event: platform_events.MessageEvent) -> SlackEvent: return event.source_platform_object @@ -84,17 +83,15 @@ class SlackEventConverter(adapter.EventConverter): ) -class SlackAdapter(adapter.MessagePlatformAdapter): +class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): bot: SlackClient - ap: app.Application bot_account_id: str message_converter: SlackMessageConverter = SlackMessageConverter() event_converter: SlackEventConverter = SlackEventConverter() config: dict - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger required_keys = [ 'bot_token', @@ -102,7 +99,7 @@ class SlackAdapter(adapter.MessagePlatformAdapter): ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员') + raise command_errors.ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员') self.bot = SlackClient( bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger @@ -135,7 +132,9 @@ class SlackAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): async def on_message(event: SlackEvent): self.bot_account_id = 'SlackBot' @@ -166,6 +165,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter): async def unregister_listener( self, event_type: type, - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/telegram.py b/pkg/platform/sources/telegram.py index 8aee12d7..458d8094 100644 --- a/pkg/platform/sources/telegram.py +++ b/pkg/platform/sources/telegram.py @@ -10,18 +10,16 @@ import typing import traceback import base64 import aiohttp +import pydantic -from lark_oapi.api.im.v1 import * - -from .. import adapter -from ...core import app -from ..types import message as platform_message -from ..types import events as platform_events -from ..types import entities as platform_entities -from ..logger import EventLogger +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger -class TelegramMessageConverter(adapter.MessageConverter): +class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverter): @staticmethod async def yiri2target(message_chain: platform_message.MessageChain, bot: telegram.Bot) -> list[dict]: components = [] @@ -90,7 +88,7 @@ class TelegramMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_components) -class TelegramEventConverter(adapter.EventConverter): +class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter): @staticmethod async def yiri2target(event: platform_events.MessageEvent, bot: telegram.Bot): return event.source_platform_object @@ -132,17 +130,14 @@ class TelegramEventConverter(adapter.EventConverter): ) -class TelegramAdapter(adapter.MessagePlatformAdapter): - bot: telegram.Bot - application: telegram.ext.Application - - bot_account_id: str +class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): + bot: telegram.Bot = pydantic.Field(exclude=True) + application: telegram.ext.Application = pydantic.Field(exclude=True) message_converter: TelegramMessageConverter = TelegramMessageConverter() event_converter: TelegramEventConverter = TelegramEventConverter() config: dict - ap: app.Application msg_stream_id: dict # 流式消息id字典,key为流式消息id,value为首次消息源id,用于在流式消息时判断编辑那条消息 @@ -150,16 +145,10 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): listeners: typing.Dict[ typing.Type[platform_events.Event], - typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None], ] = {} - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): - self.config = config - self.ap = ap - self.logger = logger - self.msg_stream_id = {} - # self.seq = 1 - + def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger): async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE): if update.message.from_user.is_bot: return @@ -171,10 +160,18 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): except Exception: await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}') - self.application = ApplicationBuilder().token(self.config['token']).build() - self.bot = self.application.bot - self.application.add_handler( - MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback) + application = ApplicationBuilder().token(config['token']).build() + bot = application.bot + application.add_handler(MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback)) + super().__init__( + config=config, + logger=logger, + msg_stream_id={}, + seq=1, + bot=bot, + application=application, + bot_account_id='', + listeners={}, ) async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): @@ -278,14 +275,18 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): self.listeners[event_type] = callback def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): self.listeners.pop(event_type) diff --git a/pkg/platform/sources/webchat.py b/pkg/platform/sources/webchat.py index c43c4628..7fd54d1e 100644 --- a/pkg/platform/sources/webchat.py +++ b/pkg/platform/sources/webchat.py @@ -3,17 +3,19 @@ import logging import typing from datetime import datetime -from pydantic import BaseModel +import pydantic -from .. import adapter as msadapter -from ..types import events as platform_events, message as platform_message, entities as platform_entities +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger from ...core import app -from ..logger import EventLogger logger = logging.getLogger(__name__) -class WebChatMessage(BaseModel): +class WebChatMessage(pydantic.BaseModel): id: int role: str content: str @@ -41,30 +43,35 @@ class WebChatSession: return self.message_lists[pipeline_uuid] -class WebChatAdapter(msadapter.MessagePlatformAdapter): +class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): """WebChat调试适配器,用于流水线调试""" - webchat_person_session: WebChatSession - webchat_group_session: WebChatSession + webchat_person_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession) + webchat_group_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession) - listeners: typing.Dict[ + listeners: dict[ typing.Type[platform_events.Event], - typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], None], - ] = {} + typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None], + ] = pydantic.Field(default_factory=dict, exclude=True) - is_stream: bool + is_stream: bool = pydantic.Field(exclude=True) + debug_messages: dict[str, list[dict]] = pydantic.Field(default_factory=dict, exclude=True) - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): - self.ap = ap - self.logger = logger - self.config = config + ap: app.Application = pydantic.Field(exclude=True) + + def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs): + super().__init__( + config=config, + logger=logger, + **kwargs, + ) self.webchat_person_session = WebChatSession(id='webchatperson') self.webchat_group_session = WebChatSession(id='webchatgroup') self.bot_account_id = 'webchatbot' - self.is_stream = False + self.debug_messages = {} async def send_message( self, @@ -159,7 +166,9 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - func: typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], typing.Awaitable[None]], + func: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None] + ], ): """注册事件监听器""" self.listeners[event_type] = func @@ -167,11 +176,16 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): def unregister_listener( self, event_type: typing.Type[platform_events.Event], - func: typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], typing.Awaitable[None]], + func: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None] + ], ): """取消注册事件监听器""" del self.listeners[event_type] + async def is_muted(self, group_id: int) -> bool: + return False + async def run_async(self): """运行适配器""" await self.logger.info('WebChat调试适配器已启动') @@ -221,7 +235,7 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp())) if session_type == 'person': - sender = platform_entities.Friend(id='webchatperson', nickname='User') + sender = platform_entities.Friend(id='webchatperson', nickname='User', remark='User') event = platform_events.FriendMessage( sender=sender, message_chain=message_chain, time=datetime.now().timestamp() ) diff --git a/pkg/platform/sources/wechatpad.py b/pkg/platform/sources/wechatpad.py index 819ae400..e35bad63 100644 --- a/pkg/platform/sources/wechatpad.py +++ b/pkg/platform/sources/wechatpad.py @@ -16,24 +16,30 @@ import threading import quart -from .. import adapter -from ...core import app -from ..types import message as platform_message -from ..types import events as platform_events -from ..types import entities as platform_entities from ..logger import EventLogger import xml.etree.ElementTree as ET from typing import Optional, Tuple from functools import partial import logging +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger -class WeChatPadMessageConverter(adapter.MessageConverter): - def __init__(self, config: dict, logger: logging.Logger): +class WeChatPadMessageConverter(abstract_platform_adapter.AbstractMessageConverter): + def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger): + self.bot = WeChatPadClient(config['wechatpad_url'], config['token']) self.config = config - self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) self.logger = logger + # super().__init__( + # config = config, + # bot = bot, + # logger = logger, + # ) + @staticmethod async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]: content_list = [] @@ -447,11 +453,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return from_user_name.endswith('@chatroom') -class WeChatPadEventConverter(adapter.EventConverter): +class WeChatPadEventConverter(abstract_platform_adapter.AbstractEventConverter): def __init__(self, config: dict, logger: logging.Logger): self.config = config - self.message_converter = WeChatPadMessageConverter(config, logger) self.logger = logger + self.message_converter = WeChatPadMessageConverter(self.config, self.logger) + # super().__init__( + # config=config, + # message_converter=message_converter, + # logger = logger, + # ) @staticmethod async def yiri2target(event: platform_events.MessageEvent) -> dict: @@ -511,7 +522,7 @@ class WeChatPadEventConverter(adapter.EventConverter): ) -class WeChatPadAdapter(adapter.MessagePlatformAdapter): +class WeChatPadAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): name: str = 'WeChatPad' # 定义适配器名称 bot: WeChatPadClient @@ -521,29 +532,38 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): config: dict - ap: app.Application + logger: EventLogger message_converter: WeChatPadMessageConverter event_converter: WeChatPadEventConverter listeners: typing.Dict[ typing.Type[platform_events.Event], - typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None], ] = {} - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): - self.config = config - self.ap = ap - self.logger = logger - self.quart_app = quart.Quart(__name__) + def __init__(self, config: dict, logger: EventLogger): - self.message_converter = WeChatPadMessageConverter(config, ap.logger) - self.event_converter = WeChatPadEventConverter(config, ap.logger) + quart_app = quart.Quart(__name__) + + message_converter = WeChatPadMessageConverter(config, logger) + event_converter = WeChatPadEventConverter(config, logger) + bot = WeChatPadClient(config['wechatpad_url'], config['token']) + super().__init__( + config=config, + logger = logger, + quart_app = quart_app, + message_converter =message_converter, + event_converter = event_converter, + listeners={}, + bot_account_id ='', + name="WeChatPad", + bot=bot, + + ) async def ws_message(self, data): """处理接收到的消息""" - # self.ap.logger.debug(f"Gewechat callback event: {data}") - # print(data) try: event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id) @@ -609,9 +629,8 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if handler := handler_map.get(msg['type']): handler(msg) - # self.ap.logger.warning(f"未处理的消息类型: {ret}") else: - self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') + self.logger.warning(f'未处理的消息类型: {msg["type"]}') continue async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): @@ -635,14 +654,18 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): self.listeners[event_type] = callback def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): pass @@ -653,7 +676,6 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if self.config['token']: self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) data = self.bot.get_login_status() - self.ap.logger.info(data) if data['Code'] == 300 and data['Text'] == '你已退出微信': response = requests.post( f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', @@ -673,7 +695,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): self.config['token'] = response.json()['Data'][0] self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger) - self.ap.logger.info(self.config['token']) + await self.logger.info(self.config['token']) thread_1 = threading.Event() def wechat_login_process(): @@ -681,10 +703,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # login_data =self.bot.get_login_qr() # url = login_data['Data']["QrCodeUrl"] - # self.ap.logger.info(login_data) profile = self.bot.get_profile() - self.ap.logger.info(profile) + # self.logger.info(profile) self.bot_account_id = profile['Data']['userInfo']['nickName']['str'] self.config['wxid'] = profile['Data']['userInfo']['userName']['str'] @@ -696,27 +717,26 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): def connect_websocket_sync() -> None: thread_1.wait() uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}' - self.ap.logger.info(f'Connecting to WebSocket: {uri}') + print(f'Connecting to WebSocket: {uri}') def on_message(ws, message): try: data = json.loads(message) - self.ap.logger.debug(f'Received message: {data}') # 这里需要确保ws_message是同步的,或者使用asyncio.run调用异步方法 asyncio.run(self.ws_message(data)) except json.JSONDecodeError: - self.ap.logger.error(f'Non-JSON message: {message[:100]}...') + self.logger.error(f'Non-JSON message: {message[:100]}...') def on_error(ws, error): - self.ap.logger.error(f'WebSocket error: {str(error)[:200]}') + self.logger.error(f'WebSocket error: {str(error)[:200]}') def on_close(ws, close_status_code, close_msg): - self.ap.logger.info('WebSocket closed, reconnecting...') + self.logger.info('WebSocket closed, reconnecting...') time.sleep(5) connect_websocket_sync() # 自动重连 def on_open(ws): - self.ap.logger.info('WebSocket connected successfully!') + self.logger.info('WebSocket connected successfully!') ws = websocket.WebSocketApp( uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open @@ -727,10 +747,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # connect_websocket_sync() # 这行代码会在WebSocket连接断开后才会执行 - # self.ap.logger.info("WebSocket client thread started") thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True) thread.start() - self.ap.logger.info('WebSocket client thread started') + self.logger.info('WebSocket client thread started') async def kill(self) -> bool: pass diff --git a/pkg/platform/sources/wecom.py b/pkg/platform/sources/wecom.py index 7be05a85..392db801 100644 --- a/pkg/platform/sources/wecom.py +++ b/pkg/platform/sources/wecom.py @@ -6,18 +6,17 @@ import traceback import datetime from libs.wecom_api.api import WecomClient -from pkg.platform.adapter import MessagePlatformAdapter -from pkg.platform.types import events as platform_events, message as platform_message +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter from libs.wecom_api.wecomevent import WecomEvent -from .. import adapter -from ...core import app -from ..types import entities as platform_entities -from ...command.errors import ParamNotEnoughError +from langbot_plugin.api.entities.builtin.command import errors as command_errors from ...utils import image from ..logger import EventLogger +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities -class WecomMessageConverter(adapter.MessageConverter): +class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter): @staticmethod async def yiri2target(message_chain: platform_message.MessageChain, bot: WecomClient): content_list = [] @@ -71,7 +70,7 @@ class WecomMessageConverter(adapter.MessageConverter): return chain -class WecomEventConverter: +class WecomEventConverter(abstract_platform_adapter.AbstractEventConverter): @staticmethod async def yiri2target(event: platform_events.Event, bot_account_id: int, bot: WecomClient) -> WecomEvent: # only for extracting user information @@ -127,17 +126,15 @@ class WecomEventConverter: return platform_events.FriendMessage(sender=friend, message_chain=yiri_chain, time=event.timestamp) -class WecomAdapter(adapter.MessagePlatformAdapter): +class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): bot: WecomClient - ap: app.Application bot_account_id: str message_converter: WecomMessageConverter = WecomMessageConverter() event_converter: WecomEventConverter = WecomEventConverter() config: dict - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger required_keys = [ @@ -149,7 +146,7 @@ class WecomAdapter(adapter.MessagePlatformAdapter): ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError('企业微信缺少相关配置项,请查看文档或联系管理员') + raise command_errors.ParamNotEnoughError('企业微信缺少相关配置项,请查看文档或联系管理员') self.bot = WecomClient( corpid=config['corpid'], @@ -195,7 +192,9 @@ class WecomAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): async def on_message(event: WecomEvent): self.bot_account_id = event.receiver_id @@ -227,6 +226,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter): async def unregister_listener( self, event_type: type, - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/wecomcs.py b/pkg/platform/sources/wecomcs.py index da84ac6d..7ce3a064 100644 --- a/pkg/platform/sources/wecomcs.py +++ b/pkg/platform/sources/wecomcs.py @@ -4,19 +4,19 @@ import asyncio import traceback import datetime +import pydantic from libs.wecom_customer_service_api.api import WecomCSClient -from pkg.platform.adapter import MessagePlatformAdapter -from pkg.platform.types import events as platform_events, message as platform_message +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter from libs.wecom_customer_service_api.wecomcsevent import WecomCSEvent -from pkg.core import app -from .. import adapter -from ..types import entities as platform_entities -from ...command.errors import ParamNotEnoughError -from ..logger import EventLogger +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +from langbot_plugin.api.entities.builtin.command import errors as command_errors +import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger -class WecomMessageConverter(adapter.MessageConverter): +class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter): @staticmethod async def yiri2target(message_chain: platform_message.MessageChain, bot: WecomCSClient): content_list = [] @@ -69,7 +69,7 @@ class WecomMessageConverter(adapter.MessageConverter): return chain -class WecomEventConverter: +class WecomEventConverter(abstract_platform_adapter.AbstractEventConverter): @staticmethod async def yiri2target(event: platform_events.Event, bot_account_id: int, bot: WecomCSClient) -> WecomCSEvent: # only for extracting user information @@ -117,19 +117,12 @@ class WecomEventConverter: ) -class WecomCSAdapter(adapter.MessagePlatformAdapter): - bot: WecomCSClient - ap: app.Application - bot_account_id: str +class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): + bot: WecomCSClient = pydantic.Field(exclude=True) message_converter: WecomMessageConverter = WecomMessageConverter() event_converter: WecomEventConverter = WecomEventConverter() - config: dict - - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): - self.config = config - self.ap = ap - self.logger = logger + def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger): required_keys = [ 'corpid', 'secret', @@ -138,14 +131,22 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError('企业微信客服缺少相关配置项,请查看文档或联系管理员') + raise command_errors.ParamNotEnoughError('企业微信客服缺少相关配置项,请查看文档或联系管理员') - self.bot = WecomCSClient( + bot = WecomCSClient( corpid=config['corpid'], secret=config['secret'], token=config['token'], EncodingAESKey=config['EncodingAESKey'], - logger=self.logger, + logger=logger, + ) + + super().__init__( + config=config, + logger=logger, + bot_account_id='', + listeners={}, + bot=bot, ) async def reply_message( @@ -172,7 +173,9 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): async def on_message(event: WecomCSEvent): self.bot_account_id = event.receiver_id @@ -201,9 +204,14 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): async def kill(self) -> bool: return False + async def is_muted(self, group_id: int) -> bool: + return False + async def unregister_listener( self, event_type: type, - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/types/__init__.py b/pkg/platform/types/__init__.py deleted file mode 100644 index 998b0fb8..00000000 --- a/pkg/platform/types/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .entities import * -from .events import * -from .message import * diff --git a/pkg/platform/types/base.py b/pkg/platform/types/base.py deleted file mode 100644 index da58d4ed..00000000 --- a/pkg/platform/types/base.py +++ /dev/null @@ -1,107 +0,0 @@ -from typing import Dict, List, Type - -import pydantic.v1.main as pdm -from pydantic.v1 import BaseModel - - -class PlatformMetaclass(pdm.ModelMetaclass): - """此类是平台中使用的 pydantic 模型的元类的基类。""" - - -def to_camel(name: str) -> str: - """将下划线命名风格转换为小驼峰命名。""" - if name[:2] == '__': # 不处理双下划线开头的特殊命名。 - return name - name_parts = name.split('_') - return ''.join(name_parts[:1] + [x.title() for x in name_parts[1:]]) - - -class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass): - """模型基类。 - - 启用了三项配置: - 1. 允许解析时传入额外的值,并将额外值保存在模型中。 - 2. 允许通过别名访问字段。 - 3. 自动生成小驼峰风格的别名。 - """ - - def __init__(self, *args, **kwargs): - """""" - super().__init__(*args, **kwargs) - - def __repr__(self) -> str: - return ( - self.__class__.__name__ + '(' + ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if v)) + ')' - ) - - class Config: - extra = 'allow' - allow_population_by_field_name = True - alias_generator = to_camel - - -class PlatformIndexedMetaclass(PlatformMetaclass): - """可以通过子类名获取子类的类的元类。""" - - __indexedbases__: List[Type['PlatformIndexedModel']] = [] - __indexedmodel__ = None - - def __new__(cls, name, bases, attrs, **kwargs): - new_cls = super().__new__(cls, name, bases, attrs, **kwargs) - # 第一类:PlatformIndexedModel - if name == 'PlatformIndexedModel': - cls.__indexedmodel__ = new_cls - new_cls.__indexes__ = {} - return new_cls - # 第二类:PlatformIndexedModel 的直接子类,这些是可以通过子类名获取子类的类。 - if cls.__indexedmodel__ in bases: - cls.__indexedbases__.append(new_cls) - new_cls.__indexes__ = {} - return new_cls - # 第三类:PlatformIndexedModel 的直接子类的子类,这些添加到直接子类的索引中。 - for base in cls.__indexedbases__: - if issubclass(new_cls, base): - base.__indexes__[name] = new_cls - return new_cls - - def __getitem__(cls, name): - return cls.get_subtype(name) - - -class PlatformIndexedModel(PlatformBaseModel, metaclass=PlatformIndexedMetaclass): - """可以通过子类名获取子类的类。""" - - __indexes__: Dict[str, Type['PlatformIndexedModel']] - - @classmethod - def get_subtype(cls, name: str) -> Type['PlatformIndexedModel']: - """根据类名称,获取相应的子类类型。 - - Args: - name: 类名称。 - - Returns: - Type['PlatformIndexedModel']: 子类类型。 - """ - try: - type_ = cls.__indexes__.get(name) - if not (type_ and issubclass(type_, cls)): - raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!') - return type_ - except AttributeError: - raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!') from None - - @classmethod - def parse_subtype(cls, obj: dict) -> 'PlatformIndexedModel': - """通过字典,构造对应的模型对象。 - - Args: - obj: 一个字典,包含了模型对象的属性。 - - Returns: - PlatformIndexedModel: 构造的对象。 - """ - if cls in PlatformIndexedModel.__subclasses__(): - ModelType = cls.get_subtype(obj['type']) - return ModelType.parse_obj(obj) - return super().parse_obj(obj) diff --git a/pkg/platform/types/entities.py b/pkg/platform/types/entities.py deleted file mode 100644 index d989ffce..00000000 --- a/pkg/platform/types/entities.py +++ /dev/null @@ -1,88 +0,0 @@ -# -*- coding: utf-8 -*- -""" -此模块提供实体和配置项模型。 -""" - -import abc -from datetime import datetime -from enum import Enum -import typing - -import pydantic.v1 as pydantic - - -class Entity(pydantic.BaseModel): - """实体,表示一个用户或群。""" - - id: int - """ID。""" - - @abc.abstractmethod - def get_name(self) -> str: - """名称。""" - - -class Friend(Entity): - """私聊对象。""" - - id: typing.Union[int, str] - """ID。""" - nickname: typing.Optional[str] - """昵称。""" - remark: typing.Optional[str] - """备注。""" - - def get_name(self) -> str: - return self.nickname or self.remark or '' - - -class Permission(str, Enum): - """群成员身份权限。""" - - Member = 'MEMBER' - """成员。""" - Administrator = 'ADMINISTRATOR' - """管理员。""" - Owner = 'OWNER' - """群主。""" - - def __repr__(self) -> str: - return repr(self.value) - - -class Group(Entity): - """群。""" - - id: typing.Union[int, str] - """群号。""" - name: str - """群名称。""" - permission: Permission - """Bot 在群中的权限。""" - - def get_name(self) -> str: - return self.name - - -class GroupMember(Entity): - """群成员。""" - - id: typing.Union[int, str] - """群员 ID。""" - member_name: str - """群员名称。""" - permission: Permission - """在群中的权限。""" - group: Group - """群。""" - special_title: str = '' - """群头衔。""" - join_timestamp: datetime = datetime.utcfromtimestamp(0) - """加入群的时间。""" - last_speak_timestamp: datetime = datetime.utcfromtimestamp(0) - """最后一次发言的时间。""" - mute_time_remaining: int = 0 - """禁言剩余时间。""" - - def get_name(self) -> str: - return self.member_name diff --git a/pkg/platform/types/events.py b/pkg/platform/types/events.py deleted file mode 100644 index 5ffccb9b..00000000 --- a/pkg/platform/types/events.py +++ /dev/null @@ -1,106 +0,0 @@ -# -*- coding: utf-8 -*- -""" -此模块提供事件模型。 -""" - -import typing - -import pydantic.v1 as pydantic - -from . import entities as platform_entities -from . import message as platform_message - - -class Event(pydantic.BaseModel): - """事件基类。 - - Args: - type: 事件名。 - """ - - type: str - """事件名。""" - - def __repr__(self): - return ( - self.__class__.__name__ - + '(' - + ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if k != 'type' and v)) - + ')' - ) - - @classmethod - def parse_subtype(cls, obj: dict) -> 'Event': - try: - return typing.cast(Event, super().parse_subtype(obj)) - except ValueError: - return Event(type=obj['type']) - - @classmethod - def get_subtype(cls, name: str) -> typing.Type['Event']: - try: - return typing.cast(typing.Type[Event], super().get_subtype(name)) - except ValueError: - return Event - - -############################### -# Message Event -class MessageEvent(Event): - """消息事件。 - - Args: - type: 事件名。 - message_chain: 消息内容。 - """ - - type: str - """事件名。""" - message_chain: platform_message.MessageChain - """消息内容。""" - - time: float | None = None - """消息发送时间戳。""" - - source_platform_object: typing.Optional[typing.Any] = None - """原消息平台对象。 - 供消息平台适配器开发者使用,如果回复用户时需要使用原消息事件对象的信息, - 那么可以将其存到这个字段以供之后取出使用。""" - - -class FriendMessage(MessageEvent): - """私聊消息。 - - Args: - type: 事件名。 - sender: 发送消息的好友。 - message_chain: 消息内容。 - """ - - type: str = 'FriendMessage' - """事件名。""" - sender: platform_entities.Friend - """发送消息的好友。""" - message_chain: platform_message.MessageChain - """消息内容。""" - - -class GroupMessage(MessageEvent): - """群消息。 - - Args: - type: 事件名。 - sender: 发送消息的群成员。 - message_chain: 消息内容。 - """ - - type: str = 'GroupMessage' - """事件名。""" - sender: platform_entities.GroupMember - """发送消息的群成员。""" - message_chain: platform_message.MessageChain - """消息内容。""" - - @property - def group(self) -> platform_entities.Group: - return self.sender.group diff --git a/pkg/platform/types/message.py b/pkg/platform/types/message.py deleted file mode 100644 index ecd7cc96..00000000 --- a/pkg/platform/types/message.py +++ /dev/null @@ -1,978 +0,0 @@ -import itertools -import logging -import typing -from datetime import datetime -from pathlib import Path -import base64 - -import aiofiles -import httpx -import pydantic.v1 as pydantic - -from . import entities as platform_entities -from .base import PlatformBaseModel, PlatformIndexedMetaclass, PlatformIndexedModel - -logger = logging.getLogger(__name__) - - -class MessageComponentMetaclass(PlatformIndexedMetaclass): - """消息组件元类。""" - - __message_component__ = None - - def __new__(cls, name, bases, attrs, **kwargs): - new_cls = super().__new__(cls, name, bases, attrs, **kwargs) - if name == 'MessageComponent': - cls.__message_component__ = new_cls - - if not cls.__message_component__: - return new_cls - - for base in bases: - if issubclass(base, cls.__message_component__): - # 获取字段名 - if hasattr(new_cls, '__fields__'): - # 忽略 type 字段 - new_cls.__parameter_names__ = list(new_cls.__fields__)[1:] - else: - new_cls.__parameter_names__ = [] - break - - return new_cls - - -class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass): - """消息组件。""" - - type: str - """消息组件类型。""" - - def __str__(self): - return '' - - def __repr__(self): - return ( - self.__class__.__name__ - + '(' - + ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if k != 'type' and v)) - + ')' - ) - - def __init__(self, *args, **kwargs): - # 解析参数列表,将位置参数转化为具名参数 - parameter_names = self.__parameter_names__ - if len(args) > len(parameter_names): - raise TypeError(f'`{self.type}`需要{len(parameter_names)}个参数,但传入了{len(args)}个。') - for name, value in zip(parameter_names, args): - if name in kwargs: - raise TypeError(f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。') - kwargs[name] = value - - super().__init__(**kwargs) - - -TMessageComponent = typing.TypeVar('TMessageComponent', bound=MessageComponent) - - -class MessageChain(PlatformBaseModel): - """消息链。 - - 一个构造消息链的例子: - ```py - message_chain = MessageChain([ - AtAll(), - Plain("Hello World!"), - ]) - ``` - - `Plain` 可以省略。 - ```py - message_chain = MessageChain([ - AtAll(), - "Hello World!", - ]) - ``` - - 在调用 API 时,参数中需要 MessageChain 的,也可以使用 `List[MessageComponent]` 代替。 - 例如,以下两种写法是等价的: - ```py - await bot.send_friend_message(12345678, [ - Plain("Hello World!") - ]) - ``` - ```py - await bot.send_friend_message(12345678, MessageChain([ - Plain("Hello World!") - ])) - ``` - - 可以使用 `in` 运算检查消息链中: - 1. 是否有某个消息组件。 - 2. 是否有某个类型的消息组件。 - - ```py - if AtAll in message_chain: - print('AtAll') - - if At(bot.qq) in message_chain: - print('At Me') - ``` - - """ - - __root__: typing.List[MessageComponent] - - @staticmethod - def _parse_message_chain(msg_chain: typing.Iterable): - result = [] - for msg in msg_chain: - if isinstance(msg, dict): - result.append(MessageComponent.parse_subtype(msg)) - elif isinstance(msg, MessageComponent): - result.append(msg) - elif isinstance(msg, str): - result.append(Plain(msg)) - else: - raise TypeError(f'消息链中元素需为 dict 或 str 或 MessageComponent,当前类型:{type(msg)}') - return result - - @pydantic.validator('__root__', always=True, pre=True) - def _parse_component(cls, msg_chain): - if isinstance(msg_chain, (str, MessageComponent)): - msg_chain = [msg_chain] - if not msg_chain: - msg_chain = [] - return cls._parse_message_chain(msg_chain) - - @classmethod - def parse_obj(cls, msg_chain: typing.Iterable): - """通过列表形式的消息链,构造对应的 `MessageChain` 对象。 - - Args: - msg_chain: 列表形式的消息链。 - """ - result = cls._parse_message_chain(msg_chain) - return cls(__root__=result) - - def __init__(self, __root__: typing.Iterable[MessageComponent] = None): - super().__init__(__root__=__root__) - - def __str__(self): - return ''.join(str(component) for component in self.__root__) - - def __repr__(self): - return f'{self.__class__.__name__}({self.__root__!r})' - - def __iter__(self): - yield from self.__root__ - - def get_first(self, t: typing.Type[TMessageComponent]) -> typing.Optional[TMessageComponent]: - """获取消息链中第一个符合类型的消息组件。""" - for component in self: - if isinstance(component, t): - return component - return None - - @typing.overload - def __getitem__(self, index: int) -> MessageComponent: ... - - @typing.overload - def __getitem__(self, index: slice) -> typing.List[MessageComponent]: ... - - @typing.overload - def __getitem__(self, index: typing.Type[TMessageComponent]) -> typing.List[TMessageComponent]: ... - - @typing.overload - def __getitem__( - self, index: typing.Tuple[typing.Type[TMessageComponent], int] - ) -> typing.List[TMessageComponent]: ... - - def __getitem__( - self, - index: typing.Union[ - int, - slice, - typing.Type[TMessageComponent], - typing.Tuple[typing.Type[TMessageComponent], int], - ], - ) -> typing.Union[MessageComponent, typing.List[MessageComponent], typing.List[TMessageComponent]]: - return self.get(index) - - def __setitem__( - self, - key: typing.Union[int, slice], - value: typing.Union[MessageComponent, str, typing.Iterable[typing.Union[MessageComponent, str]]], - ): - if isinstance(value, str): - value = Plain(value) - if isinstance(value, typing.Iterable): - value = (Plain(c) if isinstance(c, str) else c for c in value) - self.__root__[key] = value # type: ignore - - def __delitem__(self, key: typing.Union[int, slice]): - del self.__root__[key] - - def __reversed__(self) -> typing.Iterable[MessageComponent]: - return reversed(self.__root__) - - def has( - self, - sub: typing.Union[MessageComponent, typing.Type[MessageComponent], 'MessageChain', str], - ) -> bool: - """判断消息链中: - 1. 是否有某个消息组件。 - 2. 是否有某个类型的消息组件。 - - Args: - sub (`Union[MessageComponent, Type[MessageComponent], 'MessageChain', str]`): - 若为 `MessageComponent`,则判断该组件是否在消息链中。 - 若为 `Type[MessageComponent]`,则判断该组件类型是否在消息链中。 - - Returns: - bool: 是否找到。 - """ - if isinstance(sub, type): # 检测消息链中是否有某种类型的对象 - for i in self: - if type(i) is sub: - return True - return False - if isinstance(sub, MessageComponent): # 检查消息链中是否有某个组件 - for i in self: - if i == sub: - return True - return False - raise TypeError(f'类型不匹配,当前类型:{type(sub)}') - - def __contains__(self, sub) -> bool: - return self.has(sub) - - def __ge__(self, other): - return other in self - - def __len__(self) -> int: - return len(self.__root__) - - def __add__(self, other: typing.Union['MessageChain', MessageComponent, str]) -> 'MessageChain': - if isinstance(other, MessageChain): - return self.__class__(self.__root__ + other.__root__) - if isinstance(other, str): - return self.__class__(self.__root__ + [Plain(other)]) - if isinstance(other, MessageComponent): - return self.__class__(self.__root__ + [other]) - return NotImplemented - - def __radd__(self, other: typing.Union[MessageComponent, str]) -> 'MessageChain': - if isinstance(other, MessageComponent): - return self.__class__([other] + self.__root__) - if isinstance(other, str): - return self.__class__([typing.cast(MessageComponent, Plain(other))] + self.__root__) - return NotImplemented - - def __mul__(self, other: int): - if isinstance(other, int): - return self.__class__(self.__root__ * other) - return NotImplemented - - def __rmul__(self, other: int): - return self.__mul__(other) - - def __iadd__(self, other: typing.Iterable[typing.Union[MessageComponent, str]]): - self.extend(other) - - def __imul__(self, other: int): - if isinstance(other, int): - self.__root__ *= other - return NotImplemented - - def index( - self, - x: typing.Union[MessageComponent, typing.Type[MessageComponent]], - i: int = 0, - j: int = -1, - ) -> int: - """返回 x 在消息链中首次出现项的索引号(索引号在 i 或其后且在 j 之前)。 - - Args: - x (`Union[MessageComponent, Type[MessageComponent]]`): - 要查找的消息元素或消息元素类型。 - i: 从哪个位置开始查找。 - j: 查找到哪个位置结束。 - - Returns: - int: 如果找到,则返回索引号。 - - Raises: - ValueError: 没有找到。 - TypeError: 类型不匹配。 - """ - if isinstance(x, type): - l = len(self) - if i < 0: - i += l - if i < 0: - i = 0 - if j < 0: - j += l - if j > l: - j = l - for index in range(i, j): - if type(self[index]) is x: - return index - raise ValueError('消息链中不存在该类型的组件。') - if isinstance(x, MessageComponent): - return self.__root__.index(x, i, j) - raise TypeError(f'类型不匹配,当前类型:{type(x)}') - - def count(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]) -> int: - """返回消息链中 x 出现的次数。 - - Args: - x (`Union[MessageComponent, Type[MessageComponent]]`): - 要查找的消息元素或消息元素类型。 - - Returns: - int: 次数。 - """ - if isinstance(x, type): - return sum(1 for i in self if type(i) is x) - if isinstance(x, MessageComponent): - return self.__root__.count(x) - raise TypeError(f'类型不匹配,当前类型:{type(x)}') - - def extend(self, x: typing.Iterable[typing.Union[MessageComponent, str]]): - """将另一个消息链中的元素添加到消息链末尾。 - - Args: - x: 另一个消息链,也可为消息元素或字符串元素的序列。 - """ - self.__root__.extend(Plain(c) if isinstance(c, str) else c for c in x) - - def append(self, x: typing.Union[MessageComponent, str]): - """将一个消息元素或字符串元素添加到消息链末尾。 - - Args: - x: 消息元素或字符串元素。 - """ - self.__root__.append(Plain(x) if isinstance(x, str) else x) - - def insert(self, i: int, x: typing.Union[MessageComponent, str]): - """将一个消息元素或字符串添加到消息链中指定位置。 - - Args: - i: 插入位置。 - x: 消息元素或字符串元素。 - """ - self.__root__.insert(i, Plain(x) if isinstance(x, str) else x) - - def pop(self, i: int = -1) -> MessageComponent: - """从消息链中移除并返回指定位置的元素。 - - Args: - i: 移除位置。默认为末尾。 - - Returns: - MessageComponent: 移除的元素。 - """ - return self.__root__.pop(i) - - def remove(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]): - """从消息链中移除指定元素或指定类型的一个元素。 - - Args: - x: 指定的元素或元素类型。 - """ - if isinstance(x, type): - self.pop(self.index(x)) - if isinstance(x, MessageComponent): - self.__root__.remove(x) - - def exclude( - self, - x: typing.Union[MessageComponent, typing.Type[MessageComponent]], - count: int = -1, - ) -> 'MessageChain': - """返回移除指定元素或指定类型的元素后剩余的消息链。 - - Args: - x: 指定的元素或元素类型。 - count: 至多移除的数量。默认为全部移除。 - - Returns: - MessageChain: 剩余的消息链。 - """ - - def _exclude(): - nonlocal count - x_is_type = isinstance(x, type) - for c in self: - if count > 0 and ((x_is_type and type(c) is x) or c == x): - count -= 1 - continue - yield c - - return self.__class__(_exclude()) - - def reverse(self): - """将消息链原地翻转。""" - self.__root__.reverse() - - @classmethod - def join(cls, *args: typing.Iterable[typing.Union[str, MessageComponent]]): - return cls(Plain(c) if isinstance(c, str) else c for c in itertools.chain(*args)) - - @property - def source(self) -> typing.Optional['Source']: - """获取消息链中的 `Source` 对象。""" - return self.get_first(Source) - - @property - def message_id(self) -> int: - """获取消息链的 message_id,若无法获取,返回 -1。""" - source = self.source - return source.id if source else -1 - - -TMessage = typing.Union[ - MessageChain, - typing.Iterable[typing.Union[MessageComponent, str]], - MessageComponent, - str, -] -"""可以转化为 MessageChain 的类型。""" - - -class Source(MessageComponent): - """源。包含消息的基本信息。""" - - type: str = 'Source' - """消息组件类型。""" - id: typing.Union[int, str] - """消息的识别号,用于引用回复(Source 类型永远为 MessageChain 的第一个元素)。""" - time: datetime - """消息时间。""" - - -class Plain(MessageComponent): - """纯文本。""" - - type: str = 'Plain' - """消息组件类型。""" - text: str - """文字消息。""" - - def __str__(self): - return self.text - - def __repr__(self): - return f'Plain({self.text!r})' - - -class Quote(MessageComponent): - """引用。""" - - type: str = 'Quote' - """消息组件类型。""" - id: typing.Optional[int] = None - """被引用回复的原消息的 message_id。""" - group_id: typing.Optional[typing.Union[int, str]] = None - """被引用回复的原消息所接收的群号,当为好友消息时为0。""" - sender_id: typing.Optional[typing.Union[int, str]] = None - """被引用回复的原消息的发送者的ID。""" - target_id: typing.Optional[typing.Union[int, str]] = None - """被引用回复的原消息的接收者者的ID或群ID。""" - origin: MessageChain - """被引用回复的原消息的消息链对象。""" - - @pydantic.validator('origin', always=True, pre=True) - def origin_formater(cls, v): - return MessageChain.parse_obj(v) - - -class At(MessageComponent): - """At某人。""" - - type: str = 'At' - """消息组件类型。""" - target: typing.Union[int, str] - """群员 ID。""" - display: typing.Optional[str] = None - """At时显示的文字,发送消息时无效,自动使用群名片。""" - - def __eq__(self, other): - return isinstance(other, At) and self.target == other.target - - def __str__(self): - return f'@{self.display or self.target}' - - -class AtAll(MessageComponent): - """At全体。""" - - type: str = 'AtAll' - """消息组件类型。""" - - def __str__(self): - return '@全体成员' - - -class Image(MessageComponent): - """图片。""" - - type: str = 'Image' - """消息组件类型。""" - image_id: typing.Optional[str] = None - """图片的 image_id,不为空时将忽略 url 属性。""" - url: typing.Optional[pydantic.HttpUrl] = None - """图片的 URL,发送时可作网络图片的链接;接收时为图片的链接,可用于图片下载。""" - path: typing.Union[str, Path, None] = None - """图片的路径,发送本地图片。""" - base64: typing.Optional[str] = None - """图片的 Base64 编码。""" - - def __eq__(self, other): - return isinstance(other, Image) and self.type == other.type and self.uuid == other.uuid - - def __str__(self): - return '[图片]' - - @pydantic.validator('path') - def validate_path(cls, path: typing.Union[str, Path, None]): - """修复 path 参数的行为,使之相对于 LangBot 的启动路径。""" - if path: - try: - return str(Path(path).resolve(strict=True)) - except FileNotFoundError: - raise ValueError(f'无效路径:{path}') - else: - return path - - @property - def uuid(self): - image_id = self.image_id - if image_id[0] == '{': # 群图片 - image_id = image_id[1:37] - elif image_id[0] == '/': # 好友图片 - image_id = image_id[1:] - return image_id - - async def get_bytes(self) -> typing.Tuple[bytes, str]: - """获取图片的 bytes 和 mime type""" - if self.url: - async with httpx.AsyncClient() as client: - response = await client.get(self.url) - response.raise_for_status() - return response.content, response.headers.get('Content-Type') - elif self.base64: - mime_type = 'image/jpeg' - - split_index = self.base64.find(';base64,') - if split_index == -1: - raise ValueError('Invalid base64 string') - - mime_type = self.base64[5:split_index] - base64_data = self.base64[split_index + 8 :] - - return base64.b64decode(base64_data), mime_type - elif self.path: - async with aiofiles.open(self.path, 'rb') as f: - return await f.read(), 'image/jpeg' - else: - raise ValueError('Can not get bytes from image') - - @classmethod - async def from_local( - cls, - filename: typing.Union[str, Path, None] = None, - content: typing.Optional[bytes] = None, - ) -> 'Image': - """从本地文件路径加载图片,以 base64 的形式传递。 - - Args: - filename: 从本地文件路径加载图片,与 `content` 二选一。 - content: 从本地文件内容加载图片,与 `filename` 二选一。 - - Returns: - Image: 图片对象。 - """ - if content: - pass - elif filename: - path = Path(filename) - import aiofiles - - async with aiofiles.open(path, 'rb') as f: - content = await f.read() - else: - raise ValueError('请指定图片路径或图片内容!') - import base64 - - img = cls(base64=base64.b64encode(content).decode()) - return img - - @classmethod - def from_unsafe_path(cls, path: typing.Union[str, Path]) -> 'Image': - """从不安全的路径加载图片。 - - Args: - path: 从不安全的路径加载图片。 - - Returns: - Image: 图片对象。 - """ - return cls.construct(path=str(path)) - - -class Unknown(MessageComponent): - """未知。""" - - type: str = 'Unknown' - """消息组件类型。""" - text: str - """文本。""" - - def __str__(self): - return f'Unknown Message: {self.text}' - - -class Voice(MessageComponent): - """语音。""" - - type: str = 'Voice' - """消息组件类型。""" - voice_id: typing.Optional[str] = None - """语音的 voice_id,不为空时将忽略 url 属性。""" - url: typing.Optional[str] = None - """语音的 URL,发送时可作网络语音的链接;接收时为语音文件的链接,可用于语音下载。""" - path: typing.Optional[str] = None - """语音的路径,发送本地语音。""" - base64: typing.Optional[str] = None - """语音的 Base64 编码。""" - length: typing.Optional[int] = None - """语音的长度,单位为秒。""" - - @pydantic.validator('path') - def validate_path(cls, path: typing.Optional[str]): - """修复 path 参数的行为,使之相对于 LangBot 的启动路径。""" - if path: - try: - return str(Path(path).resolve(strict=True)) - except FileNotFoundError: - raise ValueError(f'无效路径:{path}') - else: - return path - - def __str__(self): - return '[语音]' - - async def download( - self, - filename: typing.Union[str, Path, None] = None, - directory: typing.Union[str, Path, None] = None, - ): - """下载语音到本地。 - - Args: - filename: 下载到本地的文件路径。与 `directory` 二选一。 - directory: 下载到本地的文件夹路径。与 `filename` 二选一。 - """ - if not self.url: - logger.warning(f'语音 `{self.voice_id}` 无 url 参数,下载失败。') - return - - import httpx - - async with httpx.AsyncClient() as client: - response = await client.get(self.url) - response.raise_for_status() - content = response.content - - if filename: - path = Path(filename) - path.parent.mkdir(parents=True, exist_ok=True) - elif directory: - path = Path(directory) - path.mkdir(parents=True, exist_ok=True) - path = path / f'{self.voice_id}.silk' - else: - raise ValueError('请指定文件路径或文件夹路径!') - - import aiofiles - - async with aiofiles.open(path, 'wb') as f: - await f.write(content) - - @classmethod - async def from_local( - cls, - filename: typing.Union[str, Path, None] = None, - content: typing.Optional[bytes] = None, - ) -> 'Voice': - """从本地文件路径加载语音,以 base64 的形式传递。 - - Args: - filename: 从本地文件路径加载语音,与 `content` 二选一。 - content: 从本地文件内容加载语音,与 `filename` 二选一。 - """ - if content: - pass - if filename: - path = Path(filename) - import aiofiles - - async with aiofiles.open(path, 'rb') as f: - content = await f.read() - else: - raise ValueError('请指定语音路径或语音内容!') - import base64 - - img = cls(base64=base64.b64encode(content).decode()) - return img - - -class ForwardMessageNode(pydantic.BaseModel): - """合并转发中的一条消息。""" - - sender_id: typing.Optional[typing.Union[int, str]] = None - """发送人ID。""" - sender_name: typing.Optional[str] = None - """显示名称。""" - message_chain: typing.Optional[MessageChain] = None - """消息内容。""" - message_id: typing.Optional[int] = None - """消息的 message_id。""" - time: typing.Optional[datetime] = None - """发送时间。""" - - @pydantic.validator('message_chain', check_fields=False) - def _validate_message_chain(cls, value: typing.Union[MessageChain, list]): - if isinstance(value, list): - return MessageChain.parse_obj(value) - return value - - @classmethod - def create( - cls, - sender: typing.Union[platform_entities.Friend, platform_entities.GroupMember], - message: MessageChain, - ) -> 'ForwardMessageNode': - """从消息链生成转发消息。 - - Args: - sender: 发送人。 - message: 消息内容。 - - Returns: - ForwardMessageNode: 生成的一条消息。 - """ - return ForwardMessageNode(sender_id=sender.id, sender_name=sender.get_name(), message_chain=message) - - -class ForwardMessageDiaplay(pydantic.BaseModel): - title: str = '群聊的聊天记录' - brief: str = '[聊天记录]' - source: str = '聊天记录' - preview: typing.List[str] = [] - summary: str = '查看x条转发消息' - - -class Forward(MessageComponent): - """合并转发。""" - - type: str = 'Forward' - """消息组件类型。""" - display: ForwardMessageDiaplay - """显示信息""" - node_list: typing.List[ForwardMessageNode] - """转发消息节点列表。""" - - def __init__(self, *args, **kwargs): - if len(args) == 1: - self.node_list = args[0] - super().__init__(**kwargs) - super().__init__(*args, **kwargs) - - def __str__(self): - return '[聊天记录]' - - -class File(MessageComponent): - """文件。""" - - type: str = 'File' - """消息组件类型。""" - id: str = '' - """文件识别 ID。""" - name: str - """文件名称。""" - size: int = 0 - """文件大小。""" - url: str - """文件路径""" - - def __str__(self): - return f'[文件]{self.name}' - - -class Face(MessageComponent): - """系统表情 - 此处将超级表情骰子/划拳,一同归类于face - 当face_type为rps(划拳)时 face_id 对应的是手势 - 当face_type为dice(骰子)时 face_id 对应的是点数 - """ - - type: str = 'Face' - """表情类型""" - face_type: str = 'face' - """表情id""" - face_id: int = 0 - """表情名""" - face_name: str = '' - - def __str__(self): - if self.face_type == 'face': - return f'[表情]{self.face_name}' - elif self.face_type == 'dice': - return f'[表情]{self.face_id}点的{self.face_name}' - elif self.face_type == 'rps': - return f'[表情]{self.face_name}({self.rps_data(self.face_id)})' - - def rps_data(self, face_id): - rps_dict = { - 1: '布', - 2: '剪刀', - 3: '石头', - } - return rps_dict[face_id] - - -# ================ 个人微信专用组件 ================ - - -class WeChatMiniPrograms(MessageComponent): - """小程序。个人微信专用组件。""" - - type: str = 'WeChatMiniPrograms' - """小程序id""" - mini_app_id: str - """小程序归属用户id""" - user_name: str - """小程序名称""" - display_name: typing.Optional[str] = '' - """打开地址""" - page_path: typing.Optional[str] = '' - """小程序标题""" - title: typing.Optional[str] = '' - """首页图片""" - image_url: typing.Optional[str] = '' - - -class WeChatForwardMiniPrograms(MessageComponent): - """转发小程序。个人微信专用组件。""" - - type: str = 'WeChatForwardMiniPrograms' - """xml数据""" - xml_data: str - """首页图片""" - image_url: typing.Optional[str] = None - - def __str__(self): - return self.xml_data - - -class WeChatEmoji(MessageComponent): - """emoji表情。个人微信专用组件。""" - - type: str = 'WeChatEmoji' - """emojimd5""" - emoji_md5: str - """emoji大小""" - emoji_size: int - - -class WeChatLink(MessageComponent): - """发送链接。个人微信专用组件。""" - - type: str = 'WeChatLink' - """标题""" - link_title: str = '' - """链接描述""" - link_desc: str = '' - """链接地址""" - link_url: str = '' - """链接略缩图""" - link_thumb_url: str = '' - - -class WeChatForwardLink(MessageComponent): - """转发链接。个人微信专用组件。""" - - type: str = 'WeChatForwardLink' - """xml数据""" - xml_data: str - - def __str__(self): - return self.xml_data - - -class WeChatForwardImage(MessageComponent): - """转发图片。个人微信专用组件。""" - - type: str = 'WeChatForwardImage' - """xml数据""" - xml_data: str - - def __str__(self): - return self.xml_data - - -class WeChatForwardFile(MessageComponent): - """转发文件。个人微信专用组件。""" - - type: str = 'WeChatForwardFile' - """xml数据""" - xml_data: str - - def __str__(self): - return self.xml_data - - -class WeChatAppMsg(MessageComponent): - """通用appmsg发送。个人微信专用组件。""" - - type: str = 'WeChatAppMsg' - """xml数据""" - app_msg: str - - def __str__(self): - return self.app_msg - - -class WeChatForwardQuote(MessageComponent): - """转发引用消息。个人微信专用组件。""" - - type: str = 'WeChatForwardQuote' - """xml数据""" - app_msg: str - - def __str__(self): - return self.app_msg - - -class WeChatFile(MessageComponent): - """文件。""" - - type: str = 'File' - """消息组件类型。""" - file_id: str = '' - """文件识别 ID。""" - file_name: str = '' - """文件名称。""" - file_size: int = 0 - """文件大小。""" - file_path: str = '' - """文件地址""" - file_base64: str = '' - """base64""" - - def __str__(self): - return f'[文件]{self.file_name}' diff --git a/pkg/plugin/connector.py b/pkg/plugin/connector.py new file mode 100644 index 00000000..da7de024 --- /dev/null +++ b/pkg/plugin/connector.py @@ -0,0 +1,200 @@ +# For connect to plugin runtime. +from __future__ import annotations + +import asyncio +from typing import Any +import typing +import os +import sys + +from async_lru import alru_cache + +from ..core import app +from . import handler +from ..utils import platform +from langbot_plugin.runtime.io.controllers.stdio import client as stdio_client_controller +from langbot_plugin.runtime.io.controllers.ws import client as ws_client_controller +from langbot_plugin.api.entities import events +from langbot_plugin.api.entities import context +import langbot_plugin.runtime.io.connection as base_connection +from langbot_plugin.api.definition.components.manifest import ComponentManifest +from langbot_plugin.api.entities.builtin.command import context as command_context +from langbot_plugin.runtime.plugin.mgr import PluginInstallSource +from ..core import taskmgr + + +class PluginRuntimeConnector: + """Plugin runtime connector""" + + ap: app.Application + + handler: handler.RuntimeConnectionHandler + + handler_task: asyncio.Task + + stdio_client_controller: stdio_client_controller.StdioClientController + + ctrl: stdio_client_controller.StdioClientController | ws_client_controller.WebSocketClientController + + runtime_disconnect_callback: typing.Callable[ + [PluginRuntimeConnector], typing.Coroutine[typing.Any, typing.Any, None] + ] + + def __init__( + self, + ap: app.Application, + runtime_disconnect_callback: typing.Callable[ + [PluginRuntimeConnector], typing.Coroutine[typing.Any, typing.Any, None] + ], + ): + self.ap = ap + self.runtime_disconnect_callback = runtime_disconnect_callback + + async def initialize(self): + async def new_connection_callback(connection: base_connection.Connection): + async def disconnect_callback(rchandler: handler.RuntimeConnectionHandler) -> bool: + if platform.get_platform() == 'docker' or platform.use_websocket_to_connect_plugin_runtime(): + self.ap.logger.error('Disconnected from plugin runtime, trying to reconnect...') + await self.runtime_disconnect_callback(self) + return False + else: + self.ap.logger.error( + 'Disconnected from plugin runtime, cannot automatically reconnect while LangBot connects to plugin runtime via stdio, please restart LangBot.' + ) + return False + + self.handler = handler.RuntimeConnectionHandler(connection, disconnect_callback, self.ap) + self.handler_task = asyncio.create_task(self.handler.run()) + _ = await self.handler.ping() + self.ap.logger.info('Connected to plugin runtime.') + await self.handler_task + + task: asyncio.Task | None = None + + if platform.get_platform() == 'docker' or platform.use_websocket_to_connect_plugin_runtime(): # use websocket + self.ap.logger.info('use websocket to connect to plugin runtime') + ws_url = self.ap.instance_config.data['plugin']['runtime_ws_url'] + + async def make_connection_failed_callback(ctrl: ws_client_controller.WebSocketClientController) -> None: + self.ap.logger.error('Failed to connect to plugin runtime, trying to reconnect...') + await self.runtime_disconnect_callback(self) + + self.ctrl = ws_client_controller.WebSocketClientController( + ws_url=ws_url, + make_connection_failed_callback=make_connection_failed_callback, + ) + task = self.ctrl.run(new_connection_callback) + else: # stdio + self.ap.logger.info('use stdio to connect to plugin runtime') + # cmd: lbp rt -s + python_path = sys.executable + env = os.environ.copy() + self.ctrl = stdio_client_controller.StdioClientController( + command=python_path, + args=['-m', 'langbot_plugin.cli.__init__', 'rt', '-s'], + env=env, + ) + task = self.ctrl.run(new_connection_callback) + + asyncio.create_task(task) + + async def initialize_plugins(self): + pass + + async def install_plugin( + self, + install_source: PluginInstallSource, + install_info: dict[str, Any], + task_context: taskmgr.TaskContext | None = None, + ): + async for ret in self.handler.install_plugin(install_source.value, install_info): + current_action = ret.get('current_action', None) + if current_action is not None: + if task_context is not None: + task_context.set_current_action(current_action) + + trace = ret.get('trace', None) + if trace is not None: + if task_context is not None: + task_context.trace(trace) + + async def upgrade_plugin( + self, plugin_author: str, plugin_name: str, task_context: taskmgr.TaskContext | None = None + ) -> dict[str, Any]: + async for ret in self.handler.upgrade_plugin(plugin_author, plugin_name): + current_action = ret.get('current_action', None) + if current_action is not None: + if task_context is not None: + task_context.set_current_action(current_action) + + trace = ret.get('trace', None) + if trace is not None: + if task_context is not None: + task_context.trace(trace) + + async def delete_plugin( + self, plugin_author: str, plugin_name: str, task_context: taskmgr.TaskContext | None = None + ) -> dict[str, Any]: + async for ret in self.handler.delete_plugin(plugin_author, plugin_name): + current_action = ret.get('current_action', None) + if current_action is not None: + if task_context is not None: + task_context.set_current_action(current_action) + + trace = ret.get('trace', None) + if trace is not None: + if task_context is not None: + task_context.trace(trace) + + async def list_plugins(self) -> list[dict[str, Any]]: + return await self.handler.list_plugins() + + async def get_plugin_info(self, author: str, plugin_name: str) -> dict[str, Any]: + return await self.handler.get_plugin_info(author, plugin_name) + + async def set_plugin_config(self, plugin_author: str, plugin_name: str, config: dict[str, Any]) -> dict[str, Any]: + return await self.handler.set_plugin_config(plugin_author, plugin_name, config) + + @alru_cache(ttl=5 * 60) # 5 minutes + async def get_plugin_icon(self, plugin_author: str, plugin_name: str) -> dict[str, Any]: + return await self.handler.get_plugin_icon(plugin_author, plugin_name) + + async def emit_event( + self, + event: events.BaseEventModel, + ) -> context.EventContext: + event_ctx = context.EventContext.from_event(event) + + event_ctx_result = await self.handler.emit_event(event_ctx.model_dump(serialize_as_any=True)) + + event_ctx = context.EventContext.model_validate(event_ctx_result['event_context']) + + return event_ctx + + async def list_tools(self) -> list[ComponentManifest]: + list_tools_data = await self.handler.list_tools() + + return [ComponentManifest.model_validate(tool) for tool in list_tools_data] + + async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]: + return await self.handler.call_tool(tool_name, parameters) + + async def list_commands(self) -> list[ComponentManifest]: + list_commands_data = await self.handler.list_commands() + + return [ComponentManifest.model_validate(command) for command in list_commands_data] + + async def execute_command( + self, command_ctx: command_context.ExecuteContext + ) -> typing.AsyncGenerator[command_context.CommandReturn, None]: + gen = self.handler.execute_command(command_ctx.model_dump(serialize_as_any=True)) + + async for ret in gen: + cmd_ret = command_context.CommandReturn.model_validate(ret) + + yield cmd_ret + + def dispose(self): + if isinstance(self.ctrl, stdio_client_controller.StdioClientController): + self.ap.logger.info('Terminating plugin runtime process...') + self.ctrl.process.terminate() diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py deleted file mode 100644 index dfd691f3..00000000 --- a/pkg/plugin/context.py +++ /dev/null @@ -1,388 +0,0 @@ -from __future__ import annotations - -import typing -import abc -import pydantic.v1 as pydantic -import enum - -from . import events -from ..provider.tools import entities as tools_entities -from ..core import app -from ..discover import engine as discover_engine -from ..platform.types import message as platform_message -from ..platform import adapter as platform_adapter - - -def register( - name: str, description: str, version: str, author: str -) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]: - """注册插件类 - - 使用示例: - - @register( - name="插件名称", - description="插件描述", - version="插件版本", - author="插件作者" - ) - class MyPlugin(BasePlugin): - pass - """ - pass - - -def handler( - event: typing.Type[events.BaseEventModel], -) -> typing.Callable[[typing.Callable], typing.Callable]: - """注册事件监听器 - - 使用示例: - - class MyPlugin(BasePlugin): - - @handler(NormalMessageResponded) - async def on_normal_message_responded(self, ctx: EventContext): - pass - """ - pass - - -def llm_func( - name: str = None, -) -> typing.Callable: - """注册内容函数 - - 使用示例: - - class MyPlugin(BasePlugin): - - @llm_func("access_the_web_page") - async def _(self, query, url: str, brief_len: int): - \"""Call this function to search about the question before you answer any questions. - - Do not search through google.com at any time. - - If you need to search somthing, visit https://www.sogou.com/web?query=. - - If user ask you to open a url (start with http:// or https://), visit it directly. - - Summary the plain content result by yourself, DO NOT directly output anything in the result you got. - - Args: - url(str): url to visit - brief_len(int): max length of the plain text content, recommend 1024-4096, prefer 4096 - - Returns: - str: plain text content of the web page or error message(starts with 'error:') - \""" - """ - pass - - -class BasePlugin(metaclass=abc.ABCMeta): - """插件基类""" - - host: APIHost - """API宿主""" - - ap: app.Application - """应用程序对象""" - - config: dict - """插件配置""" - - def __init__(self, host: APIHost): - """初始化阶段被调用""" - self.host = host - self.config = {} - - async def initialize(self): - """初始化阶段被调用""" - pass - - async def destroy(self): - """释放/禁用插件时被调用""" - pass - - def __del__(self): - """释放/禁用插件时被调用""" - pass - - -class APIHost: - """LangBot API 宿主""" - - ap: app.Application - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - pass - - # ========== 插件可调用的 API(主程序API) ========== - - def get_platform_adapters(self) -> list[platform_adapter.MessagePlatformAdapter]: - """获取已启用的消息平台适配器列表 - - Returns: - list[platform.adapter.MessageSourceAdapter]: 已启用的消息平台适配器列表 - """ - return self.ap.platform_mgr.get_running_adapters() - - async def send_active_message( - self, - adapter: platform_adapter.MessagePlatformAdapter, - target_type: str, - target_id: str, - message: platform_message.MessageChain, - ): - """发送主动消息 - - Args: - adapter (platform.adapter.MessageSourceAdapter): 消息平台适配器对象,调用 host.get_platform_adapters() 获取并取用其中某个 - target_type (str): 目标类型,`person`或`group` - target_id (str): 目标ID - message (platform.types.MessageChain): 消息链 - """ - await adapter.send_message( - target_type=target_type, - target_id=target_id, - message=message, - ) - - def require_ver( - self, - ge: str, - le: str = 'v999.999.999', - ) -> bool: - """插件版本要求装饰器 - - Args: - ge (str): 最低版本要求 - le (str, optional): 最高版本要求 - - Returns: - bool: 是否满足要求, False时为无法获取版本号,True时为满足要求,报错为不满足要求 - """ - langbot_version = '' - - try: - langbot_version = self.ap.ver_mgr.get_current_version() # 从updater模块获取版本号 - except Exception: - return False - - if self.ap.ver_mgr.compare_version_str(langbot_version, ge) < 0 or ( - self.ap.ver_mgr.compare_version_str(langbot_version, le) > 0 - ): - raise Exception( - 'LangBot 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{})'.format( - ge, le, langbot_version - ) - ) - - return True - - -class EventContext: - """事件上下文, 保存此次事件运行的信息""" - - eid = 0 - """事件编号""" - - host: APIHost = None - """API宿主""" - - event: events.BaseEventModel = None - """此次事件的对象,具体类型为handler注册时指定监听的类型,可查看events.py中的定义""" - - __prevent_default__ = False - """是否阻止默认行为""" - - __prevent_postorder__ = False - """是否阻止后续插件的执行""" - - __return_value__ = {} - """ 返回值 - 示例: - { - "example": [ - 'value1', - 'value2', - 3, - 4, - { - 'key1': 'value1', - }, - ['value1', 'value2'] - ] - } - """ - - # ========== 插件可调用的 API ========== - - def add_return(self, key: str, ret): - """添加返回值""" - if key not in self.__return_value__: - self.__return_value__[key] = [] - self.__return_value__[key].append(ret) - - async def reply(self, message_chain: platform_message.MessageChain): - """回复此次消息请求 - - Args: - message_chain (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链 - """ - # TODO 添加 at_sender 和 quote_origin 参数 - await self.event.query.adapter.reply_message( - message_source=self.event.query.message_event, message=message_chain - ) - - async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): - """主动发送消息 - - Args: - target_type (str): 目标类型,`person`或`group` - target_id (str): 目标ID - message (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链 - """ - await self.event.query.adapter.send_message(target_type=target_type, target_id=target_id, message=message) - - def prevent_postorder(self): - """阻止后续插件执行""" - self.__prevent_postorder__ = True - - def prevent_default(self): - """阻止默认行为""" - self.__prevent_default__ = True - - # ========== 以下是内部保留方法,插件不应调用 ========== - - def get_return(self, key: str) -> list: - """获取key的所有返回值""" - if key in self.__return_value__: - return self.__return_value__[key] - return None - - def get_return_value(self, key: str): - """获取key的首个返回值""" - if key in self.__return_value__: - return self.__return_value__[key][0] - return None - - def is_prevented_default(self): - """是否阻止默认行为""" - return self.__prevent_default__ - - def is_prevented_postorder(self): - """是否阻止后序插件执行""" - return self.__prevent_postorder__ - - def __init__(self, host: APIHost, event: events.BaseEventModel): - self.eid = EventContext.eid - self.host = host - self.event = event - self.__prevent_default__ = False - self.__prevent_postorder__ = False - self.__return_value__ = {} - EventContext.eid += 1 - - -class RuntimeContainerStatus(enum.Enum): - """插件容器状态""" - - MOUNTED = 'mounted' - """已加载进内存,所有位于运行时记录中的 RuntimeContainer 至少是这个状态""" - - INITIALIZED = 'initialized' - """已初始化""" - - -class RuntimeContainer(pydantic.BaseModel): - """运行时的插件容器 - - 运行期间存储单个插件的信息 - """ - - plugin_name: str - """插件名称""" - - plugin_label: discover_engine.I18nString - """插件标签""" - - plugin_description: discover_engine.I18nString - """插件描述""" - - plugin_version: str - """插件版本""" - - plugin_author: str - """插件作者""" - - plugin_repository: str - """插件源码地址""" - - main_file: str - """插件主文件路径""" - - pkg_path: str - """插件包路径""" - - plugin_class: typing.Type[BasePlugin] = None - """插件类""" - - enabled: typing.Optional[bool] = True - """是否启用""" - - priority: typing.Optional[int] = 0 - """优先级""" - - config_schema: typing.Optional[list[dict]] = [] - """插件配置模板""" - - plugin_config: typing.Optional[dict] = {} - """插件配置""" - - plugin_inst: typing.Optional[BasePlugin] = None - """插件实例""" - - event_handlers: dict[ - typing.Type[events.BaseEventModel], - typing.Callable[[BasePlugin, EventContext], typing.Awaitable[None]], - ] = {} - """事件处理器""" - - tools: list[tools_entities.LLMFunction] = [] - """内容函数""" - - status: RuntimeContainerStatus = RuntimeContainerStatus.MOUNTED - """插件状态""" - - class Config: - arbitrary_types_allowed = True - - def model_dump(self, *args, **kwargs): - return { - 'name': self.plugin_name, - 'label': self.plugin_label.to_dict(), - 'description': self.plugin_description.to_dict(), - 'version': self.plugin_version, - 'author': self.plugin_author, - 'repository': self.plugin_repository, - 'main_file': self.main_file, - 'pkg_path': self.pkg_path, - 'enabled': self.enabled, - 'priority': self.priority, - 'config_schema': self.config_schema, - 'event_handlers': { - event_name.__name__: handler.__name__ for event_name, handler in self.event_handlers.items() - }, - 'tools': [ - { - 'name': function.name, - 'human_desc': function.human_desc, - 'description': function.description, - 'parameters': function.parameters, - 'func': function.func.__name__, - } - for function in self.tools - ], - 'status': self.status.value, - } diff --git a/pkg/plugin/errors.py b/pkg/plugin/errors.py deleted file mode 100644 index 8da223db..00000000 --- a/pkg/plugin/errors.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - - -class PluginSystemError(Exception): - message: str - - def __init__(self, message: str): - self.message = message - - def __str__(self): - return self.message - - -class PluginNotFoundError(PluginSystemError): - def __init__(self, message: str): - super().__init__(f'未找到插件: {message}') - - -class PluginInstallerError(PluginSystemError): - def __init__(self, message: str): - super().__init__(f'安装器操作错误: {message}') diff --git a/pkg/plugin/events.py b/pkg/plugin/events.py deleted file mode 100644 index 61e84714..00000000 --- a/pkg/plugin/events.py +++ /dev/null @@ -1,170 +0,0 @@ -from __future__ import annotations - -import typing - -import pydantic.v1 as pydantic - -from ..core import entities as core_entities -from ..provider import entities as llm_entities -from ..platform.types import message as platform_message - - -class BaseEventModel(pydantic.BaseModel): - """事件模型基类""" - - query: typing.Union[core_entities.Query, None] - """此次请求的query对象,非请求过程的事件时为None""" - - class Config: - arbitrary_types_allowed = True - - -class PersonMessageReceived(BaseEventModel): - """收到任何私聊消息时""" - - launcher_type: str - """发起对象类型(group/person)""" - - launcher_id: typing.Union[int, str] - """发起对象ID(群号/QQ号)""" - - sender_id: typing.Union[int, str] - """发送者ID(QQ号)""" - - message_chain: platform_message.MessageChain - - -class GroupMessageReceived(BaseEventModel): - """收到任何群聊消息时""" - - launcher_type: str - - launcher_id: typing.Union[int, str] - - sender_id: typing.Union[int, str] - - message_chain: platform_message.MessageChain - - -class PersonNormalMessageReceived(BaseEventModel): - """判断为应该处理的私聊普通消息时触发""" - - launcher_type: str - - launcher_id: typing.Union[int, str] - - sender_id: typing.Union[int, str] - - text_message: str - - alter: typing.Optional[str] = None - """修改后的消息文本""" - - reply: typing.Optional[list] = None - """回复消息组件列表""" - - -class PersonCommandSent(BaseEventModel): - """判断为应该处理的私聊命令时触发""" - - launcher_type: str - - launcher_id: typing.Union[int, str] - - sender_id: typing.Union[int, str] - - command: str - - params: list[str] - - text_message: str - - is_admin: bool - - alter: typing.Optional[str] = None - """修改后的完整命令文本""" - - reply: typing.Optional[list] = None - """回复消息组件列表""" - - -class GroupNormalMessageReceived(BaseEventModel): - """判断为应该处理的群聊普通消息时触发""" - - launcher_type: str - - launcher_id: typing.Union[int, str] - - sender_id: typing.Union[int, str] - - text_message: str - - alter: typing.Optional[str] = None - """修改后的消息文本""" - - reply: typing.Optional[list] = None - """回复消息组件列表""" - - -class GroupCommandSent(BaseEventModel): - """判断为应该处理的群聊命令时触发""" - - launcher_type: str - - launcher_id: typing.Union[int, str] - - sender_id: typing.Union[int, str] - - command: str - - params: list[str] - - text_message: str - - is_admin: bool - - alter: typing.Optional[str] = None - """修改后的完整命令文本""" - - reply: typing.Optional[list] = None - """回复消息组件列表""" - - -class NormalMessageResponded(BaseEventModel): - """回复普通消息时触发""" - - launcher_type: str - - launcher_id: typing.Union[int, str] - - sender_id: typing.Union[int, str] - - session: core_entities.Session - """会话对象""" - - prefix: str - """回复消息的前缀""" - - response_text: str - """回复消息的文本""" - - finish_reason: str - """响应结束原因""" - - funcs_called: list[str] - """调用的函数列表""" - - reply: typing.Optional[list] = None - """回复消息组件列表""" - - -class PromptPreProcessing(BaseEventModel): - """会话中的Prompt预处理时触发""" - - session_name: str - - default_prompt: list[llm_entities.Message] - """此对话的情景预设,可修改""" - - prompt: list[llm_entities.Message] - """此对话现有消息记录,可修改""" diff --git a/pkg/plugin/handler.py b/pkg/plugin/handler.py new file mode 100644 index 00000000..36d11d09 --- /dev/null +++ b/pkg/plugin/handler.py @@ -0,0 +1,581 @@ +from __future__ import annotations + +import typing +from typing import Any +import base64 +import traceback + +import sqlalchemy + +from langbot_plugin.runtime.io import handler +from langbot_plugin.runtime.io.connection import Connection +from langbot_plugin.entities.io.actions.enums import ( + CommonAction, + RuntimeToLangBotAction, + LangBotToRuntimeAction, + PluginToRuntimeAction, +) +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.provider.message as provider_message +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool + +from ..entity.persistence import plugin as persistence_plugin +from ..entity.persistence import bstorage as persistence_bstorage + +from ..core import app +from ..utils import constants + + +class RuntimeConnectionHandler(handler.Handler): + """Runtime connection handler""" + + ap: app.Application + + def __init__( + self, + connection: Connection, + disconnect_callback: typing.Callable[[], typing.Coroutine[typing.Any, typing.Any, bool]], + ap: app.Application, + ): + super().__init__(connection, disconnect_callback) + self.ap = ap + + @self.action(RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS) + async def initialize_plugin_settings(data: dict[str, Any]) -> handler.ActionResponse: + """Initialize plugin settings""" + # check if exists plugin setting + plugin_author = data['plugin_author'] + plugin_name = data['plugin_name'] + install_source = data['install_source'] + install_info = data['install_info'] + + try: + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_plugin.PluginSetting) + .where(persistence_plugin.PluginSetting.plugin_author == plugin_author) + .where(persistence_plugin.PluginSetting.plugin_name == plugin_name) + ) + + if result.first() is not None: + # delete plugin setting + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_plugin.PluginSetting) + .where(persistence_plugin.PluginSetting.plugin_author == plugin_author) + .where(persistence_plugin.PluginSetting.plugin_name == plugin_name) + ) + + # create plugin setting + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(persistence_plugin.PluginSetting).values( + plugin_author=plugin_author, + plugin_name=plugin_name, + install_source=install_source, + install_info=install_info, + ) + ) + + return handler.ActionResponse.success( + data={}, + ) + except Exception as e: + traceback.print_exc() + return handler.ActionResponse.error( + message=f'Failed to initialize plugin settings: {e}', + ) + + @self.action(RuntimeToLangBotAction.GET_PLUGIN_SETTINGS) + async def get_plugin_settings(data: dict[str, Any]) -> handler.ActionResponse: + """Get plugin settings""" + + plugin_author = data['plugin_author'] + plugin_name = data['plugin_name'] + + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_plugin.PluginSetting) + .where(persistence_plugin.PluginSetting.plugin_author == plugin_author) + .where(persistence_plugin.PluginSetting.plugin_name == plugin_name) + ) + + data = { + 'enabled': True, + 'priority': 0, + 'plugin_config': {}, + 'install_source': 'local', + 'install_info': {}, + } + + setting = result.first() + + if setting is not None: + data['enabled'] = setting.enabled + data['priority'] = setting.priority + data['plugin_config'] = setting.config + data['install_source'] = setting.install_source + data['install_info'] = setting.install_info + + return handler.ActionResponse.success( + data=data, + ) + + @self.action(PluginToRuntimeAction.REPLY_MESSAGE) + async def reply_message(data: dict[str, Any]) -> handler.ActionResponse: + """Reply message""" + query_id = data['query_id'] + message_chain = data['message_chain'] + quote_origin = data['quote_origin'] + + if query_id not in self.ap.query_pool.cached_queries: + return handler.ActionResponse.error( + message=f'Query with query_id {query_id} not found', + ) + + query = self.ap.query_pool.cached_queries[query_id] + + message_chain_obj = platform_message.MessageChain.model_validate(message_chain) + + await query.adapter.reply_message( + query.message_event, + message_chain_obj, + quote_origin, + ) + + return handler.ActionResponse.success( + data={}, + ) + + @self.action(PluginToRuntimeAction.GET_BOT_UUID) + async def get_bot_uuid(data: dict[str, Any]) -> handler.ActionResponse: + """Get bot uuid""" + query_id = data['query_id'] + if query_id not in self.ap.query_pool.cached_queries: + return handler.ActionResponse.error( + message=f'Query with query_id {query_id} not found', + ) + + query = self.ap.query_pool.cached_queries[query_id] + + return handler.ActionResponse.success( + data={ + 'bot_uuid': query.bot_uuid, + }, + ) + + @self.action(PluginToRuntimeAction.SET_QUERY_VAR) + async def set_query_var(data: dict[str, Any]) -> handler.ActionResponse: + """Set query var""" + query_id = data['query_id'] + key = data['key'] + value = data['value'] + + if query_id not in self.ap.query_pool.cached_queries: + return handler.ActionResponse.error( + message=f'Query with query_id {query_id} not found', + ) + + query = self.ap.query_pool.cached_queries[query_id] + + query.variables[key] = value + + return handler.ActionResponse.success( + data={}, + ) + + @self.action(PluginToRuntimeAction.GET_QUERY_VAR) + async def get_query_var(data: dict[str, Any]) -> handler.ActionResponse: + """Get query var""" + query_id = data['query_id'] + key = data['key'] + + if query_id not in self.ap.query_pool.cached_queries: + return handler.ActionResponse.error( + message=f'Query with query_id {query_id} not found', + ) + + query = self.ap.query_pool.cached_queries[query_id] + + return handler.ActionResponse.success( + data={ + 'value': query.variables[key], + }, + ) + + @self.action(PluginToRuntimeAction.GET_QUERY_VARS) + async def get_query_vars(data: dict[str, Any]) -> handler.ActionResponse: + """Get query vars""" + query_id = data['query_id'] + if query_id not in self.ap.query_pool.cached_queries: + return handler.ActionResponse.error( + message=f'Query with query_id {query_id} not found', + ) + + query = self.ap.query_pool.cached_queries[query_id] + + return handler.ActionResponse.success( + data={ + 'vars': query.variables, + }, + ) + + @self.action(PluginToRuntimeAction.GET_LANGBOT_VERSION) + async def get_langbot_version(data: dict[str, Any]) -> handler.ActionResponse: + """Get langbot version""" + return handler.ActionResponse.success( + data={ + 'version': constants.semantic_version, + }, + ) + + @self.action(PluginToRuntimeAction.GET_BOTS) + async def get_bots(data: dict[str, Any]) -> handler.ActionResponse: + """Get bots""" + bots = await self.ap.bot_service.get_bots(include_secret=False) + return handler.ActionResponse.success( + data={ + 'bots': bots, + }, + ) + + @self.action(PluginToRuntimeAction.GET_BOT_INFO) + async def get_bot_info(data: dict[str, Any]) -> handler.ActionResponse: + """Get bot info""" + bot_uuid = data['bot_uuid'] + bot = await self.ap.bot_service.get_runtime_bot_info(bot_uuid, include_secret=False) + return handler.ActionResponse.success( + data={ + 'bot': bot, + }, + ) + + @self.action(PluginToRuntimeAction.SEND_MESSAGE) + async def send_message(data: dict[str, Any]) -> handler.ActionResponse: + """Send message""" + bot_uuid = data['bot_uuid'] + target_type = data['target_type'] + target_id = data['target_id'] + message_chain = data['message_chain'] + + message_chain_obj = platform_message.MessageChain.model_validate(message_chain) + + bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid) + if bot is None: + return handler.ActionResponse.error( + message=f'Bot with bot_uuid {bot_uuid} not found', + ) + + await bot.adapter.send_message( + target_type, + target_id, + message_chain_obj, + ) + + return handler.ActionResponse.success( + data={}, + ) + + @self.action(PluginToRuntimeAction.GET_LLM_MODELS) + async def get_llm_models(data: dict[str, Any]) -> handler.ActionResponse: + """Get llm models""" + llm_models = await self.ap.model_service.get_llm_models(include_secret=False) + return handler.ActionResponse.success( + data={ + 'llm_models': llm_models, + }, + ) + + @self.action(PluginToRuntimeAction.INVOKE_LLM) + async def invoke_llm(data: dict[str, Any]) -> handler.ActionResponse: + """Invoke llm""" + llm_model_uuid = data['llm_model_uuid'] + messages = data['messages'] + funcs = data.get('funcs', []) + extra_args = data.get('extra_args', {}) + + llm_model = await self.ap.model_mgr.get_model_by_uuid(llm_model_uuid) + if llm_model is None: + return handler.ActionResponse.error( + message=f'LLM model with llm_model_uuid {llm_model_uuid} not found', + ) + + messages_obj = [provider_message.Message.model_validate(message) for message in messages] + funcs_obj = [resource_tool.LLMTool.model_validate(func) for func in funcs] + + result = await llm_model.requester.invoke_llm( + query=None, + model=llm_model, + messages=messages_obj, + funcs=funcs_obj, + extra_args=extra_args, + ) + + return handler.ActionResponse.success( + data={ + 'message': result.model_dump(), + }, + ) + + @self.action(RuntimeToLangBotAction.SET_BINARY_STORAGE) + async def set_binary_storage(data: dict[str, Any]) -> handler.ActionResponse: + """Set binary storage""" + key = data['key'] + owner_type = data['owner_type'] + owner = data['owner'] + value = base64.b64decode(data['value_base64']) + + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_bstorage.BinaryStorage) + .where(persistence_bstorage.BinaryStorage.key == key) + .where(persistence_bstorage.BinaryStorage.owner_type == owner_type) + .where(persistence_bstorage.BinaryStorage.owner == owner) + ) + + if result.first() is not None: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_bstorage.BinaryStorage) + .where(persistence_bstorage.BinaryStorage.key == key) + .where(persistence_bstorage.BinaryStorage.owner_type == owner_type) + .where(persistence_bstorage.BinaryStorage.owner == owner) + .values(value=value) + ) + else: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(persistence_bstorage.BinaryStorage).values( + unique_key=f'{owner_type}:{owner}:{key}', + key=key, + owner_type=owner_type, + owner=owner, + value=value, + ) + ) + + return handler.ActionResponse.success( + data={}, + ) + + @self.action(RuntimeToLangBotAction.GET_BINARY_STORAGE) + async def get_binary_storage(data: dict[str, Any]) -> handler.ActionResponse: + """Get binary storage""" + key = data['key'] + owner_type = data['owner_type'] + owner = data['owner'] + + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_bstorage.BinaryStorage) + .where(persistence_bstorage.BinaryStorage.key == key) + .where(persistence_bstorage.BinaryStorage.owner_type == owner_type) + .where(persistence_bstorage.BinaryStorage.owner == owner) + ) + + storage = result.first() + if storage is None: + return handler.ActionResponse.error( + message=f'Storage with key {key} not found', + ) + + return handler.ActionResponse.success( + data={ + 'value_base64': base64.b64encode(storage.value).decode('utf-8'), + }, + ) + + @self.action(RuntimeToLangBotAction.DELETE_BINARY_STORAGE) + async def delete_binary_storage(data: dict[str, Any]) -> handler.ActionResponse: + """Delete binary storage""" + key = data['key'] + owner_type = data['owner_type'] + owner = data['owner'] + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_bstorage.BinaryStorage) + .where(persistence_bstorage.BinaryStorage.key == key) + .where(persistence_bstorage.BinaryStorage.owner_type == owner_type) + .where(persistence_bstorage.BinaryStorage.owner == owner) + ) + + return handler.ActionResponse.success( + data={}, + ) + + @self.action(RuntimeToLangBotAction.GET_BINARY_STORAGE_KEYS) + async def get_binary_storage_keys(data: dict[str, Any]) -> handler.ActionResponse: + """Get binary storage keys""" + owner_type = data['owner_type'] + owner = data['owner'] + + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_bstorage.BinaryStorage.key) + .where(persistence_bstorage.BinaryStorage.owner_type == owner_type) + .where(persistence_bstorage.BinaryStorage.owner == owner) + ) + + return handler.ActionResponse.success( + data={ + 'keys': result.scalars().all(), + }, + ) + + async def ping(self) -> dict[str, Any]: + """Ping the runtime""" + return await self.call_action( + CommonAction.PING, + {}, + timeout=10, + ) + + async def install_plugin( + self, install_source: str, install_info: dict[str, Any] + ) -> typing.AsyncGenerator[dict[str, Any], None]: + """Install plugin""" + gen = self.call_action_generator( + LangBotToRuntimeAction.INSTALL_PLUGIN, + { + 'install_source': install_source, + 'install_info': install_info, + }, + timeout=120, + ) + + async for ret in gen: + yield ret + + async def upgrade_plugin(self, plugin_author: str, plugin_name: str) -> typing.AsyncGenerator[dict[str, Any], None]: + """Upgrade plugin""" + gen = self.call_action_generator( + LangBotToRuntimeAction.UPGRADE_PLUGIN, + { + 'plugin_author': plugin_author, + 'plugin_name': plugin_name, + }, + timeout=120, + ) + + async for ret in gen: + yield ret + + async def delete_plugin(self, plugin_author: str, plugin_name: str) -> typing.AsyncGenerator[dict[str, Any], None]: + """Delete plugin""" + gen = self.call_action_generator( + LangBotToRuntimeAction.DELETE_PLUGIN, + { + 'plugin_author': plugin_author, + 'plugin_name': plugin_name, + }, + ) + + async for ret in gen: + yield ret + + async def list_plugins(self) -> list[dict[str, Any]]: + """List plugins""" + result = await self.call_action( + LangBotToRuntimeAction.LIST_PLUGINS, + {}, + timeout=10, + ) + + return result['plugins'] + + async def get_plugin_info(self, author: str, plugin_name: str) -> dict[str, Any]: + """Get plugin""" + result = await self.call_action( + LangBotToRuntimeAction.GET_PLUGIN_INFO, + { + 'author': author, + 'plugin_name': plugin_name, + }, + timeout=10, + ) + return result['plugin'] + + async def set_plugin_config(self, plugin_author: str, plugin_name: str, config: dict[str, Any]) -> dict[str, Any]: + """Set plugin config""" + # update plugin setting + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_plugin.PluginSetting) + .where(persistence_plugin.PluginSetting.plugin_author == plugin_author) + .where(persistence_plugin.PluginSetting.plugin_name == plugin_name) + .values(config=config) + ) + + # restart plugin + gen = self.call_action_generator( + LangBotToRuntimeAction.RESTART_PLUGIN, + { + 'plugin_author': plugin_author, + 'plugin_name': plugin_name, + }, + ) + async for ret in gen: + pass + + return {} + + async def emit_event( + self, + event_context: dict[str, Any], + ) -> dict[str, Any]: + """Emit event""" + result = await self.call_action( + LangBotToRuntimeAction.EMIT_EVENT, + { + 'event_context': event_context, + }, + timeout=30, + ) + + return result + + async def list_tools(self) -> list[dict[str, Any]]: + """List tools""" + result = await self.call_action( + LangBotToRuntimeAction.LIST_TOOLS, + {}, + timeout=10, + ) + + return result['tools'] + + async def get_plugin_icon(self, plugin_author: str, plugin_name: str) -> dict[str, Any]: + """Get plugin icon""" + result = await self.call_action( + LangBotToRuntimeAction.GET_PLUGIN_ICON, + { + 'plugin_author': plugin_author, + 'plugin_name': plugin_name, + }, + ) + return result + + async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]: + """Call tool""" + result = await self.call_action( + LangBotToRuntimeAction.CALL_TOOL, + { + 'tool_name': tool_name, + 'tool_parameters': parameters, + }, + timeout=30, + ) + + return result['tool_response'] + + async def list_commands(self) -> list[dict[str, Any]]: + """List commands""" + result = await self.call_action( + LangBotToRuntimeAction.LIST_COMMANDS, + {}, + timeout=10, + ) + return result['commands'] + + async def execute_command(self, command_context: dict[str, Any]) -> typing.AsyncGenerator[dict[str, Any], None]: + """Execute command""" + gen = self.call_action_generator( + LangBotToRuntimeAction.EXECUTE_COMMAND, + { + 'command_context': command_context, + }, + timeout=30, + ) + + async for ret in gen: + yield ret diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py deleted file mode 100644 index 0adb0078..00000000 --- a/pkg/plugin/host.py +++ /dev/null @@ -1,9 +0,0 @@ -# 此模块已过时 -# 请从 pkg.plugin.context 引入 BasePlugin, EventContext 和 APIHost -# 最早将于 v3.4 移除此模块 - -from .events import * - - -def emit(*args, **kwargs): - print('插件调用了已弃用的函数 pkg.plugin.host.emit()') diff --git a/pkg/plugin/installer.py b/pkg/plugin/installer.py deleted file mode 100644 index 159967dc..00000000 --- a/pkg/plugin/installer.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -import abc - -from ..core import app, taskmgr - - -class PluginInstaller(metaclass=abc.ABCMeta): - """插件安装器抽象类""" - - ap: app.Application - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - pass - - @abc.abstractmethod - async def install_plugin( - self, - plugin_source: str, - task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), - ): - """安装插件""" - raise NotImplementedError - - @abc.abstractmethod - async def uninstall_plugin( - self, - plugin_name: str, - task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), - ): - """卸载插件""" - raise NotImplementedError - - @abc.abstractmethod - async def update_plugin( - self, - plugin_name: str, - plugin_source: str = None, - task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), - ): - """更新插件""" - raise NotImplementedError diff --git a/pkg/plugin/installers/__init__.py b/pkg/plugin/installers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/plugin/installers/github.py b/pkg/plugin/installers/github.py deleted file mode 100644 index df247219..00000000 --- a/pkg/plugin/installers/github.py +++ /dev/null @@ -1,143 +0,0 @@ -from __future__ import annotations - -import re -import os -import zipfile -import ssl -import certifi - -import aiohttp -import aiofiles -import aiofiles.os as aiofiles_os -import aioshutil - -from .. import installer, errors -from ...utils import pkgmgr -from ...core import taskmgr - - -class GitHubRepoInstaller(installer.PluginInstaller): - """GitHub仓库插件安装器""" - - def get_github_plugin_repo_label(self, repo_url: str) -> list[str]: - """获取username, repo""" - repo = re.findall( - r'(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)', - repo_url, - ) - if len(repo) > 0: - return repo[0].split('/') - else: - return None - - async def download_plugin_source_code( - self, - repo_url: str, - target_path: str, - task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), - ) -> str: - """下载插件源码(全异步)""" - repo = self.get_github_plugin_repo_label(repo_url) - if repo is None: - raise errors.PluginInstallerError('仅支持GitHub仓库地址') - - target_path += repo[1] - self.ap.logger.debug('正在下载源码...') - task_context.trace('下载源码...', 'download-plugin-source-code') - - zipball_url = f'https://api.github.com/repos/{"/".join(repo)}/zipball/HEAD' - zip_resp: bytes = None - - # 创建自定义SSL上下文,使用certifi提供的根证书 - ssl_context = ssl.create_default_context(cafile=certifi.where()) - - async with aiohttp.ClientSession(trust_env=True) as session: - async with session.get( - url=zipball_url, - timeout=aiohttp.ClientTimeout(total=300), - ssl=ssl_context, # 使用自定义SSL上下文来验证证书 - ) as resp: - if resp.status != 200: - raise errors.PluginInstallerError(f'下载源码失败: {await resp.text()}') - zip_resp = await resp.read() - - if await aiofiles_os.path.exists('temp/' + target_path): - await aioshutil.rmtree('temp/' + target_path) - - if await aiofiles_os.path.exists(target_path): - await aioshutil.rmtree(target_path) - - await aiofiles_os.makedirs('temp/' + target_path) - - async with aiofiles.open('temp/' + target_path + '/source.zip', 'wb') as f: - await f.write(zip_resp) - - self.ap.logger.debug('解压中...') - task_context.trace('解压中...', 'unzip-plugin-source-code') - - with zipfile.ZipFile('temp/' + target_path + '/source.zip', 'r') as zip_ref: - zip_ref.extractall('temp/' + target_path) - await aiofiles_os.remove('temp/' + target_path + '/source.zip') - - import glob - - unzip_dir = glob.glob('temp/' + target_path + '/*')[0] - await aioshutil.copytree(unzip_dir, target_path + '/') - await aioshutil.rmtree(unzip_dir) - - self.ap.logger.debug('源码下载完成。') - return repo[1] - - async def install_requirements(self, path: str): - if os.path.exists(path + '/requirements.txt'): - pkgmgr.install_requirements(path + '/requirements.txt') - - async def install_plugin( - self, - plugin_source: str, - task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), - ): - """安装插件""" - task_context.trace('下载插件源码...', 'install-plugin') - repo_label = await self.download_plugin_source_code(plugin_source, 'plugins/', task_context) - task_context.trace('安装插件依赖...', 'install-plugin') - await self.install_requirements('plugins/' + repo_label) - task_context.trace('完成.', 'install-plugin') - - # Caution: in the v4.0, plugin without manifest will not be able to be updated - # await self.ap.plugin_mgr.setting.record_installed_plugin_source( - # "plugins/" + repo_label + '/', plugin_source - # ) - - async def uninstall_plugin( - self, - plugin_name: str, - task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), - ): - """卸载插件""" - plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) - if plugin_container is None: - raise errors.PluginInstallerError('插件不存在或未成功加载') - else: - task_context.trace('删除插件目录...', 'uninstall-plugin') - await aioshutil.rmtree(plugin_container.pkg_path) - task_context.trace('完成, 重新加载以生效.', 'uninstall-plugin') - - async def update_plugin( - self, - plugin_name: str, - plugin_source: str = None, - task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), - ): - """更新插件""" - task_context.trace('更新插件...', 'update-plugin') - plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) - if plugin_container is None: - raise errors.PluginInstallerError('插件不存在或未成功加载') - else: - if plugin_container.plugin_repository: - plugin_source = plugin_container.plugin_repository - task_context.trace('转交安装任务.', 'update-plugin') - await self.install_plugin(plugin_source, task_context) - else: - raise errors.PluginInstallerError('插件无源码信息,无法更新') diff --git a/pkg/plugin/loader.py b/pkg/plugin/loader.py deleted file mode 100644 index 191d8bc1..00000000 --- a/pkg/plugin/loader.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -import abc - -from ..core import app -from . import context - - -class PluginLoader(metaclass=abc.ABCMeta): - """插件加载器抽象类""" - - ap: app.Application - - plugins: list[context.RuntimeContainer] - - def __init__(self, ap: app.Application): - self.ap = ap - self.plugins = [] - - async def initialize(self): - pass - - @abc.abstractmethod - async def load_plugins(self): - pass diff --git a/pkg/plugin/loaders/__init__.py b/pkg/plugin/loaders/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/plugin/loaders/classic.py b/pkg/plugin/loaders/classic.py deleted file mode 100644 index 7bc5631b..00000000 --- a/pkg/plugin/loaders/classic.py +++ /dev/null @@ -1,198 +0,0 @@ -from __future__ import annotations - -import typing -import pkgutil -import importlib -import traceback - -from .. import loader, events, context, models -from ...core import entities as core_entities -from ...provider.tools import entities as tools_entities -from ...utils import funcschema -from ...discover import engine as discover_engine - - -class PluginLoader(loader.PluginLoader): - """加载 plugins/ 目录下的插件""" - - _current_pkg_path = '' - - _current_module_path = '' - - _current_container: context.RuntimeContainer = None - - plugins: list[context.RuntimeContainer] = [] - - def __init__(self, ap): - self.ap = ap - self.plugins = [] - self._current_pkg_path = '' - self._current_module_path = '' - self._current_container = None - - async def initialize(self): - """初始化""" - - def register( - self, name: str, description: str, version: str, author: str - ) -> typing.Callable[[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]]: - self.ap.logger.debug(f'注册插件 {name} {version} by {author}') - container = context.RuntimeContainer( - plugin_name=name, - plugin_label=discover_engine.I18nString(en_US=name, zh_Hans=name), - plugin_description=discover_engine.I18nString(en_US=description, zh_Hans=description), - plugin_version=version, - plugin_author=author, - plugin_repository='', - pkg_path=self._current_pkg_path, - main_file=self._current_module_path, - event_handlers={}, - tools=[], - ) - - self._current_container = container - - def wrapper(cls: context.BasePlugin) -> typing.Type[context.BasePlugin]: - container.plugin_class = cls - return cls - - return wrapper - - # 过时 - # 最早将于 v3.4 版本移除 - def on(self, event: typing.Type[events.BaseEventModel]) -> typing.Callable[[typing.Callable], typing.Callable]: - """注册过时的事件处理器""" - self.ap.logger.debug(f'注册事件处理器 {event.__name__}') - - def wrapper(func: typing.Callable) -> typing.Callable: - async def handler(plugin: context.BasePlugin, ctx: context.EventContext) -> None: - args = { - 'host': ctx.host, - 'event': ctx, - } - - # 把 ctx.event 所有的属性都放到 args 里 - # for k, v in ctx.event.dict().items(): - # args[k] = v - for attr_name in ctx.event.__dict__.keys(): - args[attr_name] = getattr(ctx.event, attr_name) - - func(plugin, **args) - - self._current_container.event_handlers[event] = handler - - return func - - return wrapper - - # 过时 - # 最早将于 v3.4 版本移除 - def func( - self, - name: str = None, - ) -> typing.Callable: - """注册过时的内容函数""" - self.ap.logger.debug(f'注册内容函数 {name}') - - def wrapper(func: typing.Callable) -> typing.Callable: - function_schema = funcschema.get_func_schema(func) - function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) - - async def handler(plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs): - return func(*args, **kwargs) - - llm_function = tools_entities.LLMFunction( - name=function_name, - human_desc='', - description=function_schema['description'], - parameters=function_schema['parameters'], - func=handler, - ) - - self._current_container.tools.append(llm_function) - - return func - - return wrapper - - def handler(self, event: typing.Type[events.BaseEventModel]) -> typing.Callable[[typing.Callable], typing.Callable]: - """注册事件处理器""" - self.ap.logger.debug(f'注册事件处理器 {event.__name__}') - - def wrapper(func: typing.Callable) -> typing.Callable: - if ( - self._current_container is None - ): # None indicates this plugin is registered through manifest, so ignore it here - return func - - self._current_container.event_handlers[event] = func - - return func - - return wrapper - - def llm_func( - self, - name: str = None, - ) -> typing.Callable: - """注册内容函数""" - self.ap.logger.debug(f'注册内容函数 {name}') - - def wrapper(func: typing.Callable) -> typing.Callable: - if ( - self._current_container is None - ): # None indicates this plugin is registered through manifest, so ignore it here - return func - - function_schema = funcschema.get_func_schema(func) - function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) - - llm_function = tools_entities.LLMFunction( - name=function_name, - human_desc='', - description=function_schema['description'], - parameters=function_schema['parameters'], - func=func, - ) - - self._current_container.tools.append(llm_function) - - return func - - return wrapper - - async def _walk_plugin_path(self, module, prefix='', path_prefix=''): - """遍历插件路径""" - for item in pkgutil.iter_modules(module.__path__): - if item.ispkg: - await self._walk_plugin_path( - __import__(module.__name__ + '.' + item.name, fromlist=['']), - prefix + item.name + '.', - path_prefix + item.name + '/', - ) - else: - try: - self._current_pkg_path = 'plugins/' + path_prefix - self._current_module_path = 'plugins/' + path_prefix + item.name + '.py' - - self._current_container = None - - importlib.import_module(module.__name__ + '.' + item.name) - - if self._current_container is not None: - self.plugins.append(self._current_container) - self.ap.logger.debug(f'插件 {self._current_container} 已加载') - except Exception: - self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误') - traceback.print_exc() - - async def load_plugins(self): - """加载插件""" - setattr(models, 'register', self.register) - setattr(models, 'on', self.on) - setattr(models, 'func', self.func) - - setattr(context, 'register', self.register) - setattr(context, 'handler', self.handler) - setattr(context, 'llm_func', self.llm_func) - await self._walk_plugin_path(__import__('plugins', fromlist=[''])) diff --git a/pkg/plugin/loaders/manifest.py b/pkg/plugin/loaders/manifest.py deleted file mode 100644 index cce6c9e3..00000000 --- a/pkg/plugin/loaders/manifest.py +++ /dev/null @@ -1,96 +0,0 @@ -from __future__ import annotations - -import typing -import os -import traceback - -from ...core import app -from .. import context, events -from .. import loader -from ...utils import funcschema -from ...provider.tools import entities as tools_entities - - -class PluginManifestLoader(loader.PluginLoader): - """通过插件清单发现插件""" - - _current_container: context.RuntimeContainer = None - - def __init__(self, ap: app.Application): - super().__init__(ap) - - def handler(self, event: typing.Type[events.BaseEventModel]) -> typing.Callable[[typing.Callable], typing.Callable]: - """注册事件处理器""" - self.ap.logger.debug(f'注册事件处理器 {event.__name__}') - - def wrapper(func: typing.Callable) -> typing.Callable: - self._current_container.event_handlers[event] = func - - return func - - return wrapper - - def llm_func( - self, - name: str = None, - ) -> typing.Callable: - """注册内容函数""" - self.ap.logger.debug(f'注册内容函数 {name}') - - def wrapper(func: typing.Callable) -> typing.Callable: - function_schema = funcschema.get_func_schema(func) - function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) - - llm_function = tools_entities.LLMFunction( - name=function_name, - human_desc='', - description=function_schema['description'], - parameters=function_schema['parameters'], - func=func, - ) - - self._current_container.tools.append(llm_function) - - return func - - return wrapper - - async def load_plugins(self): - """加载插件""" - setattr(context, 'handler', self.handler) - setattr(context, 'llm_func', self.llm_func) - - plugin_manifests = self.ap.discover.get_components_by_kind('Plugin') - - for plugin_manifest in plugin_manifests: - try: - config_schema = plugin_manifest.spec['config'] if 'config' in plugin_manifest.spec else [] - - current_plugin_container = context.RuntimeContainer( - plugin_name=plugin_manifest.metadata.name, - plugin_label=plugin_manifest.metadata.label, - plugin_description=plugin_manifest.metadata.description, - plugin_version=plugin_manifest.metadata.version, - plugin_author=plugin_manifest.metadata.author, - plugin_repository=plugin_manifest.metadata.repository, - main_file=os.path.join(plugin_manifest.rel_dir, plugin_manifest.execution.python.path), - pkg_path=plugin_manifest.rel_dir, - config_schema=config_schema, - event_handlers={}, - tools=[], - ) - - self._current_container = current_plugin_container - - # extract the plugin class - # this step will load the plugin module, - # so the event handlers and tools will be registered - plugin_class = plugin_manifest.get_python_component_class() - current_plugin_container.plugin_class = plugin_class - - # TODO load component extensions - - self.plugins.append(current_plugin_container) - except Exception: - self.ap.logger.error(f'加载插件 {plugin_manifest.metadata.name} 时发生错误') - traceback.print_exc() diff --git a/pkg/plugin/manager.py b/pkg/plugin/manager.py deleted file mode 100644 index bf2027f4..00000000 --- a/pkg/plugin/manager.py +++ /dev/null @@ -1,308 +0,0 @@ -from __future__ import annotations - -import traceback - -import sqlalchemy - -from ..core import app, taskmgr -from . import context, loader, events, installer, models -from .loaders import classic, manifest -from .installers import github -from ..entity.persistence import plugin as persistence_plugin - - -class PluginManager: - """插件管理器""" - - ap: app.Application - - loaders: list[loader.PluginLoader] - - installer: installer.PluginInstaller - - api_host: context.APIHost - - plugin_containers: list[context.RuntimeContainer] - - def plugins( - self, - enabled: bool = None, - status: context.RuntimeContainerStatus = None, - ) -> list[context.RuntimeContainer]: - """获取插件列表""" - plugins = self.plugin_containers - - if enabled is not None: - plugins = [plugin for plugin in plugins if plugin.enabled == enabled] - - if status is not None: - plugins = [plugin for plugin in plugins if plugin.status == status] - - return plugins - - def get_plugin( - self, - author: str, - plugin_name: str, - ) -> context.RuntimeContainer: - """通过作者和插件名获取插件""" - for plugin in self.plugins(): - if plugin.plugin_author == author and plugin.plugin_name == plugin_name: - return plugin - return None - - def __init__(self, ap: app.Application): - self.ap = ap - self.loaders = [ - classic.PluginLoader(ap), - manifest.PluginManifestLoader(ap), - ] - self.installer = github.GitHubRepoInstaller(ap) - self.api_host = context.APIHost(ap) - self.plugin_containers = [] - - async def initialize(self): - for loader in self.loaders: - await loader.initialize() - await self.installer.initialize() - await self.api_host.initialize() - - setattr(models, 'require_ver', self.api_host.require_ver) - - async def load_plugins(self): - self.ap.logger.info('Loading all plugins...') - - for loader in self.loaders: - await loader.load_plugins() - self.plugin_containers.extend(loader.plugins) - - await self.load_plugin_settings(self.plugin_containers) - - # 按优先级倒序 - self.plugin_containers.sort(key=lambda x: x.priority, reverse=False) - - self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugin_containers}') - - async def load_plugin_settings(self, plugin_containers: list[context.RuntimeContainer]): - for plugin_container in plugin_containers: - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_plugin.PluginSetting) - .where(persistence_plugin.PluginSetting.plugin_author == plugin_container.plugin_author) - .where(persistence_plugin.PluginSetting.plugin_name == plugin_container.plugin_name) - ) - - setting = result.first() - - if setting is None: - new_setting_data = { - 'plugin_author': plugin_container.plugin_author, - 'plugin_name': plugin_container.plugin_name, - 'enabled': plugin_container.enabled, - 'priority': plugin_container.priority, - 'config': plugin_container.plugin_config, - } - - await self.ap.persistence_mgr.execute_async( - sqlalchemy.insert(persistence_plugin.PluginSetting).values(**new_setting_data) - ) - continue - else: - plugin_container.enabled = setting.enabled - plugin_container.priority = setting.priority - plugin_container.plugin_config = setting.config - - async def dump_plugin_container_setting(self, plugin_container: context.RuntimeContainer): - """保存单个插件容器的设置到数据库""" - await self.ap.persistence_mgr.execute_async( - sqlalchemy.update(persistence_plugin.PluginSetting) - .where(persistence_plugin.PluginSetting.plugin_author == plugin_container.plugin_author) - .where(persistence_plugin.PluginSetting.plugin_name == plugin_container.plugin_name) - .values( - enabled=plugin_container.enabled, - priority=plugin_container.priority, - config=plugin_container.plugin_config, - ) - ) - - async def initialize_plugin(self, plugin: context.RuntimeContainer): - self.ap.logger.debug(f'初始化插件 {plugin.plugin_name}') - plugin.plugin_inst = plugin.plugin_class(self.api_host) - plugin.plugin_inst.config = plugin.plugin_config - plugin.plugin_inst.ap = self.ap - plugin.plugin_inst.host = self.api_host - await plugin.plugin_inst.initialize() - plugin.status = context.RuntimeContainerStatus.INITIALIZED - - async def initialize_plugins(self): - for plugin in self.plugins(): - if not plugin.enabled: - self.ap.logger.debug(f'插件 {plugin.plugin_name} 未启用,跳过初始化') - continue - try: - await self.initialize_plugin(plugin) - except Exception as e: - self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}') - self.ap.logger.exception(e) - continue - - async def destroy_plugin(self, plugin: context.RuntimeContainer): - if plugin.status != context.RuntimeContainerStatus.INITIALIZED: - return - - self.ap.logger.debug(f'释放插件 {plugin.plugin_name}') - plugin.plugin_inst.__del__() - await plugin.plugin_inst.destroy() - plugin.plugin_inst = None - plugin.status = context.RuntimeContainerStatus.MOUNTED - - async def destroy_plugins(self): - for plugin in self.plugins(): - if plugin.status != context.RuntimeContainerStatus.INITIALIZED: - self.ap.logger.debug(f'插件 {plugin.plugin_name} 未初始化,跳过释放') - continue - - try: - await self.destroy_plugin(plugin) - except Exception as e: - self.ap.logger.error(f'插件 {plugin.plugin_name} 释放失败: {e}') - self.ap.logger.exception(e) - continue - - async def install_plugin( - self, - plugin_source: str, - task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), - ): - """安装插件""" - await self.installer.install_plugin(plugin_source, task_context) - - # TODO statistics - - task_context.trace('重载插件..', 'reload-plugin') - await self.ap.reload(scope='plugin') - - async def uninstall_plugin( - self, - plugin_name: str, - task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), - ): - """卸载插件""" - - plugin_container = self.get_plugin_by_name(plugin_name) - - if plugin_container is None: - raise ValueError(f'插件 {plugin_name} 不存在') - - await self.destroy_plugin(plugin_container) - await self.installer.uninstall_plugin(plugin_name, task_context) - - # TODO statistics - - task_context.trace('重载插件..', 'reload-plugin') - await self.ap.reload(scope='plugin') - - async def update_plugin( - self, - plugin_name: str, - plugin_source: str = None, - task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), - ): - """更新插件""" - await self.installer.update_plugin(plugin_name, plugin_source, task_context) - - # TODO statistics - - task_context.trace('重载插件..', 'reload-plugin') - await self.ap.reload(scope='plugin') - - def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer: - """通过插件名获取插件""" - for plugin in self.plugins(): - if plugin.plugin_name == plugin_name: - return plugin - return None - - async def emit_event(self, event: events.BaseEventModel) -> context.EventContext: - """触发事件""" - - ctx = context.EventContext(host=self.api_host, event=event) - - emitted_plugins: list[context.RuntimeContainer] = [] - - for plugin in self.plugins(enabled=True, status=context.RuntimeContainerStatus.INITIALIZED): - if event.__class__ in plugin.event_handlers: - self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}') - - is_prevented_default_before_call = ctx.is_prevented_default() - - try: - await plugin.event_handlers[event.__class__](plugin.plugin_inst, ctx) - except Exception as e: - self.ap.logger.error( - f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}' - ) - self.ap.logger.debug(f'Traceback: {traceback.format_exc()}') - - emitted_plugins.append(plugin) - - if not is_prevented_default_before_call and ctx.is_prevented_default(): - self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行') - - if ctx.is_prevented_postorder(): - self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行') - break - - for key in ctx.__return_value__.keys(): - if hasattr(ctx.event, key): - setattr(ctx.event, key, ctx.__return_value__[key][0]) - - self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}') - - # TODO statistics - - return ctx - - async def update_plugin_switch(self, plugin_name: str, new_status: bool): - if self.get_plugin_by_name(plugin_name) is not None: - for plugin in self.plugins(): - if plugin.plugin_name == plugin_name: - if plugin.enabled == new_status: - return False - - # 初始化/释放插件 - if new_status: - await self.initialize_plugin(plugin) - else: - await self.destroy_plugin(plugin) - - plugin.enabled = new_status - - await self.dump_plugin_container_setting(plugin) - - break - - return True - else: - return False - - async def reorder_plugins(self, plugins: list[dict]): - for plugin in plugins: - plugin_name = plugin.get('name') - plugin_priority = plugin.get('priority') - - for plugin in self.plugin_containers: - if plugin.plugin_name == plugin_name: - plugin.priority = plugin_priority - break - - self.plugin_containers.sort(key=lambda x: x.priority, reverse=False) - - for plugin in self.plugin_containers: - await self.dump_plugin_container_setting(plugin) - - async def set_plugin_config(self, plugin_container: context.RuntimeContainer, new_config: dict): - plugin_container.plugin_config = new_config - - plugin_container.plugin_inst.config = new_config - - await self.dump_plugin_container_setting(plugin_container) diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py deleted file mode 100644 index dbde89a9..00000000 --- a/pkg/plugin/models.py +++ /dev/null @@ -1,28 +0,0 @@ -# 此模块已过时,请引入 pkg.plugin.context 中的 register, handler 和 llm_func 来注册插件、事件处理函数和内容函数 -# 各个事件模型请从 pkg.plugin.events 引入 -# 最早将于 v3.4 移除此模块 - -from __future__ import annotations - -import typing - -from .context import BasePlugin as Plugin -from .events import * - - -def register( - name: str, description: str, version: str, author -) -> typing.Callable[[typing.Type[Plugin]], typing.Type[Plugin]]: - pass - - -def on( - event: typing.Type[BaseEventModel], -) -> typing.Callable[[typing.Callable], typing.Callable]: - pass - - -def func( - name: str = None, -) -> typing.Callable: - pass diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py deleted file mode 100644 index 4c4a65c1..00000000 --- a/pkg/provider/entities.py +++ /dev/null @@ -1,224 +0,0 @@ -from __future__ import annotations - -import typing -import pydantic.v1 as pydantic - -from pkg.provider import entities - - -from ..platform.types import message as platform_message - - -class FunctionCall(pydantic.BaseModel): - name: str - - arguments: str - - -class ToolCall(pydantic.BaseModel): - id: str - - type: str - - function: FunctionCall - - -class ImageURLContentObject(pydantic.BaseModel): - url: str - - def __str__(self): - return self.url[:128] + ('...' if len(self.url) > 128 else '') - - -class ContentElement(pydantic.BaseModel): - type: str - """内容类型""" - - text: typing.Optional[str] = None - - image_url: typing.Optional[ImageURLContentObject] = None - - image_base64: typing.Optional[str] = None - - def __str__(self): - if self.type == 'text': - return self.text - elif self.type == 'image_url': - return f'[图片]({self.image_url})' - else: - return '未知内容' - - @classmethod - def from_text(cls, text: str): - return cls(type='text', text=text) - - @classmethod - def from_image_url(cls, image_url: str): - return cls(type='image_url', image_url=ImageURLContentObject(url=image_url)) - - @classmethod - def from_image_base64(cls, image_base64: str): - return cls(type='image_base64', image_base64=image_base64) - - -class Message(pydantic.BaseModel): - """消息""" - - role: str # user, system, assistant, tool, command, plugin - """消息的角色""" - - name: typing.Optional[str] = None - """名称,仅函数调用返回时设置""" - - content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None - """内容""" - - tool_calls: typing.Optional[list[ToolCall]] = None - """工具调用""" - - tool_call_id: typing.Optional[str] = None - - def readable_str(self) -> str: - if self.content is not None: - return str(self.role) + ': ' + str(self.get_content_platform_message_chain()) - elif self.tool_calls is not None: - return f'调用工具: {self.tool_calls[0].id}' - else: - return '未知消息' - - def get_content_platform_message_chain(self, prefix_text: str = '') -> platform_message.MessageChain | None: - """将内容转换为平台消息 MessageChain 对象 - - Args: - prefix_text (str): 首个文字组件的前缀文本 - """ - - if self.content is None: - return None - elif isinstance(self.content, str): - return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)]) - elif isinstance(self.content, list): - mc = [] - for ce in self.content: - if ce.type == 'text': - mc.append(platform_message.Plain(ce.text)) - elif ce.type == 'image_url': - if ce.image_url.url.startswith('http'): - mc.append(platform_message.Image(url=ce.image_url.url)) - else: # base64 - b64_str = ce.image_url.url - - if b64_str.startswith('data:'): - b64_str = b64_str.split(',')[1] - - mc.append(platform_message.Image(base64=b64_str)) - - # 找第一个文字组件 - if prefix_text: - for i, c in enumerate(mc): - if isinstance(c, platform_message.Plain): - mc[i] = platform_message.Plain(prefix_text + c.text) - break - else: - mc.insert(0, platform_message.Plain(prefix_text)) - - return platform_message.MessageChain(mc) - - -class MessageChunk(pydantic.BaseModel): - """消息""" - - resp_message_id: typing.Optional[str] = None - """消息id""" - - role: str # user, system, assistant, tool, command, plugin - """消息的角色""" - - name: typing.Optional[str] = None - """名称,仅函数调用返回时设置""" - - all_content: typing.Optional[str] = None - """所有内容""" - - content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None - """内容""" - - tool_calls: typing.Optional[list[ToolCall]] = None - """工具调用""" - - tool_call_id: typing.Optional[str] = None - - is_final: bool = False - """是否是结束""" - - msg_sequence: int = 0 - """消息迭代次数""" - - def readable_str(self) -> str: - if self.content is not None: - return str(self.role) + ': ' + str(self.get_content_platform_message_chain()) - elif self.tool_calls is not None: - return f'调用工具: {self.tool_calls[0].id}' - else: - return '未知消息' - - def get_content_platform_message_chain(self, prefix_text: str = '') -> platform_message.MessageChain | None: - """将内容转换为平台消息 MessageChain 对象 - - Args: - prefix_text (str): 首个文字组件的前缀文本 - """ - - if self.content is None: - return None - elif isinstance(self.content, str): - return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)]) - elif isinstance(self.content, list): - mc = [] - for ce in self.content: - if ce.type == 'text': - mc.append(platform_message.Plain(ce.text)) - elif ce.type == 'image_url': - if ce.image_url.url.startswith('http'): - mc.append(platform_message.Image(url=ce.image_url.url)) - else: # base64 - b64_str = ce.image_url.url - - if b64_str.startswith('data:'): - b64_str = b64_str.split(',')[1] - - mc.append(platform_message.Image(base64=b64_str)) - - # 找第一个文字组件 - if prefix_text: - for i, c in enumerate(mc): - if isinstance(c, platform_message.Plain): - mc[i] = platform_message.Plain(prefix_text + c.text) - break - else: - mc.insert(0, platform_message.Plain(prefix_text)) - - return platform_message.MessageChain(mc) - - -class ToolCallChunk(pydantic.BaseModel): - """工具调用""" - - id: str - """工具调用ID""" - - type: str - """工具调用类型""" - - function: FunctionCall - """函数调用""" - - -class Prompt(pydantic.BaseModel): - """供AI使用的Prompt""" - - name: str - """名称""" - - messages: list[entities.Message] - """消息列表""" diff --git a/pkg/provider/modelmgr/entities.py b/pkg/provider/modelmgr/entities.py index 7bc02a32..efe9c112 100644 --- a/pkg/provider/modelmgr/entities.py +++ b/pkg/provider/modelmgr/entities.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing -import pydantic.v1 as pydantic +import pydantic from . import requester from . import token diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 2c92eacc..d649b41e 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -3,7 +3,7 @@ from __future__ import annotations import sqlalchemy import traceback -from . import entities, requester +from . import requester from ...core import app from ...discover import engine from . import token @@ -16,14 +16,6 @@ FETCH_MODEL_LIST_URL = 'https://api.qchatgpt.rockchin.top/api/v2/fetch/model_lis class ModelManager: """模型管理器""" - model_list: list[entities.LLMModelInfo] # deprecated - - requesters: dict[str, requester.ProviderAPIRequester] # deprecated - - token_mgrs: dict[str, token.TokenManager] # deprecated - - # ====== 4.0 ====== - ap: app.Application llm_models: list[requester.RuntimeLLMModel] @@ -36,9 +28,6 @@ class ModelManager: def __init__(self, ap: app.Application): self.ap = ap - self.model_list = [] - self.requesters = {} - self.token_mgrs = {} self.llm_models = [] self.embedding_models = [] self.requester_components = [] @@ -149,13 +138,6 @@ class ModelManager: runtime_embedding_model = await self.init_runtime_embedding_model(model_info) self.embedding_models.append(runtime_embedding_model) - async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated - """通过名称获取模型""" - for model in self.model_list: - if model.name == name: - return model - raise ValueError(f'无法确定模型 {name} 的信息') - async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel: """通过uuid获取 LLM 模型""" for model in self.llm_models: diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 6af8ba70..52d73eea 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -4,11 +4,11 @@ import abc import typing from ...core import app -from ...core import entities as core_entities -from .. import entities as llm_entities -from ..tools import entities as tools_entities from ...entity.persistence import model as persistence_model +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool from . import token +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message class RuntimeLLMModel: @@ -79,13 +79,13 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): @abc.abstractmethod async def invoke_llm( self, - query: core_entities.Query, + query: pipeline_query.Query, model: RuntimeLLMModel, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message: + ) -> provider_message.Message: """调用API Args: @@ -102,42 +102,41 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): async def invoke_llm_stream( self, - query: core_entities.Query, + query: pipeline_query.Query, model: RuntimeLLMModel, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.MessageChunk: + ) -> provider_message.MessageChunk: """调用API Args: model (RuntimeLLMModel): 使用的模型信息 - messages (typing.List[llm_entities.Message]): 消息对象列表 - funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None. + messages (typing.List[provider_message.Message]): 消息对象列表 + funcs (typing.List[resource_tool.LLMTool], optional): 使用的工具函数列表. Defaults to None. extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. remove_think (bool, optional): 是否移除思考中的消息. Defaults to False. Returns: - typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 + typing.AsyncGenerator[provider_message.MessageChunk]: 返回消息对象 """ pass async def invoke_embedding( self, model: RuntimeEmbeddingModel, - input_text: list[str], + input_text: typing.List[str], extra_args: dict[str, typing.Any] = {}, - ) -> list[list[float]]: + ) -> typing.List[typing.List[float]]: """调用 Embedding API Args: - query (core_entities.Query): 请求上下文 model (RuntimeEmbeddingModel): 使用的模型信息 - input_text (list[str]): 输入文本 + input_text (typing.List[str]): 输入文本 extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. Returns: - list[list[float]]: 返回的 embedding 向量 + typing.List[typing.List[float]]: 返回的 embedding 向量 """ pass diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index cb0c7ce1..3a1b9384 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -9,10 +9,10 @@ import httpx from .. import errors, requester -from ....core import entities as core_entities -from ... import entities as llm_entities -from ...tools import entities as tools_entities from ....utils import image +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message class AnthropicMessages(requester.ProviderAPIRequester): @@ -49,13 +49,13 @@ class AnthropicMessages(requester.ProviderAPIRequester): async def invoke_llm( self, - query: core_entities.Query, + query: pipeline_query.Query, model: requester.RuntimeLLMModel, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message: + ) -> provider_message.Message: self.client.api_key = model.token_mgr.get_token() args = extra_args.copy() @@ -75,7 +75,7 @@ class AnthropicMessages(requester.ProviderAPIRequester): if system_role_message: messages.pop(i) - if isinstance(system_role_message, llm_entities.Message) and isinstance(system_role_message.content, str): + if isinstance(system_role_message, provider_message.Message) and isinstance(system_role_message.content, str): args['system'] = system_role_message.content req_messages = [] @@ -161,16 +161,16 @@ class AnthropicMessages(requester.ProviderAPIRequester): args['content'] += block.text elif block.type == 'tool_use': assert type(block) is anthropic.types.tool_use_block.ToolUseBlock - tool_call = llm_entities.ToolCall( + tool_call = provider_message.ToolCall( id=block.id, type='function', - function=llm_entities.FunctionCall(name=block.name, arguments=json.dumps(block.input)), + function=provider_message.FunctionCall(name=block.name, arguments=json.dumps(block.input)), ) if 'tool_calls' not in args: args['tool_calls'] = [] args['tool_calls'].append(tool_call) - return llm_entities.Message(**args) + return provider_message.Message(**args) except anthropic.AuthenticationError as e: raise errors.RequesterError(f'api-key 无效: {e.message}') except anthropic.BadRequestError as e: @@ -183,13 +183,13 @@ class AnthropicMessages(requester.ProviderAPIRequester): async def invoke_llm_stream( self, - query: core_entities.Query, + query: pipeline_query.Query, model: requester.RuntimeLLMModel, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message: + ) -> provider_message.Message: self.client.api_key = model.token_mgr.get_token() args = extra_args.copy() @@ -210,7 +210,7 @@ class AnthropicMessages(requester.ProviderAPIRequester): if system_role_message: messages.pop(i) - if isinstance(system_role_message, llm_entities.Message) and isinstance(system_role_message.content, str): + if isinstance(system_role_message, provider_message.Message) and isinstance(system_role_message.content, str): args['system'] = system_role_message.content req_messages = [] @@ -356,7 +356,7 @@ class AnthropicMessages(requester.ProviderAPIRequester): # assert type(chunk) is anthropic.types.message.Chunk - yield llm_entities.MessageChunk(**args) + yield provider_message.MessageChunk(**args) # return llm_entities.Message(**args) except anthropic.AuthenticationError as e: diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 7afda84f..b940859e 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -8,9 +8,9 @@ import openai.types.chat.chat_completion as chat_completion import httpx from .. import errors, requester -from ....core import entities as core_entities -from ... import entities as llm_entities -from ...tools import entities as tools_entities +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message class OpenAIChatCompletions(requester.ProviderAPIRequester): @@ -50,7 +50,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): self, chat_completion: chat_completion.ChatCompletion, remove_think: bool = False, - ) -> llm_entities.Message: + ) -> provider_message.Message: chatcmpl_message = chat_completion.choices[0].message.model_dump() # 确保 role 字段存在且不为 None @@ -71,7 +71,8 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): if 'reasoning_content' in chatcmpl_message: del chatcmpl_message['reasoning_content'] - message = llm_entities.Message(**chatcmpl_message) + message = provider_message.Message(**chatcmpl_message) + return message async def _process_thinking_content( @@ -122,13 +123,13 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): async def _closure_stream( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, + use_funcs: list[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.MessageChunk: + ) -> provider_message.MessageChunk: self.client.api_key = use_model.token_mgr.get_token() args = {} @@ -155,12 +156,12 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): args['stream'] = True # 流式处理状态 - tool_calls_map: dict[str, llm_entities.ToolCall] = {} + # tool_calls_map: dict[str, provider_message.ToolCall] = {} chunk_idx = 0 thinking_started = False thinking_ended = False role = 'assistant' # 默认角色 - tool_id = "" + tool_id = '' tool_name = '' # accumulated_reasoning = '' # 仅用于判断何时结束思维链 @@ -223,8 +224,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): if tool_call['type'] is None: tool_call['type'] = 'function' - - # 跳过空的第一个 chunk(只有 role 没有内容) if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'): chunk_idx += 1 @@ -240,18 +239,18 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): # 移除 None 值 chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - yield llm_entities.MessageChunk(**chunk_data) + yield provider_message.MessageChunk(**chunk_data) chunk_idx += 1 async def _closure( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, + use_funcs: list[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message: + ) -> provider_message.Message: self.client.api_key = use_model.token_mgr.get_token() args = {} @@ -287,13 +286,13 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): async def invoke_llm( self, - query: core_entities.Query, + query: pipeline_query.Query, model: requester.RuntimeLLMModel, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message: + ) -> provider_message.Message: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: msg_dict = m.dict(exclude_none=True) @@ -361,13 +360,13 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): async def invoke_llm_stream( self, - query: core_entities.Query, + query: pipeline_query.Query, model: requester.RuntimeLLMModel, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.MessageChunk: + ) -> provider_message.MessageChunk: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: msg_dict = m.dict(exclude_none=True) diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py index 4866caf4..83b2bfa4 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py @@ -4,9 +4,9 @@ import typing from . import chatcmpl from .. import errors, requester -from ....core import entities as core_entities -from ... import entities as llm_entities -from ...tools import entities as tools_entities +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): @@ -19,13 +19,13 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): async def _closure( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, + use_funcs: list[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message: + ) -> provider_message.Message: self.client.api_key = use_model.token_mgr.get_token() args = {} diff --git a/pkg/provider/modelmgr/requesters/geminichatcmpl.py b/pkg/provider/modelmgr/requesters/geminichatcmpl.py index df2db312..9741e6b3 100644 --- a/pkg/provider/modelmgr/requesters/geminichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/geminichatcmpl.py @@ -6,10 +6,10 @@ from . import chatcmpl import uuid -from .. import errors, requester -from ....core import entities as core_entities -from ... import entities as llm_entities -from ...tools import entities as tools_entities +from .. import requester +import langbot_plugin.api.entities.builtin.provider.message as provider_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions): @@ -20,16 +20,15 @@ class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions): 'timeout': 120, } - async def _closure_stream( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, + use_funcs: list[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.MessageChunk: + ) -> provider_message.MessageChunk: self.client.api_key = use_model.token_mgr.get_token() args = {} @@ -56,12 +55,12 @@ class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions): args['stream'] = True # 流式处理状态 - tool_calls_map: dict[str, llm_entities.ToolCall] = {} + # tool_calls_map: dict[str, provider_message.ToolCall] = {} chunk_idx = 0 thinking_started = False thinking_ended = False role = 'assistant' # 默认角色 - tool_id = "" + tool_id = '' tool_name = '' # accumulated_reasoning = '' # 仅用于判断何时结束思维链 @@ -117,15 +116,13 @@ class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions): for tool_call in delta['tool_calls']: if tool_call['id'] == '' and tool_id == '': tool_id = str(uuid.uuid4()) - if tool_call['function']['name']: + if tool_call['function']['name']: tool_name = tool_call['function']['name'] tool_call['id'] = tool_id tool_call['function']['name'] = tool_name if tool_call['type'] is None: tool_call['type'] = 'function' - - # 跳过空的第一个 chunk(只有 role 没有内容) if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'): chunk_idx += 1 @@ -141,5 +138,5 @@ class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions): # 移除 None 值 chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - yield llm_entities.MessageChunk(**chunk_data) - chunk_idx += 1 \ No newline at end of file + yield provider_message.MessageChunk(**chunk_data) + chunk_idx += 1 diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py index f8cf15ca..4e295e9f 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py @@ -4,12 +4,6 @@ from __future__ import annotations import typing from . import ppiochatcmpl -from .. import requester -from ....core import entities as core_entities -from ... import entities as llm_entities -from ...tools import entities as tools_entities -import re -import openai.types.chat.chat_completion as chat_completion class GiteeAIChatCompletions(ppiochatcmpl.PPIOChatCompletions): @@ -19,4 +13,3 @@ class GiteeAIChatCompletions(ppiochatcmpl.PPIOChatCompletions): 'base_url': 'https://ai.gitee.com/v1', 'timeout': 120, } - diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index 82d8df70..8684a677 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -1,18 +1,16 @@ from __future__ import annotations import asyncio -import json import typing import openai import openai.types.chat.chat_completion as chat_completion -import openai.types.chat.chat_completion_message_tool_call as chat_completion_message_tool_call import httpx from .. import entities, errors, requester -from ....core import entities as core_entities -from ... import entities as llm_entities -from ...tools import entities as tools_entities +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message class ModelScopeChatCompletions(requester.ProviderAPIRequester): @@ -35,7 +33,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): async def _req( self, - query: core_entities.Query, + query: pipeline_query.Query, args: dict, extra_body: dict = {}, remove_think: bool = False, @@ -113,25 +111,26 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): async def _make_msg( self, chat_completion: list[dict[str, typing.Any]], - ) -> llm_entities.Message: + ) -> provider_message.Message: chatcmpl_message = chat_completion[0] # 确保 role 字段存在且不为 None if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: chatcmpl_message['role'] = 'assistant' - message = llm_entities.Message(**chatcmpl_message) + + message = provider_message.Message(**chatcmpl_message) return message async def _closure( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, + use_funcs: list[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, - remove_think:bool = False, - ) -> llm_entities.Message: + remove_think: bool = False, + ) -> provider_message.Message: self.client.api_key = use_model.token_mgr.get_token() args = {} @@ -173,16 +172,15 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body): yield chunk - async def _closure_stream( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, + use_funcs: list[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: + ) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]: self.client.api_key = use_model.token_mgr.get_token() args = {} @@ -209,9 +207,8 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): args['messages'] = messages args['stream'] = True - # 流式处理状态 - tool_calls_map: dict[str, llm_entities.ToolCall] = {} + # tool_calls_map: dict[str, provider_message.ToolCall] = {} chunk_idx = 0 thinking_started = False thinking_ended = False @@ -275,8 +272,9 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): tool_call['type'] = 'function' tool_call['id'] = tool_id tool_call['function']['name'] = tool_name - tool_call['function']['arguments'] = "" if tool_call['function']['arguments'] is None else tool_call['function']['arguments'] - + tool_call['function']['arguments'] = ( + '' if tool_call['function']['arguments'] is None else tool_call['function']['arguments'] + ) # 跳过空的第一个 chunk(只有 role 没有内容) if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'): @@ -294,19 +292,19 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): # 移除 None 值 chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - yield llm_entities.MessageChunk(**chunk_data) + yield provider_message.MessageChunk(**chunk_data) chunk_idx += 1 # return async def invoke_llm( self, - query: core_entities.Query, + query: pipeline_query.Query, model: entities.LLMModelInfo, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message: + ) -> provider_message.Message: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: msg_dict = m.dict(exclude_none=True) @@ -320,7 +318,12 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): try: return await self._closure( - query=query, req_messages=req_messages, use_model=model, use_funcs=funcs, extra_args=extra_args, remove_think=remove_think + query=query, + req_messages=req_messages, + use_model=model, + use_funcs=funcs, + extra_args=extra_args, + remove_think=remove_think, ) except asyncio.TimeoutError: raise errors.RequesterError('请求超时') @@ -340,13 +343,13 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): async def invoke_llm_stream( self, - query: core_entities.Query, + query: pipeline_query.Query, model: requester.RuntimeLLMModel, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.MessageChunk: + ) -> provider_message.MessageChunk: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: msg_dict = m.dict(exclude_none=True) diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py index 494b2b0f..aa3d0f4f 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py @@ -5,9 +5,9 @@ import typing from . import chatcmpl from .. import requester -from ....core import entities as core_entities -from ... import entities as llm_entities -from ...tools import entities as tools_entities +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): @@ -20,13 +20,13 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): async def _closure( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, + use_funcs: list[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message: + ) -> provider_message.Message: self.client.api_key = use_model.token_mgr.get_token() args = {} diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index e993cab2..97361f89 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -10,9 +10,9 @@ import json import ollama from .. import errors, requester -from ... import entities as llm_entities -from ...tools import entities as tools_entities -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message REQUESTER_NAME: str = 'ollama-chat' @@ -39,13 +39,13 @@ class OllamaChatCompletions(requester.ProviderAPIRequester): async def _closure( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, + use_funcs: list[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message: + ) -> provider_message.Message: args = extra_args.copy() args['model'] = use_model.model_entity.name @@ -74,27 +74,27 @@ class OllamaChatCompletions(requester.ProviderAPIRequester): args['tools'] = tools resp = await self._req(args) - message: llm_entities.Message = await self._make_msg(resp) + message: provider_message.Message = await self._make_msg(resp) return message - async def _make_msg(self, chat_completions: ollama.ChatResponse) -> llm_entities.Message: + async def _make_msg(self, chat_completions: ollama.ChatResponse) -> provider_message.Message: message: ollama.Message = chat_completions.message if message is None: raise ValueError("chat_completions must contain a 'message' field") - ret_msg: llm_entities.Message = None + ret_msg: provider_message.Message = None if message.content is not None: - ret_msg = llm_entities.Message(role='assistant', content=message.content) + ret_msg = provider_message.Message(role='assistant', content=message.content) if message.tool_calls is not None and len(message.tool_calls) > 0: - tool_calls: list[llm_entities.ToolCall] = [] + tool_calls: list[provider_message.ToolCall] = [] for tool_call in message.tool_calls: tool_calls.append( - llm_entities.ToolCall( + provider_message.ToolCall( id=uuid.uuid4().hex, type='function', - function=llm_entities.FunctionCall( + function=provider_message.FunctionCall( name=tool_call.function.name, arguments=json.dumps(tool_call.function.arguments), ), @@ -106,13 +106,13 @@ class OllamaChatCompletions(requester.ProviderAPIRequester): async def invoke_llm( self, - query: core_entities.Query, + query: pipeline_query.Query, model: requester.RuntimeLLMModel, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message: + ) -> provider_message.Message: req_messages: list = [] for m in messages: msg_dict: dict = m.dict(exclude_none=True) @@ -139,8 +139,10 @@ class OllamaChatCompletions(requester.ProviderAPIRequester): input_text: list[str], extra_args: dict[str, typing.Any] = {}, ) -> list[list[float]]: - return (await self.client.embed( - model=model.model_entity.name, - input=input_text, - **extra_args, - )).embeddings + return ( + await self.client.embed( + model=model.model_entity.name, + input=input_text, + **extra_args, + ) + ).embeddings diff --git a/pkg/provider/modelmgr/requesters/ppiochatcmpl.py b/pkg/provider/modelmgr/requesters/ppiochatcmpl.py index 4af1cde0..9658312b 100644 --- a/pkg/provider/modelmgr/requesters/ppiochatcmpl.py +++ b/pkg/provider/modelmgr/requesters/ppiochatcmpl.py @@ -4,12 +4,12 @@ import openai import typing from . import chatcmpl -import openai.types.chat.chat_completion as chat_completion from .. import requester -from ....core import entities as core_entities -from ... import entities as llm_entities -from ...tools import entities as tools_entities +import openai.types.chat.chat_completion as chat_completion import re +import langbot_plugin.api.entities.builtin.provider.message as provider_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): @@ -28,7 +28,7 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): self, chat_completion: chat_completion.ChatCompletion, remove_think: bool, - ) -> llm_entities.Message: + ) -> provider_message.Message: chatcmpl_message = chat_completion.choices[0].message.model_dump() # print(chatcmpl_message.keys(), chatcmpl_message.values()) @@ -39,23 +39,23 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None # deepseek的reasoner模型 - chatcmpl_message["content"] = await self._process_thinking_content( - chatcmpl_message['content'],reasoning_content,remove_think) + chatcmpl_message['content'] = await self._process_thinking_content( + chatcmpl_message['content'], reasoning_content, remove_think + ) # 移除 reasoning_content 字段,避免传递给 Message if 'reasoning_content' in chatcmpl_message: del chatcmpl_message['reasoning_content'] - - message = llm_entities.Message(**chatcmpl_message) + message = provider_message.Message(**chatcmpl_message) return message async def _process_thinking_content( - self, - content: str, - reasoning_content: str = None, - remove_think: bool = False, + self, + content: str, + reasoning_content: str = None, + remove_think: bool = False, ) -> tuple[str, str]: """处理思维链内容 @@ -68,21 +68,17 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): 处理后的内容 """ if remove_think: - content = re.sub( - r'.*?', '', content, flags=re.DOTALL - ) + content = re.sub(r'.*?', '', content, flags=re.DOTALL) else: if reasoning_content is not None: - content = ( - '\n' + reasoning_content + '\n\n' + content - ) + content = '\n' + reasoning_content + '\n\n' + content return content async def _make_msg_chunk( - self, - delta: dict[str, typing.Any], - idx: int, - ) -> llm_entities.MessageChunk: + self, + delta: dict[str, typing.Any], + idx: int, + ) -> provider_message.MessageChunk: # 处理流式chunk和完整响应的差异 # print(chat_completion.choices[0]) @@ -100,19 +96,19 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): if reasoning_content is not None: delta['content'] += reasoning_content - message = llm_entities.MessageChunk(**delta) + message = provider_message.MessageChunk(**delta) return message async def _closure_stream( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, + use_funcs: list[resource_tool.LLMTool] = None, extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, - ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: + ) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]: self.client.api_key = use_model.token_mgr.get_token() args = {} @@ -139,7 +135,7 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): args['messages'] = messages args['stream'] = True - tool_calls_map: dict[str, llm_entities.ToolCall] = {} + # tool_calls_map: dict[str, provider_message.ToolCall] = {} chunk_idx = 0 thinking_started = False thinking_ended = False @@ -176,8 +172,7 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): elif thinking_started and not thinking_ended: continue - - delta_tool_calls = None + # delta_tool_calls = None if delta.get('tool_calls'): for tool_call in delta['tool_calls']: if tool_call['id'] and tool_call['function']['name']: @@ -194,7 +189,7 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): tool_call['type'] = 'function' # 跳过空的第一个 chunk(只有 role 没有内容) - if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'): + if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'): chunk_idx += 1 continue @@ -209,5 +204,5 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): # 移除 None 值 chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - yield llm_entities.MessageChunk(**chunk_data) + yield provider_message.MessageChunk(**chunk_data) chunk_idx += 1 diff --git a/pkg/provider/runner.py b/pkg/provider/runner.py index a74a2dc5..1fff4a76 100644 --- a/pkg/provider/runner.py +++ b/pkg/provider/runner.py @@ -3,8 +3,9 @@ from __future__ import annotations import abc import typing -from ..core import app, entities as core_entities -from . import entities as llm_entities +from ..core import app +import langbot_plugin.api.entities.builtin.provider.message as provider_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_runners: list[typing.Type[RequestRunner]] = [] @@ -35,6 +36,6 @@ class RequestRunner(abc.ABC): self.pipeline_config = pipeline_config @abc.abstractmethod - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: """运行请求""" pass diff --git a/pkg/provider/runners/dashscopeapi.py b/pkg/provider/runners/dashscopeapi.py index 737bc312..d9cf25a5 100644 --- a/pkg/provider/runners/dashscopeapi.py +++ b/pkg/provider/runners/dashscopeapi.py @@ -6,8 +6,9 @@ import re import dashscope from .. import runner -from ...core import app, entities as core_entities -from .. import entities as llm_entities +from ...core import app +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message class DashscopeAPIError(Exception): @@ -65,7 +66,7 @@ class DashScopeAPIRunner(runner.RequestRunner): # 使用 re.sub() 进行替换 return pattern.sub(replacement, text) - async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]: + async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]: """预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)""" plain_text = '' image_ids = [] @@ -89,7 +90,9 @@ class DashScopeAPIRunner(runner.RequestRunner): return plain_text, image_ids - async def _agent_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _agent_messages( + self, query: pipeline_query.Query + ) -> typing.AsyncGenerator[provider_message.Message, None]: """Dashscope 智能体对话请求""" # 局部变量 @@ -103,7 +106,7 @@ class DashScopeAPIRunner(runner.RequestRunner): think_end = False plain_text, image_ids = await self._preprocess_user_message(query) - has_thoughts = True # 获取思考过程 + has_thoughts = True # 获取思考过程 remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think') if remove_think: has_thoughts = False @@ -141,13 +144,13 @@ class DashScopeAPIRunner(runner.RequestRunner): if stream_think[0].get('thought'): if not think_start: think_start = True - pending_content += f"\n{stream_think[0].get('thought')}" + pending_content += f'\n{stream_think[0].get("thought")}' else: # 继续输出 reasoning_content pending_content += stream_think[0].get('thought') - elif stream_think[0].get('thought') == "" and not think_end: + elif stream_think[0].get('thought') == '' and not think_end: think_end = True - pending_content += "\n\n" + pending_content += '\n\n' if stream_output.get('text') is not None: pending_content += stream_output.get('text') # 是否是流式最后一个chunk @@ -166,7 +169,7 @@ class DashScopeAPIRunner(runner.RequestRunner): pending_content = self._replace_references(pending_content, references_dict) if idx_chunk % 8 == 0 or is_final: - yield llm_entities.MessageChunk( + yield provider_message.MessageChunk( role='assistant', content=pending_content, is_final=is_final, @@ -188,13 +191,13 @@ class DashScopeAPIRunner(runner.RequestRunner): if stream_think[0].get('thought'): if not think_start: think_start = True - pending_content += f"\n{stream_think[0].get('thought')}" + pending_content += f'\n{stream_think[0].get("thought")}' else: # 继续输出 reasoning_content pending_content += stream_think[0].get('thought') - elif stream_think[0].get('thought') == "" and not think_end: + elif stream_think[0].get('thought') == '' and not think_end: think_end = True - pending_content += "\n\n" + pending_content += '\n\n' if stream_output.get('text') is not None: pending_content += stream_output.get('text') @@ -213,12 +216,14 @@ class DashScopeAPIRunner(runner.RequestRunner): # 将参考资料替换到文本中 pending_content = self._replace_references(pending_content, references_dict) - yield llm_entities.Message( + yield provider_message.Message( role='assistant', content=pending_content, ) - async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _workflow_messages( + self, query: pipeline_query.Query + ) -> typing.AsyncGenerator[provider_message.Message, None]: """Dashscope 工作流对话请求""" # 局部变量 @@ -242,7 +247,7 @@ class DashScopeAPIRunner(runner.RequestRunner): incremental_output=True, # 增量输出,使用流式输出需要开启增量输出 session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话 biz_params=biz_params, # 工作流应用的自定义输入参数传递 - flow_stream_mode="message_format" # 消息模式,输出/结束节点的流式结果 + flow_stream_mode='message_format', # 消息模式,输出/结束节点的流式结果 # rag_options={ # 主要用于文件交互,暂不支持 # "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个 # } @@ -267,7 +272,7 @@ class DashScopeAPIRunner(runner.RequestRunner): # 获取流式传输的output stream_output = chunk.get('output', {}) if stream_output.get('workflow_message') is not None: - pending_content += stream_output.get('workflow_message').get('message').get('content') + pending_content += stream_output.get('workflow_message').get('message').get('content') # if stream_output.get('text') is not None: # pending_content += stream_output.get('text') @@ -285,7 +290,7 @@ class DashScopeAPIRunner(runner.RequestRunner): # 将参考资料替换到文本中 pending_content = self._replace_references(pending_content, references_dict) if idx_chunk % 8 == 0 or is_final: - yield llm_entities.MessageChunk( + yield provider_message.MessageChunk( role='assistant', content=pending_content, is_final=is_final, @@ -325,23 +330,23 @@ class DashScopeAPIRunner(runner.RequestRunner): # 将参考资料替换到文本中 pending_content = self._replace_references(pending_content, references_dict) - yield llm_entities.Message( + yield provider_message.Message( role='assistant', content=pending_content, ) - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: """运行""" msg_seq = 0 if self.app_type == 'agent': async for msg in self._agent_messages(query): - if isinstance(msg, llm_entities.MessageChunk): + if isinstance(msg, provider_message.MessageChunk): msg_seq += 1 msg.msg_sequence = msg_seq yield msg elif self.app_type == 'workflow': async for msg in self._workflow_messages(query): - if isinstance(msg, llm_entities.MessageChunk): + if isinstance(msg, provider_message.MessageChunk): msg_seq += 1 msg.msg_sequence = msg_seq yield msg diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index 9eb14a6c..b98a9bc3 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -7,10 +7,10 @@ import base64 from .. import runner -from ...core import app, entities as core_entities -from .. import entities as llm_entities +from ...core import app +import langbot_plugin.api.entities.builtin.provider.message as provider_message from ...utils import image - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query from libs.dify_service_api.v1 import client, errors @@ -70,7 +70,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): content = f'\n{thinking_content}\n\n{content}'.strip() return content, thinking_content - async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]: + async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]: """预处理用户消息,提取纯文本,并将图片上传到 Dify 服务 Returns: @@ -98,7 +98,9 @@ class DifyServiceAPIRunner(runner.RequestRunner): return plain_text, image_ids - async def _chat_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _chat_messages( + self, query: pipeline_query.Query + ) -> typing.AsyncGenerator[provider_message.Message, None]: """调用聊天助手""" cov_id = query.session.using_conversation.uuid or '' query.variables['conversation_id'] = cov_id @@ -142,7 +144,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): if chunk['data']['node_type'] == 'answer': content, _ = self._process_thinking_content(chunk['data']['outputs']['answer']) - yield llm_entities.Message( + yield provider_message.Message( role='assistant', content=content, ) @@ -151,7 +153,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): basic_mode_pending_chunk += chunk['answer'] elif chunk['event'] == 'message_end': content, _ = self._process_thinking_content(basic_mode_pending_chunk) - yield llm_entities.Message( + yield provider_message.Message( role='assistant', content=content, ) @@ -163,8 +165,8 @@ class DifyServiceAPIRunner(runner.RequestRunner): query.session.using_conversation.uuid = chunk['conversation_id'] async def _agent_chat_messages( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: + self, query: pipeline_query.Query + ) -> typing.AsyncGenerator[provider_message.Message, None]: """调用聊天助手""" cov_id = query.session.using_conversation.uuid or '' query.variables['conversation_id'] = cov_id @@ -210,7 +212,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): if pending_agent_message.strip() != '': pending_agent_message = pending_agent_message.replace('Action:', '') content, _ = self._process_thinking_content(pending_agent_message) - yield llm_entities.Message( + yield provider_message.Message( role='assistant', content=content, ) @@ -221,13 +223,13 @@ class DifyServiceAPIRunner(runner.RequestRunner): continue if chunk['tool']: - msg = llm_entities.Message( + msg = provider_message.Message( role='assistant', tool_calls=[ - llm_entities.ToolCall( + provider_message.ToolCall( id=chunk['id'], type='function', - function=llm_entities.FunctionCall( + function=provider_message.FunctionCall( name=chunk['tool'], arguments=json.dumps({}), ), @@ -244,9 +246,9 @@ class DifyServiceAPIRunner(runner.RequestRunner): image_url = base_url + chunk['url'] - yield llm_entities.Message( + yield provider_message.Message( role='assistant', - content=[llm_entities.ContentElement.from_image_url(image_url)], + content=[provider_message.ContentElement.from_image_url(image_url)], ) if chunk['event'] == 'error': raise errors.DifyAPIError('dify 服务错误: ' + chunk['message']) @@ -256,7 +258,9 @@ class DifyServiceAPIRunner(runner.RequestRunner): query.session.using_conversation.uuid = chunk['conversation_id'] - async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _workflow_messages( + self, query: pipeline_query.Query + ) -> typing.AsyncGenerator[provider_message.Message, None]: """调用工作流""" if not query.session.using_conversation.uuid: @@ -300,14 +304,14 @@ class DifyServiceAPIRunner(runner.RequestRunner): if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end': continue - msg = llm_entities.Message( + msg = provider_message.Message( role='assistant', content=None, tool_calls=[ - llm_entities.ToolCall( + provider_message.ToolCall( id=chunk['data']['node_id'], type='function', - function=llm_entities.FunctionCall( + function=provider_message.FunctionCall( name=chunk['data']['title'], arguments=json.dumps({}), ), @@ -322,7 +326,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): raise errors.DifyAPIError(chunk['data']['error']) content, _ = self._process_thinking_content(chunk['data']['outputs']['summary']) - msg = llm_entities.Message( + msg = provider_message.Message( role='assistant', content=content, ) @@ -330,8 +334,8 @@ class DifyServiceAPIRunner(runner.RequestRunner): yield msg async def _chat_messages_chunk( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.MessageChunk, None]: + self, query: pipeline_query.Query + ) -> typing.AsyncGenerator[provider_message.MessageChunk, None]: """调用聊天助手""" cov_id = query.session.using_conversation.uuid or '' query.variables['conversation_id'] = cov_id @@ -402,7 +406,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): if is_final or message_idx % 8 == 0: # content, _ = self._process_thinking_content(basic_mode_pending_chunk) - yield llm_entities.MessageChunk( + yield provider_message.MessageChunk( role='assistant', content=basic_mode_pending_chunk, is_final=is_final, @@ -414,8 +418,8 @@ class DifyServiceAPIRunner(runner.RequestRunner): query.session.using_conversation.uuid = chunk['conversation_id'] async def _agent_chat_messages_chunk( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.MessageChunk, None]: + self, query: pipeline_query.Query + ) -> typing.AsyncGenerator[provider_message.MessageChunk, None]: """调用聊天助手""" cov_id = query.session.using_conversation.uuid or '' query.variables['conversation_id'] = cov_id @@ -488,13 +492,13 @@ class DifyServiceAPIRunner(runner.RequestRunner): continue message_idx += 1 if chunk['tool']: - msg = llm_entities.MessageChunk( + msg = provider_message.MessageChunk( role='assistant', tool_calls=[ - llm_entities.ToolCall( + provider_message.ToolCall( id=chunk['id'], type='function', - function=llm_entities.FunctionCall( + function=provider_message.FunctionCall( name=chunk['tool'], arguments=json.dumps({}), ), @@ -512,16 +516,16 @@ class DifyServiceAPIRunner(runner.RequestRunner): image_url = base_url + chunk['url'] - yield llm_entities.MessageChunk( + yield provider_message.MessageChunk( role='assistant', - content=[llm_entities.ContentElement.from_image_url(image_url)], + content=[provider_message.ContentElement.from_image_url(image_url)], is_final=is_final, ) if chunk['event'] == 'error': raise errors.DifyAPIError('dify 服务错误: ' + chunk['message']) if message_idx % 8 == 0 or is_final: - yield llm_entities.MessageChunk( + yield provider_message.MessageChunk( role='assistant', content=pending_agent_message, is_final=is_final, @@ -533,8 +537,8 @@ class DifyServiceAPIRunner(runner.RequestRunner): query.session.using_conversation.uuid = chunk['conversation_id'] async def _workflow_messages_chunk( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.MessageChunk, None]: + self, query: pipeline_query.Query + ) -> typing.AsyncGenerator[provider_message.MessageChunk, None]: """调用工作流""" if not query.session.using_conversation.uuid: @@ -608,14 +612,14 @@ class DifyServiceAPIRunner(runner.RequestRunner): if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end': continue messsage_idx += 1 - msg = llm_entities.MessageChunk( + msg = provider_message.MessageChunk( role='assistant', content=None, tool_calls=[ - llm_entities.ToolCall( + provider_message.ToolCall( id=chunk['data']['node_id'], type='function', - function=llm_entities.FunctionCall( + function=provider_message.FunctionCall( name=chunk['data']['title'], arguments=json.dumps({}), ), @@ -626,13 +630,13 @@ class DifyServiceAPIRunner(runner.RequestRunner): yield msg if messsage_idx % 8 == 0 or is_final: - yield llm_entities.MessageChunk( + yield provider_message.MessageChunk( role='assistant', content=workflow_contents, is_final=is_final, ) - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: """运行请求""" if await query.adapter.is_stream_output_supported(): msg_idx = 0 diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 2500b363..7ab1e739 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -4,8 +4,8 @@ import json import copy import typing from .. import runner -from ...core import entities as core_entities -from .. import entities as llm_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message rag_combined_prompt_template = """ @@ -32,11 +32,11 @@ class LocalAgentRunner(runner.RequestRunner): def __init__(self): self.active_calls: dict[str, dict] = {} - self.completed_calls: list[llm_entities.ToolCall] = [] + self.completed_calls: list[provider_message.ToolCall] = [] async def run( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message | llm_entities.MessageChunk, None]: + self, query: pipeline_query.Query + ) -> typing.AsyncGenerator[provider_message.Message | provider_message.MessageChunk, None]: """运行请求""" pending_tool_calls = [] @@ -94,34 +94,36 @@ class LocalAgentRunner(runner.RequestRunner): except AttributeError: is_stream = False - remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think') + remove_think = query.pipeline_config['output'].get('misc', '').get('remove-think') + + use_llm_model = await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid) if not is_stream: # 非流式输出,直接请求 - msg = await query.use_llm_model.requester.invoke_llm( + msg = await use_llm_model.requester.invoke_llm( query, - query.use_llm_model, + use_llm_model, req_messages, query.use_funcs, - extra_args=query.use_llm_model.model_entity.extra_args, + extra_args=use_llm_model.model_entity.extra_args, remove_think=remove_think, ) yield msg final_msg = msg else: # 流式输出,需要处理工具调用 - tool_calls_map: dict[str, llm_entities.ToolCall] = {} + tool_calls_map: dict[str, provider_message.ToolCall] = {} msg_idx = 0 accumulated_content = '' # 从开始累积的所有内容 last_role = 'assistant' msg_sequence = 1 - async for msg in query.use_llm_model.requester.invoke_llm_stream( + async for msg in use_llm_model.requester.invoke_llm_stream( query, - query.use_llm_model, + use_llm_model, req_messages, query.use_funcs, - extra_args=query.use_llm_model.model_entity.extra_args, + extra_args=use_llm_model.model_entity.extra_args, remove_think=remove_think, ): msg_idx = msg_idx + 1 @@ -138,10 +140,10 @@ class LocalAgentRunner(runner.RequestRunner): if msg.tool_calls: for tool_call in msg.tool_calls: if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = llm_entities.ToolCall( + tool_calls_map[tool_call.id] = provider_message.ToolCall( id=tool_call.id, type=tool_call.type, - function=llm_entities.FunctionCall( + function=provider_message.FunctionCall( name=tool_call.function.name if tool_call.function else '', arguments='' ), ) @@ -152,7 +154,7 @@ class LocalAgentRunner(runner.RequestRunner): # 每8个chunk或最后一个chunk时,输出所有累积的内容 if msg_idx % 8 == 0 or msg.is_final: msg_sequence += 1 - yield llm_entities.MessageChunk( + yield provider_message.MessageChunk( role=last_role, content=accumulated_content, # 输出所有累积内容 tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None, @@ -161,7 +163,7 @@ class LocalAgentRunner(runner.RequestRunner): ) # 创建最终消息用于后续处理 - final_msg = llm_entities.MessageChunk( + final_msg = provider_message.MessageChunk( role=last_role, content=accumulated_content, tool_calls=list(tool_calls_map.values()) if tool_calls_map else None, @@ -170,8 +172,7 @@ class LocalAgentRunner(runner.RequestRunner): pending_tool_calls = final_msg.tool_calls first_content = final_msg.content - if isinstance(final_msg, llm_entities.MessageChunk): - + if isinstance(final_msg, provider_message.MessageChunk): first_end_sequence = final_msg.msg_sequence req_messages.append(final_msg) @@ -184,15 +185,15 @@ class LocalAgentRunner(runner.RequestRunner): parameters = json.loads(func.arguments) - func_ret = await self.ap.tool_mgr.execute_func_call(query, func.name, parameters) + func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters) if is_stream: - msg = llm_entities.MessageChunk( + msg = provider_message.MessageChunk( role='tool', content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id, ) else: - msg = llm_entities.Message( + msg = provider_message.Message( role='tool', content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id, @@ -203,7 +204,7 @@ class LocalAgentRunner(runner.RequestRunner): req_messages.append(msg) except Exception as e: # 工具调用出错,添加一个报错信息到 req_messages - err_msg = llm_entities.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id) + err_msg = provider_message.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id) yield err_msg @@ -216,12 +217,12 @@ class LocalAgentRunner(runner.RequestRunner): last_role = 'assistant' msg_sequence = first_end_sequence - async for msg in query.use_llm_model.requester.invoke_llm_stream( + async for msg in use_llm_model.requester.invoke_llm_stream( query, - query.use_llm_model, + use_llm_model, req_messages, query.use_funcs, - extra_args=query.use_llm_model.model_entity.extra_args, + extra_args=use_llm_model.model_entity.extra_args, remove_think=remove_think, ): msg_idx += 1 @@ -242,10 +243,10 @@ class LocalAgentRunner(runner.RequestRunner): if msg.tool_calls: for tool_call in msg.tool_calls: if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = llm_entities.ToolCall( + tool_calls_map[tool_call.id] = provider_message.ToolCall( id=tool_call.id, type=tool_call.type, - function=llm_entities.FunctionCall( + function=provider_message.FunctionCall( name=tool_call.function.name if tool_call.function else '', arguments='' ), ) @@ -256,7 +257,7 @@ class LocalAgentRunner(runner.RequestRunner): # 每8个chunk或最后一个chunk时,输出所有累积的内容 if msg_idx % 8 == 0 or msg.is_final: msg_sequence += 1 - yield llm_entities.MessageChunk( + yield provider_message.MessageChunk( role=last_role, content=accumulated_content, # 输出所有累积内容 tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None, @@ -264,21 +265,20 @@ class LocalAgentRunner(runner.RequestRunner): msg_sequence=msg_sequence, ) - final_msg = llm_entities.MessageChunk( + final_msg = provider_message.MessageChunk( role=last_role, content=accumulated_content, tool_calls=list(tool_calls_map.values()) if tool_calls_map else None, msg_sequence=msg_sequence, - ) else: # 处理完所有调用,再次请求 - msg = await query.use_llm_model.requester.invoke_llm( + msg = await use_llm_model.requester.invoke_llm( query, - query.use_llm_model, + use_llm_model, req_messages, query.use_funcs, - extra_args=query.use_llm_model.model_entity.extra_args, + extra_args=use_llm_model.model_entity.extra_args, remove_think=remove_think, ) diff --git a/pkg/provider/runners/n8nsvapi.py b/pkg/provider/runners/n8nsvapi.py index 7044cce1..d2b5aa78 100644 --- a/pkg/provider/runners/n8nsvapi.py +++ b/pkg/provider/runners/n8nsvapi.py @@ -6,8 +6,9 @@ import uuid import aiohttp from .. import runner -from ...core import app, entities as core_entities -from .. import entities as llm_entities +from ...core import app +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message class N8nAPIError(Exception): @@ -49,7 +50,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): self.header_name = self.pipeline_config['ai']['n8n-service-api'].get('header-name', '') self.header_value = self.pipeline_config['ai']['n8n-service-api'].get('header-value', '') - async def _preprocess_user_message(self, query: core_entities.Query) -> str: + async def _preprocess_user_message(self, query: pipeline_query.Query) -> str: """预处理用户消息,提取纯文本 Returns: @@ -67,7 +68,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): return plain_text - async def _call_webhook(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: """调用n8n webhook""" # 生成会话ID(如果不存在) if not query.session.using_conversation.uuid: @@ -145,7 +146,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): output_content = json.dumps(response_data, ensure_ascii=False) # 返回消息 - yield llm_entities.Message( + yield provider_message.Message( role='assistant', content=output_content, ) @@ -153,7 +154,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): self.ap.logger.error(f'n8n webhook call exception: {str(e)}') raise N8nAPIError(f'n8n webhook call exception: {str(e)}') - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: """运行请求""" async for msg in self._call_webhook(query): yield msg diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index f54b50e7..11d0254c 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -2,8 +2,10 @@ from __future__ import annotations import asyncio -from ...core import app, entities as core_entities -from ...provider import entities as provider_entities +from ...core import app +from langbot_plugin.api.entities.builtin.provider import message as provider_message, prompt as provider_prompt +import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class SessionManager: @@ -11,7 +13,7 @@ class SessionManager: ap: app.Application - session_list: list[core_entities.Session] + session_list: list[provider_session.Session] def __init__(self, ap: app.Application): self.ap = ap @@ -20,7 +22,7 @@ class SessionManager: async def initialize(self): pass - async def get_session(self, query: core_entities.Query) -> core_entities.Session: + async def get_session(self, query: pipeline_query.Query) -> provider_session.Session: """获取会话""" for session in self.session_list: if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id: @@ -28,22 +30,22 @@ class SessionManager: session_concurrency = self.ap.instance_config.data['concurrency']['session'] - session = core_entities.Session( + session = provider_session.Session( launcher_type=query.launcher_type, launcher_id=query.launcher_id, - semaphore=asyncio.Semaphore(session_concurrency), ) + session._semaphore = asyncio.Semaphore(session_concurrency) self.session_list.append(session) return session async def get_conversation( self, - query: core_entities.Query, - session: core_entities.Session, + query: pipeline_query.Query, + session: provider_session.Session, prompt_config: list[dict], pipeline_uuid: str, bot_uuid: str, - ) -> core_entities.Conversation: + ) -> provider_session.Conversation: """获取对话或创建对话""" if not session.conversations: @@ -53,20 +55,17 @@ class SessionManager: prompt_messages = [] for prompt_message in prompt_config: - prompt_messages.append(provider_entities.Message(**prompt_message)) + prompt_messages.append(provider_message.Message(**prompt_message)) - prompt = provider_entities.Prompt( + prompt = provider_prompt.Prompt( name='default', messages=prompt_messages, ) if session.using_conversation is None or session.using_conversation.pipeline_uuid != pipeline_uuid: - conversation = core_entities.Conversation( + conversation = provider_session.Conversation( prompt=prompt, messages=[], - use_funcs=await self.ap.tool_mgr.get_all_functions( - plugin_enabled=True, - ), pipeline_uuid=pipeline_uuid, bot_uuid=bot_uuid, ) diff --git a/pkg/provider/tools/entities.py b/pkg/provider/tools/entities.py deleted file mode 100644 index 102e03d3..00000000 --- a/pkg/provider/tools/entities.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -import typing - -import pydantic.v1 as pydantic - - -class LLMFunction(pydantic.BaseModel): - """函数""" - - name: str - """函数名""" - - human_desc: str - - description: str - """给LLM识别的函数描述""" - - parameters: dict - - func: typing.Callable - """供调用的python异步方法 - - 此异步方法第一个参数接收当前请求的query对象,可以从其中取出session等信息。 - query参数不在parameters中,但在调用时会自动传入。 - 但在当前版本中,插件提供的内容函数都是同步的,且均为请求无关的,故在此版本的实现(以及考虑了向后兼容性的版本)中, - 对插件的内容函数进行封装并存到这里来。 - """ - - class Config: - arbitrary_types_allowed = True diff --git a/pkg/provider/tools/loader.py b/pkg/provider/tools/loader.py index 76b7d248..f3d65fd2 100644 --- a/pkg/provider/tools/loader.py +++ b/pkg/provider/tools/loader.py @@ -3,8 +3,8 @@ from __future__ import annotations import abc import typing -from ...core import app, entities as core_entities -from . import entities as tools_entities +from ...core import app +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool preregistered_loaders: list[typing.Type[ToolLoader]] = [] @@ -35,7 +35,7 @@ class ToolLoader(abc.ABC): pass @abc.abstractmethod - async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]: + async def get_tools(self) -> list[resource_tool.LLMTool]: """获取所有工具""" pass @@ -45,7 +45,7 @@ class ToolLoader(abc.ABC): pass @abc.abstractmethod - async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + async def invoke_tool(self, name: str, parameters: dict) -> typing.Any: """执行工具调用""" pass diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py index f3223f42..36fa9751 100644 --- a/pkg/provider/tools/loaders/mcp.py +++ b/pkg/provider/tools/loaders/mcp.py @@ -7,8 +7,9 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.client.sse import sse_client -from .. import loader, entities as tools_entities -from ....core import app, entities as core_entities +from .. import loader +from ....core import app +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool class RuntimeMCPSession: @@ -24,7 +25,7 @@ class RuntimeMCPSession: exit_stack: AsyncExitStack - functions: list[tools_entities.LLMFunction] = [] + functions: list[resource_tool.LLMTool] = [] def __init__(self, server_name: str, server_config: dict, ap: app.Application): self.server_name = server_name @@ -82,7 +83,7 @@ class RuntimeMCPSession: for tool in tools.tools: - async def func(query: core_entities.Query, *, _tool=tool, **kwargs): + async def func(*, _tool=tool, **kwargs): result = await self.session.call_tool(_tool.name, kwargs) if result.isError: raise Exception(result.content[0].text) @@ -91,7 +92,7 @@ class RuntimeMCPSession: func.__name__ = tool.name self.functions.append( - tools_entities.LLMFunction( + resource_tool.LLMTool( name=tool.name, human_desc=tool.description, description=tool.description, @@ -114,7 +115,7 @@ class MCPLoader(loader.ToolLoader): sessions: dict[str, RuntimeMCPSession] = {} - _last_listed_functions: list[tools_entities.LLMFunction] = [] + _last_listed_functions: list[resource_tool.LLMTool] = [] def __init__(self, ap: app.Application): super().__init__(ap) @@ -130,7 +131,7 @@ class MCPLoader(loader.ToolLoader): # self.ap.event_loop.create_task(session.initialize()) self.sessions[server_config['name']] = session - async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]: + async def get_tools(self) -> list[resource_tool.LLMTool]: all_functions = [] for session in self.sessions.values(): @@ -143,11 +144,11 @@ class MCPLoader(loader.ToolLoader): async def has_tool(self, name: str) -> bool: return name in [f.name for f in self._last_listed_functions] - async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + async def invoke_tool(self, name: str, parameters: dict) -> typing.Any: for server_name, session in self.sessions.items(): for function in session.functions: if function.name == name: - return await function.func(query, **parameters) + return await function.func(**parameters) raise ValueError(f'未找到工具: {name}') diff --git a/pkg/provider/tools/loaders/plugin.py b/pkg/provider/tools/loaders/plugin.py index b7df2d67..94296470 100644 --- a/pkg/provider/tools/loaders/plugin.py +++ b/pkg/provider/tools/loaders/plugin.py @@ -3,9 +3,8 @@ from __future__ import annotations import typing import traceback -from .. import loader, entities as tools_entities -from ....core import entities as core_entities -from ....plugin import context as plugin_context +from .. import loader +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool @loader.loader_class('plugin-tool-loader') @@ -15,63 +14,42 @@ class PluginToolLoader(loader.ToolLoader): 本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。 """ - async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]: + async def get_tools(self) -> list[resource_tool.LLMTool]: # 从插件系统获取工具(内容函数) - all_functions: list[tools_entities.LLMFunction] = [] + all_functions: list[resource_tool.LLMTool] = [] - for plugin in self.ap.plugin_mgr.plugins( - enabled=enabled, status=plugin_context.RuntimeContainerStatus.INITIALIZED - ): - all_functions.extend(plugin.tools) + for tool in await self.ap.plugin_connector.list_tools(): + tool_obj = resource_tool.LLMTool( + name=tool.metadata.name, + human_desc=tool.metadata.description.en_US, + description=tool.spec['llm_prompt'], + parameters=tool.spec['parameters'], + func=lambda parameters: {}, + ) + all_functions.append(tool_obj) return all_functions async def has_tool(self, name: str) -> bool: """检查工具是否存在""" - for plugin in self.ap.plugin_mgr.plugins( - enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED - ): - for function in plugin.tools: - if function.name == name: - return True + for tool in await self.ap.plugin_connector.list_tools(): + if tool.metadata.name == name: + return True return False - async def _get_function_and_plugin( - self, name: str - ) -> typing.Tuple[tools_entities.LLMFunction, plugin_context.BasePlugin]: - """获取函数和插件实例""" - for plugin in self.ap.plugin_mgr.plugins( - enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED - ): - for function in plugin.tools: - if function.name == name: - return function, plugin.plugin_inst - return None, None + async def _get_tool(self, name: str) -> resource_tool.LLMTool: + for tool in await self.ap.plugin_connector.list_tools(): + if tool.metadata.name == name: + return tool + return None - async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + async def invoke_tool(self, name: str, parameters: dict) -> typing.Any: try: - function, plugin = await self._get_function_and_plugin(name) - if function is None: - return None - - parameters = parameters.copy() - - parameters = {'query': query, **parameters} - - return await function.func(plugin, **parameters) + return await self.ap.plugin_connector.call_tool(name, parameters) except Exception as e: self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}') traceback.print_exc() return f'error occurred when executing function {name}: {e}' - finally: - plugin = None - - for p in self.ap.plugin_mgr.plugins(): - if function in p.tools: - plugin = p - break - - # TODO statistics async def shutdown(self): """关闭工具""" diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index b1d43d08..43960aba 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -2,10 +2,11 @@ from __future__ import annotations import typing -from ...core import app, entities as core_entities -from . import entities, loader as tools_loader +from ...core import app +from . import loader as tools_loader from ...utils import importutil from . import loaders +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool importutil.import_modules_in_pkg(loaders) @@ -28,16 +29,16 @@ class ToolManager: await loader_inst.initialize() self.loaders.append(loader_inst) - async def get_all_functions(self, plugin_enabled: bool = None) -> list[entities.LLMFunction]: + async def get_all_tools(self) -> list[resource_tool.LLMTool]: """获取所有函数""" - all_functions: list[entities.LLMFunction] = [] + all_functions: list[resource_tool.LLMTool] = [] for loader in self.loaders: - all_functions.extend(await loader.get_tools(plugin_enabled)) + all_functions.extend(await loader.get_tools()) return all_functions - async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list: + async def generate_tools_for_openai(self, use_funcs: list[resource_tool.LLMTool]) -> list: """生成函数列表""" tools = [] @@ -54,7 +55,7 @@ class ToolManager: return tools - async def generate_tools_for_anthropic(self, use_funcs: list[entities.LLMFunction]) -> list: + async def generate_tools_for_anthropic(self, use_funcs: list[resource_tool.LLMTool]) -> list: """为anthropic生成函数列表 e.g. @@ -89,12 +90,12 @@ class ToolManager: return tools - async def execute_func_call(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + async def execute_func_call(self, name: str, parameters: dict) -> typing.Any: """执行函数调用""" for loader in self.loaders: if await loader.has_tool(name): - return await loader.invoke_tool(query, name, parameters) + return await loader.invoke_tool(name, parameters) else: raise ValueError(f'未找到工具: {name}') diff --git a/pkg/utils/announce.py b/pkg/utils/announce.py index 69d19368..8778a04f 100644 --- a/pkg/utils/announce.py +++ b/pkg/utils/announce.py @@ -6,7 +6,7 @@ import os import base64 import logging -import pydantic.v1 as pydantic +import pydantic import requests from ..core import app @@ -59,7 +59,7 @@ class AnnouncementManager: return [Announcement(**item) for item in json.loads(content)] except (requests.RequestException, json.JSONDecodeError, KeyError) as e: - self.ap.logger.warning(f"获取公告失败: {e}") + self.ap.logger.warning(f'获取公告失败: {e}') pass return [] # 请求失败时返回空列表 diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py index 74fe232b..0da3cc9d 100644 --- a/pkg/utils/constants.py +++ b/pkg/utils/constants.py @@ -1,7 +1,7 @@ -semantic_version = 'v4.2.2' +semantic_version = 'v4.3.0.beta3' -required_database_version = 5 -"""Tag the version of the database schema, used to check if the database needs to be migrated""" +required_database_version = 6 +"""标记本版本所需要的数据库结构版本,用于判断数据库迁移""" debug_mode = False diff --git a/pkg/utils/platform.py b/pkg/utils/platform.py index 0d4a1f26..b3f7a6df 100644 --- a/pkg/utils/platform.py +++ b/pkg/utils/platform.py @@ -5,7 +5,18 @@ import sys def get_platform() -> str: """获取当前平台""" # 检查是不是在 docker 里 - if os.path.exists('/.dockerenv'): + + DOCKER_ENV = os.environ.get('DOCKER_ENV', 'false') + + if os.path.exists('/.dockerenv') or DOCKER_ENV == 'true': return 'docker' return sys.platform + + +standalone_runtime = False + + +def use_websocket_to_connect_plugin_runtime() -> bool: + """是否使用 websocket 连接插件运行时""" + return standalone_runtime diff --git a/pyproject.toml b/pyproject.toml index ef1e174b..f934f8e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "langbot" -version = "4.2.2" +version = "4.3.0.beta4" description = "高稳定、支持扩展、多模态 - 大模型原生即时通信机器人平台" readme = "README.md" requires-python = ">=3.10.1" @@ -50,6 +50,7 @@ dependencies = [ "ruff>=0.11.9", "pre-commit>=4.2.0", "uv>=0.7.11", + "mypy>=1.16.0", "PyPDF2>=3.0.1", "python-docx>=1.1.0", "pandas>=2.2.2", @@ -60,6 +61,7 @@ dependencies = [ "html2text>=2024.2.26", "langchain>=0.2.0", "chromadb>=0.4.24", + "langbot-plugin==0.1.1b6", ] keywords = [ "bot", diff --git a/templates/config.yaml b/templates/config.yaml index 3faa2fd7..a0c22c8a 100644 --- a/templates/config.yaml +++ b/templates/config.yaml @@ -20,3 +20,7 @@ system: jwt: expire: 604800 secret: '' +plugin: + runtime_ws_url: 'ws://langbot_plugin_runtime:5400/control/ws' + enable_marketplace: true + cloud_service_url: 'https://space.langbot.app' diff --git a/web/package.json b/web/package.json index 4749f463..26cb7ca7 100644 --- a/web/package.json +++ b/web/package.json @@ -26,7 +26,7 @@ "@radix-ui/react-checkbox": "^1.3.1", "@radix-ui/react-context-menu": "^2.2.15", "@radix-ui/react-dialog": "^1.1.14", - "@radix-ui/react-dropdown-menu": "^2.1.15", + "@radix-ui/react-dropdown-menu": "^2.1.16", "@radix-ui/react-hover-card": "^1.1.13", "@radix-ui/react-label": "^2.1.6", "@radix-ui/react-popover": "^1.1.14", @@ -56,7 +56,9 @@ "react-dom": "^19.0.0", "react-hook-form": "^7.56.3", "react-i18next": "^15.5.1", + "react-markdown": "^10.1.0", "react-photo-view": "^1.2.7", + "remark-gfm": "^4.0.1", "sonner": "^2.0.3", "tailwind-merge": "^3.2.0", "tailwindcss": "^4.1.5", diff --git a/web/src/app/home/bots/components/bot-form/BotForm.tsx b/web/src/app/home/bots/components/bot-form/BotForm.tsx index 92811a65..f6aa21c0 100644 --- a/web/src/app/home/bots/components/bot-form/BotForm.tsx +++ b/web/src/app/home/bots/components/bot-form/BotForm.tsx @@ -47,7 +47,7 @@ import { SelectValue, } from '@/components/ui/select'; import { Switch } from '@/components/ui/switch'; -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; const getFormSchema = (t: (key: string) => string) => z.object({ @@ -162,7 +162,7 @@ export default function BotForm({ setAdapterNameList( adaptersRes.adapters.map((item) => { return { - label: i18nObj(item.label), + label: extractI18nObject(item.label), value: item.name, }; }), @@ -183,7 +183,7 @@ export default function BotForm({ setAdapterDescriptionList( adaptersRes.adapters.reduce( (acc, item) => { - acc[item.name] = i18nObj(item.description); + acc[item.name] = extractI18nObject(item.description); return acc; }, {} as Record, diff --git a/web/src/app/home/bots/page.tsx b/web/src/app/home/bots/page.tsx index ad130fae..59dc83c3 100644 --- a/web/src/app/home/bots/page.tsx +++ b/web/src/app/home/bots/page.tsx @@ -9,7 +9,7 @@ import { httpClient } from '@/app/infra/http/HttpClient'; import { Bot, Adapter } from '@/app/infra/entities/api'; import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; import BotDetailDialog from '@/app/home/bots/BotDetailDialog'; export default function BotConfigPage() { @@ -27,7 +27,7 @@ export default function BotConfigPage() { const adapterListResp = await httpClient.getAdapters(); const adapterList = adapterListResp.adapters.map((adapter: Adapter) => { return { - label: i18nObj(adapter.label), + label: extractI18nObject(adapter.label), value: adapter.name, }; }); diff --git a/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx b/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx index de040db3..6d13e99c 100644 --- a/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx +++ b/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx @@ -12,7 +12,7 @@ import { } from '@/components/ui/form'; import DynamicFormItemComponent from '@/app/home/components/dynamic-form/DynamicFormItemComponent'; import { useEffect } from 'react'; -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; export default function DynamicFormComponent({ itemConfigList, @@ -145,7 +145,7 @@ export default function DynamicFormComponent({ render={({ field }) => ( - {i18nObj(config.label)}{' '} + {extractI18nObject(config.label)}{' '} {config.required && *} @@ -153,7 +153,7 @@ export default function DynamicFormComponent({ {config.description && (

- {i18nObj(config.description)} + {extractI18nObject(config.description)}

)} diff --git a/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx b/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx index a762f7d0..60062c22 100644 --- a/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx +++ b/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx @@ -25,7 +25,7 @@ import { HoverCardTrigger, } from '@/components/ui/hover-card'; import { useTranslation } from 'react-i18next'; -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; import { Textarea } from '@/components/ui/textarea'; export default function DynamicFormItemComponent({ @@ -140,7 +140,7 @@ export default function DynamicFormItemComponent({ {config.options?.map((option) => ( - {i18nObj(option.label)} + {extractI18nObject(option.label)} ))} diff --git a/web/src/app/home/components/dynamic-form/DynamicFormItemConfig.ts b/web/src/app/home/components/dynamic-form/DynamicFormItemConfig.ts index 74fd4a0b..6b52ece0 100644 --- a/web/src/app/home/components/dynamic-form/DynamicFormItemConfig.ts +++ b/web/src/app/home/components/dynamic-form/DynamicFormItemConfig.ts @@ -3,16 +3,16 @@ import { DynamicFormItemType, IDynamicFormItemOption, } from '@/app/infra/entities/form/dynamic'; -import { I18nLabel } from '@/app/infra/entities/common'; +import { I18nObject } from '@/app/infra/entities/common'; export class DynamicFormItemConfig implements IDynamicFormItemSchema { id: string; name: string; default: string | number | boolean | Array; - label: I18nLabel; + label: I18nObject; required: boolean; type: DynamicFormItemType; - description?: I18nLabel; + description?: I18nObject; options?: IDynamicFormItemOption[]; constructor(params: IDynamicFormItemSchema) { diff --git a/web/src/app/home/components/dynamic-form/N8nAuthFormComponent.tsx b/web/src/app/home/components/dynamic-form/N8nAuthFormComponent.tsx index 1c71befc..5605cae2 100644 --- a/web/src/app/home/components/dynamic-form/N8nAuthFormComponent.tsx +++ b/web/src/app/home/components/dynamic-form/N8nAuthFormComponent.tsx @@ -12,7 +12,7 @@ import { } from '@/components/ui/form'; import { IDynamicFormItemSchema } from '@/app/infra/entities/form/dynamic'; import DynamicFormItemComponent from '@/app/home/components/dynamic-form/DynamicFormItemComponent'; -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; /** * N8n认证表单组件 @@ -182,7 +182,7 @@ export default function N8nAuthFormComponent({ render={({ field }) => ( - {i18nObj(config.label)}{' '} + {extractI18nObject(config.label)}{' '} {config.required && *} @@ -190,7 +190,7 @@ export default function N8nAuthFormComponent({ {config.description && (

- {i18nObj(config.description)} + {extractI18nObject(config.description)}

)} diff --git a/web/src/app/home/components/home-sidebar/HomeSidebar.tsx b/web/src/app/home/components/home-sidebar/HomeSidebar.tsx index 4009c77a..b9489668 100644 --- a/web/src/app/home/components/home-sidebar/HomeSidebar.tsx +++ b/web/src/app/home/components/home-sidebar/HomeSidebar.tsx @@ -9,7 +9,8 @@ import { import { useRouter, usePathname } from 'next/navigation'; import { sidebarConfigList } from '@/app/home/components/home-sidebar/sidbarConfigList'; import langbotIcon from '@/app/assets/langbot-logo.webp'; -import { systemInfo, spaceClient } from '@/app/infra/http/HttpClient'; +import { systemInfo } from '@/app/infra/http/HttpClient'; +import { getCloudServiceClientSync } from '@/app/infra/http'; import { useTranslation } from 'react-i18next'; import { Moon, Sun, Monitor } from 'lucide-react'; import { useTheme } from 'next-themes'; @@ -54,7 +55,7 @@ export default function HomeSidebar({ localStorage.setItem('userEmail', 'test@example.com'); } - spaceClient + getCloudServiceClientSync() .get('/api/v1/dist/info/repo') .then((response) => { const data = response as { repo: { stargazers_count: number } }; diff --git a/web/src/app/home/components/home-sidebar/HomeSidebarChild.tsx b/web/src/app/home/components/home-sidebar/HomeSidebarChild.tsx index 8529d410..031bc8db 100644 --- a/web/src/app/home/components/home-sidebar/HomeSidebarChild.tsx +++ b/web/src/app/home/components/home-sidebar/HomeSidebarChild.tsx @@ -1,5 +1,5 @@ import styles from './HomeSidebar.module.css'; -import { I18nLabel } from '@/app/infra/entities/common'; +import { I18nObject } from '@/app/infra/entities/common'; export interface ISidebarChildVO { id: string; @@ -7,7 +7,7 @@ export interface ISidebarChildVO { name: string; route: string; description: string; - helpLink: I18nLabel; + helpLink: I18nObject; } export class SidebarChildVO { @@ -16,7 +16,7 @@ export class SidebarChildVO { name: string; route: string; description: string; - helpLink: I18nLabel; + helpLink: I18nObject; constructor(props: ISidebarChildVO) { this.id = props.id; diff --git a/web/src/app/home/components/home-titlebar/HomeTitleBar.tsx b/web/src/app/home/components/home-titlebar/HomeTitleBar.tsx index 56e849fa..0749b8fe 100644 --- a/web/src/app/home/components/home-titlebar/HomeTitleBar.tsx +++ b/web/src/app/home/components/home-titlebar/HomeTitleBar.tsx @@ -1,6 +1,6 @@ -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; import styles from './HomeTittleBar.module.css'; -import { I18nLabel } from '@/app/infra/entities/common'; +import { I18nObject } from '@/app/infra/entities/common'; export default function HomeTitleBar({ title, @@ -9,7 +9,7 @@ export default function HomeTitleBar({ }: { title: string; subtitle: string; - helpLink: I18nLabel; + helpLink: I18nObject; }) { return (
@@ -19,7 +19,7 @@ export default function HomeTitleBar({
{ - window.open(i18nObj(helpLink), '_blank'); + window.open(extractI18nObject(helpLink), '_blank'); }} className="cursor-pointer" > diff --git a/web/src/app/home/layout.tsx b/web/src/app/home/layout.tsx index 7dd68b25..d84eb6af 100644 --- a/web/src/app/home/layout.tsx +++ b/web/src/app/home/layout.tsx @@ -5,7 +5,7 @@ import HomeSidebar from '@/app/home/components/home-sidebar/HomeSidebar'; import HomeTitleBar from '@/app/home/components/home-titlebar/HomeTitleBar'; import React, { useState } from 'react'; import { SidebarChildVO } from '@/app/home/components/home-sidebar/HomeSidebarChild'; -import { I18nLabel } from '@/app/infra/entities/common'; +import { I18nObject } from '@/app/infra/entities/common'; export default function HomeLayout({ children, @@ -14,7 +14,7 @@ export default function HomeLayout({ }>) { const [title, setTitle] = useState(''); const [subtitle, setSubtitle] = useState(''); - const [helpLink, setHelpLink] = useState({ + const [helpLink, setHelpLink] = useState({ en_US: '', zh_Hans: '', }); diff --git a/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx index eee02b5a..5acb9ac4 100644 --- a/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx +++ b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx @@ -38,7 +38,7 @@ import { SelectValue, } from '@/components/ui/select'; import { toast } from 'sonner'; -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; const getExtraArgSchema = (t: (key: string) => string) => z @@ -184,7 +184,7 @@ export default function EmbeddingForm({ setRequesterNameList( requesterNameList.requesters.map((item) => { return { - label: i18nObj(item.label), + label: extractI18nObject(item.label), value: item.name, }; }), diff --git a/web/src/app/home/models/component/llm-form/LLMForm.tsx b/web/src/app/home/models/component/llm-form/LLMForm.tsx index 434e68a4..6d46da6d 100644 --- a/web/src/app/home/models/component/llm-form/LLMForm.tsx +++ b/web/src/app/home/models/component/llm-form/LLMForm.tsx @@ -39,7 +39,7 @@ import { } from '@/components/ui/select'; import { Checkbox } from '@/components/ui/checkbox'; import { toast } from 'sonner'; -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; const getExtraArgSchema = (t: (key: string) => string) => z @@ -201,7 +201,7 @@ export default function LLMForm({ setRequesterNameList( requesterNameList.requesters.map((item) => { return { - label: i18nObj(item.label), + label: extractI18nObject(item.label), value: item.name, }; }), diff --git a/web/src/app/home/models/page.tsx b/web/src/app/home/models/page.tsx index 66099125..9a3a1597 100644 --- a/web/src/app/home/models/page.tsx +++ b/web/src/app/home/models/page.tsx @@ -17,7 +17,7 @@ import { } from '@/components/ui/dialog'; import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; import { EmbeddingCardVO } from '@/app/home/models/component/embedding-card/EmbeddingCardVO'; import EmbeddingCard from '@/app/home/models/component/embedding-card/EmbeddingCard'; import EmbeddingForm from '@/app/home/models/component/embedding-form/EmbeddingForm'; @@ -45,7 +45,7 @@ export default function LLMConfigPage() { const requesterNameListResp = await httpClient.getProviderRequesters('llm'); const requesterNameList = requesterNameListResp.requesters.map((item) => { return { - label: i18nObj(item.label), + label: extractI18nObject(item.label), value: item.name, }; }); @@ -102,7 +102,7 @@ export default function LLMConfigPage() { await httpClient.getProviderRequesters('text-embedding'); const requesterNameList = requesterNameListResp.requesters.map((item) => { return { - label: i18nObj(item.label), + label: extractI18nObject(item.label), value: item.name, }; }); diff --git a/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx b/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx index 7d26ad30..0ff2a2cd 100644 --- a/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx +++ b/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx @@ -30,7 +30,7 @@ import { } from '@/components/ui/dialog'; import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; export default function PipelineFormComponent({ isDefaultPipeline, @@ -229,10 +229,12 @@ export default function PipelineFormComponent({ if (stage.name === 'runner') { return (
-
{i18nObj(stage.label)}
+
+ {extractI18nObject(stage.label)} +
{stage.description && (
- {i18nObj(stage.description)} + {extractI18nObject(stage.description)}
)} -
{i18nObj(stage.label)}
+
+ {extractI18nObject(stage.label)} +
{stage.description && (
- {i18nObj(stage.description)} + {extractI18nObject(stage.description)}
)} -
{i18nObj(stage.label)}
+
+ {extractI18nObject(stage.label)} +
{stage.description && (
- {i18nObj(stage.description)} + {extractI18nObject(stage.description)}
)} ('local'); + const [installInfo, setInstallInfo] = useState>({}); // eslint-disable-line @typescript-eslint/no-explicit-any const [pluginInstallStatus, setPluginInstallStatus] = useState(PluginInstallStatus.WAIT_INPUT); const [installError, setInstallError] = useState(null); const [githubURL, setGithubURL] = useState(''); + const [isDragOver, setIsDragOver] = useState(false); const pluginInstalledRef = useRef(null); + const fileInputRef = useRef(null); + + function watchTask(taskId: number) { + let alreadySuccess = false; + console.log('taskId:', taskId); + + // 每秒拉取一次任务状态 + const interval = setInterval(() => { + httpClient.getAsyncTask(taskId).then((resp) => { + console.log('task status:', resp); + if (resp.runtime.done) { + clearInterval(interval); + if (resp.runtime.exception) { + setInstallError(resp.runtime.exception); + setPluginInstallStatus(PluginInstallStatus.ERROR); + } else { + // success + if (!alreadySuccess) { + toast.success(t('plugins.installSuccess')); + alreadySuccess = true; + } + setGithubURL(''); + setModalOpen(false); + pluginInstalledRef.current?.refreshPluginList(); + } + } + }); + }, 1000); + } function handleModalConfirm() { - installPlugin(githubURL); + installPlugin(installSource, installInfo as Record); // eslint-disable-line @typescript-eslint/no-explicit-any } - function installPlugin(url: string) { + + function installPlugin( + installSource: string, + installInfo: Record, // eslint-disable-line @typescript-eslint/no-explicit-any + ) { setPluginInstallStatus(PluginInstallStatus.INSTALLING); - httpClient - .installPluginFromGithub(url) - .then((resp) => { - const taskId = resp.task_id; - - let alreadySuccess = false; - console.log('taskId:', taskId); - - // 每秒拉取一次任务状态 - const interval = setInterval(() => { - httpClient.getAsyncTask(taskId).then((resp) => { - console.log('task status:', resp); - if (resp.runtime.done) { - clearInterval(interval); - if (resp.runtime.exception) { - setInstallError(resp.runtime.exception); - setPluginInstallStatus(PluginInstallStatus.ERROR); - } else { - // success - if (!alreadySuccess) { - toast.success(t('plugins.installSuccess')); - alreadySuccess = true; - } - setGithubURL(''); - setModalOpen(false); - pluginInstalledRef.current?.refreshPluginList(); - } - } - }); - }, 1000); - }) - .catch((err) => { - console.log('error when install plugin:', err); - setInstallError(err.message); - setPluginInstallStatus(PluginInstallStatus.ERROR); - }); + if (installSource === 'github') { + httpClient + .installPluginFromGithub(installInfo.url) + .then((resp) => { + const taskId = resp.task_id; + watchTask(taskId); + }) + .catch((err) => { + console.log('error when install plugin:', err); + setInstallError(err.message); + setPluginInstallStatus(PluginInstallStatus.ERROR); + }); + } else if (installSource === 'local') { + httpClient + .installPluginFromLocal(installInfo.file) + .then((resp) => { + const taskId = resp.task_id; + watchTask(taskId); + }) + .catch((err) => { + console.log('error when install plugin:', err); + setInstallError(err.message); + setPluginInstallStatus(PluginInstallStatus.ERROR); + }); + } else if (installSource === 'marketplace') { + httpClient + .installPluginFromMarketplace( + installInfo.plugin_author, + installInfo.plugin_name, + installInfo.plugin_version, + ) + .then((resp) => { + const taskId = resp.task_id; + watchTask(taskId); + }); + } } + const validateFileType = (file: File): boolean => { + const allowedExtensions = ['.lbpkg', '.zip']; + const fileName = file.name.toLowerCase(); + return allowedExtensions.some((ext) => fileName.endsWith(ext)); + }; + + const uploadPluginFile = useCallback( + async (file: File) => { + if (!validateFileType(file)) { + toast.error(t('plugins.unsupportedFileType')); + return; + } + + setModalOpen(true); + setPluginInstallStatus(PluginInstallStatus.INSTALLING); + setInstallError(null); + installPlugin('local', { file }); + }, + [t], + ); + + const handleFileSelect = useCallback(() => { + if (fileInputRef.current) { + fileInputRef.current.click(); + } + }, []); + + const handleFileChange = useCallback( + (event: React.ChangeEvent) => { + const file = event.target.files?.[0]; + if (file) { + uploadPluginFile(file); + } + // 清空input值,以便可以重复选择同一个文件 + event.target.value = ''; + }, + [uploadPluginFile], + ); + + const handleDragOver = useCallback((event: React.DragEvent) => { + event.preventDefault(); + setIsDragOver(true); + }, []); + + const handleDragLeave = useCallback((event: React.DragEvent) => { + event.preventDefault(); + setIsDragOver(false); + }, []); + + const handleDrop = useCallback( + (event: React.DragEvent) => { + event.preventDefault(); + setIsDragOver(false); + + const files = Array.from(event.dataTransfer.files); + if (files.length > 0) { + uploadPluginFile(files[0]); + } + }, + [uploadPluginFile], + ); + return ( -
- +
+ +
{t('plugins.installed')} - - {t('plugins.marketplace')} - + {systemInfo.enable_marketplace && ( + + {t('plugins.marketplace')} + + )}
- - + */} + + + + + + + + {t('plugins.uploadLocal')} + + {systemInfo.enable_marketplace && ( + { + setActiveTab('market'); + }} + > + + {t('plugins.marketplace')} + + )} + +
- { - setGithubURL(githubURL); + { + setInstallSource('marketplace'); + setInstallInfo({ + plugin_author: plugin.author, + plugin_name: plugin.name, + plugin_version: plugin.latest_version, + }); + setPluginInstallStatus(PluginInstallStatus.ASK_CONFIRM); setModalOpen(true); - setPluginInstallStatus(PluginInstallStatus.WAIT_INPUT); - setInstallError(null); }} /> @@ -137,8 +281,8 @@ export default function PluginConfigPage() { - - {t('plugins.installFromGithub')} + + {t('plugins.installPlugin')} {pluginInstallStatus === PluginInstallStatus.WAIT_INPUT && ( @@ -152,6 +296,16 @@ export default function PluginConfigPage() { />
)} + {pluginInstallStatus === PluginInstallStatus.ASK_CONFIRM && ( +
+

+ {t('plugins.askConfirm', { + name: installInfo.plugin_name, + version: installInfo.plugin_version, + })} +

+
+ )} {pluginInstallStatus === PluginInstallStatus.INSTALLING && (

{t('plugins.installing')}

@@ -164,12 +318,13 @@ export default function PluginConfigPage() {
)} - {pluginInstallStatus === PluginInstallStatus.WAIT_INPUT && ( + {(pluginInstallStatus === PluginInstallStatus.WAIT_INPUT || + pluginInstallStatus === PluginInstallStatus.ASK_CONFIRM) && ( <> - @@ -183,13 +338,27 @@ export default function PluginConfigPage() { - +
+
+ +

+ {t('plugins.dragToUpload')} +

+
+
+
+ )} + + {/* { pluginInstalledRef.current?.refreshPluginList(); }} - /> + /> */}
); } diff --git a/web/src/app/home/plugins/plugin-installed/PluginCardVO.ts b/web/src/app/home/plugins/plugin-installed/PluginCardVO.ts index 0e880543..9712cead 100644 --- a/web/src/app/home/plugins/plugin-installed/PluginCardVO.ts +++ b/web/src/app/home/plugins/plugin-installed/PluginCardVO.ts @@ -1,38 +1,46 @@ +import { PluginComponent } from '@/app/infra/entities/plugin'; + export interface IPluginCardVO { author: string; + label: string; name: string; description: string; version: string; enabled: boolean; priority: number; + install_source: string; + install_info: Record; // eslint-disable-line @typescript-eslint/no-explicit-any status: string; - tools: object[]; - event_handlers: object; - repository: string; + components: PluginComponent[]; + debug: boolean; } export class PluginCardVO implements IPluginCardVO { author: string; + label: string; name: string; description: string; version: string; enabled: boolean; priority: number; + debug: boolean; + install_source: string; + install_info: Record; // eslint-disable-line @typescript-eslint/no-explicit-any status: string; - tools: object[]; - event_handlers: object; - repository: string; + components: PluginComponent[]; constructor(prop: IPluginCardVO) { this.author = prop.author; + this.label = prop.label; this.description = prop.description; this.enabled = prop.enabled; - this.event_handlers = prop.event_handlers; + this.components = prop.components; this.name = prop.name; this.priority = prop.priority; - this.repository = prop.repository; this.status = prop.status; - this.tools = prop.tools; this.version = prop.version; + this.debug = prop.debug; + this.install_source = prop.install_source; + this.install_info = prop.install_info; } } diff --git a/web/src/app/home/plugins/plugin-installed/PluginInstalledComponent.tsx b/web/src/app/home/plugins/plugin-installed/PluginInstalledComponent.tsx index 3e5dd9b5..5581fc7a 100644 --- a/web/src/app/home/plugins/plugin-installed/PluginInstalledComponent.tsx +++ b/web/src/app/home/plugins/plugin-installed/PluginInstalledComponent.tsx @@ -11,14 +11,24 @@ import { DialogContent, DialogHeader, DialogTitle, + DialogDescription, + DialogFooter, } from '@/components/ui/dialog'; +import { Button } from '@/components/ui/button'; import { useTranslation } from 'react-i18next'; -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; +import { toast } from 'sonner'; +import { useAsyncTask, AsyncTaskStatus } from '@/hooks/useAsyncTask'; export interface PluginInstalledComponentRef { refreshPluginList: () => void; } +enum PluginOperationType { + DELETE = 'DELETE', + UPDATE = 'UPDATE', +} + // eslint-disable-next-line react/display-name const PluginInstalledComponent = forwardRef( (props, ref) => { @@ -28,6 +38,26 @@ const PluginInstalledComponent = forwardRef( const [selectedPlugin, setSelectedPlugin] = useState( null, ); + const [showOperationModal, setShowOperationModal] = useState(false); + const [operationType, setOperationType] = useState( + PluginOperationType.DELETE, + ); + const [targetPlugin, setTargetPlugin] = useState(null); + + const asyncTask = useAsyncTask({ + onSuccess: () => { + const successMessage = + operationType === PluginOperationType.DELETE + ? t('plugins.deleteSuccess') + : t('plugins.updateSuccess'); + toast.success(successMessage); + setShowOperationModal(false); + getPluginList(); + }, + onError: () => { + // Error is already handled in the hook state + }, + }); useEffect(() => { initData(); @@ -43,16 +73,23 @@ const PluginInstalledComponent = forwardRef( setPluginList( value.plugins.map((plugin) => { return new PluginCardVO({ - author: plugin.author, - description: i18nObj(plugin.description), + author: plugin.manifest.manifest.metadata.author ?? '', + label: extractI18nObject(plugin.manifest.manifest.metadata.label), + description: extractI18nObject( + plugin.manifest.manifest.metadata.description ?? { + en_US: '', + zh_Hans: '', + }, + ), + debug: plugin.debug, enabled: plugin.enabled, - name: plugin.name, - version: plugin.version, + name: plugin.manifest.manifest.metadata.name, + version: plugin.manifest.manifest.metadata.version ?? '', status: plugin.status, - tools: plugin.tools, - event_handlers: plugin.event_handlers, - repository: plugin.repository, + components: plugin.components, priority: plugin.priority, + install_source: plugin.install_source, + install_info: plugin.install_info, }); }), ); @@ -68,8 +105,149 @@ const PluginInstalledComponent = forwardRef( setModalOpen(true); } + function handlePluginDelete(plugin: PluginCardVO) { + setTargetPlugin(plugin); + setOperationType(PluginOperationType.DELETE); + setShowOperationModal(true); + asyncTask.reset(); + } + + function handlePluginUpdate(plugin: PluginCardVO) { + setTargetPlugin(plugin); + setOperationType(PluginOperationType.UPDATE); + setShowOperationModal(true); + asyncTask.reset(); + } + + function executeOperation() { + if (!targetPlugin) return; + + const apiCall = + operationType === PluginOperationType.DELETE + ? httpClient.removePlugin(targetPlugin.author, targetPlugin.name) + : httpClient.upgradePlugin(targetPlugin.author, targetPlugin.name); + + apiCall + .then((res) => { + asyncTask.startTask(res.task_id); + }) + .catch((error) => { + const errorMessage = + operationType === PluginOperationType.DELETE + ? t('plugins.deleteError') + error.message + : t('plugins.updateError') + error.message; + toast.error(errorMessage); + }); + } + return ( <> + { + if (!open) { + setShowOperationModal(false); + setTargetPlugin(null); + asyncTask.reset(); + } + }} + > + + + + {operationType === PluginOperationType.DELETE + ? t('plugins.deleteConfirm') + : t('plugins.updateConfirm')} + + + + {asyncTask.status === AsyncTaskStatus.WAIT_INPUT && ( +
+ {operationType === PluginOperationType.DELETE + ? t('plugins.confirmDeletePlugin', { + author: targetPlugin?.author ?? '', + name: targetPlugin?.name ?? '', + }) + : t('plugins.confirmUpdatePlugin', { + author: targetPlugin?.author ?? '', + name: targetPlugin?.name ?? '', + })} +
+ )} + {asyncTask.status === AsyncTaskStatus.RUNNING && ( +
+ {operationType === PluginOperationType.DELETE + ? t('plugins.deleting') + : t('plugins.updating')} +
+ )} + {asyncTask.status === AsyncTaskStatus.ERROR && ( +
+ {operationType === PluginOperationType.DELETE + ? t('plugins.deleteError') + : t('plugins.updateError')} +
{asyncTask.error}
+
+ )} +
+ + {asyncTask.status === AsyncTaskStatus.WAIT_INPUT && ( + + )} + {asyncTask.status === AsyncTaskStatus.WAIT_INPUT && ( + + )} + {asyncTask.status === AsyncTaskStatus.RUNNING && ( + + )} + {asyncTask.status === AsyncTaskStatus.ERROR && ( + + )} + +
+
+ {pluginList.length === 0 ? (
( { + onFormSubmit={(timeout?: number) => { setModalOpen(false); - getPluginList(); + if (timeout) { + setTimeout(() => { + getPluginList(); + }, timeout); + } else { + getPluginList(); + } }} onFormCancel={() => { setModalOpen(false); @@ -113,6 +297,8 @@ const PluginInstalledComponent = forwardRef( handlePluginClick(vo)} + onDeleteClick={() => handlePluginDelete(vo)} + onUpgradeClick={() => handlePluginUpdate(vo)} />
); diff --git a/web/src/app/home/plugins/plugin-installed/plugin-card/PluginCardComponent.tsx b/web/src/app/home/plugins/plugin-installed/plugin-card/PluginCardComponent.tsx index fed9c53f..a3e7596d 100644 --- a/web/src/app/home/plugins/plugin-installed/plugin-card/PluginCardComponent.tsx +++ b/web/src/app/home/plugins/plugin-installed/plugin-card/PluginCardComponent.tsx @@ -1,146 +1,231 @@ import { PluginCardVO } from '@/app/home/plugins/plugin-installed/PluginCardVO'; import { useState } from 'react'; -import { httpClient } from '@/app/infra/http/HttpClient'; import { Badge } from '@/components/ui/badge'; -import { Switch } from '@/components/ui/switch'; -import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; +import { TFunction } from 'i18next'; +import { + AudioWaveform, + Wrench, + Hash, + BugIcon, + ExternalLink, + Ellipsis, + Trash, + ArrowUp, +} from 'lucide-react'; +import { getCloudServiceClientSync } from '@/app/infra/http'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { PluginComponent } from '@/app/infra/entities/plugin'; +import { Button } from '@/components/ui/button'; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from '@/components/ui/dropdown-menu'; + +function getComponentList(components: PluginComponent[], t: TFunction) { + const componentKindCount: Record = {}; + + for (const component of components) { + const kind = component.manifest.manifest.kind; + if (componentKindCount[kind]) { + componentKindCount[kind]++; + } else { + componentKindCount[kind] = 1; + } + } + + const kindIconMap: Record = { + Tool: , + EventListener: , + Command: , + }; + + const componentKindList = Object.keys(componentKindCount); + + return ( + <> +
{t('plugins.componentsList')}
+ {componentKindList.length > 0 && ( + <> + {componentKindList.map((kind) => { + return ( +
+ {kindIconMap[kind]} {componentKindCount[kind]} +
+ ); + })} + + )} + + {componentKindList.length === 0 &&
{t('plugins.noComponents')}
} + + ); +} export default function PluginCardComponent({ cardVO, onCardClick, + onDeleteClick, + onUpgradeClick, }: { cardVO: PluginCardVO; onCardClick: () => void; + onDeleteClick: (cardVO: PluginCardVO) => void; + onUpgradeClick: (cardVO: PluginCardVO) => void; }) { const { t } = useTranslation(); - const [enabled, setEnabled] = useState(cardVO.enabled); - const [switchEnable, setSwitchEnable] = useState(true); + const [dropdownOpen, setDropdownOpen] = useState(false); - function handleEnable(e: React.MouseEvent) { - e.stopPropagation(); // 阻止事件冒泡 - setSwitchEnable(false); - httpClient - .togglePlugin(cardVO.author, cardVO.name, !enabled) - .then(() => { - setEnabled(!enabled); - }) - .catch((err) => { - toast.error(t('plugins.modifyFailed') + err.message); - }) - .finally(() => { - setSwitchEnable(true); - }); - } return ( -
-
- - - + <> +
+
+ {/* + + */} + plugin icon -
-
+
-
- {cardVO.author} /{' '} -
-
-
- {cardVO.name} +
+
+ {cardVO.author} / {cardVO.name}
- - v{cardVO.version} - +
+
+ {cardVO.label} +
+ + v{cardVO.version} + + {cardVO.debug && ( + + + {t('plugins.debugging')} + + )} + {!cardVO.debug && ( + <> + {cardVO.install_source === 'github' && ( + { + e.stopPropagation(); + window.open( + cardVO.install_info.github_url, + '_blank', + ); + }} + > + {t('plugins.fromGithub')} + + + )} + {cardVO.install_source === 'local' && ( + + {t('plugins.fromLocal')} + + )} + {cardVO.install_source === 'marketplace' && ( + { + e.stopPropagation(); + window.open( + getCloudServiceClientSync().getPluginMarketplaceURL( + cardVO.author, + cardVO.name, + ), + '_blank', + ); + }} + > + {t('plugins.fromMarketplace')} + + + )} + + )} +
+
+ +
+ {cardVO.description}
-
- {cardVO.description} +
+ {getComponentList(cardVO.components, t)}
-
-
- - - -
- {t('plugins.eventCount', { - count: Object.keys(cardVO.event_handlers).length, - })} -
-
+
+
-
- - - -
- {t('plugins.toolCount', { count: cardVO.tools.length })} -
+
+ + + + + + {/**upgrade */} + {cardVO.install_source === 'marketplace' && ( + { + e.stopPropagation(); + onUpgradeClick(cardVO); + setDropdownOpen(false); + }} + > + + {t('plugins.update')} + + )} + { + e.stopPropagation(); + onDeleteClick(cardVO); + setDropdownOpen(false); + }} + > + + {t('plugins.delete')} + + +
- -
-
- handleEnable(e)} - disabled={!switchEnable} - /> -
- - {cardVO.repository && - cardVO.repository.trim() && - cardVO.repository.startsWith('http') && ( -
- { - e.stopPropagation(); // 阻止事件冒泡 - if ( - cardVO.repository && - cardVO.repository.trim() && - cardVO.repository.startsWith('http') - ) { - window.open(cardVO.repository, '_blank'); - } - }} - > - - -
- )} -
-
+ ); } diff --git a/web/src/app/home/plugins/plugin-installed/plugin-form/PluginForm.tsx b/web/src/app/home/plugins/plugin-installed/plugin-form/PluginForm.tsx index 7ffd6796..09a79d2f 100644 --- a/web/src/app/home/plugins/plugin-installed/plugin-form/PluginForm.tsx +++ b/web/src/app/home/plugins/plugin-installed/plugin-form/PluginForm.tsx @@ -1,26 +1,13 @@ import { useState, useEffect } from 'react'; -import { ApiRespPluginConfig, Plugin } from '@/app/infra/entities/api'; +import { ApiRespPluginConfig } from '@/app/infra/entities/api'; +import { Plugin } from '@/app/infra/entities/plugin'; import { httpClient } from '@/app/infra/http/HttpClient'; import DynamicFormComponent from '@/app/home/components/dynamic-form/DynamicFormComponent'; import { Button } from '@/components/ui/button'; -import { - Dialog, - DialogContent, - DialogDescription, - DialogHeader, - DialogTitle, - DialogFooter, -} from '@/components/ui/dialog'; import { toast } from 'sonner'; -import { i18nObj } from '@/i18n/I18nProvider'; +import { extractI18nObject } from '@/i18n/I18nProvider'; import { useTranslation } from 'react-i18next'; -enum PluginRemoveStatus { - WAIT_INPUT = 'WAIT_INPUT', - REMOVING = 'REMOVING', - ERROR = 'ERROR', -} - export default function PluginForm({ pluginAuthor, pluginName, @@ -29,7 +16,7 @@ export default function PluginForm({ }: { pluginAuthor: string; pluginName: string; - onFormSubmit: () => void; + onFormSubmit: (timeout?: number) => void; onFormCancel: () => void; }) { const { t } = useTranslation(); @@ -37,13 +24,6 @@ export default function PluginForm({ const [pluginConfig, setPluginConfig] = useState(); const [isSaving, setIsLoading] = useState(false); - const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false); - const [pluginRemoveStatus, setPluginRemoveStatus] = - useState(PluginRemoveStatus.WAIT_INPUT); - const [pluginRemoveError, setPluginRemoveError] = useState( - null, - ); - useEffect(() => { // 获取插件信息 httpClient.getPlugin(pluginAuthor, pluginName).then((res) => { @@ -57,14 +37,19 @@ export default function PluginForm({ const handleSubmit = async (values: object) => { setIsLoading(true); + const isDebugPlugin = pluginInfo?.debug; httpClient .updatePluginConfig(pluginAuthor, pluginName, values) .then(() => { - onFormSubmit(); - toast.success('保存成功'); + toast.success( + isDebugPlugin + ? t('plugins.saveConfigSuccessDebugPlugin') + : t('plugins.saveConfigSuccessNormal'), + ); + onFormSubmit(1000); }) .catch((error) => { - toast.error('保存失败:' + error.message); + toast.error(t('plugins.saveConfigError') + error.message); }) .finally(() => { setIsLoading(false); @@ -72,124 +57,30 @@ export default function PluginForm({ }; if (!pluginInfo || !pluginConfig) { - return
{t('plugins.loading')}
; - } - - function deletePlugin() { - setPluginRemoveStatus(PluginRemoveStatus.REMOVING); - httpClient - .removePlugin(pluginAuthor, pluginName) - .then((res) => { - const taskId = res.task_id; - - let alreadySuccess = false; - - const interval = setInterval(() => { - httpClient.getAsyncTask(taskId).then((res) => { - if (res.runtime.done) { - clearInterval(interval); - if (res.runtime.exception) { - setPluginRemoveError(res.runtime.exception); - setPluginRemoveStatus(PluginRemoveStatus.ERROR); - } else { - // success - if (!alreadySuccess) { - toast.success('插件删除成功'); - alreadySuccess = true; - } - setPluginRemoveStatus(PluginRemoveStatus.WAIT_INPUT); - setShowDeleteConfirmModal(false); - onFormSubmit(); - } - } - }); - }, 1000); - }) - .catch((error) => { - setPluginRemoveError(error.message); - setPluginRemoveStatus(PluginRemoveStatus.ERROR); - }); + return ( +
+ {t('plugins.loading')} +
+ ); } return (
- - - - {t('plugins.deleteConfirm')} - - - {pluginRemoveStatus === PluginRemoveStatus.WAIT_INPUT && ( -
- {t('plugins.confirmDeletePlugin', { - author: pluginAuthor, - name: pluginName, - })} -
- )} - {pluginRemoveStatus === PluginRemoveStatus.REMOVING && ( -
{t('plugins.deleting')}
- )} - {pluginRemoveStatus === PluginRemoveStatus.ERROR && ( -
- {t('plugins.deleteError')} -
{pluginRemoveError}
-
- )} -
- - {pluginRemoveStatus === PluginRemoveStatus.WAIT_INPUT && ( - - )} - {pluginRemoveStatus === PluginRemoveStatus.WAIT_INPUT && ( - - )} - {pluginRemoveStatus === PluginRemoveStatus.REMOVING && ( - - )} - {pluginRemoveStatus === PluginRemoveStatus.ERROR && ( - - )} - -
-
-
-
{pluginInfo.name}
-
- {i18nObj(pluginInfo.description)} +
+ {extractI18nObject(pluginInfo.manifest.manifest.metadata.label)}
- {pluginInfo.config_schema.length > 0 && ( +
+ {extractI18nObject( + pluginInfo.manifest.manifest.metadata.description ?? { + en_US: '', + zh_Hans: '', + }, + )} +
+ {pluginInfo.manifest.manifest.spec.config.length > 0 && ( } onSubmit={(values) => { let config = pluginConfig.config; @@ -203,7 +94,7 @@ export default function PluginForm({ }} /> )} - {pluginInfo.config_schema.length === 0 && ( + {pluginInfo.manifest.manifest.spec.config.length === 0 && (
{t('plugins.pluginNoConfig')}
@@ -212,19 +103,6 @@ export default function PluginForm({
- - + )} +
+
+
+ ); + + const PluginDescription = () => ( +
+

+ {extractI18nObject(plugin!.description) || t('market.noDescription')} +

+
+ ); + + const PluginOptions = () => ( +
+ +
+ ); + + const ReadmeContent = () => ( +
+ ( +
+ + + ), + thead: ({ ...props }) => ( + + ), + tbody: ({ ...props }) => ( + + ), + th: ({ ...props }) => ( + + ), + // 删除线支持 + del: ({ ...props }) => ( + + ), + // Todo 列表支持 + input: ({ type, checked, ...props }) => { + if (type === 'checkbox') { + return ( + + ); + } + return ; + }, + ul: ({ ...props }) => ( +
    + ), + ol: ({ ...props }) => ( +
      + ), + li: ({ ...props }) =>
    1. , + h1: ({ ...props }) => ( +

      + ), + h2: ({ ...props }) => ( +

      + ), + p: ({ ...props }) => ( +

      + ), + code: ({ className, children, ...props }) => { + const match = /language-(\w+)/.exec(className || ''); + const isCodeBlock = match ? true : false; + + // 如果是代码块(有语言标识),由 pre 标签处理样式,淡灰色底,黑色字 + if (isCodeBlock) { + return ( + + {children} + + ); + } + + // 内联代码样式 - 淡灰色底 + return ( + + {children} + + ); + }, + pre: ({ ...props }) => ( +

      +          ),
      +        }}
      +      >
      +        {readme}
      +      
      +    
      +  );
      +
      +  return (
      +    
      +      
      +        {isLoading ? (
      +          
      + + {t('market.loading')} +
      + ) : plugin ? ( +
      + {/* 插件信息区域 */} +
      +
      +
      + + +
      +
      + +
      +
      +
      + + {/* README 区域 */} +
      +
      + {isLoadingReadme ? ( +
      + + + {t('market.loading')} + +
      + ) : ( + + )} +
      +
      +
      + ) : null} +
      +
      + ); +} diff --git a/web/src/app/home/plugins/plugin-market/plugin-market-card/PluginMarketCardComponent.tsx b/web/src/app/home/plugins/plugin-market/plugin-market-card/PluginMarketCardComponent.tsx index fc4a4812..6fb6618b 100644 --- a/web/src/app/home/plugins/plugin-market/plugin-market-card/PluginMarketCardComponent.tsx +++ b/web/src/app/home/plugins/plugin-market/plugin-market-card/PluginMarketCardComponent.tsx @@ -1,86 +1,79 @@ -import { PluginMarketCardVO } from '@/app/home/plugins/plugin-market/plugin-market-card/PluginMarketCardVO'; -import { Button } from '@/components/ui/button'; -import { useTranslation } from 'react-i18next'; +import { PluginMarketCardVO } from './PluginMarketCardVO'; export default function PluginMarketCardComponent({ cardVO, - installPlugin, + onPluginClick, }: { cardVO: PluginMarketCardVO; - installPlugin: (pluginURL: string) => void; + onPluginClick?: (author: string, pluginName: string) => void; }) { - const { t } = useTranslation(); - - function handleInstallClick(pluginURL: string) { - installPlugin(pluginURL); + function handleCardClick() { + if (onPluginClick) { + onPluginClick(cardVO.author, cardVO.pluginName); + } } return ( -
      -
      - - - +
      +
      + {/* 上部分:插件信息 */} +
      + plugin icon -
      -
      +
      -
      - {cardVO.author} /{' '} +
      + {cardVO.pluginId}
      - {cardVO.name} + {cardVO.label}
      -
      +
      {cardVO.description}
      -
      -
      +
      + {cardVO.githubURL && ( - - -
      - {t('plugins.starCount', { count: cardVO.starCount })} -
      -
      - -
      - window.open(cardVO.githubURL, '_blank')} + onClick={(e) => { + e.stopPropagation(); + window.open(cardVO.githubURL, '_blank'); + }} > - -
      + )} +
      +
      + + {/* 下部分:下载量 */} +
      + + + + + +
      + {cardVO.installCount.toLocaleString()}
      diff --git a/web/src/app/home/plugins/plugin-market/plugin-market-card/PluginMarketCardVO.ts b/web/src/app/home/plugins/plugin-market/plugin-market-card/PluginMarketCardVO.ts index fe0a1e75..b4c38bbe 100644 --- a/web/src/app/home/plugins/plugin-market/plugin-market-card/PluginMarketCardVO.ts +++ b/web/src/app/home/plugins/plugin-market/plugin-market-card/PluginMarketCardVO.ts @@ -1,9 +1,11 @@ export interface IPluginMarketCardVO { pluginId: string; author: string; - name: string; + pluginName: string; + label: string; description: string; - starCount: number; + installCount: number; + iconURL: string; githubURL: string; version: string; } @@ -11,18 +13,22 @@ export interface IPluginMarketCardVO { export class PluginMarketCardVO implements IPluginMarketCardVO { pluginId: string; description: string; - name: string; + label: string; author: string; + pluginName: string; + iconURL: string; githubURL: string; - starCount: number; + installCount: number; version: string; constructor(prop: IPluginMarketCardVO) { this.description = prop.description; - this.name = prop.name; + this.label = prop.label; this.author = prop.author; + this.pluginName = prop.pluginName; + this.iconURL = prop.iconURL; this.githubURL = prop.githubURL; - this.starCount = prop.starCount; + this.installCount = prop.installCount; this.pluginId = prop.pluginId; this.version = prop.version; } diff --git a/web/src/app/home/plugins/plugin-sort/PluginSortDialog.tsx b/web/src/app/home/plugins/plugin-sort/PluginSortDialog.tsx index ad6874eb..998ae93e 100644 --- a/web/src/app/home/plugins/plugin-sort/PluginSortDialog.tsx +++ b/web/src/app/home/plugins/plugin-sort/PluginSortDialog.tsx @@ -1,209 +1,215 @@ -'use client'; +// 'use client'; -import * as React from 'react'; -import { useState, useEffect } from 'react'; -import { PluginCardVO } from '@/app/home/plugins/plugin-installed/PluginCardVO'; -import { httpClient } from '@/app/infra/http/HttpClient'; -import { PluginReorderElement } from '@/app/infra/entities/api'; -import { toast } from 'sonner'; -import { - Dialog, - DialogContent, - DialogHeader, - DialogTitle, - DialogFooter, -} from '@/components/ui/dialog'; -import { Button } from '@/components/ui/button'; -import { - DndContext, - closestCenter, - KeyboardSensor, - PointerSensor, - useSensor, - useSensors, - DragEndEvent, -} from '@dnd-kit/core'; -import { - arrayMove, - SortableContext, - sortableKeyboardCoordinates, - useSortable, - verticalListSortingStrategy, -} from '@dnd-kit/sortable'; -import { CSS } from '@dnd-kit/utilities'; -import { useTranslation } from 'react-i18next'; -import { i18nObj } from '@/i18n/I18nProvider'; +// import * as React from 'react'; +// import { useState, useEffect } from 'react'; +// import { PluginCardVO } from '@/app/home/plugins/plugin-installed/PluginCardVO'; +// import { httpClient } from '@/app/infra/http/HttpClient'; +// import { PluginReorderElement } from '@/app/infra/entities/api'; +// import { toast } from 'sonner'; +// import { +// Dialog, +// DialogContent, +// DialogHeader, +// DialogTitle, +// DialogFooter, +// } from '@/components/ui/dialog'; +// import { Button } from '@/components/ui/button'; +// import { +// DndContext, +// closestCenter, +// KeyboardSensor, +// PointerSensor, +// useSensor, +// useSensors, +// DragEndEvent, +// } from '@dnd-kit/core'; +// import { +// arrayMove, +// SortableContext, +// sortableKeyboardCoordinates, +// useSortable, +// verticalListSortingStrategy, +// } from '@dnd-kit/sortable'; +// import { CSS } from '@dnd-kit/utilities'; +// import { useTranslation } from 'react-i18next'; +// import { extractI18nObject } from '@/i18n/I18nProvider'; -interface PluginSortDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; - onSortComplete: () => void; -} +// interface PluginSortDialogProps { +// open: boolean; +// onOpenChange: (open: boolean) => void; +// onSortComplete: () => void; +// } -function SortablePluginItem({ plugin }: { plugin: PluginCardVO }) { - const { attributes, listeners, setNodeRef, transform, transition } = - useSortable({ - id: `${plugin.author}-${plugin.name}`, - }); +// function SortablePluginItem({ plugin }: { plugin: PluginCardVO }) { +// const { attributes, listeners, setNodeRef, transform, transition } = +// useSortable({ +// id: `${plugin.author}-${plugin.name}`, +// }); - const style = { - transform: CSS.Transform.toString(transform), - transition, - }; +// const style = { +// transform: CSS.Transform.toString(transform), +// transition, +// }; - return ( -
      -
      -
      - {plugin.author} -
      -
      {plugin.name}
      -
      - {plugin.description} -
      -
      -
      - ); -} +// return ( +//
      +//
      +//
      +// {plugin.author} +//
      +//
      {plugin.name}
      +//
      +// {plugin.description} +//
      +//
      +//
      +// ); +// } -export default function PluginSortDialog({ - open, - onOpenChange, - onSortComplete, -}: PluginSortDialogProps) { - const { t } = useTranslation(); - const [sortedPlugins, setSortedPlugins] = useState([]); - const [isLoading, setIsLoading] = useState(false); +// export default function PluginSortDialog({ +// open, +// onOpenChange, +// onSortComplete, +// }: PluginSortDialogProps) { +// const { t } = useTranslation(); +// const [sortedPlugins, setSortedPlugins] = useState([]); +// const [isLoading, setIsLoading] = useState(false); - function getPluginList() { - httpClient.getPlugins().then((value) => { - setSortedPlugins( - value.plugins.map((plugin) => { - return new PluginCardVO({ - author: plugin.author, - description: i18nObj(plugin.description), - enabled: plugin.enabled, - name: plugin.name, - version: plugin.version, - status: plugin.status, - tools: plugin.tools, - event_handlers: plugin.event_handlers, - repository: plugin.repository, - priority: plugin.priority, - }); - }), - ); - }); - } +// function getPluginList() { +// httpClient.getPlugins().then((value) => { +// setSortedPlugins( +// value.plugins.map((plugin) => { +// return new PluginCardVO({ +// author: plugin.manifest.manifest.metadata.author ?? '', +// description: extractI18nObject( +// plugin.manifest.manifest.metadata.description ?? { +// en_US: '', +// zh_Hans: '', +// }, +// ), +// enabled: plugin.enabled, +// name: plugin.manifest.manifest.metadata.name, +// version: plugin.manifest.manifest.metadata.version ?? '', +// status: plugin.status, +// components: plugin.components, +// install_source: plugin.install_source, +// install_info: plugin.install_info, +// priority: plugin.priority, +// debug: plugin.debug, +// }); +// }), +// ); +// }); +// } - useEffect(() => { - if (open) { - getPluginList(); - } - }, [open]); +// useEffect(() => { +// if (open) { +// getPluginList(); +// } +// }, [open]); - const sensors = useSensors( - useSensor(PointerSensor), - useSensor(KeyboardSensor, { - coordinateGetter: sortableKeyboardCoordinates, - }), - ); +// const sensors = useSensors( +// useSensor(PointerSensor), +// useSensor(KeyboardSensor, { +// coordinateGetter: sortableKeyboardCoordinates, +// }), +// ); - function handleDragEnd(event: DragEndEvent) { - const { active, over } = event; - console.log('Drag end event:', { active, over }); +// function handleDragEnd(event: DragEndEvent) { +// const { active, over } = event; +// console.log('Drag end event:', { active, over }); - if (over && active.id !== over.id) { - setSortedPlugins((items) => { - const oldIndex = items.findIndex( - (item) => `${item.author}-${item.name}` === active.id, - ); - const newIndex = items.findIndex( - (item) => `${item.author}-${item.name}` === over.id, - ); +// if (over && active.id !== over.id) { +// setSortedPlugins((items) => { +// const oldIndex = items.findIndex( +// (item) => `${item.author}-${item.name}` === active.id, +// ); +// const newIndex = items.findIndex( +// (item) => `${item.author}-${item.name}` === over.id, +// ); - const newItems = arrayMove(items, oldIndex, newIndex); +// const newItems = arrayMove(items, oldIndex, newIndex); - return newItems; - }); - } - } +// return newItems; +// }); +// } +// } - function handleSave() { - setIsLoading(true); +// function handleSave() { +// setIsLoading(true); - const reorderElements: PluginReorderElement[] = sortedPlugins.map( - (plugin, index) => ({ - author: plugin.author, - name: plugin.name, - priority: index, - }), - ); +// const reorderElements: PluginReorderElement[] = sortedPlugins.map( +// (plugin, index) => ({ +// author: plugin.author, +// name: plugin.name, +// priority: index, +// }), +// ); - httpClient - .reorderPlugins(reorderElements) - .then(() => { - toast.success(t('plugins.pluginSortSuccess')); - onSortComplete(); - onOpenChange(false); - }) - .catch((err) => { - toast.error(t('plugins.pluginSortError') + err.message); - }) - .finally(() => { - setIsLoading(false); - }); - } +// httpClient +// .reorderPlugins(reorderElements) +// .then(() => { +// toast.success(t('plugins.pluginSortSuccess')); +// onSortComplete(); +// onOpenChange(false); +// }) +// .catch((err) => { +// toast.error(t('plugins.pluginSortError') + err.message); +// }) +// .finally(() => { +// setIsLoading(false); +// }); +// } - return ( - - - - {t('plugins.pluginSort')} - -
      -

      - {t('plugins.pluginSortDescription')} -

      - - `${plugin.author}-${plugin.name}`, - )} - strategy={verticalListSortingStrategy} - > - {sortedPlugins.map((plugin) => ( - - ))} - - -
      - - - - -
      -
      - ); -} +// return ( +// +// +// +// {t('plugins.pluginSort')} +// +//
      +//

      +// {t('plugins.pluginSortDescription')} +//

      +// +// `${plugin.author}-${plugin.name}`, +// )} +// strategy={verticalListSortingStrategy} +// > +// {sortedPlugins.map((plugin) => ( +// +// ))} +// +// +//
      +// +// +// +// +//
      +//
      +// ); +// } diff --git a/web/src/app/home/plugins/plugins.module.css b/web/src/app/home/plugins/plugins.module.css index 54ede1c6..a65be354 100644 --- a/web/src/app/home/plugins/plugins.module.css +++ b/web/src/app/home/plugins/plugins.module.css @@ -13,7 +13,7 @@ padding-right: 0.8rem; padding-top: 2rem; display: grid; - grid-template-columns: repeat(auto-fill, minmax(24rem, 1fr)); + grid-template-columns: repeat(auto-fill, minmax(30rem, 1fr)); gap: 2rem; justify-items: stretch; align-items: start; diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index d687dce5..5b187cae 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -1,7 +1,8 @@ import { IDynamicFormItemSchema } from '@/app/infra/entities/form/dynamic'; import { PipelineConfigTab } from '@/app/infra/entities/pipeline'; -import { I18nLabel } from '@/app/infra/entities/common'; +import { I18nObject } from '@/app/infra/entities/common'; import { Message } from '@/app/infra/entities/message'; +import { Plugin, PluginV4 } from '@/app/infra/entities/plugin'; export interface ApiResponse { code: number; @@ -23,8 +24,8 @@ export interface ApiRespProviderRequester { export interface Requester { name: string; - label: I18nLabel; - description: I18nLabel; + label: I18nObject; + description: I18nObject; icon?: string; spec: { config: IDynamicFormItemSchema[]; @@ -113,8 +114,8 @@ export interface ApiRespPlatformAdapter { export interface Adapter { name: string; - label: I18nLabel; - description: I18nLabel; + label: I18nObject; + description: I18nObject; icon?: string; spec: { config: IDynamicFormItemSchema[]; @@ -179,22 +180,22 @@ export interface ApiRespPlugin { plugin: Plugin; } -export interface Plugin { - author: string; - name: string; - description: I18nLabel; - label: I18nLabel; - version: string; - enabled: boolean; - priority: number; - status: string; - tools: object[]; - event_handlers: object; - main_file: string; - pkg_path: string; - repository: string; - config_schema: IDynamicFormItemSchema[]; -} +// export interface Plugin { +// author: string; +// name: string; +// description: I18nLabel; +// label: I18nLabel; +// version: string; +// enabled: boolean; +// priority: number; +// status: string; +// tools: object[]; +// event_handlers: object; +// main_file: string; +// pkg_path: string; +// repository: string; +// config_schema: IDynamicFormItemSchema[]; +// } export interface ApiRespPluginConfig { config: object; @@ -210,6 +211,8 @@ export interface PluginReorderElement { export interface ApiRespSystemInfo { debug: boolean; version: string; + cloud_service_url: string; + enable_marketplace: boolean; } export interface ApiRespAsyncTasks { @@ -241,26 +244,13 @@ export interface ApiRespUserToken { token: string; } -export interface MarketPlugin { - ID: number; - CreatedAt: string; // ISO 8601 格式日期 - UpdatedAt: string; - DeletedAt: string | null; - name: string; - author: string; - description: string; - repository: string; // GitHub 仓库路径 - artifacts_path: string; - stars: number; - downloads: number; - status: 'initialized' | 'mounted'; // 可根据实际状态值扩展联合类型 - synced_at: string; - pushed_at: string; // 最后一次代码推送时间 +export interface ApiRespMarketplacePlugins { + plugins: PluginV4[]; + total: number; } -export interface MarketPluginResponse { - plugins: MarketPlugin[]; - total: number; +export interface ApiRespMarketplacePluginDetail { + plugin: PluginV4; } interface GetPipelineConfig { diff --git a/web/src/app/infra/entities/common.ts b/web/src/app/infra/entities/common.ts index 02cd99f8..64331738 100644 --- a/web/src/app/infra/entities/common.ts +++ b/web/src/app/infra/entities/common.ts @@ -1,5 +1,21 @@ -export interface I18nLabel { +export interface I18nObject { en_US: string; zh_Hans: string; + zh_Hant?: string; ja_JP?: string; } + +export interface ComponentManifest { + apiVersion: string; + kind: string; + metadata: { + name: string; + label: I18nObject; + description?: I18nObject; + icon?: string; + repository?: string; + version?: string; + author?: string; + }; + spec: Record; // eslint-disable-line @typescript-eslint/no-explicit-any +} diff --git a/web/src/app/infra/entities/form/dynamic.ts b/web/src/app/infra/entities/form/dynamic.ts index 6d6de096..2d733f49 100644 --- a/web/src/app/infra/entities/form/dynamic.ts +++ b/web/src/app/infra/entities/form/dynamic.ts @@ -1,13 +1,13 @@ -import { I18nLabel } from '@/app/infra/entities/common'; +import { I18nObject } from '@/app/infra/entities/common'; export interface IDynamicFormItemSchema { id: string; default: string | number | boolean | Array; - label: I18nLabel; + label: I18nObject; name: string; required: boolean; type: DynamicFormItemType; - description?: I18nLabel; + description?: I18nObject; options?: IDynamicFormItemOption[]; } @@ -26,5 +26,5 @@ export enum DynamicFormItemType { export interface IDynamicFormItemOption { name: string; - label: I18nLabel; + label: I18nObject; } diff --git a/web/src/app/infra/entities/pipeline/index.ts b/web/src/app/infra/entities/pipeline/index.ts index 29a5f6af..cc411c9f 100644 --- a/web/src/app/infra/entities/pipeline/index.ts +++ b/web/src/app/infra/entities/pipeline/index.ts @@ -1,4 +1,4 @@ -import { I18nLabel } from '@/app/infra/entities/common'; +import { I18nObject } from '@/app/infra/entities/common'; import { IDynamicFormItemSchema } from '@/app/infra/entities/form/dynamic'; export interface PipelineFormEntity { @@ -11,13 +11,13 @@ export interface PipelineFormEntity { export interface PipelineConfigTab { name: string; - label: I18nLabel; + label: I18nObject; stages: PipelineConfigStage[]; } export interface PipelineConfigStage { name: string; - label: I18nLabel; - description?: I18nLabel; + label: I18nObject; + description?: I18nObject; config: IDynamicFormItemSchema[]; } diff --git a/web/src/app/infra/entities/plugin/index.ts b/web/src/app/infra/entities/plugin/index.ts new file mode 100644 index 00000000..1b06abd9 --- /dev/null +++ b/web/src/app/infra/entities/plugin/index.ts @@ -0,0 +1,46 @@ +import { ComponentManifest, I18nObject } from '@/app/infra/entities/common'; + +export interface Plugin { + status: 'intialized' | 'mounted' | 'unmounted'; + priority: number; + plugin_config: object; + manifest: { + manifest: ComponentManifest; + }; + debug: boolean; + enabled: boolean; + install_source: string; + install_info: Record; // eslint-disable-line @typescript-eslint/no-explicit-any + components: PluginComponent[]; +} + +export interface PluginComponent { + component_config: object; + manifest: { + manifest: ComponentManifest; + }; +} + +// marketplace plugin v4 +export enum PluginV4Status { + Any = 'any', + Live = 'live', + Deleted = 'deleted', +} + +export interface PluginV4 { + id: number; + plugin_id: string; + author: string; + name: string; + label: I18nObject; + description: I18nObject; + icon: string; + repository: string; + tags: string[]; + install_count: number; + latest_version: string; + status: PluginV4Status; + created_at: string; + updated_at: string; +} diff --git a/web/src/app/infra/http/BackendClient.ts b/web/src/app/infra/http/BackendClient.ts new file mode 100644 index 00000000..10f28da2 --- /dev/null +++ b/web/src/app/infra/http/BackendClient.ts @@ -0,0 +1,541 @@ +import { BaseHttpClient } from './BaseHttpClient'; +import { + ApiRespProviderRequesters, + ApiRespProviderRequester, + ApiRespProviderLLMModels, + ApiRespProviderLLMModel, + LLMModel, + ApiRespPipelines, + Pipeline, + ApiRespPlatformAdapters, + ApiRespPlatformAdapter, + ApiRespPlatformBots, + ApiRespPlatformBot, + Bot, + ApiRespPlugins, + ApiRespPlugin, + ApiRespPluginConfig, + AsyncTaskCreatedResp, + ApiRespSystemInfo, + ApiRespAsyncTasks, + ApiRespUserToken, + GetPipelineResponseData, + GetPipelineMetadataResponseData, + AsyncTask, + ApiRespWebChatMessage, + ApiRespWebChatMessages, + ApiRespKnowledgeBases, + ApiRespKnowledgeBase, + KnowledgeBase, + ApiRespKnowledgeBaseFiles, + ApiRespKnowledgeBaseRetrieve, + ApiRespProviderEmbeddingModels, + ApiRespProviderEmbeddingModel, + EmbeddingModel, +} from '@/app/infra/entities/api'; +import { GetBotLogsRequest } from '@/app/infra/http/requestParam/bots/GetBotLogsRequest'; +import { GetBotLogsResponse } from '@/app/infra/http/requestParam/bots/GetBotLogsResponse'; + +/** + * 后端服务客户端 + * 负责与后端 API 的所有交互 + */ +export class BackendClient extends BaseHttpClient { + constructor(baseURL: string) { + super(baseURL, false); + } + + // ============ Provider API ============ + public getProviderRequesters( + model_type: string, + ): Promise { + return this.get('/api/v1/provider/requesters', { type: model_type }); + } + + public getProviderRequester(name: string): Promise { + return this.get(`/api/v1/provider/requesters/${name}`); + } + + public getProviderRequesterIconURL(name: string): string { + if (this.instance.defaults.baseURL === '/') { + // 获取用户访问的URL + const url = window.location.href; + const baseURL = url.split('/').slice(0, 3).join('/'); + return `${baseURL}/api/v1/provider/requesters/${name}/icon`; + } + return ( + this.instance.defaults.baseURL + + `/api/v1/provider/requesters/${name}/icon` + ); + } + + // ============ Provider Model LLM ============ + public getProviderLLMModels(): Promise { + return this.get('/api/v1/provider/models/llm'); + } + + public getProviderLLMModel(uuid: string): Promise { + return this.get(`/api/v1/provider/models/llm/${uuid}`); + } + + public createProviderLLMModel(model: LLMModel): Promise { + return this.post('/api/v1/provider/models/llm', model); + } + + public deleteProviderLLMModel(uuid: string): Promise { + return this.delete(`/api/v1/provider/models/llm/${uuid}`); + } + + public updateProviderLLMModel( + uuid: string, + model: LLMModel, + ): Promise { + return this.put(`/api/v1/provider/models/llm/${uuid}`, model); + } + + public testLLMModel(uuid: string, model: LLMModel): Promise { + return this.post(`/api/v1/provider/models/llm/${uuid}/test`, model); + } + + // ============ Provider Model Embedding ============ + public getProviderEmbeddingModels(): Promise { + return this.get('/api/v1/provider/models/embedding'); + } + + public getProviderEmbeddingModel( + uuid: string, + ): Promise { + return this.get(`/api/v1/provider/models/embedding/${uuid}`); + } + + public createProviderEmbeddingModel(model: EmbeddingModel): Promise { + return this.post('/api/v1/provider/models/embedding', model); + } + + public deleteProviderEmbeddingModel(uuid: string): Promise { + return this.delete(`/api/v1/provider/models/embedding/${uuid}`); + } + + public updateProviderEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.put(`/api/v1/provider/models/embedding/${uuid}`, model); + } + + public testEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.post(`/api/v1/provider/models/embedding/${uuid}/test`, model); + } + + // ============ Pipeline API ============ + public getGeneralPipelineMetadata(): Promise { + // as designed, this method will be deprecated, and only for developer to check the prefered config schema + return this.get('/api/v1/pipelines/_/metadata'); + } + + public getPipelines( + sortBy?: string, + sortOrder?: string, + ): Promise { + const params = new URLSearchParams(); + if (sortBy) params.append('sort_by', sortBy); + if (sortOrder) params.append('sort_order', sortOrder); + const queryString = params.toString(); + return this.get(`/api/v1/pipelines${queryString ? `?${queryString}` : ''}`); + } + + public getPipeline(uuid: string): Promise { + return this.get(`/api/v1/pipelines/${uuid}`); + } + + public createPipeline(pipeline: Pipeline): Promise<{ + uuid: string; + }> { + return this.post('/api/v1/pipelines', pipeline); + } + + public updatePipeline(uuid: string, pipeline: Pipeline): Promise { + return this.put(`/api/v1/pipelines/${uuid}`, pipeline); + } + + public deletePipeline(uuid: string): Promise { + return this.delete(`/api/v1/pipelines/${uuid}`); + } + + // ============ Debug WebChat API ============ + + // ============ Debug WebChat API ============ + public sendWebChatMessage( + sessionType: string, + messageChain: object[], + pipelineId: string, + timeout: number = 15000, + ): Promise { + return this.post( + `/api/v1/pipelines/${pipelineId}/chat/send`, + { + session_type: sessionType, + message: messageChain, + }, + { + timeout, + }, + ); + } + + public async sendStreamingWebChatMessage( + sessionType: string, + messageChain: object[], + pipelineId: string, + onMessage: (data: ApiRespWebChatMessage) => void, + onComplete: () => void, + onError: (error: Error) => void, + ): Promise { + try { + // 构造完整的URL,处理相对路径的情况 + let url = `${this.baseURL}/api/v1/pipelines/${pipelineId}/chat/send`; + if (this.baseURL === '/') { + // 获取用户访问的完整URL + const baseURL = window.location.origin; + url = `${baseURL}/api/v1/pipelines/${pipelineId}/chat/send`; + } + + // 使用fetch发送流式请求,因为axios在浏览器环境中不直接支持流式响应 + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${this.getSessionSync()}`, + }, + body: JSON.stringify({ + session_type: sessionType, + message: messageChain, + is_stream: true, + }), + }); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + if (!response.body) { + throw new Error('ReadableStream not supported'); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + // 读取流式响应 + try { + while (true) { + const { done, value } = await reader.read(); + + if (done) { + onComplete(); + break; + } + + // 解码数据 + buffer += decoder.decode(value, { stream: true }); + + // 处理完整的JSON对象 + const lines = buffer.split('\n\n'); + buffer = lines.pop() || ''; + + for (const line of lines) { + if (line.startsWith('data:')) { + try { + const data = JSON.parse(line.slice(5)); + + if (data.type === 'end') { + // 流传输结束 + reader.cancel(); + onComplete(); + return; + } + if (data.type === 'start') { + console.log(data.type); + } + + if (data.message) { + // 处理消息数据 + onMessage(data); + } + } catch (error) { + console.error('Error parsing streaming data:', error); + } + } + } + } + } finally { + reader.releaseLock(); + } + } catch (error) { + onError(error as Error); + } + } + + public getWebChatHistoryMessages( + pipelineId: string, + sessionType: string, + ): Promise { + return this.get( + `/api/v1/pipelines/${pipelineId}/chat/messages/${sessionType}`, + ); + } + + public resetWebChatSession( + pipelineId: string, + sessionType: string, + ): Promise<{ message: string }> { + return this.post( + `/api/v1/pipelines/${pipelineId}/chat/reset/${sessionType}`, + ); + } + + // ============ Platform API ============ + public getAdapters(): Promise { + return this.get('/api/v1/platform/adapters'); + } + + public getAdapter(name: string): Promise { + return this.get(`/api/v1/platform/adapters/${name}`); + } + + public getAdapterIconURL(name: string): string { + if (this.instance.defaults.baseURL === '/') { + // 获取用户访问的URL + const url = window.location.href; + const baseURL = url.split('/').slice(0, 3).join('/'); + return `${baseURL}/api/v1/platform/adapters/${name}/icon`; + } + return ( + this.instance.defaults.baseURL + `/api/v1/platform/adapters/${name}/icon` + ); + } + + // ============ Platform Bots ============ + public getBots(): Promise { + return this.get('/api/v1/platform/bots'); + } + + public getBot(uuid: string): Promise { + return this.get(`/api/v1/platform/bots/${uuid}`); + } + + public createBot(bot: Bot): Promise<{ uuid: string }> { + return this.post('/api/v1/platform/bots', bot); + } + + public updateBot(uuid: string, bot: Bot): Promise { + return this.put(`/api/v1/platform/bots/${uuid}`, bot); + } + + public deleteBot(uuid: string): Promise { + return this.delete(`/api/v1/platform/bots/${uuid}`); + } + + public getBotLogs( + botId: string, + request: GetBotLogsRequest, + ): Promise { + return this.post(`/api/v1/platform/bots/${botId}/logs`, request); + } + + // ============ File management API ============ + public uploadDocumentFile(file: File): Promise<{ file_id: string }> { + const formData = new FormData(); + formData.append('file', file); + + return this.request<{ file_id: string }>({ + method: 'post', + url: '/api/v1/files/documents', + data: formData, + headers: { + 'Content-Type': 'multipart/form-data', + }, + }); + } + + // ============ Knowledge Base API ============ + public getKnowledgeBases(): Promise { + return this.get('/api/v1/knowledge/bases'); + } + + public getKnowledgeBase(uuid: string): Promise { + return this.get(`/api/v1/knowledge/bases/${uuid}`); + } + + public createKnowledgeBase(base: KnowledgeBase): Promise<{ uuid: string }> { + return this.post('/api/v1/knowledge/bases', base); + } + + public updateKnowledgeBase( + uuid: string, + base: KnowledgeBase, + ): Promise<{ uuid: string }> { + return this.put(`/api/v1/knowledge/bases/${uuid}`, base); + } + + public uploadKnowledgeBaseFile( + uuid: string, + file_id: string, + ): Promise { + return this.post(`/api/v1/knowledge/bases/${uuid}/files`, { + file_id, + }); + } + + public getKnowledgeBaseFiles( + uuid: string, + ): Promise { + return this.get(`/api/v1/knowledge/bases/${uuid}/files`); + } + + public deleteKnowledgeBaseFile( + uuid: string, + file_id: string, + ): Promise { + return this.delete(`/api/v1/knowledge/bases/${uuid}/files/${file_id}`); + } + + public deleteKnowledgeBase(uuid: string): Promise { + return this.delete(`/api/v1/knowledge/bases/${uuid}`); + } + + public retrieveKnowledgeBase( + uuid: string, + query: string, + ): Promise { + return this.post(`/api/v1/knowledge/bases/${uuid}/retrieve`, { query }); + } + + // ============ Plugins API ============ + public getPlugins(): Promise { + return this.get('/api/v1/plugins'); + } + + public getPlugin(author: string, name: string): Promise { + return this.get(`/api/v1/plugins/${author}/${name}`); + } + + public getPluginConfig( + author: string, + name: string, + ): Promise { + return this.get(`/api/v1/plugins/${author}/${name}/config`); + } + + public updatePluginConfig( + author: string, + name: string, + config: object, + ): Promise { + return this.put(`/api/v1/plugins/${author}/${name}/config`, config); + } + + public getPluginIconURL(author: string, name: string): string { + if (this.instance.defaults.baseURL === '/') { + const url = window.location.href; + const baseURL = url.split('/').slice(0, 3).join('/'); + return `${baseURL}/api/v1/plugins/${author}/${name}/icon`; + } + return ( + this.instance.defaults.baseURL + `/api/v1/plugins/${author}/${name}/icon` + ); + } + + public installPluginFromGithub( + source: string, + ): Promise { + return this.post('/api/v1/plugins/install/github', { source }); + } + + public installPluginFromLocal(file: File): Promise { + const formData = new FormData(); + formData.append('file', file); + return this.postFile('/api/v1/plugins/install/local', formData); + } + + public installPluginFromMarketplace( + author: string, + name: string, + version: string, + ): Promise { + return this.post('/api/v1/plugins/install/marketplace', { + plugin_author: author, + plugin_name: name, + plugin_version: version, + }); + } + + public removePlugin( + author: string, + name: string, + ): Promise { + return this.delete(`/api/v1/plugins/${author}/${name}`); + } + + public upgradePlugin( + author: string, + name: string, + ): Promise { + return this.post(`/api/v1/plugins/${author}/${name}/upgrade`); + } + + // ============ System API ============ + public getSystemInfo(): Promise { + return this.get('/api/v1/system/info'); + } + + public getAsyncTasks(): Promise { + return this.get('/api/v1/system/tasks'); + } + + public getAsyncTask(id: number): Promise { + return this.get(`/api/v1/system/tasks/${id}`); + } + + // ============ User API ============ + public checkIfInited(): Promise<{ initialized: boolean }> { + return this.get('/api/v1/user/init'); + } + + public initUser(user: string, password: string): Promise { + return this.post('/api/v1/user/init', { user, password }); + } + + public authUser(user: string, password: string): Promise { + return this.post('/api/v1/user/auth', { user, password }); + } + + public checkUserToken(): Promise { + return this.get('/api/v1/user/check-token'); + } + + public resetPassword( + user: string, + recoveryKey: string, + newPassword: string, + ): Promise<{ user: string }> { + return this.post('/api/v1/user/reset-password', { + user, + recovery_key: recoveryKey, + new_password: newPassword, + }); + } + + public changePassword( + currentPassword: string, + newPassword: string, + ): Promise<{ user: string }> { + return this.post('/api/v1/user/change-password', { + current_password: currentPassword, + new_password: newPassword, + }); + } +} diff --git a/web/src/app/infra/http/BaseHttpClient.ts b/web/src/app/infra/http/BaseHttpClient.ts new file mode 100644 index 00000000..019a54e6 --- /dev/null +++ b/web/src/app/infra/http/BaseHttpClient.ts @@ -0,0 +1,211 @@ +import axios, { + AxiosInstance, + AxiosRequestConfig, + AxiosResponse, + AxiosError, +} from 'axios'; + +type JSONValue = string | number | boolean | JSONObject | JSONArray | null; +interface JSONObject { + [key: string]: JSONValue; +} +type JSONArray = Array; + +export interface ResponseData { + code: number; + message: string; + data: T; + timestamp: number; +} + +export interface RequestConfig extends AxiosRequestConfig { + isSSR?: boolean; // 服务端渲染标识 + retry?: number; // 重试次数 +} + +/** + * 基础 HTTP 客户端类 + * 提供通用的 HTTP 请求方法和拦截器配置 + */ +export abstract class BaseHttpClient { + protected instance: AxiosInstance; + protected disableToken: boolean = false; + protected baseURL: string; + + constructor(baseURL: string, disableToken?: boolean) { + this.baseURL = baseURL; + this.disableToken = disableToken || false; + + this.instance = axios.create({ + baseURL: baseURL, + timeout: 15000, + headers: { + 'Content-Type': 'application/json', + }, + }); + + this.initInterceptors(); + } + + // 外部获取baseURL的方法 + public getBaseUrl(): string { + return this.baseURL; + } + + // 更新 baseURL + public updateBaseURL(newBaseURL: string): void { + this.baseURL = newBaseURL; + this.instance.defaults.baseURL = newBaseURL; + } + + // 同步获取Session + protected getSessionSync(): string | null { + if (typeof window !== 'undefined') { + return localStorage.getItem('token'); + } + return null; + } + + // 拦截器配置 + protected initInterceptors(): void { + // 请求拦截 + this.instance.interceptors.request.use( + async (config) => { + // 客户端添加认证头 + if (typeof window !== 'undefined' && !this.disableToken) { + const session = this.getSessionSync(); + if (session) { + config.headers.Authorization = `Bearer ${session}`; + } + } + + return config; + }, + (error) => Promise.reject(error), + ); + + // 响应拦截 + this.instance.interceptors.response.use( + (response: AxiosResponse) => { + return response; + }, + (error: AxiosError) => { + // 统一错误处理 + if (error.response) { + const { status, data } = error.response; + const errMessage = data?.message || error.message; + + switch (status) { + case 401: + console.log('401 error: ', errMessage, error.request); + console.log('responseURL', error.request.responseURL); + if (typeof window !== 'undefined') { + localStorage.removeItem('token'); + if (!error.request.responseURL.includes('/check-token')) { + window.location.href = '/login'; + } + } + break; + case 403: + console.error('Permission denied:', errMessage); + break; + case 500: + console.error('Server error:', errMessage); + break; + } + + return Promise.reject({ + code: data?.code || status, + message: errMessage, + data: data?.data || null, + }); + } + + return Promise.reject({ + code: -1, + message: error.message || 'Network Error', + data: null, + }); + }, + ); + } + + // 转换下划线为驼峰 + protected convertKeysToCamel(obj: JSONValue): JSONValue { + if (Array.isArray(obj)) { + return obj.map((v) => this.convertKeysToCamel(v)); + } else if (obj !== null && typeof obj === 'object') { + return Object.keys(obj).reduce((acc, key) => { + const camelKey = key.replace(/_([a-z])/g, (_, letter) => + letter.toUpperCase(), + ); + acc[camelKey] = this.convertKeysToCamel((obj as JSONObject)[key]); + return acc; + }, {} as JSONObject); + } + return obj; + } + + // 错误处理 + protected handleError(error: object): never { + if (axios.isCancel(error)) { + throw { code: -2, message: 'Request canceled', data: null }; + } + throw error; + } + + // 核心请求方法 + public async request(config: RequestConfig): Promise { + try { + const response = await this.instance.request>(config); + return response.data.data; + } catch (error) { + return this.handleError(error as object); + } + } + + // 快捷方法 + public get( + url: string, + params?: object, + config?: RequestConfig, + ): Promise { + return this.request({ method: 'get', url, params, ...config }); + } + + public post( + url: string, + data?: object, + config?: RequestConfig, + ): Promise { + return this.request({ method: 'post', url, data, ...config }); + } + + public put( + url: string, + data?: object, + config?: RequestConfig, + ): Promise { + return this.request({ method: 'put', url, data, ...config }); + } + + public delete(url: string, config?: RequestConfig): Promise { + return this.request({ method: 'delete', url, ...config }); + } + + public postFile( + url: string, + formData: FormData, + config?: RequestConfig, + ): Promise { + return this.request({ + method: 'post', + url, + data: formData, + headers: { + 'Content-Type': 'multipart/form-data', + }, + ...config, + }); + } +} diff --git a/web/src/app/infra/http/CloudServiceClient.ts b/web/src/app/infra/http/CloudServiceClient.ts new file mode 100644 index 00000000..f7491d5a --- /dev/null +++ b/web/src/app/infra/http/CloudServiceClient.ts @@ -0,0 +1,75 @@ +import { BaseHttpClient } from './BaseHttpClient'; +import { + ApiRespMarketplacePluginDetail, + ApiRespMarketplacePlugins, +} from '@/app/infra/entities/api'; + +/** + * 云服务客户端 + * 负责与 cloud service 的所有交互 + */ +export class CloudServiceClient extends BaseHttpClient { + constructor(baseURL: string = '') { + // cloud service 不需要 token 认证 + super(baseURL, true); + } + + public getMarketplacePlugins( + page: number, + page_size: number, + sort_by?: string, + sort_order?: string, + ): Promise { + return this.get('/api/v1/marketplace/plugins', { + page, + page_size, + sort_by, + sort_order, + }); + } + + public searchMarketplacePlugins( + query: string, + page: number, + page_size: number, + sort_by?: string, + sort_order?: string, + ): Promise { + return this.post( + '/api/v1/marketplace/plugins/search', + { + query, + page, + page_size, + sort_by, + sort_order, + }, + ); + } + + public getPluginDetail( + author: string, + pluginName: string, + ): Promise { + return this.get( + `/api/v1/marketplace/plugins/${author}/${pluginName}`, + ); + } + + public getPluginREADME( + author: string, + pluginName: string, + ): Promise<{ readme: string }> { + return this.get<{ readme: string }>( + `/api/v1/marketplace/plugins/${author}/${pluginName}/resources/README`, + ); + } + + public getPluginIconURL(author: string, name: string): string { + return `${this.baseURL}/api/v1/marketplace/plugins/${author}/${name}/resources/icon`; + } + + public getPluginMarketplaceURL(author: string, name: string): string { + return `${this.baseURL}/market?author=${author}&plugin=${name}`; + } +} diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index 9f5967d0..4e6f864f 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -1,741 +1,17 @@ -import axios, { - AxiosInstance, - AxiosRequestConfig, - AxiosResponse, - AxiosError, -} from 'axios'; -import { - ApiRespProviderRequesters, - ApiRespProviderRequester, - ApiRespProviderLLMModels, - ApiRespProviderLLMModel, - LLMModel, - ApiRespProviderEmbeddingModels, - ApiRespProviderEmbeddingModel, - EmbeddingModel, - ApiRespPipelines, - Pipeline, - ApiRespPlatformAdapters, - ApiRespPlatformAdapter, - ApiRespPlatformBots, - ApiRespPlatformBot, - Bot, - ApiRespPlugins, - ApiRespPlugin, - ApiRespPluginConfig, - PluginReorderElement, - AsyncTaskCreatedResp, - ApiRespSystemInfo, - ApiRespAsyncTasks, - ApiRespUserToken, - MarketPluginResponse, - GetPipelineResponseData, - GetPipelineMetadataResponseData, - AsyncTask, - ApiRespWebChatMessage, - ApiRespWebChatMessages, - ApiRespKnowledgeBases, - ApiRespKnowledgeBase, - KnowledgeBase, - ApiRespKnowledgeBaseFiles, - ApiRespKnowledgeBaseRetrieve, -} from '@/app/infra/entities/api'; -import { GetBotLogsRequest } from '@/app/infra/http/requestParam/bots/GetBotLogsRequest'; -import { GetBotLogsResponse } from '@/app/infra/http/requestParam/bots/GetBotLogsResponse'; - -type JSONValue = string | number | boolean | JSONObject | JSONArray | null; -interface JSONObject { - [key: string]: JSONValue; -} -type JSONArray = Array; - -export interface ResponseData { - code: number; - message: string; - data: T; - timestamp: number; -} - -export interface RequestConfig extends AxiosRequestConfig { - isSSR?: boolean; // 服务端渲染标识 - retry?: number; // 重试次数 -} - -export let systemInfo: ApiRespSystemInfo | null = null; - -class HttpClient { - private instance: AxiosInstance; - private disableToken: boolean = false; - private baseURL: string; - // 暂不需要SSR - // private ssrInstance: AxiosInstance | null = null - - constructor(baseURL: string, disableToken?: boolean) { - this.baseURL = baseURL; - this.instance = axios.create({ - baseURL: baseURL, - timeout: 15000, - headers: { - 'Content-Type': 'application/json', - }, - }); - this.disableToken = disableToken || false; - this.initInterceptors(); - - if (systemInfo === null && baseURL != 'https://space.langbot.app') { - this.getSystemInfo().then((res) => { - systemInfo = res; - }); - } - } - - // 外部获取baseURL的方法 - getBaseUrl(): string { - return this.baseURL; - } - - // 获取Session - private async getSession() { - // NOT IMPLEMENT - return ''; - } - - // 同步获取Session - private getSessionSync() { - // NOT IMPLEMENT - return localStorage.getItem('token'); - } - - // 拦截器配置 - private initInterceptors() { - // 请求拦截 - this.instance.interceptors.request.use( - async (config) => { - // 服务端请求自动携带 cookie, Langbot暂时用不到SSR相关 - // if (typeof window === 'undefined' && config.isSSR) { } - // cookie not required - // const { cookies } = await import('next/headers') - // config.headers.Cookie = cookies().toString() - - // 客户端添加认证头 - if (typeof window !== 'undefined' && !this.disableToken) { - const session = this.getSessionSync(); - config.headers.Authorization = `Bearer ${session}`; - } - - return config; - }, - (error) => Promise.reject(error), - ); - - // 响应拦截 - this.instance.interceptors.response.use( - (response: AxiosResponse) => { - // 响应拦截处理写在这里,暂无业务需要 - - return response; - }, - (error: AxiosError) => { - // 统一错误处理 - if (error.response) { - const { status, data } = error.response; - const errMessage = data?.message || error.message; - - switch (status) { - case 401: - console.log('401 error: ', errMessage, error.request); - console.log('responseURL', error.request.responseURL); - localStorage.removeItem('token'); - if (!error.request.responseURL.includes('/check-token')) { - window.location.href = '/login'; - } - break; - case 403: - console.error('Permission denied:', errMessage); - break; - case 500: - // NOTE: move to component layer for customized message? - // toast.error(errMessage); - console.error('Server error:', errMessage); - break; - } - - return Promise.reject({ - code: data?.code || status, - message: errMessage, - data: data?.data || null, - }); - } - - return Promise.reject({ - code: -1, - message: error.message || 'Network Error', - data: null, - }); - }, - ); - } - - // 转换下划线为驼峰 - private convertKeysToCamel(obj: JSONValue): JSONValue { - if (Array.isArray(obj)) { - return obj.map((v) => this.convertKeysToCamel(v)); - } else if (obj !== null && typeof obj === 'object') { - return Object.keys(obj).reduce((acc, key) => { - const camelKey = key.replace(/_([a-z])/g, (_, letter) => - letter.toUpperCase(), - ); - acc[camelKey] = this.convertKeysToCamel((obj as JSONObject)[key]); - return acc; - }, {} as JSONObject); - } - return obj; - } - - // 核心请求方法 - public async request(config: RequestConfig): Promise { - try { - // 这里未来如果需要SSR可以将前面替换为SSR的instance - const instance = config.isSSR ? this.instance : this.instance; - const response = await instance.request>(config); - return response.data.data; - } catch (error) { - return this.handleError(error as object); - } - } - - private handleError(error: object): never { - if (axios.isCancel(error)) { - throw { code: -2, message: 'Request canceled', data: null }; - } - throw error; - } - - // 快捷方法 - public get( - url: string, - params?: object, - config?: RequestConfig, - ) { - return this.request({ method: 'get', url, params, ...config }); - } - - public post(url: string, data?: object, config?: RequestConfig) { - return this.request({ method: 'post', url, data, ...config }); - } - - public put(url: string, data?: object, config?: RequestConfig) { - return this.request({ method: 'put', url, data, ...config }); - } - - public delete(url: string, config?: RequestConfig) { - return this.request({ method: 'delete', url, ...config }); - } - - // real api request implementation - // ============ Provider API ============ - public getProviderRequesters( - model_type: string, - ): Promise { - return this.get('/api/v1/provider/requesters', { type: model_type }); - } - - public getProviderRequester(name: string): Promise { - return this.get(`/api/v1/provider/requesters/${name}`); - } - - public getProviderRequesterIconURL(name: string): string { - if (this.instance.defaults.baseURL === '/') { - // 获取用户访问的URL - const url = window.location.href; - const baseURL = url.split('/').slice(0, 3).join('/'); - return `${baseURL}/api/v1/provider/requesters/${name}/icon`; - } - return ( - this.instance.defaults.baseURL + - `/api/v1/provider/requesters/${name}/icon` - ); - } - - // ============ Provider Model LLM ============ - public getProviderLLMModels(): Promise { - return this.get('/api/v1/provider/models/llm'); - } - - public getProviderLLMModel(uuid: string): Promise { - return this.get(`/api/v1/provider/models/llm/${uuid}`); - } - - public createProviderLLMModel(model: LLMModel): Promise { - return this.post('/api/v1/provider/models/llm', model); - } - - public deleteProviderLLMModel(uuid: string): Promise { - return this.delete(`/api/v1/provider/models/llm/${uuid}`); - } - - public updateProviderLLMModel( - uuid: string, - model: LLMModel, - ): Promise { - return this.put(`/api/v1/provider/models/llm/${uuid}`, model); - } - - public testLLMModel(uuid: string, model: LLMModel): Promise { - return this.post(`/api/v1/provider/models/llm/${uuid}/test`, model); - } - - // ============ Provider Model Embedding ============ - public getProviderEmbeddingModels(): Promise { - return this.get('/api/v1/provider/models/embedding'); - } - - public getProviderEmbeddingModel( - uuid: string, - ): Promise { - return this.get(`/api/v1/provider/models/embedding/${uuid}`); - } - - public createProviderEmbeddingModel(model: EmbeddingModel): Promise { - return this.post('/api/v1/provider/models/embedding', model); - } - - public deleteProviderEmbeddingModel(uuid: string): Promise { - return this.delete(`/api/v1/provider/models/embedding/${uuid}`); - } - - public updateProviderEmbeddingModel( - uuid: string, - model: EmbeddingModel, - ): Promise { - return this.put(`/api/v1/provider/models/embedding/${uuid}`, model); - } - - public testEmbeddingModel( - uuid: string, - model: EmbeddingModel, - ): Promise { - return this.post(`/api/v1/provider/models/embedding/${uuid}/test`, model); - } - - // ============ Pipeline API ============ - public getGeneralPipelineMetadata(): Promise { - // as designed, this method will be deprecated, and only for developer to check the prefered config schema - return this.get('/api/v1/pipelines/_/metadata'); - } - - public getPipelines( - sortBy?: string, - sortOrder?: string, - ): Promise { - const params = new URLSearchParams(); - if (sortBy) params.append('sort_by', sortBy); - if (sortOrder) params.append('sort_order', sortOrder); - const queryString = params.toString(); - return this.get(`/api/v1/pipelines${queryString ? `?${queryString}` : ''}`); - } - - public getPipeline(uuid: string): Promise { - return this.get(`/api/v1/pipelines/${uuid}`); - } - - public createPipeline(pipeline: Pipeline): Promise<{ - uuid: string; - }> { - return this.post('/api/v1/pipelines', pipeline); - } - - public updatePipeline(uuid: string, pipeline: Pipeline): Promise { - return this.put(`/api/v1/pipelines/${uuid}`, pipeline); - } - - public deletePipeline(uuid: string): Promise { - return this.delete(`/api/v1/pipelines/${uuid}`); - } - - // ============ Debug WebChat API ============ - public sendWebChatMessage( - sessionType: string, - messageChain: object[], - pipelineId: string, - timeout: number = 15000, - ): Promise { - return this.post( - `/api/v1/pipelines/${pipelineId}/chat/send`, - { - session_type: sessionType, - message: messageChain, - }, - { - timeout, - }, - ); - } - - public async sendStreamingWebChatMessage( - sessionType: string, - messageChain: object[], - pipelineId: string, - onMessage: (data: ApiRespWebChatMessage) => void, - onComplete: () => void, - onError: (error: Error) => void, - ): Promise { - try { - // 构造完整的URL,处理相对路径的情况 - let url = `${this.baseURL}/api/v1/pipelines/${pipelineId}/chat/send`; - if (this.baseURL === '/') { - // 获取用户访问的完整URL - const baseURL = window.location.origin; - url = `${baseURL}/api/v1/pipelines/${pipelineId}/chat/send`; - } - - // 使用fetch发送流式请求,因为axios在浏览器环境中不直接支持流式响应 - const response = await fetch(url, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${this.getSessionSync()}`, - }, - body: JSON.stringify({ - session_type: sessionType, - message: messageChain, - is_stream: true, - }), - }); - - if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`); - } - - if (!response.body) { - throw new Error('ReadableStream not supported'); - } - - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let buffer = ''; - - // 读取流式响应 - try { - while (true) { - const { done, value } = await reader.read(); - - if (done) { - onComplete(); - break; - } - - // 解码数据 - buffer += decoder.decode(value, { stream: true }); - - // 处理完整的JSON对象 - const lines = buffer.split('\n\n'); - buffer = lines.pop() || ''; - - for (const line of lines) { - if (line.startsWith('data:')) { - try { - const data = JSON.parse(line.slice(5)); - - if (data.type === 'end') { - // 流传输结束 - reader.cancel(); - onComplete(); - return; - } - if (data.type === 'start') { - console.log(data.type); - } - - if (data.message) { - // 处理消息数据 - onMessage(data); - } - } catch (error) { - console.error('Error parsing streaming data:', error); - } - } - } - } - } finally { - reader.releaseLock(); - } - } catch (error) { - onError(error as Error); - } - } - - public getWebChatHistoryMessages( - pipelineId: string, - sessionType: string, - ): Promise { - return this.get( - `/api/v1/pipelines/${pipelineId}/chat/messages/${sessionType}`, - ); - } - - public resetWebChatSession( - pipelineId: string, - sessionType: string, - ): Promise<{ message: string }> { - return this.post( - `/api/v1/pipelines/${pipelineId}/chat/reset/${sessionType}`, - ); - } - - // ============ Platform API ============ - public getAdapters(): Promise { - return this.get('/api/v1/platform/adapters'); - } - - public getAdapter(name: string): Promise { - return this.get(`/api/v1/platform/adapters/${name}`); - } - - public getAdapterIconURL(name: string): string { - if (this.instance.defaults.baseURL === '/') { - // 获取用户访问的URL - const url = window.location.href; - const baseURL = url.split('/').slice(0, 3).join('/'); - return `${baseURL}/api/v1/platform/adapters/${name}/icon`; - } - return ( - this.instance.defaults.baseURL + `/api/v1/platform/adapters/${name}/icon` - ); - } - - // ============ Platform Bots ============ - public getBots(): Promise { - return this.get('/api/v1/platform/bots'); - } - - public getBot(uuid: string): Promise { - return this.get(`/api/v1/platform/bots/${uuid}`); - } - - public createBot(bot: Bot): Promise<{ uuid: string }> { - return this.post('/api/v1/platform/bots', bot); - } - - public updateBot(uuid: string, bot: Bot): Promise { - return this.put(`/api/v1/platform/bots/${uuid}`, bot); - } - - public deleteBot(uuid: string): Promise { - return this.delete(`/api/v1/platform/bots/${uuid}`); - } - - public getBotLogs( - botId: string, - request: GetBotLogsRequest, - ): Promise { - return this.post(`/api/v1/platform/bots/${botId}/logs`, request); - } - - // ============ File management API ============ - public uploadDocumentFile(file: File): Promise<{ file_id: string }> { - const formData = new FormData(); - formData.append('file', file); - - return this.request<{ file_id: string }>({ - method: 'post', - url: '/api/v1/files/documents', - data: formData, - headers: { - 'Content-Type': 'multipart/form-data', - }, - }); - } - - // ============ Knowledge Base API ============ - public getKnowledgeBases(): Promise { - return this.get('/api/v1/knowledge/bases'); - } - - public getKnowledgeBase(uuid: string): Promise { - return this.get(`/api/v1/knowledge/bases/${uuid}`); - } - - public createKnowledgeBase(base: KnowledgeBase): Promise<{ uuid: string }> { - return this.post('/api/v1/knowledge/bases', base); - } - - public updateKnowledgeBase( - uuid: string, - base: KnowledgeBase, - ): Promise<{ uuid: string }> { - return this.put(`/api/v1/knowledge/bases/${uuid}`, base); - } - - public uploadKnowledgeBaseFile( - uuid: string, - file_id: string, - ): Promise { - return this.post(`/api/v1/knowledge/bases/${uuid}/files`, { - file_id, - }); - } - - public getKnowledgeBaseFiles( - uuid: string, - ): Promise { - return this.get(`/api/v1/knowledge/bases/${uuid}/files`); - } - - public deleteKnowledgeBaseFile( - uuid: string, - file_id: string, - ): Promise { - return this.delete(`/api/v1/knowledge/bases/${uuid}/files/${file_id}`); - } - - public deleteKnowledgeBase(uuid: string): Promise { - return this.delete(`/api/v1/knowledge/bases/${uuid}`); - } - - public retrieveKnowledgeBase( - uuid: string, - query: string, - ): Promise { - return this.post(`/api/v1/knowledge/bases/${uuid}/retrieve`, { query }); - } - - // ============ Plugins API ============ - public getPlugins(): Promise { - return this.get('/api/v1/plugins'); - } - - public getPlugin(author: string, name: string): Promise { - return this.get(`/api/v1/plugins/${author}/${name}`); - } - - public getPluginConfig( - author: string, - name: string, - ): Promise { - return this.get(`/api/v1/plugins/${author}/${name}/config`); - } - - public updatePluginConfig( - author: string, - name: string, - config: object, - ): Promise { - return this.put(`/api/v1/plugins/${author}/${name}/config`, config); - } - - public togglePlugin( - author: string, - name: string, - target_enabled: boolean, - ): Promise { - return this.put(`/api/v1/plugins/${author}/${name}/toggle`, { - target_enabled, - }); - } - - public reorderPlugins(plugins: PluginReorderElement[]): Promise { - return this.put('/api/v1/plugins/reorder', { plugins }); - } - - public updatePlugin( - author: string, - name: string, - ): Promise { - return this.post(`/api/v1/plugins/${author}/${name}/update`); - } - - public getMarketPlugins( - page: number, - page_size: number, - query: string, - sort_by: string = 'stars', - sort_order: string = 'DESC', - ): Promise { - return this.post(`/api/v1/market/plugins`, { - page, - page_size, - query, - sort_by, - sort_order, - }); - } - - public installPluginFromGithub( - source: string, - ): Promise { - return this.post('/api/v1/plugins/install/github', { source }); - } - - public removePlugin( - author: string, - name: string, - ): Promise { - return this.delete(`/api/v1/plugins/${author}/${name}`); - } - - // ============ System API ============ - public getSystemInfo(): Promise { - return this.get('/api/v1/system/info'); - } - - public getAsyncTasks(): Promise { - return this.get('/api/v1/system/tasks'); - } - - public getAsyncTask(id: number): Promise { - return this.get(`/api/v1/system/tasks/${id}`); - } - - // ============ User API ============ - public checkIfInited(): Promise<{ initialized: boolean }> { - return this.get('/api/v1/user/init'); - } - - public initUser(user: string, password: string): Promise { - return this.post('/api/v1/user/init', { user, password }); - } - - public authUser(user: string, password: string): Promise { - return this.post('/api/v1/user/auth', { user, password }); - } - - public checkUserToken(): Promise { - return this.get('/api/v1/user/check-token'); - } - - public resetPassword( - user: string, - recoveryKey: string, - newPassword: string, - ): Promise<{ user: string }> { - return this.post('/api/v1/user/reset-password', { - user, - recovery_key: recoveryKey, - new_password: newPassword, - }); - } - - public changePassword( - currentPassword: string, - newPassword: string, - ): Promise<{ user: string }> { - return this.post('/api/v1/user/change-password', { - current_password: currentPassword, - new_password: newPassword, - }); - } -} - -const getBaseURL = (): string => { - if (typeof window !== 'undefined' && process.env.NEXT_PUBLIC_API_BASE_URL) { - return process.env.NEXT_PUBLIC_API_BASE_URL; - } - - return '/'; -}; - -export const httpClient = new HttpClient(getBaseURL()); - -// 临时写法,未来两种Client都继承自HttpClient父类,不允许共享方法 -export const spaceClient = new HttpClient('https://space.langbot.app'); +/** + * @deprecated 此文件仅用于向后兼容。请使用新的 client: + * - import { backendClient } from '@/app/infra/http' + * - import { getCloudServiceClient } from '@/app/infra/http' + */ + +// 重新导出新的客户端实现,保持向后兼容 +export { + backendClient as httpClient, + systemInfo, + type ResponseData, + type RequestConfig, +} from './index'; + +// 为了兼容性,重新导出 BackendClient 作为 HttpClient +import { BackendClient } from './BackendClient'; +export const HttpClient = BackendClient; diff --git a/web/src/app/infra/http/README.md b/web/src/app/infra/http/README.md new file mode 100644 index 00000000..2a2e976b --- /dev/null +++ b/web/src/app/infra/http/README.md @@ -0,0 +1,72 @@ +# HTTP Client 架构说明 + +## 概述 + +HTTP Client 已经重构为更清晰的架构,将通用方法与业务逻辑分离,并为不同的服务创建了独立的客户端。 + +## 文件结构 + +- **BaseHttpClient.ts** - 基础 HTTP 客户端类,包含所有通用的 HTTP 方法和拦截器配置 +- **BackendClient.ts** - 后端服务客户端,处理与后端 API 的所有交互 +- **CloudServiceClient.ts** - 云服务客户端,处理与 cloud service 的交互(如插件市场) +- **index.ts** - 主入口文件,管理客户端实例的创建和导出 +- **HttpClient.ts** - 仅用于向后兼容的文件(已废弃) + +## 使用方法 + +### 新的推荐用法 + +```typescript +// 使用后端客户端 +import { backendClient } from '@/app/infra/http'; + +// 获取模型列表 +const models = await backendClient.getProviderLLMModels(); + +// 使用云服务客户端(异步方式,确保 URL 已初始化) +import { getCloudServiceClient } from '@/app/infra/http'; + +const cloudClient = await getCloudServiceClient(); +const marketPlugins = await cloudClient.getMarketPlugins(1, 10, 'search term'); + +// 使用云服务客户端(同步方式,可能使用默认 URL) +import { cloudServiceClient } from '@/app/infra/http'; + +const marketPlugins = await cloudServiceClient.getMarketPlugins( + 1, + 10, + 'search term', +); +``` + +### 向后兼容(不推荐) + +```typescript +// 旧的用法仍然可以工作 +import { httpClient, spaceClient } from '@/app/infra/http/HttpClient'; + +// httpClient 现在指向 backendClient +const models = await httpClient.getProviderLLMModels(); + +// spaceClient 现在指向 cloudServiceClient +const marketPlugins = await spaceClient.getMarketPlugins(1, 10, 'search term'); +``` + +## 特点 + +1. **清晰的职责分离** + - BaseHttpClient:通用 HTTP 功能 + - BackendClient:后端 API 业务逻辑 + - CloudServiceClient:云服务 API 业务逻辑 + +2. **自动初始化** + - 应用启动时自动从后端获取 cloud service URL + - 云服务客户端会自动更新 baseURL + +3. **类型安全** + - 所有方法都有完整的 TypeScript 类型定义 + - 请求和响应类型都从 `@/app/infra/entities/api` 导入 + +4. **向后兼容** + - 旧代码无需修改即可继续工作 + - 逐步迁移到新的 API diff --git a/web/src/app/infra/http/index.ts b/web/src/app/infra/http/index.ts new file mode 100644 index 00000000..df46dca0 --- /dev/null +++ b/web/src/app/infra/http/index.ts @@ -0,0 +1,87 @@ +import { BackendClient } from './BackendClient'; +import { CloudServiceClient } from './CloudServiceClient'; +import { ApiRespSystemInfo } from '@/app/infra/entities/api'; + +// 系统信息 +export let systemInfo: ApiRespSystemInfo = { + debug: false, + version: '', + enable_marketplace: true, + cloud_service_url: '', +}; + +/** + * 获取基础 URL + */ +const getBaseURL = (): string => { + if (typeof window !== 'undefined' && process.env.NEXT_PUBLIC_API_BASE_URL) { + return process.env.NEXT_PUBLIC_API_BASE_URL; + } + return '/'; +}; + +// 创建后端客户端实例 +export const backendClient = new BackendClient(getBaseURL()); + +// 创建云服务客户端实例(初始化时使用默认 URL) +export const cloudServiceClient = new CloudServiceClient( + 'https://space.langbot.app', +); + +// 应用启动时自动初始化系统信息 +if (typeof window !== 'undefined' && systemInfo.cloud_service_url === '') { + backendClient + .getSystemInfo() + .then((info) => { + systemInfo = info; + cloudServiceClient.updateBaseURL(info.cloud_service_url); + }) + .catch((error) => { + console.error('Failed to initialize system info on startup:', error); + }); +} + +/** + * 获取云服务客户端 + * 如果 cloud service URL 尚未初始化,会自动从后端获取 + */ +export const getCloudServiceClient = async (): Promise => { + if (systemInfo.cloud_service_url === '') { + try { + systemInfo = await backendClient.getSystemInfo(); + // 更新 cloud service client 的 baseURL + cloudServiceClient.updateBaseURL(systemInfo.cloud_service_url); + } catch (error) { + console.error('Failed to get system info:', error); + // 如果获取失败,继续使用默认 URL + } + } + return cloudServiceClient; +}; + +/** + * 获取云服务客户端(同步版本) + * 注意:如果 cloud service URL 尚未初始化,将使用默认 URL + */ +export const getCloudServiceClientSync = (): CloudServiceClient => { + return cloudServiceClient; +}; + +/** + * 手动初始化系统信息 + * 可以在应用启动时调用此方法预先获取系统信息 + */ +export const initializeSystemInfo = async (): Promise => { + try { + systemInfo = await backendClient.getSystemInfo(); + cloudServiceClient.updateBaseURL(systemInfo.cloud_service_url); + } catch (error) { + console.error('Failed to initialize system info:', error); + } +}; + +// 导出类型,以便其他地方使用 +export type { ResponseData, RequestConfig } from './BaseHttpClient'; +export { BaseHttpClient } from './BaseHttpClient'; +export { BackendClient } from './BackendClient'; +export { CloudServiceClient } from './CloudServiceClient'; diff --git a/web/src/hooks/useAsyncTask.ts b/web/src/hooks/useAsyncTask.ts new file mode 100644 index 00000000..085c3b97 --- /dev/null +++ b/web/src/hooks/useAsyncTask.ts @@ -0,0 +1,99 @@ +import { useState, useEffect, useRef } from 'react'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { AsyncTask } from '@/app/infra/entities/api'; + +export enum AsyncTaskStatus { + WAIT_INPUT = 'WAIT_INPUT', + RUNNING = 'RUNNING', + SUCCESS = 'SUCCESS', + ERROR = 'ERROR', +} + +export interface UseAsyncTaskOptions { + onSuccess?: () => void; + onError?: (error: string) => void; + pollInterval?: number; +} + +export interface UseAsyncTaskResult { + status: AsyncTaskStatus; + error: string | null; + startTask: (taskId: number) => void; + reset: () => void; +} + +export function useAsyncTask( + options: UseAsyncTaskOptions = {}, +): UseAsyncTaskResult { + const { onSuccess, onError, pollInterval = 1000 } = options; + + const [status, setStatus] = useState( + AsyncTaskStatus.WAIT_INPUT, + ); + const [error, setError] = useState(null); + const intervalRef = useRef(null); + const alreadySuccessRef = useRef(false); + + const clearPollingInterval = () => { + if (intervalRef.current) { + clearInterval(intervalRef.current); + intervalRef.current = null; + } + }; + + const reset = () => { + clearPollingInterval(); + setStatus(AsyncTaskStatus.WAIT_INPUT); + setError(null); + alreadySuccessRef.current = false; + }; + + const startTask = (taskId: number) => { + setStatus(AsyncTaskStatus.RUNNING); + setError(null); + alreadySuccessRef.current = false; + + const interval = setInterval(() => { + httpClient + .getAsyncTask(taskId) + .then((res: AsyncTask) => { + if (res.runtime.done) { + clearPollingInterval(); + if (res.runtime.exception) { + setError(res.runtime.exception); + setStatus(AsyncTaskStatus.ERROR); + onError?.(res.runtime.exception); + } else { + if (!alreadySuccessRef.current) { + alreadySuccessRef.current = true; + setStatus(AsyncTaskStatus.SUCCESS); + onSuccess?.(); + } + } + } + }) + .catch((error) => { + clearPollingInterval(); + const errorMessage = error.message || 'Unknown error'; + setError(errorMessage); + setStatus(AsyncTaskStatus.ERROR); + onError?.(errorMessage); + }); + }, pollInterval); + + intervalRef.current = interval; + }; + + useEffect(() => { + return () => { + clearPollingInterval(); + }; + }, []); + + return { + status, + error, + startTask, + reset, + }; +} diff --git a/web/src/i18n/I18nProvider.tsx b/web/src/i18n/I18nProvider.tsx index ef3ea0b7..55fcd4c8 100644 --- a/web/src/i18n/I18nProvider.tsx +++ b/web/src/i18n/I18nProvider.tsx @@ -2,7 +2,8 @@ import { ReactNode } from 'react'; import '@/i18n'; -import { I18nLabel } from '@/app/infra/entities/common'; +import { I18nObject } from '@/app/infra/entities/common'; +import i18n from 'i18next'; interface I18nProviderProps { children: ReactNode; @@ -11,10 +12,28 @@ interface I18nProviderProps { export default function I18nProvider({ children }: I18nProviderProps) { return <>{children}; } -export function i18nObj(i18nLabel: I18nLabel): string { - const language = localStorage.getItem('langbot_language'); - if ((language === 'zh-Hans' && i18nLabel.zh_Hans) || !i18nLabel.en_US) { - return i18nLabel.zh_Hans; - } - return i18nLabel.en_US; -} +// export function extractI18nObject(i18nLabel: I18nObject): string { +// const language = localStorage.getItem('langbot_language'); +// if ((language === 'zh-Hans' && i18nLabel.zh_Hans) || !i18nLabel.en_US) { +// return i18nLabel.zh_Hans; +// } +// return i18nLabel.en_US; +// } + +export const extractI18nObject = (i18nObject: I18nObject): string => { + // 根据当前语言返回对应的值, fallback优先级:en_US、zh_Hans、zh_Hant、ja_JP + const language = i18n.language.replace('-', '_'); + console.log('language:', language); + console.log('i18nObject:', i18nObject); + if (language === 'en_US' && i18nObject.en_US) return i18nObject.en_US; + if (language === 'zh_Hans' && i18nObject.zh_Hans) return i18nObject.zh_Hans; + if (language === 'zh_Hant' && i18nObject.zh_Hant) return i18nObject.zh_Hant; + if (language === 'ja_JP' && i18nObject.ja_JP) return i18nObject.ja_JP; + return ( + i18nObject.en_US || + i18nObject.zh_Hans || + i18nObject.zh_Hant || + i18nObject.ja_JP || + '' + ); +}; diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index be09bfe8..32a86763 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -160,7 +160,7 @@ const enUS = { marketplace: 'Marketplace', arrange: 'Sort Plugins', install: 'Install', - installFromGithub: 'Install Plugin from GitHub', + installPlugin: 'Install Plugin', onlySupportGithub: 'Currently only supports installation from GitHub', enterGithubLink: 'Enter GitHub link of the plugin', installing: 'Installing plugin...', @@ -188,15 +188,88 @@ const enUS = { saveConfig: 'Save Config', saving: 'Saving...', confirmDeletePlugin: - 'Are you sure you want to delete the plugin ({{author}}/{{name}})?', + 'Are you sure you want to delete the plugin ({{author}}/{{name}})? This will also delete the plugin configuration.', confirmDelete: 'Confirm Delete', deleteError: 'Delete failed: ', close: 'Close', deleteConfirm: 'Delete Confirmation', + deleteSuccess: 'Delete successful', modifyFailed: 'Modify failed: ', eventCount: 'Events: {{count}}', toolCount: 'Tools: {{count}}', starCount: 'Stars: {{count}}', + uploadLocal: 'Upload Local', + debugging: 'Debugging', + uploadLocalPlugin: 'Upload Local Plugin', + dragToUpload: 'Drag plugin file here to upload', + unsupportedFileType: + 'Unsupported file type, only .lbpkg and .zip files are supported', + uploadingPlugin: 'Uploading plugin...', + uploadSuccess: 'Upload successful', + uploadFailed: 'Upload failed', + selectFileToUpload: 'Select plugin file to upload', + askConfirm: 'Are you sure to install plugin "{{name}}" ({{version}})?', + fromGithub: 'From GitHub', + fromLocal: 'From Local', + fromMarketplace: 'From Marketplace', + componentsList: 'Components: ', + noComponents: 'No components', + delete: 'Delete Plugin', + update: 'Update Plugin', + updateConfirm: 'Update Confirmation', + confirmUpdatePlugin: + 'Are you sure you want to update the plugin ({{author}}/{{name}})?', + confirmUpdate: 'Confirm Update', + updating: 'Updating...', + updateSuccess: 'Plugin updated successfully', + updateError: 'Update failed: ', + saveConfigSuccessNormal: 'Configuration saved successfully', + saveConfigSuccessDebugPlugin: + 'Configuration saved successfully, please manually restart the plugin', + saveConfigError: 'Configuration save failed: ', + }, + market: { + searchPlaceholder: 'Search plugins...', + searchResults: 'Found {{count}} plugins', + totalPlugins: 'Total {{count}} plugins', + noPlugins: 'No plugins available', + noResults: 'No relevant plugins found', + loadingMore: 'Loading more...', + loading: 'Loading...', + allLoaded: 'All plugins displayed', + install: 'Install', + installConfirm: + 'Are you sure you want to install plugin "{{name}}" ({{version}})?', + downloadComplete: 'Plugin "{{name}}" download completed', + installFailed: 'Installation failed, please try again later', + loadFailed: 'Failed to get plugin list, please try again later', + noDescription: 'No description available', + notFound: 'Plugin information not found', + sortBy: 'Sort by', + sort: { + recentlyAdded: 'Recently Added', + recentlyUpdated: 'Recently Updated', + mostDownloads: 'Most Downloads', + leastDownloads: 'Least Downloads', + }, + downloads: 'downloads', + download: 'Download', + repository: 'Repository', + downloadFailed: 'Download failed', + noReadme: 'This plugin does not provide README documentation', + description: 'Description', + tags: 'Tags', + submissionTitle: 'You have a plugin submission under review: {{name}}', + submissionPending: 'Your plugin submission is under review: {{name}}', + submissionApproved: 'Your plugin submission has been approved: {{name}}', + submissionRejected: 'Your plugin submission has been rejected: {{name}}', + clickToRevoke: 'Revoke', + revokeSuccess: 'Revoke success', + revokeFailed: 'Revoke failed', + submissionDetails: 'Plugin Submission Details', + markAsRead: 'Mark as Read', + markAsReadSuccess: 'Marked as read', + markAsReadFailed: 'Mark as read failed', }, pipelines: { title: 'Pipelines', diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index 43b78f2a..36b17f00 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -161,7 +161,7 @@ const jaJP = { marketplace: 'プラグインマーケット', arrange: '並び替え', install: 'インストール', - installFromGithub: 'GitHubからプラグインをインストール', + installPlugin: 'プラグインをインストール', onlySupportGithub: '現在はGitHubからのインストールのみサポートしています', enterGithubLink: 'プラグインのGitHubリンクを入力してください', installing: 'プラグインをインストール中...', @@ -189,15 +189,89 @@ const jaJP = { saveConfig: '設定を保存', saving: '保存中...', confirmDeletePlugin: - 'プラグイン「{{author}}/{{name}}」を削除してもよろしいですか?', + 'プラグイン「{{author}}/{{name}}」を削除してもよろしいですか?この操作により、プラグインの設定も削除されます。', confirmDelete: '削除を確認', deleteError: '削除に失敗しました:', close: '閉じる', deleteConfirm: '削除の確認', + deleteSuccess: '削除に成功しました', modifyFailed: '変更に失敗しました:', eventCount: 'イベント:{{count}}', toolCount: 'ツール:{{count}}', starCount: 'スター:{{count}}', + uploadLocal: 'ローカルアップロード', + debugging: 'デバッグ中', + uploadLocalPlugin: 'ローカルプラグインのアップロード', + dragToUpload: 'ファイルをここにドラッグしてアップロード', + unsupportedFileType: + 'サポートされていないファイルタイプです。.lbpkg と .zip ファイルのみサポートされています', + uploadingPlugin: 'プラグインをアップロード中...', + uploadSuccess: 'アップロード成功', + uploadFailed: 'アップロード失敗', + selectFileToUpload: 'アップロードするプラグインファイルを選択', + askConfirm: 'プラグイン "{{name}}" ({{version}}) をインストールしますか?', + fromGithub: 'GitHubから', + fromLocal: 'ローカルから', + fromMarketplace: 'プラグインマーケットから', + componentsList: '部品:', + noComponents: '部品がありません', + delete: 'プラグインを削除', + update: 'プラグインを更新', + updateConfirm: '更新の確認', + confirmUpdatePlugin: + 'プラグイン「{{author}}/{{name}}」を更新してもよろしいですか?', + confirmUpdate: '更新を確認', + updating: '更新中...', + updateSuccess: 'プラグインの更新に成功しました', + updateError: '更新に失敗しました:', + saveConfigSuccessNormal: '設定を保存しました', + saveConfigSuccessDebugPlugin: + '設定を保存しました。手動でプラグインを再起動してください', + saveConfigError: '設定の保存に失敗しました:', + }, + market: { + searchPlaceholder: 'プラグインを検索...', + searchResults: '{{count}} 個のプラグインが見つかりました', + totalPlugins: '合計 {{count}} 個のプラグイン', + noPlugins: '利用可能なプラグインがありません', + noResults: '関連するプラグインが見つかりません', + loadingMore: 'さらに読み込み中...', + loading: '読み込み中...', + allLoaded: 'すべてのプラグインが表示されました', + install: 'インストール', + installConfirm: + 'プラグイン "{{name}}" ({{version}}) をインストールしますか?', + downloadComplete: 'プラグイン "{{name}}" のダウンロードが完了しました', + installFailed: 'インストールに失敗しました。後でもう一度お試しください', + loadFailed: + 'プラグインリストの取得に失敗しました。後でもう一度お試しください', + noDescription: '説明がありません', + notFound: 'プラグイン情報が見つかりません', + sortBy: '並び順', + sort: { + recentlyAdded: '最近追加', + recentlyUpdated: '最近更新', + mostDownloads: 'ダウンロード数多', + leastDownloads: 'ダウンロード数少', + }, + downloads: '回ダウンロード', + download: 'ダウンロード', + repository: 'リポジトリ', + downloadFailed: 'ダウンロード失敗', + noReadme: 'このプラグインはREADMEドキュメントを提供していません', + description: '説明', + tags: 'タグ', + submissionTitle: 'プラグインの提出が審査中です: {{name}}', + submissionPending: 'プラグインの提出が審査中です: {{name}}', + submissionApproved: 'プラグインの提出が承認されました: {{name}}', + submissionRejected: 'プラグインの提出が拒否されました: {{name}}', + clickToRevoke: '取り消し', + revokeSuccess: '取り消し成功', + revokeFailed: '取り消し失敗', + submissionDetails: 'プラグイン提出詳細', + markAsRead: '既読', + markAsReadSuccess: '既読に設定しました', + markAsReadFailed: '既読に設定に失敗しました', }, pipelines: { title: 'パイプライン', diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index adf20e26..f891dee9 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -156,7 +156,7 @@ const zhHans = { marketplace: '插件市场', arrange: '编排', install: '安装', - installFromGithub: '从 GitHub 安装插件', + installPlugin: '安装插件', onlySupportGithub: '目前仅支持从 GitHub 安装', enterGithubLink: '请输入插件的Github链接', installing: '正在安装插件...', @@ -183,15 +183,84 @@ const zhHans = { cancel: '取消', saveConfig: '保存配置', saving: '保存中...', - confirmDeletePlugin: '你确定要删除插件({{author}}/{{name}})吗?', + confirmDeletePlugin: + '你确定要删除插件({{author}}/{{name}})吗?这将同时删除插件的配置。', confirmDelete: '确认删除', deleteError: '删除失败:', close: '关闭', deleteConfirm: '删除确认', + deleteSuccess: '删除成功', modifyFailed: '修改失败:', eventCount: '事件:{{count}}', toolCount: '工具:{{count}}', starCount: '星标:{{count}}', + uploadLocal: '本地上传', + debugging: '调试中', + uploadLocalPlugin: '上传本地插件', + dragToUpload: '拖拽文件到此处上传', + unsupportedFileType: '不支持的文件类型,仅支持 .lbpkg 和 .zip 文件', + uploadingPlugin: '正在上传插件...', + uploadSuccess: '上传成功', + uploadFailed: '上传失败', + selectFileToUpload: '选择要上传的插件文件', + askConfirm: '确定要安装插件 "{{name}}" ({{version}}) 吗?', + fromGithub: '来自 GitHub', + fromLocal: '本地安装', + fromMarketplace: '来自市场', + componentsList: '组件: ', + noComponents: '无组件', + delete: '删除插件', + update: '更新插件', + updateConfirm: '更新确认', + confirmUpdatePlugin: '你确定要更新插件({{author}}/{{name}})吗?', + confirmUpdate: '确认更新', + updating: '更新中...', + updateSuccess: '插件更新成功', + updateError: '更新失败:', + saveConfigSuccessNormal: '保存配置成功', + saveConfigSuccessDebugPlugin: '保存配置成功,请手动重启插件', + saveConfigError: '保存配置失败:', + }, + market: { + searchPlaceholder: '搜索插件...', + searchResults: '搜索到 {{count}} 个插件', + totalPlugins: '共 {{count}} 个插件', + noPlugins: '暂无插件', + noResults: '未找到相关插件', + loadingMore: '加载更多...', + loading: '加载中...', + allLoaded: '已显示全部插件', + install: '安装', + installConfirm: '确定要安装插件 "{{name}}" ({{version}}) 吗?', + downloadComplete: '插件 "{{name}}" 下载完成', + installFailed: '安装失败,请稍后重试', + loadFailed: '获取插件列表失败,请稍后重试', + noDescription: '暂无描述', + notFound: '插件信息未找到', + sortBy: '排序方式', + sort: { + recentlyAdded: '最近新增', + recentlyUpdated: '最近更新', + mostDownloads: '最多下载', + leastDownloads: '最少下载', + }, + downloads: '次下载', + download: '下载', + repository: '代码仓库', + downloadFailed: '下载失败', + noReadme: '该插件没有提供 README 文档', + description: '描述', + tags: '标签', + submissionTitle: '您有插件提交正在审核中: {{name}}', + submissionApproved: '您的插件提交已通过审核: {{name}}', + submissionRejected: '您的插件提交已被拒绝: {{name}}', + clickToRevoke: '撤回', + revokeSuccess: '撤回成功', + revokeFailed: '撤回失败', + submissionDetails: '插件提交详情', + markAsRead: '已读', + markAsReadSuccess: '已标记为已读', + markAsReadFailed: '标记为已读失败', }, pipelines: { title: '流水线', diff --git a/web/src/i18n/locales/zh-Hant.ts b/web/src/i18n/locales/zh-Hant.ts index 0c90e1a3..5f9ce702 100644 --- a/web/src/i18n/locales/zh-Hant.ts +++ b/web/src/i18n/locales/zh-Hant.ts @@ -192,6 +192,73 @@ const zhHant = { eventCount: '事件:{{count}}', toolCount: '工具:{{count}}', starCount: '星標:{{count}}', + uploadLocal: '本地上傳', + debugging: '調試中', + uploadLocalPlugin: '上傳本地插件', + dragToUpload: '拖拽文件到此處上傳', + unsupportedFileType: '不支持的文件類型,僅支持 .lbpkg 和 .zip 文件', + uploadingPlugin: '正在上傳插件...', + uploadSuccess: '上傳成功', + uploadFailed: '上傳失敗', + selectFileToUpload: '選擇要上傳的插件文件', + askConfirm: '確定要安裝插件 "{{name}}" ({{version}}) 嗎?', + fromGithub: '來自 GitHub', + fromLocal: '本地安裝', + fromMarketplace: '來自市場', + componentsList: '組件: ', + noComponents: '無組件', + delete: '刪除插件', + update: '更新插件', + updateConfirm: '更新確認', + confirmUpdatePlugin: '您確定要更新插件({{author}}/{{name}})嗎?', + confirmUpdate: '確認更新', + updating: '更新中...', + updateSuccess: '插件更新成功', + updateError: '更新失敗:', + saveConfigSuccessNormal: '儲存配置成功', + saveConfigSuccessDebugPlugin: '儲存配置成功,請手動重啟插件', + saveConfigError: '儲存配置失敗:', + }, + market: { + searchPlaceholder: '搜尋插件...', + searchResults: '搜尋到 {{count}} 個插件', + totalPlugins: '共 {{count}} 個插件', + noPlugins: '暫無插件', + noResults: '未找到相關插件', + loadingMore: '載入更多...', + loading: '載入中...', + allLoaded: '已顯示全部插件', + install: '安裝', + installConfirm: '確定要安裝插件 "{{name}}" ({{version}}) 嗎?', + downloadComplete: '插件 "{{name}}" 下載完成', + installFailed: '安裝失敗,請稍後重試', + loadFailed: '取得插件列表失敗,請稍後重試', + noDescription: '暫無描述', + notFound: '插件資訊未找到', + sortBy: '排序方式', + sort: { + recentlyAdded: '最近新增', + recentlyUpdated: '最近更新', + mostDownloads: '最多下載', + leastDownloads: '最少下載', + }, + downloads: '次下載', + download: '下載', + repository: '代碼倉庫', + downloadFailed: '下載失敗', + noReadme: '該插件沒有提供 README 文件', + description: '描述', + tags: '標籤', + submissionTitle: '您有插件提交正在審核中: {{name}}', + submissionApproved: '您的插件提交已通過審核: {{name}}', + submissionRejected: '您的插件提交已被拒絕: {{name}}', + clickToRevoke: '撤回', + revokeSuccess: '撤回成功', + revokeFailed: '撤回失敗', + submissionDetails: '插件提交詳情', + markAsRead: '已讀', + markAsReadSuccess: '已標記為已讀', + markAsReadFailed: '標記為已讀失敗', }, pipelines: { title: '流程線',

+ ), + td: ({ ...props }) => ( + + ), + tr: ({ ...props }) => ( +