mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
feat: 完善会话处理的逻辑
This commit is contained in:
@@ -32,12 +32,13 @@ class OpenAIInteract:
|
||||
pkg.utils.context.set_openai_manager(self)
|
||||
|
||||
# 请求OpenAI Completion
|
||||
def request_completion(self, messages):
|
||||
def request_completion(self, prompts):
|
||||
config = pkg.utils.context.get_config()
|
||||
|
||||
ai:ModelRequest = create_openai_model_request(config.completion_api_params['model'], 'user')
|
||||
# 根据模型选择使用的接口
|
||||
ai: ModelRequest = create_openai_model_request(config.completion_api_params['model'], 'user')
|
||||
ai.request(
|
||||
messages,
|
||||
prompts,
|
||||
**config.completion_api_params
|
||||
)
|
||||
response = ai.get_response()
|
||||
|
||||
@@ -26,7 +26,7 @@ IMAGE_MODELS = {
|
||||
|
||||
|
||||
class ModelRequest():
|
||||
|
||||
"""模型请求抽象类"""
|
||||
can_chat = False
|
||||
|
||||
def __init__(self, model_name, user_name, request_fun):
|
||||
@@ -39,7 +39,8 @@ class ModelRequest():
|
||||
self.ret = self.ret_handle(ret)
|
||||
self.message = self.ret["choices"][0]["message"]
|
||||
|
||||
def msg_handle(self, msg):
|
||||
def __msg_handle__(self, msg):
|
||||
"""将prompt dict转换成接口需要的格式"""
|
||||
return msg
|
||||
|
||||
def ret_handle(self, ret):
|
||||
@@ -63,17 +64,16 @@ class ChatCompletionModel(ModelRequest):
|
||||
self.can_chat = True
|
||||
super().__init__(model_name, user_name, request_fun)
|
||||
|
||||
def request(self, messages, **kwargs):
|
||||
ret = self.request_fun(messages = self.msg_handle(messages), **kwargs, user=self.user_name)
|
||||
def request(self, prompts, **kwargs):
|
||||
ret = self.request_fun(messages = self.__msg_handle__(prompts), **kwargs, user=self.user_name)
|
||||
self.ret = self.ret_handle(ret)
|
||||
self.message = self.ret["choices"][0]["message"]['content']
|
||||
|
||||
def msg_handle(self, msgs):
|
||||
def __msg_handle__(self, msgs):
|
||||
temp_msgs = []
|
||||
# 把msgs拷贝进temp_msgs
|
||||
for msg in msgs:
|
||||
if msg['role'] not in self.Chat_role:
|
||||
msg['role'] = 'user'
|
||||
temp_msgs.append(msg)
|
||||
temp_msgs.append(msg.copy())
|
||||
return temp_msgs
|
||||
|
||||
def get_content(self):
|
||||
@@ -86,18 +86,21 @@ class CompletionModel(ModelRequest):
|
||||
request_fun = openai.Completion.create
|
||||
super().__init__(model_name, user_name, request_fun)
|
||||
|
||||
def request(self, prompt, **kwargs):
|
||||
ret = self.request_fun(prompt = self.msg_handle(prompt), **kwargs)
|
||||
def request(self, prompts, **kwargs):
|
||||
ret = self.request_fun(prompt = self.__msg_handle__(prompts), **kwargs)
|
||||
self.ret = self.ret_handle(ret)
|
||||
self.message = self.ret["choices"][0]["text"]
|
||||
|
||||
def msg_handle(self, msgs):
|
||||
def __msg_handle__(self, msgs):
|
||||
prompt = ''
|
||||
for msg in msgs:
|
||||
if msg['role'] == 'assistant':
|
||||
prompt = prompt + "{}\n".format(msg['content'])
|
||||
else:
|
||||
prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content'])
|
||||
prompt = prompt + "{}: {}\n".format(msg['role'], msg['content'])
|
||||
# for msg in msgs:
|
||||
# if msg['role'] == 'assistant':
|
||||
# prompt = prompt + "{}\n".format(msg['content'])
|
||||
# else:
|
||||
# prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content'])
|
||||
prompt = prompt + "assistant: "
|
||||
return prompt
|
||||
|
||||
def get_text(self):
|
||||
|
||||
@@ -126,10 +126,15 @@ class Session:
|
||||
else:
|
||||
current_default_prompt = dprompt.get_prompt(use_default)
|
||||
|
||||
return [{
|
||||
'role': 'system',
|
||||
'content': current_default_prompt
|
||||
}]
|
||||
return [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': current_default_prompt
|
||||
},{
|
||||
'role': 'assistant',
|
||||
'content': 'ok'
|
||||
}
|
||||
]
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
@@ -195,12 +200,11 @@ class Session:
|
||||
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
|
||||
|
||||
# 向API请求补全
|
||||
self.cut_out(text, max_length)
|
||||
message = pkg.utils.context.get_openai_manager().request_completion(
|
||||
self.prompt
|
||||
self.cut_out(text, max_length),
|
||||
)
|
||||
|
||||
# 处理回复
|
||||
# 成功获取,处理回复
|
||||
res_test = message
|
||||
res_ans = res_test
|
||||
|
||||
@@ -210,6 +214,8 @@ class Session:
|
||||
del (res_ans_spt[0])
|
||||
res_ans = '\n\n'.join(res_ans_spt)
|
||||
|
||||
# 将此次对话的双方内容加入到prompt中
|
||||
self.prompt.append({'role':'user', 'content':text})
|
||||
self.prompt.append({'role':'assistant', 'content':res_ans})
|
||||
|
||||
if self.just_switched_to_exist_session:
|
||||
@@ -234,17 +240,30 @@ class Session:
|
||||
return res
|
||||
|
||||
# 构建对话体
|
||||
def cut_out(self, msg: str, max_tokens: int) -> str:
|
||||
def cut_out(self, msg: str, max_tokens: int) -> list:
|
||||
"""将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens"""
|
||||
# 如果用户消息长度超过max_tokens,直接返回
|
||||
|
||||
if len(msg) > max_tokens:
|
||||
msg = msg[:max_tokens]
|
||||
temp_prompt = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': msg
|
||||
}
|
||||
]
|
||||
|
||||
self.prompt.append({
|
||||
'role': 'user',
|
||||
'content': msg
|
||||
})
|
||||
token_count = len(msg)
|
||||
# 倒序遍历prompt
|
||||
for i in range(len(self.prompt) - 1, -1, -1):
|
||||
if token_count >= max_tokens:
|
||||
break
|
||||
|
||||
logging.debug('cut_out: {}'.format(msg))
|
||||
# 将prompt加到temp_prompt头部
|
||||
temp_prompt.insert(0, self.prompt[i])
|
||||
token_count += len(self.prompt[i]['content'])
|
||||
|
||||
logging.debug('cut_out: {}'.format(str(temp_prompt)))
|
||||
|
||||
return temp_prompt
|
||||
|
||||
# 持久化session
|
||||
def persistence(self):
|
||||
|
||||
Reference in New Issue
Block a user