feat: new plugin system (#1495)

* deps: add `langbot-plugin`

* feat: connector for plugin runtime

* feat(plugin): basic communication

* feat: listing plugins

* feat: switch tool entities and format

* feat: switch Query to langbot-plugin definition

* chore: delete Query class

* feat: switch message platform adapters to sdk

* chore: remove adapter meta manifest from components.yaml

* feat: preliminary migration of events entities

* fix: serialization bug in events emitting

* feat: minor changes adapt to event emitting

* feat: adapt more events

* feat: switch all event emitting logic to new method

* refactor: use `emit_event` from connector

* feat: runtime reconnecting

* feat: add Tool component

* feat: switch command entities to sdk

* feat: command execution via plugin

* feat: `reply_message` api

* feat: get bot uuid api

* feat: query-based apis

* refactor: switch llm_entities to plugin sdk

* feat: backward call apis

* perf: longer timeout for emit_event

* feat: binary storage api

* feat(ui): list plugins

* feat: get plugin info

* feat: kill runtime process when exit in stdio mode

* perf: dispose process

* chore: bump langbot-plugin version to 0.1.1a1

* fix: message chain init

* feat: `get_bot_info` api

* feat: set cloud_service_url

* feat: refactor webui httpclient

* fix: bot switching

* feat: tag debugging plugins in webui

* feat: plugin installation

* feat: plugin installation webui

* feat: trace plugin installation

* feat: marketplace page

* perf: frontend

* fix: i18n fallback

* feat: plugin operations

* feat: plugin deletion and upgrade

* feat: setting plugin config

* feat: bump version of langbot-plugin

* chore: remove plugin reorder functionality

* chore: bump version 4.3.0b1

* chore: bump langbot_plugin version

* fix: conflict in table `plugin_settings`

* chore: bump version to '4.3.0b2'

* chore: bump version 4.3.0b3

* Update package.json (#1627)

* feat: change standalone runtime tag env

* fix: use --standalone-runtime

* feat: update docker launch method

* fix: change tag of image to `latest`

* perf: inline code display style in markdown

* fix: syntax errors

* fix: wrong migration target version

* fix: set plugin enabled=true as default

* fix: replace message_chain.has usage

* fix: dark mode for plugins management page

* fix: minor bugs

* fix: tool call params in localagent

* chore: bump version 4.3.0b4

* feat: available for disabling marketplace(offline env)

* perf: display installed plugin icon

* refactor: market plugin detail dialog

* perf: dark theme

* fix: cloudServiceClient api

* feat: supports for command return image base64

* chore: bump langbot_plugin to 0.1.1b6

* del self.ap error

* fix: dingtalk pydantic.BaseModel norm

* fix: wechatpad pydantic.BaseModel norm

* chore: move docker-compose.yaml for plugin edition

---------

Co-authored-by: How-Sean Xin <mcjiekejiemi@163.com>
Co-authored-by: fdc <2213070223@qq.com>
This commit is contained in:
Junyan Qin (Chin)
2025-09-12 23:00:49 +08:00
committed by GitHub
194 changed files with 5773 additions and 6629 deletions

View File

@@ -9,7 +9,6 @@ spec:
components: components:
ComponentTemplate: ComponentTemplate:
fromFiles: fromFiles:
- pkg/platform/adapter.yaml
- pkg/provider/modelmgr/requester.yaml - pkg/provider/modelmgr/requester.yaml
MessagePlatformAdapter: MessagePlatformAdapter:
fromDirs: fromDirs:

View File

@@ -1,3 +1,4 @@
# This file is deprecated, and will be replaced by docker/docker-compose.yaml in next version.
version: "3" version: "3"
services: services:
@@ -13,4 +14,4 @@ services:
ports: ports:
- 5300:5300 # 供 WebUI 使用 - 5300:5300 # 供 WebUI 使用
- 2280-2290:2280-2290 # 供消息平台适配器方向连接 - 2280-2290:2280-2290 # 供消息平台适配器方向连接
# 根据具体环境配置网络 # 根据具体环境配置网络

View File

@@ -0,0 +1,36 @@
version: "3"
services:
langbot_plugin_runtime:
image: rockchin/langbot:latest
container_name: langbot_plugin_runtime
volumes:
- ./data/plugins:/app/data/plugins
ports:
- 5401:5401
restart: on-failure
environment:
- TZ=Asia/Shanghai
command: ["uv", "run", "-m", "langbot_plugin.cli.__init__", "rt"]
networks:
- langbot_network
langbot:
image: rockchin/langbot:latest
container_name: langbot
volumes:
- ./data:/app/data
- ./plugins:/app/plugins
restart: on-failure
environment:
- TZ=Asia/Shanghai
ports:
- 5300:5300 # For web ui
- 2280-2290:2280-2290 # For platform webhook
networks:
- langbot_network
networks:
langbot_network:
driver: bridge

View File

@@ -3,7 +3,7 @@ from quart import request
import httpx import httpx
from quart import Quart from quart import Quart
from typing import Callable, Dict, Any from typing import Callable, Dict, Any
from pkg.platform.types import events as platform_events import langbot_plugin.api.entities.builtin.platform.events as platform_events
from .qqofficialevent import QQOfficialEvent from .qqofficialevent import QQOfficialEvent
import json import json
import traceback import traceback

View File

@@ -4,7 +4,7 @@ from quart import Quart, jsonify, request
from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.web.async_client import AsyncWebClient
from .slackevent import SlackEvent from .slackevent import SlackEvent
from typing import Callable from typing import Callable
from pkg.platform.types import events as platform_events import langbot_plugin.api.entities.builtin.platform.events as platform_events
class SlackClient: class SlackClient:

View File

@@ -8,7 +8,7 @@ from quart import Quart
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any from typing import Callable, Dict, Any
from .wecomevent import WecomEvent from .wecomevent import WecomEvent
from pkg.platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import aiofiles import aiofiles

View File

@@ -8,7 +8,7 @@ from quart import Quart
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Callable from typing import Callable
from .wecomcsevent import WecomCSEvent from .wecomcsevent import WecomCSEvent
from pkg.platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import aiofiles import aiofiles

16
main.py
View File

@@ -19,8 +19,14 @@ asciiart = r"""
async def main_entry(loop: asyncio.AbstractEventLoop): async def main_entry(loop: asyncio.AbstractEventLoop):
parser = argparse.ArgumentParser(description='LangBot') parser = argparse.ArgumentParser(description='LangBot')
parser.add_argument('--skip-plugin-deps-check', action='store_true', help='跳过插件依赖项检查', default=False) parser.add_argument('--skip-plugin-deps-check', action='store_true', help='跳过插件依赖项检查', default=False)
parser.add_argument('--standalone-runtime', action='store_true', help='使用独立插件运行时', default=False)
args = parser.parse_args() args = parser.parse_args()
if args.standalone_runtime:
from pkg.utils import platform
platform.standalone_runtime = True
print(asciiart) print(asciiart)
import sys import sys
@@ -47,13 +53,13 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
if not args.skip_plugin_deps_check: if not args.skip_plugin_deps_check:
await deps.precheck_plugin_deps() await deps.precheck_plugin_deps()
# 检查pydantic版本如果没有 pydantic.v1则把 pydantic 映射为 v1 # # 检查pydantic版本如果没有 pydantic.v1则把 pydantic 映射为 v1
import pydantic.version # import pydantic.version
if pydantic.version.VERSION < '2.0': # if pydantic.version.VERSION < '2.0':
import pydantic # import pydantic
sys.modules['pydantic.v1'] = pydantic # sys.modules['pydantic.v1'] = pydantic
# 检查配置文件 # 检查配置文件

View File

@@ -44,9 +44,9 @@ class WebChatDebugRouterGroup(group.RouterGroup):
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
'Transfer-Encoding': 'chunked', 'Transfer-Encoding': 'chunked',
'Cache-Control': 'no-cache', 'Cache-Control': 'no-cache',
'Connection': 'keep-alive' 'Connection': 'keep-alive',
} }
return quart.Response(stream_generator(generator), mimetype='text/event-stream',headers=headers) return quart.Response(stream_generator(generator), mimetype='text/event-stream', headers=headers)
else: # non-stream else: # non-stream
result = None result = None

View File

@@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
import base64
import quart import quart
from .....core import taskmgr from .....core import taskmgr
from .. import group from .. import group
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
@group.group_class('plugins', '/api/v1/plugins') @group.group_class('plugins', '/api/v1/plugins')
@@ -12,35 +13,22 @@ class PluginsRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) @self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str: async def _() -> str:
plugins = self.ap.plugin_mgr.plugins() plugins = await self.ap.plugin_connector.list_plugins()
plugins_data = [plugin.model_dump() for plugin in plugins] return self.success(data={'plugins': plugins})
return self.success(data={'plugins': plugins_data})
@self.route( @self.route(
'/<author>/<plugin_name>/toggle', '/<author>/<plugin_name>/upgrade',
methods=['PUT'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str:
data = await quart.request.json
target_enabled = data.get('target_enabled')
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
return self.success()
@self.route(
'/<author>/<plugin_name>/update',
methods=['POST'], methods=['POST'],
auth_type=group.AuthType.USER_TOKEN, auth_type=group.AuthType.USER_TOKEN,
) )
async def _(author: str, plugin_name: str) -> str: async def _(author: str, plugin_name: str) -> str:
ctx = taskmgr.TaskContext.new() ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task( wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx), self.ap.plugin_connector.upgrade_plugin(author, plugin_name, task_context=ctx),
kind='plugin-operation', kind='plugin-operation',
name=f'plugin-update-{plugin_name}', name=f'plugin-upgrade-{plugin_name}',
label=f'Updating plugin {plugin_name}', label=f'Upgrading plugin {plugin_name}',
context=ctx, context=ctx,
) )
return self.success(data={'task_id': wrapper.id}) return self.success(data={'task_id': wrapper.id})
@@ -52,14 +40,14 @@ class PluginsRouterGroup(group.RouterGroup):
) )
async def _(author: str, plugin_name: str) -> str: async def _(author: str, plugin_name: str) -> str:
if quart.request.method == 'GET': if quart.request.method == 'GET':
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name) plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
if plugin is None: if plugin is None:
return self.http_status(404, -1, 'plugin not found') return self.http_status(404, -1, 'plugin not found')
return self.success(data={'plugin': plugin.model_dump()}) return self.success(data={'plugin': plugin})
elif quart.request.method == 'DELETE': elif quart.request.method == 'DELETE':
ctx = taskmgr.TaskContext.new() ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task( wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx), self.ap.plugin_connector.delete_plugin(author, plugin_name, task_context=ctx),
kind='plugin-operation', kind='plugin-operation',
name=f'plugin-remove-{plugin_name}', name=f'plugin-remove-{plugin_name}',
label=f'Removing plugin {plugin_name}', label=f'Removing plugin {plugin_name}',
@@ -74,23 +62,32 @@ class PluginsRouterGroup(group.RouterGroup):
auth_type=group.AuthType.USER_TOKEN, auth_type=group.AuthType.USER_TOKEN,
) )
async def _(author: str, plugin_name: str) -> quart.Response: async def _(author: str, plugin_name: str) -> quart.Response:
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name) plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
if plugin is None: if plugin is None:
return self.http_status(404, -1, 'plugin not found') return self.http_status(404, -1, 'plugin not found')
if quart.request.method == 'GET': if quart.request.method == 'GET':
return self.success(data={'config': plugin.plugin_config}) return self.success(data={'config': plugin['plugin_config']})
elif quart.request.method == 'PUT': elif quart.request.method == 'PUT':
data = await quart.request.json data = await quart.request.json
await self.ap.plugin_mgr.set_plugin_config(plugin, data) await self.ap.plugin_connector.set_plugin_config(author, plugin_name, data)
return self.success(data={}) return self.success(data={})
@self.route('/reorder', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN) @self.route(
async def _() -> str: '/<author>/<plugin_name>/icon',
data = await quart.request.json methods=['GET'],
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins')) auth_type=group.AuthType.NONE,
return self.success() )
async def _(author: str, plugin_name: str) -> quart.Response:
icon_data = await self.ap.plugin_connector.get_plugin_icon(author, plugin_name)
icon_base64 = icon_data['plugin_icon_base64']
mime_type = icon_data['mime_type']
icon_data = base64.b64decode(icon_base64)
return quart.Response(icon_data, mimetype=mime_type)
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) @self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str: async def _() -> str:
@@ -102,7 +99,47 @@ class PluginsRouterGroup(group.RouterGroup):
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx), self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
kind='plugin-operation', kind='plugin-operation',
name='plugin-install-github', name='plugin-install-github',
label=f'Installing plugin ...{short_source_str}', label=f'Installing plugin from github ...{short_source_str}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/install/marketplace', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
data = await quart.request.json
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_connector.install_plugin(PluginInstallSource.MARKETPLACE, data, task_context=ctx),
kind='plugin-operation',
name='plugin-install-marketplace',
label=f'Installing plugin from marketplace ...{data}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/install/local', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
file = (await quart.request.files).get('file')
if file is None:
return self.http_status(400, -1, 'file is required')
file_bytes = file.read()
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
data = {
'plugin_file': file_base64,
}
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_connector.install_plugin(PluginInstallSource.LOCAL, data, task_context=ctx),
kind='plugin-operation',
name='plugin-install-local',
label=f'Installing plugin from local ...{file.filename}',
context=ctx, context=ctx,
) )

View File

@@ -14,6 +14,12 @@ class SystemRouterGroup(group.RouterGroup):
'version': constants.semantic_version, 'version': constants.semantic_version,
'debug': constants.debug_mode, 'debug': constants.debug_mode,
'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()), 'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()),
'enable_marketplace': self.ap.instance_config.data['plugin'].get('enable_marketplace', True),
'cloud_service_url': (
self.ap.instance_config.data['plugin']['cloud_service_url']
if 'cloud_service_url' in self.ap.instance_config.data['plugin']
else 'https://space.langbot.app'
),
} }
) )
@@ -35,16 +41,7 @@ class SystemRouterGroup(group.RouterGroup):
return self.success(data=task.to_dict()) return self.success(data=task.to_dict())
@self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) @self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
json_data = await quart.request.json
scope = json_data.get('scope')
await self.ap.reload(scope=scope)
return self.success()
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str: async def _() -> str:
if not constants.debug_mode: if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden') return self.http_status(403, 403, 'Forbidden')
@@ -54,3 +51,39 @@ class SystemRouterGroup(group.RouterGroup):
ap = self.ap ap = self.ap
return self.success(data=exec(py_code, {'ap': ap})) return self.success(data=exec(py_code, {'ap': ap}))
@self.route('/debug/tools/call', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
data = await quart.request.json
return self.success(
data=await self.ap.tool_mgr.execute_func_call(data['tool_name'], data['tool_parameters'])
)
@self.route(
'/debug/plugin/action',
methods=['POST'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
data = await quart.request.json
class AnoymousAction:
value = 'anonymous_action'
def __init__(self, value: str):
self.value = value
resp = await self.ap.plugin_connector.handler.call_action(
AnoymousAction(data['action']),
data['data'],
timeout=data.get('timeout', 10),
)
return self.success(data=resp)

View File

@@ -71,15 +71,15 @@ class UserRouterGroup(group.RouterGroup):
@self.route('/change-password', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) @self.route('/change-password', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _(user_email: str) -> str: async def _(user_email: str) -> str:
json_data = await quart.request.json json_data = await quart.request.json
current_password = json_data['current_password'] current_password = json_data['current_password']
new_password = json_data['new_password'] new_password = json_data['new_password']
try: try:
await self.ap.user_service.change_password(user_email, current_password, new_password) await self.ap.user_service.change_password(user_email, current_password, new_password)
except argon2.exceptions.VerifyMismatchError: except argon2.exceptions.VerifyMismatchError:
return self.http_status(400, -1, 'Current password is incorrect') return self.http_status(400, -1, 'Current password is incorrect')
except ValueError as e: except ValueError as e:
return self.http_status(400, -1, str(e)) return self.http_status(400, -1, str(e))
return self.success(data={'user': user_email}) return self.success(data={'user': user_email})

View File

@@ -17,16 +17,20 @@ class BotService:
def __init__(self, ap: app.Application) -> None: def __init__(self, ap: app.Application) -> None:
self.ap = ap self.ap = ap
async def get_bots(self) -> list[dict]: async def get_bots(self, include_secret: bool = True) -> list[dict]:
"""Get all bots""" """获取所有机器人"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot)) result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
bots = result.all() bots = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots] masked_columns = []
if not include_secret:
masked_columns = ['adapter_config']
async def get_bot(self, bot_uuid: str) -> dict | None: return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns) for bot in bots]
"""Get bot"""
async def get_bot(self, bot_uuid: str, include_secret: bool = True) -> dict | None:
"""获取机器人"""
result = await self.ap.persistence_mgr.execute_async( result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
) )
@@ -36,7 +40,27 @@ class BotService:
if bot is None: if bot is None:
return None return None
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) masked_columns = []
if not include_secret:
masked_columns = ['adapter_config']
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns)
async def get_runtime_bot_info(self, bot_uuid: str, include_secret: bool = True) -> dict:
"""获取机器人运行时信息"""
persistence_bot = await self.get_bot(bot_uuid, include_secret)
if persistence_bot is None:
raise Exception('Bot not found')
adapter_runtime_values = {}
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
if runtime_bot is not None:
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
return persistence_bot
async def create_bot(self, bot_data: dict) -> str: async def create_bot(self, bot_data: dict) -> str:
"""Create bot""" """Create bot"""

View File

@@ -7,7 +7,7 @@ from ....core import app
from ....entity.persistence import model as persistence_model from ....entity.persistence import model as persistence_model
from ....entity.persistence import pipeline as persistence_pipeline from ....entity.persistence import pipeline as persistence_pipeline
from ....provider.modelmgr import requester as model_requester from ....provider.modelmgr import requester as model_requester
from ....provider import entities as llm_entities from langbot_plugin.api.entities.builtin.provider import message as provider_message
class LLMModelsService: class LLMModelsService:
@@ -16,11 +16,19 @@ class LLMModelsService:
def __init__(self, ap: app.Application) -> None: def __init__(self, ap: app.Application) -> None:
self.ap = ap self.ap = ap
async def get_llm_models(self) -> list[dict]: async def get_llm_models(self, include_secret: bool = True) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
models = result.all() models = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models]
masked_columns = []
if not include_secret:
masked_columns = ['api_keys']
return [
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model, masked_columns)
for model in models
]
async def create_llm_model(self, model_data: dict) -> str: async def create_llm_model(self, model_data: dict) -> str:
model_data['uuid'] = str(uuid.uuid4()) model_data['uuid'] = str(uuid.uuid4())
@@ -99,7 +107,7 @@ class LLMModelsService:
await runtime_llm_model.requester.invoke_llm( await runtime_llm_model.requester.invoke_llm(
query=None, query=None,
model=runtime_llm_model, model=runtime_llm_model,
messages=[llm_entities.Message(role='user', content='Hello, world!')], messages=[provider_message.Message(role='user', content='Hello, world!')],
funcs=[], funcs=[],
extra_args=model_data.get('extra_args', {}), extra_args=model_data.get('extra_args', {}),
) )

