mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 03:15:06 +08:00
feat: new plugin system (#1495)
* deps: add `langbot-plugin` * feat: connector for plugin runtime * feat(plugin): basic communication * feat: listing plugins * feat: switch tool entities and format * feat: switch Query to langbot-plugin definition * chore: delete Query class * feat: switch message platform adapters to sdk * chore: remove adapter meta manifest from components.yaml * feat: preliminary migration of events entities * fix: serialization bug in events emitting * feat: minor changes adapt to event emitting * feat: adapt more events * feat: switch all event emitting logic to new method * refactor: use `emit_event` from connector * feat: runtime reconnecting * feat: add Tool component * feat: switch command entities to sdk * feat: command execution via plugin * feat: `reply_message` api * feat: get bot uuid api * feat: query-based apis * refactor: switch llm_entities to plugin sdk * feat: backward call apis * perf: longer timeout for emit_event * feat: binary storage api * feat(ui): list plugins * feat: get plugin info * feat: kill runtime process when exit in stdio mode * perf: dispose process * chore: bump langbot-plugin version to 0.1.1a1 * fix: message chain init * feat: `get_bot_info` api * feat: set cloud_service_url * feat: refactor webui httpclient * fix: bot switching * feat: tag debugging plugins in webui * feat: plugin installation * feat: plugin installation webui * feat: trace plugin installation * feat: marketplace page * perf: frontend * fix: i18n fallback * feat: plugin operations * feat: plugin deletion and upgrade * feat: setting plugin config * feat: bump version of langbot-plugin * chore: remove plugin reorder functionality * chore: bump version 4.3.0b1 * chore: bump langbot_plugin version * fix: conflict in table `plugin_settings` * chore: bump version to '4.3.0b2' * chore: bump version 4.3.0b3 * Update package.json (#1627) * feat: change standalone runtime tag env * fix: use --standalone-runtime * feat: update docker launch method * fix: change tag of image to `latest` * perf: inline code display style in markdown * fix: syntax errors * fix: wrong migration target version * fix: set plugin enabled=true as default * fix: replace message_chain.has usage * fix: dark mode for plugins management page * fix: minor bugs * fix: tool call params in localagent * chore: bump version 4.3.0b4 * feat: available for disabling marketplace(offline env) * perf: display installed plugin icon * refactor: market plugin detail dialog * perf: dark theme * fix: cloudServiceClient api * feat: supports for command return image base64 * chore: bump langbot_plugin to 0.1.1b6 * del self.ap error * fix: dingtalk pydantic.BaseModel norm * fix: wechatpad pydantic.BaseModel norm * chore: move docker-compose.yaml for plugin edition --------- Co-authored-by: How-Sean Xin <mcjiekejiemi@163.com> Co-authored-by: fdc <2213070223@qq.com>
This commit is contained in:
@@ -9,7 +9,6 @@ spec:
|
||||
components:
|
||||
ComponentTemplate:
|
||||
fromFiles:
|
||||
- pkg/platform/adapter.yaml
|
||||
- pkg/provider/modelmgr/requester.yaml
|
||||
MessagePlatformAdapter:
|
||||
fromDirs:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# This file is deprecated, and will be replaced by docker/docker-compose.yaml in next version.
|
||||
version: "3"
|
||||
|
||||
services:
|
||||
@@ -13,4 +14,4 @@ services:
|
||||
ports:
|
||||
- 5300:5300 # 供 WebUI 使用
|
||||
- 2280-2290:2280-2290 # 供消息平台适配器方向连接
|
||||
# 根据具体环境配置网络
|
||||
# 根据具体环境配置网络
|
||||
36
docker/docker-compose.yaml
Normal file
36
docker/docker-compose.yaml
Normal file
@@ -0,0 +1,36 @@
|
||||
version: "3"
|
||||
|
||||
services:
|
||||
|
||||
langbot_plugin_runtime:
|
||||
image: rockchin/langbot:latest
|
||||
container_name: langbot_plugin_runtime
|
||||
volumes:
|
||||
- ./data/plugins:/app/data/plugins
|
||||
ports:
|
||||
- 5401:5401
|
||||
restart: on-failure
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
command: ["uv", "run", "-m", "langbot_plugin.cli.__init__", "rt"]
|
||||
networks:
|
||||
- langbot_network
|
||||
|
||||
langbot:
|
||||
image: rockchin/langbot:latest
|
||||
container_name: langbot
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./plugins:/app/plugins
|
||||
restart: on-failure
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
ports:
|
||||
- 5300:5300 # For web ui
|
||||
- 2280-2290:2280-2290 # For platform webhook
|
||||
networks:
|
||||
- langbot_network
|
||||
|
||||
networks:
|
||||
langbot_network:
|
||||
driver: bridge
|
||||
@@ -3,7 +3,7 @@ from quart import request
|
||||
import httpx
|
||||
from quart import Quart
|
||||
from typing import Callable, Dict, Any
|
||||
from pkg.platform.types import events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
from .qqofficialevent import QQOfficialEvent
|
||||
import json
|
||||
import traceback
|
||||
|
||||
@@ -4,7 +4,7 @@ from quart import Quart, jsonify, request
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from .slackevent import SlackEvent
|
||||
from typing import Callable
|
||||
from pkg.platform.types import events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
|
||||
|
||||
class SlackClient:
|
||||
|
||||
@@ -8,7 +8,7 @@ from quart import Quart
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Callable, Dict, Any
|
||||
from .wecomevent import WecomEvent
|
||||
from pkg.platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import aiofiles
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from quart import Quart
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Callable
|
||||
from .wecomcsevent import WecomCSEvent
|
||||
from pkg.platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import aiofiles
|
||||
|
||||
|
||||
|
||||
16
main.py
16
main.py
@@ -19,8 +19,14 @@ asciiart = r"""
|
||||
async def main_entry(loop: asyncio.AbstractEventLoop):
|
||||
parser = argparse.ArgumentParser(description='LangBot')
|
||||
parser.add_argument('--skip-plugin-deps-check', action='store_true', help='跳过插件依赖项检查', default=False)
|
||||
parser.add_argument('--standalone-runtime', action='store_true', help='使用独立插件运行时', default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.standalone_runtime:
|
||||
from pkg.utils import platform
|
||||
|
||||
platform.standalone_runtime = True
|
||||
|
||||
print(asciiart)
|
||||
|
||||
import sys
|
||||
@@ -47,13 +53,13 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
|
||||
if not args.skip_plugin_deps_check:
|
||||
await deps.precheck_plugin_deps()
|
||||
|
||||
# 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1
|
||||
import pydantic.version
|
||||
# # 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1
|
||||
# import pydantic.version
|
||||
|
||||
if pydantic.version.VERSION < '2.0':
|
||||
import pydantic
|
||||
# if pydantic.version.VERSION < '2.0':
|
||||
# import pydantic
|
||||
|
||||
sys.modules['pydantic.v1'] = pydantic
|
||||
# sys.modules['pydantic.v1'] = pydantic
|
||||
|
||||
# 检查配置文件
|
||||
|
||||
|
||||
@@ -44,9 +44,9 @@ class WebChatDebugRouterGroup(group.RouterGroup):
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Transfer-Encoding': 'chunked',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive'
|
||||
'Connection': 'keep-alive',
|
||||
}
|
||||
return quart.Response(stream_generator(generator), mimetype='text/event-stream',headers=headers)
|
||||
return quart.Response(stream_generator(generator), mimetype='text/event-stream', headers=headers)
|
||||
|
||||
else: # non-stream
|
||||
result = None
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import base64
|
||||
import quart
|
||||
|
||||
from .....core import taskmgr
|
||||
from .. import group
|
||||
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
|
||||
|
||||
|
||||
@group.group_class('plugins', '/api/v1/plugins')
|
||||
@@ -12,35 +13,22 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
plugins = self.ap.plugin_mgr.plugins()
|
||||
plugins = await self.ap.plugin_connector.list_plugins()
|
||||
|
||||
plugins_data = [plugin.model_dump() for plugin in plugins]
|
||||
|
||||
return self.success(data={'plugins': plugins_data})
|
||||
return self.success(data={'plugins': plugins})
|
||||
|
||||
@self.route(
|
||||
'/<author>/<plugin_name>/toggle',
|
||||
methods=['PUT'],
|
||||
auth_type=group.AuthType.USER_TOKEN,
|
||||
)
|
||||
async def _(author: str, plugin_name: str) -> str:
|
||||
data = await quart.request.json
|
||||
target_enabled = data.get('target_enabled')
|
||||
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
|
||||
return self.success()
|
||||
|
||||
@self.route(
|
||||
'/<author>/<plugin_name>/update',
|
||||
'/<author>/<plugin_name>/upgrade',
|
||||
methods=['POST'],
|
||||
auth_type=group.AuthType.USER_TOKEN,
|
||||
)
|
||||
async def _(author: str, plugin_name: str) -> str:
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
|
||||
self.ap.plugin_connector.upgrade_plugin(author, plugin_name, task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name=f'plugin-update-{plugin_name}',
|
||||
label=f'Updating plugin {plugin_name}',
|
||||
name=f'plugin-upgrade-{plugin_name}',
|
||||
label=f'Upgrading plugin {plugin_name}',
|
||||
context=ctx,
|
||||
)
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
@@ -52,14 +40,14 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
)
|
||||
async def _(author: str, plugin_name: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
|
||||
plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
|
||||
if plugin is None:
|
||||
return self.http_status(404, -1, 'plugin not found')
|
||||
return self.success(data={'plugin': plugin.model_dump()})
|
||||
return self.success(data={'plugin': plugin})
|
||||
elif quart.request.method == 'DELETE':
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
|
||||
self.ap.plugin_connector.delete_plugin(author, plugin_name, task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name=f'plugin-remove-{plugin_name}',
|
||||
label=f'Removing plugin {plugin_name}',
|
||||
@@ -74,23 +62,32 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
auth_type=group.AuthType.USER_TOKEN,
|
||||
)
|
||||
async def _(author: str, plugin_name: str) -> quart.Response:
|
||||
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
|
||||
plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
|
||||
if plugin is None:
|
||||
return self.http_status(404, -1, 'plugin not found')
|
||||
|
||||
if quart.request.method == 'GET':
|
||||
return self.success(data={'config': plugin.plugin_config})
|
||||
return self.success(data={'config': plugin['plugin_config']})
|
||||
elif quart.request.method == 'PUT':
|
||||
data = await quart.request.json
|
||||
|
||||
await self.ap.plugin_mgr.set_plugin_config(plugin, data)
|
||||
await self.ap.plugin_connector.set_plugin_config(author, plugin_name, data)
|
||||
|
||||
return self.success(data={})
|
||||
|
||||
@self.route('/reorder', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
data = await quart.request.json
|
||||
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
|
||||
return self.success()
|
||||
@self.route(
|
||||
'/<author>/<plugin_name>/icon',
|
||||
methods=['GET'],
|
||||
auth_type=group.AuthType.NONE,
|
||||
)
|
||||
async def _(author: str, plugin_name: str) -> quart.Response:
|
||||
icon_data = await self.ap.plugin_connector.get_plugin_icon(author, plugin_name)
|
||||
icon_base64 = icon_data['plugin_icon_base64']
|
||||
mime_type = icon_data['mime_type']
|
||||
|
||||
icon_data = base64.b64decode(icon_base64)
|
||||
|
||||
return quart.Response(icon_data, mimetype=mime_type)
|
||||
|
||||
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
@@ -102,7 +99,47 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name='plugin-install-github',
|
||||
label=f'Installing plugin ...{short_source_str}',
|
||||
label=f'Installing plugin from github ...{short_source_str}',
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
|
||||
@self.route('/install/marketplace', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
data = await quart.request.json
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self.ap.plugin_connector.install_plugin(PluginInstallSource.MARKETPLACE, data, task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name='plugin-install-marketplace',
|
||||
label=f'Installing plugin from marketplace ...{data}',
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
|
||||
@self.route('/install/local', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
file = (await quart.request.files).get('file')
|
||||
if file is None:
|
||||
return self.http_status(400, -1, 'file is required')
|
||||
|
||||
file_bytes = file.read()
|
||||
|
||||
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
|
||||
|
||||
data = {
|
||||
'plugin_file': file_base64,
|
||||
}
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self.ap.plugin_connector.install_plugin(PluginInstallSource.LOCAL, data, task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name='plugin-install-local',
|
||||
label=f'Installing plugin from local ...{file.filename}',
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
|
||||
@@ -14,6 +14,12 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
'version': constants.semantic_version,
|
||||
'debug': constants.debug_mode,
|
||||
'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()),
|
||||
'enable_marketplace': self.ap.instance_config.data['plugin'].get('enable_marketplace', True),
|
||||
'cloud_service_url': (
|
||||
self.ap.instance_config.data['plugin']['cloud_service_url']
|
||||
if 'cloud_service_url' in self.ap.instance_config.data['plugin']
|
||||
else 'https://space.langbot.app'
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -35,16 +41,7 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
|
||||
return self.success(data=task.to_dict())
|
||||
|
||||
@self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
json_data = await quart.request.json
|
||||
|
||||
scope = json_data.get('scope')
|
||||
|
||||
await self.ap.reload(scope=scope)
|
||||
return self.success()
|
||||
|
||||
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
@@ -54,3 +51,39 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
ap = self.ap
|
||||
|
||||
return self.success(data=exec(py_code, {'ap': ap}))
|
||||
|
||||
@self.route('/debug/tools/call', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
|
||||
data = await quart.request.json
|
||||
|
||||
return self.success(
|
||||
data=await self.ap.tool_mgr.execute_func_call(data['tool_name'], data['tool_parameters'])
|
||||
)
|
||||
|
||||
@self.route(
|
||||
'/debug/plugin/action',
|
||||
methods=['POST'],
|
||||
auth_type=group.AuthType.USER_TOKEN,
|
||||
)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
|
||||
data = await quart.request.json
|
||||
|
||||
class AnoymousAction:
|
||||
value = 'anonymous_action'
|
||||
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
resp = await self.ap.plugin_connector.handler.call_action(
|
||||
AnoymousAction(data['action']),
|
||||
data['data'],
|
||||
timeout=data.get('timeout', 10),
|
||||
)
|
||||
|
||||
return self.success(data=resp)
|
||||
|
||||
@@ -71,15 +71,15 @@ class UserRouterGroup(group.RouterGroup):
|
||||
@self.route('/change-password', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(user_email: str) -> str:
|
||||
json_data = await quart.request.json
|
||||
|
||||
|
||||
current_password = json_data['current_password']
|
||||
new_password = json_data['new_password']
|
||||
|
||||
|
||||
try:
|
||||
await self.ap.user_service.change_password(user_email, current_password, new_password)
|
||||
except argon2.exceptions.VerifyMismatchError:
|
||||
return self.http_status(400, -1, 'Current password is incorrect')
|
||||
except ValueError as e:
|
||||
return self.http_status(400, -1, str(e))
|
||||
|
||||
|
||||
return self.success(data={'user': user_email})
|
||||
|
||||
@@ -17,16 +17,20 @@ class BotService:
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_bots(self) -> list[dict]:
|
||||
"""Get all bots"""
|
||||
async def get_bots(self, include_secret: bool = True) -> list[dict]:
|
||||
"""获取所有机器人"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
|
||||
|
||||
bots = result.all()
|
||||
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots]
|
||||
masked_columns = []
|
||||
if not include_secret:
|
||||
masked_columns = ['adapter_config']
|
||||
|
||||
async def get_bot(self, bot_uuid: str) -> dict | None:
|
||||
"""Get bot"""
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns) for bot in bots]
|
||||
|
||||
async def get_bot(self, bot_uuid: str, include_secret: bool = True) -> dict | None:
|
||||
"""获取机器人"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
)
|
||||
@@ -36,7 +40,27 @@ class BotService:
|
||||
if bot is None:
|
||||
return None
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot)
|
||||
masked_columns = []
|
||||
if not include_secret:
|
||||
masked_columns = ['adapter_config']
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns)
|
||||
|
||||
async def get_runtime_bot_info(self, bot_uuid: str, include_secret: bool = True) -> dict:
|
||||
"""获取机器人运行时信息"""
|
||||
persistence_bot = await self.get_bot(bot_uuid, include_secret)
|
||||
if persistence_bot is None:
|
||||
raise Exception('Bot not found')
|
||||
|
||||
adapter_runtime_values = {}
|
||||
|
||||
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||
if runtime_bot is not None:
|
||||
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
|
||||
|
||||
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
|
||||
|
||||
return persistence_bot
|
||||
|
||||
async def create_bot(self, bot_data: dict) -> str:
|
||||
"""Create bot"""
|
||||
|
||||
@@ -7,7 +7,7 @@ from ....core import app
|
||||
from ....entity.persistence import model as persistence_model
|
||||
from ....entity.persistence import pipeline as persistence_pipeline
|
||||
from ....provider.modelmgr import requester as model_requester
|
||||
from ....provider import entities as llm_entities
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
|
||||
|
||||
class LLMModelsService:
|
||||
@@ -16,11 +16,19 @@ class LLMModelsService:
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_llm_models(self) -> list[dict]:
|
||||
async def get_llm_models(self, include_secret: bool = True) -> list[dict]:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
|
||||
|
||||
models = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models]
|
||||
|
||||
masked_columns = []
|
||||
if not include_secret:
|
||||
masked_columns = ['api_keys']
|
||||
|
||||
return [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model, masked_columns)
|
||||
for model in models
|
||||
]
|
||||
|
||||
async def create_llm_model(self, model_data: dict) -> str:
|
||||
model_data['uuid'] = str(uuid.uuid4())
|
||||
@@ -99,7 +107,7 @@ class LLMModelsService:
|
||||
await runtime_llm_model.requester.invoke_llm(
|
||||
query=None,
|
||||
model=runtime_llm_model,
|
||||
messages=[llm_entities.Message(role='user', content='Hello, world!')],
|
||||
messages=[provider_message.Message(role='user', content='Hello, world!')],
|
||||
funcs=[],
|
||||
extra_args=model_data.get('extra_args', {}),
|
||||
)
|
||||
|
||||
@@ -85,15 +85,15 @@ class UserService:
|
||||
|
||||
async def change_password(self, user_email: str, current_password: str, new_password: str) -> None:
|
||||
ph = argon2.PasswordHasher()
|
||||
|
||||
|
||||
user_obj = await self.get_user_by_email(user_email)
|
||||
if user_obj is None:
|
||||
raise ValueError('User not found')
|
||||
|
||||
|
||||
ph.verify(user_obj.password, current_password)
|
||||
|
||||
|
||||
hashed_password = ph.hash(new_password)
|
||||
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(user.User).where(user.User.user == user_email).values(password=hashed_password)
|
||||
)
|
||||
|
||||
@@ -2,9 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from . import entities, operator, errors
|
||||
from ..core import app
|
||||
from . import operator
|
||||
from ..utils import importutil
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
# 引入所有算子以便注册
|
||||
from . import operators
|
||||
@@ -13,13 +16,11 @@ importutil.import_modules_in_pkg(operators)
|
||||
|
||||
|
||||
class CommandManager:
|
||||
"""命令管理器"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
cmd_list: list[operator.CommandOperator]
|
||||
"""
|
||||
运行时命令列表,扁平存储,各个对象包含对应的子节点引用
|
||||
Runtime command list, flat storage, each object contains a reference to the corresponding child node
|
||||
"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
@@ -55,43 +56,28 @@ class CommandManager:
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
context: entities.ExecuteContext,
|
||||
context: command_context.ExecuteContext,
|
||||
operator_list: list[operator.CommandOperator],
|
||||
operator: operator.CommandOperator = None,
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""执行命令"""
|
||||
|
||||
found = False
|
||||
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
|
||||
for oper in operator_list:
|
||||
if (context.crt_params[0] == oper.name or context.crt_params[0] in oper.alias) and (
|
||||
oper.parent_class is None or oper.parent_class == operator.__class__
|
||||
):
|
||||
found = True
|
||||
command_list = await self.ap.plugin_connector.list_commands()
|
||||
|
||||
context.crt_command = context.crt_params[0]
|
||||
context.crt_params = context.crt_params[1:]
|
||||
|
||||
async for ret in self._execute(context, oper.children, oper):
|
||||
yield ret
|
||||
break
|
||||
|
||||
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
|
||||
if operator is None:
|
||||
yield entities.CommandReturn(error=errors.CommandNotFoundError(context.crt_params[0]))
|
||||
else:
|
||||
if operator.lowest_privilege > context.privilege:
|
||||
yield entities.CommandReturn(error=errors.CommandPrivilegeError(operator.name))
|
||||
else:
|
||||
async for ret in operator.execute(context):
|
||||
yield ret
|
||||
for command in command_list:
|
||||
if command.metadata.name == context.command:
|
||||
async for ret in self.ap.plugin_connector.execute_command(context):
|
||||
yield ret
|
||||
break
|
||||
else:
|
||||
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(context.command))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
command_text: str,
|
||||
query: core_entities.Query,
|
||||
session: core_entities.Session,
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
query: pipeline_query.Query,
|
||||
session: provider_session.Session,
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""执行命令"""
|
||||
|
||||
privilege = 1
|
||||
@@ -99,8 +85,8 @@ class CommandManager:
|
||||
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
|
||||
privilege = 2
|
||||
|
||||
ctx = entities.ExecuteContext(
|
||||
query=query,
|
||||
ctx = command_context.ExecuteContext(
|
||||
query_id=query.query_id,
|
||||
session=session,
|
||||
command_text=command_text,
|
||||
command='',
|
||||
@@ -110,5 +96,9 @@ class CommandManager:
|
||||
privilege=privilege,
|
||||
)
|
||||
|
||||
ctx.command = ctx.params[0]
|
||||
|
||||
ctx.shift()
|
||||
|
||||
async for ret in self._execute(ctx, self.cmd_list):
|
||||
yield ret
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ..core import entities as core_entities
|
||||
from . import errors
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
|
||||
class CommandReturn(pydantic.BaseModel):
|
||||
"""命令返回值"""
|
||||
|
||||
text: typing.Optional[str] = None
|
||||
"""文本
|
||||
"""
|
||||
|
||||
image: typing.Optional[platform_message.Image] = None
|
||||
"""弃用"""
|
||||
|
||||
image_url: typing.Optional[str] = None
|
||||
"""图片链接
|
||||
"""
|
||||
|
||||
error: typing.Optional[errors.CommandError] = None
|
||||
"""错误
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ExecuteContext(pydantic.BaseModel):
|
||||
"""单次命令执行上下文"""
|
||||
|
||||
query: core_entities.Query
|
||||
"""本次消息的请求对象"""
|
||||
|
||||
session: core_entities.Session
|
||||
"""本次消息所属的会话对象"""
|
||||
|
||||
command_text: str
|
||||
"""命令完整文本"""
|
||||
|
||||
command: str
|
||||
"""命令名称"""
|
||||
|
||||
crt_command: str
|
||||
"""当前命令
|
||||
|
||||
多级命令中crt_command为当前命令,command为根命令。
|
||||
例如:!plugin on Webwlkr
|
||||
处理到plugin时,command为plugin,crt_command为plugin
|
||||
处理到on时,command为plugin,crt_command为on
|
||||
"""
|
||||
|
||||
params: list[str]
|
||||
"""命令参数
|
||||
|
||||
整个命令以空格分割后的参数列表
|
||||
"""
|
||||
|
||||
crt_params: list[str]
|
||||
"""当前命令参数
|
||||
|
||||
多级命令中crt_params为当前命令参数,params为根命令参数。
|
||||
例如:!plugin on Webwlkr
|
||||
处理到plugin时,params为['on', 'Webwlkr'],crt_params为['on', 'Webwlkr']
|
||||
处理到on时,params为['on', 'Webwlkr'],crt_params为['Webwlkr']
|
||||
"""
|
||||
|
||||
privilege: int
|
||||
"""发起人权限"""
|
||||
@@ -1,26 +0,0 @@
|
||||
class CommandError(Exception):
|
||||
def __init__(self, message: str = None):
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class CommandNotFoundError(CommandError):
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__('未知命令: ' + message)
|
||||
|
||||
|
||||
class CommandPrivilegeError(CommandError):
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__('权限不足: ' + message)
|
||||
|
||||
|
||||
class ParamNotEnoughError(CommandError):
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__('参数不足: ' + message)
|
||||
|
||||
|
||||
class CommandOperationError(CommandError):
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__('操作失败: ' + message)
|
||||
@@ -4,7 +4,7 @@ import typing
|
||||
import abc
|
||||
|
||||
from ..core import app
|
||||
from . import entities
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||
|
||||
|
||||
preregistered_operators: list[typing.Type[CommandOperator]] = []
|
||||
@@ -95,16 +95,18 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""实现此方法以执行命令
|
||||
|
||||
支持多次yield以返回多个结果。
|
||||
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
|
||||
|
||||
Args:
|
||||
context (entities.ExecuteContext): 命令执行上下文
|
||||
context (command_context.ExecuteContext): 命令执行上下文
|
||||
|
||||
Yields:
|
||||
entities.CommandReturn: 命令返回封装
|
||||
command_context.CommandReturn: 命令返回封装
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -2,14 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>')
|
||||
class CmdOperator(operator.CommandOperator):
|
||||
"""命令列表"""
|
||||
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""执行"""
|
||||
if len(context.crt_params) == 0:
|
||||
reply_str = '当前所有命令: \n\n'
|
||||
@@ -20,7 +23,7 @@ class CmdOperator(operator.CommandOperator):
|
||||
|
||||
reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助'
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
yield command_context.CommandReturn(text=reply_str.strip())
|
||||
|
||||
else:
|
||||
cmd_name = context.crt_params[0]
|
||||
@@ -33,9 +36,9 @@ class CmdOperator(operator.CommandOperator):
|
||||
break
|
||||
|
||||
if cmd is None:
|
||||
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(cmd_name))
|
||||
else:
|
||||
reply_str = f'{cmd.name}: {cmd.help}\n\n'
|
||||
reply_str += f'使用方法: \n{cmd.usage}'
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
yield command_context.CommandReturn(text=reply_str.strip())
|
||||
|
||||
@@ -2,23 +2,26 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all')
|
||||
class DelOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if context.session.conversations:
|
||||
delete_index = 0
|
||||
if len(context.crt_params) > 0:
|
||||
try:
|
||||
delete_index = int(context.crt_params[0])
|
||||
except Exception:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引必须是整数'))
|
||||
return
|
||||
|
||||
if delete_index < 0 or delete_index >= len(context.session.conversations):
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引超出范围'))
|
||||
return
|
||||
|
||||
# 倒序
|
||||
@@ -29,15 +32,17 @@ class DelOperator(operator.CommandOperator):
|
||||
|
||||
del context.session.conversations[to_delete_index]
|
||||
|
||||
yield entities.CommandReturn(text=f'已删除对话: {delete_index}')
|
||||
yield command_context.CommandReturn(text=f'已删除对话: {delete_index}')
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||
|
||||
|
||||
@operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator)
|
||||
class DelAllOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
context.session.conversations = []
|
||||
context.session.using_conversation = None
|
||||
|
||||
yield entities.CommandReturn(text='已删除所有对话')
|
||||
yield command_context.CommandReturn(text='已删除所有对话')
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from .. import operator, entities
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||
|
||||
|
||||
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
|
||||
class FuncOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> AsyncGenerator[command_context.CommandReturn, None]:
|
||||
reply_str = '当前已启用的内容函数: \n\n'
|
||||
|
||||
index = 1
|
||||
|
||||
all_functions = await self.ap.tool_mgr.get_all_functions(
|
||||
plugin_enabled=True,
|
||||
)
|
||||
all_functions = await self.ap.tool_mgr.get_all_tools()
|
||||
|
||||
for func in all_functions:
|
||||
reply_str += '{}. {}:\n{}\n\n'.format(
|
||||
@@ -23,4 +24,4 @@ class FuncOperator(operator.CommandOperator):
|
||||
)
|
||||
index += 1
|
||||
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
yield command_context.CommandReturn(text=reply_str)
|
||||
|
||||
@@ -2,14 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||
|
||||
|
||||
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
|
||||
class HelpOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app'
|
||||
|
||||
help += '\n发送命令 !cmd 可查看命令列表'
|
||||
|
||||
yield entities.CommandReturn(text=help)
|
||||
yield command_context.CommandReturn(text=help)
|
||||
|
||||
@@ -3,26 +3,31 @@ from __future__ import annotations
|
||||
import typing
|
||||
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
|
||||
class LastOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if context.session.conversations:
|
||||
# 找到当前会话的上一个会话
|
||||
for index in range(len(context.session.conversations) - 1, -1, -1):
|
||||
if context.session.conversations[index] == context.session.using_conversation:
|
||||
if index == 0:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
|
||||
yield command_context.CommandReturn(
|
||||
error=command_errors.CommandOperationError('已经是第一个对话了')
|
||||
)
|
||||
return
|
||||
else:
|
||||
context.session.using_conversation = context.session.conversations[index - 1]
|
||||
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
yield entities.CommandReturn(
|
||||
yield command_context.CommandReturn(
|
||||
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
|
||||
)
|
||||
return
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||
|
||||
@@ -2,19 +2,22 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>')
|
||||
class ListOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
page = 0
|
||||
|
||||
if len(context.crt_params) > 0:
|
||||
try:
|
||||
page = int(context.crt_params[0] - 1)
|
||||
except Exception:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('页码应为整数'))
|
||||
return
|
||||
|
||||
record_per_page = 10
|
||||
@@ -45,4 +48,4 @@ class ListOperator(operator.CommandOperator):
|
||||
else:
|
||||
content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}'
|
||||
|
||||
yield entities.CommandReturn(text=f'第 {page + 1} 页 (时间倒序):\n{content}')
|
||||
yield command_context.CommandReturn(text=f'第 {page + 1} 页 (时间倒序):\n{content}')
|
||||
|
||||
@@ -2,26 +2,31 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
|
||||
class NextOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if context.session.conversations:
|
||||
# 找到当前会话的下一个会话
|
||||
for index in range(len(context.session.conversations)):
|
||||
if context.session.conversations[index] == context.session.using_conversation:
|
||||
if index == len(context.session.conversations) - 1:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
|
||||
yield command_context.CommandReturn(
|
||||
error=command_errors.CommandOperationError('已经是最后一个对话了')
|
||||
)
|
||||
return
|
||||
else:
|
||||
context.session.using_conversation = context.session.conversations[index + 1]
|
||||
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
yield entities.CommandReturn(
|
||||
yield command_context.CommandReturn(
|
||||
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
|
||||
)
|
||||
return
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||
|
||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
@@ -11,7 +12,9 @@ from .. import operator, entities, errors
|
||||
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
|
||||
)
|
||||
class PluginOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
plugin_list = self.ap.plugin_mgr.plugins()
|
||||
reply_str = '所有插件({}):\n'.format(len(plugin_list))
|
||||
idx = 0
|
||||
@@ -27,32 +30,36 @@ class PluginOperator(operator.CommandOperator):
|
||||
|
||||
idx += 1
|
||||
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
yield command_context.CommandReturn(text=reply_str)
|
||||
|
||||
|
||||
@operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginGetOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
|
||||
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件仓库地址'))
|
||||
else:
|
||||
repo = context.crt_params[0]
|
||||
|
||||
yield entities.CommandReturn(text='正在安装插件...')
|
||||
yield command_context.CommandReturn(text='正在安装插件...')
|
||||
|
||||
try:
|
||||
await self.ap.plugin_mgr.install_plugin(repo)
|
||||
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件')
|
||||
yield command_context.CommandReturn(text='插件安装成功,请重启程序以加载插件')
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件安装失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件安装失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginUpdateOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
@@ -60,24 +67,26 @@ class PluginUpdateOperator(operator.CommandOperator):
|
||||
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
||||
|
||||
if plugin_container is not None:
|
||||
yield entities.CommandReturn(text='正在更新插件...')
|
||||
yield command_context.CommandReturn(text='正在更新插件...')
|
||||
await self.ap.plugin_mgr.update_plugin(plugin_name)
|
||||
yield entities.CommandReturn(text='插件更新成功,请重启程序以加载插件')
|
||||
yield command_context.CommandReturn(text='插件更新成功,请重启程序以加载插件')
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: 未找到插件'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: 未找到插件'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator)
|
||||
class PluginUpdateAllOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
try:
|
||||
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
|
||||
|
||||
if plugins:
|
||||
yield entities.CommandReturn(text='正在更新插件...')
|
||||
yield command_context.CommandReturn(text='正在更新插件...')
|
||||
updated = []
|
||||
try:
|
||||
for plugin_name in plugins:
|
||||
@@ -85,20 +94,22 @@ class PluginUpdateAllOperator(operator.CommandOperator):
|
||||
updated.append(plugin_name)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
||||
yield entities.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
|
||||
else:
|
||||
yield entities.CommandReturn(text='没有可更新的插件')
|
||||
yield command_context.CommandReturn(text='没有可更新的插件')
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginDelOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
@@ -106,51 +117,55 @@ class PluginDelOperator(operator.CommandOperator):
|
||||
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
||||
|
||||
if plugin_container is not None:
|
||||
yield entities.CommandReturn(text='正在删除插件...')
|
||||
yield command_context.CommandReturn(text='正在删除插件...')
|
||||
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
|
||||
yield entities.CommandReturn(text='插件删除成功,请重启程序以加载插件')
|
||||
yield command_context.CommandReturn(text='插件删除成功,请重启程序以加载插件')
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: 未找到插件'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: 未找到插件'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginEnableOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
|
||||
yield entities.CommandReturn(text='已启用插件: {}'.format(plugin_name))
|
||||
yield command_context.CommandReturn(text='已启用插件: {}'.format(plugin_name))
|
||||
else:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||
yield command_context.CommandReturn(
|
||||
error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginDisableOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
|
||||
yield entities.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
|
||||
yield command_context.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
|
||||
else:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||
yield command_context.CommandReturn(
|
||||
error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||
|
||||
@@ -2,19 +2,22 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
|
||||
class PromptOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""执行"""
|
||||
if context.session.using_conversation is None:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||
else:
|
||||
reply_str = '当前对话所有内容:\n\n'
|
||||
|
||||
for msg in context.session.using_conversation.messages:
|
||||
reply_str += f'{msg.role}: {msg.content}\n'
|
||||
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
yield command_context.CommandReturn(text=reply_str)
|
||||
|
||||
@@ -2,15 +2,18 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend')
|
||||
class ResendOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
# 回滚到最后一条用户message前
|
||||
if context.session.using_conversation is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError('当前没有对话'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('当前没有对话'))
|
||||
else:
|
||||
conv_msg = context.session.using_conversation.messages
|
||||
|
||||
@@ -23,4 +26,4 @@ class ResendOperator(operator.CommandOperator):
|
||||
conv_msg.pop()
|
||||
|
||||
# 不重发了,提示用户已删除就行了
|
||||
yield entities.CommandReturn(text='已删除最后一次请求记录')
|
||||
yield command_context.CommandReturn(text='已删除最后一次请求记录')
|
||||
|
||||
@@ -2,13 +2,16 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||
|
||||
|
||||
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
|
||||
class ResetOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""执行"""
|
||||
context.session.using_conversation = None
|
||||
|
||||
yield entities.CommandReturn(text='已重置当前会话')
|
||||
yield command_context.CommandReturn(text='已重置当前会话')
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities
|
||||
|
||||
|
||||
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
|
||||
class UpdateCommand(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
yield entities.CommandReturn(text='不再支持通过命令更新,请查看 LangBot 文档。')
|
||||
@@ -2,12 +2,15 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||
|
||||
|
||||
@operator.operator_class(name='version', help='显示版本信息', usage='!version')
|
||||
class VersionCommand(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
|
||||
|
||||
try:
|
||||
@@ -16,4 +19,4 @@ class VersionCommand(operator.CommandOperator):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
yield command_context.CommandReturn(text=reply_str.strip())
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
import sys
|
||||
import os
|
||||
|
||||
from ..platform import botmgr as im_mgr
|
||||
@@ -12,7 +11,7 @@ from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import manager as plugin_mgr
|
||||
from ..plugin import connector as plugin_connector
|
||||
from ..pipeline import pool
|
||||
from ..pipeline import controller, pipelinemgr
|
||||
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
|
||||
@@ -80,7 +79,7 @@ class Application:
|
||||
|
||||
# =========================
|
||||
|
||||
plugin_mgr: plugin_mgr.PluginManager = None
|
||||
plugin_connector: plugin_connector.PluginRuntimeConnector = None
|
||||
|
||||
query_pool: pool.QueryPool = None
|
||||
|
||||
@@ -128,7 +127,7 @@ class Application:
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
await self.plugin_connector.initialize_plugins()
|
||||
|
||||
# 后续可能会允许动态重启其他任务
|
||||
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
|
||||
@@ -169,6 +168,9 @@ class Application:
|
||||
self.logger.error(f'Application runtime fatal exception: {e}')
|
||||
self.logger.debug(f'Traceback: {traceback.format_exc()}')
|
||||
|
||||
def dispose(self):
|
||||
self.plugin_connector.dispose()
|
||||
|
||||
async def print_web_access_info(self):
|
||||
"""Print access webui tips"""
|
||||
|
||||
@@ -195,59 +197,3 @@ class Application:
|
||||
""".strip()
|
||||
for line in tips.split('\n'):
|
||||
self.logger.info(line)
|
||||
|
||||
async def reload(
|
||||
self,
|
||||
scope: core_entities.LifecycleControlScope,
|
||||
):
|
||||
match scope:
|
||||
case core_entities.LifecycleControlScope.PLATFORM.value:
|
||||
self.logger.info('Hot reload scope=' + scope)
|
||||
await self.platform_mgr.shutdown()
|
||||
|
||||
self.platform_mgr = im_mgr.PlatformManager(self)
|
||||
|
||||
await self.platform_mgr.initialize()
|
||||
|
||||
self.task_mgr.create_task(
|
||||
self.platform_mgr.run(),
|
||||
name='platform-manager',
|
||||
scopes=[
|
||||
core_entities.LifecycleControlScope.APPLICATION,
|
||||
core_entities.LifecycleControlScope.PLATFORM,
|
||||
],
|
||||
)
|
||||
case core_entities.LifecycleControlScope.PLUGIN.value:
|
||||
self.logger.info('Hot reload scope=' + scope)
|
||||
await self.plugin_mgr.destroy_plugins()
|
||||
|
||||
# 删除 sys.module 中所有的 plugins/* 下的模块
|
||||
for mod in list(sys.modules.keys()):
|
||||
if mod.startswith('plugins.'):
|
||||
del sys.modules[mod]
|
||||
|
||||
self.plugin_mgr = plugin_mgr.PluginManager(self)
|
||||
await self.plugin_mgr.initialize()
|
||||
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
|
||||
await self.plugin_mgr.load_plugins()
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
case core_entities.LifecycleControlScope.PROVIDER.value:
|
||||
self.logger.info('Hot reload scope=' + scope)
|
||||
|
||||
await self.tool_mgr.shutdown()
|
||||
|
||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(self)
|
||||
await llm_model_mgr_inst.initialize()
|
||||
self.model_mgr = llm_model_mgr_inst
|
||||
|
||||
llm_session_mgr_inst = llm_session_mgr.SessionManager(self)
|
||||
await llm_session_mgr_inst.initialize()
|
||||
self.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
self.tool_mgr = llm_tool_mgr_inst
|
||||
case _:
|
||||
pass
|
||||
|
||||
@@ -51,8 +51,8 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
import signal
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
app_inst.dispose()
|
||||
print('[Signal] Program exit.')
|
||||
# ap.shutdown()
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@@ -1,18 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import typing
|
||||
import datetime
|
||||
import asyncio
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ..provider import entities as llm_entities
|
||||
from ..provider.modelmgr import requester
|
||||
from ..provider.tools import entities as tools_entities
|
||||
from ..platform import adapter as msadapter
|
||||
from ..platform.types import message as platform_message
|
||||
from ..platform.types import events as platform_events
|
||||
|
||||
|
||||
class LifecycleControlScope(enum.Enum):
|
||||
@@ -20,159 +8,3 @@ class LifecycleControlScope(enum.Enum):
|
||||
PLATFORM = 'platform'
|
||||
PLUGIN = 'plugin'
|
||||
PROVIDER = 'provider'
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
"""一个请求的发起者类型"""
|
||||
|
||||
PERSON = 'person'
|
||||
"""私聊"""
|
||||
|
||||
GROUP = 'group'
|
||||
"""群聊"""
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
"""一次请求的信息封装"""
|
||||
|
||||
query_id: int
|
||||
"""请求ID,添加进请求池时生成"""
|
||||
|
||||
launcher_type: LauncherTypes
|
||||
"""会话类型,platform处理阶段设置"""
|
||||
|
||||
launcher_id: typing.Union[int, str]
|
||||
"""会话ID,platform处理阶段设置"""
|
||||
|
||||
sender_id: typing.Union[int, str]
|
||||
"""发送者ID,platform处理阶段设置"""
|
||||
|
||||
message_event: platform_events.MessageEvent
|
||||
"""事件,platform收到的原始事件"""
|
||||
|
||||
message_chain: platform_message.MessageChain
|
||||
"""消息链,platform收到的原始消息链"""
|
||||
|
||||
bot_uuid: typing.Optional[str] = None
|
||||
"""机器人UUID。"""
|
||||
|
||||
pipeline_uuid: typing.Optional[str] = None
|
||||
"""流水线UUID。"""
|
||||
|
||||
pipeline_config: typing.Optional[dict[str, typing.Any]] = None
|
||||
"""流水线配置,由 Pipeline 在运行开始时设置。"""
|
||||
|
||||
adapter: msadapter.MessagePlatformAdapter
|
||||
"""消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器"""
|
||||
|
||||
session: typing.Optional[Session] = None
|
||||
"""会话对象,由前置处理器阶段设置"""
|
||||
|
||||
messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""历史消息列表,由前置处理器阶段设置"""
|
||||
|
||||
prompt: typing.Optional[llm_entities.Prompt] = None
|
||||
"""情景预设内容,由前置处理器阶段设置"""
|
||||
|
||||
user_message: typing.Optional[llm_entities.Message] = None
|
||||
"""此次请求的用户消息对象,由前置处理器阶段设置"""
|
||||
|
||||
variables: typing.Optional[dict[str, typing.Any]] = None
|
||||
"""变量,由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。"""
|
||||
|
||||
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
|
||||
"""使用的对话模型,由前置处理器阶段设置"""
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
|
||||
"""使用的函数,由前置处理器阶段设置"""
|
||||
|
||||
resp_messages: (
|
||||
typing.Optional[list[llm_entities.Message]]
|
||||
| typing.Optional[list[platform_message.MessageChain]]
|
||||
| typing.Optional[list[llm_entities.MessageChunk]]
|
||||
) = []
|
||||
"""由Process阶段生成的回复消息对象列表"""
|
||||
|
||||
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
|
||||
"""回复消息链,从resp_messages包装而得"""
|
||||
|
||||
# ======= 内部保留 =======
|
||||
current_stage: typing.Optional['pkg.pipeline.pipelinemgr.StageInstContainer'] = None
|
||||
"""当前所处阶段"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# ========== 插件可调用的 API(请求 API) ==========
|
||||
|
||||
def set_variable(self, key: str, value: typing.Any):
|
||||
"""设置变量"""
|
||||
if self.variables is None:
|
||||
self.variables = {}
|
||||
self.variables[key] = value
|
||||
|
||||
def get_variable(self, key: str) -> typing.Any:
|
||||
"""获取变量"""
|
||||
if self.variables is None:
|
||||
return None
|
||||
return self.variables.get(key)
|
||||
|
||||
def get_variables(self) -> dict[str, typing.Any]:
|
||||
"""获取所有变量"""
|
||||
if self.variables is None:
|
||||
return {}
|
||||
return self.variables
|
||||
|
||||
|
||||
class Conversation(pydantic.BaseModel):
|
||||
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation,但只有一个当前使用的 Conversation"""
|
||||
|
||||
prompt: llm_entities.Prompt
|
||||
|
||||
messages: list[llm_entities.Message]
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
|
||||
|
||||
pipeline_uuid: str
|
||||
"""流水线UUID。"""
|
||||
|
||||
bot_uuid: str
|
||||
"""机器人UUID。"""
|
||||
|
||||
uuid: typing.Optional[str] = None
|
||||
"""该对话的 uuid,在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class Session(pydantic.BaseModel):
|
||||
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
|
||||
|
||||
launcher_type: LauncherTypes
|
||||
|
||||
launcher_id: typing.Union[int, str]
|
||||
|
||||
sender_id: typing.Optional[typing.Union[int, str]] = 0
|
||||
|
||||
use_prompt_name: typing.Optional[str] = 'default'
|
||||
|
||||
using_conversation: typing.Optional[Conversation] = None
|
||||
|
||||
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
semaphore: typing.Optional[asyncio.Semaphore] = None
|
||||
"""当前会话的信号量,用于限制并发"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from .. import stage, app
|
||||
from ...utils import version, proxy, announce
|
||||
from ...pipeline import pool, controller, pipelinemgr
|
||||
from ...plugin import manager as plugin_mgr
|
||||
from ...plugin import connector as plugin_connector
|
||||
from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
@@ -62,10 +63,13 @@ class BuildAppStage(stage.BootingStage):
|
||||
ap.persistence_mgr = persistence_mgr_inst
|
||||
await persistence_mgr_inst.initialize()
|
||||
|
||||
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
|
||||
await plugin_mgr_inst.initialize()
|
||||
ap.plugin_mgr = plugin_mgr_inst
|
||||
await plugin_mgr_inst.load_plugins()
|
||||
async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None:
|
||||
await asyncio.sleep(3)
|
||||
await plugin_connector_inst.initialize()
|
||||
|
||||
plugin_connector_inst = plugin_connector.PluginRuntimeConnector(ap, runtime_disconnect_callback)
|
||||
await plugin_connector_inst.initialize()
|
||||
ap.plugin_connector = plugin_connector_inst
|
||||
|
||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||
await cmd_mgr_inst.initialize()
|
||||
|
||||
22
pkg/entity/persistence/bstorage.py
Normal file
22
pkg/entity/persistence/bstorage.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import sqlalchemy
|
||||
|
||||
from .base import Base
|
||||
|
||||
|
||||
class BinaryStorage(Base):
|
||||
"""Current for plugin use only"""
|
||||
|
||||
__tablename__ = 'binary_storages'
|
||||
|
||||
unique_key = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
key = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
owner_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
owner = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
value = sqlalchemy.Column(sqlalchemy.LargeBinary, nullable=False)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
server_default=sqlalchemy.func.now(),
|
||||
onupdate=sqlalchemy.func.now(),
|
||||
)
|
||||
@@ -13,6 +13,8 @@ class PluginSetting(Base):
|
||||
enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
|
||||
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
|
||||
install_source = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, default='github')
|
||||
install_info = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
|
||||
@@ -44,6 +44,38 @@ class PersistenceManager:
|
||||
|
||||
await self.create_tables()
|
||||
|
||||
# run migrations
|
||||
database_version = await self.execute_async(
|
||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
|
||||
)
|
||||
|
||||
database_version = int(database_version.fetchone()[1])
|
||||
required_database_version = constants.required_database_version
|
||||
|
||||
if database_version < required_database_version:
|
||||
migrations = migration.preregistered_db_migrations
|
||||
migrations.sort(key=lambda x: x.number)
|
||||
|
||||
last_migration_number = database_version
|
||||
|
||||
for migration_cls in migrations:
|
||||
migration_instance = migration_cls(self.ap)
|
||||
|
||||
if (
|
||||
migration_instance.number > database_version
|
||||
and migration_instance.number <= required_database_version
|
||||
):
|
||||
await migration_instance.upgrade()
|
||||
await self.execute_async(
|
||||
sqlalchemy.update(metadata.Metadata)
|
||||
.where(metadata.Metadata.key == 'database_version')
|
||||
.values({'value': str(migration_instance.number)})
|
||||
)
|
||||
last_migration_number = migration_instance.number
|
||||
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
|
||||
|
||||
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
|
||||
|
||||
async def create_tables(self):
|
||||
# create tables
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
@@ -87,38 +119,6 @@ class PersistenceManager:
|
||||
|
||||
# =================================
|
||||
|
||||
# run migrations
|
||||
database_version = await self.execute_async(
|
||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
|
||||
)
|
||||
|
||||
database_version = int(database_version.fetchone()[1])
|
||||
required_database_version = constants.required_database_version
|
||||
|
||||
if database_version < required_database_version:
|
||||
migrations = migration.preregistered_db_migrations
|
||||
migrations.sort(key=lambda x: x.number)
|
||||
|
||||
last_migration_number = database_version
|
||||
|
||||
for migration_cls in migrations:
|
||||
migration_instance = migration_cls(self.ap)
|
||||
|
||||
if (
|
||||
migration_instance.number > database_version
|
||||
and migration_instance.number <= required_database_version
|
||||
):
|
||||
await migration_instance.upgrade()
|
||||
await self.execute_async(
|
||||
sqlalchemy.update(metadata.Metadata)
|
||||
.where(metadata.Metadata.key == 'database_version')
|
||||
.values({'value': str(migration_instance.number)})
|
||||
)
|
||||
last_migration_number = migration_instance.number
|
||||
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
|
||||
|
||||
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
|
||||
|
||||
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
result = await conn.execute(*args, **kwargs)
|
||||
@@ -128,10 +128,13 @@ class PersistenceManager:
|
||||
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
|
||||
return self.db.get_engine()
|
||||
|
||||
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict:
|
||||
def serialize_model(
|
||||
self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base, masked_columns: list[str] = []
|
||||
) -> dict:
|
||||
return {
|
||||
column.name: getattr(data, column.name)
|
||||
if not isinstance(getattr(data, column.name), (datetime.datetime))
|
||||
else getattr(data, column.name).isoformat()
|
||||
for column in model.__table__.columns
|
||||
if column.name not in masked_columns
|
||||
}
|
||||
|
||||
20
pkg/persistence/migrations/dbm004_plugin_config.py
Normal file
20
pkg/persistence/migrations/dbm004_plugin_config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(4)
|
||||
class DBMigratePluginConfig(migration.DBMigration):
|
||||
"""插件配置"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""升级"""
|
||||
|
||||
if 'plugin' not in self.ap.instance_config.data:
|
||||
self.ap.instance_config.data['plugin'] = {
|
||||
'runtime_ws_url': 'ws://localhost:5400/control/ws',
|
||||
}
|
||||
|
||||
await self.ap.instance_config.dump_config()
|
||||
|
||||
async def downgrade(self):
|
||||
"""降级"""
|
||||
pass
|
||||
32
pkg/persistence/migrations/dbm006_plugin_install_source.py
Normal file
32
pkg/persistence/migrations/dbm006_plugin_install_source.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(6)
|
||||
class DBMigratePluginInstallSource(migration.DBMigration):
|
||||
"""插件安装来源"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""升级"""
|
||||
# 查询表结构获取所有列名(异步执行 SQL)
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text('PRAGMA table_info(plugin_settings);'))
|
||||
# fetchall() 是同步方法,无需 await
|
||||
columns = [row[1] for row in result.fetchall()]
|
||||
|
||||
# 检查并添加 install_source 列
|
||||
if 'install_source' not in columns:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
"ALTER TABLE plugin_settings ADD COLUMN install_source VARCHAR(255) NOT NULL DEFAULT 'github'"
|
||||
)
|
||||
)
|
||||
|
||||
# 检查并添加 install_info 列
|
||||
if 'install_info' not in columns:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("ALTER TABLE plugin_settings ADD COLUMN install_info JSON NOT NULL DEFAULT '{}'")
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""降级"""
|
||||
pass
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@stage.stage_class('BanSessionCheckStage')
|
||||
@@ -14,7 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
pass
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
found = False
|
||||
|
||||
mode = query.pipeline_config['trigger']['access-control']['mode']
|
||||
|
||||
@@ -3,12 +3,11 @@ from __future__ import annotations
|
||||
from ...core import app
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from . import filter as filter_model, entities as filter_entities
|
||||
from ...provider import entities as llm_entities
|
||||
from ...platform.types import message as platform_message
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
from ...utils import importutil
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from . import filters
|
||||
|
||||
importutil.import_modules_in_pkg(filters)
|
||||
@@ -58,7 +57,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
async def _pre_process(
|
||||
self,
|
||||
message: str,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
"""请求llm前处理消息
|
||||
只要有一个不通过就不放行,只放行 PASS 的消息
|
||||
@@ -86,14 +85,14 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
||||
message = result.replacement
|
||||
|
||||
query.message_chain = platform_message.MessageChain(platform_message.Plain(message))
|
||||
query.message_chain = platform_message.MessageChain([platform_message.Plain(text=message)])
|
||||
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
async def _post_process(
|
||||
self,
|
||||
message: str,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
"""请求llm后处理响应
|
||||
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
||||
@@ -123,7 +122,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
if stage_inst_name == 'PreContentFilterStage':
|
||||
contain_non_text = False
|
||||
@@ -142,7 +141,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
return await self._pre_process(str(query.message_chain).strip(), query)
|
||||
elif stage_inst_name == 'PostContentFilterStage':
|
||||
# 仅处理 query.resp_messages[-1].content 是 str 的情况
|
||||
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(
|
||||
if isinstance(query.resp_messages[-1], provider_message.Message) and isinstance(
|
||||
query.resp_messages[-1].content, str
|
||||
):
|
||||
return await self._post_process(query.resp_messages[-1].content, query)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
import pydantic
|
||||
|
||||
|
||||
class ResultLevel(enum.Enum):
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
||||
|
||||
@@ -60,8 +60,8 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult:
|
||||
"""Process message
|
||||
async def process(self, query: pipeline_query.Query, message: str = None, image_url=None) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
It is divided into two stages, depending on the value of enable_stages.
|
||||
For content filters, you do not need to consider the stage of the message, you only need to check the message content.
|
||||
|
||||
@@ -4,8 +4,7 @@ import aiohttp
|
||||
|
||||
from .. import entities
|
||||
from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
|
||||
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
|
||||
@@ -27,7 +26,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
) as resp:
|
||||
return (await resp.json())['access_token']
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
|
||||
from .. import filter as filter_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@filter_model.filter_class('ban-word-filter')
|
||||
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||
found = False
|
||||
|
||||
for word in self.ap.sensitive_meta.data['words']:
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
|
||||
from .. import entities
|
||||
from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@filter_model.filter_class('content-ignore')
|
||||
@@ -16,7 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
entities.EnableStage.PRE,
|
||||
]
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
||||
if message.startswith(rule):
|
||||
|
||||
@@ -3,7 +3,10 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from ..core import app, entities
|
||||
from ..core import app
|
||||
from ..core import entities as core_entities
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
class Controller:
|
||||
@@ -22,19 +25,19 @@ class Controller:
|
||||
"""事件处理循环"""
|
||||
try:
|
||||
while True:
|
||||
selected_query: entities.Query = None
|
||||
selected_query: pipeline_query.Query = None
|
||||
|
||||
# 取请求
|
||||
async with self.ap.query_pool:
|
||||
queries: list[entities.Query] = self.ap.query_pool.queries
|
||||
queries: list[pipeline_query.Query] = self.ap.query_pool.queries
|
||||
|
||||
for query in queries:
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
self.ap.logger.debug(f'Checking query {query} session {session}')
|
||||
|
||||
if not session.semaphore.locked():
|
||||
if not session._semaphore.locked():
|
||||
selected_query = query
|
||||
await session.semaphore.acquire()
|
||||
await session._semaphore.acquire()
|
||||
|
||||
break
|
||||
|
||||
@@ -46,7 +49,7 @@ class Controller:
|
||||
|
||||
if selected_query:
|
||||
|
||||
async def _process_query(selected_query: entities.Query):
|
||||
async def _process_query(selected_query: pipeline_query.Query):
|
||||
async with self.semaphore: # 总并发上限
|
||||
# find pipeline
|
||||
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
||||
@@ -59,7 +62,7 @@ class Controller:
|
||||
await pipeline.run(selected_query)
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
(await self.ap.sess_mgr.get_session(selected_query))._semaphore.release()
|
||||
# 通知其他协程,有新的请求可以处理了
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
|
||||
@@ -68,8 +71,8 @@ class Controller:
|
||||
kind='query',
|
||||
name=f'query-{selected_query.query_id}',
|
||||
scopes=[
|
||||
entities.LifecycleControlScope.APPLICATION,
|
||||
entities.LifecycleControlScope.PLATFORM,
|
||||
core_entities.LifecycleControlScope.APPLICATION,
|
||||
core_entities.LifecycleControlScope.PLATFORM,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ from __future__ import annotations
|
||||
import enum
|
||||
import typing
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
from ..platform.types import message as platform_message
|
||||
import pydantic
|
||||
|
||||
from ..core import entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
|
||||
|
||||
class ResultType(enum.Enum):
|
||||
@@ -20,7 +20,7 @@ class ResultType(enum.Enum):
|
||||
class StageProcessResult(pydantic.BaseModel):
|
||||
result_type: ResultType
|
||||
|
||||
new_query: entities.Query
|
||||
new_query: pipeline_query.Query
|
||||
|
||||
user_notice: typing.Optional[
|
||||
typing.Union[
|
||||
|
||||
@@ -5,10 +5,9 @@ import traceback
|
||||
|
||||
from . import strategy
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
from ...utils import importutil
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from . import strategies
|
||||
|
||||
importutil.import_modules_in_pkg(strategies)
|
||||
@@ -67,8 +66,8 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
|
||||
await self.strategy_impl.initialize()
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
# Check if it contains non-Plain components
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
# 检查是否包含非 Plain 组件
|
||||
contains_non_plain = False
|
||||
|
||||
for msg in query.resp_message_chain[-1]:
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
|
||||
ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
|
||||
Forward = platform_message.Forward
|
||||
@@ -13,7 +13,7 @@ Forward = platform_message.Forward
|
||||
|
||||
@strategy_model.strategy_class('forward')
|
||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||
display = ForwardMessageDiaplay(
|
||||
title='Group chat history',
|
||||
brief='[Chat history]',
|
||||
|
||||
@@ -8,10 +8,10 @@ import re
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import functools
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
|
||||
|
||||
@strategy_model.strategy_class('image')
|
||||
@@ -27,7 +27,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
encoding='utf-8',
|
||||
)
|
||||
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||
img_path = self.text_to_image(
|
||||
text_str=message,
|
||||
save_as='temp/{}.png'.format(int(time.time())),
|
||||
@@ -131,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
text_str: str,
|
||||
save_as='temp.png',
|
||||
width=800,
|
||||
query: core_entities.Query = None,
|
||||
query: pipeline_query.Query = None,
|
||||
):
|
||||
text_str = text_str.replace('\t', ' ')
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@ import typing
|
||||
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
|
||||
@@ -49,8 +50,8 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
"""Process long text
|
||||
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||
"""处理长文本
|
||||
|
||||
If the text length exceeds the threshold, this method will be called.
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from . import truncator
|
||||
from ...utils import importutil
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from . import truncators
|
||||
|
||||
importutil.import_modules_in_pkg(truncators)
|
||||
@@ -29,8 +28,8 @@ class ConversationMessageTruncator(stage.PipelineStage):
|
||||
else:
|
||||
raise ValueError(f'Unknown truncator: {use_method}')
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""Process"""
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
query = await self.trun.truncate(query)
|
||||
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
import typing
|
||||
import abc
|
||||
|
||||
from ...core import entities as core_entities, app
|
||||
|
||||
from ...core import app
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
preregistered_truncators: list[typing.Type[Truncator]] = []
|
||||
|
||||
@@ -47,7 +47,7 @@ class Truncator(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
||||
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
|
||||
"""截断
|
||||
|
||||
一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import truncator
|
||||
from ....core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@truncator.truncator_class('round')
|
||||
class RoundTruncator(truncator.Truncator):
|
||||
"""Truncate the conversation message chain to adapt to the LLM message length limit."""
|
||||
|
||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
||||
"""Truncate"""
|
||||
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
|
||||
"""截断"""
|
||||
max_round = query.pipeline_config['ai']['local-agent']['max-round']
|
||||
|
||||
temp_messages = []
|
||||
|
||||
@@ -5,14 +5,18 @@ import traceback
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ..core import app, entities
|
||||
from ..core import app
|
||||
from . import entities as pipeline_entities
|
||||
from ..entity.persistence import pipeline as persistence_pipeline
|
||||
from . import stage
|
||||
from ..platform.types import message as platform_message, events as platform_events
|
||||
from ..plugin import events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from ..utils import importutil
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
from . import (
|
||||
resprule,
|
||||
bansess,
|
||||
@@ -75,17 +79,17 @@ class RuntimePipeline:
|
||||
self.pipeline_entity = pipeline_entity
|
||||
self.stage_containers = stage_containers
|
||||
|
||||
async def run(self, query: entities.Query):
|
||||
async def run(self, query: pipeline_query.Query):
|
||||
query.pipeline_config = self.pipeline_entity.config
|
||||
await self.process_query(query)
|
||||
|
||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
||||
async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出"""
|
||||
if result.user_notice:
|
||||
# 处理str类型
|
||||
|
||||
if isinstance(result.user_notice, str):
|
||||
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice))
|
||||
result.user_notice = platform_message.MessageChain([platform_message.Plain(text=result.user_notice)])
|
||||
elif isinstance(result.user_notice, list):
|
||||
result.user_notice = platform_message.MessageChain(*result.user_notice)
|
||||
|
||||
@@ -99,7 +103,7 @@ class RuntimePipeline:
|
||||
bot_message=query.resp_messages[-1],
|
||||
message=result.user_notice,
|
||||
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
|
||||
is_final=[msg.is_final for msg in query.resp_messages][0]
|
||||
is_final=[msg.is_final for msg in query.resp_messages][0],
|
||||
)
|
||||
else:
|
||||
await query.adapter.reply_message(
|
||||
@@ -117,7 +121,7 @@ class RuntimePipeline:
|
||||
async def _execute_from_stage(
|
||||
self,
|
||||
stage_index: int,
|
||||
query: entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
):
|
||||
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
||||
|
||||
@@ -144,7 +148,7 @@ class RuntimePipeline:
|
||||
while i < len(self.stage_containers):
|
||||
stage_container = self.stage_containers[i]
|
||||
|
||||
query.current_stage = stage_container # 标记到 Query 对象里
|
||||
query.current_stage_name = stage_container.inst_name # 标记到 Query 对象里
|
||||
|
||||
result = stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
@@ -181,26 +185,26 @@ class RuntimePipeline:
|
||||
|
||||
i += 1
|
||||
|
||||
async def process_query(self, query: entities.Query):
|
||||
async def process_query(self, query: pipeline_query.Query):
|
||||
"""处理请求"""
|
||||
try:
|
||||
# ======== 触发 MessageReceived 事件 ========
|
||||
event_type = (
|
||||
events.PersonMessageReceived
|
||||
if query.launcher_type == entities.LauncherTypes.PERSON
|
||||
if query.launcher_type == provider_session.LauncherTypes.PERSON
|
||||
else events.GroupMessageReceived
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_type(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
message_chain=query.message_chain,
|
||||
query=query,
|
||||
)
|
||||
event_obj = event_type(
|
||||
query=query,
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
message_chain=query.message_chain,
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event_obj)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
return
|
||||
|
||||
@@ -208,11 +212,12 @@ class RuntimePipeline:
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
|
||||
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
|
||||
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
|
||||
finally:
|
||||
self.ap.logger.debug(f'Query {query.query_id} processed')
|
||||
del self.ap.query_pool.cached_queries[query.query_id]
|
||||
|
||||
|
||||
class PipelineManager:
|
||||
|
||||
@@ -3,10 +3,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
from ..core import entities
|
||||
from ..platform import adapter as msadapter
|
||||
from ..platform.types import message as platform_message
|
||||
from ..platform.types import events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
|
||||
|
||||
class QueryPool:
|
||||
@@ -16,7 +17,10 @@ class QueryPool:
|
||||
|
||||
pool_lock: asyncio.Lock
|
||||
|
||||
queries: list[entities.Query]
|
||||
queries: list[pipeline_query.Query]
|
||||
|
||||
cached_queries: dict[int, pipeline_query.Query]
|
||||
"""Cached queries, used for plugin backward api call, will be removed after the query completely processed"""
|
||||
|
||||
condition: asyncio.Condition
|
||||
|
||||
@@ -24,34 +28,38 @@ class QueryPool:
|
||||
self.query_id_counter = 0
|
||||
self.pool_lock = asyncio.Lock()
|
||||
self.queries = []
|
||||
self.cached_queries = {}
|
||||
self.condition = asyncio.Condition(self.pool_lock)
|
||||
|
||||
async def add_query(
|
||||
self,
|
||||
bot_uuid: str,
|
||||
launcher_type: entities.LauncherTypes,
|
||||
launcher_type: provider_session.LauncherTypes,
|
||||
launcher_id: typing.Union[int, str],
|
||||
sender_id: typing.Union[int, str],
|
||||
message_event: platform_events.MessageEvent,
|
||||
message_chain: platform_message.MessageChain,
|
||||
adapter: msadapter.MessagePlatformAdapter,
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||
pipeline_uuid: typing.Optional[str] = None,
|
||||
) -> entities.Query:
|
||||
) -> pipeline_query.Query:
|
||||
async with self.condition:
|
||||
query = entities.Query(
|
||||
query_id = self.query_id_counter
|
||||
query = pipeline_query.Query(
|
||||
bot_uuid=bot_uuid,
|
||||
query_id=self.query_id_counter,
|
||||
query_id=query_id,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain,
|
||||
variables={},
|
||||
resp_messages=[],
|
||||
resp_message_chain=[],
|
||||
adapter=adapter,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
)
|
||||
self.queries.append(query)
|
||||
self.cached_queries[query_id] = query
|
||||
self.query_id_counter += 1
|
||||
self.condition.notify_all()
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ from __future__ import annotations
|
||||
import datetime
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...provider import entities as llm_entities
|
||||
from ...plugin import events
|
||||
from ...platform.types import message as platform_message
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
import langbot_plugin.api.entities.events as events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@stage.stage_class('PreProcessor')
|
||||
@@ -26,7 +26,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
stage_inst_name: str,
|
||||
) -> entities.StageProcessResult:
|
||||
"""Process"""
|
||||
@@ -49,80 +49,73 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.bot_uuid,
|
||||
)
|
||||
|
||||
conversation.use_llm_model = llm_model
|
||||
|
||||
# Set query
|
||||
# 设置query
|
||||
query.session = session
|
||||
query.prompt = conversation.prompt.copy()
|
||||
query.messages = conversation.messages.copy()
|
||||
|
||||
query.use_llm_model = llm_model
|
||||
query.use_llm_model_uuid = llm_model.model_entity.uuid
|
||||
|
||||
if selected_runner == 'local-agent':
|
||||
query.use_funcs = (
|
||||
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('func_call') else None
|
||||
)
|
||||
query.use_funcs = []
|
||||
|
||||
query.variables = {
|
||||
if llm_model.model_entity.abilities.__contains__('func_call'):
|
||||
query.use_funcs = await self.ap.tool_mgr.get_all_tools()
|
||||
|
||||
variables = {
|
||||
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
'conversation_id': conversation.uuid,
|
||||
'msg_create_time': (
|
||||
int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp())
|
||||
),
|
||||
}
|
||||
query.variables.update(variables)
|
||||
|
||||
# Check if this model supports vision, if not, remove all images
|
||||
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
|
||||
if selected_runner == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if selected_runner == 'local-agent' and not llm_model.model_entity.abilities.__contains__('vision'):
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
for me in msg.content:
|
||||
if me.type == 'image_url':
|
||||
msg.content.remove(me)
|
||||
|
||||
content_list: list[llm_entities.ContentElement] = []
|
||||
content_list: list[provider_message.ContentElement] = []
|
||||
|
||||
plain_text = ''
|
||||
qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message')
|
||||
|
||||
# tidy the content_list
|
||||
# combine all text content into one, and put it in the first position
|
||||
for me in query.message_chain:
|
||||
if isinstance(me, platform_message.Plain):
|
||||
content_list.append(provider_message.ContentElement.from_text(me.text))
|
||||
plain_text += me.text
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
|
||||
'vision'
|
||||
):
|
||||
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if me.base64 is not None:
|
||||
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64))
|
||||
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
|
||||
elif isinstance(me, platform_message.Quote) and qoute_msg:
|
||||
for msg in me.origin:
|
||||
if isinstance(msg, platform_message.Plain):
|
||||
content_list.append(llm_entities.ContentElement.from_text(msg.text))
|
||||
content_list.append(provider_message.ContentElement.from_text(msg.text))
|
||||
elif isinstance(msg, platform_message.Image):
|
||||
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
|
||||
'vision'
|
||||
):
|
||||
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if msg.base64 is not None:
|
||||
content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64))
|
||||
|
||||
content_list.insert(0, llm_entities.ContentElement.from_text(plain_text))
|
||||
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
|
||||
|
||||
query.variables['user_message_text'] = plain_text
|
||||
|
||||
query.user_message = llm_entities.Message(role='user', content=content_list)
|
||||
# =========== Trigger event PromptPreProcessing
|
||||
query.user_message = provider_message.Message(role='user', content=content_list)
|
||||
# =========== 触发事件 PromptPreProcessing
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PromptPreProcessing(
|
||||
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
default_prompt=query.prompt.messages,
|
||||
prompt=query.messages,
|
||||
query=query,
|
||||
)
|
||||
event = events.PromptPreProcessing(
|
||||
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
default_prompt=query.prompt.messages,
|
||||
prompt=query.messages,
|
||||
query=query,
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||
|
||||
query.prompt.messages = event_ctx.event.default_prompt
|
||||
query.messages = event_ctx.event.prompt
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
class MessageHandler(metaclass=abc.ABCMeta):
|
||||
@@ -19,7 +19,7 @@ class MessageHandler(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -7,13 +7,15 @@ import traceback
|
||||
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....provider import runner as runner_module
|
||||
from ....plugin import events
|
||||
|
||||
from ....platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from ....utils import importutil
|
||||
from ....provider import runners
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
importutil.import_modules_in_pkg(runners)
|
||||
|
||||
@@ -21,7 +23,7 @@ importutil.import_modules_in_pkg(runners)
|
||||
class ChatMessageHandler(handler.MessageHandler):
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理"""
|
||||
# 调API
|
||||
@@ -30,19 +32,20 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
# 触发插件事件
|
||||
event_class = (
|
||||
events.PersonNormalMessageReceived
|
||||
if query.launcher_type == core_entities.LauncherTypes.PERSON
|
||||
if query.launcher_type == provider_session.LauncherTypes.PERSON
|
||||
else events.GroupNormalMessageReceived
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_class(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
text_message=str(query.message_chain),
|
||||
query=query,
|
||||
)
|
||||
event = event_class(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
text_message=str(query.message_chain),
|
||||
query=query,
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||
|
||||
is_create_card = False # 判断下是否需要创建流式卡片
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
@@ -120,4 +123,4 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
)
|
||||
finally:
|
||||
# TODO statistics
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -4,16 +4,17 @@ import typing
|
||||
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....provider import entities as llm_entities
|
||||
from ....plugin import events
|
||||
from ....platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.events as events
|
||||
|
||||
|
||||
class CommandHandler(handler.MessageHandler):
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""Process"""
|
||||
|
||||
@@ -28,23 +29,23 @@ class CommandHandler(handler.MessageHandler):
|
||||
|
||||
event_class = (
|
||||
events.PersonCommandSent
|
||||
if query.launcher_type == core_entities.LauncherTypes.PERSON
|
||||
if query.launcher_type == provider_session.LauncherTypes.PERSON
|
||||
else events.GroupCommandSent
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_class(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
command=spt[0],
|
||||
params=spt[1:] if len(spt) > 1 else [],
|
||||
text_message=str(query.message_chain),
|
||||
is_admin=(privilege == 2),
|
||||
query=query,
|
||||
)
|
||||
event = event_class(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
command=spt[0],
|
||||
params=spt[1:] if len(spt) > 1 else [],
|
||||
text_message=str(query.message_chain),
|
||||
is_admin=(privilege == 2),
|
||||
query=query,
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||
@@ -64,7 +65,7 @@ class CommandHandler(handler.MessageHandler):
|
||||
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
|
||||
if ret.error is not None:
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
provider_message.Message(
|
||||
role='command',
|
||||
content=str(ret.error),
|
||||
)
|
||||
@@ -73,17 +74,20 @@ class CommandHandler(handler.MessageHandler):
|
||||
self.ap.logger.info(f'Command({query.query_id}) error: {self.cut_str(str(ret.error))}')
|
||||
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
elif ret.text is not None or ret.image_url is not None:
|
||||
content: list[llm_entities.ContentElement] = []
|
||||
elif ret.text is not None or ret.image_url is not None or ret.image_base64 is not None:
|
||||
content: list[provider_message.ContentElement] = []
|
||||
|
||||
if ret.text is not None:
|
||||
content.append(llm_entities.ContentElement.from_text(ret.text))
|
||||
content.append(provider_message.ContentElement.from_text(ret.text))
|
||||
|
||||
if ret.image_url is not None:
|
||||
content.append(llm_entities.ContentElement.from_image_url(ret.image_url))
|
||||
content.append(provider_message.ContentElement.from_image_url(ret.image_url))
|
||||
|
||||
if ret.image_base64 is not None:
|
||||
content.append(provider_message.ContentElement.from_image_base64(ret.image_base64))
|
||||
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
provider_message.Message(
|
||||
role='command',
|
||||
content=content,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...core import entities as core_entities
|
||||
from . import handler
|
||||
from .handlers import chat, command
|
||||
from .. import entities
|
||||
from .. import stage
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@stage.stage_class('MessageProcessor')
|
||||
@@ -30,7 +30,7 @@ class Processor(stage.PipelineStage):
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
stage_inst_name: str,
|
||||
) -> entities.StageProcessResult:
|
||||
"""Process"""
|
||||
|
||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from ...core import app
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
|
||||
@@ -33,7 +34,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
async def require_access(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
launcher_type: str,
|
||||
launcher_id: typing.Union[int, str],
|
||||
) -> bool:
|
||||
@@ -53,7 +54,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
async def release_access(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
launcher_type: str,
|
||||
launcher_id: typing.Union[int, str],
|
||||
):
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
import time
|
||||
import typing
|
||||
from .. import algo
|
||||
from ....core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
# 固定窗口算法
|
||||
@@ -32,7 +32,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
async def require_access(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
launcher_type: str,
|
||||
launcher_id: typing.Union[int, str],
|
||||
) -> bool:
|
||||
@@ -91,7 +91,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
async def release_access(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
launcher_type: str,
|
||||
launcher_id: typing.Union[int, str],
|
||||
):
|
||||
|
||||
@@ -4,9 +4,10 @@ import typing
|
||||
|
||||
from .. import entities, stage
|
||||
from . import algo
|
||||
from ...core import entities as core_entities
|
||||
from ...utils import importutil
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
from . import algos
|
||||
|
||||
importutil.import_modules_in_pkg(algos)
|
||||
@@ -39,7 +40,7 @@ class RateLimit(stage.PipelineStage):
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
stage_inst_name: str,
|
||||
) -> typing.Union[
|
||||
entities.StageProcessResult,
|
||||
|
||||
@@ -4,22 +4,19 @@ import random
|
||||
import asyncio
|
||||
|
||||
|
||||
from ...platform.types import events as platform_events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@stage.stage_class('SendResponseBackStage')
|
||||
class SendResponseBackStage(stage.PipelineStage):
|
||||
"""发送响应消息"""
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
|
||||
random_range = (
|
||||
@@ -40,7 +37,7 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
|
||||
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
|
||||
|
||||
has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages)
|
||||
has_chunks = any(isinstance(msg, provider_message.MessageChunk) for msg in query.resp_messages)
|
||||
# TODO 命令与流式的兼容性问题
|
||||
if await query.adapter.is_stream_output_supported() and has_chunks:
|
||||
is_final = [msg.is_final for msg in query.resp_messages][0]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pydantic.v1 as pydantic
|
||||
import pydantic
|
||||
|
||||
from ...platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
|
||||
|
||||
class RuleJudgeResult(pydantic.BaseModel):
|
||||
|
||||
@@ -4,9 +4,10 @@ from __future__ import annotations
|
||||
from . import rule
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...utils import importutil
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
from . import rules
|
||||
|
||||
importutil.import_modules_in_pkg(rules)
|
||||
@@ -32,7 +33,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
await rule_inst.initialize()
|
||||
self.rule_matchers.append(rule_inst)
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
if query.launcher_type.value != 'group': # 只处理群消息
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
|
||||
@@ -2,10 +2,11 @@ from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
from ...platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
|
||||
@@ -39,7 +40,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
|
||||
message_text: str,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
"""判断消息是否匹配规则"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@rule_model.rule_class('at-bot')
|
||||
@@ -14,19 +14,28 @@ class AtBotRule(rule_model.GroupRespondRule):
|
||||
message_text: str,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
def remove_at(message_chain: platform_message.MessageChain):
|
||||
for component in message_chain.root:
|
||||
if isinstance(component, platform_message.At) and component.target == query.adapter.bot_account_id:
|
||||
message_chain.remove(component)
|
||||
break
|
||||
|
||||
if message_chain.has(
|
||||
platform_message.At(query.adapter.bot_account_id)
|
||||
): # 回复消息时会at两次,检查并删除重复的
|
||||
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
remove_at(message_chain)
|
||||
remove_at(message_chain) # 回复消息时会at两次,检查并删除重复的
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
replacement=message_chain,
|
||||
)
|
||||
# if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
# message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
# if message_chain.has(
|
||||
# platform_message.At(query.adapter.bot_account_id)
|
||||
# ): # 回复消息时会at两次,检查并删除重复的
|
||||
# message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
# return entities.RuleJudgeResult(
|
||||
# matching=True,
|
||||
# replacement=message_chain,
|
||||
# )
|
||||
|
||||
return entities.RuleJudgeResult(matching=False, replacement=message_chain)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@rule_model.rule_class('prefix')
|
||||
@@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule):
|
||||
message_text: str,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
prefixes = rule_dict['prefix']
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import random
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@rule_model.rule_class('random')
|
||||
@@ -14,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
|
||||
message_text: str,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
random_rate = rule_dict['random']
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import re
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@rule_model.rule_class('regexp')
|
||||
@@ -14,7 +14,7 @@ class RegExpRule(rule_model.GroupRespondRule):
|
||||
message_text: str,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
regexps = rule_dict['regexp']
|
||||
|
||||
|
||||
@@ -3,8 +3,9 @@ from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from ..core import app
|
||||
from . import entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
preregistered_stages: dict[str, type[PipelineStage]] = {}
|
||||
@@ -33,7 +34,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
stage_inst_name: str,
|
||||
) -> typing.Union[
|
||||
entities.StageProcessResult,
|
||||
|
||||
@@ -2,12 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities
|
||||
from .. import stage
|
||||
from ...plugin import events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.events as events
|
||||
|
||||
|
||||
@stage.stage_class('ResponseWrapper')
|
||||
@@ -25,7 +25,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
stage_inst_name: str,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理"""
|
||||
@@ -58,21 +58,22 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
reply_text = str(result.get_content_platform_message_chain())
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls]
|
||||
if result.tool_calls is not None
|
||||
else [],
|
||||
query=query,
|
||||
)
|
||||
event = events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls]
|
||||
if result.tool_calls is not None
|
||||
else [],
|
||||
query=query,
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
@@ -96,26 +97,26 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain([platform_message.Plain(reply_text)])
|
||||
platform_message.MessageChain([platform_message.Plain(text=reply_text)])
|
||||
)
|
||||
|
||||
if query.pipeline_config['output']['misc']['track-function-calls']:
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls]
|
||||
if result.tool_calls is not None
|
||||
else [],
|
||||
query=query,
|
||||
)
|
||||
event = events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls]
|
||||
if result.tool_calls is not None
|
||||
else [],
|
||||
query=query,
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
@@ -124,12 +125,12 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(event_ctx.event.reply)
|
||||
platform_message.MessageChain(text=event_ctx.event.reply)
|
||||
)
|
||||
|
||||
else:
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain([platform_message.Plain(reply_text)])
|
||||
platform_message.MessageChain([platform_message.Plain(text=reply_text)])
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# MessageSource的适配器
|
||||
import typing
|
||||
import abc
|
||||
|
||||
|
||||
from ..core import app
|
||||
from .types import message as platform_message
|
||||
from .types import events as platform_events
|
||||
from .logger import EventLogger
|
||||
|
||||
|
||||
class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
"""消息平台适配器基类"""
|
||||
|
||||
name: str
|
||||
|
||||
bot_account_id: int
|
||||
"""机器人账号ID,需要在初始化时设置"""
|
||||
|
||||
config: dict
|
||||
|
||||
ap: app.Application
|
||||
|
||||
logger: EventLogger
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
"""初始化适配器
|
||||
|
||||
Args:
|
||||
config (dict): 对应的配置
|
||||
ap (app.Application): 应用上下文
|
||||
"""
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
"""主动发送消息
|
||||
|
||||
Args:
|
||||
target_type (str): 目标类型,`person`或`group`
|
||||
target_id (str): 目标ID
|
||||
message (platform.types.MessageChain): 消息链
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
"""回复消息
|
||||
|
||||
Args:
|
||||
message_source (platform.types.MessageEvent): 消息源事件
|
||||
message (platform.types.MessageChain): 消息链
|
||||
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message: dict,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
):
|
||||
"""回复消息(流式输出)
|
||||
Args:
|
||||
message_source (platform.types.MessageEvent): 消息源事件
|
||||
message_id (int): 消息ID
|
||||
message (platform.types.MessageChain): 消息链
|
||||
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
|
||||
is_final (bool, optional): 流式是否结束. Defaults to False.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_message_card(self, message_id: typing.Type[str, int], event: platform_events.MessageEvent) -> bool:
|
||||
"""创建卡片消息
|
||||
Args:
|
||||
message_id (str): 消息ID
|
||||
event (platform_events.MessageEvent): 消息源事件
|
||||
"""
|
||||
return False
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
"""获取账号是否在指定群被禁言"""
|
||||
raise NotImplementedError
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
):
|
||||
"""注册事件监听器
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[platform.types.Event]): 事件类型
|
||||
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
):
|
||||
"""注销事件监听器
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[platform.types.Event]): 事件类型
|
||||
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def run_async(self):
|
||||
"""异步运行"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
"""是否支持流式输出"""
|
||||
return False
|
||||
|
||||
async def kill(self) -> bool:
|
||||
"""关闭适配器
|
||||
|
||||
Returns:
|
||||
bool: 是否成功关闭,热重载时若此函数返回False则不会重载MessageSource底层
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MessageConverter:
|
||||
"""消息链转换器基类"""
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: platform_message.MessageChain):
|
||||
"""将源平台消息链转换为目标平台消息链
|
||||
|
||||
Args:
|
||||
message_chain (platform.types.MessageChain): 源平台消息链
|
||||
|
||||
Returns:
|
||||
typing.Any: 目标平台消息链
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(message_chain: typing.Any) -> platform_message.MessageChain:
|
||||
"""将目标平台消息链转换为源平台消息链
|
||||
|
||||
Args:
|
||||
message_chain (typing.Any): 目标平台消息链
|
||||
|
||||
Returns:
|
||||
platform.types.MessageChain: 源平台消息链
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EventConverter:
|
||||
"""事件转换器基类"""
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(event: typing.Type[platform_events.Event]):
|
||||
"""将源平台事件转换为目标平台事件
|
||||
|
||||
Args:
|
||||
event (typing.Type[platform.types.Event]): 源平台事件
|
||||
|
||||
Returns:
|
||||
typing.Any: 目标平台事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(event: typing.Any) -> platform_events.Event:
|
||||
"""将目标平台事件的调用参数转换为源平台的事件参数对象
|
||||
|
||||
Args:
|
||||
event (typing.Any): 目标平台事件
|
||||
|
||||
Returns:
|
||||
typing.Type[platform.types.Event]: 源平台事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,14 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: ComponentTemplate
|
||||
metadata:
|
||||
name: MessagePlatformAdapter
|
||||
label:
|
||||
en_US: Message Platform Adapter
|
||||
zh_Hans: 消息平台适配器模板类
|
||||
spec:
|
||||
type:
|
||||
- python
|
||||
execution:
|
||||
python:
|
||||
path: ./adapter.py
|
||||
attr: MessagePlatformAdapter
|
||||
@@ -1,15 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import traceback
|
||||
import sqlalchemy
|
||||
|
||||
# FriendMessage, Image, MessageChain, Plain
|
||||
from . import adapter as msadapter
|
||||
|
||||
from ..core import app, entities as core_entities, taskmgr
|
||||
from .types import events as platform_events, message as platform_message
|
||||
|
||||
from ..discover import engine
|
||||
|
||||
@@ -19,10 +14,10 @@ from ..entity.errors import platform as platform_errors
|
||||
|
||||
from .logger import EventLogger
|
||||
|
||||
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题
|
||||
from . import types as mirai
|
||||
|
||||
sys.modules['mirai'] = mirai
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
|
||||
|
||||
class RuntimeBot:
|
||||
@@ -34,7 +29,7 @@ class RuntimeBot:
|
||||
|
||||
enable: bool
|
||||
|
||||
adapter: msadapter.MessagePlatformAdapter
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
|
||||
task_wrapper: taskmgr.TaskWrapper
|
||||
|
||||
@@ -46,7 +41,7 @@ class RuntimeBot:
|
||||
self,
|
||||
ap: app.Application,
|
||||
bot_entity: persistence_bot.Bot,
|
||||
adapter: msadapter.MessagePlatformAdapter,
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||
logger: EventLogger,
|
||||
):
|
||||
self.ap = ap
|
||||
@@ -59,7 +54,7 @@ class RuntimeBot:
|
||||
async def initialize(self):
|
||||
async def on_friend_message(
|
||||
event: platform_events.FriendMessage,
|
||||
adapter: msadapter.MessagePlatformAdapter,
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||
):
|
||||
image_components = [
|
||||
component for component in event.message_chain if isinstance(component, platform_message.Image)
|
||||
@@ -73,7 +68,7 @@ class RuntimeBot:
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
@@ -84,7 +79,7 @@ class RuntimeBot:
|
||||
|
||||
async def on_group_message(
|
||||
event: platform_events.GroupMessage,
|
||||
adapter: msadapter.MessagePlatformAdapter,
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||
):
|
||||
image_components = [
|
||||
component for component in event.message_chain if isinstance(component, platform_message.Image)
|
||||
@@ -98,7 +93,7 @@ class RuntimeBot:
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
launcher_type=provider_session.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
@@ -153,7 +148,7 @@ class PlatformManager:
|
||||
|
||||
adapter_components: list[engine.Component]
|
||||
|
||||
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]]
|
||||
adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]]
|
||||
|
||||
def __init__(self, ap: app.Application = None):
|
||||
self.ap = ap
|
||||
@@ -163,7 +158,7 @@ class PlatformManager:
|
||||
|
||||
async def initialize(self):
|
||||
self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
|
||||
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {}
|
||||
adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]] = {}
|
||||
for component in self.adapter_components:
|
||||
adapter_dict[component.metadata.name] = component.get_python_component_class()
|
||||
self.adapter_dict = adapter_dict
|
||||
@@ -174,8 +169,9 @@ class PlatformManager:
|
||||
webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap)
|
||||
webchat_adapter_inst = webchat_adapter_class(
|
||||
{},
|
||||
self.ap,
|
||||
webchat_logger,
|
||||
ap=self.ap,
|
||||
is_stream=False,
|
||||
)
|
||||
|
||||
self.webchat_proxy_bot = RuntimeBot(
|
||||
@@ -195,7 +191,7 @@ class PlatformManager:
|
||||
|
||||
await self.load_bots_from_db()
|
||||
|
||||
def get_running_adapters(self) -> list[msadapter.MessagePlatformAdapter]:
|
||||
def get_running_adapters(self) -> list[abstract_platform_adapter.AbstractMessagePlatformAdapter]:
|
||||
return [bot.adapter for bot in self.bots if bot.enable]
|
||||
|
||||
async def load_bots_from_db(self):
|
||||
@@ -233,7 +229,6 @@ class PlatformManager:
|
||||
|
||||
adapter_inst = self.adapter_dict[bot_entity.adapter](
|
||||
bot_entity.adapter_config,
|
||||
self.ap,
|
||||
logger,
|
||||
)
|
||||
|
||||
@@ -276,43 +271,6 @@ class PlatformManager:
|
||||
return component
|
||||
return None
|
||||
|
||||
async def write_back_config(
|
||||
self,
|
||||
adapter_name: str,
|
||||
adapter_inst: msadapter.MessagePlatformAdapter,
|
||||
config: dict,
|
||||
):
|
||||
# index = -2
|
||||
|
||||
# for i, adapter in enumerate(self.adapters):
|
||||
# if adapter == adapter_inst:
|
||||
# index = i
|
||||
# break
|
||||
|
||||
# if index == -2:
|
||||
# raise Exception('平台适配器未找到')
|
||||
|
||||
# # 只修改启用的适配器
|
||||
# real_index = -1
|
||||
|
||||
# for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']):
|
||||
# if adapter['enable']:
|
||||
# index -= 1
|
||||
# if index == -1:
|
||||
# real_index = i
|
||||
# break
|
||||
|
||||
# new_cfg = {
|
||||
# 'adapter': adapter_name,
|
||||
# 'enable': True,
|
||||
# **config
|
||||
# }
|
||||
# self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg
|
||||
# await self.ap.platform_cfg.dump_config()
|
||||
|
||||
# TODO implement this
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
# This method will only be called when the application launching
|
||||
await self.webchat_proxy_bot.run()
|
||||
|
||||
@@ -9,7 +9,8 @@ import traceback
|
||||
import uuid
|
||||
|
||||
from ..core import app
|
||||
from .types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_event_logger
|
||||
|
||||
|
||||
class EventLogLevel(enum.Enum):
|
||||
@@ -55,7 +56,7 @@ MAX_LOG_COUNT = 200
|
||||
DELETE_COUNT_PER_TIME = 50
|
||||
|
||||
|
||||
class EventLogger:
|
||||
class EventLogger(abstract_platform_event_logger.AbstractEventLogger):
|
||||
"""used for logging bot events"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -5,17 +5,17 @@ import traceback
|
||||
import datetime
|
||||
|
||||
import aiocqhttp
|
||||
import pydantic
|
||||
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
from ...utils import image
|
||||
from ..logger import EventLogger
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
|
||||
|
||||
class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain,
|
||||
@@ -266,20 +266,21 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
await process_message_data(msg_data, reply_list)
|
||||
|
||||
reply_msg = platform_message.Quote(
|
||||
message_id=msg.data['id'], sender_id=msg_datas['sender']['user_id'], origin=reply_list
|
||||
message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list
|
||||
)
|
||||
yiri_msg_list.append(reply_msg)
|
||||
|
||||
# 这里下载所有文件会导致下载文件过多,暂时不下载
|
||||
# elif msg.type == 'file':
|
||||
# # file_name = msg.data['file']
|
||||
# file_id = msg.data['file_id']
|
||||
# file_data = await bot.get_file(file_id=file_id)
|
||||
# file_name = file_data.get('file_name')
|
||||
# file_path = file_data.get('file')
|
||||
# file_url = file_data.get('file_url')
|
||||
# file_size = file_data.get('file_size')
|
||||
# yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size))
|
||||
elif msg.type == 'file':
|
||||
pass
|
||||
# file_name = msg.data['file']
|
||||
# file_id = msg.data['file_id']
|
||||
# file_data = await bot.get_file(file_id=file_id)
|
||||
# file_name = file_data.get('file_name')
|
||||
# file_path = file_data.get('file')
|
||||
# _ = file_path
|
||||
# file_url = file_data.get('file_url')
|
||||
# file_size = file_data.get('file_size')
|
||||
# yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size))
|
||||
elif msg.type == 'face':
|
||||
face_id = msg.data['id']
|
||||
face_name = msg.data['raw']['faceText']
|
||||
@@ -298,7 +299,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
return chain
|
||||
|
||||
|
||||
class AiocqhttpEventConverter(adapter.EventConverter):
|
||||
class AiocqhttpEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int):
|
||||
return event.source_platform_object
|
||||
@@ -348,23 +349,19 @@ class AiocqhttpEventConverter(adapter.EventConverter):
|
||||
)
|
||||
|
||||
|
||||
class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
bot: aiocqhttp.CQHttp
|
||||
|
||||
bot_account_id: int
|
||||
class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: aiocqhttp.CQHttp = pydantic.Field(exclude=True, default_factory=aiocqhttp.CQHttp)
|
||||
|
||||
message_converter: AiocqhttpMessageConverter = AiocqhttpMessageConverter()
|
||||
event_converter: AiocqhttpEventConverter = AiocqhttpEventConverter()
|
||||
|
||||
config: dict
|
||||
|
||||
ap: app.Application
|
||||
|
||||
on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = []
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
async def shutdown_trigger_placeholder():
|
||||
while True:
|
||||
@@ -372,7 +369,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
self.config['shutdown_trigger'] = shutdown_trigger_placeholder
|
||||
|
||||
self.ap = ap
|
||||
self.on_websocket_connection_event_cache = []
|
||||
|
||||
if 'access-token' in config:
|
||||
@@ -408,7 +404,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
async def on_message(event: aiocqhttp.Event):
|
||||
self.bot_account_id = event.self_id
|
||||
@@ -439,7 +437,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
from re import S
|
||||
import traceback
|
||||
import typing
|
||||
from libs.dingtalk_api.dingtalkevent import DingTalkEvent
|
||||
from pkg.platform.types import message as platform_message
|
||||
from pkg.platform.adapter import MessagePlatformAdapter
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
from libs.dingtalk_api.api import DingTalkClient
|
||||
import datetime
|
||||
from ..logger import EventLogger
|
||||
|
||||
|
||||
class DingTalkMessageConverter(adapter.MessageConverter):
|
||||
class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain):
|
||||
content = ''
|
||||
@@ -52,7 +49,7 @@ class DingTalkMessageConverter(adapter.MessageConverter):
|
||||
return chain
|
||||
|
||||
|
||||
class DingTalkEventConverter(adapter.EventConverter):
|
||||
class DingTalkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.MessageEvent):
|
||||
return event.source_platform_object
|
||||
@@ -96,22 +93,18 @@ class DingTalkEventConverter(adapter.EventConverter):
|
||||
)
|
||||
|
||||
|
||||
class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: DingTalkClient
|
||||
ap: app.Application
|
||||
bot_account_id: str
|
||||
message_converter: DingTalkMessageConverter = DingTalkMessageConverter()
|
||||
event_converter: DingTalkEventConverter = DingTalkEventConverter()
|
||||
config: dict
|
||||
card_instance_id_dict: dict # 回复卡片消息字典,key为消息id,value为回复卡片实例id,用于在流式消息时判断是否发送到指定卡片
|
||||
seq: int # 消息顺序,直接以seq作为标识
|
||||
card_instance_id_dict: (
|
||||
dict # 回复卡片消息字典,key为消息id,value为回复卡片实例id,用于在流式消息时判断是否发送到指定卡片
|
||||
)
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
self.card_instance_id_dict = {}
|
||||
# self.seq = 1
|
||||
required_keys = [
|
||||
'client_id',
|
||||
'client_secret',
|
||||
@@ -121,16 +114,23 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise Exception('钉钉缺少相关配置项,请查看文档或联系管理员')
|
||||
bot = DingTalkClient(
|
||||
client_id=config['client_id'],
|
||||
client_secret=config['client_secret'],
|
||||
robot_name=config['robot_name'],
|
||||
robot_code=config['robot_code'],
|
||||
markdown_card=config['markdown_card'],
|
||||
logger=logger,
|
||||
)
|
||||
bot_account_id = config['robot_name']
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
card_instance_id_dict={},
|
||||
bot_account_id=bot_account_id,
|
||||
bot=bot,
|
||||
listeners={},
|
||||
|
||||
self.bot_account_id = self.config['robot_name']
|
||||
|
||||
self.bot = DingTalkClient(
|
||||
client_id=config['client_id'],
|
||||
client_secret=config['client_secret'],
|
||||
robot_name=config['robot_name'],
|
||||
robot_code=config['robot_code'],
|
||||
markdown_card=config['markdown_card'],
|
||||
logger=self.logger,
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
@@ -165,12 +165,11 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
msg_seq = bot_message.msg_sequence
|
||||
|
||||
if (msg_seq - 1) % 8 == 0 or is_final:
|
||||
|
||||
content, at = await DingTalkMessageConverter.yiri2target(message)
|
||||
|
||||
card_instance, card_instance_id = self.card_instance_id_dict[message_id]
|
||||
if not content and bot_message.content:
|
||||
content = bot_message.content # 兼容直接传入content的情况
|
||||
content = bot_message.content # 兼容直接传入content的情况
|
||||
# print(card_instance_id)
|
||||
if content:
|
||||
await self.bot.send_card_message(card_instance, card_instance_id, content, is_final)
|
||||
@@ -202,7 +201,9 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
async def on_message(event: DingTalkEvent):
|
||||
try:
|
||||
@@ -224,9 +225,14 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
async def is_muted(self) -> bool:
|
||||
return False
|
||||
|
||||
async def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
@@ -12,13 +12,14 @@ import asyncio
|
||||
from enum import Enum
|
||||
|
||||
import aiohttp
|
||||
import pydantic
|
||||
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
from ..logger import EventLogger
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
|
||||
|
||||
# 语音功能相关异常定义
|
||||
@@ -582,7 +583,7 @@ class VoiceConnectionManager:
|
||||
await self.stop_monitoring()
|
||||
|
||||
|
||||
class DiscordMessageConverter(adapter.MessageConverter):
|
||||
class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain,
|
||||
@@ -736,7 +737,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
||||
return platform_message.MessageChain(element_list)
|
||||
|
||||
|
||||
class DiscordEventConverter(adapter.EventConverter):
|
||||
class DiscordEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.Event) -> discord.Message:
|
||||
pass
|
||||
@@ -778,32 +779,26 @@ class DiscordEventConverter(adapter.EventConverter):
|
||||
)
|
||||
|
||||
|
||||
class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
bot: discord.Client
|
||||
|
||||
bot_account_id: str # 用于在流水线中识别at是否是本bot,直接以bot_name作为标识
|
||||
|
||||
config: dict
|
||||
|
||||
ap: app.Application
|
||||
class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: discord.Client = pydantic.Field(exclude=True)
|
||||
|
||||
message_converter: DiscordMessageConverter = DiscordMessageConverter()
|
||||
event_converter: DiscordEventConverter = DiscordEventConverter()
|
||||
|
||||
listeners: typing.Dict[
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||
] = {}
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
voice_manager: VoiceConnectionManager | None = pydantic.Field(exclude=True, default=None)
|
||||
|
||||
self.bot_account_id = self.config['client_id']
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||
bot_account_id = config['client_id']
|
||||
|
||||
listeners = {}
|
||||
|
||||
# 初始化语音连接管理器
|
||||
self.voice_manager: VoiceConnectionManager = None
|
||||
# self.voice_manager: VoiceConnectionManager = None
|
||||
|
||||
adapter_self = self
|
||||
|
||||
@@ -823,7 +818,17 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
if os.getenv('http_proxy'):
|
||||
args['proxy'] = os.getenv('http_proxy')
|
||||
|
||||
self.bot = MyClient(intents=intents, **args)
|
||||
bot = MyClient(intents=intents, **args)
|
||||
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
bot_account_id=bot_account_id,
|
||||
listeners=listeners,
|
||||
bot=bot,
|
||||
voice_manager=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Voice functionality methods
|
||||
async def join_voice_channel(self, guild_id: int, channel_id: int, user_id: int = None) -> discord.VoiceClient:
|
||||
@@ -1029,7 +1034,14 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
if quote_origin:
|
||||
args['reference'] = message_source.source_platform_object
|
||||
|
||||
if message.has(platform_message.At):
|
||||
has_at = False
|
||||
|
||||
for component in message.root:
|
||||
if isinstance(component, platform_message.At):
|
||||
has_at = True
|
||||
break
|
||||
|
||||
if has_at:
|
||||
args['mention_author'] = True
|
||||
|
||||
await message_source.source_platform_object.channel.send(**args)
|
||||
@@ -1040,14 +1052,18 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners.pop(event_type)
|
||||
|
||||
|
||||
@@ -17,14 +17,14 @@ import aiohttp
|
||||
import lark_oapi.ws.exception
|
||||
import quart
|
||||
from lark_oapi.api.im.v1 import *
|
||||
import pydantic
|
||||
from lark_oapi.api.cardkit.v1 import *
|
||||
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ..logger import EventLogger
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
|
||||
|
||||
class AESCipher(object):
|
||||
@@ -53,7 +53,7 @@ class AESCipher(object):
|
||||
return self.decrypt(enc).decode('utf8')
|
||||
|
||||
|
||||
class LarkMessageConverter(adapter.MessageConverter):
|
||||
class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain, api_client: lark_oapi.Client
|
||||
@@ -277,7 +277,7 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
return platform_message.MessageChain(lb_msg_list)
|
||||
|
||||
|
||||
class LarkEventConverter(adapter.EventConverter):
|
||||
class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
event: platform_events.MessageEvent,
|
||||
@@ -325,49 +325,37 @@ CARD_ID_CACHE_SIZE = 500
|
||||
CARD_ID_CACHE_MAX_LIFETIME = 20 * 60 # 20分钟
|
||||
|
||||
|
||||
class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
bot: lark_oapi.ws.Client
|
||||
api_client: lark_oapi.Client
|
||||
class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: lark_oapi.ws.Client = pydantic.Field(exclude=True)
|
||||
api_client: lark_oapi.Client = pydantic.Field(exclude=True)
|
||||
|
||||
bot_account_id: str # 用于在流水线中识别at是否是本bot,直接以bot_name作为标识
|
||||
lark_tenant_key: str # 飞书企业key
|
||||
lark_tenant_key: str = pydantic.Field(exclude=True, default='') # 飞书企业key
|
||||
|
||||
message_converter: LarkMessageConverter = LarkMessageConverter()
|
||||
event_converter: LarkEventConverter = LarkEventConverter()
|
||||
|
||||
listeners: typing.Dict[
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||
]
|
||||
|
||||
config: dict
|
||||
quart_app: quart.Quart
|
||||
ap: app.Application
|
||||
|
||||
quart_app: quart.Quart = pydantic.Field(exclude=True)
|
||||
|
||||
card_id_dict: dict[str, str] # 消息id到卡片id的映射,便于创建卡片后的发送消息到指定卡片
|
||||
|
||||
seq: int # 用于在发送卡片消息中识别消息顺序,直接以seq作为标识
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
self.quart_app = quart.Quart(__name__)
|
||||
self.listeners = {}
|
||||
self.card_id_dict = {}
|
||||
self.seq = 1
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||
quart_app = quart.Quart(__name__)
|
||||
|
||||
|
||||
@self.quart_app.route('/lark/callback', methods=['POST'])
|
||||
@quart_app.route('/lark/callback', methods=['POST'])
|
||||
async def lark_callback():
|
||||
try:
|
||||
data = await quart.request.json
|
||||
|
||||
self.ap.logger.debug(f'Lark callback event: {data}')
|
||||
|
||||
if 'encrypt' in data:
|
||||
cipher = AESCipher(self.config['encrypt-key'])
|
||||
cipher = AESCipher(config['encrypt-key'])
|
||||
data = cipher.decrypt_string(data['encrypt'])
|
||||
data = json.loads(data)
|
||||
|
||||
@@ -414,10 +402,24 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build()
|
||||
)
|
||||
|
||||
self.bot_account_id = config['bot_name']
|
||||
bot_account_id = config['bot_name']
|
||||
|
||||
self.bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler)
|
||||
self.api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build()
|
||||
bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler)
|
||||
api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build()
|
||||
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
lark_tenant_key=config.get('lark_tenant_key', ''),
|
||||
card_id_dict={},
|
||||
seq=1,
|
||||
listeners={},
|
||||
quart_app=quart_app,
|
||||
bot=bot,
|
||||
api_client=api_client,
|
||||
bot_account_id=bot_account_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
pass
|
||||
@@ -430,151 +432,177 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
async def create_card_id(self, message_id):
|
||||
try:
|
||||
self.ap.logger.debug('飞书支持stream输出,创建卡片......')
|
||||
# self.logger.debug('飞书支持stream输出,创建卡片......')
|
||||
|
||||
card_data = {"schema": "2.0", "config": {"update_multi": True, "streaming_mode": True,
|
||||
"streaming_config": {"print_step": {"default": 1},
|
||||
"print_frequency_ms": {"default": 70},
|
||||
"print_strategy": "fast"}},
|
||||
"body": {"direction": "vertical", "padding": "12px 12px 12px 12px", "elements": [{"tag": "div",
|
||||
"text": {
|
||||
"tag": "plain_text",
|
||||
"content": "LangBot",
|
||||
"text_size": "normal",
|
||||
"text_align": "left",
|
||||
"text_color": "default"},
|
||||
"icon": {
|
||||
"tag": "custom_icon",
|
||||
"img_key": "img_v3_02p3_05c65d5d-9bad-440a-a2fb-c89571bfd5bg"}},
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "",
|
||||
"text_align": "left",
|
||||
"text_size": "normal",
|
||||
"margin": "0px 0px 0px 0px",
|
||||
"element_id": "streaming_txt"},
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "",
|
||||
"text_align": "left",
|
||||
"text_size": "normal",
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{
|
||||
"tag": "column_set",
|
||||
"horizontal_spacing": "8px",
|
||||
"horizontal_align": "left",
|
||||
"columns": [
|
||||
{
|
||||
"tag": "column",
|
||||
"width": "weighted",
|
||||
"elements": [
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "",
|
||||
"text_align": "left",
|
||||
"text_size": "normal",
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "",
|
||||
"text_align": "left",
|
||||
"text_size": "normal",
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "",
|
||||
"text_align": "left",
|
||||
"text_size": "normal",
|
||||
"margin": "0px 0px 0px 0px"}],
|
||||
"padding": "0px 0px 0px 0px",
|
||||
"direction": "vertical",
|
||||
"horizontal_spacing": "8px",
|
||||
"vertical_spacing": "2px",
|
||||
"horizontal_align": "left",
|
||||
"vertical_align": "top",
|
||||
"margin": "0px 0px 0px 0px",
|
||||
"weight": 1}],
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{"tag": "hr",
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{
|
||||
"tag": "column_set",
|
||||
"horizontal_spacing": "12px",
|
||||
"horizontal_align": "right",
|
||||
"columns": [
|
||||
{
|
||||
"tag": "column",
|
||||
"width": "weighted",
|
||||
"elements": [
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "<font color=\"grey-600\">以上内容由 AI 生成,仅供参考。更多详细、准确信息可点击引用链接查看</font>",
|
||||
"text_align": "left",
|
||||
"text_size": "notation",
|
||||
"margin": "4px 0px 0px 0px",
|
||||
"icon": {
|
||||
"tag": "standard_icon",
|
||||
"token": "robot_outlined",
|
||||
"color": "grey"}}],
|
||||
"padding": "0px 0px 0px 0px",
|
||||
"direction": "vertical",
|
||||
"horizontal_spacing": "8px",
|
||||
"vertical_spacing": "8px",
|
||||
"horizontal_align": "left",
|
||||
"vertical_align": "top",
|
||||
"margin": "0px 0px 0px 0px",
|
||||
"weight": 1},
|
||||
{
|
||||
"tag": "column",
|
||||
"width": "20px",
|
||||
"elements": [
|
||||
{
|
||||
"tag": "button",
|
||||
"text": {
|
||||
"tag": "plain_text",
|
||||
"content": ""},
|
||||
"type": "text",
|
||||
"width": "fill",
|
||||
"size": "medium",
|
||||
"icon": {
|
||||
"tag": "standard_icon",
|
||||
"token": "thumbsup_outlined"},
|
||||
"hover_tips": {
|
||||
"tag": "plain_text",
|
||||
"content": "有帮助"},
|
||||
"margin": "0px 0px 0px 0px"}],
|
||||
"padding": "0px 0px 0px 0px",
|
||||
"direction": "vertical",
|
||||
"horizontal_spacing": "8px",
|
||||
"vertical_spacing": "8px",
|
||||
"horizontal_align": "left",
|
||||
"vertical_align": "top",
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{
|
||||
"tag": "column",
|
||||
"width": "30px",
|
||||
"elements": [
|
||||
{
|
||||
"tag": "button",
|
||||
"text": {
|
||||
"tag": "plain_text",
|
||||
"content": ""},
|
||||
"type": "text",
|
||||
"width": "default",
|
||||
"size": "medium",
|
||||
"icon": {
|
||||
"tag": "standard_icon",
|
||||
"token": "thumbdown_outlined"},
|
||||
"hover_tips": {
|
||||
"tag": "plain_text",
|
||||
"content": "无帮助"},
|
||||
"margin": "0px 0px 0px 0px"}],
|
||||
"padding": "0px 0px 0px 0px",
|
||||
"vertical_spacing": "8px",
|
||||
"horizontal_align": "left",
|
||||
"vertical_align": "top",
|
||||
"margin": "0px 0px 0px 0px"}],
|
||||
"margin": "0px 0px 4px 0px"}]}}
|
||||
card_data = {
|
||||
'schema': '2.0',
|
||||
'config': {
|
||||
'update_multi': True,
|
||||
'streaming_mode': True,
|
||||
'streaming_config': {
|
||||
'print_step': {'default': 1},
|
||||
'print_frequency_ms': {'default': 70},
|
||||
'print_strategy': 'fast',
|
||||
},
|
||||
},
|
||||
'body': {
|
||||
'direction': 'vertical',
|
||||
'padding': '12px 12px 12px 12px',
|
||||
'elements': [
|
||||
{
|
||||
'tag': 'div',
|
||||
'text': {
|
||||
'tag': 'plain_text',
|
||||
'content': 'LangBot',
|
||||
'text_size': 'normal',
|
||||
'text_align': 'left',
|
||||
'text_color': 'default',
|
||||
},
|
||||
'icon': {
|
||||
'tag': 'custom_icon',
|
||||
'img_key': 'img_v3_02p3_05c65d5d-9bad-440a-a2fb-c89571bfd5bg',
|
||||
},
|
||||
},
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': '',
|
||||
'text_align': 'left',
|
||||
'text_size': 'normal',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
'element_id': 'streaming_txt',
|
||||
},
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': '',
|
||||
'text_align': 'left',
|
||||
'text_size': 'normal',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
},
|
||||
{
|
||||
'tag': 'column_set',
|
||||
'horizontal_spacing': '8px',
|
||||
'horizontal_align': 'left',
|
||||
'columns': [
|
||||
{
|
||||
'tag': 'column',
|
||||
'width': 'weighted',
|
||||
'elements': [
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': '',
|
||||
'text_align': 'left',
|
||||
'text_size': 'normal',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
},
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': '',
|
||||
'text_align': 'left',
|
||||
'text_size': 'normal',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
},
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': '',
|
||||
'text_align': 'left',
|
||||
'text_size': 'normal',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
},
|
||||
],
|
||||
'padding': '0px 0px 0px 0px',
|
||||
'direction': 'vertical',
|
||||
'horizontal_spacing': '8px',
|
||||
'vertical_spacing': '2px',
|
||||
'horizontal_align': 'left',
|
||||
'vertical_align': 'top',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
'weight': 1,
|
||||
}
|
||||
],
|
||||
'margin': '0px 0px 0px 0px',
|
||||
},
|
||||
{'tag': 'hr', 'margin': '0px 0px 0px 0px'},
|
||||
{
|
||||
'tag': 'column_set',
|
||||
'horizontal_spacing': '12px',
|
||||
'horizontal_align': 'right',
|
||||
'columns': [
|
||||
{
|
||||
'tag': 'column',
|
||||
'width': 'weighted',
|
||||
'elements': [
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': '<font color="grey-600">以上内容由 AI 生成,仅供参考。更多详细、准确信息可点击引用链接查看</font>',
|
||||
'text_align': 'left',
|
||||
'text_size': 'notation',
|
||||
'margin': '4px 0px 0px 0px',
|
||||
'icon': {
|
||||
'tag': 'standard_icon',
|
||||
'token': 'robot_outlined',
|
||||
'color': 'grey',
|
||||
},
|
||||
}
|
||||
],
|
||||
'padding': '0px 0px 0px 0px',
|
||||
'direction': 'vertical',
|
||||
'horizontal_spacing': '8px',
|
||||
'vertical_spacing': '8px',
|
||||
'horizontal_align': 'left',
|
||||
'vertical_align': 'top',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
'weight': 1,
|
||||
},
|
||||
{
|
||||
'tag': 'column',
|
||||
'width': '20px',
|
||||
'elements': [
|
||||
{
|
||||
'tag': 'button',
|
||||
'text': {'tag': 'plain_text', 'content': ''},
|
||||
'type': 'text',
|
||||
'width': 'fill',
|
||||
'size': 'medium',
|
||||
'icon': {'tag': 'standard_icon', 'token': 'thumbsup_outlined'},
|
||||
'hover_tips': {'tag': 'plain_text', 'content': '有帮助'},
|
||||
'margin': '0px 0px 0px 0px',
|
||||
}
|
||||
],
|
||||
'padding': '0px 0px 0px 0px',
|
||||
'direction': 'vertical',
|
||||
'horizontal_spacing': '8px',
|
||||
'vertical_spacing': '8px',
|
||||
'horizontal_align': 'left',
|
||||
'vertical_align': 'top',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
},
|
||||
{
|
||||
'tag': 'column',
|
||||
'width': '30px',
|
||||
'elements': [
|
||||
{
|
||||
'tag': 'button',
|
||||
'text': {'tag': 'plain_text', 'content': ''},
|
||||
'type': 'text',
|
||||
'width': 'default',
|
||||
'size': 'medium',
|
||||
'icon': {'tag': 'standard_icon', 'token': 'thumbdown_outlined'},
|
||||
'hover_tips': {'tag': 'plain_text', 'content': '无帮助'},
|
||||
'margin': '0px 0px 0px 0px',
|
||||
}
|
||||
],
|
||||
'padding': '0px 0px 0px 0px',
|
||||
'vertical_spacing': '8px',
|
||||
'horizontal_align': 'left',
|
||||
'vertical_align': 'top',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
},
|
||||
],
|
||||
'margin': '0px 0px 4px 0px',
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
# delay / fast 创建卡片模板,delay 延迟打印,fast 实时打印,可以自定义更好看的消息模板
|
||||
|
||||
request: CreateCardRequest = (
|
||||
@@ -592,15 +620,13 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
f'client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}')
|
||||
self.card_id_dict[message_id] = response.data.card_id
|
||||
|
||||
card_id = response.data.card_id
|
||||
return card_id
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'飞书卡片创建失败,错误信息: {e}')
|
||||
|
||||
raise e
|
||||
async def create_message_card(self, message_id, event) -> str:
|
||||
"""
|
||||
创建卡片消息。
|
||||
@@ -612,7 +638,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
content = {
|
||||
'type': 'card',
|
||||
'data': {'card_id': card_id, 'template_variable': {'content': 'Thinking...'}},
|
||||
} # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档
|
||||
} # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档
|
||||
request: ReplyMessageRequest = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(event.message_chain.message_id)
|
||||
@@ -685,10 +711,8 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
message_id = bot_message.resp_message_id
|
||||
msg_seq = bot_message.msg_sequence
|
||||
if msg_seq % 8 == 0 or is_final:
|
||||
|
||||
lark_message = await self.message_converter.yiri2target(message, self.api_client)
|
||||
|
||||
|
||||
text_message = ''
|
||||
for ele in lark_message[0]:
|
||||
if ele['tag'] == 'text':
|
||||
@@ -734,14 +758,18 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners.pop(event_type)
|
||||
|
||||
@@ -778,4 +806,4 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
# 所以要设置_auto_reconnect=False,让其不重连。
|
||||
self.bot._auto_reconnect = False
|
||||
await self.bot._disconnect()
|
||||
return False
|
||||
return False
|
||||
|
||||
|
Before Width: | Height: | Size: 25 KiB After Width: | Height: | Size: 25 KiB |
@@ -11,19 +11,19 @@ import threading
|
||||
import quart
|
||||
import aiohttp
|
||||
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ...utils import image
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
from ....core import app
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
from ....utils import image
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Optional, Tuple
|
||||
from functools import partial
|
||||
from ..logger import EventLogger
|
||||
from ...logger import EventLogger
|
||||
|
||||
|
||||
class GewechatMessageConverter(adapter.MessageConverter):
|
||||
class GewechatMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
|
||||
@@ -398,7 +398,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
return from_user_name.endswith('@chatroom')
|
||||
|
||||
|
||||
class GewechatEventConverter(adapter.EventConverter):
|
||||
class GewechatEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.message_converter = GewechatMessageConverter(config)
|
||||
@@ -458,7 +458,7 @@ class GewechatEventConverter(adapter.EventConverter):
|
||||
)
|
||||
|
||||
|
||||
class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
name: str = 'gewechat' # 定义适配器名称
|
||||
|
||||
bot: gewechat_client.GewechatClient
|
||||
@@ -475,7 +475,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
listeners: typing.Dict[
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||
] = {}
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
@@ -491,7 +491,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
async def gewechat_callback():
|
||||
data = await quart.request.json
|
||||
# print(json.dumps(data, indent=4, ensure_ascii=False))
|
||||
self.ap.logger.debug(f'Gewechat callback event: {data}')
|
||||
await self.logger.debug(f'Gewechat callback event: {data}')
|
||||
|
||||
if 'data' in data:
|
||||
data['Data'] = data['data']
|
||||
@@ -601,7 +601,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
if handler := handler_map.get(msg['type']):
|
||||
handler(msg)
|
||||
else:
|
||||
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
|
||||
await self.logger.warning(f'未处理的消息类型: {msg["type"]}')
|
||||
continue
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
@@ -625,14 +625,18 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -656,9 +660,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
self.config['app_id'] = app_id
|
||||
|
||||
self.ap.logger.info(f'Gewechat 登录成功,app_id: {app_id}')
|
||||
|
||||
self.ap.platform_mgr.write_back_config('gewechat', self, self.config)
|
||||
print(f'Gewechat 登录成功,app_id: {app_id}')
|
||||
|
||||
# 获取 nickname
|
||||
profile = self.bot.get_profile(self.config['app_id'])
|
||||
|
Before Width: | Height: | Size: 274 KiB After Width: | Height: | Size: 274 KiB |
@@ -9,15 +9,15 @@ import traceback
|
||||
import nakuru
|
||||
import nakuru.entities.components as nkc
|
||||
|
||||
from .. import adapter as adapter_model
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...platform.types import message as platform_message
|
||||
from ...platform.types import entities as platform_entities
|
||||
from ...platform.types import events as platform_events
|
||||
from ..logger import EventLogger
|
||||
from ....pipeline.longtext.strategies import forward
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
from ...logger import EventLogger
|
||||
|
||||
|
||||
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
class NakuruProjectMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
"""消息转换器"""
|
||||
|
||||
@staticmethod
|
||||
@@ -109,7 +109,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
return chain
|
||||
|
||||
|
||||
class NakuruProjectEventConverter(adapter_model.EventConverter):
|
||||
class NakuruProjectEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
"""事件转换器"""
|
||||
|
||||
@staticmethod
|
||||
@@ -164,7 +164,7 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
|
||||
raise Exception('未支持转换的事件类型: ' + str(event))
|
||||
|
||||
|
||||
class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
class NakuruAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""nakuru-project适配器"""
|
||||
|
||||
bot: nakuru.CQHTTP
|
||||
@@ -256,13 +256,15 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
try:
|
||||
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
|
||||
|
||||
# 包装函数
|
||||
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls):
|
||||
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): # type: ignore
|
||||
await callback(self.event_converter.target2yiri(source), self)
|
||||
|
||||
# 将包装函数和原函数的对应关系存入列表
|
||||
@@ -283,7 +285,9 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
|
||||
|
||||
@@ -322,7 +326,6 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
except Exception:
|
||||
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
|
||||
await self.bot._run()
|
||||
self.ap.logger.info('运行 Nakuru 适配器')
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@@ -10,14 +10,14 @@ import botpy
|
||||
import botpy.message as botpy_message
|
||||
import botpy.types.message as botpy_message_type
|
||||
|
||||
from .. import adapter as adapter_model
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ...config import manager as cfg_mgr
|
||||
from ...platform.types import entities as platform_entities
|
||||
from ...platform.types import events as platform_events
|
||||
from ...platform.types import message as platform_message
|
||||
from ..logger import EventLogger
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
from ....pipeline.longtext.strategies import forward
|
||||
from ....core import app
|
||||
from ....config import manager as cfg_mgr
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
from ...logger import EventLogger
|
||||
|
||||
|
||||
class OfficialGroupMessage(platform_events.GroupMessage):
|
||||
@@ -133,7 +133,7 @@ class OpenIDMapping(typing.Generic[K, V]):
|
||||
return value
|
||||
|
||||
|
||||
class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
class OfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
"""QQ 官方消息转换器"""
|
||||
|
||||
@staticmethod
|
||||
@@ -237,7 +237,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
return chain
|
||||
|
||||
|
||||
class OfficialEventConverter(adapter_model.EventConverter):
|
||||
class OfficialEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
"""事件转换器"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -333,7 +333,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
)
|
||||
|
||||
|
||||
class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
class OfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""QQ 官方消息适配器"""
|
||||
|
||||
bot: botpy.Client = None
|
||||
@@ -484,7 +484,9 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
try:
|
||||
|
||||
@@ -507,7 +509,9 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
delattr(self.bot, event_handler_mapping[event_type])
|
||||
|
||||
@@ -519,7 +523,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
|
||||
self.cfg['ret_coro'] = True
|
||||
|
||||
self.ap.logger.info('运行 QQ 官方适配器')
|
||||
await self.logger.info('运行 QQ 官方适配器')
|
||||
await (await self.bot.start(**self.cfg))
|
||||
|
||||
async def kill(self) -> bool:
|
||||
|
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.7 KiB |
@@ -4,19 +4,18 @@ import asyncio
|
||||
import traceback
|
||||
|
||||
import datetime
|
||||
from pkg.platform.adapter import MessagePlatformAdapter
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
from libs.official_account_api.oaevent import OAEvent
|
||||
from libs.official_account_api.api import OAClient
|
||||
from libs.official_account_api.api import OAClientForLongerResponse
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
from ..types import entities as platform_entities
|
||||
from ...command.errors import ParamNotEnoughError
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
from langbot_plugin.api.entities.builtin.command import errors as command_errors
|
||||
from ..logger import EventLogger
|
||||
|
||||
|
||||
class OAMessageConverter(adapter.MessageConverter):
|
||||
class OAMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain):
|
||||
for msg in message_chain:
|
||||
@@ -34,7 +33,7 @@ class OAMessageConverter(adapter.MessageConverter):
|
||||
return chain
|
||||
|
||||
|
||||
class OAEventConverter(adapter.EventConverter):
|
||||
class OAEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
@staticmethod
|
||||
async def target2yiri(event: OAEvent):
|
||||
if event.type == 'text':
|
||||
@@ -56,17 +55,15 @@ class OAEventConverter(adapter.EventConverter):
|
||||
return None
|
||||
|
||||
|
||||
class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: OAClient | OAClientForLongerResponse
|
||||
ap: app.Application
|
||||
bot_account_id: str
|
||||
message_converter: OAMessageConverter = OAMessageConverter()
|
||||
event_converter: OAEventConverter = OAEventConverter()
|
||||
config: dict
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
|
||||
required_keys = [
|
||||
@@ -78,7 +75,7 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员')
|
||||
raise command_errors.ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
if self.config['Mode'] == 'drop':
|
||||
self.bot = OAClient(
|
||||
@@ -119,7 +116,9 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
async def on_message(event: OAEvent):
|
||||
self.bot_account_id = event.receiver_id
|
||||
@@ -150,6 +149,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
async def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
@@ -5,19 +5,18 @@ import traceback
|
||||
|
||||
import datetime
|
||||
|
||||
from pkg.platform.adapter import MessagePlatformAdapter
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
from ..types import entities as platform_entities
|
||||
from ...command.errors import ParamNotEnoughError
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
from langbot_plugin.api.entities.builtin.command import errors as command_errors
|
||||
from libs.qq_official_api.api import QQOfficialClient
|
||||
from libs.qq_official_api.qqofficialevent import QQOfficialEvent
|
||||
from ...utils import image
|
||||
from ..logger import EventLogger
|
||||
|
||||
|
||||
class QQOfficialMessageConverter(adapter.MessageConverter):
|
||||
class QQOfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain):
|
||||
content_list = []
|
||||
@@ -46,7 +45,7 @@ class QQOfficialMessageConverter(adapter.MessageConverter):
|
||||
return chain
|
||||
|
||||
|
||||
class QQOfficialEventConverter(adapter.EventConverter):
|
||||
class QQOfficialEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.MessageEvent) -> QQOfficialEvent:
|
||||
return event.source_platform_object
|
||||
@@ -132,17 +131,15 @@ class QQOfficialEventConverter(adapter.EventConverter):
|
||||
)
|
||||
|
||||
|
||||
class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: QQOfficialClient
|
||||
ap: app.Application
|
||||
config: dict
|
||||
bot_account_id: str
|
||||
message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter()
|
||||
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
|
||||
required_keys = [
|
||||
@@ -151,7 +148,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员')
|
||||
raise command_errors.ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
self.bot = QQOfficialClient(
|
||||
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger
|
||||
@@ -215,7 +212,9 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
async def on_message(event: QQOfficialEvent):
|
||||
self.bot_account_id = 'justbot'
|
||||
@@ -248,6 +247,8 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
@@ -6,18 +6,17 @@ import traceback
|
||||
import datetime
|
||||
|
||||
from libs.slack_api.api import SlackClient
|
||||
from pkg.platform.adapter import MessagePlatformAdapter
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
from libs.slack_api.slackevent import SlackEvent
|
||||
from pkg.core import app
|
||||
from .. import adapter
|
||||
from ..types import entities as platform_entities
|
||||
from ...command.errors import ParamNotEnoughError
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
from langbot_plugin.api.entities.builtin.command import errors as command_errors
|
||||
from ...utils import image
|
||||
from ..logger import EventLogger
|
||||
|
||||
|
||||
class SlackMessageConverter(adapter.MessageConverter):
|
||||
class SlackMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain):
|
||||
content_list = []
|
||||
@@ -44,7 +43,7 @@ class SlackMessageConverter(adapter.MessageConverter):
|
||||
return chain
|
||||
|
||||
|
||||
class SlackEventConverter(adapter.EventConverter):
|
||||
class SlackEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.MessageEvent) -> SlackEvent:
|
||||
return event.source_platform_object
|
||||
@@ -84,17 +83,15 @@ class SlackEventConverter(adapter.EventConverter):
|
||||
)
|
||||
|
||||
|
||||
class SlackAdapter(adapter.MessagePlatformAdapter):
|
||||
class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: SlackClient
|
||||
ap: app.Application
|
||||
bot_account_id: str
|
||||
message_converter: SlackMessageConverter = SlackMessageConverter()
|
||||
event_converter: SlackEventConverter = SlackEventConverter()
|
||||
config: dict
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
required_keys = [
|
||||
'bot_token',
|
||||
@@ -102,7 +99,7 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员')
|
||||
raise command_errors.ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
self.bot = SlackClient(
|
||||
bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger
|
||||
@@ -135,7 +132,9 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
async def on_message(event: SlackEvent):
|
||||
self.bot_account_id = 'SlackBot'
|
||||
@@ -166,6 +165,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
|
||||
async def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
@@ -10,18 +10,16 @@ import typing
|
||||
import traceback
|
||||
import base64
|
||||
import aiohttp
|
||||
import pydantic
|
||||
|
||||
from lark_oapi.api.im.v1 import *
|
||||
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ..logger import EventLogger
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
|
||||
|
||||
class TelegramMessageConverter(adapter.MessageConverter):
|
||||
class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain, bot: telegram.Bot) -> list[dict]:
|
||||
components = []
|
||||
@@ -90,7 +88,7 @@ class TelegramMessageConverter(adapter.MessageConverter):
|
||||
return platform_message.MessageChain(message_components)
|
||||
|
||||
|
||||
class TelegramEventConverter(adapter.EventConverter):
|
||||
class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.MessageEvent, bot: telegram.Bot):
|
||||
return event.source_platform_object
|
||||
@@ -132,17 +130,14 @@ class TelegramEventConverter(adapter.EventConverter):
|
||||
)
|
||||
|
||||
|
||||
class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
bot: telegram.Bot
|
||||
application: telegram.ext.Application
|
||||
|
||||
bot_account_id: str
|
||||
class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: telegram.Bot = pydantic.Field(exclude=True)
|
||||
application: telegram.ext.Application = pydantic.Field(exclude=True)
|
||||
|
||||
message_converter: TelegramMessageConverter = TelegramMessageConverter()
|
||||
event_converter: TelegramEventConverter = TelegramEventConverter()
|
||||
|
||||
config: dict
|
||||
ap: app.Application
|
||||
|
||||
msg_stream_id: dict # 流式消息id字典,key为流式消息id,value为首次消息源id,用于在流式消息时判断编辑那条消息
|
||||
|
||||
@@ -150,16 +145,10 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
listeners: typing.Dict[
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||
] = {}
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
self.msg_stream_id = {}
|
||||
# self.seq = 1
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
|
||||
async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
if update.message.from_user.is_bot:
|
||||
return
|
||||
@@ -171,10 +160,18 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}')
|
||||
|
||||
self.application = ApplicationBuilder().token(self.config['token']).build()
|
||||
self.bot = self.application.bot
|
||||
self.application.add_handler(
|
||||
MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback)
|
||||
application = ApplicationBuilder().token(config['token']).build()
|
||||
bot = application.bot
|
||||
application.add_handler(MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback))
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
msg_stream_id={},
|
||||
seq=1,
|
||||
bot=bot,
|
||||
application=application,
|
||||
bot_account_id='',
|
||||
listeners={},
|
||||
)
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
@@ -278,14 +275,18 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners.pop(event_type)
|
||||
|
||||
|
||||
@@ -3,17 +3,19 @@ import logging
|
||||
import typing
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
import pydantic
|
||||
|
||||
from .. import adapter as msadapter
|
||||
from ..types import events as platform_events, message as platform_message, entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
from ...core import app
|
||||
from ..logger import EventLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebChatMessage(BaseModel):
|
||||
class WebChatMessage(pydantic.BaseModel):
|
||||
id: int
|
||||
role: str
|
||||
content: str
|
||||
@@ -41,30 +43,35 @@ class WebChatSession:
|
||||
return self.message_lists[pipeline_uuid]
|
||||
|
||||
|
||||
class WebChatAdapter(msadapter.MessagePlatformAdapter):
|
||||
class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""WebChat调试适配器,用于流水线调试"""
|
||||
|
||||
webchat_person_session: WebChatSession
|
||||
webchat_group_session: WebChatSession
|
||||
webchat_person_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession)
|
||||
webchat_group_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession)
|
||||
|
||||
listeners: typing.Dict[
|
||||
listeners: dict[
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], None],
|
||||
] = {}
|
||||
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||
] = pydantic.Field(default_factory=dict, exclude=True)
|
||||
|
||||
is_stream: bool
|
||||
is_stream: bool = pydantic.Field(exclude=True)
|
||||
debug_messages: dict[str, list[dict]] = pydantic.Field(default_factory=dict, exclude=True)
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
ap: app.Application = pydantic.Field(exclude=True)
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.webchat_person_session = WebChatSession(id='webchatperson')
|
||||
self.webchat_group_session = WebChatSession(id='webchatgroup')
|
||||
|
||||
self.bot_account_id = 'webchatbot'
|
||||
|
||||
self.is_stream = False
|
||||
self.debug_messages = {}
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
@@ -159,7 +166,9 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
func: typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], typing.Awaitable[None]],
|
||||
func: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
|
||||
],
|
||||
):
|
||||
"""注册事件监听器"""
|
||||
self.listeners[event_type] = func
|
||||
@@ -167,11 +176,16 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
func: typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], typing.Awaitable[None]],
|
||||
func: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
|
||||
],
|
||||
):
|
||||
"""取消注册事件监听器"""
|
||||
del self.listeners[event_type]
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
|
||||
async def run_async(self):
|
||||
"""运行适配器"""
|
||||
await self.logger.info('WebChat调试适配器已启动')
|
||||
@@ -221,7 +235,7 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
|
||||
message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp()))
|
||||
|
||||
if session_type == 'person':
|
||||
sender = platform_entities.Friend(id='webchatperson', nickname='User')
|
||||
sender = platform_entities.Friend(id='webchatperson', nickname='User', remark='User')
|
||||
event = platform_events.FriendMessage(
|
||||
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
|
||||
)
|
||||
|
||||
@@ -16,24 +16,30 @@ import threading
|
||||
|
||||
import quart
|
||||
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ..logger import EventLogger
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Optional, Tuple
|
||||
from functools import partial
|
||||
import logging
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
|
||||
|
||||
class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
def __init__(self, config: dict, logger: logging.Logger):
|
||||
class WeChatPadMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
|
||||
self.bot = WeChatPadClient(config['wechatpad_url'], config['token'])
|
||||
self.config = config
|
||||
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
|
||||
self.logger = logger
|
||||
|
||||
# super().__init__(
|
||||
# config = config,
|
||||
# bot = bot,
|
||||
# logger = logger,
|
||||
# )
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]:
|
||||
content_list = []
|
||||
@@ -447,11 +453,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
return from_user_name.endswith('@chatroom')
|
||||
|
||||
|
||||
class WeChatPadEventConverter(adapter.EventConverter):
|
||||
class WeChatPadEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
def __init__(self, config: dict, logger: logging.Logger):
|
||||
self.config = config
|
||||
self.message_converter = WeChatPadMessageConverter(config, logger)
|
||||
self.logger = logger
|
||||
self.message_converter = WeChatPadMessageConverter(self.config, self.logger)
|
||||
# super().__init__(
|
||||
# config=config,
|
||||
# message_converter=message_converter,
|
||||
# logger = logger,
|
||||
# )
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.MessageEvent) -> dict:
|
||||
@@ -511,7 +522,7 @@ class WeChatPadEventConverter(adapter.EventConverter):
|
||||
)
|
||||
|
||||
|
||||
class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
class WeChatPadAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
name: str = 'WeChatPad' # 定义适配器名称
|
||||
|
||||
bot: WeChatPadClient
|
||||
@@ -521,29 +532,38 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
config: dict
|
||||
|
||||
ap: app.Application
|
||||
logger: EventLogger
|
||||
|
||||
message_converter: WeChatPadMessageConverter
|
||||
event_converter: WeChatPadEventConverter
|
||||
|
||||
listeners: typing.Dict[
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||
] = {}
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
self.quart_app = quart.Quart(__name__)
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
|
||||
self.message_converter = WeChatPadMessageConverter(config, ap.logger)
|
||||
self.event_converter = WeChatPadEventConverter(config, ap.logger)
|
||||
quart_app = quart.Quart(__name__)
|
||||
|
||||
message_converter = WeChatPadMessageConverter(config, logger)
|
||||
event_converter = WeChatPadEventConverter(config, logger)
|
||||
bot = WeChatPadClient(config['wechatpad_url'], config['token'])
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger = logger,
|
||||
quart_app = quart_app,
|
||||
message_converter =message_converter,
|
||||
event_converter = event_converter,
|
||||
listeners={},
|
||||
bot_account_id ='',
|
||||
name="WeChatPad",
|
||||
bot=bot,
|
||||
|
||||
)
|
||||
|
||||
async def ws_message(self, data):
|
||||
"""处理接收到的消息"""
|
||||
# self.ap.logger.debug(f"Gewechat callback event: {data}")
|
||||
# print(data)
|
||||
|
||||
try:
|
||||
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
|
||||
@@ -609,9 +629,8 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
if handler := handler_map.get(msg['type']):
|
||||
handler(msg)
|
||||
# self.ap.logger.warning(f"未处理的消息类型: {ret}")
|
||||
else:
|
||||
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
|
||||
self.logger.warning(f'未处理的消息类型: {msg["type"]}')
|
||||
continue
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
@@ -635,14 +654,18 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -653,7 +676,6 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
if self.config['token']:
|
||||
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
|
||||
data = self.bot.get_login_status()
|
||||
self.ap.logger.info(data)
|
||||
if data['Code'] == 300 and data['Text'] == '你已退出微信':
|
||||
response = requests.post(
|
||||
f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
|
||||
@@ -673,7 +695,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
self.config['token'] = response.json()['Data'][0]
|
||||
|
||||
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger)
|
||||
self.ap.logger.info(self.config['token'])
|
||||
await self.logger.info(self.config['token'])
|
||||
thread_1 = threading.Event()
|
||||
|
||||
def wechat_login_process():
|
||||
@@ -681,10 +703,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
# login_data =self.bot.get_login_qr()
|
||||
|
||||
# url = login_data['Data']["QrCodeUrl"]
|
||||
# self.ap.logger.info(login_data)
|
||||
|
||||
profile = self.bot.get_profile()
|
||||
self.ap.logger.info(profile)
|
||||
# self.logger.info(profile)
|
||||
|
||||
self.bot_account_id = profile['Data']['userInfo']['nickName']['str']
|
||||
self.config['wxid'] = profile['Data']['userInfo']['userName']['str']
|
||||
@@ -696,27 +717,26 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
def connect_websocket_sync() -> None:
|
||||
thread_1.wait()
|
||||
uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}'
|
||||
self.ap.logger.info(f'Connecting to WebSocket: {uri}')
|
||||
print(f'Connecting to WebSocket: {uri}')
|
||||
|
||||
def on_message(ws, message):
|
||||
try:
|
||||
data = json.loads(message)
|
||||
self.ap.logger.debug(f'Received message: {data}')
|
||||
# 这里需要确保ws_message是同步的,或者使用asyncio.run调用异步方法
|
||||
asyncio.run(self.ws_message(data))
|
||||
except json.JSONDecodeError:
|
||||
self.ap.logger.error(f'Non-JSON message: {message[:100]}...')
|
||||
self.logger.error(f'Non-JSON message: {message[:100]}...')
|
||||
|
||||
def on_error(ws, error):
|
||||
self.ap.logger.error(f'WebSocket error: {str(error)[:200]}')
|
||||
self.logger.error(f'WebSocket error: {str(error)[:200]}')
|
||||
|
||||
def on_close(ws, close_status_code, close_msg):
|
||||
self.ap.logger.info('WebSocket closed, reconnecting...')
|
||||
self.logger.info('WebSocket closed, reconnecting...')
|
||||
time.sleep(5)
|
||||
connect_websocket_sync() # 自动重连
|
||||
|
||||
def on_open(ws):
|
||||
self.ap.logger.info('WebSocket connected successfully!')
|
||||
self.logger.info('WebSocket connected successfully!')
|
||||
|
||||
ws = websocket.WebSocketApp(
|
||||
uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open
|
||||
@@ -727,10 +747,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
# connect_websocket_sync()
|
||||
|
||||
# 这行代码会在WebSocket连接断开后才会执行
|
||||
# self.ap.logger.info("WebSocket client thread started")
|
||||
thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True)
|
||||
thread.start()
|
||||
self.ap.logger.info('WebSocket client thread started')
|
||||
self.logger.info('WebSocket client thread started')
|
||||
|
||||
async def kill(self) -> bool:
|
||||
pass
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user