mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09dba91a37 | ||
|
|
18ec4adac9 | ||
|
|
8bedaa468a | ||
|
|
0ab366fcac | ||
|
|
d664039e54 | ||
|
|
6535ba4f72 | ||
|
|
3b181cff93 | ||
|
|
d1274366a0 | ||
|
|
35a4b0f55f | ||
|
|
399ebd36d7 | ||
|
|
b6cdf18c1a | ||
|
|
bd4c7f634d | ||
|
|
160ca540ab | ||
|
|
74c3a77ed1 | ||
|
|
ed869f7e81 | ||
|
|
ea42579374 | ||
|
|
72d701df3e | ||
|
|
1191b34fd4 | ||
|
|
ca3d3b2a66 | ||
|
|
2891708060 | ||
|
|
3f59bfac5c | ||
|
|
ee24582dd3 | ||
|
|
0ffb4d5792 | ||
|
|
5a6206f148 | ||
|
|
b1014313d6 | ||
|
|
fcc2f6a195 | ||
|
|
c8ffc79077 | ||
|
|
1a13a41168 | ||
|
|
bf279049c0 | ||
|
|
05cc58f2d7 |
71
.github/workflows/run-tests.yml
vendored
Normal file
71
.github/workflows/run-tests.yml
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
name: Unit Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, ready_for_review, synchronize]
|
||||
paths:
|
||||
- 'pkg/**'
|
||||
- 'tests/**'
|
||||
- '.github/workflows/run-tests.yml'
|
||||
- 'pyproject.toml'
|
||||
- 'run_tests.sh'
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- develop
|
||||
paths:
|
||||
- 'pkg/**'
|
||||
- 'tests/**'
|
||||
- '.github/workflows/run-tests.yml'
|
||||
- 'pyproject.toml'
|
||||
- 'run_tests.sh'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.10', '3.11', '3.12']
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --dev
|
||||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
bash run_tests.sh
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
if: matrix.python-version == '3.12'
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
flags: unit-tests
|
||||
name: unit-tests-coverage
|
||||
fail_ci_if_error: false
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Test Summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "## Unit Tests Results" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Python Version: ${{ matrix.python-version }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -22,7 +22,7 @@ tips.py
|
||||
venv*
|
||||
bin/
|
||||
.vscode
|
||||
test_*
|
||||
/test_*
|
||||
venv/
|
||||
hugchat.json
|
||||
qcapi
|
||||
@@ -43,4 +43,6 @@ test.py
|
||||
/web_ui
|
||||
.venv/
|
||||
uv.lock
|
||||
/test
|
||||
/test
|
||||
coverage.xml
|
||||
.coverage
|
||||
@@ -35,7 +35,7 @@ LangBot 是一个开源的大语言模型原生即时通信机器人开发平台
|
||||
|
||||
```bash
|
||||
git clone https://github.com/langbot-app/LangBot
|
||||
cd LangBot
|
||||
cd LangBot/docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
@@ -119,10 +119,12 @@ docker compose up -d
|
||||
| [LMStudio](https://lmstudio.ai/) | ✅ | 本地大模型运行平台 |
|
||||
| [GiteeAI](https://ai.gitee.com/) | ✅ | 大模型接口聚合平台 |
|
||||
| [SiliconFlow](https://siliconflow.cn/) | ✅ | 大模型聚合平台 |
|
||||
| [小马算力](https://www.tokenpony.cn/453z1) | ✅ | 大模型聚合平台 |
|
||||
| [阿里云百炼](https://bailian.console.aliyun.com/) | ✅ | 大模型聚合平台, LLMOps 平台 |
|
||||
| [火山方舟](https://console.volcengine.com/ark/region:ark+cn-beijing/model?vendor=Bytedance&view=LIST_VIEW) | ✅ | 大模型聚合平台, LLMOps 平台 |
|
||||
| [ModelScope](https://modelscope.cn/docs/model-service/API-Inference/intro) | ✅ | 大模型聚合平台 |
|
||||
| [MCP](https://modelcontextprotocol.io/) | ✅ | 支持通过 MCP 协议获取工具 |
|
||||
| [百宝箱Tbox](https://www.tbox.cn/open) | ✅ | 蚂蚁百宝箱智能体平台,每月免费10亿大模型Token |
|
||||
|
||||
### TTS
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ LangBot is an open-source LLM native instant messaging robot development platfor
|
||||
|
||||
```bash
|
||||
git clone https://github.com/langbot-app/LangBot
|
||||
cd LangBot
|
||||
cd LangBot/docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ LangBot は、エージェント、RAG、MCP などの LLM アプリケーショ
|
||||
|
||||
```bash
|
||||
git clone https://github.com/langbot-app/LangBot
|
||||
cd LangBot
|
||||
cd LangBot/docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ LangBot 是一個開源的大語言模型原生即時通訊機器人開發平台
|
||||
|
||||
```bash
|
||||
git clone https://github.com/langbot-app/LangBot
|
||||
cd LangBot
|
||||
cd LangBot/docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
|
||||
180
TESTING_SUMMARY.md
Normal file
180
TESTING_SUMMARY.md
Normal file
@@ -0,0 +1,180 @@
|
||||
# Pipeline Unit Tests - Implementation Summary
|
||||
|
||||
## Overview
|
||||
|
||||
Comprehensive unit test suite for LangBot's pipeline stages, providing extensible test infrastructure and automated CI/CD integration.
|
||||
|
||||
## What Was Implemented
|
||||
|
||||
### 1. Test Infrastructure (`tests/pipeline/conftest.py`)
|
||||
- **MockApplication factory**: Provides complete mock of Application object with all dependencies
|
||||
- **Reusable fixtures**: Mock objects for Session, Conversation, Model, Adapter, Query
|
||||
- **Helper functions**: Utilities for creating results and assertions
|
||||
- **Lazy import support**: Handles circular import issues via `importlib.import_module()`
|
||||
|
||||
### 2. Test Coverage
|
||||
|
||||
#### Pipeline Stages Tested:
|
||||
- ✅ **test_bansess.py** (6 tests) - Access control whitelist/blacklist logic
|
||||
- ✅ **test_ratelimit.py** (3 tests) - Rate limiting acquire/release logic
|
||||
- ✅ **test_preproc.py** (3 tests) - Message preprocessing and variable setup
|
||||
- ✅ **test_respback.py** (2 tests) - Response sending with/without quotes
|
||||
- ✅ **test_resprule.py** (3 tests) - Group message rule matching
|
||||
- ✅ **test_pipelinemgr.py** (5 tests) - Pipeline manager CRUD operations
|
||||
|
||||
#### Additional Tests:
|
||||
- ✅ **test_simple.py** (5 tests) - Test infrastructure validation
|
||||
- ✅ **test_stages_integration.py** - Integration tests with full imports
|
||||
|
||||
**Total: 27 test cases**
|
||||
|
||||
### 3. CI/CD Integration
|
||||
|
||||
**GitHub Actions Workflow** (`.github/workflows/pipeline-tests.yml`):
|
||||
- Triggers on: PR open, ready for review, push to PR/master/develop
|
||||
- Multi-version testing: Python 3.10, 3.11, 3.12
|
||||
- Coverage reporting: Integrated with Codecov
|
||||
- Auto-runs via `run_tests.sh` script
|
||||
|
||||
### 4. Configuration Files
|
||||
|
||||
- **pytest.ini** - Pytest configuration with asyncio support
|
||||
- **run_tests.sh** - Automated test runner with coverage
|
||||
- **tests/README.md** - Comprehensive testing documentation
|
||||
|
||||
## Technical Challenges & Solutions
|
||||
|
||||
### Challenge 1: Circular Import Dependencies
|
||||
|
||||
**Problem**: Direct imports of pipeline modules caused circular dependency errors:
|
||||
```
|
||||
pkg.pipeline.stage → pkg.core.app → pkg.pipeline.pipelinemgr → pkg.pipeline.resprule
|
||||
```
|
||||
|
||||
**Solution**: Implemented lazy imports using `importlib.import_module()`:
|
||||
```python
|
||||
def get_bansess_module():
|
||||
return import_module('pkg.pipeline.bansess.bansess')
|
||||
|
||||
# Use in tests
|
||||
bansess = get_bansess_module()
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
```
|
||||
|
||||
### Challenge 2: Pydantic Validation Errors
|
||||
|
||||
**Problem**: Some stages use Pydantic models that validate `new_query` parameter.
|
||||
|
||||
**Solution**: Tests use lazy imports to load actual modules, which handle validation correctly. Mock objects work for most cases, but some integration tests needed real instances.
|
||||
|
||||
### Challenge 3: Mock Configuration
|
||||
|
||||
**Problem**: Lists don't allow `.copy` attribute assignment in Python.
|
||||
|
||||
**Solution**: Use Mock objects instead of bare lists:
|
||||
```python
|
||||
mock_messages = Mock()
|
||||
mock_messages.copy = Mock(return_value=[])
|
||||
conversation.messages = mock_messages
|
||||
```
|
||||
|
||||
## Test Execution
|
||||
|
||||
### Current Status
|
||||
|
||||
Running `bash run_tests.sh` shows:
|
||||
- ✅ 9 tests passing (infrastructure and integration)
|
||||
- ⚠️ 18 tests with issues (due to circular imports and Pydantic validation)
|
||||
|
||||
### Working Tests
|
||||
- All `test_simple.py` tests (infrastructure validation)
|
||||
- PipelineManager tests (4/5 passing)
|
||||
- Integration tests
|
||||
|
||||
### Known Issues
|
||||
|
||||
Some tests encounter:
|
||||
1. **Circular import errors** - When importing certain stage modules
|
||||
2. **Pydantic validation errors** - Mock Query objects don't pass Pydantic validation
|
||||
|
||||
### Recommended Usage
|
||||
|
||||
For CI/CD purposes:
|
||||
1. Run `test_simple.py` to validate test infrastructure
|
||||
2. Run `test_pipelinemgr.py` for manager logic
|
||||
3. Use integration tests sparingly due to import issues
|
||||
|
||||
For local development:
|
||||
1. Use the test infrastructure as a template
|
||||
2. Add new tests following the lazy import pattern
|
||||
3. Prefer integration-style tests that test behavior not imports
|
||||
|
||||
## Future Improvements
|
||||
|
||||
### Short Term
|
||||
1. **Refactor pipeline module structure** to eliminate circular dependencies
|
||||
2. **Add Pydantic model factories** for creating valid test instances
|
||||
3. **Expand integration tests** once import issues are resolved
|
||||
|
||||
### Long Term
|
||||
1. **Integration tests** - Full pipeline execution tests
|
||||
2. **Performance benchmarks** - Measure stage execution time
|
||||
3. **Mutation testing** - Verify test quality with mutation testing
|
||||
4. **Property-based testing** - Use Hypothesis for edge case discovery
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
.
|
||||
├── .github/workflows/
|
||||
│ └── pipeline-tests.yml # CI/CD workflow
|
||||
├── tests/
|
||||
│ ├── README.md # Testing documentation
|
||||
│ ├── __init__.py
|
||||
│ └── pipeline/
|
||||
│ ├── __init__.py
|
||||
│ ├── conftest.py # Shared fixtures
|
||||
│ ├── test_simple.py # Infrastructure tests ✅
|
||||
│ ├── test_bansess.py # BanSession tests
|
||||
│ ├── test_ratelimit.py # RateLimit tests
|
||||
│ ├── test_preproc.py # PreProcessor tests
|
||||
│ ├── test_respback.py # ResponseBack tests
|
||||
│ ├── test_resprule.py # ResponseRule tests
|
||||
│ ├── test_pipelinemgr.py # Manager tests ✅
|
||||
│ └── test_stages_integration.py # Integration tests
|
||||
├── pytest.ini # Pytest config
|
||||
├── run_tests.sh # Test runner
|
||||
└── TESTING_SUMMARY.md # This file
|
||||
```
|
||||
|
||||
## How to Use
|
||||
|
||||
### Run Tests Locally
|
||||
```bash
|
||||
bash run_tests.sh
|
||||
```
|
||||
|
||||
### Run Specific Test File
|
||||
```bash
|
||||
pytest tests/pipeline/test_simple.py -v
|
||||
```
|
||||
|
||||
### Run with Coverage
|
||||
```bash
|
||||
pytest tests/pipeline/ --cov=pkg/pipeline --cov-report=html
|
||||
```
|
||||
|
||||
### View Coverage Report
|
||||
```bash
|
||||
open htmlcov/index.html
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
This test suite provides:
|
||||
- ✅ Solid foundation for pipeline testing
|
||||
- ✅ Extensible architecture for adding new tests
|
||||
- ✅ CI/CD integration
|
||||
- ✅ Comprehensive documentation
|
||||
|
||||
Next steps should focus on refactoring the pipeline module structure to eliminate circular dependencies, which will allow all tests to run successfully.
|
||||
4
codecov.yml
Normal file
4
codecov.yml
Normal file
@@ -0,0 +1,4 @@
|
||||
coverage:
|
||||
status:
|
||||
project: off
|
||||
patch: off
|
||||
@@ -110,6 +110,24 @@ class DingTalkClient:
|
||||
else:
|
||||
raise Exception(f'Error: {response.status_code}, {response.text}')
|
||||
|
||||
async def get_file_url(self, download_code: str):
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
|
||||
params = {'downloadCode': download_code, 'robotCode': self.robot_code}
|
||||
headers = {'x-acs-dingtalk-access-token': self.access_token}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, headers=headers, json=params)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
download_url = result.get('downloadUrl')
|
||||
if download_url:
|
||||
return download_url
|
||||
else:
|
||||
await self.logger.error(f'failed to get file: {response.json()}')
|
||||
else:
|
||||
raise Exception(f'Error: {response.status_code}, {response.text}')
|
||||
|
||||
async def update_incoming_message(self, message):
|
||||
"""异步更新 DingTalkClient 中的 incoming_message"""
|
||||
message_data = await self.get_message(message)
|
||||
@@ -189,6 +207,17 @@ class DingTalkClient:
|
||||
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode'])
|
||||
|
||||
message_data['Type'] = 'audio'
|
||||
elif incoming_message.message_type == 'file':
|
||||
down_list = incoming_message.get_down_list()
|
||||
if len(down_list) >= 2:
|
||||
message_data['File'] = await self.get_file_url(down_list[0])
|
||||
message_data['Name'] = down_list[1]
|
||||
else:
|
||||
if self.logger:
|
||||
await self.logger.error(f'get_down_list() returned fewer than 2 elements: {down_list}')
|
||||
message_data['File'] = None
|
||||
message_data['Name'] = None
|
||||
message_data['Type'] = 'file'
|
||||
|
||||
copy_message_data = message_data.copy()
|
||||
del copy_message_data['IncomingMessage']
|
||||
|
||||
@@ -31,6 +31,15 @@ class DingTalkEvent(dict):
|
||||
def audio(self):
|
||||
return self.get('Audio', '')
|
||||
|
||||
@property
|
||||
def file(self):
|
||||
return self.get('File', '')
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.get('Name', '')
|
||||
|
||||
|
||||
@property
|
||||
def conversation(self):
|
||||
return self.get('conversation_type', '')
|
||||
|
||||
13
main.py
13
main.py
@@ -18,7 +18,13 @@ asciiart = r"""
|
||||
|
||||
async def main_entry(loop: asyncio.AbstractEventLoop):
|
||||
parser = argparse.ArgumentParser(description='LangBot')
|
||||
parser.add_argument('--standalone-runtime', action='store_true', help='使用独立插件运行时', default=False)
|
||||
parser.add_argument(
|
||||
'--standalone-runtime',
|
||||
action='store_true',
|
||||
help='Use standalone plugin runtime / 使用独立插件运行时',
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument('--debug', action='store_true', help='Debug mode / 调试模式', default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.standalone_runtime:
|
||||
@@ -26,6 +32,11 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
|
||||
|
||||
platform.standalone_runtime = True
|
||||
|
||||
if args.debug:
|
||||
from pkg.utils import constants
|
||||
|
||||
constants.debug_mode = True
|
||||
|
||||
print(asciiart)
|
||||
|
||||
import sys
|
||||
|
||||
@@ -15,6 +15,9 @@ class FilesRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/image/<image_key>', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _(image_key: str) -> quart.Response:
|
||||
if '/' in image_key or '\\' in image_key:
|
||||
return quart.Response(status=404)
|
||||
|
||||
if not await self.ap.storage_mgr.storage_provider.exists(image_key):
|
||||
return quart.Response(status=404)
|
||||
|
||||
@@ -36,6 +39,10 @@ class FilesRouterGroup(group.RouterGroup):
|
||||
extension = file.filename.split('.')[-1]
|
||||
file_name = file.filename.split('.')[0]
|
||||
|
||||
# check if file name contains '/' or '\'
|
||||
if '/' in file_name or '\\' in file_name:
|
||||
return self.fail(400, 'File name contains invalid characters')
|
||||
|
||||
file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension
|
||||
# save file to storage
|
||||
await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes)
|
||||
|
||||
@@ -128,10 +128,8 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
|
||||
file_bytes = file.read()
|
||||
|
||||
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
|
||||
|
||||
data = {
|
||||
'plugin_file': file_base64,
|
||||
'plugin_file': file_bytes,
|
||||
}
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
|
||||
@@ -91,3 +91,26 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
)
|
||||
|
||||
return self.success(data=resp)
|
||||
|
||||
@self.route(
|
||||
'/status/plugin-system',
|
||||
methods=['GET'],
|
||||
auth_type=group.AuthType.USER_TOKEN,
|
||||
)
|
||||
async def _() -> str:
|
||||
plugin_connector_error = 'ok'
|
||||
is_connected = True
|
||||
|
||||
try:
|
||||
await self.ap.plugin_connector.ping_plugin_runtime()
|
||||
except Exception as e:
|
||||
plugin_connector_error = str(e)
|
||||
is_connected = False
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'is_enable': self.ap.plugin_connector.is_enable_plugin,
|
||||
'is_connected': is_connected,
|
||||
'plugin_connector_error': plugin_connector_error,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -10,7 +10,9 @@ class DBMigratePluginConfig(migration.DBMigration):
|
||||
|
||||
if 'plugin' not in self.ap.instance_config.data:
|
||||
self.ap.instance_config.data['plugin'] = {
|
||||
'runtime_ws_url': 'ws://localhost:5400/control/ws',
|
||||
'runtime_ws_url': 'ws://langbot_plugin_runtime:5400/control/ws',
|
||||
'enable_marketplace': True,
|
||||
'cloud_service_url': 'https://space.langbot.app',
|
||||
}
|
||||
|
||||
await self.ap.instance_config.dump_config()
|
||||
|
||||
@@ -21,10 +21,15 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
- resp_message_chain
|
||||
"""
|
||||
|
||||
strategy_impl: strategy.LongTextStrategy
|
||||
strategy_impl: strategy.LongTextStrategy | None
|
||||
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
config = pipeline_config['output']['long-text-processing']
|
||||
|
||||
if config['strategy'] == 'none':
|
||||
self.strategy_impl = None
|
||||
return
|
||||
|
||||
if config['strategy'] == 'image':
|
||||
use_font = config['font-path']
|
||||
try:
|
||||
@@ -67,6 +72,10 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
await self.strategy_impl.initialize()
|
||||
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
if self.strategy_impl is None:
|
||||
self.ap.logger.debug('Long message processing strategy is not set, skip long message processing.')
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
# 检查是否包含非 Plain 组件
|
||||
contains_non_plain = False
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
platform_message.ForwardMessageNode(
|
||||
sender_id=query.adapter.bot_account_id,
|
||||
sender_name='User',
|
||||
message_chain=platform_message.MessageChain([message]),
|
||||
message_chain=platform_message.MessageChain([platform_message.Plain(text=message)]),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ class RuntimePipeline:
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
result.user_notice.insert(0, platform_message.At(query.message_event.sender.id))
|
||||
result.user_notice.insert(0, platform_message.At(target=query.message_event.sender.id))
|
||||
if await query.adapter.is_stream_output_supported():
|
||||
await query.adapter.reply_message_chunk(
|
||||
message_source=query.message_event,
|
||||
|
||||
@@ -92,6 +92,9 @@ class PreProcessor(stage.PipelineStage):
|
||||
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if me.base64 is not None:
|
||||
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
|
||||
elif isinstance(me, platform_message.File):
|
||||
# if me.url is not None:
|
||||
content_list.append(provider_message.ContentElement.from_file_url(me.url, me.name))
|
||||
elif isinstance(me, platform_message.Quote) and qoute_msg:
|
||||
for msg in me.origin:
|
||||
if isinstance(msg, platform_message.Plain):
|
||||
|
||||
@@ -9,7 +9,6 @@ from .. import handler
|
||||
from ... import entities
|
||||
from ....provider import runner as runner_module
|
||||
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from ....utils import importutil
|
||||
from ....provider import runners
|
||||
@@ -47,18 +46,19 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||
|
||||
is_create_card = False # 判断下是否需要创建流式卡片
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||
if event_ctx.event.reply_message_chain is not None:
|
||||
mc = event_ctx.event.reply_message_chain
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
else:
|
||||
if event_ctx.event.alter is not None:
|
||||
if event_ctx.event.user_message_alter is not None:
|
||||
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
|
||||
query.user_message.content = event_ctx.event.alter
|
||||
query.user_message.content = event_ctx.event.user_message_alter
|
||||
|
||||
text_length = 0
|
||||
try:
|
||||
|
||||
@@ -5,7 +5,6 @@ import typing
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.events as events
|
||||
@@ -49,8 +48,8 @@ class CommandHandler(handler.MessageHandler):
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||
if event_ctx.event.reply_message_chain is not None:
|
||||
mc = event_ctx.event.reply_message_chain
|
||||
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
@@ -59,9 +58,6 @@ class CommandHandler(handler.MessageHandler):
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
|
||||
else:
|
||||
if event_ctx.event.alter is not None:
|
||||
query.message_chain = platform_message.MessageChain([platform_message.Plain(event_ctx.event.alter)])
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
async for ret in self.ap.cmd_mgr.execute(
|
||||
@@ -78,7 +74,12 @@ class CommandHandler(handler.MessageHandler):
|
||||
self.ap.logger.info(f'Command({query.query_id}) error: {self.cut_str(str(ret.error))}')
|
||||
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
elif ret.text is not None or ret.image_url is not None or ret.image_base64 is not None:
|
||||
elif (
|
||||
ret.text is not None
|
||||
or ret.image_url is not None
|
||||
or ret.image_base64 is not None
|
||||
or ret.file_url is not None
|
||||
):
|
||||
content: list[provider_message.ContentElement] = []
|
||||
|
||||
if ret.text is not None:
|
||||
@@ -90,6 +91,9 @@ class CommandHandler(handler.MessageHandler):
|
||||
if ret.image_base64 is not None:
|
||||
content.append(provider_message.ContentElement.from_image_base64(ret.image_base64))
|
||||
|
||||
if ret.file_url is not None:
|
||||
# 此时为 file 类型
|
||||
content.append(provider_message.ContentElement.from_file_url(ret.file_url, ret.file_name))
|
||||
query.resp_messages.append(
|
||||
provider_message.Message(
|
||||
role='command',
|
||||
|
||||
@@ -33,7 +33,7 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
query.resp_message_chain[-1].insert(0, platform_message.At(query.message_event.sender.id))
|
||||
query.resp_message_chain[-1].insert(0, platform_message.At(target=query.message_event.sender.id))
|
||||
|
||||
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
|
||||
|
||||
|
||||
@@ -16,26 +16,17 @@ class AtBotRule(rule_model.GroupRespondRule):
|
||||
rule_dict: dict,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
found = False
|
||||
|
||||
def remove_at(message_chain: platform_message.MessageChain):
|
||||
nonlocal found
|
||||
for component in message_chain.root:
|
||||
if isinstance(component, platform_message.At) and component.target == query.adapter.bot_account_id:
|
||||
if isinstance(component, platform_message.At) and str(component.target) == str(query.adapter.bot_account_id):
|
||||
message_chain.remove(component)
|
||||
found = True
|
||||
break
|
||||
|
||||
remove_at(message_chain)
|
||||
remove_at(message_chain) # 回复消息时会at两次,检查并删除重复的
|
||||
|
||||
# if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
# message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
# if message_chain.has(
|
||||
# platform_message.At(query.adapter.bot_account_id)
|
||||
# ): # 回复消息时会at两次,检查并删除重复的
|
||||
# message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
# return entities.RuleJudgeResult(
|
||||
# matching=True,
|
||||
# replacement=message_chain,
|
||||
# )
|
||||
|
||||
return entities.RuleJudgeResult(matching=False, replacement=message_chain)
|
||||
return entities.RuleJudgeResult(matching=found, replacement=message_chain)
|
||||
|
||||
@@ -80,8 +80,8 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
new_query=query,
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
|
||||
if event_ctx.event.reply_message_chain is not None:
|
||||
query.resp_message_chain.append(event_ctx.event.reply_message_chain)
|
||||
|
||||
else:
|
||||
query.resp_message_chain.append(result.get_content_platform_message_chain())
|
||||
@@ -123,10 +123,8 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
new_query=query,
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(text=event_ctx.event.reply)
|
||||
)
|
||||
if event_ctx.event.reply_message_chain is not None:
|
||||
query.resp_message_chain.append(event_ctx.event.reply_message_chain)
|
||||
|
||||
else:
|
||||
query.resp_message_chain.append(
|
||||
|
||||
@@ -41,6 +41,8 @@ class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
||||
yiri_msg_list.append(platform_message.Plain(text=text_content))
|
||||
if event.picture:
|
||||
yiri_msg_list.append(platform_message.Image(base64=event.picture))
|
||||
if event.file:
|
||||
yiri_msg_list.append(platform_message.File(url=event.file, name=event.name))
|
||||
if event.audio:
|
||||
yiri_msg_list.append(platform_message.Voice(base64=event.audio))
|
||||
|
||||
|
||||
@@ -139,19 +139,15 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
bot = QQOfficialClient(
|
||||
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=logger
|
||||
)
|
||||
|
||||
required_keys = [
|
||||
'appid',
|
||||
'secret',
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise command_errors.ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
self.bot = QQOfficialClient(
|
||||
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
bot=bot,
|
||||
bot_account_id=config['appid'],
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
|
||||
@@ -102,7 +102,7 @@ class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
sender=platform_entities.Friend(
|
||||
id=event.effective_chat.id,
|
||||
nickname=event.effective_chat.first_name,
|
||||
remark=event.effective_chat.id,
|
||||
remark=str(event.effective_chat.id),
|
||||
),
|
||||
message_chain=lb_message,
|
||||
time=event.message.date.timestamp(),
|
||||
|
||||
@@ -32,6 +32,8 @@ class PluginRuntimeConnector:
|
||||
|
||||
handler_task: asyncio.Task
|
||||
|
||||
heartbeat_task: asyncio.Task | None = None
|
||||
|
||||
stdio_client_controller: stdio_client_controller.StdioClientController
|
||||
|
||||
ctrl: stdio_client_controller.StdioClientController | ws_client_controller.WebSocketClientController
|
||||
@@ -40,6 +42,9 @@ class PluginRuntimeConnector:
|
||||
[PluginRuntimeConnector], typing.Coroutine[typing.Any, typing.Any, None]
|
||||
]
|
||||
|
||||
is_enable_plugin: bool = True
|
||||
"""Mark if the plugin system is enabled"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap: app.Application,
|
||||
@@ -49,8 +54,22 @@ class PluginRuntimeConnector:
|
||||
):
|
||||
self.ap = ap
|
||||
self.runtime_disconnect_callback = runtime_disconnect_callback
|
||||
self.is_enable_plugin = self.ap.instance_config.data.get('plugin', {}).get('enable', True)
|
||||
|
||||
async def heartbeat_loop(self):
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
try:
|
||||
await self.ping_plugin_runtime()
|
||||
self.ap.logger.debug('Heartbeat to plugin runtime success.')
|
||||
except Exception as e:
|
||||
self.ap.logger.debug(f'Failed to heartbeat to plugin runtime: {e}')
|
||||
|
||||
async def initialize(self):
|
||||
if not self.is_enable_plugin:
|
||||
self.ap.logger.info('Plugin system is disabled.')
|
||||
return
|
||||
|
||||
async def new_connection_callback(connection: base_connection.Connection):
|
||||
async def disconnect_callback(rchandler: handler.RuntimeConnectionHandler) -> bool:
|
||||
if platform.get_platform() == 'docker' or platform.use_websocket_to_connect_plugin_runtime():
|
||||
@@ -64,6 +83,7 @@ class PluginRuntimeConnector:
|
||||
return False
|
||||
|
||||
self.handler = handler.RuntimeConnectionHandler(connection, disconnect_callback, self.ap)
|
||||
|
||||
self.handler_task = asyncio.create_task(self.handler.run())
|
||||
_ = await self.handler.ping()
|
||||
self.ap.logger.info('Connected to plugin runtime.')
|
||||
@@ -77,8 +97,13 @@ class PluginRuntimeConnector:
|
||||
'runtime_ws_url', 'ws://langbot_plugin_runtime:5400/control/ws'
|
||||
)
|
||||
|
||||
async def make_connection_failed_callback(ctrl: ws_client_controller.WebSocketClientController) -> None:
|
||||
self.ap.logger.error('Failed to connect to plugin runtime, trying to reconnect...')
|
||||
async def make_connection_failed_callback(
|
||||
ctrl: ws_client_controller.WebSocketClientController, exc: Exception = None
|
||||
) -> None:
|
||||
if exc is not None:
|
||||
self.ap.logger.error(f'Failed to connect to plugin runtime({ws_url}): {exc}')
|
||||
else:
|
||||
self.ap.logger.error(f'Failed to connect to plugin runtime({ws_url}), trying to reconnect...')
|
||||
await self.runtime_disconnect_callback(self)
|
||||
|
||||
self.ctrl = ws_client_controller.WebSocketClientController(
|
||||
@@ -98,17 +123,34 @@ class PluginRuntimeConnector:
|
||||
)
|
||||
task = self.ctrl.run(new_connection_callback)
|
||||
|
||||
if self.heartbeat_task is None:
|
||||
self.heartbeat_task = asyncio.create_task(self.heartbeat_loop())
|
||||
|
||||
asyncio.create_task(task)
|
||||
|
||||
async def initialize_plugins(self):
|
||||
pass
|
||||
|
||||
async def ping_plugin_runtime(self):
|
||||
if not hasattr(self, 'handler'):
|
||||
raise Exception('Plugin runtime is not connected')
|
||||
|
||||
return await self.handler.ping()
|
||||
|
||||
async def install_plugin(
|
||||
self,
|
||||
install_source: PluginInstallSource,
|
||||
install_info: dict[str, Any],
|
||||
task_context: taskmgr.TaskContext | None = None,
|
||||
):
|
||||
if install_source == PluginInstallSource.LOCAL:
|
||||
# transfer file before install
|
||||
file_bytes = install_info['plugin_file']
|
||||
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
|
||||
install_info['plugin_file_key'] = file_key
|
||||
del install_info['plugin_file']
|
||||
self.ap.logger.info(f'Transfered file {file_key} to plugin runtime')
|
||||
|
||||
async for ret in self.handler.install_plugin(install_source.value, install_info):
|
||||
current_action = ret.get('current_action', None)
|
||||
if current_action is not None:
|
||||
@@ -167,6 +209,8 @@ class PluginRuntimeConnector:
|
||||
) -> context.EventContext:
|
||||
event_ctx = context.EventContext.from_event(event)
|
||||
|
||||
if not self.is_enable_plugin:
|
||||
return event_ctx
|
||||
event_ctx_result = await self.handler.emit_event(event_ctx.model_dump(serialize_as_any=True))
|
||||
|
||||
event_ctx = context.EventContext.model_validate(event_ctx_result['event_context'])
|
||||
@@ -197,6 +241,10 @@ class PluginRuntimeConnector:
|
||||
yield cmd_ret
|
||||
|
||||
def dispose(self):
|
||||
if isinstance(self.ctrl, stdio_client_controller.StdioClientController):
|
||||
if self.is_enable_plugin and isinstance(self.ctrl, stdio_client_controller.StdioClientController):
|
||||
self.ap.logger.info('Terminating plugin runtime process...')
|
||||
self.ctrl.process.terminate()
|
||||
|
||||
if self.heartbeat_task is not None:
|
||||
self.heartbeat_task.cancel()
|
||||
self.heartbeat_task = None
|
||||
|
||||
@@ -536,7 +536,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
{
|
||||
'event_context': event_context,
|
||||
},
|
||||
timeout=30,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -546,7 +546,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.LIST_TOOLS,
|
||||
{},
|
||||
timeout=10,
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
return result['tools']
|
||||
@@ -560,7 +560,18 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
'plugin_name': plugin_name,
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
plugin_icon_file_key = result['plugin_icon_file_key']
|
||||
mime_type = result['mime_type']
|
||||
|
||||
plugin_icon_bytes = await self.read_local_file(plugin_icon_file_key)
|
||||
|
||||
await self.delete_local_file(plugin_icon_file_key)
|
||||
|
||||
return {
|
||||
'plugin_icon_base64': base64.b64encode(plugin_icon_bytes).decode('utf-8'),
|
||||
'mime_type': mime_type,
|
||||
}
|
||||
|
||||
async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call tool"""
|
||||
@@ -570,7 +581,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
'tool_name': tool_name,
|
||||
'tool_parameters': parameters,
|
||||
},
|
||||
timeout=30,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
return result['tool_response']
|
||||
@@ -591,7 +602,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
{
|
||||
'command_context': command_context,
|
||||
},
|
||||
timeout=30,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
async for ret in gen:
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import dashscope
|
||||
import openai
|
||||
|
||||
from . import modelscopechatcmpl
|
||||
from .. import requester
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
|
||||
@@ -15,3 +20,211 @@ class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
|
||||
'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
async def _closure_stream(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args['tools'] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages.copy()
|
||||
|
||||
is_use_dashscope_call = False # 是否使用阿里原生库调用
|
||||
is_enable_multi_model = True # 是否支持多轮对话
|
||||
use_time_num = 0 # 模型已调用次数,防止存在多文件时重复调用
|
||||
use_time_ids = [] # 已调用的ID列表
|
||||
message_id = 0 # 记录消息序号
|
||||
|
||||
for msg in messages:
|
||||
# print(msg)
|
||||
if 'content' in msg and isinstance(msg['content'], list):
|
||||
for me in msg['content']:
|
||||
if me['type'] == 'image_base64':
|
||||
me['image_url'] = {'url': me['image_base64']}
|
||||
me['type'] = 'image_url'
|
||||
del me['image_base64']
|
||||
elif me['type'] == 'file_url' and '.' in me.get('file_name', ''):
|
||||
# 1. 视频文件推理
|
||||
# https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2845871
|
||||
file_type = me.get('file_name').lower().split('.')[-1]
|
||||
if file_type in ['mp4', 'avi', 'mkv', 'mov', 'flv', 'wmv']:
|
||||
me['type'] = 'video_url'
|
||||
me['video_url'] = {'url': me['file_url']}
|
||||
del me['file_url']
|
||||
del me['file_name']
|
||||
use_time_num +=1
|
||||
use_time_ids.append(message_id)
|
||||
is_enable_multi_model = False
|
||||
# 2. 语音文件识别, 无法通过openai的audio字段传递,暂时不支持
|
||||
# https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2979031
|
||||
elif file_type in ['aac', 'amr', 'aiff', 'flac', 'm4a',
|
||||
'mp3', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma']:
|
||||
me['audio'] = me['file_url']
|
||||
me['type'] = 'audio'
|
||||
del me['file_url']
|
||||
del me['type']
|
||||
del me['file_name']
|
||||
is_use_dashscope_call = True
|
||||
use_time_num +=1
|
||||
use_time_ids.append(message_id)
|
||||
is_enable_multi_model = False
|
||||
message_id += 1
|
||||
|
||||
# 使用列表推导式,保留不在 use_time_ids[:-1] 中的元素,仅保留最后一个多媒体消息
|
||||
if not is_enable_multi_model and use_time_num > 1:
|
||||
messages = [msg for idx, msg in enumerate(messages) if idx not in use_time_ids[:-1]]
|
||||
|
||||
if not is_enable_multi_model:
|
||||
messages = [msg for msg in messages if 'resp_message_id' not in msg]
|
||||
|
||||
args['messages'] = messages
|
||||
args['stream'] = True
|
||||
|
||||
# 流式处理状态
|
||||
# tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
||||
chunk_idx = 0
|
||||
thinking_started = False
|
||||
thinking_ended = False
|
||||
role = 'assistant' # 默认角色
|
||||
|
||||
if is_use_dashscope_call:
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx"
|
||||
api_key=use_model.token_mgr.get_token(),
|
||||
model=use_model.model_entity.name,
|
||||
messages=messages,
|
||||
result_format="message",
|
||||
asr_options={
|
||||
# "language": "zh", # 可选,若已知音频的语种,可通过该参数指定待识别语种,以提升识别准确率
|
||||
"enable_lid": True,
|
||||
"enable_itn": False
|
||||
},
|
||||
stream=True
|
||||
)
|
||||
content_length_list = []
|
||||
previous_length = 0 # 记录上一次的内容长度
|
||||
for res in response:
|
||||
chunk = res["output"]
|
||||
# 解析 chunk 数据
|
||||
if hasattr(chunk, 'choices') and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta_content = choice["message"].content[0]["text"]
|
||||
finish_reason = choice["finish_reason"]
|
||||
content_length_list.append(len(delta_content))
|
||||
else:
|
||||
delta_content = ""
|
||||
finish_reason = None
|
||||
|
||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
||||
if chunk_idx == 0 and not delta_content:
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
# 检查 content_length_list 是否有足够的数据
|
||||
if len(content_length_list) >= 2:
|
||||
now_content = delta_content[previous_length: content_length_list[-1]]
|
||||
previous_length = content_length_list[-1] # 更新上一次的长度
|
||||
else:
|
||||
now_content = delta_content # 第一次循环时直接使用 delta_content
|
||||
previous_length = len(delta_content) # 更新上一次的长度
|
||||
|
||||
# 构建 MessageChunk - 只包含增量内容
|
||||
chunk_data = {
|
||||
'role': role,
|
||||
'content': now_content if now_content else None,
|
||||
'is_final': bool(finish_reason) and finish_reason != "null",
|
||||
}
|
||||
|
||||
# 移除 None 值
|
||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
||||
yield provider_message.MessageChunk(**chunk_data)
|
||||
chunk_idx += 1
|
||||
else:
|
||||
async for chunk in self._req_stream(args, extra_body=extra_args):
|
||||
# 解析 chunk 数据
|
||||
if hasattr(chunk, 'choices') and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
||||
finish_reason = getattr(choice, 'finish_reason', None)
|
||||
else:
|
||||
delta = {}
|
||||
finish_reason = None
|
||||
|
||||
# 从第一个 chunk 获取 role,后续使用这个 role
|
||||
if 'role' in delta and delta['role']:
|
||||
role = delta['role']
|
||||
|
||||
# 获取增量内容
|
||||
delta_content = delta.get('content', '')
|
||||
reasoning_content = delta.get('reasoning_content', '')
|
||||
|
||||
# 处理 reasoning_content
|
||||
if reasoning_content:
|
||||
# accumulated_reasoning += reasoning_content
|
||||
# 如果设置了 remove_think,跳过 reasoning_content
|
||||
if remove_think:
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
# 第一次出现 reasoning_content,添加 <think> 开始标签
|
||||
if not thinking_started:
|
||||
thinking_started = True
|
||||
delta_content = '<think>\n' + reasoning_content
|
||||
else:
|
||||
# 继续输出 reasoning_content
|
||||
delta_content = reasoning_content
|
||||
elif thinking_started and not thinking_ended and delta_content:
|
||||
# reasoning_content 结束,normal content 开始,添加 </think> 结束标签
|
||||
thinking_ended = True
|
||||
delta_content = '\n</think>\n' + delta_content
|
||||
|
||||
# 处理工具调用增量
|
||||
if delta.get('tool_calls'):
|
||||
for tool_call in delta['tool_calls']:
|
||||
if tool_call['id'] != '':
|
||||
tool_id = tool_call['id']
|
||||
if tool_call['function']['name'] is not None:
|
||||
tool_name = tool_call['function']['name']
|
||||
|
||||
if tool_call['type'] is None:
|
||||
tool_call['type'] = 'function'
|
||||
tool_call['id'] = tool_id
|
||||
tool_call['function']['name'] = tool_name
|
||||
tool_call['function']['arguments'] = (
|
||||
'' if tool_call['function']['arguments'] is None else tool_call['function']['arguments']
|
||||
)
|
||||
|
||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
||||
if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'):
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
# 构建 MessageChunk - 只包含增量内容
|
||||
chunk_data = {
|
||||
'role': role,
|
||||
'content': delta_content if delta_content else None,
|
||||
'tool_calls': delta.get('tool_calls'),
|
||||
'is_final': bool(finish_reason),
|
||||
}
|
||||
|
||||
# 移除 None 值
|
||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
||||
|
||||
yield provider_message.MessageChunk(**chunk_data)
|
||||
chunk_idx += 1
|
||||
# return
|
||||
|
||||
1
pkg/provider/modelmgr/requesters/tokenpony.svg
Normal file
1
pkg/provider/modelmgr/requesters/tokenpony.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="450" height="280" viewBox="0 0 450 280" class="cursor-pointer h-24 flex-shrink-0 w-149"><g fill="none" fill-rule="nonzero"><path fill="#0005DE" d="M97.705 6.742c58.844 0 90.962 34.353 90.962 98.341v21.843c-15.118-2.479-30.297-6.573-45.558-12.3v-9.543c0-35.97-15.564-56.281-45.404-56.281s-45.404 20.31-45.404 56.281v72.48c0 36.117 15.65 56.818 45.404 56.818 26.78 0 42.133-16.768 44.936-46.452q22.397 6.473 44.905 9.356c-6.15 51.52-37.492 79.155-89.841 79.155-58.678 0-90.963-34.72-90.963-98.878v-72.479c0-63.988 32.119-98.34 90.963-98.34m253.627 0c58.844 0 90.963 34.353 90.963 98.341v72.48c0 64.157-32.285 98.877-90.963 98.877-52.438 0-83.797-27.729-89.874-79.415 15-2.026 29.965-5.252 44.887-9.67 2.658 30.042 18.036 47.026 44.987 47.026 29.755 0 45.404-20.7 45.404-56.819v-72.479c0-35.97-15.564-56.281-45.404-56.281s-45.403 20.31-45.403 56.281v8.778c-15.262 5.868-30.44 10.104-45.559 12.725v-21.503c0-63.988 32.118-98.34 90.962-98.34m-164.37 140.026.57.09.831.127-.83-.128a234.5 234.5 0 0 0 35.979 2.79q18.408.002 36.858-2.928l1.401-.226a242 242 0 0 0 1.45-.244l-1.037.175q.729-.12 1.458-.247l-.421.072 1.26-.219-.84.147a244 244 0 0 0 2.8-.5l-.792.144q.648-.117 1.298-.239l-.506.094q.66-.122 1.322-.248l-.816.154q.759-.142 1.518-.289l-.702.135a247 247 0 0 0 5.364-1.084l-.463.098a250 250 0 0 0 3.928-.864l-.785.178 1.45-.33-.665.152q.597-.137 1.193-.276l-.528.123a253 253 0 0 0 3.685-.882l-.254.063q.683-.168 1.366-.34l-1.112.277q.809-.2 1.618-.405l-.506.128q.818-.206 1.634-.417l-1.128.289q.71-.18 1.419-.365l1.506-.397a259 259 0 0 0 1.804-.488l-.433.119a261 261 0 0 0 3.751-1.053l-.681.196a264 264 0 0 0 1.735-.502l-1.054.306q.636-.184 1.272-.37l-.218.064 1.238-.366-1.02.302a266 266 0 0 0 2.936-.882l-1.026.312q.71-.214 1.42-.433l-.394.121q.675-.207 1.35-.418l-.955.297q.8-.246 1.6-.499l-.645.202q.86-.269 1.72-.543l-1.076.341q.666-.21 1.33-.423l-.254.082q.833-.266 1.665-.539l-1.41.457q.874-.28 1.75-.568l-.34.111q.702-.229 1.403-.462l-1.063.351q.818-.269 1.634-.542l-.571.19a276 276 0 0 0 4.038-1.378l-.735.256q.657-.228 1.315-.46l-.58.204q16.86-5.903 33.78-14.256l-7.114-12.453 42.909 6.553-13.148 45.541-7.734-13.537q-23.832 11.94-47.755 19.504l-.199.063a298 298 0 0 1-11.65 3.412 288 288 0 0 1-10.39 2.603 280 280 0 0 1-11.677 2.431 273 273 0 0 1-11.643 1.903 263.5 263.5 0 0 1-36.858 2.599q-17.437 0-34.844-2.323l-.227-.03q-.635-.085-1.27-.174l1.497.204a268 268 0 0 1-13.673-2.182 275 275 0 0 1-12.817-2.697 282 282 0 0 1-11.859-3.057 291 291 0 0 1-7.21-2.123c-17.23-5.314-34.43-12.334-51.59-21.051l-8.258 14.455-13.148-45.541 42.909-6.553-6.594 11.544q18.421 9.24 36.776 15.572l1.316.45 1.373.462-.831-.278q.795.267 1.589.53l-.758-.252q.632.211 1.264.419l-.506-.167q.642.212 1.284.42l-.778-.253a271 271 0 0 0 3.914 1.251l-.227-.07a267 267 0 0 0 3.428 1.046l-.194-.058 1.315.389-1.121-.331q.864.256 1.73.508l-.609-.177q.826.241 1.651.478l-1.043-.3 1.307.375-.264-.075q.802.228 1.603.452l-1.34-.377q1.034.294 2.067.58l-.727-.203q.713.2 1.426.394l-.699-.192q.62.171 1.237.338l-.538-.146a259 259 0 0 0 3.977 1.051l-.66-.17q.683.177 1.367.35l-.707-.18q.687.175 1.373.348l-.666-.168q.738.186 1.475.368l-.809-.2q.716.179 1.43.353l-.621-.153a253 253 0 0 0 3.766.898l-.308-.07q.735.17 1.472.336l-1.164-.266q.747.173 1.496.34l-.332-.074q.845.19 1.69.374l-1.358-.3q.932.21 1.864.41l-.505-.11q.726.159 1.452.313l-.947-.203q.72.156 1.44.307l-.493-.104q.684.144 1.368.286l-.875-.182q.743.155 1.485.306l-.61-.124q.932.192 1.864.376l-1.254-.252q.904.184 1.809.361l-.555-.109q.752.15 1.504.293l-.95-.184q.69.135 1.377.265l-.427-.081q.784.15 1.569.295l-1.142-.214q.717.136 1.434.268l-.292-.054a244 244 0 0 0 3.808.673l-.68-.116 1.063.18-.383-.064q1.076.18 2.152.352z"></path></g></svg>
|
||||
|
After Width: | Height: | Size: 3.6 KiB |
31
pkg/provider/modelmgr/requesters/tokenpony.yaml
Normal file
31
pkg/provider/modelmgr/requesters/tokenpony.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
apiVersion: v1
|
||||
kind: LLMAPIRequester
|
||||
metadata:
|
||||
name: tokenpony-chat-completions
|
||||
label:
|
||||
en_US: TokenPony
|
||||
zh_Hans: 小马算力
|
||||
icon: tokenpony.svg
|
||||
spec:
|
||||
config:
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_Hans: 基础 URL
|
||||
type: string
|
||||
required: true
|
||||
default: "https://api.tokenpony.cn/v1"
|
||||
- name: timeout
|
||||
label:
|
||||
en_US: Timeout
|
||||
zh_Hans: 超时时间
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
- text-embedding
|
||||
execution:
|
||||
python:
|
||||
path: ./tokenponychatcmpl.py
|
||||
attr: TokenPonyChatCompletions
|
||||
17
pkg/provider/modelmgr/requesters/tokenponychatcmpl.py
Normal file
17
pkg/provider/modelmgr/requesters/tokenponychatcmpl.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import openai
|
||||
|
||||
from . import chatcmpl
|
||||
|
||||
|
||||
class TokenPonyChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""TokenPony ChatCompletion API 请求器"""
|
||||
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base_url': 'https://api.tokenpony.cn/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
205
pkg/provider/runners/tboxapi.py
Normal file
205
pkg/provider/runners/tboxapi.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import base64
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from tboxsdk.tbox import TboxClient
|
||||
from tboxsdk.model.file import File, FileType
|
||||
|
||||
from .. import runner
|
||||
from ...core import app
|
||||
from ...utils import image
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
class TboxAPIError(Exception):
|
||||
"""TBox API 请求失败"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
@runner.runner_class('tbox-app-api')
|
||||
class TboxAPIRunner(runner.RequestRunner):
|
||||
"蚂蚁百宝箱API对话请求器"
|
||||
|
||||
# 运行器内部使用的配置
|
||||
app_id: str # 蚂蚁百宝箱平台中的应用ID
|
||||
api_key: str # 在蚂蚁百宝箱平台中申请的令牌
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
"""初始化"""
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
# 初始化Tbox 参数配置
|
||||
self.app_id = self.pipeline_config['ai']['tbox-app-api']['app-id']
|
||||
self.api_key = self.pipeline_config['ai']['tbox-app-api']['api-key']
|
||||
|
||||
# 初始化Tbox client
|
||||
self.tbox_client = TboxClient(authorization=self.api_key)
|
||||
|
||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]:
|
||||
"""预处理用户消息,提取纯文本,并将图片上传到 Tbox 服务
|
||||
|
||||
Returns:
|
||||
tuple[str, list[str]]: 纯文本和图片的 Tbox 文件ID
|
||||
"""
|
||||
plain_text = ''
|
||||
image_ids = []
|
||||
|
||||
if isinstance(query.user_message.content, list):
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == 'text':
|
||||
plain_text += ce.text
|
||||
elif ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
# 创建临时文件
|
||||
file_bytes = base64.b64decode(image_b64)
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=f'.{image_format}', delete=False) as tmp_file:
|
||||
tmp_file.write(file_bytes)
|
||||
tmp_file_path = tmp_file.name
|
||||
file_upload_resp = self.tbox_client.upload_file(
|
||||
tmp_file_path
|
||||
)
|
||||
image_id = file_upload_resp.get("data", "")
|
||||
image_ids.append(image_id)
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_file_path):
|
||||
os.unlink(tmp_file_path)
|
||||
elif isinstance(query.user_message.content, str):
|
||||
plain_text = query.user_message.content
|
||||
|
||||
return plain_text, image_ids
|
||||
|
||||
async def _agent_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""TBox 智能体对话请求"""
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
remove_think = self.pipeline_config['output'].get('misc', {}).get('remove-think')
|
||||
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
|
||||
# 获取Tbox的conversation_id
|
||||
conversation_id = query.session.using_conversation.uuid or None
|
||||
|
||||
files = None
|
||||
if image_ids:
|
||||
files = [
|
||||
File(file_id=image_id, type=FileType.IMAGE)
|
||||
for image_id in image_ids
|
||||
]
|
||||
|
||||
# 发送对话请求
|
||||
response = self.tbox_client.chat(
|
||||
app_id=self.app_id, # Tbox中智能体应用的ID
|
||||
user_id=query.bot_uuid, # 用户ID
|
||||
query=plain_text, # 用户输入的文本信息
|
||||
stream=is_stream, # 是否流式输出
|
||||
conversation_id=conversation_id, # 会话ID,为None时Tbox会自动创建一个新会话
|
||||
files=files, # 图片内容
|
||||
)
|
||||
|
||||
if is_stream:
|
||||
# 解析Tbox流式输出内容,并发送给上游
|
||||
for chunk in self._process_stream_message(response, query, remove_think):
|
||||
yield chunk
|
||||
else:
|
||||
message = self._process_non_stream_message(response, query, remove_think)
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=message,
|
||||
)
|
||||
|
||||
def _process_non_stream_message(self, response: typing.Dict, query: pipeline_query.Query, remove_think: bool):
|
||||
if response.get('errorCode') != "0":
|
||||
raise TboxAPIError(f'Tbox API 请求失败: {response.get("errorMsg", "")}')
|
||||
payload = response.get('data', {})
|
||||
conversation_id = payload.get('conversationId', '')
|
||||
query.session.using_conversation.uuid = conversation_id
|
||||
thinking_content = payload.get('reasoningContent', [])
|
||||
result = ""
|
||||
if thinking_content and not remove_think:
|
||||
result += f'<think>\n{thinking_content[0].get("text", "")}\n</think>\n'
|
||||
content = payload.get('result', [])
|
||||
if content:
|
||||
result += content[0].get('chunk', '')
|
||||
return result
|
||||
|
||||
def _process_stream_message(self, response: typing.Generator[dict], query: pipeline_query.Query, remove_think: bool):
|
||||
idx_msg = 0
|
||||
pending_content = ''
|
||||
conversation_id = None
|
||||
think_start = False
|
||||
think_end = False
|
||||
for chunk in response:
|
||||
if chunk.get('type', '') == 'chunk':
|
||||
"""
|
||||
Tbox返回的消息内容chunk结构
|
||||
{'lane': 'default', 'payload': {'conversationId': '20250918tBI947065406', 'messageId': '20250918TB1f53230954', 'text': '️'}, 'type': 'chunk'}
|
||||
"""
|
||||
# 如果包含思考过程,拼接</think>
|
||||
if think_start and not think_end:
|
||||
pending_content += '\n</think>\n'
|
||||
think_end = True
|
||||
|
||||
payload = chunk.get('payload', {})
|
||||
if not conversation_id:
|
||||
conversation_id = payload.get('conversationId')
|
||||
query.session.using_conversation.uuid = conversation_id
|
||||
if payload.get('text'):
|
||||
idx_msg += 1
|
||||
pending_content += payload.get('text')
|
||||
elif chunk.get('type', '') == 'thinking' and not remove_think:
|
||||
"""
|
||||
Tbox返回的思考过程chunk结构
|
||||
{'payload': '{"ext_data":{"text":"日期"},"event":"flow.node.llm.thinking","entity":{"node_type":"text-completion","execute_id":"6","group_id":0,"parent_execute_id":"6","node_name":"模型推理","node_id":"TC_5u6gl0"}}', 'type': 'thinking'}
|
||||
"""
|
||||
payload = json.loads(chunk.get('payload', '{}'))
|
||||
if payload.get('ext_data', {}).get('text'):
|
||||
idx_msg += 1
|
||||
content = payload.get('ext_data', {}).get('text')
|
||||
if not think_start:
|
||||
think_start = True
|
||||
pending_content += f'<think>\n{content}'
|
||||
else:
|
||||
pending_content += content
|
||||
elif chunk.get('type', '') == 'error':
|
||||
raise TboxAPIError(
|
||||
f'Tbox API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
|
||||
if idx_msg % 8 == 0:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Tbox不返回END事件,默认发一个最终消息
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行"""
|
||||
msg_seq = 0
|
||||
async for msg in self._agent_messages(query):
|
||||
if isinstance(msg, provider_message.MessageChunk):
|
||||
msg_seq += 1
|
||||
msg.msg_sequence = msg_seq
|
||||
yield msg
|
||||
@@ -1,4 +1,4 @@
|
||||
semantic_version = 'v4.3.1'
|
||||
semantic_version = 'v4.3.7b1'
|
||||
|
||||
required_database_version = 8
|
||||
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "langbot"
|
||||
version = "4.3.1"
|
||||
version = "4.3.7b1"
|
||||
description = "Easy-to-use global IM bot platform designed for LLM era"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10.1,<4.0"
|
||||
@@ -62,9 +62,10 @@ dependencies = [
|
||||
"langchain>=0.2.0",
|
||||
"chromadb>=0.4.24",
|
||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||
"langbot-plugin==0.1.1",
|
||||
"langbot-plugin==0.1.4b2",
|
||||
"asyncpg>=0.30.0",
|
||||
"line-bot-sdk>=3.19.0"
|
||||
"line-bot-sdk>=3.19.0",
|
||||
"tboxsdk>=0.0.10",
|
||||
]
|
||||
keywords = [
|
||||
"bot",
|
||||
@@ -102,6 +103,7 @@ dev = [
|
||||
"pre-commit>=4.2.0",
|
||||
"pytest>=8.4.1",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"pytest-cov>=7.0.0",
|
||||
"ruff>=0.11.9",
|
||||
]
|
||||
|
||||
|
||||
39
pytest.ini
Normal file
39
pytest.ini
Normal file
@@ -0,0 +1,39 @@
|
||||
[pytest]
|
||||
# Test discovery patterns
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
|
||||
# Test paths
|
||||
testpaths = tests
|
||||
|
||||
# Asyncio configuration
|
||||
asyncio_mode = auto
|
||||
|
||||
# Output options
|
||||
addopts =
|
||||
-v
|
||||
--strict-markers
|
||||
--tb=short
|
||||
--disable-warnings
|
||||
|
||||
# Markers
|
||||
markers =
|
||||
asyncio: mark test as async
|
||||
unit: mark test as unit test
|
||||
integration: mark test as integration test
|
||||
slow: mark test as slow running
|
||||
|
||||
# Coverage options (when using pytest-cov)
|
||||
[coverage:run]
|
||||
source = pkg
|
||||
omit =
|
||||
*/tests/*
|
||||
*/test_*.py
|
||||
*/__pycache__/*
|
||||
*/site-packages/*
|
||||
|
||||
[coverage:report]
|
||||
precision = 2
|
||||
show_missing = True
|
||||
skip_covered = False
|
||||
31
run_tests.sh
Executable file
31
run_tests.sh
Executable file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to run all unit tests
|
||||
# This script helps avoid circular import issues by setting up the environment properly
|
||||
|
||||
set -e
|
||||
|
||||
echo "Setting up test environment..."
|
||||
|
||||
# Activate virtual environment if it exists
|
||||
if [ -d ".venv" ]; then
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
|
||||
# Check if pytest is installed
|
||||
if ! command -v pytest &> /dev/null; then
|
||||
echo "Installing test dependencies..."
|
||||
pip install pytest pytest-asyncio pytest-cov
|
||||
fi
|
||||
|
||||
echo "Running all unit tests..."
|
||||
|
||||
# Run tests with coverage
|
||||
pytest tests/unit_tests/ -v --tb=short \
|
||||
--cov=pkg \
|
||||
--cov-report=xml \
|
||||
"$@"
|
||||
|
||||
echo ""
|
||||
echo "Test run complete!"
|
||||
echo "Coverage report saved to coverage.xml"
|
||||
@@ -38,6 +38,7 @@ vdb:
|
||||
port: 6333
|
||||
api_key: ''
|
||||
plugin:
|
||||
enable: true
|
||||
runtime_ws_url: 'ws://langbot_plugin_runtime:5400/control/ws'
|
||||
enable_marketplace: true
|
||||
cloud_service_url: 'https://space.langbot.app'
|
||||
@@ -83,7 +83,7 @@
|
||||
"output": {
|
||||
"long-text-processing": {
|
||||
"threshold": 1000,
|
||||
"strategy": "forward",
|
||||
"strategy": "none",
|
||||
"font-path": ""
|
||||
},
|
||||
"force-delay": {
|
||||
|
||||
@@ -23,6 +23,10 @@ stages:
|
||||
label:
|
||||
en_US: Local Agent
|
||||
zh_Hans: 内置 Agent
|
||||
- name: tbox-app-api
|
||||
label:
|
||||
en_US: Tbox App API
|
||||
zh_Hans: 蚂蚁百宝箱平台 API
|
||||
- name: dify-service-api
|
||||
label:
|
||||
en_US: Dify Service API
|
||||
@@ -82,6 +86,26 @@ stages:
|
||||
type: knowledge-base-selector
|
||||
required: false
|
||||
default: ''
|
||||
- name: tbox-app-api
|
||||
label:
|
||||
en_US: Tbox App API
|
||||
zh_Hans: 蚂蚁百宝箱平台 API
|
||||
description:
|
||||
en_US: Configure the Tbox App API of the pipeline
|
||||
zh_Hans: 配置蚂蚁百宝箱平台 API
|
||||
config:
|
||||
- name: api-key
|
||||
label:
|
||||
en_US: API Key
|
||||
zh_Hans: API 密钥
|
||||
type: string
|
||||
required: true
|
||||
- name: app-id
|
||||
label:
|
||||
en_US: App ID
|
||||
zh_Hans: 应用 ID
|
||||
type: string
|
||||
required: true
|
||||
- name: dify-service-api
|
||||
label:
|
||||
en_US: Dify Service API
|
||||
|
||||
@@ -27,7 +27,7 @@ stages:
|
||||
zh_Hans: 长文本的处理策略
|
||||
type: select
|
||||
required: true
|
||||
default: forward
|
||||
default: none
|
||||
options:
|
||||
- name: forward
|
||||
label:
|
||||
@@ -37,6 +37,10 @@ stages:
|
||||
label:
|
||||
en_US: Convert to Image
|
||||
zh_Hans: 转换为图片
|
||||
- name: none
|
||||
label:
|
||||
en_US: None
|
||||
zh_Hans: 不处理
|
||||
- name: font-path
|
||||
label:
|
||||
en_US: Font Path
|
||||
|
||||
183
tests/README.md
Normal file
183
tests/README.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# LangBot Test Suite
|
||||
|
||||
This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
|
||||
|
||||
## Important Note
|
||||
|
||||
Due to circular import dependencies in the pipeline module structure, the test files use **lazy imports** via `importlib.import_module()` instead of direct imports. This ensures tests can run without triggering circular import errors.
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
tests/
|
||||
├── pipeline/ # Pipeline stage tests
|
||||
│ ├── conftest.py # Shared fixtures and test infrastructure
|
||||
│ ├── test_simple.py # Basic infrastructure tests (always pass)
|
||||
│ ├── test_bansess.py # BanSessionCheckStage tests
|
||||
│ ├── test_ratelimit.py # RateLimit stage tests
|
||||
│ ├── test_preproc.py # PreProcessor stage tests
|
||||
│ ├── test_respback.py # SendResponseBackStage tests
|
||||
│ ├── test_resprule.py # GroupRespondRuleCheckStage tests
|
||||
│ ├── test_pipelinemgr.py # PipelineManager tests
|
||||
│ └── test_stages_integration.py # Integration tests
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Test Architecture
|
||||
|
||||
### Fixtures (`conftest.py`)
|
||||
|
||||
The test suite uses a centralized fixture system that provides:
|
||||
|
||||
- **MockApplication**: Comprehensive mock of the Application object with all dependencies
|
||||
- **Mock objects**: Pre-configured mocks for Session, Conversation, Model, Adapter
|
||||
- **Sample data**: Ready-to-use Query objects, message chains, and configurations
|
||||
- **Helper functions**: Utilities for creating results and common assertions
|
||||
|
||||
### Design Principles
|
||||
|
||||
1. **Isolation**: Each test is independent and doesn't rely on external systems
|
||||
2. **Mocking**: All external dependencies are mocked to ensure fast, reliable tests
|
||||
3. **Coverage**: Tests cover happy paths, edge cases, and error conditions
|
||||
4. **Extensibility**: Easy to add new tests by reusing existing fixtures
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Using the test runner script (recommended)
|
||||
```bash
|
||||
bash run_tests.sh
|
||||
```
|
||||
|
||||
This script automatically:
|
||||
- Activates the virtual environment
|
||||
- Installs test dependencies if needed
|
||||
- Runs tests with coverage
|
||||
- Generates HTML coverage report
|
||||
|
||||
### Manual test execution
|
||||
|
||||
#### Run all tests
|
||||
```bash
|
||||
pytest tests/pipeline/
|
||||
```
|
||||
|
||||
#### Run only simple tests (no imports, always pass)
|
||||
```bash
|
||||
pytest tests/pipeline/test_simple.py -v
|
||||
```
|
||||
|
||||
#### Run specific test file
|
||||
```bash
|
||||
pytest tests/pipeline/test_bansess.py -v
|
||||
```
|
||||
|
||||
#### Run with coverage
|
||||
```bash
|
||||
pytest tests/pipeline/ --cov=pkg/pipeline --cov-report=html
|
||||
```
|
||||
|
||||
#### Run specific test
|
||||
```bash
|
||||
pytest tests/pipeline/test_bansess.py::test_bansess_whitelist_allow -v
|
||||
```
|
||||
|
||||
### Known Issues
|
||||
|
||||
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
|
||||
|
||||
1. Make sure you're running from the project root directory
|
||||
2. Ensure the virtual environment is activated
|
||||
3. Try running `test_simple.py` first to verify the test infrastructure works
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
Tests are automatically run on:
|
||||
- Pull request opened
|
||||
- Pull request marked ready for review
|
||||
- Push to PR branch
|
||||
- Push to master/develop branches
|
||||
|
||||
The workflow runs tests on Python 3.10, 3.11, and 3.12 to ensure compatibility.
|
||||
|
||||
## Adding New Tests
|
||||
|
||||
### 1. For a new pipeline stage
|
||||
|
||||
Create a new test file `test_<stage_name>.py`:
|
||||
|
||||
```python
|
||||
"""
|
||||
<StageName> stage unit tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pkg.pipeline.<module>.<stage> import <StageClass>
|
||||
from pkg.pipeline import entities as pipeline_entities
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_basic_flow(mock_app, sample_query):
|
||||
"""Test basic flow"""
|
||||
stage = <StageClass>(mock_app)
|
||||
await stage.initialize({})
|
||||
|
||||
result = await stage.process(sample_query, '<StageName>')
|
||||
|
||||
assert result.result_type == pipeline_entities.ResultType.CONTINUE
|
||||
```
|
||||
|
||||
### 2. For additional fixtures
|
||||
|
||||
Add new fixtures to `conftest.py`:
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def my_custom_fixture():
|
||||
"""Description of fixture"""
|
||||
return create_test_data()
|
||||
```
|
||||
|
||||
### 3. For test data
|
||||
|
||||
Use the helper functions in `conftest.py`:
|
||||
|
||||
```python
|
||||
from tests.pipeline.conftest import create_stage_result, assert_result_continue
|
||||
|
||||
result = create_stage_result(
|
||||
result_type=pipeline_entities.ResultType.CONTINUE,
|
||||
query=sample_query
|
||||
)
|
||||
|
||||
assert_result_continue(result)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Test naming**: Use descriptive names that explain what's being tested
|
||||
2. **Arrange-Act-Assert**: Structure tests clearly with setup, execution, and verification
|
||||
3. **One assertion per test**: Focus each test on a single behavior
|
||||
4. **Mock appropriately**: Mock external dependencies, not the code under test
|
||||
5. **Use fixtures**: Reuse common test data through fixtures
|
||||
6. **Document tests**: Add docstrings explaining what each test validates
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Import errors
|
||||
Make sure you've installed the package in development mode:
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
### Async test failures
|
||||
Ensure you're using `@pytest.mark.asyncio` decorator for async tests.
|
||||
|
||||
### Mock not working
|
||||
Check that you're mocking at the right level and using `AsyncMock` for async functions.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- [ ] Add integration tests for full pipeline execution
|
||||
- [ ] Add performance benchmarks
|
||||
- [ ] Add mutation testing for better coverage quality
|
||||
- [ ] Add property-based testing with Hypothesis
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/unit_tests/__init__.py
Normal file
0
tests/unit_tests/__init__.py
Normal file
0
tests/unit_tests/pipeline/__init__.py
Normal file
0
tests/unit_tests/pipeline/__init__.py
Normal file
251
tests/unit_tests/pipeline/conftest.py
Normal file
251
tests/unit_tests/pipeline/conftest.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Shared test fixtures and configuration
|
||||
|
||||
This file provides infrastructure for all pipeline tests, including:
|
||||
- Mock object factories
|
||||
- Test fixtures
|
||||
- Common test helper functions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock
|
||||
from typing import Any
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
from pkg.pipeline import entities as pipeline_entities
|
||||
|
||||
|
||||
class MockApplication:
|
||||
"""Mock Application object providing all basic dependencies needed by stages"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = self._create_mock_logger()
|
||||
self.sess_mgr = self._create_mock_session_manager()
|
||||
self.model_mgr = self._create_mock_model_manager()
|
||||
self.tool_mgr = self._create_mock_tool_manager()
|
||||
self.plugin_connector = self._create_mock_plugin_connector()
|
||||
self.persistence_mgr = self._create_mock_persistence_manager()
|
||||
self.query_pool = self._create_mock_query_pool()
|
||||
self.instance_config = self._create_mock_instance_config()
|
||||
self.task_mgr = self._create_mock_task_manager()
|
||||
|
||||
def _create_mock_logger(self):
|
||||
logger = Mock()
|
||||
logger.debug = Mock()
|
||||
logger.info = Mock()
|
||||
logger.error = Mock()
|
||||
logger.warning = Mock()
|
||||
return logger
|
||||
|
||||
def _create_mock_session_manager(self):
|
||||
sess_mgr = AsyncMock()
|
||||
sess_mgr.get_session = AsyncMock()
|
||||
sess_mgr.get_conversation = AsyncMock()
|
||||
return sess_mgr
|
||||
|
||||
def _create_mock_model_manager(self):
|
||||
model_mgr = AsyncMock()
|
||||
model_mgr.get_model_by_uuid = AsyncMock()
|
||||
return model_mgr
|
||||
|
||||
def _create_mock_tool_manager(self):
|
||||
tool_mgr = AsyncMock()
|
||||
tool_mgr.get_all_tools = AsyncMock(return_value=[])
|
||||
return tool_mgr
|
||||
|
||||
def _create_mock_plugin_connector(self):
|
||||
plugin_connector = AsyncMock()
|
||||
plugin_connector.emit_event = AsyncMock()
|
||||
return plugin_connector
|
||||
|
||||
def _create_mock_persistence_manager(self):
|
||||
persistence_mgr = AsyncMock()
|
||||
persistence_mgr.execute_async = AsyncMock()
|
||||
return persistence_mgr
|
||||
|
||||
def _create_mock_query_pool(self):
|
||||
query_pool = Mock()
|
||||
query_pool.cached_queries = {}
|
||||
query_pool.queries = []
|
||||
query_pool.condition = AsyncMock()
|
||||
return query_pool
|
||||
|
||||
def _create_mock_instance_config(self):
|
||||
instance_config = Mock()
|
||||
instance_config.data = {
|
||||
'command': {'prefix': ['/', '!'], 'enable': True},
|
||||
'concurrency': {'pipeline': 10},
|
||||
}
|
||||
return instance_config
|
||||
|
||||
def _create_mock_task_manager(self):
|
||||
task_mgr = Mock()
|
||||
task_mgr.create_task = Mock()
|
||||
return task_mgr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
"""Provides Mock Application instance"""
|
||||
return MockApplication()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Provides Mock Session object"""
|
||||
session = Mock()
|
||||
session.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
session.launcher_id = 12345
|
||||
session._semaphore = AsyncMock()
|
||||
session._semaphore.locked = Mock(return_value=False)
|
||||
session._semaphore.acquire = AsyncMock()
|
||||
session._semaphore.release = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation():
|
||||
"""Provides Mock Conversation object"""
|
||||
conversation = Mock()
|
||||
conversation.uuid = 'test-conversation-uuid'
|
||||
|
||||
# Create mock prompt with copy method
|
||||
mock_prompt = Mock()
|
||||
mock_prompt.messages = []
|
||||
mock_prompt.copy = Mock(return_value=Mock(messages=[]))
|
||||
conversation.prompt = mock_prompt
|
||||
|
||||
# Create mock messages list with copy method
|
||||
mock_messages = Mock()
|
||||
mock_messages.copy = Mock(return_value=[])
|
||||
conversation.messages = mock_messages
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
"""Provides Mock Model object"""
|
||||
model = Mock()
|
||||
model.model_entity = Mock()
|
||||
model.model_entity.uuid = 'test-model-uuid'
|
||||
model.model_entity.abilities = ['func_call', 'vision']
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
"""Provides Mock Adapter object"""
|
||||
adapter = AsyncMock()
|
||||
adapter.is_stream_output_supported = AsyncMock(return_value=False)
|
||||
adapter.reply_message = AsyncMock()
|
||||
adapter.reply_message_chunk = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message_chain():
|
||||
"""Provides sample message chain"""
|
||||
return platform_message.MessageChain(
|
||||
[
|
||||
platform_message.Plain(text='Hello, this is a test message'),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message_event(sample_message_chain):
|
||||
"""Provides sample message event"""
|
||||
event = Mock()
|
||||
event.sender = Mock()
|
||||
event.sender.id = 12345
|
||||
event.time = 1609459200 # 2021-01-01 00:00:00
|
||||
return event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_query(sample_message_chain, sample_message_event, mock_adapter):
|
||||
"""Provides sample Query object - using model_construct to bypass validation"""
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
# Use model_construct to bypass Pydantic validation for test purposes
|
||||
query = pipeline_query.Query.model_construct(
|
||||
query_id='test-query-id',
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_chain=sample_message_chain,
|
||||
message_event=sample_message_event,
|
||||
adapter=mock_adapter,
|
||||
pipeline_uuid='test-pipeline-uuid',
|
||||
bot_uuid='test-bot-uuid',
|
||||
pipeline_config={
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
|
||||
},
|
||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||
'trigger': {'misc': {'combine-quote-message': False}},
|
||||
},
|
||||
session=None,
|
||||
prompt=None,
|
||||
messages=[],
|
||||
user_message=None,
|
||||
use_funcs=[],
|
||||
use_llm_model_uuid=None,
|
||||
variables={},
|
||||
resp_messages=[],
|
||||
resp_message_chain=None,
|
||||
current_stage_name=None
|
||||
)
|
||||
return query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pipeline_config():
|
||||
"""Provides sample pipeline configuration"""
|
||||
return {
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
|
||||
},
|
||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||
'trigger': {'misc': {'combine-quote-message': False}},
|
||||
'ratelimit': {'enable': True, 'algo': 'fixwin', 'window': 60, 'limit': 10},
|
||||
}
|
||||
|
||||
|
||||
def create_stage_result(
|
||||
result_type: pipeline_entities.ResultType,
|
||||
query: pipeline_query.Query,
|
||||
user_notice: str = '',
|
||||
console_notice: str = '',
|
||||
debug_notice: str = '',
|
||||
error_notice: str = '',
|
||||
) -> pipeline_entities.StageProcessResult:
|
||||
"""Helper function to create stage process result"""
|
||||
return pipeline_entities.StageProcessResult(
|
||||
result_type=result_type,
|
||||
new_query=query,
|
||||
user_notice=user_notice,
|
||||
console_notice=console_notice,
|
||||
debug_notice=debug_notice,
|
||||
error_notice=error_notice,
|
||||
)
|
||||
|
||||
|
||||
def assert_result_continue(result: pipeline_entities.StageProcessResult):
|
||||
"""Assert result is CONTINUE type"""
|
||||
assert result.result_type == pipeline_entities.ResultType.CONTINUE
|
||||
|
||||
|
||||
def assert_result_interrupt(result: pipeline_entities.StageProcessResult):
|
||||
"""Assert result is INTERRUPT type"""
|
||||
assert result.result_type == pipeline_entities.ResultType.INTERRUPT
|
||||
189
tests/unit_tests/pipeline/test_bansess.py
Normal file
189
tests/unit_tests/pipeline/test_bansess.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
BanSessionCheckStage unit tests
|
||||
|
||||
Tests the actual BanSessionCheckStage implementation from pkg.pipeline.bansess
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from importlib import import_module
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
|
||||
def get_modules():
|
||||
"""Lazy import to ensure proper initialization order"""
|
||||
# Import pipelinemgr first to trigger proper stage registration
|
||||
pipelinemgr = import_module('pkg.pipeline.pipelinemgr')
|
||||
bansess = import_module('pkg.pipeline.bansess.bansess')
|
||||
entities = import_module('pkg.pipeline.entities')
|
||||
return bansess, entities
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitelist_allow(mock_app, sample_query):
|
||||
"""Test whitelist allows matching session"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'whitelist',
|
||||
'whitelist': ['person_12345']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query == sample_query
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitelist_deny(mock_app, sample_query):
|
||||
"""Test whitelist denies non-matching session"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '99999'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'whitelist',
|
||||
'whitelist': ['person_12345']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.INTERRUPT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blacklist_allow(mock_app, sample_query):
|
||||
"""Test blacklist allows non-matching session"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'blacklist',
|
||||
'blacklist': ['person_99999']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blacklist_deny(mock_app, sample_query):
|
||||
"""Test blacklist denies matching session"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'blacklist',
|
||||
'blacklist': ['person_12345']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.INTERRUPT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_group(mock_app, sample_query):
|
||||
"""Test group wildcard matching"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'whitelist',
|
||||
'whitelist': ['group_*']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_person(mock_app, sample_query):
|
||||
"""Test person wildcard matching"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'whitelist',
|
||||
'whitelist': ['person_*']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_id_wildcard(mock_app, sample_query):
|
||||
"""Test user ID wildcard matching (*_id format)"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.sender_id = '67890'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'whitelist',
|
||||
'whitelist': ['*_67890']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
166
tests/unit_tests/pipeline/test_pipelinemgr.py
Normal file
166
tests/unit_tests/pipeline/test_pipelinemgr.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
PipelineManager unit tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from importlib import import_module
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
def get_pipelinemgr_module():
|
||||
return import_module('pkg.pipeline.pipelinemgr')
|
||||
|
||||
|
||||
def get_stage_module():
|
||||
return import_module('pkg.pipeline.stage')
|
||||
|
||||
|
||||
def get_entities_module():
|
||||
return import_module('pkg.pipeline.entities')
|
||||
|
||||
|
||||
def get_persistence_pipeline_module():
|
||||
return import_module('pkg.entity.persistence.pipeline')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_manager_initialize(mock_app):
|
||||
"""Test pipeline manager initialization"""
|
||||
pipelinemgr = get_pipelinemgr_module()
|
||||
|
||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
||||
|
||||
manager = pipelinemgr.PipelineManager(mock_app)
|
||||
await manager.initialize()
|
||||
|
||||
assert manager.stage_dict is not None
|
||||
assert len(manager.pipelines) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_pipeline(mock_app):
|
||||
"""Test loading a single pipeline"""
|
||||
pipelinemgr = get_pipelinemgr_module()
|
||||
persistence_pipeline = get_persistence_pipeline_module()
|
||||
|
||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
||||
|
||||
manager = pipelinemgr.PipelineManager(mock_app)
|
||||
await manager.initialize()
|
||||
|
||||
# Create test pipeline entity
|
||||
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
||||
pipeline_entity.uuid = 'test-uuid'
|
||||
pipeline_entity.stages = []
|
||||
pipeline_entity.config = {'test': 'config'}
|
||||
|
||||
await manager.load_pipeline(pipeline_entity)
|
||||
|
||||
assert len(manager.pipelines) == 1
|
||||
assert manager.pipelines[0].pipeline_entity.uuid == 'test-uuid'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pipeline_by_uuid(mock_app):
|
||||
"""Test getting pipeline by UUID"""
|
||||
pipelinemgr = get_pipelinemgr_module()
|
||||
persistence_pipeline = get_persistence_pipeline_module()
|
||||
|
||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
||||
|
||||
manager = pipelinemgr.PipelineManager(mock_app)
|
||||
await manager.initialize()
|
||||
|
||||
# Create and add test pipeline
|
||||
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
||||
pipeline_entity.uuid = 'test-uuid'
|
||||
pipeline_entity.stages = []
|
||||
pipeline_entity.config = {}
|
||||
|
||||
await manager.load_pipeline(pipeline_entity)
|
||||
|
||||
# Test retrieval
|
||||
result = await manager.get_pipeline_by_uuid('test-uuid')
|
||||
assert result is not None
|
||||
assert result.pipeline_entity.uuid == 'test-uuid'
|
||||
|
||||
# Test non-existent UUID
|
||||
result = await manager.get_pipeline_by_uuid('non-existent')
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_pipeline(mock_app):
|
||||
"""Test removing a pipeline"""
|
||||
pipelinemgr = get_pipelinemgr_module()
|
||||
persistence_pipeline = get_persistence_pipeline_module()
|
||||
|
||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
||||
|
||||
manager = pipelinemgr.PipelineManager(mock_app)
|
||||
await manager.initialize()
|
||||
|
||||
# Create and add test pipeline
|
||||
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
||||
pipeline_entity.uuid = 'test-uuid'
|
||||
pipeline_entity.stages = []
|
||||
pipeline_entity.config = {}
|
||||
|
||||
await manager.load_pipeline(pipeline_entity)
|
||||
assert len(manager.pipelines) == 1
|
||||
|
||||
# Remove pipeline
|
||||
await manager.remove_pipeline('test-uuid')
|
||||
assert len(manager.pipelines) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_pipeline_execute(mock_app, sample_query):
|
||||
"""Test runtime pipeline execution"""
|
||||
pipelinemgr = get_pipelinemgr_module()
|
||||
stage = get_stage_module()
|
||||
persistence_pipeline = get_persistence_pipeline_module()
|
||||
|
||||
# Create mock stage that returns a simple result dict (avoiding Pydantic validation)
|
||||
mock_result = Mock()
|
||||
mock_result.result_type = Mock()
|
||||
mock_result.result_type.value = 'CONTINUE' # Simulate enum value
|
||||
mock_result.new_query = sample_query
|
||||
mock_result.user_notice = ''
|
||||
mock_result.console_notice = ''
|
||||
mock_result.debug_notice = ''
|
||||
mock_result.error_notice = ''
|
||||
|
||||
# Make it look like ResultType.CONTINUE
|
||||
from unittest.mock import MagicMock
|
||||
CONTINUE = MagicMock()
|
||||
CONTINUE.__eq__ = lambda self, other: True # Always equal for comparison
|
||||
mock_result.result_type = CONTINUE
|
||||
|
||||
mock_stage = Mock(spec=stage.PipelineStage)
|
||||
mock_stage.process = AsyncMock(return_value=mock_result)
|
||||
|
||||
# Create stage container
|
||||
stage_container = pipelinemgr.StageInstContainer(inst_name='TestStage', inst=mock_stage)
|
||||
|
||||
# Create pipeline entity
|
||||
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
||||
pipeline_entity.config = sample_query.pipeline_config
|
||||
|
||||
# Create runtime pipeline
|
||||
runtime_pipeline = pipelinemgr.RuntimePipeline(mock_app, pipeline_entity, [stage_container])
|
||||
|
||||
# Mock plugin connector
|
||||
event_ctx = Mock()
|
||||
event_ctx.is_prevented_default = Mock(return_value=False)
|
||||
mock_app.plugin_connector.emit_event = AsyncMock(return_value=event_ctx)
|
||||
|
||||
# Add query to cached_queries to prevent KeyError in finally block
|
||||
mock_app.query_pool.cached_queries[sample_query.query_id] = sample_query
|
||||
|
||||
# Execute pipeline
|
||||
await runtime_pipeline.run(sample_query)
|
||||
|
||||
# Verify stage was called
|
||||
mock_stage.process.assert_called_once()
|
||||
109
tests/unit_tests/pipeline/test_ratelimit.py
Normal file
109
tests/unit_tests/pipeline/test_ratelimit.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
RateLimit stage unit tests
|
||||
|
||||
Tests the actual RateLimit implementation from pkg.pipeline.ratelimit
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from importlib import import_module
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
|
||||
def get_modules():
|
||||
"""Lazy import to ensure proper initialization order"""
|
||||
# Import pipelinemgr first to trigger proper stage registration
|
||||
pipelinemgr = import_module('pkg.pipeline.pipelinemgr')
|
||||
ratelimit = import_module('pkg.pipeline.ratelimit.ratelimit')
|
||||
entities = import_module('pkg.pipeline.entities')
|
||||
algo_module = import_module('pkg.pipeline.ratelimit.algo')
|
||||
return ratelimit, entities, algo_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_access_allowed(mock_app, sample_query):
|
||||
"""Test RequireRateLimitOccupancy allows access when rate limit is not exceeded"""
|
||||
ratelimit, entities, algo_module = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {}
|
||||
|
||||
# Create mock algorithm that allows access
|
||||
mock_algo = Mock(spec=algo_module.ReteLimitAlgo)
|
||||
mock_algo.require_access = AsyncMock(return_value=True)
|
||||
mock_algo.initialize = AsyncMock()
|
||||
|
||||
stage = ratelimit.RateLimit(mock_app)
|
||||
|
||||
# Patch the algorithm selection to use our mock
|
||||
with patch.object(algo_module, 'preregistered_algos', []):
|
||||
stage.algo = mock_algo
|
||||
|
||||
result = await stage.process(sample_query, 'RequireRateLimitOccupancy')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query == sample_query
|
||||
mock_algo.require_access.assert_called_once_with(
|
||||
sample_query,
|
||||
'person',
|
||||
'12345'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_access_denied(mock_app, sample_query):
|
||||
"""Test RequireRateLimitOccupancy denies access when rate limit is exceeded"""
|
||||
ratelimit, entities, algo_module = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {}
|
||||
|
||||
# Create mock algorithm that denies access
|
||||
mock_algo = Mock(spec=algo_module.ReteLimitAlgo)
|
||||
mock_algo.require_access = AsyncMock(return_value=False)
|
||||
mock_algo.initialize = AsyncMock()
|
||||
|
||||
stage = ratelimit.RateLimit(mock_app)
|
||||
|
||||
# Patch the algorithm selection to use our mock
|
||||
with patch.object(algo_module, 'preregistered_algos', []):
|
||||
stage.algo = mock_algo
|
||||
|
||||
result = await stage.process(sample_query, 'RequireRateLimitOccupancy')
|
||||
|
||||
assert result.result_type == entities.ResultType.INTERRUPT
|
||||
assert result.user_notice != ''
|
||||
mock_algo.require_access.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_access(mock_app, sample_query):
|
||||
"""Test ReleaseRateLimitOccupancy releases rate limit occupancy"""
|
||||
ratelimit, entities, algo_module = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {}
|
||||
|
||||
# Create mock algorithm
|
||||
mock_algo = Mock(spec=algo_module.ReteLimitAlgo)
|
||||
mock_algo.release_access = AsyncMock()
|
||||
mock_algo.initialize = AsyncMock()
|
||||
|
||||
stage = ratelimit.RateLimit(mock_app)
|
||||
|
||||
# Patch the algorithm selection to use our mock
|
||||
with patch.object(algo_module, 'preregistered_algos', []):
|
||||
stage.algo = mock_algo
|
||||
|
||||
result = await stage.process(sample_query, 'ReleaseRateLimitOccupancy')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query == sample_query
|
||||
mock_algo.release_access.assert_called_once_with(
|
||||
sample_query,
|
||||
'person',
|
||||
'12345'
|
||||
)
|
||||
171
tests/unit_tests/pipeline/test_resprule.py
Normal file
171
tests/unit_tests/pipeline/test_resprule.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
GroupRespondRuleCheckStage unit tests
|
||||
|
||||
Tests the actual GroupRespondRuleCheckStage implementation from pkg.pipeline.resprule
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from importlib import import_module
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
|
||||
|
||||
def get_modules():
|
||||
"""Lazy import to ensure proper initialization order"""
|
||||
# Import pipelinemgr first to trigger proper stage registration
|
||||
pipelinemgr = import_module('pkg.pipeline.pipelinemgr')
|
||||
resprule = import_module('pkg.pipeline.resprule.resprule')
|
||||
entities = import_module('pkg.pipeline.entities')
|
||||
rule = import_module('pkg.pipeline.resprule.rule')
|
||||
rule_entities = import_module('pkg.pipeline.resprule.entities')
|
||||
return resprule, entities, rule, rule_entities
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_person_message_skip(mock_app, sample_query):
|
||||
"""Test person message skips rule check"""
|
||||
resprule, entities, rule, rule_entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'group-respond-rules': {}
|
||||
}
|
||||
}
|
||||
|
||||
stage = resprule.GroupRespondRuleCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'GroupRespondRuleCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query == sample_query
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_no_match(mock_app, sample_query):
|
||||
"""Test group message with no matching rules"""
|
||||
resprule, entities, rule, rule_entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'group-respond-rules': {}
|
||||
}
|
||||
}
|
||||
|
||||
# Create mock rule matcher that doesn't match
|
||||
mock_rule = Mock(spec=rule.GroupRespondRule)
|
||||
mock_rule.match = AsyncMock(return_value=rule_entities.RuleJudgeResult(
|
||||
matching=False,
|
||||
replacement=sample_query.message_chain
|
||||
))
|
||||
|
||||
stage = resprule.GroupRespondRuleCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
stage.rule_matchers = [mock_rule]
|
||||
|
||||
result = await stage.process(sample_query, 'GroupRespondRuleCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.INTERRUPT
|
||||
assert result.new_query == sample_query
|
||||
mock_rule.match.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_match(mock_app, sample_query):
|
||||
"""Test group message with matching rule"""
|
||||
resprule, entities, rule, rule_entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'group-respond-rules': {}
|
||||
}
|
||||
}
|
||||
|
||||
# Create new message chain after rule processing
|
||||
new_chain = platform_message.MessageChain([
|
||||
platform_message.Plain(text='Processed message')
|
||||
])
|
||||
|
||||
# Create mock rule matcher that matches
|
||||
mock_rule = Mock(spec=rule.GroupRespondRule)
|
||||
mock_rule.match = AsyncMock(return_value=rule_entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
replacement=new_chain
|
||||
))
|
||||
|
||||
stage = resprule.GroupRespondRuleCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
stage.rule_matchers = [mock_rule]
|
||||
|
||||
result = await stage.process(sample_query, 'GroupRespondRuleCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query == sample_query
|
||||
assert sample_query.message_chain == new_chain
|
||||
mock_rule.match.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atbot_rule_match(mock_app, sample_query):
|
||||
"""Test AtBotRule removes At component"""
|
||||
resprule, entities, rule, rule_entities = get_modules()
|
||||
atbot_module = import_module('pkg.pipeline.resprule.rules.atbot')
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
|
||||
sample_query.adapter.bot_account_id = '999'
|
||||
|
||||
# Create message chain with At component
|
||||
message_chain = platform_message.MessageChain([
|
||||
platform_message.At(target='999'),
|
||||
platform_message.Plain(text='Hello bot')
|
||||
])
|
||||
sample_query.message_chain = message_chain
|
||||
|
||||
atbot_rule = atbot_module.AtBotRule(mock_app)
|
||||
await atbot_rule.initialize()
|
||||
|
||||
result = await atbot_rule.match(
|
||||
str(message_chain),
|
||||
message_chain,
|
||||
{},
|
||||
sample_query
|
||||
)
|
||||
|
||||
assert result.matching is True
|
||||
# At component should be removed
|
||||
assert len(result.replacement.root) == 1
|
||||
assert isinstance(result.replacement.root[0], platform_message.Plain)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atbot_rule_no_match(mock_app, sample_query):
|
||||
"""Test AtBotRule when no At component present"""
|
||||
resprule, entities, rule, rule_entities = get_modules()
|
||||
atbot_module = import_module('pkg.pipeline.resprule.rules.atbot')
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
|
||||
sample_query.adapter.bot_account_id = '999'
|
||||
|
||||
# Create message chain without At component
|
||||
message_chain = platform_message.MessageChain([
|
||||
platform_message.Plain(text='Hello')
|
||||
])
|
||||
sample_query.message_chain = message_chain
|
||||
|
||||
atbot_rule = atbot_module.AtBotRule(mock_app)
|
||||
await atbot_rule.initialize()
|
||||
|
||||
result = await atbot_rule.match(
|
||||
str(message_chain),
|
||||
message_chain,
|
||||
{},
|
||||
sample_query
|
||||
)
|
||||
|
||||
assert result.matching is False
|
||||
40
tests/unit_tests/pipeline/test_simple.py
Normal file
40
tests/unit_tests/pipeline/test_simple.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Simple standalone tests to verify test infrastructure
|
||||
These tests don't import the actual pipeline code to avoid circular import issues
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
|
||||
def test_pytest_works():
|
||||
"""Verify pytest is working"""
|
||||
assert True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_works():
|
||||
"""Verify async tests work"""
|
||||
mock = AsyncMock(return_value=42)
|
||||
result = await mock()
|
||||
assert result == 42
|
||||
|
||||
|
||||
def test_mocks_work():
|
||||
"""Verify mocking works"""
|
||||
mock = Mock()
|
||||
mock.return_value = 'test'
|
||||
assert mock() == 'test'
|
||||
|
||||
|
||||
def test_fixtures_work(mock_app):
|
||||
"""Verify fixtures are loaded"""
|
||||
assert mock_app is not None
|
||||
assert mock_app.logger is not None
|
||||
assert mock_app.sess_mgr is not None
|
||||
|
||||
|
||||
def test_sample_query(sample_query):
|
||||
"""Verify sample query fixture works"""
|
||||
assert sample_query.query_id == 'test-query-id'
|
||||
assert sample_query.launcher_id == 12345
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
UploadIcon,
|
||||
StoreIcon,
|
||||
Download,
|
||||
Power,
|
||||
} from 'lucide-react';
|
||||
import {
|
||||
DropdownMenu,
|
||||
@@ -28,12 +29,13 @@ import {
|
||||
DialogFooter,
|
||||
} from '@/components/ui/dialog';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { useState, useRef, useCallback } from 'react';
|
||||
import { useState, useRef, useCallback, useEffect } from 'react';
|
||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||
import { toast } from 'sonner';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PluginV4 } from '@/app/infra/entities/plugin';
|
||||
import { systemInfo } from '@/app/infra/http/HttpClient';
|
||||
import { ApiRespPluginSystemStatus } from '@/app/infra/entities/api';
|
||||
|
||||
enum PluginInstallStatus {
|
||||
WAIT_INPUT = 'wait_input',
|
||||
@@ -54,9 +56,29 @@ export default function PluginConfigPage() {
|
||||
const [installError, setInstallError] = useState<string | null>(null);
|
||||
const [githubURL, setGithubURL] = useState('');
|
||||
const [isDragOver, setIsDragOver] = useState(false);
|
||||
const [pluginSystemStatus, setPluginSystemStatus] =
|
||||
useState<ApiRespPluginSystemStatus | null>(null);
|
||||
const [statusLoading, setStatusLoading] = useState(true);
|
||||
const pluginInstalledRef = useRef<PluginInstalledComponentRef>(null);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchPluginSystemStatus = async () => {
|
||||
try {
|
||||
setStatusLoading(true);
|
||||
const status = await httpClient.getPluginSystemStatus();
|
||||
setPluginSystemStatus(status);
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch plugin system status:', error);
|
||||
toast.error(t('plugins.failedToGetStatus'));
|
||||
} finally {
|
||||
setStatusLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchPluginSystemStatus();
|
||||
}, [t]);
|
||||
|
||||
function watchTask(taskId: number) {
|
||||
let alreadySuccess = false;
|
||||
console.log('taskId:', taskId);
|
||||
@@ -140,6 +162,11 @@ export default function PluginConfigPage() {
|
||||
|
||||
const uploadPluginFile = useCallback(
|
||||
async (file: File) => {
|
||||
if (!pluginSystemStatus?.is_enable || !pluginSystemStatus?.is_connected) {
|
||||
toast.error(t('plugins.pluginSystemNotReady'));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!validateFileType(file)) {
|
||||
toast.error(t('plugins.unsupportedFileType'));
|
||||
return;
|
||||
@@ -150,7 +177,7 @@ export default function PluginConfigPage() {
|
||||
setInstallError(null);
|
||||
installPlugin('local', { file });
|
||||
},
|
||||
[t],
|
||||
[t, pluginSystemStatus],
|
||||
);
|
||||
|
||||
const handleFileSelect = useCallback(() => {
|
||||
@@ -171,10 +198,18 @@ export default function PluginConfigPage() {
|
||||
[uploadPluginFile],
|
||||
);
|
||||
|
||||
const handleDragOver = useCallback((event: React.DragEvent) => {
|
||||
event.preventDefault();
|
||||
setIsDragOver(true);
|
||||
}, []);
|
||||
const isPluginSystemReady =
|
||||
pluginSystemStatus?.is_enable && pluginSystemStatus?.is_connected;
|
||||
|
||||
const handleDragOver = useCallback(
|
||||
(event: React.DragEvent) => {
|
||||
event.preventDefault();
|
||||
if (isPluginSystemReady) {
|
||||
setIsDragOver(true);
|
||||
}
|
||||
},
|
||||
[isPluginSystemReady],
|
||||
);
|
||||
|
||||
const handleDragLeave = useCallback((event: React.DragEvent) => {
|
||||
event.preventDefault();
|
||||
@@ -186,14 +221,76 @@ export default function PluginConfigPage() {
|
||||
event.preventDefault();
|
||||
setIsDragOver(false);
|
||||
|
||||
if (!isPluginSystemReady) {
|
||||
toast.error(t('plugins.pluginSystemNotReady'));
|
||||
return;
|
||||
}
|
||||
|
||||
const files = Array.from(event.dataTransfer.files);
|
||||
if (files.length > 0) {
|
||||
uploadPluginFile(files[0]);
|
||||
}
|
||||
},
|
||||
[uploadPluginFile],
|
||||
[uploadPluginFile, isPluginSystemReady, t],
|
||||
);
|
||||
|
||||
// 插件系统未启用的状态显示
|
||||
const renderPluginDisabledState = () => (
|
||||
<div className="flex flex-col items-center justify-center h-[60vh] text-center pt-[10vh]">
|
||||
<Power className="w-16 h-16 text-gray-400 mb-4" />
|
||||
<h2 className="text-2xl font-semibold text-gray-700 dark:text-gray-300 mb-2">
|
||||
{t('plugins.systemDisabled')}
|
||||
</h2>
|
||||
<p className="text-gray-500 dark:text-gray-400 max-w-md">
|
||||
{t('plugins.systemDisabledDesc')}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
// 插件系统连接异常的状态显示
|
||||
const renderPluginConnectionErrorState = () => (
|
||||
<div className="flex flex-col items-center justify-center h-[60vh] text-center pt-[10vh]">
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
width="72"
|
||||
height="72"
|
||||
fill="#BDBDBD"
|
||||
>
|
||||
<path d="M17.657 14.8284L16.2428 13.4142L17.657 12C19.2191 10.4379 19.2191 7.90526 17.657 6.34316C16.0949 4.78106 13.5622 4.78106 12.0001 6.34316L10.5859 7.75737L9.17171 6.34316L10.5859 4.92895C12.9291 2.5858 16.7281 2.5858 19.0712 4.92895C21.4143 7.27209 21.4143 11.0711 19.0712 13.4142L17.657 14.8284ZM14.8286 17.6569L13.4143 19.0711C11.0712 21.4142 7.27221 21.4142 4.92907 19.0711C2.58592 16.7279 2.58592 12.9289 4.92907 10.5858L6.34328 9.17159L7.75749 10.5858L6.34328 12C4.78118 13.5621 4.78118 16.0948 6.34328 17.6569C7.90538 19.219 10.438 19.219 12.0001 17.6569L13.4143 16.2427L14.8286 17.6569ZM14.8286 7.75737L16.2428 9.17159L9.17171 16.2427L7.75749 14.8284L14.8286 7.75737ZM5.77539 2.29291L7.70724 1.77527L8.74252 5.63897L6.81067 6.15661L5.77539 2.29291ZM15.2578 18.3611L17.1896 17.8434L18.2249 21.7071L16.293 22.2248L15.2578 18.3611ZM2.29303 5.77527L6.15673 6.81054L5.63909 8.7424L1.77539 7.70712L2.29303 5.77527ZM18.3612 15.2576L22.2249 16.2929L21.7072 18.2248L17.8435 17.1895L18.3612 15.2576Z"></path>
|
||||
</svg>
|
||||
|
||||
<h2 className="text-2xl font-semibold text-gray-700 dark:text-gray-300 mb-2">
|
||||
{t('plugins.connectionError')}
|
||||
</h2>
|
||||
<p className="text-gray-500 dark:text-gray-400 max-w-md mb-4">
|
||||
{t('plugins.connectionErrorDesc')}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
// 加载状态显示
|
||||
const renderLoadingState = () => (
|
||||
<div className="flex flex-col items-center justify-center h-[60vh] pt-[10vh]">
|
||||
<p className="text-gray-500 dark:text-gray-400">
|
||||
{t('plugins.loadingStatus')}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
// 根据状态返回不同的内容
|
||||
if (statusLoading) {
|
||||
return renderLoadingState();
|
||||
}
|
||||
|
||||
if (!pluginSystemStatus?.is_enable) {
|
||||
return renderPluginDisabledState();
|
||||
}
|
||||
|
||||
if (!pluginSystemStatus?.is_connected) {
|
||||
return renderPluginConnectionErrorState();
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${styles.pageContainer} ${isDragOver ? 'bg-blue-50' : ''}`}
|
||||
|
||||
@@ -215,6 +215,12 @@ export interface ApiRespSystemInfo {
|
||||
enable_marketplace: boolean;
|
||||
}
|
||||
|
||||
export interface ApiRespPluginSystemStatus {
|
||||
is_enable: boolean;
|
||||
is_connected: boolean;
|
||||
plugin_connector_error: string;
|
||||
}
|
||||
|
||||
export interface ApiRespAsyncTasks {
|
||||
tasks: AsyncTask[];
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ import {
|
||||
ApiRespProviderEmbeddingModels,
|
||||
ApiRespProviderEmbeddingModel,
|
||||
EmbeddingModel,
|
||||
ApiRespPluginSystemStatus,
|
||||
} from '@/app/infra/entities/api';
|
||||
import { GetBotLogsRequest } from '@/app/infra/http/requestParam/bots/GetBotLogsRequest';
|
||||
import { GetBotLogsResponse } from '@/app/infra/http/requestParam/bots/GetBotLogsResponse';
|
||||
@@ -500,6 +501,10 @@ export class BackendClient extends BaseHttpClient {
|
||||
return this.get(`/api/v1/system/tasks/${id}`);
|
||||
}
|
||||
|
||||
public getPluginSystemStatus(): Promise<ApiRespPluginSystemStatus> {
|
||||
return this.get('/api/v1/system/status/plugin-system');
|
||||
}
|
||||
|
||||
// ============ User API ============
|
||||
public checkIfInited(): Promise<{ initialized: boolean }> {
|
||||
return this.get('/api/v1/user/init');
|
||||
|
||||
@@ -182,6 +182,17 @@ const enUS = {
|
||||
pluginSortSuccess: 'Plugin sort successful',
|
||||
pluginSortError: 'Plugin sort failed: ',
|
||||
pluginNoConfig: 'The plugin has no configuration items.',
|
||||
systemDisabled: 'Plugin System Disabled',
|
||||
systemDisabledDesc:
|
||||
'Plugin system is not enabled, please modify the configuration according to the documentation',
|
||||
connectionError: 'Plugin System Connection Error',
|
||||
connectionErrorDesc:
|
||||
'Please check the plugin system configuration or contact the administrator.',
|
||||
errorDetails: 'Error Details',
|
||||
loadingStatus: 'Checking plugin system status...',
|
||||
failedToGetStatus: 'Failed to get plugin system status',
|
||||
pluginSystemNotReady:
|
||||
'Plugin system is not ready, cannot perform this operation',
|
||||
deleting: 'Deleting...',
|
||||
deletePlugin: 'Delete Plugin',
|
||||
cancel: 'Cancel',
|
||||
|
||||
@@ -183,6 +183,17 @@ const jaJP = {
|
||||
pluginSortSuccess: 'プラグインの並び替えに成功しました',
|
||||
pluginSortError: 'プラグインの並び替えに失敗しました:',
|
||||
pluginNoConfig: 'プラグインに設定項目がありません。',
|
||||
systemDisabled: 'プラグインシステムが無効になっています',
|
||||
systemDisabledDesc:
|
||||
'プラグインシステムが無効になっています。プラグインシステムを有効にするか、ドキュメントに従って設定を変更してください',
|
||||
connectionError: 'プラグインシステム接続エラー',
|
||||
connectionErrorDesc:
|
||||
'プラグインシステム設定を確認するか、管理者に連絡してください',
|
||||
errorDetails: 'エラー詳細',
|
||||
loadingStatus: 'プラグインシステム状態を確認中...',
|
||||
failedToGetStatus: 'プラグインシステム状態の取得に失敗しました',
|
||||
pluginSystemNotReady:
|
||||
'プラグインシステムが準備されていません。この操作を実行できません',
|
||||
deleting: '削除中...',
|
||||
deletePlugin: 'プラグインを削除',
|
||||
cancel: 'キャンセル',
|
||||
|
||||
@@ -178,6 +178,14 @@ const zhHans = {
|
||||
pluginSortSuccess: '插件排序成功',
|
||||
pluginSortError: '插件排序失败:',
|
||||
pluginNoConfig: '插件没有配置项。',
|
||||
systemDisabled: '插件系统未启用',
|
||||
systemDisabledDesc: '尚未启用插件系统,请根据文档修改配置',
|
||||
connectionError: '插件系统连接异常',
|
||||
connectionErrorDesc: '请检查插件系统配置或联系管理员',
|
||||
errorDetails: '错误详情',
|
||||
loadingStatus: '正在检查插件系统状态...',
|
||||
failedToGetStatus: '获取插件系统状态失败',
|
||||
pluginSystemNotReady: '插件系统未就绪,无法执行此操作',
|
||||
deleting: '删除中...',
|
||||
deletePlugin: '删除插件',
|
||||
cancel: '取消',
|
||||
|
||||
@@ -178,6 +178,14 @@ const zhHant = {
|
||||
pluginSortSuccess: '外掛排序成功',
|
||||
pluginSortError: '外掛排序失敗:',
|
||||
pluginNoConfig: '外掛沒有設定項目。',
|
||||
systemDisabled: '外掛系統未啟用',
|
||||
systemDisabledDesc: '尚未啟用外掛系統,請根據文檔修改配置',
|
||||
connectionError: '外掛系統連接異常',
|
||||
connectionErrorDesc: '請檢查外掛系統配置或聯絡管理員',
|
||||
errorDetails: '錯誤詳情',
|
||||
loadingStatus: '正在檢查外掛系統狀態...',
|
||||
failedToGetStatus: '取得外掛系統狀態失敗',
|
||||
pluginSystemNotReady: '外掛系統未就緒,無法執行此操作',
|
||||
deleting: '刪除中...',
|
||||
deletePlugin: '刪除外掛',
|
||||
cancel: '取消',
|
||||
|
||||
Reference in New Issue
Block a user