Compare commits

..

134 Commits

Author SHA1 Message Date
Junyan Qin
2c2a89d9db chore: bump version 4.4.1 2025-11-06 00:09:35 +08:00
Junyan Qin (Chin)
c91e2f0efe feat: add file array[file] and text type plugin config fields (#1750)
* feat: add   and  type plugin config fields

* chore: add hant and jp i18n

* feat: plugin config file auto clean

* chore: bump langbot-plugin to 0.1.8

* chore: fix linter errors
2025-11-06 00:07:57 +08:00
Junyan Qin
411d082d2a chore: fix linter errors 2025-11-06 00:07:43 +08:00
Junyan Qin
d4e08a1765 chore: bump langbot-plugin to 0.1.8 2025-11-06 00:05:03 +08:00
Junyan Qin
b529d07479 feat: plugin config file auto clean 2025-11-06 00:02:25 +08:00
Junyan Qin
d44df75e5c chore: add hant and jp i18n 2025-11-05 23:54:34 +08:00
Junyan Qin
b74e07b608 feat: add and type plugin config fields 2025-11-05 23:48:59 +08:00
Junyan Qin
4a868afecd fix: plugin mgm page mistakely refreshed when open acc option menu 2025-11-05 18:59:40 +08:00
Junyan Qin
1cb9560663 perf: only check connecting mcp server when it's enabled 2025-11-05 18:53:17 +08:00
Junyan Qin
8f878673ae feat: add supports for showing image in plugin readme 2025-11-05 18:42:14 +08:00
Junyan Qin
74a5e37892 perf: plugin market layout 2025-11-05 18:34:40 +08:00
Copilot
76a69ecc7e Add environment variable override support for config.yaml (#1748)
* Initial plan

* Add environment variable override support for config.yaml

Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>

* Refactor env override code based on review feedback

Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>

* Add test for template completion with env overrides

Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>

* Move env override logic to load_config.py as requested

Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>

* perf: add print log

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>
Co-authored-by: Junyan Qin <rockchinq@gmail.com>
2025-11-05 18:15:15 +08:00
Alfons
f06e3d3efa fix: disabling potential thinking param for model testing (#1733)
* fix: 禁用模型默认思考功能以减少测试延迟

- 调整导入语句顺序
- 为没有显式设置 thinking 参数的模型添加禁用配置
- 避免某些模型厂商默认开启思考功能导致的测试延迟

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: 确保 extra_args 为空时也禁用思考功能

修复条件判断逻辑,当 extra_args 为空字典时也会添加思考功能禁用配置

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* perf(fe): increase default timeout

* perf: llm model testing prompt

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Junyan Qin <rockchinq@gmail.com>
2025-11-05 15:52:17 +08:00
Guanchao Wang
973e7bae42 fix: wecombot id (#1747) 2025-11-05 12:14:01 +08:00
Junyan Qin
94aa175c1a chore: bump langbot-plugin to 0.1.7 2025-11-05 12:11:46 +08:00
Junyan Qin
777b766fff chore: bump version 4.4.0 2025-11-04 22:05:49 +08:00
Guanchao Wang
1adaa93034 Fix/mcp (#1746)
* fix: mcp session cannot be enabled

* fix: error message

* perf: ui

* perf: ui

---------

Co-authored-by: Junyan Qin <rockchinq@gmail.com>
2025-11-04 22:02:40 +08:00
Junyan Qin
9853eccd89 chore: bump langbot-plugin to 0.1.6 2025-11-04 21:11:33 +08:00
Copilot
7699ba3cae feat: add supports for install plugin from GitHub repo releases
Add GitHub release installation for plugins
2025-11-04 21:09:14 +08:00
Junyan Qin (Chin)
9ac8b1a6fd feat: ui for mcp (#1600)
* feat: code by huntun

* chore: revert group.py

* refactor: api

* feat: adjust ui

* chore: stash

* feat: add dialog

* feat: add mcp from sse on frontend

* feat: add mcp db

* feat: semi frontend

* feat: change sse frontend

* fix: page out of control

* fix: mcp card

* fix: mcp refactor

* fix: delete description

* feat: add mcp servers

* fix: status icon

* feat: mcp-ui

* perf: remove title from mcp mgm page

* fix: delete mcp market

* feat: add i18n

* fix: run lint

* feat: add i18n

* fix: delete print function

* fix: mcp test error

* fix: i18n and mcp test

* refactor(mcp): bridge controller and db operation with service layer

* fix: try & catch & error

* fix: error message in mcp card

* feat: no longer register tool loader as component for type hints

* perf: make startup async

* feat: completely remove the fucking mcp market components and refs

* refactor: mcp server datastructure

* perf: tidy dir

* feat: perf mcp server api datastruct

* perf: ui

* perf: mcp server status checking logic

* perf: mcp server testing and refreshing

* perf: no mcp server tips

* perf: update sidebar title

* chore: update

* chore: bump langbot-plugin to 0.1.3

* chore: bump version v4.3.4

* chore: release v4.3.5

* Fix: Correct data type mismatch in AtBotRule (#1705)

Fix can't '@' in QQ group.

* chore: bump version 4.3.6

* feat: update for new events fields

* Fix/qqo (#1709)

* fix: qq official

* fix: appid

* chore: add `codecov.yml`

* chore: bump langbot-plugin to 0.1.4b2

* chore: bump version 4.3.7b1

* fix: return empty data when plugin system disabled (#1710)

* chore: bump version 4.3.7

* fix: bad Plain component init in wechatpad (#1712)

* perf: allow not set llm model (#1703)

* perf: output pipeline error in en

* fix: datetime serialization error in emit_event (#1713)

* chore: bump version 4.3.8

* perf: add component list in plugin detail dialog

* perf: store pipeline sort method

* Feat/coze runner (#1714)

* feat:add coze api client and coze runner and coze config

* del print

* fix:Change the default setting of the plugin system to true

* fix:del multimodal-support config, default multimodal-support,and in cozeapi.py Obtain timeout and auto-save-history config

* chore: add comment for coze.com

---------

Co-authored-by: Junyan Qin <rockchinq@gmail.com>

* chore: bump version 4.3.9

* feat: 实现企业微信智能机器人流式响应

- 重构 WecomBotClient,支持流式会话管理和队列机制
- 新增 StreamSession 和 StreamSessionManager 类管理流式上下文
- 实现 reply_message_chunk 接口支持流式输出
- 优化消息处理流程,支持异步流式响应

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: split WeCom callback handlers

* fix: langchain error

* fix: add langchain test splitter module

* perf: config reset logic (#1742)

* fix: inherit settings from existing settings

* feat: add optional data cleanup checkbox to plugin uninstall dialog (#1743)

* Initial plan

* Add checkbox for plugin config/storage deletion

- Add delete_data parameter to backend API endpoint
- Update delete_plugin flow to clean up settings and binary storage
- Add checkbox in uninstall dialog using shadcn/ui
- Add translations for checkbox label in all languages

Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>

* perf: param list

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>
Co-authored-by: Junyan Qin <rockchinq@gmail.com>

* chore: fix linter errors

---------

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>

---------

Co-authored-by: WangCham <651122857@qq.com>
Co-authored-by: wangcham <wangcham233@gmail.com>
Co-authored-by: Thetail001 <56257172+Thetail001@users.noreply.github.com>
Co-authored-by: fdc310 <82008029+fdc310@users.noreply.github.com>
Co-authored-by: Alfons <alfonsxh@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
2025-11-04 18:49:16 +08:00
Junyan Qin (Chin)
f476c4724d Merge branch 'master' into feat/mcp-ui 2025-11-04 18:48:30 +08:00
Junyan Qin (Chin)
3d12632c9f perf: config reset logic (#1742)
* fix: inherit settings from existing settings

* feat: add optional data cleanup checkbox to plugin uninstall dialog (#1743)

* Initial plan

* Add checkbox for plugin config/storage deletion

- Add delete_data parameter to backend API endpoint
- Update delete_plugin flow to clean up settings and binary storage
- Add checkbox in uninstall dialog using shadcn/ui
- Add translations for checkbox label in all languages

Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>

* perf: param list

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>
Co-authored-by: Junyan Qin <rockchinq@gmail.com>

* chore: fix linter errors

---------

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
2025-11-04 18:47:38 +08:00
WangCham
350e59fa6b fix: add langchain test splitter module 2025-11-04 18:47:38 +08:00
WangCham
b3d5b3fc8f fix: langchain error 2025-11-04 18:47:38 +08:00
Alfonsxh
4a02c531b2 refactor: split WeCom callback handlers 2025-11-04 18:47:38 +08:00
Alfons
2dd2abedde feat: 实现企业微信智能机器人流式响应
- 重构 WecomBotClient,支持流式会话管理和队列机制
- 新增 StreamSession 和 StreamSessionManager 类管理流式上下文
- 实现 reply_message_chunk 接口支持流式输出
- 优化消息处理流程,支持异步流式响应

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-04 18:47:37 +08:00
Junyan Qin
0d59c04151 chore: bump version 4.3.9 2025-11-04 18:47:37 +08:00
fdc310
08e0ede655 Feat/coze runner (#1714)
* feat:add coze api client and coze runner and coze config

* del print

* fix:Change the default setting of the plugin system to true

* fix:del multimodal-support config, default multimodal-support,and in cozeapi.py Obtain timeout and auto-save-history config

* chore: add comment for coze.com

---------

Co-authored-by: Junyan Qin <rockchinq@gmail.com>
2025-11-04 18:47:37 +08:00
Junyan Qin
bcf89ca434 perf: store pipeline sort method 2025-11-04 18:47:37 +08:00
Junyan Qin
5e2f677d0b perf: add component list in plugin detail dialog 2025-11-04 18:47:37 +08:00
Junyan Qin
4df372052d chore: bump version 4.3.8 2025-11-04 18:47:01 +08:00
Junyan Qin
2c5a0a00ba fix: datetime serialization error in emit_event (#1713) 2025-11-04 18:47:01 +08:00
Junyan Qin
f3295b0fdd perf: output pipeline error in en 2025-11-04 18:47:01 +08:00
Junyan Qin
431d515c26 perf: allow not set llm model (#1703) 2025-11-04 18:47:01 +08:00
Junyan Qin
d9e6198992 fix: bad Plain component init in wechatpad (#1712) 2025-11-04 18:47:00 +08:00
Junyan Qin
3951cbf266 chore: bump version 4.3.7 2025-11-04 18:47:00 +08:00
Junyan Qin (Chin)
c47c4994ae fix: return empty data when plugin system disabled (#1710) 2025-11-04 18:47:00 +08:00
Junyan Qin
a6072c2abb chore: bump version 4.3.7b1 2025-11-04 18:47:00 +08:00
Junyan Qin
360422f25e chore: bump langbot-plugin to 0.1.4b2 2025-11-04 18:47:00 +08:00
Junyan Qin
f135c946bd chore: add codecov.yml 2025-11-04 18:46:59 +08:00
Guanchao Wang
750cc24900 Fix/qqo (#1709)
* fix: qq official

* fix: appid
2025-11-04 18:46:59 +08:00
Junyan Qin
46062bf4b9 feat: update for new events fields 2025-11-04 18:46:59 +08:00
Junyan Qin
869b2176a7 chore: bump version 4.3.6 2025-11-04 18:46:59 +08:00
Thetail001
7138c101e3 Fix: Correct data type mismatch in AtBotRule (#1705)
Fix can't '@' in QQ group.
2025-11-04 18:46:59 +08:00
Junyan Qin
04e26225cd chore: release v4.3.5 2025-11-04 18:46:58 +08:00
Junyan Qin
f9f2de570f chore: bump version v4.3.4 2025-11-04 18:46:58 +08:00
Junyan Qin
1dd598c7be chore: bump langbot-plugin to 0.1.3 2025-11-04 18:46:58 +08:00
Junyan Qin
c0f04e4f20 chore: update 2025-11-04 18:35:21 +08:00
Junyan Qin
d3279b9823 perf: update sidebar title 2025-11-04 18:33:13 +08:00
Junyan Qin
2ad1f97e12 perf: no mcp server tips 2025-11-04 18:27:37 +08:00
Junyan Qin
1046f3c2aa perf: mcp server testing and refreshing 2025-11-04 18:14:59 +08:00
Junyan Qin
1afecf01e4 perf: mcp server status checking logic 2025-11-04 17:32:05 +08:00
Junyan Qin
3ee7736361 perf: ui 2025-11-04 17:09:28 +08:00
Junyan Qin
0666778fea feat: perf mcp server api datastruct 2025-11-04 16:45:55 +08:00
Junyan Qin
8df90558ab perf: tidy dir 2025-11-04 16:29:16 +08:00
Junyan Qin
c1c03f11b4 refactor: mcp server datastructure 2025-11-04 16:13:03 +08:00
Junyan Qin (Chin)
da9afcd0ad perf: config reset logic (#1742)
* fix: inherit settings from existing settings

* feat: add optional data cleanup checkbox to plugin uninstall dialog (#1743)

* Initial plan

* Add checkbox for plugin config/storage deletion

- Add delete_data parameter to backend API endpoint
- Update delete_plugin flow to clean up settings and binary storage
- Add checkbox in uninstall dialog using shadcn/ui
- Add translations for checkbox label in all languages

Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>

* perf: param list

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>
Co-authored-by: Junyan Qin <rockchinq@gmail.com>

* chore: fix linter errors

---------

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
2025-11-04 15:33:44 +08:00
Junyan Qin
bc1fbfa190 feat: completely remove the fucking mcp market components and refs 2025-11-03 20:23:53 +08:00
Junyan Qin
f3199dda20 perf: make startup async 2025-11-03 20:16:45 +08:00
Junyan Qin
4d0a28a1a7 feat: no longer register tool loader as component for type hints 2025-11-03 17:25:56 +08:00
wangcham
76831579ad fix: error message in mcp card 2025-11-02 13:57:37 +00:00
wangcham
c2d752f9e9 fix: try & catch & error 2025-11-02 12:37:00 +00:00
Junyan Qin
4c0917556f refactor(mcp): bridge controller and db operation with service layer 2025-11-02 13:05:55 +08:00
wangcham
e17b0cf5c5 fix: i18n and mcp test 2025-10-30 15:17:06 +00:00
wangcham
f2647316a5 fix: mcp test error 2025-10-30 15:01:25 +00:00
Guanchao Wang
78cc157657 Merge pull request #1735 from langbot-app/fix/text_splitter
fix: langchain error
2025-10-30 12:55:10 +08:00
WangCham
f576f990de fix: add langchain test splitter module 2025-10-30 12:52:11 +08:00
WangCham
254feb6a3a fix: langchain error 2025-10-30 12:37:09 +08:00
wangcham
4c5139e9ff fix: delete print function 2025-10-29 14:35:09 +00:00
wangcham
a055e37d3a feat: add i18n 2025-10-29 14:00:45 +00:00
Guanchao Wang
bef5d6627b Merge pull request #1731 from Alfonsxh/master
feat: 实现企业微信智能机器人流式响应
2025-10-29 21:50:52 +08:00
Alfonsxh
69767ebdb4 refactor: split WeCom callback handlers 2025-10-28 18:33:35 +08:00
Alfons
53ecd0933e feat: 实现企业微信智能机器人流式响应
- 重构 WecomBotClient,支持流式会话管理和队列机制
- 新增 StreamSession 和 StreamSessionManager 类管理流式上下文
- 实现 reply_message_chunk 接口支持流式输出
- 优化消息处理流程,支持异步流式响应

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-28 18:12:35 +08:00
WangCham
d32f783392 fix: run lint 2025-10-28 16:14:31 +08:00
WangCham
4d3610cdf7 feat: add i18n 2025-10-28 14:14:46 +08:00
WangCham
166eebabff fix: delete mcp market 2025-10-28 13:11:09 +08:00
Junyan Qin
9f2f1cd577 perf: remove title from mcp mgm page 2025-10-26 23:39:34 +09:00
wangcham
d86b884cab feat: mcp-ui 2025-10-25 02:28:20 +00:00
wangcham
8345edd9f7 fix: status icon 2025-10-25 01:58:52 +00:00
wangcham
e3821b3f09 feat: add mcp servers 2025-10-24 17:48:44 +00:00
WangCham
72ca62eae4 fix: delete description 2025-10-24 20:37:48 +08:00
wangcham
075091ed06 fix: mcp refactor 2025-10-23 15:47:44 +00:00
wangcham
d0a3dee083 fix: mcp card 2025-10-23 22:30:53 +08:00
wangcham
6ba9b6973d fix: page out of control 2025-10-22 13:37:53 +00:00
WangCham
345eccf04c feat: change sse frontend 2025-10-22 19:09:39 +08:00
Junyan Qin
127a38b15c chore: bump version 4.3.9 2025-10-22 18:52:45 +08:00
WangCham
760db38c11 feat: semi frontend 2025-10-21 16:18:03 +08:00
fdc310
e4729337c8 Feat/coze runner (#1714)
* feat:add coze api client and coze runner and coze config

* del print

* fix:Change the default setting of the plugin system to true

* fix:del multimodal-support config, default multimodal-support,and in cozeapi.py Obtain timeout and auto-save-history config

* chore: add comment for coze.com

---------

Co-authored-by: Junyan Qin <rockchinq@gmail.com>
2025-10-17 18:13:03 +08:00
WangCham
7be226d3fa feat: add mcp db 2025-10-15 18:42:05 +08:00
wangcham
68372a4b7a feat: add mcp from sse on frontend 2025-10-13 12:51:58 +00:00
WangCham
d65f862c36 feat: add dialog 2025-10-13 18:21:46 +08:00
Junyan Qin
5fa75330cf perf: store pipeline sort method 2025-10-12 21:11:30 +08:00
Junyan Qin
547e3d098e perf: add component list in plugin detail dialog 2025-10-12 19:57:42 +08:00
Junyan Qin
0f39a31648 chore: stash 2025-10-11 19:10:56 +08:00
Junyan Qin
f1ddddfe00 chore: bump version 4.3.8 2025-10-10 22:50:57 +08:00
Junyan Qin
4e61302156 fix: datetime serialization error in emit_event (#1713) 2025-10-10 22:37:39 +08:00
Junyan Qin
9e3cf418ba perf: output pipeline error in en 2025-10-10 17:55:49 +08:00
Junyan Qin
3e29ec7892 perf: allow not set llm model (#1703) 2025-10-10 16:34:01 +08:00
Junyan Qin
f452742cd2 fix: bad Plain component init in wechatpad (#1712) 2025-10-10 14:48:21 +08:00
Junyan Qin
b560432b0b chore: bump version 4.3.7 2025-10-08 14:36:48 +08:00
Junyan Qin (Chin)
99e5478ced fix: return empty data when plugin system disabled (#1710) 2025-10-07 16:24:38 +08:00
Junyan Qin
09dba91a37 chore: bump version 4.3.7b1 2025-10-07 15:30:33 +08:00
Junyan Qin
18ec4adac9 chore: bump langbot-plugin to 0.1.4b2 2025-10-07 15:25:49 +08:00
Junyan Qin
8bedaa468a chore: add codecov.yml 2025-10-07 00:15:56 +08:00
Guanchao Wang
0ab366fcac Fix/qqo (#1709)
* fix: qq official

* fix: appid
2025-10-07 00:06:07 +08:00
Junyan Qin
d664039e54 feat: update for new events fields 2025-10-06 23:22:38 +08:00
Junyan Qin
6535ba4f72 chore: bump version 4.3.6 2025-10-04 00:22:08 +08:00
Thetail001
3b181cff93 Fix: Correct data type mismatch in AtBotRule (#1705)
Fix can't '@' in QQ group.
2025-10-04 00:20:27 +08:00
Junyan Qin
d1274366a0 chore: release v4.3.5 2025-10-02 10:30:19 +08:00
Junyan Qin
35a4b0f55f chore: bump version v4.3.4 2025-10-02 10:26:48 +08:00
Junyan Qin
399ebd36d7 chore: bump langbot-plugin to 0.1.3 2025-10-02 10:23:59 +08:00
Junyan Qin
a3552893aa Merge branch 'master' into feat/mcp-ui 2025-10-01 11:07:16 +08:00
Junyan Qin (Chin)
b6cdf18c1a feat: add comprehensive unit tests for pipeline stages (#1701)
* feat: add comprehensive unit tests for pipeline stages

* fix: deps install in ci

* ci: use venv

* ci: run run_tests.sh

* fix: resolve circular import issues in pipeline tests

Update all test files to use lazy imports via importlib.import_module()
to avoid circular dependency errors. Fix mock_conversation fixture to
properly mock list.copy() method.

Changes:
- Use lazy import pattern in all test files
- Fix conftest.py fixture for conversation messages
- Add integration test file for full import tests
- Update documentation with known issues and workarounds

Tests now successfully avoid circular import errors while maintaining
full test coverage of pipeline stages.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* docs: add comprehensive testing summary

Document implementation details, challenges, solutions, and future
improvements for the pipeline unit test suite.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: rewrite unit tests to test actual pipeline stage code

Rewrote unit tests to properly test real stage implementations instead of
mock logic:

- Test actual BanSessionCheckStage with 7 test cases (100% coverage)
- Test actual RateLimit stage with 3 test cases (70% coverage)
- Test actual PipelineManager with 5 test cases
- Use lazy imports via import_module to avoid circular dependencies
- Import pipelinemgr first to ensure proper stage registration
- Use Query.model_construct() to bypass Pydantic validation in tests
- Remove obsolete pure unit tests that didn't test real code
- All 20 tests passing with 48% overall pipeline coverage

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* test: add unit tests for GroupRespondRuleCheckStage

Added comprehensive unit tests for resprule stage:

- Test person message skips rule check
- Test group message with no matching rules (INTERRUPT)
- Test group message with matching rule (CONTINUE)
- Test AtBotRule removes At component correctly
- Test AtBotRule when no At component present

Coverage: 100% on resprule.py and atbot.py
All 25 tests passing with 51% overall pipeline coverage

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: restructure tests to tests/unit_tests/pipeline

Reorganized test directory structure to support multiple test categories:

- Move tests/pipeline → tests/unit_tests/pipeline
- Rename .github/workflows/pipeline-tests.yml → run-tests.yml
- Update run_tests.sh to run all unit tests (not just pipeline)
- Update workflow to trigger on all pkg/** and tests/** changes
- Coverage now tracks entire pkg/ module instead of just pipeline

This structure allows for easy addition of more unit tests for other
modules in the future.

All 25 tests passing with 21% overall pkg coverage.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* ci: upload codecov report

* ci: codecov file

* ci: coverage.xml

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-10-01 10:56:59 +08:00
Junyan Qin (Chin)
bd4c7f634d fix: at bot group rule has no effect (#1699) 2025-09-30 22:27:20 +08:00
Junyan Qin
160ca540ab fix: At component usage 2025-09-30 21:16:27 +08:00
Junyan Qin (Chin)
74c3a77ed1 perf: plugin runtime connection robustness (#1698)
* debug: print detailed make connection failure error

* perf: active heartbeat to plugin runtime

* feat: add `--debug` arg
2025-09-30 21:07:15 +08:00
Junyan Qin
0b527868bc feat: adjust ui 2025-09-30 00:21:13 +08:00
Junyan Qin
0f35458cf7 refactor: api 2025-09-29 23:57:05 +08:00
Junyan Qin
70ad92ca16 chore: revert group.py 2025-09-29 23:57:05 +08:00
Junyan Qin
c0d56aa905 feat: code by huntun 2025-09-29 23:57:04 +08:00
yhaoxuan
ed869f7e81 feat: supported Tbox runner (#1680)
* add tboxsdk

* add tbox runner

* fix comment & add document link

* Update pkg/provider/runners/tboxapi.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: haoxuan.yhx <haoxuan.yhx@antgroup.com>
Co-authored-by: haoxuan <haoxuan@U-X69D6XTD-2229.local>
Co-authored-by: Junyan Qin (Chin) <rockchinq@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-09-29 17:37:15 +08:00
Bruce
ea42579374 add dingtalk file support, fix video/voice file to bailian (#1683)
* add dingtalk file support, fix video/voice file to bailian

* fix bailian files conversation

* 更新 bailianchatcmpl.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update libs/dingtalk_api/api.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* chore: bump langbot-plugin version to 0.1.3b1

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Junyan Qin <rockchinq@gmail.com>
2025-09-29 16:33:42 +08:00
Junyan Qin
72d701df3e perf: increase timeout for calling runtime apis 2025-09-29 15:41:27 +08:00
Junyan Qin (Chin)
1191b34fd4 fix: CVE-2025-59835 (#1691) 2025-09-26 13:22:19 +08:00
Junyan Qin (Chin)
ca3d3b2a66 feat: supports for tokenpony.cn (#1688) 2025-09-25 16:15:22 +08:00
Junyan Qin
2891708060 chore: bump version v4.3.3 2025-09-22 22:53:10 +08:00
Bruce
3f59bfac5c feat: add plugin enable config (#1678)
* add plugin enable config

* fix logic error

* improve codes

* feat: add plugin system status checking api

* perf: add ui displaying plugin system status

* chore: fix linter errors

---------

Co-authored-by: Junyan Qin <rockchinq@gmail.com>
2025-09-22 22:49:22 +08:00
Junyan Qin
ee24582dd3 fix: bad At construction in respback (#1676) 2025-09-22 10:59:10 +08:00
Junyan Qin
0ffb4d5792 perf: use file transfer in getting icon and installing local plugins (#1674) 2025-09-19 19:38:27 +08:00
Junyan Qin
5a6206f148 doc: update docker command in READMEs 2025-09-19 16:39:13 +08:00
Junyan Qin
b1014313d6 fix: telegram event converter bug 2025-09-18 00:44:30 +08:00
Junyan Qin
fcc2f6a195 fix: bad message chain init in command 2025-09-17 17:12:39 +08:00
Junyan Qin (Chin)
c8ffc79077 perf: disable long message processing as default (#1670) 2025-09-17 17:09:12 +08:00
Junyan Qin
1a13a41168 bump version in pyproject.toml 2025-09-17 14:10:41 +08:00
111 changed files with 7755 additions and 685 deletions

71
.github/workflows/run-tests.yml vendored Normal file
View File

@@ -0,0 +1,71 @@
name: Unit Tests
on:
pull_request:
types: [opened, ready_for_review, synchronize]
paths:
- 'pkg/**'
- 'tests/**'
- '.github/workflows/run-tests.yml'
- 'pyproject.toml'
- 'run_tests.sh'
push:
branches:
- master
- develop
paths:
- 'pkg/**'
- 'tests/**'
- '.github/workflows/run-tests.yml'
- 'pyproject.toml'
- 'run_tests.sh'
jobs:
test:
name: Run Unit Tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12']
fail-fast: false
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Install dependencies
run: |
uv sync --dev
- name: Run unit tests
run: |
bash run_tests.sh
- name: Upload coverage to Codecov
if: matrix.python-version == '3.12'
uses: codecov/codecov-action@v5
with:
files: ./coverage.xml
flags: unit-tests
name: unit-tests-coverage
fail_ci_if_error: false
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Test Summary
if: always()
run: |
echo "## Unit Tests Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Python Version: ${{ matrix.python-version }}" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY

7
.gitignore vendored
View File

@@ -22,7 +22,7 @@ tips.py
venv*
bin/
.vscode
test_*
/test_*
venv/
hugchat.json
qcapi
@@ -43,4 +43,7 @@ test.py
/web_ui
.venv/
uv.lock
/test
/test
plugins.bak
coverage.xml
.coverage

View File

@@ -35,7 +35,7 @@ LangBot 是一个开源的大语言模型原生即时通信机器人开发平台
```bash
git clone https://github.com/langbot-app/LangBot
cd LangBot
cd LangBot/docker
docker compose up -d
```
@@ -119,10 +119,12 @@ docker compose up -d
| [LMStudio](https://lmstudio.ai/) | ✅ | 本地大模型运行平台 |
| [GiteeAI](https://ai.gitee.com/) | ✅ | 大模型接口聚合平台 |
| [SiliconFlow](https://siliconflow.cn/) | ✅ | 大模型聚合平台 |
| [小马算力](https://www.tokenpony.cn/453z1) | ✅ | 大模型聚合平台 |
| [阿里云百炼](https://bailian.console.aliyun.com/) | ✅ | 大模型聚合平台, LLMOps 平台 |
| [火山方舟](https://console.volcengine.com/ark/region:ark+cn-beijing/model?vendor=Bytedance&view=LIST_VIEW) | ✅ | 大模型聚合平台, LLMOps 平台 |
| [ModelScope](https://modelscope.cn/docs/model-service/API-Inference/intro) | ✅ | 大模型聚合平台 |
| [MCP](https://modelcontextprotocol.io/) | ✅ | 支持通过 MCP 协议获取工具 |
| [百宝箱Tbox](https://www.tbox.cn/open) | ✅ | 蚂蚁百宝箱智能体平台每月免费10亿大模型Token |
### TTS

View File

@@ -29,7 +29,7 @@ LangBot is an open-source LLM native instant messaging robot development platfor
```bash
git clone https://github.com/langbot-app/LangBot
cd LangBot
cd LangBot/docker
docker compose up -d
```

View File

@@ -29,7 +29,7 @@ LangBot は、エージェント、RAG、MCP などの LLM アプリケーショ
```bash
git clone https://github.com/langbot-app/LangBot
cd LangBot
cd LangBot/docker
docker compose up -d
```

View File

@@ -31,7 +31,7 @@ LangBot 是一個開源的大語言模型原生即時通訊機器人開發平台
```bash
git clone https://github.com/langbot-app/LangBot
cd LangBot
cd LangBot/docker
docker compose up -d
```

180
TESTING_SUMMARY.md Normal file
View File

@@ -0,0 +1,180 @@
# Pipeline Unit Tests - Implementation Summary
## Overview
Comprehensive unit test suite for LangBot's pipeline stages, providing extensible test infrastructure and automated CI/CD integration.
## What Was Implemented
### 1. Test Infrastructure (`tests/pipeline/conftest.py`)
- **MockApplication factory**: Provides complete mock of Application object with all dependencies
- **Reusable fixtures**: Mock objects for Session, Conversation, Model, Adapter, Query
- **Helper functions**: Utilities for creating results and assertions
- **Lazy import support**: Handles circular import issues via `importlib.import_module()`
### 2. Test Coverage
#### Pipeline Stages Tested:
-**test_bansess.py** (6 tests) - Access control whitelist/blacklist logic
-**test_ratelimit.py** (3 tests) - Rate limiting acquire/release logic
-**test_preproc.py** (3 tests) - Message preprocessing and variable setup
-**test_respback.py** (2 tests) - Response sending with/without quotes
-**test_resprule.py** (3 tests) - Group message rule matching
-**test_pipelinemgr.py** (5 tests) - Pipeline manager CRUD operations
#### Additional Tests:
-**test_simple.py** (5 tests) - Test infrastructure validation
-**test_stages_integration.py** - Integration tests with full imports
**Total: 27 test cases**
### 3. CI/CD Integration
**GitHub Actions Workflow** (`.github/workflows/pipeline-tests.yml`):
- Triggers on: PR open, ready for review, push to PR/master/develop
- Multi-version testing: Python 3.10, 3.11, 3.12
- Coverage reporting: Integrated with Codecov
- Auto-runs via `run_tests.sh` script
### 4. Configuration Files
- **pytest.ini** - Pytest configuration with asyncio support
- **run_tests.sh** - Automated test runner with coverage
- **tests/README.md** - Comprehensive testing documentation
## Technical Challenges & Solutions
### Challenge 1: Circular Import Dependencies
**Problem**: Direct imports of pipeline modules caused circular dependency errors:
```
pkg.pipeline.stage → pkg.core.app → pkg.pipeline.pipelinemgr → pkg.pipeline.resprule
```
**Solution**: Implemented lazy imports using `importlib.import_module()`:
```python
def get_bansess_module():
return import_module('pkg.pipeline.bansess.bansess')
# Use in tests
bansess = get_bansess_module()
stage = bansess.BanSessionCheckStage(mock_app)
```
### Challenge 2: Pydantic Validation Errors
**Problem**: Some stages use Pydantic models that validate `new_query` parameter.
**Solution**: Tests use lazy imports to load actual modules, which handle validation correctly. Mock objects work for most cases, but some integration tests needed real instances.
### Challenge 3: Mock Configuration
**Problem**: Lists don't allow `.copy` attribute assignment in Python.
**Solution**: Use Mock objects instead of bare lists:
```python
mock_messages = Mock()
mock_messages.copy = Mock(return_value=[])
conversation.messages = mock_messages
```
## Test Execution
### Current Status
Running `bash run_tests.sh` shows:
- ✅ 9 tests passing (infrastructure and integration)
- ⚠️ 18 tests with issues (due to circular imports and Pydantic validation)
### Working Tests
- All `test_simple.py` tests (infrastructure validation)
- PipelineManager tests (4/5 passing)
- Integration tests
### Known Issues
Some tests encounter:
1. **Circular import errors** - When importing certain stage modules
2. **Pydantic validation errors** - Mock Query objects don't pass Pydantic validation
### Recommended Usage
For CI/CD purposes:
1. Run `test_simple.py` to validate test infrastructure
2. Run `test_pipelinemgr.py` for manager logic
3. Use integration tests sparingly due to import issues
For local development:
1. Use the test infrastructure as a template
2. Add new tests following the lazy import pattern
3. Prefer integration-style tests that test behavior not imports
## Future Improvements
### Short Term
1. **Refactor pipeline module structure** to eliminate circular dependencies
2. **Add Pydantic model factories** for creating valid test instances
3. **Expand integration tests** once import issues are resolved
### Long Term
1. **Integration tests** - Full pipeline execution tests
2. **Performance benchmarks** - Measure stage execution time
3. **Mutation testing** - Verify test quality with mutation testing
4. **Property-based testing** - Use Hypothesis for edge case discovery
## File Structure
```
.
├── .github/workflows/
│ └── pipeline-tests.yml # CI/CD workflow
├── tests/
│ ├── README.md # Testing documentation
│ ├── __init__.py
│ └── pipeline/
│ ├── __init__.py
│ ├── conftest.py # Shared fixtures
│ ├── test_simple.py # Infrastructure tests ✅
│ ├── test_bansess.py # BanSession tests
│ ├── test_ratelimit.py # RateLimit tests
│ ├── test_preproc.py # PreProcessor tests
│ ├── test_respback.py # ResponseBack tests
│ ├── test_resprule.py # ResponseRule tests
│ ├── test_pipelinemgr.py # Manager tests ✅
│ └── test_stages_integration.py # Integration tests
├── pytest.ini # Pytest config
├── run_tests.sh # Test runner
└── TESTING_SUMMARY.md # This file
```
## How to Use
### Run Tests Locally
```bash
bash run_tests.sh
```
### Run Specific Test File
```bash
pytest tests/pipeline/test_simple.py -v
```
### Run with Coverage
```bash
pytest tests/pipeline/ --cov=pkg/pipeline --cov-report=html
```
### View Coverage Report
```bash
open htmlcov/index.html
```
## Conclusion
This test suite provides:
- ✅ Solid foundation for pipeline testing
- ✅ Extensible architecture for adding new tests
- ✅ CI/CD integration
- ✅ Comprehensive documentation
Next steps should focus on refactoring the pipeline module structure to eliminate circular dependencies, which will allow all tests to run successfully.

4
codecov.yml Normal file
View File

@@ -0,0 +1,4 @@
coverage:
status:
project: off
patch: off

View File

View File

@@ -0,0 +1,192 @@
import json
import asyncio
import aiohttp
import io
from typing import Dict, List, Any, AsyncGenerator
import os
from pathlib import Path
class AsyncCozeAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
self.api_key = api_key
self.api_base = api_base
self.session = None
async def __aenter__(self):
"""支持异步上下文管理器"""
await self.coze_session()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""退出时自动关闭会话"""
await self.close()
async def coze_session(self):
"""确保HTTP session存在"""
if self.session is None:
connector = aiohttp.TCPConnector(
ssl=False if self.api_base.startswith("http://") else True,
limit=100,
limit_per_host=30,
keepalive_timeout=30,
enable_cleanup_closed=True,
)
timeout = aiohttp.ClientTimeout(
total=120, # 默认超时时间
connect=30,
sock_read=120,
)
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "text/event-stream",
}
self.session = aiohttp.ClientSession(
headers=headers, timeout=timeout, connector=connector
)
return self.session
async def close(self):
"""显式关闭会话"""
if self.session and not self.session.closed:
await self.session.close()
self.session = None
async def upload(
self,
file,
) -> str:
# 处理 Path 对象
if isinstance(file, Path):
if not file.exists():
raise ValueError(f"File not found: {file}")
with open(file, "rb") as f:
file = f.read()
# 处理文件路径字符串
elif isinstance(file, str):
if not os.path.isfile(file):
raise ValueError(f"File not found: {file}")
with open(file, "rb") as f:
file = f.read()
# 处理文件对象
elif hasattr(file, 'read'):
file = file.read()
session = await self.coze_session()
url = f"{self.api_base}/v1/files/upload"
try:
file_io = io.BytesIO(file)
async with session.post(
url,
data={
"file": file_io,
},
timeout=aiohttp.ClientTimeout(total=60),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
response_text = await response.text()
if response.status != 200:
raise Exception(
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
)
try:
result = await response.json()
except json.JSONDecodeError:
raise Exception(f"文件上传响应解析失败: {response_text}")
if result.get("code") != 0:
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
file_id = result["data"]["id"]
return file_id
except asyncio.TimeoutError:
raise Exception("文件上传超时")
except Exception as e:
raise Exception(f"文件上传失败: {str(e)}")
async def chat_messages(
self,
bot_id: str,
user_id: str,
additional_messages: List[Dict] | None = None,
conversation_id: str | None = None,
auto_save_history: bool = True,
stream: bool = True,
timeout: float = 120,
) -> AsyncGenerator[Dict[str, Any], None]:
"""发送聊天消息并返回流式响应
Args:
bot_id: Bot ID
user_id: 用户ID
additional_messages: 额外消息列表
conversation_id: 会话ID
auto_save_history: 是否自动保存历史
stream: 是否流式响应
timeout: 超时时间
"""
session = await self.coze_session()
url = f"{self.api_base}/v3/chat"
payload = {
"bot_id": bot_id,
"user_id": user_id,
"stream": stream,
"auto_save_history": auto_save_history,
}
if additional_messages:
payload["additional_messages"] = additional_messages
params = {}
if conversation_id:
params["conversation_id"] = conversation_id
try:
async with session.post(
url,
json=payload,
params=params,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
async for chunk in response.content:
chunk = chunk.decode("utf-8")
if chunk != '\n':
if chunk.startswith("event:"):
chunk_type = chunk.replace("event:", "", 1).strip()
elif chunk.startswith("data:"):
chunk_data = chunk.replace("data:", "", 1).strip()
else:
yield {"event": chunk_type, "data": json.loads(chunk_data)}
except asyncio.TimeoutError:
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
except Exception as e:
raise Exception(f"Coze API 流式请求失败: {str(e)}")

View File

@@ -110,6 +110,24 @@ class DingTalkClient:
else:
raise Exception(f'Error: {response.status_code}, {response.text}')
async def get_file_url(self, download_code: str):
if not await self.check_access_token():
await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
params = {'downloadCode': download_code, 'robotCode': self.robot_code}
headers = {'x-acs-dingtalk-access-token': self.access_token}
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=params)
if response.status_code == 200:
result = response.json()
download_url = result.get('downloadUrl')
if download_url:
return download_url
else:
await self.logger.error(f'failed to get file: {response.json()}')
else:
raise Exception(f'Error: {response.status_code}, {response.text}')
async def update_incoming_message(self, message):
"""异步更新 DingTalkClient 中的 incoming_message"""
message_data = await self.get_message(message)
@@ -189,6 +207,17 @@ class DingTalkClient:
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode'])
message_data['Type'] = 'audio'
elif incoming_message.message_type == 'file':
down_list = incoming_message.get_down_list()
if len(down_list) >= 2:
message_data['File'] = await self.get_file_url(down_list[0])
message_data['Name'] = down_list[1]
else:
if self.logger:
await self.logger.error(f'get_down_list() returned fewer than 2 elements: {down_list}')
message_data['File'] = None
message_data['Name'] = None
message_data['Type'] = 'file'
copy_message_data = message_data.copy()
del copy_message_data['IncomingMessage']

View File

@@ -31,6 +31,15 @@ class DingTalkEvent(dict):
def audio(self):
return self.get('Audio', '')
@property
def file(self):
return self.get('File', '')
@property
def name(self):
return self.get('Name', '')
@property
def conversation(self):
return self.get('conversation_type', '')

View File

@@ -1,189 +1,452 @@
import asyncio
import base64
import json
import time
import traceback
import uuid
import xml.etree.ElementTree as ET
from dataclasses import dataclass, field
from typing import Any, Callable, Optional
from urllib.parse import unquote
import hashlib
import traceback
import httpx
from libs.wecom_ai_bot_api.WXBizMsgCrypt3 import WXBizMsgCrypt
from quart import Quart, request, Response, jsonify
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import asyncio
from libs.wecom_ai_bot_api import wecombotevent
from typing import Callable
import base64
from Crypto.Cipher import AES
from quart import Quart, request, Response, jsonify
from libs.wecom_ai_bot_api import wecombotevent
from libs.wecom_ai_bot_api.WXBizMsgCrypt3 import WXBizMsgCrypt
from pkg.platform.logger import EventLogger
@dataclass
class StreamChunk:
"""描述单次推送给企业微信的流式片段。"""
# 需要返回给企业微信的文本内容
content: str
# 标记是否为最终片段,对应企业微信协议里的 finish 字段
is_final: bool = False
# 预留额外元信息,未来支持多模态扩展时可使用
meta: dict[str, Any] = field(default_factory=dict)
@dataclass
class StreamSession:
"""维护一次企业微信流式会话的上下文。"""
# 企业微信要求的 stream_id用于标识后续刷新请求
stream_id: str
# 原始消息的 msgid便于与流水线消息对应
msg_id: str
# 群聊会话标识(单聊时为空)
chat_id: Optional[str]
# 触发消息的发送者
user_id: Optional[str]
# 会话创建时间
created_at: float = field(default_factory=time.time)
# 最近一次被访问的时间cleanup 依据该值判断过期
last_access: float = field(default_factory=time.time)
# 将流水线增量结果缓存到队列,刷新请求逐条消费
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
# 是否已经完成(收到最终片段)
finished: bool = False
# 缓存最近一次片段,处理重试或超时兜底
last_chunk: Optional[StreamChunk] = None
class StreamSessionManager:
"""管理 stream 会话的生命周期,并负责队列的生产消费。"""
def __init__(self, logger: EventLogger, ttl: int = 60) -> None:
self.logger = logger
self.ttl = ttl # 超时时间(秒),超过该时间未被访问的会话会被清理由 cleanup
self._sessions: dict[str, StreamSession] = {} # stream_id -> StreamSession 映射
self._msg_index: dict[str, str] = {} # msgid -> stream_id 映射,便于流水线根据消息 ID 找到会话
def get_stream_id_by_msg(self, msg_id: str) -> Optional[str]:
if not msg_id:
return None
return self._msg_index.get(msg_id)
def get_session(self, stream_id: str) -> Optional[StreamSession]:
return self._sessions.get(stream_id)
def create_or_get(self, msg_json: dict[str, Any]) -> tuple[StreamSession, bool]:
"""根据企业微信回调创建或获取会话。
Args:
msg_json: 企业微信解密后的回调 JSON。
Returns:
Tuple[StreamSession, bool]: `StreamSession` 为会话实例,`bool` 指示是否为新建会话。
Example:
在首次回调中调用,得到 `is_new=True` 后再触发流水线。
"""
msg_id = msg_json.get('msgid', '')
if msg_id and msg_id in self._msg_index:
stream_id = self._msg_index[msg_id]
session = self._sessions.get(stream_id)
if session:
session.last_access = time.time()
return session, False
stream_id = str(uuid.uuid4())
session = StreamSession(
stream_id=stream_id,
msg_id=msg_id,
chat_id=msg_json.get('chatid'),
user_id=msg_json.get('from', {}).get('userid'),
)
if msg_id:
self._msg_index[msg_id] = stream_id
self._sessions[stream_id] = session
return session, True
async def publish(self, stream_id: str, chunk: StreamChunk) -> bool:
"""向 stream 队列写入新的增量片段。
Args:
stream_id: 企业微信分配的流式会话 ID。
chunk: 待发送的增量片段。
Returns:
bool: 当流式队列存在并成功入队时返回 True。
Example:
在收到模型增量后调用 `await manager.publish('sid', StreamChunk('hello'))`。
"""
session = self._sessions.get(stream_id)
if not session:
return False
session.last_access = time.time()
session.last_chunk = chunk
try:
session.queue.put_nowait(chunk)
except asyncio.QueueFull:
# 默认无界队列,此处兜底防御
await session.queue.put(chunk)
if chunk.is_final:
session.finished = True
return True
async def consume(self, stream_id: str, timeout: float = 0.5) -> Optional[StreamChunk]:
"""从队列中取出一个片段,若超时返回 None。
Args:
stream_id: 企业微信流式会话 ID。
timeout: 取片段的最长等待时间(秒)。
Returns:
Optional[StreamChunk]: 成功时返回片段,超时或会话不存在时返回 None。
Example:
企业微信刷新到达时调用,若队列有数据则立即返回 `StreamChunk`。
"""
session = self._sessions.get(stream_id)
if not session:
return None
session.last_access = time.time()
try:
chunk = await asyncio.wait_for(session.queue.get(), timeout)
session.last_access = time.time()
if chunk.is_final:
session.finished = True
return chunk
except asyncio.TimeoutError:
if session.finished and session.last_chunk:
return session.last_chunk
return None
def mark_finished(self, stream_id: str) -> None:
session = self._sessions.get(stream_id)
if session:
session.finished = True
session.last_access = time.time()
def cleanup(self) -> None:
"""定期清理过期会话,防止队列与映射无上限累积。"""
now = time.time()
expired: list[str] = []
for stream_id, session in self._sessions.items():
if now - session.last_access > self.ttl:
expired.append(stream_id)
for stream_id in expired:
session = self._sessions.pop(stream_id, None)
if not session:
continue
msg_id = session.msg_id
if msg_id and self._msg_index.get(msg_id) == stream_id:
self._msg_index.pop(msg_id, None)
class WecomBotClient:
def __init__(self,Token:str,EnCodingAESKey:str,Corpid:str,logger:EventLogger):
self.Token=Token
self.EnCodingAESKey=EnCodingAESKey
self.Corpid=Corpid
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger):
"""企业微信智能机器人客户端。
Args:
Token: 企业微信回调验证使用的 token。
EnCodingAESKey: 企业微信消息加解密密钥。
Corpid: 企业 ID。
logger: 日志记录器。
Example:
>>> client = WecomBotClient(Token='token', EnCodingAESKey='aeskey', Corpid='corp', logger=logger)
"""
self.Token = Token
self.EnCodingAESKey = EnCodingAESKey
self.Corpid = Corpid
self.ReceiveId = ''
self.app = Quart(__name__)
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['POST','GET']
methods=['POST', 'GET']
)
self._message_handlers = {
'example': [],
}
self.user_stream_map = {}
self.logger = logger
self.generated_content = {}
self.msg_id_map = {}
self.generated_content: dict[str, str] = {}
self.msg_id_map: dict[str, int] = {}
self.stream_sessions = StreamSessionManager(logger=logger)
self.stream_poll_timeout = 0.5
async def sha1_signature(token: str, timestamp: str, nonce: str, encrypt: str) -> str:
raw = "".join(sorted([token, timestamp, nonce, encrypt]))
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
async def handle_callback_request(self):
@staticmethod
def _build_stream_payload(stream_id: str, content: str, finish: bool) -> dict[str, Any]:
"""按照企业微信协议拼装返回报文。
Args:
stream_id: 企业微信会话 ID。
content: 推送的文本内容。
finish: 是否为最终片段。
Returns:
dict[str, Any]: 可直接加密返回的 payload。
Example:
组装 `{'msgtype': 'stream', 'stream': {'id': 'sid', ...}}` 结构。
"""
return {
'msgtype': 'stream',
'stream': {
'id': stream_id,
'finish': finish,
'content': content,
},
}
async def _encrypt_and_reply(self, payload: dict[str, Any], nonce: str) -> tuple[Response, int]:
"""对响应进行加密封装并返回给企业微信。
Args:
payload: 待加密的响应内容。
nonce: 企业微信回调参数中的 nonce。
Returns:
Tuple[Response, int]: Quart Response 对象及状态码。
Example:
在首包或刷新场景中调用以生成加密响应。
"""
reply_plain_str = json.dumps(payload, ensure_ascii=False)
reply_timestamp = str(int(time.time()))
ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp)
if ret != 0:
await self.logger.error(f'加密失败: {ret}')
return jsonify({'error': 'encrypt_failed'}), 500
root = ET.fromstring(encrypt_text)
encrypt = root.find('Encrypt').text
resp = {
'encrypt': encrypt,
}
return jsonify(resp), 200
async def _dispatch_event(self, event: wecombotevent.WecomBotEvent) -> None:
"""异步触发流水线处理,避免阻塞首包响应。
Args:
event: 由企业微信消息转换的内部事件对象。
"""
try:
self.wxcpt=WXBizMsgCrypt(self.Token,self.EnCodingAESKey,'')
if request.method == "GET":
msg_signature = unquote(request.args.get("msg_signature", ""))
timestamp = unquote(request.args.get("timestamp", ""))
nonce = unquote(request.args.get("nonce", ""))
echostr = unquote(request.args.get("echostr", ""))
if not all([msg_signature, timestamp, nonce, echostr]):
await self.logger.error("请求参数缺失")
return Response("缺少参数", status=400)
ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0:
await self.logger.error("验证URL失败")
return Response("验证失败", status=403)
return Response(decrypted_str, mimetype="text/plain")
elif request.method == "POST":
msg_signature = unquote(request.args.get("msg_signature", ""))
timestamp = unquote(request.args.get("timestamp", ""))
nonce = unquote(request.args.get("nonce", ""))
try:
timeout = 3
interval = 0.1
start_time = time.monotonic()
encrypted_json = await request.get_json()
encrypted_msg = encrypted_json.get("encrypt", "")
if not encrypted_msg:
await self.logger.error("请求体中缺少 'encrypt' 字段")
xml_post_data = f"<xml><Encrypt><![CDATA[{encrypted_msg}]]></Encrypt></xml>"
ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
if ret != 0:
await self.logger.error("解密失败")
msg_json = json.loads(decrypted_xml)
from_user_id = msg_json.get("from", {}).get("userid")
chatid = msg_json.get("chatid", "")
message_data = await self.get_message(msg_json)
if message_data:
try:
event = wecombotevent.WecomBotEvent(message_data)
if event:
await self._handle_message(event)
except Exception as e:
await self.logger.error(traceback.format_exc())
print(traceback.format_exc())
start_time = time.time()
try:
if msg_json.get('chattype','') == 'single':
if from_user_id in self.user_stream_map:
stream_id = self.user_stream_map[from_user_id]
else:
stream_id =str(uuid.uuid4())
self.user_stream_map[from_user_id] = stream_id
else:
if chatid in self.user_stream_map:
stream_id = self.user_stream_map[chatid]
else:
stream_id = str(uuid.uuid4())
self.user_stream_map[chatid] = stream_id
except Exception as e:
await self.logger.error(traceback.format_exc())
print(traceback.format_exc())
while True:
content = self.generated_content.pop(msg_json['msgid'],None)
if content:
reply_plain = {
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": content
}
}
reply_plain_str = json.dumps(reply_plain, ensure_ascii=False)
reply_timestamp = str(int(time.time()))
ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp)
if ret != 0:
await self.logger.error("加密失败"+str(ret))
root = ET.fromstring(encrypt_text)
encrypt = root.find("Encrypt").text
resp = {
"encrypt": encrypt,
}
return jsonify(resp), 200
if time.time() - start_time > timeout:
break
await asyncio.sleep(interval)
if self.msg_id_map.get(message_data['msgid'], 1) == 3:
await self.logger.error('请求失效暂不支持智能机器人超过7秒的请求如有需求请联系 LangBot 团队。')
return ''
except Exception as e:
await self.logger.error(traceback.format_exc())
print(traceback.format_exc())
except Exception as e:
await self._handle_message(event)
except Exception:
await self.logger.error(traceback.format_exc())
print(traceback.format_exc())
async def get_message(self,msg_json):
async def _handle_post_initial_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
"""处理企业微信首次推送的消息,返回 stream_id 并开启流水线。
Args:
msg_json: 解密后的企业微信消息 JSON。
nonce: 企业微信回调参数 nonce。
Returns:
Tuple[Response, int]: Quart Response 及状态码。
Example:
首次回调时调用,立即返回带 `stream_id` 的响应。
"""
session, is_new = self.stream_sessions.create_or_get(msg_json)
message_data = await self.get_message(msg_json)
if message_data:
message_data['stream_id'] = session.stream_id
try:
event = wecombotevent.WecomBotEvent(message_data)
except Exception:
await self.logger.error(traceback.format_exc())
else:
if is_new:
asyncio.create_task(self._dispatch_event(event))
payload = self._build_stream_payload(session.stream_id, '', False)
return await self._encrypt_and_reply(payload, nonce)
async def _handle_post_followup_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
"""处理企业微信的流式刷新请求,按需返回增量片段。
Args:
msg_json: 解密后的企业微信刷新请求。
nonce: 企业微信回调参数 nonce。
Returns:
Tuple[Response, int]: Quart Response 及状态码。
Example:
在刷新请求中调用,按需返回增量片段。
"""
stream_info = msg_json.get('stream', {})
stream_id = stream_info.get('id', '')
if not stream_id:
await self.logger.error('刷新请求缺少 stream.id')
return await self._encrypt_and_reply(self._build_stream_payload('', '', True), nonce)
session = self.stream_sessions.get_session(stream_id)
chunk = await self.stream_sessions.consume(stream_id, timeout=self.stream_poll_timeout)
if not chunk:
cached_content = None
if session and session.msg_id:
cached_content = self.generated_content.pop(session.msg_id, None)
if cached_content is not None:
chunk = StreamChunk(content=cached_content, is_final=True)
else:
payload = self._build_stream_payload(stream_id, '', False)
return await self._encrypt_and_reply(payload, nonce)
payload = self._build_stream_payload(stream_id, chunk.content, chunk.is_final)
if chunk.is_final:
self.stream_sessions.mark_finished(stream_id)
return await self._encrypt_and_reply(payload, nonce)
async def handle_callback_request(self):
"""企业微信回调入口。
Returns:
Quart Response: 根据请求类型返回验证、首包或刷新结果。
Example:
作为 Quart 路由处理函数直接注册并使用。
"""
try:
self.wxcpt = WXBizMsgCrypt(self.Token, self.EnCodingAESKey, '')
await self.logger.info(f'{request.method} {request.url} {str(request.args)}')
if request.method == 'GET':
return await self._handle_get_callback()
if request.method == 'POST':
return await self._handle_post_callback()
return Response('', status=405)
except Exception:
await self.logger.error(traceback.format_exc())
return Response('Internal Server Error', status=500)
async def _handle_get_callback(self) -> tuple[Response, int] | Response:
"""处理企业微信的 GET 验证请求。"""
msg_signature = unquote(request.args.get('msg_signature', ''))
timestamp = unquote(request.args.get('timestamp', ''))
nonce = unquote(request.args.get('nonce', ''))
echostr = unquote(request.args.get('echostr', ''))
if not all([msg_signature, timestamp, nonce, echostr]):
await self.logger.error('请求参数缺失')
return Response('缺少参数', status=400)
ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0:
await self.logger.error('验证URL失败')
return Response('验证失败', status=403)
return Response(decrypted_str, mimetype='text/plain')
async def _handle_post_callback(self) -> tuple[Response, int] | Response:
"""处理企业微信的 POST 回调请求。"""
self.stream_sessions.cleanup()
msg_signature = unquote(request.args.get('msg_signature', ''))
timestamp = unquote(request.args.get('timestamp', ''))
nonce = unquote(request.args.get('nonce', ''))
encrypted_json = await request.get_json()
encrypted_msg = (encrypted_json or {}).get('encrypt', '')
if not encrypted_msg:
await self.logger.error("请求体中缺少 'encrypt' 字段")
return Response('Bad Request', status=400)
xml_post_data = f"<xml><Encrypt><![CDATA[{encrypted_msg}]]></Encrypt></xml>"
ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
if ret != 0:
await self.logger.error('解密失败')
return Response('解密失败', status=400)
msg_json = json.loads(decrypted_xml)
if msg_json.get('msgtype') == 'stream':
return await self._handle_post_followup_response(msg_json, nonce)
return await self._handle_post_initial_response(msg_json, nonce)
async def get_message(self, msg_json):
message_data = {}
if msg_json.get('chattype','') == 'single':
if msg_json.get('chattype', '') == 'single':
message_data['type'] = 'single'
elif msg_json.get('chattype','') == 'group':
elif msg_json.get('chattype', '') == 'group':
message_data['type'] = 'group'
if msg_json.get('msgtype') == 'text':
message_data['content'] = msg_json.get('text',{}).get('content')
message_data['content'] = msg_json.get('text', {}).get('content')
elif msg_json.get('msgtype') == 'image':
picurl = msg_json.get('image', {}).get('url','')
base64 = await self.download_url_to_base64(picurl,self.EnCodingAESKey)
message_data['picurl'] = base64
picurl = msg_json.get('image', {}).get('url', '')
base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
message_data['picurl'] = base64
elif msg_json.get('msgtype') == 'mixed':
items = msg_json.get('mixed', {}).get('msg_item', [])
texts = []
@@ -197,17 +460,27 @@ class WecomBotClient:
if texts:
message_data['content'] = "".join(texts) # 拼接所有 text
if picurl:
base64 = await self.download_url_to_base64(picurl,self.EnCodingAESKey)
message_data['picurl'] = base64 # 只保留第一个 image
base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
message_data['picurl'] = base64 # 只保留第一个 image
# Extract user information
from_info = msg_json.get('from', {})
message_data['userid'] = from_info.get('userid', '')
message_data['username'] = from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
# Extract chat/group information
if msg_json.get('chattype', '') == 'group':
message_data['chatid'] = msg_json.get('chatid', '')
# Try to get group name if available
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
message_data['userid'] = msg_json.get('from', {}).get('userid', '')
message_data['msgid'] = msg_json.get('msgid', '')
if msg_json.get('aibotid'):
message_data['aibotid'] = msg_json.get('aibotid', '')
return message_data
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
"""
处理消息事件。
@@ -223,10 +496,46 @@ class WecomBotClient:
for handler in self._message_handlers[msg_type]:
await handler(event)
except Exception:
print(traceback.format_exc())
print(traceback.format_exc())
async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
"""将流水线片段推送到 stream 会话。
Args:
msg_id: 原始企业微信消息 ID。
content: 模型产生的片段内容。
is_final: 是否为最终片段。
Returns:
bool: 当成功写入流式队列时返回 True。
Example:
在流水线 `reply_message_chunk` 中调用,将增量推送至企业微信。
"""
# 根据 msg_id 找到对应 stream 会话,如果不存在说明当前消息非流式
stream_id = self.stream_sessions.get_stream_id_by_msg(msg_id)
if not stream_id:
return False
chunk = StreamChunk(content=content, is_final=is_final)
await self.stream_sessions.publish(stream_id, chunk)
if is_final:
self.stream_sessions.mark_finished(stream_id)
return True
async def set_message(self, msg_id: str, content: str):
self.generated_content[msg_id] = content
"""兼容旧逻辑:若无法流式返回则缓存最终结果。
Args:
msg_id: 企业微信消息 ID。
content: 最终回复的文本内容。
Example:
在非流式场景下缓存最终结果以备刷新时返回。
"""
handled = await self.push_stream_chunk(msg_id, content, is_final=True)
if not handled:
self.generated_content[msg_id] = content
def on_message(self, msg_type: str):
def decorator(func: Callable[[wecombotevent.WecomBotEvent], None]):
@@ -237,7 +546,6 @@ class WecomBotClient:
return decorator
async def download_url_to_base64(self, download_url, encoding_aes_key):
async with httpx.AsyncClient() as client:
response = await client.get(download_url)
@@ -247,26 +555,22 @@ class WecomBotClient:
encrypted_bytes = response.content
aes_key = base64.b64decode(encoding_aes_key + "=") # base64 补齐
iv = aes_key[:16]
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted = cipher.decrypt(encrypted_bytes)
pad_len = decrypted[-1]
decrypted = decrypted[:-pad_len]
if decrypted.startswith(b"\xff\xd8"): # JPEG
if decrypted.startswith(b"\xff\xd8"): # JPEG
mime_type = "image/jpeg"
elif decrypted.startswith(b"\x89PNG"): # PNG
mime_type = "image/png"
elif decrypted.startswith((b"GIF87a", b"GIF89a")): # GIF
mime_type = "image/gif"
elif decrypted.startswith(b"BM"): # BMP
elif decrypted.startswith(b"BM"): # BMP
mime_type = "image/bmp"
elif decrypted.startswith(b"II*\x00") or decrypted.startswith(b"MM\x00*"): # TIFF
mime_type = "image/tiff"
@@ -276,15 +580,9 @@ class WecomBotClient:
# 转 base64
base64_str = base64.b64encode(decrypted).decode("utf-8")
return f"data:{mime_type};base64,{base64_str}"
async def run_task(self, host: str, port: int, *args, **kwargs):
"""
启动 Quart 应用。
"""
await self.app.run_task(host=host, port=port, *args, **kwargs)

View File

@@ -22,7 +22,21 @@ class WecomBotEvent(dict):
"""
用户id
"""
return self.get('from', {}).get('userid', '')
return self.get('from', {}).get('userid', '') or self.get('userid', '')
@property
def username(self) -> str:
"""
用户名称
"""
return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid
@property
def chatname(self) -> str:
"""
群组名称
"""
return self.get('chatname', '') or str(self.chatid)
@property
def content(self) -> str:

13
main.py
View File

@@ -18,7 +18,13 @@ asciiart = r"""
async def main_entry(loop: asyncio.AbstractEventLoop):
parser = argparse.ArgumentParser(description='LangBot')
parser.add_argument('--standalone-runtime', action='store_true', help='使用独立插件运行时', default=False)
parser.add_argument(
'--standalone-runtime',
action='store_true',
help='Use standalone plugin runtime / 使用独立插件运行时',
default=False,
)
parser.add_argument('--debug', action='store_true', help='Debug mode / 调试模式', default=False)
args = parser.parse_args()
if args.standalone_runtime:
@@ -26,6 +32,11 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
platform.standalone_runtime = True
if args.debug:
from pkg.utils import constants
constants.debug_mode = True
print(asciiart)
import sys

View File

@@ -15,6 +15,9 @@ class FilesRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/image/<image_key>', methods=['GET'], auth_type=group.AuthType.NONE)
async def _(image_key: str) -> quart.Response:
if '/' in image_key or '\\' in image_key:
return quart.Response(status=404)
if not await self.ap.storage_mgr.storage_provider.exists(image_key):
return quart.Response(status=404)
@@ -36,6 +39,10 @@ class FilesRouterGroup(group.RouterGroup):
extension = file.filename.split('.')[-1]
file_name = file.filename.split('.')[0]
# check if file name contains '/' or '\'
if '/' in file_name or '\\' in file_name:
return self.fail(400, 'File name contains invalid characters')
file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension
# save file to storage
await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes)

View File

@@ -2,6 +2,10 @@ from __future__ import annotations
import base64
import quart
import re
import httpx
import uuid
import os
from .....core import taskmgr
from .. import group
@@ -45,9 +49,12 @@ class PluginsRouterGroup(group.RouterGroup):
return self.http_status(404, -1, 'plugin not found')
return self.success(data={'plugin': plugin})
elif quart.request.method == 'DELETE':
delete_data = quart.request.args.get('delete_data', 'false').lower() == 'true'
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_connector.delete_plugin(author, plugin_name, task_context=ctx),
self.ap.plugin_connector.delete_plugin(
author, plugin_name, delete_data=delete_data, task_context=ctx
),
kind='plugin-operation',
name=f'plugin-remove-{plugin_name}',
label=f'Removing plugin {plugin_name}',
@@ -89,23 +96,145 @@ class PluginsRouterGroup(group.RouterGroup):
return quart.Response(icon_data, mimetype=mime_type)
@self.route('/github/releases', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""Get releases from a GitHub repository URL"""
data = await quart.request.json
repo_url = data.get('repo_url', '')
# Parse GitHub repository URL to extract owner and repo
# Supports: https://github.com/owner/repo or github.com/owner/repo
pattern = r'github\.com/([^/]+)/([^/]+?)(?:\.git)?(?:/.*)?$'
match = re.search(pattern, repo_url)
if not match:
return self.http_status(400, -1, 'Invalid GitHub repository URL')
owner, repo = match.groups()
try:
# Fetch releases from GitHub API
url = f'https://api.github.com/repos/{owner}/{repo}/releases'
async with httpx.AsyncClient(
trust_env=True,
follow_redirects=True,
timeout=10,
) as client:
response = await client.get(url)
response.raise_for_status()
releases = response.json()
# Format releases data for frontend
formatted_releases = []
for release in releases:
formatted_releases.append(
{
'id': release['id'],
'tag_name': release['tag_name'],
'name': release['name'],
'published_at': release['published_at'],
'prerelease': release['prerelease'],
'draft': release['draft'],
}
)
return self.success(data={'releases': formatted_releases, 'owner': owner, 'repo': repo})
except httpx.RequestError as e:
return self.http_status(500, -1, f'Failed to fetch releases: {str(e)}')
@self.route(
'/github/release-assets',
methods=['POST'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _() -> str:
"""Get assets from a specific GitHub release"""
data = await quart.request.json
owner = data.get('owner', '')
repo = data.get('repo', '')
release_id = data.get('release_id', '')
if not all([owner, repo, release_id]):
return self.http_status(400, -1, 'Missing required parameters')
try:
# Fetch release assets from GitHub API
url = f'https://api.github.com/repos/{owner}/{repo}/releases/{release_id}'
async with httpx.AsyncClient(
trust_env=True,
follow_redirects=True,
timeout=10,
) as client:
response = await client.get(
url,
)
response.raise_for_status()
release = response.json()
# Format assets data for frontend
formatted_assets = []
for asset in release.get('assets', []):
formatted_assets.append(
{
'id': asset['id'],
'name': asset['name'],
'size': asset['size'],
'download_url': asset['browser_download_url'],
'content_type': asset['content_type'],
}
)
# add zipball as a downloadable asset
# formatted_assets.append(
# {
# "id": 0,
# "name": "Source code (zip)",
# "size": -1,
# "download_url": release["zipball_url"],
# "content_type": "application/zip",
# }
# )
return self.success(data={'assets': formatted_assets})
except httpx.RequestError as e:
return self.http_status(500, -1, f'Failed to fetch release assets: {str(e)}')
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""Install plugin from GitHub release asset"""
data = await quart.request.json
asset_url = data.get('asset_url', '')
owner = data.get('owner', '')
repo = data.get('repo', '')
release_tag = data.get('release_tag', '')
if not asset_url:
return self.http_status(400, -1, 'Missing asset_url parameter')
ctx = taskmgr.TaskContext.new()
short_source_str = data['source'][-8:]
install_info = {
'asset_url': asset_url,
'owner': owner,
'repo': repo,
'release_tag': release_tag,
'github_url': f'https://github.com/{owner}/{repo}',
}
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
self.ap.plugin_connector.install_plugin(PluginInstallSource.GITHUB, install_info, task_context=ctx),
kind='plugin-operation',
name='plugin-install-github',
label=f'Installing plugin from github ...{short_source_str}',
label=f'Installing plugin from GitHub {owner}/{repo}@{release_tag}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/install/marketplace', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
@self.route(
'/install/marketplace',
methods=['POST'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _() -> str:
data = await quart.request.json
@@ -128,10 +257,8 @@ class PluginsRouterGroup(group.RouterGroup):
file_bytes = file.read()
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
data = {
'plugin_file': file_base64,
'plugin_file': file_bytes,
}
ctx = taskmgr.TaskContext.new()
@@ -144,3 +271,39 @@ class PluginsRouterGroup(group.RouterGroup):
)
return self.success(data={'task_id': wrapper.id})
@self.route('/config-files', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""Upload a file for plugin configuration"""
file = (await quart.request.files).get('file')
if file is None:
return self.http_status(400, -1, 'file is required')
# Check file size (10MB limit)
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
file_bytes = file.read()
if len(file_bytes) > MAX_FILE_SIZE:
return self.http_status(400, -1, 'file size exceeds 10MB limit')
# Generate unique file key with original extension
original_filename = file.filename
_, ext = os.path.splitext(original_filename)
file_key = f'plugin_config_{uuid.uuid4().hex}{ext}'
# Save file using storage manager
await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes)
return self.success(data={'file_key': file_key})
@self.route('/config-files/<file_key>', methods=['DELETE'], auth_type=group.AuthType.USER_TOKEN)
async def _(file_key: str) -> str:
"""Delete a plugin configuration file"""
# Only allow deletion of files with plugin_config_ prefix for security
if not file_key.startswith('plugin_config_'):
return self.http_status(400, -1, 'invalid file key')
try:
await self.ap.storage_mgr.storage_provider.delete(file_key)
return self.success(data={'deleted': True})
except Exception as e:
return self.http_status(500, -1, f'failed to delete file: {str(e)}')

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import quart
import traceback
from ... import group
@group.group_class('mcp', '/api/v1/mcp')
class MCPRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/servers', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""获取MCP服务器列表"""
if quart.request.method == 'GET':
servers = await self.ap.mcp_service.get_mcp_servers(contain_runtime_info=True)
return self.success(data={'servers': servers})
elif quart.request.method == 'POST':
data = await quart.request.json
try:
uuid = await self.ap.mcp_service.create_mcp_server(data)
return self.success(data={'uuid': uuid})
except Exception as e:
traceback.print_exc()
return self.http_status(500, -1, f'Failed to create MCP server: {str(e)}')
@self.route('/servers/<server_name>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
"""获取、更新或删除MCP服务器配置"""
server_data = await self.ap.mcp_service.get_mcp_server_by_name(server_name)
if server_data is None:
return self.http_status(404, -1, 'Server not found')
if quart.request.method == 'GET':
return self.success(data={'server': server_data})
elif quart.request.method == 'PUT':
data = await quart.request.json
try:
await self.ap.mcp_service.update_mcp_server(server_data['uuid'], data)
return self.success()
except Exception as e:
return self.http_status(500, -1, f'Failed to update MCP server: {str(e)}')
elif quart.request.method == 'DELETE':
try:
await self.ap.mcp_service.delete_mcp_server(server_data['uuid'])
return self.success()
except Exception as e:
return self.http_status(500, -1, f'Failed to delete MCP server: {str(e)}')
@self.route('/servers/<server_name>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
"""测试MCP服务器连接"""
server_data = await quart.request.json
task_id = await self.ap.mcp_service.test_mcp_server(server_name=server_name, server_data=server_data)
return self.success(data={'task_id': task_id})

View File

@@ -91,3 +91,26 @@ class SystemRouterGroup(group.RouterGroup):
)
return self.success(data=resp)
@self.route(
'/status/plugin-system',
methods=['GET'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _() -> str:
plugin_connector_error = 'ok'
is_connected = True
try:
await self.ap.plugin_connector.ping_plugin_runtime()
except Exception as e:
plugin_connector_error = str(e)
is_connected = False
return self.success(
data={
'is_enable': self.ap.plugin_connector.is_enable_plugin,
'is_connected': is_connected,
'plugin_connector_error': plugin_connector_error,
}
)

View File

@@ -15,12 +15,14 @@ from .groups import provider as groups_provider
from .groups import platform as groups_platform
from .groups import pipelines as groups_pipelines
from .groups import knowledge as groups_knowledge
from .groups import resources as groups_resources
importutil.import_modules_in_pkg(groups)
importutil.import_modules_in_pkg(groups_provider)
importutil.import_modules_in_pkg(groups_platform)
importutil.import_modules_in_pkg(groups_pipelines)
importutil.import_modules_in_pkg(groups_knowledge)
importutil.import_modules_in_pkg(groups_resources)
class HTTPController:

158
pkg/api/http/service/mcp.py Normal file
View File

@@ -0,0 +1,158 @@
from __future__ import annotations
import sqlalchemy
import uuid
import asyncio
from ....core import app
from ....entity.persistence import mcp as persistence_mcp
from ....core import taskmgr
from ....provider.tools.loaders.mcp import RuntimeMCPSession, MCPSessionStatus
class MCPService:
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_runtime_info(self, server_name: str) -> dict | None:
session = self.ap.tool_mgr.mcp_tool_loader.get_session(server_name)
if session:
return session.get_runtime_info_dict()
return None
async def get_mcp_servers(self, contain_runtime_info: bool = False) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
servers = result.all()
serialized_servers = [
self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server) for server in servers
]
if contain_runtime_info:
for server in serialized_servers:
runtime_info = await self.get_runtime_info(server['name'])
server['runtime_info'] = runtime_info if runtime_info else None
return serialized_servers
async def create_mcp_server(self, server_data: dict) -> str:
server_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_mcp.MCPServer).values(server_data))
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_data['uuid'])
)
server_entity = result.first()
if server_entity:
server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server_entity)
if self.ap.tool_mgr.mcp_tool_loader:
task = asyncio.create_task(self.ap.tool_mgr.mcp_tool_loader.host_mcp_server(server_config))
self.ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks.append(task)
return server_data['uuid']
async def get_mcp_server_by_name(self, server_name: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.name == server_name)
)
server = result.first()
if server is None:
return None
runtime_info = await self.get_runtime_info(server.name)
server_data = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
server_data['runtime_info'] = runtime_info if runtime_info else None
return server_data
async def update_mcp_server(self, server_uuid: str, server_data: dict) -> None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
old_server = result.first()
old_server_name = old_server.name if old_server else None
old_enable = old_server.enable if old_server else False
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_mcp.MCPServer)
.where(persistence_mcp.MCPServer.uuid == server_uuid)
.values(server_data)
)
if self.ap.tool_mgr.mcp_tool_loader:
new_enable = server_data.get('enable', False)
need_remove = old_server_name and old_server_name in self.ap.tool_mgr.mcp_tool_loader.sessions
need_start = new_enable
if old_enable and not new_enable:
if need_remove:
await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(old_server_name)
elif not old_enable and new_enable:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
updated_server = result.first()
if updated_server:
server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, updated_server)
task = asyncio.create_task(self.ap.tool_mgr.mcp_tool_loader.host_mcp_server(server_config))
self.ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks.append(task)
elif old_enable and new_enable:
if need_remove:
await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(old_server_name)
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
updated_server = result.first()
if updated_server:
server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, updated_server)
task = asyncio.create_task(self.ap.tool_mgr.mcp_tool_loader.host_mcp_server(server_config))
self.ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks.append(task)
async def delete_mcp_server(self, server_uuid: str) -> None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
server = result.first()
server_name = server.name if server else None
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
if server_name and self.ap.tool_mgr.mcp_tool_loader:
if server_name in self.ap.tool_mgr.mcp_tool_loader.sessions:
await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(server_name)
async def test_mcp_server(self, server_name: str, server_data: dict) -> int:
"""测试 MCP 服务器连接并返回任务 ID"""
runtime_mcp_session: RuntimeMCPSession | None = None
if server_name != '_':
runtime_mcp_session = self.ap.tool_mgr.mcp_tool_loader.get_session(server_name)
if runtime_mcp_session is None:
raise ValueError(f'Server not found: {server_name}')
if runtime_mcp_session.status == MCPSessionStatus.ERROR:
coroutine = runtime_mcp_session.start()
else:
coroutine = runtime_mcp_session.refresh()
else:
runtime_mcp_session = await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config=server_data)
coroutine = runtime_mcp_session.start()
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
coroutine,
kind='mcp-operation',
name=f'mcp-test-{server_name}',
label=f'Testing MCP server {server_name}',
context=ctx,
)
return wrapper.id

View File

@@ -1,13 +1,14 @@
from __future__ import annotations
import uuid
import sqlalchemy
from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ....core import app
from ....entity.persistence import model as persistence_model
from ....entity.persistence import pipeline as persistence_pipeline
from ....provider.modelmgr import requester as model_requester
from langbot_plugin.api.entities.builtin.provider import message as provider_message
class LLMModelsService:
@@ -104,12 +105,17 @@ class LLMModelsService:
else:
runtime_llm_model = await self.ap.model_mgr.init_runtime_llm_model(model_data)
# 有些模型厂商默认开启了思考功能,测试容易延迟
extra_args = model_data.get('extra_args', {})
if not extra_args or 'thinking' not in extra_args:
extra_args['thinking'] = {'type': 'disabled'}
await runtime_llm_model.requester.invoke_llm(
query=None,
model=runtime_llm_model,
messages=[provider_message.Message(role='user', content='Hello, world!')],
messages=[provider_message.Message(role='user', content='Hello, world! Please just reply a "Hello".')],
funcs=[],
extra_args=model_data.get('extra_args', {}),
extra_args=extra_args,
)

View File

@@ -22,6 +22,7 @@ from ..api.http.service import model as model_service
from ..api.http.service import pipeline as pipeline_service
from ..api.http.service import bot as bot_service
from ..api.http.service import knowledge as knowledge_service
from ..api.http.service import mcp as mcp_service
from ..discover import engine as discover_engine
from ..storage import mgr as storagemgr
from ..utils import logcache
@@ -119,6 +120,8 @@ class Application:
knowledge_service: knowledge_service.KnowledgeService = None
mcp_service: mcp_service.MCPService = None
def __init__(self):
pass

View File

@@ -19,6 +19,7 @@ from ...api.http.service import model as model_service
from ...api.http.service import pipeline as pipeline_service
from ...api.http.service import bot as bot_service
from ...api.http.service import knowledge as knowledge_service
from ...api.http.service import mcp as mcp_service
from ...discover import engine as discover_engine
from ...storage import mgr as storagemgr
from ...utils import logcache
@@ -126,5 +127,8 @@ class BuildAppStage(stage.BootingStage):
knowledge_service_inst = knowledge_service.KnowledgeService(ap)
ap.knowledge_service = knowledge_service_inst
mcp_service_inst = mcp_service.MCPService(ap)
ap.mcp_service = mcp_service_inst
ctrl = controller.Controller(ap)
ap.ctrl = ctrl

View File

@@ -1,11 +1,93 @@
from __future__ import annotations
import os
from typing import Any
from .. import stage, app
from ..bootutils import config
def _apply_env_overrides_to_config(cfg: dict) -> dict:
"""Apply environment variable overrides to data/config.yaml
Environment variables should be uppercase and use __ (double underscore)
to represent nested keys. For example:
- CONCURRENCY__PIPELINE overrides concurrency.pipeline
- PLUGIN__RUNTIME_WS_URL overrides plugin.runtime_ws_url
Arrays and dict types are ignored.
Args:
cfg: Configuration dictionary
Returns:
Updated configuration dictionary
"""
def convert_value(value: str, original_value: Any) -> Any:
"""Convert string value to appropriate type based on original value
Args:
value: String value from environment variable
original_value: Original value to infer type from
Returns:
Converted value (falls back to string if conversion fails)
"""
if isinstance(original_value, bool):
return value.lower() in ('true', '1', 'yes', 'on')
elif isinstance(original_value, int):
try:
return int(value)
except ValueError:
# If conversion fails, keep as string (user error, but non-breaking)
return value
elif isinstance(original_value, float):
try:
return float(value)
except ValueError:
# If conversion fails, keep as string (user error, but non-breaking)
return value
else:
return value
# Process environment variables
for env_key, env_value in os.environ.items():
# Check if the environment variable is uppercase and contains __
if not env_key.isupper():
continue
if '__' not in env_key:
continue
print(f'apply env overrides to config: env_key: {env_key}, env_value: {env_value}')
# Convert environment variable name to config path
# e.g., CONCURRENCY__PIPELINE -> ['concurrency', 'pipeline']
keys = [key.lower() for key in env_key.split('__')]
# Navigate to the target value and validate the path
current = cfg
for i, key in enumerate(keys):
if not isinstance(current, dict) or key not in current:
break
if i == len(keys) - 1:
# At the final key - check if it's a scalar value
if isinstance(current[key], (dict, list)):
# Skip dict and list types
pass
else:
# Valid scalar value - convert and set it
converted_value = convert_value(env_value, current[key])
current[key] = converted_value
else:
# Navigate deeper
current = current[key]
return cfg
@stage.stage_class('LoadConfigStage')
class LoadConfigStage(stage.BootingStage):
"""Load config file stage"""
@@ -54,6 +136,10 @@ class LoadConfigStage(stage.BootingStage):
ap.instance_config = await config.load_yaml_config(
'data/config.yaml', 'templates/config.yaml', completion=False
)
# Apply environment variable overrides to data/config.yaml
ap.instance_config.data = _apply_env_overrides_to_config(ap.instance_config.data)
await ap.instance_config.dump_config()
ap.sensitive_meta = await config.load_json_config(

View File

@@ -156,7 +156,7 @@ class TaskWrapper:
'state': self.task._state,
'exception': self.assume_exception().__str__() if self.assume_exception() is not None else None,
'exception_traceback': exception_traceback,
'result': self.assume_result().__str__() if self.assume_result() is not None else None,
'result': self.assume_result() if self.assume_result() is not None else None,
},
}

View File

@@ -0,0 +1,20 @@
import sqlalchemy
from .base import Base
class MCPServer(Base):
__tablename__ = 'mcp_servers'
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
mode = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) # stdio, sse
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,
server_default=sqlalchemy.func.now(),
onupdate=sqlalchemy.func.now(),
)

View File

@@ -21,10 +21,15 @@ class LongTextProcessStage(stage.PipelineStage):
- resp_message_chain
"""
strategy_impl: strategy.LongTextStrategy
strategy_impl: strategy.LongTextStrategy | None
async def initialize(self, pipeline_config: dict):
config = pipeline_config['output']['long-text-processing']
if config['strategy'] == 'none':
self.strategy_impl = None
return
if config['strategy'] == 'image':
use_font = config['font-path']
try:
@@ -67,6 +72,10 @@ class LongTextProcessStage(stage.PipelineStage):
await self.strategy_impl.initialize()
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
if self.strategy_impl is None:
self.ap.logger.debug('Long message processing strategy is not set, skip long message processing.')
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
# 检查是否包含非 Plain 组件
contains_non_plain = False

View File

@@ -26,7 +26,7 @@ class ForwardComponentStrategy(strategy_model.LongTextStrategy):
platform_message.ForwardMessageNode(
sender_id=query.adapter.bot_account_id,
sender_name='User',
message_chain=platform_message.MessageChain([message]),
message_chain=platform_message.MessageChain([platform_message.Plain(text=message)]),
)
]

View File

@@ -96,7 +96,7 @@ class RuntimePipeline:
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
result.user_notice.insert(0, platform_message.At(query.message_event.sender.id))
result.user_notice.insert(0, platform_message.At(target=query.message_event.sender.id))
if await query.adapter.is_stream_output_supported():
await query.adapter.reply_message_chunk(
message_source=query.message_event,
@@ -213,7 +213,7 @@ class RuntimePipeline:
await self._execute_from_stage(0, query)
except Exception as e:
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
self.ap.logger.error(f'Error processing query {query.query_id} stage={inst_name} : {e}')
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
finally:
self.ap.logger.debug(f'Query {query.query_id} processed')

View File

@@ -35,11 +35,17 @@ class PreProcessor(stage.PipelineStage):
session = await self.ap.sess_mgr.get_session(query)
# When not local-agent, llm_model is None
llm_model = (
await self.ap.model_mgr.get_model_by_uuid(query.pipeline_config['ai']['local-agent']['model'])
if selected_runner == 'local-agent'
else None
)
try:
llm_model = (
await self.ap.model_mgr.get_model_by_uuid(query.pipeline_config['ai']['local-agent']['model'])
if selected_runner == 'local-agent'
else None
)
except ValueError:
self.ap.logger.warning(
f'LLM model {query.pipeline_config["ai"]["local-agent"]["model"] + " "}not found or not configured'
)
llm_model = None
conversation = await self.ap.sess_mgr.get_conversation(
query,
@@ -54,7 +60,7 @@ class PreProcessor(stage.PipelineStage):
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
if selected_runner == 'local-agent':
if selected_runner == 'local-agent' and llm_model:
query.use_funcs = []
query.use_llm_model_uuid = llm_model.model_entity.uuid
@@ -72,7 +78,11 @@ class PreProcessor(stage.PipelineStage):
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if selected_runner == 'local-agent' and not llm_model.model_entity.abilities.__contains__('vision'):
if (
selected_runner == 'local-agent'
and llm_model
and not llm_model.model_entity.abilities.__contains__('vision')
):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
@@ -89,15 +99,22 @@ class PreProcessor(stage.PipelineStage):
content_list.append(provider_message.ContentElement.from_text(me.text))
plain_text += me.text
elif isinstance(me, platform_message.Image):
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
if selected_runner != 'local-agent' or (
llm_model and llm_model.model_entity.abilities.__contains__('vision')
):
if me.base64 is not None:
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
elif isinstance(me, platform_message.File):
# if me.url is not None:
content_list.append(provider_message.ContentElement.from_file_url(me.url, me.name))
elif isinstance(me, platform_message.Quote) and qoute_msg:
for msg in me.origin:
if isinstance(msg, platform_message.Plain):
content_list.append(provider_message.ContentElement.from_text(msg.text))
elif isinstance(msg, platform_message.Image):
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
if selected_runner != 'local-agent' or (
llm_model and llm_model.model_entity.abilities.__contains__('vision')
):
if msg.base64 is not None:
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))

View File

@@ -9,7 +9,6 @@ from .. import handler
from ... import entities
from ....provider import runner as runner_module
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.events as events
from ....utils import importutil
from ....provider import runners
@@ -47,18 +46,19 @@ class ChatMessageHandler(handler.MessageHandler):
event_ctx = await self.ap.plugin_connector.emit_event(event)
is_create_card = False # 判断下是否需要创建流式卡片
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = platform_message.MessageChain(event_ctx.event.reply)
if event_ctx.event.reply_message_chain is not None:
mc = event_ctx.event.reply_message_chain
query.resp_messages.append(mc)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
else:
if event_ctx.event.alter is not None:
if event_ctx.event.user_message_alter is not None:
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
query.user_message.content = event_ctx.event.alter
query.user_message.content = event_ctx.event.user_message_alter
text_length = 0
try:

View File

@@ -5,7 +5,6 @@ import typing
from .. import handler
from ... import entities
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.events as events
@@ -49,8 +48,8 @@ class CommandHandler(handler.MessageHandler):
event_ctx = await self.ap.plugin_connector.emit_event(event)
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = platform_message.MessageChain(event_ctx.event.reply)
if event_ctx.event.reply_message_chain is not None:
mc = event_ctx.event.reply_message_chain
query.resp_messages.append(mc)
@@ -59,9 +58,6 @@ class CommandHandler(handler.MessageHandler):
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
else:
if event_ctx.event.alter is not None:
query.message_chain = platform_message.MessageChain([platform_message.Plain(event_ctx.event.alter)])
session = await self.ap.sess_mgr.get_session(query)
async for ret in self.ap.cmd_mgr.execute(
@@ -78,7 +74,12 @@ class CommandHandler(handler.MessageHandler):
self.ap.logger.info(f'Command({query.query_id}) error: {self.cut_str(str(ret.error))}')
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif ret.text is not None or ret.image_url is not None or ret.image_base64 is not None:
elif (
ret.text is not None
or ret.image_url is not None
or ret.image_base64 is not None
or ret.file_url is not None
):
content: list[provider_message.ContentElement] = []
if ret.text is not None:
@@ -90,6 +91,9 @@ class CommandHandler(handler.MessageHandler):
if ret.image_base64 is not None:
content.append(provider_message.ContentElement.from_image_base64(ret.image_base64))
if ret.file_url is not None:
# 此时为 file 类型
content.append(provider_message.ContentElement.from_file_url(ret.file_url, ret.file_name))
query.resp_messages.append(
provider_message.Message(
role='command',

View File

@@ -33,7 +33,7 @@ class SendResponseBackStage(stage.PipelineStage):
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
query.resp_message_chain[-1].insert(0, platform_message.At(query.message_event.sender.id))
query.resp_message_chain[-1].insert(0, platform_message.At(target=query.message_event.sender.id))
quote_origin = query.pipeline_config['output']['misc']['quote-origin']

View File

@@ -16,26 +16,17 @@ class AtBotRule(rule_model.GroupRespondRule):
rule_dict: dict,
query: pipeline_query.Query,
) -> entities.RuleJudgeResult:
found = False
def remove_at(message_chain: platform_message.MessageChain):
nonlocal found
for component in message_chain.root:
if isinstance(component, platform_message.At) and component.target == query.adapter.bot_account_id:
if isinstance(component, platform_message.At) and str(component.target) == str(query.adapter.bot_account_id):
message_chain.remove(component)
found = True
break
remove_at(message_chain)
remove_at(message_chain) # 回复消息时会at两次检查并删除重复的
# if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
# message_chain.remove(platform_message.At(query.adapter.bot_account_id))
# if message_chain.has(
# platform_message.At(query.adapter.bot_account_id)
# ): # 回复消息时会at两次检查并删除重复的
# message_chain.remove(platform_message.At(query.adapter.bot_account_id))
# return entities.RuleJudgeResult(
# matching=True,
# replacement=message_chain,
# )
return entities.RuleJudgeResult(matching=False, replacement=message_chain)
return entities.RuleJudgeResult(matching=found, replacement=message_chain)

View File

@@ -80,8 +80,8 @@ class ResponseWrapper(stage.PipelineStage):
new_query=query,
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
if event_ctx.event.reply_message_chain is not None:
query.resp_message_chain.append(event_ctx.event.reply_message_chain)
else:
query.resp_message_chain.append(result.get_content_platform_message_chain())
@@ -123,10 +123,8 @@ class ResponseWrapper(stage.PipelineStage):
new_query=query,
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(
platform_message.MessageChain(text=event_ctx.event.reply)
)
if event_ctx.event.reply_message_chain is not None:
query.resp_message_chain.append(event_ctx.event.reply_message_chain)
else:
query.resp_message_chain.append(

View File

@@ -41,6 +41,8 @@ class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverte
yiri_msg_list.append(platform_message.Plain(text=text_content))
if event.picture:
yiri_msg_list.append(platform_message.Image(base64=event.picture))
if event.file:
yiri_msg_list.append(platform_message.File(url=event.file, name=event.name))
if event.audio:
yiri_msg_list.append(platform_message.Voice(base64=event.audio))

View File

@@ -22,6 +22,7 @@ import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_
from ..logger import EventLogger
# 语音功能相关异常定义
class VoiceConnectionError(Exception):
"""语音连接基础异常"""

View File

@@ -139,19 +139,15 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
def __init__(self, config: dict, logger: EventLogger):
self.config = config
self.logger = logger
bot = QQOfficialClient(
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=logger
)
required_keys = [
'appid',
'secret',
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise command_errors.ParamNotEnoughError('QQ官方机器人缺少相关配置项请查看文档或联系管理员')
self.bot = QQOfficialClient(
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger
super().__init__(
config=config,
logger=logger,
bot=bot,
bot_account_id=config['appid'],
)
async def reply_message(

View File

@@ -102,7 +102,7 @@ class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
sender=platform_entities.Friend(
id=event.effective_chat.id,
nickname=event.effective_chat.first_name,
remark=event.effective_chat.id,
remark=str(event.effective_chat.id),
),
message_chain=lb_message,
time=event.message.date.timestamp(),

View File

@@ -139,7 +139,7 @@ class WeChatPadMessageConverter(abstract_platform_adapter.AbstractMessageConvert
pattern = r'@\S{1,20}'
content_no_preifx = re.sub(pattern, '', content_no_preifx)
return platform_message.MessageChain([platform_message.Plain(content_no_preifx)])
return platform_message.MessageChain([platform_message.Plain(text=content_no_preifx)])
async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理图像消息 (msg_type=3)"""
@@ -265,7 +265,7 @@ class WeChatPadMessageConverter(abstract_platform_adapter.AbstractMessageConvert
# 文本消息
try:
if '<msg>' not in quote_data:
quote_data_message_list.append(platform_message.Plain(quote_data))
quote_data_message_list.append(platform_message.Plain(text=quote_data))
else:
# 引用消息展开
quote_data_xml = ET.fromstring(quote_data)
@@ -280,7 +280,7 @@ class WeChatPadMessageConverter(abstract_platform_adapter.AbstractMessageConvert
quote_data_message_list.extend(await self._handler_compound(None, quote_data))
except Exception as e:
self.logger.error(f'处理引用消息异常 expcetion:{e}')
quote_data_message_list.append(platform_message.Plain(quote_data))
quote_data_message_list.append(platform_message.Plain(text=quote_data))
message_list.append(
platform_message.Quote(
sender_id=sender_id,
@@ -290,7 +290,7 @@ class WeChatPadMessageConverter(abstract_platform_adapter.AbstractMessageConvert
if len(user_data) > 0:
pattern = r'@\S{1,20}'
user_data = re.sub(pattern, '', user_data)
message_list.append(platform_message.Plain(user_data))
message_list.append(platform_message.Plain(text=user_data))
return platform_message.MessageChain(message_list)
@@ -543,7 +543,6 @@ class WeChatPadAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
] = {}
def __init__(self, config: dict, logger: EventLogger):
quart_app = quart.Quart(__name__)
message_converter = WeChatPadMessageConverter(config, logger)
@@ -551,15 +550,14 @@ class WeChatPadAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
bot = WeChatPadClient(config['wechatpad_url'], config['token'])
super().__init__(
config=config,
logger = logger,
quart_app = quart_app,
message_converter =message_converter,
event_converter = event_converter,
logger=logger,
quart_app=quart_app,
message_converter=message_converter,
event_converter=event_converter,
listeners={},
bot_account_id ='',
name="WeChatPad",
bot_account_id='',
name='WeChatPad',
bot=bot,
)
async def ws_message(self, data):

View File

@@ -49,7 +49,7 @@ class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
return platform_events.FriendMessage(
sender=platform_entities.Friend(
id=event.userid,
nickname='',
nickname=event.username,
remark='',
),
message_chain=message_chain,
@@ -61,10 +61,10 @@ class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
sender = platform_entities.GroupMember(
id=event.userid,
permission='MEMBER',
member_name=event.userid,
member_name=event.username,
group=platform_entities.Group(
id=str(event.chatid),
name='',
name=event.chatname,
permission=platform_entities.Permission.Member,
),
special_title='',
@@ -117,6 +117,50 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
content = await self.message_converter.yiri2target(message)
await self.bot.set_message(message_source.source_platform_object.message_id, content)
async def reply_message_chunk(
self,
message_source: platform_events.MessageEvent,
bot_message,
message: platform_message.MessageChain,
quote_origin: bool = False,
is_final: bool = False,
):
"""将流水线增量输出写入企业微信 stream 会话。
Args:
message_source: 流水线提供的原始消息事件。
bot_message: 当前片段对应的模型元信息(未使用)。
message: 需要回复的消息链。
quote_origin: 是否引用原消息(企业微信暂不支持)。
is_final: 标记当前片段是否为最终回复。
Returns:
dict: 包含 `stream` 键,标识写入是否成功。
Example:
在流水线 `reply_message_chunk` 调用中自动触发,无需手动调用。
"""
# 转换为纯文本(智能机器人当前协议仅支持文本流)
content = await self.message_converter.yiri2target(message)
msg_id = message_source.source_platform_object.message_id
# 将片段推送到 WecomBotClient 中的队列,返回值用于判断是否走降级逻辑
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
if not success and is_final:
# 未命中流式队列时使用旧有 set_message 兜底
await self.bot.set_message(msg_id, content)
return {'stream': success}
async def is_stream_output_supported(self) -> bool:
"""智能机器人侧默认开启流式能力。
Returns:
bool: 恒定返回 True。
Example:
流水线执行阶段会调用此方法以确认是否启用流式。"""
return True
async def send_message(self, target_type, target_id, message):
pass

View File

@@ -6,19 +6,24 @@ from typing import Any
import typing
import os
import sys
import httpx
from async_lru import alru_cache
from ..core import app
from . import handler
from ..utils import platform
from langbot_plugin.runtime.io.controllers.stdio import client as stdio_client_controller
from langbot_plugin.runtime.io.controllers.stdio import (
client as stdio_client_controller,
)
from langbot_plugin.runtime.io.controllers.ws import client as ws_client_controller
from langbot_plugin.api.entities import events
from langbot_plugin.api.entities import context
import langbot_plugin.runtime.io.connection as base_connection
from langbot_plugin.api.definition.components.manifest import ComponentManifest
from langbot_plugin.api.entities.builtin.command import context as command_context
from langbot_plugin.api.entities.builtin.command import (
context as command_context,
errors as command_errors,
)
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
from ..core import taskmgr
@@ -32,6 +37,8 @@ class PluginRuntimeConnector:
handler_task: asyncio.Task
heartbeat_task: asyncio.Task | None = None
stdio_client_controller: stdio_client_controller.StdioClientController
ctrl: stdio_client_controller.StdioClientController | ws_client_controller.WebSocketClientController
@@ -40,6 +47,9 @@ class PluginRuntimeConnector:
[PluginRuntimeConnector], typing.Coroutine[typing.Any, typing.Any, None]
]
is_enable_plugin: bool = True
"""Mark if the plugin system is enabled"""
def __init__(
self,
ap: app.Application,
@@ -49,10 +59,26 @@ class PluginRuntimeConnector:
):
self.ap = ap
self.runtime_disconnect_callback = runtime_disconnect_callback
self.is_enable_plugin = self.ap.instance_config.data.get('plugin', {}).get('enable', True)
async def heartbeat_loop(self):
while True:
await asyncio.sleep(20)
try:
await self.ping_plugin_runtime()
self.ap.logger.debug('Heartbeat to plugin runtime success.')
except Exception as e:
self.ap.logger.debug(f'Failed to heartbeat to plugin runtime: {e}')
async def initialize(self):
if not self.is_enable_plugin:
self.ap.logger.info('Plugin system is disabled.')
return
async def new_connection_callback(connection: base_connection.Connection):
async def disconnect_callback(rchandler: handler.RuntimeConnectionHandler) -> bool:
async def disconnect_callback(
rchandler: handler.RuntimeConnectionHandler,
) -> bool:
if platform.get_platform() == 'docker' or platform.use_websocket_to_connect_plugin_runtime():
self.ap.logger.error('Disconnected from plugin runtime, trying to reconnect...')
await self.runtime_disconnect_callback(self)
@@ -64,6 +90,7 @@ class PluginRuntimeConnector:
return False
self.handler = handler.RuntimeConnectionHandler(connection, disconnect_callback, self.ap)
self.handler_task = asyncio.create_task(self.handler.run())
_ = await self.handler.ping()
self.ap.logger.info('Connected to plugin runtime.')
@@ -77,8 +104,14 @@ class PluginRuntimeConnector:
'runtime_ws_url', 'ws://langbot_plugin_runtime:5400/control/ws'
)
async def make_connection_failed_callback(ctrl: ws_client_controller.WebSocketClientController) -> None:
self.ap.logger.error('Failed to connect to plugin runtime, trying to reconnect...')
async def make_connection_failed_callback(
ctrl: ws_client_controller.WebSocketClientController,
exc: Exception = None,
) -> None:
if exc is not None:
self.ap.logger.error(f'Failed to connect to plugin runtime({ws_url}): {exc}')
else:
self.ap.logger.error(f'Failed to connect to plugin runtime({ws_url}), trying to reconnect...')
await self.runtime_disconnect_callback(self)
self.ctrl = ws_client_controller.WebSocketClientController(
@@ -98,17 +131,53 @@ class PluginRuntimeConnector:
)
task = self.ctrl.run(new_connection_callback)
if self.heartbeat_task is None:
self.heartbeat_task = asyncio.create_task(self.heartbeat_loop())
asyncio.create_task(task)
async def initialize_plugins(self):
pass
async def ping_plugin_runtime(self):
if not hasattr(self, 'handler'):
raise Exception('Plugin runtime is not connected')
return await self.handler.ping()
async def install_plugin(
self,
install_source: PluginInstallSource,
install_info: dict[str, Any],
task_context: taskmgr.TaskContext | None = None,
):
if install_source == PluginInstallSource.LOCAL:
# transfer file before install
file_bytes = install_info['plugin_file']
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
install_info['plugin_file_key'] = file_key
del install_info['plugin_file']
self.ap.logger.info(f'Transfered file {file_key} to plugin runtime')
elif install_source == PluginInstallSource.GITHUB:
# download and transfer file
try:
async with httpx.AsyncClient(
trust_env=True,
follow_redirects=True,
timeout=20,
) as client:
response = await client.get(
install_info['asset_url'],
)
response.raise_for_status()
file_bytes = response.content
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
install_info['plugin_file_key'] = file_key
self.ap.logger.info(f'Transfered file {file_key} to plugin runtime')
except Exception as e:
self.ap.logger.error(f'Failed to download file from GitHub: {e}')
raise Exception(f'Failed to download file from GitHub: {e}')
async for ret in self.handler.install_plugin(install_source.value, install_info):
current_action = ret.get('current_action', None)
if current_action is not None:
@@ -121,7 +190,10 @@ class PluginRuntimeConnector:
task_context.trace(trace)
async def upgrade_plugin(
self, plugin_author: str, plugin_name: str, task_context: taskmgr.TaskContext | None = None
self,
plugin_author: str,
plugin_name: str,
task_context: taskmgr.TaskContext | None = None,
) -> dict[str, Any]:
async for ret in self.handler.upgrade_plugin(plugin_author, plugin_name):
current_action = ret.get('current_action', None)
@@ -135,7 +207,11 @@ class PluginRuntimeConnector:
task_context.trace(trace)
async def delete_plugin(
self, plugin_author: str, plugin_name: str, task_context: taskmgr.TaskContext | None = None
self,
plugin_author: str,
plugin_name: str,
delete_data: bool = False,
task_context: taskmgr.TaskContext | None = None,
) -> dict[str, Any]:
async for ret in self.handler.delete_plugin(plugin_author, plugin_name):
current_action = ret.get('current_action', None)
@@ -148,7 +224,16 @@ class PluginRuntimeConnector:
if task_context is not None:
task_context.trace(trace)
# Clean up plugin settings and binary storage if requested
if delete_data:
if task_context is not None:
task_context.trace('Cleaning up plugin configuration and storage...')
await self.handler.cleanup_plugin_data(plugin_author, plugin_name)
async def list_plugins(self) -> list[dict[str, Any]]:
if not self.is_enable_plugin:
return []
return await self.handler.list_plugins()
async def get_plugin_info(self, author: str, plugin_name: str) -> dict[str, Any]:
@@ -167,21 +252,33 @@ class PluginRuntimeConnector:
) -> context.EventContext:
event_ctx = context.EventContext.from_event(event)
event_ctx_result = await self.handler.emit_event(event_ctx.model_dump(serialize_as_any=True))
if not self.is_enable_plugin:
return event_ctx
event_ctx_result = await self.handler.emit_event(event_ctx.model_dump(serialize_as_any=False))
event_ctx = context.EventContext.model_validate(event_ctx_result['event_context'])
return event_ctx
async def list_tools(self) -> list[ComponentManifest]:
if not self.is_enable_plugin:
return []
list_tools_data = await self.handler.list_tools()
return [ComponentManifest.model_validate(tool) for tool in list_tools_data]
async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]:
if not self.is_enable_plugin:
return {'error': 'Tool not found: plugin system is disabled'}
return await self.handler.call_tool(tool_name, parameters)
async def list_commands(self) -> list[ComponentManifest]:
if not self.is_enable_plugin:
return []
list_commands_data = await self.handler.list_commands()
return [ComponentManifest.model_validate(command) for command in list_commands_data]
@@ -189,6 +286,9 @@ class PluginRuntimeConnector:
async def execute_command(
self, command_ctx: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if not self.is_enable_plugin:
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(command_ctx.command))
gen = self.handler.execute_command(command_ctx.model_dump(serialize_as_any=True))
async for ret in gen:
@@ -197,6 +297,10 @@ class PluginRuntimeConnector:
yield cmd_ret
def dispose(self):
if isinstance(self.ctrl, stdio_client_controller.StdioClientController):
if self.is_enable_plugin and isinstance(self.ctrl, stdio_client_controller.StdioClientController):
self.ap.logger.info('Terminating plugin runtime process...')
self.ctrl.process.terminate()
if self.heartbeat_task is not None:
self.heartbeat_task.cancel()
self.heartbeat_task = None

View File

@@ -56,7 +56,9 @@ class RuntimeConnectionHandler(handler.Handler):
.where(persistence_plugin.PluginSetting.plugin_name == plugin_name)
)
if result.first() is not None:
setting = result.first()
if setting is not None:
# delete plugin setting
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_plugin.PluginSetting)
@@ -71,6 +73,10 @@ class RuntimeConnectionHandler(handler.Handler):
plugin_name=plugin_name,
install_source=install_source,
install_info=install_info,
# inherit from existing setting
enabled=setting.enabled if setting is not None else True,
priority=setting.priority if setting is not None else 0,
config=setting.config if setting is not None else {}, # noqa: F821
)
)
@@ -430,6 +436,25 @@ class RuntimeConnectionHandler(handler.Handler):
},
)
@self.action(RuntimeToLangBotAction.GET_CONFIG_FILE)
async def get_config_file(data: dict[str, Any]) -> handler.ActionResponse:
"""Get a config file by file key"""
file_key = data['file_key']
try:
# Load file from storage
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_key)
return handler.ActionResponse.success(
data={
'file_base64': base64.b64encode(file_bytes).decode('utf-8'),
},
)
except Exception as e:
return handler.ActionResponse.error(
message=f'Failed to load config file {file_key}: {e}',
)
async def ping(self) -> dict[str, Any]:
"""Ping the runtime"""
return await self.call_action(
@@ -536,7 +561,7 @@ class RuntimeConnectionHandler(handler.Handler):
{
'event_context': event_context,
},
timeout=30,
timeout=60,
)
return result
@@ -546,7 +571,7 @@ class RuntimeConnectionHandler(handler.Handler):
result = await self.call_action(
LangBotToRuntimeAction.LIST_TOOLS,
{},
timeout=10,
timeout=20,
)
return result['tools']
@@ -560,7 +585,35 @@ class RuntimeConnectionHandler(handler.Handler):
'plugin_name': plugin_name,
},
)
return result
plugin_icon_file_key = result['plugin_icon_file_key']
mime_type = result['mime_type']
plugin_icon_bytes = await self.read_local_file(plugin_icon_file_key)
await self.delete_local_file(plugin_icon_file_key)
return {
'plugin_icon_base64': base64.b64encode(plugin_icon_bytes).decode('utf-8'),
'mime_type': mime_type,
}
async def cleanup_plugin_data(self, plugin_author: str, plugin_name: str) -> None:
"""Cleanup plugin settings and binary storage"""
# Delete plugin settings
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_plugin.PluginSetting)
.where(persistence_plugin.PluginSetting.plugin_author == plugin_author)
.where(persistence_plugin.PluginSetting.plugin_name == plugin_name)
)
# Delete all binary storage for this plugin
owner = f'{plugin_author}/{plugin_name}'
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_bstorage.BinaryStorage)
.where(persistence_bstorage.BinaryStorage.owner_type == 'plugin')
.where(persistence_bstorage.BinaryStorage.owner == owner)
)
async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]:
"""Call tool"""
@@ -570,7 +623,7 @@ class RuntimeConnectionHandler(handler.Handler):
'tool_name': tool_name,
'tool_parameters': parameters,
},
timeout=30,
timeout=60,
)
return result['tool_response']
@@ -591,7 +644,7 @@ class RuntimeConnectionHandler(handler.Handler):
{
'command_context': command_context,
},
timeout=30,
timeout=60,
)
async for ret in gen:

View File

@@ -59,7 +59,7 @@ class ModelManager:
try:
await self.load_llm_model(llm_model)
except provider_errors.RequesterNotFoundError as e:
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping model {llm_model.uuid}')
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping llm model {llm_model.uuid}')
except Exception as e:
self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}')
@@ -67,7 +67,14 @@ class ModelManager:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
embedding_models = result.all()
for embedding_model in embedding_models:
await self.load_embedding_model(embedding_model)
try:
await self.load_embedding_model(embedding_model)
except provider_errors.RequesterNotFoundError as e:
self.ap.logger.warning(
f'Requester {e.requester_name} not found, skipping embedding model {embedding_model.uuid}'
)
except Exception as e:
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
async def init_runtime_llm_model(
self,
@@ -107,6 +114,9 @@ class ModelManager:
elif isinstance(model_info, dict):
model_info = persistence_model.EmbeddingModel(**model_info)
if model_info.requester not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(model_info.requester)
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
await requester_inst.initialize()

View File

@@ -1,9 +1,14 @@
from __future__ import annotations
import typing
import dashscope
import openai
from . import modelscopechatcmpl
from .. import requester
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
@@ -15,3 +20,211 @@ class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
'timeout': 120,
}
async def _closure_stream(
self,
query: pipeline_query.Query,
req_messages: list[dict],
use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
self.client.api_key = use_model.token_mgr.get_token()
args = {}
args['model'] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args['tools'] = tools
# 设置此次请求中的messages
messages = req_messages.copy()
is_use_dashscope_call = False # 是否使用阿里原生库调用
is_enable_multi_model = True # 是否支持多轮对话
use_time_num = 0 # 模型已调用次数,防止存在多文件时重复调用
use_time_ids = [] # 已调用的ID列表
message_id = 0 # 记录消息序号
for msg in messages:
# print(msg)
if 'content' in msg and isinstance(msg['content'], list):
for me in msg['content']:
if me['type'] == 'image_base64':
me['image_url'] = {'url': me['image_base64']}
me['type'] = 'image_url'
del me['image_base64']
elif me['type'] == 'file_url' and '.' in me.get('file_name', ''):
# 1. 视频文件推理
# https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2845871
file_type = me.get('file_name').lower().split('.')[-1]
if file_type in ['mp4', 'avi', 'mkv', 'mov', 'flv', 'wmv']:
me['type'] = 'video_url'
me['video_url'] = {'url': me['file_url']}
del me['file_url']
del me['file_name']
use_time_num +=1
use_time_ids.append(message_id)
is_enable_multi_model = False
# 2. 语音文件识别, 无法通过openai的audio字段传递暂时不支持
# https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2979031
elif file_type in ['aac', 'amr', 'aiff', 'flac', 'm4a',
'mp3', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma']:
me['audio'] = me['file_url']
me['type'] = 'audio'
del me['file_url']
del me['type']
del me['file_name']
is_use_dashscope_call = True
use_time_num +=1
use_time_ids.append(message_id)
is_enable_multi_model = False
message_id += 1
# 使用列表推导式,保留不在 use_time_ids[:-1] 中的元素,仅保留最后一个多媒体消息
if not is_enable_multi_model and use_time_num > 1:
messages = [msg for idx, msg in enumerate(messages) if idx not in use_time_ids[:-1]]
if not is_enable_multi_model:
messages = [msg for msg in messages if 'resp_message_id' not in msg]
args['messages'] = messages
args['stream'] = True
# 流式处理状态
# tool_calls_map: dict[str, provider_message.ToolCall] = {}
chunk_idx = 0
thinking_started = False
thinking_ended = False
role = 'assistant' # 默认角色
if is_use_dashscope_call:
response = dashscope.MultiModalConversation.call(
# 若没有配置环境变量请用百炼API Key将下行替换为api_key = "sk-xxx"
api_key=use_model.token_mgr.get_token(),
model=use_model.model_entity.name,
messages=messages,
result_format="message",
asr_options={
# "language": "zh", # 可选,若已知音频的语种,可通过该参数指定待识别语种,以提升识别准确率
"enable_lid": True,
"enable_itn": False
},
stream=True
)
content_length_list = []
previous_length = 0 # 记录上一次的内容长度
for res in response:
chunk = res["output"]
# 解析 chunk 数据
if hasattr(chunk, 'choices') and chunk.choices:
choice = chunk.choices[0]
delta_content = choice["message"].content[0]["text"]
finish_reason = choice["finish_reason"]
content_length_list.append(len(delta_content))
else:
delta_content = ""
finish_reason = None
# 跳过空的第一个 chunk只有 role 没有内容)
if chunk_idx == 0 and not delta_content:
chunk_idx += 1
continue
# 检查 content_length_list 是否有足够的数据
if len(content_length_list) >= 2:
now_content = delta_content[previous_length: content_length_list[-1]]
previous_length = content_length_list[-1] # 更新上一次的长度
else:
now_content = delta_content # 第一次循环时直接使用 delta_content
previous_length = len(delta_content) # 更新上一次的长度
# 构建 MessageChunk - 只包含增量内容
chunk_data = {
'role': role,
'content': now_content if now_content else None,
'is_final': bool(finish_reason) and finish_reason != "null",
}
# 移除 None 值
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
yield provider_message.MessageChunk(**chunk_data)
chunk_idx += 1
else:
async for chunk in self._req_stream(args, extra_body=extra_args):
# 解析 chunk 数据
if hasattr(chunk, 'choices') and chunk.choices:
choice = chunk.choices[0]
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
finish_reason = getattr(choice, 'finish_reason', None)
else:
delta = {}
finish_reason = None
# 从第一个 chunk 获取 role后续使用这个 role
if 'role' in delta and delta['role']:
role = delta['role']
# 获取增量内容
delta_content = delta.get('content', '')
reasoning_content = delta.get('reasoning_content', '')
# 处理 reasoning_content
if reasoning_content:
# accumulated_reasoning += reasoning_content
# 如果设置了 remove_think跳过 reasoning_content
if remove_think:
chunk_idx += 1
continue
# 第一次出现 reasoning_content添加 <think> 开始标签
if not thinking_started:
thinking_started = True
delta_content = '<think>\n' + reasoning_content
else:
# 继续输出 reasoning_content
delta_content = reasoning_content
elif thinking_started and not thinking_ended and delta_content:
# reasoning_content 结束normal content 开始,添加 </think> 结束标签
thinking_ended = True
delta_content = '\n</think>\n' + delta_content
# 处理工具调用增量
if delta.get('tool_calls'):
for tool_call in delta['tool_calls']:
if tool_call['id'] != '':
tool_id = tool_call['id']
if tool_call['function']['name'] is not None:
tool_name = tool_call['function']['name']
if tool_call['type'] is None:
tool_call['type'] = 'function'
tool_call['id'] = tool_id
tool_call['function']['name'] = tool_name
tool_call['function']['arguments'] = (
'' if tool_call['function']['arguments'] is None else tool_call['function']['arguments']
)
# 跳过空的第一个 chunk只有 role 没有内容)
if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'):
chunk_idx += 1
continue
# 构建 MessageChunk - 只包含增量内容
chunk_data = {
'role': role,
'content': delta_content if delta_content else None,
'tool_calls': delta.get('tool_calls'),
'is_final': bool(finish_reason),
}
# 移除 None 值
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
yield provider_message.MessageChunk(**chunk_data)
chunk_idx += 1
# return

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="450" height="280" viewBox="0 0 450 280" class="cursor-pointer h-24 flex-shrink-0 w-149"><g fill="none" fill-rule="nonzero"><path fill="#0005DE" d="M97.705 6.742c58.844 0 90.962 34.353 90.962 98.341v21.843c-15.118-2.479-30.297-6.573-45.558-12.3v-9.543c0-35.97-15.564-56.281-45.404-56.281s-45.404 20.31-45.404 56.281v72.48c0 36.117 15.65 56.818 45.404 56.818 26.78 0 42.133-16.768 44.936-46.452q22.397 6.473 44.905 9.356c-6.15 51.52-37.492 79.155-89.841 79.155-58.678 0-90.963-34.72-90.963-98.878v-72.479c0-63.988 32.119-98.34 90.963-98.34m253.627 0c58.844 0 90.963 34.353 90.963 98.341v72.48c0 64.157-32.285 98.877-90.963 98.877-52.438 0-83.797-27.729-89.874-79.415 15-2.026 29.965-5.252 44.887-9.67 2.658 30.042 18.036 47.026 44.987 47.026 29.755 0 45.404-20.7 45.404-56.819v-72.479c0-35.97-15.564-56.281-45.404-56.281s-45.403 20.31-45.403 56.281v8.778c-15.262 5.868-30.44 10.104-45.559 12.725v-21.503c0-63.988 32.118-98.34 90.962-98.34m-164.37 140.026.57.09.831.127-.83-.128a234.5 234.5 0 0 0 35.979 2.79q18.408.002 36.858-2.928l1.401-.226a242 242 0 0 0 1.45-.244l-1.037.175q.729-.12 1.458-.247l-.421.072 1.26-.219-.84.147a244 244 0 0 0 2.8-.5l-.792.144q.648-.117 1.298-.239l-.506.094q.66-.122 1.322-.248l-.816.154q.759-.142 1.518-.289l-.702.135a247 247 0 0 0 5.364-1.084l-.463.098a250 250 0 0 0 3.928-.864l-.785.178 1.45-.33-.665.152q.597-.137 1.193-.276l-.528.123a253 253 0 0 0 3.685-.882l-.254.063q.683-.168 1.366-.34l-1.112.277q.809-.2 1.618-.405l-.506.128q.818-.206 1.634-.417l-1.128.289q.71-.18 1.419-.365l1.506-.397a259 259 0 0 0 1.804-.488l-.433.119a261 261 0 0 0 3.751-1.053l-.681.196a264 264 0 0 0 1.735-.502l-1.054.306q.636-.184 1.272-.37l-.218.064 1.238-.366-1.02.302a266 266 0 0 0 2.936-.882l-1.026.312q.71-.214 1.42-.433l-.394.121q.675-.207 1.35-.418l-.955.297q.8-.246 1.6-.499l-.645.202q.86-.269 1.72-.543l-1.076.341q.666-.21 1.33-.423l-.254.082q.833-.266 1.665-.539l-1.41.457q.874-.28 1.75-.568l-.34.111q.702-.229 1.403-.462l-1.063.351q.818-.269 1.634-.542l-.571.19a276 276 0 0 0 4.038-1.378l-.735.256q.657-.228 1.315-.46l-.58.204q16.86-5.903 33.78-14.256l-7.114-12.453 42.909 6.553-13.148 45.541-7.734-13.537q-23.832 11.94-47.755 19.504l-.199.063a298 298 0 0 1-11.65 3.412 288 288 0 0 1-10.39 2.603 280 280 0 0 1-11.677 2.431 273 273 0 0 1-11.643 1.903 263.5 263.5 0 0 1-36.858 2.599q-17.437 0-34.844-2.323l-.227-.03q-.635-.085-1.27-.174l1.497.204a268 268 0 0 1-13.673-2.182 275 275 0 0 1-12.817-2.697 282 282 0 0 1-11.859-3.057 291 291 0 0 1-7.21-2.123c-17.23-5.314-34.43-12.334-51.59-21.051l-8.258 14.455-13.148-45.541 42.909-6.553-6.594 11.544q18.421 9.24 36.776 15.572l1.316.45 1.373.462-.831-.278q.795.267 1.589.53l-.758-.252q.632.211 1.264.419l-.506-.167q.642.212 1.284.42l-.778-.253a271 271 0 0 0 3.914 1.251l-.227-.07a267 267 0 0 0 3.428 1.046l-.194-.058 1.315.389-1.121-.331q.864.256 1.73.508l-.609-.177q.826.241 1.651.478l-1.043-.3 1.307.375-.264-.075q.802.228 1.603.452l-1.34-.377q1.034.294 2.067.58l-.727-.203q.713.2 1.426.394l-.699-.192q.62.171 1.237.338l-.538-.146a259 259 0 0 0 3.977 1.051l-.66-.17q.683.177 1.367.35l-.707-.18q.687.175 1.373.348l-.666-.168q.738.186 1.475.368l-.809-.2q.716.179 1.43.353l-.621-.153a253 253 0 0 0 3.766.898l-.308-.07q.735.17 1.472.336l-1.164-.266q.747.173 1.496.34l-.332-.074q.845.19 1.69.374l-1.358-.3q.932.21 1.864.41l-.505-.11q.726.159 1.452.313l-.947-.203q.72.156 1.44.307l-.493-.104q.684.144 1.368.286l-.875-.182q.743.155 1.485.306l-.61-.124q.932.192 1.864.376l-1.254-.252q.904.184 1.809.361l-.555-.109q.752.15 1.504.293l-.95-.184q.69.135 1.377.265l-.427-.081q.784.15 1.569.295l-1.142-.214q.717.136 1.434.268l-.292-.054a244 244 0 0 0 3.808.673l-.68-.116 1.063.18-.383-.064q1.076.18 2.152.352z"></path></g></svg>

After

Width:  |  Height:  |  Size: 3.6 KiB

View File

@@ -0,0 +1,31 @@
apiVersion: v1
kind: LLMAPIRequester
metadata:
name: tokenpony-chat-completions
label:
en_US: TokenPony
zh_Hans: 小马算力
icon: tokenpony.svg
spec:
config:
- name: base_url
label:
en_US: Base URL
zh_Hans: 基础 URL
type: string
required: true
default: "https://api.tokenpony.cn/v1"
- name: timeout
label:
en_US: Timeout
zh_Hans: 超时时间
type: integer
required: true
default: 120
support_type:
- llm
- text-embedding
execution:
python:
path: ./tokenponychatcmpl.py
attr: TokenPonyChatCompletions

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
import typing
import openai
from . import chatcmpl
class TokenPonyChatCompletions(chatcmpl.OpenAIChatCompletions):
"""TokenPony ChatCompletion API 请求器"""
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base_url': 'https://api.tokenpony.cn/v1',
'timeout': 120,
}

View File

@@ -0,0 +1,312 @@
from __future__ import annotations
import typing
import json
import uuid
import base64
from .. import runner
from ...core import app
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from ...utils import image
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from libs.coze_server_api.client import AsyncCozeAPIClient
@runner.runner_class('coze-api')
class CozeAPIRunner(runner.RequestRunner):
"""Coze API 对话请求器"""
def __init__(self, ap: app.Application, pipeline_config: dict):
self.pipeline_config = pipeline_config
self.ap = ap
self.agent_token = pipeline_config["ai"]['coze-api']['api-key']
self.bot_id = pipeline_config["ai"]['coze-api'].get('bot-id')
self.chat_timeout = pipeline_config["ai"]['coze-api'].get('timeout')
self.auto_save_history = pipeline_config["ai"]['coze-api'].get('auto_save_history')
self.api_base = pipeline_config["ai"]['coze-api'].get('api-base')
self.coze = AsyncCozeAPIClient(
self.agent_token,
self.api_base
)
def _process_thinking_content(
self,
content: str,
) -> tuple[str, str]:
"""处理思维链内容
Args:
content: 原始内容
Returns:
(处理后的内容, 提取的思维链内容)
"""
remove_think = self.pipeline_config.get('output', {}).get('misc', {}).get('remove-think', False)
thinking_content = ''
# 从 content 中提取 <think> 标签内容
if content and '<think>' in content and '</think>' in content:
import re
think_pattern = r'<think>(.*?)</think>'
think_matches = re.findall(think_pattern, content, re.DOTALL)
if think_matches:
thinking_content = '\n'.join(think_matches)
# 移除 content 中的 <think> 标签
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
# 根据 remove_think 参数决定是否保留思维链
if remove_think:
return content, ''
else:
# 如果有思维链内容,将其以 <think> 格式添加到 content 开头
if thinking_content:
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
return content, thinking_content
async def _preprocess_user_message(self, query: pipeline_query.Query) -> list[dict]:
"""预处理用户消息转换为Coze消息格式
Returns:
list[dict]: Coze消息列表
"""
messages = []
if isinstance(query.user_message.content, list):
# 多模态消息处理
content_parts = []
for ce in query.user_message.content:
if ce.type == 'text':
content_parts.append({"type": "text", "text": ce.text})
elif ce.type == 'image_base64':
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
file_bytes = base64.b64decode(image_b64)
file_id = await self._get_file_id(file_bytes)
content_parts.append({"type": "image", "file_id": file_id})
elif ce.type == 'file':
# 处理文件上传到Coze
file_id = await self._get_file_id(ce.file)
content_parts.append({"type": "file", "file_id": file_id})
# 创建多模态消息
if content_parts:
messages.append({
"role": "user",
"content": json.dumps(content_parts),
"content_type": "object_string",
"meta_data": None
})
elif isinstance(query.user_message.content, str):
# 纯文本消息
messages.append({
"role": "user",
"content": query.user_message.content,
"content_type": "text",
"meta_data": None
})
return messages
async def _get_file_id(self, file) -> str:
"""上传文件到Coze服务
Args:
file: 文件
Returns:
str: 文件ID
"""
file_id = await self.coze.upload(file=file)
return file_id
async def _chat_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用聊天助手(非流式)
注意由于cozepy没有提供非流式API这里使用流式API并在结束后一次性返回完整内容
"""
user_id = f'{query.launcher_id}_{query.sender_id}'
# 预处理用户消息
additional_messages = await self._preprocess_user_message(query)
# 获取会话ID
conversation_id = None
# 收集完整内容
full_content = ''
full_reasoning = ''
try:
# 调用Coze API流式接口
async for chunk in self.coze.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
timeout=self.chat_timeout,
auto_save_history=self.auto_save_history,
stream=True
):
self.ap.logger.debug(f'coze-chat-stream: {chunk}')
event_type = chunk.get('event')
data = chunk.get('data', {})
if event_type == 'conversation.message.delta':
# 收集内容
if 'content' in data:
full_content += data.get('content', '')
# 收集推理内容(如果有)
if 'reasoning_content' in data:
full_reasoning += data.get('reasoning_content', '')
elif event_type == 'done':
# 保存会话ID
if 'conversation_id' in data:
conversation_id = data.get('conversation_id')
elif event_type == 'error':
# 处理错误
error_msg = f"Coze API错误: {data.get('message', '未知错误')}"
yield provider_message.Message(
role='assistant',
content=error_msg,
)
return
# 处理思维链内容
content, thinking_content = self._process_thinking_content(full_content)
if full_reasoning:
remove_think = self.pipeline_config.get('output', {}).get('misc', {}).get('remove-think', False)
if not remove_think:
content = f'<think>\n{full_reasoning}\n</think>\n{content}'.strip()
# 一次性返回完整内容
yield provider_message.Message(
role='assistant',
content=content,
)
# 保存会话ID
if conversation_id and query.session.using_conversation:
query.session.using_conversation.uuid = conversation_id
except Exception as e:
self.ap.logger.error(f'Coze API错误: {str(e)}')
yield provider_message.Message(
role='assistant',
content=f'Coze API调用失败: {str(e)}',
)
async def _chat_messages_chunk(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""调用聊天助手(流式)"""
user_id = f'{query.launcher_id}_{query.sender_id}'
# 预处理用户消息
additional_messages = await self._preprocess_user_message(query)
# 获取会话ID
conversation_id = None
start_reasoning = False
stop_reasoning = False
message_idx = 1
is_final = False
full_content = ''
remove_think = self.pipeline_config.get('output', {}).get('misc', {}).get('remove-think', False)
try:
# 调用Coze API流式接口
async for chunk in self.coze.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
timeout=self.chat_timeout,
auto_save_history=self.auto_save_history,
stream=True
):
self.ap.logger.debug(f'coze-chat-stream-chunk: {chunk}')
event_type = chunk.get('event')
data = chunk.get('data', {})
content = ""
if event_type == 'conversation.message.delta':
message_idx += 1
# 处理内容增量
if "reasoning_content" in data and not remove_think:
reasoning_content = data.get('reasoning_content', '')
if reasoning_content and not start_reasoning:
content = f"<think/>\n"
start_reasoning = True
content += reasoning_content
if 'content' in data:
if data.get('content', ''):
content += data.get('content', '')
if not stop_reasoning and start_reasoning:
content = f"</think>\n{content}"
stop_reasoning = True
elif event_type == 'done':
# 保存会话ID
if 'conversation_id' in data:
conversation_id = data.get('conversation_id')
if query.session.using_conversation:
query.session.using_conversation.uuid = conversation_id
is_final = True
elif event_type == 'error':
# 处理错误
error_msg = f"Coze API错误: {data.get('message', '未知错误')}"
yield provider_message.MessageChunk(
role='assistant',
content=error_msg,
finish_reason='error'
)
return
full_content += content
if message_idx % 8 == 0 or is_final:
if full_content:
yield provider_message.MessageChunk(
role='assistant',
content=full_content,
is_final=is_final
)
except Exception as e:
self.ap.logger.error(f'Coze API流式调用错误: {str(e)}')
yield provider_message.MessageChunk(
role='assistant',
content=f'Coze API流式调用失败: {str(e)}',
finish_reason='error'
)
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行"""
msg_seq = 0
if await query.adapter.is_stream_output_supported():
async for msg in self._chat_messages_chunk(query):
if isinstance(msg, provider_message.MessageChunk):
msg_seq += 1
msg.msg_sequence = msg_seq
yield msg
else:
async for msg in self._chat_messages(query):
yield msg

View File

@@ -0,0 +1,205 @@
from __future__ import annotations
import typing
import json
import base64
import tempfile
import os
from tboxsdk.tbox import TboxClient
from tboxsdk.model.file import File, FileType
from .. import runner
from ...core import app
from ...utils import image
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class TboxAPIError(Exception):
"""TBox API 请求失败"""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
@runner.runner_class('tbox-app-api')
class TboxAPIRunner(runner.RequestRunner):
"蚂蚁百宝箱API对话请求器"
# 运行器内部使用的配置
app_id: str # 蚂蚁百宝箱平台中的应用ID
api_key: str # 在蚂蚁百宝箱平台中申请的令牌
def __init__(self, ap: app.Application, pipeline_config: dict):
"""初始化"""
self.ap = ap
self.pipeline_config = pipeline_config
# 初始化Tbox 参数配置
self.app_id = self.pipeline_config['ai']['tbox-app-api']['app-id']
self.api_key = self.pipeline_config['ai']['tbox-app-api']['api-key']
# 初始化Tbox client
self.tbox_client = TboxClient(authorization=self.api_key)
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]:
"""预处理用户消息,提取纯文本,并将图片上传到 Tbox 服务
Returns:
tuple[str, list[str]]: 纯文本和图片的 Tbox 文件ID
"""
plain_text = ''
image_ids = []
if isinstance(query.user_message.content, list):
for ce in query.user_message.content:
if ce.type == 'text':
plain_text += ce.text
elif ce.type == 'image_base64':
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
# 创建临时文件
file_bytes = base64.b64decode(image_b64)
try:
with tempfile.NamedTemporaryFile(suffix=f'.{image_format}', delete=False) as tmp_file:
tmp_file.write(file_bytes)
tmp_file_path = tmp_file.name
file_upload_resp = self.tbox_client.upload_file(
tmp_file_path
)
image_id = file_upload_resp.get("data", "")
image_ids.append(image_id)
finally:
# 清理临时文件
if os.path.exists(tmp_file_path):
os.unlink(tmp_file_path)
elif isinstance(query.user_message.content, str):
plain_text = query.user_message.content
return plain_text, image_ids
async def _agent_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""TBox 智能体对话请求"""
plain_text, image_ids = await self._preprocess_user_message(query)
remove_think = self.pipeline_config['output'].get('misc', {}).get('remove-think')
try:
is_stream = await query.adapter.is_stream_output_supported()
except AttributeError:
is_stream = False
# 获取Tbox的conversation_id
conversation_id = query.session.using_conversation.uuid or None
files = None
if image_ids:
files = [
File(file_id=image_id, type=FileType.IMAGE)
for image_id in image_ids
]
# 发送对话请求
response = self.tbox_client.chat(
app_id=self.app_id, # Tbox中智能体应用的ID
user_id=query.bot_uuid, # 用户ID
query=plain_text, # 用户输入的文本信息
stream=is_stream, # 是否流式输出
conversation_id=conversation_id, # 会话ID为None时Tbox会自动创建一个新会话
files=files, # 图片内容
)
if is_stream:
# 解析Tbox流式输出内容并发送给上游
for chunk in self._process_stream_message(response, query, remove_think):
yield chunk
else:
message = self._process_non_stream_message(response, query, remove_think)
yield provider_message.Message(
role='assistant',
content=message,
)
def _process_non_stream_message(self, response: typing.Dict, query: pipeline_query.Query, remove_think: bool):
if response.get('errorCode') != "0":
raise TboxAPIError(f'Tbox API 请求失败: {response.get("errorMsg", "")}')
payload = response.get('data', {})
conversation_id = payload.get('conversationId', '')
query.session.using_conversation.uuid = conversation_id
thinking_content = payload.get('reasoningContent', [])
result = ""
if thinking_content and not remove_think:
result += f'<think>\n{thinking_content[0].get("text", "")}\n</think>\n'
content = payload.get('result', [])
if content:
result += content[0].get('chunk', '')
return result
def _process_stream_message(self, response: typing.Generator[dict], query: pipeline_query.Query, remove_think: bool):
idx_msg = 0
pending_content = ''
conversation_id = None
think_start = False
think_end = False
for chunk in response:
if chunk.get('type', '') == 'chunk':
"""
Tbox返回的消息内容chunk结构
{'lane': 'default', 'payload': {'conversationId': '20250918tBI947065406', 'messageId': '20250918TB1f53230954', 'text': ''}, 'type': 'chunk'}
"""
# 如果包含思考过程,拼接</think>
if think_start and not think_end:
pending_content += '\n</think>\n'
think_end = True
payload = chunk.get('payload', {})
if not conversation_id:
conversation_id = payload.get('conversationId')
query.session.using_conversation.uuid = conversation_id
if payload.get('text'):
idx_msg += 1
pending_content += payload.get('text')
elif chunk.get('type', '') == 'thinking' and not remove_think:
"""
Tbox返回的思考过程chunk结构
{'payload': '{"ext_data":{"text":"日期"},"event":"flow.node.llm.thinking","entity":{"node_type":"text-completion","execute_id":"6","group_id":0,"parent_execute_id":"6","node_name":"模型推理","node_id":"TC_5u6gl0"}}', 'type': 'thinking'}
"""
payload = json.loads(chunk.get('payload', '{}'))
if payload.get('ext_data', {}).get('text'):
idx_msg += 1
content = payload.get('ext_data', {}).get('text')
if not think_start:
think_start = True
pending_content += f'<think>\n{content}'
else:
pending_content += content
elif chunk.get('type', '') == 'error':
raise TboxAPIError(
f'Tbox API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
)
if idx_msg % 8 == 0:
yield provider_message.MessageChunk(
role='assistant',
content=pending_content,
is_final=False,
)
# Tbox不返回END事件默认发一个最终消息
yield provider_message.MessageChunk(
role='assistant',
content=pending_content,
is_final=True,
)
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行"""
msg_seq = 0
async for msg in self._agent_messages(query):
if isinstance(msg, provider_message.MessageChunk):
msg_seq += 1
msg.msg_sequence = msg_seq
yield msg

View File

@@ -1,7 +1,11 @@
from __future__ import annotations
import enum
import typing
from contextlib import AsyncExitStack
import traceback
import sqlalchemy
import asyncio
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@@ -10,6 +14,13 @@ from mcp.client.sse import sse_client
from .. import loader
from ....core import app
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from ....entity.persistence import mcp as persistence_mcp
class MCPSessionStatus(enum.Enum):
CONNECTING = 'connecting'
CONNECTED = 'connected'
ERROR = 'error'
class RuntimeMCPSession:
@@ -27,16 +38,34 @@ class RuntimeMCPSession:
functions: list[resource_tool.LLMTool] = []
def __init__(self, server_name: str, server_config: dict, ap: app.Application):
enable: bool
# connected: bool
status: MCPSessionStatus
_lifecycle_task: asyncio.Task | None
_shutdown_event: asyncio.Event
_ready_event: asyncio.Event
def __init__(self, server_name: str, server_config: dict, enable: bool, ap: app.Application):
self.server_name = server_name
self.server_config = server_config
self.ap = ap
self.enable = enable
self.session = None
self.exit_stack = AsyncExitStack()
self.functions = []
self.status = MCPSessionStatus.CONNECTING
self._lifecycle_task = None
self._shutdown_event = asyncio.Event()
self._ready_event = asyncio.Event()
async def _init_stdio_python_server(self):
server_params = StdioServerParameters(
command=self.server_config['command'],
@@ -58,6 +87,7 @@ class RuntimeMCPSession:
self.server_config['url'],
headers=self.server_config.get('headers', {}),
timeout=self.server_config.get('timeout', 10),
sse_read_timeout=self.server_config.get('ssereadtimeout', 30),
)
)
@@ -67,19 +97,65 @@ class RuntimeMCPSession:
await self.session.initialize()
async def initialize(self):
self.ap.logger.debug(f'初始化 MCP 会话: {self.server_name} {self.server_config}')
async def _lifecycle_loop(self):
"""在后台任务中管理整个MCP会话的生命周期"""
try:
if self.server_config['mode'] == 'stdio':
await self._init_stdio_python_server()
elif self.server_config['mode'] == 'sse':
await self._init_sse_server()
else:
raise ValueError(f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}')
if self.server_config['mode'] == 'stdio':
await self._init_stdio_python_server()
elif self.server_config['mode'] == 'sse':
await self._init_sse_server()
else:
raise ValueError(f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}')
await self.refresh()
self.status = MCPSessionStatus.CONNECTED
# 通知start()方法连接已建立
self._ready_event.set()
# 等待shutdown信号
await self._shutdown_event.wait()
except Exception as e:
self.status = MCPSessionStatus.ERROR
self.ap.logger.error(f'Error in MCP session lifecycle {self.server_name}: {e}\n{traceback.format_exc()}')
# 即使出错也要设置ready事件让start()方法知道初始化已完成
self._ready_event.set()
finally:
# 在同一个任务中清理所有资源
try:
if self.exit_stack:
await self.exit_stack.aclose()
self.functions.clear()
self.session = None
except Exception as e:
self.ap.logger.error(f'Error cleaning up MCP session {self.server_name}: {e}\n{traceback.format_exc()}')
async def start(self):
if not self.enable:
return
# 创建后台任务来管理生命周期
self._lifecycle_task = asyncio.create_task(self._lifecycle_loop())
# 等待连接建立或失败(带超时)
try:
await asyncio.wait_for(self._ready_event.wait(), timeout=30.0)
except asyncio.TimeoutError:
self.status = MCPSessionStatus.ERROR
raise Exception('Connection timeout after 30 seconds')
# 检查是否有错误
if self.status == MCPSessionStatus.ERROR:
raise Exception('Connection failed, please check URL')
async def refresh(self):
self.functions.clear()
tools = await self.session.list_tools()
self.ap.logger.debug(f'获取 MCP 工具: {tools}')
self.ap.logger.debug(f'Refresh MCP tools: {tools}')
for tool in tools.tools:
@@ -101,58 +177,212 @@ class RuntimeMCPSession:
)
)
def get_tools(self) -> list[resource_tool.LLMTool]:
return self.functions
def get_runtime_info_dict(self) -> dict:
return {
'status': self.status.value,
'tool_count': len(self.get_tools()),
'tools': [
{
'name': tool.name,
'description': tool.description,
}
for tool in self.get_tools()
],
}
async def shutdown(self):
"""关闭工具"""
await self.session._exit_stack.aclose()
"""关闭会话并清理资源"""
try:
# 设置shutdown事件通知lifecycle任务退出
self._shutdown_event.set()
# 等待lifecycle任务完成带超时
if self._lifecycle_task and not self._lifecycle_task.done():
try:
await asyncio.wait_for(self._lifecycle_task, timeout=5.0)
except asyncio.TimeoutError:
self.ap.logger.warning(f'MCP session {self.server_name} shutdown timeout, cancelling task')
self._lifecycle_task.cancel()
try:
await self._lifecycle_task
except asyncio.CancelledError:
pass
self.ap.logger.info(f'MCP session {self.server_name} shutdown complete')
except Exception as e:
self.ap.logger.error(f'Error shutting down MCP session {self.server_name}: {e}\n{traceback.format_exc()}')
@loader.loader_class('mcp')
# @loader.loader_class('mcp')
class MCPLoader(loader.ToolLoader):
"""MCP 工具加载器。
在此加载器中管理所有与 MCP Server 的连接。
"""
sessions: dict[str, RuntimeMCPSession] = {}
sessions: dict[str, RuntimeMCPSession]
_last_listed_functions: list[resource_tool.LLMTool] = []
_last_listed_functions: list[resource_tool.LLMTool]
_hosted_mcp_tasks: list[asyncio.Task]
def __init__(self, ap: app.Application):
super().__init__(ap)
self.sessions = {}
self._last_listed_functions = []
self._hosted_mcp_tasks = []
async def initialize(self):
for server_config in self.ap.instance_config.data.get('mcp', {}).get('servers', []):
if not server_config['enable']:
continue
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
await session.initialize()
# self.ap.event_loop.create_task(session.initialize())
await self.load_mcp_servers_from_db()
async def load_mcp_servers_from_db(self):
self.ap.logger.info('Loading MCP servers from db...')
self.sessions = {}
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
servers = result.all()
for server in servers:
config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
task = asyncio.create_task(self.host_mcp_server(config))
self._hosted_mcp_tasks.append(task)
async def host_mcp_server(self, server_config: dict):
self.ap.logger.debug(f'Loading MCP server {server_config}')
try:
session = await self.load_mcp_server(server_config)
self.sessions[server_config['name']] = session
except Exception as e:
self.ap.logger.error(
f'Failed to load MCP server from db: {server_config["name"]}({server_config["uuid"]}): {e}\n{traceback.format_exc()}'
)
return
self.ap.logger.debug(f'Starting MCP server {server_config["name"]}({server_config["uuid"]})')
try:
await session.start()
except Exception as e:
self.ap.logger.error(
f'Failed to start MCP server {server_config["name"]}({server_config["uuid"]}): {e}\n{traceback.format_exc()}'
)
return
self.ap.logger.debug(f'Started MCP server {server_config["name"]}({server_config["uuid"]})')
async def load_mcp_server(self, server_config: dict) -> RuntimeMCPSession:
"""加载 MCP 服务器到运行时
Args:
server_config: 服务器配置字典,必须包含:
- name: 服务器名称
- mode: 连接模式 (stdio/sse)
- enable: 是否启用
- extra_args: 额外的配置参数 (可选)
"""
name = server_config['name']
mode = server_config['mode']
enable = server_config['enable']
extra_args = server_config.get('extra_args', {})
mixed_config = {
'name': name,
'mode': mode,
'enable': enable,
**extra_args,
}
session = RuntimeMCPSession(name, mixed_config, enable, self.ap)
return session
async def get_tools(self) -> list[resource_tool.LLMTool]:
all_functions = []
for session in self.sessions.values():
all_functions.extend(session.functions)
all_functions.extend(session.get_tools())
self._last_listed_functions = all_functions
return all_functions
async def has_tool(self, name: str) -> bool:
return name in [f.name for f in self._last_listed_functions]
"""检查工具是否存在"""
for session in self.sessions.values():
for function in session.get_tools():
if function.name == name:
return True
return False
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
for server_name, session in self.sessions.items():
for function in session.functions:
"""执行工具调用"""
for session in self.sessions.values():
for function in session.get_tools():
if function.name == name:
return await function.func(**parameters)
self.ap.logger.debug(f'Invoking MCP tool: {name} with parameters: {parameters}')
try:
result = await function.func(**parameters)
self.ap.logger.debug(f'MCP tool {name} executed successfully')
return result
except Exception as e:
self.ap.logger.error(f'Error invoking MCP tool {name}: {e}\n{traceback.format_exc()}')
raise
raise ValueError(f'未找到工具: {name}')
raise ValueError(f'Tool not found: {name}')
async def remove_mcp_server(self, server_name: str):
"""移除 MCP 服务器"""
if server_name not in self.sessions:
self.ap.logger.warning(f'MCP server {server_name} not found in sessions, skipping removal')
return
session = self.sessions.pop(server_name)
await session.shutdown()
self.ap.logger.info(f'Removed MCP server: {server_name}')
def get_session(self, server_name: str) -> RuntimeMCPSession | None:
"""获取指定名称的 MCP 会话"""
return self.sessions.get(server_name)
def has_session(self, server_name: str) -> bool:
"""检查是否存在指定名称的 MCP 会话"""
return server_name in self.sessions
def get_all_server_names(self) -> list[str]:
"""获取所有已加载的 MCP 服务器名称"""
return list(self.sessions.keys())
def get_server_tool_count(self, server_name: str) -> int:
"""获取指定服务器的工具数量"""
session = self.get_session(server_name)
return len(session.get_tools()) if session else 0
def get_all_servers_info(self) -> dict[str, dict]:
"""获取所有服务器的信息"""
info = {}
for server_name, session in self.sessions.items():
info[server_name] = {
'name': server_name,
'mode': session.server_config.get('mode'),
'enable': session.enable,
'tools_count': len(session.get_tools()),
'tool_names': [f.name for f in session.get_tools()],
}
return info
async def shutdown(self):
"""关闭工具"""
for session in self.sessions.values():
await session.shutdown()
"""关闭所有工具"""
self.ap.logger.info('Shutting down all MCP sessions...')
for server_name, session in list(self.sessions.items()):
try:
await session.shutdown()
self.ap.logger.debug(f'Shutdown MCP session: {server_name}')
except Exception as e:
self.ap.logger.error(f'Error shutting down MCP session {server_name}: {e}\n{traceback.format_exc()}')
self.sessions.clear()
self.ap.logger.info('All MCP sessions shutdown complete')

View File

@@ -7,7 +7,7 @@ from .. import loader
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
@loader.loader_class('plugin-tool-loader')
# @loader.loader_class('plugin-tool-loader')
class PluginToolLoader(loader.ToolLoader):
"""插件工具加载器。

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
import typing
from ...core import app
from . import loader as tools_loader
from ...utils import importutil
from . import loaders
from .loaders import mcp as mcp_loader, plugin as plugin_loader
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
importutil.import_modules_in_pkg(loaders)
@@ -16,25 +16,24 @@ class ToolManager:
ap: app.Application
loaders: list[tools_loader.ToolLoader]
plugin_tool_loader: plugin_loader.PluginToolLoader
mcp_tool_loader: mcp_loader.MCPLoader
def __init__(self, ap: app.Application):
self.ap = ap
self.all_functions = []
self.loaders = []
async def initialize(self):
for loader_cls in tools_loader.preregistered_loaders:
loader_inst = loader_cls(self.ap)
await loader_inst.initialize()
self.loaders.append(loader_inst)
self.plugin_tool_loader = plugin_loader.PluginToolLoader(self.ap)
await self.plugin_tool_loader.initialize()
self.mcp_tool_loader = mcp_loader.MCPLoader(self.ap)
await self.mcp_tool_loader.initialize()
async def get_all_tools(self) -> list[resource_tool.LLMTool]:
"""获取所有函数"""
all_functions: list[resource_tool.LLMTool] = []
for loader in self.loaders:
all_functions.extend(await loader.get_tools())
all_functions.extend(await self.plugin_tool_loader.get_tools())
all_functions.extend(await self.mcp_tool_loader.get_tools())
return all_functions
@@ -93,13 +92,14 @@ class ToolManager:
async def execute_func_call(self, name: str, parameters: dict) -> typing.Any:
"""执行函数调用"""
for loader in self.loaders:
if await loader.has_tool(name):
return await loader.invoke_tool(name, parameters)
if await self.plugin_tool_loader.has_tool(name):
return await self.plugin_tool_loader.invoke_tool(name, parameters)
elif await self.mcp_tool_loader.has_tool(name):
return await self.mcp_tool_loader.invoke_tool(name, parameters)
else:
raise ValueError(f'未找到工具: {name}')
async def shutdown(self):
"""关闭所有工具"""
for loader in self.loaders:
await loader.shutdown()
await self.plugin_tool_loader.shutdown()
await self.mcp_tool_loader.shutdown()

View File

@@ -4,6 +4,7 @@ import json
from typing import List
from pkg.rag.knowledge.services import base_service
from pkg.core import app
from langchain_text_splitters import RecursiveCharacterTextSplitter
class Chunker(base_service.BaseService):
@@ -27,21 +28,6 @@ class Chunker(base_service.BaseService):
"""
if not text:
return []
# words = text.split()
# chunks = []
# current_chunk = []
# for word in words:
# current_chunk.append(word)
# if len(current_chunk) > self.chunk_size:
# chunks.append(" ".join(current_chunk[:self.chunk_size]))
# current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
# if current_chunk:
# chunks.append(" ".join(current_chunk))
# A more robust chunking strategy (e.g., using recursive character text splitter)
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,

View File

@@ -1,4 +1,4 @@
semantic_version = 'v4.3.2'
semantic_version = 'v4.4.1'
required_database_version = 8
"""Tag the version of the database schema, used to check if the database needs to be migrated"""

View File

@@ -1,6 +1,6 @@
[project]
name = "langbot"
version = "4.3.1"
version = "4.4.1"
description = "Easy-to-use global IM bot platform designed for LLM era"
readme = "README.md"
requires-python = ">=3.10.1,<4.0"
@@ -60,11 +60,13 @@ dependencies = [
"ebooklib>=0.18",
"html2text>=2024.2.26",
"langchain>=0.2.0",
"langchain-text-splitters>=0.0.1",
"chromadb>=0.4.24",
"qdrant-client (>=1.15.1,<2.0.0)",
"langbot-plugin==0.1.1",
"langbot-plugin==0.1.8",
"asyncpg>=0.30.0",
"line-bot-sdk>=3.19.0"
"line-bot-sdk>=3.19.0",
"tboxsdk>=0.0.10",
]
keywords = [
"bot",
@@ -102,6 +104,7 @@ dev = [
"pre-commit>=4.2.0",
"pytest>=8.4.1",
"pytest-asyncio>=1.0.0",
"pytest-cov>=7.0.0",
"ruff>=0.11.9",
]

39
pytest.ini Normal file
View File

@@ -0,0 +1,39 @@
[pytest]
# Test discovery patterns
python_files = test_*.py
python_classes = Test*
python_functions = test_*
# Test paths
testpaths = tests
# Asyncio configuration
asyncio_mode = auto
# Output options
addopts =
-v
--strict-markers
--tb=short
--disable-warnings
# Markers
markers =
asyncio: mark test as async
unit: mark test as unit test
integration: mark test as integration test
slow: mark test as slow running
# Coverage options (when using pytest-cov)
[coverage:run]
source = pkg
omit =
*/tests/*
*/test_*.py
*/__pycache__/*
*/site-packages/*
[coverage:report]
precision = 2
show_missing = True
skip_covered = False

31
run_tests.sh Executable file
View File

@@ -0,0 +1,31 @@
#!/bin/bash
# Script to run all unit tests
# This script helps avoid circular import issues by setting up the environment properly
set -e
echo "Setting up test environment..."
# Activate virtual environment if it exists
if [ -d ".venv" ]; then
source .venv/bin/activate
fi
# Check if pytest is installed
if ! command -v pytest &> /dev/null; then
echo "Installing test dependencies..."
pip install pytest pytest-asyncio pytest-cov
fi
echo "Running all unit tests..."
# Run tests with coverage
pytest tests/unit_tests/ -v --tb=short \
--cov=pkg \
--cov-report=xml \
"$@"
echo ""
echo "Test run complete!"
echo "Coverage report saved to coverage.xml"

View File

@@ -10,8 +10,6 @@ command:
concurrency:
pipeline: 20
session: 1
mcp:
servers: []
proxy:
http: ''
https: ''
@@ -38,6 +36,7 @@ vdb:
port: 6333
api_key: ''
plugin:
enable: true
runtime_ws_url: 'ws://langbot_plugin_runtime:5400/control/ws'
enable_marketplace: true
cloud_service_url: 'https://space.langbot.app'
cloud_service_url: 'https://space.langbot.app'

View File

@@ -83,7 +83,7 @@
"output": {
"long-text-processing": {
"threshold": 1000,
"strategy": "forward",
"strategy": "none",
"font-path": ""
},
"force-delay": {

View File

@@ -23,6 +23,10 @@ stages:
label:
en_US: Local Agent
zh_Hans: 内置 Agent
- name: tbox-app-api
label:
en_US: Tbox App API
zh_Hans: 蚂蚁百宝箱平台 API
- name: dify-service-api
label:
en_US: Dify Service API
@@ -39,6 +43,10 @@ stages:
label:
en_US: Langflow API
zh_Hans: Langflow API
- name: coze-api
label:
en_US: Coze API
zh_Hans: 扣子 API
- name: local-agent
label:
en_US: Local Agent
@@ -82,6 +90,26 @@ stages:
type: knowledge-base-selector
required: false
default: ''
- name: tbox-app-api
label:
en_US: Tbox App API
zh_Hans: 蚂蚁百宝箱平台 API
description:
en_US: Configure the Tbox App API of the pipeline
zh_Hans: 配置蚂蚁百宝箱平台 API
config:
- name: api-key
label:
en_US: API Key
zh_Hans: API 密钥
type: string
required: true
- name: app-id
label:
en_US: App ID
zh_Hans: 应用 ID
type: string
required: true
- name: dify-service-api
label:
en_US: Dify Service API
@@ -356,4 +384,57 @@ stages:
zh_Hans: 可选的流程调整参数
type: json
required: false
default: '{}'
default: '{}'
- name: coze-api
label:
en_US: coze API
zh_Hans: 扣子 API
description:
en_US: Configure the Coze API of the pipeline
zh_Hans: 配置Coze API
config:
- name: api-key
label:
en_US: API Key
zh_Hans: API 密钥
description:
en_US: The API key for the Coze server
zh_Hans: Coze服务器的 API 密钥
type: string
required: true
- name: bot-id
label:
en_US: Bot ID
zh_Hans: 机器人 ID
description:
en_US: The ID of the bot to run
zh_Hans: 要运行的机器人 ID
type: string
required: true
- name: api-base
label:
en_US: API Base URL
zh_Hans: API 基础 URL
description:
en_US: The base URL for the Coze API, please use https://api.coze.com for global Coze edition(coze.com).
zh_Hans: Coze API 的基础 URL请使用 https://api.coze.com 用于全球 Coze 版coze.com
type: string
default: "https://api.coze.cn"
- name: auto-save-history
label:
en_US: Auto Save History
zh_Hans: 自动保存历史
description:
en_US: Whether to automatically save conversation history
zh_Hans: 是否自动保存对话历史
type: boolean
default: true
- name: timeout
label:
en_US: Request Timeout
zh_Hans: 请求超时
description:
en_US: Timeout in seconds for API requests
zh_Hans: API 请求超时时间(秒)
type: number
default: 120

View File

@@ -27,7 +27,7 @@ stages:
zh_Hans: 长文本的处理策略
type: select
required: true
default: forward
default: none
options:
- name: forward
label:
@@ -37,6 +37,10 @@ stages:
label:
en_US: Convert to Image
zh_Hans: 转换为图片
- name: none
label:
en_US: None
zh_Hans: 不处理
- name: font-path
label:
en_US: Font Path

183
tests/README.md Normal file
View File

@@ -0,0 +1,183 @@
# LangBot Test Suite
This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
## Important Note
Due to circular import dependencies in the pipeline module structure, the test files use **lazy imports** via `importlib.import_module()` instead of direct imports. This ensures tests can run without triggering circular import errors.
## Structure
```
tests/
├── pipeline/ # Pipeline stage tests
│ ├── conftest.py # Shared fixtures and test infrastructure
│ ├── test_simple.py # Basic infrastructure tests (always pass)
│ ├── test_bansess.py # BanSessionCheckStage tests
│ ├── test_ratelimit.py # RateLimit stage tests
│ ├── test_preproc.py # PreProcessor stage tests
│ ├── test_respback.py # SendResponseBackStage tests
│ ├── test_resprule.py # GroupRespondRuleCheckStage tests
│ ├── test_pipelinemgr.py # PipelineManager tests
│ └── test_stages_integration.py # Integration tests
└── README.md # This file
```
## Test Architecture
### Fixtures (`conftest.py`)
The test suite uses a centralized fixture system that provides:
- **MockApplication**: Comprehensive mock of the Application object with all dependencies
- **Mock objects**: Pre-configured mocks for Session, Conversation, Model, Adapter
- **Sample data**: Ready-to-use Query objects, message chains, and configurations
- **Helper functions**: Utilities for creating results and common assertions
### Design Principles
1. **Isolation**: Each test is independent and doesn't rely on external systems
2. **Mocking**: All external dependencies are mocked to ensure fast, reliable tests
3. **Coverage**: Tests cover happy paths, edge cases, and error conditions
4. **Extensibility**: Easy to add new tests by reusing existing fixtures
## Running Tests
### Using the test runner script (recommended)
```bash
bash run_tests.sh
```
This script automatically:
- Activates the virtual environment
- Installs test dependencies if needed
- Runs tests with coverage
- Generates HTML coverage report
### Manual test execution
#### Run all tests
```bash
pytest tests/pipeline/
```
#### Run only simple tests (no imports, always pass)
```bash
pytest tests/pipeline/test_simple.py -v
```
#### Run specific test file
```bash
pytest tests/pipeline/test_bansess.py -v
```
#### Run with coverage
```bash
pytest tests/pipeline/ --cov=pkg/pipeline --cov-report=html
```
#### Run specific test
```bash
pytest tests/pipeline/test_bansess.py::test_bansess_whitelist_allow -v
```
### Known Issues
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
1. Make sure you're running from the project root directory
2. Ensure the virtual environment is activated
3. Try running `test_simple.py` first to verify the test infrastructure works
## CI/CD Integration
Tests are automatically run on:
- Pull request opened
- Pull request marked ready for review
- Push to PR branch
- Push to master/develop branches
The workflow runs tests on Python 3.10, 3.11, and 3.12 to ensure compatibility.
## Adding New Tests
### 1. For a new pipeline stage
Create a new test file `test_<stage_name>.py`:
```python
"""
<StageName> stage unit tests
"""
import pytest
from pkg.pipeline.<module>.<stage> import <StageClass>
from pkg.pipeline import entities as pipeline_entities
@pytest.mark.asyncio
async def test_stage_basic_flow(mock_app, sample_query):
"""Test basic flow"""
stage = <StageClass>(mock_app)
await stage.initialize({})
result = await stage.process(sample_query, '<StageName>')
assert result.result_type == pipeline_entities.ResultType.CONTINUE
```
### 2. For additional fixtures
Add new fixtures to `conftest.py`:
```python
@pytest.fixture
def my_custom_fixture():
"""Description of fixture"""
return create_test_data()
```
### 3. For test data
Use the helper functions in `conftest.py`:
```python
from tests.pipeline.conftest import create_stage_result, assert_result_continue
result = create_stage_result(
result_type=pipeline_entities.ResultType.CONTINUE,
query=sample_query
)
assert_result_continue(result)
```
## Best Practices
1. **Test naming**: Use descriptive names that explain what's being tested
2. **Arrange-Act-Assert**: Structure tests clearly with setup, execution, and verification
3. **One assertion per test**: Focus each test on a single behavior
4. **Mock appropriately**: Mock external dependencies, not the code under test
5. **Use fixtures**: Reuse common test data through fixtures
6. **Document tests**: Add docstrings explaining what each test validates
## Troubleshooting
### Import errors
Make sure you've installed the package in development mode:
```bash
uv pip install -e .
```
### Async test failures
Ensure you're using `@pytest.mark.asyncio` decorator for async tests.
### Mock not working
Check that you're mocking at the right level and using `AsyncMock` for async functions.
## Future Enhancements
- [ ] Add integration tests for full pipeline execution
- [ ] Add performance benchmarks
- [ ] Add mutation testing for better coverage quality
- [ ] Add property-based testing with Hypothesis

0
tests/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1 @@
# Config unit tests

View File

@@ -0,0 +1,332 @@
"""
Tests for environment variable override functionality in YAML config
"""
import os
import pytest
from typing import Any
def _apply_env_overrides_to_config(cfg: dict) -> dict:
"""Apply environment variable overrides to data/config.yaml
Environment variables should be uppercase and use __ (double underscore)
to represent nested keys. For example:
- CONCURRENCY__PIPELINE overrides concurrency.pipeline
- PLUGIN__RUNTIME_WS_URL overrides plugin.runtime_ws_url
Arrays and dict types are ignored.
Args:
cfg: Configuration dictionary
Returns:
Updated configuration dictionary
"""
def convert_value(value: str, original_value: Any) -> Any:
"""Convert string value to appropriate type based on original value
Args:
value: String value from environment variable
original_value: Original value to infer type from
Returns:
Converted value (falls back to string if conversion fails)
"""
if isinstance(original_value, bool):
return value.lower() in ('true', '1', 'yes', 'on')
elif isinstance(original_value, int):
try:
return int(value)
except ValueError:
# If conversion fails, keep as string (user error, but non-breaking)
return value
elif isinstance(original_value, float):
try:
return float(value)
except ValueError:
# If conversion fails, keep as string (user error, but non-breaking)
return value
else:
return value
# Process environment variables
for env_key, env_value in os.environ.items():
# Check if the environment variable is uppercase and contains __
if not env_key.isupper():
continue
if '__' not in env_key:
continue
# Convert environment variable name to config path
# e.g., CONCURRENCY__PIPELINE -> ['concurrency', 'pipeline']
keys = [key.lower() for key in env_key.split('__')]
# Navigate to the target value and validate the path
current = cfg
for i, key in enumerate(keys):
if not isinstance(current, dict) or key not in current:
break
if i == len(keys) - 1:
# At the final key - check if it's a scalar value
if isinstance(current[key], (dict, list)):
# Skip dict and list types
pass
else:
# Valid scalar value - convert and set it
converted_value = convert_value(env_value, current[key])
current[key] = converted_value
else:
# Navigate deeper
current = current[key]
return cfg
class TestEnvOverrides:
"""Test environment variable override functionality"""
def test_simple_string_override(self):
"""Test overriding a simple string value"""
cfg = {
'api': {
'port': 5300
}
}
# Set environment variable
os.environ['API__PORT'] = '8080'
result = _apply_env_overrides_to_config(cfg)
assert result['api']['port'] == 8080
# Cleanup
del os.environ['API__PORT']
def test_nested_key_override(self):
"""Test overriding nested keys with __ delimiter"""
cfg = {
'concurrency': {
'pipeline': 20,
'session': 1
}
}
os.environ['CONCURRENCY__PIPELINE'] = '50'
result = _apply_env_overrides_to_config(cfg)
assert result['concurrency']['pipeline'] == 50
assert result['concurrency']['session'] == 1 # Unchanged
del os.environ['CONCURRENCY__PIPELINE']
def test_deep_nested_override(self):
"""Test overriding deeply nested keys"""
cfg = {
'system': {
'jwt': {
'expire': 604800,
'secret': ''
}
}
}
os.environ['SYSTEM__JWT__EXPIRE'] = '86400'
os.environ['SYSTEM__JWT__SECRET'] = 'my_secret_key'
result = _apply_env_overrides_to_config(cfg)
assert result['system']['jwt']['expire'] == 86400
assert result['system']['jwt']['secret'] == 'my_secret_key'
del os.environ['SYSTEM__JWT__EXPIRE']
del os.environ['SYSTEM__JWT__SECRET']
def test_underscore_in_key(self):
"""Test keys with underscores like runtime_ws_url"""
cfg = {
'plugin': {
'enable': True,
'runtime_ws_url': 'ws://localhost:5400/control/ws'
}
}
os.environ['PLUGIN__RUNTIME_WS_URL'] = 'ws://newhost:6000/ws'
result = _apply_env_overrides_to_config(cfg)
assert result['plugin']['runtime_ws_url'] == 'ws://newhost:6000/ws'
del os.environ['PLUGIN__RUNTIME_WS_URL']
def test_boolean_conversion(self):
"""Test boolean value conversion"""
cfg = {
'plugin': {
'enable': True,
'enable_marketplace': False
}
}
os.environ['PLUGIN__ENABLE'] = 'false'
os.environ['PLUGIN__ENABLE_MARKETPLACE'] = 'true'
result = _apply_env_overrides_to_config(cfg)
assert result['plugin']['enable'] is False
assert result['plugin']['enable_marketplace'] is True
del os.environ['PLUGIN__ENABLE']
del os.environ['PLUGIN__ENABLE_MARKETPLACE']
def test_ignore_dict_type(self):
"""Test that dict types are ignored"""
cfg = {
'database': {
'use': 'sqlite',
'sqlite': {
'path': 'data/langbot.db'
}
}
}
# Try to override a dict value - should be ignored
os.environ['DATABASE__SQLITE'] = 'new_value'
result = _apply_env_overrides_to_config(cfg)
# Should remain a dict, not overridden
assert isinstance(result['database']['sqlite'], dict)
assert result['database']['sqlite']['path'] == 'data/langbot.db'
del os.environ['DATABASE__SQLITE']
def test_ignore_list_type(self):
"""Test that list/array types are ignored"""
cfg = {
'admins': ['admin1', 'admin2'],
'command': {
'enable': True,
'prefix': ['!', '']
}
}
# Try to override list values - should be ignored
os.environ['ADMINS'] = 'admin3'
os.environ['COMMAND__PREFIX'] = '?'
result = _apply_env_overrides_to_config(cfg)
# Should remain lists, not overridden
assert isinstance(result['admins'], list)
assert result['admins'] == ['admin1', 'admin2']
assert isinstance(result['command']['prefix'], list)
assert result['command']['prefix'] == ['!', '']
del os.environ['ADMINS']
del os.environ['COMMAND__PREFIX']
def test_lowercase_env_var_ignored(self):
"""Test that lowercase environment variables are ignored"""
cfg = {
'api': {
'port': 5300
}
}
os.environ['api__port'] = '8080'
result = _apply_env_overrides_to_config(cfg)
# Should not be overridden
assert result['api']['port'] == 5300
del os.environ['api__port']
def test_no_double_underscore_ignored(self):
"""Test that env vars without __ are ignored"""
cfg = {
'api': {
'port': 5300
}
}
os.environ['APIPORT'] = '8080'
result = _apply_env_overrides_to_config(cfg)
# Should not be overridden
assert result['api']['port'] == 5300
del os.environ['APIPORT']
def test_nonexistent_key_ignored(self):
"""Test that env vars for non-existent keys are ignored"""
cfg = {
'api': {
'port': 5300
}
}
os.environ['API__NONEXISTENT'] = 'value'
result = _apply_env_overrides_to_config(cfg)
# Should not create new key
assert 'nonexistent' not in result['api']
del os.environ['API__NONEXISTENT']
def test_integer_conversion(self):
"""Test integer value conversion"""
cfg = {
'concurrency': {
'pipeline': 20
}
}
os.environ['CONCURRENCY__PIPELINE'] = '100'
result = _apply_env_overrides_to_config(cfg)
assert result['concurrency']['pipeline'] == 100
assert isinstance(result['concurrency']['pipeline'], int)
del os.environ['CONCURRENCY__PIPELINE']
def test_multiple_overrides(self):
"""Test multiple environment variable overrides at once"""
cfg = {
'api': {
'port': 5300
},
'concurrency': {
'pipeline': 20,
'session': 1
},
'plugin': {
'enable': False
}
}
os.environ['API__PORT'] = '8080'
os.environ['CONCURRENCY__PIPELINE'] = '50'
os.environ['PLUGIN__ENABLE'] = 'true'
result = _apply_env_overrides_to_config(cfg)
assert result['api']['port'] == 8080
assert result['concurrency']['pipeline'] == 50
assert result['plugin']['enable'] is True
del os.environ['API__PORT']
del os.environ['CONCURRENCY__PIPELINE']
del os.environ['PLUGIN__ENABLE']
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

View File

@@ -0,0 +1,251 @@
"""
Shared test fixtures and configuration
This file provides infrastructure for all pipeline tests, including:
- Mock object factories
- Test fixtures
- Common test helper functions
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, Mock
from typing import Any
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from pkg.pipeline import entities as pipeline_entities
class MockApplication:
"""Mock Application object providing all basic dependencies needed by stages"""
def __init__(self):
self.logger = self._create_mock_logger()
self.sess_mgr = self._create_mock_session_manager()
self.model_mgr = self._create_mock_model_manager()
self.tool_mgr = self._create_mock_tool_manager()
self.plugin_connector = self._create_mock_plugin_connector()
self.persistence_mgr = self._create_mock_persistence_manager()
self.query_pool = self._create_mock_query_pool()
self.instance_config = self._create_mock_instance_config()
self.task_mgr = self._create_mock_task_manager()
def _create_mock_logger(self):
logger = Mock()
logger.debug = Mock()
logger.info = Mock()
logger.error = Mock()
logger.warning = Mock()
return logger
def _create_mock_session_manager(self):
sess_mgr = AsyncMock()
sess_mgr.get_session = AsyncMock()
sess_mgr.get_conversation = AsyncMock()
return sess_mgr
def _create_mock_model_manager(self):
model_mgr = AsyncMock()
model_mgr.get_model_by_uuid = AsyncMock()
return model_mgr
def _create_mock_tool_manager(self):
tool_mgr = AsyncMock()
tool_mgr.get_all_tools = AsyncMock(return_value=[])
return tool_mgr
def _create_mock_plugin_connector(self):
plugin_connector = AsyncMock()
plugin_connector.emit_event = AsyncMock()
return plugin_connector
def _create_mock_persistence_manager(self):
persistence_mgr = AsyncMock()
persistence_mgr.execute_async = AsyncMock()
return persistence_mgr
def _create_mock_query_pool(self):
query_pool = Mock()
query_pool.cached_queries = {}
query_pool.queries = []
query_pool.condition = AsyncMock()
return query_pool
def _create_mock_instance_config(self):
instance_config = Mock()
instance_config.data = {
'command': {'prefix': ['/', '!'], 'enable': True},
'concurrency': {'pipeline': 10},
}
return instance_config
def _create_mock_task_manager(self):
task_mgr = Mock()
task_mgr.create_task = Mock()
return task_mgr
@pytest.fixture
def mock_app():
"""Provides Mock Application instance"""
return MockApplication()
@pytest.fixture
def mock_session():
"""Provides Mock Session object"""
session = Mock()
session.launcher_type = provider_session.LauncherTypes.PERSON
session.launcher_id = 12345
session._semaphore = AsyncMock()
session._semaphore.locked = Mock(return_value=False)
session._semaphore.acquire = AsyncMock()
session._semaphore.release = AsyncMock()
return session
@pytest.fixture
def mock_conversation():
"""Provides Mock Conversation object"""
conversation = Mock()
conversation.uuid = 'test-conversation-uuid'
# Create mock prompt with copy method
mock_prompt = Mock()
mock_prompt.messages = []
mock_prompt.copy = Mock(return_value=Mock(messages=[]))
conversation.prompt = mock_prompt
# Create mock messages list with copy method
mock_messages = Mock()
mock_messages.copy = Mock(return_value=[])
conversation.messages = mock_messages
return conversation
@pytest.fixture
def mock_model():
"""Provides Mock Model object"""
model = Mock()
model.model_entity = Mock()
model.model_entity.uuid = 'test-model-uuid'
model.model_entity.abilities = ['func_call', 'vision']
return model
@pytest.fixture
def mock_adapter():
"""Provides Mock Adapter object"""
adapter = AsyncMock()
adapter.is_stream_output_supported = AsyncMock(return_value=False)
adapter.reply_message = AsyncMock()
adapter.reply_message_chunk = AsyncMock()
return adapter
@pytest.fixture
def sample_message_chain():
"""Provides sample message chain"""
return platform_message.MessageChain(
[
platform_message.Plain(text='Hello, this is a test message'),
]
)
@pytest.fixture
def sample_message_event(sample_message_chain):
"""Provides sample message event"""
event = Mock()
event.sender = Mock()
event.sender.id = 12345
event.time = 1609459200 # 2021-01-01 00:00:00
return event
@pytest.fixture
def sample_query(sample_message_chain, sample_message_event, mock_adapter):
"""Provides sample Query object - using model_construct to bypass validation"""
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
# Use model_construct to bypass Pydantic validation for test purposes
query = pipeline_query.Query.model_construct(
query_id='test-query-id',
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
message_chain=sample_message_chain,
message_event=sample_message_event,
adapter=mock_adapter,
pipeline_uuid='test-pipeline-uuid',
bot_uuid='test-bot-uuid',
pipeline_config={
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
},
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}},
},
session=None,
prompt=None,
messages=[],
user_message=None,
use_funcs=[],
use_llm_model_uuid=None,
variables={},
resp_messages=[],
resp_message_chain=None,
current_stage_name=None
)
return query
@pytest.fixture
def sample_pipeline_config():
"""Provides sample pipeline configuration"""
return {
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
},
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}},
'ratelimit': {'enable': True, 'algo': 'fixwin', 'window': 60, 'limit': 10},
}
def create_stage_result(
result_type: pipeline_entities.ResultType,
query: pipeline_query.Query,
user_notice: str = '',
console_notice: str = '',
debug_notice: str = '',
error_notice: str = '',
) -> pipeline_entities.StageProcessResult:
"""Helper function to create stage process result"""
return pipeline_entities.StageProcessResult(
result_type=result_type,
new_query=query,
user_notice=user_notice,
console_notice=console_notice,
debug_notice=debug_notice,
error_notice=error_notice,
)
def assert_result_continue(result: pipeline_entities.StageProcessResult):
"""Assert result is CONTINUE type"""
assert result.result_type == pipeline_entities.ResultType.CONTINUE
def assert_result_interrupt(result: pipeline_entities.StageProcessResult):
"""Assert result is INTERRUPT type"""
assert result.result_type == pipeline_entities.ResultType.INTERRUPT

View File

@@ -0,0 +1,189 @@
"""
BanSessionCheckStage unit tests
Tests the actual BanSessionCheckStage implementation from pkg.pipeline.bansess
"""
import pytest
from unittest.mock import Mock
from importlib import import_module
import langbot_plugin.api.entities.builtin.provider.session as provider_session
def get_modules():
"""Lazy import to ensure proper initialization order"""
# Import pipelinemgr first to trigger proper stage registration
pipelinemgr = import_module('pkg.pipeline.pipelinemgr')
bansess = import_module('pkg.pipeline.bansess.bansess')
entities = import_module('pkg.pipeline.entities')
return bansess, entities
@pytest.mark.asyncio
async def test_whitelist_allow(mock_app, sample_query):
"""Test whitelist allows matching session"""
bansess, entities = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345'
sample_query.pipeline_config = {
'trigger': {
'access-control': {
'mode': 'whitelist',
'whitelist': ['person_12345']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config)
result = await stage.process(sample_query, 'BanSessionCheckStage')
assert result.result_type == entities.ResultType.CONTINUE
assert result.new_query == sample_query
@pytest.mark.asyncio
async def test_whitelist_deny(mock_app, sample_query):
"""Test whitelist denies non-matching session"""
bansess, entities = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '99999'
sample_query.pipeline_config = {
'trigger': {
'access-control': {
'mode': 'whitelist',
'whitelist': ['person_12345']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config)
result = await stage.process(sample_query, 'BanSessionCheckStage')
assert result.result_type == entities.ResultType.INTERRUPT
@pytest.mark.asyncio
async def test_blacklist_allow(mock_app, sample_query):
"""Test blacklist allows non-matching session"""
bansess, entities = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345'
sample_query.pipeline_config = {
'trigger': {
'access-control': {
'mode': 'blacklist',
'blacklist': ['person_99999']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config)
result = await stage.process(sample_query, 'BanSessionCheckStage')
assert result.result_type == entities.ResultType.CONTINUE
@pytest.mark.asyncio
async def test_blacklist_deny(mock_app, sample_query):
"""Test blacklist denies matching session"""
bansess, entities = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345'
sample_query.pipeline_config = {
'trigger': {
'access-control': {
'mode': 'blacklist',
'blacklist': ['person_12345']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config)
result = await stage.process(sample_query, 'BanSessionCheckStage')
assert result.result_type == entities.ResultType.INTERRUPT
@pytest.mark.asyncio
async def test_wildcard_group(mock_app, sample_query):
"""Test group wildcard matching"""
bansess, entities = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
sample_query.launcher_id = '12345'
sample_query.pipeline_config = {
'trigger': {
'access-control': {
'mode': 'whitelist',
'whitelist': ['group_*']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config)
result = await stage.process(sample_query, 'BanSessionCheckStage')
assert result.result_type == entities.ResultType.CONTINUE
@pytest.mark.asyncio
async def test_wildcard_person(mock_app, sample_query):
"""Test person wildcard matching"""
bansess, entities = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345'
sample_query.pipeline_config = {
'trigger': {
'access-control': {
'mode': 'whitelist',
'whitelist': ['person_*']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config)
result = await stage.process(sample_query, 'BanSessionCheckStage')
assert result.result_type == entities.ResultType.CONTINUE
@pytest.mark.asyncio
async def test_user_id_wildcard(mock_app, sample_query):
"""Test user ID wildcard matching (*_id format)"""
bansess, entities = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345'
sample_query.sender_id = '67890'
sample_query.pipeline_config = {
'trigger': {
'access-control': {
'mode': 'whitelist',
'whitelist': ['*_67890']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config)
result = await stage.process(sample_query, 'BanSessionCheckStage')
assert result.result_type == entities.ResultType.CONTINUE

View File

@@ -0,0 +1,166 @@
"""
PipelineManager unit tests
"""
import pytest
from unittest.mock import AsyncMock, Mock
from importlib import import_module
import sqlalchemy
def get_pipelinemgr_module():
return import_module('pkg.pipeline.pipelinemgr')
def get_stage_module():
return import_module('pkg.pipeline.stage')
def get_entities_module():
return import_module('pkg.pipeline.entities')
def get_persistence_pipeline_module():
return import_module('pkg.entity.persistence.pipeline')
@pytest.mark.asyncio
async def test_pipeline_manager_initialize(mock_app):
"""Test pipeline manager initialization"""
pipelinemgr = get_pipelinemgr_module()
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
manager = pipelinemgr.PipelineManager(mock_app)
await manager.initialize()
assert manager.stage_dict is not None
assert len(manager.pipelines) == 0
@pytest.mark.asyncio
async def test_load_pipeline(mock_app):
"""Test loading a single pipeline"""
pipelinemgr = get_pipelinemgr_module()
persistence_pipeline = get_persistence_pipeline_module()
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
manager = pipelinemgr.PipelineManager(mock_app)
await manager.initialize()
# Create test pipeline entity
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
pipeline_entity.uuid = 'test-uuid'
pipeline_entity.stages = []
pipeline_entity.config = {'test': 'config'}
await manager.load_pipeline(pipeline_entity)
assert len(manager.pipelines) == 1
assert manager.pipelines[0].pipeline_entity.uuid == 'test-uuid'
@pytest.mark.asyncio
async def test_get_pipeline_by_uuid(mock_app):
"""Test getting pipeline by UUID"""
pipelinemgr = get_pipelinemgr_module()
persistence_pipeline = get_persistence_pipeline_module()
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
manager = pipelinemgr.PipelineManager(mock_app)
await manager.initialize()
# Create and add test pipeline
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
pipeline_entity.uuid = 'test-uuid'
pipeline_entity.stages = []
pipeline_entity.config = {}
await manager.load_pipeline(pipeline_entity)
# Test retrieval
result = await manager.get_pipeline_by_uuid('test-uuid')
assert result is not None
assert result.pipeline_entity.uuid == 'test-uuid'
# Test non-existent UUID
result = await manager.get_pipeline_by_uuid('non-existent')
assert result is None
@pytest.mark.asyncio
async def test_remove_pipeline(mock_app):
"""Test removing a pipeline"""
pipelinemgr = get_pipelinemgr_module()
persistence_pipeline = get_persistence_pipeline_module()
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
manager = pipelinemgr.PipelineManager(mock_app)
await manager.initialize()
# Create and add test pipeline
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
pipeline_entity.uuid = 'test-uuid'
pipeline_entity.stages = []
pipeline_entity.config = {}
await manager.load_pipeline(pipeline_entity)
assert len(manager.pipelines) == 1
# Remove pipeline
await manager.remove_pipeline('test-uuid')
assert len(manager.pipelines) == 0
@pytest.mark.asyncio
async def test_runtime_pipeline_execute(mock_app, sample_query):
"""Test runtime pipeline execution"""
pipelinemgr = get_pipelinemgr_module()
stage = get_stage_module()
persistence_pipeline = get_persistence_pipeline_module()
# Create mock stage that returns a simple result dict (avoiding Pydantic validation)
mock_result = Mock()
mock_result.result_type = Mock()
mock_result.result_type.value = 'CONTINUE' # Simulate enum value
mock_result.new_query = sample_query
mock_result.user_notice = ''
mock_result.console_notice = ''
mock_result.debug_notice = ''
mock_result.error_notice = ''
# Make it look like ResultType.CONTINUE
from unittest.mock import MagicMock
CONTINUE = MagicMock()
CONTINUE.__eq__ = lambda self, other: True # Always equal for comparison
mock_result.result_type = CONTINUE
mock_stage = Mock(spec=stage.PipelineStage)
mock_stage.process = AsyncMock(return_value=mock_result)
# Create stage container
stage_container = pipelinemgr.StageInstContainer(inst_name='TestStage', inst=mock_stage)
# Create pipeline entity
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
pipeline_entity.config = sample_query.pipeline_config
# Create runtime pipeline
runtime_pipeline = pipelinemgr.RuntimePipeline(mock_app, pipeline_entity, [stage_container])
# Mock plugin connector
event_ctx = Mock()
event_ctx.is_prevented_default = Mock(return_value=False)
mock_app.plugin_connector.emit_event = AsyncMock(return_value=event_ctx)
# Add query to cached_queries to prevent KeyError in finally block
mock_app.query_pool.cached_queries[sample_query.query_id] = sample_query
# Execute pipeline
await runtime_pipeline.run(sample_query)
# Verify stage was called
mock_stage.process.assert_called_once()

View File

@@ -0,0 +1,109 @@
"""
RateLimit stage unit tests
Tests the actual RateLimit implementation from pkg.pipeline.ratelimit
"""
import pytest
from unittest.mock import AsyncMock, Mock, patch
from importlib import import_module
import langbot_plugin.api.entities.builtin.provider.session as provider_session
def get_modules():
"""Lazy import to ensure proper initialization order"""
# Import pipelinemgr first to trigger proper stage registration
pipelinemgr = import_module('pkg.pipeline.pipelinemgr')
ratelimit = import_module('pkg.pipeline.ratelimit.ratelimit')
entities = import_module('pkg.pipeline.entities')
algo_module = import_module('pkg.pipeline.ratelimit.algo')
return ratelimit, entities, algo_module
@pytest.mark.asyncio
async def test_require_access_allowed(mock_app, sample_query):
"""Test RequireRateLimitOccupancy allows access when rate limit is not exceeded"""
ratelimit, entities, algo_module = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345'
sample_query.pipeline_config = {}
# Create mock algorithm that allows access
mock_algo = Mock(spec=algo_module.ReteLimitAlgo)
mock_algo.require_access = AsyncMock(return_value=True)
mock_algo.initialize = AsyncMock()
stage = ratelimit.RateLimit(mock_app)
# Patch the algorithm selection to use our mock
with patch.object(algo_module, 'preregistered_algos', []):
stage.algo = mock_algo
result = await stage.process(sample_query, 'RequireRateLimitOccupancy')
assert result.result_type == entities.ResultType.CONTINUE
assert result.new_query == sample_query
mock_algo.require_access.assert_called_once_with(
sample_query,
'person',
'12345'
)
@pytest.mark.asyncio
async def test_require_access_denied(mock_app, sample_query):
"""Test RequireRateLimitOccupancy denies access when rate limit is exceeded"""
ratelimit, entities, algo_module = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345'
sample_query.pipeline_config = {}
# Create mock algorithm that denies access
mock_algo = Mock(spec=algo_module.ReteLimitAlgo)
mock_algo.require_access = AsyncMock(return_value=False)
mock_algo.initialize = AsyncMock()
stage = ratelimit.RateLimit(mock_app)
# Patch the algorithm selection to use our mock
with patch.object(algo_module, 'preregistered_algos', []):
stage.algo = mock_algo
result = await stage.process(sample_query, 'RequireRateLimitOccupancy')
assert result.result_type == entities.ResultType.INTERRUPT
assert result.user_notice != ''
mock_algo.require_access.assert_called_once()
@pytest.mark.asyncio
async def test_release_access(mock_app, sample_query):
"""Test ReleaseRateLimitOccupancy releases rate limit occupancy"""
ratelimit, entities, algo_module = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345'
sample_query.pipeline_config = {}
# Create mock algorithm
mock_algo = Mock(spec=algo_module.ReteLimitAlgo)
mock_algo.release_access = AsyncMock()
mock_algo.initialize = AsyncMock()
stage = ratelimit.RateLimit(mock_app)
# Patch the algorithm selection to use our mock
with patch.object(algo_module, 'preregistered_algos', []):
stage.algo = mock_algo
result = await stage.process(sample_query, 'ReleaseRateLimitOccupancy')
assert result.result_type == entities.ResultType.CONTINUE
assert result.new_query == sample_query
mock_algo.release_access.assert_called_once_with(
sample_query,
'person',
'12345'
)

View File

@@ -0,0 +1,171 @@
"""
GroupRespondRuleCheckStage unit tests
Tests the actual GroupRespondRuleCheckStage implementation from pkg.pipeline.resprule
"""
import pytest
from unittest.mock import AsyncMock, Mock
from importlib import import_module
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.platform.message as platform_message
def get_modules():
"""Lazy import to ensure proper initialization order"""
# Import pipelinemgr first to trigger proper stage registration
pipelinemgr = import_module('pkg.pipeline.pipelinemgr')
resprule = import_module('pkg.pipeline.resprule.resprule')
entities = import_module('pkg.pipeline.entities')
rule = import_module('pkg.pipeline.resprule.rule')
rule_entities = import_module('pkg.pipeline.resprule.entities')
return resprule, entities, rule, rule_entities
@pytest.mark.asyncio
async def test_person_message_skip(mock_app, sample_query):
"""Test person message skips rule check"""
resprule, entities, rule, rule_entities = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.pipeline_config = {
'trigger': {
'group-respond-rules': {}
}
}
stage = resprule.GroupRespondRuleCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config)
result = await stage.process(sample_query, 'GroupRespondRuleCheckStage')
assert result.result_type == entities.ResultType.CONTINUE
assert result.new_query == sample_query
@pytest.mark.asyncio
async def test_group_message_no_match(mock_app, sample_query):
"""Test group message with no matching rules"""
resprule, entities, rule, rule_entities = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
sample_query.launcher_id = '12345'
sample_query.pipeline_config = {
'trigger': {
'group-respond-rules': {}
}
}
# Create mock rule matcher that doesn't match
mock_rule = Mock(spec=rule.GroupRespondRule)
mock_rule.match = AsyncMock(return_value=rule_entities.RuleJudgeResult(
matching=False,
replacement=sample_query.message_chain
))
stage = resprule.GroupRespondRuleCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config)
stage.rule_matchers = [mock_rule]
result = await stage.process(sample_query, 'GroupRespondRuleCheckStage')
assert result.result_type == entities.ResultType.INTERRUPT
assert result.new_query == sample_query
mock_rule.match.assert_called_once()
@pytest.mark.asyncio
async def test_group_message_match(mock_app, sample_query):
"""Test group message with matching rule"""
resprule, entities, rule, rule_entities = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
sample_query.launcher_id = '12345'
sample_query.pipeline_config = {
'trigger': {
'group-respond-rules': {}
}
}
# Create new message chain after rule processing
new_chain = platform_message.MessageChain([
platform_message.Plain(text='Processed message')
])
# Create mock rule matcher that matches
mock_rule = Mock(spec=rule.GroupRespondRule)
mock_rule.match = AsyncMock(return_value=rule_entities.RuleJudgeResult(
matching=True,
replacement=new_chain
))
stage = resprule.GroupRespondRuleCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config)
stage.rule_matchers = [mock_rule]
result = await stage.process(sample_query, 'GroupRespondRuleCheckStage')
assert result.result_type == entities.ResultType.CONTINUE
assert result.new_query == sample_query
assert sample_query.message_chain == new_chain
mock_rule.match.assert_called_once()
@pytest.mark.asyncio
async def test_atbot_rule_match(mock_app, sample_query):
"""Test AtBotRule removes At component"""
resprule, entities, rule, rule_entities = get_modules()
atbot_module = import_module('pkg.pipeline.resprule.rules.atbot')
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
sample_query.adapter.bot_account_id = '999'
# Create message chain with At component
message_chain = platform_message.MessageChain([
platform_message.At(target='999'),
platform_message.Plain(text='Hello bot')
])
sample_query.message_chain = message_chain
atbot_rule = atbot_module.AtBotRule(mock_app)
await atbot_rule.initialize()
result = await atbot_rule.match(
str(message_chain),
message_chain,
{},
sample_query
)
assert result.matching is True
# At component should be removed
assert len(result.replacement.root) == 1
assert isinstance(result.replacement.root[0], platform_message.Plain)
@pytest.mark.asyncio
async def test_atbot_rule_no_match(mock_app, sample_query):
"""Test AtBotRule when no At component present"""
resprule, entities, rule, rule_entities = get_modules()
atbot_module = import_module('pkg.pipeline.resprule.rules.atbot')
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
sample_query.adapter.bot_account_id = '999'
# Create message chain without At component
message_chain = platform_message.MessageChain([
platform_message.Plain(text='Hello')
])
sample_query.message_chain = message_chain
atbot_rule = atbot_module.AtBotRule(mock_app)
await atbot_rule.initialize()
result = await atbot_rule.match(
str(message_chain),
message_chain,
{},
sample_query
)
assert result.matching is False

View File

@@ -0,0 +1,40 @@
"""
Simple standalone tests to verify test infrastructure
These tests don't import the actual pipeline code to avoid circular import issues
"""
import pytest
from unittest.mock import Mock, AsyncMock
def test_pytest_works():
"""Verify pytest is working"""
assert True
@pytest.mark.asyncio
async def test_async_works():
"""Verify async tests work"""
mock = AsyncMock(return_value=42)
result = await mock()
assert result == 42
def test_mocks_work():
"""Verify mocking works"""
mock = Mock()
mock.return_value = 'test'
assert mock() == 'test'
def test_fixtures_work(mock_app):
"""Verify fixtures are loaded"""
assert mock_app is not None
assert mock_app.logger is not None
assert mock_app.sess_mgr is not None
def test_sample_query(sample_query):
"""Verify sample query fixture works"""
assert sample_query.query_id == 'test-query-id'
assert sample_query.launcher_id == 12345

View File

@@ -23,6 +23,7 @@
"@dnd-kit/sortable": "^10.0.0",
"@dnd-kit/utilities": "^3.2.2",
"@hookform/resolvers": "^5.0.1",
"@radix-ui/react-alert-dialog": "^1.1.15",
"@radix-ui/react-checkbox": "^1.3.1",
"@radix-ui/react-context-menu": "^2.2.15",
"@radix-ui/react-dialog": "^1.1.14",

View File

@@ -115,7 +115,6 @@ export default function BotForm({
useEffect(() => {
setBotFormValues();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
function setBotFormValues() {

View File

@@ -11,18 +11,23 @@ import {
FormMessage,
} from '@/components/ui/form';
import DynamicFormItemComponent from '@/app/home/components/dynamic-form/DynamicFormItemComponent';
import { useEffect } from 'react';
import { useEffect, useRef } from 'react';
import { extractI18nObject } from '@/i18n/I18nProvider';
export default function DynamicFormComponent({
itemConfigList,
onSubmit,
initialValues,
onFileUploaded,
}: {
itemConfigList: IDynamicFormItemSchema[];
onSubmit?: (val: object) => unknown;
initialValues?: Record<string, object>;
onFileUploaded?: (fileKey: string) => void;
}) {
const isInitialMount = useRef(true);
const previousInitialValues = useRef(initialValues);
// 根据 itemConfigList 动态生成 zod schema
const formSchema = z.object(
itemConfigList.reduce(
@@ -97,9 +102,24 @@ export default function DynamicFormComponent({
});
// 当 initialValues 变化时更新表单值
// 但要避免因为内部表单更新触发的 onSubmit 导致的 initialValues 变化而重新设置表单
useEffect(() => {
console.log('initialValues', initialValues);
if (initialValues) {
// 首次挂载时,使用 initialValues 初始化表单
if (isInitialMount.current) {
isInitialMount.current = false;
previousInitialValues.current = initialValues;
return;
}
// 检查 initialValues 是否真的发生了实质性变化
// 使用 JSON.stringify 进行深度比较
const hasRealChange =
JSON.stringify(previousInitialValues.current) !==
JSON.stringify(initialValues);
if (initialValues && hasRealChange) {
// 合并默认值和初始值
const mergedValues = itemConfigList.reduce(
(acc, item) => {
@@ -112,6 +132,8 @@ export default function DynamicFormComponent({
Object.entries(mergedValues).forEach(([key, value]) => {
form.setValue(key as keyof FormValues, value);
});
previousInitialValues.current = initialValues;
}
}, [initialValues, form, itemConfigList]);
@@ -149,7 +171,11 @@ export default function DynamicFormComponent({
{config.required && <span className="text-red-500">*</span>}
</FormLabel>
<FormControl>
<DynamicFormItemComponent config={config} field={field} />
<DynamicFormItemComponent
config={config}
field={field}
onFileUploaded={onFileUploaded}
/>
</FormControl>
{config.description && (
<p className="text-sm text-muted-foreground">

View File

@@ -1,6 +1,7 @@
import {
DynamicFormItemType,
IDynamicFormItemSchema,
IFileConfig,
} from '@/app/infra/entities/form/dynamic';
import { Input } from '@/components/ui/input';
import {
@@ -27,19 +28,53 @@ import {
import { useTranslation } from 'react-i18next';
import { extractI18nObject } from '@/i18n/I18nProvider';
import { Textarea } from '@/components/ui/textarea';
import { Card, CardContent } from '@/components/ui/card';
export default function DynamicFormItemComponent({
config,
field,
onFileUploaded,
}: {
config: IDynamicFormItemSchema;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
field: ControllerRenderProps<any, any>;
onFileUploaded?: (fileKey: string) => void;
}) {
const [llmModels, setLlmModels] = useState<LLMModel[]>([]);
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBase[]>([]);
const [uploading, setUploading] = useState<boolean>(false);
const { t } = useTranslation();
const handleFileUpload = async (file: File): Promise<IFileConfig | null> => {
const MAX_FILE_SIZE = 10 * 1024 * 1024; // 10MB
if (file.size > MAX_FILE_SIZE) {
toast.error(t('plugins.fileUpload.tooLarge'));
return null;
}
try {
setUploading(true);
const response = await httpClient.uploadPluginConfigFile(file);
toast.success(t('plugins.fileUpload.success'));
// 通知父组件文件已上传
onFileUploaded?.(response.file_key);
return {
file_key: response.file_key,
mimetype: file.type,
};
} catch (error) {
toast.error(
t('plugins.fileUpload.failed') + ': ' + (error as Error).message,
);
return null;
} finally {
setUploading(false);
}
};
useEffect(() => {
if (config.type === DynamicFormItemType.LLM_MODEL_SELECTOR) {
httpClient
@@ -80,6 +115,9 @@ export default function DynamicFormItemComponent({
case DynamicFormItemType.STRING:
return <Input {...field} />;
case DynamicFormItemType.TEXT:
return <Textarea {...field} className="min-h-[120px]" />;
case DynamicFormItemType.BOOLEAN:
return <Switch checked={field.value} onCheckedChange={field.onChange} />;
@@ -366,6 +404,185 @@ export default function DynamicFormItemComponent({
</div>
);
case DynamicFormItemType.FILE:
return (
<div className="space-y-2">
{field.value && (field.value as IFileConfig).file_key ? (
<Card className="py-3 max-w-full overflow-hidden bg-gray-900">
<CardContent className="flex items-center gap-3 p-0 px-4 min-w-0">
<div className="flex-1 min-w-0 overflow-hidden">
<div
className="text-sm font-medium truncate"
title={(field.value as IFileConfig).file_key}
>
{(field.value as IFileConfig).file_key}
</div>
<div className="text-xs text-muted-foreground truncate">
{(field.value as IFileConfig).mimetype}
</div>
</div>
<Button
type="button"
variant="ghost"
size="sm"
className="flex-shrink-0 h-8 w-8 p-0"
onClick={(e) => {
e.preventDefault();
e.stopPropagation();
field.onChange(null);
}}
title={t('common.delete')}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
className="w-4 h-4 text-destructive"
>
<path d="M7 4V2H17V4H22V6H20V21C20 21.5523 19.5523 22 19 22H5C4.44772 22 4 21.5523 4 21V6H2V4H7ZM6 6V20H18V6H6ZM9 9H11V17H9V9ZM13 9H15V17H13V9Z"></path>
</svg>
</Button>
</CardContent>
</Card>
) : (
<div className="relative">
<input
type="file"
accept={config.accept}
disabled={uploading}
onChange={async (e) => {
const file = e.target.files?.[0];
if (file) {
const fileConfig = await handleFileUpload(file);
if (fileConfig) {
field.onChange(fileConfig);
}
}
e.target.value = '';
}}
className="hidden"
id={`file-input-${config.name}`}
/>
<Button
type="button"
variant="outline"
size="sm"
disabled={uploading}
onClick={() =>
document.getElementById(`file-input-${config.name}`)?.click()
}
>
<svg
className="w-4 h-4 mr-2"
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M11 11V5H13V11H19V13H13V19H11V13H5V11H11Z"></path>
</svg>
{uploading
? t('plugins.fileUpload.uploading')
: t('plugins.fileUpload.chooseFile')}
</Button>
</div>
)}
</div>
);
case DynamicFormItemType.FILE_ARRAY:
return (
<div className="space-y-2">
{(field.value as IFileConfig[])?.map(
(fileConfig: IFileConfig, index: number) => (
<Card
key={index}
className="py-3 max-w-full overflow-hidden bg-gray-900"
>
<CardContent className="flex items-center gap-3 p-0 px-4 min-w-0">
<div className="flex-1 min-w-0 overflow-hidden">
<div
className="text-sm font-medium truncate"
title={fileConfig.file_key}
>
{fileConfig.file_key}
</div>
<div className="text-xs text-muted-foreground truncate">
{fileConfig.mimetype}
</div>
</div>
<Button
type="button"
variant="ghost"
size="sm"
className="flex-shrink-0 h-8 w-8 p-0"
onClick={(e) => {
e.preventDefault();
e.stopPropagation();
const newValue = (field.value as IFileConfig[]).filter(
(_: IFileConfig, i: number) => i !== index,
);
field.onChange(newValue);
}}
title={t('common.delete')}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
className="w-4 h-4 text-destructive"
>
<path d="M7 4V2H17V4H22V6H20V21C20 21.5523 19.5523 22 19 22H5C4.44772 22 4 21.5523 4 21V6H2V4H7ZM6 6V20H18V6H6ZM9 9H11V17H9V9ZM13 9H15V17H13V9Z"></path>
</svg>
</Button>
</CardContent>
</Card>
),
)}
<div className="relative">
<input
type="file"
accept={config.accept}
disabled={uploading}
onChange={async (e) => {
const file = e.target.files?.[0];
if (file) {
const fileConfig = await handleFileUpload(file);
if (fileConfig) {
field.onChange([...(field.value || []), fileConfig]);
}
}
e.target.value = '';
}}
className="hidden"
id={`file-array-input-${config.name}`}
/>
<Button
type="button"
variant="outline"
size="sm"
disabled={uploading}
onClick={() =>
document
.getElementById(`file-array-input-${config.name}`)
?.click()
}
>
<svg
className="w-4 h-4 mr-2"
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M11 11V5H13V11H19V13H13V19H11V13H5V11H11Z"></path>
</svg>
{uploading
? t('plugins.fileUpload.uploading')
: t('plugins.fileUpload.addFile')}
</Button>
</div>
</div>
);
default:
return <Input {...field} />;
}

View File

@@ -65,7 +65,6 @@ export default function HomeSidebar({
console.error('Failed to fetch GitHub star count:', error);
});
return () => console.log('sidebar.unmounted');
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
function handleChildClick(child: SidebarChildVO) {

View File

@@ -3,7 +3,7 @@
import styles from './layout.module.css';
import HomeSidebar from '@/app/home/components/home-sidebar/HomeSidebar';
import HomeTitleBar from '@/app/home/components/home-titlebar/HomeTitleBar';
import React, { useState } from 'react';
import React, { useState, useCallback, useMemo } from 'react';
import { SidebarChildVO } from '@/app/home/components/home-sidebar/HomeSidebarChild';
import { I18nObject } from '@/app/infra/entities/common';
@@ -18,11 +18,15 @@ export default function HomeLayout({
en_US: '',
zh_Hans: '',
});
const onSelectedChangeAction = (child: SidebarChildVO) => {
const onSelectedChangeAction = useCallback((child: SidebarChildVO) => {
setTitle(child.name);
setSubtitle(child.description);
setHelpLink(child.helpLink);
};
}, []);
// Memoize the main content area to prevent re-renders when sidebar state changes
const mainContent = useMemo(() => children, [children]);
return (
<div className={styles.homeLayoutContainer}>
@@ -33,7 +37,7 @@ export default function HomeLayout({
<div className={styles.main}>
<HomeTitleBar title={title} subtitle={subtitle} helpLink={helpLink} />
<main className={styles.mainContent}>{children}</main>
<main className={styles.mainContent}>{mainContent}</main>
</div>
</div>
);

View File

@@ -29,7 +29,17 @@ export default function PluginConfigPage() {
const [sortOrderValue, setSortOrderValue] = useState<string>('DESC');
useEffect(() => {
getPipelines();
// Load sort preference from localStorage
const savedSortBy = localStorage.getItem('pipeline_sort_by');
const savedSortOrder = localStorage.getItem('pipeline_sort_order');
if (savedSortBy && savedSortOrder) {
setSortByValue(savedSortBy);
setSortOrderValue(savedSortOrder);
getPipelines(savedSortBy, savedSortOrder);
} else {
getPipelines();
}
}, []);
function getPipelines(
@@ -91,6 +101,11 @@ export default function PluginConfigPage() {
const [newSortBy, newSortOrder] = value.split(',').map((s) => s.trim());
setSortByValue(newSortBy);
setSortOrderValue(newSortOrder);
// Save sort preference to localStorage
localStorage.setItem('pipeline_sort_by', newSortBy);
localStorage.setItem('pipeline_sort_order', newSortOrder);
getPipelines(newSortBy, newSortOrder);
}
@@ -135,6 +150,12 @@ export default function PluginConfigPage() {
>
{t('pipelines.newestCreated')}
</SelectItem>
<SelectItem
value="created_at,ASC"
className="text-gray-900 dark:text-gray-100"
>
{t('pipelines.earliestCreated')}
</SelectItem>
<SelectItem
value="updated_at,DESC"
className="text-gray-900 dark:text-gray-100"

View File

@@ -0,0 +1,75 @@
import { PluginComponent } from '@/app/infra/entities/plugin';
import { TFunction } from 'i18next';
import { Wrench, AudioWaveform, Hash } from 'lucide-react';
import { Badge } from '@/components/ui/badge';
export default function PluginComponentList({
components,
showComponentName,
showTitle,
useBadge,
t,
}: {
components: PluginComponent[];
showComponentName: boolean;
showTitle: boolean;
useBadge: boolean;
t: TFunction;
}) {
const componentKindCount: Record<string, number> = {};
for (const component of components) {
const kind = component.manifest.manifest.kind;
if (componentKindCount[kind]) {
componentKindCount[kind]++;
} else {
componentKindCount[kind] = 1;
}
}
const kindIconMap: Record<string, React.ReactNode> = {
Tool: <Wrench className="w-5 h-5" />,
EventListener: <AudioWaveform className="w-5 h-5" />,
Command: <Hash className="w-5 h-5" />,
};
const componentKindList = Object.keys(componentKindCount);
return (
<>
{showTitle && <div>{t('plugins.componentsList')}</div>}
{componentKindList.length > 0 && (
<>
{componentKindList.map((kind) => {
return (
<>
{useBadge && (
<Badge variant="outline">
{kindIconMap[kind]}
{showComponentName &&
t('plugins.componentName.' + kind) + ' '}
{componentKindCount[kind]}
</Badge>
)}
{!useBadge && (
<div
key={kind}
className="flex flex-row items-center justify-start gap-[0.2rem]"
>
{kindIconMap[kind]}
{showComponentName &&
t('plugins.componentName.' + kind) + ' '}
{componentKindCount[kind]}
</div>
)}
</>
);
})}
</>
)}
{componentKindList.length === 0 && <div>{t('plugins.noComponents')}</div>}
</>
);
}

View File

@@ -1,9 +1,9 @@
'use client';
import { useState, useEffect, forwardRef, useImperativeHandle } from 'react';
import { PluginCardVO } from '@/app/home/plugins/plugin-installed/PluginCardVO';
import PluginCardComponent from '@/app/home/plugins/plugin-installed/plugin-card/PluginCardComponent';
import PluginForm from '@/app/home/plugins/plugin-installed/plugin-form/PluginForm';
import { PluginCardVO } from '@/app/home/plugins/components/plugin-installed/PluginCardVO';
import PluginCardComponent from '@/app/home/plugins/components/plugin-installed/plugin-card/PluginCardComponent';
import PluginForm from '@/app/home/plugins/components/plugin-installed/plugin-form/PluginForm';
import styles from '@/app/home/plugins/plugins.module.css';
import { httpClient } from '@/app/infra/http/HttpClient';
import {
@@ -15,6 +15,7 @@ import {
DialogFooter,
} from '@/components/ui/dialog';
import { Button } from '@/components/ui/button';
import { Checkbox } from '@/components/ui/checkbox';
import { useTranslation } from 'react-i18next';
import { extractI18nObject } from '@/i18n/I18nProvider';
import { toast } from 'sonner';
@@ -43,6 +44,7 @@ const PluginInstalledComponent = forwardRef<PluginInstalledComponentRef>(
PluginOperationType.DELETE,
);
const [targetPlugin, setTargetPlugin] = useState<PluginCardVO | null>(null);
const [deleteData, setDeleteData] = useState<boolean>(false);
const asyncTask = useAsyncTask({
onSuccess: () => {
@@ -61,7 +63,6 @@ const PluginInstalledComponent = forwardRef<PluginInstalledComponentRef>(
useEffect(() => {
initData();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
function initData() {
@@ -109,6 +110,7 @@ const PluginInstalledComponent = forwardRef<PluginInstalledComponentRef>(
setTargetPlugin(plugin);
setOperationType(PluginOperationType.DELETE);
setShowOperationModal(true);
setDeleteData(false);
asyncTask.reset();
}
@@ -124,7 +126,11 @@ const PluginInstalledComponent = forwardRef<PluginInstalledComponentRef>(
const apiCall =
operationType === PluginOperationType.DELETE
? httpClient.removePlugin(targetPlugin.author, targetPlugin.name)
? httpClient.removePlugin(
targetPlugin.author,
targetPlugin.name,
deleteData,
)
: httpClient.upgradePlugin(targetPlugin.author, targetPlugin.name);
apiCall
@@ -162,16 +168,35 @@ const PluginInstalledComponent = forwardRef<PluginInstalledComponentRef>(
</DialogHeader>
<DialogDescription>
{asyncTask.status === AsyncTaskStatus.WAIT_INPUT && (
<div>
{operationType === PluginOperationType.DELETE
? t('plugins.confirmDeletePlugin', {
author: targetPlugin?.author ?? '',
name: targetPlugin?.name ?? '',
})
: t('plugins.confirmUpdatePlugin', {
author: targetPlugin?.author ?? '',
name: targetPlugin?.name ?? '',
})}
<div className="flex flex-col gap-4">
<div>
{operationType === PluginOperationType.DELETE
? t('plugins.confirmDeletePlugin', {
author: targetPlugin?.author ?? '',
name: targetPlugin?.name ?? '',
})
: t('plugins.confirmUpdatePlugin', {
author: targetPlugin?.author ?? '',
name: targetPlugin?.name ?? '',
})}
</div>
{operationType === PluginOperationType.DELETE && (
<div className="flex items-center space-x-2">
<Checkbox
id="delete-data"
checked={deleteData}
onCheckedChange={(checked) =>
setDeleteData(checked === true)
}
/>
<label
htmlFor="delete-data"
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70 cursor-pointer"
>
{t('plugins.deleteDataCheckbox')}
</label>
</div>
)}
</div>
)}
{asyncTask.status === AsyncTaskStatus.RUNNING && (

View File

@@ -1,21 +1,10 @@
import { PluginCardVO } from '@/app/home/plugins/plugin-installed/PluginCardVO';
import { PluginCardVO } from '@/app/home/plugins/components/plugin-installed/PluginCardVO';
import { useState } from 'react';
import { Badge } from '@/components/ui/badge';
import { useTranslation } from 'react-i18next';
import { TFunction } from 'i18next';
import {
AudioWaveform,
Wrench,
Hash,
BugIcon,
ExternalLink,
Ellipsis,
Trash,
ArrowUp,
} from 'lucide-react';
import { BugIcon, ExternalLink, Ellipsis, Trash, ArrowUp } from 'lucide-react';
import { getCloudServiceClientSync } from '@/app/infra/http';
import { httpClient } from '@/app/infra/http/HttpClient';
import { PluginComponent } from '@/app/infra/entities/plugin';
import { Button } from '@/components/ui/button';
import {
DropdownMenu,
@@ -23,49 +12,7 @@ import {
DropdownMenuItem,
DropdownMenuTrigger,
} from '@/components/ui/dropdown-menu';
function getComponentList(components: PluginComponent[], t: TFunction) {
const componentKindCount: Record<string, number> = {};
for (const component of components) {
const kind = component.manifest.manifest.kind;
if (componentKindCount[kind]) {
componentKindCount[kind]++;
} else {
componentKindCount[kind] = 1;
}
}
const kindIconMap: Record<string, React.ReactNode> = {
Tool: <Wrench className="w-5 h-5" />,
EventListener: <AudioWaveform className="w-5 h-5" />,
Command: <Hash className="w-5 h-5" />,
};
const componentKindList = Object.keys(componentKindCount);
return (
<>
<div>{t('plugins.componentsList')}</div>
{componentKindList.length > 0 && (
<>
{componentKindList.map((kind) => {
return (
<div
key={kind}
className="flex flex-row items-center justify-start gap-[0.4rem]"
>
{kindIconMap[kind]} {componentKindCount[kind]}
</div>
);
})}
</>
)}
{componentKindList.length === 0 && <div>{t('plugins.noComponents')}</div>}
</>
);
}
import PluginComponentList from '@/app/home/plugins/components/plugin-installed/PluginComponentList';
export default function PluginCardComponent({
cardVO,
@@ -180,7 +127,13 @@ export default function PluginCardComponent({
</div>
<div className="w-full flex flex-row items-start justify-start gap-[0.6rem]">
{getComponentList(cardVO.components, t)}
<PluginComponentList
components={cardVO.components}
showComponentName={false}
showTitle={true}
useBadge={false}
t={t}
/>
</div>
</div>

View File

@@ -0,0 +1,208 @@
import { useState, useEffect, useRef } from 'react';
import { ApiRespPluginConfig } from '@/app/infra/entities/api';
import { Plugin } from '@/app/infra/entities/plugin';
import { httpClient } from '@/app/infra/http/HttpClient';
import DynamicFormComponent from '@/app/home/components/dynamic-form/DynamicFormComponent';
import { Button } from '@/components/ui/button';
import { toast } from 'sonner';
import { extractI18nObject } from '@/i18n/I18nProvider';
import { useTranslation } from 'react-i18next';
import PluginComponentList from '@/app/home/plugins/components/plugin-installed/PluginComponentList';
export default function PluginForm({
pluginAuthor,
pluginName,
onFormSubmit,
onFormCancel,
}: {
pluginAuthor: string;
pluginName: string;
onFormSubmit: (timeout?: number) => void;
onFormCancel: () => void;
}) {
const { t } = useTranslation();
const [pluginInfo, setPluginInfo] = useState<Plugin>();
const [pluginConfig, setPluginConfig] = useState<ApiRespPluginConfig>();
const [isSaving, setIsLoading] = useState(false);
const currentFormValues = useRef<object>({});
const uploadedFileKeys = useRef<Set<string>>(new Set());
const initialFileKeys = useRef<Set<string>>(new Set());
useEffect(() => {
// 获取插件信息
httpClient.getPlugin(pluginAuthor, pluginName).then((res) => {
setPluginInfo(res.plugin);
});
// 获取插件配置
httpClient.getPluginConfig(pluginAuthor, pluginName).then((res) => {
setPluginConfig(res);
// 提取初始配置中的所有文件 key
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const extractFileKeys = (obj: any): string[] => {
const keys: string[] = [];
if (obj && typeof obj === 'object') {
if ('file_key' in obj && typeof obj.file_key === 'string') {
keys.push(obj.file_key);
}
for (const value of Object.values(obj)) {
if (Array.isArray(value)) {
value.forEach((item) => keys.push(...extractFileKeys(item)));
} else if (typeof value === 'object' && value !== null) {
keys.push(...extractFileKeys(value));
}
}
}
return keys;
};
const fileKeys = extractFileKeys(res.config);
initialFileKeys.current = new Set(fileKeys);
});
}, [pluginAuthor, pluginName]);
const handleSubmit = async () => {
setIsLoading(true);
const isDebugPlugin = pluginInfo?.debug;
try {
// 保存配置
await httpClient.updatePluginConfig(
pluginAuthor,
pluginName,
currentFormValues.current,
);
// 提取最终保存的配置中的所有文件 key
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const extractFileKeys = (obj: any): string[] => {
const keys: string[] = [];
if (obj && typeof obj === 'object') {
if ('file_key' in obj && typeof obj.file_key === 'string') {
keys.push(obj.file_key);
}
for (const value of Object.values(obj)) {
if (Array.isArray(value)) {
value.forEach((item) => keys.push(...extractFileKeys(item)));
} else if (typeof value === 'object' && value !== null) {
keys.push(...extractFileKeys(value));
}
}
}
return keys;
};
const finalFileKeys = new Set(extractFileKeys(currentFormValues.current));
// 计算需要删除的文件:
// 1. 在编辑期间上传的,但最终未保存的文件
// 2. 初始配置中有的,但最终配置中没有的文件(被删除的文件)
const filesToDelete: string[] = [];
// 上传了但未使用的文件
uploadedFileKeys.current.forEach((key) => {
if (!finalFileKeys.has(key)) {
filesToDelete.push(key);
}
});
// 初始有但最终没有的文件(被删除的)
initialFileKeys.current.forEach((key) => {
if (!finalFileKeys.has(key)) {
filesToDelete.push(key);
}
});
// 删除不需要的文件
const deletePromises = filesToDelete.map((fileKey) =>
httpClient.deletePluginConfigFile(fileKey).catch((err) => {
console.warn(`Failed to delete file ${fileKey}:`, err);
}),
);
await Promise.all(deletePromises);
toast.success(
isDebugPlugin
? t('plugins.saveConfigSuccessDebugPlugin')
: t('plugins.saveConfigSuccessNormal'),
);
onFormSubmit(1000);
} catch (error) {
toast.error(t('plugins.saveConfigError') + (error as Error).message);
} finally {
setIsLoading(false);
}
};
if (!pluginInfo || !pluginConfig) {
return (
<div className="flex items-center justify-center h-full mb-[2rem]">
{t('plugins.loading')}
</div>
);
}
return (
<div>
<div className="space-y-2">
<div className="text-lg font-medium">
{extractI18nObject(pluginInfo.manifest.manifest.metadata.label)}
</div>
<div className="text-sm text-gray-500 pb-2">
{extractI18nObject(
pluginInfo.manifest.manifest.metadata.description ?? {
en_US: '',
zh_Hans: '',
},
)}
</div>
<div className="mb-4 flex flex-row items-center justify-start gap-[0.4rem]">
<PluginComponentList
components={pluginInfo.components}
showComponentName={true}
showTitle={false}
useBadge={true}
t={t}
/>
</div>
{pluginInfo.manifest.manifest.spec.config.length > 0 && (
<DynamicFormComponent
itemConfigList={pluginInfo.manifest.manifest.spec.config}
initialValues={pluginConfig.config as Record<string, object>}
onSubmit={(values) => {
// 只保存表单值的引用,不触发状态更新
currentFormValues.current = values;
}}
onFileUploaded={(fileKey) => {
// 追踪上传的文件
uploadedFileKeys.current.add(fileKey);
}}
/>
)}
{pluginInfo.manifest.manifest.spec.config.length === 0 && (
<div className="text-sm text-gray-500">
{t('plugins.pluginNoConfig')}
</div>
)}
</div>
<div className="sticky bottom-0 left-0 right-0 bg-background border-t p-4 mt-4">
<div className="flex justify-end gap-2">
<Button
type="submit"
onClick={() => handleSubmit()}
disabled={isSaving}
>
{isSaving ? t('plugins.saving') : t('plugins.saveConfig')}
</Button>
<Button type="button" variant="outline" onClick={onFormCancel}>
{t('plugins.cancel')}
</Button>
</div>
</div>
</div>
);
}

View File

@@ -172,7 +172,6 @@ function MarketPageContent({
// 初始加载
useEffect(() => {
fetchPlugins(1, false, true);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
// 搜索功能
@@ -284,7 +283,7 @@ function MarketPageContent({
// };
return (
<div className="container mx-auto px-4 py-6 space-y-6">
<div className="container mx-auto px-3 sm:px-4 py-4 sm:py-6 space-y-4 sm:space-y-6">
{/* 搜索框 */}
<div className="flex items-center justify-center">
<div className="relative w-full max-w-2xl">
@@ -302,19 +301,19 @@ function MarketPageContent({
handleSearch(searchQuery);
}
}}
className="pl-10 pr-4"
className="pl-10 pr-4 text-sm sm:text-base"
/>
</div>
</div>
{/* 排序下拉框 */}
<div className="flex items-center justify-center">
<div className="w-full max-w-2xl flex items-center gap-3">
<span className="text-sm text-muted-foreground whitespace-nowrap">
<div className="w-full max-w-2xl flex items-center gap-2 sm:gap-3">
<span className="text-xs sm:text-sm text-muted-foreground whitespace-nowrap">
{t('market.sortBy')}:
</span>
<Select value={sortOption} onValueChange={handleSortChange}>
<SelectTrigger className="w-48">
<SelectTrigger className="w-40 sm:w-48 text-xs sm:text-sm">
<SelectValue />
</SelectTrigger>
<SelectContent>
@@ -330,7 +329,7 @@ function MarketPageContent({
{/* 搜索结果统计 */}
{total > 0 && (
<div className="text-center text-muted-foreground">
<div className="text-center text-muted-foreground text-sm">
{searchQuery
? t('market.searchResults', { count: total })
: t('market.totalPlugins', { count: total })}

View File

@@ -228,6 +228,30 @@ export default function PluginDetailDialog({
{...props}
/>
),
h3: ({ ...props }) => (
<h3
className="text-xl font-semibold mb-2 mt-4 dark:text-gray-400"
{...props}
/>
),
h4: ({ ...props }) => (
<h4
className="text-lg font-semibold mb-2 mt-4 dark:text-gray-400"
{...props}
/>
),
h5: ({ ...props }) => (
<h5
className="text-base font-semibold mb-2 mt-4 dark:text-gray-400"
{...props}
/>
),
h6: ({ ...props }) => (
<h6
className="text-sm font-semibold mb-2 mt-4 dark:text-gray-400"
{...props}
/>
),
p: ({ ...props }) => (
<p className="leading-relaxed dark:text-gray-400" {...props} />
),
@@ -274,6 +298,57 @@ export default function PluginDetailDialog({
{...props}
/>
),
// 图片组件 - 转换本地路径为API路径
img: ({ src, alt, ...props }) => {
// 处理图片路径
let imageSrc = src || '';
// 确保 src 是字符串类型
if (typeof imageSrc !== 'string') {
return (
<img
src={src}
alt={alt || ''}
className="max-w-full h-auto rounded-lg my-4"
{...props}
/>
);
}
// 如果是相对路径转换为API路径
if (
imageSrc &&
!imageSrc.startsWith('http://') &&
!imageSrc.startsWith('https://') &&
!imageSrc.startsWith('data:')
) {
// 移除开头的 ./ 或 / (支持多个前缀)
imageSrc = imageSrc.replace(/^(\.\/|\/)+/, '');
// 如果路径以 assets/ 开头,直接使用
// 否则假设它在 assets/ 目录下
if (!imageSrc.startsWith('assets/')) {
imageSrc = `assets/${imageSrc}`;
}
// 移除 assets/ 前缀以构建API URL
const assetPath = imageSrc.replace(/^assets\//, '');
imageSrc = getCloudServiceClientSync().getPluginAssetURL(
author!,
pluginName!,
assetPath,
);
}
return (
<img
src={imageSrc}
alt={alt || ''}
className="max-w-lg h-auto my-4"
{...props}
/>
);
},
}}
>
{readme}

View File

@@ -15,35 +15,37 @@ export default function PluginMarketCardComponent({
return (
<div
className="w-[100%] h-[9rem] bg-white rounded-[10px] shadow-[0px_0px_4px_0_rgba(0,0,0,0.2)] p-[1rem] cursor-pointer hover:shadow-[0px_2px_8px_0_rgba(0,0,0,0.15)] transition-shadow duration-200 dark:bg-[#1f1f22]"
className="w-[100%] h-auto min-h-[8rem] sm:h-[9rem] bg-white rounded-[10px] shadow-[0px_0px_4px_0_rgba(0,0,0,0.2)] p-3 sm:p-[1rem] cursor-pointer hover:shadow-[0px_2px_8px_0_rgba(0,0,0,0.15)] transition-shadow duration-200 dark:bg-[#1f1f22]"
onClick={handleCardClick}
>
<div className="w-full h-full flex flex-col justify-between">
<div className="w-full h-full flex flex-col justify-between gap-2">
{/* 上部分:插件信息 */}
<div className="flex flex-row items-start justify-start gap-[1.2rem]">
<img src={cardVO.iconURL} alt="plugin icon" className="w-16 h-16" />
<div className="flex flex-row items-start justify-start gap-2 sm:gap-[1.2rem] min-h-0">
<img
src={cardVO.iconURL}
alt="plugin icon"
className="w-12 h-12 sm:w-16 sm:h-16 flex-shrink-0"
/>
<div className="flex-1 flex flex-col items-start justify-start gap-[0.6rem]">
<div className="flex flex-col items-start justify-start">
<div className="text-[0.7rem] text-[#666] dark:text-[#999]">
<div className="flex-1 flex flex-col items-start justify-start gap-[0.4rem] sm:gap-[0.6rem] min-w-0 overflow-hidden">
<div className="flex flex-col items-start justify-start w-full min-w-0">
<div className="text-[0.65rem] sm:text-[0.7rem] text-[#666] dark:text-[#999] truncate w-full">
{cardVO.pluginId}
</div>
<div className="flex flex-row items-center justify-start gap-[0.4rem]">
<div className="text-[1.2rem] text-black dark:text-[#f0f0f0]">
{cardVO.label}
</div>
<div className="text-base sm:text-[1.2rem] text-black dark:text-[#f0f0f0] truncate w-full">
{cardVO.label}
</div>
</div>
<div className="text-[0.8rem] text-[#666] dark:text-[#999] line-clamp-2">
<div className="text-[0.7rem] sm:text-[0.8rem] text-[#666] dark:text-[#999] line-clamp-2 overflow-hidden">
{cardVO.description}
</div>
</div>
<div className="flex h-full flex-row items-start justify-center gap-[0.4rem]">
<div className="flex flex-row items-start justify-center gap-[0.4rem] flex-shrink-0">
{cardVO.githubURL && (
<svg
className="w-[1.4rem] h-[1.4rem] text-black cursor-pointer hover:text-gray-600 dark:text-[#f0f0f0]"
className="w-5 h-5 sm:w-[1.4rem] sm:h-[1.4rem] text-black cursor-pointer hover:text-gray-600 dark:text-[#f0f0f0] flex-shrink-0"
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
@@ -59,9 +61,9 @@ export default function PluginMarketCardComponent({
</div>
{/* 下部分:下载量 */}
<div className="w-full flex flex-row items-center justify-start gap-[0.4rem] px-[0.4rem]">
<div className="w-full flex flex-row items-center justify-start gap-[0.3rem] sm:gap-[0.4rem] px-0 sm:px-[0.4rem] flex-shrink-0">
<svg
className="w-[1.2rem] h-[1.2rem] text-[#2563eb]"
className="w-4 h-4 sm:w-[1.2rem] sm:h-[1.2rem] text-[#2563eb] flex-shrink-0"
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="none"
@@ -72,7 +74,7 @@ export default function PluginMarketCardComponent({
<polyline points="7,10 12,15 17,10" />
<line x1="12" y1="15" x2="12" y2="3" />
</svg>
<div className="text-sm text-[#2563eb] font-medium">
<div className="text-xs sm:text-sm text-[#2563eb] font-medium whitespace-nowrap">
{cardVO.installCount.toLocaleString()}
</div>
</div>

View File

@@ -0,0 +1,29 @@
import { MCPServer, MCPSessionStatus } from '@/app/infra/entities/api';
export class MCPCardVO {
name: string;
mode: 'stdio' | 'sse';
enable: boolean;
status: MCPSessionStatus;
tools: number;
error?: string;
constructor(data: MCPServer) {
this.name = data.name;
this.mode = data.mode;
this.enable = data.enable;
// Determine status from runtime_info
if (!data.runtime_info) {
this.status = MCPSessionStatus.ERROR;
this.tools = 0;
} else if (data.runtime_info.status === MCPSessionStatus.CONNECTED) {
this.status = data.runtime_info.status;
this.tools = data.runtime_info.tool_count || 0;
} else {
this.status = data.runtime_info.status;
this.tools = 0;
this.error = data.runtime_info.error_message;
}
}
}

View File

@@ -0,0 +1,114 @@
'use client';
import { useEffect, useState, useRef } from 'react';
import MCPCardComponent from '@/app/home/plugins/mcp-server/mcp-card/MCPCardComponent';
import { MCPCardVO } from '@/app/home/plugins/mcp-server/MCPCardVO';
import { useTranslation } from 'react-i18next';
import { MCPSessionStatus } from '@/app/infra/entities/api';
import { httpClient } from '@/app/infra/http/HttpClient';
export default function MCPComponent({
onEditServer,
}: {
askInstallServer?: (githubURL: string) => void;
onEditServer?: (serverName: string) => void;
}) {
const { t } = useTranslation();
const [installedServers, setInstalledServers] = useState<MCPCardVO[]>([]);
const [loading, setLoading] = useState(false);
const pollingIntervalRef = useRef<NodeJS.Timeout | null>(null);
useEffect(() => {
fetchInstalledServers();
return () => {
// Cleanup: clear polling interval when component unmounts
if (pollingIntervalRef.current) {
clearInterval(pollingIntervalRef.current);
}
};
}, []);
// Check if any enabled server is connecting and start/stop polling accordingly
useEffect(() => {
const hasConnecting = installedServers.some(
(server) =>
server.enable && server.status === MCPSessionStatus.CONNECTING,
);
if (hasConnecting && !pollingIntervalRef.current) {
// Start polling every 3 seconds
pollingIntervalRef.current = setInterval(() => {
fetchInstalledServers();
}, 3000);
} else if (!hasConnecting && pollingIntervalRef.current) {
// Stop polling when no enabled server is connecting
clearInterval(pollingIntervalRef.current);
pollingIntervalRef.current = null;
}
return () => {
if (pollingIntervalRef.current) {
clearInterval(pollingIntervalRef.current);
pollingIntervalRef.current = null;
}
};
}, [installedServers]);
function fetchInstalledServers() {
setLoading(true);
httpClient
.getMCPServers()
.then((resp) => {
const servers = resp.servers.map((server) => new MCPCardVO(server));
setInstalledServers(servers);
setLoading(false);
})
.catch((error) => {
console.error('Failed to fetch MCP servers:', error);
setLoading(false);
});
}
return (
<div className="w-full h-full">
{/* 已安装的服务器列表 */}
<div className="w-full px-[0.8rem] pt-[0rem] gap-4">
{loading ? (
<div className="flex flex-col items-center justify-center text-gray-500 h-[calc(100vh-16rem)] w-full gap-2">
{t('mcp.loading')}
</div>
) : installedServers.length === 0 ? (
<div className="flex flex-col items-center justify-center text-gray-500 h-[calc(100vh-16rem)] w-full gap-2">
<svg
className="h-[3rem] w-[3rem]"
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M4.5 7.65311V16.3469L12 20.689L19.5 16.3469V7.65311L12 3.311L4.5 7.65311ZM12 1L21.5 6.5V17.5L12 23L2.5 17.5V6.5L12 1ZM6.49896 9.97065L11 12.5765V17.625H13V12.5765L17.501 9.97066L16.499 8.2398L12 10.8445L7.50104 8.2398L6.49896 9.97065Z"></path>
</svg>
<div className="text-lg mb-2">{t('mcp.noServerInstalled')}</div>
</div>
) : (
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6 pt-[2rem]">
{installedServers.map((server, index) => (
<div key={`${server.name}-${index}`}>
<MCPCardComponent
cardVO={server}
onCardClick={() => {
if (onEditServer) {
onEditServer(server.name);
}
}}
onRefresh={fetchInstalledServers}
/>
</div>
))}
</div>
)}
</div>
</div>
);
}

View File

@@ -0,0 +1,172 @@
import { MCPCardVO } from '@/app/home/plugins/mcp-server/MCPCardVO';
import { useState, useEffect } from 'react';
import { httpClient } from '@/app/infra/http/HttpClient';
import { Switch } from '@/components/ui/switch';
import { Button } from '@/components/ui/button';
import { toast } from 'sonner';
import { useTranslation } from 'react-i18next';
import { RefreshCcw, Wrench, Ban, AlertCircle, Loader2 } from 'lucide-react';
import { MCPSessionStatus } from '@/app/infra/entities/api';
export default function MCPCardComponent({
cardVO,
onCardClick,
onRefresh,
}: {
cardVO: MCPCardVO;
onCardClick: () => void;
onRefresh: () => void;
}) {
const { t } = useTranslation();
const [enabled, setEnabled] = useState(cardVO.enable);
const [switchEnable, setSwitchEnable] = useState(true);
const [testing, setTesting] = useState(false);
const [toolsCount, setToolsCount] = useState(cardVO.tools);
const [status, setStatus] = useState(cardVO.status);
useEffect(() => {
setStatus(cardVO.status);
setToolsCount(cardVO.tools);
setEnabled(cardVO.enable);
}, [cardVO.status, cardVO.tools, cardVO.enable]);
function handleEnable(checked: boolean) {
setSwitchEnable(false);
httpClient
.toggleMCPServer(cardVO.name, checked)
.then(() => {
setEnabled(checked);
toast.success(t('mcp.saveSuccess'));
onRefresh();
setSwitchEnable(true);
})
.catch((err) => {
toast.error(t('mcp.modifyFailed') + err.message);
setSwitchEnable(true);
});
}
function handleTest(e: React.MouseEvent) {
e.stopPropagation();
setTesting(true);
httpClient
.testMCPServer(cardVO.name, {})
.then((resp) => {
const taskId = resp.task_id;
const interval = setInterval(() => {
httpClient.getAsyncTask(taskId).then((taskResp) => {
if (taskResp.runtime.done) {
clearInterval(interval);
setTesting(false);
if (taskResp.runtime.exception) {
toast.error(
t('mcp.refreshFailed') + taskResp.runtime.exception,
);
} else {
toast.success(t('mcp.refreshSuccess'));
}
// Refresh to get updated runtime_info
onRefresh();
}
});
}, 1000);
})
.catch((err) => {
toast.error(t('mcp.refreshFailed') + err.message);
setTesting(false);
});
}
return (
<div
className="w-[100%] h-[10rem] bg-white dark:bg-[#1f1f22] rounded-[10px] shadow-[0px_2px_2px_0_rgba(0,0,0,0.2)] dark:shadow-none p-[1.2rem] cursor-pointer transition-all duration-200 hover:shadow-[0px_2px_8px_0_rgba(0,0,0,0.1)] dark:hover:shadow-none"
onClick={onCardClick}
>
<div className="w-full h-full flex flex-row items-start justify-start gap-[1.2rem]">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
width="64"
height="64"
fill="rgba(70,146,221,1)"
>
<path d="M17.6567 14.8284L16.2425 13.4142L17.6567 12C19.2188 10.4379 19.2188 7.90524 17.6567 6.34314C16.0946 4.78105 13.5619 4.78105 11.9998 6.34314L10.5856 7.75736L9.17139 6.34314L10.5856 4.92893C12.9287 2.58578 16.7277 2.58578 19.0709 4.92893C21.414 7.27208 21.414 11.0711 19.0709 13.4142L17.6567 14.8284ZM14.8282 17.6569L13.414 19.0711C11.0709 21.4142 7.27189 21.4142 4.92875 19.0711C2.5856 16.7279 2.5856 12.9289 4.92875 10.5858L6.34296 9.17157L7.75717 10.5858L6.34296 12C4.78086 13.5621 4.78086 16.0948 6.34296 17.6569C7.90506 19.2189 10.4377 19.2189 11.9998 17.6569L13.414 16.2426L14.8282 17.6569ZM14.8282 7.75736L16.2425 9.17157L9.17139 16.2426L7.75717 14.8284L14.8282 7.75736Z"></path>
</svg>
<div className="w-full h-full flex flex-col items-start justify-between gap-[0.6rem]">
<div className="flex flex-col items-start justify-start">
<div className="text-[1.2rem] text-black dark:text-[#f0f0f0] font-medium">
{cardVO.name}
</div>
</div>
<div className="w-full flex flex-row items-start justify-start gap-[0.6rem]">
{!enabled ? (
// 未启用 - 橙色
<div className="flex flex-row items-center gap-[0.4rem]">
<Ban className="w-4 h-4 text-orange-500 dark:text-orange-400" />
<div className="text-sm text-orange-500 dark:text-orange-400 font-medium">
{t('mcp.statusDisabled')}
</div>
</div>
) : status === MCPSessionStatus.CONNECTED ? (
// 连接成功 - 显示工具数量
<div className="flex h-full flex-row items-center justify-center gap-[0.4rem]">
<Wrench className="w-5 h-5" />
<div className="text-base text-black dark:text-[#f0f0f0] font-medium">
{t('mcp.toolCount', { count: toolsCount })}
</div>
</div>
) : status === MCPSessionStatus.CONNECTING ? (
// 连接中 - 蓝色加载
<div className="flex flex-row items-center gap-[0.4rem]">
<Loader2 className="w-4 h-4 text-blue-500 dark:text-blue-400 animate-spin" />
<div className="text-sm text-blue-500 dark:text-blue-400 font-medium">
{t('mcp.connecting')}
</div>
</div>
) : (
// 连接失败 - 红色
<div className="flex flex-row items-center gap-[0.4rem]">
<AlertCircle className="w-4 h-4 text-red-500 dark:text-red-400" />
<div className="text-sm text-red-500 dark:text-red-400 font-medium">
{t('mcp.connectionFailedStatus')}
</div>
</div>
)}
</div>
</div>
<div className="flex flex-col items-center justify-between h-full">
<div
className="flex items-center justify-center"
onClick={(e) => e.stopPropagation()}
>
<Switch
className="cursor-pointer"
checked={enabled}
onCheckedChange={handleEnable}
disabled={!switchEnable}
/>
</div>
<div className="flex items-center justify-center gap-[0.4rem]">
<Button
variant="ghost"
size="sm"
className="p-1 h-8 w-8"
onClick={(e) => handleTest(e)}
disabled={testing}
>
<RefreshCcw className="w-4 h-4" />
</Button>
</div>
</div>
</div>
</div>
);
}

View File

@@ -0,0 +1,68 @@
'use client';
import React from 'react';
import { useTranslation } from 'react-i18next';
import { toast } from 'sonner';
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogFooter,
DialogDescription,
} from '@/components/ui/dialog';
import { Button } from '@/components/ui/button';
import { httpClient } from '@/app/infra/http/HttpClient';
interface MCPDeleteConfirmDialogProps {
open: boolean;
onOpenChange: (open: boolean) => void;
serverName: string | null;
onSuccess?: () => void;
}
export default function MCPDeleteConfirmDialog({
open,
onOpenChange,
serverName,
onSuccess,
}: MCPDeleteConfirmDialogProps) {
const { t } = useTranslation();
async function handleDelete() {
if (!serverName) return;
try {
await httpClient.deleteMCPServer(serverName);
toast.success(t('mcp.deleteSuccess'));
onOpenChange(false);
if (onSuccess) {
onSuccess();
}
} catch (error) {
console.error('Failed to delete server:', error);
toast.error(t('mcp.deleteFailed'));
}
}
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent>
<DialogHeader>
<DialogTitle>{t('mcp.confirmDeleteTitle')}</DialogTitle>
</DialogHeader>
<DialogDescription>{t('mcp.confirmDeleteServer')}</DialogDescription>
<DialogFooter>
<Button variant="outline" onClick={() => onOpenChange(false)}>
{t('common.cancel')}
</Button>
<Button variant="destructive" onClick={handleDelete}>
{t('common.confirm')}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
);
}

View File

@@ -0,0 +1,673 @@
'use client';
import React, { useState, useEffect, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { Resolver, useForm } from 'react-hook-form';
import { zodResolver } from '@hookform/resolvers/zod';
import { z } from 'zod';
import { toast } from 'sonner';
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogFooter,
} from '@/components/ui/dialog';
import {
Card,
CardHeader,
CardTitle,
CardDescription,
} from '@/components/ui/card';
import {
Form,
FormControl,
FormDescription,
FormField,
FormItem,
FormLabel,
FormMessage,
} from '@/components/ui/form';
import {
Select,
SelectTrigger,
SelectValue,
SelectContent,
SelectItem,
} from '@/components/ui/select';
import { Input } from '@/components/ui/input';
import { Button } from '@/components/ui/button';
import { httpClient } from '@/app/infra/http/HttpClient';
import {
MCPServerRuntimeInfo,
MCPTool,
MCPServer,
MCPSessionStatus,
} from '@/app/infra/entities/api';
// Status Display Component - 在测试中、连接中或连接失败时使用
function StatusDisplay({
testing,
runtimeInfo,
t,
}: {
testing: boolean;
runtimeInfo: MCPServerRuntimeInfo;
t: (key: string) => string;
}) {
if (testing) {
return (
<div className="flex items-center gap-2 text-blue-600">
<svg
className="w-5 h-5 animate-spin"
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
>
<circle
className="opacity-25"
cx="12"
cy="12"
r="10"
stroke="currentColor"
strokeWidth="4"
/>
<path
className="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
/>
</svg>
<span className="font-medium">{t('mcp.testing')}</span>
</div>
);
}
// 连接中
if (runtimeInfo.status === MCPSessionStatus.CONNECTING) {
return (
<div className="flex items-center gap-2 text-blue-600">
<svg
className="w-5 h-5 animate-spin"
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
>
<circle
className="opacity-25"
cx="12"
cy="12"
r="10"
stroke="currentColor"
strokeWidth="4"
/>
<path
className="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
/>
</svg>
<span className="font-medium">{t('mcp.connecting')}</span>
</div>
);
}
// 连接失败
return (
<div className="space-y-1">
<div className="flex items-center gap-2 text-red-600">
<svg
className="w-5 h-5"
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M10 14l2-2m0 0l2-2m-2 2l-2-2m2 2l2 2m7-2a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
<span className="font-medium">{t('mcp.connectionFailed')}</span>
</div>
{/* {runtimeInfo.error_message && (
<div className="text-sm text-red-500 pl-7">
{runtimeInfo.error_message}
</div>
)} */}
</div>
);
}
// Tools List Component
function ToolsList({ tools }: { tools: MCPTool[] }) {
return (
<div className="space-y-2 max-h-[300px] overflow-y-auto">
{tools.map((tool, index) => (
<Card key={index} className="py-3 shadow-none">
<CardHeader>
<CardTitle className="text-sm">{tool.name}</CardTitle>
{tool.description && (
<CardDescription className="text-xs">
{tool.description}
</CardDescription>
)}
</CardHeader>
</Card>
))}
</div>
);
}
const getFormSchema = (t: (key: string) => string) =>
z.object({
name: z
.string({ required_error: t('mcp.nameRequired') })
.min(1, { message: t('mcp.nameRequired') }),
timeout: z
.number({ invalid_type_error: t('mcp.timeoutMustBeNumber') })
.positive({ message: t('mcp.timeoutMustBePositive') })
.default(30),
ssereadtimeout: z
.number({ invalid_type_error: t('mcp.sseTimeoutMustBeNumber') })
.positive({ message: t('mcp.timeoutMustBePositive') })
.default(300),
url: z
.string({ required_error: t('mcp.urlRequired') })
.min(1, { message: t('mcp.urlRequired') }),
extra_args: z
.array(
z.object({
key: z.string(),
type: z.enum(['string', 'number', 'boolean']),
value: z.string(),
}),
)
.optional(),
});
type FormValues = z.infer<ReturnType<typeof getFormSchema>> & {
timeout: number;
ssereadtimeout: number;
};
interface MCPFormDialogProps {
open: boolean;
onOpenChange: (open: boolean) => void;
serverName?: string | null;
isEditMode?: boolean;
onSuccess?: () => void;
onDelete?: () => void;
}
export default function MCPFormDialog({
open,
onOpenChange,
serverName,
isEditMode = false,
onSuccess,
onDelete,
}: MCPFormDialogProps) {
const { t } = useTranslation();
const formSchema = getFormSchema(t);
const form = useForm<FormValues>({
resolver: zodResolver(formSchema) as unknown as Resolver<FormValues>,
defaultValues: {
name: '',
url: '',
timeout: 30,
ssereadtimeout: 300,
extra_args: [],
},
});
const [extraArgs, setExtraArgs] = useState<
{ key: string; type: 'string' | 'number' | 'boolean'; value: string }[]
>([]);
const [mcpTesting, setMcpTesting] = useState(false);
const [runtimeInfo, setRuntimeInfo] = useState<MCPServerRuntimeInfo | null>(
null,
);
const pollingIntervalRef = useRef<NodeJS.Timeout | null>(null);
// Load server data when editing
useEffect(() => {
if (open && isEditMode && serverName) {
loadServerForEdit(serverName);
} else if (open && !isEditMode) {
// Reset form when creating new server
form.reset();
setExtraArgs([]);
setRuntimeInfo(null);
}
// Cleanup polling interval when dialog closes
return () => {
if (pollingIntervalRef.current) {
clearInterval(pollingIntervalRef.current);
pollingIntervalRef.current = null;
}
};
}, [open, isEditMode, serverName]);
// Poll for updates when runtime_info status is CONNECTING
useEffect(() => {
if (
!open ||
!isEditMode ||
!serverName ||
!runtimeInfo ||
runtimeInfo.status !== MCPSessionStatus.CONNECTING
) {
// Stop polling if conditions are not met
if (pollingIntervalRef.current) {
clearInterval(pollingIntervalRef.current);
pollingIntervalRef.current = null;
}
return;
}
// Start polling if not already running
if (!pollingIntervalRef.current) {
pollingIntervalRef.current = setInterval(() => {
loadServerForEdit(serverName);
}, 3000);
}
return () => {
if (pollingIntervalRef.current) {
clearInterval(pollingIntervalRef.current);
pollingIntervalRef.current = null;
}
};
}, [open, isEditMode, serverName, runtimeInfo?.status]);
async function loadServerForEdit(serverName: string) {
try {
const resp = await httpClient.getMCPServer(serverName);
const server = resp.server ?? resp;
const extraArgs = server.extra_args;
form.setValue('name', server.name);
form.setValue('url', extraArgs.url);
form.setValue('timeout', extraArgs.timeout);
form.setValue('ssereadtimeout', extraArgs.ssereadtimeout);
if (extraArgs.headers) {
const headers = Object.entries(extraArgs.headers).map(
([key, value]) => ({
key,
type: 'string' as const,
value: String(value),
}),
);
setExtraArgs(headers);
form.setValue('extra_args', headers);
}
// Set runtime_info from server data
if (server.runtime_info) {
setRuntimeInfo(server.runtime_info);
} else {
setRuntimeInfo(null);
}
} catch (error) {
console.error('Failed to load server:', error);
toast.error(t('mcp.loadFailed'));
}
}
async function handleFormSubmit(value: z.infer<typeof formSchema>) {
// Convert extra_args to headers - all values must be strings according to MCPServerExtraArgsSSE
const headers: Record<string, string> = {};
value.extra_args?.forEach((arg) => {
// Convert all values to strings to match MCPServerExtraArgsSSE.headers type
headers[arg.key] = String(arg.value);
});
try {
const serverConfig: Omit<
MCPServer,
'uuid' | 'created_at' | 'updated_at' | 'runtime_info'
> = {
name: value.name,
mode: 'sse' as const,
enable: true,
extra_args: {
url: value.url,
headers: headers,
timeout: value.timeout,
ssereadtimeout: value.ssereadtimeout,
},
};
if (isEditMode && serverName) {
await httpClient.updateMCPServer(serverName, serverConfig);
toast.success(t('mcp.updateSuccess'));
} else {
await httpClient.createMCPServer(serverConfig);
toast.success(t('mcp.createSuccess'));
}
handleDialogClose(false);
onSuccess?.();
} catch (error) {
console.error('Failed to save MCP server:', error);
toast.error(isEditMode ? t('mcp.updateFailed') : t('mcp.createFailed'));
}
}
async function testMcp() {
setMcpTesting(true);
try {
const { task_id } = await httpClient.testMCPServer('_', {
name: form.getValues('name'),
mode: 'sse',
enable: true,
extra_args: {
url: form.getValues('url'),
timeout: form.getValues('timeout'),
ssereadtimeout: form.getValues('ssereadtimeout'),
headers: Object.fromEntries(
extraArgs.map((arg) => [arg.key, arg.value]),
),
},
});
if (!task_id) {
throw new Error(t('mcp.noTaskId'));
}
const interval = setInterval(async () => {
try {
const taskResp = await httpClient.getAsyncTask(task_id);
if (taskResp.runtime?.done) {
clearInterval(interval);
setMcpTesting(false);
if (taskResp.runtime.exception) {
const errorMsg =
taskResp.runtime.exception || t('mcp.unknownError');
toast.error(`${t('mcp.testError')}: ${errorMsg}`);
setRuntimeInfo({
status: MCPSessionStatus.ERROR,
error_message: errorMsg,
tool_count: 0,
tools: [],
});
} else {
if (isEditMode) {
await loadServerForEdit(form.getValues('name'));
}
toast.success(t('mcp.testSuccess'));
}
}
} catch (err) {
clearInterval(interval);
setMcpTesting(false);
const errorMsg = (err as Error).message || t('mcp.getTaskFailed');
toast.error(`${t('mcp.testError')}: ${errorMsg}`);
}
}, 1000);
} catch (err) {
setMcpTesting(false);
const errorMsg = (err as Error).message || t('mcp.unknownError');
toast.error(`${t('mcp.testError')}: ${errorMsg}`);
}
}
const addExtraArg = () => {
const newArgs = [
...extraArgs,
{ key: '', type: 'string' as const, value: '' },
];
setExtraArgs(newArgs);
form.setValue('extra_args', newArgs);
};
const removeExtraArg = (index: number) => {
const newArgs = extraArgs.filter((_, i) => i !== index);
setExtraArgs(newArgs);
form.setValue('extra_args', newArgs);
};
const updateExtraArg = (
index: number,
field: 'key' | 'type' | 'value',
value: string,
) => {
const newArgs = [...extraArgs];
newArgs[index] = { ...newArgs[index], [field]: value };
setExtraArgs(newArgs);
form.setValue('extra_args', newArgs);
};
const handleDialogClose = (open: boolean) => {
onOpenChange(open);
if (!open) {
form.reset();
setExtraArgs([]);
setRuntimeInfo(null);
}
};
return (
<Dialog open={open} onOpenChange={handleDialogClose}>
<DialogContent>
<DialogHeader>
<DialogTitle>
{isEditMode ? t('mcp.editServer') : t('mcp.createServer')}
</DialogTitle>
</DialogHeader>
{isEditMode && runtimeInfo && (
<div className="mb-0 space-y-3">
{/* 测试中或连接失败时显示状态 */}
{(mcpTesting ||
runtimeInfo.status !== MCPSessionStatus.CONNECTED) && (
<div className="p-3 rounded-lg border">
<StatusDisplay
testing={mcpTesting}
runtimeInfo={runtimeInfo}
t={t}
/>
</div>
)}
{/* 连接成功时只显示工具列表 */}
{!mcpTesting &&
runtimeInfo.status === MCPSessionStatus.CONNECTED &&
runtimeInfo.tools?.length > 0 && (
<>
<div className="text-sm font-medium">
{t('mcp.toolCount', {
count: runtimeInfo.tools?.length || 0,
})}
</div>
<ToolsList tools={runtimeInfo.tools} />
</>
)}
</div>
)}
<Form {...form}>
<form
onSubmit={form.handleSubmit(handleFormSubmit)}
className="space-y-4"
>
<div className="space-y-4">
<FormField
control={form.control}
name="name"
render={({ field }) => (
<FormItem>
<FormLabel>{t('mcp.name')}</FormLabel>
<FormControl>
<Input {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="url"
render={({ field }) => (
<FormItem>
<FormLabel>{t('mcp.url')}</FormLabel>
<FormControl>
<Input {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="timeout"
render={({ field }) => (
<FormItem>
<FormLabel>{t('mcp.timeout')}</FormLabel>
<FormControl>
<Input
type="number"
placeholder={t('mcp.timeout')}
{...field}
onChange={(e) => field.onChange(Number(e.target.value))}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="ssereadtimeout"
render={({ field }) => (
<FormItem>
<FormLabel>{t('mcp.sseTimeout')}</FormLabel>
<FormControl>
<Input
type="number"
placeholder={t('mcp.sseTimeoutDescription')}
{...field}
onChange={(e) => field.onChange(Number(e.target.value))}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormItem>
<FormLabel>{t('models.extraParameters')}</FormLabel>
<div className="space-y-2">
{extraArgs.map((arg, index) => (
<div key={index} className="flex gap-2">
<Input
placeholder={t('models.keyName')}
value={arg.key}
onChange={(e) =>
updateExtraArg(index, 'key', e.target.value)
}
/>
<Select
value={arg.type}
onValueChange={(value) =>
updateExtraArg(index, 'type', value)
}
>
<SelectTrigger className="w-[120px] bg-[#ffffff] dark:bg-[#2a2a2e]">
<SelectValue placeholder={t('models.type')} />
</SelectTrigger>
<SelectContent className="bg-[#ffffff] dark:bg-[#2a2a2e]">
<SelectItem value="string">
{t('models.string')}
</SelectItem>
<SelectItem value="number">
{t('models.number')}
</SelectItem>
<SelectItem value="boolean">
{t('models.boolean')}
</SelectItem>
</SelectContent>
</Select>
<Input
placeholder={t('models.value')}
value={arg.value}
onChange={(e) =>
updateExtraArg(index, 'value', e.target.value)
}
/>
<button
type="button"
className="p-2 hover:bg-gray-100 rounded"
onClick={() => removeExtraArg(index)}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
className="w-5 h-5 text-red-500"
>
<path d="M7 4V2H17V4H22V6H20V21C20 21.5523 19.5523 22 19 22H5C4.44772 22 4 21.5523 4 21V6H2V4H7ZM6 6V20H18V6H6ZM9 9H11V17H9V9ZM13 9H15V17H13V9Z"></path>
</svg>
</button>
</div>
))}
<Button type="button" variant="outline" onClick={addExtraArg}>
{t('models.addParameter')}
</Button>
</div>
<FormDescription>
{t('mcp.extraParametersDescription')}
</FormDescription>
<FormMessage />
</FormItem>
<DialogFooter>
{isEditMode && onDelete && (
<Button
type="button"
variant="destructive"
onClick={onDelete}
>
{t('common.delete')}
</Button>
)}
<Button type="submit">
{isEditMode ? t('common.save') : t('common.submit')}
</Button>
<Button
type="button"
variant="outline"
onClick={() => testMcp()}
disabled={mcpTesting}
>
{t('common.test')}
</Button>
<Button
type="button"
variant="outline"
onClick={() => handleDialogClose(false)}
>
{t('common.cancel')}
</Button>
</DialogFooter>
</div>
</form>
</Form>
</DialogContent>
</Dialog>
);
}

View File

@@ -1,18 +1,29 @@
'use client';
import PluginInstalledComponent, {
PluginInstalledComponentRef,
} from '@/app/home/plugins/plugin-installed/PluginInstalledComponent';
import MarketPage from '@/app/home/plugins/plugin-market/PluginMarketComponent';
// import PluginSortDialog from '@/app/home/plugins/plugin-sort/PluginSortDialog';
} from '@/app/home/plugins/components/plugin-installed/PluginInstalledComponent';
import MarketPage from '@/app/home/plugins/components/plugin-market/PluginMarketComponent';
import MCPServerComponent from '@/app/home/plugins/mcp-server/MCPServerComponent';
import MCPFormDialog from '@/app/home/plugins/mcp-server/mcp-form/MCPFormDialog';
import MCPDeleteConfirmDialog from '@/app/home/plugins/mcp-server/mcp-form/MCPDeleteConfirmDialog';
import styles from './plugins.module.css';
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
import { Button } from '@/components/ui/button';
import {
Card,
CardHeader,
CardTitle,
CardDescription,
} from '@/components/ui/card';
import {
PlusIcon,
ChevronDownIcon,
UploadIcon,
StoreIcon,
Download,
Power,
Github,
ChevronLeft,
} from 'lucide-react';
import {
DropdownMenu,
@@ -28,55 +39,115 @@ import {
DialogFooter,
} from '@/components/ui/dialog';
import { Input } from '@/components/ui/input';
import { useState, useRef, useCallback } from 'react';
import React, { useState, useRef, useCallback, useEffect } from 'react';
import { httpClient } from '@/app/infra/http/HttpClient';
import { toast } from 'sonner';
import { useTranslation } from 'react-i18next';
import { PluginV4 } from '@/app/infra/entities/plugin';
import { systemInfo } from '@/app/infra/http/HttpClient';
import { ApiRespPluginSystemStatus } from '@/app/infra/entities/api';
enum PluginInstallStatus {
WAIT_INPUT = 'wait_input',
SELECT_RELEASE = 'select_release',
SELECT_ASSET = 'select_asset',
ASK_CONFIRM = 'ask_confirm',
INSTALLING = 'installing',
ERROR = 'error',
}
interface GithubRelease {
id: number;
tag_name: string;
name: string;
published_at: string;
prerelease: boolean;
draft: boolean;
}
interface GithubAsset {
id: number;
name: string;
size: number;
download_url: string;
content_type: string;
}
export default function PluginConfigPage() {
const { t } = useTranslation();
const [modalOpen, setModalOpen] = useState(false);
// const [sortModalOpen, setSortModalOpen] = useState(false);
const [activeTab, setActiveTab] = useState('installed');
const [modalOpen, setModalOpen] = useState(false);
const [installSource, setInstallSource] = useState<string>('local');
const [installInfo, setInstallInfo] = useState<Record<string, any>>({}); // eslint-disable-line @typescript-eslint/no-explicit-any
const [mcpSSEModalOpen, setMcpSSEModalOpen] = useState(false);
const [pluginInstallStatus, setPluginInstallStatus] =
useState<PluginInstallStatus>(PluginInstallStatus.WAIT_INPUT);
const [installError, setInstallError] = useState<string | null>(null);
const [githubURL, setGithubURL] = useState('');
const [githubReleases, setGithubReleases] = useState<GithubRelease[]>([]);
const [selectedRelease, setSelectedRelease] = useState<GithubRelease | null>(
null,
);
const [githubAssets, setGithubAssets] = useState<GithubAsset[]>([]);
const [selectedAsset, setSelectedAsset] = useState<GithubAsset | null>(null);
const [githubOwner, setGithubOwner] = useState('');
const [githubRepo, setGithubRepo] = useState('');
const [fetchingReleases, setFetchingReleases] = useState(false);
const [fetchingAssets, setFetchingAssets] = useState(false);
const [isDragOver, setIsDragOver] = useState(false);
const pluginInstalledRef = useRef<PluginInstalledComponentRef>(null);
const [pluginSystemStatus, setPluginSystemStatus] =
useState<ApiRespPluginSystemStatus | null>(null);
const [statusLoading, setStatusLoading] = useState(true);
const fileInputRef = useRef<HTMLInputElement>(null);
const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false);
const [editingServerName, setEditingServerName] = useState<string | null>(
null,
);
const [isEditMode, setIsEditMode] = useState(false);
const [refreshKey, setRefreshKey] = useState(0);
useEffect(() => {
const fetchPluginSystemStatus = async () => {
try {
setStatusLoading(true);
const status = await httpClient.getPluginSystemStatus();
setPluginSystemStatus(status);
} catch (error) {
console.error('Failed to fetch plugin system status:', error);
toast.error(t('plugins.failedToGetStatus'));
} finally {
setStatusLoading(false);
}
};
fetchPluginSystemStatus();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
function formatFileSize(bytes: number): string {
if (bytes === 0) return '0 Bytes';
const k = 1024;
const sizes = ['Bytes', 'KB', 'MB', 'GB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return Math.round((bytes / Math.pow(k, i)) * 100) / 100 + ' ' + sizes[i];
}
function watchTask(taskId: number) {
let alreadySuccess = false;
console.log('taskId:', taskId);
// 每秒拉取一次任务状态
const interval = setInterval(() => {
httpClient.getAsyncTask(taskId).then((resp) => {
console.log('task status:', resp);
if (resp.runtime.done) {
clearInterval(interval);
if (resp.runtime.exception) {
setInstallError(resp.runtime.exception);
setPluginInstallStatus(PluginInstallStatus.ERROR);
} else {
// success
if (!alreadySuccess) {
toast.success(t('plugins.installSuccess'));
alreadySuccess = true;
}
setGithubURL('');
resetGithubState();
setModalOpen(false);
pluginInstalledRef.current?.refreshPluginList();
}
@@ -85,8 +156,96 @@ export default function PluginConfigPage() {
}, 1000);
}
const pluginInstalledRef = useRef<PluginInstalledComponentRef>(null);
function resetGithubState() {
setGithubURL('');
setGithubReleases([]);
setSelectedRelease(null);
setGithubAssets([]);
setSelectedAsset(null);
setGithubOwner('');
setGithubRepo('');
setFetchingReleases(false);
setFetchingAssets(false);
}
async function fetchGithubReleases() {
if (!githubURL.trim()) {
toast.error(t('plugins.enterRepoUrl'));
return;
}
setFetchingReleases(true);
setInstallError(null);
try {
const result = await httpClient.getGithubReleases(githubURL);
setGithubReleases(result.releases);
setGithubOwner(result.owner);
setGithubRepo(result.repo);
if (result.releases.length === 0) {
toast.warning(t('plugins.noReleasesFound'));
} else {
setPluginInstallStatus(PluginInstallStatus.SELECT_RELEASE);
}
} catch (error: unknown) {
console.error('Failed to fetch GitHub releases:', error);
const errorMessage =
error instanceof Error ? error.message : String(error);
setInstallError(errorMessage || t('plugins.fetchReleasesError'));
setPluginInstallStatus(PluginInstallStatus.ERROR);
} finally {
setFetchingReleases(false);
}
}
async function handleReleaseSelect(release: GithubRelease) {
setSelectedRelease(release);
setFetchingAssets(true);
setInstallError(null);
try {
const result = await httpClient.getGithubReleaseAssets(
githubOwner,
githubRepo,
release.id,
);
setGithubAssets(result.assets);
if (result.assets.length === 0) {
toast.warning(t('plugins.noAssetsFound'));
} else {
setPluginInstallStatus(PluginInstallStatus.SELECT_ASSET);
}
} catch (error: unknown) {
console.error('Failed to fetch GitHub release assets:', error);
const errorMessage =
error instanceof Error ? error.message : String(error);
setInstallError(errorMessage || t('plugins.fetchAssetsError'));
setPluginInstallStatus(PluginInstallStatus.ERROR);
} finally {
setFetchingAssets(false);
}
}
function handleAssetSelect(asset: GithubAsset) {
setSelectedAsset(asset);
setPluginInstallStatus(PluginInstallStatus.ASK_CONFIRM);
}
function handleModalConfirm() {
installPlugin(installSource, installInfo as Record<string, any>); // eslint-disable-line @typescript-eslint/no-explicit-any
if (installSource === 'github' && selectedAsset && selectedRelease) {
installPlugin('github', {
asset_url: selectedAsset.download_url,
owner: githubOwner,
repo: githubRepo,
release_tag: selectedRelease.tag_name,
});
} else {
installPlugin(installSource, installInfo as Record<string, any>); // eslint-disable-line @typescript-eslint/no-explicit-any
}
}
function installPlugin(
@@ -96,7 +255,12 @@ export default function PluginConfigPage() {
setPluginInstallStatus(PluginInstallStatus.INSTALLING);
if (installSource === 'github') {
httpClient
.installPluginFromGithub(installInfo.url)
.installPluginFromGithub(
installInfo.asset_url,
installInfo.owner,
installInfo.repo,
installInfo.release_tag,
)
.then((resp) => {
const taskId = resp.task_id;
watchTask(taskId);
@@ -140,6 +304,11 @@ export default function PluginConfigPage() {
const uploadPluginFile = useCallback(
async (file: File) => {
if (!pluginSystemStatus?.is_enable || !pluginSystemStatus?.is_connected) {
toast.error(t('plugins.pluginSystemNotReady'));
return;
}
if (!validateFileType(file)) {
toast.error(t('plugins.unsupportedFileType'));
return;
@@ -150,7 +319,7 @@ export default function PluginConfigPage() {
setInstallError(null);
installPlugin('local', { file });
},
[t],
[t, pluginSystemStatus, installPlugin],
);
const handleFileSelect = useCallback(() => {
@@ -165,16 +334,24 @@ export default function PluginConfigPage() {
if (file) {
uploadPluginFile(file);
}
// 清空input值以便可以重复选择同一个文件
event.target.value = '';
},
[uploadPluginFile],
);
const handleDragOver = useCallback((event: React.DragEvent) => {
event.preventDefault();
setIsDragOver(true);
}, []);
const isPluginSystemReady =
pluginSystemStatus?.is_enable && pluginSystemStatus?.is_connected;
const handleDragOver = useCallback(
(event: React.DragEvent) => {
event.preventDefault();
if (isPluginSystemReady) {
setIsDragOver(true);
}
},
[isPluginSystemReady],
);
const handleDragLeave = useCallback((event: React.DragEvent) => {
event.preventDefault();
@@ -186,14 +363,72 @@ export default function PluginConfigPage() {
event.preventDefault();
setIsDragOver(false);
if (!isPluginSystemReady) {
toast.error(t('plugins.pluginSystemNotReady'));
return;
}
const files = Array.from(event.dataTransfer.files);
if (files.length > 0) {
uploadPluginFile(files[0]);
}
},
[uploadPluginFile],
[uploadPluginFile, isPluginSystemReady, t],
);
const renderPluginDisabledState = () => (
<div className="flex flex-col items-center justify-center h-[60vh] text-center pt-[10vh]">
<Power className="w-16 h-16 text-gray-400 mb-4" />
<h2 className="text-2xl font-semibold text-gray-700 dark:text-gray-300 mb-2">
{t('plugins.systemDisabled')}
</h2>
<p className="text-gray-500 dark:text-gray-400 max-w-md">
{t('plugins.systemDisabledDesc')}
</p>
</div>
);
const renderPluginConnectionErrorState = () => (
<div className="flex flex-col items-center justify-center h-[60vh] text-center pt-[10vh]">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
width="72"
height="72"
fill="#BDBDBD"
>
<path d="M17.657 14.8284L16.2428 13.4142L17.657 12C19.2191 10.4379 19.2191 7.90526 17.657 6.34316C16.0949 4.78106 13.5622 4.78106 12.0001 6.34316L10.5859 7.75737L9.17171 6.34316L10.5859 4.92895C12.9291 2.5858 16.7281 2.5858 19.0712 4.92895C21.4143 7.27209 21.4143 11.0711 19.0712 13.4142L17.657 14.8284ZM14.8286 17.6569L13.4143 19.0711C11.0712 21.4142 7.27221 21.4142 4.92907 19.0711C2.58592 16.7279 2.58592 12.9289 4.92907 10.5858L6.34328 9.17159L7.75749 10.5858L6.34328 12C4.78118 13.5621 4.78118 16.0948 6.34328 17.6569C7.90538 19.219 10.438 19.219 12.0001 17.6569L13.4143 16.2427L14.8286 17.6569ZM14.8286 7.75737L16.2428 9.17159L9.17171 16.2427L7.75749 14.8284L14.8286 7.75737ZM5.77539 2.29291L7.70724 1.77527L8.74252 5.63897L6.81067 6.15661L5.77539 2.29291ZM15.2578 18.3611L17.1896 17.8434L18.2249 21.7071L16.293 22.2248L15.2578 18.3611ZM2.29303 5.77527L6.15673 6.81054L5.63909 8.7424L1.77539 7.70712L2.29303 5.77527ZM18.3612 15.2576L22.2249 16.2929L21.7072 18.2248L17.8435 17.1895L18.3612 15.2576Z"></path>
</svg>
<h2 className="text-2xl font-semibold text-gray-700 dark:text-gray-300 mb-2">
{t('plugins.connectionError')}
</h2>
<p className="text-gray-500 dark:text-gray-400 max-w-md mb-4">
{t('plugins.connectionErrorDesc')}
</p>
</div>
);
const renderLoadingState = () => (
<div className="flex flex-col items-center justify-center h-[60vh] pt-[10vh]">
<p className="text-gray-500 dark:text-gray-400">
{t('plugins.loadingStatus')}
</p>
</div>
);
if (statusLoading) {
return renderLoadingState();
}
if (!pluginSystemStatus?.is_enable) {
return renderPluginDisabledState();
}
if (!pluginSystemStatus?.is_connected) {
return renderPluginConnectionErrorState();
}
return (
<div
className={`${styles.pageContainer} ${isDragOver ? 'bg-blue-50' : ''}`}
@@ -219,40 +454,69 @@ export default function PluginConfigPage() {
{t('plugins.marketplace')}
</TabsTrigger>
)}
<TabsTrigger
value="mcp-servers"
className="px-6 py-4 cursor-pointer"
>
{t('mcp.title')}
</TabsTrigger>
</TabsList>
<div className="flex flex-row justify-end items-center">
{/* <Button
variant="outline"
className="px-6 py-4 cursor-pointer mr-2"
onClick={() => {
// setSortModalOpen(true);
}}
>
{t('plugins.arrange')}
</Button> */}
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button variant="default" className="px-6 py-4 cursor-pointer">
<PlusIcon className="w-4 h-4" />
{t('plugins.install')}
{activeTab === 'mcp-servers'
? t('mcp.add')
: t('plugins.install')}
<ChevronDownIcon className="ml-2 w-4 h-4" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem onClick={handleFileSelect}>
<UploadIcon className="w-4 h-4" />
{t('plugins.uploadLocal')}
</DropdownMenuItem>
{systemInfo.enable_marketplace && (
<DropdownMenuItem
onClick={() => {
setActiveTab('market');
}}
>
<StoreIcon className="w-4 h-4" />
{t('plugins.marketplace')}
</DropdownMenuItem>
{activeTab === 'mcp-servers' ? (
<>
<DropdownMenuItem
onClick={() => {
setActiveTab('mcp-servers');
setIsEditMode(false);
setEditingServerName(null);
setMcpSSEModalOpen(true);
}}
>
<PlusIcon className="w-4 h-4" />
{t('mcp.createServer')}
</DropdownMenuItem>
</>
) : (
<>
{systemInfo.enable_marketplace && (
<DropdownMenuItem
onClick={() => {
setActiveTab('market');
}}
>
<StoreIcon className="w-4 h-4" />
{t('plugins.marketplace')}
</DropdownMenuItem>
)}
<DropdownMenuItem onClick={handleFileSelect}>
<UploadIcon className="w-4 h-4" />
{t('plugins.uploadLocal')}
</DropdownMenuItem>
<DropdownMenuItem
onClick={() => {
setInstallSource('github');
setPluginInstallStatus(PluginInstallStatus.WAIT_INPUT);
setInstallError(null);
resetGithubState();
setModalOpen(true);
}}
>
<Github className="w-4 h-4" />
{t('plugins.installFromGithub')}
</DropdownMenuItem>
</>
)}
</DropdownMenuContent>
</DropdownMenu>
@@ -275,51 +539,259 @@ export default function PluginConfigPage() {
}}
/>
</TabsContent>
<TabsContent value="mcp-servers">
<MCPServerComponent
key={refreshKey}
onEditServer={(serverName) => {
setEditingServerName(serverName);
setIsEditMode(true);
setMcpSSEModalOpen(true);
}}
/>
</TabsContent>
</Tabs>
<Dialog open={modalOpen} onOpenChange={setModalOpen}>
<DialogContent className="w-[500px] p-6 bg-white dark:bg-[#1a1a1e]">
<Dialog
open={modalOpen}
onOpenChange={(open) => {
setModalOpen(open);
if (!open) {
resetGithubState();
setInstallError(null);
}
}}
>
<DialogContent className="w-[500px] max-h-[80vh] p-6 bg-white dark:bg-[#1a1a1e] overflow-y-auto">
<DialogHeader>
<DialogTitle className="flex items-center gap-4">
<Download className="size-6" />
{installSource === 'github' ? (
<Github className="size-6" />
) : (
<Download className="size-6" />
)}
<span>{t('plugins.installPlugin')}</span>
</DialogTitle>
</DialogHeader>
{pluginInstallStatus === PluginInstallStatus.WAIT_INPUT && (
<div className="mt-4">
<p className="mb-2">{t('plugins.onlySupportGithub')}</p>
<Input
placeholder={t('plugins.enterGithubLink')}
value={githubURL}
onChange={(e) => setGithubURL(e.target.value)}
className="mb-4"
/>
</div>
)}
{pluginInstallStatus === PluginInstallStatus.ASK_CONFIRM && (
<div className="mt-4">
<p className="mb-2">
{t('plugins.askConfirm', {
name: installInfo.plugin_name,
version: installInfo.plugin_version,
})}
</p>
</div>
)}
{/* GitHub Install Flow */}
{installSource === 'github' &&
pluginInstallStatus === PluginInstallStatus.WAIT_INPUT && (
<div className="mt-4">
<p className="mb-2">{t('plugins.enterRepoUrl')}</p>
<Input
placeholder={t('plugins.repoUrlPlaceholder')}
value={githubURL}
onChange={(e) => setGithubURL(e.target.value)}
className="mb-4"
/>
{fetchingReleases && (
<p className="text-sm text-gray-500">
{t('plugins.fetchingReleases')}
</p>
)}
</div>
)}
{installSource === 'github' &&
pluginInstallStatus === PluginInstallStatus.SELECT_RELEASE && (
<div className="mt-4">
<div className="flex items-center justify-between mb-4">
<p className="font-medium">{t('plugins.selectRelease')}</p>
<Button
variant="ghost"
size="sm"
onClick={() => {
setPluginInstallStatus(PluginInstallStatus.WAIT_INPUT);
setGithubReleases([]);
}}
>
<ChevronLeft className="w-4 h-4 mr-1" />
{t('plugins.backToRepoUrl')}
</Button>
</div>
<div className="max-h-[400px] overflow-y-auto space-y-2 pb-2">
{githubReleases.map((release) => (
<Card
key={release.id}
className="cursor-pointer hover:shadow-sm transition-shadow duration-200 shadow-none py-4"
onClick={() => handleReleaseSelect(release)}
>
<CardHeader className="flex flex-row items-start justify-between px-3 space-y-0">
<div className="flex-1">
<CardTitle className="text-sm">
{release.name || release.tag_name}
</CardTitle>
<CardDescription className="text-xs mt-1">
{t('plugins.releaseTag', { tag: release.tag_name })}{' '}
{' '}
{t('plugins.publishedAt', {
date: new Date(
release.published_at,
).toLocaleDateString(),
})}
</CardDescription>
</div>
{release.prerelease && (
<span className="text-xs bg-yellow-100 dark:bg-yellow-900 text-yellow-800 dark:text-yellow-200 px-2 py-0.5 rounded ml-2 shrink-0">
{t('plugins.prerelease')}
</span>
)}
</CardHeader>
</Card>
))}
</div>
{fetchingAssets && (
<p className="text-sm text-gray-500 mt-4">
{t('plugins.loading')}
</p>
)}
</div>
)}
{installSource === 'github' &&
pluginInstallStatus === PluginInstallStatus.SELECT_ASSET && (
<div className="mt-4">
<div className="flex items-center justify-between mb-4">
<p className="font-medium">{t('plugins.selectAsset')}</p>
<Button
variant="ghost"
size="sm"
onClick={() => {
setPluginInstallStatus(
PluginInstallStatus.SELECT_RELEASE,
);
setGithubAssets([]);
setSelectedAsset(null);
}}
>
<ChevronLeft className="w-4 h-4 mr-1" />
{t('plugins.backToReleases')}
</Button>
</div>
{selectedRelease && (
<div className="mb-4 p-2 bg-gray-50 dark:bg-gray-900 rounded">
<div className="text-sm font-medium">
{selectedRelease.name || selectedRelease.tag_name}
</div>
<div className="text-xs text-gray-500">
{selectedRelease.tag_name}
</div>
</div>
)}
<div className="max-h-[400px] overflow-y-auto space-y-2 pb-2">
{githubAssets.map((asset) => (
<Card
key={asset.id}
className="cursor-pointer hover:shadow-sm transition-shadow duration-200 shadow-none py-3"
onClick={() => handleAssetSelect(asset)}
>
<CardHeader className="px-3">
<CardTitle className="text-sm">{asset.name}</CardTitle>
<CardDescription className="text-xs">
{t('plugins.assetSize', {
size: formatFileSize(asset.size),
})}
</CardDescription>
</CardHeader>
</Card>
))}
</div>
</div>
)}
{/* Marketplace Install Confirm */}
{installSource === 'marketplace' &&
pluginInstallStatus === PluginInstallStatus.ASK_CONFIRM && (
<div className="mt-4">
<p className="mb-2">
{t('plugins.askConfirm', {
name: installInfo.plugin_name,
version: installInfo.plugin_version,
})}
</p>
</div>
)}
{/* GitHub Install Confirm */}
{installSource === 'github' &&
pluginInstallStatus === PluginInstallStatus.ASK_CONFIRM && (
<div className="mt-4">
<div className="flex items-center justify-between mb-4">
<p className="font-medium">{t('plugins.confirmInstall')}</p>
<Button
variant="ghost"
size="sm"
onClick={() => {
setPluginInstallStatus(PluginInstallStatus.SELECT_ASSET);
setSelectedAsset(null);
}}
>
<ChevronLeft className="w-4 h-4 mr-1" />
{t('plugins.backToAssets')}
</Button>
</div>
{selectedRelease && selectedAsset && (
<div className="p-3 bg-gray-50 dark:bg-gray-900 rounded space-y-2">
<div>
<span className="text-sm font-medium">Repository: </span>
<span className="text-sm">
{githubOwner}/{githubRepo}
</span>
</div>
<div>
<span className="text-sm font-medium">Release: </span>
<span className="text-sm">
{selectedRelease.tag_name}
</span>
</div>
<div>
<span className="text-sm font-medium">File: </span>
<span className="text-sm">{selectedAsset.name}</span>
</div>
</div>
)}
</div>
)}
{/* Installing State */}
{pluginInstallStatus === PluginInstallStatus.INSTALLING && (
<div className="mt-4">
<p className="mb-2">{t('plugins.installing')}</p>
</div>
)}
{/* Error State */}
{pluginInstallStatus === PluginInstallStatus.ERROR && (
<div className="mt-4">
<p className="mb-2">{t('plugins.installFailed')}</p>
<p className="mb-2 text-red-500">{installError}</p>
</div>
)}
<DialogFooter>
{(pluginInstallStatus === PluginInstallStatus.WAIT_INPUT ||
pluginInstallStatus === PluginInstallStatus.ASK_CONFIRM) && (
{pluginInstallStatus === PluginInstallStatus.WAIT_INPUT &&
installSource === 'github' && (
<>
<Button
variant="outline"
onClick={() => {
setModalOpen(false);
resetGithubState();
}}
>
{t('common.cancel')}
</Button>
<Button
onClick={fetchGithubReleases}
disabled={!githubURL.trim() || fetchingReleases}
>
{fetchingReleases
? t('plugins.loading')
: t('common.confirm')}
</Button>
</>
)}
{pluginInstallStatus === PluginInstallStatus.ASK_CONFIRM && (
<>
<Button variant="outline" onClick={() => setModalOpen(false)}>
{t('common.cancel')}
@@ -338,7 +810,6 @@ export default function PluginConfigPage() {
</DialogContent>
</Dialog>
{/* 拖拽提示覆盖层 */}
{isDragOver && (
<div className="fixed inset-0 bg-gray-500 bg-opacity-50 flex items-center justify-center z-50 pointer-events-none">
<div className="bg-white rounded-lg p-8 shadow-lg border-2 border-dashed border-gray-500">
@@ -352,13 +823,32 @@ export default function PluginConfigPage() {
</div>
)}
{/* <PluginSortDialog
open={sortModalOpen}
onOpenChange={setSortModalOpen}
onSortComplete={() => {
pluginInstalledRef.current?.refreshPluginList();
<MCPFormDialog
open={mcpSSEModalOpen}
onOpenChange={setMcpSSEModalOpen}
serverName={editingServerName}
isEditMode={isEditMode}
onSuccess={() => {
setEditingServerName(null);
setIsEditMode(false);
setRefreshKey((prev) => prev + 1);
}}
/> */}
onDelete={() => {
setShowDeleteConfirmModal(true);
}}
/>
<MCPDeleteConfirmDialog
open={showDeleteConfirmModal}
onOpenChange={setShowDeleteConfirmModal}
serverName={editingServerName}
onSuccess={() => {
setMcpSSEModalOpen(false);
setEditingServerName(null);
setIsEditMode(false);
setRefreshKey((prev) => prev + 1);
}}
/>
</div>
);
}

View File

@@ -1,120 +0,0 @@
import { useState, useEffect } from 'react';
import { ApiRespPluginConfig } from '@/app/infra/entities/api';
import { Plugin } from '@/app/infra/entities/plugin';
import { httpClient } from '@/app/infra/http/HttpClient';
import DynamicFormComponent from '@/app/home/components/dynamic-form/DynamicFormComponent';
import { Button } from '@/components/ui/button';
import { toast } from 'sonner';
import { extractI18nObject } from '@/i18n/I18nProvider';
import { useTranslation } from 'react-i18next';
export default function PluginForm({
pluginAuthor,
pluginName,
onFormSubmit,
onFormCancel,
}: {
pluginAuthor: string;
pluginName: string;
onFormSubmit: (timeout?: number) => void;
onFormCancel: () => void;
}) {
const { t } = useTranslation();
const [pluginInfo, setPluginInfo] = useState<Plugin>();
const [pluginConfig, setPluginConfig] = useState<ApiRespPluginConfig>();
const [isSaving, setIsLoading] = useState(false);
useEffect(() => {
// 获取插件信息
httpClient.getPlugin(pluginAuthor, pluginName).then((res) => {
setPluginInfo(res.plugin);
});
// 获取插件配置
httpClient.getPluginConfig(pluginAuthor, pluginName).then((res) => {
setPluginConfig(res);
});
}, [pluginAuthor, pluginName]);
const handleSubmit = async (values: object) => {
setIsLoading(true);
const isDebugPlugin = pluginInfo?.debug;
httpClient
.updatePluginConfig(pluginAuthor, pluginName, values)
.then(() => {
toast.success(
isDebugPlugin
? t('plugins.saveConfigSuccessDebugPlugin')
: t('plugins.saveConfigSuccessNormal'),
);
onFormSubmit(1000);
})
.catch((error) => {
toast.error(t('plugins.saveConfigError') + error.message);
})
.finally(() => {
setIsLoading(false);
});
};
if (!pluginInfo || !pluginConfig) {
return (
<div className="flex items-center justify-center h-full mb-[2rem]">
{t('plugins.loading')}
</div>
);
}
return (
<div>
<div className="space-y-2">
<div className="text-lg font-medium">
{extractI18nObject(pluginInfo.manifest.manifest.metadata.label)}
</div>
<div className="text-sm text-gray-500 pb-2">
{extractI18nObject(
pluginInfo.manifest.manifest.metadata.description ?? {
en_US: '',
zh_Hans: '',
},
)}
</div>
{pluginInfo.manifest.manifest.spec.config.length > 0 && (
<DynamicFormComponent
itemConfigList={pluginInfo.manifest.manifest.spec.config}
initialValues={pluginConfig.config as Record<string, object>}
onSubmit={(values) => {
let config = pluginConfig.config;
config = {
...config,
...values,
};
setPluginConfig({
config: config,
});
}}
/>
)}
{pluginInfo.manifest.manifest.spec.config.length === 0 && (
<div className="text-sm text-gray-500">
{t('plugins.pluginNoConfig')}
</div>
)}
</div>
<div className="sticky bottom-0 left-0 right-0 bg-background border-t p-4 mt-4">
<div className="flex justify-end gap-2">
<Button
type="submit"
onClick={() => handleSubmit(pluginConfig.config)}
disabled={isSaving}
>
{isSaving ? t('plugins.saving') : t('plugins.saveConfig')}
</Button>
<Button type="button" variant="outline" onClick={onFormCancel}>
{t('plugins.cancel')}
</Button>
</div>
</div>
</div>
);
}

Some files were not shown because too many files have changed in this diff Show More