perf: ruff check --fix

This commit is contained in:
Junyan Qin
2025-07-05 21:56:54 +08:00
parent 39c062f73e
commit 8d28ace252
23 changed files with 647 additions and 737 deletions

View File

@@ -1 +1 @@
from .client import WeChatPadClient
from .client import WeChatPadClient as WeChatPadClient

View File

@@ -1,4 +1,4 @@
from libs.wechatpad_api.util.http_util import async_request, post_json
from libs.wechatpad_api.util.http_util import post_json
class ChatRoomApi:
@@ -7,8 +7,6 @@ class ChatRoomApi:
self.token = token
def get_chatroom_member_detail(self, chatroom_name):
params = {
"ChatRoomName": chatroom_name
}
params = {'ChatRoomName': chatroom_name}
url = self.base_url + '/group/GetChatroomMemberDetail'
return post_json(url, token=self.token, data=params)

View File

@@ -1,32 +1,23 @@
from libs.wechatpad_api.util.http_util import async_request, post_json
from libs.wechatpad_api.util.http_util import post_json
import httpx
import base64
class DownloadApi:
def __init__(self, base_url, token):
self.base_url = base_url
self.token = token
def send_download(self, aeskey, file_type, file_url):
json_data = {
"AesKey": aeskey,
"FileType": file_type,
"FileURL": file_url
}
url = self.base_url + "/message/SendCdnDownload"
json_data = {'AesKey': aeskey, 'FileType': file_type, 'FileURL': file_url}
url = self.base_url + '/message/SendCdnDownload'
return post_json(url, token=self.token, data=json_data)
def get_msg_voice(self, buf_id, length, new_msgid):
json_data = {
"Bufid": buf_id,
"Length": length,
"NewMsgId": new_msgid,
"ToUserName": ""
}
url = self.base_url + "/message/GetMsgVoice"
json_data = {'Bufid': buf_id, 'Length': length, 'NewMsgId': new_msgid, 'ToUserName': ''}
url = self.base_url + '/message/GetMsgVoice'
return post_json(url, token=self.token, data=json_data)
async def download_url_to_base64(self, download_url):
async with httpx.AsyncClient() as client:
response = await client.get(download_url)

View File

@@ -1,11 +1,6 @@
from libs.wechatpad_api.util.http_util import post_json,async_request
from typing import List, Dict, Any, Optional
class FriendApi:
"""联系人API类处理所有与联系人相关的操作"""
def __init__(self, base_url: str, token: str):
self.base_url = base_url
self.token = token

View File

@@ -1,37 +1,34 @@
from libs.wechatpad_api.util.http_util import async_request,post_json,get_json
from libs.wechatpad_api.util.http_util import post_json, get_json
class LoginApi:
def __init__(self, base_url: str, token: str = None, admin_key: str = None):
'''
"""
Args:
base_url: 原始路径
token: token
admin_key: 管理员key
'''
"""
self.base_url = base_url
self.token = token
# self.admin_key = admin_key
def get_token(self, admin_key, day: int = 365):
# 获取普通token
url = f"{self.base_url}/admin/GenAuthKey1"
json_data = {
"Count": 1,
"Days": day
}
url = f'{self.base_url}/admin/GenAuthKey1'
json_data = {'Count': 1, 'Days': day}
return post_json(base_url=url, token=admin_key, data=json_data)
def get_login_qr(self, Proxy: str = ""):
'''
def get_login_qr(self, Proxy: str = ''):
"""
Args:
Proxy:异地使用时代理
Returns:json数据
'''
"""
"""
{
@@ -50,53 +47,36 @@ class LoginApi:
"""
# 获取登录二维码
url = f"{self.base_url}/login/GetLoginQrCodeNew"
url = f'{self.base_url}/login/GetLoginQrCodeNew'
check = False
if Proxy != "":
if Proxy != '':
check = True
json_data = {
"Check": check,
"Proxy": Proxy
}
json_data = {'Check': check, 'Proxy': Proxy}
return post_json(base_url=url, token=self.token, data=json_data)
def get_login_status(self):
# 获取登录状态
url = f'{self.base_url}/login/GetLoginStatus'
return get_json(base_url=url, token=self.token)
def logout(self):
# 退出登录
url = f'{self.base_url}/login/LogOut'
return post_json(base_url=url, token=self.token)
def wake_up_login(self, Proxy: str = ""):
def wake_up_login(self, Proxy: str = ''):
# 唤醒登录
url = f'{self.base_url}/login/WakeUpLogin'
check = False
if Proxy != "":
if Proxy != '':
check = True
json_data = {
"Check": check,
"Proxy": ""
}
json_data = {'Check': check, 'Proxy': ''}
return post_json(base_url=url, token=self.token, data=json_data)
def login(self, admin_key):
login_status = self.get_login_status()
if login_status["Code"] == 300 and login_status["Text"] == "你已退出微信":
print("token已经失效重新获取")
if login_status['Code'] == 300 and login_status['Text'] == '你已退出微信':
print('token已经失效重新获取')
token_data = self.get_token(admin_key)
self.token = token_data["Data"][0]
self.token = token_data['Data'][0]

View File

@@ -1,5 +1,4 @@
from libs.wechatpad_api.util.http_util import async_request, post_json
from libs.wechatpad_api.util.http_util import post_json
class MessageApi:
@@ -8,7 +7,7 @@ class MessageApi:
self.token = token
def post_text(self, to_wxid, content, ats: list = []):
'''
"""
Args:
app_id: 微信id
@@ -18,106 +17,64 @@ class MessageApi:
Returns:
'''
url = self.base_url + "/message/SendTextMessage"
"""
url = self.base_url + '/message/SendTextMessage'
"""发送文字消息"""
json_data = {
"MsgItem": [
{
"AtWxIDList": ats,
"ImageContent": "",
"MsgType": 0,
"TextContent": content,
"ToUserName": to_wxid
}
'MsgItem': [
{'AtWxIDList': ats, 'ImageContent': '', 'MsgType': 0, 'TextContent': content, 'ToUserName': to_wxid}
]
}
return post_json(base_url=url, token=self.token, data=json_data)
def post_image(self, to_wxid, img_url, ats: list = []):
"""发送图片消息"""
# 这里好像可以尝试发送多个暂时未测试
json_data = {
"MsgItem": [
{
"AtWxIDList": ats,
"ImageContent": img_url,
"MsgType": 0,
"TextContent": '',
"ToUserName": to_wxid
}
'MsgItem': [
{'AtWxIDList': ats, 'ImageContent': img_url, 'MsgType': 0, 'TextContent': '', 'ToUserName': to_wxid}
]
}
url = self.base_url + "/message/SendImageMessage"
url = self.base_url + '/message/SendImageMessage'
return post_json(base_url=url, token=self.token, data=json_data)
def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration):
"""发送语音消息"""
json_data = {
"ToUserName": to_wxid,
"VoiceData": voice_data,
"VoiceFormat": voice_forma,
"VoiceSecond": voice_duration
'ToUserName': to_wxid,
'VoiceData': voice_data,
'VoiceFormat': voice_forma,
'VoiceSecond': voice_duration,
}
url = self.base_url + "/message/SendVoice"
url = self.base_url + '/message/SendVoice'
return post_json(base_url=url, token=self.token, data=json_data)
def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag):
"""发送名片消息"""
param = {
"CardAlias": alias,
"CardFlag": flag,
"CardNickName": nick_name,
"CardWxId": name_card_wxid,
"ToUserName": to_wxid
'CardAlias': alias,
'CardFlag': flag,
'CardNickName': nick_name,
'CardWxId': name_card_wxid,
'ToUserName': to_wxid,
}
url = f"{self.base_url}/message/ShareCardMessage"
url = f'{self.base_url}/message/ShareCardMessage'
return post_json(base_url=url, token=self.token, data=param)
def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0):
"""发送emoji消息"""
json_data = {
"EmojiList": [
{
"EmojiMd5": emoji_md5,
"EmojiSize": emoji_size,
"ToUserName": to_wxid
}
]
}
url = f"{self.base_url}/message/SendEmojiMessage"
json_data = {'EmojiList': [{'EmojiMd5': emoji_md5, 'EmojiSize': emoji_size, 'ToUserName': to_wxid}]}
url = f'{self.base_url}/message/SendEmojiMessage'
return post_json(base_url=url, token=self.token, data=json_data)
def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0):
"""发送appmsg消息"""
json_data = {
"AppList": [
{
"ContentType": contenttype,
"ContentXML": xml_data,
"ToUserName": to_wxid
}
]
}
url = f"{self.base_url}/message/SendAppMessage"
json_data = {'AppList': [{'ContentType': contenttype, 'ContentXML': xml_data, 'ToUserName': to_wxid}]}
url = f'{self.base_url}/message/SendAppMessage'
return post_json(base_url=url, token=self.token, data=json_data)
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
"""撤回消息"""
param = {
"ClientMsgId": msg_id,
"CreateTime": create_time,
"NewMsgId": new_msg_id,
"ToUserName": to_wxid
}
url = f"{self.base_url}/message/RevokeMsg"
param = {'ClientMsgId': msg_id, 'CreateTime': create_time, 'NewMsgId': new_msg_id, 'ToUserName': to_wxid}
url = f'{self.base_url}/message/RevokeMsg'
return post_json(base_url=url, token=self.token, data=param)

View File

