From 7f66efcdd5a85a88d86b40d06b0b0334677755d3 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Thu, 3 Apr 2025 17:57:51 +0800 Subject: [PATCH] refactor: switch pipeline_cfg related fields to new pipeline config --- pkg/pipeline/bansess/bansess.py | 4 ++-- pkg/pipeline/cntfilter/cntfilter.py | 15 +++++++------- pkg/pipeline/cntfilter/filter.py | 4 ++-- .../cntfilter/filters/baiduexamine.py | 3 ++- pkg/pipeline/cntfilter/filters/banwords.py | 4 ++-- pkg/pipeline/cntfilter/filters/cntignore.py | 11 +++++----- pkg/pipeline/msgtrun/msgtrun.py | 2 +- pkg/pipeline/msgtrun/truncators/round.py | 2 +- pkg/pipeline/ratelimit/algo.py | 6 +++--- pkg/pipeline/ratelimit/algos/fixedwin.py | 20 ++++++++++--------- pkg/pipeline/ratelimit/ratelimit.py | 4 +++- pkg/pipeline/resprule/resprule.py | 9 +++++---- templates/metadata/pipeline/safety.yaml | 16 +++++++++++++++ 13 files changed, 62 insertions(+), 38 deletions(-) diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index 1ca42397..38fb9794 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -24,9 +24,9 @@ class BanSessionCheckStage(stage.PipelineStage): found = False - mode = self.ap.pipeline_cfg.data['access-control']['mode'] + mode = query.pipeline_config['trigger']['access-control']['mode'] - sess_list = self.ap.pipeline_cfg.data['access-control'][mode] + sess_list = query.pipeline_config['trigger']['access-control'][mode] if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \ or (query.launcher_type.value == 'person' and 'person_*' in sess_list): diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 6a0c3776..dbf7c52e 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -41,11 +41,12 @@ class ContentFilterStage(stage.PipelineStage): "content-ignore", ] - if self.ap.pipeline_cfg.data['check-sensitive-words']: + if pipeline_config['safety']['content-filter']['check-sensitive-words']: filters_required.append("ban-word-filter") - if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: - filters_required.append("baidu-cloud-examine") + # TODO revert it + # if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: + # filters_required.append("baidu-cloud-examine") for filter in filter_model.preregistered_filters: if filter.name in filters_required: @@ -65,7 +66,7 @@ class ContentFilterStage(stage.PipelineStage): 只要有一个不通过就不放行,只放行 PASS 的消息 """ - if not self.ap.pipeline_cfg.data['income-msg-check']: + if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg': return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query @@ -73,7 +74,7 @@ class ContentFilterStage(stage.PipelineStage): else: for filter in self.filter_chain: if filter_entities.EnableStage.PRE in filter.enable_stages: - result = await filter.process(message) + result = await filter.process(query, message) if result.level in [ filter_entities.ResultLevel.BLOCK, @@ -105,7 +106,7 @@ class ContentFilterStage(stage.PipelineStage): """请求llm后处理响应 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter """ - if message is None: + if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg': return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query @@ -114,7 +115,7 @@ class ContentFilterStage(stage.PipelineStage): message = message.strip() for filter in self.filter_chain: if filter_entities.EnableStage.POST in filter.enable_stages: - result = await filter.process(message) + result = await filter.process(query, message) if result.level == filter_entities.ResultLevel.BLOCK: return entities.StageProcessResult( diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 8eceb877..970e11f1 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -3,7 +3,7 @@ from __future__ import annotations import abc import typing -from ...core import app +from ...core import app, entities as core_entities from . import entities from ...provider import entities as llm_entities @@ -64,7 +64,7 @@ class ContentFilter(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process(self, message: str=None, image_url=None) -> entities.FilterResult: + async def process(self, query: core_entities.Query, message: str=None, image_url=None) -> entities.FilterResult: """处理消息 分为前后阶段,具体取决于 enable_stages 的值。 diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py index 8c5b77cd..800f0099 100644 --- a/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -4,6 +4,7 @@ import aiohttp from .. import entities from .. import filter as filter_model +from ....core import entities as core_entities BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}" @@ -26,7 +27,7 @@ class BaiduCloudExamine(filter_model.ContentFilter): ) as resp: return (await resp.json())['access_token'] - async def process(self, message: str) -> entities.FilterResult: + async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async with aiohttp.ClientSession() as session: async with session.post( diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index 1430c2ed..cd3d412c 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -3,7 +3,7 @@ import re from .. import filter as filter_model from .. import entities -from ....config import manager as cfg_mgr +from ....core import entities as core_entities @filter_model.filter_class("ban-word-filter") @@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter): async def initialize(self): pass - async def process(self, message: str) -> entities.FilterResult: + async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: found = False for word in self.ap.sensitive_meta.data['words']: diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index 781f6397..381d5c51 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -3,6 +3,7 @@ import re from .. import entities from .. import filter as filter_model +from ....core import entities as core_entities @filter_model.filter_class("content-ignore") @@ -15,9 +16,9 @@ class ContentIgnore(filter_model.ContentFilter): entities.EnableStage.PRE, ] - async def process(self, message: str) -> entities.FilterResult: - if 'prefix' in self.ap.pipeline_cfg.data['ignore-rules']: - for rule in self.ap.pipeline_cfg.data['ignore-rules']['prefix']: + async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: + if 'prefix' in query.pipeline_config['trigger']['ignore-rules']: + for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']: if message.startswith(rule): return entities.FilterResult( level=entities.ResultLevel.BLOCK, @@ -26,8 +27,8 @@ class ContentIgnore(filter_model.ContentFilter): console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息' ) - if 'regexp' in self.ap.pipeline_cfg.data['ignore-rules']: - for rule in self.ap.pipeline_cfg.data['ignore-rules']['regexp']: + if 'regexp' in query.pipeline_config['trigger']['ignore-rules']: + for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']: if re.search(rule, message): return entities.FilterResult( level=entities.ResultLevel.BLOCK, diff --git a/pkg/pipeline/msgtrun/msgtrun.py b/pkg/pipeline/msgtrun/msgtrun.py index a1116eb4..b3fb593a 100644 --- a/pkg/pipeline/msgtrun/msgtrun.py +++ b/pkg/pipeline/msgtrun/msgtrun.py @@ -15,7 +15,7 @@ class ConversationMessageTruncator(stage.PipelineStage): trun: truncator.Truncator async def initialize(self, pipeline_config: dict): - use_method = self.ap.pipeline_cfg.data['msg-truncate']['method'] + use_method = "round" for trun in truncator.preregistered_truncators: if trun.name == use_method: diff --git a/pkg/pipeline/msgtrun/truncators/round.py b/pkg/pipeline/msgtrun/truncators/round.py index 646f2856..46fce5f3 100644 --- a/pkg/pipeline/msgtrun/truncators/round.py +++ b/pkg/pipeline/msgtrun/truncators/round.py @@ -12,7 +12,7 @@ class RoundTruncator(truncator.Truncator): async def truncate(self, query: core_entities.Query) -> core_entities.Query: """截断 """ - max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round'] + max_round = query.pipeline_config['ai']['local-agent']['max-round'] temp_messages = [] diff --git a/pkg/pipeline/ratelimit/algo.py b/pkg/pipeline/ratelimit/algo.py index 9b418dd2..d9baa801 100644 --- a/pkg/pipeline/ratelimit/algo.py +++ b/pkg/pipeline/ratelimit/algo.py @@ -2,7 +2,7 @@ from __future__ import annotations import abc import typing -from ...core import app +from ...core import app, entities as core_entities preregistered_algos: list[typing.Type[ReteLimitAlgo]] = [] @@ -31,7 +31,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: + async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: """进入处理流程 这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。 @@ -46,7 +46,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta): raise NotImplementedError @abc.abstractmethod - async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]): + async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]): """退出处理流程 Args: diff --git a/pkg/pipeline/ratelimit/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py index 3cc1ab94..f17e93b8 100644 --- a/pkg/pipeline/ratelimit/algos/fixedwin.py +++ b/pkg/pipeline/ratelimit/algos/fixedwin.py @@ -3,6 +3,7 @@ import asyncio import time import typing from .. import algo +from ....core import entities as core_entities # 固定窗口算法 class SessionContainer: @@ -30,7 +31,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): self.containers_lock = asyncio.Lock() self.containers = {} - async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: + async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: # 加锁,找容器 container: SessionContainer = None @@ -47,12 +48,13 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): async with container.wait_lock: # 获取窗口大小和限制 - window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['window-size'] - limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['limit'] + window_size = query.pipeline_config['safety']['rate-limit']['window-length'] + limitation = query.pipeline_config['safety']['rate-limit']['limitation'] - if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']: - window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size'] - limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit'] + # TODO revert it + # if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']: + # window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size'] + # limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit'] # 获取当前时间戳 now = int(time.time()) @@ -65,9 +67,9 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): # 如果访问次数超过了限制 if count >= limitation: - if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop': + if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop': return False - elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait': + elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait': # 等待下一窗口 await asyncio.sleep(window_size - time.time() % window_size) @@ -84,5 +86,5 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): # 返回True return True - async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]): + async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]): pass diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index 01bde395..c74db978 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -20,7 +20,7 @@ class RateLimit(stage.PipelineStage): async def initialize(self, pipeline_config: dict): - algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo'] + algo_name = 'fixwin' algo_class = None @@ -46,6 +46,7 @@ class RateLimit(stage.PipelineStage): """ if stage_inst_name == "RequireRateLimitOccupancy": if await self.algo.require_access( + query, query.launcher_type.value, query.launcher_id, ): @@ -62,6 +63,7 @@ class RateLimit(stage.PipelineStage): ) elif stage_inst_name == "ReleaseRateLimitOccupancy": await self.algo.release_access( + query, query.launcher_type.value, query.launcher_id, ) diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index 7e4b8f99..08ba49e8 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -39,12 +39,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): new_query=query ) - rules = self.ap.pipeline_cfg.data['respond-rules'] + rules = query.pipeline_config['trigger']['group-respond-rules'] - use_rule = rules['default'] + use_rule = rules - if str(query.launcher_id) in rules: - use_rule = rules[str(query.launcher_id)] + # TODO revert it + # if str(query.launcher_id) in rules: + # use_rule = rules[str(query.launcher_id)] for rule_matcher in self.rule_matchers: # 任意一个匹配就放行 res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query) diff --git a/templates/metadata/pipeline/safety.yaml b/templates/metadata/pipeline/safety.yaml index 09f8025b..ba59e067 100644 --- a/templates/metadata/pipeline/safety.yaml +++ b/templates/metadata/pipeline/safety.yaml @@ -54,3 +54,19 @@ stages: type: integer required: true default: 60 + - name: strategy + label: + en_US: Strategy + zh_CN: 策略 + type: select + required: true + default: drop + options: + - name: drop + label: + en_US: Drop + zh_CN: 丢弃 + - name: wait + label: + en_US: Wait + zh_CN: 等待 \ No newline at end of file