2022-12-19 22:51:56 +08:00
|
|
|
import logging
|
|
|
|
|
|
2022-12-07 22:27:05 +08:00
|
|
|
import openai
|
2023-12-09 22:17:26 +08:00
|
|
|
from openai.types import images_response
|
2022-12-07 22:27:05 +08:00
|
|
|
|
2023-11-13 21:59:23 +08:00
|
|
|
from ..openai import keymgr
|
|
|
|
|
from ..utils import context
|
|
|
|
|
from ..audit import gatherer
|
|
|
|
|
from ..openai import modelmgr
|
|
|
|
|
from ..openai.api import model as api_model
|
2022-12-07 22:27:05 +08:00
|
|
|
|
2023-03-05 15:39:13 +08:00
|
|
|
|
2022-12-07 22:27:05 +08:00
|
|
|
class OpenAIInteract:
|
2023-03-05 15:39:13 +08:00
|
|
|
"""OpenAI 接口封装
|
|
|
|
|
|
|
|
|
|
将文字接口和图片接口封装供调用方使用
|
|
|
|
|
"""
|
2022-12-07 22:27:05 +08:00
|
|
|
|
2023-11-13 21:59:23 +08:00
|
|
|
key_mgr: keymgr.KeysManager = None
|
2023-01-03 17:50:13 +08:00
|
|
|
|
2023-11-13 21:59:23 +08:00
|
|
|
audit_mgr: gatherer.DataGatherer = None
|
2022-12-15 17:52:41 +08:00
|
|
|
|
2022-12-27 22:52:53 +08:00
|
|
|
default_image_api_params = {
|
|
|
|
|
"size": "256x256",
|
|
|
|
|
}
|
|
|
|
|
|
2023-11-10 11:31:14 +08:00
|
|
|
client: openai.Client = None
|
|
|
|
|
|
2022-12-27 22:52:53 +08:00
|
|
|
def __init__(self, api_key: str):
|
2022-12-07 22:27:05 +08:00
|
|
|
|
2023-11-13 21:59:23 +08:00
|
|
|
self.key_mgr = keymgr.KeysManager(api_key)
|
|
|
|
|
self.audit_mgr = gatherer.DataGatherer()
|
2022-12-15 17:52:41 +08:00
|
|
|
|
2023-07-31 21:11:28 +08:00
|
|
|
# logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length())
|
2023-02-10 19:03:25 +08:00
|
|
|
|
2023-11-10 11:31:14 +08:00
|
|
|
self.client = openai.Client(
|
2023-11-21 15:44:07 +08:00
|
|
|
api_key=self.key_mgr.get_using_key(),
|
|
|
|
|
base_url=openai.base_url
|
2023-11-10 11:31:14 +08:00
|
|
|
)
|
2022-12-07 22:27:05 +08:00
|
|
|
|
2023-11-13 21:59:23 +08:00
|
|
|
context.set_openai_manager(self)
|
2022-12-07 22:27:05 +08:00
|
|
|
|
2023-07-28 19:03:02 +08:00
|
|
|
def request_completion(self, messages: list):
|
|
|
|
|
"""请求补全接口回复=
|
|
|
|
|
"""
|
|
|
|
|
# 选择接口请求类
|
2023-11-26 23:58:06 +08:00
|
|
|
config = context.get_config_manager().data
|
2023-03-05 15:39:13 +08:00
|
|
|
|
2023-11-13 21:59:23 +08:00
|
|
|
request: api_model.RequestBase
|
2023-03-05 15:39:13 +08:00
|
|
|
|
2023-11-26 23:58:06 +08:00
|
|
|
model: str = config['completion_api_params']['model']
|
2023-03-05 15:39:13 +08:00
|
|
|
|
2023-11-26 23:58:06 +08:00
|
|
|
cp_parmas = config['completion_api_params'].copy()
|
2023-07-28 19:03:02 +08:00
|
|
|
del cp_parmas['model']
|
2023-03-02 15:31:12 +08:00
|
|
|
|
2023-11-13 21:59:23 +08:00
|
|
|
request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas)
|
2023-07-28 19:03:02 +08:00
|
|
|
|
|
|
|
|
# 请求接口
|
|
|
|
|
for resp in request:
|
2023-07-31 14:28:48 +08:00
|
|
|
|
|
|
|
|
if resp['usage']['total_tokens'] > 0:
|
|
|
|
|
self.audit_mgr.report_text_model_usage(
|
|
|
|
|
model,
|
|
|
|
|
resp['usage']['total_tokens']
|
|
|
|
|
)
|
|
|
|
|
|
2023-07-28 19:03:02 +08:00
|
|
|
yield resp
|
|
|
|
|
|
2023-12-09 22:17:26 +08:00
|
|
|
def request_image(self, prompt) -> images_response.ImagesResponse:
|
2023-03-05 15:39:13 +08:00
|
|
|
"""请求图片接口回复
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
prompt (str): 提示语
|
2022-12-28 00:05:25 +08:00
|
|
|
|
2023-03-05 15:39:13 +08:00
|
|
|
Returns:
|
|
|
|
|
dict: 响应
|
|
|
|
|
"""
|
2023-11-26 23:58:06 +08:00
|
|
|
config = context.get_config_manager().data
|
|
|
|
|
params = config['image_api_params']
|
2022-12-28 00:05:25 +08:00
|
|
|
|
2023-12-09 22:17:26 +08:00
|
|
|
response = self.client.images.generate(
|
2022-12-27 22:52:53 +08:00
|
|
|
prompt=prompt,
|
|
|
|
|
n=1,
|
2022-12-28 00:05:25 +08:00
|
|
|
**params
|
2022-12-27 22:52:53 +08:00
|
|
|
)
|
|
|
|
|
|
2023-01-03 17:50:13 +08:00
|
|
|
self.audit_mgr.report_image_model_usage(params['size'])
|
|
|
|
|
|
2022-12-27 22:52:53 +08:00
|
|
|
return response
|
|
|
|
|
|