feat: 完善会话处理的逻辑

This commit is contained in:
Rock Chin
2023-03-03 00:07:53 +08:00
parent 2234e9db0e
commit 8713fd8130
3 changed files with 56 additions and 33 deletions

View File

@@ -32,12 +32,13 @@ class OpenAIInteract:
pkg.utils.context.set_openai_manager(self)
# 请求OpenAI Completion
def request_completion(self, messages):
def request_completion(self, prompts):
config = pkg.utils.context.get_config()
ai:ModelRequest = create_openai_model_request(config.completion_api_params['model'], 'user')
# 根据模型选择使用的接口
ai: ModelRequest = create_openai_model_request(config.completion_api_params['model'], 'user')
ai.request(
messages,
prompts,
**config.completion_api_params
)
response = ai.get_response()

View File

@@ -26,7 +26,7 @@ IMAGE_MODELS = {
class ModelRequest():
"""模型请求抽象类"""
can_chat = False
def __init__(self, model_name, user_name, request_fun):
@@ -39,7 +39,8 @@ class ModelRequest():
self.ret = self.ret_handle(ret)
self.message = self.ret["choices"][0]["message"]
def msg_handle(self, msg):
def __msg_handle__(self, msg):
"""将prompt dict转换成接口需要的格式"""
return msg
def ret_handle(self, ret):
@@ -63,17 +64,16 @@ class ChatCompletionModel(ModelRequest):
self.can_chat = True
super().__init__(model_name, user_name, request_fun)
def request(self, messages, **kwargs):
ret = self.request_fun(messages = self.msg_handle(messages), **kwargs, user=self.user_name)
def request(self, prompts, **kwargs):
ret = self.request_fun(messages = self.__msg_handle__(prompts), **kwargs, user=self.user_name)
self.ret = self.ret_handle(ret)
self.message = self.ret["choices"][0]["message"]['content']
def msg_handle(self, msgs):
def __msg_handle__(self, msgs):
temp_msgs = []
# 把msgs拷贝进temp_msgs
for msg in msgs:
if msg['role'] not in self.Chat_role:
msg['role'] = 'user'
temp_msgs.append(msg)
temp_msgs.append(msg.copy())
return temp_msgs
def get_content(self):
@@ -86,18 +86,21 @@ class CompletionModel(ModelRequest):
request_fun = openai.Completion.create
super().__init__(model_name, user_name, request_fun)
def request(self, prompt, **kwargs):
ret = self.request_fun(prompt = self.msg_handle(prompt), **kwargs)
def request(self, prompts, **kwargs):
ret = self.request_fun(prompt = self.__msg_handle__(prompts), **kwargs)
self.ret = self.ret_handle(ret)
self.message = self.ret["choices"][0]["text"]
def msg_handle(self, msgs):
def __msg_handle__(self, msgs):
prompt = ''
for msg in msgs:
if msg['role'] == 'assistant':
prompt = prompt + "{}\n".format(msg['content'])
else:
prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content'])
prompt = prompt + "{}: {}\n".format(msg['role'], msg['content'])
# for msg in msgs:
# if msg['role'] == 'assistant':
# prompt = prompt + "{}\n".format(msg['content'])
# else:
# prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content'])
prompt = prompt + "assistant: "
return prompt
def get_text(self):

View File

@@ -126,10 +126,15 @@ class Session:
else:
current_default_prompt = dprompt.get_prompt(use_default)
return [{
'role': 'system',
'content': current_default_prompt
}]
return [
{
'role': 'user',
'content': current_default_prompt
},{
'role': 'assistant',
'content': 'ok'
}
]
def __init__(self, name: str):
self.name = name
@@ -195,12 +200,11 @@ class Session:
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
# 向API请求补全
self.cut_out(text, max_length)
message = pkg.utils.context.get_openai_manager().request_completion(
self.prompt
self.cut_out(text, max_length),
)
# 处理回复
# 成功获取,处理回复
res_test = message
res_ans = res_test
@@ -210,6 +214,8 @@ class Session:
del (res_ans_spt[0])
res_ans = '\n\n'.join(res_ans_spt)
# 将此次对话的双方内容加入到prompt中
self.prompt.append({'role':'user', 'content':text})
self.prompt.append({'role':'assistant', 'content':res_ans})
if self.just_switched_to_exist_session:
@@ -234,17 +240,30 @@ class Session:
return res
# 构建对话体
def cut_out(self, msg: str, max_tokens: int) -> str:
def cut_out(self, msg: str, max_tokens: int) -> list:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens"""
# 如果用户消息长度超过max_tokens直接返回
if len(msg) > max_tokens:
msg = msg[:max_tokens]
temp_prompt = [
{
'role': 'user',
'content': msg
}
]
self.prompt.append({
'role': 'user',
'content': msg
})
token_count = len(msg)
# 倒序遍历prompt
for i in range(len(self.prompt) - 1, -1, -1):
if token_count >= max_tokens:
break
logging.debug('cut_out: {}'.format(msg))
# 将prompt加到temp_prompt头部
temp_prompt.insert(0, self.prompt[i])
token_count += len(self.prompt[i]['content'])
logging.debug('cut_out: {}'.format(str(temp_prompt)))
return temp_prompt
# 持久化session
def persistence(self):