style: introduce ruff as linter and formatter (#1356)

* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
This commit is contained in:
Junyan Qin (Chin)
2025-04-29 17:24:07 +08:00
committed by GitHub
parent 09e70d70e9
commit 209f16af76
240 changed files with 5307 additions and 4689 deletions

1
.gitignore vendored
View File

@@ -29,6 +29,7 @@ qcapi
claude.json
bard.json
/*yaml
!.pre-commit-config.yaml
!components.yaml
!/docker-compose.yaml
data/labels/instance_id.json

9
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,9 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.7
hooks:
# Run the linter.
- id: ruff
# Run the formatter.
- id: ruff-format

View File

@@ -1,2 +1,4 @@
from .v1 import client
from .v1 import errors
from .v1 import client as client
from .v1 import errors as errors
__all__ = ['client', 'errors']

View File

@@ -8,25 +8,33 @@ import json
class TestDifyClient:
async def test_chat_messages(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
cln = client.AsyncDifyServiceClient(
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
)
async for chunk in cln.chat_messages(inputs={}, query="调用工具查看现在几点?", user="test"):
async for chunk in cln.chat_messages(
inputs={}, query='调用工具查看现在几点?', user='test'
):
print(json.dumps(chunk, ensure_ascii=False, indent=4))
async def test_upload_file(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
cln = client.AsyncDifyServiceClient(
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
)
file_bytes = open("img.png", "rb").read()
file_bytes = open('img.png', 'rb').read()
print(type(file_bytes))
file = ("img2.png", file_bytes, "image/png")
file = ('img2.png', file_bytes, 'image/png')
resp = await cln.upload_file(file=file, user="test")
resp = await cln.upload_file(file=file, user='test')
print(json.dumps(resp, ensure_ascii=False, indent=4))
async def test_workflow_run(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
cln = client.AsyncDifyServiceClient(
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
)
# resp = await cln.workflow_run(inputs={}, user="test")
# # print(json.dumps(resp, ensure_ascii=False, indent=4))
@@ -34,11 +42,12 @@ class TestDifyClient:
chunks = []
ignored_events = ['text_chunk']
async for chunk in cln.workflow_run(inputs={}, user="test"):
async for chunk in cln.workflow_run(inputs={}, user='test'):
if chunk['event'] in ignored_events:
continue
chunks.append(chunk)
print(json.dumps(chunks, ensure_ascii=False, indent=4))
if __name__ == "__main__":
if __name__ == '__main__':
asyncio.run(TestDifyClient().test_chat_messages())

View File

@@ -12,11 +12,11 @@ class AsyncDifyServiceClient:
api_key: str
base_url: str
def __init__(
self,
api_key: str,
base_url: str = "https://api.dify.ai/v1",
base_url: str = 'https://api.dify.ai/v1',
) -> None:
self.api_key = api_key
self.base_url = base_url
@@ -26,76 +26,81 @@ class AsyncDifyServiceClient:
inputs: dict[str, typing.Any],
query: str,
user: str,
response_mode: str = "streaming", # 当前不支持 blocking
conversation_id: str = "",
response_mode: str = 'streaming', # 当前不支持 blocking
conversation_id: str = '',
files: list[dict[str, typing.Any]] = [],
timeout: float = 30.0,
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""发送消息"""
if response_mode != "streaming":
raise DifyAPIError("当前仅支持 streaming 模式")
if response_mode != 'streaming':
raise DifyAPIError('当前仅支持 streaming 模式')
async with httpx.AsyncClient(
base_url=self.base_url,
trust_env=True,
timeout=timeout,
) as client:
async with client.stream(
"POST",
"/chat-messages",
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
'POST',
'/chat-messages',
headers={
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
},
json={
"inputs": inputs,
"query": query,
"user": user,
"response_mode": response_mode,
"conversation_id": conversation_id,
"files": files,
'inputs': inputs,
'query': query,
'user': user,
'response_mode': response_mode,
'conversation_id': conversation_id,
'files': files,
},
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise DifyAPIError(f"{r.status_code} {chunk}")
if chunk.strip() == "":
raise DifyAPIError(f'{r.status_code} {chunk}')
if chunk.strip() == '':
continue
if chunk.startswith("data:"):
if chunk.startswith('data:'):
yield json.loads(chunk[5:])
async def workflow_run(
self,
inputs: dict[str, typing.Any],
user: str,
response_mode: str = "streaming", # 当前不支持 blocking
response_mode: str = 'streaming', # 当前不支持 blocking
files: list[dict[str, typing.Any]] = [],
timeout: float = 30.0,
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""运行工作流"""
if response_mode != "streaming":
raise DifyAPIError("当前仅支持 streaming 模式")
if response_mode != 'streaming':
raise DifyAPIError('当前仅支持 streaming 模式')
async with httpx.AsyncClient(
base_url=self.base_url,
trust_env=True,
timeout=timeout,
) as client:
async with client.stream(
"POST",
"/workflows/run",
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
'POST',
'/workflows/run',
headers={
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
},
json={
"inputs": inputs,
"user": user,
"response_mode": response_mode,
"files": files,
'inputs': inputs,
'user': user,
'response_mode': response_mode,
'files': files,
},
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise DifyAPIError(f"{r.status_code} {chunk}")
if chunk.strip() == "":
raise DifyAPIError(f'{r.status_code} {chunk}')
if chunk.strip() == '':
continue
if chunk.startswith("data:"):
if chunk.startswith('data:'):
yield json.loads(chunk[5:])
async def upload_file(
@@ -112,15 +117,15 @@ class AsyncDifyServiceClient:
) as client:
# multipart/form-data
response = await client.post(
"/files/upload",
headers={"Authorization": f"Bearer {self.api_key}"},
'/files/upload',
headers={'Authorization': f'Bearer {self.api_key}'},
files={
"file": file,
"user": (None, user),
'file': file,
'user': (None, user),
},
)
if response.status_code != 201:
raise DifyAPIError(f"{response.status_code} {response.text}")
raise DifyAPIError(f'{response.status_code} {response.text}')
return response.json()

View File

@@ -7,11 +7,11 @@ import os
class TestDifyClient:
async def test_chat_messages(self):
cln = client.DifyClient(api_key=os.getenv("DIFY_API_KEY"))
cln = client.DifyClient(api_key=os.getenv('DIFY_API_KEY'))
resp = await cln.chat_messages(inputs={}, query="Who are you?", user_id="test")
resp = await cln.chat_messages(inputs={}, query='Who are you?', user_id='test')
print(resp)
if __name__ == "__main__":
if __name__ == '__main__':
asyncio.run(TestDifyClient().test_chat_messages())

View File

@@ -1,8 +1,8 @@
import asyncio
import json
import dingtalk_stream
from dingtalk_stream import AckMessage
class EchoTextHandler(dingtalk_stream.ChatbotHandler):
def __init__(self, client):
self.msg_id = ''
@@ -10,6 +10,7 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
self.client = client # 用于更新 DingTalkClient 中的 incoming_message
"""处理钉钉消息"""
async def process(self, callback: dingtalk_stream.CallbackMessage):
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
if incoming_message.message_id != self.msg_id:
@@ -26,6 +27,8 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
return self.incoming_message
async def get_dingtalk_client(client_id, client_secret):
from api import DingTalkClient # 延迟导入,避免循环导入
return DingTalkClient(client_id, client_secret)

View File

@@ -10,7 +10,9 @@ import traceback
class DingTalkClient:
def __init__(self, client_id: str, client_secret: str,robot_name:str,robot_code:str):
def __init__(
self, client_id: str, client_secret: str, robot_name: str, robot_code: str
):
"""初始化 WebSocket 连接并自动启动"""
self.credential = dingtalk_stream.Credential(client_id, client_secret)
self.client = dingtalk_stream.DingTalkStreamClient(self.credential)
@@ -18,106 +20,91 @@ class DingTalkClient:
self.secret = client_secret
# 在 DingTalkClient 中传入自己作为参数,避免循环导入
self.EchoTextHandler = EchoTextHandler(self)
self.client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler)
self.client.register_callback_handler(
dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler
)
self._message_handlers = {
"example":[],
'example': [],
}
self.access_token = ''
self.robot_name = robot_name
self.robot_code = robot_code
self.access_token_expiry_time = ''
async def get_access_token(self):
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
headers = {
"Content-Type": "application/json"
}
data = {
"appKey": self.key,
"appSecret": self.secret
}
url = 'https://api.dingtalk.com/v1.0/oauth2/accessToken'
headers = {'Content-Type': 'application/json'}
data = {'appKey': self.key, 'appSecret': self.secret}
async with httpx.AsyncClient() as client:
try:
response = await client.post(url,json=data,headers=headers)
response = await client.post(url, json=data, headers=headers)
if response.status_code == 200:
response_data = response.json()
self.access_token = response_data.get("accessToken")
expires_in = int(response_data.get("expireIn",7200))
self.access_token = response_data.get('accessToken')
expires_in = int(response_data.get('expireIn', 7200))
self.access_token_expiry_time = time.time() + expires_in - 60
except Exception as e:
raise Exception(e)
async def is_token_expired(self):
"""检查token是否过期"""
if self.access_token_expiry_time is None:
return True
return time.time() > self.access_token_expiry_time
async def check_access_token(self):
if not self.access_token or await self.is_token_expired():
return False
return bool(self.access_token and self.access_token.strip())
async def download_image(self,download_code:str):
async def download_image(self, download_code: str):
if not await self.check_access_token():
await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
params = {
"downloadCode":download_code,
"robotCode":self.robot_code
}
headers ={
"x-acs-dingtalk-access-token": self.access_token
}
params = {'downloadCode': download_code, 'robotCode': self.robot_code}
headers = {'x-acs-dingtalk-access-token': self.access_token}
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=params)
if response.status_code == 200:
result = response.json()
download_url = result.get("downloadUrl")
download_url = result.get('downloadUrl')
else:
raise Exception(f"Error: {response.status_code}, {response.text}")
raise Exception(f'Error: {response.status_code}, {response.text}')
if download_url:
return await self.download_url_to_base64(download_url)
async def download_url_to_base64(self,download_url):
async def download_url_to_base64(self, download_url):
async with httpx.AsyncClient() as client:
response = await client.get(download_url)
if response.status_code == 200:
file_bytes = response.content
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
base64_str = base64.b64encode(file_bytes).decode(
'utf-8'
) # 返回字符串格式
return base64_str
else:
raise Exception("获取文件失败")
async def get_audio_url(self,download_code:str):
raise Exception('获取文件失败')
async def get_audio_url(self, download_code: str):
if not await self.check_access_token():
await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
params = {
"downloadCode":download_code,
"robotCode":self.robot_code
}
headers ={
"x-acs-dingtalk-access-token": self.access_token
}
params = {'downloadCode': download_code, 'robotCode': self.robot_code}
headers = {'x-acs-dingtalk-access-token': self.access_token}
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=params)
if response.status_code == 200:
result = response.json()
download_url = result.get("downloadUrl")
download_url = result.get('downloadUrl')
if download_url:
return await self.download_url_to_base64(download_url)
else:
raise Exception("获取音频失败")
raise Exception('获取音频失败')
else:
raise Exception(f"Error: {response.status_code}, {response.text}")
raise Exception(f'Error: {response.status_code}, {response.text}')
async def update_incoming_message(self, message):
"""异步更新 DingTalkClient 中的 incoming_message"""
message_data = await self.get_message(message)
@@ -125,24 +112,21 @@ class DingTalkClient:
event = DingTalkEvent.from_payload(message_data)
if event:
await self._handle_message(event)
async def send_message(self,content:str,incoming_message):
self.EchoTextHandler.reply_text(content,incoming_message)
async def send_message(self, content: str, incoming_message):
self.EchoTextHandler.reply_text(content, incoming_message)
async def get_incoming_message(self):
"""获取收到的消息"""
return await self.EchoTextHandler.get_incoming_message()
def on_message(self, msg_type: str):
def decorator(func: Callable[[DingTalkEvent], None]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def _handle_message(self, event: DingTalkEvent):
@@ -154,40 +138,44 @@ class DingTalkClient:
for handler in self._message_handlers[msg_type]:
await handler(event)
async def get_message(self,incoming_message:dingtalk_stream.chatbot.ChatbotMessage):
async def get_message(
self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage
):
try:
# print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False))
message_data = {
"IncomingMessage":incoming_message,
'IncomingMessage': incoming_message,
}
if str(incoming_message.conversation_type) == '1':
message_data["conversation_type"] = 'FriendMessage'
message_data['conversation_type'] = 'FriendMessage'
elif str(incoming_message.conversation_type) == '2':
message_data["conversation_type"] = 'GroupMessage'
message_data['conversation_type'] = 'GroupMessage'
if incoming_message.message_type == 'richText':
data = incoming_message.rich_text_content.to_dict()
for item in data['richText']:
if 'text' in item:
message_data["Content"] = item['text']
message_data['Content'] = item['text']
if incoming_message.get_image_list()[0]:
message_data["Picture"] = await self.download_image(incoming_message.get_image_list()[0])
message_data["Type"] = 'text'
message_data['Picture'] = await self.download_image(
incoming_message.get_image_list()[0]
)
message_data['Type'] = 'text'
elif incoming_message.message_type == 'text':
message_data['Content'] = incoming_message.get_text_list()[0]
message_data["Type"] = 'text'
message_data['Type'] = 'text'
elif incoming_message.message_type == 'picture':
message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0])
message_data['Picture'] = await self.download_image(
incoming_message.get_image_list()[0]
)
message_data['Type'] = 'image'
elif incoming_message.message_type == 'audio':
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode'])
message_data['Audio'] = await self.get_audio_url(
incoming_message.to_dict()['content']['downloadCode']
)
message_data['Type'] = 'audio'
@@ -196,56 +184,55 @@ class DingTalkClient:
# print("message_data:", json.dumps(copy_message_data, indent=4, ensure_ascii=False))
except Exception:
traceback.print_exc()
return message_data
async def send_proactive_message_to_one(self,target_id:str,content:str):
async def send_proactive_message_to_one(self, target_id: str, content: str):
if not await self.check_access_token():
await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend'
headers ={
"x-acs-dingtalk-access-token":self.access_token,
"Content-Type":"application/json",
headers = {
'x-acs-dingtalk-access-token': self.access_token,
'Content-Type': 'application/json',
}
data ={
"robotCode":self.robot_code,
"userIds":[target_id],
"msgKey": "sampleText",
"msgParam": json.dumps({"content":content}),
data = {
'robotCode': self.robot_code,
'userIds': [target_id],
'msgKey': 'sampleText',
'msgParam': json.dumps({'content': content}),
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(url,headers=headers,json=data)
await client.post(url, headers=headers, json=data)
except Exception:
traceback.print_exc()
async def send_proactive_message_to_group(self,target_id:str,content:str):
async def send_proactive_message_to_group(self, target_id: str, content: str):
if not await self.check_access_token():
await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/groupMessages/send'
headers ={
"x-acs-dingtalk-access-token":self.access_token,
"Content-Type":"application/json",
headers = {
'x-acs-dingtalk-access-token': self.access_token,
'Content-Type': 'application/json',
}
data ={
"robotCode":self.robot_code,
"openConversationId":target_id,
"msgKey": "sampleText",
"msgParam": json.dumps({"content":content}),
data = {
'robotCode': self.robot_code,
'openConversationId': target_id,
'msgKey': 'sampleText',
'msgParam': json.dumps({'content': content}),
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(url,headers=headers,json=data)
await client.post(url, headers=headers, json=data)
except Exception:
traceback.print_exc()
async def start(self):
"""启动 WebSocket 连接,监听消息"""
await self.client.start()
await self.client.start()

View File

@@ -1,41 +1,39 @@
from typing import Dict, Any, Optional
import dingtalk_stream
class DingTalkEvent(dict):
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["DingTalkEvent"]:
def from_payload(payload: Dict[str, Any]) -> Optional['DingTalkEvent']:
try:
event = DingTalkEvent(payload)
return event
except KeyError:
return None
@property
def content(self):
return self.get("Content","")
@property
def incoming_message(self) -> Optional["dingtalk_stream.chatbot.ChatbotMessage"]:
return self.get("IncomingMessage")
def content(self):
return self.get('Content', '')
@property
def incoming_message(self) -> Optional['dingtalk_stream.chatbot.ChatbotMessage']:
return self.get('IncomingMessage')
@property
def type(self):
return self.get("Type","")
return self.get('Type', '')
@property
def picture(self):
return self.get("Picture","")
return self.get('Picture', '')
@property
def audio(self):
return self.get("Audio","")
return self.get('Audio', '')
@property
def conversation(self):
return self.get("conversation_type","")
return self.get('conversation_type', '')
def __getattr__(self, key: str) -> Optional[Any]:
"""
@@ -66,4 +64,4 @@ class DingTalkEvent(dict):
Returns:
str: 字符串表示。
"""
return f"<DingTalkEvent {super().__repr__()}>"
return f'<DingTalkEvent {super().__repr__()}>'

View File

@@ -1,20 +1,14 @@
# 微信公众号的加解密算法与企业微信一样,所以直接使用企业微信的加解密算法文件
from collections import deque
import time
import traceback
from ..wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt
import xml.etree.ElementTree as ET
from quart import Quart,request
from quart import Quart, request
import hashlib
from typing import Callable, Dict, Any
from typing import Callable
from .oaevent import OAEvent
import httpx
import asyncio
import time
import xml.etree.ElementTree as ET
from pkg.platform.sources import officialaccount as oa
xml_template = """
@@ -28,9 +22,8 @@ xml_template = """
"""
class OAClient():
def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str):
class OAClient:
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str):
self.token = token
self.aes = EncodingAESKey
self.appid = AppID
@@ -38,121 +31,130 @@ class OAClient():
self.base_url = 'https://api.weixin.qq.com'
self.access_token = ''
self.app = Quart(__name__)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = {
"example":[],
'example': [],
}
self.access_token_expiry_time = None
self.msg_id_map = {}
self.generated_content = {}
async def handle_callback_request(self):
try:
# 每隔100毫秒查询是否生成ai回答
start_time = time.time()
signature = request.args.get("signature", "")
timestamp = request.args.get("timestamp", "")
nonce = request.args.get("nonce", "")
echostr = request.args.get("echostr", "")
msg_signature = request.args.get("msg_signature","")
signature = request.args.get('signature', '')
timestamp = request.args.get('timestamp', '')
nonce = request.args.get('nonce', '')
echostr = request.args.get('echostr', '')
msg_signature = request.args.get('msg_signature', '')
if msg_signature is None:
raise Exception("msg_signature不在请求体中")
raise Exception('msg_signature不在请求体中')
if request.method == 'GET':
# 校验签名
check_str = "".join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest()
check_str = ''.join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
if check_signature == signature:
return echostr # 验证成功返回echostr
else:
raise Exception("拒绝请求")
elif request.method == "POST":
raise Exception('拒绝请求')
elif request.method == 'POST':
encryt_msg = await request.data
wxcpt = WXBizMsgCrypt(self.token,self.aes,self.appid)
ret,xml_msg = wxcpt.DecryptMsg(encryt_msg,msg_signature,timestamp,nonce)
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
ret, xml_msg = wxcpt.DecryptMsg(
encryt_msg, msg_signature, timestamp, nonce
)
xml_msg = xml_msg.decode('utf-8')
if ret != 0:
raise Exception("消息解密失败")
raise Exception('消息解密失败')
message_data = await self.get_message(xml_msg)
if message_data :
if message_data:
event = OAEvent.from_payload(message_data)
if event:
await self._handle_message(event)
root = ET.fromstring(xml_msg)
from_user = root.find("FromUserName").text # 发送者
to_user = root.find("ToUserName").text # 机器人
from_user = root.find('FromUserName').text # 发送者
to_user = root.find('ToUserName').text # 机器人
timeout = 4.80
interval = 0.1
while True:
content = self.generated_content.pop(message_data["MsgId"], None)
content = self.generated_content.pop(message_data['MsgId'], None)
if content:
response_xml = xml_template.format(
to_user=from_user,
from_user=to_user,
create_time=int(time.time()),
content = content
content=content,
)
return response_xml
if time.time() - start_time >= timeout:
break
await asyncio.sleep(interval)
if self.msg_id_map.get(message_data["MsgId"], 1) == 3:
if self.msg_id_map.get(message_data['MsgId'], 1) == 3:
# response_xml = xml_template.format(
# to_user=from_user,
# from_user=to_user,
# create_time=int(time.time()),
# content = "请求失效暂不支持公众号超过15秒的请求如有需求请联系 LangBot 团队。"
# )
print("请求失效暂不支持公众号超过15秒的请求如有需求请联系 LangBot 团队。")
print(
'请求失效暂不支持公众号超过15秒的请求如有需求请联系 LangBot 团队。'
)
return ''
except Exception as e:
except Exception:
traceback.print_exc()
async def get_message(self, xml_msg: str):
root = ET.fromstring(xml_msg)
message_data = {
"ToUserName": root.find("ToUserName").text,
"FromUserName": root.find("FromUserName").text,
"CreateTime": int(root.find("CreateTime").text),
"MsgType": root.find("MsgType").text,
"Content": root.find("Content").text if root.find("Content") is not None else None,
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
'ToUserName': root.find('ToUserName').text,
'FromUserName': root.find('FromUserName').text,
'CreateTime': int(root.find('CreateTime').text),
'MsgType': root.find('MsgType').text,
'Content': root.find('Content').text
if root.find('Content') is not None
else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
}
return message_data
async def run_task(self, host: str, port: int, *args, **kwargs):
"""
启动 Quart 应用。
"""
await self.app.run_task(host=host, port=port, *args, **kwargs)
def on_message(self, msg_type: str):
"""
注册消息类型处理器。
"""
def decorator(func: Callable[[OAEvent], None]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def _handle_message(self, event: OAEvent):
@@ -170,14 +172,19 @@ class OAClient():
for handler in self._message_handlers[msg_type]:
await handler(event)
async def set_message(self,msg_id:int,content:str):
async def set_message(self, msg_id: int, content: str):
self.generated_content[msg_id] = content
class OAClientForLongerResponse():
def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str,LoadingMessage:str):
class OAClientForLongerResponse:
def __init__(
self,
token: str,
EncodingAESKey: str,
AppID: str,
Appsecret: str,
LoadingMessage: str,
):
self.token = token
self.aes = EncodingAESKey
self.appid = AppID
@@ -185,9 +192,14 @@ class OAClientForLongerResponse():
self.base_url = 'https://api.weixin.qq.com'
self.access_token = ''
self.app = Quart(__name__)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = {
"example":[],
'example': [],
}
self.access_token_expiry_time = None
self.loading_message = LoadingMessage
@@ -196,50 +208,55 @@ class OAClientForLongerResponse():
async def handle_callback_request(self):
try:
start_time = time.time()
signature = request.args.get("signature", "")
timestamp = request.args.get("timestamp", "")
nonce = request.args.get("nonce", "")
echostr = request.args.get("echostr", "")
msg_signature = request.args.get("msg_signature", "")
signature = request.args.get('signature', '')
timestamp = request.args.get('timestamp', '')
nonce = request.args.get('nonce', '')
echostr = request.args.get('echostr', '')
msg_signature = request.args.get('msg_signature', '')
if msg_signature is None:
raise Exception("msg_signature不在请求体中")
raise Exception('msg_signature不在请求体中')
if request.method == 'GET':
check_str = "".join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest()
return echostr if check_signature == signature else "拒绝请求"
check_str = ''.join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
return echostr if check_signature == signature else '拒绝请求'
elif request.method == "POST":
elif request.method == 'POST':
encryt_msg = await request.data
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
ret, xml_msg = wxcpt.DecryptMsg(
encryt_msg, msg_signature, timestamp, nonce
)
xml_msg = xml_msg.decode('utf-8')
if ret != 0:
raise Exception("消息解密失败")
raise Exception('消息解密失败')
# 解析 XML
root = ET.fromstring(xml_msg)
from_user = root.find("FromUserName").text
to_user = root.find("ToUserName").text
if self.msg_queue.get(from_user) and self.msg_queue[from_user][0]["content"]:
from_user = root.find('FromUserName').text
to_user = root.find('ToUserName').text
if (
self.msg_queue.get(from_user)
and self.msg_queue[from_user][0]['content']
):
queue_top = self.msg_queue[from_user].pop(0)
queue_content = queue_top["content"]
queue_content = queue_top['content']
# 弹出用户消息
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user]:
if (
self.user_msg_queue.get(from_user)
and self.user_msg_queue[from_user]
):
self.user_msg_queue[from_user].pop(0)
response_xml = xml_template.format(
to_user=from_user,
from_user=to_user,
create_time=int(time.time()),
content=queue_content
content=queue_content,
)
return response_xml
@@ -248,65 +265,67 @@ class OAClientForLongerResponse():
to_user=from_user,
from_user=to_user,
create_time=int(time.time()),
content=self.loading_message
content=self.loading_message,
)
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user][0]["content"]:
if (
self.user_msg_queue.get(from_user)
and self.user_msg_queue[from_user][0]['content']
):
return response_xml
else:
message_data = await self.get_message(xml_msg)
if message_data:
event = OAEvent.from_payload(message_data)
if event:
self.user_msg_queue.setdefault(from_user,[]).append(
self.user_msg_queue.setdefault(from_user, []).append(
{
"content":event.message,
'content': event.message,
}
)
await self._handle_message(event)
return response_xml
except Exception as e:
except Exception:
traceback.print_exc()
async def get_message(self, xml_msg: str):
root = ET.fromstring(xml_msg)
message_data = {
"ToUserName": root.find("ToUserName").text,
"FromUserName": root.find("FromUserName").text,
"CreateTime": int(root.find("CreateTime").text),
"MsgType": root.find("MsgType").text,
"Content": root.find("Content").text if root.find("Content") is not None else None,
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
'ToUserName': root.find('ToUserName').text,
'FromUserName': root.find('FromUserName').text,
'CreateTime': int(root.find('CreateTime').text),
'MsgType': root.find('MsgType').text,
'Content': root.find('Content').text
if root.find('Content') is not None
else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
}
return message_data
async def run_task(self, host: str, port: int, *args, **kwargs):
"""
启动 Quart 应用。
"""
await self.app.run_task(host=host, port=port, *args, **kwargs)
def on_message(self, msg_type: str):
"""
注册消息类型处理器。
"""
def decorator(func: Callable[[OAEvent], None]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def _handle_message(self, event: OAEvent):
@@ -319,22 +338,13 @@ class OAClientForLongerResponse():
for handler in self._message_handlers[msg_type]:
await handler(event)
async def set_message(self,from_user:int,message_id:int,content:str):
if from_user not in self.msg_queue:
async def set_message(self, from_user: int, message_id: int, content: str):
if from_user not in self.msg_queue:
self.msg_queue[from_user] = []
self.msg_queue[from_user].append(
{
"msg_id":message_id,
"content":content,
'msg_id': message_id,
'content': content,
}
)

View File

@@ -9,7 +9,7 @@ class OAEvent(dict):
"""
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["OAEvent"]:
def from_payload(payload: Dict[str, Any]) -> Optional['OAEvent']:
"""
从微信公众号事件数据构造 `WecomEvent` 对象。
@@ -34,14 +34,14 @@ class OAEvent(dict):
Returns:
str: 事件类型。
"""
return self.get("MsgType", "")
return self.get('MsgType', '')
@property
def picurl(self) -> str:
"""
图片链接
"""
return self.get("PicUrl","")
return self.get('PicUrl', '')
@property
def detail_type(self) -> str:
@@ -53,8 +53,8 @@ class OAEvent(dict):
Returns:
str: 事件详细类型。
"""
if self.type == "event":
return self.get("Event", "")
if self.type == 'event':
return self.get('Event', '')
return self.type
@property
@@ -65,15 +65,14 @@ class OAEvent(dict):
Returns:
str: 事件名。
"""
return f"{self.type}.{self.detail_type}"
return f'{self.type}.{self.detail_type}'
@property
def user_id(self) -> Optional[str]:
"""
发送方账号
"""
return self.get("FromUserName")
return self.get('FromUserName')
@property
def receiver_id(self) -> Optional[str]:
@@ -83,7 +82,7 @@ class OAEvent(dict):
Returns:
Optional[str]: 接收者 ID。
"""
return self.get("ToUserName")
return self.get('ToUserName')
@property
def message_id(self) -> Optional[str]:
@@ -93,7 +92,7 @@ class OAEvent(dict):
Returns:
Optional[str]: 消息 ID。
"""
return self.get("MsgId")
return self.get('MsgId')
@property
def message(self) -> Optional[str]:
@@ -103,7 +102,7 @@ class OAEvent(dict):
Returns:
Optional[str]: 消息内容。
"""
return self.get("Content")
return self.get('Content')
@property
def media_id(self) -> Optional[str]:
@@ -113,7 +112,7 @@ class OAEvent(dict):
Returns:
Optional[str]: 媒体文件 ID。
"""
return self.get("MediaId")
return self.get('MediaId')
@property
def timestamp(self) -> Optional[int]:
@@ -123,7 +122,7 @@ class OAEvent(dict):
Returns:
Optional[int]: 时间戳。
"""
return self.get("CreateTime")
return self.get('CreateTime')
@property
def event_key(self) -> Optional[str]:
@@ -133,7 +132,7 @@ class OAEvent(dict):
Returns:
Optional[str]: 事件 Key。
"""
return self.get("EventKey")
return self.get('EventKey')
def __getattr__(self, key: str) -> Optional[Any]:
"""
@@ -164,4 +163,4 @@ class OAEvent(dict):
Returns:
str: 字符串表示。
"""
return f"<WecomEvent {super().__repr__()}>"
return f'<WecomEvent {super().__repr__()}>'

View File

@@ -1,24 +1,16 @@
import time
from quart import request
import base64
import binascii
import httpx
from quart import Quart
import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any
from pkg.platform.types import events as platform_events, message as platform_message
import aiofiles
from pkg.platform.types import events as platform_events
from .qqofficialevent import QQOfficialEvent
import json
import hmac
import base64
import hashlib
import traceback
from cryptography.hazmat.primitives.asymmetric import ed25519
from .qqofficialevent import QQOfficialEvent
def handle_validation(body: dict, bot_secret: str):
# bot正确的secert是32位的此处仅为了适配演示demo
while len(bot_secret) < 32:
bot_secret = bot_secret * 2
@@ -36,29 +28,26 @@ def handle_validation(body: dict, bot_secret: str):
signature_hex = signature.hex()
response = {
"plain_token": body['d']['plain_token'],
"signature": signature_hex
}
response = {'plain_token': body['d']['plain_token'], 'signature': signature_hex}
return response
class QQOfficialClient:
def __init__(self, secret: str, token: str, app_id: str):
self.app = Quart(__name__)
self.app.add_url_rule(
"/callback/command",
"handle_callback",
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=["GET", "POST"],
methods=['GET', 'POST'],
)
self.secret = secret
self.token = token
self.app_id = app_id
self._message_handlers = {
}
self.base_url = "https://api.sgroup.qq.com"
self.access_token = ""
self._message_handlers = {}
self.base_url = 'https://api.sgroup.qq.com'
self.access_token = ''
self.access_token_expiry_time = None
async def check_access_token(self):
@@ -66,30 +55,29 @@ class QQOfficialClient:
if not self.access_token or await self.is_token_expired():
return False
return bool(self.access_token and self.access_token.strip())
async def get_access_token(self):
"""获取access_token"""
url = "https://bots.qq.com/app/getAppAccessToken"
url = 'https://bots.qq.com/app/getAppAccessToken'
async with httpx.AsyncClient() as client:
params = {
"appId":self.app_id,
"clientSecret":self.secret,
'appId': self.app_id,
'clientSecret': self.secret,
}
headers = {
"content-type":"application/json",
'content-type': 'application/json',
}
try:
response = await client.post(url,json=params,headers=headers)
response = await client.post(url, json=params, headers=headers)
if response.status_code == 200:
response_data = response.json()
access_token = response_data.get("access_token")
expires_in = int(response_data.get("expires_in",7200))
access_token = response_data.get('access_token')
expires_in = int(response_data.get('expires_in', 7200))
self.access_token_expiry_time = time.time() + expires_in - 60
if access_token:
self.access_token = access_token
except Exception as e:
raise Exception(f"获取access_token失败: {e}")
raise Exception(f'获取access_token失败: {e}')
async def handle_callback_request(self):
"""处理回调请求"""
@@ -98,27 +86,24 @@ class QQOfficialClient:
body = await request.get_data()
payload = json.loads(body)
# 验证是否为回调验证请求
if payload.get("op") == 13:
if payload.get('op') == 13:
# 生成签名
response = handle_validation(payload, self.secret)
return response
if payload.get("op") == 0:
message_data = await self.get_message(payload)
if message_data:
event = QQOfficialEvent.from_payload(message_data)
await self._handle_message(event)
return {"code": 0, "message": "success"}
if payload.get('op') == 0:
message_data = await self.get_message(payload)
if message_data:
event = QQOfficialEvent.from_payload(message_data)
await self._handle_message(event)
return {'code': 0, 'message': 'success'}
except Exception as e:
traceback.print_exc()
return {"error": str(e)}, 400
return {'error': str(e)}, 400
async def run_task(self, host: str, port: int, *args, **kwargs):
"""启动 Quart 应用"""
@@ -135,133 +120,140 @@ class QQOfficialClient:
return decorator
async def _handle_message(self, event:QQOfficialEvent):
async def _handle_message(self, event: QQOfficialEvent):
"""处理消息事件"""
msg_type = event.t
if msg_type in self._message_handlers:
for handler in self._message_handlers[msg_type]:
await handler(event)
async def get_message(self,msg:dict) -> Dict[str,Any]:
async def get_message(self, msg: dict) -> Dict[str, Any]:
"""获取消息"""
message_data = {
"t": msg.get("t",{}),
"user_openid": msg.get("d",{}).get("author",{}).get("user_openid",{}),
"timestamp": msg.get("d",{}).get("timestamp",{}),
"d_author_id": msg.get("d",{}).get("author",{}).get("id",{}),
"content": msg.get("d",{}).get("content",{}),
"d_id": msg.get("d",{}).get("id",{}),
"id": msg.get("id",{}),
"channel_id": msg.get("d",{}).get("channel_id",{}),
"username": msg.get("d",{}).get("author",{}).get("username",{}),
"guild_id": msg.get("d",{}).get("guild_id",{}),
"member_openid": msg.get("d",{}).get("author",{}).get("openid",{}),
"group_openid": msg.get("d",{}).get("group_openid",{})
't': msg.get('t', {}),
'user_openid': msg.get('d', {}).get('author', {}).get('user_openid', {}),
'timestamp': msg.get('d', {}).get('timestamp', {}),
'd_author_id': msg.get('d', {}).get('author', {}).get('id', {}),
'content': msg.get('d', {}).get('content', {}),
'd_id': msg.get('d', {}).get('id', {}),
'id': msg.get('id', {}),
'channel_id': msg.get('d', {}).get('channel_id', {}),
'username': msg.get('d', {}).get('author', {}).get('username', {}),
'guild_id': msg.get('d', {}).get('guild_id', {}),
'member_openid': msg.get('d', {}).get('author', {}).get('openid', {}),
'group_openid': msg.get('d', {}).get('group_openid', {}),
}
attachments = msg.get("d", {}).get("attachments", [])
image_attachments = [attachment['url'] for attachment in attachments if await self.is_image(attachment)]
image_attachments_type = [attachment['content_type'] for attachment in attachments if await self.is_image(attachment)]
attachments = msg.get('d', {}).get('attachments', [])
image_attachments = [
attachment['url']
for attachment in attachments
if await self.is_image(attachment)
]
image_attachments_type = [
attachment['content_type']
for attachment in attachments
if await self.is_image(attachment)
]
if image_attachments:
message_data["image_attachments"] = image_attachments[0]
message_data["content_type"] = image_attachments_type[0]
message_data['image_attachments'] = image_attachments[0]
message_data['content_type'] = image_attachments_type[0]
else:
message_data["image_attachments"] = None
return message_data
message_data['image_attachments'] = None
async def is_image(self,attachment:dict) -> bool:
return message_data
async def is_image(self, attachment: dict) -> bool:
"""判断是否为图片附件"""
content_type = attachment.get("content_type","")
return content_type.startswith("image/")
async def send_private_text_msg(self,user_openid:str,content:str,msg_id:str):
content_type = attachment.get('content_type', '')
return content_type.startswith('image/')
async def send_private_text_msg(self, user_openid: str, content: str, msg_id: str):
"""发送私聊消息"""
if not await self.check_access_token():
await self.get_access_token()
await self.get_access_token()
url = self.base_url + "/v2/users/" + user_openid + "/messages"
url = self.base_url + '/v2/users/' + user_openid + '/messages'
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"QQBot {self.access_token}",
"Content-Type": "application/json",
'Authorization': f'QQBot {self.access_token}',
'Content-Type': 'application/json',
}
data = {
"content": content,
"msg_type": 0,
"msg_id": msg_id,
'content': content,
'msg_type': 0,
'msg_id': msg_id,
}
response = await client.post(url,headers=headers,json=data)
response = await client.post(url, headers=headers, json=data)
if response.status_code == 200:
return
else:
raise ValueError(response)
async def send_group_text_msg(self,group_openid:str,content:str,msg_id:str):
async def send_group_text_msg(self, group_openid: str, content: str, msg_id: str):
"""发送群聊消息"""
if not await self.check_access_token():
await self.get_access_token()
url = self.base_url + "/v2/groups/" + group_openid + "/messages"
url = self.base_url + '/v2/groups/' + group_openid + '/messages'
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"QQBot {self.access_token}",
"Content-Type": "application/json",
'Authorization': f'QQBot {self.access_token}',
'Content-Type': 'application/json',
}
data = {
"content": content,
"msg_type": 0,
"msg_id": msg_id,
'content': content,
'msg_type': 0,
'msg_id': msg_id,
}
response = await client.post(url,headers=headers,json=data)
response = await client.post(url, headers=headers, json=data)
if response.status_code == 200:
return
else:
raise Exception(response.read().decode())
async def send_channle_group_text_msg(self,channel_id:str,content:str,msg_id:str):
async def send_channle_group_text_msg(
self, channel_id: str, content: str, msg_id: str
):
"""发送频道群聊消息"""
if not await self.check_access_token():
await self.get_access_token()
await self.get_access_token()
url = self.base_url + "/channels/" + channel_id + "/messages"
url = self.base_url + '/channels/' + channel_id + '/messages'
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"QQBot {self.access_token}",
"Content-Type": "application/json",
'Authorization': f'QQBot {self.access_token}',
'Content-Type': 'application/json',
}
params = {
"content": content,
"msg_type": 0,
"msg_id": msg_id,
'content': content,
'msg_type': 0,
'msg_id': msg_id,
}
response = await client.post(url,headers=headers,json=params)
response = await client.post(url, headers=headers, json=params)
if response.status_code == 200:
return True
else:
raise Exception(response)
async def send_channle_private_text_msg(self,guild_id:str,content:str,msg_id:str):
async def send_channle_private_text_msg(
self, guild_id: str, content: str, msg_id: str
):
"""发送频道私聊消息"""
if not await self.check_access_token():
await self.get_access_token()
await self.get_access_token()
url = self.base_url + "/dms/" + guild_id + "/messages"
url = self.base_url + '/dms/' + guild_id + '/messages'
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"QQBot {self.access_token}",
"Content-Type": "application/json",
'Authorization': f'QQBot {self.access_token}',
'Content-Type': 'application/json',
}
params = {
"content": content,
"msg_type": 0,
"msg_id": msg_id,
'content': content,
'msg_type': 0,
'msg_id': msg_id,
}
response = await client.post(url,headers=headers,json=params)
response = await client.post(url, headers=headers, json=params)
if response.status_code == 200:
return True
else:

View File

@@ -1,114 +1,112 @@
from typing import Dict, Any, Optional
class QQOfficialEvent(dict):
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["QQOfficialEvent"]:
def from_payload(payload: Dict[str, Any]) -> Optional['QQOfficialEvent']:
try:
event = QQOfficialEvent(payload)
return event
except KeyError:
return None
@property
def t(self) -> str:
"""
事件类型
"""
return self.get("t", "")
return self.get('t', '')
@property
def user_openid(self) -> str:
"""
用户openid
"""
return self.get("user_openid",{})
return self.get('user_openid', {})
@property
def timestamp(self) -> str:
"""
时间戳
"""
return self.get("timestamp",{})
return self.get('timestamp', {})
@property
def d_author_id(self) -> str:
"""
作者id
"""
return self.get("id",{})
return self.get('id', {})
@property
def content(self) -> str:
"""
内容
"""
return self.get("content",'')
return self.get('content', '')
@property
def d_id(self) -> str:
"""
d_id
"""
return self.get("d_id",{})
return self.get('d_id', {})
@property
def id(self) -> str:
"""
消息idmsg_id
"""
return self.get("id",{})
return self.get('id', {})
@property
def channel_id(self) -> str:
"""
频道id
"""
return self.get("channel_id",{})
return self.get('channel_id', {})
@property
def username(self) -> str:
"""
用户名
"""
return self.get("username",{})
return self.get('username', {})
@property
def guild_id(self) -> str:
"""
频道id
"""
return self.get("guild_id",{})
return self.get('guild_id', {})
@property
def member_openid(self) -> str:
"""
成员openid
"""
return self.get("openid",{})
return self.get('openid', {})
@property
def attachments(self) -> str:
"""
附件url
"""
url = self.get("image_attachments", "")
if url and not url.startswith("https://"):
url = "https://" + url
url = self.get('image_attachments', '')
if url and not url.startswith('https://'):
url = 'https://' + url
return url
@property
def group_openid(self) -> str:
"""
群组id
"""
return self.get("group_openid",{})
return self.get('group_openid', {})
@property
def content_type(self) -> str:
"""
文件类型
"""
return self.get("content_type","")
return self.get('content_type', '')

View File

@@ -1,10 +1,11 @@
#!/usr/bin/env python
# -*- encoding:utf-8 -*-
""" 对企业微信发送给企业后台的消息加解密示例代码.
"""对企业微信发送给企业后台的消息加解密示例代码.
@copyright: Copyright (c) 1998-2014 Tencent Inc.
"""
# ------------------------------------------------------------------------
import logging
import base64
@@ -49,7 +50,7 @@ class SHA1:
sortlist = [token, timestamp, nonce, encrypt]
sortlist.sort()
sha = hashlib.sha1()
sha.update("".join(sortlist).encode())
sha.update(''.join(sortlist).encode())
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
except Exception as e:
logger = logging.getLogger()
@@ -75,7 +76,7 @@ class XMLParse:
"""
try:
xml_tree = ET.fromstring(xmltext)
encrypt = xml_tree.find("Encrypt")
encrypt = xml_tree.find('Encrypt')
return ierror.WXBizMsgCrypt_OK, encrypt.text
except Exception as e:
logger = logging.getLogger()
@@ -100,13 +101,13 @@ class XMLParse:
return resp_xml
class PKCS7Encoder():
class PKCS7Encoder:
"""提供基于PKCS7算法的加解密接口"""
block_size = 32
def encode(self, text):
""" 对需要加密的明文进行填充补位
"""对需要加密的明文进行填充补位
@param text: 需要进行填充补位操作的明文
@return: 补齐明文字符串
"""
@@ -134,7 +135,6 @@ class Prpcrypt(object):
"""提供接收和推送给企业微信消息的加解密接口"""
def __init__(self, key):
# self.key = base64.b64decode(key+"=")
self.key = key
# 设置加解密模式为AES的CBC模式
@@ -147,7 +147,12 @@ class Prpcrypt(object):
"""
# 16位随机字符串添加到明文开头
text = text.encode()
text = self.get_random_str() + struct.pack("I", socket.htonl(len(text))) + text + receiveid.encode()
text = (
self.get_random_str()
+ struct.pack('I', socket.htonl(len(text)))
+ text
+ receiveid.encode()
)
# 使用自定义的填充方式对明文进行补位填充
pkcs7 = PKCS7Encoder()
@@ -183,9 +188,9 @@ class Prpcrypt(object):
# plain_text = pkcs7.encode(plain_text)
# 去除16位随机字符串
content = plain_text[16:-pad]
xml_len = socket.ntohl(struct.unpack("I", content[: 4])[0])
xml_content = content[4: xml_len + 4]
from_receiveid = content[xml_len + 4:]
xml_len = socket.ntohl(struct.unpack('I', content[:4])[0])
xml_content = content[4 : xml_len + 4]
from_receiveid = content[xml_len + 4 :]
except Exception as e:
logger = logging.getLogger()
logger.error(e)
@@ -196,7 +201,7 @@ class Prpcrypt(object):
return 0, xml_content
def get_random_str(self):
""" 随机生成16位字符串
"""随机生成16位字符串
@return: 16位字符串
"""
return str(random.randint(1000000000000000, 9999999999999999)).encode()
@@ -206,10 +211,10 @@ class WXBizMsgCrypt(object):
# 构造函数
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
try:
self.key = base64.b64decode(sEncodingAESKey + "=")
self.key = base64.b64decode(sEncodingAESKey + '=')
assert len(self.key) == 32
except:
throw_exception("[error]: EncodingAESKey unvalid !", FormatException)
except Exception:
throw_exception('[error]: EncodingAESKey unvalid !', FormatException)
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
self.m_sToken = sToken
self.m_sReceiveId = sReceiveId

View File

@@ -7,15 +7,22 @@ from quart import Quart
import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any
from .wecomevent import WecomEvent
from pkg.platform.types import events as platform_events, message as platform_message
from pkg.platform.types import message as platform_message
import aiofiles
class WecomClient():
def __init__(self,corpid:str,secret:str,token:str,EncodingAESKey:str,contacts_secret:str):
class WecomClient:
def __init__(
self,
corpid: str,
secret: str,
token: str,
EncodingAESKey: str,
contacts_secret: str,
):
self.corpid = corpid
self.secret = secret
self.access_token_for_contacts =''
self.access_token_for_contacts = ''
self.token = token
self.aes = EncodingAESKey
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
@@ -23,19 +30,26 @@ class WecomClient():
self.secret_for_contacts = contacts_secret
self.app = Quart(__name__)
self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = {
"example":[],
'example': [],
}
#access——token操作
# access——token操作
async def check_access_token(self):
return bool(self.access_token and self.access_token.strip())
async def check_access_token_for_contacts(self):
return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip())
return bool(
self.access_token_for_contacts and self.access_token_for_contacts.strip()
)
async def get_access_token(self,secret):
async def get_access_token(self, secret):
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
async with httpx.AsyncClient() as client:
response = await client.get(url)
@@ -43,146 +57,163 @@ class WecomClient():
if 'access_token' in data:
return data['access_token']
else:
raise Exception(f"未获取access token: {data}")
raise Exception(f'未获取access token: {data}')
async def get_users(self):
if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
self.access_token_for_contacts = await self.get_access_token(
self.secret_for_contacts
)
url = self.base_url+'/user/list_id?access_token='+self.access_token_for_contacts
url = (
self.base_url
+ '/user/list_id?access_token='
+ self.access_token_for_contacts
)
async with httpx.AsyncClient() as client:
params = {
"cursor":"",
"limit":10000,
'cursor': '',
'limit': 10000,
}
response = await client.post(url,json=params)
response = await client.post(url, json=params)
data = response.json()
if data['errcode'] == 0:
dept_users = data['dept_user']
userid = []
for user in dept_users:
userid.append(user["userid"])
userid.append(user['userid'])
return userid
else:
raise Exception("未获取用户")
async def send_to_all(self,content:str,agent_id:int):
if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
raise Exception('未获取用户')
url = self.base_url+'/message/send?access_token='+self.access_token_for_contacts
async def send_to_all(self, content: str, agent_id: int):
if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(
self.secret_for_contacts
)
url = (
self.base_url
+ '/message/send?access_token='
+ self.access_token_for_contacts
)
user_ids = await self.get_users()
user_ids_string = "|".join(user_ids)
user_ids_string = '|'.join(user_ids)
async with httpx.AsyncClient() as client:
params = {
"touser" : user_ids_string,
"msgtype" : "text",
"agentid" : agent_id,
"text" : {
"content" : content,
},
"safe":0,
"enable_id_trans": 0,
"enable_duplicate_check": 0,
"duplicate_check_interval": 1800
'touser': user_ids_string,
'msgtype': 'text',
'agentid': agent_id,
'text': {
'content': content,
},
'safe': 0,
'enable_id_trans': 0,
'enable_duplicate_check': 0,
'duplicate_check_interval': 1800,
}
response = await client.post(url,json=params)
response = await client.post(url, json=params)
data = response.json()
if data['errcode'] != 0:
raise Exception("Failed to send message: "+str(data))
raise Exception('Failed to send message: ' + str(data))
async def send_image(self,user_id:str,agent_id:int,media_id:str):
async def send_image(self, user_id: str, agent_id: int, media_id: str):
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = self.base_url+'/media/upload?access_token='+self.access_token
url = self.base_url + '/media/upload?access_token=' + self.access_token
async with httpx.AsyncClient() as client:
params = {
"touser" : user_id,
"toparty" : "",
"totag":"",
"agentid" : agent_id,
"msgtype" : "image",
"image" : {
"media_id" : media_id,
'touser': user_id,
'toparty': '',
'totag': '',
'agentid': agent_id,
'msgtype': 'image',
'image': {
'media_id': media_id,
},
"safe":0,
"enable_id_trans": 0,
"enable_duplicate_check": 0,
"duplicate_check_interval": 1800
'safe': 0,
'enable_id_trans': 0,
'enable_duplicate_check': 0,
'duplicate_check_interval': 1800,
}
try:
response = await client.post(url,json=params)
response = await client.post(url, json=params)
data = response.json()
except Exception as e:
raise Exception("Failed to send image: "+str(e))
raise Exception('Failed to send image: ' + str(e))
# 企业微信错误码40014和42001代表accesstoken问题
if data['errcode'] == 40014 or data['errcode'] == 42001:
self.access_token = await self.get_access_token(self.secret)
return await self.send_image(user_id,agent_id,media_id)
return await self.send_image(user_id, agent_id, media_id)
if data['errcode'] != 0:
raise Exception("Failed to send image: "+str(data))
async def send_private_msg(self,user_id:str, agent_id:int,content:str):
raise Exception('Failed to send image: ' + str(data))
async def send_private_msg(self, user_id: str, agent_id: int, content: str):
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = self.base_url+'/message/send?access_token='+self.access_token
url = self.base_url + '/message/send?access_token=' + self.access_token
async with httpx.AsyncClient() as client:
params={
"touser" : user_id,
"msgtype" : "text",
"agentid" : agent_id,
"text" : {
"content" : content,
params = {
'touser': user_id,
'msgtype': 'text',
'agentid': agent_id,
'text': {
'content': content,
},
"safe":0,
"enable_id_trans": 0,
"enable_duplicate_check": 0,
"duplicate_check_interval": 1800
'safe': 0,
'enable_id_trans': 0,
'enable_duplicate_check': 0,
'duplicate_check_interval': 1800,
}
response = await client.post(url,json=params)
response = await client.post(url, json=params)
data = response.json()
if data['errcode'] == 40014 or data['errcode'] == 42001:
self.access_token = await self.get_access_token(self.secret)
return await self.send_private_msg(user_id,agent_id,content)
return await self.send_private_msg(user_id, agent_id, content)
if data['errcode'] != 0:
raise Exception("Failed to send message: "+str(data))
raise Exception('Failed to send message: ' + str(data))
async def handle_callback_request(self):
"""
处理回调请求,包括 GET 验证和 POST 消息接收。
"""
try:
msg_signature = request.args.get('msg_signature')
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
msg_signature = request.args.get("msg_signature")
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
if request.method == "GET":
echostr = request.args.get("echostr")
ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if request.method == 'GET':
echostr = request.args.get('echostr')
ret, reply_echo_str = self.wxcpt.VerifyURL(
msg_signature, timestamp, nonce, echostr
)
if ret != 0:
raise Exception(f"验证失败,错误码: {ret}")
raise Exception(f'验证失败,错误码: {ret}')
return reply_echo_str
elif request.method == "POST":
elif request.method == 'POST':
encrypt_msg = await request.data
ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
ret, xml_msg = self.wxcpt.DecryptMsg(
encrypt_msg, msg_signature, timestamp, nonce
)
if ret != 0:
raise Exception(f"消息解密失败,错误码: {ret}")
raise Exception(f'消息解密失败,错误码: {ret}')
# 解析消息并处理
message_data = await self.get_message(xml_msg)
if message_data:
event = WecomEvent.from_payload(message_data) # 转换为 WecomEvent 对象
event = WecomEvent.from_payload(
message_data
) # 转换为 WecomEvent 对象
if event:
await self._handle_message(event)
return "success"
return 'success'
except Exception as e:
return f"Error processing request: {str(e)}", 400
return f'Error processing request: {str(e)}', 400
async def run_task(self, host: str, port: int, *args, **kwargs):
"""
@@ -194,11 +225,13 @@ class WecomClient():
"""
注册消息类型处理器。
"""
def decorator(func: Callable[[WecomEvent], None]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def _handle_message(self, event: WecomEvent):
@@ -216,38 +249,47 @@ class WecomClient():
"""
root = ET.fromstring(xml_msg)
message_data = {
"ToUserName": root.find("ToUserName").text,
"FromUserName": root.find("FromUserName").text,
"CreateTime": int(root.find("CreateTime").text),
"MsgType": root.find("MsgType").text,
"Content": root.find("Content").text if root.find("Content") is not None else None,
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
"AgentID": int(root.find("AgentID").text) if root.find("AgentID") is not None else None,
'ToUserName': root.find('ToUserName').text,
'FromUserName': root.find('FromUserName').text,
'CreateTime': int(root.find('CreateTime').text),
'MsgType': root.find('MsgType').text,
'Content': root.find('Content').text
if root.find('Content') is not None
else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
'AgentID': int(root.find('AgentID').text)
if root.find('AgentID') is not None
else None,
}
if message_data["MsgType"] == "image":
message_data["MediaId"] = root.find("MediaId").text if root.find("MediaId") is not None else None
message_data["PicUrl"] = root.find("PicUrl").text if root.find("PicUrl") is not None else None
if message_data['MsgType'] == 'image':
message_data['MediaId'] = (
root.find('MediaId').text if root.find('MediaId') is not None else None
)
message_data['PicUrl'] = (
root.find('PicUrl').text if root.find('PicUrl') is not None else None
)
return message_data
@staticmethod
async def get_image_type(image_bytes: bytes) -> str:
"""
通过图片的magic numbers判断图片类型
"""
magic_numbers = {
b'\xFF\xD8\xFF': 'jpg',
b'\x89\x50\x4E\x47': 'png',
b'\xff\xd8\xff': 'jpg',
b'\x89\x50\x4e\x47': 'png',
b'\x47\x49\x46': 'gif',
b'\x42\x4D': 'bmp',
b'\x00\x00\x01\x00': 'ico'
b'\x42\x4d': 'bmp',
b'\x00\x00\x01\x00': 'ico',
}
for magic, ext in magic_numbers.items():
if image_bytes.startswith(magic):
return ext
return 'jpg' # 默认返回jpg
async def upload_to_work(self, image: platform_message.Image):
"""
@@ -256,9 +298,14 @@ class WecomClient():
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
url = (
self.base_url
+ '/media/upload?access_token='
+ self.access_token
+ '&type=file'
)
file_bytes = None
file_name = "uploaded_file.txt"
file_name = 'uploaded_file.txt'
# 获取文件的二进制数据
if image.path:
@@ -277,20 +324,22 @@ class WecomClient():
padded_base64 = base64_data + '=' * padding
file_bytes = base64.b64decode(padded_base64)
except binascii.Error as e:
raise ValueError(f"Invalid base64 string: {str(e)}")
raise ValueError(f'Invalid base64 string: {str(e)}')
else:
raise ValueError("image对象出错")
raise ValueError('image对象出错')
# 设置 multipart/form-data 格式的文件
boundary = "-------------------------acebdf13572468"
headers = {
'Content-Type': f'multipart/form-data; boundary={boundary}'
}
boundary = '-------------------------acebdf13572468'
headers = {'Content-Type': f'multipart/form-data; boundary={boundary}'}
body = (
f"--{boundary}\r\n"
f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n"
f"Content-Type: application/octet-stream\r\n\r\n"
).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8')
(
f'--{boundary}\r\n'
f'Content-Disposition: form-data; name="media"; filename="{file_name}"; filelength={len(file_bytes)}\r\n'
f'Content-Type: application/octet-stream\r\n\r\n'
).encode('utf-8')
+ file_bytes
+ f'\r\n--{boundary}--\r\n'.encode('utf-8')
)
# 上传文件
async with httpx.AsyncClient() as client:
@@ -300,19 +349,18 @@ class WecomClient():
self.access_token = await self.get_access_token(self.secret)
media_id = await self.upload_to_work(image)
if data.get('errcode', 0) != 0:
raise Exception("failed to upload file")
raise Exception('failed to upload file')
media_id = data.get('media_id')
return media_id
async def download_image_to_bytes(self,url:str) -> bytes:
async def download_image_to_bytes(self, url: str) -> bytes:
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
return response.content
#进行media_id的获取
# 进行media_id的获取
async def get_media_id(self, image: platform_message.Image):
media_id = await self.upload_to_work(image=image)
return media_id

View File

@@ -4,7 +4,7 @@
# Author: jonyqin
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
# File Name: ierror.py
# Description:定义错误码含义
# Description:定义错误码含义
#########################################################################
WXBizMsgCrypt_OK = 0
WXBizMsgCrypt_ValidateSignature_Error = -40001
@@ -17,4 +17,4 @@ WXBizMsgCrypt_DecryptAES_Error = -40007
WXBizMsgCrypt_IllegalBuffer = -40008
WXBizMsgCrypt_EncodeBase64_Error = -40009
WXBizMsgCrypt_DecodeBase64_Error = -40010
WXBizMsgCrypt_GenReturnXml_Error = -40011
WXBizMsgCrypt_GenReturnXml_Error = -40011

View File

@@ -9,7 +9,7 @@ class WecomEvent(dict):
"""
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["WecomEvent"]:
def from_payload(payload: Dict[str, Any]) -> Optional['WecomEvent']:
"""
从企业微信事件数据构造 `WecomEvent` 对象。
@@ -34,14 +34,14 @@ class WecomEvent(dict):
Returns:
str: 事件类型。
"""
return self.get("MsgType", "")
return self.get('MsgType', '')
@property
def picurl(self) -> str:
"""
图片链接
"""
return self.get("PicUrl")
return self.get('PicUrl')
@property
def detail_type(self) -> str:
@@ -53,8 +53,8 @@ class WecomEvent(dict):
Returns:
str: 事件详细类型。
"""
if self.type == "event":
return self.get("Event", "")
if self.type == 'event':
return self.get('Event', '')
return self.type
@property
@@ -65,7 +65,7 @@ class WecomEvent(dict):
Returns:
str: 事件名。
"""
return f"{self.type}.{self.detail_type}"
return f'{self.type}.{self.detail_type}'
@property
def user_id(self) -> Optional[str]:
@@ -75,8 +75,8 @@ class WecomEvent(dict):
Returns:
Optional[str]: 用户 ID。
"""
return self.get("FromUserName")
return self.get('FromUserName')
@property
def agent_id(self) -> Optional[int]:
"""
@@ -85,7 +85,7 @@ class WecomEvent(dict):
Returns:
Optional[int]: 机器人 ID。
"""
return self.get("AgentID")
return self.get('AgentID')
@property
def receiver_id(self) -> Optional[str]:
@@ -95,7 +95,7 @@ class WecomEvent(dict):
Returns:
Optional[str]: 接收者 ID。
"""
return self.get("ToUserName")
return self.get('ToUserName')
@property
def message_id(self) -> Optional[str]:
@@ -105,7 +105,7 @@ class WecomEvent(dict):
Returns:
Optional[str]: 消息 ID。
"""
return self.get("MsgId")
return self.get('MsgId')
@property
def message(self) -> Optional[str]:
@@ -115,7 +115,7 @@ class WecomEvent(dict):
Returns:
Optional[str]: 消息内容。
"""
return self.get("Content")
return self.get('Content')
@property
def media_id(self) -> Optional[str]:
@@ -125,7 +125,7 @@ class WecomEvent(dict):
Returns:
Optional[str]: 媒体文件 ID。
"""
return self.get("MediaId")
return self.get('MediaId')
@property
def timestamp(self) -> Optional[int]:
@@ -135,7 +135,7 @@ class WecomEvent(dict):
Returns:
Optional[int]: 时间戳。
"""
return self.get("CreateTime")
return self.get('CreateTime')
@property
def event_key(self) -> Optional[str]:
@@ -145,7 +145,7 @@ class WecomEvent(dict):
Returns:
Optional[str]: 事件 Key。
"""
return self.get("EventKey")
return self.get('EventKey')
def __getattr__(self, key: str) -> Optional[Any]:
"""
@@ -176,4 +176,4 @@ class WecomEvent(dict):
Returns:
str: 字符串表示。
"""
return f"<WecomEvent {super().__repr__()}>"
return f'<WecomEvent {super().__repr__()}>'

29
main.py
View File

@@ -1,3 +1,4 @@
import asyncio
# LangBot 终端启动入口
# 在此层级解决依赖项检查。
# LangBot/main.py
@@ -14,9 +15,6 @@ asciiart = r"""
"""
import asyncio
async def main_entry(loop: asyncio.AbstractEventLoop):
print(asciiart)
@@ -29,20 +27,22 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
missing_deps = await deps.check_deps()
if missing_deps:
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
print('以下依赖包未安装,将自动安装,请完成后重启程序:')
for dep in missing_deps:
print("-", dep)
print('-', dep)
await deps.install_deps(missing_deps)
print("已自动安装缺失的依赖包,请重启程序。")
print('已自动安装缺失的依赖包,请重启程序。')
sys.exit(0)
# check plugin deps
await deps.precheck_plugin_deps()
# 检查pydantic版本如果没有 pydantic.v1则把 pydantic 映射为 v1
import pydantic.version
if pydantic.version.VERSION < '2.0':
import pydantic
sys.modules['pydantic.v1'] = pydantic
# 检查配置文件
@@ -52,11 +52,12 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
generated_files = await files.generate_files()
if generated_files:
print("以下文件不存在,已自动生成:")
print('以下文件不存在,已自动生成:')
for file in generated_files:
print("-", file)
print('-', file)
from pkg.core import boot
await boot.main(loop)
@@ -66,8 +67,8 @@ if __name__ == '__main__':
# 必须大于 3.10.1
if sys.version_info < (3, 10, 1):
print("需要 Python 3.10.1 及以上版本,当前 Python 版本为:", sys.version)
input("按任意键退出...")
print('需要 Python 3.10.1 及以上版本,当前 Python 版本为:', sys.version)
input('按任意键退出...')
exit(1)
# 检查本目录是否有main.py且包含LangBot字符串
@@ -78,11 +79,11 @@ if __name__ == '__main__':
else:
with open('main.py', 'r', encoding='utf-8') as f:
content = f.read()
if "LangBot/main.py" not in content:
if 'LangBot/main.py' not in content:
invalid_pwd = True
if invalid_pwd:
print("请在 LangBot 项目根目录下以命令形式运行此程序。")
input("按任意键退出...")
print('请在 LangBot 项目根目录下以命令形式运行此程序。')
input('按任意键退出...')
exit(1)
loop = asyncio.new_event_loop()

View File

@@ -13,6 +13,7 @@ from ....core import app
preregistered_groups: list[type[RouterGroup]] = []
"""RouterGroup 的预注册列表"""
def group_class(name: str, path: str) -> None:
"""注册一个 RouterGroup"""
@@ -27,12 +28,12 @@ def group_class(name: str, path: str) -> None:
class AuthType(enum.Enum):
"""认证类型"""
NONE = 'none'
USER_TOKEN = 'user-token'
class RouterGroup(abc.ABC):
name: str
path: str
@@ -49,17 +50,24 @@ class RouterGroup(abc.ABC):
async def initialize(self) -> None:
pass
def route(self, rule: str, auth_type: AuthType = AuthType.USER_TOKEN, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
def route(
self,
rule: str,
auth_type: AuthType = AuthType.USER_TOKEN,
**options: typing.Any,
) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
"""注册一个路由"""
def decorator(f: RouteCallable) -> RouteCallable:
nonlocal rule
rule = self.path + rule
async def handler_error(*args, **kwargs):
if auth_type == AuthType.USER_TOKEN:
# 从Authorization头中获取token
token = quart.request.headers.get('Authorization', '').replace('Bearer ', '')
token = quart.request.headers.get('Authorization', '').replace(
'Bearer ', ''
)
if not token:
return self.http_status(401, -1, '未提供有效的用户令牌')
@@ -75,11 +83,11 @@ class RouterGroup(abc.ABC):
try:
return await f(*args, **kwargs)
except Exception as e: # 自动 500
except Exception: # 自动 500
traceback.print_exc()
# return self.http_status(500, -2, str(e))
return self.http_status(500, -2, 'internal server error')
new_f = handler_error
new_f.__name__ = (self.name + rule).replace('/', '__')
new_f.__doc__ = f.__doc__
@@ -91,20 +99,24 @@ class RouterGroup(abc.ABC):
def success(self, data: typing.Any = None) -> quart.Response:
"""返回一个 200 响应"""
return quart.jsonify({
'code': 0,
'msg': 'ok',
'data': data,
})
return quart.jsonify(
{
'code': 0,
'msg': 'ok',
'data': data,
}
)
def fail(self, code: int, msg: str) -> quart.Response:
"""返回一个异常响应"""
return quart.jsonify({
'code': code,
'msg': msg,
})
return quart.jsonify(
{
'code': code,
'msg': msg,
}
)
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
"""返回一个指定状态码的响应"""
return self.fail(code, msg), status

View File

@@ -1,32 +1,29 @@
from __future__ import annotations
import traceback
import quart
from .....core import app
from .. import group
@group.group_class('logs', '/api/v1/logs')
class LogsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
start_page_number = int(quart.request.args.get('start_page_number', 0))
start_offset = int(quart.request.args.get('start_offset', 0))
logs_str, end_page_number, end_offset = self.ap.log_cache.get_log_by_pointer(
start_page_number=start_page_number,
start_offset=start_offset
logs_str, end_page_number, end_offset = (
self.ap.log_cache.get_log_by_pointer(
start_page_number=start_page_number, start_offset=start_offset
)
)
return self.success(
data={
"logs": logs_str,
"end_page_number": end_page_number,
"end_offset": end_offset
'logs': logs_str,
'end_page_number': end_page_number,
'end_offset': end_offset,
}
)

View File

@@ -3,46 +3,41 @@ from __future__ import annotations
import quart
from .. import group
from .....entity.persistence import pipeline
@group.group_class('pipelines', '/api/v1/pipelines')
class PipelinesRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET', 'POST'])
async def _() -> str:
if quart.request.method == 'GET':
return self.success(data={
'pipelines': await self.ap.pipeline_service.get_pipelines()
})
return self.success(
data={'pipelines': await self.ap.pipeline_service.get_pipelines()}
)
elif quart.request.method == 'POST':
json_data = await quart.request.json
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(json_data)
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(
json_data
)
return self.success(data={
'uuid': pipeline_uuid
})
return self.success(data={'uuid': pipeline_uuid})
@self.route('/_/metadata', methods=['GET'])
async def _() -> str:
return self.success(data={
'configs': await self.ap.pipeline_service.get_pipeline_metadata()
})
return self.success(
data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()}
)
@self.route('/<pipeline_uuid>', methods=['GET', 'PUT', 'DELETE'])
async def _(pipeline_uuid: str) -> str:
if quart.request.method == 'GET':
pipeline = await self.ap.pipeline_service.get_pipeline(pipeline_uuid)
if pipeline is None:
return self.http_status(404, -1, 'pipeline not found')
return self.success(data={
'pipeline': pipeline
})
return self.success(data={'pipeline': pipeline})
elif quart.request.method == 'PUT':
json_data = await quart.request.json
@@ -53,4 +48,3 @@ class PipelinesRouterGroup(group.RouterGroup):
await self.ap.pipeline_service.delete_pipeline(pipeline_uuid)
return self.success()

View File

@@ -5,29 +5,31 @@ from ... import group
@group.group_class('adapters', '/api/v1/platform/adapters')
class AdaptersRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> str:
return self.success(data={
'adapters': self.ap.platform_mgr.get_available_adapters_info()
})
return self.success(
data={'adapters': self.ap.platform_mgr.get_available_adapters_info()}
)
@self.route('/<adapter_name>', methods=['GET'])
async def _(adapter_name: str) -> str:
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(adapter_name)
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(
adapter_name
)
if adapter_info is None:
return self.http_status(404, -1, 'adapter not found')
return self.success(data={
'adapter': adapter_info
})
return self.success(data={'adapter': adapter_info})
@self.route('/<adapter_name>/icon', methods=['GET'])
async def _(adapter_name: str) -> quart.Response:
adapter_manifest = self.ap.platform_mgr.get_available_adapter_manifest_by_name(adapter_name)
adapter_manifest = (
self.ap.platform_mgr.get_available_adapter_manifest_by_name(
adapter_name
)
)
if adapter_manifest is None:
return self.http_status(404, -1, 'adapter not found')
@@ -37,4 +39,4 @@ class AdaptersRouterGroup(group.RouterGroup):
if icon_path is None:
return self.http_status(404, -1, 'icon not found')
return await quart.send_file(icon_path)
return await quart.send_file(icon_path)

View File

@@ -5,34 +5,27 @@ from ... import group
@group.group_class('bots', '/api/v1/platform/bots')
class BotsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET', 'POST'])
async def _() -> str:
if quart.request.method == 'GET':
return self.success(data={
'bots': await self.ap.bot_service.get_bots()
})
return self.success(data={'bots': await self.ap.bot_service.get_bots()})
elif quart.request.method == 'POST':
json_data = await quart.request.json
bot_uuid = await self.ap.bot_service.create_bot(json_data)
return self.success(data={
'uuid': bot_uuid
})
return self.success(data={'uuid': bot_uuid})
@self.route('/<bot_uuid>', methods=['GET', 'PUT', 'DELETE'])
async def _(bot_uuid: str) -> str:
if quart.request.method == 'GET':
bot = await self.ap.bot_service.get_bot(bot_uuid)
if bot is None:
return self.http_status(404, -1, 'bot not found')
return self.success(data={
'bot': bot
})
return self.success(data={'bot': bot})
elif quart.request.method == 'PUT':
json_data = await quart.request.json
await self.ap.bot_service.update_bot(bot_uuid, json_data)
return self.success()
elif quart.request.method == 'DELETE':
await self.ap.bot_service.delete_bot(bot_uuid)
return self.success()
return self.success()

View File

@@ -1,17 +1,14 @@
from __future__ import annotations
import traceback
import quart
from .....core import app, taskmgr
from .....core import taskmgr
from .. import group
@group.group_class('plugins', '/api/v1/plugins')
class PluginsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
@@ -19,63 +16,69 @@ class PluginsRouterGroup(group.RouterGroup):
plugins_data = [plugin.model_dump() for plugin in plugins]
return self.success(data={
'plugins': plugins_data
})
@self.route('/<author>/<plugin_name>/toggle', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
return self.success(data={'plugins': plugins_data})
@self.route(
'/<author>/<plugin_name>/toggle',
methods=['PUT'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str:
data = await quart.request.json
target_enabled = data.get('target_enabled')
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
return self.success()
@self.route('/<author>/<plugin_name>/update', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
@self.route(
'/<author>/<plugin_name>/update',
methods=['POST'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str:
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
kind="plugin-operation",
name=f"plugin-update-{plugin_name}",
label=f"更新插件 {plugin_name}",
context=ctx
kind='plugin-operation',
name=f'plugin-update-{plugin_name}',
label=f'更新插件 {plugin_name}',
context=ctx,
)
return self.success(data={
'task_id': wrapper.id
})
@self.route('/<author>/<plugin_name>', methods=['GET', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
return self.success(data={'task_id': wrapper.id})
@self.route(
'/<author>/<plugin_name>',
methods=['GET', 'DELETE'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str:
if quart.request.method == 'GET':
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
if plugin is None:
return self.http_status(404, -1, 'plugin not found')
return self.success(data={
'plugin': plugin.model_dump()
})
return self.success(data={'plugin': plugin.model_dump()})
elif quart.request.method == 'DELETE':
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
kind="plugin-operation",
kind='plugin-operation',
name=f'plugin-remove-{plugin_name}',
label=f'删除插件 {plugin_name}',
context=ctx
context=ctx,
)
return self.success(data={
'task_id': wrapper.id
})
@self.route('/<author>/<plugin_name>/config', methods=['GET', 'PUT'], auth_type=group.AuthType.USER_TOKEN)
return self.success(data={'task_id': wrapper.id})
@self.route(
'/<author>/<plugin_name>/config',
methods=['GET', 'PUT'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> quart.Response:
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
if plugin is None:
return self.http_status(404, -1, 'plugin not found')
if quart.request.method == 'GET':
return self.success(data={
'config': plugin.plugin_config
})
return self.success(data={'config': plugin.plugin_config})
elif quart.request.method == 'PUT':
data = await quart.request.json
@@ -88,21 +91,21 @@ class PluginsRouterGroup(group.RouterGroup):
data = await quart.request.json
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
return self.success()
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
@self.route(
'/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
)
async def _() -> str:
data = await quart.request.json
ctx = taskmgr.TaskContext.new()
short_source_str = data['source'][-8:]
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
kind="plugin-operation",
name=f'plugin-install-github',
kind='plugin-operation',
name='plugin-install-github',
label=f'安装插件 ...{short_source_str}',
context=ctx
context=ctx,
)
return self.success(data={
'task_id': wrapper.id
})
return self.success(data={'task_id': wrapper.id})

View File

@@ -1,28 +1,23 @@
import quart
import uuid
from ... import group
from ......entity.persistence import model
@group.group_class('models/llm', '/api/v1/provider/models/llm')
class LLMModelsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET', 'POST'])
async def _() -> str:
if quart.request.method == 'GET':
return self.success(data={
'models': await self.ap.model_service.get_llm_models()
})
return self.success(
data={'models': await self.ap.model_service.get_llm_models()}
)
elif quart.request.method == 'POST':
json_data = await quart.request.json
model_uuid = await self.ap.model_service.create_llm_model(json_data)
return self.success(data={
'uuid': model_uuid
})
return self.success(data={'uuid': model_uuid})
@self.route('/<model_uuid>', methods=['GET', 'DELETE'])
async def _(model_uuid: str) -> str:
@@ -32,9 +27,7 @@ class LLMModelsRouterGroup(group.RouterGroup):
if model is None:
return self.http_status(404, -1, 'model not found')
return self.success(data={
'model': model
})
return self.success(data={'model': model})
# elif quart.request.method == 'PUT':
# json_data = await quart.request.json

View File

@@ -5,29 +5,31 @@ from ... import group
@group.group_class('provider/requesters', '/api/v1/provider/requesters')
class RequestersRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> quart.Response:
return self.success(data={
'requesters': self.ap.model_mgr.get_available_requesters_info()
})
return self.success(
data={'requesters': self.ap.model_mgr.get_available_requesters_info()}
)
@self.route('/<requester_name>', methods=['GET'])
async def _(requester_name: str) -> quart.Response:
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(requester_name)
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(
requester_name
)
if requester_info is None:
return self.http_status(404, -1, 'requester not found')
return self.success(data={
'requester': requester_info
})
return self.success(data={'requester': requester_info})
@self.route('/<requester_name>/icon', methods=['GET'])
async def _(requester_name: str) -> quart.Response:
requester_manifest = self.ap.model_mgr.get_available_requester_manifest_by_name(requester_name)
requester_manifest = (
self.ap.model_mgr.get_available_requester_manifest_by_name(
requester_name
)
)
if requester_manifest is None:
return self.http_status(404, -1, 'requester not found')

View File

@@ -1,23 +1,21 @@
import quart
import asyncio
from .....core import app, taskmgr
from .. import group
@group.group_class('stats', '/api/v1/stats')
class StatsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/basic', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
conv_count = 0
for session in self.ap.sess_mgr.session_list:
conv_count += len(session.conversations if session.conversations is not None else [])
conv_count += len(
session.conversations if session.conversations is not None else []
)
return self.success(data={
'active_session_count': len(self.ap.sess_mgr.session_list),
'conversation_count': conv_count,
'query_count': self.ap.query_pool.query_id_counter,
})
return self.success(
data={
'active_session_count': len(self.ap.sess_mgr.session_list),
'conversation_count': conv_count,
'query_count': self.ap.query_pool.query_id_counter,
}
)

View File

@@ -1,63 +1,62 @@
import quart
import asyncio
from .....core import app, taskmgr
from .. import group
from .....utils import constants
@group.group_class('system', '/api/v1/system')
class SystemRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE)
async def _() -> str:
return self.success(
data={
"version": constants.semantic_version,
"debug": constants.debug_mode,
"enabled_platform_count": len(self.ap.platform_mgr.get_running_adapters())
'version': constants.semantic_version,
'debug': constants.debug_mode,
'enabled_platform_count': len(
self.ap.platform_mgr.get_running_adapters()
),
}
)
@self.route('/tasks', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
task_type = quart.request.args.get("type")
task_type = quart.request.args.get('type')
if task_type == '':
task_type = None
return self.success(
data=self.ap.task_mgr.get_tasks_dict(task_type)
)
@self.route('/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
return self.success(data=self.ap.task_mgr.get_tasks_dict(task_type))
@self.route(
'/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
)
async def _(task_id: str) -> str:
task = self.ap.task_mgr.get_task_by_id(int(task_id))
if task is None:
return self.http_status(404, 404, "Task not found")
return self.http_status(404, 404, 'Task not found')
return self.success(data=task.to_dict())
@self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
json_data = await quart.request.json
scope = json_data.get("scope")
scope = json_data.get('scope')
await self.ap.reload(
scope=scope
)
await self.ap.reload(scope=scope)
return self.success()
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
@self.route(
'/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, "Forbidden")
return self.http_status(403, 403, 'Forbidden')
py_code = await quart.request.data
ap = self.ap
return self.success(data=exec(py_code, {"ap": ap}))
return self.success(data=exec(py_code, {'ap': ap}))

View File

@@ -1,22 +1,19 @@
import quart
import jwt
import argon2
from .. import group
from .....entity.persistence import user
@group.group_class('user', '/api/v1/user')
class UserRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
if quart.request.method == 'GET':
return self.success(data={
'initialized': await self.ap.user_service.is_initialized()
})
return self.success(
data={'initialized': await self.ap.user_service.is_initialized()}
)
if await self.ap.user_service.is_initialized():
return self.fail(1, '系统已初始化')
@@ -28,24 +25,24 @@ class UserRouterGroup(group.RouterGroup):
await self.ap.user_service.create_user(user_email, password)
return self.success()
@self.route('/auth', methods=['POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
json_data = await quart.request.json
try:
token = await self.ap.user_service.authenticate(json_data['user'], json_data['password'])
token = await self.ap.user_service.authenticate(
json_data['user'], json_data['password']
)
except argon2.exceptions.VerifyMismatchError:
return self.fail(1, '用户名或密码错误')
return self.success(data={
'token': token
})
return self.success(data={'token': token})
@self.route('/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
@self.route(
'/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
)
async def _(user_email: str) -> str:
token = await self.ap.user_service.generate_jwt_token(user_email)
return self.success(data={
'token': token
})
return self.success(data={'token': token})

View File

@@ -7,15 +7,19 @@ import quart
import quart_cors
from ....core import app, entities as core_entities
from ....utils import importutil
from .groups import logs, system, plugins, stats, user, pipelines
from .groups.provider import models, requesters
from .groups.platform import bots, adapters
from . import groups
from . import group
from .groups import provider as groups_provider
from .groups import platform as groups_platform
importutil.import_modules_in_pkg(groups)
importutil.import_modules_in_pkg(groups_provider)
importutil.import_modules_in_pkg(groups_platform)
class HTTPController:
ap: app.Application
quart_app: quart.Quart
@@ -23,7 +27,7 @@ class HTTPController:
def __init__(self, ap: app.Application) -> None:
self.ap = ap
self.quart_app = quart.Quart(__name__)
quart_cors.cors(self.quart_app, allow_origin="*")
quart_cors.cors(self.quart_app, allow_origin='*')
async def initialize(self) -> None:
await self.register_routes()
@@ -37,11 +41,9 @@ class HTTPController:
async def exception_handler(*args, **kwargs):
try:
await self.quart_app.run_task(
*args, **kwargs
)
await self.quart_app.run_task(*args, **kwargs)
except Exception as e:
self.ap.logger.error(f"启动 HTTP 服务失败: {e}")
self.ap.logger.error(f'启动 HTTP 服务失败: {e}')
self.ap.task_mgr.create_task(
exception_handler(
@@ -49,63 +51,62 @@ class HTTPController:
port=self.ap.instance_config.data['api']['port'],
shutdown_trigger=shutdown_trigger_placeholder,
),
name="http-api-quart",
name='http-api-quart',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
# await asyncio.sleep(5)
async def register_routes(self) -> None:
@self.quart_app.route("/healthz")
@self.quart_app.route('/healthz')
async def healthz():
return {"code": 0, "msg": "ok"}
return {'code': 0, 'msg': 'ok'}
for g in group.preregistered_groups:
ginst = g(self.ap, self.quart_app)
await ginst.initialize()
frontend_path = "web/out"
frontend_path = 'web/out'
@self.quart_app.route("/")
@self.quart_app.route('/')
async def index():
return await quart.send_from_directory(frontend_path, "index.html", mimetype="text/html")
return await quart.send_from_directory(
frontend_path, 'index.html', mimetype='text/html'
)
@self.quart_app.route("/<path:path>")
@self.quart_app.route('/<path:path>')
async def static_file(path: str):
if not os.path.exists(os.path.join(frontend_path, path)):
if os.path.exists(os.path.join(frontend_path, path+".html")):
if os.path.exists(os.path.join(frontend_path, path + '.html')):
path += '.html'
else:
return await quart.send_from_directory(frontend_path, '404.html')
mimetype = None
if path.endswith(".html"):
mimetype = "text/html"
elif path.endswith(".js"):
mimetype = "application/javascript"
elif path.endswith(".css"):
mimetype = "text/css"
elif path.endswith(".png"):
mimetype = "image/png"
elif path.endswith(".jpg"):
mimetype = "image/jpeg"
elif path.endswith(".jpeg"):
mimetype = "image/jpeg"
elif path.endswith(".gif"):
mimetype = "image/gif"
elif path.endswith(".svg"):
mimetype = "image/svg+xml"
elif path.endswith(".ico"):
mimetype = "image/x-icon"
elif path.endswith(".json"):
mimetype = "application/json"
elif path.endswith(".txt"):
mimetype = "text/plain"
if path.endswith('.html'):
mimetype = 'text/html'
elif path.endswith('.js'):
mimetype = 'application/javascript'
elif path.endswith('.css'):
mimetype = 'text/css'
elif path.endswith('.png'):
mimetype = 'image/png'
elif path.endswith('.jpg'):
mimetype = 'image/jpeg'
elif path.endswith('.jpeg'):
mimetype = 'image/jpeg'
elif path.endswith('.gif'):
mimetype = 'image/gif'
elif path.endswith('.svg'):
mimetype = 'image/svg+xml'
elif path.endswith('.ico'):
mimetype = 'image/x-icon'
elif path.endswith('.json'):
mimetype = 'application/json'
elif path.endswith('.txt'):
mimetype = 'text/plain'
return await quart.send_from_directory(
frontend_path,
path,
mimetype=mimetype
frontend_path, path, mimetype=mimetype
)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import uuid
import datetime
import sqlalchemy
from ....core import app
@@ -29,13 +28,15 @@ class BotService:
self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot)
for bot in bots
]
async def get_bot(self, bot_uuid: str) -> dict | None:
"""获取机器人"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.select(persistence_bot.Bot).where(
persistence_bot.Bot.uuid == bot_uuid
)
)
bot = result.first()
if bot is None:
@@ -50,7 +51,9 @@ class BotService:
# checkout the default pipeline
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.is_default == True)
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True
)
)
pipeline = result.first()
if pipeline is not None:
@@ -64,7 +67,7 @@ class BotService:
bot = await self.get_bot(bot_data['uuid'])
await self.ap.platform_mgr.load_bot(bot)
return bot_data['uuid']
async def update_bot(self, bot_uuid: str, bot_data: dict) -> None:
@@ -75,19 +78,24 @@ class BotService:
# set use_pipeline_name
if 'use_pipeline_uuid' in bot_data:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid'])
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid
== bot_data['use_pipeline_uuid']
)
)
pipeline = result.first()
if pipeline is not None:
bot_data['use_pipeline_name'] = pipeline.name
else:
raise Exception("Pipeline not found")
raise Exception('Pipeline not found')
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.update(persistence_bot.Bot)
.values(bot_data)
.where(persistence_bot.Bot.uuid == bot_uuid)
)
await self.ap.platform_mgr.remove_bot(bot_uuid)
# select from db
bot = await self.get_bot(bot_uuid)
@@ -100,7 +108,7 @@ class BotService:
"""删除机器人"""
await self.ap.platform_mgr.remove_bot(bot_uuid)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.delete(persistence_bot.Bot).where(
persistence_bot.Bot.uuid == bot_uuid
)
)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import uuid
import datetime
import sqlalchemy
from ....core import app
@@ -10,7 +9,6 @@ from ....entity.persistence import pipeline as persistence_pipeline
class ModelsService:
ap: app.Application
def __init__(self, ap: app.Application) -> None:
@@ -26,15 +24,12 @@ class ModelsService:
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
for model in models
]
async def create_llm_model(self, model_data: dict) -> str:
async def create_llm_model(self, model_data: dict) -> str:
model_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.LLMModel).values(
**model_data
)
sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)
)
llm_model = await self.get_llm_model(model_data['uuid'])
@@ -43,22 +38,24 @@ class ModelsService:
# check if default pipeline has no model bound
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.is_default == True)
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True
)
)
pipeline = result.first()
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
pipeline_config = pipeline.config
pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
pipeline_data = {
"config": pipeline_config
}
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
pipeline_data = {'config': pipeline_config}
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
return model_data['uuid']
async def get_llm_model(self, model_uuid: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
sqlalchemy.select(persistence_model.LLMModel).where(
persistence_model.LLMModel.uuid == model_uuid
)
)
model = result.first()
@@ -66,14 +63,18 @@ class ModelsService:
if model is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
return self.ap.persistence_mgr.serialize_model(
persistence_model.LLMModel, model
)
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
if 'uuid' in model_data:
del model_data['uuid']
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid).values(**model_data)
sqlalchemy.update(persistence_model.LLMModel)
.where(persistence_model.LLMModel.uuid == model_uuid)
.values(**model_data)
)
await self.ap.model_mgr.remove_llm_model(model_uuid)
@@ -84,7 +85,9 @@ class ModelsService:
async def delete_llm_model(self, model_uuid: str) -> None:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
sqlalchemy.delete(persistence_model.LLMModel).where(
persistence_model.LLMModel.uuid == model_uuid
)
)
await self.ap.model_mgr.remove_llm_model(model_uuid)

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import uuid
import json
import datetime
import sqlalchemy
from ....core import app
@@ -10,69 +9,79 @@ from ....entity.persistence import pipeline as persistence_pipeline
default_stage_order = [
"GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", # 预处理器
"ConversationMessageTruncator", # 会话消息截断器
"RequireRateLimitOccupancy", # 请求速率限制占用
"MessageProcessor", # 处理器
"ReleaseRateLimitOccupancy", # 释放速率限制占用
"PostContentFilterStage", # 内容过滤后置阶段
"ResponseWrapper", # 响应包装器
"LongTextProcessStage", # 长文本处理
"SendResponseBackStage", # 发送响应
'GroupRespondRuleCheckStage', # 群响应规则检查
'BanSessionCheckStage', # 封禁会话检查
'PreContentFilterStage', # 内容过滤前置阶段
'PreProcessor', # 预处理器
'ConversationMessageTruncator', # 会话消息截断器
'RequireRateLimitOccupancy', # 请求速率限制占用
'MessageProcessor', # 处理器
'ReleaseRateLimitOccupancy', # 释放速率限制占用
'PostContentFilterStage', # 内容过滤后置阶段
'ResponseWrapper', # 响应包装器
'LongTextProcessStage', # 长文本处理
'SendResponseBackStage', # 发送响应
]
class PipelineService:
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_pipeline_metadata(self) -> dict:
return [
self.ap.pipeline_config_meta_trigger.data,
self.ap.pipeline_config_meta_safety.data,
self.ap.pipeline_config_meta_ai.data,
self.ap.pipeline_config_meta_output.data
self.ap.pipeline_config_meta_output.data,
]
async def get_pipelines(self) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
)
pipelines = result.all()
return [
self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
self.ap.persistence_mgr.serialize_model(
persistence_pipeline.LegacyPipeline, pipeline
)
for pipeline in pipelines
]
async def get_pipeline(self, pipeline_uuid: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid
)
)
pipeline = result.first()
if pipeline is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
return self.ap.persistence_mgr.serialize_model(
persistence_pipeline.LegacyPipeline, pipeline
)
async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str:
pipeline_data['uuid'] = str(uuid.uuid4())
pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version()
pipeline_data['stages'] = default_stage_order.copy()
pipeline_data['is_default'] = default
pipeline_data['config'] = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
pipeline_data['config'] = json.load(
open('templates/default-pipeline-config.json', 'r', encoding='utf-8')
)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data)
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(
**pipeline_data
)
)
pipeline = await self.get_pipeline(pipeline_data['uuid'])
await self.ap.pipeline_mgr.load_pipeline(pipeline)
@@ -90,7 +99,9 @@ class PipelineService:
del pipeline_data['is_default']
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data)
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
.where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
.values(**pipeline_data)
)
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
@@ -101,6 +112,8 @@ class PipelineService:
async def delete_pipeline(self, pipeline_uuid: str) -> None:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid
)
)
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)

View File

@@ -11,7 +11,6 @@ from ....utils import constants
class UserService:
ap: app.Application
def __init__(self, ap: app.Application) -> None:
@@ -24,7 +23,7 @@ class UserService:
result_list = result.all()
return result_list is not None and len(result_list) > 0
async def create_user(self, user_email: str, password: str) -> None:
ph = argon2.PasswordHasher()
@@ -32,8 +31,7 @@ class UserService:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(user.User).values(
user=user_email,
password=hashed_password
user=user_email, password=hashed_password
)
)
@@ -61,12 +59,12 @@ class UserService:
payload = {
'user': user_email,
'iss': 'LangBot-'+constants.edition,
'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire)
'iss': 'LangBot-' + constants.edition,
'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire),
}
return jwt.encode(payload, jwt_secret, algorithm='HS256')
async def verify_jwt_token(self, token: str) -> str:
jwt_secret = self.ap.instance_config.data['system']['jwt']['secret']

View File

@@ -1,3 +1,3 @@
"""
审计相关操作
"""
"""

View File

@@ -3,11 +3,9 @@ from __future__ import annotations
import abc
import uuid
import json
import logging
import asyncio
import aiohttp
import requests
from ...core import app, entities as core_entities
@@ -38,22 +36,22 @@ class APIGroup(metaclass=abc.ABCMeta):
"""
执行请求
"""
self._runtime_info["account_id"] = "-1"
self._runtime_info['account_id'] = '-1'
url = self.prefix + path
data = json.dumps(data)
headers["Content-Type"] = "application/json"
headers['Content-Type'] = 'application/json'
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method, url, data=data, params=params, headers=headers, **kwargs
) as resp:
self.ap.logger.debug("data: %s", data)
self.ap.logger.debug("ret: %s", await resp.text())
self.ap.logger.debug('data: %s', data)
self.ap.logger.debug('ret: %s', await resp.text())
except Exception as e:
self.ap.logger.debug(f"上报失败: {e}")
self.ap.logger.debug(f'上报失败: {e}')
async def do(
self,
@@ -68,8 +66,8 @@ class APIGroup(metaclass=abc.ABCMeta):
return self.ap.task_mgr.create_task(
self._do(method, path, data, params, headers, **kwargs),
kind="telemetry-operation",
name=f"{method} {path}",
kind='telemetry-operation',
name=f'{method} {path}',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
).task
@@ -80,7 +78,7 @@ class APIGroup(metaclass=abc.ABCMeta):
def basic_info(self):
"""获取基本信息"""
basic_info = APIGroup._basic_info.copy()
basic_info["rid"] = self.gen_rid()
basic_info['rid'] = self.gen_rid()
return basic_info
def runtime_info(self):

View File

@@ -9,7 +9,7 @@ class V2MainDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/main", ap)
super().__init__(prefix + '/main', ap)
async def do(self, *args, **kwargs):
if not self.ap.instance_config.data['telemetry']['report']:
@@ -25,31 +25,31 @@ class V2MainDataAPI(apigroup.APIGroup):
):
"""提交更新记录"""
return await self.do(
"POST",
"/update",
'POST',
'/update',
data={
"basic": self.basic_info(),
"update_info": {
"spent_seconds": spent_seconds,
"infer_reason": infer_reason,
"old_version": old_version,
"new_version": new_version,
}
}
'basic': self.basic_info(),
'update_info': {
'spent_seconds': spent_seconds,
'infer_reason': infer_reason,
'old_version': old_version,
'new_version': new_version,
},
},
)
async def post_announcement_showed(
self,
ids: list[int],
):
"""提交公告已阅"""
return await self.do(
"POST",
"/announcement",
'POST',
'/announcement',
data={
"basic": self.basic_info(),
"announcement_info": {
"ids": ids,
}
}
'basic': self.basic_info(),
'announcement_info': {
'ids': ids,
},
},
)

View File

@@ -9,39 +9,33 @@ class V2PluginDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/plugin", ap)
super().__init__(prefix + '/plugin', ap)
async def do(self, *args, **kwargs):
if not self.ap.instance_config.data['telemetry']['report']:
return None
return await super().do(*args, **kwargs)
async def post_install_record(
self,
plugin: dict
):
async def post_install_record(self, plugin: dict):
"""提交插件安装记录"""
return await self.do(
"POST",
"/install",
'POST',
'/install',
data={
"basic": self.basic_info(),
"plugin": plugin,
}
'basic': self.basic_info(),
'plugin': plugin,
},
)
async def post_remove_record(
self,
plugin: dict
):
async def post_remove_record(self, plugin: dict):
"""提交插件卸载记录"""
return await self.do(
"POST",
"/remove",
'POST',
'/remove',
data={
"basic": self.basic_info(),
"plugin": plugin,
}
'basic': self.basic_info(),
'plugin': plugin,
},
)
async def post_update_record(
@@ -52,14 +46,14 @@ class V2PluginDataAPI(apigroup.APIGroup):
):
"""提交插件更新记录"""
return await self.do(
"POST",
"/update",
'POST',
'/update',
data={
"basic": self.basic_info(),
"plugin": plugin,
"update_info": {
"old_version": old_version,
"new_version": new_version,
}
}
'basic': self.basic_info(),
'plugin': plugin,
'update_info': {
'old_version': old_version,
'new_version': new_version,
},
},
)

View File

@@ -9,7 +9,7 @@ class V2UsageDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/usage", ap)
super().__init__(prefix + '/usage', ap)
async def do(self, *args, **kwargs):
if not self.ap.instance_config.data['telemetry']['report']:
@@ -28,25 +28,25 @@ class V2UsageDataAPI(apigroup.APIGroup):
):
"""提交请求记录"""
return await self.do(
"POST",
"/query",
'POST',
'/query',
data={
"basic": self.basic_info(),
"runtime": self.runtime_info(),
"session_info": {
"type": session_type,
"id": session_id,
'basic': self.basic_info(),
'runtime': self.runtime_info(),
'session_info': {
'type': session_type,
'id': session_id,
},
"query_info": {
"ability_provider": query_ability_provider,
"usage": usage,
"model_name": model_name,
"response_seconds": response_seconds,
"retry_times": retry_times,
}
}
'query_info': {
'ability_provider': query_ability_provider,
'usage': usage,
'model_name': model_name,
'response_seconds': response_seconds,
'retry_times': retry_times,
},
},
)
async def post_event_record(
self,
plugins: list[dict],
@@ -54,18 +54,18 @@ class V2UsageDataAPI(apigroup.APIGroup):
):
"""提交事件触发记录"""
return await self.do(
"POST",
"/event",
'POST',
'/event',
data={
"basic": self.basic_info(),
"runtime": self.runtime_info(),
"plugins": plugins,
"event_info": {
"name": event_name,
}
}
'basic': self.basic_info(),
'runtime': self.runtime_info(),
'plugins': plugins,
'event_info': {
'name': event_name,
},
},
)
async def post_function_record(
self,
plugin: dict,
@@ -74,15 +74,14 @@ class V2UsageDataAPI(apigroup.APIGroup):
):
"""提交内容函数使用记录"""
return await self.do(
"POST",
"/function",
'POST',
'/function',
data={
"basic": self.basic_info(),
"plugin": plugin,
"function_info": {
"name": function_name,
"description": function_description,
}
}
'basic': self.basic_info(),
'plugin': plugin,
'function_info': {
'name': function_name,
'description': function_description,
},
},
)

View File

@@ -11,7 +11,7 @@ from ...core import app
class V2CenterAPI:
"""中央服务器 v2 API 交互类"""
main: main.V2MainDataAPI = None
"""主 API 组"""
@@ -21,15 +21,20 @@ class V2CenterAPI:
plugin: plugin.V2PluginDataAPI = None
"""插件 API 组"""
def __init__(self, ap: app.Application, backend_url: str, basic_info: dict = None, runtime_info: dict = None):
def __init__(
self,
ap: app.Application,
backend_url: str,
basic_info: dict = None,
runtime_info: dict = None,
):
"""初始化"""
logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info)
logging.debug('basic_info: %s, runtime_info: %s', basic_info, runtime_info)
apigroup.APIGroup._basic_info = basic_info
apigroup.APIGroup._runtime_info = runtime_info
self.main = main.V2MainDataAPI(backend_url, ap)
self.usage = usage.V2UsageDataAPI(backend_url, ap)
self.plugin = plugin.V2PluginDataAPI(backend_url, ap)

View File

@@ -16,6 +16,7 @@ identifier = {
HOST_ID_FILE = os.path.expanduser('~/.langbot/host_id.json')
INSTANCE_ID_FILE = 'data/labels/instance_id.json'
def init():
global identifier
@@ -23,14 +24,11 @@ def init():
os.mkdir(os.path.expanduser('~/.langbot'))
if not os.path.exists(HOST_ID_FILE):
new_host_id = 'host_'+str(uuid.uuid4())
new_host_id = 'host_' + str(uuid.uuid4())
new_host_create_ts = int(time.time())
with open(HOST_ID_FILE, 'w') as f:
json.dump({
'host_id': new_host_id,
'host_create_ts': new_host_create_ts
}, f)
json.dump({'host_id': new_host_id, 'host_create_ts': new_host_create_ts}, f)
identifier['host_id'] = new_host_id
identifier['host_create_ts'] = new_host_create_ts
@@ -51,20 +49,25 @@ def init():
instance_id = {}
with open(INSTANCE_ID_FILE, 'r') as f:
instance_id = json.load(f)
if instance_id['host_id'] != identifier['host_id']: # 如果实例 id 不是当前主机的,删除
if (
instance_id['host_id'] != identifier['host_id']
): # 如果实例 id 不是当前主机的,删除
os.remove(INSTANCE_ID_FILE)
if not os.path.exists(INSTANCE_ID_FILE):
new_instance_id = 'instance_'+str(uuid.uuid4())
new_instance_id = 'instance_' + str(uuid.uuid4())
new_instance_create_ts = int(time.time())
with open(INSTANCE_ID_FILE, 'w') as f:
json.dump({
'host_id': identifier['host_id'],
'instance_id': new_instance_id,
'instance_create_ts': new_instance_create_ts
}, f)
json.dump(
{
'host_id': identifier['host_id'],
'instance_id': new_instance_id,
'instance_create_ts': new_instance_create_ts,
},
f,
)
identifier['instance_id'] = new_instance_id
identifier['instance_create_ts'] = new_instance_create_ts
@@ -80,6 +83,7 @@ def init():
identifier['instance_id'] = loaded_instance_id
identifier['instance_create_ts'] = loaded_instance_create_ts
def print_out():
global identifier
print(identifier)

View File

@@ -3,17 +3,17 @@ from __future__ import annotations
import typing
from ..core import app, entities as core_entities
from ..provider import entities as llm_entities
from . import entities, operator, errors
from ..config import manager as cfg_mgr
from ..utils import importutil
# 引入所有算子以便注册
from .operators import func, plugin, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
from . import operators
importutil.import_modules_in_pkg(operators)
class CommandManager:
"""命令管理器
"""
"""命令管理器"""
ap: app.Application
@@ -26,14 +26,13 @@ class CommandManager:
self.ap = ap
async def initialize(self):
# 设置各个类的路径
def set_path(cls: operator.CommandOperator, ancestors: list[str]):
cls.path = '.'.join(ancestors + [cls.name])
for op in operator.preregistered_operators:
if op.parent_class == cls:
set_path(op, ancestors + [cls.name])
for cls in operator.preregistered_operators:
if cls.parent_class is None:
set_path(cls, [])
@@ -41,14 +40,18 @@ class CommandManager:
# 应用命令权限配置
for cls in operator.preregistered_operators:
if cls.path in self.ap.instance_config.data['command']['privilege']:
cls.lowest_privilege = self.ap.instance_config.data['command']['privilege'][cls.path]
cls.lowest_privilege = self.ap.instance_config.data['command'][
'privilege'
][cls.path]
# 实例化所有类
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
# 设置所有类的子节点
for cmd in self.cmd_list:
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
cmd.children = [
child for child in self.cmd_list if child.parent_class == cmd.__class__
]
# 初始化所有类
for cmd in self.cmd_list:
@@ -58,27 +61,25 @@ class CommandManager:
self,
context: entities.ExecuteContext,
operator_list: list[operator.CommandOperator],
operator: operator.CommandOperator = None
operator: operator.CommandOperator = None,
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
"""执行命令"""
found = False
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
for oper in operator_list:
if (context.crt_params[0] == oper.name \
or context.crt_params[0] in oper.alias) \
and (oper.parent_class is None or oper.parent_class == operator.__class__):
if (
context.crt_params[0] == oper.name
or context.crt_params[0] in oper.alias
) and (
oper.parent_class is None or oper.parent_class == operator.__class__
):
found = True
context.crt_command = context.crt_params[0]
context.crt_params = context.crt_params[1:]
async for ret in self._execute(
context,
oper.children,
oper
):
async for ret in self._execute(context, oper.children, oper):
yield ret
break
@@ -96,19 +97,20 @@ class CommandManager:
async for ret in operator.execute(context):
yield ret
async def execute(
self,
command_text: str,
query: core_entities.Query,
session: core_entities.Session
session: core_entities.Session,
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
"""执行命令"""
privilege = 1
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
if (
f'{query.launcher_type.value}_{query.launcher_id}'
in self.ap.instance_config.data['admins']
):
privilege = 2
ctx = entities.ExecuteContext(
@@ -119,11 +121,8 @@ class CommandManager:
crt_command='',
params=command_text.split(' '),
crt_params=command_text.split(' '),
privilege=privilege
privilege=privilege,
)
async for ret in self._execute(
ctx,
self.cmd_list
):
async for ret in self._execute(ctx, self.cmd_list):
yield ret

View File

@@ -4,14 +4,13 @@ import typing
import pydantic.v1 as pydantic
from ..core import app, entities as core_entities
from . import errors, operator
from ..core import entities as core_entities
from . import errors
from ..platform.types import message as platform_message
class CommandReturn(pydantic.BaseModel):
"""命令返回值
"""
"""命令返回值"""
text: typing.Optional[str] = None
"""文本
@@ -24,7 +23,7 @@ class CommandReturn(pydantic.BaseModel):
"""图片链接
"""
error: typing.Optional[errors.CommandError]= None
error: typing.Optional[errors.CommandError] = None
"""错误
"""
@@ -33,8 +32,7 @@ class CommandReturn(pydantic.BaseModel):
class ExecuteContext(pydantic.BaseModel):
"""单次命令执行上下文
"""
"""单次命令执行上下文"""
query: core_entities.Query
"""本次消息的请求对象"""

View File

@@ -1,33 +1,26 @@
class CommandError(Exception):
def __init__(self, message: str = None):
self.message = message
def __str__(self):
return self.message
class CommandNotFoundError(CommandError):
def __init__(self, message: str = None):
super().__init__("未知命令: "+message)
super().__init__('未知命令: ' + message)
class CommandPrivilegeError(CommandError):
def __init__(self, message: str = None):
super().__init__("权限不足: "+message)
super().__init__('权限不足: ' + message)
class ParamNotEnoughError(CommandError):
def __init__(self, message: str = None):
super().__init__("参数不足: "+message)
super().__init__('参数不足: ' + message)
class CommandOperationError(CommandError):
def __init__(self, message: str = None):
super().__init__("操作失败: "+message)
super().__init__('操作失败: ' + message)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import typing
import abc
from ..core import app, entities as core_entities
from ..core import app
from . import entities
@@ -13,14 +13,14 @@ preregistered_operators: list[typing.Type[CommandOperator]] = []
def operator_class(
name: str,
help: str = "",
help: str = '',
usage: str = None,
alias: list[str] = [],
privilege: int=1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None
privilege: int = 1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None,
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
"""命令类装饰器
Args:
name (str): 名称
help (str, optional): 帮助信息. Defaults to "".
@@ -35,7 +35,7 @@ def operator_class(
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
assert issubclass(cls, CommandOperator)
cls.name = name
cls.alias = alias
cls.help = help
@@ -96,14 +96,13 @@ class CommandOperator(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""实现此方法以执行命令
支持多次yield以返回多个结果。
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
Args:
context (entities.ExecuteContext): 命令执行上下文

View File

@@ -2,49 +2,46 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="cmd",
help='显示命令列表',
usage='!cmd\n!cmd <命令名称>'
)
@operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>')
class CmdOperator(operator.CommandOperator):
"""命令列表
"""
"""命令列表"""
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
"""执行"""
if len(context.crt_params) == 0:
reply_str = "当前所有命令: \n\n"
reply_str = '当前所有命令: \n\n'
for cmd in self.ap.cmd_mgr.cmd_list:
if cmd.parent_class is None:
reply_str += f"{cmd.name}: {cmd.help}\n"
reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助"
reply_str += f'{cmd.name}: {cmd.help}\n'
reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助'
yield entities.CommandReturn(text=reply_str.strip())
else:
cmd_name = context.crt_params[0]
cmd = None
for _cmd in self.ap.cmd_mgr.cmd_list:
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None):
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (
_cmd.parent_class is None
):
cmd = _cmd
break
if cmd is None:
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
yield entities.CommandReturn(
error=errors.CommandNotFoundError(cmd_name)
)
else:
reply_str = f"{cmd.name}: {cmd.help}\n\n"
reply_str += f"使用方法: \n{cmd.usage}"
reply_str = f'{cmd.name}: {cmd.help}\n\n'
reply_str += f'使用方法: \n{cmd.usage}'
yield entities.CommandReturn(text=reply_str.strip())

View File

@@ -1,62 +1,60 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="del",
help="删除当前会话的历史记录",
usage='!del <序号>\n!del all'
name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all'
)
class DelOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
delete_index = 0
if len(context.crt_params) > 0:
try:
delete_index = int(context.crt_params[0])
except:
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
except Exception:
yield entities.CommandReturn(
error=errors.CommandOperationError('索引必须是整数')
)
return
if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
return
# 倒序
to_delete_index = len(context.session.conversations)-1-delete_index
if context.session.conversations[to_delete_index] == context.session.using_conversation:
if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(
error=errors.CommandOperationError('索引超出范围')
)
return
# 倒序
to_delete_index = len(context.session.conversations) - 1 - delete_index
if (
context.session.conversations[to_delete_index]
== context.session.using_conversation
):
context.session.using_conversation = None
del context.session.conversations[to_delete_index]
yield entities.CommandReturn(text=f"已删除对话: {delete_index}")
yield entities.CommandReturn(text=f'已删除对话: {delete_index}')
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
@operator.operator_class(
name="all",
help="删除此会话的所有历史记录",
parent_class=DelOperator
name='all', help='删除此会话的所有历史记录', parent_class=DelOperator
)
class DelAllOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
context.session.conversations = []
context.session.using_conversation = None
yield entities.CommandReturn(text="已删除所有对话")
yield entities.CommandReturn(text='已删除所有对话')

View File

@@ -1,16 +1,15 @@
from __future__ import annotations
from typing import AsyncGenerator
from .. import operator, entities, cmdmgr
from ...plugin import context as plugin_context
from .. import operator, entities
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
class FuncOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前已启用的内容函数: \n\n"
reply_str = '当前已启用的内容函数: \n\n'
index = 1
@@ -19,7 +18,7 @@ class FuncOperator(operator.CommandOperator):
)
for func in all_functions:
reply_str += "{}. {}:\n{}\n\n".format(
reply_str += '{}. {}:\n{}\n\n'.format(
index,
func.name,
func.description,

View File

@@ -2,19 +2,13 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities
@operator.operator_class(
name='help',
help='显示帮助',
usage='!help\n!help <命令名称>'
)
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
class HelpOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接https://langbot.app'

View File

@@ -1,36 +1,43 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="last",
help="切换到前一个对话",
usage='!last'
)
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
class LastOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的上一个会话
for index in range(len(context.session.conversations)-1, -1, -1):
if context.session.conversations[index] == context.session.using_conversation:
for index in range(len(context.session.conversations) - 1, -1, -1):
if (
context.session.conversations[index]
== context.session.using_conversation
):
if index == 0:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
yield entities.CommandReturn(
error=errors.CommandOperationError('已经是第一个对话了')
)
return
else:
context.session.using_conversation = context.session.conversations[index-1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
context.session.using_conversation = (
context.session.conversations[index - 1]
)
time_str = (
context.session.using_conversation.create_time.strftime(
'%Y-%m-%d %H:%M:%S'
)
)
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}")
yield entities.CommandReturn(
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
)
return
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)

View File

@@ -1,30 +1,26 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="list",
help="列出此会话中的所有历史对话",
usage='!list\n!list <页码>'
name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>'
)
class ListOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
page = 0
if len(context.crt_params) > 0:
try:
page = int(context.crt_params[0]-1)
except:
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
page = int(context.crt_params[0] - 1)
except Exception:
yield entities.CommandReturn(
error=errors.CommandOperationError('页码应为整数')
)
return
record_per_page = 10
@@ -36,21 +32,21 @@ class ListOperator(operator.CommandOperator):
using_conv_index = 0
for conv in context.session.conversations[::-1]:
time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S")
time_str = conv.create_time.strftime('%Y-%m-%d %H:%M:%S')
if conv == context.session.using_conversation:
using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n"
content += f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n'
index += 1
if content == '':
content = ''
else:
if context.session.using_conversation is None:
content += "\n当前处于新会话"
content += '\n当前处于新会话'
else:
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}"
yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}")
content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}'
yield entities.CommandReturn(text=f'{page + 1} 页 (时间倒序):\n{content}')

View File

@@ -2,42 +2,44 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="model",
name='model',
help='显示和切换模型列表',
usage='!model\n!model show <模型名>\n!model set <模型名>',
privilege=2
privilege=2,
)
class ModelOperator(operator.CommandOperator):
"""Model命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content = '模型列表:\n'
model_list = self.ap.model_mgr.model_list
for model in model_list:
content += f"\n名称: {model.name}\n"
content += f"请求器: {model.requester.name}\n"
content += f'\n名称: {model.name}\n'
content += f'请求器: {model.requester.name}\n'
content += f"\n当前对话使用模型: {context.query.use_model.name}\n"
content += f"新对话默认使用模型: {self.ap.provider_cfg.data.get('model')}\n"
content += f'\n当前对话使用模型: {context.query.use_model.name}\n'
content += f'新对话默认使用模型: {self.ap.provider_cfg.data.get("model")}\n'
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name="show",
help='显示模型详情',
privilege=2,
parent_class=ModelOperator
name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator
)
class ModelShowOperator(operator.CommandOperator):
"""Model Show命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
@@ -47,29 +49,31 @@ class ModelShowOperator(operator.CommandOperator):
break
if model is None:
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
yield entities.CommandReturn(
error=errors.CommandError(f'未找到模型 {model_name}')
)
else:
content = f"模型详情\n"
content += f"名称: {model.name}\n"
content = '模型详情\n'
content += f'名称: {model.name}\n'
if model.model_name is not None:
content += f"请求模型名称: {model.model_name}\n"
content += f"请求器: {model.requester.name}\n"
content += f"密钥组: {model.token_mgr.name}\n"
content += f"支持视觉: {model.vision_supported}\n"
content += f"支持工具: {model.tool_call_supported}\n"
content += f'请求模型名称: {model.model_name}\n'
content += f'请求器: {model.requester.name}\n'
content += f'密钥组: {model.token_mgr.name}\n'
content += f'支持视觉: {model.vision_supported}\n'
content += f'支持工具: {model.tool_call_supported}\n'
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name="set",
help='设置默认使用模型',
privilege=2,
parent_class=ModelOperator
name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator
)
class ModelSetOperator(operator.CommandOperator):
"""Model Set命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
@@ -79,8 +83,12 @@ class ModelSetOperator(operator.CommandOperator):
break
if model is None:
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
yield entities.CommandReturn(
error=errors.CommandError(f'未找到模型 {model_name}')
)
else:
self.ap.provider_cfg.data['model'] = model_name
await self.ap.provider_cfg.dump_config()
yield entities.CommandReturn(text=f"已设置当前使用模型为 {model_name},重置会话以生效")
yield entities.CommandReturn(
text=f'已设置当前使用模型为 {model_name},重置会话以生效'
)

View File

@@ -1,35 +1,42 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="next",
help="切换到后一个对话",
usage='!next'
)
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
class NextOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的下一个会话
for index in range(len(context.session.conversations)):
if context.session.conversations[index] == context.session.using_conversation:
if index == len(context.session.conversations)-1:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
if (
context.session.conversations[index]
== context.session.using_conversation
):
if index == len(context.session.conversations) - 1:
yield entities.CommandReturn(
error=errors.CommandOperationError('已经是最后一个对话了')
)
return
else:
context.session.using_conversation = context.session.conversations[index+1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
context.session.using_conversation = (
context.session.conversations[index + 1]
)
time_str = (
context.session.using_conversation.create_time.strftime(
'%Y-%m-%d %H:%M:%S'
)
)
yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
yield entities.CommandReturn(
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
)
return
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)

View File

@@ -2,31 +2,32 @@ from __future__ import annotations
import json
import typing
import traceback
import ollama
from .. import operator, entities, errors
@operator.operator_class(
name="ollama",
help="ollama平台操作",
usage="!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>"
name='ollama',
help='ollama平台操作',
usage='!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>',
)
class OllamaOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
content: str = '模型列表:\n'
model_list: list = ollama.list().get('models', [])
for model in model_list:
content += f"名称: {model['name']}\n"
content += f"修改时间: {model['modified_at']}\n"
content += f"大小: {bytes_to_mb(model['size'])}MB\n\n"
yield entities.CommandReturn(text=f"{content.strip()}")
except ollama.ResponseError as e:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
content += f'名称: {model["name"]}\n'
content += f'修改时间: {model["modified_at"]}\n'
content += f'大小: {bytes_to_mb(model["size"])}MB\n\n'
yield entities.CommandReturn(text=f'{content.strip()}')
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
)
def bytes_to_mb(num_bytes):
@@ -35,14 +36,11 @@ def bytes_to_mb(num_bytes):
@operator.operator_class(
name="show",
help="ollama模型详情",
privilege=2,
parent_class=OllamaOperator
name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator
)
class OllamaShowOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content: str = '模型详情:\n'
try:
@@ -53,31 +51,36 @@ class OllamaShowOperator(operator.CommandOperator):
for key in ['license', 'modelfile']:
show[key] = ignore_show
for key in ['tokenizer.chat_template.rag', 'tokenizer.chat_template.tool_use']:
for key in [
'tokenizer.chat_template.rag',
'tokenizer.chat_template.tool_use',
]:
model_info[key] = ignore_show
content += json.dumps(show, indent=4)
yield entities.CommandReturn(text=content.strip())
except ollama.ResponseError as e:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型详情,请确认 Ollama 服务正常"))
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常')
)
@operator.operator_class(
name="pull",
help="ollama模型拉取",
privilege=2,
parent_class=OllamaOperator
name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator
)
class OllamaPullOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
model_list: list = ollama.list().get('models', [])
if context.crt_params[0] in [model['name'] for model in model_list]:
yield entities.CommandReturn(text="模型已存在")
yield entities.CommandReturn(text='模型已存在')
return
except ollama.ResponseError as e:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
)
return
on_progress: bool = False
@@ -99,23 +102,21 @@ class OllamaPullOperator(operator.CommandOperator):
if percentage_completed > progress_count:
progress_count += 10
yield entities.CommandReturn(
text=f"下载进度: {completed}/{total} ({percentage_completed:.2f}%)")
text=f'下载进度: {completed}/{total} ({percentage_completed:.2f}%)'
)
except ollama.ResponseError as e:
yield entities.CommandReturn(text=f"拉取失败: {e.error}")
yield entities.CommandReturn(text=f'拉取失败: {e.error}')
@operator.operator_class(
name="del",
help="ollama模型删除",
privilege=2,
parent_class=OllamaOperator
name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator
)
class OllamaDelOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
ret: str = ollama.delete(model=context.crt_params[0])['status']
except ollama.ResponseError as e:
ret = f"{e.error}"
ret = f'{e.error}'
yield entities.CommandReturn(text=ret)

View File

@@ -2,31 +2,30 @@ from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
from ...core import app
from .. import operator, entities, errors
@operator.operator_class(
name="plugin",
help="插件操作",
usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>"
name='plugin',
help='插件操作',
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
)
class PluginOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = self.ap.plugin_mgr.plugins()
reply_str = "所有插件({}):\n".format(len(plugin_list))
reply_str = '所有插件({}):\n'.format(len(plugin_list))
idx = 0
for plugin in plugin_list:
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
.format((idx+1), plugin.plugin_name,
"[已禁用]" if not plugin.enabled else "",
plugin.plugin_description,
plugin.plugin_version, plugin.plugin_author)
reply_str += '\n#{} {} {}\n{}\nv{}\n作者: {}\n'.format(
(idx + 1),
plugin.plugin_name,
'[已禁用]' if not plugin.enabled else '',
plugin.plugin_description,
plugin.plugin_version,
plugin.plugin_author,
)
idx += 1
@@ -34,48 +33,42 @@ class PluginOperator(operator.CommandOperator):
@operator.operator_class(
name="get",
help="安装插件",
privilege=2,
parent_class=PluginOperator
name='get', help='安装插件', privilege=2, parent_class=PluginOperator
)
class PluginGetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件仓库地址')
)
else:
repo = context.crt_params[0]
yield entities.CommandReturn(text="正在安装插件...")
yield entities.CommandReturn(text='正在安装插件...')
try:
await self.ap.plugin_mgr.install_plugin(repo)
yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件")
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件安装失败: ' + str(e))
)
@operator.operator_class(
name="update",
help="更新插件",
privilege=2,
parent_class=PluginOperator
name='update', help='更新插件', privilege=2, parent_class=PluginOperator
)
class PluginUpdateOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else:
plugin_name = context.crt_params[0]
@@ -83,36 +76,34 @@ class PluginUpdateOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None:
yield entities.CommandReturn(text="正在更新插件...")
yield entities.CommandReturn(text='正在更新插件...')
await self.ap.plugin_mgr.update_plugin(plugin_name)
yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件")
yield entities.CommandReturn(
text='插件更新成功,请重启程序以加载插件'
)
else:
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件"))
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: 未找到插件')
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
@operator.operator_class(
name="all",
help="更新所有插件",
privilege=2,
parent_class=PluginUpdateOperator
name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator
)
class PluginUpdateAllOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
plugins = [
p.plugin_name
for p in self.ap.plugin_mgr.plugins()
]
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
if plugins:
yield entities.CommandReturn(text="正在更新插件...")
yield entities.CommandReturn(text='正在更新插件...')
updated = []
try:
for plugin_name in plugins:
@@ -120,30 +111,32 @@ class PluginUpdateAllOperator(operator.CommandOperator):
updated.append(plugin_name)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated)))
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
yield entities.CommandReturn(
text='已更新插件: {}'.format(', '.join(updated))
)
else:
yield entities.CommandReturn(text="没有可更新的插件")
yield entities.CommandReturn(text='没有可更新的插件')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
@operator.operator_class(
name="del",
help="删除插件",
privilege=2,
parent_class=PluginOperator
name='del', help='删除插件', privilege=2, parent_class=PluginOperator
)
class PluginDelOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else:
plugin_name = context.crt_params[0]
@@ -151,67 +144,81 @@ class PluginDelOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None:
yield entities.CommandReturn(text="正在删除插件...")
yield entities.CommandReturn(text='正在删除插件...')
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件")
yield entities.CommandReturn(
text='插件删除成功,请重启程序以加载插件'
)
else:
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件"))
yield entities.CommandReturn(
error=errors.CommandError('插件删除失败: 未找到插件')
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件删除失败: ' + str(e))
)
@operator.operator_class(
name="on",
help="启用插件",
privilege=2,
parent_class=PluginOperator
name='on', help='启用插件', privilege=2, parent_class=PluginOperator
)
class PluginEnableOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else:
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
yield entities.CommandReturn(
text='已启用插件: {}'.format(plugin_name)
)
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
yield entities.CommandReturn(
error=errors.CommandError(
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
)
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: ' + str(e))
)
@operator.operator_class(
name="off",
help="禁用插件",
privilege=2,
parent_class=PluginOperator
name='off', help='禁用插件', privilege=2, parent_class=PluginOperator
)
class PluginDisableOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else:
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
yield entities.CommandReturn(
text='已禁用插件: {}'.format(plugin_name)
)
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
yield entities.CommandReturn(
error=errors.CommandError(
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
)
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: ' + str(e))
)

View File

@@ -2,28 +2,23 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="prompt",
help="查看当前对话的前文",
usage='!prompt'
)
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
class PromptOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
"""执行"""
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
else:
reply_str = '当前对话所有内容:\n\n'
for msg in context.session.using_conversation.messages:
reply_str += f"{msg.role}: {msg.content}\n"
reply_str += f'{msg.role}: {msg.content}\n'
yield entities.CommandReturn(text=reply_str)
yield entities.CommandReturn(text=reply_str)

View File

@@ -2,26 +2,22 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="resend",
help="重发当前会话的最后一条消息",
usage='!resend'
name='resend', help='重发当前会话的最后一条消息', usage='!resend'
)
class ResendOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
# 回滚到最后一条用户message前
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandError("当前没有对话"))
yield entities.CommandReturn(error=errors.CommandError('当前没有对话'))
else:
conv_msg = context.session.using_conversation.messages
# 倒序一直删到最后一条用户message
while len(conv_msg) > 0 and conv_msg[-1].role != 'user':
conv_msg.pop()
@@ -31,4 +27,4 @@ class ResendOperator(operator.CommandOperator):
conv_msg.pop()
# 不重发了,提示用户已删除就行了
yield entities.CommandReturn(text="已删除最后一次请求记录")
yield entities.CommandReturn(text='已删除最后一次请求记录')

View File

@@ -2,22 +2,15 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities
@operator.operator_class(
name="reset",
help="重置当前会话",
usage='!reset'
)
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
class ResetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
"""执行"""
context.session.using_conversation = None
yield entities.CommandReturn(text="已重置当前会话")
yield entities.CommandReturn(text='已重置当前会话')

View File

@@ -3,28 +3,22 @@ from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="update",
help="更新程序",
usage='!update',
privilege=2
)
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
class UpdateCommand(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
yield entities.CommandReturn(text="正在进行更新...")
yield entities.CommandReturn(text='正在进行更新...')
if await self.ap.ver_mgr.update_all():
yield entities.CommandReturn(text="更新完成,请重启程序以应用更新")
yield entities.CommandReturn(text='更新完成,请重启程序以应用更新')
else:
yield entities.CommandReturn(text="当前已是最新版本")
yield entities.CommandReturn(text='当前已是最新版本')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("更新失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('更新失败: ' + str(e))
)

View File

@@ -2,26 +2,20 @@ from __future__ import annotations
import typing
from .. import operator, cmdmgr, entities, errors
from .. import operator, entities
@operator.operator_class(
name="version",
help="显示版本信息",
usage='!version'
)
@operator.operator_class(name='version', help='显示版本信息', usage='!version')
class VersionCommand(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = f"当前版本: \n{self.ap.ver_mgr.get_current_version()}"
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
try:
if await self.ap.ver_mgr.is_new_version_available():
reply_str += "\n\n有新版本可用。"
except:
reply_str += '\n\n有新版本可用。'
except Exception:
pass
yield entities.CommandReturn(text=reply_str.strip())
yield entities.CommandReturn(text=reply_str.strip())

View File

@@ -9,7 +9,10 @@ class JSONConfigFile(file_model.ConfigFile):
"""JSON配置文件"""
def __init__(
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
self,
config_file_name: str,
template_file_name: str = None,
template_data: dict = None,
) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
@@ -22,28 +25,26 @@ class JSONConfigFile(file_model.ConfigFile):
if self.template_file_name is not None:
shutil.copyfile(self.template_file_name, self.config_file_name)
elif self.template_data is not None:
with open(self.config_file_name, "w", encoding="utf-8") as f:
with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(self.template_data, f, indent=4, ensure_ascii=False)
else:
raise ValueError("template_file_name or template_data must be provided")
async def load(self, completion: bool=True) -> dict:
raise ValueError('template_file_name or template_data must be provided')
async def load(self, completion: bool = True) -> dict:
if not self.exists():
await self.create()
if self.template_file_name is not None:
with open(self.template_file_name, "r", encoding="utf-8") as f:
with open(self.template_file_name, 'r', encoding='utf-8') as f:
self.template_data = json.load(f)
with open(self.config_file_name, "r", encoding="utf-8") as f:
with open(self.config_file_name, 'r', encoding='utf-8') as f:
try:
cfg = json.load(f)
except json.JSONDecodeError as e:
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}')
if completion:
for key in self.template_data:
if key not in cfg:
cfg[key] = self.template_data[key]
@@ -51,9 +52,9 @@ class JSONConfigFile(file_model.ConfigFile):
return cfg
async def save(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)
def save_sync(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)

View File

@@ -25,10 +25,10 @@ class PythonModuleConfigFile(file_model.ConfigFile):
async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self, completion: bool=True) -> dict:
async def load(self, completion: bool = True) -> dict:
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
module = importlib.import_module(module_name)
cfg = {}
allowed_types = (int, float, str, bool, list, dict)
@@ -63,4 +63,4 @@ class PythonModuleConfigFile(file_model.ConfigFile):
logging.warning('Python模块配置文件不支持保存')
def save_sync(self, data: dict):
logging.warning('Python模块配置文件不支持保存')
logging.warning('Python模块配置文件不支持保存')

View File

@@ -9,7 +9,10 @@ class YAMLConfigFile(file_model.ConfigFile):
"""YAML配置文件"""
def __init__(
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
self,
config_file_name: str,
template_file_name: str = None,
template_data: dict = None,
) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
@@ -22,28 +25,26 @@ class YAMLConfigFile(file_model.ConfigFile):
if self.template_file_name is not None:
shutil.copyfile(self.template_file_name, self.config_file_name)
elif self.template_data is not None:
with open(self.config_file_name, "w", encoding="utf-8") as f:
with open(self.config_file_name, 'w', encoding='utf-8') as f:
yaml.dump(self.template_data, f, indent=4, allow_unicode=True)
else:
raise ValueError("template_file_name or template_data must be provided")
async def load(self, completion: bool=True) -> dict:
raise ValueError('template_file_name or template_data must be provided')
async def load(self, completion: bool = True) -> dict:
if not self.exists():
await self.create()
if self.template_file_name is not None:
with open(self.template_file_name, "r", encoding="utf-8") as f:
with open(self.template_file_name, 'r', encoding='utf-8') as f:
self.template_data = yaml.load(f, Loader=yaml.FullLoader)
with open(self.config_file_name, "r", encoding="utf-8") as f:
with open(self.config_file_name, 'r', encoding='utf-8') as f:
try:
cfg = yaml.load(f, Loader=yaml.FullLoader)
except yaml.YAMLError as e:
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}')
if completion:
for key in self.template_data:
if key not in cfg:
cfg[key] = self.template_data[key]
@@ -51,9 +52,9 @@ class YAMLConfigFile(file_model.ConfigFile):
return cfg
async def save(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
with open(self.config_file_name, 'w', encoding='utf-8') as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True)
def save_sync(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True)
with open(self.config_file_name, 'w', encoding='utf-8') as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True)

View File

@@ -6,7 +6,7 @@ from .impls import pymodule, json as json_file, yaml as yaml_file
class ConfigManager:
"""配置文件管理器"""
name: str = None
"""配置管理器名"""
@@ -31,7 +31,7 @@ class ConfigManager:
self.file = cfg_file
self.data = {}
async def load_config(self, completion: bool=True):
async def load_config(self, completion: bool = True):
self.data = await self.file.load(completion=completion)
async def dump_config(self):
@@ -41,9 +41,11 @@ class ConfigManager:
self.file.save_sync(self.data)
async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
async def load_python_module_config(
config_name: str, template_name: str, completion: bool = True
) -> ConfigManager:
"""加载Python模块配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
@@ -52,10 +54,7 @@ async def load_python_module_config(config_name: str, template_name: str, comple
Returns:
ConfigManager: 配置文件管理器
"""
cfg_inst = pymodule.PythonModuleConfigFile(
config_name,
template_name
)
cfg_inst = pymodule.PythonModuleConfigFile(config_name, template_name)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion)
@@ -63,20 +62,21 @@ async def load_python_module_config(config_name: str, template_name: str, comple
return cfg_mgr
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
async def load_json_config(
config_name: str,
template_name: str = None,
template_data: dict = None,
completion: bool = True,
) -> ConfigManager:
"""加载JSON配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
template_data (dict): 模板数据
completion (bool): 是否自动补全内存中的配置文件
"""
cfg_inst = json_file.JSONConfigFile(
config_name,
template_name,
template_data
)
cfg_inst = json_file.JSONConfigFile(config_name, template_name, template_data)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion)
@@ -84,9 +84,14 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
return cfg_mgr
async def load_yaml_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
async def load_yaml_config(
config_name: str,
template_name: str = None,
template_data: dict = None,
completion: bool = True,
) -> ConfigManager:
"""加载YAML配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
@@ -96,11 +101,7 @@ async def load_yaml_config(config_name: str, template_name: str=None, template_d
Returns:
ConfigManager: 配置文件管理器
"""
cfg_inst = yaml_file.YAMLConfigFile(
config_name,
template_name,
template_data
)
cfg_inst = yaml_file.YAMLConfigFile(config_name, template_name, template_data)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion)

View File

@@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def load(self, completion: bool=True) -> dict:
async def load(self, completion: bool = True) -> dict:
pass
@abc.abstractmethod

View File

@@ -2,9 +2,7 @@ from __future__ import annotations
import logging
import asyncio
import threading
import traceback
import enum
import sys
import os
@@ -29,7 +27,6 @@ from ..discover import engine as discover_engine
from ..utils import logcache, ip
from . import taskmgr
from . import entities as core_entities
from .bootutils import config
class Application:
@@ -123,33 +120,55 @@ class Application:
async def run(self):
try:
await self.plugin_mgr.initialize_plugins()
# 后续可能会允许动态重启其他任务
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
async def never_ending():
while True:
await asyncio.sleep(1)
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(
self.platform_mgr.run(),
name='platform-manager',
scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
],
)
self.task_mgr.create_task(
self.ctrl.run(),
name='query-controller',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
self.task_mgr.create_task(
self.http_ctrl.run(),
name='http-api-controller',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
self.task_mgr.create_task(
never_ending(),
name='never-ending-task',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
await self.print_web_access_info()
await self.task_mgr.wait_all()
except asyncio.CancelledError:
pass
except Exception as e:
self.logger.error(f"应用运行致命异常: {e}")
self.logger.debug(f"Traceback: {traceback.format_exc()}")
self.logger.error(f'应用运行致命异常: {e}')
self.logger.debug(f'Traceback: {traceback.format_exc()}')
async def print_web_access_info(self):
"""打印访问 webui 的提示"""
if not os.path.exists(os.path.join(".", "web/out")):
self.logger.warning("WebUI 文件缺失请根据文档获取https://docs.langbot.app/webui/intro.html")
if not os.path.exists(os.path.join('.', 'web/out')):
self.logger.warning(
'WebUI 文件缺失请根据文档获取https://docs.langbot.app/webui/intro.html'
)
return
host_ip = "127.0.0.1"
host_ip = '127.0.0.1'
public_ip = await ip.get_myip()
@@ -170,7 +189,7 @@ class Application:
🤯 WebUI 仍处于 Beta 测试阶段,如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues
=======================================
""".strip()
for line in tips.split("\n"):
for line in tips.split('\n'):
self.logger.info(line)
async def reload(
@@ -179,21 +198,28 @@ class Application:
):
match scope:
case core_entities.LifecycleControlScope.PLATFORM.value:
self.logger.info("执行热重载 scope="+scope)
self.logger.info('执行热重载 scope=' + scope)
await self.platform_mgr.shutdown()
self.platform_mgr = im_mgr.PlatformManager(self)
await self.platform_mgr.initialize()
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
self.task_mgr.create_task(
self.platform_mgr.run(),
name='platform-manager',
scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
],
)
case core_entities.LifecycleControlScope.PLUGIN.value:
self.logger.info("执行热重载 scope="+scope)
self.logger.info('执行热重载 scope=' + scope)
await self.plugin_mgr.destroy_plugins()
# 删除 sys.module 中所有的 plugins/* 下的模块
for mod in list(sys.modules.keys()):
if mod.startswith("plugins."):
if mod.startswith('plugins.'):
del sys.modules[mod]
self.plugin_mgr = plugin_mgr.PluginManager(self)
@@ -204,7 +230,7 @@ class Application:
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins()
case core_entities.LifecycleControlScope.PROVIDER.value:
self.logger.info("执行热重载 scope="+scope)
self.logger.info('执行热重载 scope=' + scope)
await self.tool_mgr.shutdown()
@@ -220,4 +246,4 @@ class Application:
await llm_tool_mgr_inst.initialize()
self.tool_mgr = llm_tool_mgr_inst
case _:
pass
pass

View File

@@ -7,29 +7,30 @@ import os
from . import app
from ..audit import identifier
from . import stage
from ..utils import constants
from ..utils import constants, importutil
# 引入启动阶段实现以便注册
from .stages import load_config, setup_logger, build_app, migrate, show_notes, genkeys
from . import stages
importutil.import_modules_in_pkg(stages)
stage_order = [
"LoadConfigStage",
"MigrationStage",
"GenKeysStage",
"SetupLoggerStage",
"BuildAppStage",
"ShowNotesStage"
'LoadConfigStage',
'MigrationStage',
'GenKeysStage',
'SetupLoggerStage',
'BuildAppStage',
'ShowNotesStage',
]
async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
# 生成标识符
identifier.init()
# 确定是否为调试模式
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']:
constants.debug_mode = True
ap = app.Application()
@@ -50,21 +51,17 @@ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
async def main(loop: asyncio.AbstractEventLoop):
try:
# 挂系统信号处理
import signal
ap: app.Application
def signal_handler(sig, frame):
print("[Signal] 程序退出.")
print('[Signal] 程序退出.')
# ap.shutdown()
os._exit(0)
signal.signal(signal.SIGINT, signal_handler)
app_inst = await make_app(loop)
ap = app_inst
await app_inst.run()
except Exception as e:
except Exception:
traceback.print_exc()

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
import json
from ...config import manager as config_mgr
from ...config.impls import pymodule
load_python_module_config = config_mgr.load_python_module_config
load_json_config = config_mgr.load_json_config
load_yaml_config = config_mgr.load_yaml_config
load_yaml_config = config_mgr.load_yaml_config

View File

@@ -5,39 +5,39 @@ from ...utils import pkgmgr
# 检查依赖,防止用户未安装
# 左边为引入名称,右边为依赖名称
required_deps = {
"requests": "requests",
"openai": "openai",
"anthropic": "anthropic",
"colorlog": "colorlog",
"aiocqhttp": "aiocqhttp",
"botpy": "qq-botpy-rc",
"PIL": "pillow",
"nakuru": "nakuru-project-idk",
"tiktoken": "tiktoken",
"yaml": "pyyaml",
"aiohttp": "aiohttp",
"psutil": "psutil",
"async_lru": "async-lru",
"ollama": "ollama",
"quart": "quart",
"quart_cors": "quart-cors",
"sqlalchemy": "sqlalchemy[asyncio]",
"aiosqlite": "aiosqlite",
"aiofiles": "aiofiles",
"aioshutil": "aioshutil",
"argon2": "argon2-cffi",
"jwt": "pyjwt",
"Crypto": "pycryptodome",
"lark_oapi": "lark-oapi",
"discord": "discord.py",
"cryptography": "cryptography",
"gewechat_client": "gewechat-client",
"dingtalk_stream": "dingtalk_stream",
"dashscope": "dashscope",
"telegram": "python-telegram-bot",
"certifi": "certifi",
"mcp": "mcp",
"sqlmodel": "sqlmodel",
'requests': 'requests',
'openai': 'openai',
'anthropic': 'anthropic',
'colorlog': 'colorlog',
'aiocqhttp': 'aiocqhttp',
'botpy': 'qq-botpy-rc',
'PIL': 'pillow',
'nakuru': 'nakuru-project-idk',
'tiktoken': 'tiktoken',
'yaml': 'pyyaml',
'aiohttp': 'aiohttp',
'psutil': 'psutil',
'async_lru': 'async-lru',
'ollama': 'ollama',
'quart': 'quart',
'quart_cors': 'quart-cors',
'sqlalchemy': 'sqlalchemy[asyncio]',
'aiosqlite': 'aiosqlite',
'aiofiles': 'aiofiles',
'aioshutil': 'aioshutil',
'argon2': 'argon2-cffi',
'jwt': 'pyjwt',
'Crypto': 'pycryptodome',
'lark_oapi': 'lark-oapi',
'discord': 'discord.py',
'cryptography': 'cryptography',
'gewechat_client': 'gewechat-client',
'dingtalk_stream': 'dingtalk_stream',
'dashscope': 'dashscope',
'telegram': 'python-telegram-bot',
'certifi': 'certifi',
'mcp': 'mcp',
'sqlmodel': 'sqlmodel',
}
@@ -52,20 +52,25 @@ async def check_deps() -> list[str]:
missing_deps.append(dep)
return missing_deps
async def install_deps(deps: list[str]):
global required_deps
for dep in deps:
pip.main(["install", required_deps[dep]])
pip.main(['install', required_deps[dep]])
async def precheck_plugin_deps():
print('[Startup] Prechecking plugin dependencies...')
# 只有在plugins目录存在时才执行插件依赖安装
if os.path.exists("plugins"):
for dir in os.listdir("plugins"):
subdir = os.path.join("plugins", dir)
if os.path.exists('plugins'):
for dir in os.listdir('plugins'):
subdir = os.path.join('plugins', dir)
if not os.path.isdir(subdir):
continue
if 'requirements.txt' in os.listdir(subdir):
pkgmgr.install_requirements(os.path.join(subdir, 'requirements.txt'), extra_params=['-q', '-q', '-q'])
pkgmgr.install_requirements(
os.path.join(subdir, 'requirements.txt'),
extra_params=['-q', '-q', '-q'],
)

View File

@@ -2,23 +2,23 @@ from __future__ import annotations
import os
import shutil
import sys
required_files = {
"plugins/__init__.py": "templates/__init__.py",
"data/config.yaml": "templates/config.yaml",
'plugins/__init__.py': 'templates/__init__.py',
'data/config.yaml': 'templates/config.yaml',
}
required_paths = [
"temp",
"data",
"data/metadata",
"data/logs",
"data/labels",
"plugins"
'temp',
'data',
'data/metadata',
'data/logs',
'data/labels',
'plugins',
]
async def generate_files() -> list[str]:
global required_files, required_paths

View File

@@ -1,5 +1,4 @@
import logging
import os
import sys
import time
@@ -9,11 +8,11 @@ from ...utils import constants
log_colors_config = {
"DEBUG": "green", # cyan white
"INFO": "white",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "cyan",
'DEBUG': 'green', # cyan white
'INFO': 'white',
'WARNING': 'yellow',
'ERROR': 'red',
'CRITICAL': 'cyan',
}
@@ -27,26 +26,31 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
if constants.debug_mode:
level = logging.DEBUG
log_file_name = "data/logs/langbot-%s.log" % time.strftime(
"%Y-%m-%d", time.localtime()
log_file_name = 'data/logs/langbot-%s.log' % time.strftime(
'%Y-%m-%d', time.localtime()
)
qcg_logger = logging.getLogger("langbot")
qcg_logger = logging.getLogger('langbot')
qcg_logger.setLevel(level)
color_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s",
datefmt="%m-%d %H:%M:%S",
fmt='%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s',
datefmt='%m-%d %H:%M:%S',
log_colors=log_colors_config,
)
stream_handler = logging.StreamHandler(sys.stdout)
# stream_handler.setLevel(level)
# stream_handler.setFormatter(color_formatter)
stream_handler.stream = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
stream_handler.stream = open(
sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1
)
log_handlers: list[logging.Handler] = [stream_handler, logging.FileHandler(log_file_name, encoding='utf-8')]
log_handlers: list[logging.Handler] = [
stream_handler,
logging.FileHandler(log_file_name, encoding='utf-8'),
]
log_handlers += extra_handlers if extra_handlers is not None else []
for handler in log_handlers:
@@ -54,13 +58,13 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
handler.setFormatter(color_formatter)
qcg_logger.addHandler(handler)
qcg_logger.debug("日志初始化完成,日志级别:%s" % level)
qcg_logger.debug('日志初始化完成,日志级别:%s' % level)
logging.basicConfig(
level=logging.CRITICAL, # 设置日志输出格式
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
format='[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s',
# 日志输出的格式
# -8表示占位符让输出左对齐输出长度都为8位
datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式
datefmt='%Y-%m-%d %H:%M:%S', # 时间输出的格式
handlers=[logging.NullHandler()],
)

View File

@@ -8,21 +8,18 @@ import asyncio
import pydantic.v1 as pydantic
from ..provider import entities as llm_entities
from ..provider.modelmgr import entities, modelmgr, requester
from ..provider.modelmgr import requester
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
from ..platform.types import entities as platform_entities
class LifecycleControlScope(enum.Enum):
APPLICATION = "application"
PLATFORM = "platform"
PLUGIN = "plugin"
PROVIDER = "provider"
APPLICATION = 'application'
PLATFORM = 'platform'
PLUGIN = 'plugin'
PROVIDER = 'provider'
class LauncherTypes(enum.Enum):
@@ -89,14 +86,17 @@ class Query(pydantic.BaseModel):
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] = []
resp_messages: (
typing.Optional[list[llm_entities.Message]]
| typing.Optional[list[platform_message.MessageChain]]
) = []
"""由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
"""回复消息链从resp_messages包装而得"""
# ======= 内部保留 =======
current_stage: "pkg.pipeline.pipelinemgr.StageInstContainer" = None
current_stage = None # pkg.pipeline.pipelinemgr.StageInstContainer
"""当前所处阶段"""
class Config:
@@ -109,13 +109,13 @@ class Query(pydantic.BaseModel):
if self.variables is None:
self.variables = {}
self.variables[key] = value
def get_variable(self, key: str) -> typing.Any:
"""获取变量"""
if self.variables is None:
return None
return self.variables.get(key)
def get_variables(self) -> dict[str, typing.Any]:
"""获取所有变量"""
if self.variables is None:
@@ -130,9 +130,13 @@ class Conversation(pydantic.BaseModel):
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
use_llm_model: requester.RuntimeLLMModel
@@ -147,6 +151,7 @@ class Conversation(pydantic.BaseModel):
class Session(pydantic.BaseModel):
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
launcher_type: LauncherTypes
launcher_id: typing.Union[int, str]
@@ -157,11 +162,17 @@ class Session(pydantic.BaseModel):
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
conversations: typing.Optional[list[Conversation]] = pydantic.Field(
default_factory=list
)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发"""

View File

@@ -9,21 +9,21 @@ from . import app
preregistered_migrations: list[typing.Type[Migration]] = []
"""当前阶段暂不支持扩展"""
def migration_class(name: str, number: int):
"""注册一个迁移
"""
"""注册一个迁移"""
def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]:
cls.name = name
cls.number = number
preregistered_migrations.append(cls)
return cls
return decorator
class Migration(abc.ABC):
"""一个版本的迁移
"""
"""一个版本的迁移"""
name: str
@@ -33,15 +33,13 @@ class Migration(abc.ABC):
def __init__(self, ap: app.Application):
self.ap = ap
@abc.abstractmethod
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
"""判断当前环境是否需要运行此迁移"""
pass
@abc.abstractmethod
async def run(self):
"""执行迁移
"""
"""执行迁移"""
pass

View File

@@ -1,26 +1,26 @@
from __future__ import annotations
import os
import sys
from .. import migration
@migration.migration_class("sensitive-word-migration", 1)
@migration.migration_class('sensitive-word-migration', 1)
class SensitiveWordMigration(migration.Migration):
"""敏感词迁移
"""
"""敏感词迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return os.path.exists("data/config/sensitive-words.json") and not os.path.exists("data/metadata/sensitive-words.json")
"""判断当前环境是否需要运行此迁移"""
return os.path.exists(
'data/config/sensitive-words.json'
) and not os.path.exists('data/metadata/sensitive-words.json')
async def run(self):
"""执行迁移
"""
"""执行迁移"""
# 移动文件
os.rename("data/config/sensitive-words.json", "data/metadata/sensitive-words.json")
os.rename(
'data/config/sensitive-words.json', 'data/metadata/sensitive-words.json'
)
# 重新加载配置
await self.ap.sensitive_meta.load_config()

View File

@@ -3,19 +3,16 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("openai-config-migration", 2)
@migration.migration_class('openai-config-migration', 2)
class OpenAIConfigMigration(migration.Migration):
"""OpenAI配置迁移
"""
"""OpenAI配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
"""判断当前环境是否需要运行此迁移"""
return 'openai-config' in self.ap.provider_cfg.data
async def run(self):
"""执行迁移
"""
"""执行迁移"""
old_openai_config = self.ap.provider_cfg.data['openai-config'].copy()
if 'keys' not in self.ap.provider_cfg.data:
@@ -26,7 +23,9 @@ class OpenAIConfigMigration(migration.Migration):
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model']
self.ap.provider_cfg.data['model'] = old_openai_config[
'chat-completions-params'
]['model']
del old_openai_config['chat-completions-params']['model']
@@ -35,7 +34,7 @@ class OpenAIConfigMigration(migration.Migration):
if 'openai-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {}
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {
'base-url': old_openai_config['base_url'],
'args': old_openai_config['chat-completions-params'],
@@ -44,4 +43,4 @@ class OpenAIConfigMigration(migration.Migration):
del self.ap.provider_cfg.data['openai-config']
await self.ap.provider_cfg.dump_config()
await self.ap.provider_cfg.dump_config()

View File

@@ -3,26 +3,23 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("anthropic-requester-config-completion", 3)
@migration.migration_class('anthropic-requester-config-completion', 3)
class AnthropicRequesterConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
"""OpenAI配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'anthropic-messages' not in self.ap.provider_cfg.data['requester'] \
"""判断当前环境是否需要运行此迁移"""
return (
'anthropic-messages' not in self.ap.provider_cfg.data['requester']
or 'anthropic' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移
"""
"""执行迁移"""
if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['anthropic-messages'] = {
'base-url': 'https://api.anthropic.com',
'args': {
'max_tokens': 1024
},
'args': {'max_tokens': 1024},
'timeout': 120,
}

View File

@@ -3,20 +3,19 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("moonshot-config-completion", 4)
@migration.migration_class('moonshot-config-completion', 4)
class MoonshotConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
"""OpenAI配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester'] \
"""判断当前环境是否需要运行此迁移"""
return (
'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'moonshot' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移
"""
"""执行迁移"""
if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = {
'base-url': 'https://api.moonshot.cn/v1',

View File

@@ -3,20 +3,19 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("deepseek-config-completion", 5)
@migration.migration_class('deepseek-config-completion', 5)
class DeepseekConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
"""OpenAI配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester'] \
"""判断当前环境是否需要运行此迁移"""
return (
'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'deepseek' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移
"""
"""执行迁移"""
if 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['deepseek-chat-completions'] = {
'base-url': 'https://api.deepseek.com',
@@ -27,4 +26,4 @@ class DeepseekConfigCompletionMigration(migration.Migration):
if 'deepseek' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['deepseek'] = []
await self.ap.provider_cfg.dump_config()
await self.ap.provider_cfg.dump_config()

View File

@@ -3,17 +3,17 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("vision-config", 6)
@migration.migration_class('vision-config', 6)
class VisionConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "enable-vision" not in self.ap.provider_cfg.data
return 'enable-vision' not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
if "enable-vision" not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data["enable-vision"] = False
if 'enable-vision' not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data['enable-vision'] = False
await self.ap.provider_cfg.dump_config()

View File

@@ -3,18 +3,20 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("qcg-center-url-config", 7)
@migration.migration_class('qcg-center-url-config', 7)
class QCGCenterURLConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "qcg-center-url" not in self.ap.system_cfg.data
return 'qcg-center-url' not in self.ap.system_cfg.data
async def run(self):
"""执行迁移"""
if "qcg-center-url" not in self.ap.system_cfg.data:
self.ap.system_cfg.data["qcg-center-url"] = "https://api.qchatgpt.rockchin.top/api/v2"
if 'qcg-center-url' not in self.ap.system_cfg.data:
self.ap.system_cfg.data['qcg-center-url'] = (
'https://api.qchatgpt.rockchin.top/api/v2'
)
await self.ap.system_cfg.dump_config()

View File

@@ -3,27 +3,27 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("ad-fixwin-cfg-migration", 8)
@migration.migration_class('ad-fixwin-cfg-migration', 8)
class AdFixwinConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return isinstance(
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]["default"],
int
self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int
)
async def run(self):
"""执行迁移"""
for session_name in self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]:
for session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
temp_dict = {
"window-size": 60,
"limit": self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name]
'window-size': 60,
'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][
session_name
],
}
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name] = temp_dict
await self.ap.pipeline_cfg.dump_config()
self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name] = temp_dict
await self.ap.pipeline_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("msg-truncator-cfg-migration", 9)
@migration.migration_class('msg-truncator-cfg-migration', 9)
class MsgTruncatorConfigMigration(migration.Migration):
"""迁移"""
@@ -13,12 +13,10 @@ class MsgTruncatorConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.pipeline_cfg.data['msg-truncate'] = {
'method': 'round',
'round': {
'max-round': 10
}
'round': {'max-round': 10},
}
await self.ap.pipeline_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("ollama-requester-config", 10)
@migration.migration_class('ollama-requester-config', 10)
class MsgTruncatorConfigMigration(migration.Migration):
"""迁移"""
@@ -13,11 +13,11 @@ class MsgTruncatorConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['ollama-chat'] = {
"base-url": "http://127.0.0.1:11434",
"args": {},
"timeout": 600
'base-url': 'http://127.0.0.1:11434',
'args': {},
'timeout': 600,
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("command-prefix-config", 11)
@migration.migration_class('command-prefix-config', 11)
class CommandPrefixConfigMigration(migration.Migration):
"""迁移"""
@@ -13,9 +13,7 @@ class CommandPrefixConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.command_cfg.data['command-prefix'] = [
"!", ""
]
self.ap.command_cfg.data['command-prefix'] = ['!', '']
await self.ap.command_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("runner-config", 12)
@migration.migration_class('runner-config', 12)
class RunnerConfigMigration(migration.Migration):
"""迁移"""
@@ -13,7 +13,7 @@ class RunnerConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['runner'] = 'local-agent'
await self.ap.provider_cfg.dump_config()

View File

@@ -3,29 +3,30 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("http-api-config", 13)
@migration.migration_class('http-api-config', 13)
class HttpApiConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'http-api' not in self.ap.system_cfg.data or "persistence" not in self.ap.system_cfg.data
return (
'http-api' not in self.ap.system_cfg.data
or 'persistence' not in self.ap.system_cfg.data
)
async def run(self):
"""执行迁移"""
self.ap.system_cfg.data['http-api'] = {
"enable": True,
"host": "0.0.0.0",
"port": 5300,
"jwt-expire": 604800
'enable': True,
'host': '0.0.0.0',
'port': 5300,
'jwt-expire': 604800,
}
self.ap.system_cfg.data['persistence'] = {
"sqlite": {
"path": "data/persistence.db"
},
"use": "sqlite"
'sqlite': {'path': 'data/persistence.db'},
'use': 'sqlite',
}
await self.ap.system_cfg.dump_config()

View File

@@ -3,20 +3,20 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("force-delay-config", 14)
@migration.migration_class('force-delay-config', 14)
class ForceDelayConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return type(self.ap.platform_cfg.data['force-delay']) == list
return isinstance(self.ap.platform_cfg.data['force-delay'], list)
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['force-delay'] = {
"min": self.ap.platform_cfg.data['force-delay'][0],
"max": self.ap.platform_cfg.data['force-delay'][1]
'min': self.ap.platform_cfg.data['force-delay'][0],
'max': self.ap.platform_cfg.data['force-delay'][1],
}
await self.ap.platform_cfg.dump_config()

View File

@@ -3,24 +3,25 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("gitee-ai-config", 15)
@migration.migration_class('gitee-ai-config', 15)
class GiteeAIConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'gitee-ai' not in self.ap.provider_cfg.data['keys']
return (
'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'gitee-ai' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['gitee-ai-chat-completions'] = {
"base-url": "https://ai.gitee.com/v1",
"args": {},
"timeout": 120
'base-url': 'https://ai.gitee.com/v1',
'args': {},
'timeout': 120,
}
self.ap.provider_cfg.data['keys']['gitee-ai'] = [
"XXXXX"
]
self.ap.provider_cfg.data['keys']['gitee-ai'] = ['XXXXX']
await self.ap.provider_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("dify-service-api-config", 16)
@migration.migration_class('dify-service-api-config', 16)
class DifyServiceAPICfgMigration(migration.Migration):
"""迁移"""
@@ -14,15 +14,10 @@ class DifyServiceAPICfgMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['dify-service-api'] = {
"base-url": "https://api.dify.ai/v1",
"app-type": "chat",
"chat": {
"api-key": "app-1234567890"
},
"workflow": {
"api-key": "app-1234567890",
"output-key": "summary"
}
'base-url': 'https://api.dify.ai/v1',
'app-type': 'chat',
'chat': {'api-key': 'app-1234567890'},
'workflow': {'api-key': 'app-1234567890', 'output-key': 'summary'},
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,22 +3,26 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("dify-api-timeout-params", 17)
@migration.migration_class('dify-api-timeout-params', 17)
class DifyAPITimeoutParamsMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat'] or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow'] \
return (
'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat']
or 'timeout'
not in self.ap.provider_cfg.data['dify-service-api']['workflow']
or 'agent' not in self.ap.provider_cfg.data['dify-service-api']
)
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['dify-service-api']['chat']['timeout'] = 120
self.ap.provider_cfg.data['dify-service-api']['workflow']['timeout'] = 120
self.ap.provider_cfg.data['dify-service-api']['agent'] = {
"api-key": "app-1234567890",
"timeout": 120
'api-key': 'app-1234567890',
'timeout': 120,
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("xai-config", 18)
@migration.migration_class('xai-config', 18)
class XaiConfigMigration(migration.Migration):
"""迁移"""
@@ -14,12 +14,10 @@ class XaiConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['xai-chat-completions'] = {
"base-url": "https://api.x.ai/v1",
"args": {},
"timeout": 120
'base-url': 'https://api.x.ai/v1',
'args': {},
'timeout': 120,
}
self.ap.provider_cfg.data['keys']['xai'] = [
"xai-1234567890"
]
self.ap.provider_cfg.data['keys']['xai'] = ['xai-1234567890']
await self.ap.provider_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("zhipuai-config", 19)
@migration.migration_class('zhipuai-config', 19)
class ZhipuaiConfigMigration(migration.Migration):
"""迁移"""
@@ -14,12 +14,10 @@ class ZhipuaiConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['zhipuai-chat-completions'] = {
"base-url": "https://open.bigmodel.cn/api/paas/v4",
"args": {},
"timeout": 120
'base-url': 'https://open.bigmodel.cn/api/paas/v4',
'args': {},
'timeout': 120,
}
self.ap.provider_cfg.data['keys']['zhipuai'] = [
"xxxxxxx"
]
self.ap.provider_cfg.data['keys']['zhipuai'] = ['xxxxxxx']
await self.ap.provider_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("wecom-config", 20)
@migration.migration_class('wecom-config', 20)
class WecomConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'wecom':
# return False
@@ -19,16 +19,18 @@ class WecomConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "wecom",
"enable": False,
"host": "0.0.0.0",
"port": 2290,
"corpid": "",
"secret": "",
"token": "",
"EncodingAESKey": "",
"contacts_secret": ""
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'wecom',
'enable': False,
'host': '0.0.0.0',
'port': 2290,
'corpid': '',
'secret': '',
'token': '',
'EncodingAESKey': '',
'contacts_secret': '',
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("lark-config", 21)
@migration.migration_class('lark-config', 21)
class LarkConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'lark':
# return False
@@ -19,15 +19,17 @@ class LarkConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "lark",
"enable": False,
"app_id": "cli_abcdefgh",
"app_secret": "XXXXXXXXXX",
"bot_name": "LangBot",
"enable-webhook": False,
"port": 2285,
"encrypt-key": "xxxxxxxxx"
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'lark',
'enable': False,
'app_id': 'cli_abcdefgh',
'app_secret': 'XXXXXXXXXX',
'bot_name': 'LangBot',
'enable-webhook': False,
'port': 2285,
'encrypt-key': 'xxxxxxxxx',
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,21 +3,21 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("lmstudio-config", 22)
@migration.migration_class('lmstudio-config', 22)
class LmStudioConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'lmstudio-chat-completions' not in self.ap.provider_cfg.data['requester']
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['lmstudio-chat-completions'] = {
"base-url": "http://127.0.0.1:1234/v1",
"args": {},
"timeout": 120
'base-url': 'http://127.0.0.1:1234/v1',
'args': {},
'timeout': 120,
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,25 +3,25 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("siliconflow-config", 23)
@migration.migration_class('siliconflow-config', 23)
class SiliconFlowConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
return (
'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
)
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['keys']['siliconflow'] = [
"xxxxxxx"
]
self.ap.provider_cfg.data['keys']['siliconflow'] = ['xxxxxxx']
self.ap.provider_cfg.data['requester']['siliconflow-chat-completions'] = {
"base-url": "https://api.siliconflow.cn/v1",
"args": {},
"timeout": 120
'base-url': 'https://api.siliconflow.cn/v1',
'args': {},
'timeout': 120,
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("discord-config", 24)
@migration.migration_class('discord-config', 24)
class DiscordConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'discord':
# return False
@@ -19,11 +19,13 @@ class DiscordConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "discord",
"enable": False,
"client_id": "1234567890",
"token": "XXXXXXXXXX"
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'discord',
'enable': False,
'client_id': '1234567890',
'token': 'XXXXXXXXXX',
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("gewechat-config", 25)
@migration.migration_class('gewechat-config', 25)
class GewechatConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'gewechat':
# return False
@@ -19,15 +19,17 @@ class GewechatConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "gewechat",
"enable": False,
"gewechat_url": "http://your-gewechat-server:2531",
"gewechat_file_url": "http://your-gewechat-server:2532",
"port": 2286,
"callback_url": "http://your-callback-url:2286/gewechat/callback",
"app_id": "",
"token": ""
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'gewechat',
'enable': False,
'gewechat_url': 'http://your-gewechat-server:2531',
'gewechat_file_url': 'http://your-gewechat-server:2532',
'port': 2286,
'callback_url': 'http://your-callback-url:2286/gewechat/callback',
'app_id': '',
'token': '',
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("qqofficial-config", 26)
@migration.migration_class('qqofficial-config', 26)
class QQOfficialConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'qqofficial':
# return False
@@ -19,13 +19,15 @@ class QQOfficialConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "qqofficial",
"enable": False,
"appid": "",
"secret": "",
"port": 2284,
"token": ""
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'qqofficial',
'enable': False,
'appid': '',
'secret': '',
'port': 2284,
'token': '',
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("wx-official-account-config", 27)
@migration.migration_class('wx-official-account-config', 27)
class WXOfficialAccountConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'officialaccount':
# return False
@@ -19,15 +19,17 @@ class WXOfficialAccountConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "officialaccount",
"enable": False,
"token": "",
"EncodingAESKey": "",
"AppID": "",
"AppSecret": "",
"host": "0.0.0.0",
"port": 2287
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'officialaccount',
'enable': False,
'token': '',
'EncodingAESKey': '',
'AppID': '',
'AppSecret': '',
'host': '0.0.0.0',
'port': 2287,
}
)
await self.ap.platform_cfg.dump_config()

Some files were not shown because too many files have changed in this diff Show More