View File

@@ -85,15 +85,15 @@ class UserService:
async def change_password(self, user_email: str, current_password: str, new_password: str) -> None: async def change_password(self, user_email: str, current_password: str, new_password: str) -> None:
ph = argon2.PasswordHasher() ph = argon2.PasswordHasher()
user_obj = await self.get_user_by_email(user_email) user_obj = await self.get_user_by_email(user_email)
if user_obj is None: if user_obj is None:
raise ValueError('User not found') raise ValueError('User not found')
ph.verify(user_obj.password, current_password) ph.verify(user_obj.password, current_password)
hashed_password = ph.hash(new_password) hashed_password = ph.hash(new_password)
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(user.User).where(user.User.user == user_email).values(password=hashed_password) sqlalchemy.update(user.User).where(user.User.user == user_email).values(password=hashed_password)
) )

View File

@@ -2,9 +2,12 @@ from __future__ import annotations
import typing import typing
from ..core import app, entities as core_entities from ..core import app
from . import entities, operator, errors from . import operator
from ..utils import importutil from ..utils import importutil
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
# 引入所有算子以便注册 # 引入所有算子以便注册
from . import operators from . import operators
@@ -13,13 +16,11 @@ importutil.import_modules_in_pkg(operators)
class CommandManager: class CommandManager:
"""命令管理器"""
ap: app.Application ap: app.Application
cmd_list: list[operator.CommandOperator] cmd_list: list[operator.CommandOperator]
""" """
运行时命令列表,扁平存储,各个对象包含对应的子节点引用 Runtime command list, flat storage, each object contains a reference to the corresponding child node
""" """
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
@@ -55,43 +56,28 @@ class CommandManager:
async def _execute( async def _execute(
self, self,
context: entities.ExecuteContext, context: command_context.ExecuteContext,
operator_list: list[operator.CommandOperator], operator_list: list[operator.CommandOperator],
operator: operator.CommandOperator = None, operator: operator.CommandOperator = None,
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""执行命令""" """执行命令"""
found = False command_list = await self.ap.plugin_connector.list_commands()
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
for oper in operator_list:
if (context.crt_params[0] == oper.name or context.crt_params[0] in oper.alias) and (
oper.parent_class is None or oper.parent_class == operator.__class__
):
found = True
context.crt_command = context.crt_params[0] for command in command_list:
context.crt_params = context.crt_params[1:] if command.metadata.name == context.command:
async for ret in self.ap.plugin_connector.execute_command(context):
async for ret in self._execute(context, oper.children, oper): yield ret
yield ret break
break else:
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(context.command))
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
if operator is None:
yield entities.CommandReturn(error=errors.CommandNotFoundError(context.crt_params[0]))
else:
if operator.lowest_privilege > context.privilege:
yield entities.CommandReturn(error=errors.CommandPrivilegeError(operator.name))
else:
async for ret in operator.execute(context):
yield ret
async def execute( async def execute(
self, self,
command_text: str, command_text: str,
query: core_entities.Query, query: pipeline_query.Query,
session: core_entities.Session, session: provider_session.Session,
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""执行命令""" """执行命令"""
privilege = 1 privilege = 1
@@ -99,8 +85,8 @@ class CommandManager:
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']: if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
privilege = 2 privilege = 2
ctx = entities.ExecuteContext( ctx = command_context.ExecuteContext(
query=query, query_id=query.query_id,
session=session, session=session,
command_text=command_text, command_text=command_text,
command='', command='',
@@ -110,5 +96,9 @@ class CommandManager:
privilege=privilege, privilege=privilege,
) )
ctx.command = ctx.params[0]
ctx.shift()
async for ret in self._execute(ctx, self.cmd_list): async for ret in self._execute(ctx, self.cmd_list):
yield ret yield ret

View File

@@ -1,74 +0,0 @@
from __future__ import annotations
import typing
import pydantic.v1 as pydantic
from ..core import entities as core_entities
from . import errors
from ..platform.types import message as platform_message
class CommandReturn(pydantic.BaseModel):
"""命令返回值"""
text: typing.Optional[str] = None
"""文本
"""
image: typing.Optional[platform_message.Image] = None
"""弃用"""
image_url: typing.Optional[str] = None
"""图片链接
"""
error: typing.Optional[errors.CommandError] = None
"""错误
"""
class Config:
arbitrary_types_allowed = True
class ExecuteContext(pydantic.BaseModel):
"""单次命令执行上下文"""
query: core_entities.Query
"""本次消息的请求对象"""
session: core_entities.Session
"""本次消息所属的会话对象"""
command_text: str
"""命令完整文本"""
command: str
"""命令名称"""
crt_command: str
"""当前命令
多级命令中crt_command为当前命令command为根命令。
例如:!plugin on Webwlkr
处理到plugin时command为plugincrt_command为plugin
处理到on时command为plugincrt_command为on
"""
params: list[str]
"""命令参数
整个命令以空格分割后的参数列表
"""
crt_params: list[str]
"""当前命令参数
多级命令中crt_params为当前命令参数params为根命令参数。
例如:!plugin on Webwlkr
处理到plugin时params为['on', 'Webwlkr']crt_params为['on', 'Webwlkr']
处理到on时params为['on', 'Webwlkr']crt_params为['Webwlkr']
"""
privilege: int
"""发起人权限"""

View File

@@ -1,26 +0,0 @@
class CommandError(Exception):
def __init__(self, message: str = None):
self.message = message
def __str__(self):
return self.message
class CommandNotFoundError(CommandError):
def __init__(self, message: str = None):
super().__init__('未知命令: ' + message)
class CommandPrivilegeError(CommandError):
def __init__(self, message: str = None):
super().__init__('权限不足: ' + message)
class ParamNotEnoughError(CommandError):
def __init__(self, message: str = None):
super().__init__('参数不足: ' + message)
class CommandOperationError(CommandError):
def __init__(self, message: str = None):
super().__init__('操作失败: ' + message)

View File

@@ -4,7 +4,7 @@ import typing
import abc import abc
from ..core import app from ..core import app
from . import entities from langbot_plugin.api.entities.builtin.command import context as command_context
preregistered_operators: list[typing.Type[CommandOperator]] = [] preregistered_operators: list[typing.Type[CommandOperator]] = []
@@ -95,16 +95,18 @@ class CommandOperator(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""实现此方法以执行命令 """实现此方法以执行命令
支持多次yield以返回多个结果。 支持多次yield以返回多个结果。
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。 例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
Args: Args:
context (entities.ExecuteContext): 命令执行上下文 context (command_context.ExecuteContext): 命令执行上下文
Yields: Yields:
entities.CommandReturn: 命令返回封装 command_context.CommandReturn: 命令返回封装
""" """
pass pass

View File

@@ -2,14 +2,17 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>') @operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>')
class CmdOperator(operator.CommandOperator): class CmdOperator(operator.CommandOperator):
"""命令列表""" """命令列表"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""执行""" """执行"""
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
reply_str = '当前所有命令: \n\n' reply_str = '当前所有命令: \n\n'
@@ -20,7 +23,7 @@ class CmdOperator(operator.CommandOperator):
reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助' reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助'
yield entities.CommandReturn(text=reply_str.strip()) yield command_context.CommandReturn(text=reply_str.strip())
else: else:
cmd_name = context.crt_params[0] cmd_name = context.crt_params[0]
@@ -33,9 +36,9 @@ class CmdOperator(operator.CommandOperator):
break break
if cmd is None: if cmd is None:
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name)) yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(cmd_name))
else: else:
reply_str = f'{cmd.name}: {cmd.help}\n\n' reply_str = f'{cmd.name}: {cmd.help}\n\n'
reply_str += f'使用方法: \n{cmd.usage}' reply_str += f'使用方法: \n{cmd.usage}'
yield entities.CommandReturn(text=reply_str.strip()) yield command_context.CommandReturn(text=reply_str.strip())

View File

@@ -2,23 +2,26 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all') @operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all')
class DelOperator(operator.CommandOperator): class DelOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if context.session.conversations: if context.session.conversations:
delete_index = 0 delete_index = 0
if len(context.crt_params) > 0: if len(context.crt_params) > 0:
try: try:
delete_index = int(context.crt_params[0]) delete_index = int(context.crt_params[0])
except Exception: except Exception:
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引必须是整数'))
return return
if delete_index < 0 or delete_index >= len(context.session.conversations): if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引超出范围'))
return return
# 倒序 # 倒序
@@ -29,15 +32,17 @@ class DelOperator(operator.CommandOperator):
del context.session.conversations[to_delete_index] del context.session.conversations[to_delete_index]
yield entities.CommandReturn(text=f'已删除对话: {delete_index}') yield command_context.CommandReturn(text=f'已删除对话: {delete_index}')
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
@operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator) @operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator)
class DelAllOperator(operator.CommandOperator): class DelAllOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
context.session.conversations = [] context.session.conversations = []
context.session.using_conversation = None context.session.using_conversation = None
yield entities.CommandReturn(text='已删除所有对话') yield command_context.CommandReturn(text='已删除所有对话')

View File

@@ -1,19 +1,20 @@
from __future__ import annotations from __future__ import annotations
from typing import AsyncGenerator from typing import AsyncGenerator
from .. import operator, entities from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func') @operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
class FuncOperator(operator.CommandOperator): class FuncOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> AsyncGenerator[command_context.CommandReturn, None]:
reply_str = '当前已启用的内容函数: \n\n' reply_str = '当前已启用的内容函数: \n\n'
index = 1 index = 1
all_functions = await self.ap.tool_mgr.get_all_functions( all_functions = await self.ap.tool_mgr.get_all_tools()
plugin_enabled=True,
)
for func in all_functions: for func in all_functions:
reply_str += '{}. {}:\n{}\n\n'.format( reply_str += '{}. {}:\n{}\n\n'.format(
@@ -23,4 +24,4 @@ class FuncOperator(operator.CommandOperator):
) )
index += 1 index += 1
yield entities.CommandReturn(text=reply_str) yield command_context.CommandReturn(text=reply_str)

View File

@@ -2,14 +2,17 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>') @operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
class HelpOperator(operator.CommandOperator): class HelpOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接https://langbot.app' help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接https://langbot.app'
help += '\n发送命令 !cmd 可查看命令列表' help += '\n发送命令 !cmd 可查看命令列表'
yield entities.CommandReturn(text=help) yield command_context.CommandReturn(text=help)

View File

@@ -3,26 +3,31 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last') @operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
class LastOperator(operator.CommandOperator): class LastOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if context.session.conversations: if context.session.conversations:
# 找到当前会话的上一个会话 # 找到当前会话的上一个会话
for index in range(len(context.session.conversations) - 1, -1, -1): for index in range(len(context.session.conversations) - 1, -1, -1):
if context.session.conversations[index] == context.session.using_conversation: if context.session.conversations[index] == context.session.using_conversation:
if index == 0: if index == 0:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了')) yield command_context.CommandReturn(
error=command_errors.CommandOperationError('已经是第一个对话了')
)
return return
else: else:
context.session.using_conversation = context.session.conversations[index - 1] context.session.using_conversation = context.session.conversations[index - 1]
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S') time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
yield entities.CommandReturn( yield command_context.CommandReturn(
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}' text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
) )
return return
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))

View File

@@ -2,19 +2,22 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>') @operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>')
class ListOperator(operator.CommandOperator): class ListOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
page = 0 page = 0
if len(context.crt_params) > 0: if len(context.crt_params) > 0:
try: try:
page = int(context.crt_params[0] - 1) page = int(context.crt_params[0] - 1)
except Exception: except Exception:
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('页码应为整数'))
return return
record_per_page = 10 record_per_page = 10
@@ -45,4 +48,4 @@ class ListOperator(operator.CommandOperator):
else: else:
content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}' content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}'
yield entities.CommandReturn(text=f'{page + 1} 页 (时间倒序):\n{content}') yield command_context.CommandReturn(text=f'{page + 1} 页 (时间倒序):\n{content}')

View File

@@ -2,26 +2,31 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next') @operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
class NextOperator(operator.CommandOperator): class NextOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if context.session.conversations: if context.session.conversations:
# 找到当前会话的下一个会话 # 找到当前会话的下一个会话
for index in range(len(context.session.conversations)): for index in range(len(context.session.conversations)):
if context.session.conversations[index] == context.session.using_conversation: if context.session.conversations[index] == context.session.using_conversation:
if index == len(context.session.conversations) - 1: if index == len(context.session.conversations) - 1:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了')) yield command_context.CommandReturn(
error=command_errors.CommandOperationError('已经是最后一个对话了')
)
return return
else: else:
context.session.using_conversation = context.session.conversations[index + 1] context.session.using_conversation = context.session.conversations[index + 1]
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S') time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
yield entities.CommandReturn( yield command_context.CommandReturn(
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}' text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
) )
return return
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
import typing import typing
import traceback import traceback
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class( @operator.operator_class(
@@ -11,7 +12,9 @@ from .. import operator, entities, errors
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>', usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
) )
class PluginOperator(operator.CommandOperator): class PluginOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
plugin_list = self.ap.plugin_mgr.plugins() plugin_list = self.ap.plugin_mgr.plugins()
reply_str = '所有插件({}):\n'.format(len(plugin_list)) reply_str = '所有插件({}):\n'.format(len(plugin_list))
idx = 0 idx = 0
@@ -27,32 +30,36 @@ class PluginOperator(operator.CommandOperator):
idx += 1 idx += 1
yield entities.CommandReturn(text=reply_str) yield command_context.CommandReturn(text=reply_str)
@operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator) @operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator)
class PluginGetOperator(operator.CommandOperator): class PluginGetOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址')) yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件仓库地址'))
else: else:
repo = context.crt_params[0] repo = context.crt_params[0]
yield entities.CommandReturn(text='正在安装插件...') yield command_context.CommandReturn(text='正在安装插件...')
try: try:
await self.ap.plugin_mgr.install_plugin(repo) await self.ap.plugin_mgr.install_plugin(repo)
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件') yield command_context.CommandReturn(text='插件安装成功,请重启程序以加载插件')
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件安装失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件安装失败: ' + str(e)))
@operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator) @operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator)
class PluginUpdateOperator(operator.CommandOperator): class PluginUpdateOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
@@ -60,24 +67,26 @@ class PluginUpdateOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None: if plugin_container is not None:
yield entities.CommandReturn(text='正在更新插件...') yield command_context.CommandReturn(text='正在更新插件...')
await self.ap.plugin_mgr.update_plugin(plugin_name) await self.ap.plugin_mgr.update_plugin(plugin_name)
yield entities.CommandReturn(text='插件更新成功,请重启程序以加载插件') yield command_context.CommandReturn(text='插件更新成功,请重启程序以加载插件')
else: else:
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: 未找到插件')) yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: 未找到插件'))
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
@operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator) @operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator)
class PluginUpdateAllOperator(operator.CommandOperator): class PluginUpdateAllOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
try: try:
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()] plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
if plugins: if plugins:
yield entities.CommandReturn(text='正在更新插件...') yield command_context.CommandReturn(text='正在更新插件...')
updated = [] updated = []
try: try:
for plugin_name in plugins: for plugin_name in plugins:
@@ -85,20 +94,22 @@ class PluginUpdateAllOperator(operator.CommandOperator):
updated.append(plugin_name) updated.append(plugin_name)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
yield entities.CommandReturn(text='已更新插件: {}'.format(', '.join(updated))) yield command_context.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
else: else:
yield entities.CommandReturn(text='没有可更新的插件') yield command_context.CommandReturn(text='没有可更新的插件')
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
@operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator) @operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator)
class PluginDelOperator(operator.CommandOperator): class PluginDelOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
@@ -106,51 +117,55 @@ class PluginDelOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None: if plugin_container is not None:
yield entities.CommandReturn(text='正在删除插件...') yield command_context.CommandReturn(text='正在删除插件...')
await self.ap.plugin_mgr.uninstall_plugin(plugin_name) await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
yield entities.CommandReturn(text='插件删除成功,请重启程序以加载插件') yield command_context.CommandReturn(text='插件删除成功,请重启程序以加载插件')
else: else:
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: 未找到插件')) yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: 未找到插件'))
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: ' + str(e)))
@operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator) @operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator)
class PluginEnableOperator(operator.CommandOperator): class PluginEnableOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True): if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
yield entities.CommandReturn(text='已启用插件: {}'.format(plugin_name)) yield command_context.CommandReturn(text='已启用插件: {}'.format(plugin_name))
else: else:
yield entities.CommandReturn( yield command_context.CommandReturn(
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name)) error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
) )
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))
@operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator) @operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator)
class PluginDisableOperator(operator.CommandOperator): class PluginDisableOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False): if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
yield entities.CommandReturn(text='已禁用插件: {}'.format(plugin_name)) yield command_context.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
else: else:
yield entities.CommandReturn( yield command_context.CommandReturn(
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name)) error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
) )
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))

View File

@@ -2,19 +2,22 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt') @operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
class PromptOperator(operator.CommandOperator): class PromptOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""执行""" """执行"""
if context.session.using_conversation is None: if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
else: else:
reply_str = '当前对话所有内容:\n\n' reply_str = '当前对话所有内容:\n\n'
for msg in context.session.using_conversation.messages: for msg in context.session.using_conversation.messages:
reply_str += f'{msg.role}: {msg.content}\n' reply_str += f'{msg.role}: {msg.content}\n'
yield entities.CommandReturn(text=reply_str) yield command_context.CommandReturn(text=reply_str)

View File

@@ -2,15 +2,18 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend') @operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend')
class ResendOperator(operator.CommandOperator): class ResendOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
# 回滚到最后一条用户message前 # 回滚到最后一条用户message前
if context.session.using_conversation is None: if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandError('当前没有对话')) yield command_context.CommandReturn(error=command_errors.CommandError('当前没有对话'))
else: else:
conv_msg = context.session.using_conversation.messages conv_msg = context.session.using_conversation.messages
@@ -23,4 +26,4 @@ class ResendOperator(operator.CommandOperator):
conv_msg.pop() conv_msg.pop()
# 不重发了,提示用户已删除就行了 # 不重发了,提示用户已删除就行了
yield entities.CommandReturn(text='已删除最后一次请求记录') yield command_context.CommandReturn(text='已删除最后一次请求记录')

View File

@@ -2,13 +2,16 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset') @operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
class ResetOperator(operator.CommandOperator): class ResetOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""执行""" """执行"""
context.session.using_conversation = None context.session.using_conversation = None
yield entities.CommandReturn(text='已重置当前会话') yield command_context.CommandReturn(text='已重置当前会话')

View File

