Merge pull request #712 from RockChinQ/feat/component-extensibility

Feat: 更多组件的可扩展性
This commit is contained in:
Junyan Qin
2024-03-13 00:32:26 +08:00
committed by GitHub
43 changed files with 208 additions and 192 deletions

View File

@@ -8,18 +8,34 @@ from . import entities
preregistered_operators: list[typing.Type[CommandOperator]] = []
"""预注册算子列表。在初始化时,所有算子类会被注册到此列表中。"""
"""预注册命令算子列表。在初始化时,所有算子类会被注册到此列表中。"""
def operator_class(
name: str,
help: str,
help: str = "",
usage: str = None,
alias: list[str] = [],
privilege: int=1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
"""命令类装饰器
Args:
name (str): 名称
help (str, optional): 帮助信息. Defaults to "".
usage (str, optional): 使用说明. Defaults to None.
alias (list[str], optional): 别名. Defaults to [].
privilege (int, optional): 权限1为普通用户可用2为仅管理员可用. Defaults to 1.
parent_class (typing.Type[CommandOperator], optional): 父节点若为None则为顶级命令. Defaults to None.
Returns:
typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器
"""
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
assert issubclass(cls, CommandOperator)
cls.name = name
cls.alias = alias
cls.help = help

View File

@@ -6,7 +6,7 @@ import traceback
from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr
from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr

View File

@@ -9,7 +9,7 @@ import pydantic
import mirai
from ..provider import entities as llm_entities
from ..provider.requester import entities
from ..provider.modelmgr import entities
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter

View File

@@ -10,7 +10,7 @@ from ...pipeline import pool, controller, stagemgr
from ...plugin import manager as plugin_mgr
from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.requester import modelmgr as llm_model_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...platform import manager as im_mgr

View File

@@ -7,7 +7,7 @@ from ...core import app
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter, entities as filter_entities
from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
@@ -16,20 +16,29 @@ from .filters import cntignore, banwords, baiduexamine
class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段"""
filter_chain: list[filter.ContentFilter]
filter_chain: list[filter_model.ContentFilter]
def __init__(self, ap: app.Application):
self.filter_chain = []
super().__init__(ap)
async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
filters_required = [
"content-filter"
]
if self.ap.pipeline_cfg.data['check-sensitive-words']:
self.filter_chain.append(banwords.BanWordFilter(self.ap))
filters_required.append("ban-word-filter")
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
filters_required.append("baidu-cloud-examine")
for filter in filter_model.preregistered_filters:
if filter.name in filters_required:
self.filter_chain.append(
filter(self.ap)
)
for filter in self.filter_chain:
await filter.initialize()

View File

@@ -1,12 +1,42 @@
# 内容过滤器的抽象类
from __future__ import annotations
import abc
import typing
from ...core import app
from . import entities
preregistered_filters: list[typing.Type[ContentFilter]] = []
def filter_class(
name: str
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
"""内容过滤器类装饰器
Args:
name (str): 过滤器名称
Returns:
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
"""
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
assert issubclass(cls, ContentFilter)
cls.name = name
preregistered_filters.append(cls)
return cls
return decorator
class ContentFilter(metaclass=abc.ABCMeta):
"""内容过滤器抽象类"""
name: str
ap: app.Application

View File

@@ -10,6 +10,7 @@ BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
@filter_model.filter_class("baidu-cloud-examine")
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""

View File

@@ -6,6 +6,7 @@ from .. import entities
from ....config import manager as cfg_mgr
@filter_model.filter_class("ban-word-filter")
class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言"""

View File

@@ -5,6 +5,7 @@ from .. import entities
from .. import filter as filter_model
@filter_model.filter_class("content-ignore")
class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息"""

View File

@@ -45,11 +45,14 @@ class LongTextProcessStage(stage.PipelineStage):
self.ap.logger.error("加载字体文件失败({})更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。".format(use_font))
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
if config['strategy'] == 'image':
self.strategy_impl = image.Text2ImageStrategy(self.ap)
elif config['strategy'] == 'forward':
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
for strategy_cls in strategy.preregistered_strategies:
if strategy_cls.name == config['strategy']:
self.strategy_impl = strategy_cls(self.ap)
break
else:
raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略")
await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:

View File

@@ -36,6 +36,7 @@ class Forward(MessageComponent):
return '[聊天记录]'
@strategy_model.strategy_class("forward")
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:

View File

@@ -15,6 +15,7 @@ from .. import strategy as strategy_model
from ....core import entities as core_entities
@strategy_model.strategy_class("image")
class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont

View File

@@ -9,7 +9,30 @@ from ...core import app
from ...core import entities as core_entities
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
def strategy_class(
name: str
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]:
assert issubclass(cls, LongTextStrategy)
cls.name = name
preregistered_strategies.append(cls)
return cls
return decorator
class LongTextStrategy(metaclass=abc.ABCMeta):
"""长文本处理策略抽象类
"""
name: str
ap: app.Application
def __init__(self, ap: app.Application):

View File

@@ -51,28 +51,6 @@ class PreProcessor(stage.PipelineStage):
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt
# 根据模型max_tokens剪裁
max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens'])
test_messages = query.prompt.messages + query.messages + [query.user_message]
while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens:
# 前文都pop完了还是大于max_tokens由于prompt和user_messages不能删减报错
if len(query.prompt.messages) == 0:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='输入内容过长,请减少情景预设或者输入内容长度',
console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens 项但不能超过所用模型最大tokens数'
)
query.messages.pop(0) # pop第一个肯定是role=user的
# 继续pop到第二个role=user前一个
while len(query.messages) > 0 and query.messages[0].role != 'user':
query.messages.pop(0)
test_messages = query.prompt.messages + query.messages + [query.user_message]
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query

View File

@@ -21,8 +21,6 @@ class ChatMessageHandler(handler.MessageHandler):
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
# 取session
# 取conversation
# 调API
# 生成器

View File

@@ -1,11 +1,26 @@
from __future__ import annotations
import abc
import typing
from ...core import app
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
def algo_class(name: str):
def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
cls.name = name
preregistered_algos.append(cls)
return cls
return decorator
class ReteLimitAlgo(metaclass=abc.ABCMeta):
name: str = None
ap: app.Application
def __init__(self, ap: app.Application):

View File

@@ -19,6 +19,7 @@ class SessionContainer:
self.records = {}
@algo.algo_class("fixwin")
class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock

View File

@@ -16,7 +16,19 @@ class RateLimit(stage.PipelineStage):
algo: algo.ReteLimitAlgo
async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
algo_class = None
for algo_cls in algo.preregistered_algos:
if algo_cls.name == algo_name:
algo_class = algo_cls
break
else:
raise ValueError(f'未知的限速算法: {algo_name}')
self.algo = algo_class(self.ap)
await self.algo.initialize()
async def process(

View File

@@ -21,15 +21,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
async def initialize(self):
"""初始化检查器
"""
self.rule_matchers = [
atbot.AtBotRule(self.ap),
prefix.PrefixRule(self.ap),
regexp.RegExpRule(self.ap),
random.RandomRespRule(self.ap),
]
for rule_matcher in self.rule_matchers:
await rule_matcher.initialize()
self.rule_matchers = []
for rule_matcher in rule.preregisetered_rules:
rule_inst = rule_matcher(self.ap)
await rule_inst.initialize()
self.rule_matchers.append(rule_inst)
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import abc
import typing
import mirai
@@ -7,9 +8,20 @@ from ...core import app, entities as core_entities
from . import entities
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
def rule_class(name: str):
def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]:
cls.name = name
preregisetered_rules.append(cls)
return cls
return decorator
class GroupRespondRule(metaclass=abc.ABCMeta):
"""群组响应规则的抽象类
"""
name: str
ap: app.Application

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities
@rule_model.rule_class("at-bot")
class AtBotRule(rule_model.GroupRespondRule):
async def match(

View File

@@ -5,6 +5,7 @@ from .. import entities
from ....core import entities as core_entities
@rule_model.rule_class("prefix")
class PrefixRule(rule_model.GroupRespondRule):
async def match(

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities
@rule_model.rule_class("random")
class RandomRespRule(rule_model.GroupRespondRule):
async def match(

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities
@rule_model.rule_class("regexp")
class RegExpRule(rule_model.GroupRespondRule):
async def match(

View File

@@ -22,6 +22,8 @@ def adapter_class(
class MessageSourceAdapter(metaclass=abc.ABCMeta):
"""消息平台适配器基类"""
name: str
bot_account_id: int
@@ -40,7 +42,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
target_id: str,
message: mirai.MessageChain
):
"""发送消息
"""主动发送消息
Args:
target_type (str): 目标类型,`person`或`group`

View File

@@ -163,25 +163,6 @@ class PlatformManager:
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False
)
# 通知系统管理员
# TODO delete
# async def notify_admin(self, message: str):
# await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
# async def notify_admin_message_chain(self, message: mirai.MessageChain):
# if self.ap.system_cfg.data['admin-sessions'] != []:
# admin_list = []
# for admin in self.ap.system_cfg.data['admin-sessions']:
# admin_list.append(admin)
# for adm in admin_list:
# self.adapter.send_message(
# adm.split("_")[0],
# adm.split("_")[1],
# message
# )
async def run(self):
try:
tasks = []

View File

@@ -24,6 +24,8 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
msg_list = message_chain.__root__
elif type(message_chain) is list:
msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(message_chain)]
else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))

View File

@@ -89,6 +89,8 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
msg_list = message_chain.__root__
elif type(message_chain) is list:
msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(text=message_chain)]
else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))

View File

@@ -7,9 +7,23 @@ from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
preregistered_requesters: list[typing.Type[LLMAPIRequester]] = []
def requester_class(name: str):
def decorator(cls: typing.Type[LLMAPIRequester]) -> typing.Type[LLMAPIRequester]:
cls.name = name
preregistered_requesters.append(cls)
return cls
return decorator
class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器
"""
name: str = None
ap: app.Application

View File

@@ -17,6 +17,7 @@ from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("openai-chat-completion")
class OpenAIChatCompletion(api.LLMAPIRequester):
"""OpenAI ChatCompletion API 请求器"""
@@ -133,7 +134,10 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e:
raise errors.RequesterError(f'请求错误: {e.message}')
if 'context_length_exceeded' in e.message:
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
else:
raise errors.RequesterError(f'请求参数错误: {e.message}')
except openai.AuthenticationError as e:
raise errors.RequesterError(f'无效的 api-key: {e.message}')
except openai.NotFoundError as e:

View File

@@ -5,7 +5,7 @@ import typing
import pydantic
from . import api
from . import token, tokenizer
from . import token
class LLMModelInfo(pydantic.BaseModel):
@@ -19,11 +19,7 @@ class LLMModelInfo(pydantic.BaseModel):
requester: api.LLMAPIRequester
tokenizer: 'tokenizer.LLMTokenizer'
tool_call_supported: typing.Optional[bool] = False
max_tokens: typing.Optional[int] = 2048
class Config:
arbitrary_types_allowed = True

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
from . import entities
from ...core import app
from .apis import chatcmpl
from . import token
from .tokenizers import tiktoken
from .apis import chatcmpl
class ModelManager:
@@ -30,9 +29,7 @@ class ModelManager:
async def initialize(self):
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
await openai_chat_completion.initialize()
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.provider_cfg.data['openai-config']['api-keys']))
tiktoken_tokenizer = tiktoken.Tiktoken(self.ap)
openai_token_mgr = token.TokenManager("openai", list(self.ap.provider_cfg.data['openai-config']['api-keys']))
model_list = [
entities.LLMModelInfo(
@@ -40,48 +37,36 @@ class ModelManager:
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=4096
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-1106",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=16385
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-16k",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=16385
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=4096
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-16k-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=16385
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-0301",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=4096
)
]
@@ -93,64 +78,48 @@ class ModelManager:
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4-turbo-preview",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4-1106-preview",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4-vision-preview",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=8192
),
entities.LLMModelInfo(
name="gpt-4-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=8192
),
entities.LLMModelInfo(
name="gpt-4-32k",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=32768
),
entities.LLMModelInfo(
name="gpt-4-32k-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=32768
)
]
@@ -163,8 +132,6 @@ class ModelManager:
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=8192
),
entities.LLMModelInfo(
name="OneAPI/chatglm_pro",
@@ -172,8 +139,6 @@ class ModelManager:
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="OneAPI/chatglm_std",
@@ -181,8 +146,6 @@ class ModelManager:
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="OneAPI/chatglm_lite",
@@ -190,8 +153,6 @@ class ModelManager:
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="OneAPI/qwen-v1",
@@ -199,8 +160,6 @@ class ModelManager:
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=6000
),
entities.LLMModelInfo(
name="OneAPI/qwen-plus-v1",
@@ -208,8 +167,6 @@ class ModelManager:
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=30000
),
entities.LLMModelInfo(
name="OneAPI/ERNIE-Bot",
@@ -217,8 +174,6 @@ class ModelManager:
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=2000
),
entities.LLMModelInfo(
name="OneAPI/ERNIE-Bot-turbo",
@@ -226,8 +181,6 @@ class ModelManager:
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=7000
),
entities.LLMModelInfo(
name="OneAPI/gemini-pro",
@@ -235,8 +188,6 @@ class ModelManager:
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=30720
),
]