@@ -1,10 +1,9 @@
import requests
import aiohttp
def post_json(base_url, token, data=None):
headers = {
'Content-Type': 'application/json'
}
headers = {'Content-Type': 'application/json'}
url = base_url + f'?key={token}'
@@ -18,14 +17,12 @@ def post_json(base_url, token, data=None):
else:
raise RuntimeError(response.text)
except Exception as e:
print(f"http请求失败, url={url}, exception={e}")
print(f'http请求失败, url={url}, exception={e}')
raise RuntimeError(str(e))
def get_json(base_url, token):
headers = {
'Content-Type': 'application/json'
}
def get_json(base_url, token):
headers = {'Content-Type': 'application/json'}
url = base_url + f'?key={token}'
@@ -39,12 +36,9 @@ def get_json(base_url, token):
else:
raise RuntimeError(response.text)
except Exception as e:
print(f"http请求失败, url={url}, exception={e}")
print(f'http请求失败, url={url}, exception={e}')
raise RuntimeError(str(e))
import aiohttp
import asyncio
async def async_request(
base_url: str,
@@ -53,7 +47,7 @@ async def async_request(
params: dict = None,
# headers: dict = None,
data: dict = None,
json: dict = None
json: dict = None,
):
"""
通用异步请求函数
@@ -67,18 +61,11 @@ async def async_request(
:param json: JSON数据
:return: 响应文本
"""
headers = {
'Content-Type': 'application/json'
}
url = f"{base_url}?key={token_key}"
headers = {'Content-Type': 'application/json'}
url = f'{base_url}?key={token_key}'
async with aiohttp.ClientSession() as session:
async with session.request(
method=method,
url=url,
params=params,
headers=headers,
data=data,
json=json
method=method, url=url, params=params, headers=headers, data=data, json=json
) as response:
response.raise_for_status() # 如果状态码不是200抛出异常
result = await response.json()
@@ -89,4 +76,3 @@ async def async_request(
# return await result
# else:
# raise RuntimeError("请求失败",response.text)

View File

@@ -1,14 +1,13 @@
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
from datetime import datetime
import numpy as np # 用于处理从LargeBinary转换回来的embedding
from sqlalchemy import Column, Integer, ForeignKey, LargeBinary
from sqlalchemy.orm import declarative_base, relationship
Base = declarative_base()
class Vector(Base):
__tablename__ = 'vectors'
id = Column(Integer, primary_key=True, index=True)
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
embedding = Column(LargeBinary) # Store embeddings as binary
chunk = relationship("Chunk", back_populates="vector")
chunk = relationship('Chunk', back_populates='vector')

View File

@@ -16,7 +16,6 @@ from ..logger import EventLogger
class AiocqhttpMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(
message_chain: platform_message.MessageChain,
@@ -62,7 +61,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
for node in msg.node_list:
msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0])
elif isinstance(msg, platform_message.File):
msg_list.append({"type":"file", "data":{'file': msg.url, "name": msg.name}})
msg_list.append({'type': 'file', 'data': {'file': msg.url, 'name': msg.name}})
elif isinstance(msg, platform_message.Face):
if msg.face_type == 'face':
msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id))
@@ -71,7 +70,6 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif msg.face_type == 'dice':
msg_list.append(aiocqhttp.MessageSegment.dice())
else:
msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))
@@ -84,65 +82,149 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
def get_face_name(face_id):
face_code_dict = {
"2": '好色',
"4": "得意", "5": "流泪", "8": "", "9": "大哭", "10": "尴尬", "12": "调皮", "14": "微笑", "16": "",
"21": "可爱",
"23": "傲慢", "24": "饥饿", "25": "", "26": "惊恐", "27": "流汗", "28": "憨笑", "29": "悠闲",
"30": "奋斗",
"32": "疑问", "33": "", "34": "", "38": "敲打", "39": "再见", "41": "发抖", "42": "爱情",
"43": "跳跳",
"49": "拥抱", "53": "蛋糕", "60": "咖啡", "63": "玫瑰", "66": "爱心", "74": "太阳", "75": "月亮",
"76": "",
"78": "握手", "79": "胜利", "85": "飞吻", "89": "西瓜", "96": "冷汗", "97": "擦汗", "98": "抠鼻",
"99": "鼓掌",
"100": "糗大了", "101": "坏笑", "102": "左哼哼", "103": "右哼哼", "104": "哈欠", "106": "委屈",
"109": "左亲亲",
"111": "可怜", "116": "示爱", "118": "抱拳", "120": "拳头", "122": "爱你", "123": "NO", "124": "OK",
"125": "转圈",
"129": "挥手", "144": "喝彩", "147": "棒棒糖", "171": "", "173": "泪奔", "174": "无奈", "175": "卖萌",
"176": "小纠结", "179": "doge", "180": "惊喜", "181": "骚扰", "182": "笑哭", "183": "我最美",
"201": "点赞",
"203": "托脸", "212": "托腮", "214": "啵啵", "219": "蹭一蹭", "222": "抱抱", "227": "拍手",
"232": "佛系",
"240": "喷脸", "243": "甩头", "246": "加油抱抱", "262": "脑阔疼", "264": "捂脸", "265": "辣眼睛",
"266": "哦哟",
"267": "头秃", "268": "问号脸", "269": "暗中观察", "270": "emm", "271": "吃瓜", "272": "呵呵哒",
"273": "我酸了",
"277": "汪汪", "278": "", "281": "无眼笑", "282": "敬礼", "284": "面无表情", "285": "摸鱼",
"287": "",
"289": "睁眼", "290": "敲开心", "293": "摸锦鲤", "294": "期待", "297": "拜谢", "298": "元宝",
"299": "牛啊",
"305": "右亲亲", "306": "牛气冲天", "307": "喵喵", "314": "仔细分析", "315": "加油", "318": "崇拜",
"319": "比心",
"320": "庆祝", "322": "拒绝", "324": "吃糖", "326": "生气"
'2': '好色',
'4': '得意',
'5': '流泪',
'8': '',
'9': '大哭',
'10': '尴尬',
'12': '调皮',
'14': '微笑',
'16': '',
'21': '可爱',
'23': '傲慢',
'24': '饥饿',
'25': '',
'26': '惊恐',
'27': '流汗',
'28': '憨笑',
'29': '悠闲',
'30': '奋斗',
'32': '疑问',
'33': '',
'34': '',
'38': '敲打',
'39': '再见',
'41': '发抖',
'42': '爱情',
'43': '跳跳',
'49': '拥抱',
'53': '蛋糕',
'60': '咖啡',
'63': '玫瑰',
'66': '爱心',
'74': '太阳',
'75': '月亮',
'76': '',
'78': '握手',
'79': '胜利',
'85': '飞吻',
'89': '西瓜',
'96': '冷汗',
'97': '擦汗',
'98': '抠鼻',
'99': '鼓掌',
'100': '糗大了',
'101': '坏笑',
'102': '左哼哼',
'103': '右哼哼',
'104': '哈欠',
'106': '委屈',
'109': '左亲亲',
'111': '可怜',
'116': '示爱',
'118': '抱拳',
'120': '拳头',
'122': '爱你',
'123': 'NO',
'124': 'OK',
'125': '转圈',
'129': '挥手',
'144': '喝彩',
'147': '棒棒糖',
'171': '',
'173': '泪奔',
'174': '无奈',
'175': '卖萌',
'176': '小纠结',
'179': 'doge',
'180': '惊喜',
'181': '骚扰',
'182': '笑哭',
'183': '我最美',
'201': '点赞',
'203': '托脸',
'212': '托腮',
'214': '啵啵',
'219': '蹭一蹭',
'222': '抱抱',
'227': '拍手',
'232': '佛系',
'240': '喷脸',
'243': '甩头',
'246': '加油抱抱',
'262': '脑阔疼',
'264': '捂脸',
'265': '辣眼睛',
'266': '哦哟',
'267': '头秃',
'268': '问号脸',
'269': '暗中观察',
'270': 'emm',
'271': '吃瓜',
'272': '呵呵哒',
'273': '我酸了',
'277': '汪汪',
'278': '',
'281': '无眼笑',
'282': '敬礼',
'284': '面无表情',
'285': '摸鱼',
'287': '',
'289': '睁眼',
'290': '敲开心',
'293': '摸锦鲤',
'294': '期待',
'297': '拜谢',
'298': '元宝',
'299': '牛啊',
'305': '右亲亲',
'306': '牛气冲天',
'307': '喵喵',
'314': '仔细分析',
'315': '加油',
'318': '崇拜',
'319': '比心',
'320': '庆祝',
'322': '拒绝',
'324': '吃糖',
'326': '生气',
}
return face_code_dict.get(face_id, '')
async def process_message_data(msg_data, reply_list):
if msg_data["type"] == "image":
image_base64, image_format = await image.qq_image_url_to_base64(msg_data["data"]['url'])
reply_list.append(
platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
if msg_data['type'] == 'image':
image_base64, image_format = await image.qq_image_url_to_base64(msg_data['data']['url'])
reply_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
elif msg_data["type"] == "text":
reply_list.append(platform_message.Plain(text=msg_data["data"]["text"]))
elif msg_data['type'] == 'text':
reply_list.append(platform_message.Plain(text=msg_data['data']['text']))
elif msg_data["type"] == "forward": # 这里来应该传入转发消息组暂时传入qoute
for forward_msg_datas in msg_data["data"]["content"]:
for forward_msg_data in forward_msg_datas["message"]:
elif msg_data['type'] == 'forward': # 这里来应该传入转发消息组暂时传入qoute
for forward_msg_datas in msg_data['data']['content']:
for forward_msg_data in forward_msg_datas['message']:
await process_message_data(forward_msg_data, reply_list)
elif msg_data["type"] == "at":
if msg_data["data"]['qq'] == 'all':
elif msg_data['type'] == 'at':
if msg_data['data']['qq'] == 'all':
reply_list.append(platform_message.AtAll())
else:
reply_list.append(
platform_message.At(
target=msg_data["data"]['qq'],
target=msg_data['data']['qq'],
)
)
yiri_msg_list = []
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
@@ -161,10 +243,10 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif msg.type == 'text':
yiri_msg_list.append(platform_message.Plain(text=msg.data['text']))
elif msg.type == 'image':
emoji_id = msg.data.get("emoji_package_id", None)
emoji_id = msg.data.get('emoji_package_id', None)
if emoji_id:
face_id = emoji_id
face_name = msg.data.get("summary", '')
face_name = msg.data.get('summary', '')
image_msg = platform_message.Face(face_id=face_id, face_name=face_name)
else:
image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url'])
@@ -178,14 +260,15 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
# await process_message_data(msg_data, yiri_msg_list)
pass
elif msg.type == 'reply': # 此处处理引用消息传入Qoute
msg_datas = await bot.get_msg(message_id=msg.data["id"])
msg_datas = await bot.get_msg(message_id=msg.data['id'])
for msg_data in msg_datas["message"]:
for msg_data in msg_datas['message']:
await process_message_data(msg_data, reply_list)
reply_msg = platform_message.Quote(message_id=msg.data["id"],sender_id=msg_datas["user_id"],origin=reply_list)
reply_msg = platform_message.Quote(
message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list
)
yiri_msg_list.append(reply_msg)
elif msg.type == 'file':
@@ -194,6 +277,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
file_data = await bot.get_file(file_id=file_id)
file_name = file_data.get('file_name')
file_path = file_data.get('file')
_ = file_path
file_url = file_data.get('file_url')
file_size = file_data.get('file_size')
yiri_msg_list.append(platform_message.File(id=file_id, name=file_name, url=file_url, size=file_size))
@@ -205,28 +289,16 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
yiri_msg_list.append(platform_message.Face(face_id=int(face_id), face_name=face_name.replace('/', '')))
elif msg.type == 'rps':
face_id = msg.data['result']
yiri_msg_list.append(platform_message.Face(face_type="rps",face_id=int(face_id),face_name='猜拳'))
yiri_msg_list.append(platform_message.Face(face_type='rps', face_id=int(face_id), face_name='猜拳'))
elif msg.type == 'dice':
face_id = msg.data['result']
yiri_msg_list.append(platform_message.Face(face_type='dice', face_id=int(face_id), face_name='骰子'))
chain = platform_message.MessageChain(yiri_msg_list)
return chain
class AiocqhttpEventConverter(adapter.EventConverter):
@staticmethod
async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int):
@@ -236,8 +308,6 @@ class AiocqhttpEventConverter(adapter.EventConverter):
async def target2yiri(event: aiocqhttp.Event, bot=None):
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot)
if event.message_type == 'group':
permission = 'MEMBER'
@@ -316,7 +386,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
if target_type == 'group':
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
elif target_type == 'person':
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)