@@ -1,11 +0,0 @@
from __future__ import annotations
import typing
from .. import operator, entities
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
class UpdateCommand(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
yield entities.CommandReturn(text='不再支持通过命令更新,请查看 LangBot 文档。')

View File

@@ -2,12 +2,15 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
@operator.operator_class(name='version', help='显示版本信息', usage='!version') @operator.operator_class(name='version', help='显示版本信息', usage='!version')
class VersionCommand(operator.CommandOperator): class VersionCommand(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}' reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
try: try:
@@ -16,4 +19,4 @@ class VersionCommand(operator.CommandOperator):
except Exception: except Exception:
pass pass
yield entities.CommandReturn(text=reply_str.strip()) yield command_context.CommandReturn(text=reply_str.strip())

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import traceback import traceback
import sys
import os import os
from ..platform import botmgr as im_mgr from ..platform import botmgr as im_mgr
@@ -12,7 +11,7 @@ from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.tools import toolmgr as llm_tool_mgr from ..provider.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr from ..config import manager as config_mgr
from ..command import cmdmgr from ..command import cmdmgr
from ..plugin import manager as plugin_mgr from ..plugin import connector as plugin_connector
from ..pipeline import pool from ..pipeline import pool
from ..pipeline import controller, pipelinemgr from ..pipeline import controller, pipelinemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
@@ -80,7 +79,7 @@ class Application:
# ========================= # =========================
plugin_mgr: plugin_mgr.PluginManager = None plugin_connector: plugin_connector.PluginRuntimeConnector = None
query_pool: pool.QueryPool = None query_pool: pool.QueryPool = None
@@ -128,7 +127,7 @@ class Application:
async def run(self): async def run(self):
try: try:
await self.plugin_mgr.initialize_plugins() await self.plugin_connector.initialize_plugins()
# 后续可能会允许动态重启其他任务 # 后续可能会允许动态重启其他任务
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程 # 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
@@ -169,6 +168,9 @@ class Application:
self.logger.error(f'Application runtime fatal exception: {e}') self.logger.error(f'Application runtime fatal exception: {e}')
self.logger.debug(f'Traceback: {traceback.format_exc()}') self.logger.debug(f'Traceback: {traceback.format_exc()}')
def dispose(self):
self.plugin_connector.dispose()
async def print_web_access_info(self): async def print_web_access_info(self):
"""Print access webui tips""" """Print access webui tips"""
@@ -195,59 +197,3 @@ class Application:
""".strip() """.strip()
for line in tips.split('\n'): for line in tips.split('\n'):
self.logger.info(line) self.logger.info(line)
async def reload(
self,
scope: core_entities.LifecycleControlScope,
):
match scope:
case core_entities.LifecycleControlScope.PLATFORM.value:
self.logger.info('Hot reload scope=' + scope)
await self.platform_mgr.shutdown()
self.platform_mgr = im_mgr.PlatformManager(self)
await self.platform_mgr.initialize()
self.task_mgr.create_task(
self.platform_mgr.run(),
name='platform-manager',
scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
],
)
case core_entities.LifecycleControlScope.PLUGIN.value:
self.logger.info('Hot reload scope=' + scope)
await self.plugin_mgr.destroy_plugins()
# 删除 sys.module 中所有的 plugins/* 下的模块
for mod in list(sys.modules.keys()):
if mod.startswith('plugins.'):
del sys.modules[mod]
self.plugin_mgr = plugin_mgr.PluginManager(self)
await self.plugin_mgr.initialize()
await self.plugin_mgr.initialize_plugins()
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins()
case core_entities.LifecycleControlScope.PROVIDER.value:
self.logger.info('Hot reload scope=' + scope)
await self.tool_mgr.shutdown()
llm_model_mgr_inst = llm_model_mgr.ModelManager(self)
await llm_model_mgr_inst.initialize()
self.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(self)
await llm_session_mgr_inst.initialize()
self.sess_mgr = llm_session_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
await llm_tool_mgr_inst.initialize()
self.tool_mgr = llm_tool_mgr_inst
case _:
pass

View File

@@ -51,8 +51,8 @@ async def main(loop: asyncio.AbstractEventLoop):
import signal import signal
def signal_handler(sig, frame): def signal_handler(sig, frame):
app_inst.dispose()
print('[Signal] Program exit.') print('[Signal] Program exit.')
# ap.shutdown()
os._exit(0) os._exit(0)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)

View File

@@ -1,18 +1,6 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
import typing
import datetime
import asyncio
import pydantic.v1 as pydantic
from ..provider import entities as llm_entities
from ..provider.modelmgr import requester
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
class LifecycleControlScope(enum.Enum): class LifecycleControlScope(enum.Enum):
@@ -20,159 +8,3 @@ class LifecycleControlScope(enum.Enum):
PLATFORM = 'platform' PLATFORM = 'platform'
PLUGIN = 'plugin' PLUGIN = 'plugin'
PROVIDER = 'provider' PROVIDER = 'provider'
class LauncherTypes(enum.Enum):
"""一个请求的发起者类型"""
PERSON = 'person'
"""私聊"""
GROUP = 'group'
"""群聊"""
class Query(pydantic.BaseModel):
"""一次请求的信息封装"""
query_id: int
"""请求ID添加进请求池时生成"""
launcher_type: LauncherTypes
"""会话类型platform处理阶段设置"""
launcher_id: typing.Union[int, str]
"""会话IDplatform处理阶段设置"""
sender_id: typing.Union[int, str]
"""发送者IDplatform处理阶段设置"""
message_event: platform_events.MessageEvent
"""事件platform收到的原始事件"""
message_chain: platform_message.MessageChain
"""消息链platform收到的原始消息链"""
bot_uuid: typing.Optional[str] = None
"""机器人UUID。"""
pipeline_uuid: typing.Optional[str] = None
"""流水线UUID。"""
pipeline_config: typing.Optional[dict[str, typing.Any]] = None
"""流水线配置,由 Pipeline 在运行开始时设置。"""
adapter: msadapter.MessagePlatformAdapter
"""消息平台适配器对象单个app中可能启用了多个消息平台适配器此对象表明发起此query的适配器"""
session: typing.Optional[Session] = None
"""会话对象,由前置处理器阶段设置"""
messages: typing.Optional[list[llm_entities.Message]] = []
"""历史消息列表,由前置处理器阶段设置"""
prompt: typing.Optional[llm_entities.Prompt] = None
"""情景预设内容,由前置处理器阶段设置"""
user_message: typing.Optional[llm_entities.Message] = None
"""此次请求的用户消息对象,由前置处理器阶段设置"""
variables: typing.Optional[dict[str, typing.Any]] = None
"""变量由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。"""
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
"""使用的对话模型,由前置处理器阶段设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置"""
resp_messages: (
typing.Optional[list[llm_entities.Message]]
| typing.Optional[list[platform_message.MessageChain]]
| typing.Optional[list[llm_entities.MessageChunk]]
) = []
"""由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
"""回复消息链从resp_messages包装而得"""
# ======= 内部保留 =======
current_stage: typing.Optional['pkg.pipeline.pipelinemgr.StageInstContainer'] = None
"""当前所处阶段"""
class Config:
arbitrary_types_allowed = True
# ========== 插件可调用的 API请求 API ==========
def set_variable(self, key: str, value: typing.Any):
"""设置变量"""
if self.variables is None:
self.variables = {}
self.variables[key] = value
def get_variable(self, key: str) -> typing.Any:
"""获取变量"""
if self.variables is None:
return None
return self.variables.get(key)
def get_variables(self) -> dict[str, typing.Any]:
"""获取所有变量"""
if self.variables is None:
return {}
return self.variables
class Conversation(pydantic.BaseModel):
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation但只有一个当前使用的 Conversation"""
prompt: llm_entities.Prompt
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
pipeline_uuid: str
"""流水线UUID。"""
bot_uuid: str
"""机器人UUID。"""
uuid: typing.Optional[str] = None
"""该对话的 uuid在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。"""
class Config:
arbitrary_types_allowed = True
class Session(pydantic.BaseModel):
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
launcher_type: LauncherTypes
launcher_id: typing.Union[int, str]
sender_id: typing.Optional[typing.Union[int, str]] = 0
use_prompt_name: typing.Optional[str] = 'default'
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发"""
class Config:
arbitrary_types_allowed = True

View File

@@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from .. import stage, app from .. import stage, app
from ...utils import version, proxy, announce from ...utils import version, proxy, announce
from ...pipeline import pool, controller, pipelinemgr from ...pipeline import pool, controller, pipelinemgr
from ...plugin import manager as plugin_mgr from ...plugin import connector as plugin_connector
from ...command import cmdmgr from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr
@@ -62,10 +63,13 @@ class BuildAppStage(stage.BootingStage):
ap.persistence_mgr = persistence_mgr_inst ap.persistence_mgr = persistence_mgr_inst
await persistence_mgr_inst.initialize() await persistence_mgr_inst.initialize()
plugin_mgr_inst = plugin_mgr.PluginManager(ap) async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None:
await plugin_mgr_inst.initialize() await asyncio.sleep(3)
ap.plugin_mgr = plugin_mgr_inst await plugin_connector_inst.initialize()
await plugin_mgr_inst.load_plugins()
plugin_connector_inst = plugin_connector.PluginRuntimeConnector(ap, runtime_disconnect_callback)
await plugin_connector_inst.initialize()
ap.plugin_connector = plugin_connector_inst
cmd_mgr_inst = cmdmgr.CommandManager(ap) cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize() await cmd_mgr_inst.initialize()

View File

@@ -0,0 +1,22 @@
import sqlalchemy
from .base import Base
class BinaryStorage(Base):
"""Current for plugin use only"""
__tablename__ = 'binary_storages'
unique_key = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
key = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
owner_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
owner = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
value = sqlalchemy.Column(sqlalchemy.LargeBinary, nullable=False)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,
server_default=sqlalchemy.func.now(),
onupdate=sqlalchemy.func.now(),
)

View File

@@ -13,6 +13,8 @@ class PluginSetting(Base):
enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True) enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0) priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict) config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
install_source = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, default='github')
install_info = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column( updated_at = sqlalchemy.Column(
sqlalchemy.DateTime, sqlalchemy.DateTime,

View File

@@ -44,6 +44,38 @@ class PersistenceManager:
await self.create_tables() await self.create_tables()
# run migrations
database_version = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
)
database_version = int(database_version.fetchone()[1])
required_database_version = constants.required_database_version
if database_version < required_database_version:
migrations = migration.preregistered_db_migrations
migrations.sort(key=lambda x: x.number)
last_migration_number = database_version
for migration_cls in migrations:
migration_instance = migration_cls(self.ap)
if (
migration_instance.number > database_version
and migration_instance.number <= required_database_version
):
await migration_instance.upgrade()
await self.execute_async(
sqlalchemy.update(metadata.Metadata)
.where(metadata.Metadata.key == 'database_version')
.values({'value': str(migration_instance.number)})
)
last_migration_number = migration_instance.number
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
async def create_tables(self): async def create_tables(self):
# create tables # create tables
async with self.get_db_engine().connect() as conn: async with self.get_db_engine().connect() as conn:
@@ -87,38 +119,6 @@ class PersistenceManager:
# ================================= # =================================
# run migrations
database_version = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
)
database_version = int(database_version.fetchone()[1])
required_database_version = constants.required_database_version
if database_version < required_database_version:
migrations = migration.preregistered_db_migrations
migrations.sort(key=lambda x: x.number)
last_migration_number = database_version
for migration_cls in migrations:
migration_instance = migration_cls(self.ap)
if (
migration_instance.number > database_version
and migration_instance.number <= required_database_version
):
await migration_instance.upgrade()
await self.execute_async(
sqlalchemy.update(metadata.Metadata)
.where(metadata.Metadata.key == 'database_version')
.values({'value': str(migration_instance.number)})
)
last_migration_number = migration_instance.number
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult: async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
async with self.get_db_engine().connect() as conn: async with self.get_db_engine().connect() as conn:
result = await conn.execute(*args, **kwargs) result = await conn.execute(*args, **kwargs)
@@ -128,10 +128,13 @@ class PersistenceManager:
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine: def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.db.get_engine() return self.db.get_engine()
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict: def serialize_model(
self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base, masked_columns: list[str] = []
) -> dict:
return { return {
column.name: getattr(data, column.name) column.name: getattr(data, column.name)
if not isinstance(getattr(data, column.name), (datetime.datetime)) if not isinstance(getattr(data, column.name), (datetime.datetime))
else getattr(data, column.name).isoformat() else getattr(data, column.name).isoformat()
for column in model.__table__.columns for column in model.__table__.columns
if column.name not in masked_columns
} }

View File

@@ -0,0 +1,20 @@
from .. import migration
@migration.migration_class(4)
class DBMigratePluginConfig(migration.DBMigration):
"""插件配置"""
async def upgrade(self):
"""升级"""
if 'plugin' not in self.ap.instance_config.data:
self.ap.instance_config.data['plugin'] = {
'runtime_ws_url': 'ws://localhost:5400/control/ws',
}
await self.ap.instance_config.dump_config()
async def downgrade(self):
"""降级"""
pass

View File

@@ -0,0 +1,32 @@
import sqlalchemy
from .. import migration
@migration.migration_class(6)
class DBMigratePluginInstallSource(migration.DBMigration):
"""插件安装来源"""
async def upgrade(self):
"""升级"""
# 查询表结构获取所有列名(异步执行 SQL
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text('PRAGMA table_info(plugin_settings);'))
# fetchall() 是同步方法,无需 await
columns = [row[1] for row in result.fetchall()]
# 检查并添加 install_source 列
if 'install_source' not in columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text(
"ALTER TABLE plugin_settings ADD COLUMN install_source VARCHAR(255) NOT NULL DEFAULT 'github'"
)
)
# 检查并添加 install_info 列
if 'install_info' not in columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text("ALTER TABLE plugin_settings ADD COLUMN install_info JSON NOT NULL DEFAULT '{}'")
)
async def downgrade(self):
"""降级"""
pass

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('BanSessionCheckStage') @stage.stage_class('BanSessionCheckStage')
@@ -14,7 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
async def initialize(self, pipeline_config: dict): async def initialize(self, pipeline_config: dict):
pass pass
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
found = False found = False
mode = query.pipeline_config['trigger']['access-control']['mode'] mode = query.pipeline_config['trigger']['access-control']['mode']

View File

@@ -3,12 +3,11 @@ from __future__ import annotations
from ...core import app from ...core import app
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities
from . import filter as filter_model, entities as filter_entities from . import filter as filter_model, entities as filter_entities
from ...provider import entities as llm_entities from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import filters from . import filters
importutil.import_modules_in_pkg(filters) importutil.import_modules_in_pkg(filters)
@@ -58,7 +57,7 @@ class ContentFilterStage(stage.PipelineStage):
async def _pre_process( async def _pre_process(
self, self,
message: str, message: str,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""请求llm前处理消息 """请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息 只要有一个不通过就不放行,只放行 PASS 的消息
@@ -86,14 +85,14 @@ class ContentFilterStage(stage.PipelineStage):
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement message = result.replacement
query.message_chain = platform_message.MessageChain(platform_message.Plain(message)) query.message_chain = platform_message.MessageChain([platform_message.Plain(text=message)])
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def _post_process( async def _post_process(
self, self,
message: str, message: str,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""请求llm后处理响应 """请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter 只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
@@ -123,7 +122,7 @@ class ContentFilterStage(stage.PipelineStage):
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理""" """处理"""
if stage_inst_name == 'PreContentFilterStage': if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False contain_non_text = False
@@ -142,7 +141,7 @@ class ContentFilterStage(stage.PipelineStage):
return await self._pre_process(str(query.message_chain).strip(), query) return await self._pre_process(str(query.message_chain).strip(), query)
elif stage_inst_name == 'PostContentFilterStage': elif stage_inst_name == 'PostContentFilterStage':
# 仅处理 query.resp_messages[-1].content 是 str 的情况 # 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance( if isinstance(query.resp_messages[-1], provider_message.Message) and isinstance(
query.resp_messages[-1].content, str query.resp_messages[-1].content, str
): ):
return await self._post_process(query.resp_messages[-1].content, query) return await self._post_process(query.resp_messages[-1].content, query)

View File

@@ -1,6 +1,6 @@
import enum import enum
import pydantic.v1 as pydantic import pydantic
class ResultLevel(enum.Enum): class ResultLevel(enum.Enum):

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ...core import app, entities as core_entities from ...core import app
from . import entities from . import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_filters: list[typing.Type[ContentFilter]] = [] preregistered_filters: list[typing.Type[ContentFilter]] = []
@@ -60,8 +60,8 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str = None, image_url=None) -> entities.FilterResult:
"""Process message """处理消息
It is divided into two stages, depending on the value of enable_stages. It is divided into two stages, depending on the value of enable_stages.
For content filters, you do not need to consider the stage of the message, you only need to check the message content. For content filters, you do not need to consider the stage of the message, you only need to check the message content.

View File

@@ -4,8 +4,7 @@ import aiohttp
from .. import entities from .. import entities
from .. import filter as filter_model from .. import filter as filter_model
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}' BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token' BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
@@ -27,7 +26,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
) as resp: ) as resp:
return (await resp.json())['access_token'] return (await resp.json())['access_token']
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()), BAIDU_EXAMINE_URL.format(await self._get_token()),

View File

@@ -3,7 +3,7 @@ import re
from .. import filter as filter_model from .. import filter as filter_model
from .. import entities from .. import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@filter_model.filter_class('ban-word-filter') @filter_model.filter_class('ban-word-filter')
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
async def initialize(self): async def initialize(self):
pass pass
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
found = False found = False
for word in self.ap.sensitive_meta.data['words']: for word in self.ap.sensitive_meta.data['words']:

View File

@@ -3,7 +3,7 @@ import re
from .. import entities from .. import entities
from .. import filter as filter_model from .. import filter as filter_model
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@filter_model.filter_class('content-ignore') @filter_model.filter_class('content-ignore')
@@ -16,7 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
entities.EnableStage.PRE, entities.EnableStage.PRE,
] ]
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']: if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']: for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule): if message.startswith(rule):

View File

