Merge pull request #262 from chordfish-k/json_scenario

[Feat] 情景预设(人格)完善
This commit is contained in:
Rock Chin
2023-03-12 20:40:20 +08:00
committed by GitHub
9 changed files with 140 additions and 49 deletions

2
.gitignore vendored
View File

@@ -12,3 +12,5 @@ logs/
sensitive.json sensitive.json
temp/ temp/
current_tag current_tag
scenario/
!scenario/default-template.json

View File

@@ -79,6 +79,11 @@ default_prompt = {
"default": "如果我之后想获取帮助,请你说“输入!help获取帮助”", "default": "如果我之后想获取帮助,请你说“输入!help获取帮助”",
} }
# 实验性设置项: JSON完整情景导入
# 预设prompt模式
# 参考值旧版本方式default | 完整情景full_scenario
preset_mode = "default"
# 群内响应规则 # 群内响应规则
# 符合此消息的群内消息即使不包含at机器人也会响应 # 符合此消息的群内消息即使不包含at机器人也会响应
# 支持消息前缀匹配及正则表达式匹配 # 支持消息前缀匹配及正则表达式匹配

View File

@@ -182,6 +182,7 @@ def main(first_time_init=False):
import pkg.openai.dprompt import pkg.openai.dprompt
pkg.openai.dprompt.read_prompt_from_file() pkg.openai.dprompt.read_prompt_from_file()
pkg.openai.dprompt.read_scenario_from_file()
pkg.utils.context.context['logger_handler'] = sh pkg.utils.context.context['logger_handler'] = sh
# 主启动流程 # 主启动流程
@@ -337,6 +338,10 @@ if __name__ == '__main__':
if not os.path.exists("sensitive.json"): if not os.path.exists("sensitive.json"):
shutil.copy("sensitive-template.json", "sensitive.json") shutil.copy("sensitive-template.json", "sensitive.json")
# 检查是否有scenario/default.json
if not os.path.exists("scenario/default.json"):
shutil.copy("scenario/default-template.json", "scenario/default.json")
# 检查temp目录 # 检查temp目录
if not os.path.exists("temp/"): if not os.path.exists("temp/"):
os.mkdir("temp/") os.mkdir("temp/")

View File