View File

@@ -1,30 +0,0 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from .. import entities as llm_entities
from . import entities
class LLMTokenizer(metaclass=abc.ABCMeta):
"""LLM分词器抽象类"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
"""初始化分词器
"""
pass
@abc.abstractmethod
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
pass

View File

@@ -1,30 +0,0 @@
from __future__ import annotations
import tiktoken
from .. import tokenizer
from ... import entities as llm_entities
from .. import entities
class Tiktoken(tokenizer.LLMTokenizer):
"""TikToken分词器
"""
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
try:
encoding = tiktoken.encoding_for_model(model.name)
except KeyError:
# print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
num_tokens += len(encoding.encode(message.role))
num_tokens += len(encoding.encode(message.content if message.content is not None else ''))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens

View File

@@ -1,13 +1,27 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from . import entities
preregistered_loaders: list[typing.Type[PromptLoader]] = []
def loader_class(name: str):
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
cls.name = name
preregistered_loaders.append(cls)
return cls
return decorator
class PromptLoader(metaclass=abc.ABCMeta):
"""Prompt加载器抽象类
"""
name: str
ap: app.Application

View File

@@ -8,6 +8,7 @@ from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("full_scenario")
class ScenarioPromptLoader(loader.PromptLoader):
"""加载scenario目录下的json"""

View File

@@ -6,6 +6,7 @@ from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("normal")
class SingleSystemPromptLoader(loader.PromptLoader):
"""配置文件中的单条system prompt的prompt加载器
"""

View File

@@ -20,14 +20,18 @@ class PromptManager:
async def initialize(self):
loader_map = {
"normal": single.SingleSystemPromptLoader,
"full_scenario": scenario.ScenarioPromptLoader
}
mode_name = self.ap.provider_cfg.data['prompt-mode']
loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']]
loader_class = None
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
for loader_cls in loader.preregistered_loaders:
if loader_cls.name == mode_name:
loader_class = loader_cls
break
else:
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
await self.loader_inst.initialize()
await self.loader_inst.load()