@@ -3,7 +3,10 @@ from __future__ import annotations
import asyncio import asyncio
import traceback import traceback
from ..core import app, entities from ..core import app
from ..core import entities as core_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class Controller: class Controller:
@@ -22,19 +25,19 @@ class Controller:
"""事件处理循环""" """事件处理循环"""
try: try:
while True: while True:
selected_query: entities.Query = None selected_query: pipeline_query.Query = None
# 取请求 # 取请求
async with self.ap.query_pool: async with self.ap.query_pool:
queries: list[entities.Query] = self.ap.query_pool.queries queries: list[pipeline_query.Query] = self.ap.query_pool.queries
for query in queries: for query in queries:
session = await self.ap.sess_mgr.get_session(query) session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(f'Checking query {query} session {session}') self.ap.logger.debug(f'Checking query {query} session {session}')
if not session.semaphore.locked(): if not session._semaphore.locked():
selected_query = query selected_query = query
await session.semaphore.acquire() await session._semaphore.acquire()
break break
@@ -46,7 +49,7 @@ class Controller:
if selected_query: if selected_query:
async def _process_query(selected_query: entities.Query): async def _process_query(selected_query: pipeline_query.Query):
async with self.semaphore: # 总并发上限 async with self.semaphore: # 总并发上限
# find pipeline # find pipeline
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one. # Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
@@ -59,7 +62,7 @@ class Controller:
await pipeline.run(selected_query) await pipeline.run(selected_query)
async with self.ap.query_pool: async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() (await self.ap.sess_mgr.get_session(selected_query))._semaphore.release()
# 通知其他协程,有新的请求可以处理了 # 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all() self.ap.query_pool.condition.notify_all()
@@ -68,8 +71,8 @@ class Controller:
kind='query', kind='query',
name=f'query-{selected_query.query_id}', name=f'query-{selected_query.query_id}',
scopes=[ scopes=[
entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.APPLICATION,
entities.LifecycleControlScope.PLATFORM, core_entities.LifecycleControlScope.PLATFORM,
], ],
) )

View File