View File

@@ -8,7 +8,6 @@ import base64
import uuid
import os
import datetime
import io
import aiohttp
@@ -101,12 +100,13 @@ class DiscordMessageConverter(adapter.MessageConverter):
filename = f'{uuid.uuid4()}.webp'
# 默认保持PNG
except Exception as e:
print(f"Error reading image file {clean_path}: {e}")
print(f'Error reading image file {clean_path}: {e}')
continue # 跳过读取失败的文件
if image_bytes:
# 使用BytesIO创建文件对象避免路径问题
import io
image_files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename))
elif isinstance(ele, platform_message.Plain):
text_string += ele.text
@@ -279,7 +279,7 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
await channel.send(**args)
except Exception as e:
await self.logger.error(f"Discord send_message failed: {e}")
await self.logger.error(f'Discord send_message failed: {e}')
raise e
async def reply_message(

View File

@@ -378,15 +378,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
if 'im.message.receive_v1' == type:
try:
event = await self.event_converter.target2yiri(p2v1, self.api_client)
except Exception as e:
await self.logger.error(f"Error in lark callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
if event.__class__ in self.listeners:
await self.listeners[event.__class__](event, self)
return {'code': 200, 'message': 'ok'}
except Exception as e:
await self.logger.error(f"Error in lark callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
return {'code': 500, 'message': 'error'}
async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1):

View File

@@ -72,8 +72,9 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
content=content_list,
)
nakuru_forward_node_list.append(nakuru_forward_node)
except Exception as e:
except Exception:
import traceback
traceback.print_exc()
nakuru_msg_list.append(nakuru_forward_node_list)
@@ -276,7 +277,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
# 注册监听器
self.bot.receiver(source_cls.__name__)(listener_wrapper)
except Exception as e:
self.logger.error(f"Error in nakuru register_listener: {traceback.format_exc()}")
self.logger.error(f'Error in nakuru register_listener: {traceback.format_exc()}')
raise e
def unregister_listener(

View File

@@ -125,8 +125,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = event.receiver_id
try:
return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e:
await self.logger.error(f"Error in officialaccount callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in officialaccount callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage:
self.bot.on_message('text')(on_message)

View File

@@ -154,10 +154,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
raise ParamNotEnoughError('QQ官方机器人缺少相关配置项请查看文档或联系管理员')
self.bot = QQOfficialClient(
app_id=config['appid'],
secret=config['secret'],
token=config['token'],
logger=self.logger
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger
)
async def reply_message(
@@ -224,8 +221,8 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = 'justbot'
try:
return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e:
await self.logger.error(f"Error in qqofficial callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in qqofficial callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage:
self.bot.on_message('DIRECT_MESSAGE_CREATE')(on_message)

View File

@@ -104,7 +104,9 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
if missing_keys:
raise ParamNotEnoughError('Slack机器人缺少相关配置项请查看文档或联系管理员')
self.bot = SlackClient(bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger)
self.bot = SlackClient(
bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger
)
async def reply_message(
self,
@@ -139,8 +141,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = 'SlackBot'
try:
return await callback(await self.event_converter.target2yiri(event, self.bot), self)
except Exception as e:
await self.logger.error(f"Error in slack callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in slack callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage:
self.bot.on_message('im')(on_message)

View File

@@ -160,8 +160,8 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
try:
lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id)
await self.listeners[type(lb_event)](lb_event, self)
except Exception as e:
await self.logger.error(f"Error in telegram callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}')
self.application = ApplicationBuilder().token(self.config['token']).build()
self.bot = self.application.bot

View File

@@ -1,5 +1,4 @@
import requests
import websockets
import websocket
import json
import time
@@ -10,53 +9,40 @@ from libs.wechatpad_api.client import WeChatPadClient
import typing
import asyncio
import traceback
import time
import re
import base64
import uuid
import json
import os
import copy
import datetime
import threading
import quart
import aiohttp
from .. import adapter
from ...pipeline.longtext.strategies import forward
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ...utils import image
from ..logger import EventLogger
import xml.etree.ElementTree as ET
from typing import Optional, List, Tuple
from typing import Optional, Tuple
from functools import partial
import logging
class WeChatPadMessageConverter(adapter.MessageConverter):
class WeChatPadMessageConverter(adapter.MessageConverter):
def __init__(self, config: dict):
self.config = config
self.bot = WeChatPadClient(self.config["wechatpad_url"],self.config["token"])
self.logger = logging.getLogger("WeChatPadMessageConverter")
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
self.logger = logging.getLogger('WeChatPadMessageConverter')
@staticmethod
async def yiri2target(
message_chain: platform_message.MessageChain
) -> list[dict]:
async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]:
content_list = []
current_file_path = os.path.abspath(__file__)
for component in message_chain:
if isinstance(component, platform_message.At):
content_list.append({"type": "at", "target": component.target})
content_list.append({'type': 'at', 'target': component.target})
elif isinstance(component, platform_message.Plain):
content_list.append({"type": "text", "content": component.text})
content_list.append({'type': 'text', 'content': component.text})
elif isinstance(component, platform_message.Image):
if component.url:
async with httpx.AsyncClient() as client:
@@ -68,15 +54,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
else:
raise Exception('获取文件失败')
# pass
content_list.append({"type": "image", "image": base64_str})
content_list.append({'type': 'image', 'image': base64_str})
elif component.base64:
content_list.append({"type": "image", "image": component.base64})
content_list.append({'type': 'image', 'image': component.base64})
elif isinstance(component, platform_message.WeChatEmoji):
content_list.append(
{'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size})
{'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size}
)
elif isinstance(component, platform_message.Voice):
content_list.append({"type": "voice", "data": component.url, "duration": component.length, "forma": 0})
content_list.append({'type': 'voice', 'data': component.url, 'duration': component.length, 'forma': 0})
elif isinstance(component, platform_message.WeChatAppMsg):
content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg})
elif isinstance(component, platform_message.Forward):
@@ -86,28 +73,23 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return content_list
async def target2yiri(
self,
message: dict,
bot_account_id: str
) -> platform_message.MessageChain:
async def target2yiri(self, message: dict, bot_account_id: str) -> platform_message.MessageChain:
"""外部消息转平台消息"""
# 数据预处理
message_list = []
ats_bot = False # 是否被@
content = message["content"]["str"]
content = message['content']['str']
content_no_preifx = content # 群消息则去掉前缀
is_group_message = self._is_group_message(message)
if is_group_message:
ats_bot = self._ats_bot(message, bot_account_id)
if "@所有人" in content:
if '@所有人' in content:
message_list.append(platform_message.AtAll())
elif ats_bot:
message_list.append(platform_message.At(target=bot_account_id))
content_no_preifx, _ = self._extract_content_and_sender(content)
msg_type = message["msg_type"]
msg_type = message['msg_type']
# 映射消息类型到处理器方法
handler_map = {
@@ -129,11 +111,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_list)
async def _handler_text(
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理文本消息 (msg_type=1)"""
if message and self._is_group_message(message):
pattern = r'@\S{1,20}'
@@ -141,16 +119,12 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain([platform_message.Plain(content_no_preifx)])
async def _handler_image(
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理图像消息 (msg_type=3)"""
try:
image_xml = content_no_preifx
if not image_xml:
return platform_message.MessageChain([platform_message.Unknown("[图片内容为空]")])
return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')])
root = ET.fromstring(image_xml)
# 提取img标签的属性
@@ -160,28 +134,22 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
cdnthumburl = img_tag.get('cdnthumburl')
# cdnmidimgurl = img_tag.get('cdnmidimgurl')
image_data = self.bot.cdn_download(aeskey=aeskey, file_type=1, file_url=cdnthumburl)
if image_data["Data"]['FileData'] == '':
if image_data['Data']['FileData'] == '':
image_data = self.bot.cdn_download(aeskey=aeskey, file_type=2, file_url=cdnthumburl)
base64_str = image_data["Data"]['FileData']
base64_str = image_data['Data']['FileData']
# self.logger.info(f"data:image/png;base64,{base64_str}")
elements = [
platform_message.Image(base64=f"data:image/png;base64,{base64_str}"),
platform_message.Image(base64=f'data:image/png;base64,{base64_str}'),
# platform_message.WeChatForwardImage(xml_data=image_xml) # 微信消息转发
]
return platform_message.MessageChain(elements)
except Exception as e:
self.logger.error(f"处理图片失败: {str(e)}")
return platform_message.MessageChain([platform_message.Unknown("[图片处理失败]")])
self.logger.error(f'处理图片失败: {str(e)}')
return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')])
async def _handler_voice(
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理语音消息 (msg_type=34)"""
message_List = []
try:
@@ -197,39 +165,33 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
bufid = voicemsg.get('bufid')
length = voicemsg.get('voicelength')
voice_data = self.bot.get_msg_voice(buf_id=str(bufid), length=int(length), msgid=str(new_msg_id))
audio_base64 = voice_data["Data"]['Base64']
audio_base64 = voice_data['Data']['Base64']
# 验证语音数据有效性
if not audio_base64:
message_List.append(platform_message.Unknown(text="[语音内容为空]"))
message_List.append(platform_message.Unknown(text='[语音内容为空]'))
return platform_message.MessageChain(message_List)
# 转换为平台支持的语音格式(如 Silk 格式)
voice_element = platform_message.Voice(
base64=f"data:audio/silk;base64,{audio_base64}"
)
voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}')
message_List.append(voice_element)
except KeyError as e:
self.logger.error(f"语音数据字段缺失: {str(e)}")
message_List.append(platform_message.Unknown(text="[语音数据解析失败]"))
self.logger.error(f'语音数据字段缺失: {str(e)}')
message_List.append(platform_message.Unknown(text='[语音数据解析失败]'))
except Exception as e:
self.logger.error(f"处理语音消息异常: {str(e)}")
message_List.append(platform_message.Unknown(text="[语音处理失败]"))
self.logger.error(f'处理语音消息异常: {str(e)}')
message_List.append(platform_message.Unknown(text='[语音处理失败]'))
return platform_message.MessageChain(message_List)
async def _handler_compound(
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理复合消息 (msg_type=49),根据子类型分派"""
try:
xml_data = ET.fromstring(content_no_preifx)
appmsg_data = xml_data.find('.//appmsg')
if appmsg_data:
data_type = appmsg_data.findtext('.//type', "")
data_type = appmsg_data.findtext('.//type', '')
# 二次分派处理器
sub_handler_map = {
'57': self._handler_compound_quote,
@@ -238,9 +200,9 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
'74': self._handler_compound_file,
'33': self._handler_compound_mini_program,
'36': self._handler_compound_mini_program,
'2000': partial(self._handler_compound_unsupported, text="[转账消息]"),
'2001': partial(self._handler_compound_unsupported, text="[红包消息]"),
'51': partial(self._handler_compound_unsupported, text="[视频号消息]"),
'2000': partial(self._handler_compound_unsupported, text='[转账消息]'),
'2001': partial(self._handler_compound_unsupported, text='[红包消息]'),
'51': partial(self._handler_compound_unsupported, text='[视频号消息]'),
}
handler = sub_handler_map.get(data_type, self._handler_compound_unsupported)
@@ -251,56 +213,54 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
else:
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
except Exception as e:
self.logger.error(f"解析复合消息失败: {str(e)}")
self.logger.error(f'解析复合消息失败: {str(e)}')
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
async def _handler_compound_quote(
self,
message: Optional[dict],
xml_data: ET.Element
self, message: Optional[dict], xml_data: ET.Element
) -> platform_message.MessageChain:
"""处理引用消息 (data_type=57)"""
message_list = []
# self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode'))
appmsg_data = xml_data.find('.//appmsg')
quote_data = "" # 引用原文
quote_data = '' # 引用原文
quote_id = None # 引用消息的原发送者
tousername = None # 接收方: 所属微信的wxid
user_data = "" # 用户消息
user_data = '' # 用户消息
sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member
# 引用消息转发
if appmsg_data:
user_data = appmsg_data.findtext('.//title') or ""
user_data = appmsg_data.findtext('.//title') or ''
quote_data = appmsg_data.find('.//refermsg').findtext('.//content')
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr')
message_list.append(
platform_message.WeChatAppMsg(
app_msg=ET.tostring(appmsg_data, encoding='unicode'))
)
message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg_data, encoding='unicode')))
if message:
tousername = message['to_user_name']["str"]
tousername = message['to_user_name']['str']
_ = quote_id
_ = tousername
if quote_data:
quote_data_message_list = platform_message.MessageChain()
# 文本消息
try:
if "<msg>" not in quote_data:
if '<msg>' not in quote_data:
quote_data_message_list.append(platform_message.Plain(quote_data))
else:
# 引用消息展开
quote_data_xml = ET.fromstring(quote_data)
if quote_data_xml.find("img"):
if quote_data_xml.find('img'):
quote_data_message_list.extend(await self._handler_image(None, quote_data))
elif quote_data_xml.find("voicemsg"):
elif quote_data_xml.find('voicemsg'):
quote_data_message_list.extend(await self._handler_voice(None, quote_data))
elif quote_data_xml.find("videomsg"):
elif quote_data_xml.find('videomsg'):
quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理
else:
# appmsg
quote_data_message_list.extend(await self._handler_compound(None, quote_data))
except Exception as e:
self.logger.error(f"处理引用消息异常 expcetion:{e}")
self.logger.error(f'处理引用消息异常 expcetion:{e}')
quote_data_message_list.append(platform_message.Plain(quote_data))
message_list.append(
platform_message.Quote(
@@ -315,15 +275,11 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_list)
async def _handler_compound_file(
self,
message: dict,
xml_data: ET.Element
) -> platform_message.MessageChain:
async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
"""处理文件消息 (data_type=6)"""
file_data = xml_data.find('.//appmsg')
if file_data.findtext('.//type', "") == "74":
if file_data.findtext('.//type', '') == '74':
return None
else:
@@ -346,22 +302,21 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
file_data = self.bot.cdn_download(aeskey=aeskey, file_type=5, file_url=cdnthumburl)
file_base64 = file_data["Data"]['FileData']
file_base64 = file_data['Data']['FileData']
# print(file_data)
file_size = file_data["Data"]['TotalSize']
file_size = file_data['Data']['TotalSize']
# print(file_base64)
return platform_message.MessageChain([
platform_message.WeChatFile(file_id=file_id, file_name=file_name, file_size=file_size,
file_base64=file_base64),
platform_message.WeChatForwardFile(xml_data=xml_data_str)
])
return platform_message.MessageChain(
[
platform_message.WeChatFile(
file_id=file_id, file_name=file_name, file_size=file_size, file_base64=file_base64
),
platform_message.WeChatForwardFile(xml_data=xml_data_str),
]
)
async def _handler_compound_link(
self,
message: dict,
xml_data: ET.Element
) -> platform_message.MessageChain:
async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
"""处理链接消息(如公众号文章、外部网页)"""
message_list = []
try:
@@ -374,56 +329,38 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
link_title=appmsg.findtext('title', ''),
link_desc=appmsg.findtext('des', ''),
link_url=appmsg.findtext('url', ''),
link_thumb_url=appmsg.findtext("thumburl", '') # 这个字段拿不到
link_thumb_url=appmsg.findtext('thumburl', ''), # 这个字段拿不到
)
)
# 还没有发链接的接口, 暂时还需要自己构造appmsg, 先用WeChatAppMsg。
message_list.append(
platform_message.WeChatAppMsg(
app_msg=ET.tostring(appmsg, encoding='unicode')
)
)
message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg, encoding='unicode')))
except Exception as e:
self.logger.error(f"解析链接消息失败: {str(e)}")
self.logger.error(f'解析链接消息失败: {str(e)}')
return platform_message.MessageChain(message_list)
async def _handler_compound_mini_program(
self,
message: dict,
xml_data: ET.Element
self, message: dict, xml_data: ET.Element
) -> platform_message.MessageChain:
"""处理小程序消息(如小程序卡片、服务通知)"""
xml_data_str = ET.tostring(xml_data, encoding='unicode')
return platform_message.MessageChain([
platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)
])
return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)])
async def _handler_default(
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理未知消息类型"""
if message:
msg_type = message["msg_type"]
msg_type = message['msg_type']
else:
msg_type = ""
return platform_message.MessageChain([
platform_message.Unknown(text=f"[未知消息类型 msg_type:{msg_type}]")
])
msg_type = ''
return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')])
def _handler_compound_unsupported(
self,
message: dict,
xml_data: str,
text: Optional[str] = None
self, message: dict, xml_data: str, text: Optional[str] = None
) -> platform_message.MessageChain:
"""处理未支持复合消息类型(msg_type=49)子类型"""
if not text:
text = f"[xml_data={xml_data}]"
text = f'[xml_data={xml_data}]'
content_list = []
content_list.append(
platform_message.Unknown(text=f"[处理未支持复合消息类型[msg_type=49]|{text}"))
content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}'))
return platform_message.MessageChain(content_list)
@@ -432,7 +369,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
ats_bot = False
try:
to_user_name = message['to_user_name']['str'] # 接收方: 所属微信的wxid
raw_content = message["content"]["str"] # 原始消息内容
raw_content = message['content']['str'] # 原始消息内容
content_no_prefix, _ = self._extract_content_and_sender(raw_content)
# 直接艾特机器人这个有bug当被引用的消息里面有@bot,会套娃
# ats_bot = ats_bot or (f"@{bot_account_id}" in content_no_prefix)
@@ -443,7 +380,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
msg_source = message.get('msg_source', '') or ''
if len(msg_source) > 0:
msg_source_data = ET.fromstring(msg_source)
at_user_list = msg_source_data.findtext("atuserlist") or ""
at_user_list = msg_source_data.findtext('atuserlist') or ''
ats_bot = ats_bot or (to_user_name in at_user_list)
# 引用bot
if message.get('msg_type', 0) == 49:
@@ -454,7 +391,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者
ats_bot = ats_bot or (quote_id == tousername)
except Exception as e:
self.logger.error(f"_ats_bot got except: {e}")
self.logger.error(f'_ats_bot got except: {e}')
finally:
return ats_bot
@@ -463,47 +400,41 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
try:
# 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉
# add: 有些用户的wxid不是上述格式。换成user_name:
regex = re.compile(r"^[a-zA-Z0-9_\-]{5,20}:")
line_split = raw_content.split("\n")
regex = re.compile(r'^[a-zA-Z0-9_\-]{5,20}:')
line_split = raw_content.split('\n')
if len(line_split) > 0 and regex.match(line_split[0]):
raw_content = "\n".join(line_split[1:])
sender_id = line_split[0].strip(":")
raw_content = '\n'.join(line_split[1:])
sender_id = line_split[0].strip(':')
return raw_content, sender_id
except Exception as e:
self.logger.error(f"_extract_content_and_sender got except: {e}")
self.logger.error(f'_extract_content_and_sender got except: {e}')
finally:
return raw_content, None
# 是否是群消息
def _is_group_message(self, message: dict) -> bool:
from_user_name = message['from_user_name']['str']
return from_user_name.endswith("@chatroom")
return from_user_name.endswith('@chatroom')
class WeChatPadEventConverter(adapter.EventConverter):
def __init__(self, config: dict):
self.config = config
self.message_converter = WeChatPadMessageConverter(config)
self.logger = logging.getLogger("WeChatPadEventConverter")
self.logger = logging.getLogger('WeChatPadEventConverter')
@staticmethod
async def yiri2target(
event: platform_events.MessageEvent
) -> dict:
async def yiri2target(event: platform_events.MessageEvent) -> dict:
pass
async def target2yiri(
self,
event: dict,
bot_account_id: str
) -> platform_events.MessageEvent:
async def target2yiri(self, event: dict, bot_account_id: str) -> platform_events.MessageEvent:
# 排除公众号以及微信团队消息
if event['from_user_name']['str'].startswith('gh_') \
or event['from_user_name']['str']=='weixin'\
or event['from_user_name']['str'] == "newsapp"\
or event['from_user_name']['str'] == self.config["wxid"]:
if (
event['from_user_name']['str'].startswith('gh_')
or event['from_user_name']['str'] == 'weixin'
or event['from_user_name']['str'] == 'newsapp'
or event['from_user_name']['str'] == self.config['wxid']
):
return None
message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id)
@@ -512,7 +443,7 @@ class WeChatPadEventConverter(adapter.EventConverter):
if '@chatroom' in event['from_user_name']['str']:
# 找出开头的 wxid_ 字符串,以:结尾
sender_wxid = event['content']['str'].split(":")[0]
sender_wxid = event['content']['str'].split(':')[0]
return platform_events.GroupMessage(
sender=platform_entities.GroupMember(
@@ -524,13 +455,13 @@ class WeChatPadEventConverter(adapter.EventConverter):
name=event['from_user_name']['str'],
permission=platform_entities.Permission.Member,
),
special_title="",
special_title='',
join_timestamp=0,
last_speak_timestamp=0,
mute_time_remaining=0,
),
message_chain=message_chain,
time=event["create_time"],
time=event['create_time'],
source_platform_object=event,
)
else:
@@ -541,13 +472,13 @@ class WeChatPadEventConverter(adapter.EventConverter):
remark='',
),
message_chain=message_chain,
time=event["create_time"],
time=event['create_time'],
source_platform_object=event,
)
class WeChatPadAdapter(adapter.MessagePlatformAdapter):
name: str = "WeChatPad" # 定义适配器名称
name: str = 'WeChatPad' # 定义适配器名称
bot: WeChatPadClient
quart_app: quart.Quart
@@ -580,27 +511,21 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
# self.ap.logger.debug(f"Gewechat callback event: {data}")
# print(data)
try:
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
except Exception as e:
await self.logger.error(f"Error in wechatpad callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in wechatpad callback: {traceback.format_exc()}')
if event.__class__ in self.listeners:
await self.listeners[event.__class__](event, self)
return 'ok'
async def _handle_message(
self,
message: platform_message.MessageChain,
target_id: str
):
async def _handle_message(self, message: platform_message.MessageChain, target_id: str):
"""统一消息处理核心逻辑"""
content_list = await self.message_converter.yiri2target(message)
# print(content_list)
at_targets = [item["target"] for item in content_list if item["type"] == "at"]
at_targets = [item['target'] for item in content_list if item['type'] == 'at']
# print(at_targets)
# 处理@逻辑
at_targets = at_targets or []
@@ -608,7 +533,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
if at_targets:
member_info = self.bot.get_chatroom_member_detail(
target_id,
)["Data"]["member_data"]["chatroom_member_list"]
)['Data']['member_data']['chatroom_member_list']
# 处理消息组件
for msg in content_list:
@@ -616,55 +541,43 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
if msg['type'] == 'text' and at_targets:
at_nick_name_list = []
for member in member_info:
if member["user_name"] in at_targets:
if member['user_name'] in at_targets:
at_nick_name_list.append(f'@{member["nick_name"]}')
msg['content'] = f'{" ".join(at_nick_name_list)} {msg["content"]}'
# 统一消息派发
handler_map = {
'text': lambda msg: self.bot.send_text_message(
to_wxid=target_id,
message=msg['content'],
ats=at_targets
to_wxid=target_id, message=msg['content'], ats=at_targets
),
'image': lambda msg: self.bot.send_image_message(
to_wxid=target_id,
img_url=msg["image"],
ats = at_targets
to_wxid=target_id, img_url=msg['image'], ats=at_targets
),
'WeChatEmoji': lambda msg: self.bot.send_emoji_message(
to_wxid=target_id,
emoji_md5=msg['emoji_md5'],
emoji_size=msg['emoji_size']
to_wxid=target_id, emoji_md5=msg['emoji_md5'], emoji_size=msg['emoji_size']
),
'voice': lambda msg: self.bot.send_voice_message(
to_wxid=target_id,
voice_data=msg['data'],
voice_duration=msg["duration"],
voice_forma=msg["forma"],
voice_duration=msg['duration'],
voice_forma=msg['forma'],
),
'WeChatAppMsg': lambda msg: self.bot.send_app_message(
to_wxid=target_id,
app_message=msg['app_msg'],
type=0,
),
'at': lambda msg: None
'at': lambda msg: None,
}
if handler := handler_map.get(msg['type']):
handler(msg)
# self.ap.logger.warning(f"未处理的消息类型: {ret}")
else:
self.ap.logger.warning(f"未处理的消息类型: {msg['type']}")
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
continue
async def send_message(
self,
target_type: str,
target_id: str,
message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""主动发送消息"""
return await self._handle_message(message, target_id)
@@ -672,7 +585,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False
quote_origin: bool = False,
):
"""回复消息"""
if message_source.source_platform_object:
@@ -685,56 +598,47 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None]
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None]
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
pass
async def run_async(self):
if not self.config["admin_key"] and not self.config["token"]:
raise RuntimeError("无wechatpad管理密匙请填入配置文件后重启")
if not self.config['admin_key'] and not self.config['token']:
raise RuntimeError('无wechatpad管理密匙请填入配置文件后重启')
else:
if self.config["token"]:
self.bot = WeChatPadClient(
self.config['wechatpad_url'],
self.config["token"]
)
if self.config['token']:
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
data = self.bot.get_login_status()
self.ap.logger.info(data)
if data["Code"] == 300 and data["Text"] == "你已退出微信":
if data['Code'] == 300 and data['Text'] == '你已退出微信':
response = requests.post(
f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}",
json={"Count": 1, "Days": 365}
f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
json={'Count': 1, 'Days': 365},
)
if response.status_code != 200:
raise Exception(f"获取token失败: {response.text}")
self.config["token"] = response.json()["Data"][0]
raise Exception(f'获取token失败: {response.text}')
self.config['token'] = response.json()['Data'][0]
elif not self.config["token"]:
elif not self.config['token']:
response = requests.post(
f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}",
json={"Count": 1, "Days": 365}
f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
json={'Count': 1, 'Days': 365},
)
if response.status_code != 200:
raise Exception(f"获取token失败: {response.text}")
self.config["token"] = response.json()["Data"][0]
raise Exception(f'获取token失败: {response.text}')
self.config['token'] = response.json()['Data'][0]
self.bot = WeChatPadClient(
self.config['wechatpad_url'],
self.config["token"],
logger=self.logger
)
self.ap.logger.info(self.config["token"])
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger)
self.ap.logger.info(self.config['token'])
thread_1 = threading.Event()
def wechat_login_process():
# 不登录这些先注释掉避免登陆态尝试拉qrcode。
# login_data =self.bot.get_login_qr()
@@ -742,67 +646,54 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
# url = login_data['Data']["QrCodeUrl"]
# self.ap.logger.info(login_data)
profile = self.bot.get_profile()
self.ap.logger.info(profile)
self.bot_account_id = profile["Data"]["userInfo"]["nickName"]["str"]
self.config["wxid"] = profile["Data"]["userInfo"]["userName"]["str"]
self.bot_account_id = profile['Data']['userInfo']['nickName']['str']
self.config['wxid'] = profile['Data']['userInfo']['userName']['str']
thread_1.set()
# asyncio.create_task(wechat_login_process)
threading.Thread(target=wechat_login_process).start()
def connect_websocket_sync() -> None:
thread_1.wait()
uri = f"{self.config['wechatpad_ws']}/GetSyncMsg?key={self.config['token']}"
self.ap.logger.info(f"Connecting to WebSocket: {uri}")
uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}'
self.ap.logger.info(f'Connecting to WebSocket: {uri}')
def on_message(ws, message):
try:
data = json.loads(message)
self.ap.logger.debug(f"Received message: {data}")
self.ap.logger.debug(f'Received message: {data}')
# 这里需要确保ws_message是同步的或者使用asyncio.run调用异步方法
asyncio.run(self.ws_message(data))
except json.JSONDecodeError:
self.ap.logger.error(f"Non-JSON message: {message[:100]}...")
self.ap.logger.error(f'Non-JSON message: {message[:100]}...')
def on_error(ws, error):
self.ap.logger.error(f"WebSocket error: {str(error)[:200]}")
self.ap.logger.error(f'WebSocket error: {str(error)[:200]}')
def on_close(ws, close_status_code, close_msg):
self.ap.logger.info("WebSocket closed, reconnecting...")
self.ap.logger.info('WebSocket closed, reconnecting...')
time.sleep(5)
connect_websocket_sync() # 自动重连
def on_open(ws):
self.ap.logger.info("WebSocket connected successfully!")
self.ap.logger.info('WebSocket connected successfully!')
ws = websocket.WebSocketApp(
uri,
on_message=on_message,
on_error=on_error,
on_close=on_close,
on_open=on_open
)
ws.run_forever(
ping_interval=60,
ping_timeout=20
uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open
)
ws.run_forever(ping_interval=60, ping_timeout=20)
# 直接调用同步版本(会阻塞)
# connect_websocket_sync()
# 这行代码会在WebSocket连接断开后才会执行
# self.ap.logger.info("WebSocket client thread started")
thread = threading.Thread(
target=connect_websocket_sync,
name="WebSocketClientThread",
daemon=True
)
thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True)
thread.start()
self.ap.logger.info("WebSocket client thread started")
self.ap.logger.info('WebSocket client thread started')
async def kill(self) -> bool:
pass

