mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 19:37:36 +08:00
feat: 实现新的前文剪切模式
This commit is contained in:
@@ -55,7 +55,7 @@ class DatabaseManager:
|
||||
`status` varchar(255) not null default 'on_going',
|
||||
`default_prompt` text not null default '',
|
||||
`prompt` text not null,
|
||||
`token_counts` text not null default '[]',
|
||||
`token_counts` text not null default '[]'
|
||||
)
|
||||
""")
|
||||
|
||||
@@ -96,7 +96,7 @@ class DatabaseManager:
|
||||
|
||||
# session持久化
|
||||
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
|
||||
last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: list = []):
|
||||
last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: str = ''):
|
||||
"""持久化指定session"""
|
||||
|
||||
# 检查是否已经有了此name和create_timestamp的session
|
||||
@@ -115,14 +115,14 @@ class DatabaseManager:
|
||||
|
||||
self.__execute__(sql,
|
||||
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
|
||||
last_interact_timestamp, prompt, default_prompt, json.dumps(token_counts)))
|
||||
last_interact_timestamp, prompt, default_prompt, token_counts))
|
||||
else:
|
||||
sql = """
|
||||
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ?
|
||||
where `type` = ? and `number` = ? and `create_timestamp` = ?
|
||||
"""
|
||||
|
||||
self.__execute__(sql, (last_interact_timestamp, prompt, json.dumps(token_counts), subject_type,
|
||||
self.__execute__(sql, (last_interact_timestamp, prompt, token_counts, subject_type,
|
||||
subject_number, create_timestamp))
|
||||
|
||||
# 显式关闭一个session
|
||||
@@ -172,7 +172,7 @@ class DatabaseManager:
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt,
|
||||
'token_counts': json.loads(token_counts)
|
||||
'token_counts': token_counts
|
||||
}
|
||||
else:
|
||||
if session_name in sessions:
|
||||
@@ -210,7 +210,7 @@ class DatabaseManager:
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt,
|
||||
'token_counts': json.loads(token_counts)
|
||||
'token_counts': token_counts
|
||||
}
|
||||
|
||||
# 获取此session_name后一个session的数据
|
||||
@@ -243,7 +243,7 @@ class DatabaseManager:
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt,
|
||||
'token_counts': json.loads(token_counts)
|
||||
'token_counts': token_counts
|
||||
}
|
||||
|
||||
# 列出与某个对象的所有对话session
|
||||
@@ -272,7 +272,7 @@ class DatabaseManager:
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt,
|
||||
'token_counts': json.loads(token_counts)
|
||||
'token_counts': token_counts
|
||||
})
|
||||
|
||||
return sessions
|
||||
|
||||
@@ -34,7 +34,7 @@ class OpenAIInteract:
|
||||
pkg.utils.context.set_openai_manager(self)
|
||||
|
||||
# 请求OpenAI Completion
|
||||
def request_completion(self, prompts) -> str:
|
||||
def request_completion(self, prompts) -> tuple[str, int]:
|
||||
"""请求补全接口回复
|
||||
|
||||
Parameters:
|
||||
@@ -60,14 +60,18 @@ class OpenAIInteract:
|
||||
|
||||
logging.debug("OpenAI response: %s", response)
|
||||
|
||||
# 记录使用量
|
||||
current_round_token = 0
|
||||
if 'model' in config.completion_api_params:
|
||||
self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
|
||||
ai.get_total_tokens())
|
||||
current_round_token = ai.get_total_tokens()
|
||||
elif 'engine' in config.completion_api_params:
|
||||
self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'],
|
||||
response['usage']['total_tokens'])
|
||||
current_round_token = response['usage']['total_tokens']
|
||||
|
||||
return ai.get_message()
|
||||
return ai.get_message(), current_round_token
|
||||
|
||||
def request_image(self, prompt) -> dict:
|
||||
"""请求图片接口回复
|
||||
|
||||
@@ -72,6 +72,7 @@ def load_sessions():
|
||||
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
|
||||
try:
|
||||
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
|
||||
temp_session.token_counts = json.loads(session_data[session_name]['token_counts'])
|
||||
except Exception:
|
||||
temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt'])
|
||||
temp_session.persistence()
|
||||
@@ -106,6 +107,9 @@ class Session:
|
||||
prompt = []
|
||||
"""使用list来保存会话中的回合"""
|
||||
|
||||
token_counts = []
|
||||
"""每个回合的token数量"""
|
||||
|
||||
default_prompt = []
|
||||
"""本session的默认prompt"""
|
||||
|
||||
@@ -146,6 +150,8 @@ class Session:
|
||||
self.name = name
|
||||
self.create_timestamp = int(time.time())
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
self.prompt = []
|
||||
self.token_counts = []
|
||||
self.schedule()
|
||||
|
||||
self.response_lock = threading.Lock()
|
||||
@@ -209,9 +215,16 @@ class Session:
|
||||
config = pkg.utils.context.get_config()
|
||||
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
|
||||
|
||||
prompts, counts = self.cut_out(text, max_length)
|
||||
|
||||
# 计算请求前的prompt数量
|
||||
total_token_before_query = 0
|
||||
for token_count in counts:
|
||||
total_token_before_query += token_count
|
||||
|
||||
# 向API请求补全
|
||||
message = pkg.utils.context.get_openai_manager().request_completion(
|
||||
self.cut_out(text, max_length),
|
||||
message, total_token = pkg.utils.context.get_openai_manager().request_completion(
|
||||
prompts,
|
||||
)
|
||||
|
||||
# 成功获取,处理回复
|
||||
@@ -228,6 +241,10 @@ class Session:
|
||||
self.prompt.append({'role': 'user', 'content': text})
|
||||
self.prompt.append({'role': 'assistant', 'content': res_ans})
|
||||
|
||||
# 向token_counts中添加本回合的token数量
|
||||
self.token_counts.append(total_token-total_token_before_query)
|
||||
logging.debug("本回合使用token: {}, session counts: {}".format(total_token-total_token_before_query, self.token_counts))
|
||||
|
||||
if self.just_switched_to_exist_session:
|
||||
self.just_switched_to_exist_session = False
|
||||
self.set_ongoing()
|
||||
@@ -244,39 +261,65 @@ class Session:
|
||||
|
||||
question = self.prompt[-2]['content']
|
||||
self.prompt = self.prompt[:-2]
|
||||
self.token_counts = self.token_counts[:-1]
|
||||
|
||||
# 返回上一回合的问题
|
||||
return question
|
||||
|
||||
# 构建对话体
|
||||
def cut_out(self, msg: str, max_tokens: int) -> list:
|
||||
"""将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens"""
|
||||
# 如果用户消息长度超过max_tokens,直接返回
|
||||
temp_prompt: list = []
|
||||
temp_prompt += self.default_prompt
|
||||
temp_prompt.append(
|
||||
def cut_out(self, msg: str, max_tokens: int) -> tuple[list, list]:
|
||||
"""将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens
|
||||
|
||||
:return: (新的prompt, 新的token_counts)
|
||||
"""
|
||||
|
||||
# 最终由三个部分组成
|
||||
# - default_prompt 情景预设固定值
|
||||
# - changable_prompts 可变部分, 此会话中的历史对话回合
|
||||
# - current_question 当前问题
|
||||
|
||||
# 包装目前的对话回合内容
|
||||
changable_prompts = []
|
||||
changable_counts = []
|
||||
# 倒着来, 遍历prompt的步长为2, 遍历tokens_counts的步长为1
|
||||
changable_index = len(self.prompt) - 1
|
||||
token_count_index = len(self.token_counts) - 1
|
||||
|
||||
packed_tokens = 0
|
||||
|
||||
print(self.prompt)
|
||||
|
||||
while changable_index >= 0 and token_count_index >= 0:
|
||||
if packed_tokens + self.token_counts[token_count_index] > max_tokens:
|
||||
break
|
||||
|
||||
changable_prompts.insert(0, self.prompt[changable_index])
|
||||
changable_prompts.insert(0, self.prompt[changable_index - 1])
|
||||
changable_counts.insert(0, self.token_counts[token_count_index])
|
||||
packed_tokens += self.token_counts[token_count_index]
|
||||
|
||||
changable_index -= 2
|
||||
token_count_index -= 1
|
||||
|
||||
# 将default_prompt和changable_prompts合并
|
||||
result_prompt = self.default_prompt + changable_prompts
|
||||
|
||||
print(changable_prompts)
|
||||
|
||||
# 添加当前问题
|
||||
result_prompt.append(
|
||||
{
|
||||
'role': 'user',
|
||||
'content': msg
|
||||
}
|
||||
)
|
||||
|
||||
token_count = 0
|
||||
for item in temp_prompt:
|
||||
token_count += len(item['content'])
|
||||
logging.debug('cut_out: {}\nchangable section tokens: {}\npacked counts: {}\nsession counts: {}'.format(json.dumps(result_prompt, ensure_ascii=False, indent=4),
|
||||
packed_tokens,
|
||||
changable_counts,
|
||||
self.token_counts))
|
||||
|
||||
# 倒序遍历prompt
|
||||
for i in range(len(self.prompt) - 1, -1, -1):
|
||||
if token_count >= max_tokens:
|
||||
break
|
||||
|
||||
# 将prompt加到temp_prompt倒数第二个位置
|
||||
temp_prompt.insert(len(self.default_prompt), self.prompt[i])
|
||||
token_count += len(self.prompt[i]['content'])
|
||||
|
||||
logging.debug('cut_out: {}'.format(json.dumps(temp_prompt, ensure_ascii=False, indent=4)))
|
||||
|
||||
return temp_prompt
|
||||
return result_prompt, changable_counts
|
||||
|
||||
# 持久化session
|
||||
def persistence(self):
|
||||
@@ -291,7 +334,7 @@ class Session:
|
||||
subject_number = int(name_spt[1])
|
||||
|
||||
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
|
||||
json.dumps(self.prompt), json.dumps(self.default_prompt))
|
||||
json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts))
|
||||
|
||||
# 重置session
|
||||
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None):
|
||||
@@ -314,6 +357,7 @@ class Session:
|
||||
|
||||
self.default_prompt = self.get_default_prompt(use_prompt)
|
||||
self.prompt = []
|
||||
self.token_counts = []
|
||||
self.create_timestamp = int(time.time())
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
self.just_switched_to_exist_session = False
|
||||
@@ -339,6 +383,7 @@ class Session:
|
||||
self.last_interact_timestamp = last_one['last_interact_timestamp']
|
||||
try:
|
||||
self.prompt = json.loads(last_one['prompt'])
|
||||
self.token_counts = json.loads(last_one['token_counts'])
|
||||
except json.decoder.JSONDecodeError:
|
||||
self.prompt = reset_session_prompt(self.name, last_one['prompt'])
|
||||
self.persistence()
|
||||
@@ -359,6 +404,7 @@ class Session:
|
||||
self.last_interact_timestamp = next_one['last_interact_timestamp']
|
||||
try:
|
||||
self.prompt = json.loads(next_one['prompt'])
|
||||
self.token_counts = json.loads(next_one['token_counts'])
|
||||
except json.decoder.JSONDecodeError:
|
||||
self.prompt = reset_session_prompt(self.name, next_one['prompt'])
|
||||
self.persistence()
|
||||
|
||||
Reference in New Issue
Block a user