mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 19:37:36 +08:00
refactor: switch pipeline_cfg related fields to new pipeline config
This commit is contained in:
@@ -24,9 +24,9 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
|
||||
found = False
|
||||
|
||||
mode = self.ap.pipeline_cfg.data['access-control']['mode']
|
||||
mode = query.pipeline_config['trigger']['access-control']['mode']
|
||||
|
||||
sess_list = self.ap.pipeline_cfg.data['access-control'][mode]
|
||||
sess_list = query.pipeline_config['trigger']['access-control'][mode]
|
||||
|
||||
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \
|
||||
or (query.launcher_type.value == 'person' and 'person_*' in sess_list):
|
||||
|
||||
@@ -41,11 +41,12 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
"content-ignore",
|
||||
]
|
||||
|
||||
if self.ap.pipeline_cfg.data['check-sensitive-words']:
|
||||
if pipeline_config['safety']['content-filter']['check-sensitive-words']:
|
||||
filters_required.append("ban-word-filter")
|
||||
|
||||
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
|
||||
filters_required.append("baidu-cloud-examine")
|
||||
# TODO revert it
|
||||
# if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
|
||||
# filters_required.append("baidu-cloud-examine")
|
||||
|
||||
for filter in filter_model.preregistered_filters:
|
||||
if filter.name in filters_required:
|
||||
@@ -65,7 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
只要有一个不通过就不放行,只放行 PASS 的消息
|
||||
"""
|
||||
|
||||
if not self.ap.pipeline_cfg.data['income-msg-check']:
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
@@ -73,7 +74,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
else:
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.PRE in filter.enable_stages:
|
||||
result = await filter.process(message)
|
||||
result = await filter.process(query, message)
|
||||
|
||||
if result.level in [
|
||||
filter_entities.ResultLevel.BLOCK,
|
||||
@@ -105,7 +106,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
"""请求llm后处理响应
|
||||
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
||||
"""
|
||||
if message is None:
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
@@ -114,7 +115,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
message = message.strip()
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.POST in filter.enable_stages:
|
||||
result = await filter.process(message)
|
||||
result = await filter.process(query, message)
|
||||
|
||||
if result.level == filter_entities.ResultLevel.BLOCK:
|
||||
return entities.StageProcessResult(
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
@@ -64,7 +64,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str=None, image_url=None) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
|
||||
@@ -4,6 +4,7 @@ import aiohttp
|
||||
|
||||
from .. import entities
|
||||
from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
|
||||
@@ -26,7 +27,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
) as resp:
|
||||
return (await resp.json())['access_token']
|
||||
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
|
||||
from .. import filter as filter_model
|
||||
from .. import entities
|
||||
from ....config import manager as cfg_mgr
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@filter_model.filter_class("ban-word-filter")
|
||||
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
found = False
|
||||
|
||||
for word in self.ap.sensitive_meta.data['words']:
|
||||
|
||||
@@ -3,6 +3,7 @@ import re
|
||||
|
||||
from .. import entities
|
||||
from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@filter_model.filter_class("content-ignore")
|
||||
@@ -15,9 +16,9 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
entities.EnableStage.PRE,
|
||||
]
|
||||
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
if 'prefix' in self.ap.pipeline_cfg.data['ignore-rules']:
|
||||
for rule in self.ap.pipeline_cfg.data['ignore-rules']['prefix']:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
||||
if message.startswith(rule):
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
@@ -26,8 +27,8 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
|
||||
)
|
||||
|
||||
if 'regexp' in self.ap.pipeline_cfg.data['ignore-rules']:
|
||||
for rule in self.ap.pipeline_cfg.data['ignore-rules']['regexp']:
|
||||
if 'regexp' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']:
|
||||
if re.search(rule, message):
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
|
||||
@@ -15,7 +15,7 @@ class ConversationMessageTruncator(stage.PipelineStage):
|
||||
trun: truncator.Truncator
|
||||
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
|
||||
use_method = "round"
|
||||
|
||||
for trun in truncator.preregistered_truncators:
|
||||
if trun.name == use_method:
|
||||
|
||||
@@ -12,7 +12,7 @@ class RoundTruncator(truncator.Truncator):
|
||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
||||
"""截断
|
||||
"""
|
||||
max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round']
|
||||
max_round = query.pipeline_config['ai']['local-agent']['max-round']
|
||||
|
||||
temp_messages = []
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from ...core import app, entities as core_entities
|
||||
|
||||
|
||||
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
|
||||
@@ -31,7 +31,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
|
||||
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
|
||||
"""进入处理流程
|
||||
|
||||
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
|
||||
@@ -46,7 +46,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]):
|
||||
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
|
||||
"""退出处理流程
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,6 +3,7 @@ import asyncio
|
||||
import time
|
||||
import typing
|
||||
from .. import algo
|
||||
from ....core import entities as core_entities
|
||||
|
||||
# 固定窗口算法
|
||||
class SessionContainer:
|
||||
@@ -30,7 +31,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
self.containers_lock = asyncio.Lock()
|
||||
self.containers = {}
|
||||
|
||||
async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
|
||||
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
|
||||
# 加锁,找容器
|
||||
container: SessionContainer = None
|
||||
|
||||
@@ -47,12 +48,13 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
async with container.wait_lock:
|
||||
|
||||
# 获取窗口大小和限制
|
||||
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['window-size']
|
||||
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['limit']
|
||||
window_size = query.pipeline_config['safety']['rate-limit']['window-length']
|
||||
limitation = query.pipeline_config['safety']['rate-limit']['limitation']
|
||||
|
||||
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
|
||||
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
|
||||
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
|
||||
# TODO revert it
|
||||
# if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
|
||||
# window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
|
||||
# limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
|
||||
|
||||
# 获取当前时间戳
|
||||
now = int(time.time())
|
||||
@@ -65,9 +67,9 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
# 如果访问次数超过了限制
|
||||
if count >= limitation:
|
||||
if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop':
|
||||
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
|
||||
return False
|
||||
elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait':
|
||||
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
|
||||
# 等待下一窗口
|
||||
await asyncio.sleep(window_size - time.time() % window_size)
|
||||
|
||||
@@ -84,5 +86,5 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
# 返回True
|
||||
return True
|
||||
|
||||
async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]):
|
||||
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
|
||||
pass
|
||||
|
||||
@@ -20,7 +20,7 @@ class RateLimit(stage.PipelineStage):
|
||||
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
|
||||
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
|
||||
algo_name = 'fixwin'
|
||||
|
||||
algo_class = None
|
||||
|
||||
@@ -46,6 +46,7 @@ class RateLimit(stage.PipelineStage):
|
||||
"""
|
||||
if stage_inst_name == "RequireRateLimitOccupancy":
|
||||
if await self.algo.require_access(
|
||||
query,
|
||||
query.launcher_type.value,
|
||||
query.launcher_id,
|
||||
):
|
||||
@@ -62,6 +63,7 @@ class RateLimit(stage.PipelineStage):
|
||||
)
|
||||
elif stage_inst_name == "ReleaseRateLimitOccupancy":
|
||||
await self.algo.release_access(
|
||||
query,
|
||||
query.launcher_type.value,
|
||||
query.launcher_id,
|
||||
)
|
||||
|
||||
@@ -39,12 +39,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
new_query=query
|
||||
)
|
||||
|
||||
rules = self.ap.pipeline_cfg.data['respond-rules']
|
||||
rules = query.pipeline_config['trigger']['group-respond-rules']
|
||||
|
||||
use_rule = rules['default']
|
||||
use_rule = rules
|
||||
|
||||
if str(query.launcher_id) in rules:
|
||||
use_rule = rules[str(query.launcher_id)]
|
||||
# TODO revert it
|
||||
# if str(query.launcher_id) in rules:
|
||||
# use_rule = rules[str(query.launcher_id)]
|
||||
|
||||
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
|
||||
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
|
||||
|
||||
@@ -54,3 +54,19 @@ stages:
|
||||
type: integer
|
||||
required: true
|
||||
default: 60
|
||||
- name: strategy
|
||||
label:
|
||||
en_US: Strategy
|
||||
zh_CN: 策略
|
||||
type: select
|
||||
required: true
|
||||
default: drop
|
||||
options:
|
||||
- name: drop
|
||||
label:
|
||||
en_US: Drop
|
||||
zh_CN: 丢弃
|
||||
- name: wait
|
||||
label:
|
||||
en_US: Wait
|
||||
zh_CN: 等待
|
||||
Reference in New Issue
Block a user