View File

@@ -157,7 +157,7 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
token=config['token'],
EncodingAESKey=config['EncodingAESKey'],
contacts_secret=config['contacts_secret'],
logger=self.logger
logger=self.logger,
)
async def reply_message(
@@ -201,8 +201,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = event.receiver_id
try:
return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e:
await self.logger.error(f"Error in wecom callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in wecom callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage:
self.bot.on_message('text')(on_message)

View File

@@ -145,7 +145,7 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
secret=config['secret'],
token=config['token'],
EncodingAESKey=config['EncodingAESKey'],
logger=self.logger
logger=self.logger,
)
async def reply_message(
@@ -178,8 +178,8 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = event.receiver_id
try:
return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e:
await self.logger.error(f"Error in wecomcs callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in wecomcs callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage:
self.bot.on_message('text')(on_message)

View File

@@ -1,19 +1,20 @@
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
from datetime import datetime
import numpy as np # 用于处理从LargeBinary转换回来的embedding
Base = declarative_base()
class KnowledgeBase(Base):
__tablename__ = 'kb'
id = Column(Integer, primary_key=True, index=True)
name = Column(String, index=True)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
embedding_model = Column(String, default="") # 默认嵌入模型
embedding_model = Column(String, default='') # 默认嵌入模型
top_k = Column(Integer, default=5) # 默认返回的top_k数量
files = relationship("File", back_populates="knowledge_base")
files = relationship('File', back_populates='knowledge_base')
class File(Base):
__tablename__ = 'file'
@@ -24,8 +25,9 @@ class File(Base):
created_at = Column(DateTime, default=datetime.utcnow)
file_type = Column(String)
status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误
knowledge_base = relationship("KnowledgeBase", back_populates="files")
chunks = relationship("Chunk", back_populates="file")
knowledge_base = relationship('KnowledgeBase', back_populates='files')
chunks = relationship('Chunk', back_populates='file')
class Chunk(Base):
__tablename__ = 'chunks'
@@ -33,8 +35,9 @@ class Chunk(Base):
file_id = Column(Integer, ForeignKey('file.id'))
text = Column(Text)
file = relationship("File", back_populates="chunks")
vector = relationship("Vector", uselist=False, back_populates="chunk") # One-to-one
file = relationship('File', back_populates='chunks')
vector = relationship('Vector', uselist=False, back_populates='chunk') # One-to-one
class Vector(Base):
__tablename__ = 'vectors'
@@ -42,17 +45,20 @@ class Vector(Base):
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
embedding = Column(LargeBinary) # Store embeddings as binary
chunk = relationship("Chunk", back_populates="vector")
chunk = relationship('Chunk', back_populates='vector')
# 数据库连接
DATABASE_URL = "sqlite:///./knowledge_base.db" # 生产环境请更换为 PostgreSQL/MySQL
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {})
DATABASE_URL = 'sqlite:///./knowledge_base.db' # 生产环境请更换为 PostgreSQL/MySQL
engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False} if 'sqlite' in DATABASE_URL else {})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建所有表 (可以在应用启动时执行一次)
def create_db_and_tables():
Base.metadata.create_all(bind=engine)
print("Database tables created/checked.")
print('Database tables created/checked.')
# 定义嵌入维度(请根据你实际使用的模型调整)
EMBEDDING_DIM = 1024

