From 209f16af7609ea80b7f8f442bdbffbed4dc82a17 Mon Sep 17 00:00:00 2001 From: "Junyan Qin (Chin)" Date: Tue, 29 Apr 2025 17:24:07 +0800 Subject: [PATCH] style: introduce ruff as linter and formatter (#1356) * style: remove necessary imports * style: fix F841 * style: fix F401 * style: fix F811 * style: fix E402 * style: fix E721 * style: fix E722 * style: fix E722 * style: fix F541 * style: ruff format * style: all passed * style: add ruff in deps * style: more ignores in ruff.toml * style: add pre-commit --- .gitignore | 1 + .pre-commit-config.yaml | 9 + libs/dify_service_api/__init__.py | 6 +- libs/dify_service_api/test.py | 27 +- libs/dify_service_api/v1/client.py | 85 ++--- libs/dify_service_api/v1/client_test.py | 6 +- libs/dingtalk_api/EchoHandler.py | 5 +- libs/dingtalk_api/api.py | 171 +++++---- libs/dingtalk_api/dingtalkevent.py | 32 +- libs/official_account_api/api.py | 234 +++++++------ libs/official_account_api/oaevent.py | 31 +- libs/qq_official_api/api.py | 210 ++++++------ libs/qq_official_api/qqofficialevent.py | 64 ++-- libs/wecom_api/WXBizMsgCrypt3.py | 33 +- libs/wecom_api/api.py | 280 ++++++++------- libs/wecom_api/ierror.py | 4 +- libs/wecom_api/wecomevent.py | 34 +- main.py | 29 +- pkg/api/http/controller/group.py | 46 ++- pkg/api/http/controller/groups/logs.py | 17 +- pkg/api/http/controller/groups/pipelines.py | 30 +- .../controller/groups/platform/adapters.py | 28 +- .../http/controller/groups/platform/bots.py | 17 +- pkg/api/http/controller/groups/plugins.py | 87 ++--- .../http/controller/groups/provider/models.py | 17 +- .../controller/groups/provider/requesters.py | 26 +- pkg/api/http/controller/groups/stats.py | 22 +- pkg/api/http/controller/groups/system.py | 45 ++- pkg/api/http/controller/groups/user.py | 29 +- pkg/api/http/controller/main.py | 87 ++--- pkg/api/http/service/bot.py | 34 +- pkg/api/http/service/model.py | 37 +- pkg/api/http/service/pipeline.py | 67 ++-- pkg/api/http/service/user.py | 12 +- pkg/audit/__init__.py | 2 +- pkg/audit/center/apigroup.py | 18 +- pkg/audit/center/groups/main.py | 38 +- pkg/audit/center/groups/plugin.py | 50 ++- pkg/audit/center/groups/usage.py | 73 ++-- pkg/audit/center/v2.py | 15 +- pkg/audit/identifier.py | 30 +- pkg/command/cmdmgr.py | 59 ++-- pkg/command/entities.py | 12 +- pkg/command/errors.py | 17 +- pkg/command/operator.py | 17 +- pkg/command/operators/cmd.py | 39 +-- pkg/command/operators/delc.py | 56 ++- pkg/command/operators/func.py | 9 +- pkg/command/operators/help.py | 12 +- pkg/command/operators/last.py | 43 ++- pkg/command/operators/list.py | 32 +- pkg/command/operators/model.py | 70 ++-- pkg/command/operators/next.py | 43 ++- pkg/command/operators/ollama.py | 71 ++-- pkg/command/operators/plugin.py | 195 ++++++----- pkg/command/operators/prompt.py | 23 +- pkg/command/operators/resend.py | 16 +- pkg/command/operators/reset.py | 17 +- pkg/command/operators/update.py | 24 +- pkg/command/operators/version.py | 20 +- pkg/config/impls/json.py | 23 +- pkg/config/impls/pymodule.py | 6 +- pkg/config/impls/yaml.py | 25 +- pkg/config/manager.py | 45 +-- pkg/config/model.py | 2 +- pkg/core/app.py | 64 +++- pkg/core/boot.py | 29 +- pkg/core/bootutils/config.py | 4 +- pkg/core/bootutils/deps.py | 83 ++--- pkg/core/bootutils/files.py | 18 +- pkg/core/bootutils/log.py | 36 +- pkg/core/entities.py | 45 ++- pkg/core/migration.py | 18 +- .../m001_sensitive_word_migration.py | 20 +- .../m002_openai_config_migration.py | 19 +- ...m003_anthropic_requester_cfg_completion.py | 19 +- .../m004_moonshot_cfg_completion.py | 15 +- .../m005_deepseek_cfg_completion.py | 17 +- pkg/core/migrations/m006_vision_config.py | 8 +- pkg/core/migrations/m007_qcg_center_url.py | 14 +- .../m008_ad_fixwin_config_migrate.py | 20 +- pkg/core/migrations/m009_msg_truncator_cfg.py | 8 +- .../m010_ollama_requester_config.py | 10 +- .../migrations/m011_command_prefix_config.py | 8 +- pkg/core/migrations/m012_runner_config.py | 4 +- pkg/core/migrations/m013_http_api_config.py | 23 +- .../migrations/m014_force_delay_config.py | 8 +- pkg/core/migrations/m015_gitee_ai_config.py | 17 +- pkg/core/migrations/m016_dify_service_api.py | 15 +- .../m017_dify_api_timeout_params.py | 12 +- pkg/core/migrations/m018_xai_config.py | 12 +- pkg/core/migrations/m019_zhipuai_config.py | 12 +- pkg/core/migrations/m020_wecom_config.py | 28 +- pkg/core/migrations/m021_lark_config.py | 26 +- pkg/core/migrations/m022_lmstudio_config.py | 10 +- .../migrations/m023_siliconflow_config.py | 18 +- pkg/core/migrations/m024_discord_config.py | 18 +- pkg/core/migrations/m025_gewechat_config.py | 26 +- pkg/core/migrations/m026_qqofficial_config.py | 22 +- .../m027_wx_official_account_config.py | 26 +- .../m028_aliyun_requester_config.py | 14 +- .../m029_dashscope_app_api_config.py | 22 +- pkg/core/migrations/m030_lark_config_cmpl.py | 6 +- pkg/core/migrations/m031_dingtalk_config.py | 22 +- pkg/core/migrations/m032_volcark_config.py | 14 +- .../migrations/m033_dify_thinking_config.py | 17 +- .../m034_gewechat_file_url_config.py | 6 +- pkg/core/migrations/m035_wxoa_mode.py | 2 +- .../migrations/m036_wxoa_loading_message.py | 2 +- pkg/core/migrations/m037_mcp_config.py | 6 +- pkg/core/note.py | 15 +- pkg/core/notes/n001_classic_msgs.py | 10 +- .../notes/n002_selection_mode_on_windows.py | 14 +- pkg/core/notes/n003_print_version.py | 12 +- pkg/core/stage.py | 12 +- pkg/core/stages/build_app.py | 27 +- pkg/core/stages/genkeys.py | 11 +- pkg/core/stages/load_config.py | 85 +++-- pkg/core/stages/migrate.py | 39 +-- pkg/core/stages/setup_logger.py | 13 +- pkg/core/stages/show_notes.py | 14 +- pkg/core/taskmgr.py | 114 +++--- pkg/discover/engine.py | 103 ++++-- pkg/entity/persistence/base.py | 1 - pkg/entity/persistence/bot.py | 12 +- pkg/entity/persistence/metadata.py | 1 + pkg/entity/persistence/model.py | 13 +- pkg/entity/persistence/pipeline.py | 24 +- pkg/entity/persistence/plugin.py | 12 +- pkg/entity/persistence/user.py | 12 +- pkg/persistence/database.py | 1 + pkg/persistence/databases/sqlite.py | 8 +- pkg/persistence/mgr.py | 80 +++-- pkg/persistence/migration.py | 2 + .../migrations/dbm001_migrate_v3_config.py | 4 +- pkg/pipeline/bansess/bansess.py | 26 +- pkg/pipeline/cntfilter/cntfilter.py | 78 ++--- pkg/pipeline/cntfilter/entities.py | 9 +- pkg/pipeline/cntfilter/filter.py | 18 +- .../cntfilter/filters/baiduexamine.py | 44 ++- pkg/pipeline/cntfilter/filters/banwords.py | 15 +- pkg/pipeline/cntfilter/filters/cntignore.py | 16 +- pkg/pipeline/controller.py | 51 +-- pkg/pipeline/entities.py | 11 +- pkg/pipeline/longtext/longtext.py | 69 ++-- pkg/pipeline/longtext/strategies/forward.py | 26 +- pkg/pipeline/longtext/strategies/image.py | 62 ++-- pkg/pipeline/longtext/strategy.py | 13 +- pkg/pipeline/msgtrun/msgtrun.py | 25 +- pkg/pipeline/msgtrun/truncator.py | 8 +- pkg/pipeline/msgtrun/truncators/round.py | 10 +- pkg/pipeline/pipelinemgr.py | 155 ++++++--- pkg/pipeline/pool.py | 2 +- pkg/pipeline/preproc/preproc.py | 49 +-- pkg/pipeline/process/handler.py | 1 - pkg/pipeline/process/handlers/chat.py | 54 +-- pkg/pipeline/process/handlers/command.py | 59 ++-- pkg/pipeline/process/process.py | 19 +- pkg/pipeline/ratelimit/algo.py | 27 +- pkg/pipeline/ratelimit/algos/fixedwin.py | 30 +- pkg/pipeline/ratelimit/ratelimit.py | 24 +- pkg/pipeline/respback/respback.py | 40 +-- pkg/pipeline/resprule/entities.py | 1 - pkg/pipeline/resprule/resprule.py | 32 +- pkg/pipeline/resprule/rule.py | 11 +- pkg/pipeline/resprule/rules/atbot.py | 20 +- pkg/pipeline/resprule/rules/prefix.py | 14 +- pkg/pipeline/resprule/rules/random.py | 12 +- pkg/pipeline/resprule/rules/regexp.py | 12 +- pkg/pipeline/stage.py | 12 +- pkg/pipeline/wrapper/wrapper.py | 98 +++--- pkg/platform/adapter.py | 30 +- pkg/platform/botmgr.py | 101 +++--- pkg/platform/sources/aiocqhttp.py | 131 ++++--- pkg/platform/sources/dingtalk.py | 113 +++--- pkg/platform/sources/discord.py | 104 +++--- pkg/platform/sources/gewechat.py | 324 ++++++++++-------- pkg/platform/sources/lark.py | 168 +++++---- pkg/platform/sources/nakuru.py | 154 +++++---- pkg/platform/sources/officialaccount.py | 108 +++--- pkg/platform/sources/qqbotpy.py | 202 ++++++----- pkg/platform/sources/qqofficial.py | 263 +++++++------- pkg/platform/sources/telegram.py | 134 ++++---- pkg/platform/sources/wecom.py | 143 ++++---- pkg/platform/types/base.py | 15 +- pkg/platform/types/entities.py | 16 +- pkg/platform/types/events.py | 25 +- pkg/platform/types/message.py | 213 +++++++----- pkg/plugin/__init__.py | 2 +- pkg/plugin/context.py | 86 +++-- pkg/plugin/errors.py | 9 +- pkg/plugin/events.py | 14 +- pkg/plugin/host.py | 6 +- pkg/plugin/installer.py | 12 +- pkg/plugin/installers/github.py | 92 ++--- pkg/plugin/loader.py | 5 +- pkg/plugin/loaders/classic.py | 114 +++--- pkg/plugin/loaders/manifest.py | 36 +- pkg/plugin/manager.py | 161 ++++----- pkg/plugin/models.py | 10 +- pkg/provider/entities.py | 27 +- pkg/provider/modelmgr/errors.py | 2 +- pkg/provider/modelmgr/modelmgr.py | 55 +-- pkg/provider/modelmgr/requester.py | 17 +- .../modelmgr/requesters/anthropicmsgs.py | 108 +++--- .../modelmgr/requesters/bailianchatcmpl.py | 2 - pkg/provider/modelmgr/requesters/chatcmpl.py | 68 ++-- .../modelmgr/requesters/deepseekchatcmpl.py | 16 +- .../modelmgr/requesters/giteeaichatcmpl.py | 18 +- .../modelmgr/requesters/lmstudiochatcmpl.py | 2 - .../modelmgr/requesters/moonshotchatcmpl.py | 19 +- .../modelmgr/requesters/ollamachat.py | 61 ++-- .../requesters/siliconflowchatcmpl.py | 2 - .../modelmgr/requesters/volcarkchatcmpl.py | 2 - .../modelmgr/requesters/xaichatcmpl.py | 2 - .../modelmgr/requesters/zhipuaichatcmpl.py | 2 - pkg/provider/modelmgr/token.py | 7 +- pkg/provider/runner.py | 16 +- pkg/provider/runners/dashscopeapi.py | 228 ++++++------ pkg/provider/runners/difysvapi.py | 200 ++++++----- pkg/provider/runners/localagent.py | 34 +- pkg/provider/session/sessionmgr.py | 21 +- pkg/provider/tools/entities.py | 4 - pkg/provider/tools/loader.py | 15 +- pkg/provider/tools/loaders/mcp.py | 74 ++-- pkg/provider/tools/loaders/plugin.py | 30 +- pkg/provider/tools/toolmgr.py | 34 +- pkg/utils/announce.py | 67 ++-- pkg/utils/constants.py | 4 +- pkg/utils/funcschema.py | 53 ++- pkg/utils/image.py | 100 +++--- pkg/utils/importutil.py | 43 +++ pkg/utils/ip.py | 11 +- pkg/utils/logcache.py | 9 +- pkg/utils/pkgmgr.py | 27 +- pkg/utils/proxy.py | 26 +- pkg/utils/version.py | 134 ++++---- requirements.txt | 4 +- res/scripts/publish_announcement.py | 24 +- ruff.toml | 38 ++ 240 files changed, 5307 insertions(+), 4689 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 pkg/utils/importutil.py create mode 100644 ruff.toml diff --git a/.gitignore b/.gitignore index 17271201..0a14ca5b 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ qcapi claude.json bard.json /*yaml +!.pre-commit-config.yaml !components.yaml !/docker-compose.yaml data/labels/instance_id.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..6efb3e3e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.11.7 + hooks: + # Run the linter. + - id: ruff + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/libs/dify_service_api/__init__.py b/libs/dify_service_api/__init__.py index 5f178abb..bd6f6d4f 100644 --- a/libs/dify_service_api/__init__.py +++ b/libs/dify_service_api/__init__.py @@ -1,2 +1,4 @@ -from .v1 import client -from .v1 import errors \ No newline at end of file +from .v1 import client as client +from .v1 import errors as errors + +__all__ = ['client', 'errors'] diff --git a/libs/dify_service_api/test.py b/libs/dify_service_api/test.py index faf7571a..b7e2281c 100644 --- a/libs/dify_service_api/test.py +++ b/libs/dify_service_api/test.py @@ -8,25 +8,33 @@ import json class TestDifyClient: async def test_chat_messages(self): - cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL")) + cln = client.AsyncDifyServiceClient( + api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL') + ) - async for chunk in cln.chat_messages(inputs={}, query="调用工具查看现在几点?", user="test"): + async for chunk in cln.chat_messages( + inputs={}, query='调用工具查看现在几点?', user='test' + ): print(json.dumps(chunk, ensure_ascii=False, indent=4)) async def test_upload_file(self): - cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL")) + cln = client.AsyncDifyServiceClient( + api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL') + ) - file_bytes = open("img.png", "rb").read() + file_bytes = open('img.png', 'rb').read() print(type(file_bytes)) - file = ("img2.png", file_bytes, "image/png") + file = ('img2.png', file_bytes, 'image/png') - resp = await cln.upload_file(file=file, user="test") + resp = await cln.upload_file(file=file, user='test') print(json.dumps(resp, ensure_ascii=False, indent=4)) async def test_workflow_run(self): - cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL")) + cln = client.AsyncDifyServiceClient( + api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL') + ) # resp = await cln.workflow_run(inputs={}, user="test") # # print(json.dumps(resp, ensure_ascii=False, indent=4)) @@ -34,11 +42,12 @@ class TestDifyClient: chunks = [] ignored_events = ['text_chunk'] - async for chunk in cln.workflow_run(inputs={}, user="test"): + async for chunk in cln.workflow_run(inputs={}, user='test'): if chunk['event'] in ignored_events: continue chunks.append(chunk) print(json.dumps(chunks, ensure_ascii=False, indent=4)) -if __name__ == "__main__": + +if __name__ == '__main__': asyncio.run(TestDifyClient().test_chat_messages()) diff --git a/libs/dify_service_api/v1/client.py b/libs/dify_service_api/v1/client.py index 70a804b7..35defe2c 100644 --- a/libs/dify_service_api/v1/client.py +++ b/libs/dify_service_api/v1/client.py @@ -12,11 +12,11 @@ class AsyncDifyServiceClient: api_key: str base_url: str - + def __init__( self, api_key: str, - base_url: str = "https://api.dify.ai/v1", + base_url: str = 'https://api.dify.ai/v1', ) -> None: self.api_key = api_key self.base_url = base_url @@ -26,76 +26,81 @@ class AsyncDifyServiceClient: inputs: dict[str, typing.Any], query: str, user: str, - response_mode: str = "streaming", # 当前不支持 blocking - conversation_id: str = "", + response_mode: str = 'streaming', # 当前不支持 blocking + conversation_id: str = '', files: list[dict[str, typing.Any]] = [], timeout: float = 30.0, ) -> typing.AsyncGenerator[dict[str, typing.Any], None]: """发送消息""" - if response_mode != "streaming": - raise DifyAPIError("当前仅支持 streaming 模式") - + if response_mode != 'streaming': + raise DifyAPIError('当前仅支持 streaming 模式') + async with httpx.AsyncClient( base_url=self.base_url, trust_env=True, timeout=timeout, ) as client: async with client.stream( - "POST", - "/chat-messages", - headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}, + 'POST', + '/chat-messages', + headers={ + 'Authorization': f'Bearer {self.api_key}', + 'Content-Type': 'application/json', + }, json={ - "inputs": inputs, - "query": query, - "user": user, - "response_mode": response_mode, - "conversation_id": conversation_id, - "files": files, + 'inputs': inputs, + 'query': query, + 'user': user, + 'response_mode': response_mode, + 'conversation_id': conversation_id, + 'files': files, }, ) as r: async for chunk in r.aiter_lines(): if r.status_code != 200: - raise DifyAPIError(f"{r.status_code} {chunk}") - if chunk.strip() == "": + raise DifyAPIError(f'{r.status_code} {chunk}') + if chunk.strip() == '': continue - if chunk.startswith("data:"): + if chunk.startswith('data:'): yield json.loads(chunk[5:]) - + async def workflow_run( self, inputs: dict[str, typing.Any], user: str, - response_mode: str = "streaming", # 当前不支持 blocking + response_mode: str = 'streaming', # 当前不支持 blocking files: list[dict[str, typing.Any]] = [], timeout: float = 30.0, ) -> typing.AsyncGenerator[dict[str, typing.Any], None]: """运行工作流""" - if response_mode != "streaming": - raise DifyAPIError("当前仅支持 streaming 模式") - + if response_mode != 'streaming': + raise DifyAPIError('当前仅支持 streaming 模式') + async with httpx.AsyncClient( base_url=self.base_url, trust_env=True, timeout=timeout, ) as client: - async with client.stream( - "POST", - "/workflows/run", - headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}, + 'POST', + '/workflows/run', + headers={ + 'Authorization': f'Bearer {self.api_key}', + 'Content-Type': 'application/json', + }, json={ - "inputs": inputs, - "user": user, - "response_mode": response_mode, - "files": files, + 'inputs': inputs, + 'user': user, + 'response_mode': response_mode, + 'files': files, }, ) as r: async for chunk in r.aiter_lines(): if r.status_code != 200: - raise DifyAPIError(f"{r.status_code} {chunk}") - if chunk.strip() == "": + raise DifyAPIError(f'{r.status_code} {chunk}') + if chunk.strip() == '': continue - if chunk.startswith("data:"): + if chunk.startswith('data:'): yield json.loads(chunk[5:]) async def upload_file( @@ -112,15 +117,15 @@ class AsyncDifyServiceClient: ) as client: # multipart/form-data response = await client.post( - "/files/upload", - headers={"Authorization": f"Bearer {self.api_key}"}, + '/files/upload', + headers={'Authorization': f'Bearer {self.api_key}'}, files={ - "file": file, - "user": (None, user), + 'file': file, + 'user': (None, user), }, ) if response.status_code != 201: - raise DifyAPIError(f"{response.status_code} {response.text}") + raise DifyAPIError(f'{response.status_code} {response.text}') return response.json() diff --git a/libs/dify_service_api/v1/client_test.py b/libs/dify_service_api/v1/client_test.py index 58ef53b4..2695b2ea 100644 --- a/libs/dify_service_api/v1/client_test.py +++ b/libs/dify_service_api/v1/client_test.py @@ -7,11 +7,11 @@ import os class TestDifyClient: async def test_chat_messages(self): - cln = client.DifyClient(api_key=os.getenv("DIFY_API_KEY")) + cln = client.DifyClient(api_key=os.getenv('DIFY_API_KEY')) - resp = await cln.chat_messages(inputs={}, query="Who are you?", user_id="test") + resp = await cln.chat_messages(inputs={}, query='Who are you?', user_id='test') print(resp) -if __name__ == "__main__": +if __name__ == '__main__': asyncio.run(TestDifyClient().test_chat_messages()) diff --git a/libs/dingtalk_api/EchoHandler.py b/libs/dingtalk_api/EchoHandler.py index 4cf0f563..793c3d6d 100644 --- a/libs/dingtalk_api/EchoHandler.py +++ b/libs/dingtalk_api/EchoHandler.py @@ -1,8 +1,8 @@ import asyncio -import json import dingtalk_stream from dingtalk_stream import AckMessage + class EchoTextHandler(dingtalk_stream.ChatbotHandler): def __init__(self, client): self.msg_id = '' @@ -10,6 +10,7 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler): self.client = client # 用于更新 DingTalkClient 中的 incoming_message """处理钉钉消息""" + async def process(self, callback: dingtalk_stream.CallbackMessage): incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data) if incoming_message.message_id != self.msg_id: @@ -26,6 +27,8 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler): return self.incoming_message + async def get_dingtalk_client(client_id, client_secret): from api import DingTalkClient # 延迟导入,避免循环导入 + return DingTalkClient(client_id, client_secret) diff --git a/libs/dingtalk_api/api.py b/libs/dingtalk_api/api.py index fa4d0421..b908fd4f 100644 --- a/libs/dingtalk_api/api.py +++ b/libs/dingtalk_api/api.py @@ -10,7 +10,9 @@ import traceback class DingTalkClient: - def __init__(self, client_id: str, client_secret: str,robot_name:str,robot_code:str): + def __init__( + self, client_id: str, client_secret: str, robot_name: str, robot_code: str + ): """初始化 WebSocket 连接并自动启动""" self.credential = dingtalk_stream.Credential(client_id, client_secret) self.client = dingtalk_stream.DingTalkStreamClient(self.credential) @@ -18,106 +20,91 @@ class DingTalkClient: self.secret = client_secret # 在 DingTalkClient 中传入自己作为参数,避免循环导入 self.EchoTextHandler = EchoTextHandler(self) - self.client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler) + self.client.register_callback_handler( + dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler + ) self._message_handlers = { - "example":[], + 'example': [], } self.access_token = '' self.robot_name = robot_name self.robot_code = robot_code self.access_token_expiry_time = '' - - async def get_access_token(self): - url = "https://api.dingtalk.com/v1.0/oauth2/accessToken" - headers = { - "Content-Type": "application/json" - } - data = { - "appKey": self.key, - "appSecret": self.secret - } + url = 'https://api.dingtalk.com/v1.0/oauth2/accessToken' + headers = {'Content-Type': 'application/json'} + data = {'appKey': self.key, 'appSecret': self.secret} async with httpx.AsyncClient() as client: try: - response = await client.post(url,json=data,headers=headers) + response = await client.post(url, json=data, headers=headers) if response.status_code == 200: response_data = response.json() - self.access_token = response_data.get("accessToken") - expires_in = int(response_data.get("expireIn",7200)) + self.access_token = response_data.get('accessToken') + expires_in = int(response_data.get('expireIn', 7200)) self.access_token_expiry_time = time.time() + expires_in - 60 except Exception as e: raise Exception(e) - async def is_token_expired(self): """检查token是否过期""" if self.access_token_expiry_time is None: return True return time.time() > self.access_token_expiry_time - + async def check_access_token(self): if not self.access_token or await self.is_token_expired(): return False return bool(self.access_token and self.access_token.strip()) - async def download_image(self,download_code:str): + async def download_image(self, download_code: str): if not await self.check_access_token(): await self.get_access_token() url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download' - params = { - "downloadCode":download_code, - "robotCode":self.robot_code - } - headers ={ - "x-acs-dingtalk-access-token": self.access_token - } + params = {'downloadCode': download_code, 'robotCode': self.robot_code} + headers = {'x-acs-dingtalk-access-token': self.access_token} async with httpx.AsyncClient() as client: response = await client.post(url, headers=headers, json=params) if response.status_code == 200: result = response.json() - download_url = result.get("downloadUrl") + download_url = result.get('downloadUrl') else: - raise Exception(f"Error: {response.status_code}, {response.text}") + raise Exception(f'Error: {response.status_code}, {response.text}') if download_url: return await self.download_url_to_base64(download_url) - async def download_url_to_base64(self,download_url): + async def download_url_to_base64(self, download_url): async with httpx.AsyncClient() as client: response = await client.get(download_url) - + if response.status_code == 200: - file_bytes = response.content - base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式 + base64_str = base64.b64encode(file_bytes).decode( + 'utf-8' + ) # 返回字符串格式 return base64_str else: - raise Exception("获取文件失败") - - async def get_audio_url(self,download_code:str): + raise Exception('获取文件失败') + + async def get_audio_url(self, download_code: str): if not await self.check_access_token(): await self.get_access_token() url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download' - params = { - "downloadCode":download_code, - "robotCode":self.robot_code - } - headers ={ - "x-acs-dingtalk-access-token": self.access_token - } + params = {'downloadCode': download_code, 'robotCode': self.robot_code} + headers = {'x-acs-dingtalk-access-token': self.access_token} async with httpx.AsyncClient() as client: response = await client.post(url, headers=headers, json=params) if response.status_code == 200: result = response.json() - download_url = result.get("downloadUrl") + download_url = result.get('downloadUrl') if download_url: return await self.download_url_to_base64(download_url) else: - raise Exception("获取音频失败") + raise Exception('获取音频失败') else: - raise Exception(f"Error: {response.status_code}, {response.text}") - + raise Exception(f'Error: {response.status_code}, {response.text}') + async def update_incoming_message(self, message): """异步更新 DingTalkClient 中的 incoming_message""" message_data = await self.get_message(message) @@ -125,24 +112,21 @@ class DingTalkClient: event = DingTalkEvent.from_payload(message_data) if event: await self._handle_message(event) - - async def send_message(self,content:str,incoming_message): - self.EchoTextHandler.reply_text(content,incoming_message) - + async def send_message(self, content: str, incoming_message): + self.EchoTextHandler.reply_text(content, incoming_message) async def get_incoming_message(self): """获取收到的消息""" return await self.EchoTextHandler.get_incoming_message() - - def on_message(self, msg_type: str): def decorator(func: Callable[[DingTalkEvent], None]): if msg_type not in self._message_handlers: self._message_handlers[msg_type] = [] self._message_handlers[msg_type].append(func) return func + return decorator async def _handle_message(self, event: DingTalkEvent): @@ -154,40 +138,44 @@ class DingTalkClient: for handler in self._message_handlers[msg_type]: await handler(event) - - async def get_message(self,incoming_message:dingtalk_stream.chatbot.ChatbotMessage): + async def get_message( + self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage + ): try: - # print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False)) message_data = { - "IncomingMessage":incoming_message, + 'IncomingMessage': incoming_message, } if str(incoming_message.conversation_type) == '1': - message_data["conversation_type"] = 'FriendMessage' + message_data['conversation_type'] = 'FriendMessage' elif str(incoming_message.conversation_type) == '2': - message_data["conversation_type"] = 'GroupMessage' + message_data['conversation_type'] = 'GroupMessage' - if incoming_message.message_type == 'richText': - data = incoming_message.rich_text_content.to_dict() for item in data['richText']: if 'text' in item: - message_data["Content"] = item['text'] + message_data['Content'] = item['text'] if incoming_message.get_image_list()[0]: - message_data["Picture"] = await self.download_image(incoming_message.get_image_list()[0]) - message_data["Type"] = 'text' - + message_data['Picture'] = await self.download_image( + incoming_message.get_image_list()[0] + ) + message_data['Type'] = 'text' + elif incoming_message.message_type == 'text': message_data['Content'] = incoming_message.get_text_list()[0] - message_data["Type"] = 'text' + message_data['Type'] = 'text' elif incoming_message.message_type == 'picture': - message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0]) - + message_data['Picture'] = await self.download_image( + incoming_message.get_image_list()[0] + ) + message_data['Type'] = 'image' elif incoming_message.message_type == 'audio': - message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode']) + message_data['Audio'] = await self.get_audio_url( + incoming_message.to_dict()['content']['downloadCode'] + ) message_data['Type'] = 'audio' @@ -196,56 +184,55 @@ class DingTalkClient: # print("message_data:", json.dumps(copy_message_data, indent=4, ensure_ascii=False)) except Exception: traceback.print_exc() - + return message_data - async def send_proactive_message_to_one(self,target_id:str,content:str): + async def send_proactive_message_to_one(self, target_id: str, content: str): if not await self.check_access_token(): await self.get_access_token() url = 'https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend' - headers ={ - "x-acs-dingtalk-access-token":self.access_token, - "Content-Type":"application/json", + headers = { + 'x-acs-dingtalk-access-token': self.access_token, + 'Content-Type': 'application/json', } - data ={ - "robotCode":self.robot_code, - "userIds":[target_id], - "msgKey": "sampleText", - "msgParam": json.dumps({"content":content}), + data = { + 'robotCode': self.robot_code, + 'userIds': [target_id], + 'msgKey': 'sampleText', + 'msgParam': json.dumps({'content': content}), } try: async with httpx.AsyncClient() as client: - response = await client.post(url,headers=headers,json=data) + await client.post(url, headers=headers, json=data) except Exception: traceback.print_exc() - - async def send_proactive_message_to_group(self,target_id:str,content:str): + async def send_proactive_message_to_group(self, target_id: str, content: str): if not await self.check_access_token(): await self.get_access_token() url = 'https://api.dingtalk.com/v1.0/robot/groupMessages/send' - headers ={ - "x-acs-dingtalk-access-token":self.access_token, - "Content-Type":"application/json", + headers = { + 'x-acs-dingtalk-access-token': self.access_token, + 'Content-Type': 'application/json', } - data ={ - "robotCode":self.robot_code, - "openConversationId":target_id, - "msgKey": "sampleText", - "msgParam": json.dumps({"content":content}), + data = { + 'robotCode': self.robot_code, + 'openConversationId': target_id, + 'msgKey': 'sampleText', + 'msgParam': json.dumps({'content': content}), } try: async with httpx.AsyncClient() as client: - response = await client.post(url,headers=headers,json=data) + await client.post(url, headers=headers, json=data) except Exception: traceback.print_exc() - + async def start(self): """启动 WebSocket 连接,监听消息""" - await self.client.start() + await self.client.start() diff --git a/libs/dingtalk_api/dingtalkevent.py b/libs/dingtalk_api/dingtalkevent.py index 4feca010..df968e74 100644 --- a/libs/dingtalk_api/dingtalkevent.py +++ b/libs/dingtalk_api/dingtalkevent.py @@ -1,41 +1,39 @@ from typing import Dict, Any, Optional import dingtalk_stream + class DingTalkEvent(dict): @staticmethod - def from_payload(payload: Dict[str, Any]) -> Optional["DingTalkEvent"]: + def from_payload(payload: Dict[str, Any]) -> Optional['DingTalkEvent']: try: event = DingTalkEvent(payload) return event except KeyError: return None - - - @property - def content(self): - return self.get("Content","") @property - def incoming_message(self) -> Optional["dingtalk_stream.chatbot.ChatbotMessage"]: - return self.get("IncomingMessage") + def content(self): + return self.get('Content', '') + + @property + def incoming_message(self) -> Optional['dingtalk_stream.chatbot.ChatbotMessage']: + return self.get('IncomingMessage') @property def type(self): - return self.get("Type","") - + return self.get('Type', '') + @property def picture(self): - return self.get("Picture","") - + return self.get('Picture', '') + @property def audio(self): - return self.get("Audio","") + return self.get('Audio', '') @property def conversation(self): - return self.get("conversation_type","") - - + return self.get('conversation_type', '') def __getattr__(self, key: str) -> Optional[Any]: """ @@ -66,4 +64,4 @@ class DingTalkEvent(dict): Returns: str: 字符串表示。 """ - return f"" + return f'' diff --git a/libs/official_account_api/api.py b/libs/official_account_api/api.py index a8d318dc..fc392c30 100644 --- a/libs/official_account_api/api.py +++ b/libs/official_account_api/api.py @@ -1,20 +1,14 @@ # 微信公众号的加解密算法与企业微信一样,所以直接使用企业微信的加解密算法文件 -from collections import deque import time import traceback from ..wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt import xml.etree.ElementTree as ET -from quart import Quart,request +from quart import Quart, request import hashlib -from typing import Callable, Dict, Any +from typing import Callable from .oaevent import OAEvent -import httpx import asyncio -import time -import xml.etree.ElementTree as ET -from pkg.platform.sources import officialaccount as oa - xml_template = """ @@ -28,9 +22,8 @@ xml_template = """ """ -class OAClient(): - - def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str): +class OAClient: + def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str): self.token = token self.aes = EncodingAESKey self.appid = AppID @@ -38,121 +31,130 @@ class OAClient(): self.base_url = 'https://api.weixin.qq.com' self.access_token = '' self.app = Quart(__name__) - self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']) + self.app.add_url_rule( + '/callback/command', + 'handle_callback', + self.handle_callback_request, + methods=['GET', 'POST'], + ) self._message_handlers = { - "example":[], + 'example': [], } self.access_token_expiry_time = None self.msg_id_map = {} self.generated_content = {} async def handle_callback_request(self): - try: # 每隔100毫秒查询是否生成ai回答 start_time = time.time() - signature = request.args.get("signature", "") - timestamp = request.args.get("timestamp", "") - nonce = request.args.get("nonce", "") - echostr = request.args.get("echostr", "") - msg_signature = request.args.get("msg_signature","") + signature = request.args.get('signature', '') + timestamp = request.args.get('timestamp', '') + nonce = request.args.get('nonce', '') + echostr = request.args.get('echostr', '') + msg_signature = request.args.get('msg_signature', '') if msg_signature is None: - raise Exception("msg_signature不在请求体中") + raise Exception('msg_signature不在请求体中') if request.method == 'GET': # 校验签名 - check_str = "".join(sorted([self.token, timestamp, nonce])) - check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest() - + check_str = ''.join(sorted([self.token, timestamp, nonce])) + check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest() + if check_signature == signature: return echostr # 验证成功返回echostr else: - raise Exception("拒绝请求") - elif request.method == "POST": + raise Exception('拒绝请求') + elif request.method == 'POST': encryt_msg = await request.data - wxcpt = WXBizMsgCrypt(self.token,self.aes,self.appid) - ret,xml_msg = wxcpt.DecryptMsg(encryt_msg,msg_signature,timestamp,nonce) + wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid) + ret, xml_msg = wxcpt.DecryptMsg( + encryt_msg, msg_signature, timestamp, nonce + ) xml_msg = xml_msg.decode('utf-8') if ret != 0: - raise Exception("消息解密失败") + raise Exception('消息解密失败') message_data = await self.get_message(xml_msg) - if message_data : + if message_data: event = OAEvent.from_payload(message_data) if event: await self._handle_message(event) root = ET.fromstring(xml_msg) - from_user = root.find("FromUserName").text # 发送者 - to_user = root.find("ToUserName").text # 机器人 - + from_user = root.find('FromUserName').text # 发送者 + to_user = root.find('ToUserName').text # 机器人 + timeout = 4.80 interval = 0.1 while True: - content = self.generated_content.pop(message_data["MsgId"], None) + content = self.generated_content.pop(message_data['MsgId'], None) if content: response_xml = xml_template.format( to_user=from_user, from_user=to_user, create_time=int(time.time()), - content = content + content=content, ) return response_xml - + if time.time() - start_time >= timeout: break - + await asyncio.sleep(interval) - if self.msg_id_map.get(message_data["MsgId"], 1) == 3: - + if self.msg_id_map.get(message_data['MsgId'], 1) == 3: # response_xml = xml_template.format( # to_user=from_user, # from_user=to_user, # create_time=int(time.time()), # content = "请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。" # ) - print("请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。") + print( + '请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。' + ) return '' - except Exception as e: + except Exception: traceback.print_exc() - async def get_message(self, xml_msg: str): - root = ET.fromstring(xml_msg) message_data = { - "ToUserName": root.find("ToUserName").text, - "FromUserName": root.find("FromUserName").text, - "CreateTime": int(root.find("CreateTime").text), - "MsgType": root.find("MsgType").text, - "Content": root.find("Content").text if root.find("Content") is not None else None, - "MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None, + 'ToUserName': root.find('ToUserName').text, + 'FromUserName': root.find('FromUserName').text, + 'CreateTime': int(root.find('CreateTime').text), + 'MsgType': root.find('MsgType').text, + 'Content': root.find('Content').text + if root.find('Content') is not None + else None, + 'MsgId': int(root.find('MsgId').text) + if root.find('MsgId') is not None + else None, } return message_data - async def run_task(self, host: str, port: int, *args, **kwargs): """ 启动 Quart 应用。 """ await self.app.run_task(host=host, port=port, *args, **kwargs) - def on_message(self, msg_type: str): """ 注册消息类型处理器。 """ + def decorator(func: Callable[[OAEvent], None]): if msg_type not in self._message_handlers: self._message_handlers[msg_type] = [] self._message_handlers[msg_type].append(func) return func + return decorator async def _handle_message(self, event: OAEvent): @@ -170,14 +172,19 @@ class OAClient(): for handler in self._message_handlers[msg_type]: await handler(event) - async def set_message(self,msg_id:int,content:str): + async def set_message(self, msg_id: int, content: str): self.generated_content[msg_id] = content - -class OAClientForLongerResponse(): - - def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str,LoadingMessage:str): +class OAClientForLongerResponse: + def __init__( + self, + token: str, + EncodingAESKey: str, + AppID: str, + Appsecret: str, + LoadingMessage: str, + ): self.token = token self.aes = EncodingAESKey self.appid = AppID @@ -185,9 +192,14 @@ class OAClientForLongerResponse(): self.base_url = 'https://api.weixin.qq.com' self.access_token = '' self.app = Quart(__name__) - self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']) + self.app.add_url_rule( + '/callback/command', + 'handle_callback', + self.handle_callback_request, + methods=['GET', 'POST'], + ) self._message_handlers = { - "example":[], + 'example': [], } self.access_token_expiry_time = None self.loading_message = LoadingMessage @@ -196,50 +208,55 @@ class OAClientForLongerResponse(): async def handle_callback_request(self): try: - start_time = time.time() - signature = request.args.get("signature", "") - timestamp = request.args.get("timestamp", "") - nonce = request.args.get("nonce", "") - echostr = request.args.get("echostr", "") - msg_signature = request.args.get("msg_signature", "") + signature = request.args.get('signature', '') + timestamp = request.args.get('timestamp', '') + nonce = request.args.get('nonce', '') + echostr = request.args.get('echostr', '') + msg_signature = request.args.get('msg_signature', '') if msg_signature is None: - raise Exception("msg_signature不在请求体中") + raise Exception('msg_signature不在请求体中') if request.method == 'GET': - check_str = "".join(sorted([self.token, timestamp, nonce])) - check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest() - return echostr if check_signature == signature else "拒绝请求" + check_str = ''.join(sorted([self.token, timestamp, nonce])) + check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest() + return echostr if check_signature == signature else '拒绝请求' - elif request.method == "POST": + elif request.method == 'POST': encryt_msg = await request.data wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid) - ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce) + ret, xml_msg = wxcpt.DecryptMsg( + encryt_msg, msg_signature, timestamp, nonce + ) xml_msg = xml_msg.decode('utf-8') if ret != 0: - raise Exception("消息解密失败") + raise Exception('消息解密失败') # 解析 XML root = ET.fromstring(xml_msg) - from_user = root.find("FromUserName").text - to_user = root.find("ToUserName").text - - - if self.msg_queue.get(from_user) and self.msg_queue[from_user][0]["content"]: + from_user = root.find('FromUserName').text + to_user = root.find('ToUserName').text + if ( + self.msg_queue.get(from_user) + and self.msg_queue[from_user][0]['content'] + ): queue_top = self.msg_queue[from_user].pop(0) - queue_content = queue_top["content"] + queue_content = queue_top['content'] # 弹出用户消息 - if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user]: + if ( + self.user_msg_queue.get(from_user) + and self.user_msg_queue[from_user] + ): self.user_msg_queue[from_user].pop(0) response_xml = xml_template.format( to_user=from_user, from_user=to_user, create_time=int(time.time()), - content=queue_content + content=queue_content, ) return response_xml @@ -248,65 +265,67 @@ class OAClientForLongerResponse(): to_user=from_user, from_user=to_user, create_time=int(time.time()), - content=self.loading_message + content=self.loading_message, ) - - if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user][0]["content"]: + + if ( + self.user_msg_queue.get(from_user) + and self.user_msg_queue[from_user][0]['content'] + ): return response_xml else: message_data = await self.get_message(xml_msg) - + if message_data: event = OAEvent.from_payload(message_data) if event: - self.user_msg_queue.setdefault(from_user,[]).append( + self.user_msg_queue.setdefault(from_user, []).append( { - "content":event.message, + 'content': event.message, } ) await self._handle_message(event) return response_xml - except Exception as e: + except Exception: traceback.print_exc() - - async def get_message(self, xml_msg: str): - root = ET.fromstring(xml_msg) message_data = { - "ToUserName": root.find("ToUserName").text, - "FromUserName": root.find("FromUserName").text, - "CreateTime": int(root.find("CreateTime").text), - "MsgType": root.find("MsgType").text, - "Content": root.find("Content").text if root.find("Content") is not None else None, - "MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None, + 'ToUserName': root.find('ToUserName').text, + 'FromUserName': root.find('FromUserName').text, + 'CreateTime': int(root.find('CreateTime').text), + 'MsgType': root.find('MsgType').text, + 'Content': root.find('Content').text + if root.find('Content') is not None + else None, + 'MsgId': int(root.find('MsgId').text) + if root.find('MsgId') is not None + else None, } return message_data - - async def run_task(self, host: str, port: int, *args, **kwargs): """ 启动 Quart 应用。 """ await self.app.run_task(host=host, port=port, *args, **kwargs) - - def on_message(self, msg_type: str): """ 注册消息类型处理器。 """ + def decorator(func: Callable[[OAEvent], None]): if msg_type not in self._message_handlers: self._message_handlers[msg_type] = [] self._message_handlers[msg_type].append(func) return func + return decorator async def _handle_message(self, event: OAEvent): @@ -319,22 +338,13 @@ class OAClientForLongerResponse(): for handler in self._message_handlers[msg_type]: await handler(event) - async def set_message(self,from_user:int,message_id:int,content:str): - if from_user not in self.msg_queue: + async def set_message(self, from_user: int, message_id: int, content: str): + if from_user not in self.msg_queue: self.msg_queue[from_user] = [] - + self.msg_queue[from_user].append( { - "msg_id":message_id, - "content":content, + 'msg_id': message_id, + 'content': content, } ) - - - - - - - - - diff --git a/libs/official_account_api/oaevent.py b/libs/official_account_api/oaevent.py index ebbccd7e..d4de3914 100644 --- a/libs/official_account_api/oaevent.py +++ b/libs/official_account_api/oaevent.py @@ -9,7 +9,7 @@ class OAEvent(dict): """ @staticmethod - def from_payload(payload: Dict[str, Any]) -> Optional["OAEvent"]: + def from_payload(payload: Dict[str, Any]) -> Optional['OAEvent']: """ 从微信公众号事件数据构造 `WecomEvent` 对象。 @@ -34,14 +34,14 @@ class OAEvent(dict): Returns: str: 事件类型。 """ - return self.get("MsgType", "") - + return self.get('MsgType', '') + @property def picurl(self) -> str: """ 图片链接 """ - return self.get("PicUrl","") + return self.get('PicUrl', '') @property def detail_type(self) -> str: @@ -53,8 +53,8 @@ class OAEvent(dict): Returns: str: 事件详细类型。 """ - if self.type == "event": - return self.get("Event", "") + if self.type == 'event': + return self.get('Event', '') return self.type @property @@ -65,15 +65,14 @@ class OAEvent(dict): Returns: str: 事件名。 """ - return f"{self.type}.{self.detail_type}" + return f'{self.type}.{self.detail_type}' @property def user_id(self) -> Optional[str]: """ 发送方账号 """ - return self.get("FromUserName") - + return self.get('FromUserName') @property def receiver_id(self) -> Optional[str]: @@ -83,7 +82,7 @@ class OAEvent(dict): Returns: Optional[str]: 接收者 ID。 """ - return self.get("ToUserName") + return self.get('ToUserName') @property def message_id(self) -> Optional[str]: @@ -93,7 +92,7 @@ class OAEvent(dict): Returns: Optional[str]: 消息 ID。 """ - return self.get("MsgId") + return self.get('MsgId') @property def message(self) -> Optional[str]: @@ -103,7 +102,7 @@ class OAEvent(dict): Returns: Optional[str]: 消息内容。 """ - return self.get("Content") + return self.get('Content') @property def media_id(self) -> Optional[str]: @@ -113,7 +112,7 @@ class OAEvent(dict): Returns: Optional[str]: 媒体文件 ID。 """ - return self.get("MediaId") + return self.get('MediaId') @property def timestamp(self) -> Optional[int]: @@ -123,7 +122,7 @@ class OAEvent(dict): Returns: Optional[int]: 时间戳。 """ - return self.get("CreateTime") + return self.get('CreateTime') @property def event_key(self) -> Optional[str]: @@ -133,7 +132,7 @@ class OAEvent(dict): Returns: Optional[str]: 事件 Key。 """ - return self.get("EventKey") + return self.get('EventKey') def __getattr__(self, key: str) -> Optional[Any]: """ @@ -164,4 +163,4 @@ class OAEvent(dict): Returns: str: 字符串表示。 """ - return f"" + return f'' diff --git a/libs/qq_official_api/api.py b/libs/qq_official_api/api.py index 62f252df..89360881 100644 --- a/libs/qq_official_api/api.py +++ b/libs/qq_official_api/api.py @@ -1,24 +1,16 @@ import time from quart import request -import base64 -import binascii import httpx from quart import Quart -import xml.etree.ElementTree as ET from typing import Callable, Dict, Any -from pkg.platform.types import events as platform_events, message as platform_message -import aiofiles +from pkg.platform.types import events as platform_events from .qqofficialevent import QQOfficialEvent import json -import hmac -import base64 -import hashlib import traceback from cryptography.hazmat.primitives.asymmetric import ed25519 -from .qqofficialevent import QQOfficialEvent + def handle_validation(body: dict, bot_secret: str): - # bot正确的secert是32位的,此处仅为了适配演示demo while len(bot_secret) < 32: bot_secret = bot_secret * 2 @@ -36,29 +28,26 @@ def handle_validation(body: dict, bot_secret: str): signature_hex = signature.hex() - response = { - "plain_token": body['d']['plain_token'], - "signature": signature_hex - } + response = {'plain_token': body['d']['plain_token'], 'signature': signature_hex} return response + class QQOfficialClient: def __init__(self, secret: str, token: str, app_id: str): self.app = Quart(__name__) self.app.add_url_rule( - "/callback/command", - "handle_callback", + '/callback/command', + 'handle_callback', self.handle_callback_request, - methods=["GET", "POST"], + methods=['GET', 'POST'], ) self.secret = secret self.token = token self.app_id = app_id - self._message_handlers = { - } - self.base_url = "https://api.sgroup.qq.com" - self.access_token = "" + self._message_handlers = {} + self.base_url = 'https://api.sgroup.qq.com' + self.access_token = '' self.access_token_expiry_time = None async def check_access_token(self): @@ -66,30 +55,29 @@ class QQOfficialClient: if not self.access_token or await self.is_token_expired(): return False return bool(self.access_token and self.access_token.strip()) - + async def get_access_token(self): """获取access_token""" - url = "https://bots.qq.com/app/getAppAccessToken" + url = 'https://bots.qq.com/app/getAppAccessToken' async with httpx.AsyncClient() as client: params = { - "appId":self.app_id, - "clientSecret":self.secret, + 'appId': self.app_id, + 'clientSecret': self.secret, } headers = { - "content-type":"application/json", + 'content-type': 'application/json', } try: - response = await client.post(url,json=params,headers=headers) + response = await client.post(url, json=params, headers=headers) if response.status_code == 200: response_data = response.json() - access_token = response_data.get("access_token") - expires_in = int(response_data.get("expires_in",7200)) + access_token = response_data.get('access_token') + expires_in = int(response_data.get('expires_in', 7200)) self.access_token_expiry_time = time.time() + expires_in - 60 if access_token: self.access_token = access_token except Exception as e: - raise Exception(f"获取access_token失败: {e}") - + raise Exception(f'获取access_token失败: {e}') async def handle_callback_request(self): """处理回调请求""" @@ -98,27 +86,24 @@ class QQOfficialClient: body = await request.get_data() payload = json.loads(body) - # 验证是否为回调验证请求 - if payload.get("op") == 13: + if payload.get('op') == 13: # 生成签名 response = handle_validation(payload, self.secret) return response - if payload.get("op") == 0: - message_data = await self.get_message(payload) - if message_data: - event = QQOfficialEvent.from_payload(message_data) - await self._handle_message(event) - - return {"code": 0, "message": "success"} + if payload.get('op') == 0: + message_data = await self.get_message(payload) + if message_data: + event = QQOfficialEvent.from_payload(message_data) + await self._handle_message(event) + + return {'code': 0, 'message': 'success'} except Exception as e: traceback.print_exc() - return {"error": str(e)}, 400 - - + return {'error': str(e)}, 400 async def run_task(self, host: str, port: int, *args, **kwargs): """启动 Quart 应用""" @@ -135,133 +120,140 @@ class QQOfficialClient: return decorator - async def _handle_message(self, event:QQOfficialEvent): + async def _handle_message(self, event: QQOfficialEvent): """处理消息事件""" msg_type = event.t if msg_type in self._message_handlers: for handler in self._message_handlers[msg_type]: await handler(event) - - async def get_message(self,msg:dict) -> Dict[str,Any]: + async def get_message(self, msg: dict) -> Dict[str, Any]: """获取消息""" message_data = { - "t": msg.get("t",{}), - "user_openid": msg.get("d",{}).get("author",{}).get("user_openid",{}), - "timestamp": msg.get("d",{}).get("timestamp",{}), - "d_author_id": msg.get("d",{}).get("author",{}).get("id",{}), - "content": msg.get("d",{}).get("content",{}), - "d_id": msg.get("d",{}).get("id",{}), - "id": msg.get("id",{}), - "channel_id": msg.get("d",{}).get("channel_id",{}), - "username": msg.get("d",{}).get("author",{}).get("username",{}), - "guild_id": msg.get("d",{}).get("guild_id",{}), - "member_openid": msg.get("d",{}).get("author",{}).get("openid",{}), - "group_openid": msg.get("d",{}).get("group_openid",{}) + 't': msg.get('t', {}), + 'user_openid': msg.get('d', {}).get('author', {}).get('user_openid', {}), + 'timestamp': msg.get('d', {}).get('timestamp', {}), + 'd_author_id': msg.get('d', {}).get('author', {}).get('id', {}), + 'content': msg.get('d', {}).get('content', {}), + 'd_id': msg.get('d', {}).get('id', {}), + 'id': msg.get('id', {}), + 'channel_id': msg.get('d', {}).get('channel_id', {}), + 'username': msg.get('d', {}).get('author', {}).get('username', {}), + 'guild_id': msg.get('d', {}).get('guild_id', {}), + 'member_openid': msg.get('d', {}).get('author', {}).get('openid', {}), + 'group_openid': msg.get('d', {}).get('group_openid', {}), } - attachments = msg.get("d", {}).get("attachments", []) - image_attachments = [attachment['url'] for attachment in attachments if await self.is_image(attachment)] - image_attachments_type = [attachment['content_type'] for attachment in attachments if await self.is_image(attachment)] + attachments = msg.get('d', {}).get('attachments', []) + image_attachments = [ + attachment['url'] + for attachment in attachments + if await self.is_image(attachment) + ] + image_attachments_type = [ + attachment['content_type'] + for attachment in attachments + if await self.is_image(attachment) + ] if image_attachments: - message_data["image_attachments"] = image_attachments[0] - message_data["content_type"] = image_attachments_type[0] + message_data['image_attachments'] = image_attachments[0] + message_data['content_type'] = image_attachments_type[0] else: - - message_data["image_attachments"] = None - - return message_data - + message_data['image_attachments'] = None - async def is_image(self,attachment:dict) -> bool: + return message_data + + async def is_image(self, attachment: dict) -> bool: """判断是否为图片附件""" - content_type = attachment.get("content_type","") - return content_type.startswith("image/") - - - async def send_private_text_msg(self,user_openid:str,content:str,msg_id:str): + content_type = attachment.get('content_type', '') + return content_type.startswith('image/') + + async def send_private_text_msg(self, user_openid: str, content: str, msg_id: str): """发送私聊消息""" if not await self.check_access_token(): - await self.get_access_token() + await self.get_access_token() - url = self.base_url + "/v2/users/" + user_openid + "/messages" + url = self.base_url + '/v2/users/' + user_openid + '/messages' async with httpx.AsyncClient() as client: headers = { - "Authorization": f"QQBot {self.access_token}", - "Content-Type": "application/json", + 'Authorization': f'QQBot {self.access_token}', + 'Content-Type': 'application/json', } data = { - "content": content, - "msg_type": 0, - "msg_id": msg_id, + 'content': content, + 'msg_type': 0, + 'msg_id': msg_id, } - response = await client.post(url,headers=headers,json=data) + response = await client.post(url, headers=headers, json=data) if response.status_code == 200: return else: raise ValueError(response) - - async def send_group_text_msg(self,group_openid:str,content:str,msg_id:str): + async def send_group_text_msg(self, group_openid: str, content: str, msg_id: str): """发送群聊消息""" if not await self.check_access_token(): await self.get_access_token() - url = self.base_url + "/v2/groups/" + group_openid + "/messages" + url = self.base_url + '/v2/groups/' + group_openid + '/messages' async with httpx.AsyncClient() as client: headers = { - "Authorization": f"QQBot {self.access_token}", - "Content-Type": "application/json", + 'Authorization': f'QQBot {self.access_token}', + 'Content-Type': 'application/json', } data = { - "content": content, - "msg_type": 0, - "msg_id": msg_id, + 'content': content, + 'msg_type': 0, + 'msg_id': msg_id, } - response = await client.post(url,headers=headers,json=data) + response = await client.post(url, headers=headers, json=data) if response.status_code == 200: return else: raise Exception(response.read().decode()) - async def send_channle_group_text_msg(self,channel_id:str,content:str,msg_id:str): + async def send_channle_group_text_msg( + self, channel_id: str, content: str, msg_id: str + ): """发送频道群聊消息""" if not await self.check_access_token(): - await self.get_access_token() + await self.get_access_token() - url = self.base_url + "/channels/" + channel_id + "/messages" + url = self.base_url + '/channels/' + channel_id + '/messages' async with httpx.AsyncClient() as client: headers = { - "Authorization": f"QQBot {self.access_token}", - "Content-Type": "application/json", + 'Authorization': f'QQBot {self.access_token}', + 'Content-Type': 'application/json', } params = { - "content": content, - "msg_type": 0, - "msg_id": msg_id, + 'content': content, + 'msg_type': 0, + 'msg_id': msg_id, } - response = await client.post(url,headers=headers,json=params) + response = await client.post(url, headers=headers, json=params) if response.status_code == 200: return True else: raise Exception(response) - async def send_channle_private_text_msg(self,guild_id:str,content:str,msg_id:str): + async def send_channle_private_text_msg( + self, guild_id: str, content: str, msg_id: str + ): """发送频道私聊消息""" if not await self.check_access_token(): - await self.get_access_token() + await self.get_access_token() - url = self.base_url + "/dms/" + guild_id + "/messages" + url = self.base_url + '/dms/' + guild_id + '/messages' async with httpx.AsyncClient() as client: headers = { - "Authorization": f"QQBot {self.access_token}", - "Content-Type": "application/json", + 'Authorization': f'QQBot {self.access_token}', + 'Content-Type': 'application/json', } params = { - "content": content, - "msg_type": 0, - "msg_id": msg_id, + 'content': content, + 'msg_type': 0, + 'msg_id': msg_id, } - response = await client.post(url,headers=headers,json=params) + response = await client.post(url, headers=headers, json=params) if response.status_code == 200: return True else: diff --git a/libs/qq_official_api/qqofficialevent.py b/libs/qq_official_api/qqofficialevent.py index 41e842f1..7c29b9d8 100644 --- a/libs/qq_official_api/qqofficialevent.py +++ b/libs/qq_official_api/qqofficialevent.py @@ -1,114 +1,112 @@ from typing import Dict, Any, Optional + class QQOfficialEvent(dict): @staticmethod - def from_payload(payload: Dict[str, Any]) -> Optional["QQOfficialEvent"]: + def from_payload(payload: Dict[str, Any]) -> Optional['QQOfficialEvent']: try: event = QQOfficialEvent(payload) return event except KeyError: return None - @property def t(self) -> str: """ 事件类型 """ - return self.get("t", "") - + return self.get('t', '') + @property def user_openid(self) -> str: """ 用户openid """ - return self.get("user_openid",{}) - + return self.get('user_openid', {}) + @property def timestamp(self) -> str: """ 时间戳 """ - return self.get("timestamp",{}) - - + return self.get('timestamp', {}) + @property def d_author_id(self) -> str: """ 作者id """ - return self.get("id",{}) - + return self.get('id', {}) + @property def content(self) -> str: """ 内容 """ - return self.get("content",'') - + return self.get('content', '') + @property def d_id(self) -> str: """ d_id """ - return self.get("d_id",{}) - + return self.get('d_id', {}) + @property def id(self) -> str: """ 消息id,msg_id """ - return self.get("id",{}) - + return self.get('id', {}) + @property def channel_id(self) -> str: """ 频道id """ - return self.get("channel_id",{}) - + return self.get('channel_id', {}) + @property def username(self) -> str: """ 用户名 """ - return self.get("username",{}) - + return self.get('username', {}) + @property def guild_id(self) -> str: """ 频道id """ - return self.get("guild_id",{}) - + return self.get('guild_id', {}) + @property def member_openid(self) -> str: """ 成员openid """ - return self.get("openid",{}) - + return self.get('openid', {}) + @property def attachments(self) -> str: """ 附件url """ - url = self.get("image_attachments", "") - if url and not url.startswith("https://"): - url = "https://" + url + url = self.get('image_attachments', '') + if url and not url.startswith('https://'): + url = 'https://' + url return url - + @property def group_openid(self) -> str: """ 群组id """ - return self.get("group_openid",{}) - + return self.get('group_openid', {}) + @property def content_type(self) -> str: """ 文件类型 """ - return self.get("content_type","") - + return self.get('content_type', '') diff --git a/libs/wecom_api/WXBizMsgCrypt3.py b/libs/wecom_api/WXBizMsgCrypt3.py index 0123c7d1..a9a7bc89 100644 --- a/libs/wecom_api/WXBizMsgCrypt3.py +++ b/libs/wecom_api/WXBizMsgCrypt3.py @@ -1,10 +1,11 @@ #!/usr/bin/env python # -*- encoding:utf-8 -*- -""" 对企业微信发送给企业后台的消息加解密示例代码. +"""对企业微信发送给企业后台的消息加解密示例代码. @copyright: Copyright (c) 1998-2014 Tencent Inc. """ + # ------------------------------------------------------------------------ import logging import base64 @@ -49,7 +50,7 @@ class SHA1: sortlist = [token, timestamp, nonce, encrypt] sortlist.sort() sha = hashlib.sha1() - sha.update("".join(sortlist).encode()) + sha.update(''.join(sortlist).encode()) return ierror.WXBizMsgCrypt_OK, sha.hexdigest() except Exception as e: logger = logging.getLogger() @@ -75,7 +76,7 @@ class XMLParse: """ try: xml_tree = ET.fromstring(xmltext) - encrypt = xml_tree.find("Encrypt") + encrypt = xml_tree.find('Encrypt') return ierror.WXBizMsgCrypt_OK, encrypt.text except Exception as e: logger = logging.getLogger() @@ -100,13 +101,13 @@ class XMLParse: return resp_xml -class PKCS7Encoder(): +class PKCS7Encoder: """提供基于PKCS7算法的加解密接口""" block_size = 32 def encode(self, text): - """ 对需要加密的明文进行填充补位 + """对需要加密的明文进行填充补位 @param text: 需要进行填充补位操作的明文 @return: 补齐明文字符串 """ @@ -134,7 +135,6 @@ class Prpcrypt(object): """提供接收和推送给企业微信消息的加解密接口""" def __init__(self, key): - # self.key = base64.b64decode(key+"=") self.key = key # 设置加解密模式为AES的CBC模式 @@ -147,7 +147,12 @@ class Prpcrypt(object): """ # 16位随机字符串添加到明文开头 text = text.encode() - text = self.get_random_str() + struct.pack("I", socket.htonl(len(text))) + text + receiveid.encode() + text = ( + self.get_random_str() + + struct.pack('I', socket.htonl(len(text))) + + text + + receiveid.encode() + ) # 使用自定义的填充方式对明文进行补位填充 pkcs7 = PKCS7Encoder() @@ -183,9 +188,9 @@ class Prpcrypt(object): # plain_text = pkcs7.encode(plain_text) # 去除16位随机字符串 content = plain_text[16:-pad] - xml_len = socket.ntohl(struct.unpack("I", content[: 4])[0]) - xml_content = content[4: xml_len + 4] - from_receiveid = content[xml_len + 4:] + xml_len = socket.ntohl(struct.unpack('I', content[:4])[0]) + xml_content = content[4 : xml_len + 4] + from_receiveid = content[xml_len + 4 :] except Exception as e: logger = logging.getLogger() logger.error(e) @@ -196,7 +201,7 @@ class Prpcrypt(object): return 0, xml_content def get_random_str(self): - """ 随机生成16位字符串 + """随机生成16位字符串 @return: 16位字符串 """ return str(random.randint(1000000000000000, 9999999999999999)).encode() @@ -206,10 +211,10 @@ class WXBizMsgCrypt(object): # 构造函数 def __init__(self, sToken, sEncodingAESKey, sReceiveId): try: - self.key = base64.b64decode(sEncodingAESKey + "=") + self.key = base64.b64decode(sEncodingAESKey + '=') assert len(self.key) == 32 - except: - throw_exception("[error]: EncodingAESKey unvalid !", FormatException) + except Exception: + throw_exception('[error]: EncodingAESKey unvalid !', FormatException) # return ierror.WXBizMsgCrypt_IllegalAesKey,None self.m_sToken = sToken self.m_sReceiveId = sReceiveId diff --git a/libs/wecom_api/api.py b/libs/wecom_api/api.py index 61458f8e..8993885b 100644 --- a/libs/wecom_api/api.py +++ b/libs/wecom_api/api.py @@ -7,15 +7,22 @@ from quart import Quart import xml.etree.ElementTree as ET from typing import Callable, Dict, Any from .wecomevent import WecomEvent -from pkg.platform.types import events as platform_events, message as platform_message +from pkg.platform.types import message as platform_message import aiofiles -class WecomClient(): - def __init__(self,corpid:str,secret:str,token:str,EncodingAESKey:str,contacts_secret:str): +class WecomClient: + def __init__( + self, + corpid: str, + secret: str, + token: str, + EncodingAESKey: str, + contacts_secret: str, + ): self.corpid = corpid self.secret = secret - self.access_token_for_contacts ='' + self.access_token_for_contacts = '' self.token = token self.aes = EncodingAESKey self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin' @@ -23,19 +30,26 @@ class WecomClient(): self.secret_for_contacts = contacts_secret self.app = Quart(__name__) self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid) - self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']) + self.app.add_url_rule( + '/callback/command', + 'handle_callback', + self.handle_callback_request, + methods=['GET', 'POST'], + ) self._message_handlers = { - "example":[], + 'example': [], } - #access——token操作 + # access——token操作 async def check_access_token(self): return bool(self.access_token and self.access_token.strip()) async def check_access_token_for_contacts(self): - return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip()) + return bool( + self.access_token_for_contacts and self.access_token_for_contacts.strip() + ) - async def get_access_token(self,secret): + async def get_access_token(self, secret): url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}' async with httpx.AsyncClient() as client: response = await client.get(url) @@ -43,146 +57,163 @@ class WecomClient(): if 'access_token' in data: return data['access_token'] else: - raise Exception(f"未获取access token: {data}") + raise Exception(f'未获取access token: {data}') async def get_users(self): if not self.check_access_token_for_contacts(): - self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts) + self.access_token_for_contacts = await self.get_access_token( + self.secret_for_contacts + ) - url = self.base_url+'/user/list_id?access_token='+self.access_token_for_contacts + url = ( + self.base_url + + '/user/list_id?access_token=' + + self.access_token_for_contacts + ) async with httpx.AsyncClient() as client: params = { - "cursor":"", - "limit":10000, + 'cursor': '', + 'limit': 10000, } - response = await client.post(url,json=params) + response = await client.post(url, json=params) data = response.json() if data['errcode'] == 0: dept_users = data['dept_user'] userid = [] for user in dept_users: - userid.append(user["userid"]) + userid.append(user['userid']) return userid else: - raise Exception("未获取用户") - - async def send_to_all(self,content:str,agent_id:int): - if not self.check_access_token_for_contacts(): - self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts) + raise Exception('未获取用户') - url = self.base_url+'/message/send?access_token='+self.access_token_for_contacts + async def send_to_all(self, content: str, agent_id: int): + if not self.check_access_token_for_contacts(): + self.access_token_for_contacts = await self.get_access_token( + self.secret_for_contacts + ) + + url = ( + self.base_url + + '/message/send?access_token=' + + self.access_token_for_contacts + ) user_ids = await self.get_users() - user_ids_string = "|".join(user_ids) + user_ids_string = '|'.join(user_ids) async with httpx.AsyncClient() as client: params = { - "touser" : user_ids_string, - "msgtype" : "text", - "agentid" : agent_id, - "text" : { - "content" : content, - }, - "safe":0, - "enable_id_trans": 0, - "enable_duplicate_check": 0, - "duplicate_check_interval": 1800 + 'touser': user_ids_string, + 'msgtype': 'text', + 'agentid': agent_id, + 'text': { + 'content': content, + }, + 'safe': 0, + 'enable_id_trans': 0, + 'enable_duplicate_check': 0, + 'duplicate_check_interval': 1800, } - response = await client.post(url,json=params) + response = await client.post(url, json=params) data = response.json() if data['errcode'] != 0: - raise Exception("Failed to send message: "+str(data)) + raise Exception('Failed to send message: ' + str(data)) - async def send_image(self,user_id:str,agent_id:int,media_id:str): + async def send_image(self, user_id: str, agent_id: int, media_id: str): if not await self.check_access_token(): self.access_token = await self.get_access_token(self.secret) - url = self.base_url+'/media/upload?access_token='+self.access_token + url = self.base_url + '/media/upload?access_token=' + self.access_token async with httpx.AsyncClient() as client: params = { - "touser" : user_id, - "toparty" : "", - "totag":"", - "agentid" : agent_id, - "msgtype" : "image", - "image" : { - "media_id" : media_id, + 'touser': user_id, + 'toparty': '', + 'totag': '', + 'agentid': agent_id, + 'msgtype': 'image', + 'image': { + 'media_id': media_id, }, - "safe":0, - "enable_id_trans": 0, - "enable_duplicate_check": 0, - "duplicate_check_interval": 1800 + 'safe': 0, + 'enable_id_trans': 0, + 'enable_duplicate_check': 0, + 'duplicate_check_interval': 1800, } try: - response = await client.post(url,json=params) + response = await client.post(url, json=params) data = response.json() except Exception as e: - raise Exception("Failed to send image: "+str(e)) + raise Exception('Failed to send image: ' + str(e)) # 企业微信错误码40014和42001,代表accesstoken问题 if data['errcode'] == 40014 or data['errcode'] == 42001: self.access_token = await self.get_access_token(self.secret) - return await self.send_image(user_id,agent_id,media_id) + return await self.send_image(user_id, agent_id, media_id) if data['errcode'] != 0: - raise Exception("Failed to send image: "+str(data)) - - async def send_private_msg(self,user_id:str, agent_id:int,content:str): + raise Exception('Failed to send image: ' + str(data)) + + async def send_private_msg(self, user_id: str, agent_id: int, content: str): if not await self.check_access_token(): self.access_token = await self.get_access_token(self.secret) - url = self.base_url+'/message/send?access_token='+self.access_token + url = self.base_url + '/message/send?access_token=' + self.access_token async with httpx.AsyncClient() as client: - params={ - "touser" : user_id, - "msgtype" : "text", - "agentid" : agent_id, - "text" : { - "content" : content, + params = { + 'touser': user_id, + 'msgtype': 'text', + 'agentid': agent_id, + 'text': { + 'content': content, }, - "safe":0, - "enable_id_trans": 0, - "enable_duplicate_check": 0, - "duplicate_check_interval": 1800 + 'safe': 0, + 'enable_id_trans': 0, + 'enable_duplicate_check': 0, + 'duplicate_check_interval': 1800, } - response = await client.post(url,json=params) + response = await client.post(url, json=params) data = response.json() if data['errcode'] == 40014 or data['errcode'] == 42001: self.access_token = await self.get_access_token(self.secret) - return await self.send_private_msg(user_id,agent_id,content) + return await self.send_private_msg(user_id, agent_id, content) if data['errcode'] != 0: - raise Exception("Failed to send message: "+str(data)) + raise Exception('Failed to send message: ' + str(data)) async def handle_callback_request(self): """ 处理回调请求,包括 GET 验证和 POST 消息接收。 """ try: + msg_signature = request.args.get('msg_signature') + timestamp = request.args.get('timestamp') + nonce = request.args.get('nonce') - msg_signature = request.args.get("msg_signature") - timestamp = request.args.get("timestamp") - nonce = request.args.get("nonce") - - if request.method == "GET": - echostr = request.args.get("echostr") - ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) + if request.method == 'GET': + echostr = request.args.get('echostr') + ret, reply_echo_str = self.wxcpt.VerifyURL( + msg_signature, timestamp, nonce, echostr + ) if ret != 0: - raise Exception(f"验证失败,错误码: {ret}") + raise Exception(f'验证失败,错误码: {ret}') return reply_echo_str - elif request.method == "POST": + elif request.method == 'POST': encrypt_msg = await request.data - ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce) + ret, xml_msg = self.wxcpt.DecryptMsg( + encrypt_msg, msg_signature, timestamp, nonce + ) if ret != 0: - raise Exception(f"消息解密失败,错误码: {ret}") + raise Exception(f'消息解密失败,错误码: {ret}') # 解析消息并处理 message_data = await self.get_message(xml_msg) if message_data: - event = WecomEvent.from_payload(message_data) # 转换为 WecomEvent 对象 + event = WecomEvent.from_payload( + message_data + ) # 转换为 WecomEvent 对象 if event: await self._handle_message(event) - return "success" + return 'success' except Exception as e: - return f"Error processing request: {str(e)}", 400 + return f'Error processing request: {str(e)}', 400 async def run_task(self, host: str, port: int, *args, **kwargs): """ @@ -194,11 +225,13 @@ class WecomClient(): """ 注册消息类型处理器。 """ + def decorator(func: Callable[[WecomEvent], None]): if msg_type not in self._message_handlers: self._message_handlers[msg_type] = [] self._message_handlers[msg_type].append(func) return func + return decorator async def _handle_message(self, event: WecomEvent): @@ -216,38 +249,47 @@ class WecomClient(): """ root = ET.fromstring(xml_msg) message_data = { - "ToUserName": root.find("ToUserName").text, - "FromUserName": root.find("FromUserName").text, - "CreateTime": int(root.find("CreateTime").text), - "MsgType": root.find("MsgType").text, - "Content": root.find("Content").text if root.find("Content") is not None else None, - "MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None, - "AgentID": int(root.find("AgentID").text) if root.find("AgentID") is not None else None, + 'ToUserName': root.find('ToUserName').text, + 'FromUserName': root.find('FromUserName').text, + 'CreateTime': int(root.find('CreateTime').text), + 'MsgType': root.find('MsgType').text, + 'Content': root.find('Content').text + if root.find('Content') is not None + else None, + 'MsgId': int(root.find('MsgId').text) + if root.find('MsgId') is not None + else None, + 'AgentID': int(root.find('AgentID').text) + if root.find('AgentID') is not None + else None, } - if message_data["MsgType"] == "image": - message_data["MediaId"] = root.find("MediaId").text if root.find("MediaId") is not None else None - message_data["PicUrl"] = root.find("PicUrl").text if root.find("PicUrl") is not None else None - + if message_data['MsgType'] == 'image': + message_data['MediaId'] = ( + root.find('MediaId').text if root.find('MediaId') is not None else None + ) + message_data['PicUrl'] = ( + root.find('PicUrl').text if root.find('PicUrl') is not None else None + ) + return message_data - + @staticmethod async def get_image_type(image_bytes: bytes) -> str: """ 通过图片的magic numbers判断图片类型 """ magic_numbers = { - b'\xFF\xD8\xFF': 'jpg', - b'\x89\x50\x4E\x47': 'png', + b'\xff\xd8\xff': 'jpg', + b'\x89\x50\x4e\x47': 'png', b'\x47\x49\x46': 'gif', - b'\x42\x4D': 'bmp', - b'\x00\x00\x01\x00': 'ico' + b'\x42\x4d': 'bmp', + b'\x00\x00\x01\x00': 'ico', } - + for magic, ext in magic_numbers.items(): if image_bytes.startswith(magic): return ext return 'jpg' # 默认返回jpg - async def upload_to_work(self, image: platform_message.Image): """ @@ -256,9 +298,14 @@ class WecomClient(): if not await self.check_access_token(): self.access_token = await self.get_access_token(self.secret) - url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file' + url = ( + self.base_url + + '/media/upload?access_token=' + + self.access_token + + '&type=file' + ) file_bytes = None - file_name = "uploaded_file.txt" + file_name = 'uploaded_file.txt' # 获取文件的二进制数据 if image.path: @@ -277,20 +324,22 @@ class WecomClient(): padded_base64 = base64_data + '=' * padding file_bytes = base64.b64decode(padded_base64) except binascii.Error as e: - raise ValueError(f"Invalid base64 string: {str(e)}") + raise ValueError(f'Invalid base64 string: {str(e)}') else: - raise ValueError("image对象出错") + raise ValueError('image对象出错') # 设置 multipart/form-data 格式的文件 - boundary = "-------------------------acebdf13572468" - headers = { - 'Content-Type': f'multipart/form-data; boundary={boundary}' - } + boundary = '-------------------------acebdf13572468' + headers = {'Content-Type': f'multipart/form-data; boundary={boundary}'} body = ( - f"--{boundary}\r\n" - f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n" - f"Content-Type: application/octet-stream\r\n\r\n" - ).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8') + ( + f'--{boundary}\r\n' + f'Content-Disposition: form-data; name="media"; filename="{file_name}"; filelength={len(file_bytes)}\r\n' + f'Content-Type: application/octet-stream\r\n\r\n' + ).encode('utf-8') + + file_bytes + + f'\r\n--{boundary}--\r\n'.encode('utf-8') + ) # 上传文件 async with httpx.AsyncClient() as client: @@ -300,19 +349,18 @@ class WecomClient(): self.access_token = await self.get_access_token(self.secret) media_id = await self.upload_to_work(image) if data.get('errcode', 0) != 0: - raise Exception("failed to upload file") + raise Exception('failed to upload file') media_id = data.get('media_id') return media_id - async def download_image_to_bytes(self,url:str) -> bytes: + async def download_image_to_bytes(self, url: str) -> bytes: async with httpx.AsyncClient() as client: response = await client.get(url) response.raise_for_status() return response.content - #进行media_id的获取 + # 进行media_id的获取 async def get_media_id(self, image: platform_message.Image): - media_id = await self.upload_to_work(image=image) return media_id diff --git a/libs/wecom_api/ierror.py b/libs/wecom_api/ierror.py index 8985b886..6c7ca122 100644 --- a/libs/wecom_api/ierror.py +++ b/libs/wecom_api/ierror.py @@ -4,7 +4,7 @@ # Author: jonyqin # Created Time: Thu 11 Sep 2014 01:53:58 PM CST # File Name: ierror.py -# Description:定义错误码含义 +# Description:定义错误码含义 ######################################################################### WXBizMsgCrypt_OK = 0 WXBizMsgCrypt_ValidateSignature_Error = -40001 @@ -17,4 +17,4 @@ WXBizMsgCrypt_DecryptAES_Error = -40007 WXBizMsgCrypt_IllegalBuffer = -40008 WXBizMsgCrypt_EncodeBase64_Error = -40009 WXBizMsgCrypt_DecodeBase64_Error = -40010 -WXBizMsgCrypt_GenReturnXml_Error = -40011 \ No newline at end of file +WXBizMsgCrypt_GenReturnXml_Error = -40011 diff --git a/libs/wecom_api/wecomevent.py b/libs/wecom_api/wecomevent.py index 3606cdf5..a0c2c7da 100644 --- a/libs/wecom_api/wecomevent.py +++ b/libs/wecom_api/wecomevent.py @@ -9,7 +9,7 @@ class WecomEvent(dict): """ @staticmethod - def from_payload(payload: Dict[str, Any]) -> Optional["WecomEvent"]: + def from_payload(payload: Dict[str, Any]) -> Optional['WecomEvent']: """ 从企业微信事件数据构造 `WecomEvent` 对象。 @@ -34,14 +34,14 @@ class WecomEvent(dict): Returns: str: 事件类型。 """ - return self.get("MsgType", "") - + return self.get('MsgType', '') + @property def picurl(self) -> str: """ 图片链接 """ - return self.get("PicUrl") + return self.get('PicUrl') @property def detail_type(self) -> str: @@ -53,8 +53,8 @@ class WecomEvent(dict): Returns: str: 事件详细类型。 """ - if self.type == "event": - return self.get("Event", "") + if self.type == 'event': + return self.get('Event', '') return self.type @property @@ -65,7 +65,7 @@ class WecomEvent(dict): Returns: str: 事件名。 """ - return f"{self.type}.{self.detail_type}" + return f'{self.type}.{self.detail_type}' @property def user_id(self) -> Optional[str]: @@ -75,8 +75,8 @@ class WecomEvent(dict): Returns: Optional[str]: 用户 ID。 """ - return self.get("FromUserName") - + return self.get('FromUserName') + @property def agent_id(self) -> Optional[int]: """ @@ -85,7 +85,7 @@ class WecomEvent(dict): Returns: Optional[int]: 机器人 ID。 """ - return self.get("AgentID") + return self.get('AgentID') @property def receiver_id(self) -> Optional[str]: @@ -95,7 +95,7 @@ class WecomEvent(dict): Returns: Optional[str]: 接收者 ID。 """ - return self.get("ToUserName") + return self.get('ToUserName') @property def message_id(self) -> Optional[str]: @@ -105,7 +105,7 @@ class WecomEvent(dict): Returns: Optional[str]: 消息 ID。 """ - return self.get("MsgId") + return self.get('MsgId') @property def message(self) -> Optional[str]: @@ -115,7 +115,7 @@ class WecomEvent(dict): Returns: Optional[str]: 消息内容。 """ - return self.get("Content") + return self.get('Content') @property def media_id(self) -> Optional[str]: @@ -125,7 +125,7 @@ class WecomEvent(dict): Returns: Optional[str]: 媒体文件 ID。 """ - return self.get("MediaId") + return self.get('MediaId') @property def timestamp(self) -> Optional[int]: @@ -135,7 +135,7 @@ class WecomEvent(dict): Returns: Optional[int]: 时间戳。 """ - return self.get("CreateTime") + return self.get('CreateTime') @property def event_key(self) -> Optional[str]: @@ -145,7 +145,7 @@ class WecomEvent(dict): Returns: Optional[str]: 事件 Key。 """ - return self.get("EventKey") + return self.get('EventKey') def __getattr__(self, key: str) -> Optional[Any]: """ @@ -176,4 +176,4 @@ class WecomEvent(dict): Returns: str: 字符串表示。 """ - return f"" + return f'' diff --git a/main.py b/main.py index 5c86b7ca..8be603c6 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import asyncio # LangBot 终端启动入口 # 在此层级解决依赖项检查。 # LangBot/main.py @@ -14,9 +15,6 @@ asciiart = r""" """ -import asyncio - - async def main_entry(loop: asyncio.AbstractEventLoop): print(asciiart) @@ -29,20 +27,22 @@ async def main_entry(loop: asyncio.AbstractEventLoop): missing_deps = await deps.check_deps() if missing_deps: - print("以下依赖包未安装,将自动安装,请完成后重启程序:") + print('以下依赖包未安装,将自动安装,请完成后重启程序:') for dep in missing_deps: - print("-", dep) + print('-', dep) await deps.install_deps(missing_deps) - print("已自动安装缺失的依赖包,请重启程序。") + print('已自动安装缺失的依赖包,请重启程序。') sys.exit(0) - + # check plugin deps await deps.precheck_plugin_deps() # 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1 import pydantic.version + if pydantic.version.VERSION < '2.0': import pydantic + sys.modules['pydantic.v1'] = pydantic # 检查配置文件 @@ -52,11 +52,12 @@ async def main_entry(loop: asyncio.AbstractEventLoop): generated_files = await files.generate_files() if generated_files: - print("以下文件不存在,已自动生成:") + print('以下文件不存在,已自动生成:') for file in generated_files: - print("-", file) + print('-', file) from pkg.core import boot + await boot.main(loop) @@ -66,8 +67,8 @@ if __name__ == '__main__': # 必须大于 3.10.1 if sys.version_info < (3, 10, 1): - print("需要 Python 3.10.1 及以上版本,当前 Python 版本为:", sys.version) - input("按任意键退出...") + print('需要 Python 3.10.1 及以上版本,当前 Python 版本为:', sys.version) + input('按任意键退出...') exit(1) # 检查本目录是否有main.py,且包含LangBot字符串 @@ -78,11 +79,11 @@ if __name__ == '__main__': else: with open('main.py', 'r', encoding='utf-8') as f: content = f.read() - if "LangBot/main.py" not in content: + if 'LangBot/main.py' not in content: invalid_pwd = True if invalid_pwd: - print("请在 LangBot 项目根目录下以命令形式运行此程序。") - input("按任意键退出...") + print('请在 LangBot 项目根目录下以命令形式运行此程序。') + input('按任意键退出...') exit(1) loop = asyncio.new_event_loop() diff --git a/pkg/api/http/controller/group.py b/pkg/api/http/controller/group.py index 7186802f..efbb7247 100644 --- a/pkg/api/http/controller/group.py +++ b/pkg/api/http/controller/group.py @@ -13,6 +13,7 @@ from ....core import app preregistered_groups: list[type[RouterGroup]] = [] """RouterGroup 的预注册列表""" + def group_class(name: str, path: str) -> None: """注册一个 RouterGroup""" @@ -27,12 +28,12 @@ def group_class(name: str, path: str) -> None: class AuthType(enum.Enum): """认证类型""" + NONE = 'none' USER_TOKEN = 'user-token' class RouterGroup(abc.ABC): - name: str path: str @@ -49,17 +50,24 @@ class RouterGroup(abc.ABC): async def initialize(self) -> None: pass - def route(self, rule: str, auth_type: AuthType = AuthType.USER_TOKEN, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator + def route( + self, + rule: str, + auth_type: AuthType = AuthType.USER_TOKEN, + **options: typing.Any, + ) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator """注册一个路由""" + def decorator(f: RouteCallable) -> RouteCallable: nonlocal rule rule = self.path + rule async def handler_error(*args, **kwargs): - if auth_type == AuthType.USER_TOKEN: # 从Authorization头中获取token - token = quart.request.headers.get('Authorization', '').replace('Bearer ', '') + token = quart.request.headers.get('Authorization', '').replace( + 'Bearer ', '' + ) if not token: return self.http_status(401, -1, '未提供有效的用户令牌') @@ -75,11 +83,11 @@ class RouterGroup(abc.ABC): try: return await f(*args, **kwargs) - except Exception as e: # 自动 500 + except Exception: # 自动 500 traceback.print_exc() # return self.http_status(500, -2, str(e)) return self.http_status(500, -2, 'internal server error') - + new_f = handler_error new_f.__name__ = (self.name + rule).replace('/', '__') new_f.__doc__ = f.__doc__ @@ -91,20 +99,24 @@ class RouterGroup(abc.ABC): def success(self, data: typing.Any = None) -> quart.Response: """返回一个 200 响应""" - return quart.jsonify({ - 'code': 0, - 'msg': 'ok', - 'data': data, - }) - + return quart.jsonify( + { + 'code': 0, + 'msg': 'ok', + 'data': data, + } + ) + def fail(self, code: int, msg: str) -> quart.Response: """返回一个异常响应""" - return quart.jsonify({ - 'code': code, - 'msg': msg, - }) - + return quart.jsonify( + { + 'code': code, + 'msg': msg, + } + ) + def http_status(self, status: int, code: int, msg: str) -> quart.Response: """返回一个指定状态码的响应""" return self.fail(code, msg), status diff --git a/pkg/api/http/controller/groups/logs.py b/pkg/api/http/controller/groups/logs.py index 4244d889..b0643cb6 100644 --- a/pkg/api/http/controller/groups/logs.py +++ b/pkg/api/http/controller/groups/logs.py @@ -1,32 +1,29 @@ from __future__ import annotations -import traceback import quart -from .....core import app from .. import group @group.group_class('logs', '/api/v1/logs') class LogsRouterGroup(group.RouterGroup): - async def initialize(self) -> None: @self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: - start_page_number = int(quart.request.args.get('start_page_number', 0)) start_offset = int(quart.request.args.get('start_offset', 0)) - logs_str, end_page_number, end_offset = self.ap.log_cache.get_log_by_pointer( - start_page_number=start_page_number, - start_offset=start_offset + logs_str, end_page_number, end_offset = ( + self.ap.log_cache.get_log_by_pointer( + start_page_number=start_page_number, start_offset=start_offset + ) ) return self.success( data={ - "logs": logs_str, - "end_page_number": end_page_number, - "end_offset": end_offset + 'logs': logs_str, + 'end_page_number': end_page_number, + 'end_offset': end_offset, } ) diff --git a/pkg/api/http/controller/groups/pipelines.py b/pkg/api/http/controller/groups/pipelines.py index 400da593..02564e58 100644 --- a/pkg/api/http/controller/groups/pipelines.py +++ b/pkg/api/http/controller/groups/pipelines.py @@ -3,46 +3,41 @@ from __future__ import annotations import quart from .. import group -from .....entity.persistence import pipeline @group.group_class('pipelines', '/api/v1/pipelines') class PipelinesRouterGroup(group.RouterGroup): - async def initialize(self) -> None: - @self.route('', methods=['GET', 'POST']) async def _() -> str: if quart.request.method == 'GET': - return self.success(data={ - 'pipelines': await self.ap.pipeline_service.get_pipelines() - }) + return self.success( + data={'pipelines': await self.ap.pipeline_service.get_pipelines()} + ) elif quart.request.method == 'POST': json_data = await quart.request.json - pipeline_uuid = await self.ap.pipeline_service.create_pipeline(json_data) + pipeline_uuid = await self.ap.pipeline_service.create_pipeline( + json_data + ) - return self.success(data={ - 'uuid': pipeline_uuid - }) + return self.success(data={'uuid': pipeline_uuid}) @self.route('/_/metadata', methods=['GET']) async def _() -> str: - return self.success(data={ - 'configs': await self.ap.pipeline_service.get_pipeline_metadata() - }) + return self.success( + data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()} + ) @self.route('/', methods=['GET', 'PUT', 'DELETE']) async def _(pipeline_uuid: str) -> str: if quart.request.method == 'GET': pipeline = await self.ap.pipeline_service.get_pipeline(pipeline_uuid) - + if pipeline is None: return self.http_status(404, -1, 'pipeline not found') - return self.success(data={ - 'pipeline': pipeline - }) + return self.success(data={'pipeline': pipeline}) elif quart.request.method == 'PUT': json_data = await quart.request.json @@ -53,4 +48,3 @@ class PipelinesRouterGroup(group.RouterGroup): await self.ap.pipeline_service.delete_pipeline(pipeline_uuid) return self.success() - diff --git a/pkg/api/http/controller/groups/platform/adapters.py b/pkg/api/http/controller/groups/platform/adapters.py index fa8b2d9c..511ae003 100644 --- a/pkg/api/http/controller/groups/platform/adapters.py +++ b/pkg/api/http/controller/groups/platform/adapters.py @@ -5,29 +5,31 @@ from ... import group @group.group_class('adapters', '/api/v1/platform/adapters') class AdaptersRouterGroup(group.RouterGroup): - async def initialize(self) -> None: @self.route('', methods=['GET']) async def _() -> str: - return self.success(data={ - 'adapters': self.ap.platform_mgr.get_available_adapters_info() - }) - + return self.success( + data={'adapters': self.ap.platform_mgr.get_available_adapters_info()} + ) + @self.route('/', methods=['GET']) async def _(adapter_name: str) -> str: - adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(adapter_name) + adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name( + adapter_name + ) if adapter_info is None: return self.http_status(404, -1, 'adapter not found') - return self.success(data={ - 'adapter': adapter_info - }) - + return self.success(data={'adapter': adapter_info}) + @self.route('//icon', methods=['GET']) async def _(adapter_name: str) -> quart.Response: - - adapter_manifest = self.ap.platform_mgr.get_available_adapter_manifest_by_name(adapter_name) + adapter_manifest = ( + self.ap.platform_mgr.get_available_adapter_manifest_by_name( + adapter_name + ) + ) if adapter_manifest is None: return self.http_status(404, -1, 'adapter not found') @@ -37,4 +39,4 @@ class AdaptersRouterGroup(group.RouterGroup): if icon_path is None: return self.http_status(404, -1, 'icon not found') - return await quart.send_file(icon_path) \ No newline at end of file + return await quart.send_file(icon_path) diff --git a/pkg/api/http/controller/groups/platform/bots.py b/pkg/api/http/controller/groups/platform/bots.py index fe20aa53..af248fac 100644 --- a/pkg/api/http/controller/groups/platform/bots.py +++ b/pkg/api/http/controller/groups/platform/bots.py @@ -5,34 +5,27 @@ from ... import group @group.group_class('bots', '/api/v1/platform/bots') class BotsRouterGroup(group.RouterGroup): - async def initialize(self) -> None: @self.route('', methods=['GET', 'POST']) async def _() -> str: if quart.request.method == 'GET': - return self.success(data={ - 'bots': await self.ap.bot_service.get_bots() - }) + return self.success(data={'bots': await self.ap.bot_service.get_bots()}) elif quart.request.method == 'POST': json_data = await quart.request.json bot_uuid = await self.ap.bot_service.create_bot(json_data) - return self.success(data={ - 'uuid': bot_uuid - }) - + return self.success(data={'uuid': bot_uuid}) + @self.route('/', methods=['GET', 'PUT', 'DELETE']) async def _(bot_uuid: str) -> str: if quart.request.method == 'GET': bot = await self.ap.bot_service.get_bot(bot_uuid) if bot is None: return self.http_status(404, -1, 'bot not found') - return self.success(data={ - 'bot': bot - }) + return self.success(data={'bot': bot}) elif quart.request.method == 'PUT': json_data = await quart.request.json await self.ap.bot_service.update_bot(bot_uuid, json_data) return self.success() elif quart.request.method == 'DELETE': await self.ap.bot_service.delete_bot(bot_uuid) - return self.success() \ No newline at end of file + return self.success() diff --git a/pkg/api/http/controller/groups/plugins.py b/pkg/api/http/controller/groups/plugins.py index 330231c2..1deecca6 100644 --- a/pkg/api/http/controller/groups/plugins.py +++ b/pkg/api/http/controller/groups/plugins.py @@ -1,17 +1,14 @@ from __future__ import annotations -import traceback - import quart -from .....core import app, taskmgr +from .....core import taskmgr from .. import group @group.group_class('plugins', '/api/v1/plugins') class PluginsRouterGroup(group.RouterGroup): - async def initialize(self) -> None: @self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: @@ -19,63 +16,69 @@ class PluginsRouterGroup(group.RouterGroup): plugins_data = [plugin.model_dump() for plugin in plugins] - return self.success(data={ - 'plugins': plugins_data - }) - - @self.route('///toggle', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN) + return self.success(data={'plugins': plugins_data}) + + @self.route( + '///toggle', + methods=['PUT'], + auth_type=group.AuthType.USER_TOKEN, + ) async def _(author: str, plugin_name: str) -> str: data = await quart.request.json target_enabled = data.get('target_enabled') await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled) return self.success() - - @self.route('///update', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + + @self.route( + '///update', + methods=['POST'], + auth_type=group.AuthType.USER_TOKEN, + ) async def _(author: str, plugin_name: str) -> str: ctx = taskmgr.TaskContext.new() wrapper = self.ap.task_mgr.create_user_task( self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx), - kind="plugin-operation", - name=f"plugin-update-{plugin_name}", - label=f"更新插件 {plugin_name}", - context=ctx + kind='plugin-operation', + name=f'plugin-update-{plugin_name}', + label=f'更新插件 {plugin_name}', + context=ctx, ) - return self.success(data={ - 'task_id': wrapper.id - }) - - @self.route('//', methods=['GET', 'DELETE'], auth_type=group.AuthType.USER_TOKEN) + return self.success(data={'task_id': wrapper.id}) + + @self.route( + '//', + methods=['GET', 'DELETE'], + auth_type=group.AuthType.USER_TOKEN, + ) async def _(author: str, plugin_name: str) -> str: if quart.request.method == 'GET': plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name) if plugin is None: return self.http_status(404, -1, 'plugin not found') - return self.success(data={ - 'plugin': plugin.model_dump() - }) + return self.success(data={'plugin': plugin.model_dump()}) elif quart.request.method == 'DELETE': ctx = taskmgr.TaskContext.new() wrapper = self.ap.task_mgr.create_user_task( self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx), - kind="plugin-operation", + kind='plugin-operation', name=f'plugin-remove-{plugin_name}', label=f'删除插件 {plugin_name}', - context=ctx + context=ctx, ) - return self.success(data={ - 'task_id': wrapper.id - }) - - @self.route('///config', methods=['GET', 'PUT'], auth_type=group.AuthType.USER_TOKEN) + return self.success(data={'task_id': wrapper.id}) + + @self.route( + '///config', + methods=['GET', 'PUT'], + auth_type=group.AuthType.USER_TOKEN, + ) async def _(author: str, plugin_name: str) -> quart.Response: plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name) if plugin is None: return self.http_status(404, -1, 'plugin not found') if quart.request.method == 'GET': - return self.success(data={ - 'config': plugin.plugin_config - }) + return self.success(data={'config': plugin.plugin_config}) elif quart.request.method == 'PUT': data = await quart.request.json @@ -88,21 +91,21 @@ class PluginsRouterGroup(group.RouterGroup): data = await quart.request.json await self.ap.plugin_mgr.reorder_plugins(data.get('plugins')) return self.success() - - @self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + + @self.route( + '/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN + ) async def _() -> str: data = await quart.request.json - + ctx = taskmgr.TaskContext.new() short_source_str = data['source'][-8:] wrapper = self.ap.task_mgr.create_user_task( self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx), - kind="plugin-operation", - name=f'plugin-install-github', + kind='plugin-operation', + name='plugin-install-github', label=f'安装插件 ...{short_source_str}', - context=ctx + context=ctx, ) - return self.success(data={ - 'task_id': wrapper.id - }) + return self.success(data={'task_id': wrapper.id}) diff --git a/pkg/api/http/controller/groups/provider/models.py b/pkg/api/http/controller/groups/provider/models.py index 81e9078b..eaf96047 100644 --- a/pkg/api/http/controller/groups/provider/models.py +++ b/pkg/api/http/controller/groups/provider/models.py @@ -1,28 +1,23 @@ import quart -import uuid from ... import group -from ......entity.persistence import model @group.group_class('models/llm', '/api/v1/provider/models/llm') class LLMModelsRouterGroup(group.RouterGroup): - async def initialize(self) -> None: @self.route('', methods=['GET', 'POST']) async def _() -> str: if quart.request.method == 'GET': - return self.success(data={ - 'models': await self.ap.model_service.get_llm_models() - }) + return self.success( + data={'models': await self.ap.model_service.get_llm_models()} + ) elif quart.request.method == 'POST': json_data = await quart.request.json model_uuid = await self.ap.model_service.create_llm_model(json_data) - return self.success(data={ - 'uuid': model_uuid - }) + return self.success(data={'uuid': model_uuid}) @self.route('/', methods=['GET', 'DELETE']) async def _(model_uuid: str) -> str: @@ -32,9 +27,7 @@ class LLMModelsRouterGroup(group.RouterGroup): if model is None: return self.http_status(404, -1, 'model not found') - return self.success(data={ - 'model': model - }) + return self.success(data={'model': model}) # elif quart.request.method == 'PUT': # json_data = await quart.request.json diff --git a/pkg/api/http/controller/groups/provider/requesters.py b/pkg/api/http/controller/groups/provider/requesters.py index 18939d32..f95dfdb4 100644 --- a/pkg/api/http/controller/groups/provider/requesters.py +++ b/pkg/api/http/controller/groups/provider/requesters.py @@ -5,29 +5,31 @@ from ... import group @group.group_class('provider/requesters', '/api/v1/provider/requesters') class RequestersRouterGroup(group.RouterGroup): - async def initialize(self) -> None: @self.route('', methods=['GET']) async def _() -> quart.Response: - return self.success(data={ - 'requesters': self.ap.model_mgr.get_available_requesters_info() - }) - + return self.success( + data={'requesters': self.ap.model_mgr.get_available_requesters_info()} + ) + @self.route('/', methods=['GET']) async def _(requester_name: str) -> quart.Response: - - requester_info = self.ap.model_mgr.get_available_requester_info_by_name(requester_name) + requester_info = self.ap.model_mgr.get_available_requester_info_by_name( + requester_name + ) if requester_info is None: return self.http_status(404, -1, 'requester not found') - return self.success(data={ - 'requester': requester_info - }) - + return self.success(data={'requester': requester_info}) + @self.route('//icon', methods=['GET']) async def _(requester_name: str) -> quart.Response: - requester_manifest = self.ap.model_mgr.get_available_requester_manifest_by_name(requester_name) + requester_manifest = ( + self.ap.model_mgr.get_available_requester_manifest_by_name( + requester_name + ) + ) if requester_manifest is None: return self.http_status(404, -1, 'requester not found') diff --git a/pkg/api/http/controller/groups/stats.py b/pkg/api/http/controller/groups/stats.py index 43d56f27..7b1d4353 100644 --- a/pkg/api/http/controller/groups/stats.py +++ b/pkg/api/http/controller/groups/stats.py @@ -1,23 +1,21 @@ -import quart -import asyncio - -from .....core import app, taskmgr from .. import group @group.group_class('stats', '/api/v1/stats') class StatsRouterGroup(group.RouterGroup): - async def initialize(self) -> None: @self.route('/basic', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: - conv_count = 0 for session in self.ap.sess_mgr.session_list: - conv_count += len(session.conversations if session.conversations is not None else []) + conv_count += len( + session.conversations if session.conversations is not None else [] + ) - return self.success(data={ - 'active_session_count': len(self.ap.sess_mgr.session_list), - 'conversation_count': conv_count, - 'query_count': self.ap.query_pool.query_id_counter, - }) + return self.success( + data={ + 'active_session_count': len(self.ap.sess_mgr.session_list), + 'conversation_count': conv_count, + 'query_count': self.ap.query_pool.query_id_counter, + } + ) diff --git a/pkg/api/http/controller/groups/system.py b/pkg/api/http/controller/groups/system.py index 04ace284..c586ea27 100644 --- a/pkg/api/http/controller/groups/system.py +++ b/pkg/api/http/controller/groups/system.py @@ -1,63 +1,62 @@ import quart -import asyncio -from .....core import app, taskmgr from .. import group from .....utils import constants @group.group_class('system', '/api/v1/system') class SystemRouterGroup(group.RouterGroup): - async def initialize(self) -> None: @self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE) async def _() -> str: return self.success( data={ - "version": constants.semantic_version, - "debug": constants.debug_mode, - "enabled_platform_count": len(self.ap.platform_mgr.get_running_adapters()) + 'version': constants.semantic_version, + 'debug': constants.debug_mode, + 'enabled_platform_count': len( + self.ap.platform_mgr.get_running_adapters() + ), } ) @self.route('/tasks', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: - task_type = quart.request.args.get("type") + task_type = quart.request.args.get('type') if task_type == '': task_type = None - return self.success( - data=self.ap.task_mgr.get_tasks_dict(task_type) - ) - - @self.route('/tasks/', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) + return self.success(data=self.ap.task_mgr.get_tasks_dict(task_type)) + + @self.route( + '/tasks/', methods=['GET'], auth_type=group.AuthType.USER_TOKEN + ) async def _(task_id: str) -> str: task = self.ap.task_mgr.get_task_by_id(int(task_id)) if task is None: - return self.http_status(404, 404, "Task not found") - + return self.http_status(404, 404, 'Task not found') + return self.success(data=task.to_dict()) - + @self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: json_data = await quart.request.json - scope = json_data.get("scope") + scope = json_data.get('scope') - await self.ap.reload( - scope=scope - ) + await self.ap.reload(scope=scope) return self.success() - @self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + @self.route( + '/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN + ) async def _() -> str: if not constants.debug_mode: - return self.http_status(403, 403, "Forbidden") - + return self.http_status(403, 403, 'Forbidden') + py_code = await quart.request.data ap = self.ap - return self.success(data=exec(py_code, {"ap": ap})) + return self.success(data=exec(py_code, {'ap': ap})) diff --git a/pkg/api/http/controller/groups/user.py b/pkg/api/http/controller/groups/user.py index 3cd08240..4c330782 100644 --- a/pkg/api/http/controller/groups/user.py +++ b/pkg/api/http/controller/groups/user.py @@ -1,22 +1,19 @@ import quart -import jwt import argon2 from .. import group -from .....entity.persistence import user @group.group_class('user', '/api/v1/user') class UserRouterGroup(group.RouterGroup): - async def initialize(self) -> None: @self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE) async def _() -> str: if quart.request.method == 'GET': - return self.success(data={ - 'initialized': await self.ap.user_service.is_initialized() - }) - + return self.success( + data={'initialized': await self.ap.user_service.is_initialized()} + ) + if await self.ap.user_service.is_initialized(): return self.fail(1, '系统已初始化') @@ -28,24 +25,24 @@ class UserRouterGroup(group.RouterGroup): await self.ap.user_service.create_user(user_email, password) return self.success() - + @self.route('/auth', methods=['POST'], auth_type=group.AuthType.NONE) async def _() -> str: json_data = await quart.request.json try: - token = await self.ap.user_service.authenticate(json_data['user'], json_data['password']) + token = await self.ap.user_service.authenticate( + json_data['user'], json_data['password'] + ) except argon2.exceptions.VerifyMismatchError: return self.fail(1, '用户名或密码错误') - return self.success(data={ - 'token': token - }) + return self.success(data={'token': token}) - @self.route('/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) + @self.route( + '/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN + ) async def _(user_email: str) -> str: token = await self.ap.user_service.generate_jwt_token(user_email) - return self.success(data={ - 'token': token - }) + return self.success(data={'token': token}) diff --git a/pkg/api/http/controller/main.py b/pkg/api/http/controller/main.py index 0d6bcd15..e5e42df6 100644 --- a/pkg/api/http/controller/main.py +++ b/pkg/api/http/controller/main.py @@ -7,15 +7,19 @@ import quart import quart_cors from ....core import app, entities as core_entities +from ....utils import importutil -from .groups import logs, system, plugins, stats, user, pipelines -from .groups.provider import models, requesters -from .groups.platform import bots, adapters +from . import groups from . import group +from .groups import provider as groups_provider +from .groups import platform as groups_platform + +importutil.import_modules_in_pkg(groups) +importutil.import_modules_in_pkg(groups_provider) +importutil.import_modules_in_pkg(groups_platform) class HTTPController: - ap: app.Application quart_app: quart.Quart @@ -23,7 +27,7 @@ class HTTPController: def __init__(self, ap: app.Application) -> None: self.ap = ap self.quart_app = quart.Quart(__name__) - quart_cors.cors(self.quart_app, allow_origin="*") + quart_cors.cors(self.quart_app, allow_origin='*') async def initialize(self) -> None: await self.register_routes() @@ -37,11 +41,9 @@ class HTTPController: async def exception_handler(*args, **kwargs): try: - await self.quart_app.run_task( - *args, **kwargs - ) + await self.quart_app.run_task(*args, **kwargs) except Exception as e: - self.ap.logger.error(f"启动 HTTP 服务失败: {e}") + self.ap.logger.error(f'启动 HTTP 服务失败: {e}') self.ap.task_mgr.create_task( exception_handler( @@ -49,63 +51,62 @@ class HTTPController: port=self.ap.instance_config.data['api']['port'], shutdown_trigger=shutdown_trigger_placeholder, ), - name="http-api-quart", + name='http-api-quart', scopes=[core_entities.LifecycleControlScope.APPLICATION], ) # await asyncio.sleep(5) async def register_routes(self) -> None: - - @self.quart_app.route("/healthz") + @self.quart_app.route('/healthz') async def healthz(): - return {"code": 0, "msg": "ok"} + return {'code': 0, 'msg': 'ok'} for g in group.preregistered_groups: ginst = g(self.ap, self.quart_app) await ginst.initialize() - frontend_path = "web/out" + frontend_path = 'web/out' - @self.quart_app.route("/") + @self.quart_app.route('/') async def index(): - return await quart.send_from_directory(frontend_path, "index.html", mimetype="text/html") + return await quart.send_from_directory( + frontend_path, 'index.html', mimetype='text/html' + ) - @self.quart_app.route("/") + @self.quart_app.route('/') async def static_file(path: str): if not os.path.exists(os.path.join(frontend_path, path)): - if os.path.exists(os.path.join(frontend_path, path+".html")): + if os.path.exists(os.path.join(frontend_path, path + '.html')): path += '.html' else: return await quart.send_from_directory(frontend_path, '404.html') mimetype = None - if path.endswith(".html"): - mimetype = "text/html" - elif path.endswith(".js"): - mimetype = "application/javascript" - elif path.endswith(".css"): - mimetype = "text/css" - elif path.endswith(".png"): - mimetype = "image/png" - elif path.endswith(".jpg"): - mimetype = "image/jpeg" - elif path.endswith(".jpeg"): - mimetype = "image/jpeg" - elif path.endswith(".gif"): - mimetype = "image/gif" - elif path.endswith(".svg"): - mimetype = "image/svg+xml" - elif path.endswith(".ico"): - mimetype = "image/x-icon" - elif path.endswith(".json"): - mimetype = "application/json" - elif path.endswith(".txt"): - mimetype = "text/plain" + if path.endswith('.html'): + mimetype = 'text/html' + elif path.endswith('.js'): + mimetype = 'application/javascript' + elif path.endswith('.css'): + mimetype = 'text/css' + elif path.endswith('.png'): + mimetype = 'image/png' + elif path.endswith('.jpg'): + mimetype = 'image/jpeg' + elif path.endswith('.jpeg'): + mimetype = 'image/jpeg' + elif path.endswith('.gif'): + mimetype = 'image/gif' + elif path.endswith('.svg'): + mimetype = 'image/svg+xml' + elif path.endswith('.ico'): + mimetype = 'image/x-icon' + elif path.endswith('.json'): + mimetype = 'application/json' + elif path.endswith('.txt'): + mimetype = 'text/plain' return await quart.send_from_directory( - frontend_path, - path, - mimetype=mimetype + frontend_path, path, mimetype=mimetype ) diff --git a/pkg/api/http/service/bot.py b/pkg/api/http/service/bot.py index 557b63f1..23e9fa5b 100644 --- a/pkg/api/http/service/bot.py +++ b/pkg/api/http/service/bot.py @@ -1,7 +1,6 @@ from __future__ import annotations import uuid -import datetime import sqlalchemy from ....core import app @@ -29,13 +28,15 @@ class BotService: self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots ] - + async def get_bot(self, bot_uuid: str) -> dict | None: """获取机器人""" result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) + sqlalchemy.select(persistence_bot.Bot).where( + persistence_bot.Bot.uuid == bot_uuid + ) ) - + bot = result.first() if bot is None: @@ -50,7 +51,9 @@ class BotService: # checkout the default pipeline result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.is_default == True) + sqlalchemy.select(persistence_pipeline.LegacyPipeline).where( + persistence_pipeline.LegacyPipeline.is_default == True + ) ) pipeline = result.first() if pipeline is not None: @@ -64,7 +67,7 @@ class BotService: bot = await self.get_bot(bot_data['uuid']) await self.ap.platform_mgr.load_bot(bot) - + return bot_data['uuid'] async def update_bot(self, bot_uuid: str, bot_data: dict) -> None: @@ -75,19 +78,24 @@ class BotService: # set use_pipeline_name if 'use_pipeline_uuid' in bot_data: result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid']) + sqlalchemy.select(persistence_pipeline.LegacyPipeline).where( + persistence_pipeline.LegacyPipeline.uuid + == bot_data['use_pipeline_uuid'] + ) ) pipeline = result.first() if pipeline is not None: bot_data['use_pipeline_name'] = pipeline.name else: - raise Exception("Pipeline not found") + raise Exception('Pipeline not found') await self.ap.persistence_mgr.execute_async( - sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid) + sqlalchemy.update(persistence_bot.Bot) + .values(bot_data) + .where(persistence_bot.Bot.uuid == bot_uuid) ) await self.ap.platform_mgr.remove_bot(bot_uuid) - + # select from db bot = await self.get_bot(bot_uuid) @@ -100,7 +108,7 @@ class BotService: """删除机器人""" await self.ap.platform_mgr.remove_bot(bot_uuid) await self.ap.persistence_mgr.execute_async( - sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) + sqlalchemy.delete(persistence_bot.Bot).where( + persistence_bot.Bot.uuid == bot_uuid + ) ) - - diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index e74e975d..8a71bf1c 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -1,7 +1,6 @@ from __future__ import annotations import uuid -import datetime import sqlalchemy from ....core import app @@ -10,7 +9,6 @@ from ....entity.persistence import pipeline as persistence_pipeline class ModelsService: - ap: app.Application def __init__(self, ap: app.Application) -> None: @@ -26,15 +24,12 @@ class ModelsService: self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models ] - - async def create_llm_model(self, model_data: dict) -> str: + async def create_llm_model(self, model_data: dict) -> str: model_data['uuid'] = str(uuid.uuid4()) await self.ap.persistence_mgr.execute_async( - sqlalchemy.insert(persistence_model.LLMModel).values( - **model_data - ) + sqlalchemy.insert(persistence_model.LLMModel).values(**model_data) ) llm_model = await self.get_llm_model(model_data['uuid']) @@ -43,22 +38,24 @@ class ModelsService: # check if default pipeline has no model bound result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.is_default == True) + sqlalchemy.select(persistence_pipeline.LegacyPipeline).where( + persistence_pipeline.LegacyPipeline.is_default == True + ) ) pipeline = result.first() if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '': pipeline_config = pipeline.config pipeline_config['ai']['local-agent']['model'] = model_data['uuid'] - pipeline_data = { - "config": pipeline_config - } - await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data) + pipeline_data = {'config': pipeline_config} + await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data) return model_data['uuid'] async def get_llm_model(self, model_uuid: str) -> dict | None: result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid) + sqlalchemy.select(persistence_model.LLMModel).where( + persistence_model.LLMModel.uuid == model_uuid + ) ) model = result.first() @@ -66,14 +63,18 @@ class ModelsService: if model is None: return None - return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) + return self.ap.persistence_mgr.serialize_model( + persistence_model.LLMModel, model + ) async def update_llm_model(self, model_uuid: str, model_data: dict) -> None: if 'uuid' in model_data: del model_data['uuid'] - + await self.ap.persistence_mgr.execute_async( - sqlalchemy.update(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid).values(**model_data) + sqlalchemy.update(persistence_model.LLMModel) + .where(persistence_model.LLMModel.uuid == model_uuid) + .values(**model_data) ) await self.ap.model_mgr.remove_llm_model(model_uuid) @@ -84,7 +85,9 @@ class ModelsService: async def delete_llm_model(self, model_uuid: str) -> None: await self.ap.persistence_mgr.execute_async( - sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid) + sqlalchemy.delete(persistence_model.LLMModel).where( + persistence_model.LLMModel.uuid == model_uuid + ) ) await self.ap.model_mgr.remove_llm_model(model_uuid) diff --git a/pkg/api/http/service/pipeline.py b/pkg/api/http/service/pipeline.py index 6e8a66dc..0dd73ef2 100644 --- a/pkg/api/http/service/pipeline.py +++ b/pkg/api/http/service/pipeline.py @@ -2,7 +2,6 @@ from __future__ import annotations import uuid import json -import datetime import sqlalchemy from ....core import app @@ -10,69 +9,79 @@ from ....entity.persistence import pipeline as persistence_pipeline default_stage_order = [ - "GroupRespondRuleCheckStage", # 群响应规则检查 - "BanSessionCheckStage", # 封禁会话检查 - "PreContentFilterStage", # 内容过滤前置阶段 - "PreProcessor", # 预处理器 - "ConversationMessageTruncator", # 会话消息截断器 - "RequireRateLimitOccupancy", # 请求速率限制占用 - "MessageProcessor", # 处理器 - "ReleaseRateLimitOccupancy", # 释放速率限制占用 - "PostContentFilterStage", # 内容过滤后置阶段 - "ResponseWrapper", # 响应包装器 - "LongTextProcessStage", # 长文本处理 - "SendResponseBackStage", # 发送响应 + 'GroupRespondRuleCheckStage', # 群响应规则检查 + 'BanSessionCheckStage', # 封禁会话检查 + 'PreContentFilterStage', # 内容过滤前置阶段 + 'PreProcessor', # 预处理器 + 'ConversationMessageTruncator', # 会话消息截断器 + 'RequireRateLimitOccupancy', # 请求速率限制占用 + 'MessageProcessor', # 处理器 + 'ReleaseRateLimitOccupancy', # 释放速率限制占用 + 'PostContentFilterStage', # 内容过滤后置阶段 + 'ResponseWrapper', # 响应包装器 + 'LongTextProcessStage', # 长文本处理 + 'SendResponseBackStage', # 发送响应 ] class PipelineService: ap: app.Application - + def __init__(self, ap: app.Application) -> None: self.ap = ap - + async def get_pipeline_metadata(self) -> dict: return [ self.ap.pipeline_config_meta_trigger.data, self.ap.pipeline_config_meta_safety.data, self.ap.pipeline_config_meta_ai.data, - self.ap.pipeline_config_meta_output.data + self.ap.pipeline_config_meta_output.data, ] async def get_pipelines(self) -> list[dict]: result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_pipeline.LegacyPipeline) ) - + pipelines = result.all() return [ - self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) + self.ap.persistence_mgr.serialize_model( + persistence_pipeline.LegacyPipeline, pipeline + ) for pipeline in pipelines ] - + async def get_pipeline(self, pipeline_uuid: str) -> dict | None: result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid) + sqlalchemy.select(persistence_pipeline.LegacyPipeline).where( + persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid + ) ) - + pipeline = result.first() if pipeline is None: return None - return self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) + return self.ap.persistence_mgr.serialize_model( + persistence_pipeline.LegacyPipeline, pipeline + ) async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str: pipeline_data['uuid'] = str(uuid.uuid4()) pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version() pipeline_data['stages'] = default_stage_order.copy() pipeline_data['is_default'] = default - pipeline_data['config'] = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8')) + pipeline_data['config'] = json.load( + open('templates/default-pipeline-config.json', 'r', encoding='utf-8') + ) await self.ap.persistence_mgr.execute_async( - sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data) + sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values( + **pipeline_data + ) ) - + pipeline = await self.get_pipeline(pipeline_data['uuid']) await self.ap.pipeline_mgr.load_pipeline(pipeline) @@ -90,7 +99,9 @@ class PipelineService: del pipeline_data['is_default'] await self.ap.persistence_mgr.execute_async( - sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data) + sqlalchemy.update(persistence_pipeline.LegacyPipeline) + .where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid) + .values(**pipeline_data) ) await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid) @@ -101,6 +112,8 @@ class PipelineService: async def delete_pipeline(self, pipeline_uuid: str) -> None: await self.ap.persistence_mgr.execute_async( - sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid) + sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where( + persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid + ) ) await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid) diff --git a/pkg/api/http/service/user.py b/pkg/api/http/service/user.py index 4b31f1c0..2edfd874 100644 --- a/pkg/api/http/service/user.py +++ b/pkg/api/http/service/user.py @@ -11,7 +11,6 @@ from ....utils import constants class UserService: - ap: app.Application def __init__(self, ap: app.Application) -> None: @@ -24,7 +23,7 @@ class UserService: result_list = result.all() return result_list is not None and len(result_list) > 0 - + async def create_user(self, user_email: str, password: str) -> None: ph = argon2.PasswordHasher() @@ -32,8 +31,7 @@ class UserService: await self.ap.persistence_mgr.execute_async( sqlalchemy.insert(user.User).values( - user=user_email, - password=hashed_password + user=user_email, password=hashed_password ) ) @@ -61,12 +59,12 @@ class UserService: payload = { 'user': user_email, - 'iss': 'LangBot-'+constants.edition, - 'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire) + 'iss': 'LangBot-' + constants.edition, + 'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire), } return jwt.encode(payload, jwt_secret, algorithm='HS256') - + async def verify_jwt_token(self, token: str) -> str: jwt_secret = self.ap.instance_config.data['system']['jwt']['secret'] diff --git a/pkg/audit/__init__.py b/pkg/audit/__init__.py index c1a8353b..5f89c8b8 100644 --- a/pkg/audit/__init__.py +++ b/pkg/audit/__init__.py @@ -1,3 +1,3 @@ """ 审计相关操作 -""" \ No newline at end of file +""" diff --git a/pkg/audit/center/apigroup.py b/pkg/audit/center/apigroup.py index 4b20a09a..ac5e5117 100644 --- a/pkg/audit/center/apigroup.py +++ b/pkg/audit/center/apigroup.py @@ -3,11 +3,9 @@ from __future__ import annotations import abc import uuid import json -import logging import asyncio import aiohttp -import requests from ...core import app, entities as core_entities @@ -38,22 +36,22 @@ class APIGroup(metaclass=abc.ABCMeta): """ 执行请求 """ - self._runtime_info["account_id"] = "-1" + self._runtime_info['account_id'] = '-1' url = self.prefix + path data = json.dumps(data) - headers["Content-Type"] = "application/json" + headers['Content-Type'] = 'application/json' try: async with aiohttp.ClientSession() as session: async with session.request( method, url, data=data, params=params, headers=headers, **kwargs ) as resp: - self.ap.logger.debug("data: %s", data) - self.ap.logger.debug("ret: %s", await resp.text()) + self.ap.logger.debug('data: %s', data) + self.ap.logger.debug('ret: %s', await resp.text()) except Exception as e: - self.ap.logger.debug(f"上报失败: {e}") + self.ap.logger.debug(f'上报失败: {e}') async def do( self, @@ -68,8 +66,8 @@ class APIGroup(metaclass=abc.ABCMeta): return self.ap.task_mgr.create_task( self._do(method, path, data, params, headers, **kwargs), - kind="telemetry-operation", - name=f"{method} {path}", + kind='telemetry-operation', + name=f'{method} {path}', scopes=[core_entities.LifecycleControlScope.APPLICATION], ).task @@ -80,7 +78,7 @@ class APIGroup(metaclass=abc.ABCMeta): def basic_info(self): """获取基本信息""" basic_info = APIGroup._basic_info.copy() - basic_info["rid"] = self.gen_rid() + basic_info['rid'] = self.gen_rid() return basic_info def runtime_info(self): diff --git a/pkg/audit/center/groups/main.py b/pkg/audit/center/groups/main.py index 854437a1..2c2302d1 100644 --- a/pkg/audit/center/groups/main.py +++ b/pkg/audit/center/groups/main.py @@ -9,7 +9,7 @@ class V2MainDataAPI(apigroup.APIGroup): def __init__(self, prefix: str, ap: app.Application): self.ap = ap - super().__init__(prefix+"/main", ap) + super().__init__(prefix + '/main', ap) async def do(self, *args, **kwargs): if not self.ap.instance_config.data['telemetry']['report']: @@ -25,31 +25,31 @@ class V2MainDataAPI(apigroup.APIGroup): ): """提交更新记录""" return await self.do( - "POST", - "/update", + 'POST', + '/update', data={ - "basic": self.basic_info(), - "update_info": { - "spent_seconds": spent_seconds, - "infer_reason": infer_reason, - "old_version": old_version, - "new_version": new_version, - } - } + 'basic': self.basic_info(), + 'update_info': { + 'spent_seconds': spent_seconds, + 'infer_reason': infer_reason, + 'old_version': old_version, + 'new_version': new_version, + }, + }, ) - + async def post_announcement_showed( self, ids: list[int], ): """提交公告已阅""" return await self.do( - "POST", - "/announcement", + 'POST', + '/announcement', data={ - "basic": self.basic_info(), - "announcement_info": { - "ids": ids, - } - } + 'basic': self.basic_info(), + 'announcement_info': { + 'ids': ids, + }, + }, ) diff --git a/pkg/audit/center/groups/plugin.py b/pkg/audit/center/groups/plugin.py index d6ed0b02..4978e5d0 100644 --- a/pkg/audit/center/groups/plugin.py +++ b/pkg/audit/center/groups/plugin.py @@ -9,39 +9,33 @@ class V2PluginDataAPI(apigroup.APIGroup): def __init__(self, prefix: str, ap: app.Application): self.ap = ap - super().__init__(prefix+"/plugin", ap) + super().__init__(prefix + '/plugin', ap) async def do(self, *args, **kwargs): if not self.ap.instance_config.data['telemetry']['report']: return None return await super().do(*args, **kwargs) - async def post_install_record( - self, - plugin: dict - ): + async def post_install_record(self, plugin: dict): """提交插件安装记录""" return await self.do( - "POST", - "/install", + 'POST', + '/install', data={ - "basic": self.basic_info(), - "plugin": plugin, - } + 'basic': self.basic_info(), + 'plugin': plugin, + }, ) - async def post_remove_record( - self, - plugin: dict - ): + async def post_remove_record(self, plugin: dict): """提交插件卸载记录""" return await self.do( - "POST", - "/remove", + 'POST', + '/remove', data={ - "basic": self.basic_info(), - "plugin": plugin, - } + 'basic': self.basic_info(), + 'plugin': plugin, + }, ) async def post_update_record( @@ -52,14 +46,14 @@ class V2PluginDataAPI(apigroup.APIGroup): ): """提交插件更新记录""" return await self.do( - "POST", - "/update", + 'POST', + '/update', data={ - "basic": self.basic_info(), - "plugin": plugin, - "update_info": { - "old_version": old_version, - "new_version": new_version, - } - } + 'basic': self.basic_info(), + 'plugin': plugin, + 'update_info': { + 'old_version': old_version, + 'new_version': new_version, + }, + }, ) diff --git a/pkg/audit/center/groups/usage.py b/pkg/audit/center/groups/usage.py index 79bc56f5..bdbb27eb 100644 --- a/pkg/audit/center/groups/usage.py +++ b/pkg/audit/center/groups/usage.py @@ -9,7 +9,7 @@ class V2UsageDataAPI(apigroup.APIGroup): def __init__(self, prefix: str, ap: app.Application): self.ap = ap - super().__init__(prefix+"/usage", ap) + super().__init__(prefix + '/usage', ap) async def do(self, *args, **kwargs): if not self.ap.instance_config.data['telemetry']['report']: @@ -28,25 +28,25 @@ class V2UsageDataAPI(apigroup.APIGroup): ): """提交请求记录""" return await self.do( - "POST", - "/query", + 'POST', + '/query', data={ - "basic": self.basic_info(), - "runtime": self.runtime_info(), - "session_info": { - "type": session_type, - "id": session_id, + 'basic': self.basic_info(), + 'runtime': self.runtime_info(), + 'session_info': { + 'type': session_type, + 'id': session_id, }, - "query_info": { - "ability_provider": query_ability_provider, - "usage": usage, - "model_name": model_name, - "response_seconds": response_seconds, - "retry_times": retry_times, - } - } + 'query_info': { + 'ability_provider': query_ability_provider, + 'usage': usage, + 'model_name': model_name, + 'response_seconds': response_seconds, + 'retry_times': retry_times, + }, + }, ) - + async def post_event_record( self, plugins: list[dict], @@ -54,18 +54,18 @@ class V2UsageDataAPI(apigroup.APIGroup): ): """提交事件触发记录""" return await self.do( - "POST", - "/event", + 'POST', + '/event', data={ - "basic": self.basic_info(), - "runtime": self.runtime_info(), - "plugins": plugins, - "event_info": { - "name": event_name, - } - } + 'basic': self.basic_info(), + 'runtime': self.runtime_info(), + 'plugins': plugins, + 'event_info': { + 'name': event_name, + }, + }, ) - + async def post_function_record( self, plugin: dict, @@ -74,15 +74,14 @@ class V2UsageDataAPI(apigroup.APIGroup): ): """提交内容函数使用记录""" return await self.do( - "POST", - "/function", + 'POST', + '/function', data={ - "basic": self.basic_info(), - "plugin": plugin, - "function_info": { - "name": function_name, - "description": function_description, - } - } + 'basic': self.basic_info(), + 'plugin': plugin, + 'function_info': { + 'name': function_name, + 'description': function_description, + }, + }, ) - diff --git a/pkg/audit/center/v2.py b/pkg/audit/center/v2.py index 234e6d22..e9df6f91 100644 --- a/pkg/audit/center/v2.py +++ b/pkg/audit/center/v2.py @@ -11,7 +11,7 @@ from ...core import app class V2CenterAPI: """中央服务器 v2 API 交互类""" - + main: main.V2MainDataAPI = None """主 API 组""" @@ -21,15 +21,20 @@ class V2CenterAPI: plugin: plugin.V2PluginDataAPI = None """插件 API 组""" - def __init__(self, ap: app.Application, backend_url: str, basic_info: dict = None, runtime_info: dict = None): + def __init__( + self, + ap: app.Application, + backend_url: str, + basic_info: dict = None, + runtime_info: dict = None, + ): """初始化""" - logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info) - + logging.debug('basic_info: %s, runtime_info: %s', basic_info, runtime_info) + apigroup.APIGroup._basic_info = basic_info apigroup.APIGroup._runtime_info = runtime_info self.main = main.V2MainDataAPI(backend_url, ap) self.usage = usage.V2UsageDataAPI(backend_url, ap) self.plugin = plugin.V2PluginDataAPI(backend_url, ap) - diff --git a/pkg/audit/identifier.py b/pkg/audit/identifier.py index 3e2ec57d..b10d093c 100644 --- a/pkg/audit/identifier.py +++ b/pkg/audit/identifier.py @@ -16,6 +16,7 @@ identifier = { HOST_ID_FILE = os.path.expanduser('~/.langbot/host_id.json') INSTANCE_ID_FILE = 'data/labels/instance_id.json' + def init(): global identifier @@ -23,14 +24,11 @@ def init(): os.mkdir(os.path.expanduser('~/.langbot')) if not os.path.exists(HOST_ID_FILE): - new_host_id = 'host_'+str(uuid.uuid4()) + new_host_id = 'host_' + str(uuid.uuid4()) new_host_create_ts = int(time.time()) with open(HOST_ID_FILE, 'w') as f: - json.dump({ - 'host_id': new_host_id, - 'host_create_ts': new_host_create_ts - }, f) + json.dump({'host_id': new_host_id, 'host_create_ts': new_host_create_ts}, f) identifier['host_id'] = new_host_id identifier['host_create_ts'] = new_host_create_ts @@ -51,20 +49,25 @@ def init(): instance_id = {} with open(INSTANCE_ID_FILE, 'r') as f: instance_id = json.load(f) - - if instance_id['host_id'] != identifier['host_id']: # 如果实例 id 不是当前主机的,删除 + + if ( + instance_id['host_id'] != identifier['host_id'] + ): # 如果实例 id 不是当前主机的,删除 os.remove(INSTANCE_ID_FILE) if not os.path.exists(INSTANCE_ID_FILE): - new_instance_id = 'instance_'+str(uuid.uuid4()) + new_instance_id = 'instance_' + str(uuid.uuid4()) new_instance_create_ts = int(time.time()) with open(INSTANCE_ID_FILE, 'w') as f: - json.dump({ - 'host_id': identifier['host_id'], - 'instance_id': new_instance_id, - 'instance_create_ts': new_instance_create_ts - }, f) + json.dump( + { + 'host_id': identifier['host_id'], + 'instance_id': new_instance_id, + 'instance_create_ts': new_instance_create_ts, + }, + f, + ) identifier['instance_id'] = new_instance_id identifier['instance_create_ts'] = new_instance_create_ts @@ -80,6 +83,7 @@ def init(): identifier['instance_id'] = loaded_instance_id identifier['instance_create_ts'] = loaded_instance_create_ts + def print_out(): global identifier print(identifier) diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 3275b6fc..10d76067 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -3,17 +3,17 @@ from __future__ import annotations import typing from ..core import app, entities as core_entities -from ..provider import entities as llm_entities from . import entities, operator, errors -from ..config import manager as cfg_mgr +from ..utils import importutil # 引入所有算子以便注册 -from .operators import func, plugin, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model +from . import operators + +importutil.import_modules_in_pkg(operators) class CommandManager: - """命令管理器 - """ + """命令管理器""" ap: app.Application @@ -26,14 +26,13 @@ class CommandManager: self.ap = ap async def initialize(self): - # 设置各个类的路径 def set_path(cls: operator.CommandOperator, ancestors: list[str]): cls.path = '.'.join(ancestors + [cls.name]) for op in operator.preregistered_operators: if op.parent_class == cls: set_path(op, ancestors + [cls.name]) - + for cls in operator.preregistered_operators: if cls.parent_class is None: set_path(cls, []) @@ -41,14 +40,18 @@ class CommandManager: # 应用命令权限配置 for cls in operator.preregistered_operators: if cls.path in self.ap.instance_config.data['command']['privilege']: - cls.lowest_privilege = self.ap.instance_config.data['command']['privilege'][cls.path] + cls.lowest_privilege = self.ap.instance_config.data['command'][ + 'privilege' + ][cls.path] # 实例化所有类 self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators] # 设置所有类的子节点 for cmd in self.cmd_list: - cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__] + cmd.children = [ + child for child in self.cmd_list if child.parent_class == cmd.__class__ + ] # 初始化所有类 for cmd in self.cmd_list: @@ -58,27 +61,25 @@ class CommandManager: self, context: entities.ExecuteContext, operator_list: list[operator.CommandOperator], - operator: operator.CommandOperator = None + operator: operator.CommandOperator = None, ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - """执行命令 - """ + """执行命令""" found = False if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名 for oper in operator_list: - if (context.crt_params[0] == oper.name \ - or context.crt_params[0] in oper.alias) \ - and (oper.parent_class is None or oper.parent_class == operator.__class__): + if ( + context.crt_params[0] == oper.name + or context.crt_params[0] in oper.alias + ) and ( + oper.parent_class is None or oper.parent_class == operator.__class__ + ): found = True context.crt_command = context.crt_params[0] context.crt_params = context.crt_params[1:] - async for ret in self._execute( - context, - oper.children, - oper - ): + async for ret in self._execute(context, oper.children, oper): yield ret break @@ -96,19 +97,20 @@ class CommandManager: async for ret in operator.execute(context): yield ret - async def execute( self, command_text: str, query: core_entities.Query, - session: core_entities.Session + session: core_entities.Session, ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - """执行命令 - """ + """执行命令""" privilege = 1 - if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']: + if ( + f'{query.launcher_type.value}_{query.launcher_id}' + in self.ap.instance_config.data['admins'] + ): privilege = 2 ctx = entities.ExecuteContext( @@ -119,11 +121,8 @@ class CommandManager: crt_command='', params=command_text.split(' '), crt_params=command_text.split(' '), - privilege=privilege + privilege=privilege, ) - async for ret in self._execute( - ctx, - self.cmd_list - ): + async for ret in self._execute(ctx, self.cmd_list): yield ret diff --git a/pkg/command/entities.py b/pkg/command/entities.py index 538766bf..cccd588e 100644 --- a/pkg/command/entities.py +++ b/pkg/command/entities.py @@ -4,14 +4,13 @@ import typing import pydantic.v1 as pydantic -from ..core import app, entities as core_entities -from . import errors, operator +from ..core import entities as core_entities +from . import errors from ..platform.types import message as platform_message class CommandReturn(pydantic.BaseModel): - """命令返回值 - """ + """命令返回值""" text: typing.Optional[str] = None """文本 @@ -24,7 +23,7 @@ class CommandReturn(pydantic.BaseModel): """图片链接 """ - error: typing.Optional[errors.CommandError]= None + error: typing.Optional[errors.CommandError] = None """错误 """ @@ -33,8 +32,7 @@ class CommandReturn(pydantic.BaseModel): class ExecuteContext(pydantic.BaseModel): - """单次命令执行上下文 - """ + """单次命令执行上下文""" query: core_entities.Query """本次消息的请求对象""" diff --git a/pkg/command/errors.py b/pkg/command/errors.py index 5bc253f6..df05b3d1 100644 --- a/pkg/command/errors.py +++ b/pkg/command/errors.py @@ -1,33 +1,26 @@ - - class CommandError(Exception): - def __init__(self, message: str = None): self.message = message - + def __str__(self): return self.message class CommandNotFoundError(CommandError): - def __init__(self, message: str = None): - super().__init__("未知命令: "+message) + super().__init__('未知命令: ' + message) class CommandPrivilegeError(CommandError): - def __init__(self, message: str = None): - super().__init__("权限不足: "+message) + super().__init__('权限不足: ' + message) class ParamNotEnoughError(CommandError): - def __init__(self, message: str = None): - super().__init__("参数不足: "+message) + super().__init__('参数不足: ' + message) class CommandOperationError(CommandError): - def __init__(self, message: str = None): - super().__init__("操作失败: "+message) + super().__init__('操作失败: ' + message) diff --git a/pkg/command/operator.py b/pkg/command/operator.py index 5e3b1a8f..7072edf7 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing import abc -from ..core import app, entities as core_entities +from ..core import app from . import entities @@ -13,14 +13,14 @@ preregistered_operators: list[typing.Type[CommandOperator]] = [] def operator_class( name: str, - help: str = "", + help: str = '', usage: str = None, alias: list[str] = [], - privilege: int=1, # 1为普通用户,2为管理员 - parent_class: typing.Type[CommandOperator] = None + privilege: int = 1, # 1为普通用户,2为管理员 + parent_class: typing.Type[CommandOperator] = None, ) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: """命令类装饰器 - + Args: name (str): 名称 help (str, optional): 帮助信息. Defaults to "". @@ -35,7 +35,7 @@ def operator_class( def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]: assert issubclass(cls, CommandOperator) - + cls.name = name cls.alias = alias cls.help = help @@ -96,14 +96,13 @@ class CommandOperator(metaclass=abc.ABCMeta): @abc.abstractmethod async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: """实现此方法以执行命令 支持多次yield以返回多个结果。 例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。 - + Args: context (entities.ExecuteContext): 命令执行上下文 diff --git a/pkg/command/operators/cmd.py b/pkg/command/operators/cmd.py index 17b5ed08..a13d5b35 100644 --- a/pkg/command/operators/cmd.py +++ b/pkg/command/operators/cmd.py @@ -2,49 +2,46 @@ from __future__ import annotations import typing -from .. import operator, entities, cmdmgr, errors +from .. import operator, entities, errors -@operator.operator_class( - name="cmd", - help='显示命令列表', - usage='!cmd\n!cmd <命令名称>' -) +@operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>') class CmdOperator(operator.CommandOperator): - """命令列表 - """ + """命令列表""" async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - """执行 - """ + """执行""" if len(context.crt_params) == 0: - reply_str = "当前所有命令: \n\n" + reply_str = '当前所有命令: \n\n' for cmd in self.ap.cmd_mgr.cmd_list: if cmd.parent_class is None: - reply_str += f"{cmd.name}: {cmd.help}\n" - - reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助" + reply_str += f'{cmd.name}: {cmd.help}\n' + + reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助' yield entities.CommandReturn(text=reply_str.strip()) - + else: cmd_name = context.crt_params[0] cmd = None for _cmd in self.ap.cmd_mgr.cmd_list: - if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None): + if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and ( + _cmd.parent_class is None + ): cmd = _cmd break if cmd is None: - yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name)) + yield entities.CommandReturn( + error=errors.CommandNotFoundError(cmd_name) + ) else: - reply_str = f"{cmd.name}: {cmd.help}\n\n" - reply_str += f"使用方法: \n{cmd.usage}" + reply_str = f'{cmd.name}: {cmd.help}\n\n' + reply_str += f'使用方法: \n{cmd.usage}' yield entities.CommandReturn(text=reply_str.strip()) diff --git a/pkg/command/operators/delc.py b/pkg/command/operators/delc.py index db865ff7..9ae507f5 100644 --- a/pkg/command/operators/delc.py +++ b/pkg/command/operators/delc.py @@ -1,62 +1,60 @@ from __future__ import annotations import typing -import datetime -from .. import operator, entities, cmdmgr, errors +from .. import operator, entities, errors @operator.operator_class( - name="del", - help="删除当前会话的历史记录", - usage='!del <序号>\n!del all' + name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all' ) class DelOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - if context.session.conversations: delete_index = 0 if len(context.crt_params) > 0: try: delete_index = int(context.crt_params[0]) - except: - yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数')) + except Exception: + yield entities.CommandReturn( + error=errors.CommandOperationError('索引必须是整数') + ) return - - if delete_index < 0 or delete_index >= len(context.session.conversations): - yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围')) - return - - # 倒序 - to_delete_index = len(context.session.conversations)-1-delete_index - if context.session.conversations[to_delete_index] == context.session.using_conversation: + if delete_index < 0 or delete_index >= len(context.session.conversations): + yield entities.CommandReturn( + error=errors.CommandOperationError('索引超出范围') + ) + return + + # 倒序 + to_delete_index = len(context.session.conversations) - 1 - delete_index + + if ( + context.session.conversations[to_delete_index] + == context.session.using_conversation + ): context.session.using_conversation = None del context.session.conversations[to_delete_index] - yield entities.CommandReturn(text=f"已删除对话: {delete_index}") + yield entities.CommandReturn(text=f'已删除对话: {delete_index}') else: - yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + yield entities.CommandReturn( + error=errors.CommandOperationError('当前没有对话') + ) @operator.operator_class( - name="all", - help="删除此会话的所有历史记录", - parent_class=DelOperator + name='all', help='删除此会话的所有历史记录', parent_class=DelOperator ) class DelAllOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - context.session.conversations = [] context.session.using_conversation = None - yield entities.CommandReturn(text="已删除所有对话") \ No newline at end of file + yield entities.CommandReturn(text='已删除所有对话') diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py index ae2ba4c1..9cb3fd32 100644 --- a/pkg/command/operators/func.py +++ b/pkg/command/operators/func.py @@ -1,16 +1,15 @@ from __future__ import annotations from typing import AsyncGenerator -from .. import operator, entities, cmdmgr -from ...plugin import context as plugin_context +from .. import operator, entities -@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func') +@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func') class FuncOperator(operator.CommandOperator): async def execute( self, context: entities.ExecuteContext ) -> AsyncGenerator[entities.CommandReturn, None]: - reply_str = "当前已启用的内容函数: \n\n" + reply_str = '当前已启用的内容函数: \n\n' index = 1 @@ -19,7 +18,7 @@ class FuncOperator(operator.CommandOperator): ) for func in all_functions: - reply_str += "{}. {}:\n{}\n\n".format( + reply_str += '{}. {}:\n{}\n\n'.format( index, func.name, func.description, diff --git a/pkg/command/operators/help.py b/pkg/command/operators/help.py index d8b42137..c718d4b9 100644 --- a/pkg/command/operators/help.py +++ b/pkg/command/operators/help.py @@ -2,19 +2,13 @@ from __future__ import annotations import typing -from .. import operator, entities, cmdmgr, errors +from .. import operator, entities -@operator.operator_class( - name='help', - help='显示帮助', - usage='!help\n!help <命令名称>' -) +@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>') class HelpOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app' diff --git a/pkg/command/operators/last.py b/pkg/command/operators/last.py index e7a14c83..7e2f2453 100644 --- a/pkg/command/operators/last.py +++ b/pkg/command/operators/last.py @@ -1,36 +1,43 @@ from __future__ import annotations import typing -import datetime -from .. import operator, entities, cmdmgr, errors +from .. import operator, entities, errors -@operator.operator_class( - name="last", - help="切换到前一个对话", - usage='!last' -) +@operator.operator_class(name='last', help='切换到前一个对话', usage='!last') class LastOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - if context.session.conversations: # 找到当前会话的上一个会话 - for index in range(len(context.session.conversations)-1, -1, -1): - if context.session.conversations[index] == context.session.using_conversation: + for index in range(len(context.session.conversations) - 1, -1, -1): + if ( + context.session.conversations[index] + == context.session.using_conversation + ): if index == 0: - yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了')) + yield entities.CommandReturn( + error=errors.CommandOperationError('已经是第一个对话了') + ) return else: - context.session.using_conversation = context.session.conversations[index-1] - time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") + context.session.using_conversation = ( + context.session.conversations[index - 1] + ) + time_str = ( + context.session.using_conversation.create_time.strftime( + '%Y-%m-%d %H:%M:%S' + ) + ) - yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}") + yield entities.CommandReturn( + text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}' + ) return else: - yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) \ No newline at end of file + yield entities.CommandReturn( + error=errors.CommandOperationError('当前没有对话') + ) diff --git a/pkg/command/operators/list.py b/pkg/command/operators/list.py index ff90d4dd..1aa63c94 100644 --- a/pkg/command/operators/list.py +++ b/pkg/command/operators/list.py @@ -1,30 +1,26 @@ from __future__ import annotations import typing -import datetime -from .. import operator, entities, cmdmgr, errors +from .. import operator, entities, errors @operator.operator_class( - name="list", - help="列出此会话中的所有历史对话", - usage='!list\n!list <页码>' + name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>' ) class ListOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - page = 0 if len(context.crt_params) > 0: try: - page = int(context.crt_params[0]-1) - except: - yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数')) + page = int(context.crt_params[0] - 1) + except Exception: + yield entities.CommandReturn( + error=errors.CommandOperationError('页码应为整数') + ) return record_per_page = 10 @@ -36,21 +32,21 @@ class ListOperator(operator.CommandOperator): using_conv_index = 0 for conv in context.session.conversations[::-1]: - time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S") + time_str = conv.create_time.strftime('%Y-%m-%d %H:%M:%S') if conv == context.session.using_conversation: using_conv_index = index if index >= page * record_per_page and index < (page + 1) * record_per_page: - content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n" + content += f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n' index += 1 if content == '': content = '无' else: if context.session.using_conversation is None: - content += "\n当前处于新会话" + content += '\n当前处于新会话' else: - content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}" - - yield entities.CommandReturn(text=f"第 {page + 1} 页 (时间倒序):\n{content}") + content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}' + + yield entities.CommandReturn(text=f'第 {page + 1} 页 (时间倒序):\n{content}') diff --git a/pkg/command/operators/model.py b/pkg/command/operators/model.py index f46c9590..cc3ef5b9 100644 --- a/pkg/command/operators/model.py +++ b/pkg/command/operators/model.py @@ -2,42 +2,44 @@ from __future__ import annotations import typing -from .. import operator, entities, cmdmgr, errors +from .. import operator, entities, errors + @operator.operator_class( - name="model", + name='model', help='显示和切换模型列表', usage='!model\n!model show <模型名>\n!model set <模型名>', - privilege=2 + privilege=2, ) class ModelOperator(operator.CommandOperator): """Model命令""" - - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + async def execute( + self, context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: content = '模型列表:\n' model_list = self.ap.model_mgr.model_list for model in model_list: - content += f"\n名称: {model.name}\n" - content += f"请求器: {model.requester.name}\n" + content += f'\n名称: {model.name}\n' + content += f'请求器: {model.requester.name}\n' - content += f"\n当前对话使用模型: {context.query.use_model.name}\n" - content += f"新对话默认使用模型: {self.ap.provider_cfg.data.get('model')}\n" + content += f'\n当前对话使用模型: {context.query.use_model.name}\n' + content += f'新对话默认使用模型: {self.ap.provider_cfg.data.get("model")}\n' yield entities.CommandReturn(text=content.strip()) @operator.operator_class( - name="show", - help='显示模型详情', - privilege=2, - parent_class=ModelOperator + name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator ) class ModelShowOperator(operator.CommandOperator): """Model Show命令""" - - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + async def execute( + self, context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: model_name = context.crt_params[0] model = None @@ -47,29 +49,31 @@ class ModelShowOperator(operator.CommandOperator): break if model is None: - yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}")) + yield entities.CommandReturn( + error=errors.CommandError(f'未找到模型 {model_name}') + ) else: - content = f"模型详情\n" - content += f"名称: {model.name}\n" + content = '模型详情\n' + content += f'名称: {model.name}\n' if model.model_name is not None: - content += f"请求模型名称: {model.model_name}\n" - content += f"请求器: {model.requester.name}\n" - content += f"密钥组: {model.token_mgr.name}\n" - content += f"支持视觉: {model.vision_supported}\n" - content += f"支持工具: {model.tool_call_supported}\n" + content += f'请求模型名称: {model.model_name}\n' + content += f'请求器: {model.requester.name}\n' + content += f'密钥组: {model.token_mgr.name}\n' + content += f'支持视觉: {model.vision_supported}\n' + content += f'支持工具: {model.tool_call_supported}\n' yield entities.CommandReturn(text=content.strip()) + @operator.operator_class( - name="set", - help='设置默认使用模型', - privilege=2, - parent_class=ModelOperator + name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator ) class ModelSetOperator(operator.CommandOperator): """Model Set命令""" - - async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + async def execute( + self, context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: model_name = context.crt_params[0] model = None @@ -79,8 +83,12 @@ class ModelSetOperator(operator.CommandOperator): break if model is None: - yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}")) + yield entities.CommandReturn( + error=errors.CommandError(f'未找到模型 {model_name}') + ) else: self.ap.provider_cfg.data['model'] = model_name await self.ap.provider_cfg.dump_config() - yield entities.CommandReturn(text=f"已设置当前使用模型为 {model_name},重置会话以生效") + yield entities.CommandReturn( + text=f'已设置当前使用模型为 {model_name},重置会话以生效' + ) diff --git a/pkg/command/operators/next.py b/pkg/command/operators/next.py index 8f4b5a5a..ef5ae103 100644 --- a/pkg/command/operators/next.py +++ b/pkg/command/operators/next.py @@ -1,35 +1,42 @@ from __future__ import annotations import typing -import datetime -from .. import operator, entities, cmdmgr, errors +from .. import operator, entities, errors -@operator.operator_class( - name="next", - help="切换到后一个对话", - usage='!next' -) +@operator.operator_class(name='next', help='切换到后一个对话', usage='!next') class NextOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - if context.session.conversations: # 找到当前会话的下一个会话 for index in range(len(context.session.conversations)): - if context.session.conversations[index] == context.session.using_conversation: - if index == len(context.session.conversations)-1: - yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了')) + if ( + context.session.conversations[index] + == context.session.using_conversation + ): + if index == len(context.session.conversations) - 1: + yield entities.CommandReturn( + error=errors.CommandOperationError('已经是最后一个对话了') + ) return else: - context.session.using_conversation = context.session.conversations[index+1] - time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") + context.session.using_conversation = ( + context.session.conversations[index + 1] + ) + time_str = ( + context.session.using_conversation.create_time.strftime( + '%Y-%m-%d %H:%M:%S' + ) + ) - yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") + yield entities.CommandReturn( + text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}' + ) return else: - yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) \ No newline at end of file + yield entities.CommandReturn( + error=errors.CommandOperationError('当前没有对话') + ) diff --git a/pkg/command/operators/ollama.py b/pkg/command/operators/ollama.py index f5ed382d..7e65d440 100644 --- a/pkg/command/operators/ollama.py +++ b/pkg/command/operators/ollama.py @@ -2,31 +2,32 @@ from __future__ import annotations import json import typing -import traceback import ollama from .. import operator, entities, errors @operator.operator_class( - name="ollama", - help="ollama平台操作", - usage="!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>" + name='ollama', + help='ollama平台操作', + usage='!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>', ) class OllamaOperator(operator.CommandOperator): async def execute( - self, context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: try: content: str = '模型列表:\n' model_list: list = ollama.list().get('models', []) for model in model_list: - content += f"名称: {model['name']}\n" - content += f"修改时间: {model['modified_at']}\n" - content += f"大小: {bytes_to_mb(model['size'])}MB\n\n" - yield entities.CommandReturn(text=f"{content.strip()}") - except ollama.ResponseError as e: - yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常")) + content += f'名称: {model["name"]}\n' + content += f'修改时间: {model["modified_at"]}\n' + content += f'大小: {bytes_to_mb(model["size"])}MB\n\n' + yield entities.CommandReturn(text=f'{content.strip()}') + except ollama.ResponseError: + yield entities.CommandReturn( + error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常') + ) def bytes_to_mb(num_bytes): @@ -35,14 +36,11 @@ def bytes_to_mb(num_bytes): @operator.operator_class( - name="show", - help="ollama模型详情", - privilege=2, - parent_class=OllamaOperator + name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator ) class OllamaShowOperator(operator.CommandOperator): async def execute( - self, context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: content: str = '模型详情:\n' try: @@ -53,31 +51,36 @@ class OllamaShowOperator(operator.CommandOperator): for key in ['license', 'modelfile']: show[key] = ignore_show - for key in ['tokenizer.chat_template.rag', 'tokenizer.chat_template.tool_use']: + for key in [ + 'tokenizer.chat_template.rag', + 'tokenizer.chat_template.tool_use', + ]: model_info[key] = ignore_show content += json.dumps(show, indent=4) yield entities.CommandReturn(text=content.strip()) - except ollama.ResponseError as e: - yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型详情,请确认 Ollama 服务正常")) + except ollama.ResponseError: + yield entities.CommandReturn( + error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常') + ) + @operator.operator_class( - name="pull", - help="ollama模型拉取", - privilege=2, - parent_class=OllamaOperator + name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator ) class OllamaPullOperator(operator.CommandOperator): async def execute( - self, context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: try: model_list: list = ollama.list().get('models', []) if context.crt_params[0] in [model['name'] for model in model_list]: - yield entities.CommandReturn(text="模型已存在") + yield entities.CommandReturn(text='模型已存在') return - except ollama.ResponseError as e: - yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常")) + except ollama.ResponseError: + yield entities.CommandReturn( + error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常') + ) return on_progress: bool = False @@ -99,23 +102,21 @@ class OllamaPullOperator(operator.CommandOperator): if percentage_completed > progress_count: progress_count += 10 yield entities.CommandReturn( - text=f"下载进度: {completed}/{total} ({percentage_completed:.2f}%)") + text=f'下载进度: {completed}/{total} ({percentage_completed:.2f}%)' + ) except ollama.ResponseError as e: - yield entities.CommandReturn(text=f"拉取失败: {e.error}") + yield entities.CommandReturn(text=f'拉取失败: {e.error}') @operator.operator_class( - name="del", - help="ollama模型删除", - privilege=2, - parent_class=OllamaOperator + name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator ) class OllamaDelOperator(operator.CommandOperator): async def execute( - self, context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: try: ret: str = ollama.delete(model=context.crt_params[0])['status'] except ollama.ResponseError as e: - ret = f"{e.error}" + ret = f'{e.error}' yield entities.CommandReturn(text=ret) diff --git a/pkg/command/operators/plugin.py b/pkg/command/operators/plugin.py index c36fbbc3..1bf4c7af 100644 --- a/pkg/command/operators/plugin.py +++ b/pkg/command/operators/plugin.py @@ -2,31 +2,30 @@ from __future__ import annotations import typing import traceback -from .. import operator, entities, cmdmgr, errors -from ...core import app +from .. import operator, entities, errors @operator.operator_class( - name="plugin", - help="插件操作", - usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>" + name='plugin', + help='插件操作', + usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>', ) class PluginOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - plugin_list = self.ap.plugin_mgr.plugins() - reply_str = "所有插件({}):\n".format(len(plugin_list)) + reply_str = '所有插件({}):\n'.format(len(plugin_list)) idx = 0 for plugin in plugin_list: - reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\ - .format((idx+1), plugin.plugin_name, - "[已禁用]" if not plugin.enabled else "", - plugin.plugin_description, - plugin.plugin_version, plugin.plugin_author) + reply_str += '\n#{} {} {}\n{}\nv{}\n作者: {}\n'.format( + (idx + 1), + plugin.plugin_name, + '[已禁用]' if not plugin.enabled else '', + plugin.plugin_description, + plugin.plugin_version, + plugin.plugin_author, + ) idx += 1 @@ -34,48 +33,42 @@ class PluginOperator(operator.CommandOperator): @operator.operator_class( - name="get", - help="安装插件", - privilege=2, - parent_class=PluginOperator + name='get', help='安装插件', privilege=2, parent_class=PluginOperator ) class PluginGetOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址')) + yield entities.CommandReturn( + error=errors.ParamNotEnoughError('请提供插件仓库地址') + ) else: repo = context.crt_params[0] - yield entities.CommandReturn(text="正在安装插件...") + yield entities.CommandReturn(text='正在安装插件...') try: await self.ap.plugin_mgr.install_plugin(repo) - yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件") + yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件') except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e))) + yield entities.CommandReturn( + error=errors.CommandError('插件安装失败: ' + str(e)) + ) @operator.operator_class( - name="update", - help="更新插件", - privilege=2, - parent_class=PluginOperator + name='update', help='更新插件', privilege=2, parent_class=PluginOperator ) class PluginUpdateOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + yield entities.CommandReturn( + error=errors.ParamNotEnoughError('请提供插件名称') + ) else: plugin_name = context.crt_params[0] @@ -83,36 +76,34 @@ class PluginUpdateOperator(operator.CommandOperator): plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) if plugin_container is not None: - yield entities.CommandReturn(text="正在更新插件...") + yield entities.CommandReturn(text='正在更新插件...') await self.ap.plugin_mgr.update_plugin(plugin_name) - yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件") + yield entities.CommandReturn( + text='插件更新成功,请重启程序以加载插件' + ) else: - yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件")) + yield entities.CommandReturn( + error=errors.CommandError('插件更新失败: 未找到插件') + ) except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) + yield entities.CommandReturn( + error=errors.CommandError('插件更新失败: ' + str(e)) + ) + @operator.operator_class( - name="all", - help="更新所有插件", - privilege=2, - parent_class=PluginUpdateOperator + name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator ) class PluginUpdateAllOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - try: - plugins = [ - p.plugin_name - for p in self.ap.plugin_mgr.plugins() - ] + plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()] if plugins: - yield entities.CommandReturn(text="正在更新插件...") + yield entities.CommandReturn(text='正在更新插件...') updated = [] try: for plugin_name in plugins: @@ -120,30 +111,32 @@ class PluginUpdateAllOperator(operator.CommandOperator): updated.append(plugin_name) except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) - yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated))) + yield entities.CommandReturn( + error=errors.CommandError('插件更新失败: ' + str(e)) + ) + yield entities.CommandReturn( + text='已更新插件: {}'.format(', '.join(updated)) + ) else: - yield entities.CommandReturn(text="没有可更新的插件") + yield entities.CommandReturn(text='没有可更新的插件') except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) + yield entities.CommandReturn( + error=errors.CommandError('插件更新失败: ' + str(e)) + ) @operator.operator_class( - name="del", - help="删除插件", - privilege=2, - parent_class=PluginOperator + name='del', help='删除插件', privilege=2, parent_class=PluginOperator ) class PluginDelOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + yield entities.CommandReturn( + error=errors.ParamNotEnoughError('请提供插件名称') + ) else: plugin_name = context.crt_params[0] @@ -151,67 +144,81 @@ class PluginDelOperator(operator.CommandOperator): plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) if plugin_container is not None: - yield entities.CommandReturn(text="正在删除插件...") + yield entities.CommandReturn(text='正在删除插件...') await self.ap.plugin_mgr.uninstall_plugin(plugin_name) - yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件") + yield entities.CommandReturn( + text='插件删除成功,请重启程序以加载插件' + ) else: - yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件")) + yield entities.CommandReturn( + error=errors.CommandError('插件删除失败: 未找到插件') + ) except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e))) + yield entities.CommandReturn( + error=errors.CommandError('插件删除失败: ' + str(e)) + ) @operator.operator_class( - name="on", - help="启用插件", - privilege=2, - parent_class=PluginOperator + name='on', help='启用插件', privilege=2, parent_class=PluginOperator ) class PluginEnableOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + yield entities.CommandReturn( + error=errors.ParamNotEnoughError('请提供插件名称') + ) else: plugin_name = context.crt_params[0] try: if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True): - yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name)) + yield entities.CommandReturn( + text='已启用插件: {}'.format(plugin_name) + ) else: - yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) + yield entities.CommandReturn( + error=errors.CommandError( + '插件状态修改失败: 未找到插件 {}'.format(plugin_name) + ) + ) except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e))) + yield entities.CommandReturn( + error=errors.CommandError('插件状态修改失败: ' + str(e)) + ) @operator.operator_class( - name="off", - help="禁用插件", - privilege=2, - parent_class=PluginOperator + name='off', help='禁用插件', privilege=2, parent_class=PluginOperator ) class PluginDisableOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + yield entities.CommandReturn( + error=errors.ParamNotEnoughError('请提供插件名称') + ) else: plugin_name = context.crt_params[0] try: if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False): - yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name)) + yield entities.CommandReturn( + text='已禁用插件: {}'.format(plugin_name) + ) else: - yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) + yield entities.CommandReturn( + error=errors.CommandError( + '插件状态修改失败: 未找到插件 {}'.format(plugin_name) + ) + ) except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e))) + yield entities.CommandReturn( + error=errors.CommandError('插件状态修改失败: ' + str(e)) + ) diff --git a/pkg/command/operators/prompt.py b/pkg/command/operators/prompt.py index 29d688a6..41f42de4 100644 --- a/pkg/command/operators/prompt.py +++ b/pkg/command/operators/prompt.py @@ -2,28 +2,23 @@ from __future__ import annotations import typing -from .. import operator, entities, cmdmgr, errors +from .. import operator, entities, errors -@operator.operator_class( - name="prompt", - help="查看当前对话的前文", - usage='!prompt' -) +@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt') class PromptOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - """执行 - """ + """执行""" if context.session.using_conversation is None: - yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + yield entities.CommandReturn( + error=errors.CommandOperationError('当前没有对话') + ) else: reply_str = '当前对话所有内容:\n\n' for msg in context.session.using_conversation.messages: - reply_str += f"{msg.role}: {msg.content}\n" + reply_str += f'{msg.role}: {msg.content}\n' - yield entities.CommandReturn(text=reply_str) \ No newline at end of file + yield entities.CommandReturn(text=reply_str) diff --git a/pkg/command/operators/resend.py b/pkg/command/operators/resend.py index 6d930413..44e5a35c 100644 --- a/pkg/command/operators/resend.py +++ b/pkg/command/operators/resend.py @@ -2,26 +2,22 @@ from __future__ import annotations import typing -from .. import operator, entities, cmdmgr, errors +from .. import operator, entities, errors @operator.operator_class( - name="resend", - help="重发当前会话的最后一条消息", - usage='!resend' + name='resend', help='重发当前会话的最后一条消息', usage='!resend' ) class ResendOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: # 回滚到最后一条用户message前 if context.session.using_conversation is None: - yield entities.CommandReturn(error=errors.CommandError("当前没有对话")) + yield entities.CommandReturn(error=errors.CommandError('当前没有对话')) else: conv_msg = context.session.using_conversation.messages - + # 倒序一直删到最后一条用户message while len(conv_msg) > 0 and conv_msg[-1].role != 'user': conv_msg.pop() @@ -31,4 +27,4 @@ class ResendOperator(operator.CommandOperator): conv_msg.pop() # 不重发了,提示用户已删除就行了 - yield entities.CommandReturn(text="已删除最后一次请求记录") + yield entities.CommandReturn(text='已删除最后一次请求记录') diff --git a/pkg/command/operators/reset.py b/pkg/command/operators/reset.py index 5d1402ac..7ef54e08 100644 --- a/pkg/command/operators/reset.py +++ b/pkg/command/operators/reset.py @@ -2,22 +2,15 @@ from __future__ import annotations import typing -from .. import operator, entities, cmdmgr, errors +from .. import operator, entities -@operator.operator_class( - name="reset", - help="重置当前会话", - usage='!reset' -) +@operator.operator_class(name='reset', help='重置当前会话', usage='!reset') class ResetOperator(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - """执行 - """ + """执行""" context.session.using_conversation = None - yield entities.CommandReturn(text="已重置当前会话") + yield entities.CommandReturn(text='已重置当前会话') diff --git a/pkg/command/operators/update.py b/pkg/command/operators/update.py index 524a26dd..775ee26a 100644 --- a/pkg/command/operators/update.py +++ b/pkg/command/operators/update.py @@ -3,28 +3,22 @@ from __future__ import annotations import typing import traceback -from .. import operator, entities, cmdmgr, errors +from .. import operator, entities, errors -@operator.operator_class( - name="update", - help="更新程序", - usage='!update', - privilege=2 -) +@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2) class UpdateCommand(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - try: - yield entities.CommandReturn(text="正在进行更新...") + yield entities.CommandReturn(text='正在进行更新...') if await self.ap.ver_mgr.update_all(): - yield entities.CommandReturn(text="更新完成,请重启程序以应用更新") + yield entities.CommandReturn(text='更新完成,请重启程序以应用更新') else: - yield entities.CommandReturn(text="当前已是最新版本") + yield entities.CommandReturn(text='当前已是最新版本') except Exception as e: traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("更新失败: "+str(e))) \ No newline at end of file + yield entities.CommandReturn( + error=errors.CommandError('更新失败: ' + str(e)) + ) diff --git a/pkg/command/operators/version.py b/pkg/command/operators/version.py index a5d7a81b..267b1113 100644 --- a/pkg/command/operators/version.py +++ b/pkg/command/operators/version.py @@ -2,26 +2,20 @@ from __future__ import annotations import typing -from .. import operator, cmdmgr, entities, errors +from .. import operator, entities -@operator.operator_class( - name="version", - help="显示版本信息", - usage='!version' -) +@operator.operator_class(name='version', help='显示版本信息', usage='!version') class VersionCommand(operator.CommandOperator): - async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - reply_str = f"当前版本: \n{self.ap.ver_mgr.get_current_version()}" + reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}' try: if await self.ap.ver_mgr.is_new_version_available(): - reply_str += "\n\n有新版本可用。" - except: + reply_str += '\n\n有新版本可用。' + except Exception: pass - yield entities.CommandReturn(text=reply_str.strip()) \ No newline at end of file + yield entities.CommandReturn(text=reply_str.strip()) diff --git a/pkg/config/impls/json.py b/pkg/config/impls/json.py index e414e451..07fc533c 100644 --- a/pkg/config/impls/json.py +++ b/pkg/config/impls/json.py @@ -9,7 +9,10 @@ class JSONConfigFile(file_model.ConfigFile): """JSON配置文件""" def __init__( - self, config_file_name: str, template_file_name: str = None, template_data: dict = None + self, + config_file_name: str, + template_file_name: str = None, + template_data: dict = None, ) -> None: self.config_file_name = config_file_name self.template_file_name = template_file_name @@ -22,28 +25,26 @@ class JSONConfigFile(file_model.ConfigFile): if self.template_file_name is not None: shutil.copyfile(self.template_file_name, self.config_file_name) elif self.template_data is not None: - with open(self.config_file_name, "w", encoding="utf-8") as f: + with open(self.config_file_name, 'w', encoding='utf-8') as f: json.dump(self.template_data, f, indent=4, ensure_ascii=False) else: - raise ValueError("template_file_name or template_data must be provided") - - async def load(self, completion: bool=True) -> dict: + raise ValueError('template_file_name or template_data must be provided') + async def load(self, completion: bool = True) -> dict: if not self.exists(): await self.create() if self.template_file_name is not None: - with open(self.template_file_name, "r", encoding="utf-8") as f: + with open(self.template_file_name, 'r', encoding='utf-8') as f: self.template_data = json.load(f) - with open(self.config_file_name, "r", encoding="utf-8") as f: + with open(self.config_file_name, 'r', encoding='utf-8') as f: try: cfg = json.load(f) except json.JSONDecodeError as e: - raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}") + raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}') if completion: - for key in self.template_data: if key not in cfg: cfg[key] = self.template_data[key] @@ -51,9 +52,9 @@ class JSONConfigFile(file_model.ConfigFile): return cfg async def save(self, cfg: dict): - with open(self.config_file_name, "w", encoding="utf-8") as f: + with open(self.config_file_name, 'w', encoding='utf-8') as f: json.dump(cfg, f, indent=4, ensure_ascii=False) def save_sync(self, cfg: dict): - with open(self.config_file_name, "w", encoding="utf-8") as f: + with open(self.config_file_name, 'w', encoding='utf-8') as f: json.dump(cfg, f, indent=4, ensure_ascii=False) diff --git a/pkg/config/impls/pymodule.py b/pkg/config/impls/pymodule.py index 67e5867d..2311992e 100644 --- a/pkg/config/impls/pymodule.py +++ b/pkg/config/impls/pymodule.py @@ -25,10 +25,10 @@ class PythonModuleConfigFile(file_model.ConfigFile): async def create(self): shutil.copyfile(self.template_file_name, self.config_file_name) - async def load(self, completion: bool=True) -> dict: + async def load(self, completion: bool = True) -> dict: module_name = os.path.splitext(os.path.basename(self.config_file_name))[0] module = importlib.import_module(module_name) - + cfg = {} allowed_types = (int, float, str, bool, list, dict) @@ -63,4 +63,4 @@ class PythonModuleConfigFile(file_model.ConfigFile): logging.warning('Python模块配置文件不支持保存') def save_sync(self, data: dict): - logging.warning('Python模块配置文件不支持保存') \ No newline at end of file + logging.warning('Python模块配置文件不支持保存') diff --git a/pkg/config/impls/yaml.py b/pkg/config/impls/yaml.py index f4518003..55045186 100644 --- a/pkg/config/impls/yaml.py +++ b/pkg/config/impls/yaml.py @@ -9,7 +9,10 @@ class YAMLConfigFile(file_model.ConfigFile): """YAML配置文件""" def __init__( - self, config_file_name: str, template_file_name: str = None, template_data: dict = None + self, + config_file_name: str, + template_file_name: str = None, + template_data: dict = None, ) -> None: self.config_file_name = config_file_name self.template_file_name = template_file_name @@ -22,28 +25,26 @@ class YAMLConfigFile(file_model.ConfigFile): if self.template_file_name is not None: shutil.copyfile(self.template_file_name, self.config_file_name) elif self.template_data is not None: - with open(self.config_file_name, "w", encoding="utf-8") as f: + with open(self.config_file_name, 'w', encoding='utf-8') as f: yaml.dump(self.template_data, f, indent=4, allow_unicode=True) else: - raise ValueError("template_file_name or template_data must be provided") - - async def load(self, completion: bool=True) -> dict: + raise ValueError('template_file_name or template_data must be provided') + async def load(self, completion: bool = True) -> dict: if not self.exists(): await self.create() if self.template_file_name is not None: - with open(self.template_file_name, "r", encoding="utf-8") as f: + with open(self.template_file_name, 'r', encoding='utf-8') as f: self.template_data = yaml.load(f, Loader=yaml.FullLoader) - with open(self.config_file_name, "r", encoding="utf-8") as f: + with open(self.config_file_name, 'r', encoding='utf-8') as f: try: cfg = yaml.load(f, Loader=yaml.FullLoader) except yaml.YAMLError as e: - raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}") + raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}') if completion: - for key in self.template_data: if key not in cfg: cfg[key] = self.template_data[key] @@ -51,9 +52,9 @@ class YAMLConfigFile(file_model.ConfigFile): return cfg async def save(self, cfg: dict): - with open(self.config_file_name, "w", encoding="utf-8") as f: + with open(self.config_file_name, 'w', encoding='utf-8') as f: yaml.dump(cfg, f, indent=4, allow_unicode=True) def save_sync(self, cfg: dict): - with open(self.config_file_name, "w", encoding="utf-8") as f: - yaml.dump(cfg, f, indent=4, allow_unicode=True) \ No newline at end of file + with open(self.config_file_name, 'w', encoding='utf-8') as f: + yaml.dump(cfg, f, indent=4, allow_unicode=True) diff --git a/pkg/config/manager.py b/pkg/config/manager.py index 4421003c..2385c6b5 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -6,7 +6,7 @@ from .impls import pymodule, json as json_file, yaml as yaml_file class ConfigManager: """配置文件管理器""" - + name: str = None """配置管理器名""" @@ -31,7 +31,7 @@ class ConfigManager: self.file = cfg_file self.data = {} - async def load_config(self, completion: bool=True): + async def load_config(self, completion: bool = True): self.data = await self.file.load(completion=completion) async def dump_config(self): @@ -41,9 +41,11 @@ class ConfigManager: self.file.save_sync(self.data) -async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager: +async def load_python_module_config( + config_name: str, template_name: str, completion: bool = True +) -> ConfigManager: """加载Python模块配置文件 - + Args: config_name (str): 配置文件名 template_name (str): 模板文件名 @@ -52,10 +54,7 @@ async def load_python_module_config(config_name: str, template_name: str, comple Returns: ConfigManager: 配置文件管理器 """ - cfg_inst = pymodule.PythonModuleConfigFile( - config_name, - template_name - ) + cfg_inst = pymodule.PythonModuleConfigFile(config_name, template_name) cfg_mgr = ConfigManager(cfg_inst) await cfg_mgr.load_config(completion=completion) @@ -63,20 +62,21 @@ async def load_python_module_config(config_name: str, template_name: str, comple return cfg_mgr -async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager: +async def load_json_config( + config_name: str, + template_name: str = None, + template_data: dict = None, + completion: bool = True, +) -> ConfigManager: """加载JSON配置文件 - + Args: config_name (str): 配置文件名 template_name (str): 模板文件名 template_data (dict): 模板数据 completion (bool): 是否自动补全内存中的配置文件 """ - cfg_inst = json_file.JSONConfigFile( - config_name, - template_name, - template_data - ) + cfg_inst = json_file.JSONConfigFile(config_name, template_name, template_data) cfg_mgr = ConfigManager(cfg_inst) await cfg_mgr.load_config(completion=completion) @@ -84,9 +84,14 @@ async def load_json_config(config_name: str, template_name: str=None, template_d return cfg_mgr -async def load_yaml_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager: +async def load_yaml_config( + config_name: str, + template_name: str = None, + template_data: dict = None, + completion: bool = True, +) -> ConfigManager: """加载YAML配置文件 - + Args: config_name (str): 配置文件名 template_name (str): 模板文件名 @@ -96,11 +101,7 @@ async def load_yaml_config(config_name: str, template_name: str=None, template_d Returns: ConfigManager: 配置文件管理器 """ - cfg_inst = yaml_file.YAMLConfigFile( - config_name, - template_name, - template_data - ) + cfg_inst = yaml_file.YAMLConfigFile(config_name, template_name, template_data) cfg_mgr = ConfigManager(cfg_inst) await cfg_mgr.load_config(completion=completion) diff --git a/pkg/config/model.py b/pkg/config/model.py index 153123e3..f3536804 100644 --- a/pkg/config/model.py +++ b/pkg/config/model.py @@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def load(self, completion: bool=True) -> dict: + async def load(self, completion: bool = True) -> dict: pass @abc.abstractmethod diff --git a/pkg/core/app.py b/pkg/core/app.py index 9e337efb..6e631e5c 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -2,9 +2,7 @@ from __future__ import annotations import logging import asyncio -import threading import traceback -import enum import sys import os @@ -29,7 +27,6 @@ from ..discover import engine as discover_engine from ..utils import logcache, ip from . import taskmgr from . import entities as core_entities -from .bootutils import config class Application: @@ -123,33 +120,55 @@ class Application: async def run(self): try: await self.plugin_mgr.initialize_plugins() + # 后续可能会允许动态重启其他任务 # 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程 async def never_ending(): while True: await asyncio.sleep(1) - self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM]) - self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION]) - self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION]) - self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION]) + self.task_mgr.create_task( + self.platform_mgr.run(), + name='platform-manager', + scopes=[ + core_entities.LifecycleControlScope.APPLICATION, + core_entities.LifecycleControlScope.PLATFORM, + ], + ) + self.task_mgr.create_task( + self.ctrl.run(), + name='query-controller', + scopes=[core_entities.LifecycleControlScope.APPLICATION], + ) + self.task_mgr.create_task( + self.http_ctrl.run(), + name='http-api-controller', + scopes=[core_entities.LifecycleControlScope.APPLICATION], + ) + self.task_mgr.create_task( + never_ending(), + name='never-ending-task', + scopes=[core_entities.LifecycleControlScope.APPLICATION], + ) await self.print_web_access_info() await self.task_mgr.wait_all() except asyncio.CancelledError: pass except Exception as e: - self.logger.error(f"应用运行致命异常: {e}") - self.logger.debug(f"Traceback: {traceback.format_exc()}") + self.logger.error(f'应用运行致命异常: {e}') + self.logger.debug(f'Traceback: {traceback.format_exc()}') async def print_web_access_info(self): """打印访问 webui 的提示""" - if not os.path.exists(os.path.join(".", "web/out")): - self.logger.warning("WebUI 文件缺失,请根据文档获取:https://docs.langbot.app/webui/intro.html") + if not os.path.exists(os.path.join('.', 'web/out')): + self.logger.warning( + 'WebUI 文件缺失,请根据文档获取:https://docs.langbot.app/webui/intro.html' + ) return - host_ip = "127.0.0.1" + host_ip = '127.0.0.1' public_ip = await ip.get_myip() @@ -170,7 +189,7 @@ class Application: 🤯 WebUI 仍处于 Beta 测试阶段,如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues ======================================= """.strip() - for line in tips.split("\n"): + for line in tips.split('\n'): self.logger.info(line) async def reload( @@ -179,21 +198,28 @@ class Application: ): match scope: case core_entities.LifecycleControlScope.PLATFORM.value: - self.logger.info("执行热重载 scope="+scope) + self.logger.info('执行热重载 scope=' + scope) await self.platform_mgr.shutdown() self.platform_mgr = im_mgr.PlatformManager(self) await self.platform_mgr.initialize() - self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM]) + self.task_mgr.create_task( + self.platform_mgr.run(), + name='platform-manager', + scopes=[ + core_entities.LifecycleControlScope.APPLICATION, + core_entities.LifecycleControlScope.PLATFORM, + ], + ) case core_entities.LifecycleControlScope.PLUGIN.value: - self.logger.info("执行热重载 scope="+scope) + self.logger.info('执行热重载 scope=' + scope) await self.plugin_mgr.destroy_plugins() # 删除 sys.module 中所有的 plugins/* 下的模块 for mod in list(sys.modules.keys()): - if mod.startswith("plugins."): + if mod.startswith('plugins.'): del sys.modules[mod] self.plugin_mgr = plugin_mgr.PluginManager(self) @@ -204,7 +230,7 @@ class Application: await self.plugin_mgr.load_plugins() await self.plugin_mgr.initialize_plugins() case core_entities.LifecycleControlScope.PROVIDER.value: - self.logger.info("执行热重载 scope="+scope) + self.logger.info('执行热重载 scope=' + scope) await self.tool_mgr.shutdown() @@ -220,4 +246,4 @@ class Application: await llm_tool_mgr_inst.initialize() self.tool_mgr = llm_tool_mgr_inst case _: - pass \ No newline at end of file + pass diff --git a/pkg/core/boot.py b/pkg/core/boot.py index 307fa95c..e3f2a9da 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -7,29 +7,30 @@ import os from . import app from ..audit import identifier from . import stage -from ..utils import constants +from ..utils import constants, importutil # 引入启动阶段实现以便注册 -from .stages import load_config, setup_logger, build_app, migrate, show_notes, genkeys +from . import stages + +importutil.import_modules_in_pkg(stages) stage_order = [ - "LoadConfigStage", - "MigrationStage", - "GenKeysStage", - "SetupLoggerStage", - "BuildAppStage", - "ShowNotesStage" + 'LoadConfigStage', + 'MigrationStage', + 'GenKeysStage', + 'SetupLoggerStage', + 'BuildAppStage', + 'ShowNotesStage', ] async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application: - # 生成标识符 identifier.init() # 确定是否为调试模式 - if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]: + if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']: constants.debug_mode = True ap = app.Application() @@ -50,21 +51,17 @@ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application: async def main(loop: asyncio.AbstractEventLoop): try: - # 挂系统信号处理 import signal - ap: app.Application - def signal_handler(sig, frame): - print("[Signal] 程序退出.") + print('[Signal] 程序退出.') # ap.shutdown() os._exit(0) signal.signal(signal.SIGINT, signal_handler) app_inst = await make_app(loop) - ap = app_inst await app_inst.run() - except Exception as e: + except Exception: traceback.print_exc() diff --git a/pkg/core/bootutils/config.py b/pkg/core/bootutils/config.py index 794c329a..cea4af45 100644 --- a/pkg/core/bootutils/config.py +++ b/pkg/core/bootutils/config.py @@ -1,11 +1,9 @@ from __future__ import annotations -import json from ...config import manager as config_mgr -from ...config.impls import pymodule load_python_module_config = config_mgr.load_python_module_config load_json_config = config_mgr.load_json_config -load_yaml_config = config_mgr.load_yaml_config \ No newline at end of file +load_yaml_config = config_mgr.load_yaml_config diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index 8d8e7e00..b0ba7983 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -5,39 +5,39 @@ from ...utils import pkgmgr # 检查依赖,防止用户未安装 # 左边为引入名称,右边为依赖名称 required_deps = { - "requests": "requests", - "openai": "openai", - "anthropic": "anthropic", - "colorlog": "colorlog", - "aiocqhttp": "aiocqhttp", - "botpy": "qq-botpy-rc", - "PIL": "pillow", - "nakuru": "nakuru-project-idk", - "tiktoken": "tiktoken", - "yaml": "pyyaml", - "aiohttp": "aiohttp", - "psutil": "psutil", - "async_lru": "async-lru", - "ollama": "ollama", - "quart": "quart", - "quart_cors": "quart-cors", - "sqlalchemy": "sqlalchemy[asyncio]", - "aiosqlite": "aiosqlite", - "aiofiles": "aiofiles", - "aioshutil": "aioshutil", - "argon2": "argon2-cffi", - "jwt": "pyjwt", - "Crypto": "pycryptodome", - "lark_oapi": "lark-oapi", - "discord": "discord.py", - "cryptography": "cryptography", - "gewechat_client": "gewechat-client", - "dingtalk_stream": "dingtalk_stream", - "dashscope": "dashscope", - "telegram": "python-telegram-bot", - "certifi": "certifi", - "mcp": "mcp", - "sqlmodel": "sqlmodel", + 'requests': 'requests', + 'openai': 'openai', + 'anthropic': 'anthropic', + 'colorlog': 'colorlog', + 'aiocqhttp': 'aiocqhttp', + 'botpy': 'qq-botpy-rc', + 'PIL': 'pillow', + 'nakuru': 'nakuru-project-idk', + 'tiktoken': 'tiktoken', + 'yaml': 'pyyaml', + 'aiohttp': 'aiohttp', + 'psutil': 'psutil', + 'async_lru': 'async-lru', + 'ollama': 'ollama', + 'quart': 'quart', + 'quart_cors': 'quart-cors', + 'sqlalchemy': 'sqlalchemy[asyncio]', + 'aiosqlite': 'aiosqlite', + 'aiofiles': 'aiofiles', + 'aioshutil': 'aioshutil', + 'argon2': 'argon2-cffi', + 'jwt': 'pyjwt', + 'Crypto': 'pycryptodome', + 'lark_oapi': 'lark-oapi', + 'discord': 'discord.py', + 'cryptography': 'cryptography', + 'gewechat_client': 'gewechat-client', + 'dingtalk_stream': 'dingtalk_stream', + 'dashscope': 'dashscope', + 'telegram': 'python-telegram-bot', + 'certifi': 'certifi', + 'mcp': 'mcp', + 'sqlmodel': 'sqlmodel', } @@ -52,20 +52,25 @@ async def check_deps() -> list[str]: missing_deps.append(dep) return missing_deps + async def install_deps(deps: list[str]): global required_deps - + for dep in deps: - pip.main(["install", required_deps[dep]]) + pip.main(['install', required_deps[dep]]) + async def precheck_plugin_deps(): print('[Startup] Prechecking plugin dependencies...') # 只有在plugins目录存在时才执行插件依赖安装 - if os.path.exists("plugins"): - for dir in os.listdir("plugins"): - subdir = os.path.join("plugins", dir) + if os.path.exists('plugins'): + for dir in os.listdir('plugins'): + subdir = os.path.join('plugins', dir) if not os.path.isdir(subdir): continue if 'requirements.txt' in os.listdir(subdir): - pkgmgr.install_requirements(os.path.join(subdir, 'requirements.txt'), extra_params=['-q', '-q', '-q']) + pkgmgr.install_requirements( + os.path.join(subdir, 'requirements.txt'), + extra_params=['-q', '-q', '-q'], + ) diff --git a/pkg/core/bootutils/files.py b/pkg/core/bootutils/files.py index 9a2dff71..3599e41b 100644 --- a/pkg/core/bootutils/files.py +++ b/pkg/core/bootutils/files.py @@ -2,23 +2,23 @@ from __future__ import annotations import os import shutil -import sys required_files = { - "plugins/__init__.py": "templates/__init__.py", - "data/config.yaml": "templates/config.yaml", + 'plugins/__init__.py': 'templates/__init__.py', + 'data/config.yaml': 'templates/config.yaml', } required_paths = [ - "temp", - "data", - "data/metadata", - "data/logs", - "data/labels", - "plugins" + 'temp', + 'data', + 'data/metadata', + 'data/logs', + 'data/labels', + 'plugins', ] + async def generate_files() -> list[str]: global required_files, required_paths diff --git a/pkg/core/bootutils/log.py b/pkg/core/bootutils/log.py index 7cbb7412..df65e1ba 100644 --- a/pkg/core/bootutils/log.py +++ b/pkg/core/bootutils/log.py @@ -1,5 +1,4 @@ import logging -import os import sys import time @@ -9,11 +8,11 @@ from ...utils import constants log_colors_config = { - "DEBUG": "green", # cyan white - "INFO": "white", - "WARNING": "yellow", - "ERROR": "red", - "CRITICAL": "cyan", + 'DEBUG': 'green', # cyan white + 'INFO': 'white', + 'WARNING': 'yellow', + 'ERROR': 'red', + 'CRITICAL': 'cyan', } @@ -27,26 +26,31 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging. if constants.debug_mode: level = logging.DEBUG - log_file_name = "data/logs/langbot-%s.log" % time.strftime( - "%Y-%m-%d", time.localtime() + log_file_name = 'data/logs/langbot-%s.log' % time.strftime( + '%Y-%m-%d', time.localtime() ) - qcg_logger = logging.getLogger("langbot") + qcg_logger = logging.getLogger('langbot') qcg_logger.setLevel(level) color_formatter = colorlog.ColoredFormatter( - fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s", - datefmt="%m-%d %H:%M:%S", + fmt='%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s', + datefmt='%m-%d %H:%M:%S', log_colors=log_colors_config, ) stream_handler = logging.StreamHandler(sys.stdout) # stream_handler.setLevel(level) # stream_handler.setFormatter(color_formatter) - stream_handler.stream = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1) + stream_handler.stream = open( + sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1 + ) - log_handlers: list[logging.Handler] = [stream_handler, logging.FileHandler(log_file_name, encoding='utf-8')] + log_handlers: list[logging.Handler] = [ + stream_handler, + logging.FileHandler(log_file_name, encoding='utf-8'), + ] log_handlers += extra_handlers if extra_handlers is not None else [] for handler in log_handlers: @@ -54,13 +58,13 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging. handler.setFormatter(color_formatter) qcg_logger.addHandler(handler) - qcg_logger.debug("日志初始化完成,日志级别:%s" % level) + qcg_logger.debug('日志初始化完成,日志级别:%s' % level) logging.basicConfig( level=logging.CRITICAL, # 设置日志输出格式 - format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", + format='[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s', # 日志输出的格式 # -8表示占位符,让输出左对齐,输出长度都为8位 - datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式 + datefmt='%Y-%m-%d %H:%M:%S', # 时间输出的格式 handlers=[logging.NullHandler()], ) diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 1753495b..5ffd0029 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -8,21 +8,18 @@ import asyncio import pydantic.v1 as pydantic from ..provider import entities as llm_entities -from ..provider.modelmgr import entities, modelmgr, requester +from ..provider.modelmgr import requester from ..provider.tools import entities as tools_entities from ..platform import adapter as msadapter from ..platform.types import message as platform_message from ..platform.types import events as platform_events -from ..platform.types import entities as platform_entities - class LifecycleControlScope(enum.Enum): - - APPLICATION = "application" - PLATFORM = "platform" - PLUGIN = "plugin" - PROVIDER = "provider" + APPLICATION = 'application' + PLATFORM = 'platform' + PLUGIN = 'plugin' + PROVIDER = 'provider' class LauncherTypes(enum.Enum): @@ -89,14 +86,17 @@ class Query(pydantic.BaseModel): use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None """使用的函数,由前置处理器阶段设置""" - resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] = [] + resp_messages: ( + typing.Optional[list[llm_entities.Message]] + | typing.Optional[list[platform_message.MessageChain]] + ) = [] """由Process阶段生成的回复消息对象列表""" resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None """回复消息链,从resp_messages包装而得""" # ======= 内部保留 ======= - current_stage: "pkg.pipeline.pipelinemgr.StageInstContainer" = None + current_stage = None # pkg.pipeline.pipelinemgr.StageInstContainer """当前所处阶段""" class Config: @@ -109,13 +109,13 @@ class Query(pydantic.BaseModel): if self.variables is None: self.variables = {} self.variables[key] = value - + def get_variable(self, key: str) -> typing.Any: """获取变量""" if self.variables is None: return None return self.variables.get(key) - + def get_variables(self) -> dict[str, typing.Any]: """获取所有变量""" if self.variables is None: @@ -130,9 +130,13 @@ class Conversation(pydantic.BaseModel): messages: list[llm_entities.Message] - create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + create_time: typing.Optional[datetime.datetime] = pydantic.Field( + default_factory=datetime.datetime.now + ) - update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + update_time: typing.Optional[datetime.datetime] = pydantic.Field( + default_factory=datetime.datetime.now + ) use_llm_model: requester.RuntimeLLMModel @@ -147,6 +151,7 @@ class Conversation(pydantic.BaseModel): class Session(pydantic.BaseModel): """会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}""" + launcher_type: LauncherTypes launcher_id: typing.Union[int, str] @@ -157,11 +162,17 @@ class Session(pydantic.BaseModel): using_conversation: typing.Optional[Conversation] = None - conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list) + conversations: typing.Optional[list[Conversation]] = pydantic.Field( + default_factory=list + ) - create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + create_time: typing.Optional[datetime.datetime] = pydantic.Field( + default_factory=datetime.datetime.now + ) - update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + update_time: typing.Optional[datetime.datetime] = pydantic.Field( + default_factory=datetime.datetime.now + ) semaphore: typing.Optional[asyncio.Semaphore] = None """当前会话的信号量,用于限制并发""" diff --git a/pkg/core/migration.py b/pkg/core/migration.py index 2c5c7597..e97c0cf3 100644 --- a/pkg/core/migration.py +++ b/pkg/core/migration.py @@ -9,21 +9,21 @@ from . import app preregistered_migrations: list[typing.Type[Migration]] = [] """当前阶段暂不支持扩展""" + def migration_class(name: str, number: int): - """注册一个迁移 - """ + """注册一个迁移""" + def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]: cls.name = name cls.number = number preregistered_migrations.append(cls) return cls - + return decorator class Migration(abc.ABC): - """一个版本的迁移 - """ + """一个版本的迁移""" name: str @@ -33,15 +33,13 @@ class Migration(abc.ABC): def __init__(self, ap: app.Application): self.ap = ap - + @abc.abstractmethod async def need_migrate(self) -> bool: - """判断当前环境是否需要运行此迁移 - """ + """判断当前环境是否需要运行此迁移""" pass @abc.abstractmethod async def run(self): - """执行迁移 - """ + """执行迁移""" pass diff --git a/pkg/core/migrations/m001_sensitive_word_migration.py b/pkg/core/migrations/m001_sensitive_word_migration.py index 6e435eeb..72200346 100644 --- a/pkg/core/migrations/m001_sensitive_word_migration.py +++ b/pkg/core/migrations/m001_sensitive_word_migration.py @@ -1,26 +1,26 @@ from __future__ import annotations import os -import sys from .. import migration -@migration.migration_class("sensitive-word-migration", 1) +@migration.migration_class('sensitive-word-migration', 1) class SensitiveWordMigration(migration.Migration): - """敏感词迁移 - """ + """敏感词迁移""" async def need_migrate(self) -> bool: - """判断当前环境是否需要运行此迁移 - """ - return os.path.exists("data/config/sensitive-words.json") and not os.path.exists("data/metadata/sensitive-words.json") + """判断当前环境是否需要运行此迁移""" + return os.path.exists( + 'data/config/sensitive-words.json' + ) and not os.path.exists('data/metadata/sensitive-words.json') async def run(self): - """执行迁移 - """ + """执行迁移""" # 移动文件 - os.rename("data/config/sensitive-words.json", "data/metadata/sensitive-words.json") + os.rename( + 'data/config/sensitive-words.json', 'data/metadata/sensitive-words.json' + ) # 重新加载配置 await self.ap.sensitive_meta.load_config() diff --git a/pkg/core/migrations/m002_openai_config_migration.py b/pkg/core/migrations/m002_openai_config_migration.py index 2f2553ef..6892110f 100644 --- a/pkg/core/migrations/m002_openai_config_migration.py +++ b/pkg/core/migrations/m002_openai_config_migration.py @@ -3,19 +3,16 @@ from __future__ import annotations from .. import migration -@migration.migration_class("openai-config-migration", 2) +@migration.migration_class('openai-config-migration', 2) class OpenAIConfigMigration(migration.Migration): - """OpenAI配置迁移 - """ + """OpenAI配置迁移""" async def need_migrate(self) -> bool: - """判断当前环境是否需要运行此迁移 - """ + """判断当前环境是否需要运行此迁移""" return 'openai-config' in self.ap.provider_cfg.data async def run(self): - """执行迁移 - """ + """执行迁移""" old_openai_config = self.ap.provider_cfg.data['openai-config'].copy() if 'keys' not in self.ap.provider_cfg.data: @@ -26,7 +23,9 @@ class OpenAIConfigMigration(migration.Migration): self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys'] - self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model'] + self.ap.provider_cfg.data['model'] = old_openai_config[ + 'chat-completions-params' + ]['model'] del old_openai_config['chat-completions-params']['model'] @@ -35,7 +34,7 @@ class OpenAIConfigMigration(migration.Migration): if 'openai-chat-completions' not in self.ap.provider_cfg.data['requester']: self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {} - + self.ap.provider_cfg.data['requester']['openai-chat-completions'] = { 'base-url': old_openai_config['base_url'], 'args': old_openai_config['chat-completions-params'], @@ -44,4 +43,4 @@ class OpenAIConfigMigration(migration.Migration): del self.ap.provider_cfg.data['openai-config'] - await self.ap.provider_cfg.dump_config() \ No newline at end of file + await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m003_anthropic_requester_cfg_completion.py b/pkg/core/migrations/m003_anthropic_requester_cfg_completion.py index 101b03d0..19369679 100644 --- a/pkg/core/migrations/m003_anthropic_requester_cfg_completion.py +++ b/pkg/core/migrations/m003_anthropic_requester_cfg_completion.py @@ -3,26 +3,23 @@ from __future__ import annotations from .. import migration -@migration.migration_class("anthropic-requester-config-completion", 3) +@migration.migration_class('anthropic-requester-config-completion', 3) class AnthropicRequesterConfigCompletionMigration(migration.Migration): - """OpenAI配置迁移 - """ + """OpenAI配置迁移""" async def need_migrate(self) -> bool: - """判断当前环境是否需要运行此迁移 - """ - return 'anthropic-messages' not in self.ap.provider_cfg.data['requester'] \ + """判断当前环境是否需要运行此迁移""" + return ( + 'anthropic-messages' not in self.ap.provider_cfg.data['requester'] or 'anthropic' not in self.ap.provider_cfg.data['keys'] + ) async def run(self): - """执行迁移 - """ + """执行迁移""" if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']: self.ap.provider_cfg.data['requester']['anthropic-messages'] = { 'base-url': 'https://api.anthropic.com', - 'args': { - 'max_tokens': 1024 - }, + 'args': {'max_tokens': 1024}, 'timeout': 120, } diff --git a/pkg/core/migrations/m004_moonshot_cfg_completion.py b/pkg/core/migrations/m004_moonshot_cfg_completion.py index b1f7e9ed..de086159 100644 --- a/pkg/core/migrations/m004_moonshot_cfg_completion.py +++ b/pkg/core/migrations/m004_moonshot_cfg_completion.py @@ -3,20 +3,19 @@ from __future__ import annotations from .. import migration -@migration.migration_class("moonshot-config-completion", 4) +@migration.migration_class('moonshot-config-completion', 4) class MoonshotConfigCompletionMigration(migration.Migration): - """OpenAI配置迁移 - """ + """OpenAI配置迁移""" async def need_migrate(self) -> bool: - """判断当前环境是否需要运行此迁移 - """ - return 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester'] \ + """判断当前环境是否需要运行此迁移""" + return ( + 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'moonshot' not in self.ap.provider_cfg.data['keys'] + ) async def run(self): - """执行迁移 - """ + """执行迁移""" if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']: self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = { 'base-url': 'https://api.moonshot.cn/v1', diff --git a/pkg/core/migrations/m005_deepseek_cfg_completion.py b/pkg/core/migrations/m005_deepseek_cfg_completion.py index bd8aa2ee..d4d82e3f 100644 --- a/pkg/core/migrations/m005_deepseek_cfg_completion.py +++ b/pkg/core/migrations/m005_deepseek_cfg_completion.py @@ -3,20 +3,19 @@ from __future__ import annotations from .. import migration -@migration.migration_class("deepseek-config-completion", 5) +@migration.migration_class('deepseek-config-completion', 5) class DeepseekConfigCompletionMigration(migration.Migration): - """OpenAI配置迁移 - """ + """OpenAI配置迁移""" async def need_migrate(self) -> bool: - """判断当前环境是否需要运行此迁移 - """ - return 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester'] \ + """判断当前环境是否需要运行此迁移""" + return ( + 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'deepseek' not in self.ap.provider_cfg.data['keys'] + ) async def run(self): - """执行迁移 - """ + """执行迁移""" if 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']: self.ap.provider_cfg.data['requester']['deepseek-chat-completions'] = { 'base-url': 'https://api.deepseek.com', @@ -27,4 +26,4 @@ class DeepseekConfigCompletionMigration(migration.Migration): if 'deepseek' not in self.ap.provider_cfg.data['keys']: self.ap.provider_cfg.data['keys']['deepseek'] = [] - await self.ap.provider_cfg.dump_config() \ No newline at end of file + await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m006_vision_config.py b/pkg/core/migrations/m006_vision_config.py index 8084611e..ea824d44 100644 --- a/pkg/core/migrations/m006_vision_config.py +++ b/pkg/core/migrations/m006_vision_config.py @@ -3,17 +3,17 @@ from __future__ import annotations from .. import migration -@migration.migration_class("vision-config", 6) +@migration.migration_class('vision-config', 6) class VisionConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - return "enable-vision" not in self.ap.provider_cfg.data + return 'enable-vision' not in self.ap.provider_cfg.data async def run(self): """执行迁移""" - if "enable-vision" not in self.ap.provider_cfg.data: - self.ap.provider_cfg.data["enable-vision"] = False + if 'enable-vision' not in self.ap.provider_cfg.data: + self.ap.provider_cfg.data['enable-vision'] = False await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m007_qcg_center_url.py b/pkg/core/migrations/m007_qcg_center_url.py index cecd6b11..b3fcd853 100644 --- a/pkg/core/migrations/m007_qcg_center_url.py +++ b/pkg/core/migrations/m007_qcg_center_url.py @@ -3,18 +3,20 @@ from __future__ import annotations from .. import migration -@migration.migration_class("qcg-center-url-config", 7) +@migration.migration_class('qcg-center-url-config', 7) class QCGCenterURLConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - return "qcg-center-url" not in self.ap.system_cfg.data + return 'qcg-center-url' not in self.ap.system_cfg.data async def run(self): """执行迁移""" - - if "qcg-center-url" not in self.ap.system_cfg.data: - self.ap.system_cfg.data["qcg-center-url"] = "https://api.qchatgpt.rockchin.top/api/v2" - + + if 'qcg-center-url' not in self.ap.system_cfg.data: + self.ap.system_cfg.data['qcg-center-url'] = ( + 'https://api.qchatgpt.rockchin.top/api/v2' + ) + await self.ap.system_cfg.dump_config() diff --git a/pkg/core/migrations/m008_ad_fixwin_config_migrate.py b/pkg/core/migrations/m008_ad_fixwin_config_migrate.py index ccd6fbd7..96fd58e7 100644 --- a/pkg/core/migrations/m008_ad_fixwin_config_migrate.py +++ b/pkg/core/migrations/m008_ad_fixwin_config_migrate.py @@ -3,27 +3,27 @@ from __future__ import annotations from .. import migration -@migration.migration_class("ad-fixwin-cfg-migration", 8) +@migration.migration_class('ad-fixwin-cfg-migration', 8) class AdFixwinConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" return isinstance( - self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]["default"], - int + self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int ) async def run(self): """执行迁移""" - - for session_name in self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]: + for session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']: temp_dict = { - "window-size": 60, - "limit": self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name] + 'window-size': 60, + 'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][ + session_name + ], } - - self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name] = temp_dict - await self.ap.pipeline_cfg.dump_config() \ No newline at end of file + self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name] = temp_dict + + await self.ap.pipeline_cfg.dump_config() diff --git a/pkg/core/migrations/m009_msg_truncator_cfg.py b/pkg/core/migrations/m009_msg_truncator_cfg.py index 369b60eb..066af126 100644 --- a/pkg/core/migrations/m009_msg_truncator_cfg.py +++ b/pkg/core/migrations/m009_msg_truncator_cfg.py @@ -3,7 +3,7 @@ from __future__ import annotations from .. import migration -@migration.migration_class("msg-truncator-cfg-migration", 9) +@migration.migration_class('msg-truncator-cfg-migration', 9) class MsgTruncatorConfigMigration(migration.Migration): """迁移""" @@ -13,12 +13,10 @@ class MsgTruncatorConfigMigration(migration.Migration): async def run(self): """执行迁移""" - + self.ap.pipeline_cfg.data['msg-truncate'] = { 'method': 'round', - 'round': { - 'max-round': 10 - } + 'round': {'max-round': 10}, } await self.ap.pipeline_cfg.dump_config() diff --git a/pkg/core/migrations/m010_ollama_requester_config.py b/pkg/core/migrations/m010_ollama_requester_config.py index 56e49663..8e2e15eb 100644 --- a/pkg/core/migrations/m010_ollama_requester_config.py +++ b/pkg/core/migrations/m010_ollama_requester_config.py @@ -3,7 +3,7 @@ from __future__ import annotations from .. import migration -@migration.migration_class("ollama-requester-config", 10) +@migration.migration_class('ollama-requester-config', 10) class MsgTruncatorConfigMigration(migration.Migration): """迁移""" @@ -13,11 +13,11 @@ class MsgTruncatorConfigMigration(migration.Migration): async def run(self): """执行迁移""" - + self.ap.provider_cfg.data['requester']['ollama-chat'] = { - "base-url": "http://127.0.0.1:11434", - "args": {}, - "timeout": 600 + 'base-url': 'http://127.0.0.1:11434', + 'args': {}, + 'timeout': 600, } await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m011_command_prefix_config.py b/pkg/core/migrations/m011_command_prefix_config.py index 6a9e1118..6165ae47 100644 --- a/pkg/core/migrations/m011_command_prefix_config.py +++ b/pkg/core/migrations/m011_command_prefix_config.py @@ -3,7 +3,7 @@ from __future__ import annotations from .. import migration -@migration.migration_class("command-prefix-config", 11) +@migration.migration_class('command-prefix-config', 11) class CommandPrefixConfigMigration(migration.Migration): """迁移""" @@ -13,9 +13,7 @@ class CommandPrefixConfigMigration(migration.Migration): async def run(self): """执行迁移""" - - self.ap.command_cfg.data['command-prefix'] = [ - "!", "!" - ] + + self.ap.command_cfg.data['command-prefix'] = ['!', '!'] await self.ap.command_cfg.dump_config() diff --git a/pkg/core/migrations/m012_runner_config.py b/pkg/core/migrations/m012_runner_config.py index fa236bb7..e7f0e67a 100644 --- a/pkg/core/migrations/m012_runner_config.py +++ b/pkg/core/migrations/m012_runner_config.py @@ -3,7 +3,7 @@ from __future__ import annotations from .. import migration -@migration.migration_class("runner-config", 12) +@migration.migration_class('runner-config', 12) class RunnerConfigMigration(migration.Migration): """迁移""" @@ -13,7 +13,7 @@ class RunnerConfigMigration(migration.Migration): async def run(self): """执行迁移""" - + self.ap.provider_cfg.data['runner'] = 'local-agent' await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m013_http_api_config.py b/pkg/core/migrations/m013_http_api_config.py index c5fe55ba..55aff2b9 100644 --- a/pkg/core/migrations/m013_http_api_config.py +++ b/pkg/core/migrations/m013_http_api_config.py @@ -3,29 +3,30 @@ from __future__ import annotations from .. import migration -@migration.migration_class("http-api-config", 13) +@migration.migration_class('http-api-config', 13) class HttpApiConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - return 'http-api' not in self.ap.system_cfg.data or "persistence" not in self.ap.system_cfg.data + return ( + 'http-api' not in self.ap.system_cfg.data + or 'persistence' not in self.ap.system_cfg.data + ) async def run(self): """执行迁移""" - + self.ap.system_cfg.data['http-api'] = { - "enable": True, - "host": "0.0.0.0", - "port": 5300, - "jwt-expire": 604800 + 'enable': True, + 'host': '0.0.0.0', + 'port': 5300, + 'jwt-expire': 604800, } self.ap.system_cfg.data['persistence'] = { - "sqlite": { - "path": "data/persistence.db" - }, - "use": "sqlite" + 'sqlite': {'path': 'data/persistence.db'}, + 'use': 'sqlite', } await self.ap.system_cfg.dump_config() diff --git a/pkg/core/migrations/m014_force_delay_config.py b/pkg/core/migrations/m014_force_delay_config.py index 55521c9c..005a2ca2 100644 --- a/pkg/core/migrations/m014_force_delay_config.py +++ b/pkg/core/migrations/m014_force_delay_config.py @@ -3,20 +3,20 @@ from __future__ import annotations from .. import migration -@migration.migration_class("force-delay-config", 14) +@migration.migration_class('force-delay-config', 14) class ForceDelayConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - return type(self.ap.platform_cfg.data['force-delay']) == list + return isinstance(self.ap.platform_cfg.data['force-delay'], list) async def run(self): """执行迁移""" self.ap.platform_cfg.data['force-delay'] = { - "min": self.ap.platform_cfg.data['force-delay'][0], - "max": self.ap.platform_cfg.data['force-delay'][1] + 'min': self.ap.platform_cfg.data['force-delay'][0], + 'max': self.ap.platform_cfg.data['force-delay'][1], } await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/migrations/m015_gitee_ai_config.py b/pkg/core/migrations/m015_gitee_ai_config.py index b41071ad..7dd9b853 100644 --- a/pkg/core/migrations/m015_gitee_ai_config.py +++ b/pkg/core/migrations/m015_gitee_ai_config.py @@ -3,24 +3,25 @@ from __future__ import annotations from .. import migration -@migration.migration_class("gitee-ai-config", 15) +@migration.migration_class('gitee-ai-config', 15) class GiteeAIConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - return 'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'gitee-ai' not in self.ap.provider_cfg.data['keys'] + return ( + 'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester'] + or 'gitee-ai' not in self.ap.provider_cfg.data['keys'] + ) async def run(self): """执行迁移""" self.ap.provider_cfg.data['requester']['gitee-ai-chat-completions'] = { - "base-url": "https://ai.gitee.com/v1", - "args": {}, - "timeout": 120 + 'base-url': 'https://ai.gitee.com/v1', + 'args': {}, + 'timeout': 120, } - self.ap.provider_cfg.data['keys']['gitee-ai'] = [ - "XXXXX" - ] + self.ap.provider_cfg.data['keys']['gitee-ai'] = ['XXXXX'] await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m016_dify_service_api.py b/pkg/core/migrations/m016_dify_service_api.py index 123879f8..e7c4dc6d 100644 --- a/pkg/core/migrations/m016_dify_service_api.py +++ b/pkg/core/migrations/m016_dify_service_api.py @@ -3,7 +3,7 @@ from __future__ import annotations from .. import migration -@migration.migration_class("dify-service-api-config", 16) +@migration.migration_class('dify-service-api-config', 16) class DifyServiceAPICfgMigration(migration.Migration): """迁移""" @@ -14,15 +14,10 @@ class DifyServiceAPICfgMigration(migration.Migration): async def run(self): """执行迁移""" self.ap.provider_cfg.data['dify-service-api'] = { - "base-url": "https://api.dify.ai/v1", - "app-type": "chat", - "chat": { - "api-key": "app-1234567890" - }, - "workflow": { - "api-key": "app-1234567890", - "output-key": "summary" - } + 'base-url': 'https://api.dify.ai/v1', + 'app-type': 'chat', + 'chat': {'api-key': 'app-1234567890'}, + 'workflow': {'api-key': 'app-1234567890', 'output-key': 'summary'}, } await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m017_dify_api_timeout_params.py b/pkg/core/migrations/m017_dify_api_timeout_params.py index a0e502a4..7ce9133c 100644 --- a/pkg/core/migrations/m017_dify_api_timeout_params.py +++ b/pkg/core/migrations/m017_dify_api_timeout_params.py @@ -3,22 +3,26 @@ from __future__ import annotations from .. import migration -@migration.migration_class("dify-api-timeout-params", 17) +@migration.migration_class('dify-api-timeout-params', 17) class DifyAPITimeoutParamsMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - return 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat'] or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow'] \ + return ( + 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat'] + or 'timeout' + not in self.ap.provider_cfg.data['dify-service-api']['workflow'] or 'agent' not in self.ap.provider_cfg.data['dify-service-api'] + ) async def run(self): """执行迁移""" self.ap.provider_cfg.data['dify-service-api']['chat']['timeout'] = 120 self.ap.provider_cfg.data['dify-service-api']['workflow']['timeout'] = 120 self.ap.provider_cfg.data['dify-service-api']['agent'] = { - "api-key": "app-1234567890", - "timeout": 120 + 'api-key': 'app-1234567890', + 'timeout': 120, } await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m018_xai_config.py b/pkg/core/migrations/m018_xai_config.py index bf422451..db5ed5bf 100644 --- a/pkg/core/migrations/m018_xai_config.py +++ b/pkg/core/migrations/m018_xai_config.py @@ -3,7 +3,7 @@ from __future__ import annotations from .. import migration -@migration.migration_class("xai-config", 18) +@migration.migration_class('xai-config', 18) class XaiConfigMigration(migration.Migration): """迁移""" @@ -14,12 +14,10 @@ class XaiConfigMigration(migration.Migration): async def run(self): """执行迁移""" self.ap.provider_cfg.data['requester']['xai-chat-completions'] = { - "base-url": "https://api.x.ai/v1", - "args": {}, - "timeout": 120 + 'base-url': 'https://api.x.ai/v1', + 'args': {}, + 'timeout': 120, } - self.ap.provider_cfg.data['keys']['xai'] = [ - "xai-1234567890" - ] + self.ap.provider_cfg.data['keys']['xai'] = ['xai-1234567890'] await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m019_zhipuai_config.py b/pkg/core/migrations/m019_zhipuai_config.py index 67f33340..081d8dcf 100644 --- a/pkg/core/migrations/m019_zhipuai_config.py +++ b/pkg/core/migrations/m019_zhipuai_config.py @@ -3,7 +3,7 @@ from __future__ import annotations from .. import migration -@migration.migration_class("zhipuai-config", 19) +@migration.migration_class('zhipuai-config', 19) class ZhipuaiConfigMigration(migration.Migration): """迁移""" @@ -14,12 +14,10 @@ class ZhipuaiConfigMigration(migration.Migration): async def run(self): """执行迁移""" self.ap.provider_cfg.data['requester']['zhipuai-chat-completions'] = { - "base-url": "https://open.bigmodel.cn/api/paas/v4", - "args": {}, - "timeout": 120 + 'base-url': 'https://open.bigmodel.cn/api/paas/v4', + 'args': {}, + 'timeout': 120, } - self.ap.provider_cfg.data['keys']['zhipuai'] = [ - "xxxxxxx" - ] + self.ap.provider_cfg.data['keys']['zhipuai'] = ['xxxxxxx'] await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m020_wecom_config.py b/pkg/core/migrations/m020_wecom_config.py index 9581cb91..3e833d3e 100644 --- a/pkg/core/migrations/m020_wecom_config.py +++ b/pkg/core/migrations/m020_wecom_config.py @@ -3,13 +3,13 @@ from __future__ import annotations from .. import migration -@migration.migration_class("wecom-config", 20) +@migration.migration_class('wecom-config', 20) class WecomConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + # for adapter in self.ap.platform_cfg.data['platform-adapters']: # if adapter['adapter'] == 'wecom': # return False @@ -19,16 +19,18 @@ class WecomConfigMigration(migration.Migration): async def run(self): """执行迁移""" - self.ap.platform_cfg.data['platform-adapters'].append({ - "adapter": "wecom", - "enable": False, - "host": "0.0.0.0", - "port": 2290, - "corpid": "", - "secret": "", - "token": "", - "EncodingAESKey": "", - "contacts_secret": "" - }) + self.ap.platform_cfg.data['platform-adapters'].append( + { + 'adapter': 'wecom', + 'enable': False, + 'host': '0.0.0.0', + 'port': 2290, + 'corpid': '', + 'secret': '', + 'token': '', + 'EncodingAESKey': '', + 'contacts_secret': '', + } + ) await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/migrations/m021_lark_config.py b/pkg/core/migrations/m021_lark_config.py index 49d9bb8f..04f29db4 100644 --- a/pkg/core/migrations/m021_lark_config.py +++ b/pkg/core/migrations/m021_lark_config.py @@ -3,13 +3,13 @@ from __future__ import annotations from .. import migration -@migration.migration_class("lark-config", 21) +@migration.migration_class('lark-config', 21) class LarkConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + # for adapter in self.ap.platform_cfg.data['platform-adapters']: # if adapter['adapter'] == 'lark': # return False @@ -19,15 +19,17 @@ class LarkConfigMigration(migration.Migration): async def run(self): """执行迁移""" - self.ap.platform_cfg.data['platform-adapters'].append({ - "adapter": "lark", - "enable": False, - "app_id": "cli_abcdefgh", - "app_secret": "XXXXXXXXXX", - "bot_name": "LangBot", - "enable-webhook": False, - "port": 2285, - "encrypt-key": "xxxxxxxxx" - }) + self.ap.platform_cfg.data['platform-adapters'].append( + { + 'adapter': 'lark', + 'enable': False, + 'app_id': 'cli_abcdefgh', + 'app_secret': 'XXXXXXXXXX', + 'bot_name': 'LangBot', + 'enable-webhook': False, + 'port': 2285, + 'encrypt-key': 'xxxxxxxxx', + } + ) await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/migrations/m022_lmstudio_config.py b/pkg/core/migrations/m022_lmstudio_config.py index 5506b37b..bffc6bb8 100644 --- a/pkg/core/migrations/m022_lmstudio_config.py +++ b/pkg/core/migrations/m022_lmstudio_config.py @@ -3,21 +3,21 @@ from __future__ import annotations from .. import migration -@migration.migration_class("lmstudio-config", 22) +@migration.migration_class('lmstudio-config', 22) class LmStudioConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + return 'lmstudio-chat-completions' not in self.ap.provider_cfg.data['requester'] async def run(self): """执行迁移""" self.ap.provider_cfg.data['requester']['lmstudio-chat-completions'] = { - "base-url": "http://127.0.0.1:1234/v1", - "args": {}, - "timeout": 120 + 'base-url': 'http://127.0.0.1:1234/v1', + 'args': {}, + 'timeout': 120, } await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m023_siliconflow_config.py b/pkg/core/migrations/m023_siliconflow_config.py index a0e65c6a..fdf696eb 100644 --- a/pkg/core/migrations/m023_siliconflow_config.py +++ b/pkg/core/migrations/m023_siliconflow_config.py @@ -3,25 +3,25 @@ from __future__ import annotations from .. import migration -@migration.migration_class("siliconflow-config", 23) +@migration.migration_class('siliconflow-config', 23) class SiliconFlowConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - - return 'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester'] + + return ( + 'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester'] + ) async def run(self): """执行迁移""" - self.ap.provider_cfg.data['keys']['siliconflow'] = [ - "xxxxxxx" - ] + self.ap.provider_cfg.data['keys']['siliconflow'] = ['xxxxxxx'] self.ap.provider_cfg.data['requester']['siliconflow-chat-completions'] = { - "base-url": "https://api.siliconflow.cn/v1", - "args": {}, - "timeout": 120 + 'base-url': 'https://api.siliconflow.cn/v1', + 'args': {}, + 'timeout': 120, } await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m024_discord_config.py b/pkg/core/migrations/m024_discord_config.py index fcfac6e6..ebcae232 100644 --- a/pkg/core/migrations/m024_discord_config.py +++ b/pkg/core/migrations/m024_discord_config.py @@ -3,13 +3,13 @@ from __future__ import annotations from .. import migration -@migration.migration_class("discord-config", 24) +@migration.migration_class('discord-config', 24) class DiscordConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + # for adapter in self.ap.platform_cfg.data['platform-adapters']: # if adapter['adapter'] == 'discord': # return False @@ -19,11 +19,13 @@ class DiscordConfigMigration(migration.Migration): async def run(self): """执行迁移""" - self.ap.platform_cfg.data['platform-adapters'].append({ - "adapter": "discord", - "enable": False, - "client_id": "1234567890", - "token": "XXXXXXXXXX" - }) + self.ap.platform_cfg.data['platform-adapters'].append( + { + 'adapter': 'discord', + 'enable': False, + 'client_id': '1234567890', + 'token': 'XXXXXXXXXX', + } + ) await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/migrations/m025_gewechat_config.py b/pkg/core/migrations/m025_gewechat_config.py index 65b5c1d5..bb729854 100644 --- a/pkg/core/migrations/m025_gewechat_config.py +++ b/pkg/core/migrations/m025_gewechat_config.py @@ -3,13 +3,13 @@ from __future__ import annotations from .. import migration -@migration.migration_class("gewechat-config", 25) +@migration.migration_class('gewechat-config', 25) class GewechatConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + # for adapter in self.ap.platform_cfg.data['platform-adapters']: # if adapter['adapter'] == 'gewechat': # return False @@ -19,15 +19,17 @@ class GewechatConfigMigration(migration.Migration): async def run(self): """执行迁移""" - self.ap.platform_cfg.data['platform-adapters'].append({ - "adapter": "gewechat", - "enable": False, - "gewechat_url": "http://your-gewechat-server:2531", - "gewechat_file_url": "http://your-gewechat-server:2532", - "port": 2286, - "callback_url": "http://your-callback-url:2286/gewechat/callback", - "app_id": "", - "token": "" - }) + self.ap.platform_cfg.data['platform-adapters'].append( + { + 'adapter': 'gewechat', + 'enable': False, + 'gewechat_url': 'http://your-gewechat-server:2531', + 'gewechat_file_url': 'http://your-gewechat-server:2532', + 'port': 2286, + 'callback_url': 'http://your-callback-url:2286/gewechat/callback', + 'app_id': '', + 'token': '', + } + ) await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/migrations/m026_qqofficial_config.py b/pkg/core/migrations/m026_qqofficial_config.py index b4745806..90674341 100644 --- a/pkg/core/migrations/m026_qqofficial_config.py +++ b/pkg/core/migrations/m026_qqofficial_config.py @@ -3,13 +3,13 @@ from __future__ import annotations from .. import migration -@migration.migration_class("qqofficial-config", 26) +@migration.migration_class('qqofficial-config', 26) class QQOfficialConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + # for adapter in self.ap.platform_cfg.data['platform-adapters']: # if adapter['adapter'] == 'qqofficial': # return False @@ -19,13 +19,15 @@ class QQOfficialConfigMigration(migration.Migration): async def run(self): """执行迁移""" - self.ap.platform_cfg.data['platform-adapters'].append({ - "adapter": "qqofficial", - "enable": False, - "appid": "", - "secret": "", - "port": 2284, - "token": "" - }) + self.ap.platform_cfg.data['platform-adapters'].append( + { + 'adapter': 'qqofficial', + 'enable': False, + 'appid': '', + 'secret': '', + 'port': 2284, + 'token': '', + } + ) await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/migrations/m027_wx_official_account_config.py b/pkg/core/migrations/m027_wx_official_account_config.py index 5abaad87..7c5b0e35 100644 --- a/pkg/core/migrations/m027_wx_official_account_config.py +++ b/pkg/core/migrations/m027_wx_official_account_config.py @@ -3,13 +3,13 @@ from __future__ import annotations from .. import migration -@migration.migration_class("wx-official-account-config", 27) +@migration.migration_class('wx-official-account-config', 27) class WXOfficialAccountConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + # for adapter in self.ap.platform_cfg.data['platform-adapters']: # if adapter['adapter'] == 'officialaccount': # return False @@ -19,15 +19,17 @@ class WXOfficialAccountConfigMigration(migration.Migration): async def run(self): """执行迁移""" - self.ap.platform_cfg.data['platform-adapters'].append({ - "adapter": "officialaccount", - "enable": False, - "token": "", - "EncodingAESKey": "", - "AppID": "", - "AppSecret": "", - "host": "0.0.0.0", - "port": 2287 - }) + self.ap.platform_cfg.data['platform-adapters'].append( + { + 'adapter': 'officialaccount', + 'enable': False, + 'token': '', + 'EncodingAESKey': '', + 'AppID': '', + 'AppSecret': '', + 'host': '0.0.0.0', + 'port': 2287, + } + ) await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/migrations/m028_aliyun_requester_config.py b/pkg/core/migrations/m028_aliyun_requester_config.py index f28bc04f..8d80727a 100644 --- a/pkg/core/migrations/m028_aliyun_requester_config.py +++ b/pkg/core/migrations/m028_aliyun_requester_config.py @@ -3,25 +3,23 @@ from __future__ import annotations from .. import migration -@migration.migration_class("bailian-requester-config", 28) +@migration.migration_class('bailian-requester-config', 28) class BailianRequesterConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + return 'bailian-chat-completions' not in self.ap.provider_cfg.data['requester'] async def run(self): """执行迁移""" - self.ap.provider_cfg.data['keys']['bailian'] = [ - "sk-xxxxxxx" - ] + self.ap.provider_cfg.data['keys']['bailian'] = ['sk-xxxxxxx'] self.ap.provider_cfg.data['requester']['bailian-chat-completions'] = { - "base-url": "https://dashscope.aliyuncs.com/compatible-mode/v1", - "args": {}, - "timeout": 120 + 'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'args': {}, + 'timeout': 120, } await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m029_dashscope_app_api_config.py b/pkg/core/migrations/m029_dashscope_app_api_config.py index 3a069bac..5a61fe0d 100644 --- a/pkg/core/migrations/m029_dashscope_app_api_config.py +++ b/pkg/core/migrations/m029_dashscope_app_api_config.py @@ -3,7 +3,7 @@ from __future__ import annotations from .. import migration -@migration.migration_class("dashscope-app-api-config", 29) +@migration.migration_class('dashscope-app-api-config', 29) class DashscopeAppAPICfgMigration(migration.Migration): """迁移""" @@ -14,20 +14,14 @@ class DashscopeAppAPICfgMigration(migration.Migration): async def run(self): """执行迁移""" self.ap.provider_cfg.data['dashscope-app-api'] = { - "app-type": "agent", - "api-key": "sk-1234567890", - "agent": { - "app-id": "Your_app_id", - "references_quote": "参考资料来自:" + 'app-type': 'agent', + 'api-key': 'sk-1234567890', + 'agent': {'app-id': 'Your_app_id', 'references_quote': '参考资料来自:'}, + 'workflow': { + 'app-id': 'Your_app_id', + 'references_quote': '参考资料来自:', + 'biz_params': {'city': '北京', 'date': '2023-08-10'}, }, - "workflow": { - "app-id": "Your_app_id", - "references_quote": "参考资料来自:", - "biz_params": { - "city": "北京", - "date": "2023-08-10" - } - } } await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m030_lark_config_cmpl.py b/pkg/core/migrations/m030_lark_config_cmpl.py index e016af7b..37e8fabe 100644 --- a/pkg/core/migrations/m030_lark_config_cmpl.py +++ b/pkg/core/migrations/m030_lark_config_cmpl.py @@ -3,13 +3,13 @@ from __future__ import annotations from .. import migration -@migration.migration_class("lark-config-cmpl", 30) +@migration.migration_class('lark-config-cmpl', 30) class LarkConfigCmplMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + for adapter in self.ap.platform_cfg.data['platform-adapters']: if adapter['adapter'] == 'lark': if 'enable-webhook' not in adapter: @@ -26,6 +26,6 @@ class LarkConfigCmplMigration(migration.Migration): if 'port' not in adapter: adapter['port'] = 2285 if 'encrypt-key' not in adapter: - adapter['encrypt-key'] = "xxxxxxxxx" + adapter['encrypt-key'] = 'xxxxxxxxx' await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/migrations/m031_dingtalk_config.py b/pkg/core/migrations/m031_dingtalk_config.py index 7dbc4735..22ba0bbf 100644 --- a/pkg/core/migrations/m031_dingtalk_config.py +++ b/pkg/core/migrations/m031_dingtalk_config.py @@ -3,13 +3,13 @@ from __future__ import annotations from .. import migration -@migration.migration_class("dingtalk-config", 31) +@migration.migration_class('dingtalk-config', 31) class DingTalkConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + # for adapter in self.ap.platform_cfg.data['platform-adapters']: # if adapter['adapter'] == 'dingtalk': # return False @@ -19,13 +19,15 @@ class DingTalkConfigMigration(migration.Migration): async def run(self): """执行迁移""" - self.ap.platform_cfg.data['platform-adapters'].append({ - "adapter": "dingtalk", - "enable": False, - "client_id": "", - "client_secret": "", - "robot_code": "", - "robot_name": "" - }) + self.ap.platform_cfg.data['platform-adapters'].append( + { + 'adapter': 'dingtalk', + 'enable': False, + 'client_id': '', + 'client_secret': '', + 'robot_code': '', + 'robot_name': '', + } + ) await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/migrations/m032_volcark_config.py b/pkg/core/migrations/m032_volcark_config.py index a07e5686..ae8feb52 100644 --- a/pkg/core/migrations/m032_volcark_config.py +++ b/pkg/core/migrations/m032_volcark_config.py @@ -3,25 +3,23 @@ from __future__ import annotations from .. import migration -@migration.migration_class("volcark-requester-config", 32) +@migration.migration_class('volcark-requester-config', 32) class VolcArkRequesterConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + return 'volcark-chat-completions' not in self.ap.provider_cfg.data['requester'] async def run(self): """执行迁移""" - self.ap.provider_cfg.data['keys']['volcark'] = [ - "xxxxxxxx" - ] + self.ap.provider_cfg.data['keys']['volcark'] = ['xxxxxxxx'] self.ap.provider_cfg.data['requester']['volcark-chat-completions'] = { - "base-url": "https://ark.cn-beijing.volces.com/api/v3", - "args": {}, - "timeout": 120 + 'base-url': 'https://ark.cn-beijing.volces.com/api/v3', + 'args': {}, + 'timeout': 120, } await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m033_dify_thinking_config.py b/pkg/core/migrations/m033_dify_thinking_config.py index 1f663b46..d25a4aad 100644 --- a/pkg/core/migrations/m033_dify_thinking_config.py +++ b/pkg/core/migrations/m033_dify_thinking_config.py @@ -3,24 +3,27 @@ from __future__ import annotations from .. import migration -@migration.migration_class("dify-thinking-config", 33) +@migration.migration_class('dify-thinking-config', 33) class DifyThinkingConfigMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - - if 'options' not in self.ap.provider_cfg.data["dify-service-api"]: + + if 'options' not in self.ap.provider_cfg.data['dify-service-api']: return True - if 'convert-thinking-tips' not in self.ap.provider_cfg.data["dify-service-api"]["options"]: + if ( + 'convert-thinking-tips' + not in self.ap.provider_cfg.data['dify-service-api']['options'] + ): return True return False - + async def run(self): """执行迁移""" - self.ap.provider_cfg.data["dify-service-api"]["options"] = { - "convert-thinking-tips": "plain" + self.ap.provider_cfg.data['dify-service-api']['options'] = { + 'convert-thinking-tips': 'plain' } await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m034_gewechat_file_url_config.py b/pkg/core/migrations/m034_gewechat_file_url_config.py index 44bbd65e..8c3e0a83 100644 --- a/pkg/core/migrations/m034_gewechat_file_url_config.py +++ b/pkg/core/migrations/m034_gewechat_file_url_config.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse from .. import migration -@migration.migration_class("gewechat-file-url-config", 34) +@migration.migration_class('gewechat-file-url-config', 34) class GewechatFileUrlConfigMigration(migration.Migration): """迁移""" @@ -24,6 +24,8 @@ class GewechatFileUrlConfigMigration(migration.Migration): if adapter['adapter'] == 'gewechat': if 'gewechat_file_url' not in adapter: parsed_url = urlparse(adapter['gewechat_url']) - adapter['gewechat_file_url'] = f"{parsed_url.scheme}://{parsed_url.hostname}:2532" + adapter['gewechat_file_url'] = ( + f'{parsed_url.scheme}://{parsed_url.hostname}:2532' + ) await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/migrations/m035_wxoa_mode.py b/pkg/core/migrations/m035_wxoa_mode.py index ce0ce628..6b675e30 100644 --- a/pkg/core/migrations/m035_wxoa_mode.py +++ b/pkg/core/migrations/m035_wxoa_mode.py @@ -3,7 +3,7 @@ from __future__ import annotations from .. import migration -@migration.migration_class("wxoa-mode", 35) +@migration.migration_class('wxoa-mode', 35) class WxoaModeMigration(migration.Migration): """迁移""" diff --git a/pkg/core/migrations/m036_wxoa_loading_message.py b/pkg/core/migrations/m036_wxoa_loading_message.py index 682be435..29ecba20 100644 --- a/pkg/core/migrations/m036_wxoa_loading_message.py +++ b/pkg/core/migrations/m036_wxoa_loading_message.py @@ -3,7 +3,7 @@ from __future__ import annotations from .. import migration -@migration.migration_class("wxoa-loading-message", 36) +@migration.migration_class('wxoa-loading-message', 36) class WxoaLoadingMessageMigration(migration.Migration): """迁移""" diff --git a/pkg/core/migrations/m037_mcp_config.py b/pkg/core/migrations/m037_mcp_config.py index f045f0ff..3752193e 100644 --- a/pkg/core/migrations/m037_mcp_config.py +++ b/pkg/core/migrations/m037_mcp_config.py @@ -3,7 +3,7 @@ from __future__ import annotations from .. import migration -@migration.migration_class("mcp-config", 37) +@migration.migration_class('mcp-config', 37) class MCPConfigMigration(migration.Migration): """迁移""" @@ -13,8 +13,6 @@ class MCPConfigMigration(migration.Migration): async def run(self): """执行迁移""" - self.ap.provider_cfg.data['mcp'] = { - "servers": [] - } + self.ap.provider_cfg.data['mcp'] = {'servers': []} await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/note.py b/pkg/core/note.py index 6ffbff51..07171581 100644 --- a/pkg/core/note.py +++ b/pkg/core/note.py @@ -7,9 +7,10 @@ from . import app preregistered_notes: list[typing.Type[LaunchNote]] = [] + def note_class(name: str, number: int): - """注册一个启动信息 - """ + """注册一个启动信息""" + def decorator(cls: typing.Type[LaunchNote]) -> typing.Type[LaunchNote]: cls.name = name cls.number = number @@ -20,8 +21,8 @@ def note_class(name: str, number: int): class LaunchNote(abc.ABC): - """启动信息 - """ + """启动信息""" + name: str number: int @@ -33,12 +34,10 @@ class LaunchNote(abc.ABC): @abc.abstractmethod async def need_show(self) -> bool: - """判断当前环境是否需要显示此启动信息 - """ + """判断当前环境是否需要显示此启动信息""" pass @abc.abstractmethod async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]: - """生成启动信息 - """ + """生成启动信息""" pass diff --git a/pkg/core/notes/n001_classic_msgs.py b/pkg/core/notes/n001_classic_msgs.py index bdc5c44e..3f3bd8e0 100644 --- a/pkg/core/notes/n001_classic_msgs.py +++ b/pkg/core/notes/n001_classic_msgs.py @@ -2,19 +2,17 @@ from __future__ import annotations import typing -from .. import note, app +from .. import note -@note.note_class("ClassicNotes", 1) +@note.note_class('ClassicNotes', 1) class ClassicNotes(note.LaunchNote): - """经典启动信息 - """ + """经典启动信息""" async def need_show(self) -> bool: return True async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]: - yield await self.ap.ann_mgr.show_announcements() - yield await self.ap.ver_mgr.show_version_update() \ No newline at end of file + yield await self.ap.ver_mgr.show_version_update() diff --git a/pkg/core/notes/n002_selection_mode_on_windows.py b/pkg/core/notes/n002_selection_mode_on_windows.py index 961d697d..23bff24a 100644 --- a/pkg/core/notes/n002_selection_mode_on_windows.py +++ b/pkg/core/notes/n002_selection_mode_on_windows.py @@ -2,20 +2,20 @@ from __future__ import annotations import typing import os -import sys import logging -from .. import note, app +from .. import note -@note.note_class("SelectionModeOnWindows", 2) +@note.note_class('SelectionModeOnWindows', 2) class SelectionModeOnWindows(note.LaunchNote): - """Windows 上的选择模式提示信息 - """ + """Windows 上的选择模式提示信息""" async def need_show(self) -> bool: return os.name == 'nt' async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]: - - yield """您正在使用 Windows 系统,若窗口左上角显示处于”选择“模式,程序将被暂停运行,此时请右键窗口中空白区域退出选择模式。""", logging.INFO + yield ( + """您正在使用 Windows 系统,若窗口左上角显示处于”选择“模式,程序将被暂停运行,此时请右键窗口中空白区域退出选择模式。""", + logging.INFO, + ) diff --git a/pkg/core/notes/n003_print_version.py b/pkg/core/notes/n003_print_version.py index 91208fdf..18eebf4f 100644 --- a/pkg/core/notes/n003_print_version.py +++ b/pkg/core/notes/n003_print_version.py @@ -1,21 +1,17 @@ from __future__ import annotations import typing -import os -import sys import logging -from .. import note, app +from .. import note -@note.note_class("PrintVersion", 3) +@note.note_class('PrintVersion', 3) class PrintVersion(note.LaunchNote): - """Print Version Information - """ + """Print Version Information""" async def need_show(self) -> bool: return True async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]: - - yield f"Current Version: {self.ap.ver_mgr.get_current_version()}", logging.INFO + yield f'Current Version: {self.ap.ver_mgr.get_current_version()}', logging.INFO diff --git a/pkg/core/stage.py b/pkg/core/stage.py index f1c65295..220c474d 100644 --- a/pkg/core/stage.py +++ b/pkg/core/stage.py @@ -12,9 +12,8 @@ preregistered_stages: dict[str, typing.Type[BootingStage]] = {} 当前阶段暂不支持扩展 """ -def stage_class( - name: str -): + +def stage_class(name: str): def decorator(cls: typing.Type[BootingStage]) -> typing.Type[BootingStage]: preregistered_stages[name] = cls return cls @@ -23,12 +22,11 @@ def stage_class( class BootingStage(abc.ABC): - """启动阶段 - """ + """启动阶段""" + name: str = None @abc.abstractmethod async def run(self, ap: app.Application): - """启动 - """ + """启动""" pass diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 7d62e9c9..5dee9386 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys from .. import stage, app from ...utils import version, proxy, announce, platform @@ -24,26 +23,22 @@ from ...utils import logcache from .. import taskmgr -@stage.stage_class("BuildAppStage") +@stage.stage_class('BuildAppStage') class BuildAppStage(stage.BootingStage): - """构建应用阶段 - """ + """构建应用阶段""" async def run(self, ap: app.Application): - """构建app对象的各个组件对象并初始化 - """ + """构建app对象的各个组件对象并初始化""" ap.task_mgr = taskmgr.AsyncTaskManager(ap) discover = discover_engine.ComponentDiscoveryEngine(ap) - discover.discover_blueprint( - "components.yaml" - ) + discover.discover_blueprint('components.yaml') ap.discover = discover proxy_mgr = proxy.ProxyManager(ap) await proxy_mgr.initialize() ap.proxy_mgr = proxy_mgr - + ver_mgr = version.VersionManager(ap) await ver_mgr.initialize() ap.ver_mgr = ver_mgr @@ -52,14 +47,14 @@ class BuildAppStage(stage.BootingStage): ap, backend_url=ap.instance_config.data['telemetry']['url'], basic_info={ - "host_id": identifier.identifier["host_id"], - "instance_id": identifier.identifier["instance_id"], - "semantic_version": ver_mgr.get_current_version(), - "platform": platform.get_platform(), + 'host_id': identifier.identifier['host_id'], + 'instance_id': identifier.identifier['instance_id'], + 'semantic_version': ver_mgr.get_current_version(), + 'platform': platform.get_platform(), }, runtime_info={ - "admin_id": "{}".format(ap.instance_config.data["admins"]), - "msg_source": str([]), + 'admin_id': '{}'.format(ap.instance_config.data['admins']), + 'msg_source': str([]), }, ) ap.ctr_mgr = center_v2_api diff --git a/pkg/core/stages/genkeys.py b/pkg/core/stages/genkeys.py index 843f1532..c24ebd70 100644 --- a/pkg/core/stages/genkeys.py +++ b/pkg/core/stages/genkeys.py @@ -1,20 +1,17 @@ from __future__ import annotations import secrets -import os from .. import stage, app -@stage.stage_class("GenKeysStage") +@stage.stage_class('GenKeysStage') class GenKeysStage(stage.BootingStage): - """生成密钥阶段 - """ + """生成密钥阶段""" async def run(self, ap: app.Application): - """启动 - """ - + """启动""" + if not ap.instance_config.data['system']['jwt']['secret']: ap.instance_config.data['system']['jwt']['secret'] = secrets.token_hex(16) await ap.instance_config.dump_config() diff --git a/pkg/core/stages/load_config.py b/pkg/core/stages/load_config.py index edfe5915..ac2f0c37 100644 --- a/pkg/core/stages/load_config.py +++ b/pkg/core/stages/load_config.py @@ -7,45 +7,80 @@ from .. import stage, app from ..bootutils import config -@stage.stage_class("LoadConfigStage") +@stage.stage_class('LoadConfigStage') class LoadConfigStage(stage.BootingStage): - """加载配置文件阶段 - """ + """加载配置文件阶段""" async def run(self, ap: app.Application): - """启动 - """ + """启动""" # ======= deprecated ======= - if os.path.exists("data/config/command.json"): - ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/legacy/command.json", completion=False) + if os.path.exists('data/config/command.json'): + ap.command_cfg = await config.load_json_config( + 'data/config/command.json', + 'templates/legacy/command.json', + completion=False, + ) - if os.path.exists("data/config/pipeline.json"): - ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/legacy/pipeline.json", completion=False) + if os.path.exists('data/config/pipeline.json'): + ap.pipeline_cfg = await config.load_json_config( + 'data/config/pipeline.json', + 'templates/legacy/pipeline.json', + completion=False, + ) - if os.path.exists("data/config/platform.json"): - ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/legacy/platform.json", completion=False) + if os.path.exists('data/config/platform.json'): + ap.platform_cfg = await config.load_json_config( + 'data/config/platform.json', + 'templates/legacy/platform.json', + completion=False, + ) - if os.path.exists("data/config/provider.json"): - ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/legacy/provider.json", completion=False) + if os.path.exists('data/config/provider.json'): + ap.provider_cfg = await config.load_json_config( + 'data/config/provider.json', + 'templates/legacy/provider.json', + completion=False, + ) - if os.path.exists("data/config/system.json"): - ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/legacy/system.json", completion=False) + if os.path.exists('data/config/system.json'): + ap.system_cfg = await config.load_json_config( + 'data/config/system.json', + 'templates/legacy/system.json', + completion=False, + ) - if os.path.exists("data/metadata/instance-secret.json"): - ap.instance_secret_meta = await config.load_json_config("data/metadata/instance-secret.json", template_data={ - 'jwt_secret': secrets.token_hex(16) - }) + if os.path.exists('data/metadata/instance-secret.json'): + ap.instance_secret_meta = await config.load_json_config( + 'data/metadata/instance-secret.json', + template_data={'jwt_secret': secrets.token_hex(16)}, + ) await ap.instance_secret_meta.dump_config() # ======= deprecated ======= - ap.instance_config = await config.load_yaml_config("data/config.yaml", "templates/config.yaml", completion=False) + ap.instance_config = await config.load_yaml_config( + 'data/config.yaml', 'templates/config.yaml', completion=False + ) await ap.instance_config.dump_config() - ap.sensitive_meta = await config.load_json_config("data/metadata/sensitive-words.json", "templates/metadata/sensitive-words.json") + ap.sensitive_meta = await config.load_json_config( + 'data/metadata/sensitive-words.json', + 'templates/metadata/sensitive-words.json', + ) await ap.sensitive_meta.dump_config() - ap.pipeline_config_meta_trigger = await config.load_yaml_config("templates/metadata/pipeline/trigger.yaml", "templates/metadata/pipeline/trigger.yaml") - ap.pipeline_config_meta_safety = await config.load_yaml_config("templates/metadata/pipeline/safety.yaml", "templates/metadata/pipeline/safety.yaml") - ap.pipeline_config_meta_ai = await config.load_yaml_config("templates/metadata/pipeline/ai.yaml", "templates/metadata/pipeline/ai.yaml") - ap.pipeline_config_meta_output = await config.load_yaml_config("templates/metadata/pipeline/output.yaml", "templates/metadata/pipeline/output.yaml") + ap.pipeline_config_meta_trigger = await config.load_yaml_config( + 'templates/metadata/pipeline/trigger.yaml', + 'templates/metadata/pipeline/trigger.yaml', + ) + ap.pipeline_config_meta_safety = await config.load_yaml_config( + 'templates/metadata/pipeline/safety.yaml', + 'templates/metadata/pipeline/safety.yaml', + ) + ap.pipeline_config_meta_ai = await config.load_yaml_config( + 'templates/metadata/pipeline/ai.yaml', 'templates/metadata/pipeline/ai.yaml' + ) + ap.pipeline_config_meta_output = await config.load_yaml_config( + 'templates/metadata/pipeline/output.yaml', + 'templates/metadata/pipeline/output.yaml', + ) diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index e902431a..02b03256 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -1,37 +1,30 @@ from __future__ import annotations -import importlib -import os from .. import stage, app from .. import migration -from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion -from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg -from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config, m013_http_api_config, m014_force_delay_config -from ..migrations import m015_gitee_ai_config, m016_dify_service_api, m017_dify_api_timeout_params, m018_xai_config, m019_zhipuai_config -from ..migrations import m020_wecom_config, m021_lark_config, m022_lmstudio_config, m023_siliconflow_config, m024_discord_config, m025_gewechat_config -from ..migrations import m026_qqofficial_config, m027_wx_official_account_config, m028_aliyun_requester_config -from ..migrations import m029_dashscope_app_api_config, m030_lark_config_cmpl, m031_dingtalk_config, m032_volcark_config -from ..migrations import m033_dify_thinking_config, m034_gewechat_file_url_config, m035_wxoa_mode, m036_wxoa_loading_message -from ..migrations import m037_mcp_config +from ...utils import importutil +from .. import migrations + +importutil.import_modules_in_pkg(migrations) -@stage.stage_class("MigrationStage") +@stage.stage_class('MigrationStage') class MigrationStage(stage.BootingStage): - """迁移阶段 - """ + """迁移阶段""" async def run(self, ap: app.Application): - """启动 - """ + """启动""" - if any([ - ap.command_cfg is None, - ap.pipeline_cfg is None, - ap.platform_cfg is None, - ap.provider_cfg is None, - ap.system_cfg is None, - ]): # only run migration when version is 3.x + if any( + [ + ap.command_cfg is None, + ap.pipeline_cfg is None, + ap.platform_cfg is None, + ap.provider_cfg is None, + ap.system_cfg is None, + ] + ): # only run migration when version is 3.x return migrations = migration.preregistered_migrations diff --git a/pkg/core/stages/setup_logger.py b/pkg/core/stages/setup_logger.py index 8f385d1f..0c630175 100644 --- a/pkg/core/stages/setup_logger.py +++ b/pkg/core/stages/setup_logger.py @@ -1,8 +1,6 @@ from __future__ import annotations import logging -import asyncio -from datetime import datetime from .. import stage, app from ..bootutils import log @@ -12,6 +10,7 @@ class PersistenceHandler(logging.Handler, object): """ 保存日志到数据库 """ + ap: app.Application def __init__(self, name, ap: app.Application): @@ -28,19 +27,17 @@ class PersistenceHandler(logging.Handler, object): msg = self.format(record) if self.ap.log_cache is not None: self.ap.log_cache.add_log(msg) - + except Exception: self.handleError(record) -@stage.stage_class("SetupLoggerStage") +@stage.stage_class('SetupLoggerStage') class SetupLoggerStage(stage.BootingStage): - """设置日志器阶段 - """ + """设置日志器阶段""" async def run(self, ap: app.Application): - """启动 - """ + """启动""" persistence_handler = PersistenceHandler('LoggerHandler', ap) extra_handlers = [] diff --git a/pkg/core/stages/show_notes.py b/pkg/core/stages/show_notes.py index 63d8f580..e7c98b42 100644 --- a/pkg/core/stages/show_notes.py +++ b/pkg/core/stages/show_notes.py @@ -1,16 +1,18 @@ from __future__ import annotations from .. import stage, app, note -from ..notes import n001_classic_msgs, n002_selection_mode_on_windows, n003_print_version +from ...utils import importutil + +from .. import notes + +importutil.import_modules_in_pkg(notes) -@stage.stage_class("ShowNotesStage") +@stage.stage_class('ShowNotesStage') class ShowNotesStage(stage.BootingStage): - """显示启动信息阶段 - """ + """显示启动信息阶段""" async def run(self, ap: app.Application): - # 排序 note.preregistered_notes.sort(key=lambda x: x.number) @@ -24,5 +26,5 @@ class ShowNotesStage(stage.BootingStage): msg, level = ret if msg: ap.logger.log(level, msg) - except Exception as e: + except Exception: continue diff --git a/pkg/core/taskmgr.py b/pkg/core/taskmgr.py index d5887019..ae2394cf 100644 --- a/pkg/core/taskmgr.py +++ b/pkg/core/taskmgr.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio import typing import datetime -import traceback from . import app from . import entities as core_entities @@ -19,11 +18,11 @@ class TaskContext: """记录日志""" def __init__(self): - self.current_action = "default" - self.log = "" + self.current_action = 'default' + self.log = '' def _log(self, msg: str): - self.log += msg + "\n" + self.log += msg + '\n' def set_current_action(self, action: str): self.current_action = action @@ -37,16 +36,16 @@ class TaskContext: self.set_current_action(action) self._log( - f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | {self.current_action} | {msg}" + f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | {self.current_action} | {msg}' ) def to_dict(self) -> dict: - return {"current_action": self.current_action, "log": self.log} - + return {'current_action': self.current_action, 'log': self.log} + @staticmethod def new() -> TaskContext: return TaskContext() - + @staticmethod def placeholder() -> TaskContext: global placeholder_context @@ -69,16 +68,16 @@ class TaskWrapper: id: int """任务ID""" - task_type: str = "system" # 任务类型: system 或 user + task_type: str = 'system' # 任务类型: system 或 user """任务类型""" - kind: str = "system_task" # 由发起者确定任务种类,通常同质化的任务种类相同 + kind: str = 'system_task' # 由发起者确定任务种类,通常同质化的任务种类相同 """任务种类""" - name: str = "" + name: str = '' """任务唯一名称""" - label: str = "" + label: str = '' """任务显示名称""" task_context: TaskContext @@ -100,12 +99,14 @@ class TaskWrapper: self, ap: app.Application, coro: typing.Coroutine, - task_type: str = "system", - kind: str = "system_task", - name: str = "", - label: str = "", + task_type: str = 'system', + kind: str = 'system_task', + name: str = '', + label: str = '', context: TaskContext = None, - scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], + scopes: list[core_entities.LifecycleControlScope] = [ + core_entities.LifecycleControlScope.APPLICATION + ], ): self.id = TaskWrapper._id_index TaskWrapper._id_index += 1 @@ -115,7 +116,7 @@ class TaskWrapper: self.task_type = task_type self.kind = kind self.name = name - self.label = label if label != "" else name + self.label = label if label != '' else name self.task.set_name(name) self.scopes = scopes @@ -125,43 +126,46 @@ class TaskWrapper: if self.task_stack is None: self.task_stack = self.task.get_stack() return exception - except: + except Exception: return None def assume_result(self): try: return self.task.result() - except: + except Exception: return None def to_dict(self) -> dict: - exception_traceback = None if self.assume_exception() is not None: exception_traceback = 'Traceback (most recent call last):\n' for frame in self.task_stack: - exception_traceback += f" File \"{frame.f_code.co_filename}\", line {frame.f_lineno}, in {frame.f_code.co_name}\n" + exception_traceback += f' File "{frame.f_code.co_filename}", line {frame.f_lineno}, in {frame.f_code.co_name}\n' - exception_traceback += f" {self.assume_exception().__str__()}\n" + exception_traceback += f' {self.assume_exception().__str__()}\n' return { - "id": self.id, - "task_type": self.task_type, - "kind": self.kind, - "name": self.name, - "label": self.label, - "scopes": [scope.value for scope in self.scopes], - "task_context": self.task_context.to_dict(), - "runtime": { - "done": self.task.done(), - "state": self.task._state, - "exception": self.assume_exception().__str__() if self.assume_exception() is not None else None, - "exception_traceback": exception_traceback, - "result": self.assume_result().__str__() if self.assume_result() is not None else None, + 'id': self.id, + 'task_type': self.task_type, + 'kind': self.kind, + 'name': self.name, + 'label': self.label, + 'scopes': [scope.value for scope in self.scopes], + 'task_context': self.task_context.to_dict(), + 'runtime': { + 'done': self.task.done(), + 'state': self.task._state, + 'exception': self.assume_exception().__str__() + if self.assume_exception() is not None + else None, + 'exception_traceback': exception_traceback, + 'result': self.assume_result().__str__() + if self.assume_result() is not None + else None, }, } - + def cancel(self): self.task.cancel() @@ -182,27 +186,33 @@ class AsyncTaskManager: def create_task( self, coro: typing.Coroutine, - task_type: str = "system", - kind: str = "system-task", - name: str = "", - label: str = "", + task_type: str = 'system', + kind: str = 'system-task', + name: str = '', + label: str = '', context: TaskContext = None, - scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], + scopes: list[core_entities.LifecycleControlScope] = [ + core_entities.LifecycleControlScope.APPLICATION + ], ) -> TaskWrapper: - wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes) + wrapper = TaskWrapper( + self.ap, coro, task_type, kind, name, label, context, scopes + ) self.tasks.append(wrapper) return wrapper def create_user_task( self, coro: typing.Coroutine, - kind: str = "user-task", - name: str = "", - label: str = "", + kind: str = 'user-task', + name: str = '', + label: str = '', context: TaskContext = None, - scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], + scopes: list[core_entities.LifecycleControlScope] = [ + core_entities.LifecycleControlScope.APPLICATION + ], ) -> TaskWrapper: - return self.create_task(coro, "user", kind, name, label, context, scopes) + return self.create_task(coro, 'user', kind, name, label, context, scopes) async def wait_all(self): await asyncio.gather(*[t.task for t in self.tasks], return_exceptions=True) @@ -215,12 +225,12 @@ class AsyncTaskManager: type: str = None, ) -> dict: return { - "tasks": [ + 'tasks': [ t.to_dict() for t in self.tasks if type is None or t.task_type == type ], - "id_index": TaskWrapper._id_index, + 'id_index': TaskWrapper._id_index, } - + def get_task_by_id(self, id: int) -> TaskWrapper | None: for t in self.tasks: if t.id == id: @@ -229,9 +239,7 @@ class AsyncTaskManager: def cancel_by_scope(self, scope: core_entities.LifecycleControlScope): for wrapper in self.tasks: - if not wrapper.task.done() and scope in wrapper.scopes: - wrapper.task.cancel() def cancel_task(self, task_id: int): diff --git a/pkg/discover/engine.py b/pkg/discover/engine.py index e79a97ad..be23a4ac 100644 --- a/pkg/discover/engine.py +++ b/pkg/discover/engine.py @@ -3,8 +3,6 @@ from __future__ import annotations import typing import importlib import os -import inspect -import mimetypes import yaml import pydantic @@ -61,11 +59,9 @@ class Metadata(pydantic.BaseModel): def __init__(self, **kwargs): super().__init__(**kwargs) - + if self.description is None: - self.description = I18nString( - en_US='' - ) + self.description = I18nString(en_US='') if self.icon is None: self.icon = '' @@ -118,47 +114,60 @@ class Component(pydantic.BaseModel): _execution: Execution """组件执行""" - def __init__(self, owner: str, manifest: typing.Dict[str, typing.Any], rel_path: str): + def __init__( + self, owner: str, manifest: typing.Dict[str, typing.Any], rel_path: str + ): super().__init__( owner=owner, manifest=manifest, rel_path=rel_path, - rel_dir=os.path.dirname(rel_path) + rel_dir=os.path.dirname(rel_path), ) self._metadata = Metadata(**manifest['metadata']) self._spec = manifest['spec'] - self._execution = Execution(**manifest['execution']) if 'execution' in manifest else None + self._execution = ( + Execution(**manifest['execution']) if 'execution' in manifest else None + ) @classmethod def is_component_manifest(cls, manifest: typing.Dict[str, typing.Any]) -> bool: """判断是否为组件清单""" - return 'apiVersion' in manifest and 'kind' in manifest and 'metadata' in manifest and 'spec' in manifest + return ( + 'apiVersion' in manifest + and 'kind' in manifest + and 'metadata' in manifest + and 'spec' in manifest + ) @property def kind(self) -> str: """组件类型""" return self.manifest['kind'] - + @property def metadata(self) -> Metadata: """组件元数据""" return self._metadata - + @property def spec(self) -> typing.Dict[str, typing.Any]: """组件规格""" return self._spec - + @property def execution(self) -> Execution: """组件可执行文件信息""" return self._execution - + @property def icon_rel_path(self) -> str: """图标相对路径""" - return os.path.join(self.rel_dir, self.metadata.icon) if self.metadata.icon is not None and self.metadata.icon.strip() != '' else None - + return ( + os.path.join(self.rel_dir, self.metadata.icon) + if self.metadata.icon is not None and self.metadata.icon.strip() != '' + else None + ) + def get_python_component_class(self) -> typing.Type[typing.Any]: """获取Python组件类""" module_path = os.path.join(self.rel_dir, self.execution.python.path) @@ -167,7 +176,7 @@ class Component(pydantic.BaseModel): module_path = module_path.replace('/', '.').replace('\\', '.') module = importlib.import_module(module_path) return getattr(module, self.execution.python.attr) - + def to_plain_dict(self) -> dict: """转换为平铺字典""" return { @@ -175,7 +184,7 @@ class Component(pydantic.BaseModel): 'label': self.metadata.label.to_dict(), 'description': self.metadata.description.to_dict(), 'icon': self.metadata.icon, - 'spec': self.spec + 'spec': self.spec, } @@ -191,24 +200,28 @@ class ComponentDiscoveryEngine: def __init__(self, ap: app.Application): self.ap = ap - def load_component_manifest(self, path: str, owner: str = 'builtin', no_save: bool = False) -> Component | None: + def load_component_manifest( + self, path: str, owner: str = 'builtin', no_save: bool = False + ) -> Component | None: """加载组件清单""" with open(path, 'r', encoding='utf-8') as f: manifest = yaml.safe_load(f) if not Component.is_component_manifest(manifest): return None - comp = Component( - owner=owner, - manifest=manifest, - rel_path=path - ) + comp = Component(owner=owner, manifest=manifest, rel_path=path) if not no_save: if comp.kind not in self.components: self.components[comp.kind] = [] self.components[comp.kind].append(comp) return comp - - def load_component_manifests_in_dir(self, path: str, owner: str = 'builtin', no_save: bool = False, max_depth: int = 1) -> typing.List[Component]: + + def load_component_manifests_in_dir( + self, + path: str, + owner: str = 'builtin', + no_save: bool = False, + max_depth: int = 1, + ) -> typing.List[Component]: """加载目录中的组件清单""" components: typing.List[Component] = [] @@ -216,17 +229,25 @@ class ComponentDiscoveryEngine: if depth > max_depth: return for file in os.listdir(path): - if (not os.path.isdir(os.path.join(path, file))) and (file.endswith('.yaml') or file.endswith('.yml')): - comp = self.load_component_manifest(os.path.join(path, file), owner, no_save) + if (not os.path.isdir(os.path.join(path, file))) and ( + file.endswith('.yaml') or file.endswith('.yml') + ): + comp = self.load_component_manifest( + os.path.join(path, file), owner, no_save + ) if comp is not None: components.append(comp) elif os.path.isdir(os.path.join(path, file)): - recursive_load_component_manifests_in_dir(os.path.join(path, file), depth + 1) + recursive_load_component_manifests_in_dir( + os.path.join(path, file), depth + 1 + ) recursive_load_component_manifests_in_dir(path) return components - - def load_blueprint_comp_group(self, group: dict, owner: str = 'builtin', no_save: bool = False) -> typing.List[Component]: + + def load_blueprint_comp_group( + self, group: dict, owner: str = 'builtin', no_save: bool = False + ) -> typing.List[Component]: """加载蓝图组件组""" components: typing.List[Component] = [] if 'fromFiles' in group: @@ -238,12 +259,18 @@ class ComponentDiscoveryEngine: for dir in group['fromDirs']: path = dir['path'] max_depth = dir['maxDepth'] if 'maxDepth' in dir else 1 - components.extend(self.load_component_manifests_in_dir(path, owner, no_save, max_depth)) + components.extend( + self.load_component_manifests_in_dir( + path, owner, no_save, max_depth + ) + ) return components def discover_blueprint(self, blueprint_manifest_path: str, owner: str = 'builtin'): """发现蓝图""" - blueprint_manifest = self.load_component_manifest(blueprint_manifest_path, owner, no_save=True) + blueprint_manifest = self.load_component_manifest( + blueprint_manifest_path, owner, no_save=True + ) if blueprint_manifest is None: raise ValueError(f'Invalid blueprint manifest: {blueprint_manifest_path}') assert blueprint_manifest.kind == 'Blueprint', '`Kind` must be `Blueprint`' @@ -251,13 +278,15 @@ class ComponentDiscoveryEngine: # load ComponentTemplate first if 'ComponentTemplate' in blueprint_manifest.spec['components']: - components['ComponentTemplate'] = self.load_blueprint_comp_group(blueprint_manifest.spec['components']['ComponentTemplate'], owner) + components['ComponentTemplate'] = self.load_blueprint_comp_group( + blueprint_manifest.spec['components']['ComponentTemplate'], owner + ) for name, component in blueprint_manifest.spec['components'].items(): if name == 'ComponentTemplate': continue components[name] = self.load_blueprint_comp_group(component, owner) - + self.ap.logger.debug(f'Components: {components}') return blueprint_manifest, components @@ -268,7 +297,9 @@ class ComponentDiscoveryEngine: return [] return self.components[kind] - def find_components(self, kind: str, component_list: typing.List[Component]) -> typing.List[Component]: + def find_components( + self, kind: str, component_list: typing.List[Component] + ) -> typing.List[Component]: """查找组件""" result: typing.List[Component] = [] for component in component_list: diff --git a/pkg/entity/persistence/base.py b/pkg/entity/persistence/base.py index 9d9ea759..b0d8b5db 100644 --- a/pkg/entity/persistence/base.py +++ b/pkg/entity/persistence/base.py @@ -1,5 +1,4 @@ import sqlalchemy.orm -import pydantic class Base(sqlalchemy.orm.DeclarativeBase): diff --git a/pkg/entity/persistence/bot.py b/pkg/entity/persistence/bot.py index 0cd7bce7..86932cac 100644 --- a/pkg/entity/persistence/bot.py +++ b/pkg/entity/persistence/bot.py @@ -5,6 +5,7 @@ from .base import Base class Bot(Base): """机器人""" + __tablename__ = 'bots' uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) @@ -15,5 +16,12 @@ class Bot(Base): enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False) use_pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) use_pipeline_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) - created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) - updated_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now()) + created_at = sqlalchemy.Column( + sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() + ) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/pkg/entity/persistence/metadata.py b/pkg/entity/persistence/metadata.py index e1ebaefd..d9e03663 100644 --- a/pkg/entity/persistence/metadata.py +++ b/pkg/entity/persistence/metadata.py @@ -13,6 +13,7 @@ initial_metadata = [ class Metadata(Base): """数据库元数据""" + __tablename__ = 'metadata' key = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True) diff --git a/pkg/entity/persistence/model.py b/pkg/entity/persistence/model.py index 13700f25..65e016f3 100644 --- a/pkg/entity/persistence/model.py +++ b/pkg/entity/persistence/model.py @@ -1,11 +1,11 @@ import sqlalchemy -import datetime from .base import Base class LLMModel(Base): """LLM 模型""" + __tablename__ = 'llm_models' uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) @@ -16,5 +16,12 @@ class LLMModel(Base): api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[]) extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) - created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) - updated_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now()) \ No newline at end of file + created_at = sqlalchemy.Column( + sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() + ) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/pkg/entity/persistence/pipeline.py b/pkg/entity/persistence/pipeline.py index aaa393a7..ca854203 100644 --- a/pkg/entity/persistence/pipeline.py +++ b/pkg/entity/persistence/pipeline.py @@ -5,13 +5,21 @@ from .base import Base class LegacyPipeline(Base): """旧版流水线""" + __tablename__ = 'legacy_pipelines' uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) - updated_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now()) + created_at = sqlalchemy.Column( + sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() + ) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) for_version = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) is_default = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False) @@ -21,13 +29,21 @@ class LegacyPipeline(Base): class PipelineRunRecord(Base): """流水线运行记录""" + __tablename__ = 'pipeline_run_records' uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) pipeline_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) status = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) - updated_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now()) + created_at = sqlalchemy.Column( + sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() + ) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) started_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False) finished_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False) result = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) diff --git a/pkg/entity/persistence/plugin.py b/pkg/entity/persistence/plugin.py index b1e2cac4..94d6b8b4 100644 --- a/pkg/entity/persistence/plugin.py +++ b/pkg/entity/persistence/plugin.py @@ -5,6 +5,7 @@ from .base import Base class PluginSetting(Base): """插件配置""" + __tablename__ = 'plugin_settings' plugin_author = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True) @@ -12,5 +13,12 @@ class PluginSetting(Base): enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True) priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0) config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict) - created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) - updated_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now()) + created_at = sqlalchemy.Column( + sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() + ) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/pkg/entity/persistence/user.py b/pkg/entity/persistence/user.py index 23d309c4..a0d9f168 100644 --- a/pkg/entity/persistence/user.py +++ b/pkg/entity/persistence/user.py @@ -1,5 +1,4 @@ import sqlalchemy -import sqlmodel from .base import Base @@ -10,5 +9,12 @@ class User(Base): id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) user = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) password = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) - updated_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now()) \ No newline at end of file + created_at = sqlalchemy.Column( + sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() + ) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/pkg/persistence/database.py b/pkg/persistence/database.py index 0dd82817..528c6a34 100644 --- a/pkg/persistence/database.py +++ b/pkg/persistence/database.py @@ -9,6 +9,7 @@ from ..core import app preregistered_managers: list[type[BaseDatabaseManager]] = [] + def manager_class(name: str) -> None: """注册一个数据库管理类""" diff --git a/pkg/persistence/databases/sqlite.py b/pkg/persistence/databases/sqlite.py index 0bc3db32..1b12def8 100644 --- a/pkg/persistence/databases/sqlite.py +++ b/pkg/persistence/databases/sqlite.py @@ -5,10 +5,12 @@ import sqlalchemy.ext.asyncio as sqlalchemy_asyncio from .. import database -@database.manager_class("sqlite") +@database.manager_class('sqlite') class SQLiteDatabaseManager(database.BaseDatabaseManager): """SQLite 数据库管理类""" - + async def initialize(self) -> None: sqlite_path = 'data/langbot.db' - self.engine = sqlalchemy_asyncio.create_async_engine(f"sqlite+aiosqlite:///{sqlite_path}") + self.engine = sqlalchemy_asyncio.create_async_engine( + f'sqlite+aiosqlite:///{sqlite_path}' + ) diff --git a/pkg/persistence/mgr.py b/pkg/persistence/mgr.py index f17b169e..e8f953ab 100644 --- a/pkg/persistence/mgr.py +++ b/pkg/persistence/mgr.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import datetime import typing import json @@ -10,12 +9,16 @@ import sqlalchemy.ext.asyncio as sqlalchemy_asyncio import sqlalchemy from . import database, migration -from ..entity.persistence import base, user, model, pipeline, bot, plugin, metadata +from ..entity.persistence import base, pipeline, metadata +from ..entity import persistence from ..core import app -from .databases import sqlite -from ..utils import constants -from .migrations import dbm001_migrate_v3_config +from ..utils import constants, importutil from ..api.http.service import pipeline as pipeline_service +from . import databases, migrations + +importutil.import_modules_in_pkg(databases) +importutil.import_modules_in_pkg(migrations) +importutil.import_modules_in_pkg(persistence) class PersistenceManager: @@ -33,9 +36,8 @@ class PersistenceManager: self.meta = base.Base.metadata async def initialize(self): + self.ap.logger.info('Initializing database...') - self.ap.logger.info("Initializing database...") - for manager in database.preregistered_managers: self.db = manager(self.ap) await self.db.initialize() @@ -43,7 +45,6 @@ class PersistenceManager: await self.create_tables() async def create_tables(self): - # create tables async with self.get_db_engine().connect() as conn: await conn.run_sync(self.meta.create_all) @@ -53,26 +54,28 @@ class PersistenceManager: # ======= write initial data ======= # write initial metadata - self.ap.logger.info("Creating initial metadata...") + self.ap.logger.info('Creating initial metadata...') for item in metadata.initial_metadata: # check if the item exists result = await self.execute_async( - sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key']) + sqlalchemy.select(metadata.Metadata).where( + metadata.Metadata.key == item['key'] + ) ) row = result.first() if row is None: await self.execute_async( sqlalchemy.insert(metadata.Metadata).values(item) ) - - # write default pipeline - result = await self.execute_async( - sqlalchemy.select(pipeline.LegacyPipeline) - ) - if result.first() is None: - self.ap.logger.info("Creating default pipeline...") - pipeline_config = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8')) + # write default pipeline + result = await self.execute_async(sqlalchemy.select(pipeline.LegacyPipeline)) + if result.first() is None: + self.ap.logger.info('Creating default pipeline...') + + pipeline_config = json.load( + open('templates/default-pipeline-config.json', 'r', encoding='utf-8') + ) pipeline_data = { 'uuid': str(uuid.uuid4()), @@ -91,7 +94,9 @@ class PersistenceManager: # run migrations database_version = await self.execute_async( - sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version') + sqlalchemy.select(metadata.Metadata).where( + metadata.Metadata.key == 'database_version' + ) ) database_version = int(database_version.fetchone()[1]) @@ -106,24 +111,27 @@ class PersistenceManager: for migration_cls in migrations: migration_instance = migration_cls(self.ap) - if migration_instance.number > database_version and migration_instance.number <= required_database_version: + if ( + migration_instance.number > database_version + and migration_instance.number <= required_database_version + ): await migration_instance.upgrade() await self.execute_async( - sqlalchemy.update(metadata.Metadata).where(metadata.Metadata.key == 'database_version').values( - { - 'value': str(migration_instance.number) - } - ) + sqlalchemy.update(metadata.Metadata) + .where(metadata.Metadata.key == 'database_version') + .values({'value': str(migration_instance.number)}) ) last_migration_number = migration_instance.number - self.ap.logger.info(f'Migration {migration_instance.number} completed.') - - self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.') + self.ap.logger.info( + f'Migration {migration_instance.number} completed.' + ) + + self.ap.logger.info( + f'Successfully upgraded database to version {last_migration_number}.' + ) async def execute_async( - self, - *args, - **kwargs + self, *args, **kwargs ) -> sqlalchemy.engine.cursor.CursorResult: async with self.get_db_engine().connect() as conn: result = await conn.execute(*args, **kwargs) @@ -132,9 +140,13 @@ class PersistenceManager: def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine: return self.db.get_engine() - - def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict: + + def serialize_model( + self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base + ) -> dict: return { - column.name: getattr(data, column.name) if not isinstance(getattr(data, column.name), (datetime.datetime)) else getattr(data, column.name).isoformat() + column.name: getattr(data, column.name) + if not isinstance(getattr(data, column.name), (datetime.datetime)) + else getattr(data, column.name).isoformat() for column in model.__table__.columns } diff --git a/pkg/persistence/migration.py b/pkg/persistence/migration.py index 81a3aac3..c191b686 100644 --- a/pkg/persistence/migration.py +++ b/pkg/persistence/migration.py @@ -8,6 +8,7 @@ from ..core import app preregistered_db_migrations: list[typing.Type[DBMigration]] = [] + def migration_class(number: int): """迁移类装饰器""" @@ -15,6 +16,7 @@ def migration_class(number: int): cls.number = number preregistered_db_migrations.append(cls) return cls + return wrapper diff --git a/pkg/persistence/migrations/dbm001_migrate_v3_config.py b/pkg/persistence/migrations/dbm001_migrate_v3_config.py index afed5eea..6aee2854 100644 --- a/pkg/persistence/migrations/dbm001_migrate_v3_config.py +++ b/pkg/persistence/migrations/dbm001_migrate_v3_config.py @@ -1,5 +1,3 @@ -from .. import migration - # TODO fill this # @migration.migration_class(1) # class DBMigrationV3(migration.DBMigration): @@ -10,4 +8,4 @@ from .. import migration # pass # async def downgrade(self): -# """降级""" \ No newline at end of file +# """降级""" diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index 38fb9794..dad6a3ab 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -1,15 +1,13 @@ from __future__ import annotations -import re from .. import stage, entities from ...core import entities as core_entities -from ...config import manager as cfg_mgr @stage.stage_class('BanSessionCheckStage') class BanSessionCheckStage(stage.PipelineStage): """访问控制处理阶段 - + 仅检查query中群号或个人号是否在访问控制列表中。 """ @@ -17,26 +15,24 @@ class BanSessionCheckStage(stage.PipelineStage): pass async def process( - self, - query: core_entities.Query, - stage_inst_name: str + self, query: core_entities.Query, stage_inst_name: str ) -> entities.StageProcessResult: - found = False mode = query.pipeline_config['trigger']['access-control']['mode'] sess_list = query.pipeline_config['trigger']['access-control'][mode] - if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \ - or (query.launcher_type.value == 'person' and 'person_*' in sess_list): + if (query.launcher_type.value == 'group' and 'group_*' in sess_list) or ( + query.launcher_type.value == 'person' and 'person_*' in sess_list + ): found = True else: for sess in sess_list: - if sess == f"{query.launcher_type.value}_{query.launcher_id}": + if sess == f'{query.launcher_type.value}_{query.launcher_id}': found = True break - + ctn = False if mode == 'whitelist': @@ -45,7 +41,11 @@ class BanSessionCheckStage(stage.PipelineStage): ctn = not found return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT, + result_type=entities.ResultType.CONTINUE + if ctn + else entities.ResultType.INTERRUPT, new_query=query, - console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else '' + console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' + if not ctn + else '', ) diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index dbf7c52e..6547cb16 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -4,20 +4,21 @@ from ...core import app from .. import stage, entities from ...core import entities as core_entities -from ...config import manager as cfg_mgr from . import filter as filter_model, entities as filter_entities -from .filters import cntignore, banwords, baiduexamine from ...provider import entities as llm_entities from ...platform.types import message as platform_message -from ...platform.types import events as platform_events -from ...platform.types import entities as platform_entities +from ...utils import importutil + +from . import filters + +importutil.import_modules_in_pkg(filters) @stage.stage_class('PostContentFilterStage') @stage.stage_class('PreContentFilterStage') class ContentFilterStage(stage.PipelineStage): """内容过滤阶段 - + 前置: 检查消息是否符合规则,不符合则拦截。 改写: @@ -36,13 +37,12 @@ class ContentFilterStage(stage.PipelineStage): super().__init__(ap) async def initialize(self, pipeline_config: dict): - filters_required = [ - "content-ignore", + 'content-ignore', ] if pipeline_config['safety']['content-filter']['check-sensitive-words']: - filters_required.append("ban-word-filter") + filters_required.append('ban-word-filter') # TODO revert it # if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: @@ -50,9 +50,7 @@ class ContentFilterStage(stage.PipelineStage): for filter in filter_model.preregistered_filters: if filter.name in filters_required: - self.filter_chain.append( - filter(self.ap) - ) + self.filter_chain.append(filter(self.ap)) for filter in self.filter_chain: await filter.initialize() @@ -68,8 +66,7 @@ class ContentFilterStage(stage.PipelineStage): if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg': return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) else: for filter in self.filter_chain: @@ -78,26 +75,25 @@ class ContentFilterStage(stage.PipelineStage): if result.level in [ filter_entities.ResultLevel.BLOCK, - filter_entities.ResultLevel.MASKED + filter_entities.ResultLevel.MASKED, ]: return entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, new_query=query, user_notice=result.user_notice, - console_notice=result.console_notice + console_notice=result.console_notice, ) elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 message = result.replacement - + query.message_chain = platform_message.MessageChain( platform_message.Plain(message) ) return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) - + async def _post_process( self, message: str, @@ -108,8 +104,7 @@ class ContentFilterStage(stage.PipelineStage): """ if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg': return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) else: message = message.strip() @@ -122,30 +117,25 @@ class ContentFilterStage(stage.PipelineStage): result_type=entities.ResultType.INTERRUPT, new_query=query, user_notice=result.user_notice, - console_notice=result.console_notice + console_notice=result.console_notice, ) elif result.level in [ filter_entities.ResultLevel.PASS, - filter_entities.ResultLevel.MASKED + filter_entities.ResultLevel.MASKED, ]: message = result.replacement query.resp_messages[-1].content = message return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) async def process( - self, - query: core_entities.Query, - stage_inst_name: str + self, query: core_entities.Query, stage_inst_name: str ) -> entities.StageProcessResult: - """处理 - """ + """处理""" if stage_inst_name == 'PreContentFilterStage': - contain_non_text = False text_components = [platform_message.Plain, platform_message.Source] @@ -156,28 +146,24 @@ class ContentFilterStage(stage.PipelineStage): break if contain_non_text: - self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。") + self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。') return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) - return await self._pre_process( - str(query.message_chain).strip(), - query - ) + return await self._pre_process(str(query.message_chain).strip(), query) elif stage_inst_name == 'PostContentFilterStage': # 仅处理 query.resp_messages[-1].content 是 str 的情况 - if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(query.resp_messages[-1].content, str): - return await self._post_process( - query.resp_messages[-1].content, - query - ) + if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance( + query.resp_messages[-1].content, str + ): + return await self._post_process(query.resp_messages[-1].content, query) else: - self.ap.logger.debug(f"resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。") + self.ap.logger.debug( + 'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。' + ) return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) else: raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}') diff --git a/pkg/pipeline/cntfilter/entities.py b/pkg/pipeline/cntfilter/entities.py index b4bc0f7e..5e804c0d 100644 --- a/pkg/pipeline/cntfilter/entities.py +++ b/pkg/pipeline/cntfilter/entities.py @@ -1,14 +1,11 @@ - -import typing import enum import pydantic.v1 as pydantic -from ...provider import entities as llm_entities - class ResultLevel(enum.Enum): """结果等级""" + PASS = enum.auto() """通过""" @@ -24,6 +21,7 @@ class ResultLevel(enum.Enum): class EnableStage(enum.Enum): """启用阶段""" + PRE = enum.auto() """预处理""" @@ -55,14 +53,15 @@ class FilterResult(pydantic.BaseModel): class ManagerResultLevel(enum.Enum): """处理器结果等级""" + CONTINUE = enum.auto() """继续""" INTERRUPT = enum.auto() """中断""" -class FilterManagerResult(pydantic.BaseModel): +class FilterManagerResult(pydantic.BaseModel): level: ManagerResultLevel replacement: str diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 970e11f1..ae7ceb79 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -5,14 +5,13 @@ import typing from ...core import app, entities as core_entities from . import entities -from ...provider import entities as llm_entities preregistered_filters: list[typing.Type[ContentFilter]] = [] def filter_class( - name: str + name: str, ) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: """内容过滤器类装饰器 @@ -22,6 +21,7 @@ def filter_class( Returns: typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器 """ + def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]: assert issubclass(cls, ContentFilter) @@ -53,23 +53,21 @@ class ContentFilter(metaclass=abc.ABCMeta): entity.EnableStage.PRE: 消息请求AI前,此时需要检查的内容是用户的输入消息。 entity.EnableStage.POST: 消息请求AI后,此时需要检查的内容是AI的回复消息。 """ - return [ - entities.EnableStage.PRE, - entities.EnableStage.POST - ] + return [entities.EnableStage.PRE, entities.EnableStage.POST] async def initialize(self): - """初始化过滤器 - """ + """初始化过滤器""" pass @abc.abstractmethod - async def process(self, query: core_entities.Query, message: str=None, image_url=None) -> entities.FilterResult: + async def process( + self, query: core_entities.Query, message: str = None, image_url=None + ) -> entities.FilterResult: """处理消息 分为前后阶段,具体取决于 enable_stages 的值。 对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。 - + Args: message (str): 需要检查的内容 image_url (str): 要检查的图片的 URL diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py index 800f0099..c3776bc9 100644 --- a/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -7,11 +7,11 @@ from .. import filter as filter_model from ....core import entities as core_entities -BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}" -BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" +BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}' +BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token' -@filter_model.filter_class("baidu-cloud-examine") +@filter_model.filter_class('baidu-cloud-examine') class BaiduCloudExamine(filter_model.ContentFilter): """百度云内容审核""" @@ -20,44 +20,52 @@ class BaiduCloudExamine(filter_model.ContentFilter): async with session.post( BAIDU_EXAMINE_TOKEN_URL, params={ - "grant_type": "client_credentials", - "client_id": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'], - "client_secret": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'] - } + 'grant_type': 'client_credentials', + 'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine'][ + 'api-key' + ], + 'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine'][ + 'api-secret' + ], + }, ) as resp: return (await resp.json())['access_token'] - async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: - + async def process( + self, query: core_entities.Query, message: str + ) -> entities.FilterResult: async with aiohttp.ClientSession() as session: async with session.post( BAIDU_EXAMINE_URL.format(await self._get_token()), - headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'}, - data=f"text={message}".encode('utf-8') + headers={ + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json', + }, + data=f'text={message}'.encode('utf-8'), ) as resp: result = await resp.json() - if "error_code" in result: + if 'error_code' in result: return entities.FilterResult( level=entities.ResultLevel.BLOCK, replacement=message, user_notice='', - console_notice=f"百度云判定出错,错误信息:{result['error_msg']}" + console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}', ) else: - conclusion = result["conclusion"] + conclusion = result['conclusion'] - if conclusion in ("合规"): + if conclusion in ('合规'): return entities.FilterResult( level=entities.ResultLevel.PASS, replacement=message, user_notice='', - console_notice=f"百度云判定结果:{conclusion}" + console_notice=f'百度云判定结果:{conclusion}', ) else: return entities.FilterResult( level=entities.ResultLevel.BLOCK, replacement=message, - user_notice="消息中存在不合适的内容, 请修改", - console_notice=f"百度云判定结果:{conclusion}" + user_notice='消息中存在不合适的内容, 请修改', + console_notice=f'百度云判定结果:{conclusion}', ) diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index cd3d412c..598fa299 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -6,14 +6,16 @@ from .. import entities from ....core import entities as core_entities -@filter_model.filter_class("ban-word-filter") +@filter_model.filter_class('ban-word-filter') class BanWordFilter(filter_model.ContentFilter): """根据内容过滤""" async def initialize(self): pass - async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: + async def process( + self, query: core_entities.Query, message: str + ) -> entities.FilterResult: found = False for word in self.ap.sensitive_meta.data['words']: @@ -23,9 +25,10 @@ class BanWordFilter(filter_model.ContentFilter): found = True for i in range(len(match)): - if self.ap.sensitive_meta.data['mask_word'] == "": + if self.ap.sensitive_meta.data['mask_word'] == '': message = message.replace( - match[i], self.ap.sensitive_meta.data['mask'] * len(match[i]) + match[i], + self.ap.sensitive_meta.data['mask'] * len(match[i]), ) else: message = message.replace( @@ -36,5 +39,5 @@ class BanWordFilter(filter_model.ContentFilter): level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS, replacement=message, user_notice='消息中存在不合适的内容, 请修改' if found else '', - console_notice='' - ) \ No newline at end of file + console_notice='', + ) diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index 381d5c51..cb563593 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -6,7 +6,7 @@ from .. import filter as filter_model from ....core import entities as core_entities -@filter_model.filter_class("content-ignore") +@filter_model.filter_class('content-ignore') class ContentIgnore(filter_model.ContentFilter): """根据内容忽略消息""" @@ -16,7 +16,9 @@ class ContentIgnore(filter_model.ContentFilter): entities.EnableStage.PRE, ] - async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: + async def process( + self, query: core_entities.Query, message: str + ) -> entities.FilterResult: if 'prefix' in query.pipeline_config['trigger']['ignore-rules']: for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']: if message.startswith(rule): @@ -24,9 +26,9 @@ class ContentIgnore(filter_model.ContentFilter): level=entities.ResultLevel.BLOCK, replacement='', user_notice='', - console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息' + console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息', ) - + if 'regexp' in query.pipeline_config['trigger']['ignore-rules']: for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']: if re.search(rule, message): @@ -34,12 +36,12 @@ class ContentIgnore(filter_model.ContentFilter): level=entities.ResultLevel.BLOCK, replacement='', user_notice='', - console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息' + console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息', ) return entities.FilterResult( level=entities.ResultLevel.PASS, replacement=message, user_notice='', - console_notice='' - ) \ No newline at end of file + console_notice='', + ) diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index a7cf2153..2ad1690f 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -1,18 +1,14 @@ from __future__ import annotations import asyncio -import typing import traceback from ..core import app, entities -from . import entities as pipeline_entities -from ..plugin import events -from ..platform.types import message as platform_message class Controller: - """总控制器 - """ + """总控制器""" + ap: app.Application semaphore: asyncio.Semaphore = None @@ -20,11 +16,12 @@ class Controller: def __init__(self, ap: app.Application): self.ap = ap - self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline']) + self.semaphore = asyncio.Semaphore( + self.ap.instance_config.data['concurrency']['pipeline'] + ) async def consumer(self): - """事件处理循环 - """ + """事件处理循环""" try: while True: selected_query: entities.Query = None @@ -35,7 +32,9 @@ class Controller: for query in queries: session = await self.ap.sess_mgr.get_session(query) - self.ap.logger.debug(f"Checking query {query} session {session}") + self.ap.logger.debug( + f'Checking query {query} session {session}' + ) if not session.semaphore.locked(): selected_query = query @@ -56,30 +55,40 @@ class Controller: # find pipeline # Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one. # Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected. - bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid) + bot = await self.ap.platform_mgr.get_bot_by_uuid( + selected_query.bot_uuid + ) if bot: - pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(bot.bot_entity.use_pipeline_uuid) + pipeline = ( + await self.ap.pipeline_mgr.get_pipeline_by_uuid( + bot.bot_entity.use_pipeline_uuid + ) + ) if pipeline: await pipeline.run(selected_query) - + async with self.ap.query_pool: - (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() + ( + await self.ap.sess_mgr.get_session(selected_query) + ).semaphore.release() # 通知其他协程,有新的请求可以处理了 self.ap.query_pool.condition.notify_all() self.ap.task_mgr.create_task( _process_query(selected_query), - kind="query", - name=f"query-{selected_query.query_id}", - scopes=[entities.LifecycleControlScope.APPLICATION, entities.LifecycleControlScope.PLATFORM], + kind='query', + name=f'query-{selected_query.query_id}', + scopes=[ + entities.LifecycleControlScope.APPLICATION, + entities.LifecycleControlScope.PLATFORM, + ], ) except Exception as e: # traceback.print_exc() - self.ap.logger.error(f"控制器循环出错: {e}") - self.ap.logger.error(f"Traceback: {traceback.format_exc()}") + self.ap.logger.error(f'控制器循环出错: {e}') + self.ap.logger.error(f'Traceback: {traceback.format_exc()}') async def run(self): - """运行控制器 - """ + """运行控制器""" await self.consumer() diff --git a/pkg/pipeline/entities.py b/pkg/pipeline/entities.py index ffcc4654..dd6434c0 100644 --- a/pkg/pipeline/entities.py +++ b/pkg/pipeline/entities.py @@ -10,7 +10,6 @@ from ..core import entities class ResultType(enum.Enum): - CONTINUE = enum.auto() """继续流水线""" @@ -19,12 +18,18 @@ class ResultType(enum.Enum): class StageProcessResult(pydantic.BaseModel): - result_type: ResultType new_query: entities.Query - user_notice: typing.Optional[typing.Union[str, list[platform_message.MessageComponent], platform_message.MessageChain, None]] = [] + user_notice: typing.Optional[ + typing.Union[ + str, + list[platform_message.MessageComponent], + platform_message.MessageChain, + None, + ] + ] = [] """只要设置了就会发送给用户""" console_notice: typing.Optional[str] = '' diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index ac03ad42..ab20f3eb 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -2,18 +2,19 @@ from __future__ import annotations import os import traceback -from PIL import Image, ImageDraw, ImageFont -from ...core import app from . import strategy -from .strategies import image, forward from .. import stage, entities from ...core import entities as core_entities -from ...config import manager as cfg_mgr from ...platform.types import message as platform_message +from ...utils import importutil + +from . import strategies + +importutil.import_modules_in_pkg(strategies) -@stage.stage_class("LongTextProcessStage") +@stage.stage_class('LongTextProcessStage') class LongTextProcessStage(stage.PipelineStage): """长消息处理阶段 @@ -31,34 +32,48 @@ class LongTextProcessStage(stage.PipelineStage): # 检查是否存在 if not os.path.exists(use_font): # 若是windows系统,使用微软雅黑 - if os.name == "nt": - use_font = "C:/Windows/Fonts/msyh.ttc" + if os.name == 'nt': + use_font = 'C:/Windows/Fonts/msyh.ttc' if not os.path.exists(use_font): - self.ap.logger.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。") - config['blob_message_strategy'] = "forward" + self.ap.logger.warn( + '未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。' + ) + config['blob_message_strategy'] = 'forward' else: - self.ap.logger.info("使用Windows自带字体:" + use_font) + self.ap.logger.info('使用Windows自带字体:' + use_font) config['font-path'] = use_font else: - self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。") + self.ap.logger.warn( + '未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。' + ) - pipeline_config['output']['long-text-processing']['strategy'] = "forward" - except: + pipeline_config['output']['long-text-processing'][ + 'strategy' + ] = 'forward' + except Exception: traceback.print_exc() - self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font)) + self.ap.logger.error( + '加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'.format( + use_font + ) + ) - pipeline_config['output']['long-text-processing']['strategy'] = "forward" + pipeline_config['output']['long-text-processing']['strategy'] = ( + 'forward' + ) for strategy_cls in strategy.preregistered_strategies: if strategy_cls.name == config['strategy']: self.strategy_impl = strategy_cls(self.ap) break else: - raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略") + raise ValueError(f'未找到名为 {config["strategy"]} 的长消息处理策略') await self.strategy_impl.initialize() - - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + + async def process( + self, query: core_entities.Query, stage_inst_name: str + ) -> entities.StageProcessResult: # 检查是否包含非 Plain 组件 contains_non_plain = False @@ -66,13 +81,19 @@ class LongTextProcessStage(stage.PipelineStage): if not isinstance(msg, platform_message.Plain): contains_non_plain = True break - + if contains_non_plain: - self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。") - elif len(str(query.resp_message_chain[-1])) > query.pipeline_config['output']['long-text-processing']['threshold']: - query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)) + self.ap.logger.debug('消息中包含非 Plain 组件,跳过长消息处理。') + elif ( + len(str(query.resp_message_chain[-1])) + > query.pipeline_config['output']['long-text-processing']['threshold'] + ): + query.resp_message_chain[-1] = platform_message.MessageChain( + await self.strategy_impl.process( + str(query.resp_message_chain[-1]), query + ) + ) return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py index 7abb9c6e..57084d76 100644 --- a/pkg/pipeline/longtext/strategies/forward.py +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -1,8 +1,6 @@ # 转发消息组件 from __future__ import annotations -import typing -import pydantic.v1 as pydantic from .. import strategy as strategy_model from ....core import entities as core_entities @@ -13,29 +11,27 @@ ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay Forward = platform_message.Forward -@strategy_model.strategy_class("forward") +@strategy_model.strategy_class('forward') class ForwardComponentStrategy(strategy_model.LongTextStrategy): - - async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: + async def process( + self, message: str, query: core_entities.Query + ) -> list[platform_message.MessageComponent]: display = ForwardMessageDiaplay( - title="群聊的聊天记录", - brief="[聊天记录]", - source="聊天记录", - preview=["QQ用户: "+message], - summary="查看1条转发消息" + title='群聊的聊天记录', + brief='[聊天记录]', + source='聊天记录', + preview=['QQ用户: ' + message], + summary='查看1条转发消息', ) node_list = [ platform_message.ForwardMessageNode( sender_id=query.adapter.bot_account_id, sender_name='QQ用户', - message_chain=platform_message.MessageChain([message]) + message_chain=platform_message.MessageChain([message]), ) ] - forward = Forward( - display=display, - node_list=node_list - ) + forward = Forward(display=display, node_list=node_list) return [forward] diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py index b30d3a81..26c4b731 100644 --- a/pkg/pipeline/longtext/strategies/image.py +++ b/pkg/pipeline/longtext/strategies/image.py @@ -1,6 +1,5 @@ from __future__ import annotations -import typing import os import base64 import time @@ -15,26 +14,30 @@ from .. import strategy as strategy_model from ....core import entities as core_entities -@strategy_model.strategy_class("image") +@strategy_model.strategy_class('image') class Text2ImageStrategy(strategy_model.LongTextStrategy): - async def initialize(self): pass @functools.lru_cache(maxsize=16) def get_font(self, query: core_entities.Query): - return ImageFont.truetype(query.pipeline_config['output']['long-text-processing']['font-path'], 32, encoding="utf-8") - - async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: + return ImageFont.truetype( + query.pipeline_config['output']['long-text-processing']['font-path'], + 32, + encoding='utf-8', + ) + + async def process( + self, message: str, query: core_entities.Query + ) -> list[platform_message.MessageComponent]: img_path = self.text_to_image( text_str=message, save_as='temp/{}.png'.format(int(time.time())), - query=query + query=query, ) compressed_path, size = self.compress_image( - img_path, - outfile="temp/{}_compressed.png".format(int(time.time())) + img_path, outfile='temp/{}_compressed.png'.format(int(time.time())) ) with open(compressed_path, 'rb') as f: @@ -93,13 +96,11 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): resultIndex.append(v) return resultIndex - def get_size(self, file): # 获取文件大小:KB size = os.path.getsize(file) return size / 1024 - def get_outfile(self, infile, outfile): if outfile: return outfile @@ -107,7 +108,6 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): outfile = '{}-out{}'.format(dir, suffix) return outfile - def compress_image(self, infile, outfile='', kb=100, step=20, quality=90): """不改变图片尺寸压缩到指定大小 :param infile: 压缩源文件 @@ -130,24 +130,28 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): o_size = self.get_size(outfile) return outfile, self.get_size(outfile) + def text_to_image( + self, + text_str: str, + save_as='temp.png', + width=800, + query: core_entities.Query = None, + ): + text_str = text_str.replace('\t', ' ') - def text_to_image(self, text_str: str, save_as="temp.png", width=800, query: core_entities.Query = None): - - text_str = text_str.replace("\t", " ") - # 分行 lines = text_str.split('\n') # 计算并分割 final_lines = [] - text_width = width-80 + text_width = width - 80 - self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width)) + self.ap.logger.debug('lines: {}, text_width: {}'.format(lines, text_width)) for line in lines: # 如果长了就分割 line_width = self.get_font(query).getlength(line) - self.ap.logger.debug("line_width: {}".format(line_width)) + self.ap.logger.debug('line_width: {}'.format(line_width)) if line_width < text_width: final_lines.append(line) continue @@ -161,7 +165,10 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): numbers = self.indexNumber(rest_text) for number in numbers: - if number[1] < point < number[1] + len(number[0]) and number[1] != 0: + if ( + number[1] < point < number[1] + len(number[0]) + and number[1] != 0 + ): point = number[1] break @@ -174,16 +181,23 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): else: continue # 准备画布 - img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)) + img = Image.new( + 'RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255) + ) draw = ImageDraw.Draw(img, mode='RGBA') - self.ap.logger.debug("正在绘制图片...") + self.ap.logger.debug('正在绘制图片...') # 绘制正文 line_number = 0 offset_x = 20 offset_y = 30 for final_line in final_lines: - draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font) + draw.text( + (offset_x, offset_y + 35 * line_number), + final_line, + fill=(0, 0, 0), + font=self.text_render_font, + ) # 遍历此行,检查是否有emoji idx_in_line = 0 for ch in final_line: @@ -196,7 +210,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): line_number += 1 - self.ap.logger.debug("正在保存图片...") + self.ap.logger.debug('正在保存图片...') img.save(save_as) return save_as diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py index 6f66bbff..4e141045 100644 --- a/pkg/pipeline/longtext/strategy.py +++ b/pkg/pipeline/longtext/strategy.py @@ -12,7 +12,7 @@ preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] def strategy_class( - name: str + name: str, ) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: """长文本处理策略类装饰器 @@ -36,8 +36,7 @@ def strategy_class( class LongTextStrategy(metaclass=abc.ABCMeta): - """长文本处理策略抽象类 - """ + """长文本处理策略抽象类""" name: str @@ -45,12 +44,14 @@ class LongTextStrategy(metaclass=abc.ABCMeta): def __init__(self, ap: app.Application): self.ap = ap - + async def initialize(self): pass - + @abc.abstractmethod - async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: + async def process( + self, message: str, query: core_entities.Query + ) -> list[platform_message.MessageComponent]: """处理长文本 在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法 diff --git a/pkg/pipeline/msgtrun/msgtrun.py b/pkg/pipeline/msgtrun/msgtrun.py index b3fb593a..2595e289 100644 --- a/pkg/pipeline/msgtrun/msgtrun.py +++ b/pkg/pipeline/msgtrun/msgtrun.py @@ -3,33 +3,38 @@ from __future__ import annotations from .. import stage, entities from ...core import entities as core_entities from . import truncator -from .truncators import round +from ...utils import importutil + +from . import truncators + +importutil.import_modules_in_pkg(truncators) -@stage.stage_class("ConversationMessageTruncator") +@stage.stage_class('ConversationMessageTruncator') class ConversationMessageTruncator(stage.PipelineStage): """会话消息截断器 用于截断会话消息链,以适应平台消息长度限制。 """ + trun: truncator.Truncator async def initialize(self, pipeline_config: dict): - use_method = "round" + use_method = 'round' for trun in truncator.preregistered_truncators: if trun.name == use_method: self.trun = trun(self.ap) break else: - raise ValueError(f"未知的截断器: {use_method}") + raise ValueError(f'未知的截断器: {use_method}') - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: - """处理 - """ + async def process( + self, query: core_entities.Query, stage_inst_name: str + ) -> entities.StageProcessResult: + """处理""" query = await self.trun.truncate(query) return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) \ No newline at end of file + result_type=entities.ResultType.CONTINUE, new_query=query + ) diff --git a/pkg/pipeline/msgtrun/truncator.py b/pkg/pipeline/msgtrun/truncator.py index 4afaf9fb..9e8b8a6c 100644 --- a/pkg/pipeline/msgtrun/truncator.py +++ b/pkg/pipeline/msgtrun/truncator.py @@ -10,7 +10,7 @@ preregistered_truncators: list[typing.Type[Truncator]] = [] def truncator_class( - name: str + name: str, ) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: """截断器类装饰器 @@ -20,6 +20,7 @@ def truncator_class( Returns: typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器 """ + def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]: assert issubclass(cls, Truncator) @@ -33,13 +34,12 @@ def truncator_class( class Truncator(abc.ABC): - """消息截断器基类 - """ + """消息截断器基类""" name: str ap: app.Application - + def __init__(self, ap: app.Application): self.ap = ap diff --git a/pkg/pipeline/msgtrun/truncators/round.py b/pkg/pipeline/msgtrun/truncators/round.py index 46fce5f3..fa72a0e1 100644 --- a/pkg/pipeline/msgtrun/truncators/round.py +++ b/pkg/pipeline/msgtrun/truncators/round.py @@ -4,14 +4,12 @@ from .. import truncator from ....core import entities as core_entities -@truncator.truncator_class("round") +@truncator.truncator_class('round') class RoundTruncator(truncator.Truncator): - """前文回合数阶段器 - """ + """前文回合数阶段器""" async def truncate(self, query: core_entities.Query) -> core_entities.Query: - """截断 - """ + """截断""" max_round = query.pipeline_config['ai']['local-agent']['max-round'] temp_messages = [] @@ -26,7 +24,7 @@ class RoundTruncator(truncator.Truncator): current_round += 1 else: break - + query.messages = temp_messages[::-1] return query diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index b7eaaab4..8ca0d592 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -11,22 +11,39 @@ from ..entity.persistence import pipeline as persistence_pipeline from . import stage from ..platform.types import message as platform_message, events as platform_events from ..plugin import events +from ..utils import importutil -from .resprule import resprule -from .bansess import bansess -from .cntfilter import cntfilter -from .process import process -from .longtext import longtext -from .respback import respback -from .wrapper import wrapper -from .preproc import preproc -from .ratelimit import ratelimit -from .msgtrun import msgtrun +from . import ( + resprule, + bansess, + cntfilter, + process, + longtext, + respback, + wrapper, + preproc, + ratelimit, + msgtrun, +) + +importutil.import_modules_in_pkgs( + [ + resprule, + bansess, + cntfilter, + process, + longtext, + respback, + wrapper, + preproc, + ratelimit, + msgtrun, + ] +) -class StageInstContainer(): - """阶段实例容器 - """ +class StageInstContainer: + """阶段实例容器""" inst_name: str @@ -48,7 +65,12 @@ class RuntimePipeline: stage_containers: list[StageInstContainer] """阶段实例容器""" - def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[StageInstContainer]): + def __init__( + self, + ap: app.Application, + pipeline_entity: persistence_pipeline.LegacyPipeline, + stage_containers: list[StageInstContainer], + ): self.ap = ap self.pipeline_entity = pipeline_entity self.stage_containers = stage_containers @@ -57,9 +79,10 @@ class RuntimePipeline: query.pipeline_config = self.pipeline_entity.config await self.process_query(query) - async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): - """检查输出 - """ + async def _check_output( + self, query: entities.Query, result: pipeline_entities.StageProcessResult + ): + """检查输出""" if result.user_notice: # 处理str类型 @@ -68,22 +91,19 @@ class RuntimePipeline: platform_message.Plain(result.user_notice) ) elif isinstance(result.user_notice, list): - result.user_notice = platform_message.MessageChain( - *result.user_notice - ) + result.user_notice = platform_message.MessageChain(*result.user_notice) - if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage): + if query.pipeline_config['output']['misc']['at-sender'] and isinstance( + query.message_event, platform_events.GroupMessage + ): result.user_notice.insert( - 0, - platform_message.At( - query.message_event.sender.id - ) + 0, platform_message.At(query.message_event.sender.id) ) await query.adapter.reply_message( message_source=query.message_event, message=result.user_notice, - quote_origin=query.pipeline_config['output']['misc']['quote-origin'] + quote_origin=query.pipeline_config['output']['misc']['quote-origin'], ) if result.debug_notice: self.ap.logger.debug(result.debug_notice) @@ -123,32 +143,44 @@ class RuntimePipeline: stage_container = self.stage_containers[i] query.current_stage = stage_container # 标记到 Query 对象里 - + result = stage_container.inst.process(query, stage_container.inst_name) if isinstance(result, typing.Coroutine): result = await result if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 - self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}") + self.ap.logger.debug( + f'Stage {stage_container.inst_name} processed query {query} res {result}' + ) await self._check_output(query, result) if result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + self.ap.logger.debug( + f'Stage {stage_container.inst_name} interrupted query {query}' + ) break elif result.result_type == pipeline_entities.ResultType.CONTINUE: query = result.new_query elif isinstance(result, typing.AsyncGenerator): # 生成器 - self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen") + self.ap.logger.debug( + f'Stage {stage_container.inst_name} processed query {query} gen' + ) async for sub_result in result: - self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}") + self.ap.logger.debug( + f'Stage {stage_container.inst_name} processed query {query} res {sub_result}' + ) await self._check_output(query, sub_result) if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + self.ap.logger.debug( + f'Stage {stage_container.inst_name} interrupted query {query}' + ) break - elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE: + elif ( + sub_result.result_type == pipeline_entities.ResultType.CONTINUE + ): query = sub_result.new_query await self._execute_from_stage(i + 1, query) break @@ -156,12 +188,14 @@ class RuntimePipeline: i += 1 async def process_query(self, query: entities.Query): - """处理请求 - """ + """处理请求""" try: - # ======== 触发 MessageReceived 事件 ======== - event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived + event_type = ( + events.PersonMessageReceived + if query.launcher_type == entities.LauncherTypes.PERSON + else events.GroupMessageReceived + ) event_ctx = await self.ap.plugin_mgr.emit_event( event=event_type( @@ -169,22 +203,26 @@ class RuntimePipeline: launcher_id=query.launcher_id, sender_id=query.sender_id, message_chain=query.message_chain, - query=query + query=query, ) ) if event_ctx.is_prevented_default(): return - - self.ap.logger.debug(f"Processing query {query}") + + self.ap.logger.debug(f'Processing query {query}') await self._execute_from_stage(0, query) except Exception as e: - inst_name = query.current_stage.inst_name if query.current_stage else 'unknown' - self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}") - self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + inst_name = ( + query.current_stage.inst_name if query.current_stage else 'unknown' + ) + self.ap.logger.error( + f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}' + ) + self.ap.logger.debug(f'Traceback: {traceback.format_exc()}') finally: - self.ap.logger.debug(f"Query {query} processed") + self.ap.logger.debug(f'Query {query} processed') class PipelineManager: @@ -203,7 +241,9 @@ class PipelineManager: self.pipelines = [] async def initialize(self): - self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()} + self.stage_dict = { + name: cls for name, cls in stage.preregistered_stages.items() + } await self.load_pipelines_from_db() @@ -220,24 +260,31 @@ class PipelineManager: for pipeline in pipelines: await self.load_pipeline(pipeline) - async def load_pipeline(self, pipeline_entity: persistence_pipeline.LegacyPipeline | sqlalchemy.Row[persistence_pipeline.LegacyPipeline] | dict): - + async def load_pipeline( + self, + pipeline_entity: persistence_pipeline.LegacyPipeline + | sqlalchemy.Row[persistence_pipeline.LegacyPipeline] + | dict, + ): if isinstance(pipeline_entity, sqlalchemy.Row): - pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping) + pipeline_entity = persistence_pipeline.LegacyPipeline( + **pipeline_entity._mapping + ) elif isinstance(pipeline_entity, dict): pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity) # initialize stage containers according to pipeline_entity.stages stage_containers: list[StageInstContainer] = [] for stage_name in pipeline_entity.stages: - stage_containers.append(StageInstContainer( - inst_name=stage_name, - inst=self.stage_dict[stage_name](self.ap) - )) + stage_containers.append( + StageInstContainer( + inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap) + ) + ) for stage_container in stage_containers: await stage_container.inst.initialize(pipeline_entity.config) - + runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers) self.pipelines.append(runtime_pipeline) @@ -251,4 +298,4 @@ class PipelineManager: for pipeline in self.pipelines: if pipeline.pipeline_entity.uuid == uuid: self.pipelines.remove(pipeline) - return \ No newline at end of file + return diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index df4d0741..3da4e19b 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -47,7 +47,7 @@ class QueryPool: message_chain=message_chain, resp_messages=[], resp_message_chain=[], - adapter=adapter + adapter=adapter, ) self.queries.append(query) self.query_id_counter += 1 diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 42bb5b4c..bab1127d 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -9,7 +9,7 @@ from ...plugin import events from ...platform.types import message as platform_message -@stage.stage_class("PreProcessor") +@stage.stage_class('PreProcessor') class PreProcessor(stage.PipelineStage): """请求预处理阶段 @@ -29,11 +29,12 @@ class PreProcessor(stage.PipelineStage): query: core_entities.Query, stage_inst_name: str, ) -> entities.StageProcessResult: - """处理 - """ + """处理""" session = await self.ap.sess_mgr.get_session(query) - conversation = await self.ap.sess_mgr.get_conversation(query, session, query.pipeline_config['ai']['local-agent']['prompt']) + conversation = await self.ap.sess_mgr.get_conversation( + query, session, query.pipeline_config['ai']['local-agent']['prompt'] + ) # 设置query query.session = session @@ -42,17 +43,26 @@ class PreProcessor(stage.PipelineStage): query.use_llm_model = conversation.use_llm_model - query.use_funcs = conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None + query.use_funcs = ( + conversation.use_funcs + if query.use_llm_model.model_entity.abilities.__contains__('tool_call') + else None + ) query.variables = { - "session_id": f"{query.session.launcher_type.value}_{query.session.launcher_id}", - "conversation_id": conversation.uuid, - "msg_create_time": int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()), + 'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}', + 'conversation_id': conversation.uuid, + 'msg_create_time': int(query.message_event.time) + if query.message_event.time + else int(datetime.datetime.now().timestamp()), } # Check if this model supports vision, if not, remove all images # TODO this checking should be performed in runner, and in this stage, the image should be reserved - if query.pipeline_config['ai']['runner']['runner'] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'): + if ( + query.pipeline_config['ai']['runner']['runner'] == 'local-agent' + and not query.use_llm_model.model_entity.abilities.__contains__('vision') + ): for msg in query.messages: if isinstance(msg.content, list): for me in msg.content: @@ -61,16 +71,17 @@ class PreProcessor(stage.PipelineStage): content_list = [] - plain_text = "" + plain_text = '' for me in query.message_chain: if isinstance(me, platform_message.Plain): - content_list.append( - llm_entities.ContentElement.from_text(me.text) - ) + content_list.append(llm_entities.ContentElement.from_text(me.text)) plain_text += me.text elif isinstance(me, platform_message.Image): - if query.pipeline_config['ai']['runner']['runner'] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'): + if ( + query.pipeline_config['ai']['runner']['runner'] != 'local-agent' + or query.use_llm_model.model_entity.abilities.__contains__('vision') + ): if me.base64 is not None: content_list.append( llm_entities.ContentElement.from_image_base64(me.base64) @@ -78,10 +89,7 @@ class PreProcessor(stage.PipelineStage): query.variables['user_message_text'] = plain_text - query.user_message = llm_entities.Message( - role='user', - content=content_list - ) + query.user_message = llm_entities.Message(role='user', content=content_list) # =========== 触发事件 PromptPreProcessing event_ctx = await self.ap.plugin_mgr.emit_event( @@ -89,7 +97,7 @@ class PreProcessor(stage.PipelineStage): session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}', default_prompt=query.prompt.messages, prompt=query.messages, - query=query + query=query, ) ) @@ -97,6 +105,5 @@ class PreProcessor(stage.PipelineStage): query.messages = event_ctx.event.prompt return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) diff --git a/pkg/pipeline/process/handler.py b/pkg/pipeline/process/handler.py index 879b4cfe..8a32bcfb 100644 --- a/pkg/pipeline/process/handler.py +++ b/pkg/pipeline/process/handler.py @@ -8,7 +8,6 @@ from .. import entities class MessageHandler(metaclass=abc.ABCMeta): - ap: app.Application def __init__(self, ap: app.Application): diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 9d231dda..7943d8d1 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -3,33 +3,36 @@ from __future__ import annotations import typing import time import traceback -import json from .. import handler from ... import entities from ....core import entities as core_entities -from ....provider import entities as llm_entities from ....provider import runner as runner_module -from ....provider.runners import localagent, difysvapi, dashscopeapi from ....plugin import events from ....platform.types import message as platform_message +from ....utils import importutil +from ....provider import runners + +importutil.import_modules_in_pkg(runners) class ChatMessageHandler(handler.MessageHandler): - async def handle( self, query: core_entities.Query, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: - """处理 - """ + """处理""" # 调API # 生成器 # 触发插件事件 - event_class = events.PersonNormalMessageReceived if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupNormalMessageReceived + event_class = ( + events.PersonNormalMessageReceived + if query.launcher_type == core_entities.LauncherTypes.PERSON + else events.GroupNormalMessageReceived + ) event_ctx = await self.ap.plugin_mgr.emit_event( event=event_class( @@ -37,7 +40,7 @@ class ChatMessageHandler(handler.MessageHandler): launcher_id=query.launcher_id, sender_id=query.sender_id, text_message=str(query.message_chain), - query=query + query=query, ) ) @@ -48,16 +51,13 @@ class ChatMessageHandler(handler.MessageHandler): query.resp_messages.append(mc) yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) else: yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query + result_type=entities.ResultType.INTERRUPT, new_query=query ) else: - if event_ctx.event.alter is not None: # if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter query.user_message.content = event_ctx.event.alter @@ -67,48 +67,52 @@ class ChatMessageHandler(handler.MessageHandler): start_time = time.time() try: - for r in runner_module.preregistered_runners: - if r.name == query.pipeline_config["ai"]["runner"]["runner"]: + if r.name == query.pipeline_config['ai']['runner']['runner']: runner = r(self.ap, query.pipeline_config) break else: - raise ValueError(f"未找到请求运行器: {query.pipeline_config['ai']['runner']['runner']}") + raise ValueError( + f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}' + ) async for result in runner.run(query): query.resp_messages.append(result) - self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') + self.ap.logger.info( + f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}' + ) if result.content is not None: text_length += len(result.content) yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) query.session.using_conversation.messages.append(query.user_message) query.session.using_conversation.messages.extend(query.resp_messages) except Exception as e: - - self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}') + self.ap.logger.error( + f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}' + ) - hide_exception_info = query.pipeline_config['output']['misc']['hide-exception'] + hide_exception_info = query.pipeline_config['output']['misc'][ + 'hide-exception' + ] yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, new_query=query, user_notice='请求失败' if hide_exception_info else f'{e}', error_notice=f'{e}', - debug_notice=traceback.format_exc() + debug_notice=traceback.format_exc(), ) finally: - await self.ap.ctr_mgr.usage.post_query_record( session_type=query.session.launcher_type.value, session_id=str(query.session.launcher_id), - query_ability_provider="LangBot.Chat", + query_ability_provider='LangBot.Chat', usage=text_length, model_name=query.use_model.name, response_seconds=int(time.time() - start_time), diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index b0316e1f..af1357b5 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -11,24 +11,29 @@ from ....platform.types import message as platform_message class CommandHandler(handler.MessageHandler): - async def handle( self, query: core_entities.Query, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: - """处理 - """ + """处理""" command_text = str(query.message_chain).strip()[1:] privilege = 1 - - if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']: + + if ( + f'{query.launcher_type.value}_{query.launcher_id}' + in self.ap.instance_config.data['admins'] + ): privilege = 2 spt = command_text.split(' ') - event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent + event_class = ( + events.PersonCommandSent + if query.launcher_type == core_entities.LauncherTypes.PERSON + else events.GroupCommandSent + ) event_ctx = await self.ap.plugin_mgr.emit_event( event=event_class( @@ -38,41 +43,35 @@ class CommandHandler(handler.MessageHandler): command=spt[0], params=spt[1:] if len(spt) > 1 else [], text_message=str(query.message_chain), - is_admin=(privilege==2), - query=query + is_admin=(privilege == 2), + query=query, ) ) if event_ctx.is_prevented_default(): - if event_ctx.event.reply is not None: mc = platform_message.MessageChain(event_ctx.event.reply) query.resp_messages.append(mc) yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) else: yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query + result_type=entities.ResultType.INTERRUPT, new_query=query ) else: - if event_ctx.event.alter is not None: - query.message_chain = platform_message.MessageChain([ - platform_message.Plain(event_ctx.event.alter) - ]) + query.message_chain = platform_message.MessageChain( + [platform_message.Plain(event_ctx.event.alter)] + ) session = await self.ap.sess_mgr.get_session(query) async for ret in self.ap.cmd_mgr.execute( - command_text=command_text, - query=query, - session=session + command_text=command_text, query=query, session=session ): if ret.error is not None: query.resp_messages.append( @@ -82,20 +81,18 @@ class CommandHandler(handler.MessageHandler): ) ) - self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}') + self.ap.logger.info( + f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}' + ) yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) elif ret.text is not None or ret.image_url is not None: - - content: list[llm_entities.ContentElement]= [] + content: list[llm_entities.ContentElement] = [] if ret.text is not None: - content.append( - llm_entities.ContentElement.from_text(ret.text) - ) + content.append(llm_entities.ContentElement.from_text(ret.text)) if ret.image_url is not None: content.append( @@ -112,11 +109,9 @@ class CommandHandler(handler.MessageHandler): self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}') yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) else: yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query + result_type=entities.ResultType.INTERRUPT, new_query=query ) diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py index 11f43d3c..64903552 100644 --- a/pkg/pipeline/process/process.py +++ b/pkg/pipeline/process/process.py @@ -1,18 +1,16 @@ from __future__ import annotations -from ...core import app, entities as core_entities +from ...core import entities as core_entities from . import handler from .handlers import chat, command from .. import entities -from .. import stage, entities -from ...core import entities as core_entities -from ...config import manager as cfg_mgr +from .. import stage -@stage.stage_class("MessageProcessor") +@stage.stage_class('MessageProcessor') class Processor(stage.PipelineStage): """请求实际处理阶段 - + 通过命令处理器和聊天处理器处理消息。 改写: @@ -35,11 +33,12 @@ class Processor(stage.PipelineStage): query: core_entities.Query, stage_inst_name: str, ) -> entities.StageProcessResult: - """处理 - """ + """处理""" message_text = str(query.message_chain).strip() - self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}") + self.ap.logger.info( + f'处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}' + ) async def generator(): cmd_prefix = self.ap.instance_config.data['command']['prefix'] @@ -50,5 +49,5 @@ class Processor(stage.PipelineStage): else: async for result in self.chat_handler.handle(query): yield result - + return generator() diff --git a/pkg/pipeline/ratelimit/algo.py b/pkg/pipeline/ratelimit/algo.py index d9baa801..3bcc347a 100644 --- a/pkg/pipeline/ratelimit/algo.py +++ b/pkg/pipeline/ratelimit/algo.py @@ -7,19 +7,19 @@ from ...core import app, entities as core_entities preregistered_algos: list[typing.Type[ReteLimitAlgo]] = [] + def algo_class(name: str): - def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]: cls.name = name preregistered_algos.append(cls) return cls - + return decorator class ReteLimitAlgo(metaclass=abc.ABCMeta): """限流算法抽象类""" - + name: str = None ap: app.Application @@ -31,11 +31,16 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: + async def require_access( + self, + query: core_entities.Query, + launcher_type: str, + launcher_id: typing.Union[int, str], + ) -> bool: """进入处理流程 这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。 - + Args: launcher_type (str): 请求者类型 群聊为 group 私聊为 person launcher_id (int): 请求者ID @@ -44,15 +49,19 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta): bool: 是否允许进入处理流程,若返回false,则直接丢弃该请求 """ raise NotImplementedError - + @abc.abstractmethod - async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]): + async def release_access( + self, + query: core_entities.Query, + launcher_type: str, + launcher_id: typing.Union[int, str], + ): """退出处理流程 Args: launcher_type (str): 请求者类型 群聊为 group 私聊为 person launcher_id (int): 请求者ID """ - + raise NotImplementedError - \ No newline at end of file diff --git a/pkg/pipeline/ratelimit/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py index f17e93b8..32079a97 100644 --- a/pkg/pipeline/ratelimit/algos/fixedwin.py +++ b/pkg/pipeline/ratelimit/algos/fixedwin.py @@ -5,9 +5,9 @@ import typing from .. import algo from ....core import entities as core_entities + # 固定窗口算法 class SessionContainer: - wait_lock: asyncio.Lock records: dict[int, int] @@ -18,9 +18,8 @@ class SessionContainer: self.records = {} -@algo.algo_class("fixwin") +@algo.algo_class('fixwin') class FixedWindowAlgo(algo.ReteLimitAlgo): - containers_lock: asyncio.Lock """访问记录容器锁""" @@ -31,7 +30,12 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): self.containers_lock = asyncio.Lock() self.containers = {} - async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: + async def require_access( + self, + query: core_entities.Query, + launcher_type: str, + launcher_id: typing.Union[int, str], + ) -> bool: # 加锁,找容器 container: SessionContainer = None @@ -46,7 +50,6 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): # 等待锁 async with container.wait_lock: - # 获取窗口大小和限制 window_size = query.pipeline_config['safety']['rate-limit']['window-length'] limitation = query.pipeline_config['safety']['rate-limit']['limitation'] @@ -69,13 +72,15 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): if count >= limitation: if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop': return False - elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait': + elif ( + query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait' + ): # 等待下一窗口 await asyncio.sleep(window_size - time.time() % window_size) - + now = int(time.time()) now = now - now % window_size - + if now not in container.records: container.records = {} container.records[now] = 1 @@ -85,6 +90,11 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): # 返回True return True - - async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]): + + async def release_access( + self, + query: core_entities.Query, + launcher_type: str, + launcher_id: typing.Union[int, str], + ): pass diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index c74db978..23de4ec6 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -4,22 +4,25 @@ import typing from .. import entities, stage from . import algo -from .algos import fixedwin from ...core import entities as core_entities +from ...utils import importutil + +from . import algos + +importutil.import_modules_in_pkg(algos) -@stage.stage_class("RequireRateLimitOccupancy") -@stage.stage_class("ReleaseRateLimitOccupancy") +@stage.stage_class('RequireRateLimitOccupancy') +@stage.stage_class('ReleaseRateLimitOccupancy') class RateLimit(stage.PipelineStage): """限速器控制阶段 - + 不改写query,只检查是否需要限速。 """ algo: algo.ReteLimitAlgo async def initialize(self, pipeline_config: dict): - algo_name = 'fixwin' algo_class = None @@ -42,9 +45,8 @@ class RateLimit(stage.PipelineStage): entities.StageProcessResult, typing.AsyncGenerator[entities.StageProcessResult, None], ]: - """处理 - """ - if stage_inst_name == "RequireRateLimitOccupancy": + """处理""" + if stage_inst_name == 'RequireRateLimitOccupancy': if await self.algo.require_access( query, query.launcher_type.value, @@ -58,10 +60,10 @@ class RateLimit(stage.PipelineStage): return entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, new_query=query, - console_notice=f"根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息", - user_notice=f"请求数超过限速器设定值,已丢弃本消息。" + console_notice=f'根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息', + user_notice='请求数超过限速器设定值,已丢弃本消息。', ) - elif stage_inst_name == "ReleaseRateLimitOccupancy": + elif stage_inst_name == 'ReleaseRateLimitOccupancy': await self.algo.release_access( query, query.launcher_type.value, diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 8c074d89..42c141c8 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -4,41 +4,38 @@ import random import asyncio -from ...core import app from ...platform.types import events as platform_events from ...platform.types import message as platform_message from .. import stage, entities from ...core import entities as core_entities -from ...config import manager as cfg_mgr -@stage.stage_class("SendResponseBackStage") +@stage.stage_class('SendResponseBackStage') class SendResponseBackStage(stage.PipelineStage): - """发送响应消息 - """ + """发送响应消息""" - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: - """处理 - """ + async def process( + self, query: core_entities.Query, stage_inst_name: str + ) -> entities.StageProcessResult: + """处理""" - random_range = (query.pipeline_config['output']['force-delay']['min'], query.pipeline_config['output']['force-delay']['max']) + random_range = ( + query.pipeline_config['output']['force-delay']['min'], + query.pipeline_config['output']['force-delay']['max'], + ) random_delay = random.uniform(*random_range) - self.ap.logger.debug( - "根据规则强制延迟回复: %s s", - random_delay - ) + self.ap.logger.debug('根据规则强制延迟回复: %s s', random_delay) await asyncio.sleep(random_delay) - if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage): + if query.pipeline_config['output']['misc']['at-sender'] and isinstance( + query.message_event, platform_events.GroupMessage + ): query.resp_message_chain[-1].insert( - 0, - platform_message.At( - query.message_event.sender.id - ) + 0, platform_message.At(query.message_event.sender.id) ) quote_origin = query.pipeline_config['output']['misc']['quote-origin'] @@ -46,10 +43,9 @@ class SendResponseBackStage(stage.PipelineStage): await query.adapter.reply_message( message_source=query.message_event, message=query.resp_message_chain[-1], - quote_origin=quote_origin + quote_origin=quote_origin, ) return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) \ No newline at end of file + result_type=entities.ResultType.CONTINUE, new_query=query + ) diff --git a/pkg/pipeline/resprule/entities.py b/pkg/pipeline/resprule/entities.py index a334e843..a0ba7807 100644 --- a/pkg/pipeline/resprule/entities.py +++ b/pkg/pipeline/resprule/entities.py @@ -4,7 +4,6 @@ from ...platform.types import message as platform_message class RuleJudgeResult(pydantic.BaseModel): - matching: bool = False replacement: platform_message.MessageChain = None diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index 08ba49e8..99402351 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -1,16 +1,18 @@ from __future__ import annotations -from ...core import app -from . import entities as rule_entities, rule -from .rules import atbot, prefix, regexp, random +from . import rule from .. import stage, entities from ...core import entities as core_entities -from ...config import manager as cfg_mgr +from ...utils import importutil + +from . import rules + +importutil.import_modules_in_pkg(rules) -@stage.stage_class("GroupRespondRuleCheckStage") +@stage.stage_class('GroupRespondRuleCheckStage') class GroupRespondRuleCheckStage(stage.PipelineStage): """群组响应规则检查器 @@ -21,8 +23,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): """检查器实例""" async def initialize(self, pipeline_config: dict): - """初始化检查器 - """ + """初始化检查器""" self.rule_matchers = [] @@ -31,12 +32,12 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): await rule_inst.initialize() self.rule_matchers.append(rule_inst) - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: - + async def process( + self, query: core_entities.Query, stage_inst_name: str + ) -> entities.StageProcessResult: if query.launcher_type.value != 'group': # 只处理群消息 return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) rules = query.pipeline_config['trigger']['group-respond-rules'] @@ -48,7 +49,9 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): # use_rule = rules[str(query.launcher_id)] for rule_matcher in self.rule_matchers: # 任意一个匹配就放行 - res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query) + res = await rule_matcher.match( + str(query.message_chain), query.message_chain, use_rule, query + ) if res.matching: query.message_chain = res.replacement @@ -56,8 +59,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): result_type=entities.ResultType.CONTINUE, new_query=query, ) - + return entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query + result_type=entities.ResultType.INTERRUPT, new_query=query ) diff --git a/pkg/pipeline/resprule/rule.py b/pkg/pipeline/resprule/rule.py index ad69d8a0..3fdb0386 100644 --- a/pkg/pipeline/resprule/rule.py +++ b/pkg/pipeline/resprule/rule.py @@ -10,17 +10,19 @@ from ...platform.types import message as platform_message preregisetered_rules: list[typing.Type[GroupRespondRule]] = [] + def rule_class(name: str): def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]: cls.name = name preregisetered_rules.append(cls) return cls + return decorator class GroupRespondRule(metaclass=abc.ABCMeta): - """群组响应规则的抽象类 - """ + """群组响应规则的抽象类""" + name: str ap: app.Application @@ -37,8 +39,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta): message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query + query: core_entities.Query, ) -> entities.RuleJudgeResult: - """判断消息是否匹配规则 - """ + """判断消息是否匹配规则""" raise NotImplementedError diff --git a/pkg/pipeline/resprule/rules/atbot.py b/pkg/pipeline/resprule/rules/atbot.py index a0b7a7c8..0f4845f8 100644 --- a/pkg/pipeline/resprule/rules/atbot.py +++ b/pkg/pipeline/resprule/rules/atbot.py @@ -7,21 +7,24 @@ from ....core import entities as core_entities from ....platform.types import message as platform_message -@rule_model.rule_class("at-bot") +@rule_model.rule_class('at-bot') class AtBotRule(rule_model.GroupRespondRule): - async def match( self, message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query + query: core_entities.Query, ) -> entities.RuleJudgeResult: - - if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']: + if ( + message_chain.has(platform_message.At(query.adapter.bot_account_id)) + and rule_dict['at'] + ): message_chain.remove(platform_message.At(query.adapter.bot_account_id)) - if message_chain.has(platform_message.At(query.adapter.bot_account_id)): # 回复消息时会at两次,检查并删除重复的 + if message_chain.has( + platform_message.At(query.adapter.bot_account_id) + ): # 回复消息时会at两次,检查并删除重复的 message_chain.remove(platform_message.At(query.adapter.bot_account_id)) return entities.RuleJudgeResult( @@ -29,7 +32,4 @@ class AtBotRule(rule_model.GroupRespondRule): replacement=message_chain, ) - return entities.RuleJudgeResult( - matching=False, - replacement = message_chain - ) + return entities.RuleJudgeResult(matching=False, replacement=message_chain) diff --git a/pkg/pipeline/resprule/rules/prefix.py b/pkg/pipeline/resprule/rules/prefix.py index fb7bbcfc..c712d3e8 100644 --- a/pkg/pipeline/resprule/rules/prefix.py +++ b/pkg/pipeline/resprule/rules/prefix.py @@ -1,36 +1,30 @@ - from .. import rule as rule_model from .. import entities from ....core import entities as core_entities from ....platform.types import message as platform_message -@rule_model.rule_class("prefix") +@rule_model.rule_class('prefix') class PrefixRule(rule_model.GroupRespondRule): - async def match( self, message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query + query: core_entities.Query, ) -> entities.RuleJudgeResult: prefixes = rule_dict['prefix'] for prefix in prefixes: if message_text.startswith(prefix): - # 查找第一个plain元素 for me in message_chain: if isinstance(me, platform_message.Plain): - me.text = me.text[len(prefix):] + me.text = me.text[len(prefix) :] return entities.RuleJudgeResult( matching=True, replacement=message_chain, ) - return entities.RuleJudgeResult( - matching=False, - replacement=message_chain - ) + return entities.RuleJudgeResult(matching=False, replacement=message_chain) diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py index 0178f2c4..535bfe6b 100644 --- a/pkg/pipeline/resprule/rules/random.py +++ b/pkg/pipeline/resprule/rules/random.py @@ -7,19 +7,17 @@ from ....core import entities as core_entities from ....platform.types import message as platform_message -@rule_model.rule_class("random") +@rule_model.rule_class('random') class RandomRespRule(rule_model.GroupRespondRule): - async def match( self, message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query + query: core_entities.Query, ) -> entities.RuleJudgeResult: random_rate = rule_dict['random'] - + return entities.RuleJudgeResult( - matching=random.random() < random_rate, - replacement=message_chain - ) \ No newline at end of file + matching=random.random() < random_rate, replacement=message_chain + ) diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py index f5f5b3f6..daac0869 100644 --- a/pkg/pipeline/resprule/rules/regexp.py +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -7,15 +7,14 @@ from ....core import entities as core_entities from ....platform.types import message as platform_message -@rule_model.rule_class("regexp") +@rule_model.rule_class('regexp') class RegExpRule(rule_model.GroupRespondRule): - async def match( self, message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query + query: core_entities.Query, ) -> entities.RuleJudgeResult: regexps = rule_dict['regexp'] @@ -27,8 +26,5 @@ class RegExpRule(rule_model.GroupRespondRule): matching=True, replacement=message_chain, ) - - return entities.RuleJudgeResult( - matching=False, - replacement=message_chain - ) + + return entities.RuleJudgeResult(matching=False, replacement=message_chain) diff --git a/pkg/pipeline/stage.py b/pkg/pipeline/stage.py index 859286d9..18636e9f 100644 --- a/pkg/pipeline/stage.py +++ b/pkg/pipeline/stage.py @@ -11,17 +11,15 @@ preregistered_stages: dict[str, PipelineStage] = {} def stage_class(name: str): - def decorator(cls): preregistered_stages[name] = cls return cls - + return decorator class PipelineStage(metaclass=abc.ABCMeta): - """流水线阶段 - """ + """流水线阶段""" ap: app.Application @@ -29,8 +27,7 @@ class PipelineStage(metaclass=abc.ABCMeta): self.ap = ap async def initialize(self, pipeline_config: dict): - """初始化 - """ + """初始化""" pass @abc.abstractmethod @@ -42,6 +39,5 @@ class PipelineStage(metaclass=abc.ABCMeta): entities.StageProcessResult, typing.AsyncGenerator[entities.StageProcessResult, None], ]: - """处理 - """ + """处理""" raise NotImplementedError diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index 6b12ca65..bca02527 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -3,21 +3,19 @@ from __future__ import annotations import typing -from ...core import app, entities as core_entities -from .. import entities -from .. import stage, entities from ...core import entities as core_entities -from ...config import manager as cfg_mgr +from .. import entities +from .. import stage from ...plugin import events from ...platform.types import message as platform_message -@stage.stage_class("ResponseWrapper") +@stage.stage_class('ResponseWrapper') class ResponseWrapper(stage.PipelineStage): """回复包装阶段 把回复的 message 包装成人类识读的形式。 - + 改写: - resp_message_chain """ @@ -30,36 +28,36 @@ class ResponseWrapper(stage.PipelineStage): query: core_entities.Query, stage_inst_name: str, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: - """处理 - """ + """处理""" # 如果 resp_messages[-1] 已经是 MessageChain 了 if isinstance(query.resp_messages[-1], platform_message.MessageChain): query.resp_message_chain.append(query.resp_messages[-1]) yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) else: - if query.resp_messages[-1].role == 'command': - query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] ')) + query.resp_message_chain.append( + query.resp_messages[-1].get_content_platform_message_chain( + prefix_text='[bot] ' + ) + ) yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) elif query.resp_messages[-1].role == 'plugin': - query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain()) + query.resp_message_chain.append( + query.resp_messages[-1].get_content_platform_message_chain() + ) yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + result_type=entities.ResultType.CONTINUE, new_query=query ) else: - if query.resp_messages[-1].role == 'assistant': result = query.resp_messages[-1] session = await self.ap.sess_mgr.get_session(query) @@ -79,39 +77,51 @@ class ResponseWrapper(stage.PipelineStage): prefix='', response_text=reply_text, finish_reason='stop', - funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], - query=query + funcs_called=[ + fc.function.name for fc in result.tool_calls + ] + if result.tool_calls is not None + else [], + query=query, ) ) if event_ctx.is_prevented_default(): yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, - new_query=query + new_query=query, ) else: if event_ctx.event.reply is not None: - - query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply)) + query.resp_message_chain.append( + platform_message.MessageChain(event_ctx.event.reply) + ) else: - - query.resp_message_chain.append(result.get_content_platform_message_chain()) + query.resp_message_chain.append( + result.get_content_platform_message_chain() + ) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, - new_query=query + new_query=query, ) - if result.tool_calls is not None and len(result.tool_calls) > 0: # 有函数调用 - + if ( + result.tool_calls is not None and len(result.tool_calls) > 0 + ): # 有函数调用 function_names = [tc.function.name for tc in result.tool_calls] reply_text = f'调用函数 {".".join(function_names)}...' - query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)])) + query.resp_message_chain.append( + platform_message.MessageChain( + [platform_message.Plain(reply_text)] + ) + ) - if query.pipeline_config['output']['misc']['track-function-calls']: - + if query.pipeline_config['output']['misc'][ + 'track-function-calls' + ]: event_ctx = await self.ap.plugin_mgr.emit_event( event=events.NormalMessageResponded( launcher_type=query.launcher_type.value, @@ -121,26 +131,36 @@ class ResponseWrapper(stage.PipelineStage): prefix='', response_text=reply_text, finish_reason='stop', - funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], - query=query + funcs_called=[ + fc.function.name for fc in result.tool_calls + ] + if result.tool_calls is not None + else [], + query=query, ) ) if event_ctx.is_prevented_default(): yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, - new_query=query + new_query=query, ) else: if event_ctx.event.reply is not None: - - query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply)) + query.resp_message_chain.append( + platform_message.MessageChain( + event_ctx.event.reply + ) + ) else: - - query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)])) + query.resp_message_chain.append( + platform_message.MessageChain( + [platform_message.Plain(reply_text)] + ) + ) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, - new_query=query + new_query=query, ) diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index 42ea75e0..61ff32cd 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -17,7 +17,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): bot_account_id: int """机器人账号ID,需要在初始化时设置""" - + config: dict ap: app.Application @@ -32,14 +32,11 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): self.config = config self.ap = ap - async def send_message( - self, - target_type: str, - target_id: str, - message: platform_message.MessageChain + async def send_message( + self, target_type: str, target_id: str, message: platform_message.MessageChain ): """主动发送消息 - + Args: target_type (str): 目标类型,`person`或`group` target_id (str): 目标ID @@ -51,7 +48,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): self, message_source: platform_events.MessageEvent, message: platform_message.MessageChain, - quote_origin: bool = False + quote_origin: bool = False, ): """回复消息 @@ -69,23 +66,27 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): def register_listener( self, event_type: typing.Type[platform_message.Event], - callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None] + callback: typing.Callable[ + [platform_message.Event, MessagePlatformAdapter], None + ], ): """注册事件监听器 - + Args: event_type (typing.Type[platform.types.Event]): 事件类型 callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件 """ raise NotImplementedError - + def unregister_listener( self, event_type: typing.Type[platform_message.Event], - callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None] + callback: typing.Callable[ + [platform_message.Event, MessagePlatformAdapter], None + ], ): """注销事件监听器 - + Args: event_type (typing.Type[platform.types.Event]): 事件类型 callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件 @@ -98,7 +99,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): async def kill(self) -> bool: """关闭适配器 - + Returns: bool: 是否成功关闭,热重载时若此函数返回False则不会重载MessageSource底层 """ @@ -107,6 +108,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): class MessageConverter: """消息链转换器基类""" + @staticmethod def yiri2target(message_chain: platform_message.MessageChain): """将源平台消息链转换为目标平台消息链 diff --git a/pkg/platform/botmgr.py b/pkg/platform/botmgr.py index 6073243d..507f067e 100644 --- a/pkg/platform/botmgr.py +++ b/pkg/platform/botmgr.py @@ -1,23 +1,16 @@ from __future__ import annotations -import json -import os import sys -import logging import asyncio import traceback import sqlalchemy -from .sources import qqofficial # FriendMessage, Image, MessageChain, Plain from . import adapter as msadapter from ..core import app, entities as core_entities, taskmgr -from ..plugin import events -from .types import message as platform_message from .types import events as platform_events -from .types import entities as platform_entities from ..discover import engine @@ -25,6 +18,7 @@ from ..entity.persistence import bot as persistence_bot # 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题 from . import types as mirai + sys.modules['mirai'] = mirai @@ -43,7 +37,12 @@ class RuntimeBot: task_context: taskmgr.TaskContext - def __init__(self, ap: app.Application, bot_entity: persistence_bot.Bot, adapter: msadapter.MessagePlatformAdapter): + def __init__( + self, + ap: app.Application, + bot_entity: persistence_bot.Bot, + adapter: msadapter.MessagePlatformAdapter, + ): self.ap = ap self.bot_entity = bot_entity self.enable = bot_entity.enable @@ -51,9 +50,10 @@ class RuntimeBot: self.task_context = taskmgr.TaskContext() async def initialize(self): - - async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter): - + async def on_friend_message( + event: platform_events.FriendMessage, + adapter: msadapter.MessagePlatformAdapter, + ): await self.ap.query_pool.add_query( bot_uuid=self.bot_entity.uuid, launcher_type=core_entities.LauncherTypes.PERSON, @@ -64,8 +64,10 @@ class RuntimeBot: adapter=adapter, ) - async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter): - + async def on_group_message( + event: platform_events.GroupMessage, + adapter: msadapter.MessagePlatformAdapter, + ): await self.ap.query_pool.add_query( bot_uuid=self.bot_entity.uuid, launcher_type=core_entities.LauncherTypes.GROUP, @@ -76,17 +78,10 @@ class RuntimeBot: adapter=adapter, ) - self.adapter.register_listener( - platform_events.FriendMessage, - on_friend_message - ) - self.adapter.register_listener( - platform_events.GroupMessage, - on_group_message - ) + self.adapter.register_listener(platform_events.FriendMessage, on_friend_message) + self.adapter.register_listener(platform_events.GroupMessage, on_group_message) async def run(self): - async def exception_wrapper(): try: self.task_context.set_current_action('Running...') @@ -98,16 +93,19 @@ class RuntimeBot: return self.task_context.set_current_action('Exited with error.') self.task_context.log(f'平台适配器运行出错: {e}') - self.task_context.log(f"Traceback: {traceback.format_exc()}") + self.task_context.log(f'Traceback: {traceback.format_exc()}') self.ap.logger.error(f'平台适配器运行出错: {e}') - self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + self.ap.logger.debug(f'Traceback: {traceback.format_exc()}') self.task_wrapper = self.ap.task_mgr.create_task( exception_wrapper(), - kind="platform-adapter", - name=f"platform-adapter-{self.adapter.__class__.__name__}", + kind='platform-adapter', + name=f'platform-adapter-{self.adapter.__class__.__name__}', context=self.task_context, - scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM] + scopes=[ + core_entities.LifecycleControlScope.APPLICATION, + core_entities.LifecycleControlScope.PLATFORM, + ], ) async def shutdown(self): @@ -118,7 +116,6 @@ class RuntimeBot: # 控制QQ消息输入输出的类 class PlatformManager: - # ====== 4.0 ====== ap: app.Application = None @@ -129,18 +126,20 @@ class PlatformManager: adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] def __init__(self, ap: app.Application = None): - self.ap = ap self.bots = [] self.adapter_components = [] self.adapter_dict = {} - - async def initialize(self): - self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter') + async def initialize(self): + self.adapter_components = self.ap.discover.get_components_by_kind( + 'MessagePlatformAdapter' + ) adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {} for component in self.adapter_components: - adapter_dict[component.metadata.name] = component.get_python_component_class() + adapter_dict[component.metadata.name] = ( + component.get_python_component_class() + ) self.adapter_dict = adapter_dict await self.load_bots_from_db() @@ -158,12 +157,15 @@ class PlatformManager: ) bots = result.all() - + for bot in bots: # load all bots here, enable or disable will be handled in runtime await self.load_bot(bot) - async def load_bot(self, bot_entity: persistence_bot.Bot | sqlalchemy.Row[persistence_bot.Bot] | dict) -> RuntimeBot: + async def load_bot( + self, + bot_entity: persistence_bot.Bot | sqlalchemy.Row[persistence_bot.Bot] | dict, + ) -> RuntimeBot: """加载机器人""" if isinstance(bot_entity, sqlalchemy.Row): bot_entity = persistence_bot.Bot(**bot_entity._mapping) @@ -171,14 +173,11 @@ class PlatformManager: bot_entity = persistence_bot.Bot(**bot_entity) adapter_inst = self.adapter_dict[bot_entity.adapter]( - bot_entity.adapter_config, - self.ap + bot_entity.adapter_config, self.ap ) runtime_bot = RuntimeBot( - ap=self.ap, - bot_entity=bot_entity, - adapter=adapter_inst + ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst ) await runtime_bot.initialize() @@ -186,7 +185,7 @@ class PlatformManager: self.bots.append(runtime_bot) return runtime_bot - + async def get_bot_by_uuid(self, bot_uuid: str) -> RuntimeBot | None: for bot in self.bots: if bot.bot_entity.uuid == bot_uuid: @@ -202,24 +201,28 @@ class PlatformManager: return def get_available_adapters_info(self) -> list[dict]: - return [ - component.to_plain_dict() - for component in self.adapter_components - ] + return [component.to_plain_dict() for component in self.adapter_components] def get_available_adapter_info_by_name(self, name: str) -> dict | None: for component in self.adapter_components: if component.metadata.name == name: return component.to_plain_dict() return None - - def get_available_adapter_manifest_by_name(self, name: str) -> engine.Component | None: + + def get_available_adapter_manifest_by_name( + self, name: str + ) -> engine.Component | None: for component in self.adapter_components: if component.metadata.name == name: return component return None - async def write_back_config(self, adapter_name: str, adapter_inst: msadapter.MessagePlatformAdapter, config: dict): + async def write_back_config( + self, + adapter_name: str, + adapter_inst: msadapter.MessagePlatformAdapter, + config: dict, + ): # index = -2 # for i, adapter in enumerate(self.adapters): @@ -251,7 +254,7 @@ class PlatformManager: # TODO implement this pass - async def run(self): + async def run(self): # This method will only be called when the application launching for bot in self.bots: if bot.enable: diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index 9149e427..48116507 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -2,24 +2,23 @@ from __future__ import annotations import typing import asyncio import traceback -import time import datetime import aiocqhttp -import aiohttp from .. import adapter -from ...pipeline.longtext.strategies import forward from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities from ...utils import image -class AiocqhttpMessageConverter(adapter.MessageConverter): +class AiocqhttpMessageConverter(adapter.MessageConverter): @staticmethod - async def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]: + async def yiri2target( + message_chain: platform_message.MessageChain, + ) -> typing.Tuple[list, int, datetime.datetime]: msg_list = aiocqhttp.Message() msg_id = 0 @@ -35,7 +34,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): arg = '' if msg.base64: arg = msg.base64 - msg_list.append(aiocqhttp.MessageSegment.image(f"base64://{arg}")) + msg_list.append(aiocqhttp.MessageSegment.image(f'base64://{arg}')) elif msg.url: arg = msg.url msg_list.append(aiocqhttp.MessageSegment.image(arg)) @@ -45,12 +44,12 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): elif type(msg) is platform_message.At: msg_list.append(aiocqhttp.MessageSegment.at(msg.target)) elif type(msg) is platform_message.AtAll: - msg_list.append(aiocqhttp.MessageSegment.at("all")) + msg_list.append(aiocqhttp.MessageSegment.at('all')) elif type(msg) is platform_message.Voice: arg = '' if msg.base64: arg = msg.base64 - msg_list.append(aiocqhttp.MessageSegment.record(f"base64://{arg}")) + msg_list.append(aiocqhttp.MessageSegment.record(f'base64://{arg}')) elif msg.url: arg = msg.url msg_list.append(aiocqhttp.MessageSegment.record(arg)) @@ -58,10 +57,15 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): arg = msg.path msg_list.append(aiocqhttp.MessageSegment.record(msg.path)) elif type(msg) is platform_message.Forward: - for node in msg.node_list: - msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0]) - + msg_list.extend( + ( + await AiocqhttpMessageConverter.yiri2target( + node.message_chain + ) + )[0] + ) + else: msg_list.append(aiocqhttp.MessageSegment.text(str(msg))) @@ -78,20 +82,26 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): ) for msg in message: - if msg.type == "at": - if msg.data["qq"] == "all": + if msg.type == 'at': + if msg.data['qq'] == 'all': yiri_msg_list.append(platform_message.AtAll()) else: yiri_msg_list.append( platform_message.At( - target=msg.data["qq"], + target=msg.data['qq'], ) ) - elif msg.type == "text": - yiri_msg_list.append(platform_message.Plain(text=msg.data["text"])) - elif msg.type == "image": - image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url']) - yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}")) + elif msg.type == 'text': + yiri_msg_list.append(platform_message.Plain(text=msg.data['text'])) + elif msg.type == 'image': + image_base64, image_format = await image.qq_image_url_to_base64( + msg.data['url'] + ) + yiri_msg_list.append( + platform_message.Image( + base64=f'data:image/{image_format};base64,{image_base64}' + ) + ) chain = platform_message.MessageChain(yiri_msg_list) @@ -99,7 +109,6 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): class AiocqhttpEventConverter(adapter.EventConverter): - @staticmethod async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int): return event.source_platform_object @@ -110,49 +119,50 @@ class AiocqhttpEventConverter(adapter.EventConverter): event.message, event.message_id ) - if event.message_type == "group": - permission = "MEMBER" + if event.message_type == 'group': + permission = 'MEMBER' - if "role" in event.sender: - if event.sender["role"] == "admin": - permission = "ADMINISTRATOR" - elif event.sender["role"] == "owner": - permission = "OWNER" + if 'role' in event.sender: + if event.sender['role'] == 'admin': + permission = 'ADMINISTRATOR' + elif event.sender['role'] == 'owner': + permission = 'OWNER' converted_event = platform_events.GroupMessage( sender=platform_entities.GroupMember( - id=event.sender["user_id"], # message_seq 放哪? - member_name=event.sender["nickname"], + id=event.sender['user_id'], # message_seq 放哪? + member_name=event.sender['nickname'], permission=permission, group=platform_entities.Group( id=event.group_id, - name=event.sender["nickname"], + name=event.sender['nickname'], permission=platform_entities.Permission.Member, ), - special_title=event.sender["title"] if "title" in event.sender else "", + special_title=event.sender['title'] + if 'title' in event.sender + else '', join_timestamp=0, last_speak_timestamp=0, mute_time_remaining=0, ), message_chain=yiri_chain, time=event.time, - source_platform_object=event + source_platform_object=event, ) return converted_event - elif event.message_type == "private": + elif event.message_type == 'private': return platform_events.FriendMessage( sender=platform_entities.Friend( - id=event.sender["user_id"], - nickname=event.sender["nickname"], - remark="", + id=event.sender['user_id'], + nickname=event.sender['nickname'], + remark='', ), message_chain=yiri_chain, time=event.time, - source_platform_object=event + source_platform_object=event, ) class AiocqhttpAdapter(adapter.MessagePlatformAdapter): - bot: aiocqhttp.CQHttp bot_account_id: int @@ -170,14 +180,14 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): async def shutdown_trigger_placeholder(): while True: await asyncio.sleep(1) - + self.config['shutdown_trigger'] = shutdown_trigger_placeholder self.ap = ap - if "access-token" in config: - self.bot = aiocqhttp.CQHttp(access_token=config["access-token"]) - del self.config["access-token"] + if 'access-token' in config: + self.bot = aiocqhttp.CQHttp(access_token=config['access-token']) + del self.config['access-token'] else: self.bot = aiocqhttp.CQHttp() @@ -186,9 +196,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): ): aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0] - if target_type == "group": + if target_type == 'group': await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg) - elif target_type == "person": + elif target_type == 'person': await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg) async def reply_message( @@ -196,16 +206,17 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): message_source: platform_events.MessageEvent, message: platform_message.MessageChain, quote_origin: bool = False, - ): - aiocq_event = await AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id) + ): + aiocq_event = await AiocqhttpEventConverter.yiri2target( + message_source, self.bot_account_id + ) aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0] if quote_origin: - aiocq_msg = aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg + aiocq_msg = ( + aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg + ) - return await self.bot.send( - aiocq_event, - aiocq_msg - ) + return await self.bot.send(aiocq_event, aiocq_msg) async def is_muted(self, group_id: int) -> bool: return False @@ -213,24 +224,30 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, adapter.MessagePlatformAdapter], None + ], ): async def on_message(event: aiocqhttp.Event): self.bot_account_id = event.self_id try: - return await callback(await self.event_converter.target2yiri(event), self) - except: + return await callback( + await self.event_converter.target2yiri(event), self + ) + except Exception: traceback.print_exc() if event_type == platform_events.GroupMessage: - self.bot.on_message("group")(on_message) + self.bot.on_message('group')(on_message) elif event_type == platform_events.FriendMessage: - self.bot.on_message("private")(on_message) + self.bot.on_message('private')(on_message) def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, adapter.MessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/dingtalk.py b/pkg/platform/sources/dingtalk.py index aa768039..52f4a832 100644 --- a/pkg/platform/sources/dingtalk.py +++ b/pkg/platform/sources/dingtalk.py @@ -1,37 +1,30 @@ - import traceback import typing from libs.dingtalk_api.dingtalkevent import DingTalkEvent from pkg.platform.types import message as platform_message from pkg.platform.adapter import MessagePlatformAdapter -from pkg.platform.types import events as platform_events, message as platform_message -from pkg.core import app from .. import adapter -from ...pipeline.longtext.strategies import forward from ...core import app -from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities -from ...command.errors import ParamNotEnoughError from libs.dingtalk_api.api import DingTalkClient import datetime class DingTalkMessageConverter(adapter.MessageConverter): - @staticmethod - async def yiri2target( - message_chain:platform_message.MessageChain - ): + async def yiri2target(message_chain: platform_message.MessageChain): for msg in message_chain: if type(msg) is platform_message.Plain: return msg.text @staticmethod - async def target2yiri(event:DingTalkEvent, bot_name:str): + async def target2yiri(event: DingTalkEvent, bot_name: str): yiri_msg_list = [] yiri_msg_list.append( - platform_message.Source(id = event.incoming_message.message_id,time=datetime.datetime.now()) + platform_message.Source( + id=event.incoming_message.message_id, time=datetime.datetime.now() + ) ) for atUser in event.incoming_message.at_users: @@ -39,7 +32,7 @@ class DingTalkMessageConverter(adapter.MessageConverter): yiri_msg_list.append(platform_message.At(target=bot_name)) if event.content: - text_content = event.content.replace("@"+bot_name, '') + text_content = event.content.replace('@' + bot_name, '') yiri_msg_list.append(platform_message.Plain(text=text_content)) if event.picture: yiri_msg_list.append(platform_message.Image(base64=event.picture)) @@ -47,60 +40,51 @@ class DingTalkMessageConverter(adapter.MessageConverter): yiri_msg_list.append(platform_message.Voice(base64=event.audio)) chain = platform_message.MessageChain(yiri_msg_list) - + return chain class DingTalkEventConverter(adapter.EventConverter): - @staticmethod - async def yiri2target( - event:platform_events.MessageEvent - ): + async def yiri2target(event: platform_events.MessageEvent): return event.source_platform_object @staticmethod - async def target2yiri( - event:DingTalkEvent, - bot_name:str - ): - + async def target2yiri(event: DingTalkEvent, bot_name: str): message_chain = await DingTalkMessageConverter.target2yiri(event, bot_name) - if event.conversation == 'FriendMessage': - return platform_events.FriendMessage( sender=platform_entities.Friend( id=event.incoming_message.sender_id, - nickname = event.incoming_message.sender_nick, - remark="" + nickname=event.incoming_message.sender_nick, + remark='', ), - message_chain = message_chain, - time = event.incoming_message.create_at, + message_chain=message_chain, + time=event.incoming_message.create_at, source_platform_object=event, ) elif event.conversation == 'GroupMessage': sender = platform_entities.GroupMember( - id = event.incoming_message.sender_id, + id=event.incoming_message.sender_id, member_name=event.incoming_message.sender_nick, - permission= 'MEMBER', - group = platform_entities.Group( - id = event.incoming_message.conversation_id, - name = event.incoming_message.conversation_title, - permission=platform_entities.Permission.Member + permission='MEMBER', + group=platform_entities.Group( + id=event.incoming_message.conversation_id, + name=event.incoming_message.conversation_title, + permission=platform_entities.Permission.Member, ), special_title='', join_timestamp=0, last_speak_timestamp=0, - mute_time_remaining=0 + mute_time_remaining=0, ) time = event.incoming_message.create_at return platform_events.GroupMessage( - sender =sender, - message_chain = message_chain, - time = time, - source_platform_object=event + sender=sender, + message_chain=message_chain, + time=time, + source_platform_object=event, ) @@ -112,28 +96,28 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): event_converter: DingTalkEventConverter = DingTalkEventConverter() config: dict - def __init__(self,config:dict,ap:app.Application): + def __init__(self, config: dict, ap: app.Application): self.config = config self.ap = ap required_keys = [ - "client_id", - "client_secret", - "robot_name", - "robot_code", + 'client_id', + 'client_secret', + 'robot_name', + 'robot_code', ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError("钉钉缺少相关配置项,请查看文档或联系管理员") + raise Exception('钉钉缺少相关配置项,请查看文档或联系管理员') + + self.bot_account_id = self.config['robot_name'] - self.bot_account_id = self.config["robot_name"] - self.bot = DingTalkClient( - client_id=config["client_id"], - client_secret=config["client_secret"], - robot_name = config["robot_name"], - robot_code=config["robot_code"] + client_id=config['client_id'], + client_secret=config['client_secret'], + robot_name=config['robot_name'], + robot_code=config['robot_code'], ) - + async def reply_message( self, message_source: platform_events.MessageEvent, @@ -146,17 +130,16 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): incoming_message = event.incoming_message content = await DingTalkMessageConverter.yiri2target(message) - await self.bot.send_message(content,incoming_message) - + await self.bot.send_message(content, incoming_message) async def send_message( self, target_type: str, target_id: str, message: platform_message.MessageChain ): content = await DingTalkMessageConverter.yiri2target(message) if target_type == 'person': - await self.bot.send_proactive_message_to_one(target_id,content) + await self.bot.send_proactive_message_to_one(target_id, content) if target_type == 'group': - await self.bot.send_proactive_message_to_group(target_id,content) + await self.bot.send_proactive_message_to_group(target_id, content) def register_listener( self, @@ -168,15 +151,18 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): async def on_message(event: DingTalkEvent): try: return await callback( - await self.event_converter.target2yiri(event, self.config["robot_name"]), self + await self.event_converter.target2yiri( + event, self.config['robot_name'] + ), + self, ) - except: + except Exception: traceback.print_exc() if event_type == platform_events.FriendMessage: - self.bot.on_message("FriendMessage")(on_message) + self.bot.on_message('FriendMessage')(on_message) elif event_type == platform_events.GroupMessage: - self.bot.on_message("GroupMessage")(on_message) + self.bot.on_message('GroupMessage')(on_message) async def run_async(self): await self.bot.start() @@ -187,7 +173,8 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): async def unregister_listener( self, event_type: type, - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, MessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) - diff --git a/pkg/platform/sources/discord.py b/pkg/platform/sources/discord.py index 961b031a..07dd586f 100644 --- a/pkg/platform/sources/discord.py +++ b/pkg/platform/sources/discord.py @@ -3,39 +3,32 @@ from __future__ import annotations import discord import typing -import asyncio -import traceback -import time import re import base64 import uuid -import json import os import datetime import aiohttp from .. import adapter -from ...pipeline.longtext.strategies import forward from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities -from ...utils import image class DiscordMessageConverter(adapter.MessageConverter): - @staticmethod async def yiri2target( - message_chain: platform_message.MessageChain + message_chain: platform_message.MessageChain, ) -> typing.Tuple[str, typing.List[discord.File]]: for ele in message_chain: if isinstance(ele, platform_message.At): message_chain.remove(ele) break - text_string = "" + text_string = '' image_files = [] for ele in message_chain: @@ -49,46 +42,49 @@ class DiscordMessageConverter(adapter.MessageConverter): async with session.get(ele.url) as response: image_bytes = await response.read() elif ele.path: - with open(ele.path, "rb") as f: + with open(ele.path, 'rb') as f: image_bytes = f.read() - image_files.append(discord.File(fp=image_bytes, filename=f"{uuid.uuid4()}.png")) + image_files.append( + discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png') + ) elif isinstance(ele, platform_message.Plain): text_string += ele.text elif isinstance(ele, platform_message.Forward): for node in ele.node_list: - text_string, image_files = await DiscordMessageConverter.yiri2target(node.message_chain) + ( + text_string, + image_files, + ) = await DiscordMessageConverter.yiri2target(node.message_chain) text_string += text_string image_files.extend(image_files) return text_string, image_files @staticmethod - async def target2yiri( - message: discord.Message - ) -> platform_message.MessageChain: + async def target2yiri(message: discord.Message) -> platform_message.MessageChain: lb_msg_list = [] msg_create_time = datetime.datetime.fromtimestamp( int(message.created_at.timestamp()) ) - lb_msg_list.append( - platform_message.Source(id=message.id, time=msg_create_time) - ) + lb_msg_list.append(platform_message.Source(id=message.id, time=msg_create_time)) element_list = [] - def text_element_recur(text_ele: str) -> list[platform_message.MessageComponent]: - if text_ele == "": + def text_element_recur( + text_ele: str, + ) -> list[platform_message.MessageComponent]: + if text_ele == '': return [] # <@1234567890> # @everyone # @here - at_pattern = re.compile(r"(@everyone|@here|<@[\d]+>)") + at_pattern = re.compile(r'(@everyone|@here|<@[\d]+>)') at_matches = at_pattern.findall(text_ele) - + if len(at_matches) > 0: mid_at = at_matches[0] @@ -96,18 +92,19 @@ class DiscordMessageConverter(adapter.MessageConverter): mid_at_component = [] - if mid_at == "@everyone" or mid_at == "@here": + if mid_at == '@everyone' or mid_at == '@here': mid_at_component.append(platform_message.AtAll()) else: mid_at_component.append(platform_message.At(target=mid_at[2:-1])) - return text_element_recur(text_split[0]) + \ - mid_at_component + \ - text_element_recur(text_split[1]) + return ( + text_element_recur(text_split[0]) + + mid_at_component + + text_element_recur(text_split[1]) + ) else: return [platform_message.Plain(text=text_ele)] - element_list.extend(text_element_recur(message.content)) # attachments @@ -115,28 +112,27 @@ class DiscordMessageConverter(adapter.MessageConverter): async with aiohttp.ClientSession(trust_env=True) as session: async with session.get(attachment.url) as response: image_data = await response.read() - image_base64 = base64.b64encode(image_data).decode("utf-8") - image_format = response.headers["Content-Type"] - element_list.append(platform_message.Image(base64=f"data:{image_format};base64,{image_base64}")) + image_base64 = base64.b64encode(image_data).decode('utf-8') + image_format = response.headers['Content-Type'] + element_list.append( + platform_message.Image( + base64=f'data:{image_format};base64,{image_base64}' + ) + ) return platform_message.MessageChain(element_list) class DiscordEventConverter(adapter.EventConverter): - @staticmethod - async def yiri2target( - event: platform_events.Event - ) -> discord.Message: + async def yiri2target(event: platform_events.Event) -> discord.Message: pass @staticmethod - async def target2yiri( - event: discord.Message - ) -> platform_events.Event: + async def target2yiri(event: discord.Message) -> platform_events.Event: message_chain = await DiscordMessageConverter.target2yiri(event) - if type(event.channel) == discord.DMChannel: + if isinstance(event.channel, discord.DMChannel): return platform_events.FriendMessage( sender=platform_entities.Friend( id=event.author.id, @@ -147,7 +143,7 @@ class DiscordEventConverter(adapter.EventConverter): time=event.created_at.timestamp(), source_platform_object=event, ) - elif type(event.channel) == discord.TextChannel: + elif isinstance(event.channel, discord.TextChannel): return platform_events.GroupMessage( sender=platform_entities.GroupMember( id=event.author.id, @@ -158,7 +154,7 @@ class DiscordEventConverter(adapter.EventConverter): name=event.channel.name, permission=platform_entities.Permission.Member, ), - special_title="", + special_title='', join_timestamp=0, last_speak_timestamp=0, mute_time_remaining=0, @@ -170,7 +166,6 @@ class DiscordEventConverter(adapter.EventConverter): class DiscordAdapter(adapter.MessagePlatformAdapter): - bot: discord.Client bot_account_id: str # 用于在流水线中识别at是否是本bot,直接以bot_name作为标识 @@ -191,12 +186,11 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): self.config = config self.ap = ap - self.bot_account_id = self.config["client_id"] + self.bot_account_id = self.config['client_id'] adapter_self = self class MyClient(discord.Client): - async def on_message(self: discord.Client, message: discord.Message): if message.author.id == self.user.id or message.author.bot: return @@ -209,11 +203,11 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): args = {} - if os.getenv("http_proxy"): - args["proxy"] = os.getenv("http_proxy") + if os.getenv('http_proxy'): + args['proxy'] = os.getenv('http_proxy') self.bot = MyClient(intents=intents, **args) - + async def send_message( self, target_type: str, target_id: str, message: platform_message.MessageChain ): @@ -229,17 +223,17 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): assert isinstance(message_source.source_platform_object, discord.Message) args = { - "content": msg_to_send, + 'content': msg_to_send, } if len(image_files) > 0: - args["files"] = image_files + args['files'] = image_files if quote_origin: - args["reference"] = message_source.source_platform_object + args['reference'] = message_source.source_platform_object if message.has(platform_message.At): - args["mention_author"] = True + args['mention_author'] = True await message_source.source_platform_object.channel.send(**args) @@ -249,20 +243,24 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, adapter.MessagePlatformAdapter], None + ], ): self.listeners[event_type] = callback def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, adapter.MessagePlatformAdapter], None + ], ): self.listeners.pop(event_type) async def run_async(self): async with self.bot: - await self.bot.start(self.config["token"], reconnect=True) + await self.bot.start(self.config['token'], reconnect=True) async def kill(self) -> bool: await self.bot.close() diff --git a/pkg/platform/sources/gewechat.py b/pkg/platform/sources/gewechat.py index 869e05c6..2c9e5733 100644 --- a/pkg/platform/sources/gewechat.py +++ b/pkg/platform/sources/gewechat.py @@ -8,18 +8,13 @@ import traceback import time import re import base64 -import uuid -import json -import os import copy -import datetime import threading import quart import aiohttp from .. import adapter -from ...pipeline.longtext.strategies import forward from ...core import app from ..types import message as platform_message from ..types import events as platform_events @@ -29,109 +24,123 @@ import xml.etree.ElementTree as ET class GewechatMessageConverter(adapter.MessageConverter): - def __init__(self, config: dict): self.config = config @staticmethod - async def yiri2target( - message_chain: platform_message.MessageChain - ) -> list[dict]: + async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]: content_list = [] for component in message_chain: if isinstance(component, platform_message.At): - content_list.append({"type": "at", "target": component.target}) + content_list.append({'type': 'at', 'target': component.target}) elif isinstance(component, platform_message.Plain): - content_list.append({"type": "text", "content": component.text}) + content_list.append({'type': 'text', 'content': component.text}) elif isinstance(component, platform_message.Image): if not component.url: pass - content_list.append({"type": "image", "image": component.url}) - + content_list.append({'type': 'image', 'image': component.url}) elif isinstance(component, platform_message.Voice): - content_list.append({"type": "voice", "url": component.url, "length": component.length}) + content_list.append( + {'type': 'voice', 'url': component.url, 'length': component.length} + ) elif isinstance(component, platform_message.Forward): for node in component.node_list: - content_list.extend(await GewechatMessageConverter.yiri2target(node.message_chain)) + content_list.extend( + await GewechatMessageConverter.yiri2target(node.message_chain) + ) return content_list async def target2yiri( - self, - message: dict, - bot_account_id: str + self, message: dict, bot_account_id: str ) -> platform_message.MessageChain: - - - - if message["Data"]["MsgType"] == 1: + if message['Data']['MsgType'] == 1: # 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉 - regex = re.compile(r"^wxid_.*:") + regex = re.compile(r'^wxid_.*:') # print(message) - line_split = message["Data"]["Content"]["string"].split("\n") + line_split = message['Data']['Content']['string'].split('\n') if len(line_split) > 0 and regex.match(line_split[0]): - message["Data"]["Content"]["string"] = "\n".join(line_split[1:]) - + message['Data']['Content']['string'] = '\n'.join(line_split[1:]) # 正则表达式模式,匹配'@'后跟任意数量的非空白字符 pattern = r'@\S+' - at_string = f"@{bot_account_id}" + at_string = f'@{bot_account_id}' content_list = [] - if at_string in message["Data"]["Content"]["string"]: + if at_string in message['Data']['Content']['string']: content_list.append(platform_message.At(target=bot_account_id)) - content_list.append(platform_message.Plain(message["Data"]["Content"]["string"].replace(at_string, '', 1))) + content_list.append( + platform_message.Plain( + message['Data']['Content']['string'].replace(at_string, '', 1) + ) + ) # 更优雅的替换改名后@机器人,仅仅限于单独AT的情况 - elif "PushContent" in message['Data'] and '在群聊中@了你' in message["Data"]["PushContent"]: - if '@所有人' in message["Data"]["Content"]["string"]: # at全员时候传入atll不当作at自己 + elif ( + 'PushContent' in message['Data'] + and '在群聊中@了你' in message['Data']['PushContent'] + ): + if ( + '@所有人' in message['Data']['Content']['string'] + ): # at全员时候传入atll不当作at自己 content_list.append(platform_message.AtAll()) else: content_list.append(platform_message.At(target=bot_account_id)) - content_list.append(platform_message.Plain(re.sub(pattern, '', message["Data"]["Content"]["string"]))) + content_list.append( + platform_message.Plain( + re.sub(pattern, '', message['Data']['Content']['string']) + ) + ) else: - content_list = [platform_message.Plain(message["Data"]["Content"]["string"])] + content_list = [ + platform_message.Plain(message['Data']['Content']['string']) + ] return platform_message.MessageChain(content_list) - - elif message["Data"]["MsgType"] == 3: - image_xml = message["Data"]["Content"]["string"] - if not image_xml: - return platform_message.MessageChain([ - platform_message.Plain(text="[图片内容为空]") - ]) + elif message['Data']['MsgType'] == 3: + image_xml = message['Data']['Content']['string'] + if not image_xml: + return platform_message.MessageChain( + [platform_message.Plain(text='[图片内容为空]')] + ) try: base64_str, image_format = await image.get_gewechat_image_base64( - gewechat_url=self.config["gewechat_url"], - gewechat_file_url=self.config["gewechat_file_url"], - app_id=self.config["app_id"], + gewechat_url=self.config['gewechat_url'], + gewechat_file_url=self.config['gewechat_file_url'], + app_id=self.config['app_id'], xml_content=image_xml, - token=self.config["token"], + token=self.config['token'], image_type=2, ) - return platform_message.MessageChain([ - platform_message.Image( - base64=f"data:image/{image_format};base64,{base64_str}" - ) - ]) + return platform_message.MessageChain( + [ + platform_message.Image( + base64=f'data:image/{image_format};base64,{base64_str}' + ) + ] + ) except Exception as e: - print(f"处理图片消息失败: {str(e)}") - return platform_message.MessageChain([ - platform_message.Plain(text=f"[图片处理失败]") - ]) - elif message["Data"]["MsgType"] == 34: - audio_base64 = message["Data"]["ImgBuf"]["buffer"] + print(f'处理图片消息失败: {str(e)}') + return platform_message.MessageChain( + [platform_message.Plain(text='[图片处理失败]')] + ) + elif message['Data']['MsgType'] == 34: + audio_base64 = message['Data']['ImgBuf']['buffer'] return platform_message.MessageChain( - [platform_message.Voice(base64=f"data:audio/silk;base64,{audio_base64}")] + [ + platform_message.Voice( + base64=f'data:audio/silk;base64,{audio_base64}' + ) + ] ) - elif message["Data"]["MsgType"] == 49: + elif message['Data']['MsgType'] == 49: # 支持微信聊天记录的消息类型,将 XML 内容转换为 MessageChain 传递 try: - content = message["Data"]["Content"]["string"] + content = message['Data']['Content']['string'] # 有三种可能的消息结构weid开头,私聊直接和直接 if content.startswith('wxid'): xml_list = content.split('\n')[2:] @@ -145,140 +154,145 @@ class GewechatMessageConverter(adapter.MessageConverter): content_data = ET.fromstring(xml_data) # print(xml_data) # 拿到细分消息类型,按照gewe接口中描述 - ''' + """ 小程序:33/36 引用消息:57 转账消息:2000 红包消息:2001 视频号消息:51 - ''' + """ appmsg_data = content_data.find('.//appmsg') data_type = appmsg_data.find('.//type').text if data_type == '57': user_data = appmsg_data.find('.//title').text # 拿到用户消息 - quote_data = appmsg_data.find('.//refermsg').find('.//content').text # 引用原文 - sender_id = appmsg_data.find('.//refermsg').find('.//chatusr').text # 引用用户id + quote_data = ( + appmsg_data.find('.//refermsg').find('.//content').text + ) # 引用原文 + sender_id = ( + appmsg_data.find('.//refermsg').find('.//chatusr').text + ) # 引用用户id from_name = message['Data']['FromUserName']['string'] - message_list =[] - if message['Wxid'] == sender_id and from_name.endswith('@chatroom'): # 因为引用机制暂时无法响应用户,所以当引用用户是机器人是构建一个at激活机器人 + message_list = [] + if ( + message['Wxid'] == sender_id and from_name.endswith('@chatroom') + ): # 因为引用机制暂时无法响应用户,所以当引用用户是机器人是构建一个at激活机器人 message_list.append(platform_message.At(target=bot_account_id)) - message_list.append(platform_message.Quote( + message_list.append( + platform_message.Quote( sender_id=sender_id, origin=platform_message.MessageChain( [platform_message.Plain(quote_data)] - ))) + ), + ) + ) message_list.append(platform_message.Plain(user_data)) return platform_message.MessageChain(message_list) elif data_type == '51': return platform_message.MessageChain( - [platform_message.Plain(text=f'[视频号消息]')] + [platform_message.Plain(text='[视频号消息]')] ) # print(content_data) elif data_type == '2000': return platform_message.MessageChain( - [platform_message.Plain(text=f'[转账消息]')] + [platform_message.Plain(text='[转账消息]')] ) elif data_type == '2001': return platform_message.MessageChain( - [platform_message.Plain(text=f'[红包消息]')] + [platform_message.Plain(text='[红包消息]')] ) elif data_type == '5': return platform_message.MessageChain( - [platform_message.Plain(text=f'[公众号消息]')] + [platform_message.Plain(text='[公众号消息]')] ) elif data_type == '33' or data_type == '36': return platform_message.MessageChain( - [platform_message.Plain(text=f'[小程序消息]')] + [platform_message.Plain(text='[小程序消息]')] ) # print(data_type.text) else: - - try: content_bytes = content.encode('utf-8') decoded_content = base64.b64decode(content_bytes) return platform_message.MessageChain( [platform_message.Unknown(content=decoded_content)] ) - except Exception as e: + except Exception: return platform_message.MessageChain( [platform_message.Plain(text=content)] ) except Exception as e: - print(f"Error processing type 49 message: {str(e)}") + print(f'Error processing type 49 message: {str(e)}') return platform_message.MessageChain( - [platform_message.Plain(text="[无法解析的消息]")] + [platform_message.Plain(text='[无法解析的消息]')] ) -class GewechatEventConverter(adapter.EventConverter): +class GewechatEventConverter(adapter.EventConverter): def __init__(self, config: dict): self.config = config self.message_converter = GewechatMessageConverter(config) @staticmethod - async def yiri2target( - event: platform_events.MessageEvent - ) -> dict: + async def yiri2target(event: platform_events.MessageEvent) -> dict: pass async def target2yiri( - self, - event: dict, - bot_account_id: str + self, event: dict, bot_account_id: str ) -> platform_events.MessageEvent: # print(event) # 排除自己发消息回调回答问题 if event['Wxid'] == event['Data']['FromUserName']['string']: return None # 排除公众号以及微信团队消息 - if event['Data']['FromUserName']['string'].startswith('gh_')\ - or event['Data']['FromUserName']['string'].startswith('weixin'): + if event['Data']['FromUserName']['string'].startswith('gh_') or event['Data'][ + 'FromUserName' + ]['string'].startswith('weixin'): return None - message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id) + message_chain = await self.message_converter.target2yiri( + copy.deepcopy(event), bot_account_id + ) if not message_chain: return None - - if '@chatroom' in event["Data"]["FromUserName"]["string"]: + + if '@chatroom' in event['Data']['FromUserName']['string']: # 找出开头的 wxid_ 字符串,以:结尾 - sender_wxid = event["Data"]["Content"]["string"].split(":")[0] + sender_wxid = event['Data']['Content']['string'].split(':')[0] return platform_events.GroupMessage( sender=platform_entities.GroupMember( id=sender_wxid, - member_name=event["Data"]["FromUserName"]["string"], + member_name=event['Data']['FromUserName']['string'], permission=platform_entities.Permission.Member, group=platform_entities.Group( - id=event["Data"]["FromUserName"]["string"], - name=event["Data"]["FromUserName"]["string"], + id=event['Data']['FromUserName']['string'], + name=event['Data']['FromUserName']['string'], permission=platform_entities.Permission.Member, ), - special_title="", + special_title='', join_timestamp=0, last_speak_timestamp=0, mute_time_remaining=0, ), message_chain=message_chain, - time=event["Data"]["CreateTime"], + time=event['Data']['CreateTime'], source_platform_object=event, ) else: return platform_events.FriendMessage( sender=platform_entities.Friend( - id=event["Data"]["FromUserName"]["string"], - nickname=event["Data"]["FromUserName"]["string"], + id=event['Data']['FromUserName']['string'], + nickname=event['Data']['FromUserName']['string'], remark='', ), message_chain=message_chain, - time=event["Data"]["CreateTime"], + time=event['Data']['CreateTime'], source_platform_object=event, ) class GeWeChatAdapter(adapter.MessagePlatformAdapter): - - name: str = "gewechat" # 定义适配器名称 + name: str = 'gewechat' # 定义适配器名称 bot: gewechat_client.GewechatClient quart_app: quart.Quart @@ -296,7 +310,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): typing.Type[platform_events.Event], typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ] = {} - + def __init__(self, config: dict, ap: app.Application): self.config = config self.ap = ap @@ -310,21 +324,21 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): async def gewechat_callback(): data = await quart.request.json # print(json.dumps(data, indent=4, ensure_ascii=False)) - + if 'data' in data: data['Data'] = data['data'] if 'type_name' in data: data['TypeName'] = data['type_name'] # print(json.dumps(data, indent=4, ensure_ascii=False)) - if 'testMsg' in data: return 'ok' elif 'TypeName' in data and data['TypeName'] == 'AddMsg': try: - - event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id) - except Exception as e: + event = await self.event_converter.target2yiri( + data.copy(), self.bot_account_id + ) + except Exception: traceback.print_exc() if event.__class__ in self.listeners: @@ -333,65 +347,67 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): return 'ok' async def send_message( - self, - target_type: str, - target_id: str, - message: platform_message.MessageChain + self, target_type: str, target_id: str, message: platform_message.MessageChain ): geweap_msg = await self.message_converter.yiri2target(message) # 此处加上群消息at处理 - ats = [item["target"] for item in geweap_msg if item["type"] == "at"] - + ats = [item['target'] for item in geweap_msg if item['type'] == 'at'] for msg in geweap_msg: # at主动发送消息 if msg['type'] == 'text': if ats: member_info = self.bot.get_chatroom_member_detail( - self.config["app_id"], - target_id, - ats[::-1] - )["data"] + self.config['app_id'], target_id, ats[::-1] + )['data'] for member in member_info: msg['content'] = f'@{member["nickName"]} {msg["content"]}' - self.bot.post_text(app_id=self.config['app_id'], to_wxid=target_id, content=msg['content'], - ats=",".join(ats)) + self.bot.post_text( + app_id=self.config['app_id'], + to_wxid=target_id, + content=msg['content'], + ats=','.join(ats), + ) elif msg['type'] == 'image': - - self.bot.post_image(app_id=self.config['app_id'], to_wxid=target_id, img_url=msg["image"]) - - + self.bot.post_image( + app_id=self.config['app_id'], + to_wxid=target_id, + img_url=msg['image'], + ) async def reply_message( self, message_source: platform_events.MessageEvent, message: platform_message.MessageChain, - quote_origin: bool = False + quote_origin: bool = False, ): content_list = await self.message_converter.yiri2target(message) - ats = [item["target"] for item in content_list if item["type"] == "at"] + ats = [item['target'] for item in content_list if item['type'] == 'at'] for msg in content_list: - if msg["type"] == "text": - + if msg['type'] == 'text': if ats: member_info = self.bot.get_chatroom_member_detail( - self.config["app_id"], - message_source.source_platform_object["Data"]["FromUserName"]["string"], - ats[::-1] - )["data"] + self.config['app_id'], + message_source.source_platform_object['Data']['FromUserName'][ + 'string' + ], + ats[::-1], + )['data'] for member in member_info: msg['content'] = f'@{member["nickName"]} {msg["content"]}' self.bot.post_text( - app_id=self.config["app_id"], - to_wxid=message_source.source_platform_object["Data"]["FromUserName"]["string"], - content=msg["content"], - ats=",".join(ats) + app_id=self.config['app_id'], + to_wxid=message_source.source_platform_object['Data'][ + 'FromUserName' + ]['string'], + content=msg['content'], + ats=','.join(ats), ) async def is_muted(self, group_id: int) -> bool: @@ -400,51 +416,57 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] + callback: typing.Callable[ + [platform_events.Event, adapter.MessagePlatformAdapter], None + ], ): self.listeners[event_type] = callback def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] + callback: typing.Callable[ + [platform_events.Event, adapter.MessagePlatformAdapter], None + ], ): pass async def run_async(self): - - if not self.config["token"]: + if not self.config['token']: async with aiohttp.ClientSession() as session: async with session.post( - f"{self.config['gewechat_url']}/v2/api/tools/getTokenId", - json={"app_id": self.config["app_id"]} + f'{self.config["gewechat_url"]}/v2/api/tools/getTokenId', + json={'app_id': self.config['app_id']}, ) as response: if response.status != 200: - raise Exception(f"获取gewechat token失败: {await response.text()}") - self.config["token"] = (await response.json())["data"] + raise Exception( + f'获取gewechat token失败: {await response.text()}' + ) + self.config['token'] = (await response.json())['data'] self.bot = gewechat_client.GewechatClient( - f"{self.config['gewechat_url']}/v2/api", - self.config["token"] + f'{self.config["gewechat_url"]}/v2/api', self.config['token'] ) - app_id, error_msg = self.bot.login(self.config["app_id"]) + app_id, error_msg = self.bot.login(self.config['app_id']) if error_msg: - raise Exception(f"Gewechat 登录失败: {error_msg}") + raise Exception(f'Gewechat 登录失败: {error_msg}') - self.config["app_id"] = app_id + self.config['app_id'] = app_id - self.ap.logger.info(f"Gewechat 登录成功,app_id: {app_id}") + self.ap.logger.info(f'Gewechat 登录成功,app_id: {app_id}') await self.ap.platform_mgr.write_back_config('gewechat', self, self.config) # 获取 nickname - profile = self.bot.get_profile(self.config["app_id"]) - self.bot_account_id = profile["data"]["nickName"] + profile = self.bot.get_profile(self.config['app_id']) + self.bot_account_id = profile['data']['nickName'] def thread_set_callback(): time.sleep(3) - ret = self.bot.set_callback(self.config["token"], self.config["callback_url"]) + ret = self.bot.set_callback( + self.config['token'], self.config['callback_url'] + ) print('设置 Gewechat 回调:', ret) threading.Thread(target=thread_set_callback).start() @@ -455,7 +477,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): await self.quart_app.run_task( host='0.0.0.0', - port=self.config["port"], + port=self.config['port'], shutdown_trigger=shutdown_trigger_placeholder, ) diff --git a/pkg/platform/sources/lark.py b/pkg/platform/sources/lark.py index 4c87640b..1396ba3b 100644 --- a/pkg/platform/sources/lark.py +++ b/pkg/platform/sources/lark.py @@ -5,56 +5,53 @@ import lark_oapi import typing import asyncio import traceback -import time import re import base64 import uuid import json import datetime import hashlib -import base64 from Crypto.Cipher import AES import aiohttp import lark_oapi.ws.exception import quart -from flask import jsonify from lark_oapi.api.im.v1 import * -from lark_oapi.api.verification.v1 import GetVerificationRequest from .. import adapter -from ...pipeline.longtext.strategies import forward from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities -from ...utils import image -class AESCipher(object): +class AESCipher(object): def __init__(self, key): self.bs = AES.block_size - self.key=hashlib.sha256(AESCipher.str_to_bytes(key)).digest() + self.key = hashlib.sha256(AESCipher.str_to_bytes(key)).digest() + @staticmethod def str_to_bytes(data): - u_type = type(b"".decode('utf8')) + u_type = type(b''.decode('utf8')) if isinstance(data, u_type): return data.encode('utf8') return data + @staticmethod def _unpad(s): - return s[:-ord(s[len(s) - 1:])] + return s[: -ord(s[len(s) - 1 :])] + def decrypt(self, enc): - iv = enc[:AES.block_size] + iv = enc[: AES.block_size] cipher = AES.new(self.key, AES.MODE_CBC, iv) - return self._unpad(cipher.decrypt(enc[AES.block_size:])) + return self._unpad(cipher.decrypt(enc[AES.block_size :])) + def decrypt_string(self, enc): enc = base64.b64decode(enc) - return self.decrypt(enc).decode('utf8') + return self.decrypt(enc).decode('utf8') class LarkMessageConverter(adapter.MessageConverter): - @staticmethod async def yiri2target( message_chain: platform_message.MessageChain, api_client: lark_oapi.Client @@ -65,15 +62,14 @@ class LarkMessageConverter(adapter.MessageConverter): for msg in message_chain: if isinstance(msg, platform_message.Plain): - pending_paragraph.append({"tag": "md", "text": msg.text}) + pending_paragraph.append({'tag': 'md', 'text': msg.text}) elif isinstance(msg, platform_message.At): pending_paragraph.append( - {"tag": "at", "user_id": msg.target, "style": []} + {'tag': 'at', 'user_id': msg.target, 'style': []} ) elif isinstance(msg, platform_message.AtAll): - pending_paragraph.append({"tag": "at", "user_id": "all", "style": []}) + pending_paragraph.append({'tag': 'at', 'user_id': 'all', 'style': []}) elif isinstance(msg, platform_message.Image): - image_bytes = None if msg.base64: @@ -83,14 +79,14 @@ class LarkMessageConverter(adapter.MessageConverter): async with session.get(msg.url) as response: image_bytes = await response.read() elif msg.path: - with open(msg.path, "rb") as f: + with open(msg.path, 'rb') as f: image_bytes = f.read() request: CreateImageRequest = ( CreateImageRequest.builder() .request_body( CreateImageRequestBody.builder() - .image_type("message") + .image_type('message') .image(image_bytes) .build() ) @@ -103,7 +99,7 @@ class LarkMessageConverter(adapter.MessageConverter): if not response.success(): raise Exception( - f"client.im.v1.image.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}" + f'client.im.v1.image.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' ) image_key = response.data.image_key @@ -112,15 +108,19 @@ class LarkMessageConverter(adapter.MessageConverter): message_elements.append( [ { - "tag": "img", - "image_key": image_key, + 'tag': 'img', + 'image_key': image_key, } ] ) pending_paragraph = [] elif isinstance(msg, platform_message.Forward): for node in msg.node_list: - message_elements.extend(await LarkMessageConverter.yiri2target(node.message_chain, api_client)) + message_elements.extend( + await LarkMessageConverter.yiri2target( + node.message_chain, api_client + ) + ) if pending_paragraph: message_elements.append(pending_paragraph) @@ -144,15 +144,15 @@ class LarkMessageConverter(adapter.MessageConverter): platform_message.Source(id=message.message_id, time=msg_create_time) ) - if message.message_type == "text": + if message.message_type == 'text': element_list = [] def text_element_recur(text_ele: dict) -> list[dict]: - if text_ele["text"] == "": + if text_ele['text'] == '': return [] - at_pattern = re.compile(r"@_user_[\d]+") - at_matches = at_pattern.findall(text_ele["text"]) + at_pattern = re.compile(r'@_user_[\d]+') + at_matches = at_pattern.findall(text_ele['text']) name_mapping = {} for mathc in at_matches: @@ -165,7 +165,7 @@ class LarkMessageConverter(adapter.MessageConverter): return [text_ele] # 只处理第一个,剩下的递归处理 - text_split = text_ele["text"].split(list(name_mapping.keys())[0]) + text_split = text_ele['text'].split(list(name_mapping.keys())[0]) new_list = [] @@ -173,58 +173,58 @@ class LarkMessageConverter(adapter.MessageConverter): right_text = text_split[1] new_list.extend( - text_element_recur({"tag": "text", "text": left_text, "style": []}) + text_element_recur({'tag': 'text', 'text': left_text, 'style': []}) ) new_list.append( { - "tag": "at", - "user_id": list(name_mapping.keys())[0], - "user_name": name_mapping[list(name_mapping.keys())[0]], - "style": [], + 'tag': 'at', + 'user_id': list(name_mapping.keys())[0], + 'user_name': name_mapping[list(name_mapping.keys())[0]], + 'style': [], } ) new_list.extend( - text_element_recur({"tag": "text", "text": right_text, "style": []}) + text_element_recur({'tag': 'text', 'text': right_text, 'style': []}) ) return new_list element_list = text_element_recur( - {"tag": "text", "text": message_content["text"], "style": []} + {'tag': 'text', 'text': message_content['text'], 'style': []} ) - message_content = {"title": "", "content": element_list} + message_content = {'title': '', 'content': element_list} - elif message.message_type == "post": + elif message.message_type == 'post': new_list = [] - for ele in message_content["content"]: + for ele in message_content['content']: if type(ele) is dict: new_list.append(ele) elif type(ele) is list: new_list.extend(ele) - message_content["content"] = new_list - elif message.message_type == "image": - message_content["content"] = [ - {"tag": "img", "image_key": message_content["image_key"], "style": []} + message_content['content'] = new_list + elif message.message_type == 'image': + message_content['content'] = [ + {'tag': 'img', 'image_key': message_content['image_key'], 'style': []} ] - for ele in message_content["content"]: - if ele["tag"] == "text": - lb_msg_list.append(platform_message.Plain(text=ele["text"])) - elif ele["tag"] == "at": - lb_msg_list.append(platform_message.At(target=ele["user_name"])) - elif ele["tag"] == "img": - image_key = ele["image_key"] + for ele in message_content['content']: + if ele['tag'] == 'text': + lb_msg_list.append(platform_message.Plain(text=ele['text'])) + elif ele['tag'] == 'at': + lb_msg_list.append(platform_message.At(target=ele['user_name'])) + elif ele['tag'] == 'img': + image_key = ele['image_key'] request: GetMessageResourceRequest = ( GetMessageResourceRequest.builder() .message_id(message.message_id) .file_key(image_key) - .type("image") + .type('image') .build() ) @@ -234,17 +234,17 @@ class LarkMessageConverter(adapter.MessageConverter): if not response.success(): raise Exception( - f"client.im.v1.message_resource.get failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}" + f'client.im.v1.message_resource.get failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' ) image_bytes = response.file.read() image_base64 = base64.b64encode(image_bytes).decode() - image_format = response.raw.headers["content-type"] + image_format = response.raw.headers['content-type'] lb_msg_list.append( platform_message.Image( - base64=f"data:{image_format};base64,{image_base64}" + base64=f'data:{image_format};base64,{image_base64}' ) ) @@ -252,7 +252,6 @@ class LarkMessageConverter(adapter.MessageConverter): class LarkEventConverter(adapter.EventConverter): - @staticmethod async def yiri2target( event: platform_events.MessageEvent, @@ -267,17 +266,17 @@ class LarkEventConverter(adapter.EventConverter): event.event.message, api_client ) - if event.event.message.chat_type == "p2p": + if event.event.message.chat_type == 'p2p': return platform_events.FriendMessage( sender=platform_entities.Friend( id=event.event.sender.sender_id.open_id, nickname=event.event.sender.sender_id.union_id, - remark="", + remark='', ), message_chain=message_chain, time=event.event.message.create_time, ) - elif event.event.message.chat_type == "group": + elif event.event.message.chat_type == 'group': return platform_events.GroupMessage( sender=platform_entities.GroupMember( id=event.event.sender.sender_id.open_id, @@ -285,10 +284,10 @@ class LarkEventConverter(adapter.EventConverter): permission=platform_entities.Permission.Member, group=platform_entities.Group( id=event.event.message.chat_id, - name="", + name='', permission=platform_entities.Permission.Member, ), - special_title="", + special_title='', join_timestamp=0, last_speak_timestamp=0, mute_time_remaining=0, @@ -299,7 +298,6 @@ class LarkEventConverter(adapter.EventConverter): class LarkAdapter(adapter.MessagePlatformAdapter): - bot: lark_oapi.ws.Client api_client: lark_oapi.Client @@ -333,17 +331,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter): data = cipher.decrypt_string(data['encrypt']) data = json.loads(data) - type = data.get("type") - if type is None : + type = data.get('type') + if type is None: context = EventContext(data) type = context.header.event_type - + if 'url_verification' == type: - print(data.get("challenge")) + print(data.get('challenge')) # todo 验证verification token - return { - "challenge": data.get("challenge") - } + return {'challenge': data.get('challenge')} context = EventContext(data) type = context.header.event_type p2v1 = P2ImMessageReceiveV1() @@ -355,20 +351,21 @@ class LarkAdapter(adapter.MessagePlatformAdapter): p2v1.schema = context.schema if 'im.message.receive_v1' == type: try: - event = await self.event_converter.target2yiri(p2v1, self.api_client) - except Exception as e: + event = await self.event_converter.target2yiri( + p2v1, self.api_client + ) + except Exception: traceback.print_exc() if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) - return {"code": 200, "message": "ok"} - except Exception as e: + return {'code': 200, 'message': 'ok'} + except Exception: traceback.print_exc() - return {"code": 500, "message": "error"} + return {'code': 500, 'message': 'error'} async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1): - lb_event = await self.event_converter.target2yiri(event, self.api_client) await self.listeners[type(lb_event)](lb_event, self) @@ -377,20 +374,20 @@ class LarkAdapter(adapter.MessagePlatformAdapter): asyncio.create_task(on_message(event)) event_handler = ( - lark_oapi.EventDispatcherHandler.builder("", "") + lark_oapi.EventDispatcherHandler.builder('', '') .register_p2_im_message_receive_v1(sync_on_message) .build() ) - self.bot_account_id = config["bot_name"] + self.bot_account_id = config['bot_name'] self.bot = lark_oapi.ws.Client( - config["app_id"], config["app_secret"], event_handler=event_handler + config['app_id'], config['app_secret'], event_handler=event_handler ) self.api_client = ( lark_oapi.Client.builder() - .app_id(config["app_id"]) - .app_secret(config["app_secret"]) + .app_id(config['app_id']) + .app_secret(config['app_secret']) .build() ) @@ -405,7 +402,6 @@ class LarkAdapter(adapter.MessagePlatformAdapter): message: platform_message.MessageChain, quote_origin: bool = False, ): - # 不再需要了,因为message_id已经被包含到message_chain中 # lark_event = await self.event_converter.yiri2target(message_source) lark_message = await self.message_converter.yiri2target( @@ -413,9 +409,9 @@ class LarkAdapter(adapter.MessagePlatformAdapter): ) final_content = { - "zh_cn": { - "title": "", - "content": lark_message, + 'zh_cn': { + 'title': '', + 'content': lark_message, }, } @@ -425,7 +421,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter): .request_body( ReplyMessageRequestBody.builder() .content(json.dumps(final_content)) - .msg_type("post") + .msg_type('post') .reply_in_thread(False) .uuid(str(uuid.uuid4())) .build() @@ -439,7 +435,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter): if not response.success(): raise Exception( - f"client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}" + f'client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' ) async def is_muted(self, group_id: int) -> bool: @@ -479,6 +475,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter): else: raise e else: + async def shutdown_trigger_placeholder(): while True: await asyncio.sleep(1) @@ -488,5 +485,6 @@ class LarkAdapter(adapter.MessagePlatformAdapter): port=port, shutdown_trigger=shutdown_trigger_placeholder, ) + async def kill(self) -> bool: return False diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 8dcf6e52..7038af1d 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -4,7 +4,6 @@ import asyncio import typing import traceback -import logging import nakuru @@ -19,6 +18,7 @@ from ...platform.types import events as platform_events class NakuruProjectMessageConverter(adapter_model.MessageConverter): """消息转换器""" + @staticmethod def yiri2target(message_chain: platform_message.MessageChain) -> list: msg_list = [] @@ -29,10 +29,12 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): elif type(message_chain) is str: msg_list = [platform_message.Plain(message_chain)] else: - raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) - + raise Exception( + 'Unknown message type: ' + str(message_chain) + str(type(message_chain)) + ) + nakuru_msg_list = [] - + # 遍历并转换 for component in msg_list: if type(component) is platform_message.Plain: @@ -61,33 +63,43 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): # 遍历并转换 for yiri_forward_node in yiri_forward_node_list: try: - content_list = NakuruProjectMessageConverter.yiri2target(yiri_forward_node.message_chain) + content_list = NakuruProjectMessageConverter.yiri2target( + yiri_forward_node.message_chain + ) nakuru_forward_node = nkc.Node( name=yiri_forward_node.sender_name, uin=yiri_forward_node.sender_id, - time=int(yiri_forward_node.time.timestamp()) if yiri_forward_node.time is not None else None, - content=content_list + time=int(yiri_forward_node.time.timestamp()) + if yiri_forward_node.time is not None + else None, + content=content_list, ) nakuru_forward_node_list.append(nakuru_forward_node) - except Exception as e: + except Exception: import traceback + traceback.print_exc() nakuru_msg_list.append(nakuru_forward_node_list) else: nakuru_msg_list.append(nkc.Plain(str(component))) - + return nakuru_msg_list @staticmethod - def target2yiri(message_chain: typing.Any, message_id: int = -1) -> platform_message.MessageChain: + def target2yiri( + message_chain: typing.Any, message_id: int = -1 + ) -> platform_message.MessageChain: """将Yiri的消息链转换为YiriMirai的消息链""" assert type(message_chain) is list yiri_msg_list = [] import datetime + # 添加Source组件以标记message_id等信息 - yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) + yiri_msg_list.append( + platform_message.Source(id=message_id, time=datetime.datetime.now()) + ) for component in message_chain: if type(component) is nkc.Plain: yiri_msg_list.append(platform_message.Plain(text=component.text)) @@ -106,6 +118,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): class NakuruProjectEventConverter(adapter_model.EventConverter): """事件转换器""" + @staticmethod def yiri2target(event: typing.Type[platform_events.Event]): if event is platform_events.GroupMessage: @@ -113,28 +126,30 @@ class NakuruProjectEventConverter(adapter_model.EventConverter): elif event is platform_events.FriendMessage: return nakuru.FriendMessage else: - raise Exception("未支持转换的事件类型: " + str(event)) + raise Exception('未支持转换的事件类型: ' + str(event)) @staticmethod def target2yiri(event: typing.Any) -> platform_events.Event: - yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id) + yiri_chain = NakuruProjectMessageConverter.target2yiri( + event.message, event.message_id + ) if type(event) is nakuru.FriendMessage: # 私聊消息事件 return platform_events.FriendMessage( sender=platform_entities.Friend( id=event.sender.user_id, nickname=event.sender.nickname, - remark=event.sender.nickname + remark=event.sender.nickname, ), message_chain=yiri_chain, - time=event.time + time=event.time, ) elif type(event) is nakuru.GroupMessage: # 群聊消息事件 - permission = "MEMBER" + permission = 'MEMBER' - if event.sender.role == "admin": - permission = "ADMINISTRATOR" - elif event.sender.role == "owner": - permission = "OWNER" + if event.sender.role == 'admin': + permission = 'ADMINISTRATOR' + elif event.sender.role == 'owner': + permission = 'OWNER' return platform_events.GroupMessage( sender=platform_entities.GroupMember( @@ -144,7 +159,7 @@ class NakuruProjectEventConverter(adapter_model.EventConverter): group=platform_entities.Group( id=event.group_id, name=event.sender.nickname, - permission=platform_entities.Permission.Member + permission=platform_entities.Permission.Member, ), special_title=event.sender.title, join_timestamp=0, @@ -152,14 +167,15 @@ class NakuruProjectEventConverter(adapter_model.EventConverter): mute_time_remaining=0, ), message_chain=yiri_chain, - time=event.time + time=event.time, ) else: - raise Exception("未支持转换的事件类型: " + str(event)) + raise Exception('未支持转换的事件类型: ' + str(event)) class NakuruAdapter(adapter_model.MessagePlatformAdapter): """nakuru-project适配器""" + bot: nakuru.CQHTTP bot_account_id: int @@ -186,12 +202,14 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): target_type: str, target_id: str, message: typing.Union[platform_message.MessageChain, list], - converted: bool = False + converted: bool = False, ): task = None - converted_msg = self.message_converter.yiri2target(message) if not converted else message - + converted_msg = ( + self.message_converter.yiri2target(message) if not converted else message + ) + # 检查是否有转发消息 has_forward = False for msg in converted_msg: @@ -200,19 +218,19 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): converted_msg = msg break if has_forward: - if target_type == "group": + if target_type == 'group': task = self.bot.sendGroupForwardMessage(int(target_id), converted_msg) - elif target_type == "person": + elif target_type == 'person': task = self.bot.sendPrivateForwardMessage(int(target_id), converted_msg) else: - raise Exception("Unknown target type: " + target_type) + raise Exception('Unknown target type: ' + target_type) else: - if target_type == "group": + if target_type == 'group': task = self.bot.sendGroupMessage(int(target_id), converted_msg) - elif target_type == "person": + elif target_type == 'person': task = self.bot.sendFriendMessage(int(target_id), converted_msg) else: - raise Exception("Unknown target type: " + target_type) + raise Exception('Unknown target type: ' + target_type) await task @@ -220,45 +238,45 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): self, message_source: platform_events.MessageEvent, message: platform_message.MessageChain, - quote_origin: bool = False + quote_origin: bool = False, ): message = self.message_converter.yiri2target(message) if quote_origin: # 在前方添加引用组件 - message.insert(0, nkc.Reply( + message.insert( + 0, + nkc.Reply( id=message_source.message_chain.message_id, - ) + ), ) if type(message_source) is platform_events.GroupMessage: await self.send_message( - "group", - message_source.sender.group.id, - message, - converted=True + 'group', message_source.sender.group.id, message, converted=True ) elif type(message_source) is platform_events.FriendMessage: await self.send_message( - "person", - message_source.sender.id, - message, - converted=True + 'person', message_source.sender.id, message, converted=True ) else: - raise Exception("Unknown message source type: " + str(type(message_source))) + raise Exception('Unknown message source type: ' + str(type(message_source))) def is_muted(self, group_id: int) -> bool: import time + # 检查是否被禁言 - group_member_info = asyncio.run(self.bot.getGroupMemberInfo(group_id, self.bot_account_id)) + group_member_info = asyncio.run( + self.bot.getGroupMemberInfo(group_id, self.bot_account_id) + ) return group_member_info.shut_up_timestamp > int(time.time()) def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None] + callback: typing.Callable[ + [platform_events.Event, adapter_model.MessagePlatformAdapter], None + ], ): try: - source_cls = NakuruProjectEventConverter.yiri2target(event_type) # 包装函数 @@ -268,9 +286,9 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): # 将包装函数和原函数的对应关系存入列表 self.listener_list.append( { - "event_type": event_type, - "callable": callback, - "wrapper": listener_wrapper, + 'event_type': event_type, + 'callable': callback, + 'wrapper': listener_wrapper, } ) @@ -283,7 +301,9 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None] + callback: typing.Callable[ + [platform_events.Event, adapter_model.MessagePlatformAdapter], None + ], ): nakuru_event_name = self.event_converter.yiri2target(event_type).__name__ @@ -292,13 +312,16 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): # 从本对象的监听器列表中查找并删除 target_wrapper = None for listener in self.listener_list: - if listener["event_type"] == event_type and listener["callable"] == callback: - target_wrapper = listener["wrapper"] + if ( + listener['event_type'] == event_type + and listener['callable'] == callback + ): + target_wrapper = listener['wrapper'] self.listener_list.remove(listener) break if target_wrapper is None: - raise Exception("未找到对应的监听器") + raise Exception('未找到对应的监听器') for func in self.bot.event[nakuru_event_name]: if func.callable != target_wrapper: @@ -309,23 +332,30 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): async def run_async(self): try: import requests + resp = requests.get( - url="http://{}:{}/get_login_info".format(self.cfg['host'], self.cfg['http_port']), + url='http://{}:{}/get_login_info'.format( + self.cfg['host'], self.cfg['http_port'] + ), headers={ - 'Authorization': "Bearer " + self.cfg['token'] if 'token' in self.cfg else "" + 'Authorization': 'Bearer ' + self.cfg['token'] + if 'token' in self.cfg + else '' }, timeout=5, - proxies=None + proxies=None, ) if resp.status_code == 403: - raise Exception("go-cqhttp拒绝访问,请检查配置文件中nakuru适配器的配置") + raise Exception('go-cqhttp拒绝访问,请检查配置文件中nakuru适配器的配置') self.bot_account_id = int(resp.json()['data']['user_id']) - except Exception as e: - raise Exception("获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确") + except Exception: + raise Exception( + '获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确' + ) await self.bot._run() - self.ap.logger.info("运行 Nakuru 适配器") + self.ap.logger.info('运行 Nakuru 适配器') while True: await asyncio.sleep(1) async def kill(self) -> bool: - return False \ No newline at end of file + return False diff --git a/pkg/platform/sources/officialaccount.py b/pkg/platform/sources/officialaccount.py index 0816824f..6e7eaf2f 100644 --- a/pkg/platform/sources/officialaccount.py +++ b/pkg/platform/sources/officialaccount.py @@ -4,20 +4,13 @@ import asyncio import traceback import datetime -from pkg.core import app from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.types import events as platform_events, message as platform_message -from collections import deque from libs.official_account_api.oaevent import OAEvent -from pkg.platform.adapter import MessagePlatformAdapter -from pkg.platform.types import events as platform_events, message as platform_message from libs.official_account_api.api import OAClient from libs.official_account_api.api import OAClientForLongerResponse -from pkg.core import app from .. import adapter from ...core import app -from ..types import message as platform_message -from ..types import events as platform_events from ..types import entities as platform_entities from ...command.errors import ParamNotEnoughError @@ -28,10 +21,9 @@ class OAMessageConverter(adapter.MessageConverter): for msg in message_chain: if type(msg) is platform_message.Plain: return msg.text - @staticmethod - async def target2yiri(message:str,message_id =-1): + async def target2yiri(message: str, message_id=-1): yiri_msg_list = [] yiri_msg_list.append( platform_message.Source(id=message_id, time=datetime.datetime.now()) @@ -41,12 +33,12 @@ class OAMessageConverter(adapter.MessageConverter): chain = platform_message.MessageChain(yiri_msg_list) return chain - + class OAEventConverter(adapter.EventConverter): @staticmethod - async def target2yiri(event:OAEvent): - if event.type == "text": + async def target2yiri(event: OAEvent): + if event.type == 'text': yiri_chain = await OAMessageConverter.target2yiri( event.message, event.message_id ) @@ -54,91 +46,101 @@ class OAEventConverter(adapter.EventConverter): friend = platform_entities.Friend( id=event.user_id, nickname=str(event.user_id), - remark="", + remark='', ) return platform_events.FriendMessage( - sender=friend, message_chain=yiri_chain, time=event.timestamp, source_platform_object=event + sender=friend, + message_chain=yiri_chain, + time=event.timestamp, + source_platform_object=event, ) else: return None -class OfficialAccountAdapter(adapter.MessagePlatformAdapter): - bot : OAClient | OAClientForLongerResponse - ap : app.Application +class OfficialAccountAdapter(adapter.MessagePlatformAdapter): + bot: OAClient | OAClientForLongerResponse + ap: app.Application bot_account_id: str message_converter: OAMessageConverter = OAMessageConverter() event_converter: OAEventConverter = OAEventConverter() config: dict - def __init__(self, config: dict, ap: app.Application): self.config = config - + self.ap = ap required_keys = [ - "token", - "EncodingAESKey", - "AppSecret", - "AppID", - "Mode", + 'token', + 'EncodingAESKey', + 'AppSecret', + 'AppID', + 'Mode', ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError("微信公众号缺少相关配置项,请查看文档或联系管理员") - - - if self.config['Mode'] == "drop": + raise ParamNotEnoughError( + '微信公众号缺少相关配置项,请查看文档或联系管理员' + ) + + if self.config['Mode'] == 'drop': self.bot = OAClient( token=config['token'], EncodingAESKey=config['EncodingAESKey'], Appsecret=config['AppSecret'], - AppID=config['AppID'], + AppID=config['AppID'], ) - elif self.config['Mode'] == "passive": + elif self.config['Mode'] == 'passive': self.bot = OAClientForLongerResponse( token=config['token'], EncodingAESKey=config['EncodingAESKey'], Appsecret=config['AppSecret'], - AppID=config['AppID'], - LoadingMessage=config['LoadingMessage'] + AppID=config['AppID'], + LoadingMessage=config['LoadingMessage'], ) else: - raise KeyError("请设置微信公众号通信模式") + raise KeyError('请设置微信公众号通信模式') - - async def reply_message(self, message_source: platform_events.FriendMessage, message: platform_message.MessageChain, quote_origin: bool = False): - - content = await OAMessageConverter.yiri2target( - message - ) - if type(self.bot) == OAClient: - await self.bot.set_message(message_source.message_chain.message_id,content) - if type(self.bot) == OAClientForLongerResponse: + async def reply_message( + self, + message_source: platform_events.FriendMessage, + message: platform_message.MessageChain, + quote_origin: bool = False, + ): + content = await OAMessageConverter.yiri2target(message) + if isinstance(self.bot, OAClient): + await self.bot.set_message(message_source.message_chain.message_id, content) + elif isinstance(self.bot, OAClientForLongerResponse): from_user = message_source.sender.id - await self.bot.set_message(from_user,message_source.message_chain.message_id,content) + await self.bot.set_message( + from_user, message_source.message_chain.message_id, content + ) - async def send_message( self, target_type: str, target_id: str, message: platform_message.MessageChain ): pass - - def register_listener(self, event_type: type, callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None]): + def register_listener( + self, + event_type: type, + callback: typing.Callable[ + [platform_events.Event, MessagePlatformAdapter], None + ], + ): async def on_message(event: OAEvent): self.bot_account_id = event.receiver_id try: return await callback( await self.event_converter.target2yiri(event), self ) - except: + except Exception: traceback.print_exc() if event_type == platform_events.FriendMessage: - self.bot.on_message("text")(on_message) + self.bot.on_message('text')(on_message) elif event_type == platform_events.GroupMessage: pass @@ -148,8 +150,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): await asyncio.sleep(1) await self.bot.run_task( - host=self.config["host"], - port=self.config["port"], + host=self.config['host'], + port=self.config['port'], shutdown_trigger=shutdown_trigger_placeholder, ) @@ -159,8 +161,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): async def unregister_listener( self, event_type: type, - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, MessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) - - \ No newline at end of file diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py index 716da80f..a91f86dd 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/qqbotpy.py @@ -22,12 +22,20 @@ from ...platform.types import message as platform_message class OfficialGroupMessage(platform_events.GroupMessage): pass + class OfficialFriendMessage(platform_events.FriendMessage): pass + event_handler_mapping = { - platform_events.GroupMessage: ["on_at_message_create", "on_group_at_message_create"], - platform_events.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"], + platform_events.GroupMessage: [ + 'on_at_message_create', + 'on_group_at_message_create', + ], + platform_events.FriendMessage: [ + 'on_direct_message_create', + 'on_c2c_message_create', + ], } @@ -53,9 +61,10 @@ def char_to_value(char): return ord(char) - ord('0') elif 'A' <= char <= 'Z': return ord(char) - ord('A') + 10 - + return ord(char) - ord('a') + 36 + def digest(s: str) -> int: """计算字符串的hash值。""" # 取末尾的8位 @@ -69,19 +78,24 @@ def digest(s: str) -> int: return number -K = typing.TypeVar("K") -V = typing.TypeVar("V") + +K = typing.TypeVar('K') +V = typing.TypeVar('V') class OpenIDMapping(typing.Generic[K, V]): - map: dict[K, V] dump_func: typing.Callable digest_func: typing.Callable[[K], V] - def __init__(self, map: dict[K, V], dump_func: typing.Callable, digest_func: typing.Callable[[K], V] = digest): + def __init__( + self, + map: dict[K, V], + dump_func: typing.Callable, + digest_func: typing.Callable[[K], V] = digest, + ): self.map = map self.dump_func = dump_func @@ -104,12 +118,11 @@ class OpenIDMapping(typing.Generic[K, V]): def getkey(self, value: V) -> K: return list(self.map.keys())[list(self.map.values()).index(value)] - + def save_openid(self, key: K) -> V: - if key in self.map: return self.map[key] - + value = self.digest_func(key) self.map[key] = value @@ -135,7 +148,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter): msg_list = [platform_message.Plain(text=message_chain)] else: raise Exception( - "Unknown message type: " + str(message_chain) + str(type(message_chain)) + 'Unknown message type: ' + str(message_chain) + str(type(message_chain)) ) offcial_messages: list[dict] = [] @@ -154,23 +167,23 @@ class OfficialMessageConverter(adapter_model.MessageConverter): # 遍历并转换 for component in msg_list: if type(component) is platform_message.Plain: - offcial_messages.append({"type": "text", "content": component.text}) + offcial_messages.append({'type': 'text', 'content': component.text}) elif type(component) is platform_message.Image: if component.url is not None: - offcial_messages.append({"type": "image", "content": component.url}) + offcial_messages.append({'type': 'image', 'content': component.url}) elif component.path is not None: offcial_messages.append( - {"type": "file_image", "content": component.path} + {'type': 'file_image', 'content': component.path} ) elif type(component) is platform_message.At: - offcial_messages.append({"type": "at", "content": ""}) + offcial_messages.append({'type': 'at', 'content': ''}) elif type(component) is platform_message.AtAll: print( - "上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。" + '上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。' ) elif type(component) is platform_message.Voice: print( - "上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。" + '上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。' ) elif type(component) is forward.Forward: # 转发消息 @@ -185,7 +198,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter): offcial_messages.extend( OfficialMessageConverter.yiri2target(message_chain) ) - except Exception as e: + except Exception: import traceback traceback.print_exc() @@ -194,7 +207,12 @@ class OfficialMessageConverter(adapter_model.MessageConverter): @staticmethod def extract_message_chain_from_obj( - message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage], + message: typing.Union[ + botpy_message.Message, + botpy_message.DirectMessage, + botpy_message.GroupMessage, + botpy_message.C2CMessage, + ], message_id: str = None, bot_account_id: int = 0, ) -> platform_message.MessageChain: @@ -210,7 +228,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter): if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]: yiri_msg_list.append(platform_message.At(target=bot_account_id)) - if hasattr(message, "mentions"): + if hasattr(message, 'mentions'): for mention in message.mentions: if mention.bot: continue @@ -218,15 +236,15 @@ class OfficialMessageConverter(adapter_model.MessageConverter): yiri_msg_list.append(platform_message.At(target=mention.id)) for attachment in message.attachments: - if attachment.content_type.startswith("image"): + if attachment.content_type.startswith('image'): yiri_msg_list.append(platform_message.Image(url=attachment.url)) else: logging.warning( - "不支持的附件类型:" + attachment.content_type + ",忽略此附件。" + '不支持的附件类型:' + attachment.content_type + ',忽略此附件。' ) - content = re.sub(r"<@!\d+>", "", str(message.content)) - if content.strip() != "": + content = re.sub(r'<@!\d+>', '', str(message.content)) + if content.strip() != '': yiri_msg_list.append(platform_message.Plain(text=content)) chain = platform_message.MessageChain(yiri_msg_list) @@ -247,21 +265,25 @@ class OfficialEventConverter(adapter_model.EventConverter): return botpy_message.DirectMessage else: raise Exception( - "未支持转换的事件类型(YiriMirai -> Official): " + str(event) + '未支持转换的事件类型(YiriMirai -> Official): ' + str(event) ) def target2yiri( self, - event: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage], + event: typing.Union[ + botpy_message.Message, + botpy_message.DirectMessage, + botpy_message.GroupMessage, + botpy_message.C2CMessage, + ], ) -> platform_events.Event: + if isinstance(event, botpy_message.Message): # 频道内,转群聊事件 + permission = 'MEMBER' - if type(event) == botpy_message.Message: # 频道内,转群聊事件 - permission = "MEMBER" - - if "2" in event.member.roles: - permission = "ADMINISTRATOR" - elif "4" in event.member.roles: - permission = "OWNER" + if '2' in event.member.roles: + permission = 'ADMINISTRATOR' + elif '4' in event.member.roles: + permission = 'OWNER' return platform_events.GroupMessage( sender=platform_entities.GroupMember( @@ -273,10 +295,10 @@ class OfficialEventConverter(adapter_model.EventConverter): name=event.author.username, permission=platform_entities.Permission.Member, ), - special_title="", + special_title='', join_timestamp=int( datetime.datetime.strptime( - event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z" + event.member.joined_at, '%Y-%m-%dT%H:%M:%S%z' ).timestamp() ), last_speak_timestamp=datetime.datetime.now().timestamp(), @@ -287,11 +309,11 @@ class OfficialEventConverter(adapter_model.EventConverter): ), time=int( datetime.datetime.strptime( - event.timestamp, "%Y-%m-%dT%H:%M:%S%z" + event.timestamp, '%Y-%m-%dT%H:%M:%S%z' ).timestamp() ), ) - elif type(event) == botpy_message.DirectMessage: # 频道私聊,转私聊事件 + elif isinstance(event, botpy_message.DirectMessage): # 频道私聊,转私聊事件 return platform_events.FriendMessage( sender=platform_entities.Friend( id=event.guild_id, @@ -303,25 +325,24 @@ class OfficialEventConverter(adapter_model.EventConverter): ), time=int( datetime.datetime.strptime( - event.timestamp, "%Y-%m-%dT%H:%M:%S%z" + event.timestamp, '%Y-%m-%dT%H:%M:%S%z' ).timestamp() ), ) - elif type(event) == botpy_message.GroupMessage: # 群聊,转群聊事件 - + elif isinstance(event, botpy_message.GroupMessage): # 群聊,转群聊事件 author_member_id = event.author.member_openid return OfficialGroupMessage( sender=platform_entities.GroupMember( id=author_member_id, member_name=author_member_id, - permission="MEMBER", + permission='MEMBER', group=platform_entities.Group( id=event.group_openid, name=author_member_id, permission=platform_entities.Permission.Member, ), - special_title="", + special_title='', join_timestamp=int(0), last_speak_timestamp=datetime.datetime.now().timestamp(), mute_time_remaining=0, @@ -331,12 +352,11 @@ class OfficialEventConverter(adapter_model.EventConverter): ), time=int( datetime.datetime.strptime( - event.timestamp, "%Y-%m-%dT%H:%M:%S%z" + event.timestamp, '%Y-%m-%dT%H:%M:%S%z' ).timestamp() ), ) - elif type(event) == botpy_message.C2CMessage: # 私聊,转私聊事件 - + elif isinstance(event, botpy_message.C2CMessage): # 私聊,转私聊事件 user_id_alter = event.author.user_openid return OfficialFriendMessage( @@ -350,7 +370,7 @@ class OfficialEventConverter(adapter_model.EventConverter): ), time=int( datetime.datetime.strptime( - event.timestamp, "%Y-%m-%dT%H:%M:%S%z" + event.timestamp, '%Y-%m-%dT%H:%M:%S%z' ).timestamp() ), ) @@ -391,10 +411,10 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): switchs = {} - for intent in cfg["intents"]: + for intent in cfg['intents']: switchs[intent] = True - del cfg["intents"] + del cfg['intents'] intents = botpy.Intents(**switchs) @@ -408,21 +428,21 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): for msg in message_list: args = {} - if msg["type"] == "text": - args["content"] = msg["content"] - elif msg["type"] == "image": - args["image"] = msg["content"] - elif msg["type"] == "file_image": - args["file_image"] = msg["content"] + if msg['type'] == 'text': + args['content'] = msg['content'] + elif msg['type'] == 'image': + args['image'] = msg['content'] + elif msg['type'] == 'file_image': + args['file_image'] = msg['content'] else: continue - if target_type == "group": - args["channel_id"] = str(target_id) + if target_type == 'group': + args['channel_id'] = str(target_id) await self.bot.api.post_message(**args) - elif target_type == "person": - args["guild_id"] = str(target_id) + elif target_type == 'person': + args['guild_id'] = str(target_id) await self.bot.api.post_dms(**args) @@ -432,86 +452,82 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): message: platform_message.MessageChain, quote_origin: bool = False, ): - message_list = self.message_converter.yiri2target(message) for msg in message_list: args = {} - if msg["type"] == "text": - args["content"] = msg["content"] - elif msg["type"] == "image": - args["image"] = msg["content"] - elif msg["type"] == "file_image": - args["file_image"] = msg["content"] + if msg['type'] == 'text': + args['content'] = msg['content'] + elif msg['type'] == 'image': + args['image'] = msg['content'] + elif msg['type'] == 'file_image': + args['file_image'] = msg['content'] else: continue if quote_origin: - args["message_reference"] = botpy_message_type.Reference( + args['message_reference'] = botpy_message_type.Reference( message_id=cached_message_ids[ str(message_source.message_chain.message_id) ] ) - if type(message_source) == platform_events.GroupMessage: - args["channel_id"] = str(message_source.sender.group.id) - args["msg_id"] = cached_message_ids[ + if isinstance(message_source, platform_events.GroupMessage): + args['channel_id'] = str(message_source.sender.group.id) + args['msg_id'] = cached_message_ids[ str(message_source.message_chain.message_id) ] await self.bot.api.post_message(**args) - elif type(message_source) == platform_events.FriendMessage: - args["guild_id"] = str(message_source.sender.id) - args["msg_id"] = cached_message_ids[ + elif isinstance(message_source, platform_events.FriendMessage): + args['guild_id'] = str(message_source.sender.id) + args['msg_id'] = cached_message_ids[ str(message_source.message_chain.message_id) ] await self.bot.api.post_dms(**args) - elif type(message_source) == OfficialGroupMessage: - - if "file_image" in args: # 暂不支持发送文件图片 + elif isinstance(message_source, OfficialGroupMessage): + if 'file_image' in args: # 暂不支持发送文件图片 continue - args["group_openid"] = message_source.sender.group.id + args['group_openid'] = message_source.sender.group.id - if "image" in args: + if 'image' in args: uploadMedia = await self.bot.api.post_group_file( - group_openid=args["group_openid"], + group_openid=args['group_openid'], file_type=1, - url=str(args['image']) + url=str(args['image']), ) del args['image'] args['media'] = uploadMedia args['msg_type'] = 7 - args["msg_id"] = cached_message_ids[ + args['msg_id'] = cached_message_ids[ str(message_source.message_chain.message_id) ] - args["msg_seq"] = self.group_msg_seq + args['msg_seq'] = self.group_msg_seq self.group_msg_seq += 1 await self.bot.api.post_group_message(**args) - elif type(message_source) == OfficialFriendMessage: - if "file_image" in args: + elif isinstance(message_source, OfficialFriendMessage): + if 'file_image' in args: continue - args["openid"] = message_source.sender.id + args['openid'] = message_source.sender.id - if "image" in args: + if 'image' in args: uploadMedia = await self.bot.api.post_c2c_file( - openid=args["openid"], - file_type=1, - url=str(args['image']) + openid=args['openid'], file_type=1, url=str(args['image']) ) del args['image'] args['media'] = uploadMedia args['msg_type'] = 7 - args["msg_id"] = cached_message_ids[ + args['msg_id'] = cached_message_ids[ str(message_source.message_chain.message_id) ] - args["msg_seq"] = self.c2c_msg_seq + args['msg_seq'] = self.c2c_msg_seq self.c2c_msg_seq += 1 await self.bot.api.post_c2c_message(**args) @@ -526,7 +542,6 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): [platform_events.Event, adapter_model.MessagePlatformAdapter], None ], ): - try: async def wrapper( @@ -534,7 +549,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, - ] + ], ): self.cached_official_messages[str(message.id)] = message await callback(self.event_converter.target2yiri(message), self) @@ -555,7 +570,6 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): delattr(self.bot, event_handler_mapping[event_type]) async def run_async(self): - self.metadata = self.ap.adapter_qq_botpy_meta self.message_converter = OfficialMessageConverter() @@ -563,7 +577,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): self.cfg['ret_coro'] = True - self.ap.logger.info("运行 QQ 官方适配器") + self.ap.logger.info('运行 QQ 官方适配器') await (await self.bot.start(**self.cfg)) async def kill(self) -> bool: diff --git a/pkg/platform/sources/qqofficial.py b/pkg/platform/sources/qqofficial.py index bfef2135..06893485 100644 --- a/pkg/platform/sources/qqofficial.py +++ b/pkg/platform/sources/qqofficial.py @@ -7,12 +7,8 @@ import datetime from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.types import events as platform_events, message as platform_message -from pkg.core import app from .. import adapter - from ...core import app -from ..types import message as platform_message -from ..types import events as platform_events from ..types import entities as platform_entities from ...command.errors import ParamNotEnoughError from libs.qq_official_api.api import QQOfficialClient @@ -21,157 +17,164 @@ from ...utils import image class QQOfficialMessageConverter(adapter.MessageConverter): - @staticmethod async def yiri2target(message_chain: platform_message.MessageChain): content_list = [] - #只实现了发文字 + # 只实现了发文字 for msg in message_chain: if type(msg) is platform_message.Plain: - content_list.append({ - "type":"text", - "content":msg.text, - }) - + content_list.append( + { + 'type': 'text', + 'content': msg.text, + } + ) + return content_list - + @staticmethod - async def target2yiri(message:str,message_id:str,pic_url:str,content_type): + async def target2yiri(message: str, message_id: str, pic_url: str, content_type): yiri_msg_list = [] yiri_msg_list.append( - platform_message.Source(id=message_id,time=datetime.datetime.now()) + platform_message.Source(id=message_id, time=datetime.datetime.now()) ) if pic_url is not None: - base64_url = await image.get_qq_official_image_base64(pic_url=pic_url,content_type=content_type) - yiri_msg_list.append( - platform_message.Image(base64=base64_url) + base64_url = await image.get_qq_official_image_base64( + pic_url=pic_url, content_type=content_type ) + yiri_msg_list.append(platform_message.Image(base64=base64_url)) yiri_msg_list.append(platform_message.Plain(text=message)) chain = platform_message.MessageChain(yiri_msg_list) return chain + class QQOfficialEventConverter(adapter.EventConverter): + @staticmethod + async def yiri2target(event: platform_events.MessageEvent) -> QQOfficialEvent: + return event.source_platform_object @staticmethod - async def yiri2target(event:platform_events.MessageEvent) -> QQOfficialEvent: - return event.source_platform_object - - @staticmethod - async def target2yiri(event:QQOfficialEvent): + async def target2yiri(event: QQOfficialEvent): """ QQ官方消息转换为LB对象 """ yiri_chain = await QQOfficialMessageConverter.target2yiri( - message=event.content,message_id=event.d_id,pic_url=event.attachments,content_type=event.content_type + message=event.content, + message_id=event.d_id, + pic_url=event.attachments, + content_type=event.content_type, ) - + if event.t == 'C2C_MESSAGE_CREATE': friend = platform_entities.Friend( - id = event.user_openid, - nickname = event.t, - remark = "", + id=event.user_openid, + nickname=event.t, + remark='', ) return platform_events.FriendMessage( - sender = friend,message_chain = yiri_chain,time = int( + sender=friend, + message_chain=yiri_chain, + time=int( datetime.datetime.strptime( - event.timestamp, "%Y-%m-%dT%H:%M:%S%z" + event.timestamp, '%Y-%m-%dT%H:%M:%S%z' ).timestamp() ), - source_platform_object=event + source_platform_object=event, ) - + if event.t == 'DIRECT_MESSAGE_CREATE': friend = platform_entities.Friend( - id = event.guild_id, - nickname = event.t, - remark = "", + id=event.guild_id, + nickname=event.t, + remark='', ) return platform_events.FriendMessage( - sender = friend,message_chain = yiri_chain, - source_platform_object=event + sender=friend, message_chain=yiri_chain, source_platform_object=event ) if event.t == 'GROUP_AT_MESSAGE_CREATE': - yiri_chain.insert(0, platform_message.At(target="justbot")) + yiri_chain.insert(0, platform_message.At(target='justbot')) sender = platform_entities.GroupMember( - id = event.group_openid, - member_name= event.t, - permission= 'MEMBER', - group = platform_entities.Group( - id = event.group_openid, - name = 'MEMBER', - permission= platform_entities.Permission.Member - ), - special_title='', - join_timestamp=0, - last_speak_timestamp=0, - mute_time_remaining=0 - ) - time = int( - datetime.datetime.strptime( - event.timestamp, "%Y-%m-%dT%H:%M:%S%z" - ).timestamp() - ) - return platform_events.GroupMessage( - sender = sender, - message_chain=yiri_chain, - time = time, - source_platform_object=event - ) - if event.t =='AT_MESSAGE_CREATE': - yiri_chain.insert(0, platform_message.At(target="justbot")) - sender = platform_entities.GroupMember( - id = event.channel_id, + id=event.group_openid, member_name=event.t, - permission= 'MEMBER', - group = platform_entities.Group( - id = event.channel_id, - name = 'MEMBER', - permission=platform_entities.Permission.Member + permission='MEMBER', + group=platform_entities.Group( + id=event.group_openid, + name='MEMBER', + permission=platform_entities.Permission.Member, ), special_title='', join_timestamp=0, last_speak_timestamp=0, - mute_time_remaining=0 + mute_time_remaining=0, ) time = int( - datetime.datetime.strptime( - event.timestamp, "%Y-%m-%dT%H:%M:%S%z" - ).timestamp() - ) + datetime.datetime.strptime( + event.timestamp, '%Y-%m-%dT%H:%M:%S%z' + ).timestamp() + ) return platform_events.GroupMessage( - sender =sender, - message_chain = yiri_chain, - time = time, - source_platform_object=event + sender=sender, + message_chain=yiri_chain, + time=time, + source_platform_object=event, + ) + if event.t == 'AT_MESSAGE_CREATE': + yiri_chain.insert(0, platform_message.At(target='justbot')) + sender = platform_entities.GroupMember( + id=event.channel_id, + member_name=event.t, + permission='MEMBER', + group=platform_entities.Group( + id=event.channel_id, + name='MEMBER', + permission=platform_entities.Permission.Member, + ), + special_title='', + join_timestamp=0, + last_speak_timestamp=0, + mute_time_remaining=0, + ) + time = int( + datetime.datetime.strptime( + event.timestamp, '%Y-%m-%dT%H:%M:%S%z' + ).timestamp() + ) + return platform_events.GroupMessage( + sender=sender, + message_chain=yiri_chain, + time=time, + source_platform_object=event, ) class QQOfficialAdapter(adapter.MessagePlatformAdapter): - bot:QQOfficialClient - ap:app.Application - config:dict - bot_account_id:str + bot: QQOfficialClient + ap: app.Application + config: dict + bot_account_id: str message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter() event_converter: QQOfficialEventConverter = QQOfficialEventConverter() - def __init__(self, config:dict, ap:app.Application): + def __init__(self, config: dict, ap: app.Application): self.config = config self.ap = ap required_keys = [ - "appid", - "secret", + 'appid', + 'secret', ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError("QQ官方机器人缺少相关配置项,请查看文档或联系管理员") - + raise ParamNotEnoughError( + 'QQ官方机器人缺少相关配置项,请查看文档或联系管理员' + ) + self.bot = QQOfficialClient( - app_id=config["appid"], - secret=config["secret"], - token=config["token"], + app_id=config['appid'], + secret=config['secret'], + token=config['token'], ) async def reply_message( @@ -186,31 +189,45 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): content_list = await QQOfficialMessageConverter.yiri2target(message) - #私聊消息 + # 私聊消息 if qq_official_event.t == 'C2C_MESSAGE_CREATE': for content in content_list: - if content["type"] == 'text': - await self.bot.send_private_text_msg(qq_official_event.user_openid,content['content'],qq_official_event.d_id) + if content['type'] == 'text': + await self.bot.send_private_text_msg( + qq_official_event.user_openid, + content['content'], + qq_official_event.d_id, + ) - #群聊消息 + # 群聊消息 if qq_official_event.t == 'GROUP_AT_MESSAGE_CREATE': for content in content_list: - if content["type"] == 'text': - await self.bot.send_group_text_msg(qq_official_event.group_openid,content['content'],qq_official_event.d_id) - - #频道群聊 + if content['type'] == 'text': + await self.bot.send_group_text_msg( + qq_official_event.group_openid, + content['content'], + qq_official_event.d_id, + ) + + # 频道群聊 if qq_official_event.t == 'AT_MESSAGE_CREATE': for content in content_list: - if content["type"] == 'text': - await self.bot.send_channle_group_text_msg(qq_official_event.channel_id,content['content'],qq_official_event.d_id) + if content['type'] == 'text': + await self.bot.send_channle_group_text_msg( + qq_official_event.channel_id, + content['content'], + qq_official_event.d_id, + ) - #频道私聊 + # 频道私聊 if qq_official_event.t == 'DIRECT_MESSAGE_CREATE': for content in content_list: - if content["type"] == 'text': - await self.bot.send_channle_private_text_msg(qq_official_event.guild_id,content['content'],qq_official_event.d_id) - - + if content['type'] == 'text': + await self.bot.send_channle_private_text_msg( + qq_official_event.guild_id, + content['content'], + qq_official_event.d_id, + ) async def send_message( self, target_type: str, target_id: str, message: platform_message.MessageChain @@ -224,22 +241,21 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): [platform_events.Event, adapter.MessagePlatformAdapter], None ], ): - async def on_message(event:QQOfficialEvent): - self.bot_account_id = "justbot" + async def on_message(event: QQOfficialEvent): + self.bot_account_id = 'justbot' try: return await callback( - await self.event_converter.target2yiri(event),self + await self.event_converter.target2yiri(event), self ) - except: + except Exception: traceback.print_exc() - - if event_type == platform_events.FriendMessage: - self.bot.on_message("DIRECT_MESSAGE_CREATE")(on_message) - self.bot.on_message("C2C_MESSAGE_CREATE")(on_message) - elif event_type == platform_events.GroupMessage: - self.bot.on_message("GROUP_AT_MESSAGE_CREATE")(on_message) - self.bot.on_message("AT_MESSAGE_CREATE")(on_message) + if event_type == platform_events.FriendMessage: + self.bot.on_message('DIRECT_MESSAGE_CREATE')(on_message) + self.bot.on_message('C2C_MESSAGE_CREATE')(on_message) + elif event_type == platform_events.GroupMessage: + self.bot.on_message('GROUP_AT_MESSAGE_CREATE')(on_message) + self.bot.on_message('AT_MESSAGE_CREATE')(on_message) async def run_async(self): async def shutdown_trigger_placeholder(): @@ -248,17 +264,18 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): await self.bot.run_task( host='0.0.0.0', - port=self.config["port"], + port=self.config['port'], shutdown_trigger=shutdown_trigger_placeholder, - ) - + ) + async def kill(self) -> bool: return False - + def unregister_listener( self, event_type: type, - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, MessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) - diff --git a/pkg/platform/sources/telegram.py b/pkg/platform/sources/telegram.py index 49822673..b35b7e7a 100644 --- a/pkg/platform/sources/telegram.py +++ b/pkg/platform/sources/telegram.py @@ -3,48 +3,33 @@ from __future__ import annotations import telegram import telegram.ext from telegram import Update -from telegram.ext import ApplicationBuilder, ContextTypes, CommandHandler, MessageHandler, filters +from telegram.ext import ApplicationBuilder, ContextTypes, MessageHandler, filters import typing -import asyncio import traceback -import time -import re -import base64 -import uuid -import json -import datetime -import hashlib import base64 import aiohttp -from Crypto.Cipher import AES -from flask import jsonify from lark_oapi.api.im.v1 import * -from lark_oapi.api.verification.v1 import GetVerificationRequest from .. import adapter -from ...pipeline.longtext.strategies import forward from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities -from ...utils import image class TelegramMessageConverter(adapter.MessageConverter): @staticmethod - async def yiri2target(message_chain: platform_message.MessageChain, bot: telegram.Bot) -> list[dict]: + async def yiri2target( + message_chain: platform_message.MessageChain, bot: telegram.Bot + ) -> list[dict]: components = [] for component in message_chain: if isinstance(component, platform_message.Plain): - components.append({ - "type": "text", - "text": component.text - }) + components.append({'type': 'text', 'text': component.text}) elif isinstance(component, platform_message.Image): - photo_bytes = None if component.base64: @@ -54,24 +39,25 @@ class TelegramMessageConverter(adapter.MessageConverter): async with session.get(component.url) as response: photo_bytes = await response.read() elif component.path: - with open(component.path, "rb") as f: + with open(component.path, 'rb') as f: photo_bytes = f.read() - - components.append({ - "type": "photo", - "photo": photo_bytes - }) + + components.append({'type': 'photo', 'photo': photo_bytes}) elif isinstance(component, platform_message.Forward): for node in component.node_list: - components.extend(await TelegramMessageConverter.yiri2target(node.message_chain, bot)) + components.extend( + await TelegramMessageConverter.yiri2target( + node.message_chain, bot + ) + ) return components - - @staticmethod - async def target2yiri(message: telegram.Message, bot: telegram.Bot, bot_account_id: str): - - message_components = [] + @staticmethod + async def target2yiri( + message: telegram.Message, bot: telegram.Bot, bot_account_id: str + ): + message_components = [] def parse_message_text(text: str) -> list[platform_message.MessageComponent]: msg_components = [] @@ -86,7 +72,7 @@ class TelegramMessageConverter(adapter.MessageConverter): if message.text: message_text = message.text message_components.extend(parse_message_text(message_text)) - + if message.photo: message_components.extend(parse_message_text(message.caption)) @@ -100,21 +86,26 @@ class TelegramMessageConverter(adapter.MessageConverter): file_bytes = await response.read() file_format = 'image/jpeg' - message_components.append(platform_message.Image(base64=f"data:{file_format};base64,{base64.b64encode(file_bytes).decode('utf-8')}")) - + message_components.append( + platform_message.Image( + base64=f'data:{file_format};base64,{base64.b64encode(file_bytes).decode("utf-8")}' + ) + ) + return platform_message.MessageChain(message_components) - + class TelegramEventConverter(adapter.EventConverter): @staticmethod async def yiri2target(event: platform_events.MessageEvent, bot: telegram.Bot): return event.source_platform_object - + @staticmethod async def target2yiri(event: Update, bot: telegram.Bot, bot_account_id: str): + lb_message = await TelegramMessageConverter.target2yiri( + event.message, bot, bot_account_id + ) - lb_message = await TelegramMessageConverter.target2yiri(event.message, bot, bot_account_id) - if event.effective_chat.type == 'private': return platform_events.FriendMessage( sender=platform_entities.Friend( @@ -124,7 +115,7 @@ class TelegramEventConverter(adapter.EventConverter): ), message_chain=lb_message, time=event.message.date.timestamp(), - source_platform_object=event + source_platform_object=event, ) elif event.effective_chat.type == 'group': return platform_events.GroupMessage( @@ -137,19 +128,18 @@ class TelegramEventConverter(adapter.EventConverter): name=event.effective_chat.title, permission=platform_entities.Permission.Member, ), - special_title="", + special_title='', join_timestamp=0, last_speak_timestamp=0, mute_time_remaining=0, ), message_chain=lb_message, time=event.message.date.timestamp(), - source_platform_object=event + source_platform_object=event, ) - + class TelegramAdapter(adapter.MessagePlatformAdapter): - bot: telegram.Bot application: telegram.ext.Application @@ -165,26 +155,31 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): typing.Type[platform_events.Event], typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ] = {} - + def __init__(self, config: dict, ap: app.Application): self.config = config self.ap = ap - - async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE): + async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE): if update.message.from_user.is_bot: return try: - lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id) + lb_event = await self.event_converter.target2yiri( + update, self.bot, self.bot_account_id + ) await self.listeners[type(lb_event)](lb_event, self) - except Exception as e: + except Exception: print(traceback.format_exc()) - + self.application = ApplicationBuilder().token(self.config['token']).build() self.bot = self.application.bot - self.application.add_handler(MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO , telegram_callback)) - + self.application.add_handler( + MessageHandler( + filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback + ) + ) + async def send_message( self, target_type: str, target_id: str, message: platform_message.MessageChain ): @@ -198,45 +193,48 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): ): assert isinstance(message_source.source_platform_object, Update) components = await TelegramMessageConverter.yiri2target(message, self.bot) - + for component in components: if component['type'] == 'text': - args = { - "chat_id": message_source.source_platform_object.effective_chat.id, - "text": component['text'], + 'chat_id': message_source.source_platform_object.effective_chat.id, + 'text': component['text'], } if quote_origin: - args['reply_to_message_id'] = message_source.source_platform_object.message.id + args['reply_to_message_id'] = ( + message_source.source_platform_object.message.id + ) await self.bot.send_message(**args) - + async def is_muted(self, group_id: int) -> bool: return False - + def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, adapter.MessagePlatformAdapter], None + ], ): self.listeners[event_type] = callback - + def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, adapter.MessagePlatformAdapter], None + ], ): self.listeners.pop(event_type) - + async def run_async(self): await self.application.initialize() self.bot_account_id = (await self.bot.get_me()).username - await self.application.updater.start_polling( - allowed_updates=Update.ALL_TYPES - ) + await self.application.updater.start_polling(allowed_updates=Update.ALL_TYPES) await self.application.start() - + async def kill(self) -> bool: await self.application.stop() - return True \ No newline at end of file + return True diff --git a/pkg/platform/sources/wecom.py b/pkg/platform/sources/wecom.py index 40632595..53878062 100644 --- a/pkg/platform/sources/wecom.py +++ b/pkg/platform/sources/wecom.py @@ -9,17 +9,14 @@ from libs.wecom_api.api import WecomClient from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.types import events as platform_events, message as platform_message from libs.wecom_api.wecomevent import WecomEvent -from pkg.core import app from .. import adapter from ...core import app -from ..types import message as platform_message -from ..types import events as platform_events from ..types import entities as platform_entities from ...command.errors import ParamNotEnoughError from ...utils import image -class WecomMessageConverter(adapter.MessageConverter): +class WecomMessageConverter(adapter.MessageConverter): @staticmethod async def yiri2target( message_chain: platform_message.MessageChain, bot: WecomClient @@ -28,23 +25,35 @@ class WecomMessageConverter(adapter.MessageConverter): for msg in message_chain: if type(msg) is platform_message.Plain: - content_list.append({ - "type": "text", - "content": msg.text, - }) + content_list.append( + { + 'type': 'text', + 'content': msg.text, + } + ) elif type(msg) is platform_message.Image: - content_list.append({ - "type": "image", - "media_id": await bot.get_media_id(msg), - }) + content_list.append( + { + 'type': 'image', + 'media_id': await bot.get_media_id(msg), + } + ) elif type(msg) is platform_message.Forward: for node in msg.node_list: - content_list.extend((await WecomMessageConverter.yiri2target(node.message_chain, bot))) + content_list.extend( + ( + await WecomMessageConverter.yiri2target( + node.message_chain, bot + ) + ) + ) else: - content_list.append({ - "type": "text", - "content": str(msg), - }) + content_list.append( + { + 'type': 'text', + 'content': str(msg), + } + ) return content_list @@ -67,14 +76,17 @@ class WecomMessageConverter(adapter.MessageConverter): platform_message.Source(id=message_id, time=datetime.datetime.now()) ) image_base64, image_format = await image.get_wecom_image_base64(pic_url=picurl) - yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}")) + yiri_msg_list.append( + platform_message.Image( + base64=f'data:image/{image_format};base64,{image_base64}' + ) + ) chain = platform_message.MessageChain(yiri_msg_list) - + return chain class WecomEventConverter: - @staticmethod async def yiri2target( event: platform_events.Event, bot_account_id: int, bot: WecomClient @@ -85,18 +97,17 @@ class WecomEventConverter: pass if type(event) is platform_events.FriendMessage: - payload = { - "MsgType": "text", - "Content": '', - "FromUserName": event.sender.id, - "ToUserName": bot_account_id, - "CreateTime": int(datetime.datetime.now().timestamp()), - "AgentID": event.sender.nickname, + 'MsgType': 'text', + 'Content': '', + 'FromUserName': event.sender.id, + 'ToUserName': bot_account_id, + 'CreateTime': int(datetime.datetime.now().timestamp()), + 'AgentID': event.sender.nickname, } wecom_event = WecomEvent.from_payload(payload=payload) if not wecom_event: - raise ValueError("无法从 message_data 构造 WecomEvent 对象") + raise ValueError('无法从 message_data 构造 WecomEvent 对象') return wecom_event @@ -112,24 +123,24 @@ class WecomEventConverter: platform_events.FriendMessage: 转换后的 FriendMessage 对象。 """ # 转换消息链 - if event.type == "text": + if event.type == 'text': yiri_chain = await WecomMessageConverter.target2yiri( event.message, event.message_id ) friend = platform_entities.Friend( - id=f"u{event.user_id}", + id=f'u{event.user_id}', nickname=str(event.agent_id), - remark="", + remark='', ) return platform_events.FriendMessage( sender=friend, message_chain=yiri_chain, time=event.timestamp ) - elif event.type == "image": + elif event.type == 'image': friend = platform_entities.Friend( - id=f"u{event.user_id}", + id=f'u{event.user_id}', nickname=str(event.agent_id), - remark="", + remark='', ) yiri_chain = await WecomMessageConverter.target2yiri_image( @@ -142,7 +153,6 @@ class WecomEventConverter: class WecomAdapter(adapter.MessagePlatformAdapter): - bot: WecomClient ap: app.Application bot_account_id: str @@ -156,22 +166,22 @@ class WecomAdapter(adapter.MessagePlatformAdapter): self.ap = ap required_keys = [ - "corpid", - "secret", - "token", - "EncodingAESKey", - "contacts_secret", + 'corpid', + 'secret', + 'token', + 'EncodingAESKey', + 'contacts_secret', ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError("企业微信缺少相关配置项,请查看文档或联系管理员") + raise ParamNotEnoughError('企业微信缺少相关配置项,请查看文档或联系管理员') self.bot = WecomClient( - corpid=config["corpid"], - secret=config["secret"], - token=config["token"], - EncodingAESKey=config["EncodingAESKey"], - contacts_secret=config["contacts_secret"], + corpid=config['corpid'], + secret=config['secret'], + token=config['token'], + EncodingAESKey=config['EncodingAESKey'], + contacts_secret=config['contacts_secret'], ) async def reply_message( @@ -180,7 +190,6 @@ class WecomAdapter(adapter.MessagePlatformAdapter): message: platform_message.MessageChain, quote_origin: bool = False, ): - Wecom_event = await WecomEventConverter.yiri2target( message_source, self.bot_account_id, self.bot ) @@ -189,11 +198,15 @@ class WecomAdapter(adapter.MessagePlatformAdapter): # 删掉开头的u fixed_user_id = fixed_user_id[1:] for content in content_list: - if content["type"] == "text": - await self.bot.send_private_msg(fixed_user_id, Wecom_event.agent_id, content["content"]) - elif content["type"] == "image": - await self.bot.send_image(fixed_user_id, Wecom_event.agent_id, content["media_id"]) - + if content['type'] == 'text': + await self.bot.send_private_msg( + fixed_user_id, Wecom_event.agent_id, content['content'] + ) + elif content['type'] == 'image': + await self.bot.send_image( + fixed_user_id, Wecom_event.agent_id, content['media_id'] + ) + async def send_message( self, target_type: str, target_id: str, message: platform_message.MessageChain ): @@ -201,15 +214,17 @@ class WecomAdapter(adapter.MessagePlatformAdapter): 构造target_id的方式为前半部分为账户id,后半部分为agent_id,中间使用“|”符号隔开。 """ content_list = await WecomMessageConverter.yiri2target(message, self.bot) - parts = target_id.split("|") + parts = target_id.split('|') user_id = parts[0] agent_id = int(parts[1]) if target_type == 'person': for content in content_list: - if content["type"] == "text": - await self.bot.send_private_msg(user_id,agent_id,content["content"]) - if content["type"] == "image": - await self.bot.send_image(user_id,agent_id,content["media"]) + if content['type'] == 'text': + await self.bot.send_private_msg( + user_id, agent_id, content['content'] + ) + if content['type'] == 'image': + await self.bot.send_image(user_id, agent_id, content['media']) def register_listener( self, @@ -224,12 +239,12 @@ class WecomAdapter(adapter.MessagePlatformAdapter): return await callback( await self.event_converter.target2yiri(event), self ) - except: + except Exception: traceback.print_exc() if event_type == platform_events.FriendMessage: - self.bot.on_message("text")(on_message) - self.bot.on_message("image")(on_message) + self.bot.on_message('text')(on_message) + self.bot.on_message('image')(on_message) elif event_type == platform_events.GroupMessage: pass @@ -239,8 +254,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter): await asyncio.sleep(1) await self.bot.run_task( - host=self.config["host"], - port=self.config["port"], + host=self.config['host'], + port=self.config['port'], shutdown_trigger=shutdown_trigger_placeholder, ) @@ -250,6 +265,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter): async def unregister_listener( self, event_type: type, - callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], + callback: typing.Callable[ + [platform_events.Event, MessagePlatformAdapter], None + ], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/types/base.py b/pkg/platform/types/base.py index ce87d36c..9e31bafe 100644 --- a/pkg/platform/types/base.py +++ b/pkg/platform/types/base.py @@ -1,4 +1,3 @@ - from typing import Dict, List, Type import pydantic.v1.main as pdm @@ -25,14 +24,18 @@ class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass): 2. 允许通过别名访问字段。 3. 自动生成小驼峰风格的别名。 """ + def __init__(self, *args, **kwargs): """""" super().__init__(*args, **kwargs) def __repr__(self) -> str: - return self.__class__.__name__ + '(' + ', '.join( - (f'{k}={repr(v)}' for k, v in self.__dict__.items() if v) - ) + ')' + return ( + self.__class__.__name__ + + '(' + + ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if v)) + + ')' + ) class Config: extra = 'allow' @@ -42,6 +45,7 @@ class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass): class PlatformIndexedMetaclass(PlatformMetaclass): """可以通过子类名获取子类的类的元类。""" + __indexedbases__: List[Type['PlatformIndexedModel']] = [] __indexedmodel__ = None @@ -69,6 +73,7 @@ class PlatformIndexedMetaclass(PlatformMetaclass): class PlatformIndexedModel(PlatformBaseModel, metaclass=PlatformIndexedMetaclass): """可以通过子类名获取子类的类。""" + __indexes__: Dict[str, Type['PlatformIndexedModel']] @classmethod @@ -86,7 +91,7 @@ class PlatformIndexedModel(PlatformBaseModel, metaclass=PlatformIndexedMetaclass if not (type_ and issubclass(type_, cls)): raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!') return type_ - except AttributeError as e: + except AttributeError: raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!') from None @classmethod diff --git a/pkg/platform/types/entities.py b/pkg/platform/types/entities.py index 33fbefe9..d989ffce 100644 --- a/pkg/platform/types/entities.py +++ b/pkg/platform/types/entities.py @@ -2,6 +2,7 @@ """ 此模块提供实体和配置项模型。 """ + import abc from datetime import datetime from enum import Enum @@ -12,8 +13,10 @@ import pydantic.v1 as pydantic class Entity(pydantic.BaseModel): """实体,表示一个用户或群。""" + id: int """ID。""" + @abc.abstractmethod def get_name(self) -> str: """名称。""" @@ -21,31 +24,35 @@ class Entity(pydantic.BaseModel): class Friend(Entity): """私聊对象。""" + id: typing.Union[int, str] """ID。""" nickname: typing.Optional[str] """昵称。""" remark: typing.Optional[str] """备注。""" + def get_name(self) -> str: return self.nickname or self.remark or '' - class Permission(str, Enum): """群成员身份权限。""" - Member = "MEMBER" + + Member = 'MEMBER' """成员。""" - Administrator = "ADMINISTRATOR" + Administrator = 'ADMINISTRATOR' """管理员。""" - Owner = "OWNER" + Owner = 'OWNER' """群主。""" + def __repr__(self) -> str: return repr(self.value) class Group(Entity): """群。""" + id: typing.Union[int, str] """群号。""" name: str @@ -59,6 +66,7 @@ class Group(Entity): class GroupMember(Entity): """群成员。""" + id: typing.Union[int, str] """群员 ID。""" member_name: str diff --git a/pkg/platform/types/events.py b/pkg/platform/types/events.py index 40507315..1a724beb 100644 --- a/pkg/platform/types/events.py +++ b/pkg/platform/types/events.py @@ -2,8 +2,7 @@ """ 此模块提供事件模型。 """ -from datetime import datetime -from enum import Enum + import typing import pydantic.v1 as pydantic @@ -18,15 +17,23 @@ class Event(pydantic.BaseModel): Args: type: 事件名。 """ + type: str """事件名。""" + def __repr__(self): - return self.__class__.__name__ + '(' + ', '.join( - ( - f'{k}={repr(v)}' - for k, v in self.__dict__.items() if k != 'type' and v + return ( + self.__class__.__name__ + + '(' + + ', '.join( + ( + f'{k}={repr(v)}' + for k, v in self.__dict__.items() + if k != 'type' and v + ) ) - ) + ')' + + ')' + ) @classmethod def parse_subtype(cls, obj: dict) -> 'Event': @@ -52,6 +59,7 @@ class MessageEvent(Event): type: 事件名。 message_chain: 消息内容。 """ + type: str """事件名。""" message_chain: platform_message.MessageChain @@ -74,6 +82,7 @@ class FriendMessage(MessageEvent): sender: 发送消息的好友。 message_chain: 消息内容。 """ + type: str = 'FriendMessage' """事件名。""" sender: platform_entities.Friend @@ -90,12 +99,14 @@ class GroupMessage(MessageEvent): sender: 发送消息的群成员。 message_chain: 消息内容。 """ + type: str = 'GroupMessage' """事件名。""" sender: platform_entities.GroupMember """发送消息的群成员。""" message_chain: platform_message.MessageChain """消息内容。""" + @property def group(self) -> platform_entities.Group: return self.sender.group diff --git a/pkg/platform/types/message.py b/pkg/platform/types/message.py index b99a28b3..529ce4c1 100644 --- a/pkg/platform/types/message.py +++ b/pkg/platform/types/message.py @@ -1,7 +1,6 @@ import itertools import logging from datetime import datetime -from enum import Enum from pathlib import Path import typing @@ -16,6 +15,7 @@ logger = logging.getLogger(__name__) class MessageComponentMetaclass(PlatformIndexedMetaclass): """消息组件元类。""" + __message_component__ = None def __new__(cls, name, bases, attrs, **kwargs): @@ -41,18 +41,26 @@ class MessageComponentMetaclass(PlatformIndexedMetaclass): class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass): """消息组件。""" + type: str """消息组件类型。""" + def __str__(self): return '' def __repr__(self): - return self.__class__.__name__ + '(' + ', '.join( - ( - f'{k}={repr(v)}' - for k, v in self.__dict__.items() if k != 'type' and v + return ( + self.__class__.__name__ + + '(' + + ', '.join( + ( + f'{k}={repr(v)}' + for k, v in self.__dict__.items() + if k != 'type' and v + ) ) - ) + ')' + + ')' + ) def __init__(self, *args, **kwargs): # 解析参数列表,将位置参数转化为具名参数 @@ -63,7 +71,9 @@ class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass ) for name, value in zip(parameter_names, args): if name in kwargs: - raise TypeError(f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。') + raise TypeError( + f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。' + ) kwargs[name] = value super().__init__(**kwargs) @@ -117,6 +127,7 @@ class MessageChain(PlatformBaseModel): ``` """ + __root__: typing.List[MessageComponent] @staticmethod @@ -131,10 +142,10 @@ class MessageChain(PlatformBaseModel): result.append(Plain(msg)) else: raise TypeError( - f"消息链中元素需为 dict 或 str 或 MessageComponent,当前类型:{type(msg)}" + f'消息链中元素需为 dict 或 str 或 MessageComponent,当前类型:{type(msg)}' ) return result - + @pydantic.validator('__root__', always=True, pre=True) def _parse_component(cls, msg_chain): if isinstance(msg_chain, (str, MessageComponent)): @@ -157,7 +168,7 @@ class MessageChain(PlatformBaseModel): super().__init__(__root__=__root__) def __str__(self): - return "".join(str(component) for component in self.__root__) + return ''.join(str(component) for component in self.__root__) def __repr__(self): return f'{self.__class__.__name__}({self.__root__!r})' @@ -165,8 +176,9 @@ class MessageChain(PlatformBaseModel): def __iter__(self): yield from self.__root__ - def get_first(self, - t: typing.Type[TMessageComponent]) -> typing.Optional[TMessageComponent]: + def get_first( + self, t: typing.Type[TMessageComponent] + ) -> typing.Optional[TMessageComponent]: """获取消息链中第一个符合类型的消息组件。""" for component in self: if isinstance(component, t): @@ -174,35 +186,40 @@ class MessageChain(PlatformBaseModel): return None @typing.overload - def __getitem__(self, index: int) -> MessageComponent: - ... + def __getitem__(self, index: int) -> MessageComponent: ... @typing.overload - def __getitem__(self, index: slice) -> typing.List[MessageComponent]: - ... + def __getitem__(self, index: slice) -> typing.List[MessageComponent]: ... @typing.overload - def __getitem__(self, - index: typing.Type[TMessageComponent]) -> typing.List[TMessageComponent]: - ... + def __getitem__( + self, index: typing.Type[TMessageComponent] + ) -> typing.List[TMessageComponent]: ... @typing.overload def __getitem__( self, index: typing.Tuple[typing.Type[TMessageComponent], int] - ) -> typing.List[TMessageComponent]: - ... + ) -> typing.List[TMessageComponent]: ... def __getitem__( - self, index: typing.Union[int, slice, typing.Type[TMessageComponent], - typing.Tuple[typing.Type[TMessageComponent], int]] - ) -> typing.Union[MessageComponent, typing.List[MessageComponent], - typing.List[TMessageComponent]]: + self, + index: typing.Union[ + int, + slice, + typing.Type[TMessageComponent], + typing.Tuple[typing.Type[TMessageComponent], int], + ], + ) -> typing.Union[ + MessageComponent, typing.List[MessageComponent], typing.List[TMessageComponent] + ]: return self.get(index) def __setitem__( - self, key: typing.Union[int, slice], - value: typing.Union[MessageComponent, str, typing.Iterable[typing.Union[MessageComponent, - str]]] + self, + key: typing.Union[int, slice], + value: typing.Union[ + MessageComponent, str, typing.Iterable[typing.Union[MessageComponent, str]] + ], ): if isinstance(value, str): value = Plain(value) @@ -217,8 +234,10 @@ class MessageChain(PlatformBaseModel): return reversed(self.__root__) def has( - self, sub: typing.Union[MessageComponent, typing.Type[MessageComponent], - 'MessageChain', str] + self, + sub: typing.Union[ + MessageComponent, typing.Type[MessageComponent], 'MessageChain', str + ], ) -> bool: """判断消息链中: 1. 是否有某个消息组件。 @@ -242,7 +261,7 @@ class MessageChain(PlatformBaseModel): if i == sub: return True return False - raise TypeError(f"类型不匹配,当前类型:{type(sub)}") + raise TypeError(f'类型不匹配,当前类型:{type(sub)}') def __contains__(self, sub) -> bool: return self.has(sub) @@ -293,7 +312,7 @@ class MessageChain(PlatformBaseModel): self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]], i: int = 0, - j: int = -1 + j: int = -1, ) -> int: """返回 x 在消息链中首次出现项的索引号(索引号在 i 或其后且在 j 之前)。 @@ -323,12 +342,14 @@ class MessageChain(PlatformBaseModel): for index in range(i, j): if type(self[index]) is x: return index - raise ValueError("消息链中不存在该类型的组件。") + raise ValueError('消息链中不存在该类型的组件。') if isinstance(x, MessageComponent): return self.__root__.index(x, i, j) - raise TypeError(f"类型不匹配,当前类型:{type(x)}") + raise TypeError(f'类型不匹配,当前类型:{type(x)}') - def count(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]) -> int: + def count( + self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]] + ) -> int: """返回消息链中 x 出现的次数。 Args: @@ -342,7 +363,7 @@ class MessageChain(PlatformBaseModel): return sum(1 for i in self if type(i) is x) if isinstance(x, MessageComponent): return self.__root__.count(x) - raise TypeError(f"类型不匹配,当前类型:{type(x)}") + raise TypeError(f'类型不匹配,当前类型:{type(x)}') def extend(self, x: typing.Iterable[typing.Union[MessageComponent, str]]): """将另一个消息链中的元素添加到消息链末尾。 @@ -394,7 +415,7 @@ class MessageChain(PlatformBaseModel): def exclude( self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]], - count: int = -1 + count: int = -1, ) -> 'MessageChain': """返回移除指定元素或指定类型的元素后剩余的消息链。 @@ -405,6 +426,7 @@ class MessageChain(PlatformBaseModel): Returns: MessageChain: 剩余的消息链。 """ + def _exclude(): nonlocal count x_is_type = isinstance(x, type) @@ -423,8 +445,7 @@ class MessageChain(PlatformBaseModel): @classmethod def join(cls, *args: typing.Iterable[typing.Union[str, MessageComponent]]): return cls( - Plain(c) if isinstance(c, str) else c - for c in itertools.chain(*args) + Plain(c) if isinstance(c, str) else c for c in itertools.chain(*args) ) @property @@ -439,14 +460,19 @@ class MessageChain(PlatformBaseModel): return source.id if source else -1 -TMessage = typing.Union[MessageChain, typing.Iterable[typing.Union[MessageComponent, str]], - MessageComponent, str] +TMessage = typing.Union[ + MessageChain, + typing.Iterable[typing.Union[MessageComponent, str]], + MessageComponent, + str, +] """可以转化为 MessageChain 的类型。""" class Source(MessageComponent): """源。包含消息的基本信息。""" - type: str = "Source" + + type: str = 'Source' """消息组件类型。""" id: typing.Union[int, str] """消息的识别号,用于引用回复(Source 类型永远为 MessageChain 的第一个元素)。""" @@ -456,10 +482,12 @@ class Source(MessageComponent): class Plain(MessageComponent): """纯文本。""" - type: str = "Plain" + + type: str = 'Plain' """消息组件类型。""" text: str """文字消息。""" + def __str__(self): return self.text @@ -469,7 +497,8 @@ class Plain(MessageComponent): class Quote(MessageComponent): """引用。""" - type: str = "Quote" + + type: str = 'Quote' """消息组件类型。""" id: typing.Optional[int] = None """被引用回复的原消息的 message_id。""" @@ -482,37 +511,42 @@ class Quote(MessageComponent): origin: MessageChain """被引用回复的原消息的消息链对象。""" - @pydantic.validator("origin", always=True, pre=True) + @pydantic.validator('origin', always=True, pre=True) def origin_formater(cls, v): return MessageChain.parse_obj(v) class At(MessageComponent): """At某人。""" - type: str = "At" + + type: str = 'At' """消息组件类型。""" target: typing.Union[int, str] """群员 ID。""" display: typing.Optional[str] = None """At时显示的文字,发送消息时无效,自动使用群名片。""" + def __eq__(self, other): return isinstance(other, At) and self.target == other.target def __str__(self): - return f"@{self.display or self.target}" + return f'@{self.display or self.target}' class AtAll(MessageComponent): """At全体。""" - type: str = "AtAll" + + type: str = 'AtAll' """消息组件类型。""" + def __str__(self): - return "@全体成员" + return '@全体成员' class Image(MessageComponent): """图片。""" - type: str = "Image" + + type: str = 'Image' """消息组件类型。""" image_id: typing.Optional[str] = None """图片的 image_id,不为空时将忽略 url 属性。""" @@ -522,10 +556,13 @@ class Image(MessageComponent): """图片的路径,发送本地图片。""" base64: typing.Optional[str] = None """图片的 Base64 编码。""" + def __eq__(self, other): - return isinstance( - other, Image - ) and self.type == other.type and self.uuid == other.uuid + return ( + isinstance(other, Image) + and self.type == other.type + and self.uuid == other.uuid + ) def __str__(self): return '[图片]' @@ -537,7 +574,7 @@ class Image(MessageComponent): try: return str(Path(path).resolve(strict=True)) except FileNotFoundError: - raise ValueError(f"无效路径:{path}") + raise ValueError(f'无效路径:{path}') else: return path @@ -554,7 +591,7 @@ class Image(MessageComponent): self, filename: typing.Union[str, Path, None] = None, directory: typing.Union[str, Path, None] = None, - determine_type: bool = True + determine_type: bool = True, ): """下载图片到本地。 @@ -568,6 +605,7 @@ class Image(MessageComponent): return import httpx + async with httpx.AsyncClient() as client: response = await client.get(self.url) response.raise_for_status() @@ -577,19 +615,20 @@ class Image(MessageComponent): path = Path(filename) if determine_type: import imghdr - path = path.with_suffix( - '.' + str(imghdr.what(None, content)) - ) + + path = path.with_suffix('.' + str(imghdr.what(None, content))) path.parent.mkdir(parents=True, exist_ok=True) elif directory: import imghdr + path = Path(directory) path.mkdir(parents=True, exist_ok=True) path = path / f'{self.uuid}.{imghdr.what(None, content)}' else: - raise ValueError("请指定文件路径或文件夹路径!") + raise ValueError('请指定文件路径或文件夹路径!') import aiofiles + async with aiofiles.open(path, 'wb') as f: await f.write(content) @@ -600,7 +639,7 @@ class Image(MessageComponent): cls, filename: typing.Union[str, Path, None] = None, content: typing.Optional[bytes] = None, - ) -> "Image": + ) -> 'Image': """从本地文件路径加载图片,以 base64 的形式传递。 Args: @@ -615,16 +654,18 @@ class Image(MessageComponent): elif filename: path = Path(filename) import aiofiles + async with aiofiles.open(path, 'rb') as f: content = await f.read() else: - raise ValueError("请指定图片路径或图片内容!") + raise ValueError('请指定图片路径或图片内容!') import base64 + img = cls(base64=base64.b64encode(content).decode()) return img @classmethod - def from_unsafe_path(cls, path: typing.Union[str, Path]) -> "Image": + def from_unsafe_path(cls, path: typing.Union[str, Path]) -> 'Image': """从不安全的路径加载图片。 Args: @@ -638,7 +679,8 @@ class Image(MessageComponent): class Unknown(MessageComponent): """未知。""" - type: str = "Unknown" + + type: str = 'Unknown' """消息组件类型。""" text: str """文本。""" @@ -646,7 +688,8 @@ class Unknown(MessageComponent): class Voice(MessageComponent): """语音。""" - type: str = "Voice" + + type: str = 'Voice' """消息组件类型。""" voice_id: typing.Optional[str] = None """语音的 voice_id,不为空时将忽略 url 属性。""" @@ -658,6 +701,7 @@ class Voice(MessageComponent): """语音的 Base64 编码。""" length: typing.Optional[int] = None """语音的长度,单位为秒。""" + @pydantic.validator('path') def validate_path(cls, path: typing.Optional[str]): """修复 path 参数的行为,使之相对于 LangBot 的启动路径。""" @@ -665,7 +709,7 @@ class Voice(MessageComponent): try: return str(Path(path).resolve(strict=True)) except FileNotFoundError: - raise ValueError(f"无效路径:{path}") + raise ValueError(f'无效路径:{path}') else: return path @@ -675,7 +719,7 @@ class Voice(MessageComponent): async def download( self, filename: typing.Union[str, Path, None] = None, - directory: typing.Union[str, Path, None] = None + directory: typing.Union[str, Path, None] = None, ): """下载语音到本地。 @@ -688,6 +732,7 @@ class Voice(MessageComponent): return import httpx + async with httpx.AsyncClient() as client: response = await client.get(self.url) response.raise_for_status() @@ -701,9 +746,10 @@ class Voice(MessageComponent): path.mkdir(parents=True, exist_ok=True) path = path / f'{self.voice_id}.silk' else: - raise ValueError("请指定文件路径或文件夹路径!") + raise ValueError('请指定文件路径或文件夹路径!') import aiofiles + async with aiofiles.open(path, 'wb') as f: await f.write(content) @@ -712,7 +758,7 @@ class Voice(MessageComponent): cls, filename: typing.Union[str, Path, None] = None, content: typing.Optional[bytes] = None, - ) -> "Voice": + ) -> 'Voice': """从本地文件路径加载语音,以 base64 的形式传递。 Args: @@ -724,17 +770,20 @@ class Voice(MessageComponent): if filename: path = Path(filename) import aiofiles + async with aiofiles.open(path, 'rb') as f: content = await f.read() else: - raise ValueError("请指定语音路径或语音内容!") + raise ValueError('请指定语音路径或语音内容!') import base64 + img = cls(base64=base64.b64encode(content).decode()) return img class ForwardMessageNode(pydantic.BaseModel): """合并转发中的一条消息。""" + sender_id: typing.Optional[typing.Union[int, str]] = None """发送人ID。""" sender_name: typing.Optional[str] = None @@ -745,6 +794,7 @@ class ForwardMessageNode(pydantic.BaseModel): """消息的 message_id。""" time: typing.Optional[datetime] = None """发送时间。""" + @pydantic.validator('message_chain', check_fields=False) def _validate_message_chain(cls, value: typing.Union[MessageChain, list]): if isinstance(value, list): @@ -753,7 +803,9 @@ class ForwardMessageNode(pydantic.BaseModel): @classmethod def create( - cls, sender: typing.Union[platform_entities.Friend, platform_entities.GroupMember], message: MessageChain + cls, + sender: typing.Union[platform_entities.Friend, platform_entities.GroupMember], + message: MessageChain, ) -> 'ForwardMessageNode': """从消息链生成转发消息。 @@ -765,28 +817,28 @@ class ForwardMessageNode(pydantic.BaseModel): ForwardMessageNode: 生成的一条消息。 """ return ForwardMessageNode( - sender_id=sender.id, - sender_name=sender.get_name(), - message_chain=message + sender_id=sender.id, sender_name=sender.get_name(), message_chain=message ) class ForwardMessageDiaplay(pydantic.BaseModel): - title: str = "群聊的聊天记录" - brief: str = "[聊天记录]" - source: str = "聊天记录" + title: str = '群聊的聊天记录' + brief: str = '[聊天记录]' + source: str = '聊天记录' preview: typing.List[str] = [] - summary: str = "查看x条转发消息" + summary: str = '查看x条转发消息' class Forward(MessageComponent): """合并转发。""" - type: str = "Forward" + + type: str = 'Forward' """消息组件类型。""" display: ForwardMessageDiaplay """显示信息""" node_list: typing.List[ForwardMessageNode] """转发消息节点列表。""" + def __init__(self, *args, **kwargs): if len(args) == 1: self.node_list = args[0] @@ -799,7 +851,8 @@ class Forward(MessageComponent): class File(MessageComponent): """文件。""" - type: str = "File" + + type: str = 'File' """消息组件类型。""" id: str """文件识别 ID。""" @@ -807,6 +860,6 @@ class File(MessageComponent): """文件名称。""" size: int """文件大小。""" + def __str__(self): return f'[文件]{self.name}' - diff --git a/pkg/plugin/__init__.py b/pkg/plugin/__init__.py index c543161a..f6bf97d7 100644 --- a/pkg/plugin/__init__.py +++ b/pkg/plugin/__init__.py @@ -1,4 +1,4 @@ """插件支持包 包含插件基类、插件宿主以及部分API接口 -""" \ No newline at end of file +""" diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index fc06b9ec..cc95adaa 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -14,13 +14,10 @@ from ..platform import adapter as platform_adapter def register( - name: str, - description: str, - version: str, - author: str + name: str, description: str, version: str, author: str ) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]: """注册插件类 - + 使用示例: @register( @@ -34,15 +31,16 @@ def register( """ pass + def handler( - event: typing.Type[events.BaseEventModel] + event: typing.Type[events.BaseEventModel], ) -> typing.Callable[[typing.Callable], typing.Callable]: """注册事件监听器 - + 使用示例: class MyPlugin(BasePlugin): - + @handler(NormalMessageResponded) async def on_normal_message_responded(self, ctx: EventContext): pass @@ -51,14 +49,14 @@ def handler( def llm_func( - name: str=None, + name: str = None, ) -> typing.Callable: """注册内容函数 - + 使用示例: class MyPlugin(BasePlugin): - + @llm_func("access_the_web_page") async def _(self, query, url: str, brief_len: int): \"""Call this function to search about the question before you answer any questions. @@ -98,7 +96,7 @@ class BasePlugin(metaclass=abc.ABCMeta): async def initialize(self): """初始化阶段被调用""" pass - + async def destroy(self): """释放/禁用插件时被调用""" pass @@ -123,12 +121,12 @@ class APIHost: def get_platform_adapters(self) -> list[platform_adapter.MessagePlatformAdapter]: """获取已启用的消息平台适配器列表 - + Returns: list[platform.adapter.MessageSourceAdapter]: 已启用的消息平台适配器列表 """ return self.ap.platform_mgr.get_running_adapters() - + async def send_active_message( self, adapter: platform_adapter.MessagePlatformAdapter, @@ -137,7 +135,7 @@ class APIHost: message: platform_message.MessageChain, ): """发送主动消息 - + Args: adapter (platform.adapter.MessageSourceAdapter): 消息平台适配器对象,调用 host.get_platform_adapters() 获取并取用其中某个 target_type (str): 目标类型,`person`或`group` @@ -153,7 +151,7 @@ class APIHost: def require_ver( self, ge: str, - le: str='v999.999.999', + le: str = 'v999.999.999', ) -> bool: """插件版本要求装饰器 @@ -164,16 +162,23 @@ class APIHost: Returns: bool: 是否满足要求, False时为无法获取版本号,True时为满足要求,报错为不满足要求 """ - langbot_version = "" + langbot_version = '' try: - langbot_version = self.ap.ver_mgr.get_current_version() # 从updater模块获取版本号 - except: + langbot_version = ( + self.ap.ver_mgr.get_current_version() + ) # 从updater模块获取版本号 + except Exception: return False - if self.ap.ver_mgr.compare_version_str(langbot_version, ge) < 0 or \ - (self.ap.ver_mgr.compare_version_str(langbot_version, le) > 0): - raise Exception("LangBot 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{})".format(ge, le, langbot_version)) + if self.ap.ver_mgr.compare_version_str(langbot_version, ge) < 0 or ( + self.ap.ver_mgr.compare_version_str(langbot_version, le) > 0 + ): + raise Exception( + 'LangBot 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{})'.format( + ge, le, langbot_version + ) + ) return True @@ -220,36 +225,30 @@ class EventContext: if key not in self.__return_value__: self.__return_value__[key] = [] self.__return_value__[key].append(ret) - + async def reply(self, message_chain: platform_message.MessageChain): """回复此次消息请求 - + Args: message_chain (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链 """ # TODO 添加 at_sender 和 quote_origin 参数 await self.event.query.adapter.reply_message( - message_source=self.event.query.message_event, - message=message_chain + message_source=self.event.query.message_event, message=message_chain ) - + async def send_message( - self, - target_type: str, - target_id: str, - message: platform_message.MessageChain + self, target_type: str, target_id: str, message: platform_message.MessageChain ): """主动发送消息 - + Args: target_type (str): 目标类型,`person`或`group` target_id (str): 目标ID message (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链 """ await self.event.query.adapter.send_message( - target_type=target_type, - target_id=target_id, - message=message + target_type=target_type, target_id=target_id, message=message ) def prevent_postorder(self): @@ -281,10 +280,8 @@ class EventContext: def is_prevented_postorder(self): """是否阻止后序插件执行""" return self.__prevent_postorder__ - def __init__(self, host: APIHost, event: events.BaseEventModel): - self.eid = EventContext.eid self.host = host self.event = event @@ -297,16 +294,16 @@ class EventContext: class RuntimeContainerStatus(enum.Enum): """插件容器状态""" - MOUNTED = "mounted" + MOUNTED = 'mounted' """已加载进内存,所有位于运行时记录中的 RuntimeContainer 至少是这个状态""" - INITIALIZED = "initialized" + INITIALIZED = 'initialized' """已初始化""" class RuntimeContainer(pydantic.BaseModel): """运行时的插件容器 - + 运行期间存储单个插件的信息 """ @@ -352,9 +349,10 @@ class RuntimeContainer(pydantic.BaseModel): plugin_inst: typing.Optional[BasePlugin] = None """插件实例""" - event_handlers: dict[typing.Type[events.BaseEventModel], typing.Callable[ - [BasePlugin, EventContext], typing.Awaitable[None] - ]] = {} + event_handlers: dict[ + typing.Type[events.BaseEventModel], + typing.Callable[[BasePlugin, EventContext], typing.Awaitable[None]], + ] = {} """事件处理器""" tools: list[tools_entities.LLMFunction] = [] @@ -378,7 +376,7 @@ class RuntimeContainer(pydantic.BaseModel): 'pkg_path': self.pkg_path, 'enabled': self.enabled, 'priority': self.priority, - "config_schema": self.config_schema, + 'config_schema': self.config_schema, 'event_handlers': { event_name.__name__: handler.__name__ for event_name, handler in self.event_handlers.items() diff --git a/pkg/plugin/errors.py b/pkg/plugin/errors.py index bd6199e3..8da223db 100644 --- a/pkg/plugin/errors.py +++ b/pkg/plugin/errors.py @@ -2,7 +2,6 @@ from __future__ import annotations class PluginSystemError(Exception): - message: str def __init__(self, message: str): @@ -10,15 +9,13 @@ class PluginSystemError(Exception): def __str__(self): return self.message - + class PluginNotFoundError(PluginSystemError): - def __init__(self, message: str): - super().__init__(f"未找到插件: {message}") + super().__init__(f'未找到插件: {message}') class PluginInstallerError(PluginSystemError): - def __init__(self, message: str): - super().__init__(f"安装器操作错误: {message}") + super().__init__(f'安装器操作错误: {message}') diff --git a/pkg/plugin/events.py b/pkg/plugin/events.py index 152ac39f..61e84714 100644 --- a/pkg/plugin/events.py +++ b/pkg/plugin/events.py @@ -27,7 +27,7 @@ class PersonMessageReceived(BaseEventModel): launcher_id: typing.Union[int, str] """发起对象ID(群号/QQ号)""" - + sender_id: typing.Union[int, str] """发送者ID(QQ号)""" @@ -40,7 +40,7 @@ class GroupMessageReceived(BaseEventModel): launcher_type: str launcher_id: typing.Union[int, str] - + sender_id: typing.Union[int, str] message_chain: platform_message.MessageChain @@ -52,7 +52,7 @@ class PersonNormalMessageReceived(BaseEventModel): launcher_type: str launcher_id: typing.Union[int, str] - + sender_id: typing.Union[int, str] text_message: str @@ -70,7 +70,7 @@ class PersonCommandSent(BaseEventModel): launcher_type: str launcher_id: typing.Union[int, str] - + sender_id: typing.Union[int, str] command: str @@ -94,7 +94,7 @@ class GroupNormalMessageReceived(BaseEventModel): launcher_type: str launcher_id: typing.Union[int, str] - + sender_id: typing.Union[int, str] text_message: str @@ -112,7 +112,7 @@ class GroupCommandSent(BaseEventModel): launcher_type: str launcher_id: typing.Union[int, str] - + sender_id: typing.Union[int, str] command: str @@ -136,7 +136,7 @@ class NormalMessageResponded(BaseEventModel): launcher_type: str launcher_id: typing.Union[int, str] - + sender_id: typing.Union[int, str] session: core_entities.Session diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py index 2868875d..0adb0078 100644 --- a/pkg/plugin/host.py +++ b/pkg/plugin/host.py @@ -2,8 +2,8 @@ # 请从 pkg.plugin.context 引入 BasePlugin, EventContext 和 APIHost # 最早将于 v3.4 移除此模块 -from . events import * -from . context import EventContext, APIHost as PluginHost +from .events import * + def emit(*args, **kwargs): - print('插件调用了已弃用的函数 pkg.plugin.host.emit()') \ No newline at end of file + print('插件调用了已弃用的函数 pkg.plugin.host.emit()') diff --git a/pkg/plugin/installer.py b/pkg/plugin/installer.py index b9ffab8b..159967dc 100644 --- a/pkg/plugin/installer.py +++ b/pkg/plugin/installer.py @@ -1,6 +1,5 @@ from __future__ import annotations -import typing import abc from ..core import app, taskmgr @@ -23,8 +22,7 @@ class PluginInstaller(metaclass=abc.ABCMeta): plugin_source: str, task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), ): - """安装插件 - """ + """安装插件""" raise NotImplementedError @abc.abstractmethod @@ -33,17 +31,15 @@ class PluginInstaller(metaclass=abc.ABCMeta): plugin_name: str, task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), ): - """卸载插件 - """ + """卸载插件""" raise NotImplementedError @abc.abstractmethod async def update_plugin( self, plugin_name: str, - plugin_source: str=None, + plugin_source: str = None, task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), ): - """更新插件 - """ + """更新插件""" raise NotImplementedError diff --git a/pkg/plugin/installers/github.py b/pkg/plugin/installers/github.py index ff36cb5b..a867b04d 100644 --- a/pkg/plugin/installers/github.py +++ b/pkg/plugin/installers/github.py @@ -2,7 +2,6 @@ from __future__ import annotations import re import os -import shutil import zipfile import ssl import certifi @@ -18,33 +17,37 @@ from ...core import taskmgr class GitHubRepoInstaller(installer.PluginInstaller): - """GitHub仓库插件安装器 - """ + """GitHub仓库插件安装器""" def get_github_plugin_repo_label(self, repo_url: str) -> list[str]: """获取username, repo""" repo = re.findall( - r"(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)", + r'(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)', repo_url, ) if len(repo) > 0: - return repo[0].split("/") + return repo[0].split('/') else: return None - async def download_plugin_source_code(self, repo_url: str, target_path: str, task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder()) -> str: + async def download_plugin_source_code( + self, + repo_url: str, + target_path: str, + task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), + ) -> str: """下载插件源码(全异步)""" repo = self.get_github_plugin_repo_label(repo_url) if repo is None: raise errors.PluginInstallerError('仅支持GitHub仓库地址') - + target_path += repo[1] - self.ap.logger.debug("正在下载源码...") - task_context.trace("下载源码...", "download-plugin-source-code") - - zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD" + self.ap.logger.debug('正在下载源码...') + task_context.trace('下载源码...', 'download-plugin-source-code') + + zipball_url = f'https://api.github.com/repos/{"/".join(repo)}/zipball/HEAD' zip_resp: bytes = None - + # 创建自定义SSL上下文,使用certifi提供的根证书 ssl_context = ssl.create_default_context(cafile=certifi.where()) @@ -52,41 +55,44 @@ class GitHubRepoInstaller(installer.PluginInstaller): async with session.get( url=zipball_url, timeout=aiohttp.ClientTimeout(total=300), - ssl=ssl_context # 使用自定义SSL上下文来验证证书 + ssl=ssl_context, # 使用自定义SSL上下文来验证证书 ) as resp: if resp.status != 200: - raise errors.PluginInstallerError(f"下载源码失败: {await resp.text()}") + raise errors.PluginInstallerError( + f'下载源码失败: {await resp.text()}' + ) zip_resp = await resp.read() - - if await aiofiles_os.path.exists("temp/" + target_path): - await aioshutil.rmtree("temp/" + target_path) + + if await aiofiles_os.path.exists('temp/' + target_path): + await aioshutil.rmtree('temp/' + target_path) if await aiofiles_os.path.exists(target_path): await aioshutil.rmtree(target_path) - await aiofiles_os.makedirs("temp/" + target_path) + await aiofiles_os.makedirs('temp/' + target_path) - async with aiofiles.open("temp/" + target_path + "/source.zip", "wb") as f: + async with aiofiles.open('temp/' + target_path + '/source.zip', 'wb') as f: await f.write(zip_resp) - self.ap.logger.debug("解压中...") - task_context.trace("解压中...", "unzip-plugin-source-code") - - with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref: - zip_ref.extractall("temp/" + target_path) - await aiofiles_os.remove("temp/" + target_path + "/source.zip") + self.ap.logger.debug('解压中...') + task_context.trace('解压中...', 'unzip-plugin-source-code') + + with zipfile.ZipFile('temp/' + target_path + '/source.zip', 'r') as zip_ref: + zip_ref.extractall('temp/' + target_path) + await aiofiles_os.remove('temp/' + target_path + '/source.zip') import glob - unzip_dir = glob.glob("temp/" + target_path + "/*")[0] - await aioshutil.copytree(unzip_dir, target_path + "/") + + unzip_dir = glob.glob('temp/' + target_path + '/*')[0] + await aioshutil.copytree(unzip_dir, target_path + '/') await aioshutil.rmtree(unzip_dir) - - self.ap.logger.debug("源码下载完成。") + + self.ap.logger.debug('源码下载完成。') return repo[1] async def install_requirements(self, path: str): - if os.path.exists(path + "/requirements.txt"): - pkgmgr.install_requirements(path + "/requirements.txt") + if os.path.exists(path + '/requirements.txt'): + pkgmgr.install_requirements(path + '/requirements.txt') async def install_plugin( self, @@ -94,12 +100,14 @@ class GitHubRepoInstaller(installer.PluginInstaller): task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), ): """安装插件""" - task_context.trace("下载插件源码...", "install-plugin") - repo_label = await self.download_plugin_source_code(plugin_source, "plugins/", task_context) - task_context.trace("安装插件依赖...", "install-plugin") - await self.install_requirements("plugins/" + repo_label) - task_context.trace("完成.", "install-plugin") - + task_context.trace('下载插件源码...', 'install-plugin') + repo_label = await self.download_plugin_source_code( + plugin_source, 'plugins/', task_context + ) + task_context.trace('安装插件依赖...', 'install-plugin') + await self.install_requirements('plugins/' + repo_label) + task_context.trace('完成.', 'install-plugin') + # Caution: in the v4.0, plugin without manifest will not be able to be updated # await self.ap.plugin_mgr.setting.record_installed_plugin_source( # "plugins/" + repo_label + '/', plugin_source @@ -115,9 +123,9 @@ class GitHubRepoInstaller(installer.PluginInstaller): if plugin_container is None: raise errors.PluginInstallerError('插件不存在或未成功加载') else: - task_context.trace("删除插件目录...", "uninstall-plugin") + task_context.trace('删除插件目录...', 'uninstall-plugin') await aioshutil.rmtree(plugin_container.pkg_path) - task_context.trace("完成, 重新加载以生效.", "uninstall-plugin") + task_context.trace('完成, 重新加载以生效.', 'uninstall-plugin') async def update_plugin( self, @@ -126,14 +134,14 @@ class GitHubRepoInstaller(installer.PluginInstaller): task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), ): """更新插件""" - task_context.trace("更新插件...", "update-plugin") + task_context.trace('更新插件...', 'update-plugin') plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) if plugin_container is None: raise errors.PluginInstallerError('插件不存在或未成功加载') else: if plugin_container.plugin_repository: plugin_source = plugin_container.plugin_repository - task_context.trace("转交安装任务.", "update-plugin") + task_context.trace('转交安装任务.', 'update-plugin') await self.install_plugin(plugin_source, task_context) else: - raise errors.PluginInstallerError('插件无源码信息,无法更新') \ No newline at end of file + raise errors.PluginInstallerError('插件无源码信息,无法更新') diff --git a/pkg/plugin/loader.py b/pkg/plugin/loader.py index 44ded4ac..191d8bc1 100644 --- a/pkg/plugin/loader.py +++ b/pkg/plugin/loader.py @@ -1,11 +1,9 @@ from __future__ import annotations -from abc import ABCMeta -import typing import abc from ..core import app -from . import context, events +from . import context class PluginLoader(metaclass=abc.ABCMeta): @@ -25,4 +23,3 @@ class PluginLoader(metaclass=abc.ABCMeta): @abc.abstractmethod async def load_plugins(self): pass - diff --git a/pkg/plugin/loaders/classic.py b/pkg/plugin/loaders/classic.py index ae503ba3..857a7b9c 100644 --- a/pkg/plugin/loaders/classic.py +++ b/pkg/plugin/loaders/classic.py @@ -3,7 +3,6 @@ from __future__ import annotations import typing import pkgutil import importlib -import os import traceback from .. import loader, events, context, models @@ -11,7 +10,6 @@ from ...core import entities as core_entities from ...provider.tools import entities as tools_entities from ...utils import funcschema from ...discover import engine as discover_engine -from ...utils import pkgmgr class PluginLoader(loader.PluginLoader): @@ -36,17 +34,17 @@ class PluginLoader(loader.PluginLoader): """初始化""" def register( - self, - name: str, - description: str, - version: str, - author: str - ) -> typing.Callable[[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]]: + self, name: str, description: str, version: str, author: str + ) -> typing.Callable[ + [typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin] + ]: self.ap.logger.debug(f'注册插件 {name} {version} by {author}') container = context.RuntimeContainer( plugin_name=name, plugin_label=discover_engine.I18nString(en_US=name, zh_CN=name), - plugin_description=discover_engine.I18nString(en_US=description, zh_CN=description), + plugin_description=discover_engine.I18nString( + en_US=description, zh_CN=description + ), plugin_version=version, plugin_author=author, plugin_repository='', @@ -61,20 +59,21 @@ class PluginLoader(loader.PluginLoader): def wrapper(cls: context.BasePlugin) -> typing.Type[context.BasePlugin]: container.plugin_class = cls return cls - + return wrapper # 过时 # 最早将于 v3.4 版本移除 def on( - self, - event: typing.Type[events.BaseEventModel] + self, event: typing.Type[events.BaseEventModel] ) -> typing.Callable[[typing.Callable], typing.Callable]: """注册过时的事件处理器""" self.ap.logger.debug(f'注册事件处理器 {event.__name__}') + def wrapper(func: typing.Callable) -> typing.Callable: - - async def handler(plugin: context.BasePlugin, ctx: context.EventContext) -> None: + async def handler( + plugin: context.BasePlugin, ctx: context.EventContext + ) -> None: args = { 'host': ctx.host, 'event': ctx, @@ -82,12 +81,12 @@ class PluginLoader(loader.PluginLoader): # 把 ctx.event 所有的属性都放到 args 里 # for k, v in ctx.event.dict().items(): - # args[k] = v + # args[k] = v for attr_name in ctx.event.__dict__.keys(): args[attr_name] = getattr(ctx.event, attr_name) func(plugin, **args) - + self._current_container.event_handlers[event] = handler return func @@ -98,20 +97,21 @@ class PluginLoader(loader.PluginLoader): # 最早将于 v3.4 版本移除 def func( self, - name: str=None, + name: str = None, ) -> typing.Callable: """注册过时的内容函数""" self.ap.logger.debug(f'注册内容函数 {name}') + def wrapper(func: typing.Callable) -> typing.Callable: - function_schema = funcschema.get_func_schema(func) - function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) + function_name = ( + self._current_container.plugin_name + + '-' + + (func.__name__ if name is None else name) + ) async def handler( - plugin: context.BasePlugin, - query: core_entities.Query, - *args, - **kwargs + plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs ): return func(*args, **kwargs) @@ -126,18 +126,19 @@ class PluginLoader(loader.PluginLoader): self._current_container.tools.append(llm_function) return func - + return wrapper - + def handler( - self, - event: typing.Type[events.BaseEventModel] + self, event: typing.Type[events.BaseEventModel] ) -> typing.Callable[[typing.Callable], typing.Callable]: """注册事件处理器""" self.ap.logger.debug(f'注册事件处理器 {event.__name__}') - def wrapper(func: typing.Callable) -> typing.Callable: - if self._current_container is None: # None indicates this plugin is registered through manifest, so ignore it here + def wrapper(func: typing.Callable) -> typing.Callable: + if ( + self._current_container is None + ): # None indicates this plugin is registered through manifest, so ignore it here return func self._current_container.event_handlers[event] = func @@ -148,17 +149,23 @@ class PluginLoader(loader.PluginLoader): def llm_func( self, - name: str=None, + name: str = None, ) -> typing.Callable: """注册内容函数""" self.ap.logger.debug(f'注册内容函数 {name}') + def wrapper(func: typing.Callable) -> typing.Callable: - - if self._current_container is None: # None indicates this plugin is registered through manifest, so ignore it here + if ( + self._current_container is None + ): # None indicates this plugin is registered through manifest, so ignore it here return func - + function_schema = funcschema.get_func_schema(func) - function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) + function_name = ( + self._current_container.plugin_name + + '-' + + (func.__name__ if name is None else name) + ) llm_function = tools_entities.LLMFunction( name=function_name, @@ -171,43 +178,40 @@ class PluginLoader(loader.PluginLoader): self._current_container.tools.append(llm_function) return func - + return wrapper - - async def _walk_plugin_path( - self, - module, - prefix='', - path_prefix='' - ): - """遍历插件路径 - """ + + async def _walk_plugin_path(self, module, prefix='', path_prefix=''): + """遍历插件路径""" for item in pkgutil.iter_modules(module.__path__): if item.ispkg: await self._walk_plugin_path( - __import__(module.__name__ + "." + item.name, fromlist=[""]), - prefix + item.name + ".", - path_prefix + item.name + "/", + __import__(module.__name__ + '.' + item.name, fromlist=['']), + prefix + item.name + '.', + path_prefix + item.name + '/', ) else: try: - self._current_pkg_path = "plugins/" + path_prefix - self._current_module_path = "plugins/" + path_prefix + item.name + ".py" + self._current_pkg_path = 'plugins/' + path_prefix + self._current_module_path = ( + 'plugins/' + path_prefix + item.name + '.py' + ) self._current_container = None - importlib.import_module(module.__name__ + "." + item.name) + importlib.import_module(module.__name__ + '.' + item.name) if self._current_container is not None: self.plugins.append(self._current_container) self.ap.logger.debug(f'插件 {self._current_container} 已加载') - except: - self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误') + except Exception: + self.ap.logger.error( + f'加载插件模块 {prefix + item.name} 时发生错误' + ) traceback.print_exc() async def load_plugins(self): - """加载插件 - """ + """加载插件""" setattr(models, 'register', self.register) setattr(models, 'on', self.on) setattr(models, 'func', self.func) @@ -215,4 +219,4 @@ class PluginLoader(loader.PluginLoader): setattr(context, 'register', self.register) setattr(context, 'handler', self.handler) setattr(context, 'llm_func', self.llm_func) - await self._walk_plugin_path(__import__("plugins", fromlist=[""])) + await self._walk_plugin_path(__import__('plugins', fromlist=[''])) diff --git a/pkg/plugin/loaders/manifest.py b/pkg/plugin/loaders/manifest.py index 101fdb3a..b634c5b5 100644 --- a/pkg/plugin/loaders/manifest.py +++ b/pkg/plugin/loaders/manifest.py @@ -1,12 +1,11 @@ from __future__ import annotations import typing -import abc import os import traceback from ...core import app -from .. import context, events, models +from .. import context, events from .. import loader from ...utils import funcschema from ...provider.tools import entities as tools_entities @@ -21,13 +20,12 @@ class PluginManifestLoader(loader.PluginLoader): super().__init__(ap) def handler( - self, - event: typing.Type[events.BaseEventModel] + self, event: typing.Type[events.BaseEventModel] ) -> typing.Callable[[typing.Callable], typing.Callable]: """注册事件处理器""" self.ap.logger.debug(f'注册事件处理器 {event.__name__}') + def wrapper(func: typing.Callable) -> typing.Callable: - self._current_container.event_handlers[event] = func return func @@ -36,14 +34,18 @@ class PluginManifestLoader(loader.PluginLoader): def llm_func( self, - name: str=None, + name: str = None, ) -> typing.Callable: """注册内容函数""" self.ap.logger.debug(f'注册内容函数 {name}') + def wrapper(func: typing.Callable) -> typing.Callable: - function_schema = funcschema.get_func_schema(func) - function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) + function_name = ( + self._current_container.plugin_name + + '-' + + (func.__name__ if name is None else name) + ) llm_function = tools_entities.LLMFunction( name=function_name, @@ -56,7 +58,7 @@ class PluginManifestLoader(loader.PluginLoader): self._current_container.tools.append(llm_function) return func - + return wrapper async def load_plugins(self): @@ -68,7 +70,11 @@ class PluginManifestLoader(loader.PluginLoader): for plugin_manifest in plugin_manifests: try: - config_schema = plugin_manifest.spec['config'] if 'config' in plugin_manifest.spec else [] + config_schema = ( + plugin_manifest.spec['config'] + if 'config' in plugin_manifest.spec + else [] + ) current_plugin_container = context.RuntimeContainer( plugin_name=plugin_manifest.metadata.name, @@ -77,7 +83,9 @@ class PluginManifestLoader(loader.PluginLoader): plugin_version=plugin_manifest.metadata.version, plugin_author=plugin_manifest.metadata.author, plugin_repository=plugin_manifest.metadata.repository, - main_file=os.path.join(plugin_manifest.rel_dir, plugin_manifest.execution.python.path), + main_file=os.path.join( + plugin_manifest.rel_dir, plugin_manifest.execution.python.path + ), pkg_path=plugin_manifest.rel_dir, config_schema=config_schema, event_handlers={}, @@ -95,6 +103,8 @@ class PluginManifestLoader(loader.PluginLoader): # TODO load component extensions self.plugins.append(current_plugin_container) - except Exception as e: - self.ap.logger.error(f'加载插件 {plugin_manifest.metadata.name} 时发生错误') + except Exception: + self.ap.logger.error( + f'加载插件 {plugin_manifest.metadata.name} 时发生错误' + ) traceback.print_exc() diff --git a/pkg/plugin/manager.py b/pkg/plugin/manager.py index f5b5c28f..2322eb5c 100644 --- a/pkg/plugin/manager.py +++ b/pkg/plugin/manager.py @@ -1,10 +1,8 @@ from __future__ import annotations -import typing import traceback import sqlalchemy -import logging from ..core import app, taskmgr from . import context, loader, events, installer, models @@ -28,28 +26,26 @@ class PluginManager: def plugins( self, - enabled: bool=None, - status: context.RuntimeContainerStatus=None, + enabled: bool = None, + status: context.RuntimeContainerStatus = None, ) -> list[context.RuntimeContainer]: - """获取插件列表 - """ + """获取插件列表""" plugins = self.plugin_containers if enabled is not None: plugins = [plugin for plugin in plugins if plugin.enabled == enabled] - + if status is not None: plugins = [plugin for plugin in plugins if plugin.status == status] return plugins - + def get_plugin( self, author: str, plugin_name: str, ) -> context.RuntimeContainer: - """通过作者和插件名获取插件 - """ + """通过作者和插件名获取插件""" for plugin in self.plugins(): if plugin.plugin_author == author and plugin.plugin_name == plugin_name: return plugin @@ -88,20 +84,24 @@ class PluginManager: self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugin_containers}') async def load_plugin_settings( - self, - plugin_containers: list[context.RuntimeContainer] + self, plugin_containers: list[context.RuntimeContainer] ): for plugin_container in plugin_containers: result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_plugin.PluginSetting) \ - .where(persistence_plugin.PluginSetting.plugin_author == plugin_container.plugin_author) - .where(persistence_plugin.PluginSetting.plugin_name == plugin_container.plugin_name) + sqlalchemy.select(persistence_plugin.PluginSetting) + .where( + persistence_plugin.PluginSetting.plugin_author + == plugin_container.plugin_author + ) + .where( + persistence_plugin.PluginSetting.plugin_name + == plugin_container.plugin_name + ) ) setting = result.first() if setting is None: - new_setting_data = { 'plugin_author': plugin_container.plugin_author, 'plugin_name': plugin_container.plugin_name, @@ -111,7 +111,9 @@ class PluginManager: } await self.ap.persistence_mgr.execute_async( - sqlalchemy.insert(persistence_plugin.PluginSetting).values(**new_setting_data) + sqlalchemy.insert(persistence_plugin.PluginSetting).values( + **new_setting_data + ) ) continue else: @@ -120,19 +122,23 @@ class PluginManager: plugin_container.plugin_config = setting.config async def dump_plugin_container_setting( - self, - plugin_container: context.RuntimeContainer + self, plugin_container: context.RuntimeContainer ): - """保存单个插件容器的设置到数据库 - """ + """保存单个插件容器的设置到数据库""" await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_plugin.PluginSetting) - .where(persistence_plugin.PluginSetting.plugin_author == plugin_container.plugin_author) - .where(persistence_plugin.PluginSetting.plugin_name == plugin_container.plugin_name) + .where( + persistence_plugin.PluginSetting.plugin_author + == plugin_container.plugin_author + ) + .where( + persistence_plugin.PluginSetting.plugin_name + == plugin_container.plugin_name + ) .values( enabled=plugin_container.enabled, priority=plugin_container.priority, - config=plugin_container.plugin_config + config=plugin_container.plugin_config, ) ) @@ -160,13 +166,13 @@ class PluginManager: async def destroy_plugin(self, plugin: context.RuntimeContainer): if plugin.status != context.RuntimeContainerStatus.INITIALIZED: return - + self.ap.logger.debug(f'释放插件 {plugin.plugin_name}') plugin.plugin_inst.__del__() await plugin.plugin_inst.destroy() plugin.plugin_inst = None plugin.status = context.RuntimeContainerStatus.MOUNTED - + async def destroy_plugins(self): for plugin in self.plugins(): if plugin.status != context.RuntimeContainerStatus.INITIALIZED: @@ -185,16 +191,15 @@ class PluginManager: plugin_source: str, task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), ): - """安装插件 - """ + """安装插件""" await self.installer.install_plugin(plugin_source, task_context) await self.ap.ctr_mgr.plugin.post_install_record( { - "name": "unknown", - "remote": plugin_source, - "author": "unknown", - "version": "HEAD" + 'name': 'unknown', + 'remote': plugin_source, + 'author': 'unknown', + 'version': 'HEAD', } ) @@ -206,8 +211,7 @@ class PluginManager: plugin_name: str, task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), ): - """卸载插件 - """ + """卸载插件""" plugin_container = self.get_plugin_by_name(plugin_name) @@ -219,10 +223,10 @@ class PluginManager: await self.ap.ctr_mgr.plugin.post_remove_record( { - "name": plugin_name, - "remote": plugin_container.plugin_repository, - "author": plugin_container.plugin_author, - "version": plugin_container.plugin_version + 'name': plugin_name, + 'remote': plugin_container.plugin_repository, + 'author': plugin_container.plugin_author, + 'version': plugin_container.plugin_version, } ) @@ -232,80 +236,82 @@ class PluginManager: async def update_plugin( self, plugin_name: str, - plugin_source: str=None, + plugin_source: str = None, task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(), ): - """更新插件 - """ + """更新插件""" await self.installer.update_plugin(plugin_name, plugin_source, task_context) - + plugin_container = self.get_plugin_by_name(plugin_name) await self.ap.ctr_mgr.plugin.post_update_record( plugin={ - "name": plugin_name, - "remote": plugin_container.plugin_repository, - "author": plugin_container.plugin_author, - "version": plugin_container.plugin_version + 'name': plugin_name, + 'remote': plugin_container.plugin_repository, + 'author': plugin_container.plugin_author, + 'version': plugin_container.plugin_version, }, old_version=plugin_container.plugin_version, - new_version="HEAD" + new_version='HEAD', ) task_context.trace('重载插件..', 'reload-plugin') await self.ap.reload(scope='plugin') def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer: - """通过插件名获取插件 - """ + """通过插件名获取插件""" for plugin in self.plugins(): if plugin.plugin_name == plugin_name: return plugin return None async def emit_event(self, event: events.BaseEventModel) -> context.EventContext: - """触发事件 - """ + """触发事件""" + + ctx = context.EventContext(host=self.api_host, event=event) - ctx = context.EventContext( - host=self.api_host, - event=event - ) - emitted_plugins: list[context.RuntimeContainer] = [] for plugin in self.plugins( - enabled=True, - status=context.RuntimeContainerStatus.INITIALIZED + enabled=True, status=context.RuntimeContainerStatus.INITIALIZED ): if event.__class__ in plugin.event_handlers: - self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}') - + self.ap.logger.debug( + f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}' + ) + is_prevented_default_before_call = ctx.is_prevented_default() try: await plugin.event_handlers[event.__class__]( - plugin.plugin_inst, - ctx + plugin.plugin_inst, ctx ) except Exception as e: - self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}') - self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") - + self.ap.logger.error( + f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}' + ) + self.ap.logger.debug(f'Traceback: {traceback.format_exc()}') + emitted_plugins.append(plugin) if not is_prevented_default_before_call and ctx.is_prevented_default(): - self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行') + self.ap.logger.debug( + f'插件 {plugin.plugin_name} 阻止了默认行为执行' + ) if ctx.is_prevented_postorder(): - self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行') + self.ap.logger.debug( + f'插件 {plugin.plugin_name} 阻止了后序插件的执行' + ) break for key in ctx.__return_value__.keys(): if hasattr(ctx.event, key): setattr(ctx.event, key, ctx.__return_value__[key][0]) - - self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}') + + self.ap.logger.debug( + f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}' + ) if emitted_plugins: plugins_info: list[dict] = [ @@ -313,13 +319,13 @@ class PluginManager: 'name': plugin.plugin_name, 'remote': plugin.plugin_repository, 'version': plugin.plugin_version, - 'author': plugin.plugin_author - } for plugin in emitted_plugins + 'author': plugin.plugin_author, + } + for plugin in emitted_plugins ] await self.ap.ctr_mgr.usage.post_event_record( - plugins=plugins_info, - event_name=event.__class__.__name__ + plugins=plugins_info, event_name=event.__class__.__name__ ) return ctx @@ -330,7 +336,7 @@ class PluginManager: if plugin.plugin_name == plugin_name: if plugin.enabled == new_status: return False - + # 初始化/释放插件 if new_status: await self.initialize_plugin(plugin) @@ -338,7 +344,7 @@ class PluginManager: await self.destroy_plugin(plugin) plugin.enabled = new_status - + await self.dump_plugin_container_setting(plugin) break @@ -348,7 +354,6 @@ class PluginManager: return False async def reorder_plugins(self, plugins: list[dict]): - for plugin in plugins: plugin_name = plugin.get('name') plugin_priority = plugin.get('priority') @@ -363,7 +368,9 @@ class PluginManager: for plugin in self.plugin_containers: await self.dump_plugin_container_setting(plugin) - async def set_plugin_config(self, plugin_container: context.RuntimeContainer, new_config: dict): + async def set_plugin_config( + self, plugin_container: context.RuntimeContainer, new_config: dict + ): plugin_container.plugin_config = new_config plugin_container.plugin_inst.config = new_config diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index b8b499f5..dbde89a9 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -9,22 +9,20 @@ import typing from .context import BasePlugin as Plugin from .events import * + def register( - name: str, - description: str, - version: str, - author + name: str, description: str, version: str, author ) -> typing.Callable[[typing.Type[Plugin]], typing.Type[Plugin]]: pass def on( - event: typing.Type[BaseEventModel] + event: typing.Type[BaseEventModel], ) -> typing.Callable[[typing.Callable], typing.Callable]: pass def func( - name: str=None, + name: str = None, ) -> typing.Callable: pass diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 0fb75f80..ff95b128 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -1,7 +1,6 @@ from __future__ import annotations import typing -import enum import pydantic.v1 as pydantic from pkg.provider import entities @@ -32,7 +31,6 @@ class ImageURLContentObject(pydantic.BaseModel): class ContentElement(pydantic.BaseModel): - type: str """内容类型""" @@ -57,7 +55,7 @@ class ContentElement(pydantic.BaseModel): @classmethod def from_image_url(cls, image_url: str): return cls(type='image_url', image_url=ImageURLContentObject(url=image_url)) - + @classmethod def from_image_base64(cls, image_base64: str): return cls(type='image_base64', image_base64=image_base64) @@ -82,15 +80,19 @@ class Message(pydantic.BaseModel): def readable_str(self) -> str: if self.content is not None: - return str(self.role) + ": " + str(self.get_content_platform_message_chain()) + return ( + str(self.role) + ': ' + str(self.get_content_platform_message_chain()) + ) elif self.tool_calls is not None: return f'调用工具: {self.tool_calls[0].id}' else: return '未知消息' - def get_content_platform_message_chain(self, prefix_text: str="") -> platform_message.MessageChain | None: + def get_content_platform_message_chain( + self, prefix_text: str = '' + ) -> platform_message.MessageChain | None: """将内容转换为平台消息 MessageChain 对象 - + Args: prefix_text (str): 首个文字组件的前缀文本 """ @@ -98,21 +100,22 @@ class Message(pydantic.BaseModel): if self.content is None: return None elif isinstance(self.content, str): - return platform_message.MessageChain([platform_message.Plain(prefix_text+self.content)]) + return platform_message.MessageChain( + [platform_message.Plain(prefix_text + self.content)] + ) elif isinstance(self.content, list): mc = [] for ce in self.content: if ce.type == 'text': mc.append(platform_message.Plain(ce.text)) elif ce.type == 'image_url': - if ce.image_url.url.startswith("http"): + if ce.image_url.url.startswith('http'): mc.append(platform_message.Image(url=ce.image_url.url)) else: # base64 - b64_str = ce.image_url.url - if b64_str.startswith("data:"): - b64_str = b64_str.split(",")[1] + if b64_str.startswith('data:'): + b64_str = b64_str.split(',')[1] mc.append(platform_message.Image(base64=b64_str)) @@ -120,7 +123,7 @@ class Message(pydantic.BaseModel): if prefix_text: for i, c in enumerate(mc): if isinstance(c, platform_message.Plain): - mc[i] = platform_message.Plain(prefix_text+c.text) + mc[i] = platform_message.Plain(prefix_text + c.text) break else: mc.insert(0, platform_message.Plain(prefix_text)) diff --git a/pkg/provider/modelmgr/errors.py b/pkg/provider/modelmgr/errors.py index d466cf11..dc3b35b6 100644 --- a/pkg/provider/modelmgr/errors.py +++ b/pkg/provider/modelmgr/errors.py @@ -2,4 +2,4 @@ class RequesterError(Exception): """Base class for all Requester errors.""" def __init__(self, message: str): - super().__init__("模型请求失败: "+message) + super().__init__('模型请求失败: ' + message) diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 244d5753..25a79fec 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -1,8 +1,6 @@ from __future__ import annotations -import typing import sqlalchemy -import pydantic.v1 as pydantic from . import entities, requester from ...core import app @@ -12,10 +10,8 @@ from ..tools import entities as tools_entities from ...discover import engine from . import token from ...entity.persistence import model as persistence_model -from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat, giteeaichatcmpl, volcarkchatcmpl, xaichatcmpl, zhipuaichatcmpl, lmstudiochatcmpl, siliconflowchatcmpl, volcarkchatcmpl - -FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list" +FETCH_MODEL_LIST_URL = 'https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list' class ModelManager: @@ -36,7 +32,7 @@ class ModelManager: requester_components: list[engine.Component] requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache - + def __init__(self, ap: app.Application): self.ap = ap self.model_list = [] @@ -45,14 +41,18 @@ class ModelManager: self.llm_models = [] self.requester_components = [] self.requester_dict = {} - + async def initialize(self): - self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester') + self.requester_components = self.ap.discover.get_components_by_kind( + 'LLMAPIRequester' + ) # forge requester class dict requester_dict: dict[str, type[requester.LLMAPIRequester]] = {} for component in self.requester_components: - requester_dict[component.metadata.name] = component.get_python_component_class() + requester_dict[component.metadata.name] = ( + component.get_python_component_class() + ) self.requester_dict = requester_dict @@ -74,18 +74,22 @@ class ModelManager: # load models for llm_model in llm_models: await self.load_llm_model(llm_model) - - async def load_llm_model(self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict): + + async def load_llm_model( + self, + model_info: persistence_model.LLMModel + | sqlalchemy.Row[persistence_model.LLMModel] + | dict, + ): """加载模型""" - + if isinstance(model_info, sqlalchemy.Row): model_info = persistence_model.LLMModel(**model_info._mapping) elif isinstance(model_info, dict): model_info = persistence_model.LLMModel(**model_info) requester_inst = self.requester_dict[model_info.requester]( - ap=self.ap, - config=model_info.requester_config + ap=self.ap, config=model_info.requester_config ) await requester_inst.initialize() @@ -96,24 +100,23 @@ class ModelManager: name=model_info.uuid, tokens=model_info.api_keys, ), - requester=requester_inst + requester=requester_inst, ) self.llm_models.append(runtime_llm_model) async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated - """通过名称获取模型 - """ + """通过名称获取模型""" for model in self.model_list: if model.name == name: return model - raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置") - + raise ValueError(f'无法确定模型 {name} 的信息,请在元数据中配置') + async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo: """通过uuid获取模型""" for model in self.llm_models: if model.model_entity.uuid == uuid: return model - raise ValueError(f"model {uuid} not found") + raise ValueError(f'model {uuid} not found') async def remove_llm_model(self, model_uuid: str): """移除模型""" @@ -124,10 +127,7 @@ class ModelManager: def get_available_requesters_info(self) -> list[dict]: """获取所有可用的请求器""" - return [ - component.to_plain_dict() - for component in self.requester_components - ] + return [component.to_plain_dict() for component in self.requester_components] def get_available_requester_info_by_name(self, name: str) -> dict | None: """通过名称获取请求器信息""" @@ -135,8 +135,10 @@ class ModelManager: if component.metadata.name == name: return component.to_plain_dict() return None - - def get_available_requester_manifest_by_name(self, name: str) -> engine.Component | None: + + def get_available_requester_manifest_by_name( + self, name: str + ) -> engine.Component | None: """通过名称获取请求器清单""" for component in self.requester_components: if component.metadata.name == name: @@ -151,4 +153,3 @@ class ModelManager: funcs: list[tools_entities.LLMFunction] = None, ) -> llm_entities.Message: pass - diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 5ea8d23f..244f4c82 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -22,16 +22,21 @@ class RuntimeLLMModel: requester: LLMAPIRequester """请求器实例""" - - def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: LLMAPIRequester): + + def __init__( + self, + model_entity: persistence_model.LLMModel, + token_mgr: token.TokenManager, + requester: LLMAPIRequester, + ): self.model_entity = model_entity self.token_mgr = token_mgr self.requester = requester class LLMAPIRequester(metaclass=abc.ABCMeta): - """LLM API请求器 - """ + """LLM API请求器""" + name: str = None ap: app.Application @@ -42,9 +47,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): def __init__(self, ap: app.Application, config: dict[str, typing.Any]): self.ap = ap - self.requester_cfg = { - **self.default_config - } + self.requester_cfg = {**self.default_config} self.requester_cfg.update(config) async def initialize(self): diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index 7edc4405..ca145e9e 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -2,16 +2,12 @@ from __future__ import annotations import typing import json -import traceback -import base64 import anthropic import httpx -from ....core import app -from .. import entities, errors, requester +from .. import errors, requester -from .. import entities, errors from ....core import entities as core_entities from ... import entities as llm_entities from ...tools import entities as tools_entities @@ -29,7 +25,6 @@ class AnthropicMessages(requester.LLMAPIRequester): } async def initialize(self): - httpx_client = anthropic._base_client.AsyncHttpxClientWrapper( base_url=self.requester_cfg['base_url'], # cast to a valid type because mypy doesn't understand our type narrowing @@ -40,7 +35,7 @@ class AnthropicMessages(requester.LLMAPIRequester): ) self.client = anthropic.AsyncAnthropic( - api_key="", + api_key='', http_client=httpx_client, ) @@ -55,7 +50,7 @@ class AnthropicMessages(requester.LLMAPIRequester): self.client.api_key = model.token_mgr.get_token() args = extra_args.copy() - args["model"] = model.model_entity.name + args['model'] = model.model_entity.name # 处理消息 @@ -63,14 +58,15 @@ class AnthropicMessages(requester.LLMAPIRequester): system_role_message = None for i, m in enumerate(messages): - if m.role == "system": + if m.role == 'system': system_role_message = m messages.pop(i) break - if isinstance(system_role_message, llm_entities.Message) \ - and isinstance(system_role_message.content, str): + if isinstance(system_role_message, llm_entities.Message) and isinstance( + system_role_message.content, str + ): args['system'] = system_role_message.content req_messages = [] @@ -79,67 +75,64 @@ class AnthropicMessages(requester.LLMAPIRequester): if m.role == 'tool': tool_call_id = m.tool_call_id - req_messages.append({ - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": tool_call_id, - "content": m.content - } - ] - }) + req_messages.append( + { + 'role': 'user', + 'content': [ + { + 'type': 'tool_result', + 'tool_use_id': tool_call_id, + 'content': m.content, + } + ], + } + ) continue msg_dict = m.dict(exclude_none=True) - if isinstance(m.content, str) and m.content.strip() != "": - msg_dict["content"] = [ - { - "type": "text", - "text": m.content - } - ] + if isinstance(m.content, str) and m.content.strip() != '': + msg_dict['content'] = [{'type': 'text', 'text': m.content}] elif isinstance(m.content, list): - for i, ce in enumerate(m.content): - - if ce.type == "image_base64": - image_b64, image_format = await image.extract_b64_and_format(ce.image_base64) + if ce.type == 'image_base64': + image_b64, image_format = await image.extract_b64_and_format( + ce.image_base64 + ) alter_image_ele = { - "type": "image", - "source": { - "type": "base64", - "media_type": f"image/{image_format}", - "data": image_b64 - } + 'type': 'image', + 'source': { + 'type': 'base64', + 'media_type': f'image/{image_format}', + 'data': image_b64, + }, } - msg_dict["content"][i] = alter_image_ele + msg_dict['content'][i] = alter_image_ele if m.tool_calls: - for tool_call in m.tool_calls: - msg_dict["content"].append({ - "type": "tool_use", - "id": tool_call.id, - "name": tool_call.function.name, - "input": json.loads(tool_call.function.arguments) - }) + msg_dict['content'].append( + { + 'type': 'tool_use', + 'id': tool_call.id, + 'name': tool_call.function.name, + 'input': json.loads(tool_call.function.arguments), + } + ) - del msg_dict["tool_calls"] + del msg_dict['tool_calls'] req_messages.append(msg_dict) - - args["messages"] = req_messages - + args['messages'] = req_messages + if funcs: tools = await self.ap.tool_mgr.generate_tools_for_anthropic(funcs) if tools: - args["tools"] = tools + args['tools'] = tools try: # print(json.dumps(args, indent=4, ensure_ascii=False)) @@ -149,23 +142,24 @@ class AnthropicMessages(requester.LLMAPIRequester): 'content': '', 'role': resp.role, } - + assert type(resp) is anthropic.types.message.Message for block in resp.content: if block.type == 'thinking': - args['content'] = '' + block.thinking + '\n' + args['content'] + args['content'] = ( + '' + block.thinking + '\n' + args['content'] + ) elif block.type == 'text': args['content'] += block.text elif block.type == 'tool_use': assert type(block) is anthropic.types.tool_use_block.ToolUseBlock tool_call = llm_entities.ToolCall( id=block.id, - type="function", + type='function', function=llm_entities.FunctionCall( - name=block.name, - arguments=json.dumps(block.input) - ) + name=block.name, arguments=json.dumps(block.input) + ), ) if 'tool_calls' not in args: args['tool_calls'] = [] diff --git a/pkg/provider/modelmgr/requesters/bailianchatcmpl.py b/pkg/provider/modelmgr/requesters/bailianchatcmpl.py index e20e3376..287eb5b9 100644 --- a/pkg/provider/modelmgr/requesters/bailianchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/bailianchatcmpl.py @@ -4,8 +4,6 @@ import typing import openai from . import chatcmpl -from .. import requester -from ....core import app class BailianChatCompletions(chatcmpl.OpenAIChatCompletions): diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index c54d466f..e341d0fb 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -2,22 +2,15 @@ from __future__ import annotations import asyncio import typing -import json -import base64 -from typing import AsyncGenerator import openai import openai.types.chat.chat_completion as chat_completion -import openai.types.chat.chat_completion_message_tool_call as chat_completion_message_tool_call import httpx -import aiohttp -import async_lru -from .. import entities, errors, requester -from ....core import entities as core_entities, app +from .. import errors, requester +from ....core import entities as core_entities from ... import entities as llm_entities from ...tools import entities as tools_entities -from ....utils import image class OpenAIChatCompletions(requester.LLMAPIRequester): @@ -26,18 +19,17 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - "base_url": "https://api.openai.com/v1", - "timeout": 120, + 'base_url': 'https://api.openai.com/v1', + 'timeout': 120, } async def initialize(self): - self.client = openai.AsyncClient( - api_key="", - base_url=self.requester_cfg["base_url"], - timeout=self.requester_cfg["timeout"], + api_key='', + base_url=self.requester_cfg['base_url'], + timeout=self.requester_cfg['timeout'], http_client=httpx.AsyncClient( - trust_env=True, timeout=self.requester_cfg["timeout"] + trust_env=True, timeout=self.requester_cfg['timeout'] ), ) @@ -54,8 +46,8 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): chatcmpl_message = chat_completion.choices[0].message.model_dump() # 确保 role 字段存在且不为 None - if "role" not in chatcmpl_message or chatcmpl_message["role"] is None: - chatcmpl_message["role"] = "assistant" + if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: + chatcmpl_message['role'] = 'assistant' message = llm_entities.Message(**chatcmpl_message) @@ -72,27 +64,27 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): self.client.api_key = use_model.token_mgr.get_token() args = extra_args.copy() - args["model"] = use_model.model_entity.name + args['model'] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) if tools: - args["tools"] = tools + args['tools'] = tools # 设置此次请求中的messages messages = req_messages.copy() # 检查vision for msg in messages: - if "content" in msg and isinstance(msg["content"], list): - for me in msg["content"]: - if me["type"] == "image_base64": - me["image_url"] = {"url": me["image_base64"]} - me["type"] = "image_url" - del me["image_base64"] + if 'content' in msg and isinstance(msg['content'], list): + for me in msg['content']: + if me['type'] == 'image_base64': + me['image_url'] = {'url': me['image_base64']} + me['type'] = 'image_url' + del me['image_base64'] - args["messages"] = messages + args['messages'] = messages # 发送请求 resp = await self._req(args) @@ -113,15 +105,15 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: msg_dict = m.dict(exclude_none=True) - content = msg_dict.get("content") + content = msg_dict.get('content') if isinstance(content, list): # 检查 content 列表中是否每个部分都是文本 if all( - isinstance(part, dict) and part.get("type") == "text" + isinstance(part, dict) and part.get('type') == 'text' for part in content ): # 将所有文本部分合并为一个字符串 - msg_dict["content"] = "\n".join(part["text"] for part in content) + msg_dict['content'] = '\n'.join(part['text'] for part in content) req_messages.append(msg_dict) try: @@ -133,17 +125,17 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): extra_args=extra_args, ) except asyncio.TimeoutError: - raise errors.RequesterError("请求超时") + raise errors.RequesterError('请求超时') except openai.BadRequestError as e: - if "context_length_exceeded" in e.message: - raise errors.RequesterError(f"上文过长,请重置会话: {e.message}") + if 'context_length_exceeded' in e.message: + raise errors.RequesterError(f'上文过长,请重置会话: {e.message}') else: - raise errors.RequesterError(f"请求参数错误: {e.message}") + raise errors.RequesterError(f'请求参数错误: {e.message}') except openai.AuthenticationError as e: - raise errors.RequesterError(f"无效的 api-key: {e.message}") + raise errors.RequesterError(f'无效的 api-key: {e.message}') except openai.NotFoundError as e: - raise errors.RequesterError(f"请求路径错误: {e.message}") + raise errors.RequesterError(f'请求路径错误: {e.message}') except openai.RateLimitError as e: - raise errors.RequesterError(f"请求过于频繁或余额不足: {e.message}") + raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: - raise errors.RequesterError(f"请求错误: {e.message}") + raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py index ee17ac05..49457ac0 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py @@ -3,8 +3,8 @@ from __future__ import annotations import typing from . import chatcmpl -from .. import entities, errors, requester -from ....core import entities as core_entities, app +from .. import errors, requester +from ....core import entities as core_entities from ... import entities as llm_entities from ...tools import entities as tools_entities @@ -28,23 +28,23 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): self.client.api_key = use_model.token_mgr.get_token() args = extra_args.copy() - args["model"] = use_model.model_entity.name + args['model'] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) if tools: - args["tools"] = tools + args['tools'] = tools # 设置此次请求中的messages messages = req_messages # deepseek 不支持多模态,把content都转换成纯文字 for m in messages: - if 'content' in m and isinstance(m["content"], list): - m["content"] = " ".join([c["text"] for c in m["content"]]) + if 'content' in m and isinstance(m['content'], list): + m['content'] = ' '.join([c['text'] for c in m['content']]) - args["messages"] = messages + args['messages'] = messages # 发送请求 resp = await self._req(args) @@ -55,4 +55,4 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): # 处理请求结果 message = await self._make_msg(resp) - return message \ No newline at end of file + return message diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py index 35052682..b85cc54d 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py @@ -1,17 +1,13 @@ from __future__ import annotations -import json -import asyncio -import aiohttp import typing from . import chatcmpl -from .. import entities, errors, requester -from ....core import app, entities as core_entities +from .. import requester +from ....core import entities as core_entities from ... import entities as llm_entities from ...tools import entities as tools_entities -from .. import entities as modelmgr_entities class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): @@ -33,20 +29,20 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): self.client.api_key = use_model.token_mgr.get_token() args = extra_args.copy() - args["model"] = use_model.model_entity.name + args['model'] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) if tools: - args["tools"] = tools + args['tools'] = tools # gitee 不支持多模态,把content都转换成纯文字 for m in req_messages: - if 'content' in m and isinstance(m["content"], list): - m["content"] = " ".join([c["text"] for c in m["content"]]) + if 'content' in m and isinstance(m['content'], list): + m['content'] = ' '.join([c['text'] for c in m['content']]) - args["messages"] = req_messages + args['messages'] = req_messages resp = await self._req(args) diff --git a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py index 6be76051..c9060c1b 100644 --- a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py +++ b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py @@ -4,8 +4,6 @@ import typing import openai from . import chatcmpl -from .. import requester -from ....core import app class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions): diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py index 1ef7d9c9..ac565cdb 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py @@ -2,11 +2,10 @@ from __future__ import annotations import typing -from ....core import app from . import chatcmpl -from .. import entities, errors, requester -from ....core import entities as core_entities, app +from .. import requester +from ....core import entities as core_entities from ... import entities as llm_entities from ...tools import entities as tools_entities @@ -30,26 +29,26 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): self.client.api_key = use_model.token_mgr.get_token() args = extra_args.copy() - args["model"] = use_model.model_entity.name + args['model'] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) if tools: - args["tools"] = tools + args['tools'] = tools # 设置此次请求中的messages messages = req_messages # deepseek 不支持多模态,把content都转换成纯文字 for m in messages: - if 'content' in m and isinstance(m["content"], list): - m["content"] = " ".join([c["text"] for c in m["content"]]) + if 'content' in m and isinstance(m['content'], list): + m['content'] = ' '.join([c['text'] for c in m['content']]) # 删除空的 - messages = [m for m in messages if m["content"].strip() != ""] + messages = [m for m in messages if m['content'].strip() != ''] - args["messages"] = messages + args['messages'] = messages # 发送请求 resp = await self._req(args) @@ -57,4 +56,4 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): # 处理请求结果 message = await self._make_msg(resp) - return message \ No newline at end of file + return message diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index ee331036..995dd855 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -6,18 +6,15 @@ import typing from typing import Union, Mapping, Any, AsyncIterator import uuid import json -import base64 -import async_lru import ollama -from .. import entities, errors, requester +from .. import errors, requester from ... import entities as llm_entities from ...tools import entities as tools_entities -from ....core import app, entities as core_entities -from ....utils import image +from ....core import entities as core_entities -REQUESTER_NAME: str = "ollama-chat" +REQUESTER_NAME: str = 'ollama-chat' class OllamaChatCompletions(requester.LLMAPIRequester): @@ -26,13 +23,13 @@ class OllamaChatCompletions(requester.LLMAPIRequester): client: ollama.AsyncClient default_config: dict[str, typing.Any] = { - "base_url": "http://127.0.0.1:11434", - "timeout": 120, + 'base_url': 'http://127.0.0.1:11434', + 'timeout': 120, } async def initialize(self): - os.environ["OLLAMA_HOST"] = self.requester_cfg["base_url"] - self.client = ollama.AsyncClient(timeout=self.requester_cfg["timeout"]) + os.environ['OLLAMA_HOST'] = self.requester_cfg['base_url'] + self.client = ollama.AsyncClient(timeout=self.requester_cfg['timeout']) async def _req( self, @@ -49,35 +46,35 @@ class OllamaChatCompletions(requester.LLMAPIRequester): extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: args = extra_args.copy() - args["model"] = use_model.model_entity.name + args['model'] = use_model.model_entity.name messages: list[dict] = req_messages.copy() for msg in messages: - if "content" in msg and isinstance(msg["content"], list): + if 'content' in msg and isinstance(msg['content'], list): text_content: list = [] image_urls: list = [] - for me in msg["content"]: - if me["type"] == "text": - text_content.append(me["text"]) - elif me["type"] == "image_base64": - image_urls.append(me["image_base64"]) + for me in msg['content']: + if me['type'] == 'text': + text_content.append(me['text']) + elif me['type'] == 'image_base64': + image_urls.append(me['image_base64']) - msg["content"] = "\n".join(text_content) - msg["images"] = [url.split(",")[1] for url in image_urls] + msg['content'] = '\n'.join(text_content) + msg['images'] = [url.split(',')[1] for url in image_urls] if ( - "tool_calls" in msg + 'tool_calls' in msg ): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict - for tool_call in msg["tool_calls"]: - tool_call["function"]["arguments"] = json.loads( - tool_call["function"]["arguments"] + for tool_call in msg['tool_calls']: + tool_call['function']['arguments'] = json.loads( + tool_call['function']['arguments'] ) - args["messages"] = messages + args['messages'] = messages - args["tools"] = [] + args['tools'] = [] if user_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(user_funcs) if tools: - args["tools"] = tools + args['tools'] = tools resp = await self._req(args) message: llm_entities.Message = await self._make_msg(resp) @@ -93,7 +90,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester): ret_msg: llm_entities.Message = None if message.content is not None: - ret_msg = llm_entities.Message(role="assistant", content=message.content) + ret_msg = llm_entities.Message(role='assistant', content=message.content) if message.tool_calls is not None and len(message.tool_calls) > 0: tool_calls: list[llm_entities.ToolCall] = [] @@ -101,7 +98,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester): tool_calls.append( llm_entities.ToolCall( id=uuid.uuid4().hex, - type="function", + type='function', function=llm_entities.FunctionCall( name=tool_call.function.name, arguments=json.dumps(tool_call.function.arguments), @@ -123,13 +120,13 @@ class OllamaChatCompletions(requester.LLMAPIRequester): req_messages: list = [] for m in messages: msg_dict: dict = m.dict(exclude_none=True) - content: Any = msg_dict.get("content") + content: Any = msg_dict.get('content') if isinstance(content, list): if all( - isinstance(part, dict) and part.get("type") == "text" + isinstance(part, dict) and part.get('type') == 'text' for part in content ): - msg_dict["content"] = "\n".join(part["text"] for part in content) + msg_dict['content'] = '\n'.join(part['text'] for part in content) req_messages.append(msg_dict) try: return await self._closure( @@ -140,4 +137,4 @@ class OllamaChatCompletions(requester.LLMAPIRequester): extra_args=extra_args, ) except asyncio.TimeoutError: - raise errors.RequesterError("请求超时") + raise errors.RequesterError('请求超时') diff --git a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py index dd5b9a14..3636d9d1 100644 --- a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py @@ -4,8 +4,6 @@ import typing import openai from . import chatcmpl -from .. import requester -from ....core import app class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions): diff --git a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py index 9b5505e1..7eb68956 100644 --- a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py @@ -4,8 +4,6 @@ import typing import openai from . import chatcmpl -from .. import requester -from ....core import app class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions): diff --git a/pkg/provider/modelmgr/requesters/xaichatcmpl.py b/pkg/provider/modelmgr/requesters/xaichatcmpl.py index e08af875..db2022f1 100644 --- a/pkg/provider/modelmgr/requesters/xaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/xaichatcmpl.py @@ -4,8 +4,6 @@ import typing import openai from . import chatcmpl -from .. import requester -from ....core import app class XaiChatCompletions(chatcmpl.OpenAIChatCompletions): diff --git a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py index 7bbca164..a1a07068 100644 --- a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py @@ -3,9 +3,7 @@ from __future__ import annotations import typing import openai -from ....core import app from . import chatcmpl -from .. import requester class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions): diff --git a/pkg/provider/modelmgr/token.py b/pkg/provider/modelmgr/token.py index eeec6986..9f477243 100644 --- a/pkg/provider/modelmgr/token.py +++ b/pkg/provider/modelmgr/token.py @@ -3,9 +3,8 @@ from __future__ import annotations import typing -class TokenManager(): - """鉴权 Token 管理器 - """ +class TokenManager: + """鉴权 Token 管理器""" name: str @@ -20,6 +19,6 @@ class TokenManager(): def get_token(self) -> str: return self.tokens[self.using_token_index] - + def next_token(self): self.using_token_index = (self.using_token_index + 1) % len(self.tokens) diff --git a/pkg/provider/runner.py b/pkg/provider/runner.py index 1762e546..ccfcee73 100644 --- a/pkg/provider/runner.py +++ b/pkg/provider/runner.py @@ -9,9 +9,10 @@ from . import entities as llm_entities preregistered_runners: list[typing.Type[RequestRunner]] = [] + def runner_class(name: str): - """注册一个请求运行器 - """ + """注册一个请求运行器""" + def decorator(cls: typing.Type[RequestRunner]) -> typing.Type[RequestRunner]: cls.name = name preregistered_runners.append(cls) @@ -21,8 +22,8 @@ def runner_class(name: str): class RequestRunner(abc.ABC): - """请求运行器 - """ + """请求运行器""" + name: str = None ap: app.Application @@ -34,7 +35,8 @@ class RequestRunner(abc.ABC): self.pipeline_config = pipeline_config @abc.abstractmethod - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: - """运行请求 - """ + async def run( + self, query: core_entities.Query + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """运行请求""" pass diff --git a/pkg/provider/runners/dashscopeapi.py b/pkg/provider/runners/dashscopeapi.py index d5e6a83d..92a1eb18 100644 --- a/pkg/provider/runners/dashscopeapi.py +++ b/pkg/provider/runners/dashscopeapi.py @@ -1,8 +1,6 @@ from __future__ import annotations import typing -import json -import base64 import re import dashscope @@ -10,7 +8,7 @@ import dashscope from .. import runner from ...core import app, entities as core_entities from .. import entities as llm_entities -from ...utils import image + class DashscopeAPIError(Exception): """Dashscope API 请求失败""" @@ -20,49 +18,49 @@ class DashscopeAPIError(Exception): super().__init__(self.message) -@runner.runner_class("dashscope-app-api") +@runner.runner_class('dashscope-app-api') class DashScopeAPIRunner(runner.RequestRunner): "阿里云百炼DashsscopeAPI对话请求器" - + # 运行器内部使用的配置 - app_type: str # 应用类型 - app_id: str # 应用ID - api_key: str # API Key - references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置) + app_type: str # 应用类型 + app_id: str # 应用ID + api_key: str # API Key + references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置) def __init__(self, ap: app.Application, pipeline_config: dict): """初始化""" self.ap = ap self.pipeline_config = pipeline_config - valid_app_types = ["agent", "workflow"] - self.app_type = self.pipeline_config["ai"]["dashscope-app-api"]["app-type"] - #检查配置文件中使用的应用类型是否支持 - if (self.app_type not in valid_app_types): - raise DashscopeAPIError( - f"不支持的 Dashscope 应用类型: {self.app_type}" - ) - - #初始化Dashscope 参数配置 - self.app_id = self.pipeline_config["ai"]["dashscope-app-api"]["app-id"] - self.api_key = self.pipeline_config["ai"]["dashscope-app-api"]["api-key"] - self.references_quote = self.pipeline_config["ai"]["dashscope-app-api"]["references_quote"] - + valid_app_types = ['agent', 'workflow'] + self.app_type = self.pipeline_config['ai']['dashscope-app-api']['app-type'] + # 检查配置文件中使用的应用类型是否支持 + if self.app_type not in valid_app_types: + raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}') + + # 初始化Dashscope 参数配置 + self.app_id = self.pipeline_config['ai']['dashscope-app-api']['app-id'] + self.api_key = self.pipeline_config['ai']['dashscope-app-api']['api-key'] + self.references_quote = self.pipeline_config['ai']['dashscope-app-api'][ + 'references_quote' + ] + def _replace_references(self, text, references_dict): """阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料""" - + # 匹配 [index_id] 形式的字符串 pattern = re.compile(r'\[(.*?)\]') def replacement(match): # 获取引用编号 - ref_key = match.group(1) + ref_key = match.group(1) if ref_key in references_dict: # 如果有对应的参考资料按照provider.json中的reference_quote返回提示,来自哪个参考资料文件 - return f"({self.references_quote} {references_dict[ref_key]})" + return f'({self.references_quote} {references_dict[ref_key]})' else: # 如果没有对应的参考资料,保留原样 - return match.group(0) + return match.group(0) # 使用 re.sub() 进行替换 return pattern.sub(replacement, text) @@ -71,14 +69,14 @@ class DashScopeAPIRunner(runner.RequestRunner): self, query: core_entities.Query ) -> tuple[str, list[str]]: """预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)""" - plain_text = "" + plain_text = '' image_ids = [] if isinstance(query.user_message.content, list): for ce in query.user_message.content: - if ce.type == "text": + if ce.type == 'text': plain_text += ce.text # 暂时不支持上传图片,保留代码以便后续扩展 - # elif ce.type == "image_base64": + # elif ce.type == "image_base64": # image_b64, image_format = await image.extract_b64_and_format(ce.image_base64) # file_bytes = base64.b64decode(image_b64) # file = ("img.png", file_bytes, f"image/{image_format}") @@ -92,147 +90,141 @@ class DashScopeAPIRunner(runner.RequestRunner): plain_text = query.user_message.content return plain_text, image_ids - - + async def _agent_messages( self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """Dashscope 智能体对话请求""" - - #局部变量 - chunk = None # 流式传输的块 - pending_content = "" # 待处理的Agent输出内容 - references_dict = {} # 用于存储引用编号和对应的参考资料 - plain_text = "" # 用户输入的纯文本信息 - image_ids = [] # 用户输入的图片ID列表 (暂不支持) - + + # 局部变量 + chunk = None # 流式传输的块 + pending_content = '' # 待处理的Agent输出内容 + references_dict = {} # 用于存储引用编号和对应的参考资料 + plain_text = '' # 用户输入的纯文本信息 + image_ids = [] # 用户输入的图片ID列表 (暂不支持) + plain_text, image_ids = await self._preprocess_user_message(query) - - #发送对话请求 + + # 发送对话请求 response = dashscope.Application.call( - api_key=self.api_key, # 智能体应用的API Key - app_id=self.app_id, # 智能体应用的ID - prompt=plain_text, # 用户输入的文本信息 - stream=True, # 流式输出 - incremental_output=True, # 增量输出,使用流式输出需要开启增量输出 - session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话 + api_key=self.api_key, # 智能体应用的API Key + app_id=self.app_id, # 智能体应用的ID + prompt=plain_text, # 用户输入的文本信息 + stream=True, # 流式输出 + incremental_output=True, # 增量输出,使用流式输出需要开启增量输出 + session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话 # rag_options={ # 主要用于文件交互,暂不支持 # "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个 # } ) for chunk in response: - if chunk.get("status_code") != 200: + if chunk.get('status_code') != 200: raise DashscopeAPIError( - f"Dashscope API 请求失败: status_code={chunk.get('status_code')} message={chunk.get('message')} request_id={chunk.get('request_id')} " + f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} ' ) if not chunk: continue - - #获取流式传输的output - stream_output = chunk.get("output", {}) - if stream_output.get("text") is not None: - pending_content += stream_output.get("text") - - #保存当前会话的session_id用于下次对话的语境 - query.session.using_conversation.uuid = stream_output.get("session_id") - - #获取模型传出的参考资料列表 - references_dict_list = stream_output.get("doc_references", []) - - #从模型传出的参考资料信息中提取用于替换的字典 + + # 获取流式传输的output + stream_output = chunk.get('output', {}) + if stream_output.get('text') is not None: + pending_content += stream_output.get('text') + + # 保存当前会话的session_id用于下次对话的语境 + query.session.using_conversation.uuid = stream_output.get('session_id') + + # 获取模型传出的参考资料列表 + references_dict_list = stream_output.get('doc_references', []) + + # 从模型传出的参考资料信息中提取用于替换的字典 if references_dict_list is not None: for doc in references_dict_list: - if doc.get("index_id") is not None: - references_dict[doc.get("index_id")] = doc.get("doc_name") - - #将参考资料替换到文本中 + if doc.get('index_id') is not None: + references_dict[doc.get('index_id')] = doc.get('doc_name') + + # 将参考资料替换到文本中 pending_content = self._replace_references(pending_content, references_dict) - + yield llm_entities.Message( - role="assistant", + role='assistant', content=pending_content, ) - - + async def _workflow_messages( self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """Dashscope 工作流对话请求""" - - #局部变量 - chunk = None # 流式传输的块 - pending_content = "" # 待处理的Agent输出内容 - references_dict = {} # 用于存储引用编号和对应的参考资料 - plain_text = "" # 用户输入的纯文本信息 - image_ids = [] # 用户输入的图片ID列表 (暂不支持) - + + # 局部变量 + chunk = None # 流式传输的块 + pending_content = '' # 待处理的Agent输出内容 + references_dict = {} # 用于存储引用编号和对应的参考资料 + plain_text = '' # 用户输入的纯文本信息 + image_ids = [] # 用户输入的图片ID列表 (暂不支持) + plain_text, image_ids = await self._preprocess_user_message(query) biz_params = {} biz_params.update(query.variables) - - #发送对话请求 + + # 发送对话请求 response = dashscope.Application.call( - api_key=self.api_key, # 智能体应用的API Key - app_id=self.app_id, # 智能体应用的ID - prompt=plain_text, # 用户输入的文本信息 - stream=True, # 流式输出 - incremental_output=True, # 增量输出,使用流式输出需要开启增量输出 - session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话 - biz_params=biz_params, # 工作流应用的自定义输入参数传递 + api_key=self.api_key, # 智能体应用的API Key + app_id=self.app_id, # 智能体应用的ID + prompt=plain_text, # 用户输入的文本信息 + stream=True, # 流式输出 + incremental_output=True, # 增量输出,使用流式输出需要开启增量输出 + session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话 + biz_params=biz_params, # 工作流应用的自定义输入参数传递 # rag_options={ # 主要用于文件交互,暂不支持 # "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个 # } ) - - #处理API返回的流式输出 + + # 处理API返回的流式输出 for chunk in response: - if chunk.get("status_code") != 200: + if chunk.get('status_code') != 200: raise DashscopeAPIError( - f"Dashscope API 请求失败: status_code={chunk.get('status_code')} message={chunk.get('message')} request_id={chunk.get('request_id')} " + f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} ' ) if not chunk: continue - - #获取流式传输的output - stream_output = chunk.get("output", {}) - if stream_output.get("text") is not None: - pending_content += stream_output.get("text") - - #保存当前会话的session_id用于下次对话的语境 - query.session.using_conversation.uuid = stream_output.get("session_id") - - #获取模型传出的参考资料列表 - references_dict_list = stream_output.get("doc_references", []) - - #从模型传出的参考资料信息中提取用于替换的字典 + + # 获取流式传输的output + stream_output = chunk.get('output', {}) + if stream_output.get('text') is not None: + pending_content += stream_output.get('text') + + # 保存当前会话的session_id用于下次对话的语境 + query.session.using_conversation.uuid = stream_output.get('session_id') + + # 获取模型传出的参考资料列表 + references_dict_list = stream_output.get('doc_references', []) + + # 从模型传出的参考资料信息中提取用于替换的字典 if references_dict_list is not None: for doc in references_dict_list: - if doc.get("index_id") is not None: - references_dict[doc.get("index_id")] = doc.get("doc_name") - - #将参考资料替换到文本中 + if doc.get('index_id') is not None: + references_dict[doc.get('index_id')] = doc.get('doc_name') + + # 将参考资料替换到文本中 pending_content = self._replace_references(pending_content, references_dict) - + yield llm_entities.Message( - role="assistant", + role='assistant', content=pending_content, ) - + async def run( self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行""" - if self.app_type == "agent": + if self.app_type == 'agent': async for msg in self._agent_messages(query): yield msg - elif self.app_type == "workflow": + elif self.app_type == 'workflow': async for msg in self._workflow_messages(query): yield msg else: - raise DashscopeAPIError( - f"不支持的 Dashscope 应用类型: {self.app_type}" - ) - - + raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}') diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index f48cbd57..1d1576a6 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -5,9 +5,7 @@ import json import uuid import re import base64 -import datetime -import aiohttp from .. import runner from ...core import app, entities as core_entities @@ -17,7 +15,7 @@ from ...utils import image from libs.dify_service_api.v1 import client, errors -@runner.runner_class("dify-service-api") +@runner.runner_class('dify-service-api') class DifyServiceAPIRunner(runner.RequestRunner): """Dify Service API 对话请求器""" @@ -27,38 +25,54 @@ class DifyServiceAPIRunner(runner.RequestRunner): self.ap = ap self.pipeline_config = pipeline_config - valid_app_types = ["chat", "agent", "workflow"] + valid_app_types = ['chat', 'agent', 'workflow'] if ( - self.pipeline_config["ai"]["dify-service-api"]["app-type"] + self.pipeline_config['ai']['dify-service-api']['app-type'] not in valid_app_types ): raise errors.DifyAPIError( - f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}" + f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}' ) - api_key = self.pipeline_config["ai"]["dify-service-api"]["api-key"] + api_key = self.pipeline_config['ai']['dify-service-api']['api-key'] self.dify_client = client.AsyncDifyServiceClient( api_key=api_key, - base_url=self.pipeline_config["ai"]["dify-service-api"]["base-url"], + base_url=self.pipeline_config['ai']['dify-service-api']['base-url'], ) def _try_convert_thinking(self, resp_text: str) -> str: """尝试转换 Dify 的思考提示""" - if not resp_text.startswith("
Thinking... "): + if not resp_text.startswith( + '
Thinking... ' + ): return resp_text - if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "original": + if ( + self.pipeline_config['ai']['dify-service-api']['thinking-convert'] + == 'original' + ): return resp_text - - if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "remove": - return re.sub(r'
Thinking... .*?
', '', resp_text, flags=re.DOTALL) - - if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "plain": + + if ( + self.pipeline_config['ai']['dify-service-api']['thinking-convert'] + == 'remove' + ): + return re.sub( + r'
Thinking... .*?
', + '', + resp_text, + flags=re.DOTALL, + ) + + if ( + self.pipeline_config['ai']['dify-service-api']['thinking-convert'] + == 'plain' + ): pattern = r'
Thinking... (.*?)
' thinking_text = re.search(pattern, resp_text, flags=re.DOTALL) content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL) - return f"{thinking_text.group(1)}\n{content_text}" + return f'{thinking_text.group(1)}\n{content_text}' async def _preprocess_user_message( self, query: core_entities.Query @@ -68,22 +82,24 @@ class DifyServiceAPIRunner(runner.RequestRunner): Returns: tuple[str, list[str]]: 纯文本和图片的 Dify 服务图片 ID """ - plain_text = "" + plain_text = '' image_ids = [] if isinstance(query.user_message.content, list): for ce in query.user_message.content: - if ce.type == "text": + if ce.type == 'text': plain_text += ce.text - elif ce.type == "image_base64": - image_b64, image_format = await image.extract_b64_and_format(ce.image_base64) + elif ce.type == 'image_base64': + image_b64, image_format = await image.extract_b64_and_format( + ce.image_base64 + ) file_bytes = base64.b64decode(image_b64) - file = ("img.png", file_bytes, f"image/{image_format}") + file = ('img.png', file_bytes, f'image/{image_format}') file_upload_resp = await self.dify_client.upload_file( file, - f"{query.session.launcher_type.value}_{query.session.launcher_id}", + f'{query.session.launcher_type.value}_{query.session.launcher_id}', ) - image_id = file_upload_resp["id"] + image_id = file_upload_resp['id'] image_ids.append(image_id) elif isinstance(query.user_message.content, str): plain_text = query.user_message.content @@ -94,116 +110,119 @@ class DifyServiceAPIRunner(runner.RequestRunner): self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """调用聊天助手""" - cov_id = query.session.using_conversation.uuid or "" + cov_id = query.session.using_conversation.uuid or '' plain_text, image_ids = await self._preprocess_user_message(query) files = [ { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": image_id, + 'type': 'image', + 'transfer_method': 'local_file', + 'upload_file_id': image_id, } for image_id in image_ids ] - mode = "basic" # 标记是基础编排还是工作流编排 + mode = 'basic' # 标记是基础编排还是工作流编排 basic_mode_pending_chunk = '' inputs = {} - + inputs.update(query.variables) async for chunk in self.dify_client.chat_messages( inputs=inputs, query=plain_text, - user=f"{query.session.launcher_type.value}_{query.session.launcher_id}", + user=f'{query.session.launcher_type.value}_{query.session.launcher_id}', conversation_id=cov_id, files=files, - timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"], + timeout=self.pipeline_config['ai']['dify-service-api']['timeout'], ): - self.ap.logger.debug("dify-chat-chunk: " + str(chunk)) + self.ap.logger.debug('dify-chat-chunk: ' + str(chunk)) if chunk['event'] == 'workflow_started': - mode = "workflow" + mode = 'workflow' - if mode == "workflow": + if mode == 'workflow': if chunk['event'] == 'node_finished': if chunk['data']['node_type'] == 'answer': yield llm_entities.Message( - role="assistant", - content=self._try_convert_thinking(chunk['data']['outputs']['answer']), + role='assistant', + content=self._try_convert_thinking( + chunk['data']['outputs']['answer'] + ), ) - elif mode == "basic": + elif mode == 'basic': if chunk['event'] == 'message': basic_mode_pending_chunk += chunk['answer'] elif chunk['event'] == 'message_end': yield llm_entities.Message( - role="assistant", + role='assistant', content=self._try_convert_thinking(basic_mode_pending_chunk), ) basic_mode_pending_chunk = '' - query.session.using_conversation.uuid = chunk["conversation_id"] + query.session.using_conversation.uuid = chunk['conversation_id'] async def _agent_chat_messages( self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """调用聊天助手""" - cov_id = query.session.using_conversation.uuid or "" + cov_id = query.session.using_conversation.uuid or '' plain_text, image_ids = await self._preprocess_user_message(query) files = [ { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": image_id, + 'type': 'image', + 'transfer_method': 'local_file', + 'upload_file_id': image_id, } for image_id in image_ids ] - ignored_events = ["agent_message"] + ignored_events = ['agent_message'] inputs = {} - + inputs.update(query.variables) async for chunk in self.dify_client.chat_messages( inputs=inputs, query=plain_text, - user=f"{query.session.launcher_type.value}_{query.session.launcher_id}", - response_mode="streaming", + user=f'{query.session.launcher_type.value}_{query.session.launcher_id}', + response_mode='streaming', conversation_id=cov_id, files=files, - timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"], + timeout=self.pipeline_config['ai']['dify-service-api']['timeout'], ): - self.ap.logger.debug("dify-agent-chunk: " + str(chunk)) + self.ap.logger.debug('dify-agent-chunk: ' + str(chunk)) - if chunk["event"] in ignored_events: + if chunk['event'] in ignored_events: continue - if chunk["event"] == "agent_thought": - - if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过 + if chunk['event'] == 'agent_thought': + if ( + chunk['tool'] != '' and chunk['observation'] != '' + ): # 工具调用结果,跳过 continue if chunk['thought'].strip() != '': # 文字回复内容 msg = llm_entities.Message( - role="assistant", - content=chunk["thought"], + role='assistant', + content=chunk['thought'], ) yield msg if chunk['tool']: msg = llm_entities.Message( - role="assistant", + role='assistant', tool_calls=[ llm_entities.ToolCall( id=chunk['id'], - type="function", + type='function', function=llm_entities.FunctionCall( - name=chunk["tool"], + name=chunk['tool'], arguments=json.dumps({}), ), ) @@ -211,9 +230,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): ) yield msg if chunk['event'] == 'message_file': - if chunk['type'] == 'image' and chunk['belongs_to'] == 'assistant': - base_url = self.dify_client.base_url if base_url.endswith('/v1'): @@ -222,11 +239,11 @@ class DifyServiceAPIRunner(runner.RequestRunner): image_url = base_url + chunk['url'] yield llm_entities.Message( - role="assistant", + role='assistant', content=[llm_entities.ContentElement.from_image_url(image_url)], ) - query.session.using_conversation.uuid = chunk["conversation_id"] + query.session.using_conversation.uuid = chunk['conversation_id'] async def _workflow_messages( self, query: core_entities.Query @@ -235,58 +252,57 @@ class DifyServiceAPIRunner(runner.RequestRunner): if not query.session.using_conversation.uuid: query.session.using_conversation.uuid = str(uuid.uuid4()) - - query.variables["conversation_id"] = query.session.using_conversation.uuid + + query.variables['conversation_id'] = query.session.using_conversation.uuid plain_text, image_ids = await self._preprocess_user_message(query) files = [ { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": image_id, + 'type': 'image', + 'transfer_method': 'local_file', + 'upload_file_id': image_id, } for image_id in image_ids ] - ignored_events = ["text_chunk", "workflow_started"] + ignored_events = ['text_chunk', 'workflow_started'] inputs = { # these variables are legacy variables, we need to keep them for compatibility - "langbot_user_message_text": plain_text, - "langbot_session_id": query.variables["session_id"], - "langbot_conversation_id": query.variables["conversation_id"], - "langbot_msg_create_time": query.variables["msg_create_time"], + 'langbot_user_message_text': plain_text, + 'langbot_session_id': query.variables['session_id'], + 'langbot_conversation_id': query.variables['conversation_id'], + 'langbot_msg_create_time': query.variables['msg_create_time'], } - + inputs.update(query.variables) async for chunk in self.dify_client.workflow_run( inputs=inputs, - user=f"{query.session.launcher_type.value}_{query.session.launcher_id}", + user=f'{query.session.launcher_type.value}_{query.session.launcher_id}', files=files, - timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"], + timeout=self.pipeline_config['ai']['dify-service-api']['timeout'], ): - self.ap.logger.debug("dify-workflow-chunk: " + str(chunk)) - if chunk["event"] in ignored_events: + self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk)) + if chunk['event'] in ignored_events: continue - if chunk["event"] == "node_started": - + if chunk['event'] == 'node_started': if ( - chunk["data"]["node_type"] == "start" - or chunk["data"]["node_type"] == "end" + chunk['data']['node_type'] == 'start' + or chunk['data']['node_type'] == 'end' ): continue msg = llm_entities.Message( - role="assistant", + role='assistant', content=None, tool_calls=[ llm_entities.ToolCall( - id=chunk["data"]["node_id"], - type="function", + id=chunk['data']['node_id'], + type='function', function=llm_entities.FunctionCall( - name=chunk["data"]["title"], + name=chunk['data']['title'], arguments=json.dumps({}), ), ) @@ -295,13 +311,13 @@ class DifyServiceAPIRunner(runner.RequestRunner): yield msg - elif chunk["event"] == "workflow_finished": + elif chunk['event'] == 'workflow_finished': if chunk['data']['error']: raise errors.DifyAPIError(chunk['data']['error']) msg = llm_entities.Message( - role="assistant", - content=chunk["data"]["outputs"]["summary"], + role='assistant', + content=chunk['data']['outputs']['summary'], ) yield msg @@ -310,16 +326,16 @@ class DifyServiceAPIRunner(runner.RequestRunner): self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" - if self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "chat": + if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat': async for msg in self._chat_messages(query): yield msg - elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "agent": + elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent': async for msg in self._agent_chat_messages(query): yield msg - elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "workflow": + elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow': async for msg in self._workflow_messages(query): yield msg else: raise errors.DifyAPIError( - f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}" + f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}' ) diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 68bb2b4f..d6b6f6cd 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -4,24 +4,28 @@ import json import typing from .. import runner -from ...core import app, entities as core_entities +from ...core import entities as core_entities from .. import entities as llm_entities -@runner.runner_class("local-agent") +@runner.runner_class('local-agent') class LocalAgentRunner(runner.RequestRunner): - """本地Agent请求运行器 - """ + """本地Agent请求运行器""" - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: - """运行请求 - """ + async def run( + self, query: core_entities.Query + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """运行请求""" pending_tool_calls = [] - req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] + req_messages = ( + query.prompt.messages.copy() + query.messages.copy() + [query.user_message] + ) # 首次请求 - msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs) + msg = await query.use_llm_model.requester.invoke_llm( + query, query.use_llm_model, req_messages, query.use_funcs + ) yield msg @@ -34,7 +38,7 @@ class LocalAgentRunner(runner.RequestRunner): for tool_call in pending_tool_calls: try: func = tool_call.function - + parameters = json.loads(func.arguments) func_ret = await self.ap.tool_mgr.execute_func_call( @@ -42,7 +46,9 @@ class LocalAgentRunner(runner.RequestRunner): ) msg = llm_entities.Message( - role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id + role='tool', + content=json.dumps(func_ret, ensure_ascii=False), + tool_call_id=tool_call.id, ) yield msg @@ -51,7 +57,7 @@ class LocalAgentRunner(runner.RequestRunner): except Exception as e: # 工具调用出错,添加一个报错信息到 req_messages err_msg = llm_entities.Message( - role="tool", content=f"err: {e}", tool_call_id=tool_call.id + role='tool', content=f'err: {e}', tool_call_id=tool_call.id ) yield err_msg @@ -59,7 +65,9 @@ class LocalAgentRunner(runner.RequestRunner): req_messages.append(err_msg) # 处理完所有调用,再次请求 - msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs) + msg = await query.use_llm_model.requester.invoke_llm( + query, query.use_llm_model, req_messages, query.use_funcs + ) yield msg diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index eef723da..a0a582ad 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -3,13 +3,11 @@ from __future__ import annotations import asyncio from ...core import app, entities as core_entities -from ...plugin import context as plugin_context from ...provider import entities as provider_entities class SessionManager: - """会话管理器 - """ + """会话管理器""" ap: app.Application @@ -23,10 +21,12 @@ class SessionManager: pass async def get_session(self, query: core_entities.Query) -> core_entities.Session: - """获取会话 - """ + """获取会话""" for session in self.session_list: - if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id: + if ( + query.launcher_type == session.launcher_type + and query.launcher_id == session.launcher_id + ): return session session_concurrency = self.ap.instance_config.data['concurrency']['session'] @@ -39,7 +39,12 @@ class SessionManager: self.session_list.append(session) return session - async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, prompt_config: list[dict]) -> core_entities.Conversation: + async def get_conversation( + self, + query: core_entities.Query, + session: core_entities.Session, + prompt_config: list[dict], + ) -> core_entities.Conversation: """获取对话或创建对话""" if not session.conversations: @@ -52,7 +57,7 @@ class SessionManager: prompt_messages.append(provider_entities.Message(**prompt_message)) prompt = provider_entities.Prompt( - name="default", + name='default', messages=prompt_messages, ) diff --git a/pkg/provider/tools/entities.py b/pkg/provider/tools/entities.py index 746ffe92..102e03d3 100644 --- a/pkg/provider/tools/entities.py +++ b/pkg/provider/tools/entities.py @@ -1,13 +1,9 @@ from __future__ import annotations -import abc import typing -import asyncio import pydantic.v1 as pydantic -from ...core import entities as core_entities - class LLMFunction(pydantic.BaseModel): """函数""" diff --git a/pkg/provider/tools/loader.py b/pkg/provider/tools/loader.py index cae4a63f..25bb13eb 100644 --- a/pkg/provider/tools/loader.py +++ b/pkg/provider/tools/loader.py @@ -9,9 +9,10 @@ from . import entities as tools_entities preregistered_loaders: list[typing.Type[ToolLoader]] = [] + def loader_class(name: str): - """注册一个工具加载器 - """ + """注册一个工具加载器""" + def decorator(cls: typing.Type[ToolLoader]) -> typing.Type[ToolLoader]: cls.name = name preregistered_loaders.append(cls) @@ -22,7 +23,7 @@ def loader_class(name: str): class ToolLoader(abc.ABC): """工具加载器""" - + name: str = None ap: app.Application @@ -34,7 +35,7 @@ class ToolLoader(abc.ABC): pass @abc.abstractmethod - async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]: + async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]: """获取所有工具""" pass @@ -44,11 +45,13 @@ class ToolLoader(abc.ABC): pass @abc.abstractmethod - async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + async def invoke_tool( + self, query: core_entities.Query, name: str, parameters: dict + ) -> typing.Any: """执行工具调用""" pass @abc.abstractmethod async def shutdown(self): """关闭工具""" - pass \ No newline at end of file + pass diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py index 6bda7f89..5377709f 100644 --- a/pkg/provider/tools/loaders/mcp.py +++ b/pkg/provider/tools/loaders/mcp.py @@ -30,7 +30,7 @@ class RuntimeMCPSession: self.server_name = server_name self.server_config = server_config self.ap = ap - + self.session = None self.exit_stack = AsyncExitStack() @@ -38,9 +38,9 @@ class RuntimeMCPSession: async def _init_stdio_python_server(self): server_params = StdioServerParameters( - command=self.server_config["command"], - args=self.server_config["args"], - env=self.server_config["env"], + command=self.server_config['command'], + args=self.server_config['args'], + env=self.server_config['env'], ) stdio_transport = await self.exit_stack.enter_async_context( @@ -58,12 +58,12 @@ class RuntimeMCPSession: async def _init_sse_server(self): sse_transport = await self.exit_stack.enter_async_context( sse_client( - self.server_config["url"], - headers=self.server_config.get("headers", {}), - timeout=self.server_config.get("timeout", 10), + self.server_config['url'], + headers=self.server_config.get('headers', {}), + timeout=self.server_config.get('timeout', 10), ) ) - + sseio, write = sse_transport self.session = await self.exit_stack.enter_async_context( @@ -73,18 +73,22 @@ class RuntimeMCPSession: await self.session.initialize() async def initialize(self): - self.ap.logger.debug(f"初始化 MCP 会话: {self.server_name} {self.server_config}") + self.ap.logger.debug( + f'初始化 MCP 会话: {self.server_name} {self.server_config}' + ) - if self.server_config["mode"] == "stdio": + if self.server_config['mode'] == 'stdio': await self._init_stdio_python_server() - elif self.server_config["mode"] == "sse": + elif self.server_config['mode'] == 'sse': await self._init_sse_server() else: - raise ValueError(f"无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}") - + raise ValueError( + f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}' + ) + tools = await self.session.list_tools() - self.ap.logger.debug(f"获取 MCP 工具: {tools}") + self.ap.logger.debug(f'获取 MCP 工具: {tools}') for tool in tools.tools: @@ -93,25 +97,28 @@ class RuntimeMCPSession: if result.isError: raise Exception(result.content[0].text) return result.content[0].text - + func.__name__ = tool.name - self.functions.append(tools_entities.LLMFunction( - name=tool.name, - human_desc=tool.description, - description=tool.description, - parameters=tool.inputSchema, - func=func, - )) + self.functions.append( + tools_entities.LLMFunction( + name=tool.name, + human_desc=tool.description, + description=tool.description, + parameters=tool.inputSchema, + func=func, + ) + ) async def shutdown(self): """关闭工具""" await self.session._exit_stack.aclose() -@loader.loader_class("mcp") + +@loader.loader_class('mcp') class MCPLoader(loader.ToolLoader): """MCP 工具加载器。 - + 在此加载器中管理所有与 MCP Server 的连接。 """ @@ -125,16 +132,17 @@ class MCPLoader(loader.ToolLoader): self._last_listed_functions = [] async def initialize(self): - - for server_config in self.ap.instance_config.data.get("mcp", {}).get("servers", []): - if not server_config["enable"]: + for server_config in self.ap.instance_config.data.get('mcp', {}).get( + 'servers', [] + ): + if not server_config['enable']: continue - session = RuntimeMCPSession(server_config["name"], server_config, self.ap) + session = RuntimeMCPSession(server_config['name'], server_config, self.ap) await session.initialize() # self.ap.event_loop.create_task(session.initialize()) - self.sessions[server_config["name"]] = session + self.sessions[server_config['name']] = session - async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]: + async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]: all_functions = [] for session in self.sessions.values(): @@ -147,13 +155,15 @@ class MCPLoader(loader.ToolLoader): async def has_tool(self, name: str) -> bool: return name in [f.name for f in self._last_listed_functions] - async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + async def invoke_tool( + self, query: core_entities.Query, name: str, parameters: dict + ) -> typing.Any: for server_name, session in self.sessions.items(): for function in session.functions: if function.name == name: return await function.func(query, **parameters) - raise ValueError(f"未找到工具: {name}") + raise ValueError(f'未找到工具: {name}') async def shutdown(self): """关闭工具""" diff --git a/pkg/provider/tools/loaders/plugin.py b/pkg/provider/tools/loaders/plugin.py index 5b964556..c53403af 100644 --- a/pkg/provider/tools/loaders/plugin.py +++ b/pkg/provider/tools/loaders/plugin.py @@ -4,19 +4,18 @@ import typing import traceback from .. import loader, entities as tools_entities -from ....core import app, entities as core_entities +from ....core import entities as core_entities from ....plugin import context as plugin_context -@loader.loader_class("plugin-tool-loader") +@loader.loader_class('plugin-tool-loader') class PluginToolLoader(loader.ToolLoader): """插件工具加载器。 - + 本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。 """ - async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]: - + async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]: # 从插件系统获取工具(内容函数) all_functions: list[tools_entities.LLMFunction] = [] @@ -49,23 +48,23 @@ class PluginToolLoader(loader.ToolLoader): return function, plugin.plugin_inst return None, None - async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: - + async def invoke_tool( + self, query: core_entities.Query, name: str, parameters: dict + ) -> typing.Any: try: - function, plugin = await self._get_function_and_plugin(name) if function is None: return None parameters = parameters.copy() - parameters = {"query": query, **parameters} + parameters = {'query': query, **parameters} return await function.func(plugin, **parameters) except Exception as e: - self.ap.logger.error(f"执行函数 {name} 时发生错误: {e}") + self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}') traceback.print_exc() - return f"error occurred when executing function {name}: {e}" + return f'error occurred when executing function {name}: {e}' finally: plugin = None @@ -75,13 +74,12 @@ class PluginToolLoader(loader.ToolLoader): break if plugin is not None: - await self.ap.ctr_mgr.usage.post_function_record( plugin={ - "name": plugin.plugin_name, - "remote": plugin.plugin_repository, - "version": plugin.plugin_version, - "author": plugin.plugin_author, + 'name': plugin.plugin_name, + 'remote': plugin.plugin_repository, + 'version': plugin.plugin_version, + 'author': plugin.plugin_author, }, function_name=function.name, function_description=function.description, diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index 64befd8c..0f6fdac0 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -1,12 +1,13 @@ from __future__ import annotations import typing -import traceback from ...core import app, entities as core_entities from . import entities, loader as tools_loader -from ...plugin import context as plugin_context -from .loaders import plugin, mcp +from ...utils import importutil +from . import loaders + +importutil.import_modules_in_pkg(loaders) class ToolManager: @@ -22,13 +23,14 @@ class ToolManager: self.loaders = [] async def initialize(self): - for loader_cls in tools_loader.preregistered_loaders: loader_inst = loader_cls(self.ap) await loader_inst.initialize() self.loaders.append(loader_inst) - async def get_all_functions(self, plugin_enabled: bool=None) -> list[entities.LLMFunction]: + async def get_all_functions( + self, plugin_enabled: bool = None + ) -> list[entities.LLMFunction]: """获取所有函数""" all_functions: list[entities.LLMFunction] = [] @@ -37,17 +39,19 @@ class ToolManager: return all_functions - async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list: + async def generate_tools_for_openai( + self, use_funcs: list[entities.LLMFunction] + ) -> list: """生成函数列表""" tools = [] for function in use_funcs: function_schema = { - "type": "function", - "function": { - "name": function.name, - "description": function.description, - "parameters": function.parameters, + 'type': 'function', + 'function': { + 'name': function.name, + 'description': function.description, + 'parameters': function.parameters, }, } tools.append(function_schema) @@ -83,9 +87,9 @@ class ToolManager: for function in use_funcs: function_schema = { - "name": function.name, - "description": function.description, - "input_schema": function.parameters, + 'name': function.name, + 'description': function.description, + 'input_schema': function.parameters, } tools.append(function_schema) @@ -100,7 +104,7 @@ class ToolManager: if await loader.has_tool(name): return await loader.invoke_tool(query, name, parameters) else: - raise ValueError(f"未找到工具: {name}") + raise ValueError(f'未找到工具: {name}') async def shutdown(self): """关闭所有工具""" diff --git a/pkg/utils/announce.py b/pkg/utils/announce.py index 1fb6e166..47de9325 100644 --- a/pkg/utils/announce.py +++ b/pkg/utils/announce.py @@ -14,7 +14,7 @@ from ..core import app class Announcement(pydantic.BaseModel): """公告""" - + id: int time: str @@ -27,11 +27,11 @@ class Announcement(pydantic.BaseModel): def to_dict(self) -> dict: return { - "id": self.id, - "time": self.time, - "timestamp": self.timestamp, - "content": self.content, - "enabled": self.enabled + 'id': self.id, + 'time': self.time, + 'timestamp': self.timestamp, + 'content': self.content, + 'enabled': self.enabled, } @@ -43,30 +43,28 @@ class AnnouncementManager: def __init__(self, ap: app.Application): self.ap = ap - async def fetch_all( - self - ) -> list[Announcement]: + async def fetch_all(self) -> list[Announcement]: """获取所有公告""" resp = requests.get( - url="https://api.github.com/repos/RockChinQ/LangBot/contents/res/announcement.json", + url='https://api.github.com/repos/RockChinQ/LangBot/contents/res/announcement.json', proxies=self.ap.proxy_mgr.get_forward_proxies(), - timeout=5 + timeout=5, ) obj_json = resp.json() - b64_content = obj_json["content"] + b64_content = obj_json['content'] # 解码 - content = base64.b64decode(b64_content).decode("utf-8") + content = base64.b64decode(b64_content).decode('utf-8') return [Announcement(**item) for item in json.loads(content)] - async def fetch_saved( - self - ) -> list[Announcement]: - if not os.path.exists("data/labels/announcement_saved.json"): - with open("data/labels/announcement_saved.json", "w", encoding="utf-8") as f: - f.write("[]") + async def fetch_saved(self) -> list[Announcement]: + if not os.path.exists('data/labels/announcement_saved.json'): + with open( + 'data/labels/announcement_saved.json', 'w', encoding='utf-8' + ) as f: + f.write('[]') - with open("data/labels/announcement_saved.json", "r", encoding="utf-8") as f: + with open('data/labels/announcement_saved.json', 'r', encoding='utf-8') as f: content = f.read() if not content: @@ -74,19 +72,15 @@ class AnnouncementManager: return [Announcement(**item) for item in json.loads(content)] - async def write_saved( - self, - content: list[Announcement] - ): + async def write_saved(self, content: list[Announcement]): + with open('data/labels/announcement_saved.json', 'w', encoding='utf-8') as f: + f.write( + json.dumps( + [item.to_dict() for item in content], indent=4, ensure_ascii=False + ) + ) - with open("data/labels/announcement_saved.json", "w", encoding="utf-8") as f: - f.write(json.dumps([ - item.to_dict() for item in content - ], indent=4, ensure_ascii=False)) - - async def fetch_new( - self - ) -> list[Announcement]: + async def fetch_new(self) -> list[Announcement]: """获取新公告""" all = await self.fetch_all() saved = await self.fetch_saved() @@ -106,18 +100,15 @@ class AnnouncementManager: await self.write_saved(all) return to_show - async def show_announcements( - self - ) -> typing.Tuple[str, int]: + async def show_announcements(self) -> typing.Tuple[str, int]: """显示公告""" try: announcements = await self.fetch_new() - ann_text = "" + ann_text = '' for ann in announcements: - ann_text += f"[公告] {ann.time}: {ann.content}\n" + ann_text += f'[公告] {ann.time}: {ann.content}\n' if announcements: - await self.ap.ctr_mgr.main.post_announcement_showed( ids=[item.id for item in announcements] ) diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py index f37f2151..7f3d0804 100644 --- a/pkg/utils/constants.py +++ b/pkg/utils/constants.py @@ -1,8 +1,8 @@ -semantic_version = "v4.0.0" +semantic_version = 'v4.0.0' required_database_version = 1 """标记本版本所需要的数据库结构版本,用于判断数据库迁移""" debug_mode = False -edition = 'community' \ No newline at end of file +edition = 'community' diff --git a/pkg/utils/funcschema.py b/pkg/utils/funcschema.py index c39b4886..52dd6efc 100644 --- a/pkg/utils/funcschema.py +++ b/pkg/utils/funcschema.py @@ -1,4 +1,3 @@ -import sys import re import inspect @@ -33,18 +32,18 @@ def get_func_schema(function: callable) -> dict: func_doc = function.__doc__ # Google Style Docstring if func_doc is None: - raise Exception("Function {} has no docstring.".format(function.__name__)) - func_doc = func_doc.strip().replace(' ','').replace('\t', '') + raise Exception('Function {} has no docstring.'.format(function.__name__)) + func_doc = func_doc.strip().replace(' ', '').replace('\t', '') # extract doc of args from docstring doc_spt = func_doc.split('\n\n') desc = doc_spt[0] - args = doc_spt[1] if len(doc_spt) > 1 else "" - returns = doc_spt[2] if len(doc_spt) > 2 else "" + args = doc_spt[1] if len(doc_spt) > 1 else '' + # returns = doc_spt[2] if len(doc_spt) > 2 else "" # extract args # delete the first line of args arg_lines = args.split('\n')[1:] - arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args) + # arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args) args_doc = {} for arg_line in arg_lines: doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line) @@ -53,18 +52,16 @@ def get_func_schema(function: callable) -> dict: args_doc[doc_tuple[0][0]] = doc_tuple[0][3] # extract returns - return_doc_list = re.findall(r'(\w+):\s*(.*)', returns) + # return_doc_list = re.findall(r'(\w+):\s*(.*)', returns) params = enumerate(inspect.signature(function).parameters.values()) parameters = { - "type": "object", - "required": [], - "properties": {}, + 'type': 'object', + 'required': [], + 'properties': {}, } - for i, param in params: - # 排除 self, query if param.name in ['self', 'query']: continue @@ -72,24 +69,24 @@ def get_func_schema(function: callable) -> dict: param_type = param.annotation.__name__ type_name_mapping = { - "str": "string", - "int": "integer", - "float": "number", - "bool": "boolean", - "list": "array", - "dict": "object", + 'str': 'string', + 'int': 'integer', + 'float': 'number', + 'bool': 'boolean', + 'list': 'array', + 'dict': 'object', } if param_type in type_name_mapping: param_type = type_name_mapping[param_type] parameters['properties'][param.name] = { - "type": param_type, - "description": args_doc[param.name], + 'type': param_type, + 'description': args_doc[param.name], } # add schema for array - if param_type == "array": + if param_type == 'array': # extract type of array, the int of list[int] # use re array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation)) @@ -102,15 +99,15 @@ def get_func_schema(function: callable) -> dict: if array_type in type_name_mapping: array_type = type_name_mapping[array_type] - parameters['properties'][param.name]["items"] = { - "type": array_type, + parameters['properties'][param.name]['items'] = { + 'type': array_type, } if param.default is inspect.Parameter.empty: - parameters["required"].append(param.name) + parameters['required'].append(param.name) return { - "function": function, - "description": desc, - "parameters": parameters, - } \ No newline at end of file + 'function': function, + 'description': desc, + 'parameters': parameters, + } diff --git a/pkg/utils/image.py b/pkg/utils/image.py index 760c2128..9af766fb 100644 --- a/pkg/utils/image.py +++ b/pkg/utils/image.py @@ -8,23 +8,16 @@ import aiohttp import PIL.Image import httpx -import os -import aiofiles -import pathlib import asyncio -from urllib.parse import urlparse - - - async def get_gewechat_image_base64( - gewechat_url: str, - gewechat_file_url: str, - app_id: str, - xml_content: str, - token: str, - image_type: int = 2, + gewechat_url: str, + gewechat_file_url: str, + app_id: str, + xml_content: str, + token: str, + image_type: int = 2, ) -> typing.Tuple[str, str]: """从gewechat服务器获取图片并转换为base64格式 @@ -43,17 +36,14 @@ async def get_gewechat_image_base64( aiohttp.ClientTimeout: 请求超时(15秒)或连接超时(2秒) Exception: 其他错误 """ - headers = { - 'X-GEWE-TOKEN': token, - 'Content-Type': 'application/json' - } + headers = {'X-GEWE-TOKEN': token, 'Content-Type': 'application/json'} # 设置超时 timeout = aiohttp.ClientTimeout( total=15.0, # 总超时时间15秒 connect=2.0, # 连接超时2秒 sock_connect=2.0, # socket连接超时2秒 - sock_read=15.0 # socket读取超时15秒 + sock_read=15.0, # socket读取超时15秒 ) try: @@ -61,37 +51,37 @@ async def get_gewechat_image_base64( # 获取图片下载链接 try: async with session.post( - f"{gewechat_url}/v2/api/message/downloadImage", - headers=headers, - json={ - "appId": app_id, - "type": image_type, - "xml": xml_content - } + f'{gewechat_url}/v2/api/message/downloadImage', + headers=headers, + json={'appId': app_id, 'type': image_type, 'xml': xml_content}, ) as response: if response.status != 200: # print(response) - raise Exception(f"获取gewechat图片下载失败: {await response.text()}") + raise Exception( + f'获取gewechat图片下载失败: {await response.text()}' + ) resp_data = await response.json() - if resp_data.get("ret") != 200: - raise Exception(f"获取gewechat图片下载链接失败: {resp_data}") + if resp_data.get('ret') != 200: + raise Exception(f'获取gewechat图片下载链接失败: {resp_data}') file_url = resp_data['data']['fileUrl'] except asyncio.TimeoutError: - raise Exception("获取图片下载链接超时") + raise Exception('获取图片下载链接超时') except aiohttp.ClientError as e: - raise Exception(f"获取图片下载链接网络错误: {str(e)}") + raise Exception(f'获取图片下载链接网络错误: {str(e)}') # 解析原始URL并替换端口 base_url = gewechat_file_url - download_url = f"{base_url}/download/{file_url}" + download_url = f'{base_url}/download/{file_url}' # 下载图片 try: async with session.get(download_url) as img_response: if img_response.status != 200: - raise Exception(f"下载图片失败: {await img_response.text()}, URL: {download_url}") + raise Exception( + f'下载图片失败: {await img_response.text()}, URL: {download_url}' + ) image_data = await img_response.read() @@ -105,14 +95,11 @@ async def get_gewechat_image_base64( return base64_str, image_format except asyncio.TimeoutError: - raise Exception(f"下载图片超时, URL: {download_url}") + raise Exception(f'下载图片超时, URL: {download_url}') except aiohttp.ClientError as e: - raise Exception(f"下载图片网络错误: {str(e)}, URL: {download_url}") + raise Exception(f'下载图片网络错误: {str(e)}, URL: {download_url}') except Exception as e: - raise Exception(f"获取图片失败: {str(e)}") from e - - - + raise Exception(f'获取图片失败: {str(e)}') from e async def get_wecom_image_base64(pic_url: str) -> tuple[str, str]: @@ -124,22 +111,26 @@ async def get_wecom_image_base64(pic_url: str) -> tuple[str, str]: async with aiohttp.ClientSession() as session: async with session.get(pic_url) as response: if response.status != 200: - raise Exception(f"Failed to download image: {response.status}") - + raise Exception(f'Failed to download image: {response.status}') + # 读取图片数据 image_data = await response.read() - + # 获取图片格式 content_type = response.headers.get('Content-Type', '') image_format = content_type.split('/')[-1] # 例如 'image/jpeg' -> 'jpeg' - + # 转换为 base64 import base64 + image_base64 = base64.b64encode(image_data).decode('utf-8') - + return image_base64, image_format - -async def get_qq_official_image_base64(pic_url:str,content_type:str) -> tuple[str,str]: + + +async def get_qq_official_image_base64( + pic_url: str, content_type: str +) -> tuple[str, str]: """ 下载QQ官方图片, 并且转换为base64格式 @@ -149,18 +140,18 @@ async def get_qq_official_image_base64(pic_url:str,content_type:str) -> tuple[st response.raise_for_status() # 确保请求成功 image_data = response.content base64_data = base64.b64encode(image_data).decode('utf-8') - - return f"data:{content_type};base64,{base64_data}" + + return f'data:{content_type};base64,{base64_data}' def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]: """获取QQ图片的下载链接""" parsed = urlparse(image_url) query = parse_qs(parsed.query) - return f"http://{parsed.netloc}{parsed.path}", query + return f'http://{parsed.netloc}{parsed.path}', query -async def get_qq_image_bytes(image_url: str, query: dict={}) -> tuple[bytes, str]: +async def get_qq_image_bytes(image_url: str, query: dict = {}) -> tuple[bytes, str]: """[弃用]获取QQ图片的bytes""" image_url, query_in_url = get_qq_image_downloadable_url(image_url) query = {**query, **query_in_url} @@ -177,14 +168,12 @@ async def get_qq_image_bytes(image_url: str, query: dict={}) -> tuple[bytes, str elif not content_type.startswith('image/'): pil_img = PIL.Image.open(io.BytesIO(file_bytes)) image_format = pil_img.format.lower() - else: + else: image_format = content_type.split('/')[-1] return file_bytes, image_format -async def qq_image_url_to_base64( - image_url: str -) -> typing.Tuple[str, str]: +async def qq_image_url_to_base64(image_url: str) -> typing.Tuple[str, str]: """[弃用]将QQ图片URL转为base64,并返回图片格式 Args: @@ -204,12 +193,13 @@ async def qq_image_url_to_base64( return base64_str, image_format + async def extract_b64_and_format(image_base64_data: str) -> typing.Tuple[str, str]: """提取base64编码和图片格式 - +  提取出base64编码和图片格式 """ base64_str = image_base64_data.split(',')[-1] image_format = image_base64_data.split(':')[-1].split(';')[0].split('/')[-1] - return base64_str, image_format \ No newline at end of file + return base64_str, image_format diff --git a/pkg/utils/importutil.py b/pkg/utils/importutil.py new file mode 100644 index 00000000..87ca652a --- /dev/null +++ b/pkg/utils/importutil.py @@ -0,0 +1,43 @@ +import importlib +import importlib.util +import os +import typing + + +def import_modules_in_pkg(pkg: typing.Any) -> None: + """ + 导入一个包内的所有模块 + Args: + pkg: 要导入的包对象 + """ + pkg_path = os.path.dirname(pkg.__file__) + import_dir(pkg_path) + + +def import_modules_in_pkgs(pkgs: typing.List) -> None: + for pkg in pkgs: + import_modules_in_pkg(pkg) + + +def import_dot_style_dir(dot_sep_path: str): + sec = dot_sep_path.split('.') + + return import_dir(os.path.join(*sec)) + + +def import_dir(path: str): + for file in os.listdir(path): + if file.endswith('.py') and file != '__init__.py': + full_path = os.path.join(path, file) + rel_path = full_path.replace( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '' + ) + rel_path = rel_path[1:] + rel_path = rel_path.replace('/', '.')[:-3] + importlib.import_module(rel_path) + + +if __name__ == '__main__': + from pkg.platform import types + + import_modules_in_pkg(types) diff --git a/pkg/utils/ip.py b/pkg/utils/ip.py index 1250f99e..56a12086 100644 --- a/pkg/utils/ip.py +++ b/pkg/utils/ip.py @@ -1,9 +1,12 @@ import aiohttp + async def get_myip() -> str: try: - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10)) as session: - async with session.get("https://ip.useragentinfo.com/myip") as response: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=10) + ) as session: + async with session.get('https://ip.useragentinfo.com/myip') as response: return await response.text() - except Exception as e: - return '0.0.0.0' \ No newline at end of file + except Exception: + return '0.0.0.0' diff --git a/pkg/utils/logcache.py b/pkg/utils/logcache.py index d3206e9b..84c58f55 100644 --- a/pkg/utils/logcache.py +++ b/pkg/utils/logcache.py @@ -5,8 +5,9 @@ LOG_PAGE_SIZE = 20 MAX_CACHED_PAGES = 10 -class LogPage(): +class LogPage: """日志页""" + number: int """页码""" @@ -51,12 +52,12 @@ class LogCache: start_offset: int, ) -> tuple[str, int, int]: """获取指定页码和偏移量的日志""" - final_logs_str = "" + final_logs_str = '' for page in self.log_pages: if page.number == start_page_number: - final_logs_str += "\n".join(page.logs[start_offset:]) + final_logs_str += '\n'.join(page.logs[start_offset:]) elif page.number > start_page_number: - final_logs_str += "\n".join(page.logs) + final_logs_str += '\n'.join(page.logs) return final_logs_str, page.number, len(page.logs) diff --git a/pkg/utils/pkgmgr.py b/pkg/utils/pkgmgr.py index 4f3186b6..9ce8bdb8 100644 --- a/pkg/utils/pkgmgr.py +++ b/pkg/utils/pkgmgr.py @@ -6,8 +6,17 @@ def install(package): def install_upgrade(package): - pipmain(['install', '--upgrade', package, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple", - "--trusted-host", "pypi.tuna.tsinghua.edu.cn"]) + pipmain( + [ + 'install', + '--upgrade', + package, + '-i', + 'https://pypi.tuna.tsinghua.edu.cn/simple', + '--trusted-host', + 'pypi.tuna.tsinghua.edu.cn', + ] + ) def run_pip(params: list): @@ -15,5 +24,15 @@ def run_pip(params: list): def install_requirements(file, extra_params: list = []): - pipmain(['install', '-r', file, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple", - "--trusted-host", "pypi.tuna.tsinghua.edu.cn"] + extra_params) + pipmain( + [ + 'install', + '-r', + file, + '-i', + 'https://pypi.tuna.tsinghua.edu.cn/simple', + '--trusted-host', + 'pypi.tuna.tsinghua.edu.cn', + ] + + extra_params + ) diff --git a/pkg/utils/proxy.py b/pkg/utils/proxy.py index 838814dd..4f3f7dec 100644 --- a/pkg/utils/proxy.py +++ b/pkg/utils/proxy.py @@ -1,14 +1,12 @@ from __future__ import annotations import os -import sys from ..core import app class ProxyManager: - """代理管理器 - """ + """代理管理器""" ap: app.Application @@ -21,14 +19,24 @@ class ProxyManager: async def initialize(self): self.forward_proxies = { - "http://": os.getenv("HTTP_PROXY") or os.getenv("http_proxy"), - "https://": os.getenv("HTTPS_PROXY") or os.getenv("https_proxy"), + 'http://': os.getenv('HTTP_PROXY') or os.getenv('http_proxy'), + 'https://': os.getenv('HTTPS_PROXY') or os.getenv('https_proxy'), } - if 'http' in self.ap.instance_config.data['proxy'] and self.ap.instance_config.data['proxy']['http']: - self.forward_proxies['http://'] = self.ap.instance_config.data['proxy']['http'] - if 'https' in self.ap.instance_config.data['proxy'] and self.ap.instance_config.data['proxy']['https']: - self.forward_proxies['https://'] = self.ap.instance_config.data['proxy']['https'] + if ( + 'http' in self.ap.instance_config.data['proxy'] + and self.ap.instance_config.data['proxy']['http'] + ): + self.forward_proxies['http://'] = self.ap.instance_config.data['proxy'][ + 'http' + ] + if ( + 'https' in self.ap.instance_config.data['proxy'] + and self.ap.instance_config.data['proxy']['https'] + ): + self.forward_proxies['https://'] = self.ap.instance_config.data['proxy'][ + 'https' + ] # 设置到环境变量 os.environ['HTTP_PROXY'] = self.forward_proxies['http://'] or '' diff --git a/pkg/utils/version.py b/pkg/utils/version.py index 9a206171..ef30c192 100644 --- a/pkg/utils/version.py +++ b/pkg/utils/version.py @@ -12,41 +12,33 @@ from . import constants class VersionManager: - """版本管理器 - """ + """版本管理器""" ap: app.Application - def __init__( - self, - ap: app.Application - ): + def __init__(self, ap: app.Application): self.ap = ap - async def initialize( - self - ): + async def initialize(self): pass - - def get_current_version( - self - ) -> str: + + def get_current_version(self) -> str: current_tag = constants.semantic_version return current_tag - + async def get_release_list(self) -> list: """获取发行列表""" rls_list_resp = requests.get( - url="https://api.github.com/repos/RockChinQ/LangBot/releases", + url='https://api.github.com/repos/RockChinQ/LangBot/releases', proxies=self.ap.proxy_mgr.get_forward_proxies(), - timeout=5 + timeout=5, ) rls_list = rls_list_resp.json() return rls_list - + async def update_all(self): """检查更新并下载源码""" start_time = time.time() @@ -58,10 +50,10 @@ class VersionManager: latest_rls = {} rls_notes = [] - latest_tag_name = "" + latest_tag_name = '' for rls in rls_list: rls_notes.append(rls['name']) # 使用发行名称作为note - if latest_tag_name == "": + if latest_tag_name == '': latest_tag_name = rls['tag_name'] if rls['tag_name'] == current_tag: @@ -69,56 +61,68 @@ class VersionManager: if latest_rls == {}: latest_rls = rls - self.ap.logger.info("更新日志: {}".format(rls_notes)) + self.ap.logger.info('更新日志: {}'.format(rls_notes)) - if latest_rls == {} and not self.is_newer(latest_tag_name, current_tag): # 没有新版本 + if latest_rls == {} and not self.is_newer( + latest_tag_name, current_tag + ): # 没有新版本 return False # 下载最新版本的zip到temp目录 - self.ap.logger.info("开始下载最新版本: {}".format(latest_rls['zipball_url'])) + self.ap.logger.info('开始下载最新版本: {}'.format(latest_rls['zipball_url'])) zip_url = latest_rls['zipball_url'] zip_resp = requests.get( - url=zip_url, - proxies=self.ap.proxy_mgr.get_forward_proxies() + url=zip_url, proxies=self.ap.proxy_mgr.get_forward_proxies() ) zip_data = zip_resp.content # 检查temp/updater目录 - if not os.path.exists("temp"): - os.mkdir("temp") - if not os.path.exists("temp/updater"): - os.mkdir("temp/updater") - with open("temp/updater/{}.zip".format(latest_rls['tag_name']), "wb") as f: + if not os.path.exists('temp'): + os.mkdir('temp') + if not os.path.exists('temp/updater'): + os.mkdir('temp/updater') + with open('temp/updater/{}.zip'.format(latest_rls['tag_name']), 'wb') as f: f.write(zip_data) - self.ap.logger.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name']))) + self.ap.logger.info( + '下载最新版本完成: {}'.format( + 'temp/updater/{}.zip'.format(latest_rls['tag_name']) + ) + ) # 解压zip到temp/updater// import zipfile + # 检查目标文件夹 - if os.path.exists("temp/updater/{}".format(latest_rls['tag_name'])): + if os.path.exists('temp/updater/{}'.format(latest_rls['tag_name'])): import shutil - shutil.rmtree("temp/updater/{}".format(latest_rls['tag_name'])) - os.mkdir("temp/updater/{}".format(latest_rls['tag_name'])) - with zipfile.ZipFile("temp/updater/{}.zip".format(latest_rls['tag_name']), 'r') as zip_ref: - zip_ref.extractall("temp/updater/{}".format(latest_rls['tag_name'])) + + shutil.rmtree('temp/updater/{}'.format(latest_rls['tag_name'])) + os.mkdir('temp/updater/{}'.format(latest_rls['tag_name'])) + with zipfile.ZipFile( + 'temp/updater/{}.zip'.format(latest_rls['tag_name']), 'r' + ) as zip_ref: + zip_ref.extractall('temp/updater/{}'.format(latest_rls['tag_name'])) # 覆盖源码 - source_root = "" + source_root = '' # 找到temp/updater//中的第一个子目录路径 - for root, dirs, files in os.walk("temp/updater/{}".format(latest_rls['tag_name'])): - if root != "temp/updater/{}".format(latest_rls['tag_name']): + for root, dirs, files in os.walk( + 'temp/updater/{}'.format(latest_rls['tag_name']) + ): + if root != 'temp/updater/{}'.format(latest_rls['tag_name']): source_root = root break # 覆盖源码 import shutil + for root, dirs, files in os.walk(source_root): # 覆盖所有子文件子目录 for file in files: src = os.path.join(root, file) - dst = src.replace(source_root, ".") + dst = src.replace(source_root, '.') if os.path.exists(dst): os.remove(dst) @@ -128,18 +132,18 @@ class VersionManager: # 检查目标文件是否存在 if not os.path.exists(dst): # 创建目标文件 - open(dst, "w").close() + open(dst, 'w').close() shutil.copy(src, dst) # 把current_tag写入文件 current_tag = latest_rls['tag_name'] - with open("current_tag", "w") as f: + with open('current_tag', 'w') as f: f.write(current_tag) await self.ap.ctr_mgr.main.post_update_record( - spent_seconds=int(time.time()-start_time), - infer_reason="update", + spent_seconds=int(time.time() - start_time), + infer_reason='update', old_version=old_tag, new_version=current_tag, ) @@ -155,23 +159,22 @@ class VersionManager: current_tag = self.get_current_version() # 检查是否有新版本 - latest_tag_name = "" + latest_tag_name = '' for rls in rls_list: - if latest_tag_name == "": + if latest_tag_name == '': latest_tag_name = rls['tag_name'] break return self.is_newer(latest_tag_name, current_tag) - def is_newer(self, new_tag: str, old_tag: str): """判断版本是否更新,忽略第四位版本和第一位版本""" if new_tag == old_tag: return False - new_tag = new_tag.split(".") - old_tag = old_tag.split(".") - + new_tag = new_tag.split('.') + old_tag = old_tag.split('.') + # 判断主版本是否相同 if new_tag[0] != old_tag[0]: return False @@ -180,29 +183,28 @@ class VersionManager: return True # 合成前三段,判断是否相同 - new_tag = ".".join(new_tag[:3]) - old_tag = ".".join(old_tag[:3]) + new_tag = '.'.join(new_tag[:3]) + old_tag = '.'.join(old_tag[:3]) return new_tag != old_tag - def compare_version_str(v0: str, v1: str) -> int: """比较两个版本号""" # 删除版本号前的v - if v0.startswith("v"): + if v0.startswith('v'): v0 = v0[1:] - if v1.startswith("v"): + if v1.startswith('v'): v1 = v1[1:] - v0:list = v0.split(".") - v1:list = v1.split(".") + v0: list = v0.split('.') + v1: list = v1.split('.') # 如果两个版本号节数不同,把短的后面用0补齐 if len(v0) < len(v1): - v0.extend(["0"]*(len(v1)-len(v0))) + v0.extend(['0'] * (len(v1) - len(v0))) elif len(v0) > len(v1): - v1.extend(["0"]*(len(v0)-len(v1))) + v1.extend(['0'] * (len(v0) - len(v1))) # 从高位向低位比较 for i in range(len(v0)): @@ -210,16 +212,16 @@ class VersionManager: return 1 elif int(v0[i]) < int(v1[i]): return -1 - + return 0 - async def show_version_update( - self - ) -> typing.Tuple[str, int]: + async def show_version_update(self) -> typing.Tuple[str, int]: try: - if await self.ap.ver_mgr.is_new_version_available(): - return "有新版本可用,根据文档更新:https://docs.langbot.app/deploy/update.html", logging.INFO - + return ( + '有新版本可用,根据文档更新:https://docs.langbot.app/deploy/update.html', + logging.INFO, + ) + except Exception as e: - return f"检查版本更新时出错: {e}", logging.WARNING + return f'检查版本更新时出错: {e}', logging.WARNING diff --git a/requirements.txt b/requirements.txt index e01eb373..b34c18a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,4 +37,6 @@ mcp sqlmodel # indirect -taskgroup==0.0.0a4 \ No newline at end of file +taskgroup==0.0.0a4 +ruff +pre-commit \ No newline at end of file diff --git a/res/scripts/publish_announcement.py b/res/scripts/publish_announcement.py index 812e83d9..7d2e7d40 100644 --- a/res/scripts/publish_announcement.py +++ b/res/scripts/publish_announcement.py @@ -1,32 +1,32 @@ # 输出工作路径 import os -print("工作路径: " + os.getcwd()) -announcement = input("请输入公告内容: ") - +import time import json +print('工作路径: ' + os.getcwd()) +announcement = input('请输入公告内容: ') + # 读取现有的公告文件 res/announcement.json -with open("res/announcement.json", "r", encoding="utf-8") as f: +with open('res/announcement.json', 'r', encoding='utf-8') as f: announcement_json = json.load(f) # 将公告内容写入公告文件 # 当前自然时间 -import time -now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) +now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) # 获取最后一个公告的id -last_id = announcement_json[-1]["id"] if len(announcement_json) > 0 else -1 +last_id = announcement_json[-1]['id'] if len(announcement_json) > 0 else -1 announcement = { - "id": last_id + 1, - "time": now, - "timestamp": int(time.time()), - "content": announcement + 'id': last_id + 1, + 'time': now, + 'timestamp': int(time.time()), + 'content': announcement, } announcement_json.append(announcement) # 将公告写入公告文件 -with open("res/announcement.json", "w", encoding="utf-8") as f: +with open('res/announcement.json', 'w', encoding='utf-8') as f: json.dump(announcement_json, f, indent=4, ensure_ascii=False) diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..27159c90 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,38 @@ +[lint] + +ignore = [ + "E712", # Comparison to true should be 'if cond is true:' or 'if cond:' (E712) + "F402", # Import `loader` from line 8 shadowed by loop variable + "F403", # * used, unable to detect undefined names + "F405", # may be undefined, or defined from star imports + "E741", # Ambiguous variable name: `l` + "E722", # bare-except + "E721", # type-comparison + "FURB113", # repeated-append + "FURB152", # math-constant + "UP007", # non-pep604-annotation + "UP032", # f-string + "UP045", # non-pep604-annotation-optional + "B005", # strip-with-multi-characters + "B006", # mutable-argument-default + "B007", # unused-loop-control-variable + "B026", # star-arg-unpacking-after-keyword-arg + "B903", # class-as-data-structure + "B904", # raise-without-from-inside-except + "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope + "PT011", # pytest-raises-too-broad + "SIM102", # collapsible-if + "SIM103", # needless-bool + "SIM105", # suppressible-exception + "SIM107", # return-in-try-except-finally + "SIM108", # if-else-block-instead-of-if-exp + "SIM113", # enumerate-for-loop + "SIM117", # multiple-with-statements + "SIM210", # if-expr-with-true-false +] + +[format] +# 5. Use single quotes in `ruff format`. +quote-style = "single" \ No newline at end of file