@@ -3,10 +3,10 @@ from __future__ import annotations
import enum import enum
import typing import typing
import pydantic.v1 as pydantic import pydantic
from ..platform.types import message as platform_message
from ..core import entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
class ResultType(enum.Enum): class ResultType(enum.Enum):
@@ -20,7 +20,7 @@ class ResultType(enum.Enum):
class StageProcessResult(pydantic.BaseModel): class StageProcessResult(pydantic.BaseModel):
result_type: ResultType result_type: ResultType
new_query: entities.Query new_query: pipeline_query.Query
user_notice: typing.Optional[ user_notice: typing.Optional[
typing.Union[ typing.Union[

View File

@@ -5,10 +5,9 @@ import traceback
from . import strategy from . import strategy
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...platform.types import message as platform_message
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import strategies from . import strategies
importutil.import_modules_in_pkg(strategies) importutil.import_modules_in_pkg(strategies)
@@ -67,8 +66,8 @@ class LongTextProcessStage(stage.PipelineStage):
await self.strategy_impl.initialize() await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
# Check if it contains non-Plain components # 检查是否包含非 Plain 组件
contains_non_plain = False contains_non_plain = False
for msg in query.resp_message_chain[-1]: for msg in query.resp_message_chain[-1]:

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
from .. import strategy as strategy_model from .. import strategy as strategy_model
from ....core import entities as core_entities
from ....platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
Forward = platform_message.Forward Forward = platform_message.Forward
@@ -13,7 +13,7 @@ Forward = platform_message.Forward
@strategy_model.strategy_class('forward') @strategy_model.strategy_class('forward')
class ForwardComponentStrategy(strategy_model.LongTextStrategy): class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay( display = ForwardMessageDiaplay(
title='Group chat history', title='Group chat history',
brief='[Chat history]', brief='[Chat history]',

View File

@@ -8,10 +8,10 @@ import re
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import functools import functools
from ....platform.types import message as platform_message
from .. import strategy as strategy_model from .. import strategy as strategy_model
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
@strategy_model.strategy_class('image') @strategy_model.strategy_class('image')
@@ -27,7 +27,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
encoding='utf-8', encoding='utf-8',
) )
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image( img_path = self.text_to_image(
text_str=message, text_str=message,
save_as='temp/{}.png'.format(int(time.time())), save_as='temp/{}.png'.format(int(time.time())),
@@ -131,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_str: str, text_str: str,
save_as='temp.png', save_as='temp.png',
width=800, width=800,
query: core_entities.Query = None, query: pipeline_query.Query = None,
): ):
text_str = text_str.replace('\t', ' ') text_str = text_str.replace('\t', ' ')

View File

@@ -4,8 +4,9 @@ import typing
from ...core import app from ...core import app
from ...core import entities as core_entities
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
@@ -49,8 +50,8 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
"""Process long text """处理长文本
If the text length exceeds the threshold, this method will be called. If the text length exceeds the threshold, this method will be called.

View File

@@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities
from . import truncator from . import truncator
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import truncators from . import truncators
importutil.import_modules_in_pkg(truncators) importutil.import_modules_in_pkg(truncators)
@@ -29,8 +28,8 @@ class ConversationMessageTruncator(stage.PipelineStage):
else: else:
raise ValueError(f'Unknown truncator: {use_method}') raise ValueError(f'Unknown truncator: {use_method}')
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""Process""" """处理"""
query = await self.trun.truncate(query) query = await self.trun.truncate(query)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import typing import typing
import abc import abc
from ...core import entities as core_entities, app from ...core import app
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_truncators: list[typing.Type[Truncator]] = [] preregistered_truncators: list[typing.Type[Truncator]] = []
@@ -47,7 +47,7 @@ class Truncator(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def truncate(self, query: core_entities.Query) -> core_entities.Query: async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
"""截断 """截断
一般只需要操作query.messages也可以扩展操作query.prompt, query.user_message。 一般只需要操作query.messages也可以扩展操作query.prompt, query.user_message。

View File

@@ -1,15 +1,15 @@
from __future__ import annotations from __future__ import annotations
from .. import truncator from .. import truncator
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@truncator.truncator_class('round') @truncator.truncator_class('round')
class RoundTruncator(truncator.Truncator): class RoundTruncator(truncator.Truncator):
"""Truncate the conversation message chain to adapt to the LLM message length limit.""" """Truncate the conversation message chain to adapt to the LLM message length limit."""
async def truncate(self, query: core_entities.Query) -> core_entities.Query: async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
"""Truncate""" """截断"""
max_round = query.pipeline_config['ai']['local-agent']['max-round'] max_round = query.pipeline_config['ai']['local-agent']['max-round']
temp_messages = [] temp_messages = []

View File

@@ -5,14 +5,18 @@ import traceback
import sqlalchemy import sqlalchemy
from ..core import app, entities from ..core import app
from . import entities as pipeline_entities from . import entities as pipeline_entities
from ..entity.persistence import pipeline as persistence_pipeline from ..entity.persistence import pipeline as persistence_pipeline
from . import stage from . import stage
from ..platform.types import message as platform_message, events as platform_events import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..plugin import events import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.events as events
from ..utils import importutil from ..utils import importutil
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import ( from . import (
resprule, resprule,
bansess, bansess,
@@ -75,17 +79,17 @@ class RuntimePipeline:
self.pipeline_entity = pipeline_entity self.pipeline_entity = pipeline_entity
self.stage_containers = stage_containers self.stage_containers = stage_containers
async def run(self, query: entities.Query): async def run(self, query: pipeline_query.Query):
query.pipeline_config = self.pipeline_entity.config query.pipeline_config = self.pipeline_entity.config
await self.process_query(query) await self.process_query(query)
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult):
"""检查输出""" """检查输出"""
if result.user_notice: if result.user_notice:
# 处理str类型 # 处理str类型
if isinstance(result.user_notice, str): if isinstance(result.user_notice, str):
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice)) result.user_notice = platform_message.MessageChain([platform_message.Plain(text=result.user_notice)])
elif isinstance(result.user_notice, list): elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(*result.user_notice) result.user_notice = platform_message.MessageChain(*result.user_notice)
@@ -99,7 +103,7 @@ class RuntimePipeline:
bot_message=query.resp_messages[-1], bot_message=query.resp_messages[-1],
message=result.user_notice, message=result.user_notice,
quote_origin=query.pipeline_config['output']['misc']['quote-origin'], quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
is_final=[msg.is_final for msg in query.resp_messages][0] is_final=[msg.is_final for msg in query.resp_messages][0],
) )
else: else:
await query.adapter.reply_message( await query.adapter.reply_message(
@@ -117,7 +121,7 @@ class RuntimePipeline:
async def _execute_from_stage( async def _execute_from_stage(
self, self,
stage_index: int, stage_index: int,
query: entities.Query, query: pipeline_query.Query,
): ):
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。 """从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
@@ -144,7 +148,7 @@ class RuntimePipeline:
while i < len(self.stage_containers): while i < len(self.stage_containers):
stage_container = self.stage_containers[i] stage_container = self.stage_containers[i]
query.current_stage = stage_container # 标记到 Query 对象里 query.current_stage_name = stage_container.inst_name # 标记到 Query 对象里
result = stage_container.inst.process(query, stage_container.inst_name) result = stage_container.inst.process(query, stage_container.inst_name)
@@ -181,26 +185,26 @@ class RuntimePipeline:
i += 1 i += 1
async def process_query(self, query: entities.Query): async def process_query(self, query: pipeline_query.Query):
"""处理请求""" """处理请求"""
try: try:
# ======== 触发 MessageReceived 事件 ======== # ======== 触发 MessageReceived 事件 ========
event_type = ( event_type = (
events.PersonMessageReceived events.PersonMessageReceived
if query.launcher_type == entities.LauncherTypes.PERSON if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupMessageReceived else events.GroupMessageReceived
) )
event_ctx = await self.ap.plugin_mgr.emit_event( event_obj = event_type(
event=event_type( query=query,
launcher_type=query.launcher_type.value, launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id, launcher_id=query.launcher_id,
sender_id=query.sender_id, sender_id=query.sender_id,
message_chain=query.message_chain, message_chain=query.message_chain,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event_obj)
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
return return
@@ -208,11 +212,12 @@ class RuntimePipeline:
await self._execute_from_stage(0, query) await self._execute_from_stage(0, query)
except Exception as e: except Exception as e:
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown' inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}') self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
self.ap.logger.error(f'Traceback: {traceback.format_exc()}') self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
finally: finally:
self.ap.logger.debug(f'Query {query.query_id} processed') self.ap.logger.debug(f'Query {query.query_id} processed')
del self.ap.query_pool.cached_queries[query.query_id]
class PipelineManager: class PipelineManager:

View File

@@ -3,10 +3,11 @@ from __future__ import annotations
import asyncio import asyncio
import typing import typing
from ..core import entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..platform import adapter as msadapter import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ..platform.types import message as platform_message import langbot_plugin.api.entities.builtin.provider.session as provider_session
from ..platform.types import events as platform_events import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
class QueryPool: class QueryPool:
@@ -16,7 +17,10 @@ class QueryPool:
pool_lock: asyncio.Lock pool_lock: asyncio.Lock
queries: list[entities.Query] queries: list[pipeline_query.Query]
cached_queries: dict[int, pipeline_query.Query]
"""Cached queries, used for plugin backward api call, will be removed after the query completely processed"""
condition: asyncio.Condition condition: asyncio.Condition
@@ -24,34 +28,38 @@ class QueryPool:
self.query_id_counter = 0 self.query_id_counter = 0
self.pool_lock = asyncio.Lock() self.pool_lock = asyncio.Lock()
self.queries = [] self.queries = []
self.cached_queries = {}
self.condition = asyncio.Condition(self.pool_lock) self.condition = asyncio.Condition(self.pool_lock)
async def add_query( async def add_query(
self, self,
bot_uuid: str, bot_uuid: str,
launcher_type: entities.LauncherTypes, launcher_type: provider_session.LauncherTypes,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str], sender_id: typing.Union[int, str],
message_event: platform_events.MessageEvent, message_event: platform_events.MessageEvent,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
adapter: msadapter.MessagePlatformAdapter, adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
pipeline_uuid: typing.Optional[str] = None, pipeline_uuid: typing.Optional[str] = None,
) -> entities.Query: ) -> pipeline_query.Query:
async with self.condition: async with self.condition:
query = entities.Query( query_id = self.query_id_counter
query = pipeline_query.Query(
bot_uuid=bot_uuid, bot_uuid=bot_uuid,
query_id=self.query_id_counter, query_id=query_id,
launcher_type=launcher_type, launcher_type=launcher_type,
launcher_id=launcher_id, launcher_id=launcher_id,
sender_id=sender_id, sender_id=sender_id,
message_event=message_event, message_event=message_event,
message_chain=message_chain, message_chain=message_chain,
variables={},
resp_messages=[], resp_messages=[],
resp_message_chain=[], resp_message_chain=[],
adapter=adapter, adapter=adapter,
pipeline_uuid=pipeline_uuid, pipeline_uuid=pipeline_uuid,
) )
self.queries.append(query) self.queries.append(query)
self.cached_queries[query_id] = query
self.query_id_counter += 1 self.query_id_counter += 1
self.condition.notify_all() self.condition.notify_all()

View File

@@ -3,10 +3,10 @@ from __future__ import annotations
import datetime import datetime
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ...provider import entities as llm_entities import langbot_plugin.api.entities.events as events
from ...plugin import events import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('PreProcessor') @stage.stage_class('PreProcessor')
@@ -26,7 +26,7 @@ class PreProcessor(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""Process""" """Process"""
@@ -49,80 +49,73 @@ class PreProcessor(stage.PipelineStage):
query.bot_uuid, query.bot_uuid,
) )
conversation.use_llm_model = llm_model # 设置query
# Set query
query.session = session query.session = session
query.prompt = conversation.prompt.copy() query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy() query.messages = conversation.messages.copy()
query.use_llm_model = llm_model query.use_llm_model_uuid = llm_model.model_entity.uuid
if selected_runner == 'local-agent': if selected_runner == 'local-agent':
query.use_funcs = ( query.use_funcs = []
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('func_call') else None
)
query.variables = { if llm_model.model_entity.abilities.__contains__('func_call'):
query.use_funcs = await self.ap.tool_mgr.get_all_tools()
variables = {
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}', 'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
'conversation_id': conversation.uuid, 'conversation_id': conversation.uuid,
'msg_create_time': ( 'msg_create_time': (
int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()) int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp())
), ),
} }
query.variables.update(variables)
# Check if this model supports vision, if not, remove all images # Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved # TODO this checking should be performed in runner, and in this stage, the image should be reserved
if selected_runner == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'): if selected_runner == 'local-agent' and not llm_model.model_entity.abilities.__contains__('vision'):
for msg in query.messages: for msg in query.messages:
if isinstance(msg.content, list): if isinstance(msg.content, list):
for me in msg.content: for me in msg.content:
if me.type == 'image_url': if me.type == 'image_url':
msg.content.remove(me) msg.content.remove(me)
content_list: list[llm_entities.ContentElement] = [] content_list: list[provider_message.ContentElement] = []
plain_text = '' plain_text = ''
qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message') qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message')
# tidy the content_list
# combine all text content into one, and put it in the first position
for me in query.message_chain: for me in query.message_chain:
if isinstance(me, platform_message.Plain): if isinstance(me, platform_message.Plain):
content_list.append(provider_message.ContentElement.from_text(me.text))
plain_text += me.text plain_text += me.text
elif isinstance(me, platform_message.Image): elif isinstance(me, platform_message.Image):
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__( if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
'vision'
):
if me.base64 is not None: if me.base64 is not None:
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64)) content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
elif isinstance(me, platform_message.Quote) and qoute_msg: elif isinstance(me, platform_message.Quote) and qoute_msg:
for msg in me.origin: for msg in me.origin:
if isinstance(msg, platform_message.Plain): if isinstance(msg, platform_message.Plain):
content_list.append(llm_entities.ContentElement.from_text(msg.text)) content_list.append(provider_message.ContentElement.from_text(msg.text))
elif isinstance(msg, platform_message.Image): elif isinstance(msg, platform_message.Image):
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__( if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
'vision'
):
if msg.base64 is not None: if msg.base64 is not None:
content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64)) content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
content_list.insert(0, llm_entities.ContentElement.from_text(plain_text))
query.variables['user_message_text'] = plain_text query.variables['user_message_text'] = plain_text
query.user_message = llm_entities.Message(role='user', content=content_list) query.user_message = provider_message.Message(role='user', content=content_list)
# =========== Trigger event PromptPreProcessing # =========== 触发事件 PromptPreProcessing
event_ctx = await self.ap.plugin_mgr.emit_event( event = events.PromptPreProcessing(
event=events.PromptPreProcessing( session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}', default_prompt=query.prompt.messages,
default_prompt=query.prompt.messages, prompt=query.messages,
prompt=query.messages, query=query,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event)
query.prompt.messages = event_ctx.event.default_prompt query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt query.messages = event_ctx.event.prompt

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import abc import abc
from ...core import app from ...core import app
from ...core import entities as core_entities
from .. import entities from .. import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class MessageHandler(metaclass=abc.ABCMeta): class MessageHandler(metaclass=abc.ABCMeta):
@@ -19,7 +19,7 @@ class MessageHandler(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def handle( async def handle(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
raise NotImplementedError raise NotImplementedError

View File

@@ -7,13 +7,15 @@ import traceback
from .. import handler from .. import handler
from ... import entities from ... import entities
from ....core import entities as core_entities
from ....provider import runner as runner_module from ....provider import runner as runner_module
from ....plugin import events
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.events as events
from ....utils import importutil from ....utils import importutil
from ....provider import runners from ....provider import runners
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
importutil.import_modules_in_pkg(runners) importutil.import_modules_in_pkg(runners)
@@ -21,7 +23,7 @@ importutil.import_modules_in_pkg(runners)
class ChatMessageHandler(handler.MessageHandler): class ChatMessageHandler(handler.MessageHandler):
async def handle( async def handle(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理""" """处理"""
# 调API # 调API
@@ -30,19 +32,20 @@ class ChatMessageHandler(handler.MessageHandler):
# 触发插件事件 # 触发插件事件
event_class = ( event_class = (
events.PersonNormalMessageReceived events.PersonNormalMessageReceived
if query.launcher_type == core_entities.LauncherTypes.PERSON if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupNormalMessageReceived else events.GroupNormalMessageReceived
) )
event_ctx = await self.ap.plugin_mgr.emit_event( event = event_class(
event=event_class( launcher_type=query.launcher_type.value,
launcher_type=query.launcher_type.value, launcher_id=query.launcher_id,
launcher_id=query.launcher_id, sender_id=query.sender_id,
sender_id=query.sender_id, text_message=str(query.message_chain),
text_message=str(query.message_chain), query=query,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event)
is_create_card = False # 判断下是否需要创建流式卡片 is_create_card = False # 判断下是否需要创建流式卡片
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
@@ -120,4 +123,4 @@ class ChatMessageHandler(handler.MessageHandler):
) )
finally: finally:
# TODO statistics # TODO statistics
pass pass

View File

@@ -4,16 +4,17 @@ import typing
from .. import handler from .. import handler
from ... import entities from ... import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.provider.message as provider_message
from ....provider import entities as llm_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ....plugin import events import langbot_plugin.api.entities.builtin.provider.session as provider_session
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.events as events
class CommandHandler(handler.MessageHandler): class CommandHandler(handler.MessageHandler):
async def handle( async def handle(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""Process""" """Process"""
@@ -28,23 +29,23 @@ class CommandHandler(handler.MessageHandler):
event_class = ( event_class = (
events.PersonCommandSent events.PersonCommandSent
if query.launcher_type == core_entities.LauncherTypes.PERSON if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupCommandSent else events.GroupCommandSent
) )
event_ctx = await self.ap.plugin_mgr.emit_event( event = event_class(
event=event_class( launcher_type=query.launcher_type.value,
launcher_type=query.launcher_type.value, launcher_id=query.launcher_id,
launcher_id=query.launcher_id, sender_id=query.sender_id,
sender_id=query.sender_id, command=spt[0],
command=spt[0], params=spt[1:] if len(spt) > 1 else [],
params=spt[1:] if len(spt) > 1 else [], text_message=str(query.message_chain),
text_message=str(query.message_chain), is_admin=(privilege == 2),
is_admin=(privilege == 2), query=query,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event)
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = platform_message.MessageChain(event_ctx.event.reply) mc = platform_message.MessageChain(event_ctx.event.reply)
@@ -64,7 +65,7 @@ class CommandHandler(handler.MessageHandler):
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session): async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
if ret.error is not None: if ret.error is not None:
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( provider_message.Message(
role='command', role='command',
content=str(ret.error), content=str(ret.error),
) )
@@ -73,17 +74,20 @@ class CommandHandler(handler.MessageHandler):
self.ap.logger.info(f'Command({query.query_id}) error: {self.cut_str(str(ret.error))}') self.ap.logger.info(f'Command({query.query_id}) error: {self.cut_str(str(ret.error))}')
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif ret.text is not None or ret.image_url is not None: elif ret.text is not None or ret.image_url is not None or ret.image_base64 is not None:
content: list[llm_entities.ContentElement] = [] content: list[provider_message.ContentElement] = []
if ret.text is not None: if ret.text is not None:
content.append(llm_entities.ContentElement.from_text(ret.text)) content.append(provider_message.ContentElement.from_text(ret.text))
if ret.image_url is not None: if ret.image_url is not None:
content.append(llm_entities.ContentElement.from_image_url(ret.image_url)) content.append(provider_message.ContentElement.from_image_url(ret.image_url))
if ret.image_base64 is not None:
content.append(provider_message.ContentElement.from_image_base64(ret.image_base64))
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( provider_message.Message(
role='command', role='command',
content=content, content=content,
) )

View File

@@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
from ...core import entities as core_entities
from . import handler from . import handler
from .handlers import chat, command from .handlers import chat, command
from .. import entities from .. import entities
from .. import stage from .. import stage
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('MessageProcessor') @stage.stage_class('MessageProcessor')
@@ -30,7 +30,7 @@ class Processor(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""Process""" """Process"""

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ...core import app, entities as core_entities from ...core import app
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = [] preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
@@ -33,7 +34,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def require_access( async def require_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
) -> bool: ) -> bool:
@@ -53,7 +54,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def release_access( async def release_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
): ):

View File

@@ -3,7 +3,7 @@ import asyncio
import time import time
import typing import typing
from .. import algo from .. import algo
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
# 固定窗口算法 # 固定窗口算法
@@ -32,7 +32,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
async def require_access( async def require_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
) -> bool: ) -> bool:
@@ -91,7 +91,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
async def release_access( async def release_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
): ):

View File

@@ -4,9 +4,10 @@ import typing
from .. import entities, stage from .. import entities, stage
from . import algo from . import algo
from ...core import entities as core_entities
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import algos from . import algos
importutil.import_modules_in_pkg(algos) importutil.import_modules_in_pkg(algos)
@@ -39,7 +40,7 @@ class RateLimit(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> typing.Union[ ) -> typing.Union[
entities.StageProcessResult, entities.StageProcessResult,

View File

@@ -4,22 +4,19 @@ import random
import asyncio import asyncio
from ...platform.types import events as platform_events import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from ...provider import entities as llm_entities
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('SendResponseBackStage') @stage.stage_class('SendResponseBackStage')
class SendResponseBackStage(stage.PipelineStage): class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息""" """发送响应消息"""
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理""" """处理"""
random_range = ( random_range = (
@@ -40,7 +37,7 @@ class SendResponseBackStage(stage.PipelineStage):
quote_origin = query.pipeline_config['output']['misc']['quote-origin'] quote_origin = query.pipeline_config['output']['misc']['quote-origin']
has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages) has_chunks = any(isinstance(msg, provider_message.MessageChunk) for msg in query.resp_messages)
# TODO 命令与流式的兼容性问题 # TODO 命令与流式的兼容性问题
if await query.adapter.is_stream_output_supported() and has_chunks: if await query.adapter.is_stream_output_supported() and has_chunks:
is_final = [msg.is_final for msg in query.resp_messages][0] is_final = [msg.is_final for msg in query.resp_messages][0]

View File

@@ -1,6 +1,6 @@
import pydantic.v1 as pydantic import pydantic
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
class RuleJudgeResult(pydantic.BaseModel): class RuleJudgeResult(pydantic.BaseModel):

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
from . import rule from . import rule
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import rules from . import rules
importutil.import_modules_in_pkg(rules) importutil.import_modules_in_pkg(rules)
@@ -32,7 +33,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
await rule_inst.initialize() await rule_inst.initialize()
self.rule_matchers.append(rule_inst) self.rule_matchers.append(rule_inst)
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': # 只处理群消息 if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -2,10 +2,11 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ...core import app, entities as core_entities from ...core import app
from . import entities from . import entities
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregisetered_rules: list[typing.Type[GroupRespondRule]] = [] preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
@@ -39,7 +40,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
"""判断消息是否匹配规则""" """判断消息是否匹配规则"""
raise NotImplementedError raise NotImplementedError

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('at-bot') @rule_model.rule_class('at-bot')
@@ -14,19 +14,28 @@ class AtBotRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']: def remove_at(message_chain: platform_message.MessageChain):
message_chain.remove(platform_message.At(query.adapter.bot_account_id)) for component in message_chain.root:
if isinstance(component, platform_message.At) and component.target == query.adapter.bot_account_id:
message_chain.remove(component)
break
if message_chain.has( remove_at(message_chain)
platform_message.At(query.adapter.bot_account_id) remove_at(message_chain) # 回复消息时会at两次检查并删除重复的
): # 回复消息时会at两次检查并删除重复的
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult( # if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
matching=True, # message_chain.remove(platform_message.At(query.adapter.bot_account_id))
replacement=message_chain,
) # if message_chain.has(
# platform_message.At(query.adapter.bot_account_id)
# ): # 回复消息时会at两次检查并删除重复的
# message_chain.remove(platform_message.At(query.adapter.bot_account_id))
# return entities.RuleJudgeResult(
# matching=True,
# replacement=message_chain,
# )
return entities.RuleJudgeResult(matching=False, replacement=message_chain) return entities.RuleJudgeResult(matching=False, replacement=message_chain)

View File

@@ -1,7 +1,7 @@
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('prefix') @rule_model.rule_class('prefix')
@@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
prefixes = rule_dict['prefix'] prefixes = rule_dict['prefix']

View File

@@ -3,8 +3,8 @@ import random
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('random') @rule_model.rule_class('random')
@@ -14,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
random_rate = rule_dict['random'] random_rate = rule_dict['random']

View File

@@ -3,8 +3,8 @@ import re
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('regexp') @rule_model.rule_class('regexp')
@@ -14,7 +14,7 @@ class RegExpRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
regexps = rule_dict['regexp'] regexps = rule_dict['regexp']

View File

@@ -3,8 +3,9 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ..core import app, entities as core_entities from ..core import app
from . import entities from . import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_stages: dict[str, type[PipelineStage]] = {} preregistered_stages: dict[str, type[PipelineStage]] = {}
@@ -33,7 +34,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> typing.Union[ ) -> typing.Union[
entities.StageProcessResult, entities.StageProcessResult,

View File

@@ -2,12 +2,12 @@ from __future__ import annotations
import typing import typing
from ...core import entities as core_entities
from .. import entities from .. import entities
from .. import stage from .. import stage
from ...plugin import events
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.events as events
@stage.stage_class('ResponseWrapper') @stage.stage_class('ResponseWrapper')
@@ -25,7 +25,7 @@ class ResponseWrapper(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理""" """处理"""
@@ -58,21 +58,22 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = str(result.get_content_platform_message_chain()) reply_text = str(result.get_content_platform_message_chain())
# ============= 触发插件事件 =============== # ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event( event = events.NormalMessageResponded(
event=events.NormalMessageResponded( launcher_type=query.launcher_type.value,
launcher_type=query.launcher_type.value, launcher_id=query.launcher_id,
launcher_id=query.launcher_id, sender_id=query.sender_id,
sender_id=query.sender_id, session=session,
session=session, prefix='',
prefix='', response_text=reply_text,
response_text=reply_text, finish_reason='stop',
finish_reason='stop', funcs_called=[fc.function.name for fc in result.tool_calls]
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None
if result.tool_calls is not None else [],
else [], query=query,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event)
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, result_type=entities.ResultType.INTERRUPT,
@@ -96,26 +97,26 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = f'调用函数 {".".join(function_names)}...' reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append( query.resp_message_chain.append(
platform_message.MessageChain([platform_message.Plain(reply_text)]) platform_message.MessageChain([platform_message.Plain(text=reply_text)])
) )
if query.pipeline_config['output']['misc']['track-function-calls']: if query.pipeline_config['output']['misc']['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event( event = events.NormalMessageResponded(
event=events.NormalMessageResponded( launcher_type=query.launcher_type.value,
launcher_type=query.launcher_type.value, launcher_id=query.launcher_id,
launcher_id=query.launcher_id, sender_id=query.sender_id,
sender_id=query.sender_id, session=session,
session=session, prefix='',
prefix='', response_text=reply_text,
response_text=reply_text, finish_reason='stop',
finish_reason='stop', funcs_called=[fc.function.name for fc in result.tool_calls]
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None
if result.tool_calls is not None else [],
else [], query=query,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event)
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, result_type=entities.ResultType.INTERRUPT,
@@ -124,12 +125,12 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain.append( query.resp_message_chain.append(
platform_message.MessageChain(event_ctx.event.reply) platform_message.MessageChain(text=event_ctx.event.reply)
) )
else: else:
query.resp_message_chain.append( query.resp_message_chain.append(
platform_message.MessageChain([platform_message.Plain(reply_text)]) platform_message.MessageChain([platform_message.Plain(text=reply_text)])
) )
yield entities.StageProcessResult( yield entities.StageProcessResult(

View File

@@ -1,190 +0,0 @@
from __future__ import annotations
# MessageSource的适配器
import typing
import abc
from ..core import app
from .types import message as platform_message
from .types import events as platform_events
from .logger import EventLogger
class MessagePlatformAdapter(metaclass=abc.ABCMeta):
"""消息平台适配器基类"""
name: str
bot_account_id: int
"""机器人账号ID需要在初始化时设置"""
config: dict
ap: app.Application
logger: EventLogger
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
"""初始化适配器
Args:
config (dict): 对应的配置
ap (app.Application): 应用上下文
"""
self.config = config
self.ap = ap
self.logger = logger
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""主动发送消息
Args:
target_type (str): 目标类型,`person`或`group`
target_id (str): 目标ID
message (platform.types.MessageChain): 消息链
"""
raise NotImplementedError
async def reply_message(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
"""回复消息
Args:
message_source (platform.types.MessageEvent): 消息源事件
message (platform.types.MessageChain): 消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
"""
raise NotImplementedError
async def reply_message_chunk(
self,
message_source: platform_events.MessageEvent,
bot_message: dict,
message: platform_message.MessageChain,
quote_origin: bool = False,
is_final: bool = False,
):
"""回复消息(流式输出)
Args:
message_source (platform.types.MessageEvent): 消息源事件
message_id (int): 消息ID
message (platform.types.MessageChain): 消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
is_final (bool, optional): 流式是否结束. Defaults to False.
"""
raise NotImplementedError
async def create_message_card(self, message_id: typing.Type[str, int], event: platform_events.MessageEvent) -> bool:
"""创建卡片消息
Args:
message_id (str): 消息ID
event (platform_events.MessageEvent): 消息源事件
"""
return False
async def is_muted(self, group_id: int) -> bool:
"""获取账号是否在指定群被禁言"""
raise NotImplementedError
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
"""注册事件监听器
Args:
event_type (typing.Type[platform.types.Event]): 事件类型
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
"""
raise NotImplementedError
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
"""注销事件监听器
Args:
event_type (typing.Type[platform.types.Event]): 事件类型
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
"""
raise NotImplementedError
async def run_async(self):
"""异步运行"""
raise NotImplementedError
async def is_stream_output_supported(self) -> bool:
"""是否支持流式输出"""
return False
async def kill(self) -> bool:
"""关闭适配器
Returns:
bool: 是否成功关闭热重载时若此函数返回False则不会重载MessageSource底层
"""
raise NotImplementedError
class MessageConverter:
"""消息链转换器基类"""
@staticmethod
def yiri2target(message_chain: platform_message.MessageChain):
"""将源平台消息链转换为目标平台消息链
Args:
message_chain (platform.types.MessageChain): 源平台消息链
Returns:
typing.Any: 目标平台消息链
"""
raise NotImplementedError
@staticmethod
def target2yiri(message_chain: typing.Any) -> platform_message.MessageChain:
"""将目标平台消息链转换为源平台消息链
Args:
message_chain (typing.Any): 目标平台消息链
Returns:
platform.types.MessageChain: 源平台消息链
"""
raise NotImplementedError
class EventConverter:
"""事件转换器基类"""
@staticmethod
def yiri2target(event: typing.Type[platform_events.Event]):
"""将源平台事件转换为目标平台事件
Args:
event (typing.Type[platform.types.Event]): 源平台事件
Returns:
typing.Any: 目标平台事件
"""
raise NotImplementedError
@staticmethod
def target2yiri(event: typing.Any) -> platform_events.Event:
"""将目标平台事件的调用参数转换为源平台的事件参数对象
Args:
event (typing.Any): 目标平台事件
Returns:
typing.Type[platform.types.Event]: 源平台事件
"""
raise NotImplementedError

View File

@@ -1,14 +0,0 @@
apiVersion: v1
kind: ComponentTemplate
metadata:
name: MessagePlatformAdapter
label:
en_US: Message Platform Adapter
zh_Hans: 消息平台适配器模板类
spec:
type:
- python
execution:
python:
path: ./adapter.py
attr: MessagePlatformAdapter

View File

@@ -1,15 +1,10 @@
from __future__ import annotations from __future__ import annotations
import sys
import asyncio import asyncio
import traceback import traceback
import sqlalchemy import sqlalchemy
# FriendMessage, Image, MessageChain, Plain
from . import adapter as msadapter
from ..core import app, entities as core_entities, taskmgr from ..core import app, entities as core_entities, taskmgr
from .types import events as platform_events, message as platform_message
from ..discover import engine from ..discover import engine
@@ -19,10 +14,10 @@ from ..entity.errors import platform as platform_errors
from .logger import EventLogger from .logger import EventLogger
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题 import langbot_plugin.api.entities.builtin.provider.session as provider_session
from . import types as mirai import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
sys.modules['mirai'] = mirai import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
class RuntimeBot: class RuntimeBot:
@@ -34,7 +29,7 @@ class RuntimeBot:
enable: bool enable: bool
adapter: msadapter.MessagePlatformAdapter adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter
task_wrapper: taskmgr.TaskWrapper task_wrapper: taskmgr.TaskWrapper
@@ -46,7 +41,7 @@ class RuntimeBot:
self, self,
ap: app.Application, ap: app.Application,
bot_entity: persistence_bot.Bot, bot_entity: persistence_bot.Bot,
adapter: msadapter.MessagePlatformAdapter, adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
logger: EventLogger, logger: EventLogger,
): ):
self.ap = ap self.ap = ap
@@ -59,7 +54,7 @@ class RuntimeBot:
async def initialize(self): async def initialize(self):
async def on_friend_message( async def on_friend_message(
event: platform_events.FriendMessage, event: platform_events.FriendMessage,
adapter: msadapter.MessagePlatformAdapter, adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
): ):
image_components = [ image_components = [
component for component in event.message_chain if isinstance(component, platform_message.Image) component for component in event.message_chain if isinstance(component, platform_message.Image)
@@ -73,7 +68,7 @@ class RuntimeBot:
await self.ap.query_pool.add_query( await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid, bot_uuid=self.bot_entity.uuid,
launcher_type=core_entities.LauncherTypes.PERSON, launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=event.sender.id, launcher_id=event.sender.id,
sender_id=event.sender.id, sender_id=event.sender.id,
message_event=event, message_event=event,
@@ -84,7 +79,7 @@ class RuntimeBot:
async def on_group_message( async def on_group_message(
event: platform_events.GroupMessage, event: platform_events.GroupMessage,
adapter: msadapter.MessagePlatformAdapter, adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
): ):
image_components = [ image_components = [
component for component in event.message_chain if isinstance(component, platform_message.Image) component for component in event.message_chain if isinstance(component, platform_message.Image)
@@ -98,7 +93,7 @@ class RuntimeBot:
await self.ap.query_pool.add_query( await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid, bot_uuid=self.bot_entity.uuid,
launcher_type=core_entities.LauncherTypes.GROUP, launcher_type=provider_session.LauncherTypes.GROUP,
launcher_id=event.group.id, launcher_id=event.group.id,
sender_id=event.sender.id, sender_id=event.sender.id,
message_event=event, message_event=event,
@@ -153,7 +148,7 @@ class PlatformManager:
adapter_components: list[engine.Component] adapter_components: list[engine.Component]
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]]
def __init__(self, ap: app.Application = None): def __init__(self, ap: app.Application = None):
self.ap = ap self.ap = ap
@@ -163,7 +158,7 @@ class PlatformManager:
async def initialize(self): async def initialize(self):
self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter') self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {} adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]] = {}
for component in self.adapter_components: for component in self.adapter_components:
adapter_dict[component.metadata.name] = component.get_python_component_class() adapter_dict[component.metadata.name] = component.get_python_component_class()
self.adapter_dict = adapter_dict self.adapter_dict = adapter_dict
@@ -174,8 +169,9 @@ class PlatformManager:
webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap) webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap)
webchat_adapter_inst = webchat_adapter_class( webchat_adapter_inst = webchat_adapter_class(
{}, {},
self.ap,
webchat_logger, webchat_logger,
ap=self.ap,
is_stream=False,
) )
self.webchat_proxy_bot = RuntimeBot( self.webchat_proxy_bot = RuntimeBot(
@@ -195,7 +191,7 @@ class PlatformManager:
await self.load_bots_from_db() await self.load_bots_from_db()
def get_running_adapters(self) -> list[msadapter.MessagePlatformAdapter]: def get_running_adapters(self) -> list[abstract_platform_adapter.AbstractMessagePlatformAdapter]:
return [bot.adapter for bot in self.bots if bot.enable] return [bot.adapter for bot in self.bots if bot.enable]
async def load_bots_from_db(self): async def load_bots_from_db(self):
@@ -233,7 +229,6 @@ class PlatformManager:
adapter_inst = self.adapter_dict[bot_entity.adapter]( adapter_inst = self.adapter_dict[bot_entity.adapter](
bot_entity.adapter_config, bot_entity.adapter_config,
self.ap,
logger, logger,
) )
@@ -276,43 +271,6 @@ class PlatformManager:
return component return component
return None return None
async def write_back_config(
self,
adapter_name: str,
adapter_inst: msadapter.MessagePlatformAdapter,
config: dict,
):
# index = -2
# for i, adapter in enumerate(self.adapters):
# if adapter == adapter_inst:
# index = i
# break
# if index == -2:
# raise Exception('平台适配器未找到')
# # 只修改启用的适配器
# real_index = -1
# for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']):
# if adapter['enable']:
# index -= 1
# if index == -1:
# real_index = i
# break
# new_cfg = {
# 'adapter': adapter_name,
# 'enable': True,
# **config
# }
# self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg
# await self.ap.platform_cfg.dump_config()
# TODO implement this
pass
async def run(self): async def run(self):
# This method will only be called when the application launching # This method will only be called when the application launching
await self.webchat_proxy_bot.run() await self.webchat_proxy_bot.run()

View File

@@ -9,7 +9,8 @@ import traceback
import uuid import uuid
from ..core import app from ..core import app
from .types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_event_logger
class EventLogLevel(enum.Enum): class EventLogLevel(enum.Enum):
@@ -55,7 +56,7 @@ MAX_LOG_COUNT = 200
DELETE_COUNT_PER_TIME = 50 DELETE_COUNT_PER_TIME = 50
class EventLogger: class EventLogger(abstract_platform_event_logger.AbstractEventLogger):
"""used for logging bot events""" """used for logging bot events"""
ap: app.Application ap: app.Application

View File

@@ -5,17 +5,17 @@ import traceback
import datetime import datetime
import aiocqhttp import aiocqhttp
import pydantic
from .. import adapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ...core import app import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..types import message as platform_message import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ..types import events as platform_events import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ..types import entities as platform_entities
from ...utils import image from ...utils import image
from ..logger import EventLogger import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
class AiocqhttpMessageConverter(adapter.MessageConverter): class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target( async def yiri2target(
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
@@ -266,20 +266,21 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
await process_message_data(msg_data, reply_list) await process_message_data(msg_data, reply_list)
reply_msg = platform_message.Quote( reply_msg = platform_message.Quote(
message_id=msg.data['id'], sender_id=msg_datas['sender']['user_id'], origin=reply_list message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list
) )
yiri_msg_list.append(reply_msg) yiri_msg_list.append(reply_msg)
# 这里下载所有文件会导致下载文件过多,暂时不下载 elif msg.type == 'file':
# elif msg.type == 'file': pass
# # file_name = msg.data['file'] # file_name = msg.data['file']
# file_id = msg.data['file_id'] # file_id = msg.data['file_id']
# file_data = await bot.get_file(file_id=file_id) # file_data = await bot.get_file(file_id=file_id)
# file_name = file_data.get('file_name') # file_name = file_data.get('file_name')
# file_path = file_data.get('file') # file_path = file_data.get('file')
# file_url = file_data.get('file_url') # _ = file_path
# file_size = file_data.get('file_size') # file_url = file_data.get('file_url')
# yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size)) # file_size = file_data.get('file_size')
# yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size))
elif msg.type == 'face': elif msg.type == 'face':
face_id = msg.data['id'] face_id = msg.data['id']
face_name = msg.data['raw']['faceText'] face_name = msg.data['raw']['faceText']
@@ -298,7 +299,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
return chain return chain
class AiocqhttpEventConverter(adapter.EventConverter): class AiocqhttpEventConverter(abstract_platform_adapter.AbstractEventConverter):
@staticmethod @staticmethod
async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int): async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int):
return event.source_platform_object return event.source_platform_object
@@ -348,23 +349,19 @@ class AiocqhttpEventConverter(adapter.EventConverter):
) )
class AiocqhttpAdapter(adapter.MessagePlatformAdapter): class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: aiocqhttp.CQHttp bot: aiocqhttp.CQHttp = pydantic.Field(exclude=True, default_factory=aiocqhttp.CQHttp)
bot_account_id: int
message_converter: AiocqhttpMessageConverter = AiocqhttpMessageConverter() message_converter: AiocqhttpMessageConverter = AiocqhttpMessageConverter()
event_converter: AiocqhttpEventConverter = AiocqhttpEventConverter() event_converter: AiocqhttpEventConverter = AiocqhttpEventConverter()
config: dict
ap: app.Application
on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = [] on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = []
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
self.config = config super().__init__(
self.logger = logger config=config,
logger=logger,
)
async def shutdown_trigger_placeholder(): async def shutdown_trigger_placeholder():
while True: while True:
@@ -372,7 +369,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
self.config['shutdown_trigger'] = shutdown_trigger_placeholder self.config['shutdown_trigger'] = shutdown_trigger_placeholder
self.ap = ap
self.on_websocket_connection_event_cache = [] self.on_websocket_connection_event_cache = []
if 'access-token' in config: if 'access-token' in config:
@@ -408,7 +404,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
async def on_message(event: aiocqhttp.Event): async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id self.bot_account_id = event.self_id
@@ -439,7 +437,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
return super().unregister_listener(event_type, callback) return super().unregister_listener(event_type, callback)

View File

@@ -1,19 +1,16 @@
from re import S
import traceback import traceback
import typing import typing
from libs.dingtalk_api.dingtalkevent import DingTalkEvent from libs.dingtalk_api.dingtalkevent import DingTalkEvent
from pkg.platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
from pkg.platform.adapter import MessagePlatformAdapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from .. import adapter import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...core import app import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ..types import events as platform_events
from ..types import entities as platform_entities
from libs.dingtalk_api.api import DingTalkClient from libs.dingtalk_api.api import DingTalkClient
import datetime import datetime
from ..logger import EventLogger from ..logger import EventLogger
class DingTalkMessageConverter(adapter.MessageConverter): class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target(message_chain: platform_message.MessageChain): async def yiri2target(message_chain: platform_message.MessageChain):
content = '' content = ''
@@ -52,7 +49,7 @@ class DingTalkMessageConverter(adapter.MessageConverter):
return chain return chain
class DingTalkEventConverter(adapter.EventConverter): class DingTalkEventConverter(abstract_platform_adapter.AbstractEventConverter):
@staticmethod @staticmethod
async def yiri2target(event: platform_events.MessageEvent): async def yiri2target(event: platform_events.MessageEvent):
return event.source_platform_object return event.source_platform_object
@@ -96,22 +93,18 @@ class DingTalkEventConverter(adapter.EventConverter):
) )
class DingTalkAdapter(adapter.MessagePlatformAdapter): class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: DingTalkClient bot: DingTalkClient
ap: app.Application
bot_account_id: str bot_account_id: str
message_converter: DingTalkMessageConverter = DingTalkMessageConverter() message_converter: DingTalkMessageConverter = DingTalkMessageConverter()
event_converter: DingTalkEventConverter = DingTalkEventConverter() event_converter: DingTalkEventConverter = DingTalkEventConverter()
config: dict config: dict
card_instance_id_dict: dict # 回复卡片消息字典key为消息idvalue为回复卡片实例id用于在流式消息时判断是否发送到指定卡片 card_instance_id_dict: (
seq: int # 消息顺序直接以seq作为标识 dict # 回复卡片消息字典key为消息idvalue为回复卡片实例id用于在流式消息时判断是否发送到指定卡片
)
def __init__(self, config: dict, logger: EventLogger):
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
self.config = config
self.ap = ap
self.logger = logger
self.card_instance_id_dict = {}
# self.seq = 1
required_keys = [ required_keys = [
'client_id', 'client_id',
'client_secret', 'client_secret',
@@ -121,16 +114,23 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
missing_keys = [key for key in required_keys if key not in config] missing_keys = [key for key in required_keys if key not in config]
if missing_keys: if missing_keys:
raise Exception('钉钉缺少相关配置项,请查看文档或联系管理员') raise Exception('钉钉缺少相关配置项,请查看文档或联系管理员')
bot = DingTalkClient(
client_id=config['client_id'],
client_secret=config['client_secret'],
robot_name=config['robot_name'],
robot_code=config['robot_code'],
markdown_card=config['markdown_card'],
logger=logger,
)
bot_account_id = config['robot_name']
super().__init__(
config=config,
logger=logger,
card_instance_id_dict={},
bot_account_id=bot_account_id,
bot=bot,
listeners={},
self.bot_account_id = self.config['robot_name']
self.bot = DingTalkClient(
client_id=config['client_id'],
client_secret=config['client_secret'],
robot_name=config['robot_name'],
robot_code=config['robot_code'],
markdown_card=config['markdown_card'],
logger=self.logger,
) )
async def reply_message( async def reply_message(
@@ -165,12 +165,11 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
msg_seq = bot_message.msg_sequence msg_seq = bot_message.msg_sequence
if (msg_seq - 1) % 8 == 0 or is_final: if (msg_seq - 1) % 8 == 0 or is_final:
content, at = await DingTalkMessageConverter.yiri2target(message) content, at = await DingTalkMessageConverter.yiri2target(message)
card_instance, card_instance_id = self.card_instance_id_dict[message_id] card_instance, card_instance_id = self.card_instance_id_dict[message_id]
if not content and bot_message.content: if not content and bot_message.content:
content = bot_message.content # 兼容直接传入content的情况 content = bot_message.content # 兼容直接传入content的情况
# print(card_instance_id) # print(card_instance_id)
if content: if content:
await self.bot.send_card_message(card_instance, card_instance_id, content, is_final) await self.bot.send_card_message(card_instance, card_instance_id, content, is_final)
@@ -202,7 +201,9 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
async def on_message(event: DingTalkEvent): async def on_message(event: DingTalkEvent):
try: try:
@@ -224,9 +225,14 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
async def kill(self) -> bool: async def kill(self) -> bool:
return False return False
async def is_muted(self) -> bool:
return False
async def unregister_listener( async def unregister_listener(
self, self,
event_type: type, event_type: type,
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
return super().unregister_listener(event_type, callback) return super().unregister_listener(event_type, callback)

View File

@@ -12,13 +12,14 @@ import asyncio
from enum import Enum from enum import Enum
import aiohttp import aiohttp
import pydantic
from .. import adapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ...core import app import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from ..logger import EventLogger from ..logger import EventLogger
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
# 语音功能相关异常定义 # 语音功能相关异常定义
@@ -582,7 +583,7 @@ class VoiceConnectionManager:
await self.stop_monitoring() await self.stop_monitoring()
class DiscordMessageConverter(adapter.MessageConverter): class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target( async def yiri2target(
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
@@ -736,7 +737,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(element_list) return platform_message.MessageChain(element_list)
class DiscordEventConverter(adapter.EventConverter): class DiscordEventConverter(abstract_platform_adapter.AbstractEventConverter):
@staticmethod @staticmethod
async def yiri2target(event: platform_events.Event) -> discord.Message: async def yiri2target(event: platform_events.Event) -> discord.Message:
pass pass
@@ -778,32 +779,26 @@ class DiscordEventConverter(adapter.EventConverter):
) )
class DiscordAdapter(adapter.MessagePlatformAdapter): class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: discord.Client bot: discord.Client = pydantic.Field(exclude=True)
bot_account_id: str # 用于在流水线中识别at是否是本bot直接以bot_name作为标识
config: dict
ap: app.Application
message_converter: DiscordMessageConverter = DiscordMessageConverter() message_converter: DiscordMessageConverter = DiscordMessageConverter()
event_converter: DiscordEventConverter = DiscordEventConverter() event_converter: DiscordEventConverter = DiscordEventConverter()
listeners: typing.Dict[ listeners: typing.Dict[
typing.Type[platform_events.Event], typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] = {} ] = {}
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): voice_manager: VoiceConnectionManager | None = pydantic.Field(exclude=True, default=None)
self.config = config
self.ap = ap
self.logger = logger
self.bot_account_id = self.config['client_id'] def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
bot_account_id = config['client_id']
listeners = {}
# 初始化语音连接管理器 # 初始化语音连接管理器
self.voice_manager: VoiceConnectionManager = None # self.voice_manager: VoiceConnectionManager = None
adapter_self = self adapter_self = self
@@ -823,7 +818,17 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
if os.getenv('http_proxy'): if os.getenv('http_proxy'):
args['proxy'] = os.getenv('http_proxy') args['proxy'] = os.getenv('http_proxy')
self.bot = MyClient(intents=intents, **args) bot = MyClient(intents=intents, **args)
super().__init__(
config=config,
logger=logger,
bot_account_id=bot_account_id,
listeners=listeners,
bot=bot,
voice_manager=None,
**kwargs,
)
# Voice functionality methods # Voice functionality methods
async def join_voice_channel(self, guild_id: int, channel_id: int, user_id: int = None) -> discord.VoiceClient: async def join_voice_channel(self, guild_id: int, channel_id: int, user_id: int = None) -> discord.VoiceClient:
@@ -1029,7 +1034,14 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
if quote_origin: if quote_origin:
args['reference'] = message_source.source_platform_object args['reference'] = message_source.source_platform_object
if message.has(platform_message.At): has_at = False
for component in message.root:
if isinstance(component, platform_message.At):
has_at = True
break
if has_at:
args['mention_author'] = True args['mention_author'] = True
await message_source.source_platform_object.channel.send(**args) await message_source.source_platform_object.channel.send(**args)
@@ -1040,14 +1052,18 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners[event_type] = callback self.listeners[event_type] = callback
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners.pop(event_type) self.listeners.pop(event_type)

View File

@@ -17,14 +17,14 @@ import aiohttp
import lark_oapi.ws.exception import lark_oapi.ws.exception
import quart import quart
from lark_oapi.api.im.v1 import * from lark_oapi.api.im.v1 import *
import pydantic
from lark_oapi.api.cardkit.v1 import * from lark_oapi.api.cardkit.v1 import *
from .. import adapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ...core import app import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..types import message as platform_message import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ..types import events as platform_events import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ..types import entities as platform_entities import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from ..logger import EventLogger
class AESCipher(object): class AESCipher(object):
@@ -53,7 +53,7 @@ class AESCipher(object):
return self.decrypt(enc).decode('utf8') return self.decrypt(enc).decode('utf8')
class LarkMessageConverter(adapter.MessageConverter): class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target( async def yiri2target(
message_chain: platform_message.MessageChain, api_client: lark_oapi.Client message_chain: platform_message.MessageChain, api_client: lark_oapi.Client
@@ -277,7 +277,7 @@ class LarkMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(lb_msg_list) return platform_message.MessageChain(lb_msg_list)
class LarkEventConverter(adapter.EventConverter): class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
@staticmethod @staticmethod
async def yiri2target( async def yiri2target(
event: platform_events.MessageEvent, event: platform_events.MessageEvent,
@@ -325,49 +325,37 @@ CARD_ID_CACHE_SIZE = 500
CARD_ID_CACHE_MAX_LIFETIME = 20 * 60 # 20分钟 CARD_ID_CACHE_MAX_LIFETIME = 20 * 60 # 20分钟
class LarkAdapter(adapter.MessagePlatformAdapter): class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: lark_oapi.ws.Client bot: lark_oapi.ws.Client = pydantic.Field(exclude=True)
api_client: lark_oapi.Client api_client: lark_oapi.Client = pydantic.Field(exclude=True)
bot_account_id: str # 用于在流水线中识别at是否是本bot直接以bot_name作为标识 bot_account_id: str # 用于在流水线中识别at是否是本bot直接以bot_name作为标识
lark_tenant_key: str # 飞书企业key lark_tenant_key: str = pydantic.Field(exclude=True, default='') # 飞书企业key
message_converter: LarkMessageConverter = LarkMessageConverter() message_converter: LarkMessageConverter = LarkMessageConverter()
event_converter: LarkEventConverter = LarkEventConverter() event_converter: LarkEventConverter = LarkEventConverter()
listeners: typing.Dict[ listeners: typing.Dict[
typing.Type[platform_events.Event], typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] ]
config: dict quart_app: quart.Quart = pydantic.Field(exclude=True)
quart_app: quart.Quart
ap: app.Application
card_id_dict: dict[str, str] # 消息id到卡片id的映射便于创建卡片后的发送消息到指定卡片 card_id_dict: dict[str, str] # 消息id到卡片id的映射便于创建卡片后的发送消息到指定卡片
seq: int # 用于在发送卡片消息中识别消息顺序直接以seq作为标识 seq: int # 用于在发送卡片消息中识别消息顺序直接以seq作为标识
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
self.config = config quart_app = quart.Quart(__name__)
self.ap = ap
self.logger = logger
self.quart_app = quart.Quart(__name__)
self.listeners = {}
self.card_id_dict = {}
self.seq = 1
@quart_app.route('/lark/callback', methods=['POST'])
@self.quart_app.route('/lark/callback', methods=['POST'])
async def lark_callback(): async def lark_callback():
try: try:
data = await quart.request.json data = await quart.request.json
self.ap.logger.debug(f'Lark callback event: {data}')
if 'encrypt' in data: if 'encrypt' in data:
cipher = AESCipher(self.config['encrypt-key']) cipher = AESCipher(config['encrypt-key'])
data = cipher.decrypt_string(data['encrypt']) data = cipher.decrypt_string(data['encrypt'])
data = json.loads(data) data = json.loads(data)
@@ -414,10 +402,24 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build() lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build()
) )
self.bot_account_id = config['bot_name'] bot_account_id = config['bot_name']
self.bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler) bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler)
self.api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build() api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build()
super().__init__(
config=config,
logger=logger,
lark_tenant_key=config.get('lark_tenant_key', ''),
card_id_dict={},
seq=1,
listeners={},
quart_app=quart_app,
bot=bot,
api_client=api_client,
bot_account_id=bot_account_id,
**kwargs,
)
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass pass
@@ -430,151 +432,177 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
async def create_card_id(self, message_id): async def create_card_id(self, message_id):
try: try:
self.ap.logger.debug('飞书支持stream输出,创建卡片......') # self.logger.debug('飞书支持stream输出,创建卡片......')
card_data = {"schema": "2.0", "config": {"update_multi": True, "streaming_mode": True, card_data = {
"streaming_config": {"print_step": {"default": 1}, 'schema': '2.0',
"print_frequency_ms": {"default": 70}, 'config': {
"print_strategy": "fast"}}, 'update_multi': True,
"body": {"direction": "vertical", "padding": "12px 12px 12px 12px", "elements": [{"tag": "div", 'streaming_mode': True,
"text": { 'streaming_config': {
"tag": "plain_text", 'print_step': {'default': 1},
"content": "LangBot", 'print_frequency_ms': {'default': 70},
"text_size": "normal", 'print_strategy': 'fast',
"text_align": "left", },
"text_color": "default"}, },
"icon": { 'body': {
"tag": "custom_icon", 'direction': 'vertical',
"img_key": "img_v3_02p3_05c65d5d-9bad-440a-a2fb-c89571bfd5bg"}}, 'padding': '12px 12px 12px 12px',
{ 'elements': [
"tag": "markdown", {
"content": "", 'tag': 'div',
"text_align": "left", 'text': {
"text_size": "normal", 'tag': 'plain_text',
"margin": "0px 0px 0px 0px", 'content': 'LangBot',
"element_id": "streaming_txt"}, 'text_size': 'normal',
{ 'text_align': 'left',
"tag": "markdown", 'text_color': 'default',
"content": "", },
"text_align": "left", 'icon': {
"text_size": "normal", 'tag': 'custom_icon',
"margin": "0px 0px 0px 0px"}, 'img_key': 'img_v3_02p3_05c65d5d-9bad-440a-a2fb-c89571bfd5bg',
{ },
"tag": "column_set", },
"horizontal_spacing": "8px", {
"horizontal_align": "left", 'tag': 'markdown',
"columns": [ 'content': '',
{ 'text_align': 'left',
"tag": "column", 'text_size': 'normal',
"width": "weighted", 'margin': '0px 0px 0px 0px',
"elements": [ 'element_id': 'streaming_txt',
{ },
"tag": "markdown", {
"content": "", 'tag': 'markdown',
"text_align": "left", 'content': '',
"text_size": "normal", 'text_align': 'left',
"margin": "0px 0px 0px 0px"}, 'text_size': 'normal',
{ 'margin': '0px 0px 0px 0px',
"tag": "markdown", },
"content": "", {
"text_align": "left", 'tag': 'column_set',
"text_size": "normal", 'horizontal_spacing': '8px',
"margin": "0px 0px 0px 0px"}, 'horizontal_align': 'left',
{ 'columns': [
"tag": "markdown", {
"content": "", 'tag': 'column',
"text_align": "left", 'width': 'weighted',
"text_size": "normal", 'elements': [
"margin": "0px 0px 0px 0px"}], {
"padding": "0px 0px 0px 0px", 'tag': 'markdown',
"direction": "vertical", 'content': '',
"horizontal_spacing": "8px", 'text_align': 'left',
"vertical_spacing": "2px", 'text_size': 'normal',
"horizontal_align": "left", 'margin': '0px 0px 0px 0px',
"vertical_align": "top", },
"margin": "0px 0px 0px 0px", {
"weight": 1}], 'tag': 'markdown',
"margin": "0px 0px 0px 0px"}, 'content': '',
{"tag": "hr", 'text_align': 'left',
"margin": "0px 0px 0px 0px"}, 'text_size': 'normal',
{ 'margin': '0px 0px 0px 0px',
"tag": "column_set", },
"horizontal_spacing": "12px", {
"horizontal_align": "right", 'tag': 'markdown',
"columns": [ 'content': '',
{ 'text_align': 'left',
"tag": "column", 'text_size': 'normal',
"width": "weighted", 'margin': '0px 0px 0px 0px',
"elements": [ },
{ ],
"tag": "markdown", 'padding': '0px 0px 0px 0px',
"content": "<font color=\"grey-600\">以上内容由 AI 生成,仅供参考。更多详细、准确信息可点击引用链接查看</font>", 'direction': 'vertical',
"text_align": "left", 'horizontal_spacing': '8px',
"text_size": "notation", 'vertical_spacing': '2px',
"margin": "4px 0px 0px 0px", 'horizontal_align': 'left',
"icon": { 'vertical_align': 'top',
"tag": "standard_icon", 'margin': '0px 0px 0px 0px',
"token": "robot_outlined", 'weight': 1,
"color": "grey"}}], }
"padding": "0px 0px 0px 0px", ],
"direction": "vertical", 'margin': '0px 0px 0px 0px',
"horizontal_spacing": "8px", },
"vertical_spacing": "8px", {'tag': 'hr', 'margin': '0px 0px 0px 0px'},
"horizontal_align": "left", {
"vertical_align": "top", 'tag': 'column_set',
"margin": "0px 0px 0px 0px", 'horizontal_spacing': '12px',
"weight": 1}, 'horizontal_align': 'right',
{ 'columns': [
"tag": "column", {
"width": "20px", 'tag': 'column',
"elements": [ 'width': 'weighted',
{ 'elements': [
"tag": "button", {
"text": { 'tag': 'markdown',
"tag": "plain_text", 'content': '<font color="grey-600">以上内容由 AI 生成,仅供参考。更多详细、准确信息可点击引用链接查看</font>',
"content": ""}, 'text_align': 'left',
"type": "text", 'text_size': 'notation',
"width": "fill", 'margin': '4px 0px 0px 0px',
"size": "medium", 'icon': {
"icon": { 'tag': 'standard_icon',
"tag": "standard_icon", 'token': 'robot_outlined',
"token": "thumbsup_outlined"}, 'color': 'grey',
"hover_tips": { },
"tag": "plain_text", }
"content": "有帮助"}, ],
"margin": "0px 0px 0px 0px"}], 'padding': '0px 0px 0px 0px',
"padding": "0px 0px 0px 0px", 'direction': 'vertical',
"direction": "vertical", 'horizontal_spacing': '8px',
"horizontal_spacing": "8px", 'vertical_spacing': '8px',
"vertical_spacing": "8px", 'horizontal_align': 'left',
"horizontal_align": "left", 'vertical_align': 'top',
"vertical_align": "top", 'margin': '0px 0px 0px 0px',
"margin": "0px 0px 0px 0px"}, 'weight': 1,
{ },
"tag": "column", {
"width": "30px", 'tag': 'column',
"elements": [ 'width': '20px',
{ 'elements': [
"tag": "button", {
"text": { 'tag': 'button',
"tag": "plain_text", 'text': {'tag': 'plain_text', 'content': ''},
"content": ""}, 'type': 'text',
"type": "text", 'width': 'fill',
"width": "default", 'size': 'medium',
"size": "medium", 'icon': {'tag': 'standard_icon', 'token': 'thumbsup_outlined'},
"icon": { 'hover_tips': {'tag': 'plain_text', 'content': '有帮助'},
"tag": "standard_icon", 'margin': '0px 0px 0px 0px',
"token": "thumbdown_outlined"}, }
"hover_tips": { ],
"tag": "plain_text", 'padding': '0px 0px 0px 0px',
"content": "无帮助"}, 'direction': 'vertical',
"margin": "0px 0px 0px 0px"}], 'horizontal_spacing': '8px',
"padding": "0px 0px 0px 0px", 'vertical_spacing': '8px',
"vertical_spacing": "8px", 'horizontal_align': 'left',
"horizontal_align": "left", 'vertical_align': 'top',
"vertical_align": "top", 'margin': '0px 0px 0px 0px',
"margin": "0px 0px 0px 0px"}], },
"margin": "0px 0px 4px 0px"}]}} {
'tag': 'column',
'width': '30px',
'elements': [
{
'tag': 'button',
'text': {'tag': 'plain_text', 'content': ''},
'type': 'text',
'width': 'default',
'size': 'medium',
'icon': {'tag': 'standard_icon', 'token': 'thumbdown_outlined'},
'hover_tips': {'tag': 'plain_text', 'content': '无帮助'},
'margin': '0px 0px 0px 0px',
}
],
'padding': '0px 0px 0px 0px',
'vertical_spacing': '8px',
'horizontal_align': 'left',
'vertical_align': 'top',
'margin': '0px 0px 0px 0px',
},
],
'margin': '0px 0px 4px 0px',
},
],
},
}
# delay / fast 创建卡片模板delay 延迟打印fast 实时打印,可以自定义更好看的消息模板 # delay / fast 创建卡片模板delay 延迟打印fast 实时打印,可以自定义更好看的消息模板
request: CreateCardRequest = ( request: CreateCardRequest = (
@@ -592,15 +620,13 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
f'client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' f'client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
) )
self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}')
self.card_id_dict[message_id] = response.data.card_id self.card_id_dict[message_id] = response.data.card_id
card_id = response.data.card_id card_id = response.data.card_id
return card_id return card_id
except Exception as e: except Exception as e:
self.ap.logger.error(f'飞书卡片创建失败,错误信息: {e}') raise e
async def create_message_card(self, message_id, event) -> str: async def create_message_card(self, message_id, event) -> str:
""" """
创建卡片消息。 创建卡片消息。
@@ -612,7 +638,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
content = { content = {
'type': 'card', 'type': 'card',
'data': {'card_id': card_id, 'template_variable': {'content': 'Thinking...'}}, 'data': {'card_id': card_id, 'template_variable': {'content': 'Thinking...'}},
} # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档 } # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档
request: ReplyMessageRequest = ( request: ReplyMessageRequest = (
ReplyMessageRequest.builder() ReplyMessageRequest.builder()
.message_id(event.message_chain.message_id) .message_id(event.message_chain.message_id)
@@ -685,10 +711,8 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
message_id = bot_message.resp_message_id message_id = bot_message.resp_message_id
msg_seq = bot_message.msg_sequence msg_seq = bot_message.msg_sequence
if msg_seq % 8 == 0 or is_final: if msg_seq % 8 == 0 or is_final:
lark_message = await self.message_converter.yiri2target(message, self.api_client) lark_message = await self.message_converter.yiri2target(message, self.api_client)
text_message = '' text_message = ''
for ele in lark_message[0]: for ele in lark_message[0]:
if ele['tag'] == 'text': if ele['tag'] == 'text':
@@ -734,14 +758,18 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners[event_type] = callback self.listeners[event_type] = callback
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners.pop(event_type) self.listeners.pop(event_type)
@@ -778,4 +806,4 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
# 所以要设置_auto_reconnect=False,让其不重连。 # 所以要设置_auto_reconnect=False,让其不重连。
self.bot._auto_reconnect = False self.bot._auto_reconnect = False
await self.bot._disconnect() await self.bot._disconnect()
return False return False

View File

Before

Width:  |  Height:  |  Size: 25 KiB

After

Width:  |  Height:  |  Size: 25 KiB

View File

@@ -11,19 +11,19 @@ import threading
import quart import quart
import aiohttp import aiohttp
from .. import adapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ...core import app from ....core import app
from ..types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..types import events as platform_events import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ..types import entities as platform_entities import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ...utils import image from ....utils import image
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Optional, Tuple from typing import Optional, Tuple
from functools import partial from functools import partial
from ..logger import EventLogger from ...logger import EventLogger
class GewechatMessageConverter(adapter.MessageConverter): class GewechatMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
def __init__(self, config: dict): def __init__(self, config: dict):
self.config = config self.config = config
@@ -398,7 +398,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
return from_user_name.endswith('@chatroom') return from_user_name.endswith('@chatroom')
class GewechatEventConverter(adapter.EventConverter): class GewechatEventConverter(abstract_platform_adapter.AbstractEventConverter):
def __init__(self, config: dict): def __init__(self, config: dict):
self.config = config self.config = config
self.message_converter = GewechatMessageConverter(config) self.message_converter = GewechatMessageConverter(config)
@@ -458,7 +458,7 @@ class GewechatEventConverter(adapter.EventConverter):
) )
class GeWeChatAdapter(adapter.MessagePlatformAdapter): class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
name: str = 'gewechat' # 定义适配器名称 name: str = 'gewechat' # 定义适配器名称
bot: gewechat_client.GewechatClient bot: gewechat_client.GewechatClient
@@ -475,7 +475,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
listeners: typing.Dict[ listeners: typing.Dict[
typing.Type[platform_events.Event], typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] = {} ] = {}
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
@@ -491,7 +491,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
async def gewechat_callback(): async def gewechat_callback():
data = await quart.request.json data = await quart.request.json
# print(json.dumps(data, indent=4, ensure_ascii=False)) # print(json.dumps(data, indent=4, ensure_ascii=False))
self.ap.logger.debug(f'Gewechat callback event: {data}') await self.logger.debug(f'Gewechat callback event: {data}')
if 'data' in data: if 'data' in data:
data['Data'] = data['data'] data['Data'] = data['data']
@@ -601,7 +601,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
if handler := handler_map.get(msg['type']): if handler := handler_map.get(msg['type']):
handler(msg) handler(msg)
else: else:
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') await self.logger.warning(f'未处理的消息类型: {msg["type"]}')
continue continue
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
@@ -625,14 +625,18 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners[event_type] = callback self.listeners[event_type] = callback
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
pass pass
@@ -656,9 +660,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
self.config['app_id'] = app_id self.config['app_id'] = app_id
self.ap.logger.info(f'Gewechat 登录成功app_id: {app_id}') print(f'Gewechat 登录成功app_id: {app_id}')
self.ap.platform_mgr.write_back_config('gewechat', self, self.config)
# 获取 nickname # 获取 nickname
profile = self.bot.get_profile(self.config['app_id']) profile = self.bot.get_profile(self.config['app_id'])

View File

Before

Width:  |  Height:  |  Size: 274 KiB

After

Width:  |  Height:  |  Size: 274 KiB

View File

@@ -9,15 +9,15 @@ import traceback
import nakuru import nakuru
import nakuru.entities.components as nkc import nakuru.entities.components as nkc
from .. import adapter as adapter_model from ....pipeline.longtext.strategies import forward
from ...pipeline.longtext.strategies import forward import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ...platform.types import entities as platform_entities import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...platform.types import events as platform_events import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ..logger import EventLogger from ...logger import EventLogger
class NakuruProjectMessageConverter(adapter_model.MessageConverter): class NakuruProjectMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
"""消息转换器""" """消息转换器"""
@staticmethod @staticmethod
@@ -109,7 +109,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
return chain return chain
class NakuruProjectEventConverter(adapter_model.EventConverter): class NakuruProjectEventConverter(abstract_platform_adapter.AbstractEventConverter):
"""事件转换器""" """事件转换器"""
@staticmethod @staticmethod
@@ -164,7 +164,7 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
raise Exception('未支持转换的事件类型: ' + str(event)) raise Exception('未支持转换的事件类型: ' + str(event))
class NakuruAdapter(adapter_model.MessagePlatformAdapter): class NakuruAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
"""nakuru-project适配器""" """nakuru-project适配器"""
bot: nakuru.CQHTTP bot: nakuru.CQHTTP
@@ -256,13 +256,15 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
try: try:
source_cls = NakuruProjectEventConverter.yiri2target(event_type) source_cls = NakuruProjectEventConverter.yiri2target(event_type)
# 包装函数 # 包装函数
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): # type: ignore
await callback(self.event_converter.target2yiri(source), self) await callback(self.event_converter.target2yiri(source), self)
# 将包装函数和原函数的对应关系存入列表 # 将包装函数和原函数的对应关系存入列表
@@ -283,7 +285,9 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__ nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
@@ -322,7 +326,6 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
except Exception: except Exception:
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确') raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
await self.bot._run() await self.bot._run()
self.ap.logger.info('运行 Nakuru 适配器')
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)

View File

@@ -10,14 +10,14 @@ import botpy
import botpy.message as botpy_message import botpy.message as botpy_message
import botpy.types.message as botpy_message_type import botpy.types.message as botpy_message_type
from .. import adapter as adapter_model import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ...pipeline.longtext.strategies import forward from ....pipeline.longtext.strategies import forward
from ...core import app from ....core import app
from ...config import manager as cfg_mgr from ....config import manager as cfg_mgr
from ...platform.types import entities as platform_entities import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ...platform.types import events as platform_events import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..logger import EventLogger from ...logger import EventLogger
class OfficialGroupMessage(platform_events.GroupMessage): class OfficialGroupMessage(platform_events.GroupMessage):
@@ -133,7 +133,7 @@ class OpenIDMapping(typing.Generic[K, V]):
return value return value
class OfficialMessageConverter(adapter_model.MessageConverter): class OfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
"""QQ 官方消息转换器""" """QQ 官方消息转换器"""
@staticmethod @staticmethod
@@ -237,7 +237,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
return chain return chain
class OfficialEventConverter(adapter_model.EventConverter): class OfficialEventConverter(abstract_platform_adapter.AbstractEventConverter):
"""事件转换器""" """事件转换器"""
def __init__(self): def __init__(self):
@@ -333,7 +333,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
) )
class OfficialAdapter(adapter_model.MessagePlatformAdapter): class OfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
"""QQ 官方消息适配器""" """QQ 官方消息适配器"""
bot: botpy.Client = None bot: botpy.Client = None
@@ -484,7 +484,9 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
try: try:
@@ -507,7 +509,9 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
delattr(self.bot, event_handler_mapping[event_type]) delattr(self.bot, event_handler_mapping[event_type])
@@ -519,7 +523,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
self.cfg['ret_coro'] = True self.cfg['ret_coro'] = True
self.ap.logger.info('运行 QQ 官方适配器') await self.logger.info('运行 QQ 官方适配器')
await (await self.bot.start(**self.cfg)) await (await self.bot.start(**self.cfg))
async def kill(self) -> bool: async def kill(self) -> bool:

View File

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

@@ -4,19 +4,18 @@ import asyncio
import traceback import traceback
import datetime import datetime
from pkg.platform.adapter import MessagePlatformAdapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from pkg.platform.types import events as platform_events, message as platform_message
from libs.official_account_api.oaevent import OAEvent from libs.official_account_api.oaevent import OAEvent
from libs.official_account_api.api import OAClient from libs.official_account_api.api import OAClient
from libs.official_account_api.api import OAClientForLongerResponse from libs.official_account_api.api import OAClientForLongerResponse
from .. import adapter import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ...core import app import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..types import entities as platform_entities import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...command.errors import ParamNotEnoughError from langbot_plugin.api.entities.builtin.command import errors as command_errors
from ..logger import EventLogger from ..logger import EventLogger
class OAMessageConverter(adapter.MessageConverter): class OAMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target(message_chain: platform_message.MessageChain): async def yiri2target(message_chain: platform_message.MessageChain):
for msg in message_chain: for msg in message_chain:
@@ -34,7 +33,7 @@ class OAMessageConverter(adapter.MessageConverter):
return chain return chain
class OAEventConverter(adapter.EventConverter): class OAEventConverter(abstract_platform_adapter.AbstractEventConverter):
@staticmethod @staticmethod
async def target2yiri(event: OAEvent): async def target2yiri(event: OAEvent):
if event.type == 'text': if event.type == 'text':
@@ -56,17 +55,15 @@ class OAEventConverter(adapter.EventConverter):
return None return None
class OfficialAccountAdapter(adapter.MessagePlatformAdapter): class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: OAClient | OAClientForLongerResponse bot: OAClient | OAClientForLongerResponse
ap: app.Application
bot_account_id: str bot_account_id: str
message_converter: OAMessageConverter = OAMessageConverter() message_converter: OAMessageConverter = OAMessageConverter()
event_converter: OAEventConverter = OAEventConverter() event_converter: OAEventConverter = OAEventConverter()
config: dict config: dict
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
required_keys = [ required_keys = [
@@ -78,7 +75,7 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
] ]
missing_keys = [key for key in required_keys if key not in config] missing_keys = [key for key in required_keys if key not in config]
if missing_keys: if missing_keys:
raise ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员') raise command_errors.ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员')
if self.config['Mode'] == 'drop': if self.config['Mode'] == 'drop':
self.bot = OAClient( self.bot = OAClient(
@@ -119,7 +116,9 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: type, event_type: type,
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
async def on_message(event: OAEvent): async def on_message(event: OAEvent):
self.bot_account_id = event.receiver_id self.bot_account_id = event.receiver_id
@@ -150,6 +149,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
async def unregister_listener( async def unregister_listener(
self, self,
event_type: type, event_type: type,
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
return super().unregister_listener(event_type, callback) return super().unregister_listener(event_type, callback)

View File

@@ -5,19 +5,18 @@ import traceback
import datetime import datetime
from pkg.platform.adapter import MessagePlatformAdapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from pkg.platform.types import events as platform_events, message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
from .. import adapter import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...core import app import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ..types import entities as platform_entities from langbot_plugin.api.entities.builtin.command import errors as command_errors
from ...command.errors import ParamNotEnoughError
from libs.qq_official_api.api import QQOfficialClient from libs.qq_official_api.api import QQOfficialClient
from libs.qq_official_api.qqofficialevent import QQOfficialEvent from libs.qq_official_api.qqofficialevent import QQOfficialEvent
from ...utils import image from ...utils import image
from ..logger import EventLogger from ..logger import EventLogger
class QQOfficialMessageConverter(adapter.MessageConverter): class QQOfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target(message_chain: platform_message.MessageChain): async def yiri2target(message_chain: platform_message.MessageChain):
content_list = [] content_list = []
@@ -46,7 +45,7 @@ class QQOfficialMessageConverter(adapter.MessageConverter):
return chain return chain
class QQOfficialEventConverter(adapter.EventConverter): class QQOfficialEventConverter(abstract_platform_adapter.AbstractEventConverter):
@staticmethod @staticmethod
async def yiri2target(event: platform_events.MessageEvent) -> QQOfficialEvent: async def yiri2target(event: platform_events.MessageEvent) -> QQOfficialEvent:
return event.source_platform_object return event.source_platform_object
@@ -132,17 +131,15 @@ class QQOfficialEventConverter(adapter.EventConverter):
) )
class QQOfficialAdapter(adapter.MessagePlatformAdapter): class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: QQOfficialClient bot: QQOfficialClient
ap: app.Application
config: dict config: dict
bot_account_id: str bot_account_id: str
message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter() message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter()
event_converter: QQOfficialEventConverter = QQOfficialEventConverter() event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
required_keys = [ required_keys = [
@@ -151,7 +148,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
] ]
missing_keys = [key for key in required_keys if key not in config] missing_keys = [key for key in required_keys if key not in config]
if missing_keys: if missing_keys:
raise ParamNotEnoughError('QQ官方机器人缺少相关配置项请查看文档或联系管理员') raise command_errors.ParamNotEnoughError('QQ官方机器人缺少相关配置项请查看文档或联系管理员')
self.bot = QQOfficialClient( self.bot = QQOfficialClient(
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger
@@ -215,7 +212,9 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
async def on_message(event: QQOfficialEvent): async def on_message(event: QQOfficialEvent):
self.bot_account_id = 'justbot' self.bot_account_id = 'justbot'
@@ -248,6 +247,8 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: type, event_type: type,
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
return super().unregister_listener(event_type, callback) return super().unregister_listener(event_type, callback)

View File

@@ -6,18 +6,17 @@ import traceback
import datetime import datetime
from libs.slack_api.api import SlackClient from libs.slack_api.api import SlackClient
from pkg.platform.adapter import MessagePlatformAdapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from pkg.platform.types import events as platform_events, message as platform_message
from libs.slack_api.slackevent import SlackEvent from libs.slack_api.slackevent import SlackEvent
from pkg.core import app import langbot_plugin.api.entities.builtin.platform.events as platform_events
from .. import adapter import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..types import entities as platform_entities import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ...command.errors import ParamNotEnoughError from langbot_plugin.api.entities.builtin.command import errors as command_errors
from ...utils import image from ...utils import image
from ..logger import EventLogger from ..logger import EventLogger
class SlackMessageConverter(adapter.MessageConverter): class SlackMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target(message_chain: platform_message.MessageChain): async def yiri2target(message_chain: platform_message.MessageChain):
content_list = [] content_list = []
@@ -44,7 +43,7 @@ class SlackMessageConverter(adapter.MessageConverter):
return chain return chain
class SlackEventConverter(adapter.EventConverter): class SlackEventConverter(abstract_platform_adapter.AbstractEventConverter):
@staticmethod @staticmethod
async def yiri2target(event: platform_events.MessageEvent) -> SlackEvent: async def yiri2target(event: platform_events.MessageEvent) -> SlackEvent:
return event.source_platform_object return event.source_platform_object
@@ -84,17 +83,15 @@ class SlackEventConverter(adapter.EventConverter):
) )
class SlackAdapter(adapter.MessagePlatformAdapter): class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: SlackClient bot: SlackClient
ap: app.Application
bot_account_id: str bot_account_id: str
message_converter: SlackMessageConverter = SlackMessageConverter() message_converter: SlackMessageConverter = SlackMessageConverter()
event_converter: SlackEventConverter = SlackEventConverter() event_converter: SlackEventConverter = SlackEventConverter()
config: dict config: dict
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
required_keys = [ required_keys = [
'bot_token', 'bot_token',
@@ -102,7 +99,7 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
] ]
missing_keys = [key for key in required_keys if key not in config] missing_keys = [key for key in required_keys if key not in config]
if missing_keys: if missing_keys:
raise ParamNotEnoughError('Slack机器人缺少相关配置项请查看文档或联系管理员') raise command_errors.ParamNotEnoughError('Slack机器人缺少相关配置项请查看文档或联系管理员')
self.bot = SlackClient( self.bot = SlackClient(
bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger
@@ -135,7 +132,9 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
async def on_message(event: SlackEvent): async def on_message(event: SlackEvent):
self.bot_account_id = 'SlackBot' self.bot_account_id = 'SlackBot'
@@ -166,6 +165,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
async def unregister_listener( async def unregister_listener(
self, self,
event_type: type, event_type: type,
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
return super().unregister_listener(event_type, callback) return super().unregister_listener(event_type, callback)

View File

@@ -10,18 +10,16 @@ import typing
import traceback import traceback
import base64 import base64
import aiohttp import aiohttp
import pydantic
from lark_oapi.api.im.v1 import * import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.message as platform_message
from .. import adapter import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...core import app import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ..types import message as platform_message import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from ..types import events as platform_events
from ..types import entities as platform_entities
from ..logger import EventLogger
class TelegramMessageConverter(adapter.MessageConverter): class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target(message_chain: platform_message.MessageChain, bot: telegram.Bot) -> list[dict]: async def yiri2target(message_chain: platform_message.MessageChain, bot: telegram.Bot) -> list[dict]:
components = [] components = []
@@ -90,7 +88,7 @@ class TelegramMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_components) return platform_message.MessageChain(message_components)
class TelegramEventConverter(adapter.EventConverter): class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
@staticmethod @staticmethod
async def yiri2target(event: platform_events.MessageEvent, bot: telegram.Bot): async def yiri2target(event: platform_events.MessageEvent, bot: telegram.Bot):
return event.source_platform_object return event.source_platform_object
@@ -132,17 +130,14 @@ class TelegramEventConverter(adapter.EventConverter):
) )
class TelegramAdapter(adapter.MessagePlatformAdapter): class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: telegram.Bot bot: telegram.Bot = pydantic.Field(exclude=True)
application: telegram.ext.Application application: telegram.ext.Application = pydantic.Field(exclude=True)
bot_account_id: str
message_converter: TelegramMessageConverter = TelegramMessageConverter() message_converter: TelegramMessageConverter = TelegramMessageConverter()
event_converter: TelegramEventConverter = TelegramEventConverter() event_converter: TelegramEventConverter = TelegramEventConverter()
config: dict config: dict
ap: app.Application
msg_stream_id: dict # 流式消息id字典key为流式消息idvalue为首次消息源id用于在流式消息时判断编辑那条消息 msg_stream_id: dict # 流式消息id字典key为流式消息idvalue为首次消息源id用于在流式消息时判断编辑那条消息
@@ -150,16 +145,10 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
listeners: typing.Dict[ listeners: typing.Dict[
typing.Type[platform_events.Event], typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] = {} ] = {}
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
self.config = config
self.ap = ap
self.logger = logger
self.msg_stream_id = {}
# self.seq = 1
async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE): async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
if update.message.from_user.is_bot: if update.message.from_user.is_bot:
return return
@@ -171,10 +160,18 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
except Exception: except Exception:
await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}') await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}')
self.application = ApplicationBuilder().token(self.config['token']).build() application = ApplicationBuilder().token(config['token']).build()
self.bot = self.application.bot bot = application.bot
self.application.add_handler( application.add_handler(MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback))
MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback) super().__init__(
config=config,
logger=logger,
msg_stream_id={},
seq=1,
bot=bot,
application=application,
bot_account_id='',
listeners={},
) )
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
@@ -278,14 +275,18 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners[event_type] = callback self.listeners[event_type] = callback
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners.pop(event_type) self.listeners.pop(event_type)

View File

@@ -3,17 +3,19 @@ import logging
import typing import typing
from datetime import datetime from datetime import datetime
from pydantic import BaseModel import pydantic
from .. import adapter as msadapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ..types import events as platform_events, message as platform_message, entities as platform_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from ...core import app from ...core import app
from ..logger import EventLogger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WebChatMessage(BaseModel): class WebChatMessage(pydantic.BaseModel):
id: int id: int
role: str role: str
content: str content: str
@@ -41,30 +43,35 @@ class WebChatSession:
return self.message_lists[pipeline_uuid] return self.message_lists[pipeline_uuid]
class WebChatAdapter(msadapter.MessagePlatformAdapter): class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
"""WebChat调试适配器用于流水线调试""" """WebChat调试适配器用于流水线调试"""
webchat_person_session: WebChatSession webchat_person_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession)
webchat_group_session: WebChatSession webchat_group_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession)
listeners: typing.Dict[ listeners: dict[
typing.Type[platform_events.Event], typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] = {} ] = pydantic.Field(default_factory=dict, exclude=True)
is_stream: bool is_stream: bool = pydantic.Field(exclude=True)
debug_messages: dict[str, list[dict]] = pydantic.Field(default_factory=dict, exclude=True)
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): ap: app.Application = pydantic.Field(exclude=True)
self.ap = ap
self.logger = logger def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
self.config = config super().__init__(
config=config,
logger=logger,
**kwargs,
)
self.webchat_person_session = WebChatSession(id='webchatperson') self.webchat_person_session = WebChatSession(id='webchatperson')
self.webchat_group_session = WebChatSession(id='webchatgroup') self.webchat_group_session = WebChatSession(id='webchatgroup')
self.bot_account_id = 'webchatbot' self.bot_account_id = 'webchatbot'
self.is_stream = False self.debug_messages = {}
async def send_message( async def send_message(
self, self,
@@ -159,7 +166,9 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
func: typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], typing.Awaitable[None]], func: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
],
): ):
"""注册事件监听器""" """注册事件监听器"""
self.listeners[event_type] = func self.listeners[event_type] = func
@@ -167,11 +176,16 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
func: typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], typing.Awaitable[None]], func: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
],
): ):
"""取消注册事件监听器""" """取消注册事件监听器"""
del self.listeners[event_type] del self.listeners[event_type]
async def is_muted(self, group_id: int) -> bool:
return False
async def run_async(self): async def run_async(self):
"""运行适配器""" """运行适配器"""
await self.logger.info('WebChat调试适配器已启动') await self.logger.info('WebChat调试适配器已启动')
@@ -221,7 +235,7 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp())) message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp()))
if session_type == 'person': if session_type == 'person':
sender = platform_entities.Friend(id='webchatperson', nickname='User') sender = platform_entities.Friend(id='webchatperson', nickname='User', remark='User')
event = platform_events.FriendMessage( event = platform_events.FriendMessage(
sender=sender, message_chain=message_chain, time=datetime.now().timestamp() sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
) )