View File

@@ -1,7 +1,7 @@
# services/embedding_models.py
import os
from typing import Dict, Any, List, Type, Optional
from typing import Dict, Any, List
import logging
import aiohttp # Import aiohttp for asynchronous requests
import asyncio
@@ -9,6 +9,7 @@ from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
# Base class for all embedding models
class BaseEmbeddingModel:
def __init__(self, model_name: str):
@@ -27,9 +28,10 @@ class BaseEmbeddingModel:
def embedding_dimension(self) -> int:
"""Returns the embedding dimension of the model."""
if self._embedding_dimension is None:
raise NotImplementedError("Embedding dimension not set for this model.")
raise NotImplementedError('Embedding dimension not set for this model.')
return self._embedding_dimension
class EmbeddingModelFactory:
@staticmethod
def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel:
@@ -42,10 +44,12 @@ class EmbeddingModelFactory:
config = EMBEDDING_MODEL_CONFIGS[model_name_key]
if config['type'] == "third_party_api":
if config['type'] == 'third_party_api':
required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension']
if not all(key in config for key in required_keys):
raise ValueError(f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}")
raise ValueError(
f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}"
)
# Retrieve model_name from config if it differs from model_name_key
# Some APIs expect a specific 'model' value in the payload that might be different from the key
@@ -56,9 +60,10 @@ class EmbeddingModelFactory:
api_endpoint=config['api_endpoint'],
headers=config['headers'],
payload_template=config['payload_template'],
embedding_dimension=config['embedding_dimension']
embedding_dimension=config['embedding_dimension'],
)
class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
def __init__(self, model_name: str):
super().__init__(model_name)
@@ -68,9 +73,11 @@ class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
# if not run in a separate thread/process, but this keeps the API consistent.
self.model = SentenceTransformer(model_name)
self._embedding_dimension = self.model.get_sentence_embedding_dimension()
logger.info(f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}")
logger.info(
f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}"
)
except Exception as e:
logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}")
logger.error(f'Failed to load SentenceTransformer model {model_name}: {e}')
raise
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -84,14 +91,23 @@ class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
def __init__(self, model_name: str, api_endpoint: str, headers: Dict[str, str], payload_template: Dict[str, Any], embedding_dimension: int):
def __init__(
self,
model_name: str,
api_endpoint: str,
headers: Dict[str, str],
payload_template: Dict[str, Any],
embedding_dimension: int,
):
super().__init__(model_name)
self.api_endpoint = api_endpoint
self.headers = headers
self.payload_template = payload_template
self._embedding_dimension = embedding_dimension
self.session = None # aiohttp client session will be initialized on first use or in a context manager
logger.info(f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}")
logger.info(
f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}"
)
async def _get_session(self):
"""Lazily create or return the aiohttp client session."""
@@ -104,7 +120,7 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
if self.session and not self.session.closed:
await self.session.close()
self.session = None
logger.info(f"Closed aiohttp session for model {self.model_name}")
logger.info(f'Closed aiohttp session for model {self.model_name}')
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously embeds a list of texts using the third-party API."""
@@ -118,7 +134,7 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
elif 'texts' in payload:
payload['texts'] = [text]
else:
raise ValueError("Payload template does not contain expected text input key.")
raise ValueError('Payload template does not contain expected text input key.')
tasks.append(self._make_api_request(session, payload))
@@ -145,24 +161,31 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
api_response = await response.json()
# Adjust this based on your API's actual response structure
if "data" in api_response and len(api_response["data"]) > 0 and "embedding" in api_response["data"][0]:
embedding = api_response["data"][0]["embedding"]
if 'data' in api_response and len(api_response['data']) > 0 and 'embedding' in api_response['data'][0]:
embedding = api_response['data'][0]['embedding']
if len(embedding) != self.embedding_dimension:
logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.")
logger.warning(
f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.'
)
return embedding
elif "embeddings" in api_response and isinstance(api_response["embeddings"], list) and api_response["embeddings"]:
embedding = api_response["embeddings"][0]
elif (
'embeddings' in api_response
and isinstance(api_response['embeddings'], list)
and api_response['embeddings']
):
embedding = api_response['embeddings'][0]
if len(embedding) != self.embedding_dimension:
logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.")
logger.warning(
f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.'
)
return embedding
else:
raise ValueError(f"Unexpected API response structure: {api_response}")
raise ValueError(f'Unexpected API response structure: {api_response}')
except aiohttp.ClientError as e:
raise ConnectionError(f"API request failed: {e}") from e
raise ConnectionError(f'API request failed: {e}') from e
except ValueError as e:
raise ValueError(f"Error processing API response: {e}") from e
raise ValueError(f'Error processing API response: {e}') from e
async def embed_query(self, text: str) -> List[float]:
"""Asynchronously embeds a single query text."""
@@ -171,53 +194,45 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
return results[0]
return [] # Or raise an error if embedding a query must always succeed
# --- Embedding Model Configuration ---
EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
"MiniLM": { # Example for a local Sentence Transformer model
"type": "sentence_transformer",
"model_name": "sentence-transformers/all-MiniLM-L6-v2"
'MiniLM': { # Example for a local Sentence Transformer model
'type': 'sentence_transformer',
'model_name': 'sentence-transformers/all-MiniLM-L6-v2',
},
"bge-m3": { # Example for a third-party API model
"type": "third_party_api",
"model_name": "bge-m3",
"api_endpoint": "https://api.qhaigc.net/v1/embeddings",
"headers": {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('rag_api_key')}"
'bge-m3': { # Example for a third-party API model
'type': 'third_party_api',
'model_name': 'bge-m3',
'api_endpoint': 'https://api.qhaigc.net/v1/embeddings',
'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("rag_api_key")}'},
'payload_template': {'model': 'bge-m3', 'input': ''},
'embedding_dimension': 1024,
},
"payload_template": {
"model": "bge-m3",
"input": ""
'OpenAI-Ada-002': {
'type': 'third_party_api',
'model_name': 'text-embedding-ada-002',
'api_endpoint': 'https://api.openai.com/v1/embeddings',
'headers': {
'Content-Type': 'application/json',
'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}', # Ensure OPENAI_API_KEY is set
},
"embedding_dimension": 1024
'payload_template': {
'model': 'text-embedding-ada-002',
'input': '', # Text will be injected here
},
"OpenAI-Ada-002": {
"type": "third_party_api",
"model_name": "text-embedding-ada-002",
"api_endpoint": "https://api.openai.com/v1/embeddings",
"headers": {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" # Ensure OPENAI_API_KEY is set
'embedding_dimension': 1536,
},
"payload_template": {
"model": "text-embedding-ada-002",
"input": "" # Text will be injected here
},
"embedding_dimension": 1536
},
"OpenAI-Embedding-3-Small": {
"type": "third_party_api",
"model_name": "text-embedding-3-small",
"api_endpoint": "https://api.openai.com/v1/embeddings",
"headers": {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
},
"payload_template": {
"model": "text-embedding-3-small",
"input": "",
'OpenAI-Embedding-3-Small': {
'type': 'third_party_api',
'model_name': 'text-embedding-3-small',
'api_endpoint': 'https://api.openai.com/v1/embeddings',
'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}'},
'payload_template': {
'model': 'text-embedding-3-small',
'input': '',
# "dimensions": 512 # Optional: uncomment if you want a specific output dimension
},
"embedding_dimension": 1536 # Default max dimension for text-embedding-3-small
'embedding_dimension': 1536, # Default max dimension for text-embedding-3-small
},
}

View File

@@ -1,10 +1,8 @@
import PyPDF2
from docx import Document
import pandas as pd
import csv
import chardet
from typing import Union, List, Callable, Any
from typing import Union, Callable, Any
import logging
import markdown
from bs4 import BeautifulSoup
@@ -17,6 +15,7 @@ import os
# Configure logging
logger = logging.getLogger(__name__)
class FileParser:
"""
A robust file parser class to extract text content from various document formats.
@@ -24,8 +23,8 @@ class FileParser:
All core file reading operations are designed to be run synchronously in a thread pool
to avoid blocking the asyncio event loop.
"""
def __init__(self):
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any:
@@ -36,7 +35,7 @@ class FileParser:
try:
return await asyncio.to_thread(sync_func, *args, **kwargs)
except Exception as e:
self.logger.error(f"Error running synchronous function {sync_func.__name__}: {e}")
self.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}')
raise
async def parse(self, file_path: str) -> Union[str, None]:
@@ -51,21 +50,21 @@ class FileParser:
Union[str, None]: The extracted text content as a single string, or None if parsing fails.
"""
if not file_path or not os.path.exists(file_path):
self.logger.error(f"Invalid file path provided: {file_path}")
self.logger.error(f'Invalid file path provided: {file_path}')
return None
file_extension = file_path.split('.')[-1].lower()
parser_method = getattr(self, f'_parse_{file_extension}', None)
if parser_method is None:
self.logger.error(f"Unsupported file format: {file_extension} for file {file_path}")
self.logger.error(f'Unsupported file format: {file_extension} for file {file_path}')
return None
try:
# Pass file_path to the specific parser methods
return await parser_method(file_path)
except Exception as e:
self.logger.error(f"Failed to parse {file_extension} file {file_path}: {e}")
self.logger.error(f'Failed to parse {file_extension} file {file_path}: {e}')
return None
# --- Helper for reading files with encoding detection ---
@@ -74,6 +73,7 @@ class FileParser:
Reads a file with automatic encoding detection, ensuring the synchronous
file read operation runs in a separate thread.
"""
def _read_sync():
with open(file_path, 'rb') as file:
raw_data = file.read()
@@ -90,12 +90,13 @@ class FileParser:
async def _parse_txt(self, file_path: str) -> str:
"""Parses a TXT file and returns its content."""
self.logger.info(f"Parsing TXT file: {file_path}")
self.logger.info(f'Parsing TXT file: {file_path}')
return await self._read_file_content(file_path, mode='r')
async def _parse_pdf(self, file_path: str) -> str:
"""Parses a PDF file and returns its text content."""
self.logger.info(f"Parsing PDF file: {file_path}")
self.logger.info(f'Parsing PDF file: {file_path}')
def _parse_pdf_sync():
text_content = []
with open(file_path, 'rb') as file:
@@ -105,57 +106,69 @@ class FileParser:
if text:
text_content.append(text)
return '\n'.join(text_content)
return await self._run_sync(_parse_pdf_sync)
async def _parse_docx(self, file_path: str) -> str:
"""Parses a DOCX file and returns its text content."""
self.logger.info(f"Parsing DOCX file: {file_path}")
self.logger.info(f'Parsing DOCX file: {file_path}')
def _parse_docx_sync():
doc = Document(file_path)
text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()]
return '\n'.join(text_content)
return await self._run_sync(_parse_docx_sync)
async def _parse_doc(self, file_path: str) -> str:
"""Handles .doc files, explicitly stating lack of direct support."""
self.logger.warning(f"Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.")
raise NotImplementedError("Direct .doc parsing not supported. Please convert to .docx first.")
self.logger.warning(f'Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.')
raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.')
async def _parse_xlsx(self, file_path: str) -> str:
"""Parses an XLSX file, returning text from all sheets."""
self.logger.info(f"Parsing XLSX file: {file_path}")
self.logger.info(f'Parsing XLSX file: {file_path}')
def _parse_xlsx_sync():
excel_file = pd.ExcelFile(file_path)
all_sheet_content = []
for sheet_name in excel_file.sheet_names:
df = pd.read_excel(file_path, sheet_name=sheet_name)
sheet_text = f"--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n"
sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n'
all_sheet_content.append(sheet_text)
return '\n'.join(all_sheet_content)
return await self._run_sync(_parse_xlsx_sync)
async def _parse_csv(self, file_path: str) -> str:
"""Parses a CSV file and returns its content as a string."""
self.logger.info(f"Parsing CSV file: {file_path}")
self.logger.info(f'Parsing CSV file: {file_path}')
def _parse_csv_sync():
# pd.read_csv can often detect encoding, but explicit detection is safer
raw_data = self._read_file_content(file_path, mode='rb') # Note: this will need to be await outside this sync function
raw_data = self._read_file_content(
file_path, mode='rb'
) # Note: this will need to be await outside this sync function
_ = raw_data
# For simplicity, we'll let pandas handle encoding internally after a raw read.
# A more robust solution might pass encoding directly to pd.read_csv after detection.
detected = chardet.detect(open(file_path, 'rb').read())
encoding = detected['encoding'] or 'utf-8'
df = pd.read_csv(file_path, encoding=encoding)
return df.to_string(index=False)
return await self._run_sync(_parse_csv_sync)
async def _parse_markdown(self, file_path: str) -> str:
"""Parses a Markdown file, converting it to structured plain text."""
self.logger.info(f"Parsing Markdown file: {file_path}")
self.logger.info(f'Parsing Markdown file: {file_path}')
def _parse_markdown_sync():
md_content = self._read_file_content(file_path, mode='r') # This is a synchronous call within a sync function
md_content = self._read_file_content(
file_path, mode='r'
) # This is a synchronous call within a sync function
html_content = markdown.markdown(
md_content,
extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
)
soup = BeautifulSoup(html_content, 'html.parser')
text_parts = []
@@ -169,11 +182,11 @@ class FileParser:
text_parts.append(text)
elif element.name in ['ul', 'ol']:
for li in element.find_all('li'):
text_parts.append(f"* {li.get_text().strip()}")
text_parts.append(f'* {li.get_text().strip()}')
elif element.name == 'pre':
code_block = element.get_text().strip()
if code_block:
text_parts.append(f"```\n{code_block}\n```")
text_parts.append(f'```\n{code_block}\n```')
elif element.name == 'table':
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
if table_str:
@@ -184,15 +197,17 @@ class FileParser:
text_parts.append(text)
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
return cleaned_text.strip()
return await self._run_sync(_parse_markdown_sync)
async def _parse_html(self, file_path: str) -> str:
"""Parses an HTML file, extracting structured plain text."""
self.logger.info(f"Parsing HTML file: {file_path}")
self.logger.info(f'Parsing HTML file: {file_path}')
def _parse_html_sync():
html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function
soup = BeautifulSoup(html_content, 'html.parser')
for script_or_style in soup(["script", "style"]):
for script_or_style in soup(['script', 'style']):
script_or_style.decompose()
text_parts = []
for element in soup.body.children if soup.body else soup.children:
@@ -207,7 +222,7 @@ class FileParser:
for li in element.find_all('li'):
text = li.get_text().strip()
if text:
text_parts.append(f"* {text}")
text_parts.append(f'* {text}')
elif element.name == 'table':
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
if table_str:
@@ -218,39 +233,42 @@ class FileParser:
text_parts.append(text)
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
return cleaned_text.strip()
return await self._run_sync(_parse_html_sync)
async def _parse_epub(self, file_path: str) -> str:
"""Parses an EPUB file, extracting metadata and content."""
self.logger.info(f"Parsing EPUB file: {file_path}")
self.logger.info(f'Parsing EPUB file: {file_path}')
def _parse_epub_sync():
book = epub.read_epub(file_path)
text_content = []
title_meta = book.get_metadata('DC', 'title')
if title_meta:
text_content.append(f"Title: {title_meta[0][0]}")
text_content.append(f'Title: {title_meta[0][0]}')
creator_meta = book.get_metadata('DC', 'creator')
if creator_meta:
text_content.append(f"Author: {creator_meta[0][0]}")
text_content.append(f'Author: {creator_meta[0][0]}')
date_meta = book.get_metadata('DC', 'date')
if date_meta:
text_content.append(f"Publish Date: {date_meta[0][0]}")
text_content.append(f'Publish Date: {date_meta[0][0]}')
toc = book.get_toc()
if toc:
text_content.append("\n--- Table of Contents ---")
text_content.append('\n--- Table of Contents ---')
self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper
text_content.append("--- End of Table of Contents ---\n")
text_content.append('--- End of Table of Contents ---\n')
for item in book.get_items():
if item.get_type() == ebooklib.ITEM_DOCUMENT:
html_content = item.get_content().decode('utf-8', errors='ignore')
soup = BeautifulSoup(html_content, 'html.parser')
for junk in soup(["script", "style", "nav", "header", "footer"]):
for junk in soup(['script', 'style', 'nav', 'header', 'footer']):
junk.decompose()
text = soup.get_text(separator='\n', strip=True)
text = re.sub(r'\n\s*\n', '\n\n', text)
if text:
text_content.append(text)
return re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_content)).strip()
return await self._run_sync(_parse_epub_sync)
def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int):
@@ -259,10 +277,10 @@ class FileParser:
for item in toc_list:
if isinstance(item, tuple):
chapter, subchapters = item
text_content.append(f"{indent}- {chapter.title}")
text_content.append(f'{indent}- {chapter.title}')
self._add_toc_items_sync(subchapters, text_content, level + 1)
else:
text_content.append(f"{indent}- {item.title}")
text_content.append(f'{indent}- {item.title}')
def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str:
"""Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous)."""
@@ -274,7 +292,7 @@ class FileParser:
rows.append(cells)
if not headers and not rows:
return ""
return ''
table_lines = []
if headers:

View File

@@ -1,5 +1,4 @@
# services/retriever.py
import asyncio
import logging
import numpy as np # Make sure numpy is imported
from typing import List, Dict, Any
@@ -11,6 +10,7 @@ from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
logger = logging.getLogger(__name__)
class Retriever(BaseService):
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
super().__init__()
@@ -22,10 +22,14 @@ class Retriever(BaseService):
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
def _load_embedding_model(self) -> BaseEmbeddingModel:
self.logger.info(f"Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...")
self.logger.info(
f'Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...'
)
try:
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
self.logger.info(f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}")
self.logger.info(
f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}"
)
return model
except Exception as e:
self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}")
@@ -33,43 +37,42 @@ class Retriever(BaseService):
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
if not self.embedding_model:
raise RuntimeError("Retriever embedding model not loaded. Please check Retriever initialization.")
raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.')
self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}")
query_embedding: List[float] = await self.embedding_model.embed_query(query)
query_embedding_np = np.array([query_embedding], dtype=np.float32)
chroma_results = await self._run_sync(
self.chroma_manager.search_sync,
query_embedding_np, k
)
chroma_results = await self._run_sync(self.chroma_manager.search_sync, query_embedding_np, k)
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
matched_chroma_ids = chroma_results.get("ids", [[]])[0]
distances = chroma_results.get("distances", [[]])[0]
chroma_metadatas = chroma_results.get("metadatas", [[]])[0]
chroma_documents = chroma_results.get("documents", [[]])[0]
matched_chroma_ids = chroma_results.get('ids', [[]])[0]
distances = chroma_results.get('distances', [[]])[0]
chroma_metadatas = chroma_results.get('metadatas', [[]])[0]
chroma_documents = chroma_results.get('documents', [[]])[0]
if not matched_chroma_ids:
self.logger.info("No relevant chunks found in Chroma.")
self.logger.info('No relevant chunks found in Chroma.')
return []
db_chunk_ids = []
for metadata in chroma_metadatas:
if "chunk_id" in metadata:
db_chunk_ids.append(metadata["chunk_id"])
if 'chunk_id' in metadata:
db_chunk_ids.append(metadata['chunk_id'])
else:
self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.")
if not db_chunk_ids:
self.logger.warning("No valid chunk_ids extracted from Chroma results metadata.")
self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.')
return []
self.logger.info(f"Fetching {len(db_chunk_ids)} chunk details from relational database...")
self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...')
chunks_from_db = await self._run_sync(
lambda cids: self._db_get_chunks_sync(SessionLocal(), cids), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
db_chunk_ids
lambda cids: self._db_get_chunks_sync(
SessionLocal(), cids
), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
db_chunk_ids,
)
chunk_map = {chunk.id: chunk for chunk in chunks_from_db}
@@ -80,27 +83,29 @@ class Retriever(BaseService):
# Ensure original_chunk_id is int for DB lookup
original_chunk_id = int(chroma_id.split('_')[-1])
except (ValueError, IndexError):
self.logger.warning(f"Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.")
self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.')
continue
chunk_text_from_chroma = chroma_documents[i]
distance = float(distances[i])
file_id_from_chroma = chroma_metadatas[i].get("file_id")
file_id_from_chroma = chroma_metadatas[i].get('file_id')
chunk_from_db = chunk_map.get(original_chunk_id)
results_list.append({
"chunk_id": original_chunk_id,
"text": chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
"distance": distance,
"file_id": file_id_from_chroma
})
results_list.append(
{
'chunk_id': original_chunk_id,
'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
'distance': distance,
'file_id': file_id_from_chroma,
}
)
self.logger.info(f"Retrieved {len(results_list)} chunks.")
self.logger.info(f'Retrieved {len(results_list)} chunks.')
return results_list
def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]:
self.logger.debug(f"Fetching {len(chunk_ids)} chunk details from database (sync).")
self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).')
chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
session.close()
return chunks