From 8713fd813089c2447495ac51c1fef693ffc25fe8 Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Fri, 3 Mar 2023 00:07:53 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=E4=BC=9A=E8=AF=9D?= =?UTF-8?q?=E5=A4=84=E7=90=86=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/openai/manager.py | 7 +++--- pkg/openai/modelmgr.py | 33 +++++++++++++++------------- pkg/openai/session.py | 49 +++++++++++++++++++++++++++++------------- 3 files changed, 56 insertions(+), 33 deletions(-) diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index 563e14ad..3bd0c275 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -32,12 +32,13 @@ class OpenAIInteract: pkg.utils.context.set_openai_manager(self) # 请求OpenAI Completion - def request_completion(self, messages): + def request_completion(self, prompts): config = pkg.utils.context.get_config() - ai:ModelRequest = create_openai_model_request(config.completion_api_params['model'], 'user') + # 根据模型选择使用的接口 + ai: ModelRequest = create_openai_model_request(config.completion_api_params['model'], 'user') ai.request( - messages, + prompts, **config.completion_api_params ) response = ai.get_response() diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index 66b3c080..cdaa45cd 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -26,7 +26,7 @@ IMAGE_MODELS = { class ModelRequest(): - + """模型请求抽象类""" can_chat = False def __init__(self, model_name, user_name, request_fun): @@ -39,7 +39,8 @@ class ModelRequest(): self.ret = self.ret_handle(ret) self.message = self.ret["choices"][0]["message"] - def msg_handle(self, msg): + def __msg_handle__(self, msg): + """将prompt dict转换成接口需要的格式""" return msg def ret_handle(self, ret): @@ -63,17 +64,16 @@ class ChatCompletionModel(ModelRequest): self.can_chat = True super().__init__(model_name, user_name, request_fun) - def request(self, messages, **kwargs): - ret = self.request_fun(messages = self.msg_handle(messages), **kwargs, user=self.user_name) + def request(self, prompts, **kwargs): + ret = self.request_fun(messages = self.__msg_handle__(prompts), **kwargs, user=self.user_name) self.ret = self.ret_handle(ret) self.message = self.ret["choices"][0]["message"]['content'] - def msg_handle(self, msgs): + def __msg_handle__(self, msgs): temp_msgs = [] + # 把msgs拷贝进temp_msgs for msg in msgs: - if msg['role'] not in self.Chat_role: - msg['role'] = 'user' - temp_msgs.append(msg) + temp_msgs.append(msg.copy()) return temp_msgs def get_content(self): @@ -86,18 +86,21 @@ class CompletionModel(ModelRequest): request_fun = openai.Completion.create super().__init__(model_name, user_name, request_fun) - def request(self, prompt, **kwargs): - ret = self.request_fun(prompt = self.msg_handle(prompt), **kwargs) + def request(self, prompts, **kwargs): + ret = self.request_fun(prompt = self.__msg_handle__(prompts), **kwargs) self.ret = self.ret_handle(ret) self.message = self.ret["choices"][0]["text"] - def msg_handle(self, msgs): + def __msg_handle__(self, msgs): prompt = '' for msg in msgs: - if msg['role'] == 'assistant': - prompt = prompt + "{}\n".format(msg['content']) - else: - prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content']) + prompt = prompt + "{}: {}\n".format(msg['role'], msg['content']) + # for msg in msgs: + # if msg['role'] == 'assistant': + # prompt = prompt + "{}\n".format(msg['content']) + # else: + # prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content']) + prompt = prompt + "assistant: " return prompt def get_text(self): diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 81c8cbe4..9bd9e72d 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -126,10 +126,15 @@ class Session: else: current_default_prompt = dprompt.get_prompt(use_default) - return [{ - 'role': 'system', - 'content': current_default_prompt - }] + return [ + { + 'role': 'user', + 'content': current_default_prompt + },{ + 'role': 'assistant', + 'content': 'ok' + } + ] def __init__(self, name: str): self.name = name @@ -195,12 +200,11 @@ class Session: max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024 # 向API请求补全 - self.cut_out(text, max_length) message = pkg.utils.context.get_openai_manager().request_completion( - self.prompt + self.cut_out(text, max_length), ) - # 处理回复 + # 成功获取,处理回复 res_test = message res_ans = res_test @@ -210,6 +214,8 @@ class Session: del (res_ans_spt[0]) res_ans = '\n\n'.join(res_ans_spt) + # 将此次对话的双方内容加入到prompt中 + self.prompt.append({'role':'user', 'content':text}) self.prompt.append({'role':'assistant', 'content':res_ans}) if self.just_switched_to_exist_session: @@ -234,17 +240,30 @@ class Session: return res # 构建对话体 - def cut_out(self, msg: str, max_tokens: int) -> str: + def cut_out(self, msg: str, max_tokens: int) -> list: + """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens""" + # 如果用户消息长度超过max_tokens,直接返回 + + temp_prompt = [ + { + 'role': 'user', + 'content': msg + } + ] - if len(msg) > max_tokens: - msg = msg[:max_tokens] + token_count = len(msg) + # 倒序遍历prompt + for i in range(len(self.prompt) - 1, -1, -1): + if token_count >= max_tokens: + break - self.prompt.append({ - 'role': 'user', - 'content': msg - }) + # 将prompt加到temp_prompt头部 + temp_prompt.insert(0, self.prompt[i]) + token_count += len(self.prompt[i]['content']) - logging.debug('cut_out: {}'.format(msg)) + logging.debug('cut_out: {}'.format(str(temp_prompt))) + + return temp_prompt # 持久化session def persistence(self):