View File

@@ -16,24 +16,30 @@ import threading
import quart import quart
from .. import adapter
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ..logger import EventLogger from ..logger import EventLogger
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Optional, Tuple from typing import Optional, Tuple
from functools import partial from functools import partial
import logging import logging
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
class WeChatPadMessageConverter(adapter.MessageConverter): class WeChatPadMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
def __init__(self, config: dict, logger: logging.Logger): def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
self.bot = WeChatPadClient(config['wechatpad_url'], config['token'])
self.config = config self.config = config
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
self.logger = logger self.logger = logger
# super().__init__(
# config = config,
# bot = bot,
# logger = logger,
# )
@staticmethod @staticmethod
async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]: async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]:
content_list = [] content_list = []
@@ -447,11 +453,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return from_user_name.endswith('@chatroom') return from_user_name.endswith('@chatroom')
class WeChatPadEventConverter(adapter.EventConverter): class WeChatPadEventConverter(abstract_platform_adapter.AbstractEventConverter):
def __init__(self, config: dict, logger: logging.Logger): def __init__(self, config: dict, logger: logging.Logger):
self.config = config self.config = config
self.message_converter = WeChatPadMessageConverter(config, logger)
self.logger = logger self.logger = logger
self.message_converter = WeChatPadMessageConverter(self.config, self.logger)
# super().__init__(
# config=config,
# message_converter=message_converter,
# logger = logger,
# )
@staticmethod @staticmethod
async def yiri2target(event: platform_events.MessageEvent) -> dict: async def yiri2target(event: platform_events.MessageEvent) -> dict:
@@ -511,7 +522,7 @@ class WeChatPadEventConverter(adapter.EventConverter):
) )
class WeChatPadAdapter(adapter.MessagePlatformAdapter): class WeChatPadAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
name: str = 'WeChatPad' # 定义适配器名称 name: str = 'WeChatPad' # 定义适配器名称
bot: WeChatPadClient bot: WeChatPadClient
@@ -521,29 +532,38 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
config: dict config: dict
ap: app.Application logger: EventLogger
message_converter: WeChatPadMessageConverter message_converter: WeChatPadMessageConverter
event_converter: WeChatPadEventConverter event_converter: WeChatPadEventConverter
listeners: typing.Dict[ listeners: typing.Dict[
typing.Type[platform_events.Event], typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] = {} ] = {}
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config
self.ap = ap
self.logger = logger
self.quart_app = quart.Quart(__name__)
self.message_converter = WeChatPadMessageConverter(config, ap.logger) quart_app = quart.Quart(__name__)
self.event_converter = WeChatPadEventConverter(config, ap.logger)
message_converter = WeChatPadMessageConverter(config, logger)
event_converter = WeChatPadEventConverter(config, logger)
bot = WeChatPadClient(config['wechatpad_url'], config['token'])
super().__init__(
config=config,
logger = logger,
quart_app = quart_app,
message_converter =message_converter,
event_converter = event_converter,
listeners={},
bot_account_id ='',
name="WeChatPad",
bot=bot,
)
async def ws_message(self, data): async def ws_message(self, data):
"""处理接收到的消息""" """处理接收到的消息"""
# self.ap.logger.debug(f"Gewechat callback event: {data}")
# print(data)
try: try:
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id) event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
@@ -609,9 +629,8 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
if handler := handler_map.get(msg['type']): if handler := handler_map.get(msg['type']):
handler(msg) handler(msg)
# self.ap.logger.warning(f"未处理的消息类型: {ret}")
else: else:
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') self.logger.warning(f'未处理的消息类型: {msg["type"]}')
continue continue
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
@@ -635,14 +654,18 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners[event_type] = callback self.listeners[event_type] = callback
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
pass pass
@@ -653,7 +676,6 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
if self.config['token']: if self.config['token']:
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
data = self.bot.get_login_status() data = self.bot.get_login_status()
self.ap.logger.info(data)
if data['Code'] == 300 and data['Text'] == '你已退出微信': if data['Code'] == 300 and data['Text'] == '你已退出微信':
response = requests.post( response = requests.post(
f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
@@ -673,7 +695,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
self.config['token'] = response.json()['Data'][0] self.config['token'] = response.json()['Data'][0]
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger) self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger)
self.ap.logger.info(self.config['token']) await self.logger.info(self.config['token'])
thread_1 = threading.Event() thread_1 = threading.Event()
def wechat_login_process(): def wechat_login_process():
@@ -681,10 +703,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
# login_data =self.bot.get_login_qr() # login_data =self.bot.get_login_qr()
# url = login_data['Data']["QrCodeUrl"] # url = login_data['Data']["QrCodeUrl"]
# self.ap.logger.info(login_data)
profile = self.bot.get_profile() profile = self.bot.get_profile()
self.ap.logger.info(profile) # self.logger.info(profile)
self.bot_account_id = profile['Data']['userInfo']['nickName']['str'] self.bot_account_id = profile['Data']['userInfo']['nickName']['str']
self.config['wxid'] = profile['Data']['userInfo']['userName']['str'] self.config['wxid'] = profile['Data']['userInfo']['userName']['str']
@@ -696,27 +717,26 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
def connect_websocket_sync() -> None: def connect_websocket_sync() -> None:
thread_1.wait() thread_1.wait()
uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}' uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}'
self.ap.logger.info(f'Connecting to WebSocket: {uri}') print(f'Connecting to WebSocket: {uri}')
def on_message(ws, message): def on_message(ws, message):
try: try:
data = json.loads(message) data = json.loads(message)
self.ap.logger.debug(f'Received message: {data}')
# 这里需要确保ws_message是同步的或者使用asyncio.run调用异步方法 # 这里需要确保ws_message是同步的或者使用asyncio.run调用异步方法
asyncio.run(self.ws_message(data)) asyncio.run(self.ws_message(data))
except json.JSONDecodeError: except json.JSONDecodeError:
self.ap.logger.error(f'Non-JSON message: {message[:100]}...') self.logger.error(f'Non-JSON message: {message[:100]}...')
def on_error(ws, error): def on_error(ws, error):
self.ap.logger.error(f'WebSocket error: {str(error)[:200]}') self.logger.error(f'WebSocket error: {str(error)[:200]}')
def on_close(ws, close_status_code, close_msg): def on_close(ws, close_status_code, close_msg):
self.ap.logger.info('WebSocket closed, reconnecting...') self.logger.info('WebSocket closed, reconnecting...')
time.sleep(5) time.sleep(5)
connect_websocket_sync() # 自动重连 connect_websocket_sync() # 自动重连
def on_open(ws): def on_open(ws):
self.ap.logger.info('WebSocket connected successfully!') self.logger.info('WebSocket connected successfully!')
ws = websocket.WebSocketApp( ws = websocket.WebSocketApp(
uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open
@@ -727,10 +747,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
# connect_websocket_sync() # connect_websocket_sync()
# 这行代码会在WebSocket连接断开后才会执行 # 这行代码会在WebSocket连接断开后才会执行
# self.ap.logger.info("WebSocket client thread started")
thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True) thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True)
thread.start() thread.start()
self.ap.logger.info('WebSocket client thread started') self.logger.info('WebSocket client thread started')
async def kill(self) -> bool: async def kill(self) -> bool:
pass pass

Some files were not shown because too many files have changed in this diff Show More