Files
LangBot/pkg/pipeline/ratelimit/algos/fixedwin.py

99 lines
3.1 KiB
Python
Raw Normal View History

2024-01-25 22:35:15 +08:00
from __future__ import annotations
import asyncio
import time
2025-01-12 05:09:53 -05:00
import typing
2024-01-25 22:35:15 +08:00
from .. import algo
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
2024-01-25 22:35:15 +08:00
# 固定窗口算法
2024-01-25 22:35:15 +08:00
class SessionContainer:
wait_lock: asyncio.Lock
records: dict[int, int]
"""访问记录key为每窗口长度的起始时间戳value为访问次数"""
2024-01-25 22:35:15 +08:00
def __init__(self):
self.wait_lock = asyncio.Lock()
self.records = {}
@algo.algo_class('fixwin')
2024-01-25 22:35:15 +08:00
class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock
"""访问记录容器锁"""
containers: dict[str, SessionContainer]
"""访问记录容器key为launcher_type launcher_id"""
async def initialize(self):
self.containers_lock = asyncio.Lock()
self.containers = {}
async def require_access(
self,
query: pipeline_query.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
) -> bool:
2024-01-25 22:35:15 +08:00
# 加锁,找容器
container: SessionContainer = None
session_name = f'{launcher_type}_{launcher_id}'
async with self.containers_lock:
container = self.containers.get(session_name)
if container is None:
container = SessionContainer()
self.containers[session_name] = container
# 等待锁
async with container.wait_lock:
# 获取窗口大小和限制
window_size = query.pipeline_config['safety']['rate-limit']['window-length']
limitation = query.pipeline_config['safety']['rate-limit']['limitation']
# TODO revert it
# if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
# window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
# limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
2024-01-25 22:35:15 +08:00
# 获取当前时间戳
now = int(time.time())
# 获取当前窗口的起始时间戳
now = now - now % window_size
2024-01-25 22:35:15 +08:00
# 获取当前窗口的访问次数
2024-01-25 22:35:15 +08:00
count = container.records.get(now, 0)
# 如果访问次数超过了限制
if count >= limitation:
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
2024-01-25 22:35:15 +08:00
return False
2025-05-10 18:04:58 +08:00
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
# 等待下一窗口
await asyncio.sleep(window_size - time.time() % window_size)
2024-01-25 22:35:15 +08:00
now = int(time.time())
now = now - now % window_size
2024-01-25 22:35:15 +08:00
if now not in container.records:
container.records = {}
container.records[now] = 1
else:
# 访问次数加一
container.records[now] = count + 1
# 返回True
return True
async def release_access(
self,
query: pipeline_query.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
):
2024-01-25 22:35:15 +08:00
pass