@@ -53,10 +53,23 @@ class DatabaseManager:
`create_timestamp` bigint not null, `create_timestamp` bigint not null,
`last_interact_timestamp` bigint not null, `last_interact_timestamp` bigint not null,
`status` varchar(255) not null default 'on_going', `status` varchar(255) not null default 'on_going',
`default_prompt` text not null default '',
`prompt` text not null `prompt` text not null
) )
""") """)
# 检查sessions表是否存在`default_prompt`字段
self.__execute__("PRAGMA table_info('sessions')")
columns = self.cursor.fetchall()
has_default_prompt = False
for field in columns:
if field[1] == 'default_prompt':
has_default_prompt = True
break
if not has_default_prompt:
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
self.__execute__(""" self.__execute__("""
create table if not exists `account_fee`( create table if not exists `account_fee`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT, `id` INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -76,7 +89,7 @@ class DatabaseManager:
# session持久化 # session持久化
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str): last_interact_timestamp: int, prompt: str, default_prompt: str = ''):
"""持久化指定session""" """持久化指定session"""
# 检查是否已经有了此name和create_timestamp的session # 检查是否已经有了此name和create_timestamp的session
@@ -89,13 +102,13 @@ class DatabaseManager:
if count == 0: if count == 0:
sql = """ sql = """
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`) insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`)
values (?, ?, ?, ?, ?, ?) values (?, ?, ?, ?, ?, ?, ?)
""" """
self.__execute__(sql, self.__execute__(sql,
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt)) last_interact_timestamp, prompt, default_prompt))
else: else:
sql = """ sql = """
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ? update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?
@@ -127,7 +140,7 @@ class DatabaseManager:
# 从数据库中加载所有还没过期的session # 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
self.__execute__(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
from `sessions` where `last_interact_timestamp` > {} from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config.session_expire_time)) """.format(int(time.time()) - config.session_expire_time))
results = self.cursor.fetchall() results = self.cursor.fetchall()
@@ -140,6 +153,7 @@ class DatabaseManager:
last_interact_timestamp = result[4] last_interact_timestamp = result[4]
prompt = result[5] prompt = result[5]
status = result[6] status = result[6]
default_prompt = result[7]
# 当且仅当最后一个该对象的会话是on_going状态时才会被加载 # 当且仅当最后一个该对象的会话是on_going状态时才会被加载
if status == 'on_going': if status == 'on_going':
@@ -148,7 +162,8 @@ class DatabaseManager:
'subject_number': subject_number, 'subject_number': subject_number,
'create_timestamp': create_timestamp, 'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp, 'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt 'prompt': prompt,
'default_prompt': default_prompt
} }
else: else:
if session_name in sessions: if session_name in sessions:
@@ -160,7 +175,7 @@ class DatabaseManager:
def last_session(self, session_name: str, cursor_timestamp: int): def last_session(self, session_name: str, cursor_timestamp: int):
self.__execute__(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
limit 1 limit 1
""".format(session_name, cursor_timestamp)) """.format(session_name, cursor_timestamp))
@@ -176,20 +191,22 @@ class DatabaseManager:
last_interact_timestamp = result[4] last_interact_timestamp = result[4]
prompt = result[5] prompt = result[5]
status = result[6] status = result[6]
default_prompt = result[7]
return { return {
'subject_type': subject_type, 'subject_type': subject_type,
'subject_number': subject_number, 'subject_number': subject_number,
'create_timestamp': create_timestamp, 'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp, 'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt 'prompt': prompt,
'default_prompt': default_prompt
} }
# 获取此session_name后一个session的数据 # 获取此session_name后一个session的数据
def next_session(self, session_name: str, cursor_timestamp: int): def next_session(self, session_name: str, cursor_timestamp: int):
self.__execute__(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
limit 1 limit 1
""".format(session_name, cursor_timestamp)) """.format(session_name, cursor_timestamp))
@@ -205,19 +222,21 @@ class DatabaseManager:
last_interact_timestamp = result[4] last_interact_timestamp = result[4]
prompt = result[5] prompt = result[5]
status = result[6] status = result[6]
default_prompt = result[7]
return { return {
'subject_type': subject_type, 'subject_type': subject_type,
'subject_number': subject_number, 'subject_number': subject_number,
'create_timestamp': create_timestamp, 'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp, 'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt 'prompt': prompt,
'default_prompt': default_prompt
} }
# 列出与某个对象的所有对话session # 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int): def list_history(self, session_name: str, capacity: int, page: int):
self.__execute__(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
""".format(session_name, capacity, capacity * page)) """.format(session_name, capacity, capacity * page))
results = self.cursor.fetchall() results = self.cursor.fetchall()
@@ -230,13 +249,15 @@ class DatabaseManager:
last_interact_timestamp = result[4] last_interact_timestamp = result[4]
prompt = result[5] prompt = result[5]
status = result[6] status = result[6]
default_prompt = result[7]
sessions.append({ sessions.append({
'subject_type': subject_type, 'subject_type': subject_type,
'subject_number': subject_number, 'subject_number': subject_number,
'create_timestamp': create_timestamp, 'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp, 'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt 'prompt': prompt,
'default_prompt': default_prompt
}) })
return sessions return sessions

View File

@@ -1,4 +1,6 @@
# 多情景预设值管理 # 多情景预设值管理
import json
import logging
__current__ = "default" __current__ = "default"
"""当前默认使用的情景预设的名称 """当前默认使用的情景预设的名称
@@ -9,8 +11,10 @@ __current__ = "default"
__prompts_from_files__ = {} __prompts_from_files__ = {}
"""从文件中读取的情景预设值""" """从文件中读取的情景预设值"""
__scenario_from_files__ = {}
def read_prompt_from_file() -> str:
def read_prompt_from_file():
"""从文件读取预设值""" """从文件读取预设值"""
# 读取prompts/目录下的所有文件,以文件名为键,文件内容为值 # 读取prompts/目录下的所有文件,以文件名为键,文件内容为值
# 保存在__prompts_from_files__中 # 保存在__prompts_from_files__中
@@ -23,6 +27,19 @@ def read_prompt_from_file() -> str:
__prompts_from_files__[file] = f.read() __prompts_from_files__[file] = f.read()
def read_scenario_from_file():
"""从JSON文件读取情景预设"""
global __scenario_from_files__
import os
__scenario_from_files__ = {}
for file in os.listdir("scenario"):
if file == "default-template.json":
continue
with open(os.path.join("scenario", file), encoding="utf-8") as f:
__scenario_from_files__[file] = json.load(f)
def get_prompt_dict() -> dict: def get_prompt_dict() -> dict:
"""获取预设值字典""" """获取预设值字典"""
import config import config
@@ -65,15 +82,40 @@ def set_to_default():
__current__ = list(default_dict.keys())[0] __current__ = list(default_dict.keys())[0]
def get_prompt(name: str = None) -> str: def get_prompt(name: str = None) -> list:
global __scenario_from_files__
import config
preset_mode = config.preset_mode
"""获取预设值""" """获取预设值"""
if name is None: if name is None:
name = get_current() name = get_current()
# JSON预设方式
if preset_mode == 'full_scenario':
import os
for key in __scenario_from_files__:
if key.lower().startswith(name.lower()):
logging.debug('成功加载情景预设从JSON文件: {}'.format(key))
return __scenario_from_files__[key]['prompt']
# 默认预设方式
elif preset_mode == 'default':
default_dict = get_prompt_dict() default_dict = get_prompt_dict()
for key in default_dict: for key in default_dict:
if key.lower().startswith(name.lower()): if key.lower().startswith(name.lower()):
return default_dict[key] return [
{
"role": "user",
"content": default_dict[key]
},
{
"role": "assistant",
"content": "好的。"
}
]
raise KeyError("未找到情景预设: " + name) raise KeyError("未找到默认情景预设: " + name)

View File

@@ -75,6 +75,8 @@ def load_sessions():
except Exception: except Exception:
temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt']) temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt'])
temp_session.persistence() temp_session.persistence()
temp_session.default_prompt = json.loads(session_data[session_name]['default_prompt']) if \
session_data[session_name]['default_prompt'] else []
sessions[session_name] = temp_session sessions[session_name] = temp_session
@@ -104,6 +106,9 @@ class Session:
prompt = [] prompt = []
"""使用list来保存会话中的回合""" """使用list来保存会话中的回合"""
default_prompt = []
"""本session的默认prompt"""
create_timestamp = 0 create_timestamp = 0
"""会话创建时间""" """会话创建时间"""
@@ -129,24 +134,13 @@ class Session:
# 从配置文件获取会话预设信息 # 从配置文件获取会话预设信息
def get_default_prompt(self, use_default: str = None): def get_default_prompt(self, use_default: str = None):
config = pkg.utils.context.get_config()
import pkg.openai.dprompt as dprompt import pkg.openai.dprompt as dprompt
if use_default is None: if use_default is None:
current_default_prompt = dprompt.get_prompt(dprompt.get_current()) use_default = dprompt.get_current()
else:
current_default_prompt = dprompt.get_prompt(use_default)
return [ current_default_prompt = dprompt.get_prompt(use_default)
{ return current_default_prompt
'role': 'user',
'content': current_default_prompt
}, {
'role': 'assistant',
'content': 'ok'
}
]
def __init__(self, name: str): def __init__(self, name: str):
self.name = name self.name = name
@@ -155,7 +149,9 @@ class Session:
self.schedule() self.schedule()
self.response_lock = threading.Lock() self.response_lock = threading.Lock()
self.prompt = self.get_default_prompt()
self.default_prompt = self.get_default_prompt()
logging.debug("prompt is: {}".format(self.default_prompt))
# 设定检查session最后一次对话是否超过过期时间的计时器 # 设定检查session最后一次对话是否超过过期时间的计时器
def schedule(self): def schedule(self):
@@ -199,11 +195,11 @@ class Session:
self.last_interact_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time())
# 触发插件事件 # 触发插件事件
if self.prompt == self.get_default_prompt(): if not self.prompt:
args = { args = {
'session_name': self.name, 'session_name': self.name,
'session': self, 'session': self,
'default_prompt': self.prompt, 'default_prompt': self.default_prompt,
} }
event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args) event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args)
@@ -256,25 +252,29 @@ class Session:
def cut_out(self, msg: str, max_tokens: int) -> list: def cut_out(self, msg: str, max_tokens: int) -> list:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens""" """将现有prompt进行切割处理使得新的prompt长度不超过max_tokens"""
# 如果用户消息长度超过max_tokens直接返回 # 如果用户消息长度超过max_tokens直接返回
temp_prompt: list = []
temp_prompt = [ temp_prompt += self.default_prompt
temp_prompt.append(
{ {
'role': 'user', 'role': 'user',
'content': msg 'content': msg
} }
] )
token_count = 0
for item in temp_prompt:
token_count += len(item['content'])
token_count = len(msg)
# 倒序遍历prompt # 倒序遍历prompt
for i in range(len(self.prompt) - 1, -1, -1): for i in range(len(self.prompt) - 1, -1, -1):
if token_count >= max_tokens: if token_count >= max_tokens:
break break
# 将prompt加到temp_prompt头部 # 将prompt加到temp_prompt倒数第二个位置
temp_prompt.insert(0, self.prompt[i]) temp_prompt.insert(len(self.default_prompt), self.prompt[i])
token_count += len(self.prompt[i]['content']) token_count += len(self.prompt[i]['content'])
logging.debug('cut_out: {}'.format(str(temp_prompt))) logging.debug('cut_out: {}'.format(json.dumps(temp_prompt, ensure_ascii=False, indent=4)))
return temp_prompt return temp_prompt
@@ -291,11 +291,11 @@ class Session:
subject_number = int(name_spt[1]) subject_number = int(name_spt[1])
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
json.dumps(self.prompt)) json.dumps(self.prompt), json.dumps(self.default_prompt))
# 重置session # 重置session
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None): def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None):
if self.prompt[-1]['role'] != "system": if self.prompt:
self.persistence() self.persistence()
if explicit: if explicit:
# 触发插件事件 # 触发插件事件
@@ -311,7 +311,9 @@ class Session:
if expired: if expired:
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp) pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
self.prompt = self.get_default_prompt(use_prompt)
self.default_prompt = self.get_default_prompt(use_prompt)
self.prompt = []
self.create_timestamp = int(time.time()) self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time())
self.just_switched_to_exist_session = False self.just_switched_to_exist_session = False
@@ -340,6 +342,7 @@ class Session:
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
self.prompt = reset_session_prompt(self.name, last_one['prompt']) self.prompt = reset_session_prompt(self.name, last_one['prompt'])
self.persistence() self.persistence()
self.default_prompt = json.loads(last_one['default_prompt']) if last_one['default_prompt'] else []
self.just_switched_to_exist_session = True self.just_switched_to_exist_session = True
return self return self
@@ -359,6 +362,7 @@ class Session:
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
self.prompt = reset_session_prompt(self.name, next_one['prompt']) self.prompt = reset_session_prompt(self.name, next_one['prompt'])
self.persistence() self.persistence()
self.default_prompt = json.loads(next_one['default_prompt']) if next_one['default_prompt'] else []
self.just_switched_to_exist_session = True self.just_switched_to_exist_session = True
return self return self

View File

@@ -234,7 +234,7 @@ def process_command(session_name: str, text_message: str, mgr, config,
if len(msg) >= 2: if len(msg) >= 2:
reply_str += "#{} 创建:{} {}\n".format(i + page * 10, reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
msg[1]['content']) msg[0]['content'])
else: else:
reply_str += "#{} 创建:{} {}\n".format(i + page * 10, reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),

View File

@@ -0,0 +1,12 @@
{
"prompt": [
{
"role": "system",
"content": "You are a helpful assistant. 如果我需要帮助,你要说“输入!help获得帮助”"
},
{
"role": "assistant",
"content": "好的我是一个能干的AI助手。 如果你需要帮助,我会说“输入!help获得帮助”"
}
]
}