feat: 新增PromptPreprocessing事件

This commit is contained in:
RockChinQ
2023-07-31 17:21:09 +08:00
parent 2b9612e933
commit 6b8fa664f1
3 changed files with 43 additions and 7 deletions

View File

@@ -214,7 +214,29 @@ class Session:
config = pkg.utils.context.get_config()
max_length = config.prompt_submit_length
prompts, _ = self.cut_out(text, max_length)
local_default_prompt = self.default_prompt.copy()
local_prompt = self.prompt.copy()
# 触发PromptPreProcessing事件
args = {
'session_name': self.name,
'default_prompt': self.default_prompt,
'prompt': self.prompt,
'text_message': text,
}
event = pkg.plugin.host.emit(plugin_models.PromptPreProcessing, **args)
if event.get_return_value('default_prompt') is not None:
local_default_prompt = event.get_return_value('default_prompt')
if event.get_return_value('prompt') is not None:
local_prompt = event.get_return_value('prompt')
if event.get_return_value('text_message') is not None:
text = event.get_return_value('text_message')
prompts, _ = self.cut_out(text, max_length, local_default_prompt, local_prompt)
res_text = ""
@@ -301,7 +323,7 @@ class Session:
return question
# 构建对话体
def cut_out(self, msg: str, max_tokens: int) -> tuple[list, list]:
def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) -> tuple[list, list]:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens
:return: (新的prompt, 新的token_counts)
@@ -317,19 +339,19 @@ class Session:
use_model = pkg.utils.context.get_config().completion_api_params['model']
ptr = len(self.prompt) - 1
ptr = len(prompt) - 1
# 直接从后向前扫描拼接,不管是否是整回合
while ptr >= 0:
if count_tokens(self.prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
if count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
break
changable_prompts.insert(0, self.prompt[ptr])
changable_prompts.insert(0, prompt[ptr])
ptr -= 1
# 将default_prompt和changable_prompts合并
result_prompt = self.default_prompt + changable_prompts
result_prompt = default_prompt + changable_prompts
# 添加当前问题
if msg:

View File

@@ -266,7 +266,7 @@ class EventContext:
self.__return_value__[key] = []
self.__return_value__[key].append(ret)
def get_return(self, key: str):
def get_return(self, key: str) -> list:
"""获取key的所有返回值"""
if key in self.__return_value__:
return self.__return_value__[key]

View File

@@ -132,6 +132,20 @@ KeySwitched = "key_switched"
key_list: list[str] api-key列表
"""
PromptPreProcessing = "prompt_pre_processing"
"""每回合调用接口前对prompt进行预处理时触发此事件不支持阻止默认行为
kwargs:
session_name: str 会话名称(<launcher_type>_<launcher_id>)
default_prompt: list 此session使用的情景预设内容
prompt: list 此session现有的prompt内容
text_message: str 用户发送的消息文本
returns (optional):
default_prompt: list 修改后的情景预设内容
prompt: list 修改后的prompt内容
text_message: str 修改后的消息文本
"""
def on(*args, **kwargs):
"""注册事件监听器