2024-01-25 22:35:15 +08:00
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import time
|
2025-01-12 05:09:53 -05:00
|
|
|
|
import typing
|
2024-01-25 22:35:15 +08:00
|
|
|
|
from .. import algo
|
2025-06-15 22:04:31 +08:00
|
|
|
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
2024-01-25 22:35:15 +08:00
|
|
|
|
|
2025-04-29 17:24:07 +08:00
|
|
|
|
|
2024-05-26 10:29:10 +08:00
|
|
|
|
# 固定窗口算法
|
2024-01-25 22:35:15 +08:00
|
|
|
|
class SessionContainer:
|
|
|
|
|
|
wait_lock: asyncio.Lock
|
|
|
|
|
|
|
|
|
|
|
|
records: dict[int, int]
|
2024-05-26 10:29:10 +08:00
|
|
|
|
"""访问记录,key为每窗口长度的起始时间戳,value为访问次数"""
|
2024-01-25 22:35:15 +08:00
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.wait_lock = asyncio.Lock()
|
|
|
|
|
|
self.records = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-04-29 17:24:07 +08:00
|
|
|
|
@algo.algo_class('fixwin')
|
2024-01-25 22:35:15 +08:00
|
|
|
|
class FixedWindowAlgo(algo.ReteLimitAlgo):
|
|
|
|
|
|
containers_lock: asyncio.Lock
|
|
|
|
|
|
"""访问记录容器锁"""
|
|
|
|
|
|
|
|
|
|
|
|
containers: dict[str, SessionContainer]
|
|
|
|
|
|
"""访问记录容器,key为launcher_type launcher_id"""
|
|
|
|
|
|
|
|
|
|
|
|
async def initialize(self):
|
|
|
|
|
|
self.containers_lock = asyncio.Lock()
|
|
|
|
|
|
self.containers = {}
|
|
|
|
|
|
|
2025-04-29 17:24:07 +08:00
|
|
|
|
async def require_access(
|
|
|
|
|
|
self,
|
2025-06-15 22:04:31 +08:00
|
|
|
|
query: pipeline_query.Query,
|
2025-04-29 17:24:07 +08:00
|
|
|
|
launcher_type: str,
|
|
|
|
|
|
launcher_id: typing.Union[int, str],
|
|
|
|
|
|
) -> bool:
|
2024-01-25 22:35:15 +08:00
|
|
|
|
# 加锁,找容器
|
|
|
|
|
|
container: SessionContainer = None
|
|
|
|
|
|
|
|
|
|
|
|
session_name = f'{launcher_type}_{launcher_id}'
|
|
|
|
|
|
|
|
|
|
|
|
async with self.containers_lock:
|
|
|
|
|
|
container = self.containers.get(session_name)
|
|
|
|
|
|
|
|
|
|
|
|
if container is None:
|
|
|
|
|
|
container = SessionContainer()
|
|
|
|
|
|
self.containers[session_name] = container
|
|
|
|
|
|
|
|
|
|
|
|
# 等待锁
|
|
|
|
|
|
async with container.wait_lock:
|
2024-05-26 10:29:10 +08:00
|
|
|
|
# 获取窗口大小和限制
|
2025-04-03 17:57:51 +08:00
|
|
|
|
window_size = query.pipeline_config['safety']['rate-limit']['window-length']
|
|
|
|
|
|
limitation = query.pipeline_config['safety']['rate-limit']['limitation']
|
2024-05-26 10:29:10 +08:00
|
|
|
|
|
2025-04-03 17:57:51 +08:00
|
|
|
|
# TODO revert it
|
|
|
|
|
|
# if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
|
|
|
|
|
|
# window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
|
|
|
|
|
|
# limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
|
2024-05-26 10:29:10 +08:00
|
|
|
|
|
2024-01-25 22:35:15 +08:00
|
|
|
|
# 获取当前时间戳
|
|
|
|
|
|
now = int(time.time())
|
|
|
|
|
|
|
2024-05-26 10:29:10 +08:00
|
|
|
|
# 获取当前窗口的起始时间戳
|
|
|
|
|
|
now = now - now % window_size
|
2024-01-25 22:35:15 +08:00
|
|
|
|
|
2024-05-26 10:29:10 +08:00
|
|
|
|
# 获取当前窗口的访问次数
|
2024-01-25 22:35:15 +08:00
|
|
|
|
count = container.records.get(now, 0)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果访问次数超过了限制
|
|
|
|
|
|
if count >= limitation:
|
2025-04-03 17:57:51 +08:00
|
|
|
|
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
|
2024-01-25 22:35:15 +08:00
|
|
|
|
return False
|
2025-05-10 18:04:58 +08:00
|
|
|
|
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
|
2024-05-26 10:29:10 +08:00
|
|
|
|
# 等待下一窗口
|
|
|
|
|
|
await asyncio.sleep(window_size - time.time() % window_size)
|
2025-04-29 17:24:07 +08:00
|
|
|
|
|
2024-01-25 22:35:15 +08:00
|
|
|
|
now = int(time.time())
|
2024-05-26 10:29:10 +08:00
|
|
|
|
now = now - now % window_size
|
2025-04-29 17:24:07 +08:00
|
|
|
|
|
2024-01-25 22:35:15 +08:00
|
|
|
|
if now not in container.records:
|
|
|
|
|
|
container.records = {}
|
|
|
|
|
|
container.records[now] = 1
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 访问次数加一
|
|
|
|
|
|
container.records[now] = count + 1
|
|
|
|
|
|
|
|
|
|
|
|
# 返回True
|
|
|
|
|
|
return True
|
2025-04-29 17:24:07 +08:00
|
|
|
|
|
|
|
|
|
|
async def release_access(
|
|
|
|
|
|
self,
|
2025-06-15 22:04:31 +08:00
|
|
|
|
query: pipeline_query.Query,
|
2025-04-29 17:24:07 +08:00
|
|
|
|
launcher_type: str,
|
|
|
|
|
|
launcher_id: typing.Union[int, str],
|
|
|
|
|
|
):
|
2024-01-25 22:35:15 +08:00
|
|
|
|
pass
|