mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 19:37:36 +08:00
feat: 初步支持绘图api
This commit is contained in:
@@ -77,6 +77,12 @@ completion_api_params = {
|
||||
"presence_penalty": 1.0,
|
||||
}
|
||||
|
||||
# OpenAI的Image API的参数
|
||||
# 具体请查看OpenAI的文档: https://beta.openai.com/docs/api-reference/images/create
|
||||
image_api_params = {
|
||||
"size": "256x256",
|
||||
}
|
||||
|
||||
# 消息处理的超时时间,单位为秒
|
||||
process_message_timeout = 15
|
||||
|
||||
|
||||
2
main.py
2
main.py
@@ -59,7 +59,7 @@ def main():
|
||||
|
||||
database.initialize_database()
|
||||
|
||||
openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params)
|
||||
openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'])
|
||||
|
||||
# 加载所有未超时的session
|
||||
pkg.openai.session.load_sessions()
|
||||
|
||||
@@ -15,9 +15,12 @@ class OpenAIInteract:
|
||||
|
||||
key_mgr = None
|
||||
|
||||
def __init__(self, api_key: str, api_params: dict):
|
||||
default_image_api_params = {
|
||||
"size": "256x256",
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
# self.api_key = api_key
|
||||
self.api_params = api_params
|
||||
|
||||
self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
|
||||
|
||||
@@ -28,12 +31,11 @@ class OpenAIInteract:
|
||||
|
||||
# 请求OpenAI Completion
|
||||
def request_completion(self, prompt, stop):
|
||||
logging.debug("请求OpenAI Completion, key:"+openai.api_key)
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
timeout=config.process_message_timeout,
|
||||
**self.api_params
|
||||
**config.completion_api_params
|
||||
)
|
||||
switched = self.key_mgr.report_usage(prompt + response['choices'][0]['text'])
|
||||
if switched:
|
||||
@@ -41,6 +43,15 @@ class OpenAIInteract:
|
||||
|
||||
return response
|
||||
|
||||
def request_image(self, prompt):
|
||||
response = openai.Image.create(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
**config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_inst() -> OpenAIInteract:
|
||||
global inst
|
||||
|
||||
@@ -268,3 +268,6 @@ class Session:
|
||||
def list_history(self, capacity: int = 10, page: int = 0):
|
||||
return pkg.database.manager.get_inst().list_history(self.name, capacity, page,
|
||||
get_default_prompt())
|
||||
|
||||
def draw_image(self, prompt: str):
|
||||
return pkg.openai.manager.get_inst().request_image(prompt)
|
||||
|
||||
@@ -6,6 +6,8 @@ from func_timeout import func_set_timeout
|
||||
import logging
|
||||
import openai
|
||||
|
||||
from mirai import Image
|
||||
|
||||
import config
|
||||
|
||||
import pkg.openai.session
|
||||
@@ -128,6 +130,17 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str) ->
|
||||
reply_str += "\n当前使用:{}".format(using_key_name)
|
||||
|
||||
reply = [reply_str]
|
||||
|
||||
elif cmd == 'draw':
|
||||
if len(params) == 0:
|
||||
reply = ["[bot]err:请输入图片描述文字"]
|
||||
else:
|
||||
session = pkg.openai.session.get_session(session_name)
|
||||
|
||||
res = session.draw_image(" ".join(params))
|
||||
|
||||
logging.debug("draw_image result:{}".format(res))
|
||||
reply = [Image(url=res['data'][0]['url'])]
|
||||
except Exception as e:
|
||||
mgr.notify_admin("{}指令执行失败:{}".format(session_name, e))
|
||||
logging.exception(e)
|
||||
|
||||
Reference in New Issue
Block a user