refactor: switch pipeline_cfg related fields to new pipeline config

This commit is contained in:
Junyan Qin
2025-04-03 17:57:51 +08:00
parent 472d472bc1
commit 7f66efcdd5
13 changed files with 62 additions and 38 deletions

View File

@@ -24,9 +24,9 @@ class BanSessionCheckStage(stage.PipelineStage):
found = False
mode = self.ap.pipeline_cfg.data['access-control']['mode']
mode = query.pipeline_config['trigger']['access-control']['mode']
sess_list = self.ap.pipeline_cfg.data['access-control'][mode]
sess_list = query.pipeline_config['trigger']['access-control'][mode]
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \
or (query.launcher_type.value == 'person' and 'person_*' in sess_list):

View File

@@ -41,11 +41,12 @@ class ContentFilterStage(stage.PipelineStage):
"content-ignore",
]
if self.ap.pipeline_cfg.data['check-sensitive-words']:
if pipeline_config['safety']['content-filter']['check-sensitive-words']:
filters_required.append("ban-word-filter")
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
filters_required.append("baidu-cloud-examine")
# TODO revert it
# if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
# filters_required.append("baidu-cloud-examine")
for filter in filter_model.preregistered_filters:
if filter.name in filters_required:
@@ -65,7 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
只要有一个不通过就不放行,只放行 PASS 的消息
"""
if not self.ap.pipeline_cfg.data['income-msg-check']:
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
@@ -73,7 +74,7 @@ class ContentFilterStage(stage.PipelineStage):
else:
for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages:
result = await filter.process(message)
result = await filter.process(query, message)
if result.level in [
filter_entities.ResultLevel.BLOCK,
@@ -105,7 +106,7 @@ class ContentFilterStage(stage.PipelineStage):
"""请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
"""
if message is None:
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
@@ -114,7 +115,7 @@ class ContentFilterStage(stage.PipelineStage):
message = message.strip()
for filter in self.filter_chain:
if filter_entities.EnableStage.POST in filter.enable_stages:
result = await filter.process(message)
result = await filter.process(query, message)
if result.level == filter_entities.ResultLevel.BLOCK:
return entities.StageProcessResult(

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import app, entities as core_entities
from . import entities
from ...provider import entities as llm_entities
@@ -64,7 +64,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str=None, image_url=None) -> entities.FilterResult:
"""处理消息
分为前后阶段,具体取决于 enable_stages 的值。

View File

@@ -4,6 +4,7 @@ import aiohttp
from .. import entities
from .. import filter as filter_model
from ....core import entities as core_entities
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
@@ -26,7 +27,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
) as resp:
return (await resp.json())['access_token']
async def process(self, message: str) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(

View File

@@ -3,7 +3,7 @@ import re
from .. import filter as filter_model
from .. import entities
from ....config import manager as cfg_mgr
from ....core import entities as core_entities
@filter_model.filter_class("ban-word-filter")
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
async def initialize(self):
pass
async def process(self, message: str) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
found = False
for word in self.ap.sensitive_meta.data['words']:

View File

@@ -3,6 +3,7 @@ import re
from .. import entities
from .. import filter as filter_model
from ....core import entities as core_entities
@filter_model.filter_class("content-ignore")
@@ -15,9 +16,9 @@ class ContentIgnore(filter_model.ContentFilter):
entities.EnableStage.PRE,
]
async def process(self, message: str) -> entities.FilterResult:
if 'prefix' in self.ap.pipeline_cfg.data['ignore-rules']:
for rule in self.ap.pipeline_cfg.data['ignore-rules']['prefix']:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
@@ -26,8 +27,8 @@ class ContentIgnore(filter_model.ContentFilter):
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
)
if 'regexp' in self.ap.pipeline_cfg.data['ignore-rules']:
for rule in self.ap.pipeline_cfg.data['ignore-rules']['regexp']:
if 'regexp' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']:
if re.search(rule, message):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,

View File

@@ -15,7 +15,7 @@ class ConversationMessageTruncator(stage.PipelineStage):
trun: truncator.Truncator
async def initialize(self, pipeline_config: dict):
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
use_method = "round"
for trun in truncator.preregistered_truncators:
if trun.name == use_method:

View File

@@ -12,7 +12,7 @@ class RoundTruncator(truncator.Truncator):
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
"""
max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round']
max_round = query.pipeline_config['ai']['local-agent']['max-round']
temp_messages = []

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import app, entities as core_entities
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
@@ -31,7 +31,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
"""进入处理流程
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
@@ -46,7 +46,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
raise NotImplementedError
@abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]):
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
"""退出处理流程
Args:

View File

@@ -3,6 +3,7 @@ import asyncio
import time
import typing
from .. import algo
from ....core import entities as core_entities
# 固定窗口算法
class SessionContainer:
@@ -30,7 +31,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
self.containers_lock = asyncio.Lock()
self.containers = {}
async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
# 加锁,找容器
container: SessionContainer = None
@@ -47,12 +48,13 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
async with container.wait_lock:
# 获取窗口大小和限制
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['window-size']
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['limit']
window_size = query.pipeline_config['safety']['rate-limit']['window-length']
limitation = query.pipeline_config['safety']['rate-limit']['limitation']
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
# TODO revert it
# if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
# window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
# limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
# 获取当前时间戳
now = int(time.time())
@@ -65,9 +67,9 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
# 如果访问次数超过了限制
if count >= limitation:
if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop':
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
return False
elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait':
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
# 等待下一窗口
await asyncio.sleep(window_size - time.time() % window_size)
@@ -84,5 +86,5 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
# 返回True
return True
async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]):
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
pass

View File

@@ -20,7 +20,7 @@ class RateLimit(stage.PipelineStage):
async def initialize(self, pipeline_config: dict):
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
algo_name = 'fixwin'
algo_class = None
@@ -46,6 +46,7 @@ class RateLimit(stage.PipelineStage):
"""
if stage_inst_name == "RequireRateLimitOccupancy":
if await self.algo.require_access(
query,
query.launcher_type.value,
query.launcher_id,
):
@@ -62,6 +63,7 @@ class RateLimit(stage.PipelineStage):
)
elif stage_inst_name == "ReleaseRateLimitOccupancy":
await self.algo.release_access(
query,
query.launcher_type.value,
query.launcher_id,
)

View File

@@ -39,12 +39,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
new_query=query
)
rules = self.ap.pipeline_cfg.data['respond-rules']
rules = query.pipeline_config['trigger']['group-respond-rules']
use_rule = rules['default']
use_rule = rules
if str(query.launcher_id) in rules:
use_rule = rules[str(query.launcher_id)]
# TODO revert it
# if str(query.launcher_id) in rules:
# use_rule = rules[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)