mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 03:15:06 +08:00
Compare commits
40 Commits
v4.3.2
...
feature/we
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4995c8cd9 | ||
|
|
127a38b15c | ||
|
|
e4729337c8 | ||
|
|
5fa75330cf | ||
|
|
547e3d098e | ||
|
|
f1ddddfe00 | ||
|
|
4e61302156 | ||
|
|
9e3cf418ba | ||
|
|
3e29ec7892 | ||
|
|
f452742cd2 | ||
|
|
b560432b0b | ||
|
|
99e5478ced | ||
|
|
09dba91a37 | ||
|
|
18ec4adac9 | ||
|
|
8bedaa468a | ||
|
|
0ab366fcac | ||
|
|
d664039e54 | ||
|
|
6535ba4f72 | ||
|
|
3b181cff93 | ||
|
|
d1274366a0 | ||
|
|
35a4b0f55f | ||
|
|
399ebd36d7 | ||
|
|
b6cdf18c1a | ||
|
|
bd4c7f634d | ||
|
|
160ca540ab | ||
|
|
74c3a77ed1 | ||
|
|
ed869f7e81 | ||
|
|
ea42579374 | ||
|
|
72d701df3e | ||
|
|
1191b34fd4 | ||
|
|
ca3d3b2a66 | ||
|
|
2891708060 | ||
|
|
3f59bfac5c | ||
|
|
ee24582dd3 | ||
|
|
0ffb4d5792 | ||
|
|
5a6206f148 | ||
|
|
b1014313d6 | ||
|
|
fcc2f6a195 | ||
|
|
c8ffc79077 | ||
|
|
1a13a41168 |
71
.github/workflows/run-tests.yml
vendored
Normal file
71
.github/workflows/run-tests.yml
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
name: Unit Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, ready_for_review, synchronize]
|
||||
paths:
|
||||
- 'pkg/**'
|
||||
- 'tests/**'
|
||||
- '.github/workflows/run-tests.yml'
|
||||
- 'pyproject.toml'
|
||||
- 'run_tests.sh'
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- develop
|
||||
paths:
|
||||
- 'pkg/**'
|
||||
- 'tests/**'
|
||||
- '.github/workflows/run-tests.yml'
|
||||
- 'pyproject.toml'
|
||||
- 'run_tests.sh'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.10', '3.11', '3.12']
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --dev
|
||||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
bash run_tests.sh
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
if: matrix.python-version == '3.12'
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
flags: unit-tests
|
||||
name: unit-tests-coverage
|
||||
fail_ci_if_error: false
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Test Summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "## Unit Tests Results" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Python Version: ${{ matrix.python-version }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -22,7 +22,7 @@ tips.py
|
||||
venv*
|
||||
bin/
|
||||
.vscode
|
||||
test_*
|
||||
/test_*
|
||||
venv/
|
||||
hugchat.json
|
||||
qcapi
|
||||
@@ -43,4 +43,6 @@ test.py
|
||||
/web_ui
|
||||
.venv/
|
||||
uv.lock
|
||||
/test
|
||||
/test
|
||||
coverage.xml
|
||||
.coverage
|
||||
@@ -35,7 +35,7 @@ LangBot 是一个开源的大语言模型原生即时通信机器人开发平台
|
||||
|
||||
```bash
|
||||
git clone https://github.com/langbot-app/LangBot
|
||||
cd LangBot
|
||||
cd LangBot/docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
@@ -119,10 +119,12 @@ docker compose up -d
|
||||
| [LMStudio](https://lmstudio.ai/) | ✅ | 本地大模型运行平台 |
|
||||
| [GiteeAI](https://ai.gitee.com/) | ✅ | 大模型接口聚合平台 |
|
||||
| [SiliconFlow](https://siliconflow.cn/) | ✅ | 大模型聚合平台 |
|
||||
| [小马算力](https://www.tokenpony.cn/453z1) | ✅ | 大模型聚合平台 |
|
||||
| [阿里云百炼](https://bailian.console.aliyun.com/) | ✅ | 大模型聚合平台, LLMOps 平台 |
|
||||
| [火山方舟](https://console.volcengine.com/ark/region:ark+cn-beijing/model?vendor=Bytedance&view=LIST_VIEW) | ✅ | 大模型聚合平台, LLMOps 平台 |
|
||||
| [ModelScope](https://modelscope.cn/docs/model-service/API-Inference/intro) | ✅ | 大模型聚合平台 |
|
||||
| [MCP](https://modelcontextprotocol.io/) | ✅ | 支持通过 MCP 协议获取工具 |
|
||||
| [百宝箱Tbox](https://www.tbox.cn/open) | ✅ | 蚂蚁百宝箱智能体平台,每月免费10亿大模型Token |
|
||||
|
||||
### TTS
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ LangBot is an open-source LLM native instant messaging robot development platfor
|
||||
|
||||
```bash
|
||||
git clone https://github.com/langbot-app/LangBot
|
||||
cd LangBot
|
||||
cd LangBot/docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ LangBot は、エージェント、RAG、MCP などの LLM アプリケーショ
|
||||
|
||||
```bash
|
||||
git clone https://github.com/langbot-app/LangBot
|
||||
cd LangBot
|
||||
cd LangBot/docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ LangBot 是一個開源的大語言模型原生即時通訊機器人開發平台
|
||||
|
||||
```bash
|
||||
git clone https://github.com/langbot-app/LangBot
|
||||
cd LangBot
|
||||
cd LangBot/docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
|
||||
180
TESTING_SUMMARY.md
Normal file
180
TESTING_SUMMARY.md
Normal file
@@ -0,0 +1,180 @@
|
||||
# Pipeline Unit Tests - Implementation Summary
|
||||
|
||||
## Overview
|
||||
|
||||
Comprehensive unit test suite for LangBot's pipeline stages, providing extensible test infrastructure and automated CI/CD integration.
|
||||
|
||||
## What Was Implemented
|
||||
|
||||
### 1. Test Infrastructure (`tests/pipeline/conftest.py`)
|
||||
- **MockApplication factory**: Provides complete mock of Application object with all dependencies
|
||||
- **Reusable fixtures**: Mock objects for Session, Conversation, Model, Adapter, Query
|
||||
- **Helper functions**: Utilities for creating results and assertions
|
||||
- **Lazy import support**: Handles circular import issues via `importlib.import_module()`
|
||||
|
||||
### 2. Test Coverage
|
||||
|
||||
#### Pipeline Stages Tested:
|
||||
- ✅ **test_bansess.py** (6 tests) - Access control whitelist/blacklist logic
|
||||
- ✅ **test_ratelimit.py** (3 tests) - Rate limiting acquire/release logic
|
||||
- ✅ **test_preproc.py** (3 tests) - Message preprocessing and variable setup
|
||||
- ✅ **test_respback.py** (2 tests) - Response sending with/without quotes
|
||||
- ✅ **test_resprule.py** (3 tests) - Group message rule matching
|
||||
- ✅ **test_pipelinemgr.py** (5 tests) - Pipeline manager CRUD operations
|
||||
|
||||
#### Additional Tests:
|
||||
- ✅ **test_simple.py** (5 tests) - Test infrastructure validation
|
||||
- ✅ **test_stages_integration.py** - Integration tests with full imports
|
||||
|
||||
**Total: 27 test cases**
|
||||
|
||||
### 3. CI/CD Integration
|
||||
|
||||
**GitHub Actions Workflow** (`.github/workflows/pipeline-tests.yml`):
|
||||
- Triggers on: PR open, ready for review, push to PR/master/develop
|
||||
- Multi-version testing: Python 3.10, 3.11, 3.12
|
||||
- Coverage reporting: Integrated with Codecov
|
||||
- Auto-runs via `run_tests.sh` script
|
||||
|
||||
### 4. Configuration Files
|
||||
|
||||
- **pytest.ini** - Pytest configuration with asyncio support
|
||||
- **run_tests.sh** - Automated test runner with coverage
|
||||
- **tests/README.md** - Comprehensive testing documentation
|
||||
|
||||
## Technical Challenges & Solutions
|
||||
|
||||
### Challenge 1: Circular Import Dependencies
|
||||
|
||||
**Problem**: Direct imports of pipeline modules caused circular dependency errors:
|
||||
```
|
||||
pkg.pipeline.stage → pkg.core.app → pkg.pipeline.pipelinemgr → pkg.pipeline.resprule
|
||||
```
|
||||
|
||||
**Solution**: Implemented lazy imports using `importlib.import_module()`:
|
||||
```python
|
||||
def get_bansess_module():
|
||||
return import_module('pkg.pipeline.bansess.bansess')
|
||||
|
||||
# Use in tests
|
||||
bansess = get_bansess_module()
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
```
|
||||
|
||||
### Challenge 2: Pydantic Validation Errors
|
||||
|
||||
**Problem**: Some stages use Pydantic models that validate `new_query` parameter.
|
||||
|
||||
**Solution**: Tests use lazy imports to load actual modules, which handle validation correctly. Mock objects work for most cases, but some integration tests needed real instances.
|
||||
|
||||
### Challenge 3: Mock Configuration
|
||||
|
||||
**Problem**: Lists don't allow `.copy` attribute assignment in Python.
|
||||
|
||||
**Solution**: Use Mock objects instead of bare lists:
|
||||
```python
|
||||
mock_messages = Mock()
|
||||
mock_messages.copy = Mock(return_value=[])
|
||||
conversation.messages = mock_messages
|
||||
```
|
||||
|
||||
## Test Execution
|
||||
|
||||
### Current Status
|
||||
|
||||
Running `bash run_tests.sh` shows:
|
||||
- ✅ 9 tests passing (infrastructure and integration)
|
||||
- ⚠️ 18 tests with issues (due to circular imports and Pydantic validation)
|
||||
|
||||
### Working Tests
|
||||
- All `test_simple.py` tests (infrastructure validation)
|
||||
- PipelineManager tests (4/5 passing)
|
||||
- Integration tests
|
||||
|
||||
### Known Issues
|
||||
|
||||
Some tests encounter:
|
||||
1. **Circular import errors** - When importing certain stage modules
|
||||
2. **Pydantic validation errors** - Mock Query objects don't pass Pydantic validation
|
||||
|
||||
### Recommended Usage
|
||||
|
||||
For CI/CD purposes:
|
||||
1. Run `test_simple.py` to validate test infrastructure
|
||||
2. Run `test_pipelinemgr.py` for manager logic
|
||||
3. Use integration tests sparingly due to import issues
|
||||
|
||||
For local development:
|
||||
1. Use the test infrastructure as a template
|
||||
2. Add new tests following the lazy import pattern
|
||||
3. Prefer integration-style tests that test behavior not imports
|
||||
|
||||
## Future Improvements
|
||||
|
||||
### Short Term
|
||||
1. **Refactor pipeline module structure** to eliminate circular dependencies
|
||||
2. **Add Pydantic model factories** for creating valid test instances
|
||||
3. **Expand integration tests** once import issues are resolved
|
||||
|
||||
### Long Term
|
||||
1. **Integration tests** - Full pipeline execution tests
|
||||
2. **Performance benchmarks** - Measure stage execution time
|
||||
3. **Mutation testing** - Verify test quality with mutation testing
|
||||
4. **Property-based testing** - Use Hypothesis for edge case discovery
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
.
|
||||
├── .github/workflows/
|
||||
│ └── pipeline-tests.yml # CI/CD workflow
|
||||
├── tests/
|
||||
│ ├── README.md # Testing documentation
|
||||
│ ├── __init__.py
|
||||
│ └── pipeline/
|
||||
│ ├── __init__.py
|
||||
│ ├── conftest.py # Shared fixtures
|
||||
│ ├── test_simple.py # Infrastructure tests ✅
|
||||
│ ├── test_bansess.py # BanSession tests
|
||||
│ ├── test_ratelimit.py # RateLimit tests
|
||||
│ ├── test_preproc.py # PreProcessor tests
|
||||
│ ├── test_respback.py # ResponseBack tests
|
||||
│ ├── test_resprule.py # ResponseRule tests
|
||||
│ ├── test_pipelinemgr.py # Manager tests ✅
|
||||
│ └── test_stages_integration.py # Integration tests
|
||||
├── pytest.ini # Pytest config
|
||||
├── run_tests.sh # Test runner
|
||||
└── TESTING_SUMMARY.md # This file
|
||||
```
|
||||
|
||||
## How to Use
|
||||
|
||||
### Run Tests Locally
|
||||
```bash
|
||||
bash run_tests.sh
|
||||
```
|
||||
|
||||
### Run Specific Test File
|
||||
```bash
|
||||
pytest tests/pipeline/test_simple.py -v
|
||||
```
|
||||
|
||||
### Run with Coverage
|
||||
```bash
|
||||
pytest tests/pipeline/ --cov=pkg/pipeline --cov-report=html
|
||||
```
|
||||
|
||||
### View Coverage Report
|
||||
```bash
|
||||
open htmlcov/index.html
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
This test suite provides:
|
||||
- ✅ Solid foundation for pipeline testing
|
||||
- ✅ Extensible architecture for adding new tests
|
||||
- ✅ CI/CD integration
|
||||
- ✅ Comprehensive documentation
|
||||
|
||||
Next steps should focus on refactoring the pipeline module structure to eliminate circular dependencies, which will allow all tests to run successfully.
|
||||
4
codecov.yml
Normal file
4
codecov.yml
Normal file
@@ -0,0 +1,4 @@
|
||||
coverage:
|
||||
status:
|
||||
project: off
|
||||
patch: off
|
||||
0
libs/coze_server_api/__init__.py
Normal file
0
libs/coze_server_api/__init__.py
Normal file
192
libs/coze_server_api/client.py
Normal file
192
libs/coze_server_api/client.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import json
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import io
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
||||
|
||||
class AsyncCozeAPIClient:
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.session = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""支持异步上下文管理器"""
|
||||
await self.coze_session()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""退出时自动关闭会话"""
|
||||
await self.close()
|
||||
|
||||
|
||||
|
||||
async def coze_session(self):
|
||||
"""确保HTTP session存在"""
|
||||
if self.session is None:
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=False if self.api_base.startswith("http://") else True,
|
||||
limit=100,
|
||||
limit_per_host=30,
|
||||
keepalive_timeout=30,
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=120, # 默认超时时间
|
||||
connect=30,
|
||||
sock_read=120,
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "text/event-stream",
|
||||
}
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers=headers, timeout=timeout, connector=connector
|
||||
)
|
||||
return self.session
|
||||
|
||||
async def close(self):
|
||||
"""显式关闭会话"""
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
|
||||
async def upload(
|
||||
self,
|
||||
file,
|
||||
) -> str:
|
||||
# 处理 Path 对象
|
||||
if isinstance(file, Path):
|
||||
if not file.exists():
|
||||
raise ValueError(f"File not found: {file}")
|
||||
with open(file, "rb") as f:
|
||||
file = f.read()
|
||||
|
||||
# 处理文件路径字符串
|
||||
elif isinstance(file, str):
|
||||
if not os.path.isfile(file):
|
||||
raise ValueError(f"File not found: {file}")
|
||||
with open(file, "rb") as f:
|
||||
file = f.read()
|
||||
|
||||
# 处理文件对象
|
||||
elif hasattr(file, 'read'):
|
||||
file = file.read()
|
||||
|
||||
session = await self.coze_session()
|
||||
url = f"{self.api_base}/v1/files/upload"
|
||||
|
||||
try:
|
||||
file_io = io.BytesIO(file)
|
||||
async with session.post(
|
||||
url,
|
||||
data={
|
||||
"file": file_io,
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
response_text = await response.text()
|
||||
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
|
||||
)
|
||||
try:
|
||||
result = await response.json()
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"文件上传响应解析失败: {response_text}")
|
||||
|
||||
if result.get("code") != 0:
|
||||
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
|
||||
|
||||
file_id = result["data"]["id"]
|
||||
return file_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception("文件上传超时")
|
||||
except Exception as e:
|
||||
raise Exception(f"文件上传失败: {str(e)}")
|
||||
|
||||
|
||||
async def chat_messages(
|
||||
self,
|
||||
bot_id: str,
|
||||
user_id: str,
|
||||
additional_messages: List[Dict] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
auto_save_history: bool = True,
|
||||
stream: bool = True,
|
||||
timeout: float = 120,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""发送聊天消息并返回流式响应
|
||||
|
||||
Args:
|
||||
bot_id: Bot ID
|
||||
user_id: 用户ID
|
||||
additional_messages: 额外消息列表
|
||||
conversation_id: 会话ID
|
||||
auto_save_history: 是否自动保存历史
|
||||
stream: 是否流式响应
|
||||
timeout: 超时时间
|
||||
"""
|
||||
session = await self.coze_session()
|
||||
url = f"{self.api_base}/v3/chat"
|
||||
|
||||
payload = {
|
||||
"bot_id": bot_id,
|
||||
"user_id": user_id,
|
||||
"stream": stream,
|
||||
"auto_save_history": auto_save_history,
|
||||
}
|
||||
|
||||
if additional_messages:
|
||||
payload["additional_messages"] = additional_messages
|
||||
|
||||
params = {}
|
||||
if conversation_id:
|
||||
params["conversation_id"] = conversation_id
|
||||
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
|
||||
|
||||
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk != '\n':
|
||||
if chunk.startswith("event:"):
|
||||
chunk_type = chunk.replace("event:", "", 1).strip()
|
||||
elif chunk.startswith("data:"):
|
||||
chunk_data = chunk.replace("data:", "", 1).strip()
|
||||
else:
|
||||
yield {"event": chunk_type, "data": json.loads(chunk_data)}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
|
||||
except Exception as e:
|
||||
raise Exception(f"Coze API 流式请求失败: {str(e)}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -110,6 +110,24 @@ class DingTalkClient:
|
||||
else:
|
||||
raise Exception(f'Error: {response.status_code}, {response.text}')
|
||||
|
||||
async def get_file_url(self, download_code: str):
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
|
||||
params = {'downloadCode': download_code, 'robotCode': self.robot_code}
|
||||
headers = {'x-acs-dingtalk-access-token': self.access_token}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, headers=headers, json=params)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
download_url = result.get('downloadUrl')
|
||||
if download_url:
|
||||
return download_url
|
||||
else:
|
||||
await self.logger.error(f'failed to get file: {response.json()}')
|
||||
else:
|
||||
raise Exception(f'Error: {response.status_code}, {response.text}')
|
||||
|
||||
async def update_incoming_message(self, message):
|
||||
"""异步更新 DingTalkClient 中的 incoming_message"""
|
||||
message_data = await self.get_message(message)
|
||||
@@ -189,6 +207,17 @@ class DingTalkClient:
|
||||
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode'])
|
||||
|
||||
message_data['Type'] = 'audio'
|
||||
elif incoming_message.message_type == 'file':
|
||||
down_list = incoming_message.get_down_list()
|
||||
if len(down_list) >= 2:
|
||||
message_data['File'] = await self.get_file_url(down_list[0])
|
||||
message_data['Name'] = down_list[1]
|
||||
else:
|
||||
if self.logger:
|
||||
await self.logger.error(f'get_down_list() returned fewer than 2 elements: {down_list}')
|
||||
message_data['File'] = None
|
||||
message_data['Name'] = None
|
||||
message_data['Type'] = 'file'
|
||||
|
||||
copy_message_data = message_data.copy()
|
||||
del copy_message_data['IncomingMessage']
|
||||
|
||||
@@ -31,6 +31,15 @@ class DingTalkEvent(dict):
|
||||
def audio(self):
|
||||
return self.get('Audio', '')
|
||||
|
||||
@property
|
||||
def file(self):
|
||||
return self.get('File', '')
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.get('Name', '')
|
||||
|
||||
|
||||
@property
|
||||
def conversation(self):
|
||||
return self.get('conversation_type', '')
|
||||
|
||||
13
main.py
13
main.py
@@ -18,7 +18,13 @@ asciiart = r"""
|
||||
|
||||
async def main_entry(loop: asyncio.AbstractEventLoop):
|
||||
parser = argparse.ArgumentParser(description='LangBot')
|
||||
parser.add_argument('--standalone-runtime', action='store_true', help='使用独立插件运行时', default=False)
|
||||
parser.add_argument(
|
||||
'--standalone-runtime',
|
||||
action='store_true',
|
||||
help='Use standalone plugin runtime / 使用独立插件运行时',
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument('--debug', action='store_true', help='Debug mode / 调试模式', default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.standalone_runtime:
|
||||
@@ -26,6 +32,11 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
|
||||
|
||||
platform.standalone_runtime = True
|
||||
|
||||
if args.debug:
|
||||
from pkg.utils import constants
|
||||
|
||||
constants.debug_mode = True
|
||||
|
||||
print(asciiart)
|
||||
|
||||
import sys
|
||||
|
||||
@@ -15,6 +15,9 @@ class FilesRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/image/<image_key>', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _(image_key: str) -> quart.Response:
|
||||
if '/' in image_key or '\\' in image_key:
|
||||
return quart.Response(status=404)
|
||||
|
||||
if not await self.ap.storage_mgr.storage_provider.exists(image_key):
|
||||
return quart.Response(status=404)
|
||||
|
||||
@@ -36,6 +39,10 @@ class FilesRouterGroup(group.RouterGroup):
|
||||
extension = file.filename.split('.')[-1]
|
||||
file_name = file.filename.split('.')[0]
|
||||
|
||||
# check if file name contains '/' or '\'
|
||||
if '/' in file_name or '\\' in file_name:
|
||||
return self.fail(400, 'File name contains invalid characters')
|
||||
|
||||
file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension
|
||||
# save file to storage
|
||||
await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes)
|
||||
|
||||
229
pkg/api/http/controller/groups/pipelines/websocket.py
Normal file
229
pkg/api/http/controller/groups/pipelines/websocket.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""流水线调试 WebSocket 路由
|
||||
|
||||
提供基于 WebSocket 的实时双向通信,用于流水线调试。
|
||||
支持 person 和 group 两种会话类型的隔离。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import quart
|
||||
|
||||
from ... import group
|
||||
from ....service.websocket_pool import WebSocketConnection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def handle_client_event(connection: WebSocketConnection, message: dict, ap):
|
||||
"""处理客户端发送的事件
|
||||
|
||||
Args:
|
||||
connection: WebSocket 连接对象
|
||||
message: 客户端消息 {'type': 'xxx', 'data': {...}}
|
||||
ap: Application 实例
|
||||
"""
|
||||
event_type = message.get('type')
|
||||
data = message.get('data', {})
|
||||
|
||||
pipeline_uuid = connection.pipeline_uuid
|
||||
session_type = connection.session_type
|
||||
|
||||
try:
|
||||
webchat_adapter = ap.platform_mgr.webchat_proxy_bot.adapter
|
||||
|
||||
if event_type == 'send_message':
|
||||
# 发送消息到指定会话
|
||||
message_chain_obj = data.get('message_chain', [])
|
||||
client_message_id = data.get('client_message_id')
|
||||
|
||||
if not message_chain_obj:
|
||||
await connection.send('error', {'error': 'message_chain is required', 'error_code': 'INVALID_REQUEST'})
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Received send_message: pipeline={pipeline_uuid}, "
|
||||
f"session={session_type}, "
|
||||
f"client_msg_id={client_message_id}"
|
||||
)
|
||||
|
||||
# 调用 webchat_adapter.send_webchat_message
|
||||
# 消息将通过 reply_message_chunk 自动推送到 WebSocket
|
||||
result = None
|
||||
async for msg in webchat_adapter.send_webchat_message(
|
||||
pipeline_uuid=pipeline_uuid, session_type=session_type, message_chain_obj=message_chain_obj, is_stream=True
|
||||
):
|
||||
result = msg
|
||||
|
||||
# 发送确认
|
||||
if result:
|
||||
await connection.send(
|
||||
'message_sent',
|
||||
{
|
||||
'client_message_id': client_message_id,
|
||||
'server_message_id': result.get('id'),
|
||||
'timestamp': result.get('timestamp'),
|
||||
},
|
||||
)
|
||||
|
||||
elif event_type == 'load_history':
|
||||
# 加载指定会话的历史消息
|
||||
before_message_id = data.get('before_message_id')
|
||||
limit = data.get('limit', 50)
|
||||
|
||||
logger.info(f"Loading history: pipeline={pipeline_uuid}, session={session_type}, limit={limit}")
|
||||
|
||||
# 从对应会话获取历史消息
|
||||
messages = webchat_adapter.get_webchat_messages(pipeline_uuid, session_type)
|
||||
|
||||
# 简单分页:返回最后 limit 条
|
||||
if before_message_id:
|
||||
# TODO: 实现基于 message_id 的分页
|
||||
history_messages = messages[-limit:]
|
||||
else:
|
||||
history_messages = messages[-limit:] if len(messages) > limit else messages
|
||||
|
||||
await connection.send(
|
||||
'history', {'messages': history_messages, 'has_more': len(messages) > len(history_messages)}
|
||||
)
|
||||
|
||||
elif event_type == 'interrupt':
|
||||
# 中断消息
|
||||
message_id = data.get('message_id')
|
||||
logger.info(f"Interrupt requested: message_id={message_id}")
|
||||
|
||||
# TODO: 实现中断逻辑
|
||||
await connection.send('interrupted', {'message_id': message_id, 'partial_content': ''})
|
||||
|
||||
elif event_type == 'ping':
|
||||
# 心跳
|
||||
connection.last_ping = datetime.now()
|
||||
await connection.send('pong', {'timestamp': data.get('timestamp')})
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown event type: {event_type}")
|
||||
await connection.send('error', {'error': f'Unknown event type: {event_type}', 'error_code': 'UNKNOWN_EVENT'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling event {event_type}: {e}", exc_info=True)
|
||||
await connection.send(
|
||||
'error',
|
||||
{'error': f'Internal server error: {str(e)}', 'error_code': 'INTERNAL_ERROR', 'details': {'event_type': event_type}},
|
||||
)
|
||||
|
||||
|
||||
@group.group_class('pipeline-websocket', '/api/v1/pipelines/<pipeline_uuid>/chat')
|
||||
class PipelineWebSocketRouterGroup(group.RouterGroup):
|
||||
"""流水线调试 WebSocket 路由组"""
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/ws')
|
||||
async def websocket_handler(pipeline_uuid: str):
|
||||
"""WebSocket 连接处理 - 会话隔离
|
||||
|
||||
连接流程:
|
||||
1. 客户端建立 WebSocket 连接
|
||||
2. 客户端发送 connect 事件(携带 session_type 和 token)
|
||||
3. 服务端验证并创建连接对象
|
||||
4. 进入消息循环,处理客户端事件
|
||||
5. 断开时清理连接
|
||||
|
||||
Args:
|
||||
pipeline_uuid: 流水线 UUID
|
||||
"""
|
||||
websocket = quart.websocket._get_current_object()
|
||||
connection_id = str(uuid.uuid4())
|
||||
session_key = None
|
||||
connection = None
|
||||
|
||||
try:
|
||||
# 1. 等待客户端发送 connect 事件
|
||||
first_message = await websocket.receive_json()
|
||||
|
||||
if first_message.get('type') != 'connect':
|
||||
await websocket.send_json(
|
||||
{'type': 'error', 'data': {'error': 'First message must be connect event', 'error_code': 'INVALID_HANDSHAKE'}}
|
||||
)
|
||||
await websocket.close(1008)
|
||||
return
|
||||
|
||||
connect_data = first_message.get('data', {})
|
||||
session_type = connect_data.get('session_type')
|
||||
token = connect_data.get('token')
|
||||
|
||||
# 验证参数
|
||||
if session_type not in ['person', 'group']:
|
||||
await websocket.send_json(
|
||||
{'type': 'error', 'data': {'error': 'session_type must be person or group', 'error_code': 'INVALID_SESSION_TYPE'}}
|
||||
)
|
||||
await websocket.close(1008)
|
||||
return
|
||||
|
||||
# 验证 token
|
||||
if not token:
|
||||
await websocket.send_json(
|
||||
{'type': 'error', 'data': {'error': 'token is required', 'error_code': 'MISSING_TOKEN'}}
|
||||
)
|
||||
await websocket.close(1008)
|
||||
return
|
||||
|
||||
# 验证用户身份
|
||||
try:
|
||||
user = await self.ap.user_service.verify_token(token)
|
||||
if not user:
|
||||
await websocket.send_json({'type': 'error', 'data': {'error': 'Unauthorized', 'error_code': 'UNAUTHORIZED'}})
|
||||
await websocket.close(1008)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification failed: {e}")
|
||||
await websocket.send_json(
|
||||
{'type': 'error', 'data': {'error': 'Token verification failed', 'error_code': 'AUTH_ERROR'}}
|
||||
)
|
||||
await websocket.close(1008)
|
||||
return
|
||||
|
||||
# 2. 创建连接对象并加入连接池
|
||||
connection = WebSocketConnection(
|
||||
connection_id=connection_id,
|
||||
websocket=websocket,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
session_type=session_type,
|
||||
created_at=datetime.now(),
|
||||
last_ping=datetime.now(),
|
||||
)
|
||||
|
||||
session_key = connection.session_key
|
||||
ws_pool = self.ap.ws_pool
|
||||
ws_pool.add_connection(connection)
|
||||
|
||||
# 3. 发送连接成功事件
|
||||
await connection.send(
|
||||
'connected', {'connection_id': connection_id, 'session_type': session_type, 'pipeline_uuid': pipeline_uuid}
|
||||
)
|
||||
|
||||
logger.info(f"WebSocket connected: {connection_id} [pipeline={pipeline_uuid}, session={session_type}]")
|
||||
|
||||
# 4. 进入消息处理循环
|
||||
while True:
|
||||
try:
|
||||
message = await websocket.receive_json()
|
||||
await handle_client_event(connection, message, self.ap)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"WebSocket connection cancelled: {connection_id}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving message from {connection_id}: {e}")
|
||||
break
|
||||
|
||||
except quart.exceptions.WebsocketDisconnected:
|
||||
logger.info(f"WebSocket disconnected: {connection_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for {connection_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# 清理连接
|
||||
if connection and session_key:
|
||||
ws_pool = self.ap.ws_pool
|
||||
await ws_pool.remove_connection(connection_id, session_key)
|
||||
logger.info(f"WebSocket connection cleaned up: {connection_id}")
|
||||
@@ -128,10 +128,8 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
|
||||
file_bytes = file.read()
|
||||
|
||||
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
|
||||
|
||||
data = {
|
||||
'plugin_file': file_base64,
|
||||
'plugin_file': file_bytes,
|
||||
}
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
|
||||
@@ -91,3 +91,26 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
)
|
||||
|
||||
return self.success(data=resp)
|
||||
|
||||
@self.route(
|
||||
'/status/plugin-system',
|
||||
methods=['GET'],
|
||||
auth_type=group.AuthType.USER_TOKEN,
|
||||
)
|
||||
async def _() -> str:
|
||||
plugin_connector_error = 'ok'
|
||||
is_connected = True
|
||||
|
||||
try:
|
||||
await self.ap.plugin_connector.ping_plugin_runtime()
|
||||
except Exception as e:
|
||||
plugin_connector_error = str(e)
|
||||
is_connected = False
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'is_enable': self.ap.plugin_connector.is_enable_plugin,
|
||||
'is_connected': is_connected,
|
||||
'plugin_connector_error': plugin_connector_error,
|
||||
}
|
||||
)
|
||||
|
||||
211
pkg/api/http/service/websocket_pool.py
Normal file
211
pkg/api/http/service/websocket_pool.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""WebSocket 连接池管理
|
||||
|
||||
用于管理流水线调试的 WebSocket 连接,支持会话隔离和消息广播。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import quart
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketConnection:
|
||||
"""单个 WebSocket 连接"""
|
||||
|
||||
connection_id: str
|
||||
websocket: quart.websocket.WebSocket
|
||||
pipeline_uuid: str
|
||||
session_type: str # 'person' 或 'group'
|
||||
created_at: datetime
|
||||
last_ping: datetime
|
||||
|
||||
@property
|
||||
def session_key(self) -> str:
|
||||
"""会话唯一标识: pipeline_uuid:session_type"""
|
||||
return f"{self.pipeline_uuid}:{self.session_type}"
|
||||
|
||||
async def send(self, event_type: str, data: dict):
|
||||
"""发送事件到客户端
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
data: 事件数据
|
||||
"""
|
||||
try:
|
||||
await self.websocket.send_json({"type": event_type, "data": data})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message to {self.connection_id}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class WebSocketConnectionPool:
|
||||
"""WebSocket 连接池 - 按会话隔离
|
||||
|
||||
连接池结构:
|
||||
connections[session_key][connection_id] = WebSocketConnection
|
||||
其中 session_key = f"{pipeline_uuid}:{session_type}"
|
||||
|
||||
这样可以确保:
|
||||
- person 和 group 会话完全隔离
|
||||
- 不同 pipeline 的会话隔离
|
||||
- 同一会话的多个连接可以同步接收消息(多标签页)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.connections: dict[str, dict[str, WebSocketConnection]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def add_connection(self, conn: WebSocketConnection):
|
||||
"""添加连接到指定会话
|
||||
|
||||
Args:
|
||||
conn: WebSocket 连接对象
|
||||
"""
|
||||
session_key = conn.session_key
|
||||
|
||||
if session_key not in self.connections:
|
||||
self.connections[session_key] = {}
|
||||
|
||||
self.connections[session_key][conn.connection_id] = conn
|
||||
|
||||
logger.info(
|
||||
f"WebSocket connection added: {conn.connection_id} "
|
||||
f"to session {session_key} "
|
||||
f"(total: {len(self.connections[session_key])} connections)"
|
||||
)
|
||||
|
||||
async def remove_connection(self, connection_id: str, session_key: str):
|
||||
"""从指定会话移除连接
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID
|
||||
session_key: 会话标识
|
||||
"""
|
||||
async with self._lock:
|
||||
if session_key in self.connections:
|
||||
conn = self.connections[session_key].pop(connection_id, None)
|
||||
|
||||
# 如果该会话没有连接了,清理会话
|
||||
if not self.connections[session_key]:
|
||||
del self.connections[session_key]
|
||||
|
||||
if conn:
|
||||
logger.info(
|
||||
f"WebSocket connection removed: {connection_id} "
|
||||
f"from session {session_key} "
|
||||
f"(remaining: {len(self.connections.get(session_key, {}))} connections)"
|
||||
)
|
||||
|
||||
def get_connection(self, connection_id: str, session_key: str) -> Optional[WebSocketConnection]:
|
||||
"""获取指定连接
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID
|
||||
session_key: 会话标识
|
||||
|
||||
Returns:
|
||||
WebSocketConnection 或 None
|
||||
"""
|
||||
return self.connections.get(session_key, {}).get(connection_id)
|
||||
|
||||
def get_connections_by_session(self, pipeline_uuid: str, session_type: str) -> list[WebSocketConnection]:
|
||||
"""获取指定会话的所有连接
|
||||
|
||||
Args:
|
||||
pipeline_uuid: 流水线 UUID
|
||||
session_type: 会话类型 ('person' 或 'group')
|
||||
|
||||
Returns:
|
||||
连接列表
|
||||
"""
|
||||
session_key = f"{pipeline_uuid}:{session_type}"
|
||||
return list(self.connections.get(session_key, {}).values())
|
||||
|
||||
async def broadcast_to_session(self, pipeline_uuid: str, session_type: str, event_type: str, data: dict):
|
||||
"""广播消息到指定会话的所有连接
|
||||
|
||||
Args:
|
||||
pipeline_uuid: 流水线 UUID
|
||||
session_type: 会话类型 ('person' 或 'group')
|
||||
event_type: 事件类型
|
||||
data: 事件数据
|
||||
"""
|
||||
connections = self.get_connections_by_session(pipeline_uuid, session_type)
|
||||
|
||||
if not connections:
|
||||
logger.debug(f"No connections for session {pipeline_uuid}:{session_type}, skipping broadcast")
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Broadcasting {event_type} to session {pipeline_uuid}:{session_type}, " f"{len(connections)} connections"
|
||||
)
|
||||
|
||||
# 并发发送到所有连接,忽略失败的连接
|
||||
results = await asyncio.gather(*[conn.send(event_type, data) for conn in connections], return_exceptions=True)
|
||||
|
||||
# 统计失败的连接
|
||||
failed_count = sum(1 for result in results if isinstance(result, Exception))
|
||||
if failed_count > 0:
|
||||
logger.warning(f"Failed to send to {failed_count}/{len(connections)} connections")
|
||||
|
||||
def get_all_sessions(self) -> list[str]:
|
||||
"""获取所有活跃会话的 session_key 列表
|
||||
|
||||
Returns:
|
||||
会话标识列表
|
||||
"""
|
||||
return list(self.connections.keys())
|
||||
|
||||
def get_connection_count(self, pipeline_uuid: str, session_type: str) -> int:
|
||||
"""获取指定会话的连接数量
|
||||
|
||||
Args:
|
||||
pipeline_uuid: 流水线 UUID
|
||||
session_type: 会话类型
|
||||
|
||||
Returns:
|
||||
连接数量
|
||||
"""
|
||||
session_key = f"{pipeline_uuid}:{session_type}"
|
||||
return len(self.connections.get(session_key, {}))
|
||||
|
||||
async def cleanup_stale_connections(self, timeout_seconds: int = 120):
|
||||
"""清理超时的连接
|
||||
|
||||
Args:
|
||||
timeout_seconds: 超时时间(秒)
|
||||
"""
|
||||
now = datetime.now()
|
||||
stale_connections = []
|
||||
|
||||
# 查找超时连接
|
||||
for session_key, session_conns in self.connections.items():
|
||||
for conn_id, conn in session_conns.items():
|
||||
elapsed = (now - conn.last_ping).total_seconds()
|
||||
if elapsed > timeout_seconds:
|
||||
stale_connections.append((conn_id, session_key))
|
||||
|
||||
# 移除超时连接
|
||||
for conn_id, session_key in stale_connections:
|
||||
logger.warning(f"Removing stale connection: {conn_id} from {session_key}")
|
||||
await self.remove_connection(conn_id, session_key)
|
||||
|
||||
# 尝试关闭 WebSocket
|
||||
try:
|
||||
conn = self.get_connection(conn_id, session_key)
|
||||
if conn:
|
||||
await conn.websocket.close(1000, "Connection timeout")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing stale connection {conn_id}: {e}")
|
||||
|
||||
if stale_connections:
|
||||
logger.info(f"Cleaned up {len(stale_connections)} stale connections")
|
||||
@@ -105,6 +105,10 @@ class Application:
|
||||
|
||||
storage_mgr: storagemgr.StorageMgr = None
|
||||
|
||||
# ========= WebSocket =========
|
||||
|
||||
ws_pool = None # WebSocketConnectionPool
|
||||
|
||||
# ========= HTTP Services =========
|
||||
|
||||
user_service: user_service.UserService = None
|
||||
|
||||
@@ -19,6 +19,7 @@ from ...api.http.service import model as model_service
|
||||
from ...api.http.service import pipeline as pipeline_service
|
||||
from ...api.http.service import bot as bot_service
|
||||
from ...api.http.service import knowledge as knowledge_service
|
||||
from ...api.http.service import websocket_pool
|
||||
from ...discover import engine as discover_engine
|
||||
from ...storage import mgr as storagemgr
|
||||
from ...utils import logcache
|
||||
@@ -87,10 +88,18 @@ class BuildAppStage(stage.BootingStage):
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
# Initialize WebSocket connection pool
|
||||
ws_pool_inst = websocket_pool.WebSocketConnectionPool()
|
||||
ap.ws_pool = ws_pool_inst
|
||||
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.platform_mgr = im_mgr_inst
|
||||
|
||||
# Inject WebSocket pool into WebChatAdapter
|
||||
if hasattr(ap.platform_mgr, 'webchat_proxy_bot') and ap.platform_mgr.webchat_proxy_bot:
|
||||
ap.platform_mgr.webchat_proxy_bot.adapter.set_ws_pool(ws_pool_inst)
|
||||
|
||||
pipeline_mgr = pipelinemgr.PipelineManager(ap)
|
||||
await pipeline_mgr.initialize()
|
||||
ap.pipeline_mgr = pipeline_mgr
|
||||
|
||||
@@ -21,10 +21,15 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
- resp_message_chain
|
||||
"""
|
||||
|
||||
strategy_impl: strategy.LongTextStrategy
|
||||
strategy_impl: strategy.LongTextStrategy | None
|
||||
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
config = pipeline_config['output']['long-text-processing']
|
||||
|
||||
if config['strategy'] == 'none':
|
||||
self.strategy_impl = None
|
||||
return
|
||||
|
||||
if config['strategy'] == 'image':
|
||||
use_font = config['font-path']
|
||||
try:
|
||||
@@ -67,6 +72,10 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
await self.strategy_impl.initialize()
|
||||
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
if self.strategy_impl is None:
|
||||
self.ap.logger.debug('Long message processing strategy is not set, skip long message processing.')
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
# 检查是否包含非 Plain 组件
|
||||
contains_non_plain = False
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
platform_message.ForwardMessageNode(
|
||||
sender_id=query.adapter.bot_account_id,
|
||||
sender_name='User',
|
||||
message_chain=platform_message.MessageChain([message]),
|
||||
message_chain=platform_message.MessageChain([platform_message.Plain(text=message)]),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ class RuntimePipeline:
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
result.user_notice.insert(0, platform_message.At(query.message_event.sender.id))
|
||||
result.user_notice.insert(0, platform_message.At(target=query.message_event.sender.id))
|
||||
if await query.adapter.is_stream_output_supported():
|
||||
await query.adapter.reply_message_chunk(
|
||||
message_source=query.message_event,
|
||||
@@ -213,7 +213,7 @@ class RuntimePipeline:
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
|
||||
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
|
||||
self.ap.logger.error(f'Error processing query {query.query_id} stage={inst_name} : {e}')
|
||||
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
|
||||
finally:
|
||||
self.ap.logger.debug(f'Query {query.query_id} processed')
|
||||
|
||||
@@ -35,11 +35,17 @@ class PreProcessor(stage.PipelineStage):
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
# When not local-agent, llm_model is None
|
||||
llm_model = (
|
||||
await self.ap.model_mgr.get_model_by_uuid(query.pipeline_config['ai']['local-agent']['model'])
|
||||
if selected_runner == 'local-agent'
|
||||
else None
|
||||
)
|
||||
try:
|
||||
llm_model = (
|
||||
await self.ap.model_mgr.get_model_by_uuid(query.pipeline_config['ai']['local-agent']['model'])
|
||||
if selected_runner == 'local-agent'
|
||||
else None
|
||||
)
|
||||
except ValueError:
|
||||
self.ap.logger.warning(
|
||||
f'LLM model {query.pipeline_config["ai"]["local-agent"]["model"] + " "}not found or not configured'
|
||||
)
|
||||
llm_model = None
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(
|
||||
query,
|
||||
@@ -54,7 +60,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.prompt = conversation.prompt.copy()
|
||||
query.messages = conversation.messages.copy()
|
||||
|
||||
if selected_runner == 'local-agent':
|
||||
if selected_runner == 'local-agent' and llm_model:
|
||||
query.use_funcs = []
|
||||
query.use_llm_model_uuid = llm_model.model_entity.uuid
|
||||
|
||||
@@ -72,7 +78,11 @@ class PreProcessor(stage.PipelineStage):
|
||||
|
||||
# Check if this model supports vision, if not, remove all images
|
||||
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
|
||||
if selected_runner == 'local-agent' and not llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if (
|
||||
selected_runner == 'local-agent'
|
||||
and llm_model
|
||||
and not llm_model.model_entity.abilities.__contains__('vision')
|
||||
):
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
for me in msg.content:
|
||||
@@ -89,15 +99,22 @@ class PreProcessor(stage.PipelineStage):
|
||||
content_list.append(provider_message.ContentElement.from_text(me.text))
|
||||
plain_text += me.text
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if selected_runner != 'local-agent' or (
|
||||
llm_model and llm_model.model_entity.abilities.__contains__('vision')
|
||||
):
|
||||
if me.base64 is not None:
|
||||
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
|
||||
elif isinstance(me, platform_message.File):
|
||||
# if me.url is not None:
|
||||
content_list.append(provider_message.ContentElement.from_file_url(me.url, me.name))
|
||||
elif isinstance(me, platform_message.Quote) and qoute_msg:
|
||||
for msg in me.origin:
|
||||
if isinstance(msg, platform_message.Plain):
|
||||
content_list.append(provider_message.ContentElement.from_text(msg.text))
|
||||
elif isinstance(msg, platform_message.Image):
|
||||
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if selected_runner != 'local-agent' or (
|
||||
llm_model and llm_model.model_entity.abilities.__contains__('vision')
|
||||
):
|
||||
if msg.base64 is not None:
|
||||
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from .. import handler
|
||||
from ... import entities
|
||||
from ....provider import runner as runner_module
|
||||
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from ....utils import importutil
|
||||
from ....provider import runners
|
||||
@@ -47,18 +46,19 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||
|
||||
is_create_card = False # 判断下是否需要创建流式卡片
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||
if event_ctx.event.reply_message_chain is not None:
|
||||
mc = event_ctx.event.reply_message_chain
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
else:
|
||||
if event_ctx.event.alter is not None:
|
||||
if event_ctx.event.user_message_alter is not None:
|
||||
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
|
||||
query.user_message.content = event_ctx.event.alter
|
||||
query.user_message.content = event_ctx.event.user_message_alter
|
||||
|
||||
text_length = 0
|
||||
try:
|
||||
|
||||
@@ -5,7 +5,6 @@ import typing
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.events as events
|
||||
@@ -49,8 +48,8 @@ class CommandHandler(handler.MessageHandler):
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||
if event_ctx.event.reply_message_chain is not None:
|
||||
mc = event_ctx.event.reply_message_chain
|
||||
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
@@ -59,9 +58,6 @@ class CommandHandler(handler.MessageHandler):
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
|
||||
else:
|
||||
if event_ctx.event.alter is not None:
|
||||
query.message_chain = platform_message.MessageChain([platform_message.Plain(event_ctx.event.alter)])
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
async for ret in self.ap.cmd_mgr.execute(
|
||||
@@ -78,7 +74,12 @@ class CommandHandler(handler.MessageHandler):
|
||||
self.ap.logger.info(f'Command({query.query_id}) error: {self.cut_str(str(ret.error))}')
|
||||
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
elif ret.text is not None or ret.image_url is not None or ret.image_base64 is not None:
|
||||
elif (
|
||||
ret.text is not None
|
||||
or ret.image_url is not None
|
||||
or ret.image_base64 is not None
|
||||
or ret.file_url is not None
|
||||
):
|
||||
content: list[provider_message.ContentElement] = []
|
||||
|
||||
if ret.text is not None:
|
||||
@@ -90,6 +91,9 @@ class CommandHandler(handler.MessageHandler):
|
||||
if ret.image_base64 is not None:
|
||||
content.append(provider_message.ContentElement.from_image_base64(ret.image_base64))
|
||||
|
||||
if ret.file_url is not None:
|
||||
# 此时为 file 类型
|
||||
content.append(provider_message.ContentElement.from_file_url(ret.file_url, ret.file_name))
|
||||
query.resp_messages.append(
|
||||
provider_message.Message(
|
||||
role='command',
|
||||
|
||||
@@ -33,7 +33,7 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
query.resp_message_chain[-1].insert(0, platform_message.At(query.message_event.sender.id))
|
||||
query.resp_message_chain[-1].insert(0, platform_message.At(target=query.message_event.sender.id))
|
||||
|
||||
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
|
||||
|
||||
|
||||
@@ -16,26 +16,17 @@ class AtBotRule(rule_model.GroupRespondRule):
|
||||
rule_dict: dict,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
found = False
|
||||
|
||||
def remove_at(message_chain: platform_message.MessageChain):
|
||||
nonlocal found
|
||||
for component in message_chain.root:
|
||||
if isinstance(component, platform_message.At) and component.target == query.adapter.bot_account_id:
|
||||
if isinstance(component, platform_message.At) and str(component.target) == str(query.adapter.bot_account_id):
|
||||
message_chain.remove(component)
|
||||
found = True
|
||||
break
|
||||
|
||||
remove_at(message_chain)
|
||||
remove_at(message_chain) # 回复消息时会at两次,检查并删除重复的
|
||||
|
||||
# if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
# message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
# if message_chain.has(
|
||||
# platform_message.At(query.adapter.bot_account_id)
|
||||
# ): # 回复消息时会at两次,检查并删除重复的
|
||||
# message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
# return entities.RuleJudgeResult(
|
||||
# matching=True,
|
||||
# replacement=message_chain,
|
||||
# )
|
||||
|
||||
return entities.RuleJudgeResult(matching=False, replacement=message_chain)
|
||||
return entities.RuleJudgeResult(matching=found, replacement=message_chain)
|
||||
|
||||
@@ -80,8 +80,8 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
new_query=query,
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
|
||||
if event_ctx.event.reply_message_chain is not None:
|
||||
query.resp_message_chain.append(event_ctx.event.reply_message_chain)
|
||||
|
||||
else:
|
||||
query.resp_message_chain.append(result.get_content_platform_message_chain())
|
||||
@@ -123,10 +123,8 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
new_query=query,
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(text=event_ctx.event.reply)
|
||||
)
|
||||
if event_ctx.event.reply_message_chain is not None:
|
||||
query.resp_message_chain.append(event_ctx.event.reply_message_chain)
|
||||
|
||||
else:
|
||||
query.resp_message_chain.append(
|
||||
|
||||
@@ -41,6 +41,8 @@ class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
||||
yiri_msg_list.append(platform_message.Plain(text=text_content))
|
||||
if event.picture:
|
||||
yiri_msg_list.append(platform_message.Image(base64=event.picture))
|
||||
if event.file:
|
||||
yiri_msg_list.append(platform_message.File(url=event.file, name=event.name))
|
||||
if event.audio:
|
||||
yiri_msg_list.append(platform_message.Voice(base64=event.audio))
|
||||
|
||||
|
||||
@@ -139,19 +139,15 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
bot = QQOfficialClient(
|
||||
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=logger
|
||||
)
|
||||
|
||||
required_keys = [
|
||||
'appid',
|
||||
'secret',
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise command_errors.ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
self.bot = QQOfficialClient(
|
||||
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
bot=bot,
|
||||
bot_account_id=config['appid'],
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
|
||||
@@ -102,7 +102,7 @@ class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
sender=platform_entities.Friend(
|
||||
id=event.effective_chat.id,
|
||||
nickname=event.effective_chat.first_name,
|
||||
remark=event.effective_chat.id,
|
||||
remark=str(event.effective_chat.id),
|
||||
),
|
||||
message_chain=lb_message,
|
||||
time=event.message.date.timestamp(),
|
||||
|
||||
@@ -58,6 +58,7 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
debug_messages: dict[str, list[dict]] = pydantic.Field(default_factory=dict, exclude=True)
|
||||
|
||||
ap: app.Application = pydantic.Field(exclude=True)
|
||||
ws_pool: typing.Any = pydantic.Field(exclude=True, default=None) # WebSocketConnectionPool
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||
super().__init__(
|
||||
@@ -72,6 +73,15 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
self.bot_account_id = 'webchatbot'
|
||||
|
||||
self.debug_messages = {}
|
||||
self.ws_pool = None
|
||||
|
||||
def set_ws_pool(self, ws_pool):
|
||||
"""设置 WebSocket 连接池
|
||||
|
||||
Args:
|
||||
ws_pool: WebSocketConnectionPool 实例
|
||||
"""
|
||||
self.ws_pool = ws_pool
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
@@ -130,7 +140,7 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
) -> dict:
|
||||
"""回复消息"""
|
||||
"""Reply message chunk - supports both SSE (legacy) and WebSocket"""
|
||||
message_data = WebChatMessage(
|
||||
id=-1,
|
||||
role='assistant',
|
||||
@@ -139,24 +149,32 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
# notify waiter
|
||||
session = (
|
||||
self.webchat_group_session
|
||||
if isinstance(message_source, platform_events.GroupMessage)
|
||||
else self.webchat_person_session
|
||||
)
|
||||
if message_source.message_chain.message_id not in session.resp_waiters:
|
||||
# session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue()
|
||||
queue = session.resp_queues[message_source.message_chain.message_id]
|
||||
# Determine session type
|
||||
if isinstance(message_source, platform_events.GroupMessage):
|
||||
session_type = 'group'
|
||||
session = self.webchat_group_session
|
||||
else: # FriendMessage
|
||||
session_type = 'person'
|
||||
session = self.webchat_person_session
|
||||
|
||||
# if isinstance(message_source, platform_events.FriendMessage):
|
||||
# queue = self.webchat_person_session.resp_queues[message_source.message_chain.message_id]
|
||||
# elif isinstance(message_source, platform_events.GroupMessage):
|
||||
# queue = self.webchat_group_session.resp_queues[message_source.message_chain.message_id]
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_data.is_final = True
|
||||
# print(message_data)
|
||||
await queue.put(message_data)
|
||||
# Legacy SSE support: put message into queue
|
||||
if message_source.message_chain.message_id in session.resp_queues:
|
||||
queue = session.resp_queues[message_source.message_chain.message_id]
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_data.is_final = True
|
||||
await queue.put(message_data)
|
||||
|
||||
# WebSocket support: broadcast to all connections
|
||||
if self.ws_pool:
|
||||
pipeline_uuid = self.ap.platform_mgr.webchat_proxy_bot.bot_entity.use_pipeline_uuid
|
||||
|
||||
# Determine event type
|
||||
event_type = 'message_complete' if (is_final and bot_message.tool_calls is None) else 'message_chunk'
|
||||
|
||||
# Broadcast to specified session only
|
||||
await self.ws_pool.broadcast_to_session(
|
||||
pipeline_uuid=pipeline_uuid, session_type=session_type, event_type=event_type, data=message_data.model_dump()
|
||||
)
|
||||
|
||||
return message_data.model_dump()
|
||||
|
||||
|
||||
@@ -139,7 +139,7 @@ class WeChatPadMessageConverter(abstract_platform_adapter.AbstractMessageConvert
|
||||
pattern = r'@\S{1,20}'
|
||||
content_no_preifx = re.sub(pattern, '', content_no_preifx)
|
||||
|
||||
return platform_message.MessageChain([platform_message.Plain(content_no_preifx)])
|
||||
return platform_message.MessageChain([platform_message.Plain(text=content_no_preifx)])
|
||||
|
||||
async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
|
||||
"""处理图像消息 (msg_type=3)"""
|
||||
@@ -265,7 +265,7 @@ class WeChatPadMessageConverter(abstract_platform_adapter.AbstractMessageConvert
|
||||
# 文本消息
|
||||
try:
|
||||
if '<msg>' not in quote_data:
|
||||
quote_data_message_list.append(platform_message.Plain(quote_data))
|
||||
quote_data_message_list.append(platform_message.Plain(text=quote_data))
|
||||
else:
|
||||
# 引用消息展开
|
||||
quote_data_xml = ET.fromstring(quote_data)
|
||||
@@ -280,7 +280,7 @@ class WeChatPadMessageConverter(abstract_platform_adapter.AbstractMessageConvert
|
||||
quote_data_message_list.extend(await self._handler_compound(None, quote_data))
|
||||
except Exception as e:
|
||||
self.logger.error(f'处理引用消息异常 expcetion:{e}')
|
||||
quote_data_message_list.append(platform_message.Plain(quote_data))
|
||||
quote_data_message_list.append(platform_message.Plain(text=quote_data))
|
||||
message_list.append(
|
||||
platform_message.Quote(
|
||||
sender_id=sender_id,
|
||||
@@ -290,7 +290,7 @@ class WeChatPadMessageConverter(abstract_platform_adapter.AbstractMessageConvert
|
||||
if len(user_data) > 0:
|
||||
pattern = r'@\S{1,20}'
|
||||
user_data = re.sub(pattern, '', user_data)
|
||||
message_list.append(platform_message.Plain(user_data))
|
||||
message_list.append(platform_message.Plain(text=user_data))
|
||||
|
||||
return platform_message.MessageChain(message_list)
|
||||
|
||||
@@ -543,7 +543,6 @@ class WeChatPadAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
] = {}
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
|
||||
quart_app = quart.Quart(__name__)
|
||||
|
||||
message_converter = WeChatPadMessageConverter(config, logger)
|
||||
@@ -551,15 +550,14 @@ class WeChatPadAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
bot = WeChatPadClient(config['wechatpad_url'], config['token'])
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger = logger,
|
||||
quart_app = quart_app,
|
||||
message_converter =message_converter,
|
||||
event_converter = event_converter,
|
||||
logger=logger,
|
||||
quart_app=quart_app,
|
||||
message_converter=message_converter,
|
||||
event_converter=event_converter,
|
||||
listeners={},
|
||||
bot_account_id ='',
|
||||
name="WeChatPad",
|
||||
bot_account_id='',
|
||||
name='WeChatPad',
|
||||
bot=bot,
|
||||
|
||||
)
|
||||
|
||||
async def ws_message(self, data):
|
||||
|
||||
@@ -18,7 +18,7 @@ from langbot_plugin.api.entities import events
|
||||
from langbot_plugin.api.entities import context
|
||||
import langbot_plugin.runtime.io.connection as base_connection
|
||||
from langbot_plugin.api.definition.components.manifest import ComponentManifest
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
|
||||
from ..core import taskmgr
|
||||
|
||||
@@ -32,6 +32,8 @@ class PluginRuntimeConnector:
|
||||
|
||||
handler_task: asyncio.Task
|
||||
|
||||
heartbeat_task: asyncio.Task | None = None
|
||||
|
||||
stdio_client_controller: stdio_client_controller.StdioClientController
|
||||
|
||||
ctrl: stdio_client_controller.StdioClientController | ws_client_controller.WebSocketClientController
|
||||
@@ -40,6 +42,9 @@ class PluginRuntimeConnector:
|
||||
[PluginRuntimeConnector], typing.Coroutine[typing.Any, typing.Any, None]
|
||||
]
|
||||
|
||||
is_enable_plugin: bool = True
|
||||
"""Mark if the plugin system is enabled"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap: app.Application,
|
||||
@@ -49,8 +54,22 @@ class PluginRuntimeConnector:
|
||||
):
|
||||
self.ap = ap
|
||||
self.runtime_disconnect_callback = runtime_disconnect_callback
|
||||
self.is_enable_plugin = self.ap.instance_config.data.get('plugin', {}).get('enable', True)
|
||||
|
||||
async def heartbeat_loop(self):
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
try:
|
||||
await self.ping_plugin_runtime()
|
||||
self.ap.logger.debug('Heartbeat to plugin runtime success.')
|
||||
except Exception as e:
|
||||
self.ap.logger.debug(f'Failed to heartbeat to plugin runtime: {e}')
|
||||
|
||||
async def initialize(self):
|
||||
if not self.is_enable_plugin:
|
||||
self.ap.logger.info('Plugin system is disabled.')
|
||||
return
|
||||
|
||||
async def new_connection_callback(connection: base_connection.Connection):
|
||||
async def disconnect_callback(rchandler: handler.RuntimeConnectionHandler) -> bool:
|
||||
if platform.get_platform() == 'docker' or platform.use_websocket_to_connect_plugin_runtime():
|
||||
@@ -64,6 +83,7 @@ class PluginRuntimeConnector:
|
||||
return False
|
||||
|
||||
self.handler = handler.RuntimeConnectionHandler(connection, disconnect_callback, self.ap)
|
||||
|
||||
self.handler_task = asyncio.create_task(self.handler.run())
|
||||
_ = await self.handler.ping()
|
||||
self.ap.logger.info('Connected to plugin runtime.')
|
||||
@@ -77,8 +97,13 @@ class PluginRuntimeConnector:
|
||||
'runtime_ws_url', 'ws://langbot_plugin_runtime:5400/control/ws'
|
||||
)
|
||||
|
||||
async def make_connection_failed_callback(ctrl: ws_client_controller.WebSocketClientController) -> None:
|
||||
self.ap.logger.error('Failed to connect to plugin runtime, trying to reconnect...')
|
||||
async def make_connection_failed_callback(
|
||||
ctrl: ws_client_controller.WebSocketClientController, exc: Exception = None
|
||||
) -> None:
|
||||
if exc is not None:
|
||||
self.ap.logger.error(f'Failed to connect to plugin runtime({ws_url}): {exc}')
|
||||
else:
|
||||
self.ap.logger.error(f'Failed to connect to plugin runtime({ws_url}), trying to reconnect...')
|
||||
await self.runtime_disconnect_callback(self)
|
||||
|
||||
self.ctrl = ws_client_controller.WebSocketClientController(
|
||||
@@ -98,17 +123,34 @@ class PluginRuntimeConnector:
|
||||
)
|
||||
task = self.ctrl.run(new_connection_callback)
|
||||
|
||||
if self.heartbeat_task is None:
|
||||
self.heartbeat_task = asyncio.create_task(self.heartbeat_loop())
|
||||
|
||||
asyncio.create_task(task)
|
||||
|
||||
async def initialize_plugins(self):
|
||||
pass
|
||||
|
||||
async def ping_plugin_runtime(self):
|
||||
if not hasattr(self, 'handler'):
|
||||
raise Exception('Plugin runtime is not connected')
|
||||
|
||||
return await self.handler.ping()
|
||||
|
||||
async def install_plugin(
|
||||
self,
|
||||
install_source: PluginInstallSource,
|
||||
install_info: dict[str, Any],
|
||||
task_context: taskmgr.TaskContext | None = None,
|
||||
):
|
||||
if install_source == PluginInstallSource.LOCAL:
|
||||
# transfer file before install
|
||||
file_bytes = install_info['plugin_file']
|
||||
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
|
||||
install_info['plugin_file_key'] = file_key
|
||||
del install_info['plugin_file']
|
||||
self.ap.logger.info(f'Transfered file {file_key} to plugin runtime')
|
||||
|
||||
async for ret in self.handler.install_plugin(install_source.value, install_info):
|
||||
current_action = ret.get('current_action', None)
|
||||
if current_action is not None:
|
||||
@@ -149,6 +191,9 @@ class PluginRuntimeConnector:
|
||||
task_context.trace(trace)
|
||||
|
||||
async def list_plugins(self) -> list[dict[str, Any]]:
|
||||
if not self.is_enable_plugin:
|
||||
return []
|
||||
|
||||
return await self.handler.list_plugins()
|
||||
|
||||
async def get_plugin_info(self, author: str, plugin_name: str) -> dict[str, Any]:
|
||||
@@ -167,21 +212,33 @@ class PluginRuntimeConnector:
|
||||
) -> context.EventContext:
|
||||
event_ctx = context.EventContext.from_event(event)
|
||||
|
||||
event_ctx_result = await self.handler.emit_event(event_ctx.model_dump(serialize_as_any=True))
|
||||
if not self.is_enable_plugin:
|
||||
return event_ctx
|
||||
|
||||
event_ctx_result = await self.handler.emit_event(event_ctx.model_dump(serialize_as_any=False))
|
||||
|
||||
event_ctx = context.EventContext.model_validate(event_ctx_result['event_context'])
|
||||
|
||||
return event_ctx
|
||||
|
||||
async def list_tools(self) -> list[ComponentManifest]:
|
||||
if not self.is_enable_plugin:
|
||||
return []
|
||||
|
||||
list_tools_data = await self.handler.list_tools()
|
||||
|
||||
return [ComponentManifest.model_validate(tool) for tool in list_tools_data]
|
||||
|
||||
async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self.is_enable_plugin:
|
||||
return {'error': 'Tool not found: plugin system is disabled'}
|
||||
|
||||
return await self.handler.call_tool(tool_name, parameters)
|
||||
|
||||
async def list_commands(self) -> list[ComponentManifest]:
|
||||
if not self.is_enable_plugin:
|
||||
return []
|
||||
|
||||
list_commands_data = await self.handler.list_commands()
|
||||
|
||||
return [ComponentManifest.model_validate(command) for command in list_commands_data]
|
||||
@@ -189,6 +246,9 @@ class PluginRuntimeConnector:
|
||||
async def execute_command(
|
||||
self, command_ctx: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if not self.is_enable_plugin:
|
||||
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(command_ctx.command))
|
||||
|
||||
gen = self.handler.execute_command(command_ctx.model_dump(serialize_as_any=True))
|
||||
|
||||
async for ret in gen:
|
||||
@@ -197,6 +257,10 @@ class PluginRuntimeConnector:
|
||||
yield cmd_ret
|
||||
|
||||
def dispose(self):
|
||||
if isinstance(self.ctrl, stdio_client_controller.StdioClientController):
|
||||
if self.is_enable_plugin and isinstance(self.ctrl, stdio_client_controller.StdioClientController):
|
||||
self.ap.logger.info('Terminating plugin runtime process...')
|
||||
self.ctrl.process.terminate()
|
||||
|
||||
if self.heartbeat_task is not None:
|
||||
self.heartbeat_task.cancel()
|
||||
self.heartbeat_task = None
|
||||
|
||||
@@ -536,7 +536,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
{
|
||||
'event_context': event_context,
|
||||
},
|
||||
timeout=30,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -546,7 +546,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.LIST_TOOLS,
|
||||
{},
|
||||
timeout=10,
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
return result['tools']
|
||||
@@ -560,7 +560,18 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
'plugin_name': plugin_name,
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
plugin_icon_file_key = result['plugin_icon_file_key']
|
||||
mime_type = result['mime_type']
|
||||
|
||||
plugin_icon_bytes = await self.read_local_file(plugin_icon_file_key)
|
||||
|
||||
await self.delete_local_file(plugin_icon_file_key)
|
||||
|
||||
return {
|
||||
'plugin_icon_base64': base64.b64encode(plugin_icon_bytes).decode('utf-8'),
|
||||
'mime_type': mime_type,
|
||||
}
|
||||
|
||||
async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call tool"""
|
||||
@@ -570,7 +581,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
'tool_name': tool_name,
|
||||
'tool_parameters': parameters,
|
||||
},
|
||||
timeout=30,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
return result['tool_response']
|
||||
@@ -591,7 +602,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
{
|
||||
'command_context': command_context,
|
||||
},
|
||||
timeout=30,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
async for ret in gen:
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import dashscope
|
||||
import openai
|
||||
|
||||
from . import modelscopechatcmpl
|
||||
from .. import requester
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
|
||||
@@ -15,3 +20,211 @@ class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
|
||||
'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
async def _closure_stream(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args['tools'] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages.copy()
|
||||
|
||||
is_use_dashscope_call = False # 是否使用阿里原生库调用
|
||||
is_enable_multi_model = True # 是否支持多轮对话
|
||||
use_time_num = 0 # 模型已调用次数,防止存在多文件时重复调用
|
||||
use_time_ids = [] # 已调用的ID列表
|
||||
message_id = 0 # 记录消息序号
|
||||
|
||||
for msg in messages:
|
||||
# print(msg)
|
||||
if 'content' in msg and isinstance(msg['content'], list):
|
||||
for me in msg['content']:
|
||||
if me['type'] == 'image_base64':
|
||||
me['image_url'] = {'url': me['image_base64']}
|
||||
me['type'] = 'image_url'
|
||||
del me['image_base64']
|
||||
elif me['type'] == 'file_url' and '.' in me.get('file_name', ''):
|
||||
# 1. 视频文件推理
|
||||
# https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2845871
|
||||
file_type = me.get('file_name').lower().split('.')[-1]
|
||||
if file_type in ['mp4', 'avi', 'mkv', 'mov', 'flv', 'wmv']:
|
||||
me['type'] = 'video_url'
|
||||
me['video_url'] = {'url': me['file_url']}
|
||||
del me['file_url']
|
||||
del me['file_name']
|
||||
use_time_num +=1
|
||||
use_time_ids.append(message_id)
|
||||
is_enable_multi_model = False
|
||||
# 2. 语音文件识别, 无法通过openai的audio字段传递,暂时不支持
|
||||
# https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2979031
|
||||
elif file_type in ['aac', 'amr', 'aiff', 'flac', 'm4a',
|
||||
'mp3', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma']:
|
||||
me['audio'] = me['file_url']
|
||||
me['type'] = 'audio'
|
||||
del me['file_url']
|
||||
del me['type']
|
||||
del me['file_name']
|
||||
is_use_dashscope_call = True
|
||||
use_time_num +=1
|
||||
use_time_ids.append(message_id)
|
||||
is_enable_multi_model = False
|
||||
message_id += 1
|
||||
|
||||
# 使用列表推导式,保留不在 use_time_ids[:-1] 中的元素,仅保留最后一个多媒体消息
|
||||
if not is_enable_multi_model and use_time_num > 1:
|
||||
messages = [msg for idx, msg in enumerate(messages) if idx not in use_time_ids[:-1]]
|
||||
|
||||
if not is_enable_multi_model:
|
||||
messages = [msg for msg in messages if 'resp_message_id' not in msg]
|
||||
|
||||
args['messages'] = messages
|
||||
args['stream'] = True
|
||||
|
||||
# 流式处理状态
|
||||
# tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
||||
chunk_idx = 0
|
||||
thinking_started = False
|
||||
thinking_ended = False
|
||||
role = 'assistant' # 默认角色
|
||||
|
||||
if is_use_dashscope_call:
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx"
|
||||
api_key=use_model.token_mgr.get_token(),
|
||||
model=use_model.model_entity.name,
|
||||
messages=messages,
|
||||
result_format="message",
|
||||
asr_options={
|
||||
# "language": "zh", # 可选,若已知音频的语种,可通过该参数指定待识别语种,以提升识别准确率
|
||||
"enable_lid": True,
|
||||
"enable_itn": False
|
||||
},
|
||||
stream=True
|
||||
)
|
||||
content_length_list = []
|
||||
previous_length = 0 # 记录上一次的内容长度
|
||||
for res in response:
|
||||
chunk = res["output"]
|
||||
# 解析 chunk 数据
|
||||
if hasattr(chunk, 'choices') and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta_content = choice["message"].content[0]["text"]
|
||||
finish_reason = choice["finish_reason"]
|
||||
content_length_list.append(len(delta_content))
|
||||
else:
|
||||
delta_content = ""
|
||||
finish_reason = None
|
||||
|
||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
||||
if chunk_idx == 0 and not delta_content:
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
# 检查 content_length_list 是否有足够的数据
|
||||
if len(content_length_list) >= 2:
|
||||
now_content = delta_content[previous_length: content_length_list[-1]]
|
||||
previous_length = content_length_list[-1] # 更新上一次的长度
|
||||
else:
|
||||
now_content = delta_content # 第一次循环时直接使用 delta_content
|
||||
previous_length = len(delta_content) # 更新上一次的长度
|
||||
|
||||
# 构建 MessageChunk - 只包含增量内容
|
||||
chunk_data = {
|
||||
'role': role,
|
||||
'content': now_content if now_content else None,
|
||||
'is_final': bool(finish_reason) and finish_reason != "null",
|
||||
}
|
||||
|
||||
# 移除 None 值
|
||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
||||
yield provider_message.MessageChunk(**chunk_data)
|
||||
chunk_idx += 1
|
||||
else:
|
||||
async for chunk in self._req_stream(args, extra_body=extra_args):
|
||||
# 解析 chunk 数据
|
||||
if hasattr(chunk, 'choices') and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
||||
finish_reason = getattr(choice, 'finish_reason', None)
|
||||
else:
|
||||
delta = {}
|
||||
finish_reason = None
|
||||
|
||||
# 从第一个 chunk 获取 role,后续使用这个 role
|
||||
if 'role' in delta and delta['role']:
|
||||
role = delta['role']
|
||||
|
||||
# 获取增量内容
|
||||
delta_content = delta.get('content', '')
|
||||
reasoning_content = delta.get('reasoning_content', '')
|
||||
|
||||
# 处理 reasoning_content
|
||||
if reasoning_content:
|
||||
# accumulated_reasoning += reasoning_content
|
||||
# 如果设置了 remove_think,跳过 reasoning_content
|
||||
if remove_think:
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
# 第一次出现 reasoning_content,添加 <think> 开始标签
|
||||
if not thinking_started:
|
||||
thinking_started = True
|
||||
delta_content = '<think>\n' + reasoning_content
|
||||
else:
|
||||
# 继续输出 reasoning_content
|
||||
delta_content = reasoning_content
|
||||
elif thinking_started and not thinking_ended and delta_content:
|
||||
# reasoning_content 结束,normal content 开始,添加 </think> 结束标签
|
||||
thinking_ended = True
|
||||
delta_content = '\n</think>\n' + delta_content
|
||||
|
||||
# 处理工具调用增量
|
||||
if delta.get('tool_calls'):
|
||||
for tool_call in delta['tool_calls']:
|
||||
if tool_call['id'] != '':
|
||||
tool_id = tool_call['id']
|
||||
if tool_call['function']['name'] is not None:
|
||||
tool_name = tool_call['function']['name']
|
||||
|
||||
if tool_call['type'] is None:
|
||||
tool_call['type'] = 'function'
|
||||
tool_call['id'] = tool_id
|
||||
tool_call['function']['name'] = tool_name
|
||||
tool_call['function']['arguments'] = (
|
||||
'' if tool_call['function']['arguments'] is None else tool_call['function']['arguments']
|
||||
)
|
||||
|
||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
||||
if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'):
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
# 构建 MessageChunk - 只包含增量内容
|
||||
chunk_data = {
|
||||
'role': role,
|
||||
'content': delta_content if delta_content else None,
|
||||
'tool_calls': delta.get('tool_calls'),
|
||||
'is_final': bool(finish_reason),
|
||||
}
|
||||
|
||||
# 移除 None 值
|
||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
||||
|
||||
yield provider_message.MessageChunk(**chunk_data)
|
||||
chunk_idx += 1
|
||||
# return
|
||||
|
||||
1
pkg/provider/modelmgr/requesters/tokenpony.svg
Normal file
1
pkg/provider/modelmgr/requesters/tokenpony.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="450" height="280" viewBox="0 0 450 280" class="cursor-pointer h-24 flex-shrink-0 w-149"><g fill="none" fill-rule="nonzero"><path fill="#0005DE" d="M97.705 6.742c58.844 0 90.962 34.353 90.962 98.341v21.843c-15.118-2.479-30.297-6.573-45.558-12.3v-9.543c0-35.97-15.564-56.281-45.404-56.281s-45.404 20.31-45.404 56.281v72.48c0 36.117 15.65 56.818 45.404 56.818 26.78 0 42.133-16.768 44.936-46.452q22.397 6.473 44.905 9.356c-6.15 51.52-37.492 79.155-89.841 79.155-58.678 0-90.963-34.72-90.963-98.878v-72.479c0-63.988 32.119-98.34 90.963-98.34m253.627 0c58.844 0 90.963 34.353 90.963 98.341v72.48c0 64.157-32.285 98.877-90.963 98.877-52.438 0-83.797-27.729-89.874-79.415 15-2.026 29.965-5.252 44.887-9.67 2.658 30.042 18.036 47.026 44.987 47.026 29.755 0 45.404-20.7 45.404-56.819v-72.479c0-35.97-15.564-56.281-45.404-56.281s-45.403 20.31-45.403 56.281v8.778c-15.262 5.868-30.44 10.104-45.559 12.725v-21.503c0-63.988 32.118-98.34 90.962-98.34m-164.37 140.026.57.09.831.127-.83-.128a234.5 234.5 0 0 0 35.979 2.79q18.408.002 36.858-2.928l1.401-.226a242 242 0 0 0 1.45-.244l-1.037.175q.729-.12 1.458-.247l-.421.072 1.26-.219-.84.147a244 244 0 0 0 2.8-.5l-.792.144q.648-.117 1.298-.239l-.506.094q.66-.122 1.322-.248l-.816.154q.759-.142 1.518-.289l-.702.135a247 247 0 0 0 5.364-1.084l-.463.098a250 250 0 0 0 3.928-.864l-.785.178 1.45-.33-.665.152q.597-.137 1.193-.276l-.528.123a253 253 0 0 0 3.685-.882l-.254.063q.683-.168 1.366-.34l-1.112.277q.809-.2 1.618-.405l-.506.128q.818-.206 1.634-.417l-1.128.289q.71-.18 1.419-.365l1.506-.397a259 259 0 0 0 1.804-.488l-.433.119a261 261 0 0 0 3.751-1.053l-.681.196a264 264 0 0 0 1.735-.502l-1.054.306q.636-.184 1.272-.37l-.218.064 1.238-.366-1.02.302a266 266 0 0 0 2.936-.882l-1.026.312q.71-.214 1.42-.433l-.394.121q.675-.207 1.35-.418l-.955.297q.8-.246 1.6-.499l-.645.202q.86-.269 1.72-.543l-1.076.341q.666-.21 1.33-.423l-.254.082q.833-.266 1.665-.539l-1.41.457q.874-.28 1.75-.568l-.34.111q.702-.229 1.403-.462l-1.063.351q.818-.269 1.634-.542l-.571.19a276 276 0 0 0 4.038-1.378l-.735.256q.657-.228 1.315-.46l-.58.204q16.86-5.903 33.78-14.256l-7.114-12.453 42.909 6.553-13.148 45.541-7.734-13.537q-23.832 11.94-47.755 19.504l-.199.063a298 298 0 0 1-11.65 3.412 288 288 0 0 1-10.39 2.603 280 280 0 0 1-11.677 2.431 273 273 0 0 1-11.643 1.903 263.5 263.5 0 0 1-36.858 2.599q-17.437 0-34.844-2.323l-.227-.03q-.635-.085-1.27-.174l1.497.204a268 268 0 0 1-13.673-2.182 275 275 0 0 1-12.817-2.697 282 282 0 0 1-11.859-3.057 291 291 0 0 1-7.21-2.123c-17.23-5.314-34.43-12.334-51.59-21.051l-8.258 14.455-13.148-45.541 42.909-6.553-6.594 11.544q18.421 9.24 36.776 15.572l1.316.45 1.373.462-.831-.278q.795.267 1.589.53l-.758-.252q.632.211 1.264.419l-.506-.167q.642.212 1.284.42l-.778-.253a271 271 0 0 0 3.914 1.251l-.227-.07a267 267 0 0 0 3.428 1.046l-.194-.058 1.315.389-1.121-.331q.864.256 1.73.508l-.609-.177q.826.241 1.651.478l-1.043-.3 1.307.375-.264-.075q.802.228 1.603.452l-1.34-.377q1.034.294 2.067.58l-.727-.203q.713.2 1.426.394l-.699-.192q.62.171 1.237.338l-.538-.146a259 259 0 0 0 3.977 1.051l-.66-.17q.683.177 1.367.35l-.707-.18q.687.175 1.373.348l-.666-.168q.738.186 1.475.368l-.809-.2q.716.179 1.43.353l-.621-.153a253 253 0 0 0 3.766.898l-.308-.07q.735.17 1.472.336l-1.164-.266q.747.173 1.496.34l-.332-.074q.845.19 1.69.374l-1.358-.3q.932.21 1.864.41l-.505-.11q.726.159 1.452.313l-.947-.203q.72.156 1.44.307l-.493-.104q.684.144 1.368.286l-.875-.182q.743.155 1.485.306l-.61-.124q.932.192 1.864.376l-1.254-.252q.904.184 1.809.361l-.555-.109q.752.15 1.504.293l-.95-.184q.69.135 1.377.265l-.427-.081q.784.15 1.569.295l-1.142-.214q.717.136 1.434.268l-.292-.054a244 244 0 0 0 3.808.673l-.68-.116 1.063.18-.383-.064q1.076.18 2.152.352z"></path></g></svg>
|
||||
|
After Width: | Height: | Size: 3.6 KiB |
31
pkg/provider/modelmgr/requesters/tokenpony.yaml
Normal file
31
pkg/provider/modelmgr/requesters/tokenpony.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
apiVersion: v1
|
||||
kind: LLMAPIRequester
|
||||
metadata:
|
||||
name: tokenpony-chat-completions
|
||||
label:
|
||||
en_US: TokenPony
|
||||
zh_Hans: 小马算力
|
||||
icon: tokenpony.svg
|
||||
spec:
|
||||
config:
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_Hans: 基础 URL
|
||||
type: string
|
||||
required: true
|
||||
default: "https://api.tokenpony.cn/v1"
|
||||
- name: timeout
|
||||
label:
|
||||
en_US: Timeout
|
||||
zh_Hans: 超时时间
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
- text-embedding
|
||||
execution:
|
||||
python:
|
||||
path: ./tokenponychatcmpl.py
|
||||
attr: TokenPonyChatCompletions
|
||||
17
pkg/provider/modelmgr/requesters/tokenponychatcmpl.py
Normal file
17
pkg/provider/modelmgr/requesters/tokenponychatcmpl.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import openai
|
||||
|
||||
from . import chatcmpl
|
||||
|
||||
|
||||
class TokenPonyChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""TokenPony ChatCompletion API 请求器"""
|
||||
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base_url': 'https://api.tokenpony.cn/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
312
pkg/provider/runners/cozeapi.py
Normal file
312
pkg/provider/runners/cozeapi.py
Normal file
@@ -0,0 +1,312 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
|
||||
from .. import runner
|
||||
from ...core import app
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
from ...utils import image
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from libs.coze_server_api.client import AsyncCozeAPIClient
|
||||
|
||||
@runner.runner_class('coze-api')
|
||||
class CozeAPIRunner(runner.RequestRunner):
|
||||
"""Coze API 对话请求器"""
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
self.pipeline_config = pipeline_config
|
||||
self.ap = ap
|
||||
self.agent_token = pipeline_config["ai"]['coze-api']['api-key']
|
||||
self.bot_id = pipeline_config["ai"]['coze-api'].get('bot-id')
|
||||
self.chat_timeout = pipeline_config["ai"]['coze-api'].get('timeout')
|
||||
self.auto_save_history = pipeline_config["ai"]['coze-api'].get('auto_save_history')
|
||||
self.api_base = pipeline_config["ai"]['coze-api'].get('api-base')
|
||||
|
||||
self.coze = AsyncCozeAPIClient(
|
||||
self.agent_token,
|
||||
self.api_base
|
||||
)
|
||||
|
||||
def _process_thinking_content(
|
||||
self,
|
||||
content: str,
|
||||
) -> tuple[str, str]:
|
||||
"""处理思维链内容
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
Returns:
|
||||
(处理后的内容, 提取的思维链内容)
|
||||
"""
|
||||
remove_think = self.pipeline_config.get('output', {}).get('misc', {}).get('remove-think', False)
|
||||
thinking_content = ''
|
||||
# 从 content 中提取 <think> 标签内容
|
||||
if content and '<think>' in content and '</think>' in content:
|
||||
import re
|
||||
|
||||
think_pattern = r'<think>(.*?)</think>'
|
||||
think_matches = re.findall(think_pattern, content, re.DOTALL)
|
||||
if think_matches:
|
||||
thinking_content = '\n'.join(think_matches)
|
||||
# 移除 content 中的 <think> 标签
|
||||
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
|
||||
|
||||
# 根据 remove_think 参数决定是否保留思维链
|
||||
if remove_think:
|
||||
return content, ''
|
||||
else:
|
||||
# 如果有思维链内容,将其以 <think> 格式添加到 content 开头
|
||||
if thinking_content:
|
||||
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
|
||||
return content, thinking_content
|
||||
|
||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> list[dict]:
|
||||
"""预处理用户消息,转换为Coze消息格式
|
||||
|
||||
Returns:
|
||||
list[dict]: Coze消息列表
|
||||
"""
|
||||
messages = []
|
||||
|
||||
if isinstance(query.user_message.content, list):
|
||||
# 多模态消息处理
|
||||
content_parts = []
|
||||
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == 'text':
|
||||
content_parts.append({"type": "text", "text": ce.text})
|
||||
elif ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
file_bytes = base64.b64decode(image_b64)
|
||||
file_id = await self._get_file_id(file_bytes)
|
||||
content_parts.append({"type": "image", "file_id": file_id})
|
||||
elif ce.type == 'file':
|
||||
# 处理文件,上传到Coze
|
||||
file_id = await self._get_file_id(ce.file)
|
||||
content_parts.append({"type": "file", "file_id": file_id})
|
||||
|
||||
# 创建多模态消息
|
||||
if content_parts:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": json.dumps(content_parts),
|
||||
"content_type": "object_string",
|
||||
"meta_data": None
|
||||
})
|
||||
|
||||
elif isinstance(query.user_message.content, str):
|
||||
# 纯文本消息
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": query.user_message.content,
|
||||
"content_type": "text",
|
||||
"meta_data": None
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
async def _get_file_id(self, file) -> str:
|
||||
"""上传文件到Coze服务
|
||||
Args:
|
||||
file: 文件
|
||||
Returns:
|
||||
str: 文件ID
|
||||
"""
|
||||
file_id = await self.coze.upload(file=file)
|
||||
return file_id
|
||||
|
||||
async def _chat_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用聊天助手(非流式)
|
||||
|
||||
注意:由于cozepy没有提供非流式API,这里使用流式API并在结束后一次性返回完整内容
|
||||
"""
|
||||
user_id = f'{query.launcher_id}_{query.sender_id}'
|
||||
|
||||
# 预处理用户消息
|
||||
additional_messages = await self._preprocess_user_message(query)
|
||||
|
||||
# 获取会话ID
|
||||
conversation_id = None
|
||||
|
||||
# 收集完整内容
|
||||
full_content = ''
|
||||
full_reasoning = ''
|
||||
|
||||
try:
|
||||
# 调用Coze API流式接口
|
||||
async for chunk in self.coze.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
timeout=self.chat_timeout,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True
|
||||
):
|
||||
self.ap.logger.debug(f'coze-chat-stream: {chunk}')
|
||||
|
||||
event_type = chunk.get('event')
|
||||
data = chunk.get('data', {})
|
||||
|
||||
if event_type == 'conversation.message.delta':
|
||||
# 收集内容
|
||||
if 'content' in data:
|
||||
full_content += data.get('content', '')
|
||||
|
||||
# 收集推理内容(如果有)
|
||||
if 'reasoning_content' in data:
|
||||
full_reasoning += data.get('reasoning_content', '')
|
||||
|
||||
elif event_type == 'done':
|
||||
# 保存会话ID
|
||||
if 'conversation_id' in data:
|
||||
conversation_id = data.get('conversation_id')
|
||||
|
||||
elif event_type == 'error':
|
||||
# 处理错误
|
||||
error_msg = f"Coze API错误: {data.get('message', '未知错误')}"
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=error_msg,
|
||||
)
|
||||
return
|
||||
|
||||
# 处理思维链内容
|
||||
content, thinking_content = self._process_thinking_content(full_content)
|
||||
if full_reasoning:
|
||||
remove_think = self.pipeline_config.get('output', {}).get('misc', {}).get('remove-think', False)
|
||||
if not remove_think:
|
||||
content = f'<think>\n{full_reasoning}\n</think>\n{content}'.strip()
|
||||
|
||||
# 一次性返回完整内容
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=content,
|
||||
)
|
||||
|
||||
# 保存会话ID
|
||||
if conversation_id and query.session.using_conversation:
|
||||
query.session.using_conversation.uuid = conversation_id
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Coze API错误: {str(e)}')
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=f'Coze API调用失败: {str(e)}',
|
||||
)
|
||||
|
||||
|
||||
async def _chat_messages_chunk(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""调用聊天助手(流式)"""
|
||||
user_id = f'{query.launcher_id}_{query.sender_id}'
|
||||
|
||||
# 预处理用户消息
|
||||
additional_messages = await self._preprocess_user_message(query)
|
||||
|
||||
# 获取会话ID
|
||||
conversation_id = None
|
||||
|
||||
start_reasoning = False
|
||||
stop_reasoning = False
|
||||
message_idx = 1
|
||||
is_final = False
|
||||
full_content = ''
|
||||
remove_think = self.pipeline_config.get('output', {}).get('misc', {}).get('remove-think', False)
|
||||
|
||||
|
||||
|
||||
try:
|
||||
# 调用Coze API流式接口
|
||||
async for chunk in self.coze.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
timeout=self.chat_timeout,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True
|
||||
):
|
||||
self.ap.logger.debug(f'coze-chat-stream-chunk: {chunk}')
|
||||
|
||||
event_type = chunk.get('event')
|
||||
data = chunk.get('data', {})
|
||||
content = ""
|
||||
|
||||
if event_type == 'conversation.message.delta':
|
||||
message_idx += 1
|
||||
# 处理内容增量
|
||||
if "reasoning_content" in data and not remove_think:
|
||||
|
||||
reasoning_content = data.get('reasoning_content', '')
|
||||
if reasoning_content and not start_reasoning:
|
||||
content = f"<think/>\n"
|
||||
start_reasoning = True
|
||||
content += reasoning_content
|
||||
|
||||
if 'content' in data:
|
||||
if data.get('content', ''):
|
||||
content += data.get('content', '')
|
||||
if not stop_reasoning and start_reasoning:
|
||||
content = f"</think>\n{content}"
|
||||
stop_reasoning = True
|
||||
|
||||
|
||||
elif event_type == 'done':
|
||||
# 保存会话ID
|
||||
if 'conversation_id' in data:
|
||||
conversation_id = data.get('conversation_id')
|
||||
if query.session.using_conversation:
|
||||
query.session.using_conversation.uuid = conversation_id
|
||||
is_final = True
|
||||
|
||||
|
||||
elif event_type == 'error':
|
||||
# 处理错误
|
||||
error_msg = f"Coze API错误: {data.get('message', '未知错误')}"
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=error_msg,
|
||||
finish_reason='error'
|
||||
)
|
||||
return
|
||||
full_content += content
|
||||
if message_idx % 8 == 0 or is_final:
|
||||
if full_content:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=full_content,
|
||||
is_final=is_final
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Coze API流式调用错误: {str(e)}')
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=f'Coze API流式调用失败: {str(e)}',
|
||||
finish_reason='error'
|
||||
)
|
||||
|
||||
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行"""
|
||||
msg_seq = 0
|
||||
if await query.adapter.is_stream_output_supported():
|
||||
async for msg in self._chat_messages_chunk(query):
|
||||
if isinstance(msg, provider_message.MessageChunk):
|
||||
msg_seq += 1
|
||||
msg.msg_sequence = msg_seq
|
||||
yield msg
|
||||
else:
|
||||
async for msg in self._chat_messages(query):
|
||||
yield msg
|
||||
|
||||
|
||||
|
||||
|
||||
205
pkg/provider/runners/tboxapi.py
Normal file
205
pkg/provider/runners/tboxapi.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import base64
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from tboxsdk.tbox import TboxClient
|
||||
from tboxsdk.model.file import File, FileType
|
||||
|
||||
from .. import runner
|
||||
from ...core import app
|
||||
from ...utils import image
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
class TboxAPIError(Exception):
|
||||
"""TBox API 请求失败"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
@runner.runner_class('tbox-app-api')
|
||||
class TboxAPIRunner(runner.RequestRunner):
|
||||
"蚂蚁百宝箱API对话请求器"
|
||||
|
||||
# 运行器内部使用的配置
|
||||
app_id: str # 蚂蚁百宝箱平台中的应用ID
|
||||
api_key: str # 在蚂蚁百宝箱平台中申请的令牌
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
"""初始化"""
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
# 初始化Tbox 参数配置
|
||||
self.app_id = self.pipeline_config['ai']['tbox-app-api']['app-id']
|
||||
self.api_key = self.pipeline_config['ai']['tbox-app-api']['api-key']
|
||||
|
||||
# 初始化Tbox client
|
||||
self.tbox_client = TboxClient(authorization=self.api_key)
|
||||
|
||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]:
|
||||
"""预处理用户消息,提取纯文本,并将图片上传到 Tbox 服务
|
||||
|
||||
Returns:
|
||||
tuple[str, list[str]]: 纯文本和图片的 Tbox 文件ID
|
||||
"""
|
||||
plain_text = ''
|
||||
image_ids = []
|
||||
|
||||
if isinstance(query.user_message.content, list):
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == 'text':
|
||||
plain_text += ce.text
|
||||
elif ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
# 创建临时文件
|
||||
file_bytes = base64.b64decode(image_b64)
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=f'.{image_format}', delete=False) as tmp_file:
|
||||
tmp_file.write(file_bytes)
|
||||
tmp_file_path = tmp_file.name
|
||||
file_upload_resp = self.tbox_client.upload_file(
|
||||
tmp_file_path
|
||||
)
|
||||
image_id = file_upload_resp.get("data", "")
|
||||
image_ids.append(image_id)
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_file_path):
|
||||
os.unlink(tmp_file_path)
|
||||
elif isinstance(query.user_message.content, str):
|
||||
plain_text = query.user_message.content
|
||||
|
||||
return plain_text, image_ids
|
||||
|
||||
async def _agent_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""TBox 智能体对话请求"""
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
remove_think = self.pipeline_config['output'].get('misc', {}).get('remove-think')
|
||||
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
|
||||
# 获取Tbox的conversation_id
|
||||
conversation_id = query.session.using_conversation.uuid or None
|
||||
|
||||
files = None
|
||||
if image_ids:
|
||||
files = [
|
||||
File(file_id=image_id, type=FileType.IMAGE)
|
||||
for image_id in image_ids
|
||||
]
|
||||
|
||||
# 发送对话请求
|
||||
response = self.tbox_client.chat(
|
||||
app_id=self.app_id, # Tbox中智能体应用的ID
|
||||
user_id=query.bot_uuid, # 用户ID
|
||||
query=plain_text, # 用户输入的文本信息
|
||||
stream=is_stream, # 是否流式输出
|
||||
conversation_id=conversation_id, # 会话ID,为None时Tbox会自动创建一个新会话
|
||||
files=files, # 图片内容
|
||||
)
|
||||
|
||||
if is_stream:
|
||||
# 解析Tbox流式输出内容,并发送给上游
|
||||
for chunk in self._process_stream_message(response, query, remove_think):
|
||||
yield chunk
|
||||
else:
|
||||
message = self._process_non_stream_message(response, query, remove_think)
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=message,
|
||||
)
|
||||
|
||||
def _process_non_stream_message(self, response: typing.Dict, query: pipeline_query.Query, remove_think: bool):
|
||||
if response.get('errorCode') != "0":
|
||||
raise TboxAPIError(f'Tbox API 请求失败: {response.get("errorMsg", "")}')
|
||||
payload = response.get('data', {})
|
||||
conversation_id = payload.get('conversationId', '')
|
||||
query.session.using_conversation.uuid = conversation_id
|
||||
thinking_content = payload.get('reasoningContent', [])
|
||||
result = ""
|
||||
if thinking_content and not remove_think:
|
||||
result += f'<think>\n{thinking_content[0].get("text", "")}\n</think>\n'
|
||||
content = payload.get('result', [])
|
||||
if content:
|
||||
result += content[0].get('chunk', '')
|
||||
return result
|
||||
|
||||
def _process_stream_message(self, response: typing.Generator[dict], query: pipeline_query.Query, remove_think: bool):
|
||||
idx_msg = 0
|
||||
pending_content = ''
|
||||
conversation_id = None
|
||||
think_start = False
|
||||
think_end = False
|
||||
for chunk in response:
|
||||
if chunk.get('type', '') == 'chunk':
|
||||
"""
|
||||
Tbox返回的消息内容chunk结构
|
||||
{'lane': 'default', 'payload': {'conversationId': '20250918tBI947065406', 'messageId': '20250918TB1f53230954', 'text': '️'}, 'type': 'chunk'}
|
||||
"""
|
||||
# 如果包含思考过程,拼接</think>
|
||||
if think_start and not think_end:
|
||||
pending_content += '\n</think>\n'
|
||||
think_end = True
|
||||
|
||||
payload = chunk.get('payload', {})
|
||||
if not conversation_id:
|
||||
conversation_id = payload.get('conversationId')
|
||||
query.session.using_conversation.uuid = conversation_id
|
||||
if payload.get('text'):
|
||||
idx_msg += 1
|
||||
pending_content += payload.get('text')
|
||||
elif chunk.get('type', '') == 'thinking' and not remove_think:
|
||||
"""
|
||||
Tbox返回的思考过程chunk结构
|
||||
{'payload': '{"ext_data":{"text":"日期"},"event":"flow.node.llm.thinking","entity":{"node_type":"text-completion","execute_id":"6","group_id":0,"parent_execute_id":"6","node_name":"模型推理","node_id":"TC_5u6gl0"}}', 'type': 'thinking'}
|
||||
"""
|
||||
payload = json.loads(chunk.get('payload', '{}'))
|
||||
if payload.get('ext_data', {}).get('text'):
|
||||
idx_msg += 1
|
||||
content = payload.get('ext_data', {}).get('text')
|
||||
if not think_start:
|
||||
think_start = True
|
||||
pending_content += f'<think>\n{content}'
|
||||
else:
|
||||
pending_content += content
|
||||
elif chunk.get('type', '') == 'error':
|
||||
raise TboxAPIError(
|
||||
f'Tbox API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
|
||||
if idx_msg % 8 == 0:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Tbox不返回END事件,默认发一个最终消息
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行"""
|
||||
msg_seq = 0
|
||||
async for msg in self._agent_messages(query):
|
||||
if isinstance(msg, provider_message.MessageChunk):
|
||||
msg_seq += 1
|
||||
msg.msg_sequence = msg_seq
|
||||
yield msg
|
||||
@@ -1,4 +1,4 @@
|
||||
semantic_version = 'v4.3.2'
|
||||
semantic_version = 'v4.3.9'
|
||||
|
||||
required_database_version = 8
|
||||
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "langbot"
|
||||
version = "4.3.1"
|
||||
version = "4.3.9"
|
||||
description = "Easy-to-use global IM bot platform designed for LLM era"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10.1,<4.0"
|
||||
@@ -62,9 +62,10 @@ dependencies = [
|
||||
"langchain>=0.2.0",
|
||||
"chromadb>=0.4.24",
|
||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||
"langbot-plugin==0.1.1",
|
||||
"langbot-plugin==0.1.4",
|
||||
"asyncpg>=0.30.0",
|
||||
"line-bot-sdk>=3.19.0"
|
||||
"line-bot-sdk>=3.19.0",
|
||||
"tboxsdk>=0.0.10",
|
||||
]
|
||||
keywords = [
|
||||
"bot",
|
||||
@@ -102,6 +103,7 @@ dev = [
|
||||
"pre-commit>=4.2.0",
|
||||
"pytest>=8.4.1",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"pytest-cov>=7.0.0",
|
||||
"ruff>=0.11.9",
|
||||
]
|
||||
|
||||
|
||||
39
pytest.ini
Normal file
39
pytest.ini
Normal file
@@ -0,0 +1,39 @@
|
||||
[pytest]
|
||||
# Test discovery patterns
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
|
||||
# Test paths
|
||||
testpaths = tests
|
||||
|
||||
# Asyncio configuration
|
||||
asyncio_mode = auto
|
||||
|
||||
# Output options
|
||||
addopts =
|
||||
-v
|
||||
--strict-markers
|
||||
--tb=short
|
||||
--disable-warnings
|
||||
|
||||
# Markers
|
||||
markers =
|
||||
asyncio: mark test as async
|
||||
unit: mark test as unit test
|
||||
integration: mark test as integration test
|
||||
slow: mark test as slow running
|
||||
|
||||
# Coverage options (when using pytest-cov)
|
||||
[coverage:run]
|
||||
source = pkg
|
||||
omit =
|
||||
*/tests/*
|
||||
*/test_*.py
|
||||
*/__pycache__/*
|
||||
*/site-packages/*
|
||||
|
||||
[coverage:report]
|
||||
precision = 2
|
||||
show_missing = True
|
||||
skip_covered = False
|
||||
31
run_tests.sh
Executable file
31
run_tests.sh
Executable file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to run all unit tests
|
||||
# This script helps avoid circular import issues by setting up the environment properly
|
||||
|
||||
set -e
|
||||
|
||||
echo "Setting up test environment..."
|
||||
|
||||
# Activate virtual environment if it exists
|
||||
if [ -d ".venv" ]; then
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
|
||||
# Check if pytest is installed
|
||||
if ! command -v pytest &> /dev/null; then
|
||||
echo "Installing test dependencies..."
|
||||
pip install pytest pytest-asyncio pytest-cov
|
||||
fi
|
||||
|
||||
echo "Running all unit tests..."
|
||||
|
||||
# Run tests with coverage
|
||||
pytest tests/unit_tests/ -v --tb=short \
|
||||
--cov=pkg \
|
||||
--cov-report=xml \
|
||||
"$@"
|
||||
|
||||
echo ""
|
||||
echo "Test run complete!"
|
||||
echo "Coverage report saved to coverage.xml"
|
||||
@@ -38,6 +38,7 @@ vdb:
|
||||
port: 6333
|
||||
api_key: ''
|
||||
plugin:
|
||||
enable: true
|
||||
runtime_ws_url: 'ws://langbot_plugin_runtime:5400/control/ws'
|
||||
enable_marketplace: true
|
||||
cloud_service_url: 'https://space.langbot.app'
|
||||
cloud_service_url: 'https://space.langbot.app'
|
||||
|
||||
@@ -83,7 +83,7 @@
|
||||
"output": {
|
||||
"long-text-processing": {
|
||||
"threshold": 1000,
|
||||
"strategy": "forward",
|
||||
"strategy": "none",
|
||||
"font-path": ""
|
||||
},
|
||||
"force-delay": {
|
||||
|
||||
@@ -23,6 +23,10 @@ stages:
|
||||
label:
|
||||
en_US: Local Agent
|
||||
zh_Hans: 内置 Agent
|
||||
- name: tbox-app-api
|
||||
label:
|
||||
en_US: Tbox App API
|
||||
zh_Hans: 蚂蚁百宝箱平台 API
|
||||
- name: dify-service-api
|
||||
label:
|
||||
en_US: Dify Service API
|
||||
@@ -39,6 +43,10 @@ stages:
|
||||
label:
|
||||
en_US: Langflow API
|
||||
zh_Hans: Langflow API
|
||||
- name: coze-api
|
||||
label:
|
||||
en_US: Coze API
|
||||
zh_Hans: 扣子 API
|
||||
- name: local-agent
|
||||
label:
|
||||
en_US: Local Agent
|
||||
@@ -82,6 +90,26 @@ stages:
|
||||
type: knowledge-base-selector
|
||||
required: false
|
||||
default: ''
|
||||
- name: tbox-app-api
|
||||
label:
|
||||
en_US: Tbox App API
|
||||
zh_Hans: 蚂蚁百宝箱平台 API
|
||||
description:
|
||||
en_US: Configure the Tbox App API of the pipeline
|
||||
zh_Hans: 配置蚂蚁百宝箱平台 API
|
||||
config:
|
||||
- name: api-key
|
||||
label:
|
||||
en_US: API Key
|
||||
zh_Hans: API 密钥
|
||||
type: string
|
||||
required: true
|
||||
- name: app-id
|
||||
label:
|
||||
en_US: App ID
|
||||
zh_Hans: 应用 ID
|
||||
type: string
|
||||
required: true
|
||||
- name: dify-service-api
|
||||
label:
|
||||
en_US: Dify Service API
|
||||
@@ -356,4 +384,57 @@ stages:
|
||||
zh_Hans: 可选的流程调整参数
|
||||
type: json
|
||||
required: false
|
||||
default: '{}'
|
||||
default: '{}'
|
||||
- name: coze-api
|
||||
label:
|
||||
en_US: coze API
|
||||
zh_Hans: 扣子 API
|
||||
description:
|
||||
en_US: Configure the Coze API of the pipeline
|
||||
zh_Hans: 配置Coze API
|
||||
config:
|
||||
- name: api-key
|
||||
label:
|
||||
en_US: API Key
|
||||
zh_Hans: API 密钥
|
||||
description:
|
||||
en_US: The API key for the Coze server
|
||||
zh_Hans: Coze服务器的 API 密钥
|
||||
type: string
|
||||
required: true
|
||||
- name: bot-id
|
||||
label:
|
||||
en_US: Bot ID
|
||||
zh_Hans: 机器人 ID
|
||||
description:
|
||||
en_US: The ID of the bot to run
|
||||
zh_Hans: 要运行的机器人 ID
|
||||
type: string
|
||||
required: true
|
||||
- name: api-base
|
||||
label:
|
||||
en_US: API Base URL
|
||||
zh_Hans: API 基础 URL
|
||||
description:
|
||||
en_US: The base URL for the Coze API, please use https://api.coze.com for global Coze edition(coze.com).
|
||||
zh_Hans: Coze API 的基础 URL,请使用 https://api.coze.com 用于全球 Coze 版(coze.com)
|
||||
type: string
|
||||
default: "https://api.coze.cn"
|
||||
- name: auto-save-history
|
||||
label:
|
||||
en_US: Auto Save History
|
||||
zh_Hans: 自动保存历史
|
||||
description:
|
||||
en_US: Whether to automatically save conversation history
|
||||
zh_Hans: 是否自动保存对话历史
|
||||
type: boolean
|
||||
default: true
|
||||
- name: timeout
|
||||
label:
|
||||
en_US: Request Timeout
|
||||
zh_Hans: 请求超时
|
||||
description:
|
||||
en_US: Timeout in seconds for API requests
|
||||
zh_Hans: API 请求超时时间(秒)
|
||||
type: number
|
||||
default: 120
|
||||
@@ -27,7 +27,7 @@ stages:
|
||||
zh_Hans: 长文本的处理策略
|
||||
type: select
|
||||
required: true
|
||||
default: forward
|
||||
default: none
|
||||
options:
|
||||
- name: forward
|
||||
label:
|
||||
@@ -37,6 +37,10 @@ stages:
|
||||
label:
|
||||
en_US: Convert to Image
|
||||
zh_Hans: 转换为图片
|
||||
- name: none
|
||||
label:
|
||||
en_US: None
|
||||
zh_Hans: 不处理
|
||||
- name: font-path
|
||||
label:
|
||||
en_US: Font Path
|
||||
|
||||
183
tests/README.md
Normal file
183
tests/README.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# LangBot Test Suite
|
||||
|
||||
This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
|
||||
|
||||
## Important Note
|
||||
|
||||
Due to circular import dependencies in the pipeline module structure, the test files use **lazy imports** via `importlib.import_module()` instead of direct imports. This ensures tests can run without triggering circular import errors.
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
tests/
|
||||
├── pipeline/ # Pipeline stage tests
|
||||
│ ├── conftest.py # Shared fixtures and test infrastructure
|
||||
│ ├── test_simple.py # Basic infrastructure tests (always pass)
|
||||
│ ├── test_bansess.py # BanSessionCheckStage tests
|
||||
│ ├── test_ratelimit.py # RateLimit stage tests
|
||||
│ ├── test_preproc.py # PreProcessor stage tests
|
||||
│ ├── test_respback.py # SendResponseBackStage tests
|
||||
│ ├── test_resprule.py # GroupRespondRuleCheckStage tests
|
||||
│ ├── test_pipelinemgr.py # PipelineManager tests
|
||||
│ └── test_stages_integration.py # Integration tests
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Test Architecture
|
||||
|
||||
### Fixtures (`conftest.py`)
|
||||
|
||||
The test suite uses a centralized fixture system that provides:
|
||||
|
||||
- **MockApplication**: Comprehensive mock of the Application object with all dependencies
|
||||
- **Mock objects**: Pre-configured mocks for Session, Conversation, Model, Adapter
|
||||
- **Sample data**: Ready-to-use Query objects, message chains, and configurations
|
||||
- **Helper functions**: Utilities for creating results and common assertions
|
||||
|
||||
### Design Principles
|
||||
|
||||
1. **Isolation**: Each test is independent and doesn't rely on external systems
|
||||
2. **Mocking**: All external dependencies are mocked to ensure fast, reliable tests
|
||||
3. **Coverage**: Tests cover happy paths, edge cases, and error conditions
|
||||
4. **Extensibility**: Easy to add new tests by reusing existing fixtures
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Using the test runner script (recommended)
|
||||
```bash
|
||||
bash run_tests.sh
|
||||
```
|
||||
|
||||
This script automatically:
|
||||
- Activates the virtual environment
|
||||
- Installs test dependencies if needed
|
||||
- Runs tests with coverage
|
||||
- Generates HTML coverage report
|
||||
|
||||
### Manual test execution
|
||||
|
||||
#### Run all tests
|
||||
```bash
|
||||
pytest tests/pipeline/
|
||||
```
|
||||
|
||||
#### Run only simple tests (no imports, always pass)
|
||||
```bash
|
||||
pytest tests/pipeline/test_simple.py -v
|
||||
```
|
||||
|
||||
#### Run specific test file
|
||||
```bash
|
||||
pytest tests/pipeline/test_bansess.py -v
|
||||
```
|
||||
|
||||
#### Run with coverage
|
||||
```bash
|
||||
pytest tests/pipeline/ --cov=pkg/pipeline --cov-report=html
|
||||
```
|
||||
|
||||
#### Run specific test
|
||||
```bash
|
||||
pytest tests/pipeline/test_bansess.py::test_bansess_whitelist_allow -v
|
||||
```
|
||||
|
||||
### Known Issues
|
||||
|
||||
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
|
||||
|
||||
1. Make sure you're running from the project root directory
|
||||
2. Ensure the virtual environment is activated
|
||||
3. Try running `test_simple.py` first to verify the test infrastructure works
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
Tests are automatically run on:
|
||||
- Pull request opened
|
||||
- Pull request marked ready for review
|
||||
- Push to PR branch
|
||||
- Push to master/develop branches
|
||||
|
||||
The workflow runs tests on Python 3.10, 3.11, and 3.12 to ensure compatibility.
|
||||
|
||||
## Adding New Tests
|
||||
|
||||
### 1. For a new pipeline stage
|
||||
|
||||
Create a new test file `test_<stage_name>.py`:
|
||||
|
||||
```python
|
||||
"""
|
||||
<StageName> stage unit tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pkg.pipeline.<module>.<stage> import <StageClass>
|
||||
from pkg.pipeline import entities as pipeline_entities
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_basic_flow(mock_app, sample_query):
|
||||
"""Test basic flow"""
|
||||
stage = <StageClass>(mock_app)
|
||||
await stage.initialize({})
|
||||
|
||||
result = await stage.process(sample_query, '<StageName>')
|
||||
|
||||
assert result.result_type == pipeline_entities.ResultType.CONTINUE
|
||||
```
|
||||
|
||||
### 2. For additional fixtures
|
||||
|
||||
Add new fixtures to `conftest.py`:
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def my_custom_fixture():
|
||||
"""Description of fixture"""
|
||||
return create_test_data()
|
||||
```
|
||||
|
||||
### 3. For test data
|
||||
|
||||
Use the helper functions in `conftest.py`:
|
||||
|
||||
```python
|
||||
from tests.pipeline.conftest import create_stage_result, assert_result_continue
|
||||
|
||||
result = create_stage_result(
|
||||
result_type=pipeline_entities.ResultType.CONTINUE,
|
||||
query=sample_query
|
||||
)
|
||||
|
||||
assert_result_continue(result)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Test naming**: Use descriptive names that explain what's being tested
|
||||
2. **Arrange-Act-Assert**: Structure tests clearly with setup, execution, and verification
|
||||
3. **One assertion per test**: Focus each test on a single behavior
|
||||
4. **Mock appropriately**: Mock external dependencies, not the code under test
|
||||
5. **Use fixtures**: Reuse common test data through fixtures
|
||||
6. **Document tests**: Add docstrings explaining what each test validates
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Import errors
|
||||
Make sure you've installed the package in development mode:
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
### Async test failures
|
||||
Ensure you're using `@pytest.mark.asyncio` decorator for async tests.
|
||||
|
||||
### Mock not working
|
||||
Check that you're mocking at the right level and using `AsyncMock` for async functions.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- [ ] Add integration tests for full pipeline execution
|
||||
- [ ] Add performance benchmarks
|
||||
- [ ] Add mutation testing for better coverage quality
|
||||
- [ ] Add property-based testing with Hypothesis
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/unit_tests/__init__.py
Normal file
0
tests/unit_tests/__init__.py
Normal file
0
tests/unit_tests/pipeline/__init__.py
Normal file
0
tests/unit_tests/pipeline/__init__.py
Normal file
251
tests/unit_tests/pipeline/conftest.py
Normal file
251
tests/unit_tests/pipeline/conftest.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Shared test fixtures and configuration
|
||||
|
||||
This file provides infrastructure for all pipeline tests, including:
|
||||
- Mock object factories
|
||||
- Test fixtures
|
||||
- Common test helper functions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock
|
||||
from typing import Any
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
from pkg.pipeline import entities as pipeline_entities
|
||||
|
||||
|
||||
class MockApplication:
|
||||
"""Mock Application object providing all basic dependencies needed by stages"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = self._create_mock_logger()
|
||||
self.sess_mgr = self._create_mock_session_manager()
|
||||
self.model_mgr = self._create_mock_model_manager()
|
||||
self.tool_mgr = self._create_mock_tool_manager()
|
||||
self.plugin_connector = self._create_mock_plugin_connector()
|
||||
self.persistence_mgr = self._create_mock_persistence_manager()
|
||||
self.query_pool = self._create_mock_query_pool()
|
||||
self.instance_config = self._create_mock_instance_config()
|
||||
self.task_mgr = self._create_mock_task_manager()
|
||||
|
||||
def _create_mock_logger(self):
|
||||
logger = Mock()
|
||||
logger.debug = Mock()
|
||||
logger.info = Mock()
|
||||
logger.error = Mock()
|
||||
logger.warning = Mock()
|
||||
return logger
|
||||
|
||||
def _create_mock_session_manager(self):
|
||||
sess_mgr = AsyncMock()
|
||||
sess_mgr.get_session = AsyncMock()
|
||||
sess_mgr.get_conversation = AsyncMock()
|
||||
return sess_mgr
|
||||
|
||||
def _create_mock_model_manager(self):
|
||||
model_mgr = AsyncMock()
|
||||
model_mgr.get_model_by_uuid = AsyncMock()
|
||||
return model_mgr
|
||||
|
||||
def _create_mock_tool_manager(self):
|
||||
tool_mgr = AsyncMock()
|
||||
tool_mgr.get_all_tools = AsyncMock(return_value=[])
|
||||
return tool_mgr
|
||||
|
||||
def _create_mock_plugin_connector(self):
|
||||
plugin_connector = AsyncMock()
|
||||
plugin_connector.emit_event = AsyncMock()
|
||||
return plugin_connector
|
||||
|
||||
def _create_mock_persistence_manager(self):
|
||||
persistence_mgr = AsyncMock()
|
||||
persistence_mgr.execute_async = AsyncMock()
|
||||
return persistence_mgr
|
||||
|
||||
def _create_mock_query_pool(self):
|
||||
query_pool = Mock()
|
||||
query_pool.cached_queries = {}
|
||||
query_pool.queries = []
|
||||
query_pool.condition = AsyncMock()
|
||||
return query_pool
|
||||
|
||||
def _create_mock_instance_config(self):
|
||||
instance_config = Mock()
|
||||
instance_config.data = {
|
||||
'command': {'prefix': ['/', '!'], 'enable': True},
|
||||
'concurrency': {'pipeline': 10},
|
||||
}
|
||||
return instance_config
|
||||
|
||||
def _create_mock_task_manager(self):
|
||||
task_mgr = Mock()
|
||||
task_mgr.create_task = Mock()
|
||||
return task_mgr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
"""Provides Mock Application instance"""
|
||||
return MockApplication()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Provides Mock Session object"""
|
||||
session = Mock()
|
||||
session.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
session.launcher_id = 12345
|
||||
session._semaphore = AsyncMock()
|
||||
session._semaphore.locked = Mock(return_value=False)
|
||||
session._semaphore.acquire = AsyncMock()
|
||||
session._semaphore.release = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation():
|
||||
"""Provides Mock Conversation object"""
|
||||
conversation = Mock()
|
||||
conversation.uuid = 'test-conversation-uuid'
|
||||
|
||||
# Create mock prompt with copy method
|
||||
mock_prompt = Mock()
|
||||
mock_prompt.messages = []
|
||||
mock_prompt.copy = Mock(return_value=Mock(messages=[]))
|
||||
conversation.prompt = mock_prompt
|
||||
|
||||
# Create mock messages list with copy method
|
||||
mock_messages = Mock()
|
||||
mock_messages.copy = Mock(return_value=[])
|
||||
conversation.messages = mock_messages
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
"""Provides Mock Model object"""
|
||||
model = Mock()
|
||||
model.model_entity = Mock()
|
||||
model.model_entity.uuid = 'test-model-uuid'
|
||||
model.model_entity.abilities = ['func_call', 'vision']
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
"""Provides Mock Adapter object"""
|
||||
adapter = AsyncMock()
|
||||
adapter.is_stream_output_supported = AsyncMock(return_value=False)
|
||||
adapter.reply_message = AsyncMock()
|
||||
adapter.reply_message_chunk = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message_chain():
|
||||
"""Provides sample message chain"""
|
||||
return platform_message.MessageChain(
|
||||
[
|
||||
platform_message.Plain(text='Hello, this is a test message'),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message_event(sample_message_chain):
|
||||
"""Provides sample message event"""
|
||||
event = Mock()
|
||||
event.sender = Mock()
|
||||
event.sender.id = 12345
|
||||
event.time = 1609459200 # 2021-01-01 00:00:00
|
||||
return event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_query(sample_message_chain, sample_message_event, mock_adapter):
|
||||
"""Provides sample Query object - using model_construct to bypass validation"""
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
# Use model_construct to bypass Pydantic validation for test purposes
|
||||
query = pipeline_query.Query.model_construct(
|
||||
query_id='test-query-id',
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_chain=sample_message_chain,
|
||||
message_event=sample_message_event,
|
||||
adapter=mock_adapter,
|
||||
pipeline_uuid='test-pipeline-uuid',
|
||||
bot_uuid='test-bot-uuid',
|
||||
pipeline_config={
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
|
||||
},
|
||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||
'trigger': {'misc': {'combine-quote-message': False}},
|
||||
},
|
||||
session=None,
|
||||
prompt=None,
|
||||
messages=[],
|
||||
user_message=None,
|
||||
use_funcs=[],
|
||||
use_llm_model_uuid=None,
|
||||
variables={},
|
||||
resp_messages=[],
|
||||
resp_message_chain=None,
|
||||
current_stage_name=None
|
||||
)
|
||||
return query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pipeline_config():
|
||||
"""Provides sample pipeline configuration"""
|
||||
return {
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
|
||||
},
|
||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||
'trigger': {'misc': {'combine-quote-message': False}},
|
||||
'ratelimit': {'enable': True, 'algo': 'fixwin', 'window': 60, 'limit': 10},
|
||||
}
|
||||
|
||||
|
||||
def create_stage_result(
|
||||
result_type: pipeline_entities.ResultType,
|
||||
query: pipeline_query.Query,
|
||||
user_notice: str = '',
|
||||
console_notice: str = '',
|
||||
debug_notice: str = '',
|
||||
error_notice: str = '',
|
||||
) -> pipeline_entities.StageProcessResult:
|
||||
"""Helper function to create stage process result"""
|
||||
return pipeline_entities.StageProcessResult(
|
||||
result_type=result_type,
|
||||
new_query=query,
|
||||
user_notice=user_notice,
|
||||
console_notice=console_notice,
|
||||
debug_notice=debug_notice,
|
||||
error_notice=error_notice,
|
||||
)
|
||||
|
||||
|
||||
def assert_result_continue(result: pipeline_entities.StageProcessResult):
|
||||
"""Assert result is CONTINUE type"""
|
||||
assert result.result_type == pipeline_entities.ResultType.CONTINUE
|
||||
|
||||
|
||||
def assert_result_interrupt(result: pipeline_entities.StageProcessResult):
|
||||
"""Assert result is INTERRUPT type"""
|
||||
assert result.result_type == pipeline_entities.ResultType.INTERRUPT
|
||||
189
tests/unit_tests/pipeline/test_bansess.py
Normal file
189
tests/unit_tests/pipeline/test_bansess.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
BanSessionCheckStage unit tests
|
||||
|
||||
Tests the actual BanSessionCheckStage implementation from pkg.pipeline.bansess
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from importlib import import_module
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
|
||||
def get_modules():
|
||||
"""Lazy import to ensure proper initialization order"""
|
||||
# Import pipelinemgr first to trigger proper stage registration
|
||||
pipelinemgr = import_module('pkg.pipeline.pipelinemgr')
|
||||
bansess = import_module('pkg.pipeline.bansess.bansess')
|
||||
entities = import_module('pkg.pipeline.entities')
|
||||
return bansess, entities
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitelist_allow(mock_app, sample_query):
|
||||
"""Test whitelist allows matching session"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'whitelist',
|
||||
'whitelist': ['person_12345']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query == sample_query
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitelist_deny(mock_app, sample_query):
|
||||
"""Test whitelist denies non-matching session"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '99999'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'whitelist',
|
||||
'whitelist': ['person_12345']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.INTERRUPT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blacklist_allow(mock_app, sample_query):
|
||||
"""Test blacklist allows non-matching session"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'blacklist',
|
||||
'blacklist': ['person_99999']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blacklist_deny(mock_app, sample_query):
|
||||
"""Test blacklist denies matching session"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'blacklist',
|
||||
'blacklist': ['person_12345']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.INTERRUPT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_group(mock_app, sample_query):
|
||||
"""Test group wildcard matching"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'whitelist',
|
||||
'whitelist': ['group_*']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_person(mock_app, sample_query):
|
||||
"""Test person wildcard matching"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'whitelist',
|
||||
'whitelist': ['person_*']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_id_wildcard(mock_app, sample_query):
|
||||
"""Test user ID wildcard matching (*_id format)"""
|
||||
bansess, entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.sender_id = '67890'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'access-control': {
|
||||
'mode': 'whitelist',
|
||||
'whitelist': ['*_67890']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage = bansess.BanSessionCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'BanSessionCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
166
tests/unit_tests/pipeline/test_pipelinemgr.py
Normal file
166
tests/unit_tests/pipeline/test_pipelinemgr.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
PipelineManager unit tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from importlib import import_module
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
def get_pipelinemgr_module():
|
||||
return import_module('pkg.pipeline.pipelinemgr')
|
||||
|
||||
|
||||
def get_stage_module():
|
||||
return import_module('pkg.pipeline.stage')
|
||||
|
||||
|
||||
def get_entities_module():
|
||||
return import_module('pkg.pipeline.entities')
|
||||
|
||||
|
||||
def get_persistence_pipeline_module():
|
||||
return import_module('pkg.entity.persistence.pipeline')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_manager_initialize(mock_app):
|
||||
"""Test pipeline manager initialization"""
|
||||
pipelinemgr = get_pipelinemgr_module()
|
||||
|
||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
||||
|
||||
manager = pipelinemgr.PipelineManager(mock_app)
|
||||
await manager.initialize()
|
||||
|
||||
assert manager.stage_dict is not None
|
||||
assert len(manager.pipelines) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_pipeline(mock_app):
|
||||
"""Test loading a single pipeline"""
|
||||
pipelinemgr = get_pipelinemgr_module()
|
||||
persistence_pipeline = get_persistence_pipeline_module()
|
||||
|
||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
||||
|
||||
manager = pipelinemgr.PipelineManager(mock_app)
|
||||
await manager.initialize()
|
||||
|
||||
# Create test pipeline entity
|
||||
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
||||
pipeline_entity.uuid = 'test-uuid'
|
||||
pipeline_entity.stages = []
|
||||
pipeline_entity.config = {'test': 'config'}
|
||||
|
||||
await manager.load_pipeline(pipeline_entity)
|
||||
|
||||
assert len(manager.pipelines) == 1
|
||||
assert manager.pipelines[0].pipeline_entity.uuid == 'test-uuid'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pipeline_by_uuid(mock_app):
|
||||
"""Test getting pipeline by UUID"""
|
||||
pipelinemgr = get_pipelinemgr_module()
|
||||
persistence_pipeline = get_persistence_pipeline_module()
|
||||
|
||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
||||
|
||||
manager = pipelinemgr.PipelineManager(mock_app)
|
||||
await manager.initialize()
|
||||
|
||||
# Create and add test pipeline
|
||||
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
||||
pipeline_entity.uuid = 'test-uuid'
|
||||
pipeline_entity.stages = []
|
||||
pipeline_entity.config = {}
|
||||
|
||||
await manager.load_pipeline(pipeline_entity)
|
||||
|
||||
# Test retrieval
|
||||
result = await manager.get_pipeline_by_uuid('test-uuid')
|
||||
assert result is not None
|
||||
assert result.pipeline_entity.uuid == 'test-uuid'
|
||||
|
||||
# Test non-existent UUID
|
||||
result = await manager.get_pipeline_by_uuid('non-existent')
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_pipeline(mock_app):
|
||||
"""Test removing a pipeline"""
|
||||
pipelinemgr = get_pipelinemgr_module()
|
||||
persistence_pipeline = get_persistence_pipeline_module()
|
||||
|
||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
||||
|
||||
manager = pipelinemgr.PipelineManager(mock_app)
|
||||
await manager.initialize()
|
||||
|
||||
# Create and add test pipeline
|
||||
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
||||
pipeline_entity.uuid = 'test-uuid'
|
||||
pipeline_entity.stages = []
|
||||
pipeline_entity.config = {}
|
||||
|
||||
await manager.load_pipeline(pipeline_entity)
|
||||
assert len(manager.pipelines) == 1
|
||||
|
||||
# Remove pipeline
|
||||
await manager.remove_pipeline('test-uuid')
|
||||
assert len(manager.pipelines) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_pipeline_execute(mock_app, sample_query):
|
||||
"""Test runtime pipeline execution"""
|
||||
pipelinemgr = get_pipelinemgr_module()
|
||||
stage = get_stage_module()
|
||||
persistence_pipeline = get_persistence_pipeline_module()
|
||||
|
||||
# Create mock stage that returns a simple result dict (avoiding Pydantic validation)
|
||||
mock_result = Mock()
|
||||
mock_result.result_type = Mock()
|
||||
mock_result.result_type.value = 'CONTINUE' # Simulate enum value
|
||||
mock_result.new_query = sample_query
|
||||
mock_result.user_notice = ''
|
||||
mock_result.console_notice = ''
|
||||
mock_result.debug_notice = ''
|
||||
mock_result.error_notice = ''
|
||||
|
||||
# Make it look like ResultType.CONTINUE
|
||||
from unittest.mock import MagicMock
|
||||
CONTINUE = MagicMock()
|
||||
CONTINUE.__eq__ = lambda self, other: True # Always equal for comparison
|
||||
mock_result.result_type = CONTINUE
|
||||
|
||||
mock_stage = Mock(spec=stage.PipelineStage)
|
||||
mock_stage.process = AsyncMock(return_value=mock_result)
|
||||
|
||||
# Create stage container
|
||||
stage_container = pipelinemgr.StageInstContainer(inst_name='TestStage', inst=mock_stage)
|
||||
|
||||
# Create pipeline entity
|
||||
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
||||
pipeline_entity.config = sample_query.pipeline_config
|
||||
|
||||
# Create runtime pipeline
|
||||
runtime_pipeline = pipelinemgr.RuntimePipeline(mock_app, pipeline_entity, [stage_container])
|
||||
|
||||
# Mock plugin connector
|
||||
event_ctx = Mock()
|
||||
event_ctx.is_prevented_default = Mock(return_value=False)
|
||||
mock_app.plugin_connector.emit_event = AsyncMock(return_value=event_ctx)
|
||||
|
||||
# Add query to cached_queries to prevent KeyError in finally block
|
||||
mock_app.query_pool.cached_queries[sample_query.query_id] = sample_query
|
||||
|
||||
# Execute pipeline
|
||||
await runtime_pipeline.run(sample_query)
|
||||
|
||||
# Verify stage was called
|
||||
mock_stage.process.assert_called_once()
|
||||
109
tests/unit_tests/pipeline/test_ratelimit.py
Normal file
109
tests/unit_tests/pipeline/test_ratelimit.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
RateLimit stage unit tests
|
||||
|
||||
Tests the actual RateLimit implementation from pkg.pipeline.ratelimit
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from importlib import import_module
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
|
||||
def get_modules():
|
||||
"""Lazy import to ensure proper initialization order"""
|
||||
# Import pipelinemgr first to trigger proper stage registration
|
||||
pipelinemgr = import_module('pkg.pipeline.pipelinemgr')
|
||||
ratelimit = import_module('pkg.pipeline.ratelimit.ratelimit')
|
||||
entities = import_module('pkg.pipeline.entities')
|
||||
algo_module = import_module('pkg.pipeline.ratelimit.algo')
|
||||
return ratelimit, entities, algo_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_access_allowed(mock_app, sample_query):
|
||||
"""Test RequireRateLimitOccupancy allows access when rate limit is not exceeded"""
|
||||
ratelimit, entities, algo_module = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {}
|
||||
|
||||
# Create mock algorithm that allows access
|
||||
mock_algo = Mock(spec=algo_module.ReteLimitAlgo)
|
||||
mock_algo.require_access = AsyncMock(return_value=True)
|
||||
mock_algo.initialize = AsyncMock()
|
||||
|
||||
stage = ratelimit.RateLimit(mock_app)
|
||||
|
||||
# Patch the algorithm selection to use our mock
|
||||
with patch.object(algo_module, 'preregistered_algos', []):
|
||||
stage.algo = mock_algo
|
||||
|
||||
result = await stage.process(sample_query, 'RequireRateLimitOccupancy')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query == sample_query
|
||||
mock_algo.require_access.assert_called_once_with(
|
||||
sample_query,
|
||||
'person',
|
||||
'12345'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_access_denied(mock_app, sample_query):
|
||||
"""Test RequireRateLimitOccupancy denies access when rate limit is exceeded"""
|
||||
ratelimit, entities, algo_module = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {}
|
||||
|
||||
# Create mock algorithm that denies access
|
||||
mock_algo = Mock(spec=algo_module.ReteLimitAlgo)
|
||||
mock_algo.require_access = AsyncMock(return_value=False)
|
||||
mock_algo.initialize = AsyncMock()
|
||||
|
||||
stage = ratelimit.RateLimit(mock_app)
|
||||
|
||||
# Patch the algorithm selection to use our mock
|
||||
with patch.object(algo_module, 'preregistered_algos', []):
|
||||
stage.algo = mock_algo
|
||||
|
||||
result = await stage.process(sample_query, 'RequireRateLimitOccupancy')
|
||||
|
||||
assert result.result_type == entities.ResultType.INTERRUPT
|
||||
assert result.user_notice != ''
|
||||
mock_algo.require_access.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_access(mock_app, sample_query):
|
||||
"""Test ReleaseRateLimitOccupancy releases rate limit occupancy"""
|
||||
ratelimit, entities, algo_module = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {}
|
||||
|
||||
# Create mock algorithm
|
||||
mock_algo = Mock(spec=algo_module.ReteLimitAlgo)
|
||||
mock_algo.release_access = AsyncMock()
|
||||
mock_algo.initialize = AsyncMock()
|
||||
|
||||
stage = ratelimit.RateLimit(mock_app)
|
||||
|
||||
# Patch the algorithm selection to use our mock
|
||||
with patch.object(algo_module, 'preregistered_algos', []):
|
||||
stage.algo = mock_algo
|
||||
|
||||
result = await stage.process(sample_query, 'ReleaseRateLimitOccupancy')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query == sample_query
|
||||
mock_algo.release_access.assert_called_once_with(
|
||||
sample_query,
|
||||
'person',
|
||||
'12345'
|
||||
)
|
||||
171
tests/unit_tests/pipeline/test_resprule.py
Normal file
171
tests/unit_tests/pipeline/test_resprule.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
GroupRespondRuleCheckStage unit tests
|
||||
|
||||
Tests the actual GroupRespondRuleCheckStage implementation from pkg.pipeline.resprule
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from importlib import import_module
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
|
||||
|
||||
def get_modules():
|
||||
"""Lazy import to ensure proper initialization order"""
|
||||
# Import pipelinemgr first to trigger proper stage registration
|
||||
pipelinemgr = import_module('pkg.pipeline.pipelinemgr')
|
||||
resprule = import_module('pkg.pipeline.resprule.resprule')
|
||||
entities = import_module('pkg.pipeline.entities')
|
||||
rule = import_module('pkg.pipeline.resprule.rule')
|
||||
rule_entities = import_module('pkg.pipeline.resprule.entities')
|
||||
return resprule, entities, rule, rule_entities
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_person_message_skip(mock_app, sample_query):
|
||||
"""Test person message skips rule check"""
|
||||
resprule, entities, rule, rule_entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.PERSON
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'group-respond-rules': {}
|
||||
}
|
||||
}
|
||||
|
||||
stage = resprule.GroupRespondRuleCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
|
||||
result = await stage.process(sample_query, 'GroupRespondRuleCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query == sample_query
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_no_match(mock_app, sample_query):
|
||||
"""Test group message with no matching rules"""
|
||||
resprule, entities, rule, rule_entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'group-respond-rules': {}
|
||||
}
|
||||
}
|
||||
|
||||
# Create mock rule matcher that doesn't match
|
||||
mock_rule = Mock(spec=rule.GroupRespondRule)
|
||||
mock_rule.match = AsyncMock(return_value=rule_entities.RuleJudgeResult(
|
||||
matching=False,
|
||||
replacement=sample_query.message_chain
|
||||
))
|
||||
|
||||
stage = resprule.GroupRespondRuleCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
stage.rule_matchers = [mock_rule]
|
||||
|
||||
result = await stage.process(sample_query, 'GroupRespondRuleCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.INTERRUPT
|
||||
assert result.new_query == sample_query
|
||||
mock_rule.match.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_match(mock_app, sample_query):
|
||||
"""Test group message with matching rule"""
|
||||
resprule, entities, rule, rule_entities = get_modules()
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
|
||||
sample_query.launcher_id = '12345'
|
||||
sample_query.pipeline_config = {
|
||||
'trigger': {
|
||||
'group-respond-rules': {}
|
||||
}
|
||||
}
|
||||
|
||||
# Create new message chain after rule processing
|
||||
new_chain = platform_message.MessageChain([
|
||||
platform_message.Plain(text='Processed message')
|
||||
])
|
||||
|
||||
# Create mock rule matcher that matches
|
||||
mock_rule = Mock(spec=rule.GroupRespondRule)
|
||||
mock_rule.match = AsyncMock(return_value=rule_entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
replacement=new_chain
|
||||
))
|
||||
|
||||
stage = resprule.GroupRespondRuleCheckStage(mock_app)
|
||||
await stage.initialize(sample_query.pipeline_config)
|
||||
stage.rule_matchers = [mock_rule]
|
||||
|
||||
result = await stage.process(sample_query, 'GroupRespondRuleCheckStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query == sample_query
|
||||
assert sample_query.message_chain == new_chain
|
||||
mock_rule.match.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atbot_rule_match(mock_app, sample_query):
|
||||
"""Test AtBotRule removes At component"""
|
||||
resprule, entities, rule, rule_entities = get_modules()
|
||||
atbot_module = import_module('pkg.pipeline.resprule.rules.atbot')
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
|
||||
sample_query.adapter.bot_account_id = '999'
|
||||
|
||||
# Create message chain with At component
|
||||
message_chain = platform_message.MessageChain([
|
||||
platform_message.At(target='999'),
|
||||
platform_message.Plain(text='Hello bot')
|
||||
])
|
||||
sample_query.message_chain = message_chain
|
||||
|
||||
atbot_rule = atbot_module.AtBotRule(mock_app)
|
||||
await atbot_rule.initialize()
|
||||
|
||||
result = await atbot_rule.match(
|
||||
str(message_chain),
|
||||
message_chain,
|
||||
{},
|
||||
sample_query
|
||||
)
|
||||
|
||||
assert result.matching is True
|
||||
# At component should be removed
|
||||
assert len(result.replacement.root) == 1
|
||||
assert isinstance(result.replacement.root[0], platform_message.Plain)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atbot_rule_no_match(mock_app, sample_query):
|
||||
"""Test AtBotRule when no At component present"""
|
||||
resprule, entities, rule, rule_entities = get_modules()
|
||||
atbot_module = import_module('pkg.pipeline.resprule.rules.atbot')
|
||||
|
||||
sample_query.launcher_type = provider_session.LauncherTypes.GROUP
|
||||
sample_query.adapter.bot_account_id = '999'
|
||||
|
||||
# Create message chain without At component
|
||||
message_chain = platform_message.MessageChain([
|
||||
platform_message.Plain(text='Hello')
|
||||
])
|
||||
sample_query.message_chain = message_chain
|
||||
|
||||
atbot_rule = atbot_module.AtBotRule(mock_app)
|
||||
await atbot_rule.initialize()
|
||||
|
||||
result = await atbot_rule.match(
|
||||
str(message_chain),
|
||||
message_chain,
|
||||
{},
|
||||
sample_query
|
||||
)
|
||||
|
||||
assert result.matching is False
|
||||
40
tests/unit_tests/pipeline/test_simple.py
Normal file
40
tests/unit_tests/pipeline/test_simple.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Simple standalone tests to verify test infrastructure
|
||||
These tests don't import the actual pipeline code to avoid circular import issues
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
|
||||
def test_pytest_works():
|
||||
"""Verify pytest is working"""
|
||||
assert True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_works():
|
||||
"""Verify async tests work"""
|
||||
mock = AsyncMock(return_value=42)
|
||||
result = await mock()
|
||||
assert result == 42
|
||||
|
||||
|
||||
def test_mocks_work():
|
||||
"""Verify mocking works"""
|
||||
mock = Mock()
|
||||
mock.return_value = 'test'
|
||||
assert mock() == 'test'
|
||||
|
||||
|
||||
def test_fixtures_work(mock_app):
|
||||
"""Verify fixtures are loaded"""
|
||||
assert mock_app is not None
|
||||
assert mock_app.logger is not None
|
||||
assert mock_app.sess_mgr is not None
|
||||
|
||||
|
||||
def test_sample_query(sample_query):
|
||||
"""Verify sample query fixture works"""
|
||||
assert sample_query.query_id == 'test-query-id'
|
||||
assert sample_query.launcher_id == 12345
|
||||
@@ -11,6 +11,10 @@ import { Message } from '@/app/infra/entities/message';
|
||||
import { toast } from 'sonner';
|
||||
import AtBadge from './AtBadge';
|
||||
import { Switch } from '@/components/ui/switch';
|
||||
import {
|
||||
PipelineWebSocketClient,
|
||||
SessionType,
|
||||
} from '@/app/infra/websocket/PipelineWebSocketClient';
|
||||
|
||||
interface MessageComponent {
|
||||
type: 'At' | 'Plain';
|
||||
@@ -31,17 +35,27 @@ export default function DebugDialog({
|
||||
}: DebugDialogProps) {
|
||||
const { t } = useTranslation();
|
||||
const [selectedPipelineId, setSelectedPipelineId] = useState(pipelineId);
|
||||
const [sessionType, setSessionType] = useState<'person' | 'group'>('person');
|
||||
const [sessionType, setSessionType] = useState<SessionType>('person');
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
const [inputValue, setInputValue] = useState('');
|
||||
const [showAtPopover, setShowAtPopover] = useState(false);
|
||||
const [hasAt, setHasAt] = useState(false);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
const [isStreaming, setIsStreaming] = useState(true);
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const popoverRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// WebSocket states
|
||||
const [wsClient, setWsClient] = useState<PipelineWebSocketClient | null>(
|
||||
null,
|
||||
);
|
||||
const [connectionStatus, setConnectionStatus] = useState<
|
||||
'disconnected' | 'connecting' | 'connected'
|
||||
>('disconnected');
|
||||
const [pendingMessages, setPendingMessages] = useState<
|
||||
Map<string, Message>
|
||||
>(new Map());
|
||||
|
||||
const scrollToBottom = useCallback(() => {
|
||||
// 使用setTimeout确保在DOM更新后执行滚动
|
||||
setTimeout(() => {
|
||||
@@ -57,37 +71,161 @@ export default function DebugDialog({
|
||||
}, 0);
|
||||
}, []);
|
||||
|
||||
const loadMessages = useCallback(
|
||||
async (pipelineId: string) => {
|
||||
try {
|
||||
const response = await httpClient.getWebChatHistoryMessages(
|
||||
pipelineId,
|
||||
sessionType,
|
||||
);
|
||||
setMessages(response.messages);
|
||||
} catch (error) {
|
||||
console.error('Failed to load messages:', error);
|
||||
}
|
||||
},
|
||||
[sessionType],
|
||||
);
|
||||
// 在useEffect中监听messages变化时滚动
|
||||
// Scroll to bottom when messages change
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
}, [messages, scrollToBottom]);
|
||||
|
||||
// WebSocket connection setup
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
setSelectedPipelineId(pipelineId);
|
||||
loadMessages(pipelineId);
|
||||
}
|
||||
}, [open, pipelineId]);
|
||||
if (!open) return;
|
||||
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
loadMessages(selectedPipelineId);
|
||||
}
|
||||
}, [sessionType, selectedPipelineId, open, loadMessages]);
|
||||
const client = new PipelineWebSocketClient(
|
||||
selectedPipelineId,
|
||||
sessionType,
|
||||
);
|
||||
|
||||
// Setup event handlers
|
||||
client.onConnected = (data) => {
|
||||
console.log('[DebugDialog] WebSocket connected:', data);
|
||||
setConnectionStatus('connected');
|
||||
// Load history messages after connection
|
||||
client.loadHistory();
|
||||
};
|
||||
|
||||
client.onHistory = (data) => {
|
||||
console.log('[DebugDialog] History loaded:', data?.messages.length);
|
||||
if (data) {
|
||||
setMessages(data.messages);
|
||||
}
|
||||
};
|
||||
|
||||
client.onMessageSent = (data) => {
|
||||
console.log('[DebugDialog] Message sent confirmed:', data);
|
||||
if (data) {
|
||||
// Update client message ID to server message ID
|
||||
const clientMsgId = data.client_message_id;
|
||||
const serverMsgId = data.server_message_id;
|
||||
|
||||
setMessages((prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === -1 && pendingMessages.has(clientMsgId)
|
||||
? { ...msg, id: serverMsgId }
|
||||
: msg,
|
||||
),
|
||||
);
|
||||
|
||||
setPendingMessages((prev) => {
|
||||
const newMap = new Map(prev);
|
||||
newMap.delete(clientMsgId);
|
||||
return newMap;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
client.onMessageStart = (data) => {
|
||||
console.log('[DebugDialog] Message start:', data);
|
||||
if (data) {
|
||||
const placeholderMessage: Message = {
|
||||
id: data.message_id,
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
message_chain: [],
|
||||
timestamp: data.timestamp,
|
||||
};
|
||||
setMessages((prev) => [...prev, placeholderMessage]);
|
||||
}
|
||||
};
|
||||
|
||||
client.onMessageChunk = (data) => {
|
||||
if (data) {
|
||||
// Update streaming message (content is cumulative)
|
||||
setMessages((prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === data.message_id
|
||||
? {
|
||||
...msg,
|
||||
content: data.content,
|
||||
message_chain: data.message_chain,
|
||||
}
|
||||
: msg,
|
||||
),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
client.onMessageComplete = (data) => {
|
||||
console.log('[DebugDialog] Message complete:', data);
|
||||
if (data) {
|
||||
// Mark message as complete
|
||||
setMessages((prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === data.message_id
|
||||
? {
|
||||
...msg,
|
||||
content: data.final_content,
|
||||
message_chain: data.message_chain,
|
||||
}
|
||||
: msg,
|
||||
),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
client.onMessageError = (data) => {
|
||||
console.error('[DebugDialog] Message error:', data);
|
||||
if (data) {
|
||||
toast.error(`Message error: ${data.error}`);
|
||||
}
|
||||
};
|
||||
|
||||
client.onPluginMessage = (data) => {
|
||||
console.log('[DebugDialog] Plugin message:', data);
|
||||
if (data) {
|
||||
const pluginMessage: Message = {
|
||||
id: data.message_id,
|
||||
role: 'assistant',
|
||||
content: data.content,
|
||||
message_chain: data.message_chain,
|
||||
timestamp: data.timestamp,
|
||||
};
|
||||
setMessages((prev) => [...prev, pluginMessage]);
|
||||
}
|
||||
};
|
||||
|
||||
client.onError = (data) => {
|
||||
console.error('[DebugDialog] WebSocket error:', data);
|
||||
if (data) {
|
||||
toast.error(`WebSocket error: ${data.error}`);
|
||||
}
|
||||
};
|
||||
|
||||
client.onDisconnected = () => {
|
||||
console.log('[DebugDialog] WebSocket disconnected');
|
||||
setConnectionStatus('disconnected');
|
||||
};
|
||||
|
||||
// Connect to WebSocket
|
||||
setConnectionStatus('connecting');
|
||||
client
|
||||
.connect(httpClient.getSessionSync())
|
||||
.then(() => {
|
||||
console.log('[DebugDialog] WebSocket connection established');
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error('[DebugDialog] Failed to connect WebSocket:', err);
|
||||
toast.error('Failed to connect to server');
|
||||
setConnectionStatus('disconnected');
|
||||
});
|
||||
|
||||
setWsClient(client);
|
||||
|
||||
// Cleanup on unmount or session type change
|
||||
return () => {
|
||||
console.log('[DebugDialog] Cleaning up WebSocket connection');
|
||||
client.disconnect();
|
||||
};
|
||||
}, [open, selectedPipelineId, sessionType]); // Reconnect when session type changes
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
@@ -150,6 +288,12 @@ export default function DebugDialog({
|
||||
const sendMessage = async () => {
|
||||
if (!inputValue.trim() && !hasAt) return;
|
||||
|
||||
// Check WebSocket connection
|
||||
if (!wsClient || connectionStatus !== 'connected') {
|
||||
toast.error('Not connected to server');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const messageChain = [];
|
||||
|
||||
@@ -170,7 +314,7 @@ export default function DebugDialog({
|
||||
});
|
||||
|
||||
if (hasAt) {
|
||||
// for showing
|
||||
// For display
|
||||
text_content = '@webchatbot' + text_content;
|
||||
}
|
||||
|
||||
@@ -181,97 +325,26 @@ export default function DebugDialog({
|
||||
timestamp: new Date().toISOString(),
|
||||
message_chain: messageChain,
|
||||
};
|
||||
// 根据isStreaming状态决定使用哪种传输方式
|
||||
if (isStreaming) {
|
||||
// streaming
|
||||
// 创建初始bot消息
|
||||
const placeholderRandomId = Math.floor(Math.random() * 1000000);
|
||||
const botMessagePlaceholder: Message = {
|
||||
id: placeholderRandomId,
|
||||
role: 'assistant',
|
||||
content: 'Generating...',
|
||||
timestamp: new Date().toISOString(),
|
||||
message_chain: [{ type: 'Plain', text: 'Generating...' }],
|
||||
};
|
||||
|
||||
// 添加用户消息和初始bot消息到状态
|
||||
// Add user message to UI immediately
|
||||
setMessages((prevMessages) => [...prevMessages, userMessage]);
|
||||
setInputValue('');
|
||||
setHasAt(false);
|
||||
|
||||
setMessages((prevMessages) => [
|
||||
...prevMessages,
|
||||
userMessage,
|
||||
botMessagePlaceholder,
|
||||
]);
|
||||
setInputValue('');
|
||||
setHasAt(false);
|
||||
try {
|
||||
await httpClient.sendStreamingWebChatMessage(
|
||||
sessionType,
|
||||
messageChain,
|
||||
selectedPipelineId,
|
||||
(data) => {
|
||||
// 处理流式响应数据
|
||||
console.log('data', data);
|
||||
if (data.message) {
|
||||
// 更新完整内容
|
||||
// Send via WebSocket
|
||||
const clientMessageId = wsClient.sendMessage(messageChain);
|
||||
|
||||
setMessages((prevMessages) => {
|
||||
const updatedMessages = [...prevMessages];
|
||||
const botMessageIndex = updatedMessages.findIndex(
|
||||
(message) => message.id === placeholderRandomId,
|
||||
);
|
||||
if (botMessageIndex !== -1) {
|
||||
updatedMessages[botMessageIndex] = {
|
||||
...updatedMessages[botMessageIndex],
|
||||
content: data.message.content,
|
||||
message_chain: [
|
||||
{ type: 'Plain', text: data.message.content },
|
||||
],
|
||||
};
|
||||
}
|
||||
return updatedMessages;
|
||||
});
|
||||
}
|
||||
},
|
||||
() => {},
|
||||
(error) => {
|
||||
// 处理错误
|
||||
console.error('Streaming error:', error);
|
||||
if (sessionType === 'person') {
|
||||
toast.error(t('pipelines.debugDialog.sendFailed'));
|
||||
}
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Failed to send streaming message:', error);
|
||||
if (sessionType === 'person') {
|
||||
toast.error(t('pipelines.debugDialog.sendFailed'));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// non-streaming
|
||||
setMessages((prevMessages) => [...prevMessages, userMessage]);
|
||||
setInputValue('');
|
||||
setHasAt(false);
|
||||
// Track pending message for ID mapping
|
||||
setPendingMessages((prev) => {
|
||||
const newMap = new Map(prev);
|
||||
newMap.set(clientMessageId, userMessage);
|
||||
return newMap;
|
||||
});
|
||||
|
||||
const response = await httpClient.sendWebChatMessage(
|
||||
sessionType,
|
||||
messageChain,
|
||||
selectedPipelineId,
|
||||
180000,
|
||||
);
|
||||
|
||||
setMessages((prevMessages) => [...prevMessages, response.message]);
|
||||
}
|
||||
} catch (
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
error: any
|
||||
) {
|
||||
console.log(error, 'type of error', typeof error);
|
||||
console.error('Failed to send message:', error);
|
||||
|
||||
if (!error.message.includes('timeout') && sessionType === 'person') {
|
||||
toast.error(t('pipelines.debugDialog.sendFailed'));
|
||||
}
|
||||
console.log('[DebugDialog] Message sent:', clientMessageId);
|
||||
} catch (error) {
|
||||
console.error('[DebugDialog] Failed to send message:', error);
|
||||
toast.error(t('pipelines.debugDialog.sendFailed'));
|
||||
} finally {
|
||||
inputRef.current?.focus();
|
||||
}
|
||||
@@ -390,12 +463,6 @@ export default function DebugDialog({
|
||||
</ScrollArea>
|
||||
|
||||
<div className="p-4 pb-0 bg-white dark:bg-black flex gap-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm text-gray-600">
|
||||
{t('pipelines.debugDialog.streaming')}
|
||||
</span>
|
||||
<Switch checked={isStreaming} onCheckedChange={setIsStreaming} />
|
||||
</div>
|
||||
<div className="flex-1 flex items-center gap-2">
|
||||
{hasAt && (
|
||||
<AtBadge targetName="webchatbot" onRemove={handleAtRemove} />
|
||||
|
||||
@@ -29,7 +29,17 @@ export default function PluginConfigPage() {
|
||||
const [sortOrderValue, setSortOrderValue] = useState<string>('DESC');
|
||||
|
||||
useEffect(() => {
|
||||
getPipelines();
|
||||
// Load sort preference from localStorage
|
||||
const savedSortBy = localStorage.getItem('pipeline_sort_by');
|
||||
const savedSortOrder = localStorage.getItem('pipeline_sort_order');
|
||||
|
||||
if (savedSortBy && savedSortOrder) {
|
||||
setSortByValue(savedSortBy);
|
||||
setSortOrderValue(savedSortOrder);
|
||||
getPipelines(savedSortBy, savedSortOrder);
|
||||
} else {
|
||||
getPipelines();
|
||||
}
|
||||
}, []);
|
||||
|
||||
function getPipelines(
|
||||
@@ -91,6 +101,11 @@ export default function PluginConfigPage() {
|
||||
const [newSortBy, newSortOrder] = value.split(',').map((s) => s.trim());
|
||||
setSortByValue(newSortBy);
|
||||
setSortOrderValue(newSortOrder);
|
||||
|
||||
// Save sort preference to localStorage
|
||||
localStorage.setItem('pipeline_sort_by', newSortBy);
|
||||
localStorage.setItem('pipeline_sort_order', newSortOrder);
|
||||
|
||||
getPipelines(newSortBy, newSortOrder);
|
||||
}
|
||||
|
||||
@@ -135,6 +150,12 @@ export default function PluginConfigPage() {
|
||||
>
|
||||
{t('pipelines.newestCreated')}
|
||||
</SelectItem>
|
||||
<SelectItem
|
||||
value="created_at,ASC"
|
||||
className="text-gray-900 dark:text-gray-100"
|
||||
>
|
||||
{t('pipelines.earliestCreated')}
|
||||
</SelectItem>
|
||||
<SelectItem
|
||||
value="updated_at,DESC"
|
||||
className="text-gray-900 dark:text-gray-100"
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
import { PluginComponent } from '@/app/infra/entities/plugin';
|
||||
import { TFunction } from 'i18next';
|
||||
import { Wrench, AudioWaveform, Hash } from 'lucide-react';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
|
||||
export default function PluginComponentList({
|
||||
components,
|
||||
showComponentName,
|
||||
showTitle,
|
||||
useBadge,
|
||||
t,
|
||||
}: {
|
||||
components: PluginComponent[];
|
||||
showComponentName: boolean;
|
||||
showTitle: boolean;
|
||||
useBadge: boolean;
|
||||
t: TFunction;
|
||||
}) {
|
||||
const componentKindCount: Record<string, number> = {};
|
||||
|
||||
for (const component of components) {
|
||||
const kind = component.manifest.manifest.kind;
|
||||
if (componentKindCount[kind]) {
|
||||
componentKindCount[kind]++;
|
||||
} else {
|
||||
componentKindCount[kind] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
const kindIconMap: Record<string, React.ReactNode> = {
|
||||
Tool: <Wrench className="w-5 h-5" />,
|
||||
EventListener: <AudioWaveform className="w-5 h-5" />,
|
||||
Command: <Hash className="w-5 h-5" />,
|
||||
};
|
||||
|
||||
const componentKindList = Object.keys(componentKindCount);
|
||||
|
||||
return (
|
||||
<>
|
||||
{showTitle && <div>{t('plugins.componentsList')}</div>}
|
||||
{componentKindList.length > 0 && (
|
||||
<>
|
||||
{componentKindList.map((kind) => {
|
||||
return (
|
||||
<>
|
||||
{useBadge && (
|
||||
<Badge variant="outline">
|
||||
{kindIconMap[kind]}
|
||||
{showComponentName &&
|
||||
t('plugins.componentName.' + kind) + ' '}
|
||||
{componentKindCount[kind]}
|
||||
</Badge>
|
||||
)}
|
||||
|
||||
{!useBadge && (
|
||||
<div
|
||||
key={kind}
|
||||
className="flex flex-row items-center justify-start gap-[0.2rem]"
|
||||
>
|
||||
{kindIconMap[kind]}
|
||||
{showComponentName &&
|
||||
t('plugins.componentName.' + kind) + ' '}
|
||||
{componentKindCount[kind]}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
{componentKindList.length === 0 && <div>{t('plugins.noComponents')}</div>}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
'use client';
|
||||
|
||||
import { useState, useEffect, forwardRef, useImperativeHandle } from 'react';
|
||||
import { PluginCardVO } from '@/app/home/plugins/plugin-installed/PluginCardVO';
|
||||
import PluginCardComponent from '@/app/home/plugins/plugin-installed/plugin-card/PluginCardComponent';
|
||||
import PluginForm from '@/app/home/plugins/plugin-installed/plugin-form/PluginForm';
|
||||
import { PluginCardVO } from '@/app/home/plugins/components/plugin-installed/PluginCardVO';
|
||||
import PluginCardComponent from '@/app/home/plugins/components/plugin-installed/plugin-card/PluginCardComponent';
|
||||
import PluginForm from '@/app/home/plugins/components/plugin-installed/plugin-form/PluginForm';
|
||||
import styles from '@/app/home/plugins/plugins.module.css';
|
||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||
import {
|
||||
@@ -1,21 +1,10 @@
|
||||
import { PluginCardVO } from '@/app/home/plugins/plugin-installed/PluginCardVO';
|
||||
import { PluginCardVO } from '@/app/home/plugins/components/plugin-installed/PluginCardVO';
|
||||
import { useState } from 'react';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { TFunction } from 'i18next';
|
||||
import {
|
||||
AudioWaveform,
|
||||
Wrench,
|
||||
Hash,
|
||||
BugIcon,
|
||||
ExternalLink,
|
||||
Ellipsis,
|
||||
Trash,
|
||||
ArrowUp,
|
||||
} from 'lucide-react';
|
||||
import { BugIcon, ExternalLink, Ellipsis, Trash, ArrowUp } from 'lucide-react';
|
||||
import { getCloudServiceClientSync } from '@/app/infra/http';
|
||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||
import { PluginComponent } from '@/app/infra/entities/plugin';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import {
|
||||
DropdownMenu,
|
||||
@@ -23,49 +12,7 @@ import {
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from '@/components/ui/dropdown-menu';
|
||||
|
||||
function getComponentList(components: PluginComponent[], t: TFunction) {
|
||||
const componentKindCount: Record<string, number> = {};
|
||||
|
||||
for (const component of components) {
|
||||
const kind = component.manifest.manifest.kind;
|
||||
if (componentKindCount[kind]) {
|
||||
componentKindCount[kind]++;
|
||||
} else {
|
||||
componentKindCount[kind] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
const kindIconMap: Record<string, React.ReactNode> = {
|
||||
Tool: <Wrench className="w-5 h-5" />,
|
||||
EventListener: <AudioWaveform className="w-5 h-5" />,
|
||||
Command: <Hash className="w-5 h-5" />,
|
||||
};
|
||||
|
||||
const componentKindList = Object.keys(componentKindCount);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div>{t('plugins.componentsList')}</div>
|
||||
{componentKindList.length > 0 && (
|
||||
<>
|
||||
{componentKindList.map((kind) => {
|
||||
return (
|
||||
<div
|
||||
key={kind}
|
||||
className="flex flex-row items-center justify-start gap-[0.4rem]"
|
||||
>
|
||||
{kindIconMap[kind]} {componentKindCount[kind]}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
{componentKindList.length === 0 && <div>{t('plugins.noComponents')}</div>}
|
||||
</>
|
||||
);
|
||||
}
|
||||
import PluginComponentList from '@/app/home/plugins/components/plugin-installed/PluginComponentList';
|
||||
|
||||
export default function PluginCardComponent({
|
||||
cardVO,
|
||||
@@ -180,7 +127,13 @@ export default function PluginCardComponent({
|
||||
</div>
|
||||
|
||||
<div className="w-full flex flex-row items-start justify-start gap-[0.6rem]">
|
||||
{getComponentList(cardVO.components, t)}
|
||||
<PluginComponentList
|
||||
components={cardVO.components}
|
||||
showComponentName={false}
|
||||
showTitle={true}
|
||||
useBadge={false}
|
||||
t={t}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -7,6 +7,7 @@ import { Button } from '@/components/ui/button';
|
||||
import { toast } from 'sonner';
|
||||
import { extractI18nObject } from '@/i18n/I18nProvider';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import PluginComponentList from '@/app/home/plugins/components/plugin-installed/PluginComponentList';
|
||||
|
||||
export default function PluginForm({
|
||||
pluginAuthor,
|
||||
@@ -78,6 +79,17 @@ export default function PluginForm({
|
||||
},
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="mb-4 flex flex-row items-center justify-start gap-[0.4rem]">
|
||||
<PluginComponentList
|
||||
components={pluginInfo.components}
|
||||
showComponentName={true}
|
||||
showTitle={false}
|
||||
useBadge={true}
|
||||
t={t}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{pluginInfo.manifest.manifest.spec.config.length > 0 && (
|
||||
<DynamicFormComponent
|
||||
itemConfigList={pluginInfo.manifest.manifest.spec.config}
|
||||
@@ -1,8 +1,8 @@
|
||||
'use client';
|
||||
import PluginInstalledComponent, {
|
||||
PluginInstalledComponentRef,
|
||||
} from '@/app/home/plugins/plugin-installed/PluginInstalledComponent';
|
||||
import MarketPage from '@/app/home/plugins/plugin-market/PluginMarketComponent';
|
||||
} from '@/app/home/plugins/components/plugin-installed/PluginInstalledComponent';
|
||||
import MarketPage from '@/app/home/plugins/components/plugin-market/PluginMarketComponent';
|
||||
// import PluginSortDialog from '@/app/home/plugins/plugin-sort/PluginSortDialog';
|
||||
import styles from './plugins.module.css';
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
UploadIcon,
|
||||
StoreIcon,
|
||||
Download,
|
||||
Power,
|
||||
} from 'lucide-react';
|
||||
import {
|
||||
DropdownMenu,
|
||||
@@ -28,12 +29,13 @@ import {
|
||||
DialogFooter,
|
||||
} from '@/components/ui/dialog';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { useState, useRef, useCallback } from 'react';
|
||||
import { useState, useRef, useCallback, useEffect } from 'react';
|
||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||
import { toast } from 'sonner';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PluginV4 } from '@/app/infra/entities/plugin';
|
||||
import { systemInfo } from '@/app/infra/http/HttpClient';
|
||||
import { ApiRespPluginSystemStatus } from '@/app/infra/entities/api';
|
||||
|
||||
enum PluginInstallStatus {
|
||||
WAIT_INPUT = 'wait_input',
|
||||
@@ -54,9 +56,29 @@ export default function PluginConfigPage() {
|
||||
const [installError, setInstallError] = useState<string | null>(null);
|
||||
const [githubURL, setGithubURL] = useState('');
|
||||
const [isDragOver, setIsDragOver] = useState(false);
|
||||
const [pluginSystemStatus, setPluginSystemStatus] =
|
||||
useState<ApiRespPluginSystemStatus | null>(null);
|
||||
const [statusLoading, setStatusLoading] = useState(true);
|
||||
const pluginInstalledRef = useRef<PluginInstalledComponentRef>(null);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchPluginSystemStatus = async () => {
|
||||
try {
|
||||
setStatusLoading(true);
|
||||
const status = await httpClient.getPluginSystemStatus();
|
||||
setPluginSystemStatus(status);
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch plugin system status:', error);
|
||||
toast.error(t('plugins.failedToGetStatus'));
|
||||
} finally {
|
||||
setStatusLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchPluginSystemStatus();
|
||||
}, [t]);
|
||||
|
||||
function watchTask(taskId: number) {
|
||||
let alreadySuccess = false;
|
||||
console.log('taskId:', taskId);
|
||||
@@ -140,6 +162,11 @@ export default function PluginConfigPage() {
|
||||
|
||||
const uploadPluginFile = useCallback(
|
||||
async (file: File) => {
|
||||
if (!pluginSystemStatus?.is_enable || !pluginSystemStatus?.is_connected) {
|
||||
toast.error(t('plugins.pluginSystemNotReady'));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!validateFileType(file)) {
|
||||
toast.error(t('plugins.unsupportedFileType'));
|
||||
return;
|
||||
@@ -150,7 +177,7 @@ export default function PluginConfigPage() {
|
||||
setInstallError(null);
|
||||
installPlugin('local', { file });
|
||||
},
|
||||
[t],
|
||||
[t, pluginSystemStatus],
|
||||
);
|
||||
|
||||
const handleFileSelect = useCallback(() => {
|
||||
@@ -171,10 +198,18 @@ export default function PluginConfigPage() {
|
||||
[uploadPluginFile],
|
||||
);
|
||||
|
||||
const handleDragOver = useCallback((event: React.DragEvent) => {
|
||||
event.preventDefault();
|
||||
setIsDragOver(true);
|
||||
}, []);
|
||||
const isPluginSystemReady =
|
||||
pluginSystemStatus?.is_enable && pluginSystemStatus?.is_connected;
|
||||
|
||||
const handleDragOver = useCallback(
|
||||
(event: React.DragEvent) => {
|
||||
event.preventDefault();
|
||||
if (isPluginSystemReady) {
|
||||
setIsDragOver(true);
|
||||
}
|
||||
},
|
||||
[isPluginSystemReady],
|
||||
);
|
||||
|
||||
const handleDragLeave = useCallback((event: React.DragEvent) => {
|
||||
event.preventDefault();
|
||||
@@ -186,14 +221,76 @@ export default function PluginConfigPage() {
|
||||
event.preventDefault();
|
||||
setIsDragOver(false);
|
||||
|
||||
if (!isPluginSystemReady) {
|
||||
toast.error(t('plugins.pluginSystemNotReady'));
|
||||
return;
|
||||
}
|
||||
|
||||
const files = Array.from(event.dataTransfer.files);
|
||||
if (files.length > 0) {
|
||||
uploadPluginFile(files[0]);
|
||||
}
|
||||
},
|
||||
[uploadPluginFile],
|
||||
[uploadPluginFile, isPluginSystemReady, t],
|
||||
);
|
||||
|
||||
// 插件系统未启用的状态显示
|
||||
const renderPluginDisabledState = () => (
|
||||
<div className="flex flex-col items-center justify-center h-[60vh] text-center pt-[10vh]">
|
||||
<Power className="w-16 h-16 text-gray-400 mb-4" />
|
||||
<h2 className="text-2xl font-semibold text-gray-700 dark:text-gray-300 mb-2">
|
||||
{t('plugins.systemDisabled')}
|
||||
</h2>
|
||||
<p className="text-gray-500 dark:text-gray-400 max-w-md">
|
||||
{t('plugins.systemDisabledDesc')}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
// 插件系统连接异常的状态显示
|
||||
const renderPluginConnectionErrorState = () => (
|
||||
<div className="flex flex-col items-center justify-center h-[60vh] text-center pt-[10vh]">
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
width="72"
|
||||
height="72"
|
||||
fill="#BDBDBD"
|
||||
>
|
||||
<path d="M17.657 14.8284L16.2428 13.4142L17.657 12C19.2191 10.4379 19.2191 7.90526 17.657 6.34316C16.0949 4.78106 13.5622 4.78106 12.0001 6.34316L10.5859 7.75737L9.17171 6.34316L10.5859 4.92895C12.9291 2.5858 16.7281 2.5858 19.0712 4.92895C21.4143 7.27209 21.4143 11.0711 19.0712 13.4142L17.657 14.8284ZM14.8286 17.6569L13.4143 19.0711C11.0712 21.4142 7.27221 21.4142 4.92907 19.0711C2.58592 16.7279 2.58592 12.9289 4.92907 10.5858L6.34328 9.17159L7.75749 10.5858L6.34328 12C4.78118 13.5621 4.78118 16.0948 6.34328 17.6569C7.90538 19.219 10.438 19.219 12.0001 17.6569L13.4143 16.2427L14.8286 17.6569ZM14.8286 7.75737L16.2428 9.17159L9.17171 16.2427L7.75749 14.8284L14.8286 7.75737ZM5.77539 2.29291L7.70724 1.77527L8.74252 5.63897L6.81067 6.15661L5.77539 2.29291ZM15.2578 18.3611L17.1896 17.8434L18.2249 21.7071L16.293 22.2248L15.2578 18.3611ZM2.29303 5.77527L6.15673 6.81054L5.63909 8.7424L1.77539 7.70712L2.29303 5.77527ZM18.3612 15.2576L22.2249 16.2929L21.7072 18.2248L17.8435 17.1895L18.3612 15.2576Z"></path>
|
||||
</svg>
|
||||
|
||||
<h2 className="text-2xl font-semibold text-gray-700 dark:text-gray-300 mb-2">
|
||||
{t('plugins.connectionError')}
|
||||
</h2>
|
||||
<p className="text-gray-500 dark:text-gray-400 max-w-md mb-4">
|
||||
{t('plugins.connectionErrorDesc')}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
// 加载状态显示
|
||||
const renderLoadingState = () => (
|
||||
<div className="flex flex-col items-center justify-center h-[60vh] pt-[10vh]">
|
||||
<p className="text-gray-500 dark:text-gray-400">
|
||||
{t('plugins.loadingStatus')}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
// 根据状态返回不同的内容
|
||||
if (statusLoading) {
|
||||
return renderLoadingState();
|
||||
}
|
||||
|
||||
if (!pluginSystemStatus?.is_enable) {
|
||||
return renderPluginDisabledState();
|
||||
}
|
||||
|
||||
if (!pluginSystemStatus?.is_connected) {
|
||||
return renderPluginConnectionErrorState();
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${styles.pageContainer} ${isDragOver ? 'bg-blue-50' : ''}`}
|
||||
|
||||
@@ -215,6 +215,12 @@ export interface ApiRespSystemInfo {
|
||||
enable_marketplace: boolean;
|
||||
}
|
||||
|
||||
export interface ApiRespPluginSystemStatus {
|
||||
is_enable: boolean;
|
||||
is_connected: boolean;
|
||||
plugin_connector_error: string;
|
||||
}
|
||||
|
||||
export interface ApiRespAsyncTasks {
|
||||
tasks: AsyncTask[];
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ import {
|
||||
ApiRespProviderEmbeddingModels,
|
||||
ApiRespProviderEmbeddingModel,
|
||||
EmbeddingModel,
|
||||
ApiRespPluginSystemStatus,
|
||||
} from '@/app/infra/entities/api';
|
||||
import { GetBotLogsRequest } from '@/app/infra/http/requestParam/bots/GetBotLogsRequest';
|
||||
import { GetBotLogsResponse } from '@/app/infra/http/requestParam/bots/GetBotLogsResponse';
|
||||
@@ -500,6 +501,10 @@ export class BackendClient extends BaseHttpClient {
|
||||
return this.get(`/api/v1/system/tasks/${id}`);
|
||||
}
|
||||
|
||||
public getPluginSystemStatus(): Promise<ApiRespPluginSystemStatus> {
|
||||
return this.get('/api/v1/system/status/plugin-system');
|
||||
}
|
||||
|
||||
// ============ User API ============
|
||||
public checkIfInited(): Promise<{ initialized: boolean }> {
|
||||
return this.get('/api/v1/user/init');
|
||||
|
||||
394
web/src/app/infra/websocket/PipelineWebSocketClient.ts
Normal file
394
web/src/app/infra/websocket/PipelineWebSocketClient.ts
Normal file
@@ -0,0 +1,394 @@
|
||||
/**
|
||||
* Pipeline WebSocket Client
|
||||
*
|
||||
* Provides real-time bidirectional communication for pipeline debugging.
|
||||
* Supports person and group session isolation.
|
||||
*/
|
||||
|
||||
import { Message } from '@/app/infra/entities/message';
|
||||
|
||||
export type SessionType = 'person' | 'group';
|
||||
|
||||
export interface WebSocketEventData {
|
||||
// Connected event
|
||||
connected?: {
|
||||
connection_id: string;
|
||||
session_type: SessionType;
|
||||
pipeline_uuid: string;
|
||||
};
|
||||
|
||||
// History event
|
||||
history?: {
|
||||
messages: Message[];
|
||||
has_more: boolean;
|
||||
};
|
||||
|
||||
// Message sent confirmation
|
||||
message_sent?: {
|
||||
client_message_id: string;
|
||||
server_message_id: number;
|
||||
timestamp: string;
|
||||
};
|
||||
|
||||
// Message start
|
||||
message_start?: {
|
||||
message_id: number;
|
||||
role: 'assistant';
|
||||
timestamp: string;
|
||||
reply_to: number;
|
||||
};
|
||||
|
||||
// Message chunk
|
||||
message_chunk?: {
|
||||
message_id: number;
|
||||
content: string;
|
||||
message_chain: object[];
|
||||
timestamp: string;
|
||||
};
|
||||
|
||||
// Message complete
|
||||
message_complete?: {
|
||||
message_id: number;
|
||||
final_content: string;
|
||||
message_chain: object[];
|
||||
timestamp: string;
|
||||
};
|
||||
|
||||
// Message error
|
||||
message_error?: {
|
||||
message_id: number;
|
||||
error: string;
|
||||
error_code?: string;
|
||||
};
|
||||
|
||||
// Interrupted
|
||||
interrupted?: {
|
||||
message_id: number;
|
||||
partial_content: string;
|
||||
};
|
||||
|
||||
// Plugin message
|
||||
plugin_message?: {
|
||||
message_id: number;
|
||||
role: 'assistant';
|
||||
content: string;
|
||||
message_chain: object[];
|
||||
timestamp: string;
|
||||
source: 'plugin';
|
||||
};
|
||||
|
||||
// Error
|
||||
error?: {
|
||||
error: string;
|
||||
error_code: string;
|
||||
details?: object;
|
||||
};
|
||||
|
||||
// Pong
|
||||
pong?: {
|
||||
timestamp: number;
|
||||
};
|
||||
}
|
||||
|
||||
export class PipelineWebSocketClient {
|
||||
private ws: WebSocket | null = null;
|
||||
private pipelineId: string;
|
||||
private sessionType: SessionType;
|
||||
private reconnectAttempts = 0;
|
||||
private maxReconnectAttempts = 5;
|
||||
private pingInterval: NodeJS.Timeout | null = null;
|
||||
private reconnectTimeout: NodeJS.Timeout | null = null;
|
||||
private isManualDisconnect = false;
|
||||
|
||||
// Event callbacks
|
||||
public onConnected?: (data: WebSocketEventData['connected']) => void;
|
||||
public onHistory?: (data: WebSocketEventData['history']) => void;
|
||||
public onMessageSent?: (data: WebSocketEventData['message_sent']) => void;
|
||||
public onMessageStart?: (data: WebSocketEventData['message_start']) => void;
|
||||
public onMessageChunk?: (data: WebSocketEventData['message_chunk']) => void;
|
||||
public onMessageComplete?: (
|
||||
data: WebSocketEventData['message_complete'],
|
||||
) => void;
|
||||
public onMessageError?: (data: WebSocketEventData['message_error']) => void;
|
||||
public onInterrupted?: (data: WebSocketEventData['interrupted']) => void;
|
||||
public onPluginMessage?: (data: WebSocketEventData['plugin_message']) => void;
|
||||
public onError?: (data: WebSocketEventData['error']) => void;
|
||||
public onDisconnected?: () => void;
|
||||
|
||||
constructor(pipelineId: string, sessionType: SessionType) {
|
||||
this.pipelineId = pipelineId;
|
||||
this.sessionType = sessionType;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to WebSocket server
|
||||
*/
|
||||
connect(token: string): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
this.isManualDisconnect = false;
|
||||
|
||||
const wsUrl = this.buildWebSocketUrl();
|
||||
console.log(`[WebSocket] Connecting to ${wsUrl}...`);
|
||||
|
||||
try {
|
||||
this.ws = new WebSocket(wsUrl);
|
||||
} catch (error) {
|
||||
console.error('[WebSocket] Failed to create WebSocket:', error);
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
|
||||
this.ws.onopen = () => {
|
||||
console.log('[WebSocket] Connection opened');
|
||||
|
||||
// Send connect event with session type and token
|
||||
this.send('connect', {
|
||||
pipeline_uuid: this.pipelineId,
|
||||
session_type: this.sessionType,
|
||||
token,
|
||||
});
|
||||
|
||||
// Start ping interval
|
||||
this.startPing();
|
||||
|
||||
// Reset reconnect attempts on successful connection
|
||||
this.reconnectAttempts = 0;
|
||||
|
||||
resolve();
|
||||
};
|
||||
|
||||
this.ws.onmessage = (event) => {
|
||||
this.handleMessage(event);
|
||||
};
|
||||
|
||||
this.ws.onerror = (error) => {
|
||||
console.error('[WebSocket] Error:', error);
|
||||
reject(error);
|
||||
};
|
||||
|
||||
this.ws.onclose = (event) => {
|
||||
console.log(
|
||||
`[WebSocket] Connection closed: code=${event.code}, reason=${event.reason}`,
|
||||
);
|
||||
this.handleDisconnect();
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle incoming WebSocket message
|
||||
*/
|
||||
private handleMessage(event: MessageEvent) {
|
||||
try {
|
||||
const message = JSON.parse(event.data);
|
||||
const { type, data } = message;
|
||||
|
||||
console.log(`[WebSocket] Received: ${type}`, data);
|
||||
|
||||
switch (type) {
|
||||
case 'connected':
|
||||
this.onConnected?.(data);
|
||||
break;
|
||||
case 'history':
|
||||
this.onHistory?.(data);
|
||||
break;
|
||||
case 'message_sent':
|
||||
this.onMessageSent?.(data);
|
||||
break;
|
||||
case 'message_start':
|
||||
this.onMessageStart?.(data);
|
||||
break;
|
||||
case 'message_chunk':
|
||||
this.onMessageChunk?.(data);
|
||||
break;
|
||||
case 'message_complete':
|
||||
this.onMessageComplete?.(data);
|
||||
break;
|
||||
case 'message_error':
|
||||
this.onMessageError?.(data);
|
||||
break;
|
||||
case 'interrupted':
|
||||
this.onInterrupted?.(data);
|
||||
break;
|
||||
case 'plugin_message':
|
||||
this.onPluginMessage?.(data);
|
||||
break;
|
||||
case 'error':
|
||||
this.onError?.(data);
|
||||
break;
|
||||
case 'pong':
|
||||
// Heartbeat response, no action needed
|
||||
break;
|
||||
default:
|
||||
console.warn(`[WebSocket] Unknown message type: ${type}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[WebSocket] Failed to parse message:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send message to server
|
||||
*/
|
||||
sendMessage(messageChain: object[]): string {
|
||||
const clientMessageId = this.generateMessageId();
|
||||
this.send('send_message', {
|
||||
message_chain: messageChain,
|
||||
client_message_id: clientMessageId,
|
||||
});
|
||||
return clientMessageId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load history messages
|
||||
*/
|
||||
loadHistory(beforeMessageId?: number, limit?: number) {
|
||||
this.send('load_history', {
|
||||
before_message_id: beforeMessageId,
|
||||
limit,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Interrupt streaming message
|
||||
*/
|
||||
interrupt(messageId: number) {
|
||||
this.send('interrupt', { message_id: messageId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Send event to server
|
||||
*/
|
||||
private send(type: string, data: object) {
|
||||
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||
const message = JSON.stringify({ type, data });
|
||||
console.log(`[WebSocket] Sending: ${type}`, data);
|
||||
this.ws.send(message);
|
||||
} else {
|
||||
console.warn(
|
||||
`[WebSocket] Cannot send message, connection not open (state: ${this.ws?.readyState})`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start ping interval (heartbeat)
|
||||
*/
|
||||
private startPing() {
|
||||
this.stopPing();
|
||||
this.pingInterval = setInterval(() => {
|
||||
this.send('ping', { timestamp: Date.now() });
|
||||
}, 30000); // Ping every 30 seconds
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop ping interval
|
||||
*/
|
||||
private stopPing() {
|
||||
if (this.pingInterval) {
|
||||
clearInterval(this.pingInterval);
|
||||
this.pingInterval = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle disconnection
|
||||
*/
|
||||
private handleDisconnect() {
|
||||
this.stopPing();
|
||||
this.onDisconnected?.();
|
||||
|
||||
// Auto reconnect if not manual disconnect
|
||||
if (
|
||||
!this.isManualDisconnect &&
|
||||
this.reconnectAttempts < this.maxReconnectAttempts
|
||||
) {
|
||||
const delay = Math.min(2000 * Math.pow(2, this.reconnectAttempts), 30000);
|
||||
console.log(
|
||||
`[WebSocket] Reconnecting in ${delay}ms (attempt ${this.reconnectAttempts + 1}/${this.maxReconnectAttempts})`,
|
||||
);
|
||||
|
||||
this.reconnectTimeout = setTimeout(() => {
|
||||
this.reconnectAttempts++;
|
||||
// Note: Need to get token again, should be handled by caller
|
||||
console.warn(
|
||||
'[WebSocket] Auto-reconnect requires token, please reconnect manually',
|
||||
);
|
||||
}, delay);
|
||||
} else if (this.reconnectAttempts >= this.maxReconnectAttempts) {
|
||||
console.error(
|
||||
'[WebSocket] Max reconnect attempts reached, giving up',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Disconnect from server
|
||||
*/
|
||||
disconnect() {
|
||||
this.isManualDisconnect = true;
|
||||
this.stopPing();
|
||||
|
||||
if (this.reconnectTimeout) {
|
||||
clearTimeout(this.reconnectTimeout);
|
||||
this.reconnectTimeout = null;
|
||||
}
|
||||
|
||||
if (this.ws) {
|
||||
this.ws.close(1000, 'Client disconnect');
|
||||
this.ws = null;
|
||||
}
|
||||
|
||||
console.log('[WebSocket] Disconnected');
|
||||
}
|
||||
|
||||
/**
|
||||
* Build WebSocket URL
|
||||
*/
|
||||
private buildWebSocketUrl(): string {
|
||||
// Get current base URL
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const host = window.location.host;
|
||||
|
||||
return `${protocol}//${host}/api/v1/pipelines/${this.pipelineId}/chat/ws`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate unique client message ID
|
||||
*/
|
||||
private generateMessageId(): string {
|
||||
return `${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get connection state
|
||||
*/
|
||||
getState():
|
||||
| 'CONNECTING'
|
||||
| 'OPEN'
|
||||
| 'CLOSING'
|
||||
| 'CLOSED'
|
||||
| 'DISCONNECTED' {
|
||||
if (!this.ws) return 'DISCONNECTED';
|
||||
|
||||
switch (this.ws.readyState) {
|
||||
case WebSocket.CONNECTING:
|
||||
return 'CONNECTING';
|
||||
case WebSocket.OPEN:
|
||||
return 'OPEN';
|
||||
case WebSocket.CLOSING:
|
||||
return 'CLOSING';
|
||||
case WebSocket.CLOSED:
|
||||
return 'CLOSED';
|
||||
default:
|
||||
return 'DISCONNECTED';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if connected
|
||||
*/
|
||||
isConnected(): boolean {
|
||||
return this.ws?.readyState === WebSocket.OPEN;
|
||||
}
|
||||
}
|
||||
@@ -182,6 +182,17 @@ const enUS = {
|
||||
pluginSortSuccess: 'Plugin sort successful',
|
||||
pluginSortError: 'Plugin sort failed: ',
|
||||
pluginNoConfig: 'The plugin has no configuration items.',
|
||||
systemDisabled: 'Plugin System Disabled',
|
||||
systemDisabledDesc:
|
||||
'Plugin system is not enabled, please modify the configuration according to the documentation',
|
||||
connectionError: 'Plugin System Connection Error',
|
||||
connectionErrorDesc:
|
||||
'Please check the plugin system configuration or contact the administrator.',
|
||||
errorDetails: 'Error Details',
|
||||
loadingStatus: 'Checking plugin system status...',
|
||||
failedToGetStatus: 'Failed to get plugin system status',
|
||||
pluginSystemNotReady:
|
||||
'Plugin system is not ready, cannot perform this operation',
|
||||
deleting: 'Deleting...',
|
||||
deletePlugin: 'Delete Plugin',
|
||||
cancel: 'Cancel',
|
||||
@@ -195,9 +206,11 @@ const enUS = {
|
||||
deleteConfirm: 'Delete Confirmation',
|
||||
deleteSuccess: 'Delete successful',
|
||||
modifyFailed: 'Modify failed: ',
|
||||
eventCount: 'Events: {{count}}',
|
||||
toolCount: 'Tools: {{count}}',
|
||||
starCount: 'Stars: {{count}}',
|
||||
componentName: {
|
||||
Tool: 'Tool',
|
||||
EventListener: 'Event Listener',
|
||||
Command: 'Command',
|
||||
},
|
||||
uploadLocal: 'Upload Local',
|
||||
debugging: 'Debugging',
|
||||
uploadLocalPlugin: 'Upload Local Plugin',
|
||||
@@ -287,6 +300,7 @@ const enUS = {
|
||||
defaultBadge: 'Default',
|
||||
sortBy: 'Sort by',
|
||||
newestCreated: 'Newest Created',
|
||||
earliestCreated: 'Earliest Created',
|
||||
recentlyEdited: 'Recently Edited',
|
||||
earliestEdited: 'Earliest Edited',
|
||||
basicInfo: 'Basic',
|
||||
|
||||
@@ -183,6 +183,17 @@ const jaJP = {
|
||||
pluginSortSuccess: 'プラグインの並び替えに成功しました',
|
||||
pluginSortError: 'プラグインの並び替えに失敗しました:',
|
||||
pluginNoConfig: 'プラグインに設定項目がありません。',
|
||||
systemDisabled: 'プラグインシステムが無効になっています',
|
||||
systemDisabledDesc:
|
||||
'プラグインシステムが無効になっています。プラグインシステムを有効にするか、ドキュメントに従って設定を変更してください',
|
||||
connectionError: 'プラグインシステム接続エラー',
|
||||
connectionErrorDesc:
|
||||
'プラグインシステム設定を確認するか、管理者に連絡してください',
|
||||
errorDetails: 'エラー詳細',
|
||||
loadingStatus: 'プラグインシステム状態を確認中...',
|
||||
failedToGetStatus: 'プラグインシステム状態の取得に失敗しました',
|
||||
pluginSystemNotReady:
|
||||
'プラグインシステムが準備されていません。この操作を実行できません',
|
||||
deleting: '削除中...',
|
||||
deletePlugin: 'プラグインを削除',
|
||||
cancel: 'キャンセル',
|
||||
@@ -196,9 +207,11 @@ const jaJP = {
|
||||
deleteConfirm: '削除の確認',
|
||||
deleteSuccess: '削除に成功しました',
|
||||
modifyFailed: '変更に失敗しました:',
|
||||
eventCount: 'イベント:{{count}}',
|
||||
toolCount: 'ツール:{{count}}',
|
||||
starCount: 'スター:{{count}}',
|
||||
componentName: {
|
||||
Tool: 'ツール',
|
||||
EventListener: 'イベント監視器',
|
||||
Command: 'コマンド',
|
||||
},
|
||||
uploadLocal: 'ローカルアップロード',
|
||||
debugging: 'デバッグ中',
|
||||
uploadLocalPlugin: 'ローカルプラグインのアップロード',
|
||||
@@ -289,6 +302,7 @@ const jaJP = {
|
||||
defaultBadge: 'デフォルト',
|
||||
sortBy: '並び順',
|
||||
newestCreated: '最新作成',
|
||||
earliestCreated: '最古作成',
|
||||
recentlyEdited: '最近編集',
|
||||
earliestEdited: '最古編集',
|
||||
basicInfo: '基本情報',
|
||||
|
||||
@@ -178,6 +178,14 @@ const zhHans = {
|
||||
pluginSortSuccess: '插件排序成功',
|
||||
pluginSortError: '插件排序失败:',
|
||||
pluginNoConfig: '插件没有配置项。',
|
||||
systemDisabled: '插件系统未启用',
|
||||
systemDisabledDesc: '尚未启用插件系统,请根据文档修改配置',
|
||||
connectionError: '插件系统连接异常',
|
||||
connectionErrorDesc: '请检查插件系统配置或联系管理员',
|
||||
errorDetails: '错误详情',
|
||||
loadingStatus: '正在检查插件系统状态...',
|
||||
failedToGetStatus: '获取插件系统状态失败',
|
||||
pluginSystemNotReady: '插件系统未就绪,无法执行此操作',
|
||||
deleting: '删除中...',
|
||||
deletePlugin: '删除插件',
|
||||
cancel: '取消',
|
||||
@@ -191,9 +199,11 @@ const zhHans = {
|
||||
deleteConfirm: '删除确认',
|
||||
deleteSuccess: '删除成功',
|
||||
modifyFailed: '修改失败:',
|
||||
eventCount: '事件:{{count}}',
|
||||
toolCount: '工具:{{count}}',
|
||||
starCount: '星标:{{count}}',
|
||||
componentName: {
|
||||
Tool: '工具',
|
||||
EventListener: '事件监听器',
|
||||
Command: '命令',
|
||||
},
|
||||
uploadLocal: '本地上传',
|
||||
debugging: '调试中',
|
||||
uploadLocalPlugin: '上传本地插件',
|
||||
@@ -277,6 +287,7 @@ const zhHans = {
|
||||
defaultBadge: '默认',
|
||||
sortBy: '排序方式',
|
||||
newestCreated: '最新创建',
|
||||
earliestCreated: '最早创建',
|
||||
recentlyEdited: '最近编辑',
|
||||
earliestEdited: '最早编辑',
|
||||
basicInfo: '基础信息',
|
||||
|
||||
@@ -178,6 +178,14 @@ const zhHant = {
|
||||
pluginSortSuccess: '外掛排序成功',
|
||||
pluginSortError: '外掛排序失敗:',
|
||||
pluginNoConfig: '外掛沒有設定項目。',
|
||||
systemDisabled: '外掛系統未啟用',
|
||||
systemDisabledDesc: '尚未啟用外掛系統,請根據文檔修改配置',
|
||||
connectionError: '外掛系統連接異常',
|
||||
connectionErrorDesc: '請檢查外掛系統配置或聯絡管理員',
|
||||
errorDetails: '錯誤詳情',
|
||||
loadingStatus: '正在檢查外掛系統狀態...',
|
||||
failedToGetStatus: '取得外掛系統狀態失敗',
|
||||
pluginSystemNotReady: '外掛系統未就緒,無法執行此操作',
|
||||
deleting: '刪除中...',
|
||||
deletePlugin: '刪除外掛',
|
||||
cancel: '取消',
|
||||
@@ -189,9 +197,11 @@ const zhHant = {
|
||||
close: '關閉',
|
||||
deleteConfirm: '刪除確認',
|
||||
modifyFailed: '修改失敗:',
|
||||
eventCount: '事件:{{count}}',
|
||||
toolCount: '工具:{{count}}',
|
||||
starCount: '星標:{{count}}',
|
||||
componentName: {
|
||||
Tool: '工具',
|
||||
EventListener: '事件監聽器',
|
||||
Command: '命令',
|
||||
},
|
||||
uploadLocal: '本地上傳',
|
||||
debugging: '調試中',
|
||||
uploadLocalPlugin: '上傳本地插件',
|
||||
@@ -275,6 +285,7 @@ const zhHant = {
|
||||
defaultBadge: '預設',
|
||||
sortBy: '排序方式',
|
||||
newestCreated: '最新建立',
|
||||
earliestCreated: '最早建立',
|
||||
recentlyEdited: '最近編輯',
|
||||
earliestEdited: '最早編輯',
|
||||
basicInfo: '基本資訊',
|
||||
|
||||
Reference in New Issue
Block a user