From 8d28ace25276820714d33c5aedf359b48d0faf3e Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sat, 5 Jul 2025 21:56:54 +0800 Subject: [PATCH] perf: ruff check --fix --- libs/wechatpad_api/__init__.py | 2 +- libs/wechatpad_api/api/chatroom.py | 6 +- libs/wechatpad_api/api/downloadpai.py | 25 +- libs/wechatpad_api/api/friend.py | 5 - libs/wechatpad_api/api/login.py | 60 +-- libs/wechatpad_api/api/message.py | 111 ++--- libs/wechatpad_api/util/http_util.py | 48 +- pkg/entity/persistence/vector.py | 11 +- pkg/platform/sources/aiocqhttp.py | 229 ++++++---- pkg/platform/sources/discord.py | 20 +- pkg/platform/sources/lark.py | 8 +- pkg/platform/sources/nakuru.py | 5 +- pkg/platform/sources/officialaccount.py | 4 +- pkg/platform/sources/qqofficial.py | 9 +- pkg/platform/sources/slack.py | 8 +- pkg/platform/sources/telegram.py | 4 +- pkg/platform/sources/wechatpad.py | 425 +++++++----------- pkg/platform/sources/wecom.py | 6 +- pkg/platform/sources/wecomcs.py | 6 +- pkg/rag/knowledge/services/database.py | 32 +- .../knowledge/services/embedding_models.py | 165 +++---- pkg/rag/knowledge/services/parser.py | 128 +++--- pkg/rag/knowledge/services/retriever.py | 67 +-- 23 files changed, 647 insertions(+), 737 deletions(-) diff --git a/libs/wechatpad_api/__init__.py b/libs/wechatpad_api/__init__.py index 23c23fb2..9ac533f7 100644 --- a/libs/wechatpad_api/__init__.py +++ b/libs/wechatpad_api/__init__.py @@ -1 +1 @@ -from .client import WeChatPadClient \ No newline at end of file +from .client import WeChatPadClient as WeChatPadClient diff --git a/libs/wechatpad_api/api/chatroom.py b/libs/wechatpad_api/api/chatroom.py index a7af207c..2d9281a2 100644 --- a/libs/wechatpad_api/api/chatroom.py +++ b/libs/wechatpad_api/api/chatroom.py @@ -1,4 +1,4 @@ -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json class ChatRoomApi: @@ -7,8 +7,6 @@ class ChatRoomApi: self.token = token def get_chatroom_member_detail(self, chatroom_name): - params = { - "ChatRoomName": chatroom_name - } + params = {'ChatRoomName': chatroom_name} url = self.base_url + '/group/GetChatroomMemberDetail' return post_json(url, token=self.token, data=params) diff --git a/libs/wechatpad_api/api/downloadpai.py b/libs/wechatpad_api/api/downloadpai.py index a82a5674..2d45fac6 100644 --- a/libs/wechatpad_api/api/downloadpai.py +++ b/libs/wechatpad_api/api/downloadpai.py @@ -1,32 +1,23 @@ -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json import httpx import base64 + class DownloadApi: def __init__(self, base_url, token): self.base_url = base_url self.token = token def send_download(self, aeskey, file_type, file_url): - json_data = { - "AesKey": aeskey, - "FileType": file_type, - "FileURL": file_url - } - url = self.base_url + "/message/SendCdnDownload" + json_data = {'AesKey': aeskey, 'FileType': file_type, 'FileURL': file_url} + url = self.base_url + '/message/SendCdnDownload' return post_json(url, token=self.token, data=json_data) - def get_msg_voice(self,buf_id, length, new_msgid): - json_data = { - "Bufid": buf_id, - "Length": length, - "NewMsgId": new_msgid, - "ToUserName": "" - } - url = self.base_url + "/message/GetMsgVoice" + def get_msg_voice(self, buf_id, length, new_msgid): + json_data = {'Bufid': buf_id, 'Length': length, 'NewMsgId': new_msgid, 'ToUserName': ''} + url = self.base_url + '/message/GetMsgVoice' return post_json(url, token=self.token, data=json_data) - async def download_url_to_base64(self, download_url): async with httpx.AsyncClient() as client: response = await client.get(download_url) @@ -36,4 +27,4 @@ class DownloadApi: base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式 return base64_str else: - raise Exception('获取文件失败') \ No newline at end of file + raise Exception('获取文件失败') diff --git a/libs/wechatpad_api/api/friend.py b/libs/wechatpad_api/api/friend.py index 00701a5d..a7a448aa 100644 --- a/libs/wechatpad_api/api/friend.py +++ b/libs/wechatpad_api/api/friend.py @@ -1,11 +1,6 @@ -from libs.wechatpad_api.util.http_util import post_json,async_request -from typing import List, Dict, Any, Optional - - class FriendApi: """联系人API类,处理所有与联系人相关的操作""" def __init__(self, base_url: str, token: str): self.base_url = base_url self.token = token - diff --git a/libs/wechatpad_api/api/login.py b/libs/wechatpad_api/api/login.py index 142a3c85..4aa4ae8d 100644 --- a/libs/wechatpad_api/api/login.py +++ b/libs/wechatpad_api/api/login.py @@ -1,37 +1,34 @@ -from libs.wechatpad_api.util.http_util import async_request,post_json,get_json +from libs.wechatpad_api.util.http_util import post_json, get_json class LoginApi: def __init__(self, base_url: str, token: str = None, admin_key: str = None): - ''' + """ Args: base_url: 原始路径 token: token admin_key: 管理员key - ''' + """ self.base_url = base_url self.token = token # self.admin_key = admin_key - def get_token(self, admin_key, day: int=365): + def get_token(self, admin_key, day: int = 365): # 获取普通token - url = f"{self.base_url}/admin/GenAuthKey1" - json_data = { - "Count": 1, - "Days": day - } + url = f'{self.base_url}/admin/GenAuthKey1' + json_data = {'Count': 1, 'Days': day} return post_json(base_url=url, token=admin_key, data=json_data) - def get_login_qr(self, Proxy: str = ""): - ''' + def get_login_qr(self, Proxy: str = ''): + """ Args: Proxy:异地使用时代理 Returns:json数据 - ''' + """ """ { @@ -49,54 +46,37 @@ class LoginApi: } """ - #获取登录二维码 - url = f"{self.base_url}/login/GetLoginQrCodeNew" + # 获取登录二维码 + url = f'{self.base_url}/login/GetLoginQrCodeNew' check = False - if Proxy != "": + if Proxy != '': check = True - json_data = { - "Check": check, - "Proxy": Proxy - } + json_data = {'Check': check, 'Proxy': Proxy} return post_json(base_url=url, token=self.token, data=json_data) - def get_login_status(self): # 获取登录状态 url = f'{self.base_url}/login/GetLoginStatus' return get_json(base_url=url, token=self.token) - - def logout(self): # 退出登录 url = f'{self.base_url}/login/LogOut' return post_json(base_url=url, token=self.token) - - - - def wake_up_login(self, Proxy: str = ""): + def wake_up_login(self, Proxy: str = ''): # 唤醒登录 url = f'{self.base_url}/login/WakeUpLogin' check = False - if Proxy != "": + if Proxy != '': check = True - json_data = { - "Check": check, - "Proxy": "" - } + json_data = {'Check': check, 'Proxy': ''} return post_json(base_url=url, token=self.token, data=json_data) - - - def login(self,admin_key): + def login(self, admin_key): login_status = self.get_login_status() - if login_status["Code"] == 300 and login_status["Text"] == "你已退出微信": - print("token已经失效,重新获取") + if login_status['Code'] == 300 and login_status['Text'] == '你已退出微信': + print('token已经失效,重新获取') token_data = self.get_token(admin_key) - self.token = token_data["Data"][0] - - - + self.token = token_data['Data'][0] diff --git a/libs/wechatpad_api/api/message.py b/libs/wechatpad_api/api/message.py index 2089ce96..cca76313 100644 --- a/libs/wechatpad_api/api/message.py +++ b/libs/wechatpad_api/api/message.py @@ -1,5 +1,4 @@ - -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json class MessageApi: @@ -7,8 +6,8 @@ class MessageApi: self.base_url = base_url self.token = token - def post_text(self, to_wxid, content, ats: list= []): - ''' + def post_text(self, to_wxid, content, ats: list = []): + """ Args: app_id: 微信id @@ -18,106 +17,64 @@ class MessageApi: Returns: - ''' - url = self.base_url + "/message/SendTextMessage" + """ + url = self.base_url + '/message/SendTextMessage' """发送文字消息""" json_data = { - "MsgItem": [ - { - "AtWxIDList": ats, - "ImageContent": "", - "MsgType": 0, - "TextContent": content, - "ToUserName": to_wxid - } - ] - } - return post_json(base_url=url, token=self.token, data=json_data) + 'MsgItem': [ + {'AtWxIDList': ats, 'ImageContent': '', 'MsgType': 0, 'TextContent': content, 'ToUserName': to_wxid} + ] + } + return post_json(base_url=url, token=self.token, data=json_data) - - - - def post_image(self, to_wxid, img_url, ats: list= []): + def post_image(self, to_wxid, img_url, ats: list = []): """发送图片消息""" # 这里好像可以尝试发送多个暂时未测试 json_data = { - "MsgItem": [ - { - "AtWxIDList": ats, - "ImageContent": img_url, - "MsgType": 0, - "TextContent": '', - "ToUserName": to_wxid - } + 'MsgItem': [ + {'AtWxIDList': ats, 'ImageContent': img_url, 'MsgType': 0, 'TextContent': '', 'ToUserName': to_wxid} ] } - url = self.base_url + "/message/SendImageMessage" + url = self.base_url + '/message/SendImageMessage' return post_json(base_url=url, token=self.token, data=json_data) def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration): """发送语音消息""" json_data = { - "ToUserName": to_wxid, - "VoiceData": voice_data, - "VoiceFormat": voice_forma, - "VoiceSecond": voice_duration + 'ToUserName': to_wxid, + 'VoiceData': voice_data, + 'VoiceFormat': voice_forma, + 'VoiceSecond': voice_duration, } - url = self.base_url + "/message/SendVoice" + url = self.base_url + '/message/SendVoice' return post_json(base_url=url, token=self.token, data=json_data) - - - - def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag): """发送名片消息""" param = { - "CardAlias": alias, - "CardFlag": flag, - "CardNickName": nick_name, - "CardWxId": name_card_wxid, - "ToUserName": to_wxid + 'CardAlias': alias, + 'CardFlag': flag, + 'CardNickName': nick_name, + 'CardWxId': name_card_wxid, + 'ToUserName': to_wxid, } - url = f"{self.base_url}/message/ShareCardMessage" + url = f'{self.base_url}/message/ShareCardMessage' return post_json(base_url=url, token=self.token, data=param) - def post_emoji(self, to_wxid, emoji_md5, emoji_size:int=0): + def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0): """发送emoji消息""" - json_data = { - "EmojiList": [ - { - "EmojiMd5": emoji_md5, - "EmojiSize": emoji_size, - "ToUserName": to_wxid - } - ] - } - url = f"{self.base_url}/message/SendEmojiMessage" + json_data = {'EmojiList': [{'EmojiMd5': emoji_md5, 'EmojiSize': emoji_size, 'ToUserName': to_wxid}]} + url = f'{self.base_url}/message/SendEmojiMessage' return post_json(base_url=url, token=self.token, data=json_data) - def post_app_msg(self, to_wxid,xml_data, contenttype:int=0): + def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0): """发送appmsg消息""" - json_data = { - "AppList": [ - { - "ContentType": contenttype, - "ContentXML": xml_data, - "ToUserName": to_wxid - } - ] - } - url = f"{self.base_url}/message/SendAppMessage" + json_data = {'AppList': [{'ContentType': contenttype, 'ContentXML': xml_data, 'ToUserName': to_wxid}]} + url = f'{self.base_url}/message/SendAppMessage' return post_json(base_url=url, token=self.token, data=json_data) - - def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time): """撤回消息""" - param = { - "ClientMsgId": msg_id, - "CreateTime": create_time, - "NewMsgId": new_msg_id, - "ToUserName": to_wxid - } - url = f"{self.base_url}/message/RevokeMsg" - return post_json(base_url=url, token=self.token, data=param) \ No newline at end of file + param = {'ClientMsgId': msg_id, 'CreateTime': create_time, 'NewMsgId': new_msg_id, 'ToUserName': to_wxid} + url = f'{self.base_url}/message/RevokeMsg' + return post_json(base_url=url, token=self.token, data=param) diff --git a/libs/wechatpad_api/util/http_util.py b/libs/wechatpad_api/util/http_util.py index 754003e9..447c29df 100644 --- a/libs/wechatpad_api/util/http_util.py +++ b/libs/wechatpad_api/util/http_util.py @@ -1,10 +1,9 @@ import requests +import aiohttp + def post_json(base_url, token, data=None): - headers = { - 'Content-Type': 'application/json' - } - + headers = {'Content-Type': 'application/json'} url = base_url + f'?key={token}' @@ -18,14 +17,12 @@ def post_json(base_url, token, data=None): else: raise RuntimeError(response.text) except Exception as e: - print(f"http请求失败, url={url}, exception={e}") + print(f'http请求失败, url={url}, exception={e}') raise RuntimeError(str(e)) -def get_json(base_url, token): - headers = { - 'Content-Type': 'application/json' - } +def get_json(base_url, token): + headers = {'Content-Type': 'application/json'} url = base_url + f'?key={token}' @@ -39,21 +36,18 @@ def get_json(base_url, token): else: raise RuntimeError(response.text) except Exception as e: - print(f"http请求失败, url={url}, exception={e}") + print(f'http请求失败, url={url}, exception={e}') raise RuntimeError(str(e)) -import aiohttp -import asyncio - async def async_request( - base_url: str, - token_key: str, - method: str = 'POST', - params: dict = None, - # headers: dict = None, - data: dict = None, - json: dict = None + base_url: str, + token_key: str, + method: str = 'POST', + params: dict = None, + # headers: dict = None, + data: dict = None, + json: dict = None, ): """ 通用异步请求函数 @@ -67,18 +61,11 @@ async def async_request( :param json: JSON数据 :return: 响应文本 """ - headers = { - 'Content-Type': 'application/json' - } - url = f"{base_url}?key={token_key}" + headers = {'Content-Type': 'application/json'} + url = f'{base_url}?key={token_key}' async with aiohttp.ClientSession() as session: async with session.request( - method=method, - url=url, - params=params, - headers=headers, - data=data, - json=json + method=method, url=url, params=params, headers=headers, data=data, json=json ) as response: response.raise_for_status() # 如果状态码不是200,抛出异常 result = await response.json() @@ -89,4 +76,3 @@ async def async_request( # return await result # else: # raise RuntimeError("请求失败",response.text) - diff --git a/pkg/entity/persistence/vector.py b/pkg/entity/persistence/vector.py index 84d1dfb1..465125f5 100644 --- a/pkg/entity/persistence/vector.py +++ b/pkg/entity/persistence/vector.py @@ -1,14 +1,13 @@ -from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary -from sqlalchemy.orm import declarative_base, sessionmaker, relationship -from datetime import datetime -import numpy as np # 用于处理从LargeBinary转换回来的embedding +from sqlalchemy import Column, Integer, ForeignKey, LargeBinary +from sqlalchemy.orm import declarative_base, relationship Base = declarative_base() + class Vector(Base): __tablename__ = 'vectors' id = Column(Integer, primary_key=True, index=True) chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) - embedding = Column(LargeBinary) # Store embeddings as binary + embedding = Column(LargeBinary) # Store embeddings as binary - chunk = relationship("Chunk", back_populates="vector") \ No newline at end of file + chunk = relationship('Chunk', back_populates='vector') diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index 3f3ef512..2730874f 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -16,7 +16,6 @@ from ..logger import EventLogger class AiocqhttpMessageConverter(adapter.MessageConverter): - @staticmethod async def yiri2target( message_chain: platform_message.MessageChain, @@ -62,87 +61,170 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): for node in msg.node_list: msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0]) elif isinstance(msg, platform_message.File): - msg_list.append({"type":"file", "data":{'file': msg.url, "name": msg.name}}) + msg_list.append({'type': 'file', 'data': {'file': msg.url, 'name': msg.name}}) elif isinstance(msg, platform_message.Face): - if msg.face_type=='face': + if msg.face_type == 'face': msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id)) - elif msg.face_type=='rps': + elif msg.face_type == 'rps': msg_list.append(aiocqhttp.MessageSegment.rps()) - elif msg.face_type=='dice': + elif msg.face_type == 'dice': msg_list.append(aiocqhttp.MessageSegment.dice()) - else: msg_list.append(aiocqhttp.MessageSegment.text(str(msg))) return msg_list, msg_id, msg_time @staticmethod - async def target2yiri(message: str, message_id: int = -1,bot=None): + async def target2yiri(message: str, message_id: int = -1, bot=None): print(message) message = aiocqhttp.Message(message) def get_face_name(face_id): face_code_dict = { - "2": '好色', - "4": "得意", "5": "流泪", "8": "睡", "9": "大哭", "10": "尴尬", "12": "调皮", "14": "微笑", "16": "酷", - "21": "可爱", - "23": "傲慢", "24": "饥饿", "25": "困", "26": "惊恐", "27": "流汗", "28": "憨笑", "29": "悠闲", - "30": "奋斗", - "32": "疑问", "33": "嘘", "34": "晕", "38": "敲打", "39": "再见", "41": "发抖", "42": "爱情", - "43": "跳跳", - "49": "拥抱", "53": "蛋糕", "60": "咖啡", "63": "玫瑰", "66": "爱心", "74": "太阳", "75": "月亮", - "76": "赞", - "78": "握手", "79": "胜利", "85": "飞吻", "89": "西瓜", "96": "冷汗", "97": "擦汗", "98": "抠鼻", - "99": "鼓掌", - "100": "糗大了", "101": "坏笑", "102": "左哼哼", "103": "右哼哼", "104": "哈欠", "106": "委屈", - "109": "左亲亲", - "111": "可怜", "116": "示爱", "118": "抱拳", "120": "拳头", "122": "爱你", "123": "NO", "124": "OK", - "125": "转圈", - "129": "挥手", "144": "喝彩", "147": "棒棒糖", "171": "茶", "173": "泪奔", "174": "无奈", "175": "卖萌", - "176": "小纠结", "179": "doge", "180": "惊喜", "181": "骚扰", "182": "笑哭", "183": "我最美", - "201": "点赞", - "203": "托脸", "212": "托腮", "214": "啵啵", "219": "蹭一蹭", "222": "抱抱", "227": "拍手", - "232": "佛系", - "240": "喷脸", "243": "甩头", "246": "加油抱抱", "262": "脑阔疼", "264": "捂脸", "265": "辣眼睛", - "266": "哦哟", - "267": "头秃", "268": "问号脸", "269": "暗中观察", "270": "emm", "271": "吃瓜", "272": "呵呵哒", - "273": "我酸了", - "277": "汪汪", "278": "汗", "281": "无眼笑", "282": "敬礼", "284": "面无表情", "285": "摸鱼", - "287": "哦", - "289": "睁眼", "290": "敲开心", "293": "摸锦鲤", "294": "期待", "297": "拜谢", "298": "元宝", - "299": "牛啊", - "305": "右亲亲", "306": "牛气冲天", "307": "喵喵", "314": "仔细分析", "315": "加油", "318": "崇拜", - "319": "比心", - "320": "庆祝", "322": "拒绝", "324": "吃糖", "326": "生气" + '2': '好色', + '4': '得意', + '5': '流泪', + '8': '睡', + '9': '大哭', + '10': '尴尬', + '12': '调皮', + '14': '微笑', + '16': '酷', + '21': '可爱', + '23': '傲慢', + '24': '饥饿', + '25': '困', + '26': '惊恐', + '27': '流汗', + '28': '憨笑', + '29': '悠闲', + '30': '奋斗', + '32': '疑问', + '33': '嘘', + '34': '晕', + '38': '敲打', + '39': '再见', + '41': '发抖', + '42': '爱情', + '43': '跳跳', + '49': '拥抱', + '53': '蛋糕', + '60': '咖啡', + '63': '玫瑰', + '66': '爱心', + '74': '太阳', + '75': '月亮', + '76': '赞', + '78': '握手', + '79': '胜利', + '85': '飞吻', + '89': '西瓜', + '96': '冷汗', + '97': '擦汗', + '98': '抠鼻', + '99': '鼓掌', + '100': '糗大了', + '101': '坏笑', + '102': '左哼哼', + '103': '右哼哼', + '104': '哈欠', + '106': '委屈', + '109': '左亲亲', + '111': '可怜', + '116': '示爱', + '118': '抱拳', + '120': '拳头', + '122': '爱你', + '123': 'NO', + '124': 'OK', + '125': '转圈', + '129': '挥手', + '144': '喝彩', + '147': '棒棒糖', + '171': '茶', + '173': '泪奔', + '174': '无奈', + '175': '卖萌', + '176': '小纠结', + '179': 'doge', + '180': '惊喜', + '181': '骚扰', + '182': '笑哭', + '183': '我最美', + '201': '点赞', + '203': '托脸', + '212': '托腮', + '214': '啵啵', + '219': '蹭一蹭', + '222': '抱抱', + '227': '拍手', + '232': '佛系', + '240': '喷脸', + '243': '甩头', + '246': '加油抱抱', + '262': '脑阔疼', + '264': '捂脸', + '265': '辣眼睛', + '266': '哦哟', + '267': '头秃', + '268': '问号脸', + '269': '暗中观察', + '270': 'emm', + '271': '吃瓜', + '272': '呵呵哒', + '273': '我酸了', + '277': '汪汪', + '278': '汗', + '281': '无眼笑', + '282': '敬礼', + '284': '面无表情', + '285': '摸鱼', + '287': '哦', + '289': '睁眼', + '290': '敲开心', + '293': '摸锦鲤', + '294': '期待', + '297': '拜谢', + '298': '元宝', + '299': '牛啊', + '305': '右亲亲', + '306': '牛气冲天', + '307': '喵喵', + '314': '仔细分析', + '315': '加油', + '318': '崇拜', + '319': '比心', + '320': '庆祝', + '322': '拒绝', + '324': '吃糖', + '326': '生气', } - return face_code_dict.get(face_id,'') + return face_code_dict.get(face_id, '') async def process_message_data(msg_data, reply_list): - if msg_data["type"] == "image": - image_base64, image_format = await image.qq_image_url_to_base64(msg_data["data"]['url']) - reply_list.append( - platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}')) + if msg_data['type'] == 'image': + image_base64, image_format = await image.qq_image_url_to_base64(msg_data['data']['url']) + reply_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}')) - elif msg_data["type"] == "text": - reply_list.append(platform_message.Plain(text=msg_data["data"]["text"])) + elif msg_data['type'] == 'text': + reply_list.append(platform_message.Plain(text=msg_data['data']['text'])) - elif msg_data["type"] == "forward": # 这里来应该传入转发消息组,暂时传入qoute - for forward_msg_datas in msg_data["data"]["content"]: - for forward_msg_data in forward_msg_datas["message"]: + elif msg_data['type'] == 'forward': # 这里来应该传入转发消息组,暂时传入qoute + for forward_msg_datas in msg_data['data']['content']: + for forward_msg_data in forward_msg_datas['message']: await process_message_data(forward_msg_data, reply_list) - elif msg_data["type"] == "at": - if msg_data["data"]['qq'] == 'all': + elif msg_data['type'] == 'at': + if msg_data['data']['qq'] == 'all': reply_list.append(platform_message.AtAll()) else: reply_list.append( platform_message.At( - target=msg_data["data"]['qq'], + target=msg_data['data']['qq'], ) ) - yiri_msg_list = [] yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) @@ -161,10 +243,10 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): elif msg.type == 'text': yiri_msg_list.append(platform_message.Plain(text=msg.data['text'])) elif msg.type == 'image': - emoji_id = msg.data.get("emoji_package_id", None) + emoji_id = msg.data.get('emoji_package_id', None) if emoji_id: face_id = emoji_id - face_name = msg.data.get("summary", '') + face_name = msg.data.get('summary', '') image_msg = platform_message.Face(face_id=face_id, face_name=face_name) else: image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url']) @@ -178,14 +260,15 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): # await process_message_data(msg_data, yiri_msg_list) pass - elif msg.type == 'reply': # 此处处理引用消息传入Qoute - msg_datas = await bot.get_msg(message_id=msg.data["id"]) + msg_datas = await bot.get_msg(message_id=msg.data['id']) - for msg_data in msg_datas["message"]: + for msg_data in msg_datas['message']: await process_message_data(msg_data, reply_list) - reply_msg = platform_message.Quote(message_id=msg.data["id"],sender_id=msg_datas["user_id"],origin=reply_list) + reply_msg = platform_message.Quote( + message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list + ) yiri_msg_list.append(reply_msg) elif msg.type == 'file': @@ -194,49 +277,36 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): file_data = await bot.get_file(file_id=file_id) file_name = file_data.get('file_name') file_path = file_data.get('file') + _ = file_path file_url = file_data.get('file_url') file_size = file_data.get('file_size') - yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size)) + yiri_msg_list.append(platform_message.File(id=file_id, name=file_name, url=file_url, size=file_size)) elif msg.type == 'face': face_id = msg.data['id'] face_name = msg.data['raw']['faceText'] if not face_name: face_name = get_face_name(face_id) - yiri_msg_list.append(platform_message.Face(face_id=int(face_id),face_name=face_name.replace('/',''))) + yiri_msg_list.append(platform_message.Face(face_id=int(face_id), face_name=face_name.replace('/', ''))) elif msg.type == 'rps': face_id = msg.data['result'] - yiri_msg_list.append(platform_message.Face(face_type="rps",face_id=int(face_id),face_name='猜拳')) + yiri_msg_list.append(platform_message.Face(face_type='rps', face_id=int(face_id), face_name='猜拳')) elif msg.type == 'dice': face_id = msg.data['result'] - yiri_msg_list.append(platform_message.Face(face_type='dice',face_id=int(face_id),face_name='骰子')) - - - - - - - - + yiri_msg_list.append(platform_message.Face(face_type='dice', face_id=int(face_id), face_name='骰子')) chain = platform_message.MessageChain(yiri_msg_list) return chain - - - - class AiocqhttpEventConverter(adapter.EventConverter): @staticmethod async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int): return event.source_platform_object @staticmethod - async def target2yiri(event: aiocqhttp.Event,bot=None): - yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id,bot) - - + async def target2yiri(event: aiocqhttp.Event, bot=None): + yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot) if event.message_type == 'group': permission = 'MEMBER' @@ -316,7 +386,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0] if target_type == 'group': - await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg) elif target_type == 'person': await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg) @@ -345,7 +414,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): async def on_message(event: aiocqhttp.Event): self.bot_account_id = event.self_id try: - return await callback(await self.event_converter.target2yiri(event,self.bot), self) + return await callback(await self.event_converter.target2yiri(event, self.bot), self) except Exception: await self.logger.error(f'Error in on_message: {traceback.format_exc()}') traceback.print_exc() diff --git a/pkg/platform/sources/discord.py b/pkg/platform/sources/discord.py index 4f5cac28..6cc09a72 100644 --- a/pkg/platform/sources/discord.py +++ b/pkg/platform/sources/discord.py @@ -8,7 +8,6 @@ import base64 import uuid import os import datetime -import io import aiohttp @@ -78,10 +77,10 @@ class DiscordMessageConverter(adapter.MessageConverter): # 确保路径没有空字节 clean_path = ele.path.replace('\x00', '') clean_path = os.path.abspath(clean_path) - + if not os.path.exists(clean_path): continue # 跳过不存在的文件 - + try: with open(clean_path, 'rb') as f: image_bytes = f.read() @@ -101,12 +100,13 @@ class DiscordMessageConverter(adapter.MessageConverter): filename = f'{uuid.uuid4()}.webp' # 默认保持PNG except Exception as e: - print(f"Error reading image file {clean_path}: {e}") + print(f'Error reading image file {clean_path}: {e}') continue # 跳过读取失败的文件 if image_bytes: # 使用BytesIO创建文件对象,避免路径问题 import io + image_files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename)) elif isinstance(ele, platform_message.Plain): text_string += ele.text @@ -261,25 +261,25 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): msg_to_send, image_files = await self.message_converter.yiri2target(message) - + try: # 获取频道对象 channel = self.bot.get_channel(int(target_id)) if channel is None: # 如果本地缓存中没有,尝试从API获取 channel = await self.bot.fetch_channel(int(target_id)) - + args = { 'content': msg_to_send, } - + if len(image_files) > 0: args['files'] = image_files - + await channel.send(**args) - + except Exception as e: - await self.logger.error(f"Discord send_message failed: {e}") + await self.logger.error(f'Discord send_message failed: {e}') raise e async def reply_message( diff --git a/pkg/platform/sources/lark.py b/pkg/platform/sources/lark.py index d1116362..f8faf522 100644 --- a/pkg/platform/sources/lark.py +++ b/pkg/platform/sources/lark.py @@ -378,15 +378,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter): if 'im.message.receive_v1' == type: try: event = await self.event_converter.target2yiri(p2v1, self.api_client) - except Exception as e: - await self.logger.error(f"Error in lark callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) return {'code': 200, 'message': 'ok'} - except Exception as e: - await self.logger.error(f"Error in lark callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') return {'code': 500, 'message': 'error'} async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1): diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 389a2db1..16ad54db 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -72,8 +72,9 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): content=content_list, ) nakuru_forward_node_list.append(nakuru_forward_node) - except Exception as e: + except Exception: import traceback + traceback.print_exc() nakuru_msg_list.append(nakuru_forward_node_list) @@ -276,7 +277,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): # 注册监听器 self.bot.receiver(source_cls.__name__)(listener_wrapper) except Exception as e: - self.logger.error(f"Error in nakuru register_listener: {traceback.format_exc()}") + self.logger.error(f'Error in nakuru register_listener: {traceback.format_exc()}') raise e def unregister_listener( diff --git a/pkg/platform/sources/officialaccount.py b/pkg/platform/sources/officialaccount.py index 030db56d..3fc1e393 100644 --- a/pkg/platform/sources/officialaccount.py +++ b/pkg/platform/sources/officialaccount.py @@ -125,8 +125,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in officialaccount callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in officialaccount callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/platform/sources/qqofficial.py b/pkg/platform/sources/qqofficial.py index c61afea4..63ab531f 100644 --- a/pkg/platform/sources/qqofficial.py +++ b/pkg/platform/sources/qqofficial.py @@ -154,10 +154,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): raise ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员') self.bot = QQOfficialClient( - app_id=config['appid'], - secret=config['secret'], - token=config['token'], - logger=self.logger + app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger ) async def reply_message( @@ -224,8 +221,8 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = 'justbot' try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in qqofficial callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in qqofficial callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('DIRECT_MESSAGE_CREATE')(on_message) diff --git a/pkg/platform/sources/slack.py b/pkg/platform/sources/slack.py index 6dfcff59..1bd5aa2d 100644 --- a/pkg/platform/sources/slack.py +++ b/pkg/platform/sources/slack.py @@ -104,7 +104,9 @@ class SlackAdapter(adapter.MessagePlatformAdapter): if missing_keys: raise ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员') - self.bot = SlackClient(bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger) + self.bot = SlackClient( + bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger + ) async def reply_message( self, @@ -139,8 +141,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = 'SlackBot' try: return await callback(await self.event_converter.target2yiri(event, self.bot), self) - except Exception as e: - await self.logger.error(f"Error in slack callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in slack callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('im')(on_message) diff --git a/pkg/platform/sources/telegram.py b/pkg/platform/sources/telegram.py index 266d994e..c2fcc22e 100644 --- a/pkg/platform/sources/telegram.py +++ b/pkg/platform/sources/telegram.py @@ -160,8 +160,8 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): try: lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id) await self.listeners[type(lb_event)](lb_event, self) - except Exception as e: - await self.logger.error(f"Error in telegram callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}') self.application = ApplicationBuilder().token(self.config['token']).build() self.bot = self.application.bot diff --git a/pkg/platform/sources/wechatpad.py b/pkg/platform/sources/wechatpad.py index fdd4a69b..75cad727 100644 --- a/pkg/platform/sources/wechatpad.py +++ b/pkg/platform/sources/wechatpad.py @@ -1,5 +1,4 @@ import requests -import websockets import websocket import json import time @@ -10,53 +9,40 @@ from libs.wechatpad_api.client import WeChatPadClient import typing import asyncio import traceback -import time import re import base64 -import uuid -import json -import os import copy -import datetime import threading import quart -import aiohttp from .. import adapter -from ...pipeline.longtext.strategies import forward from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities -from ...utils import image from ..logger import EventLogger import xml.etree.ElementTree as ET -from typing import Optional, List, Tuple +from typing import Optional, Tuple from functools import partial import logging -class WeChatPadMessageConverter(adapter.MessageConverter): +class WeChatPadMessageConverter(adapter.MessageConverter): def __init__(self, config: dict): self.config = config - self.bot = WeChatPadClient(self.config["wechatpad_url"],self.config["token"]) - self.logger = logging.getLogger("WeChatPadMessageConverter") + self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) + self.logger = logging.getLogger('WeChatPadMessageConverter') @staticmethod - async def yiri2target( - message_chain: platform_message.MessageChain - ) -> list[dict]: + async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]: content_list = [] - current_file_path = os.path.abspath(__file__) - - for component in message_chain: if isinstance(component, platform_message.At): - content_list.append({"type": "at", "target": component.target}) + content_list.append({'type': 'at', 'target': component.target}) elif isinstance(component, platform_message.Plain): - content_list.append({"type": "text", "content": component.text}) + content_list.append({'type': 'text', 'content': component.text}) elif isinstance(component, platform_message.Image): if component.url: async with httpx.AsyncClient() as client: @@ -68,15 +54,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter): else: raise Exception('获取文件失败') # pass - content_list.append({"type": "image", "image": base64_str}) + content_list.append({'type': 'image', 'image': base64_str}) elif component.base64: - content_list.append({"type": "image", "image": component.base64}) + content_list.append({'type': 'image', 'image': component.base64}) elif isinstance(component, platform_message.WeChatEmoji): content_list.append( - {'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size}) + {'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size} + ) elif isinstance(component, platform_message.Voice): - content_list.append({"type": "voice", "data": component.url, "duration": component.length, "forma": 0}) + content_list.append({'type': 'voice', 'data': component.url, 'duration': component.length, 'forma': 0}) elif isinstance(component, platform_message.WeChatAppMsg): content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg}) elif isinstance(component, platform_message.Forward): @@ -86,28 +73,23 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return content_list - - async def target2yiri( - self, - message: dict, - bot_account_id: str - ) -> platform_message.MessageChain: + async def target2yiri(self, message: dict, bot_account_id: str) -> platform_message.MessageChain: """外部消息转平台消息""" # 数据预处理 message_list = [] ats_bot = False # 是否被@ - content = message["content"]["str"] + content = message['content']['str'] content_no_preifx = content # 群消息则去掉前缀 is_group_message = self._is_group_message(message) if is_group_message: ats_bot = self._ats_bot(message, bot_account_id) - if "@所有人" in content: + if '@所有人' in content: message_list.append(platform_message.AtAll()) elif ats_bot: message_list.append(platform_message.At(target=bot_account_id)) content_no_preifx, _ = self._extract_content_and_sender(content) - msg_type = message["msg_type"] + msg_type = message['msg_type'] # 映射消息类型到处理器方法 handler_map = { @@ -129,11 +111,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_list) - async def _handler_text( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理文本消息 (msg_type=1)""" if message and self._is_group_message(message): pattern = r'@\S{1,20}' @@ -141,16 +119,12 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain([platform_message.Plain(content_no_preifx)]) - async def _handler_image( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理图像消息 (msg_type=3)""" try: image_xml = content_no_preifx if not image_xml: - return platform_message.MessageChain([platform_message.Unknown("[图片内容为空]")]) + return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')]) root = ET.fromstring(image_xml) # 提取img标签的属性 @@ -160,28 +134,22 @@ class WeChatPadMessageConverter(adapter.MessageConverter): cdnthumburl = img_tag.get('cdnthumburl') # cdnmidimgurl = img_tag.get('cdnmidimgurl') - image_data = self.bot.cdn_download(aeskey=aeskey, file_type=1, file_url=cdnthumburl) - if image_data["Data"]['FileData'] == '': + if image_data['Data']['FileData'] == '': image_data = self.bot.cdn_download(aeskey=aeskey, file_type=2, file_url=cdnthumburl) - base64_str = image_data["Data"]['FileData'] + base64_str = image_data['Data']['FileData'] # self.logger.info(f"data:image/png;base64,{base64_str}") - elements = [ - platform_message.Image(base64=f"data:image/png;base64,{base64_str}"), + platform_message.Image(base64=f'data:image/png;base64,{base64_str}'), # platform_message.WeChatForwardImage(xml_data=image_xml) # 微信消息转发 ] return platform_message.MessageChain(elements) except Exception as e: - self.logger.error(f"处理图片失败: {str(e)}") - return platform_message.MessageChain([platform_message.Unknown("[图片处理失败]")]) + self.logger.error(f'处理图片失败: {str(e)}') + return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')]) - async def _handler_voice( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理语音消息 (msg_type=34)""" message_List = [] try: @@ -197,39 +165,33 @@ class WeChatPadMessageConverter(adapter.MessageConverter): bufid = voicemsg.get('bufid') length = voicemsg.get('voicelength') voice_data = self.bot.get_msg_voice(buf_id=str(bufid), length=int(length), msgid=str(new_msg_id)) - audio_base64 = voice_data["Data"]['Base64'] + audio_base64 = voice_data['Data']['Base64'] # 验证语音数据有效性 if not audio_base64: - message_List.append(platform_message.Unknown(text="[语音内容为空]")) + message_List.append(platform_message.Unknown(text='[语音内容为空]')) return platform_message.MessageChain(message_List) # 转换为平台支持的语音格式(如 Silk 格式) - voice_element = platform_message.Voice( - base64=f"data:audio/silk;base64,{audio_base64}" - ) + voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}') message_List.append(voice_element) except KeyError as e: - self.logger.error(f"语音数据字段缺失: {str(e)}") - message_List.append(platform_message.Unknown(text="[语音数据解析失败]")) + self.logger.error(f'语音数据字段缺失: {str(e)}') + message_List.append(platform_message.Unknown(text='[语音数据解析失败]')) except Exception as e: - self.logger.error(f"处理语音消息异常: {str(e)}") - message_List.append(platform_message.Unknown(text="[语音处理失败]")) + self.logger.error(f'处理语音消息异常: {str(e)}') + message_List.append(platform_message.Unknown(text='[语音处理失败]')) return platform_message.MessageChain(message_List) - async def _handler_compound( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理复合消息 (msg_type=49),根据子类型分派""" try: xml_data = ET.fromstring(content_no_preifx) appmsg_data = xml_data.find('.//appmsg') if appmsg_data: - data_type = appmsg_data.findtext('.//type', "") + data_type = appmsg_data.findtext('.//type', '') # 二次分派处理器 sub_handler_map = { '57': self._handler_compound_quote, @@ -238,9 +200,9 @@ class WeChatPadMessageConverter(adapter.MessageConverter): '74': self._handler_compound_file, '33': self._handler_compound_mini_program, '36': self._handler_compound_mini_program, - '2000': partial(self._handler_compound_unsupported, text="[转账消息]"), - '2001': partial(self._handler_compound_unsupported, text="[红包消息]"), - '51': partial(self._handler_compound_unsupported, text="[视频号消息]"), + '2000': partial(self._handler_compound_unsupported, text='[转账消息]'), + '2001': partial(self._handler_compound_unsupported, text='[红包消息]'), + '51': partial(self._handler_compound_unsupported, text='[视频号消息]'), } handler = sub_handler_map.get(data_type, self._handler_compound_unsupported) @@ -251,56 +213,54 @@ class WeChatPadMessageConverter(adapter.MessageConverter): else: return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) except Exception as e: - self.logger.error(f"解析复合消息失败: {str(e)}") + self.logger.error(f'解析复合消息失败: {str(e)}') return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) async def _handler_compound_quote( - self, - message: Optional[dict], - xml_data: ET.Element + self, message: Optional[dict], xml_data: ET.Element ) -> platform_message.MessageChain: """处理引用消息 (data_type=57)""" message_list = [] -# self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode')) + # self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode')) appmsg_data = xml_data.find('.//appmsg') - quote_data = "" # 引用原文 + quote_data = '' # 引用原文 quote_id = None # 引用消息的原发送者 tousername = None # 接收方: 所属微信的wxid - user_data = "" # 用户消息 + user_data = '' # 用户消息 sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member # 引用消息转发 if appmsg_data: - user_data = appmsg_data.findtext('.//title') or "" + user_data = appmsg_data.findtext('.//title') or '' quote_data = appmsg_data.find('.//refermsg').findtext('.//content') quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') - message_list.append( - platform_message.WeChatAppMsg( - app_msg=ET.tostring(appmsg_data, encoding='unicode')) - ) + message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg_data, encoding='unicode'))) if message: - tousername = message['to_user_name']["str"] - + tousername = message['to_user_name']['str'] + + _ = quote_id + _ = tousername + if quote_data: quote_data_message_list = platform_message.MessageChain() # 文本消息 try: - if "" not in quote_data: + if '' not in quote_data: quote_data_message_list.append(platform_message.Plain(quote_data)) else: # 引用消息展开 quote_data_xml = ET.fromstring(quote_data) - if quote_data_xml.find("img"): + if quote_data_xml.find('img'): quote_data_message_list.extend(await self._handler_image(None, quote_data)) - elif quote_data_xml.find("voicemsg"): + elif quote_data_xml.find('voicemsg'): quote_data_message_list.extend(await self._handler_voice(None, quote_data)) - elif quote_data_xml.find("videomsg"): + elif quote_data_xml.find('videomsg'): quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理 else: # appmsg quote_data_message_list.extend(await self._handler_compound(None, quote_data)) except Exception as e: - self.logger.error(f"处理引用消息异常 expcetion:{e}") + self.logger.error(f'处理引用消息异常 expcetion:{e}') quote_data_message_list.append(platform_message.Plain(quote_data)) message_list.append( platform_message.Quote( @@ -315,15 +275,11 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_list) - async def _handler_compound_file( - self, - message: dict, - xml_data: ET.Element - ) -> platform_message.MessageChain: + async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain: """处理文件消息 (data_type=6)""" file_data = xml_data.find('.//appmsg') - if file_data.findtext('.//type', "") == "74": + if file_data.findtext('.//type', '') == '74': return None else: @@ -346,22 +302,21 @@ class WeChatPadMessageConverter(adapter.MessageConverter): file_data = self.bot.cdn_download(aeskey=aeskey, file_type=5, file_url=cdnthumburl) - file_base64 = file_data["Data"]['FileData'] + file_base64 = file_data['Data']['FileData'] # print(file_data) - file_size = file_data["Data"]['TotalSize'] + file_size = file_data['Data']['TotalSize'] # print(file_base64) - return platform_message.MessageChain([ - platform_message.WeChatFile(file_id=file_id, file_name=file_name, file_size=file_size, - file_base64=file_base64), - platform_message.WeChatForwardFile(xml_data=xml_data_str) - ]) + return platform_message.MessageChain( + [ + platform_message.WeChatFile( + file_id=file_id, file_name=file_name, file_size=file_size, file_base64=file_base64 + ), + platform_message.WeChatForwardFile(xml_data=xml_data_str), + ] + ) - async def _handler_compound_link( - self, - message: dict, - xml_data: ET.Element - ) -> platform_message.MessageChain: + async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain: """处理链接消息(如公众号文章、外部网页)""" message_list = [] try: @@ -374,56 +329,38 @@ class WeChatPadMessageConverter(adapter.MessageConverter): link_title=appmsg.findtext('title', ''), link_desc=appmsg.findtext('des', ''), link_url=appmsg.findtext('url', ''), - link_thumb_url=appmsg.findtext("thumburl", '') # 这个字段拿不到 + link_thumb_url=appmsg.findtext('thumburl', ''), # 这个字段拿不到 ) ) # 还没有发链接的接口, 暂时还需要自己构造appmsg, 先用WeChatAppMsg。 - message_list.append( - platform_message.WeChatAppMsg( - app_msg=ET.tostring(appmsg, encoding='unicode') - ) - ) + message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg, encoding='unicode'))) except Exception as e: - self.logger.error(f"解析链接消息失败: {str(e)}") + self.logger.error(f'解析链接消息失败: {str(e)}') return platform_message.MessageChain(message_list) async def _handler_compound_mini_program( - self, - message: dict, - xml_data: ET.Element + self, message: dict, xml_data: ET.Element ) -> platform_message.MessageChain: """处理小程序消息(如小程序卡片、服务通知)""" xml_data_str = ET.tostring(xml_data, encoding='unicode') - return platform_message.MessageChain([ - platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str) - ]) + return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)]) - async def _handler_default( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理未知消息类型""" if message: - msg_type = message["msg_type"] + msg_type = message['msg_type'] else: - msg_type = "" - return platform_message.MessageChain([ - platform_message.Unknown(text=f"[未知消息类型 msg_type:{msg_type}]") - ]) + msg_type = '' + return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')]) def _handler_compound_unsupported( - self, - message: dict, - xml_data: str, - text: Optional[str] = None + self, message: dict, xml_data: str, text: Optional[str] = None ) -> platform_message.MessageChain: """处理未支持复合消息类型(msg_type=49)子类型""" if not text: - text = f"[xml_data={xml_data}]" + text = f'[xml_data={xml_data}]' content_list = [] - content_list.append( - platform_message.Unknown(text=f"[处理未支持复合消息类型[msg_type=49]|{text}")) + content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}')) return platform_message.MessageChain(content_list) @@ -432,7 +369,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): ats_bot = False try: to_user_name = message['to_user_name']['str'] # 接收方: 所属微信的wxid - raw_content = message["content"]["str"] # 原始消息内容 + raw_content = message['content']['str'] # 原始消息内容 content_no_prefix, _ = self._extract_content_and_sender(raw_content) # 直接艾特机器人(这个有bug,当被引用的消息里面有@bot,会套娃 # ats_bot = ats_bot or (f"@{bot_account_id}" in content_no_prefix) @@ -443,7 +380,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): msg_source = message.get('msg_source', '') or '' if len(msg_source) > 0: msg_source_data = ET.fromstring(msg_source) - at_user_list = msg_source_data.findtext("atuserlist") or "" + at_user_list = msg_source_data.findtext('atuserlist') or '' ats_bot = ats_bot or (to_user_name in at_user_list) # 引用bot if message.get('msg_type', 0) == 49: @@ -454,7 +391,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者 ats_bot = ats_bot or (quote_id == tousername) except Exception as e: - self.logger.error(f"_ats_bot got except: {e}") + self.logger.error(f'_ats_bot got except: {e}') finally: return ats_bot @@ -463,47 +400,41 @@ class WeChatPadMessageConverter(adapter.MessageConverter): try: # 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉 # add: 有些用户的wxid不是上述格式。换成user_name: - regex = re.compile(r"^[a-zA-Z0-9_\-]{5,20}:") - line_split = raw_content.split("\n") + regex = re.compile(r'^[a-zA-Z0-9_\-]{5,20}:') + line_split = raw_content.split('\n') if len(line_split) > 0 and regex.match(line_split[0]): - raw_content = "\n".join(line_split[1:]) - sender_id = line_split[0].strip(":") + raw_content = '\n'.join(line_split[1:]) + sender_id = line_split[0].strip(':') return raw_content, sender_id except Exception as e: - self.logger.error(f"_extract_content_and_sender got except: {e}") + self.logger.error(f'_extract_content_and_sender got except: {e}') finally: return raw_content, None # 是否是群消息 def _is_group_message(self, message: dict) -> bool: from_user_name = message['from_user_name']['str'] - return from_user_name.endswith("@chatroom") + return from_user_name.endswith('@chatroom') class WeChatPadEventConverter(adapter.EventConverter): - def __init__(self, config: dict): self.config = config self.message_converter = WeChatPadMessageConverter(config) - self.logger = logging.getLogger("WeChatPadEventConverter") - + self.logger = logging.getLogger('WeChatPadEventConverter') + @staticmethod - async def yiri2target( - event: platform_events.MessageEvent - ) -> dict: + async def yiri2target(event: platform_events.MessageEvent) -> dict: pass - async def target2yiri( - self, - event: dict, - bot_account_id: str - ) -> platform_events.MessageEvent: - + async def target2yiri(self, event: dict, bot_account_id: str) -> platform_events.MessageEvent: # 排除公众号以及微信团队消息 - if event['from_user_name']['str'].startswith('gh_') \ - or event['from_user_name']['str']=='weixin'\ - or event['from_user_name']['str'] == "newsapp"\ - or event['from_user_name']['str'] == self.config["wxid"]: + if ( + event['from_user_name']['str'].startswith('gh_') + or event['from_user_name']['str'] == 'weixin' + or event['from_user_name']['str'] == 'newsapp' + or event['from_user_name']['str'] == self.config['wxid'] + ): return None message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id) @@ -512,7 +443,7 @@ class WeChatPadEventConverter(adapter.EventConverter): if '@chatroom' in event['from_user_name']['str']: # 找出开头的 wxid_ 字符串,以:结尾 - sender_wxid = event['content']['str'].split(":")[0] + sender_wxid = event['content']['str'].split(':')[0] return platform_events.GroupMessage( sender=platform_entities.GroupMember( @@ -524,13 +455,13 @@ class WeChatPadEventConverter(adapter.EventConverter): name=event['from_user_name']['str'], permission=platform_entities.Permission.Member, ), - special_title="", + special_title='', join_timestamp=0, last_speak_timestamp=0, mute_time_remaining=0, ), message_chain=message_chain, - time=event["create_time"], + time=event['create_time'], source_platform_object=event, ) else: @@ -541,13 +472,13 @@ class WeChatPadEventConverter(adapter.EventConverter): remark='', ), message_chain=message_chain, - time=event["create_time"], + time=event['create_time'], source_platform_object=event, ) class WeChatPadAdapter(adapter.MessagePlatformAdapter): - name: str = "WeChatPad" # 定义适配器名称 + name: str = 'WeChatPad' # 定义适配器名称 bot: WeChatPadClient quart_app: quart.Quart @@ -580,27 +511,21 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # self.ap.logger.debug(f"Gewechat callback event: {data}") # print(data) - try: event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id) - except Exception as e: - await self.logger.error(f"Error in wechatpad callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wechatpad callback: {traceback.format_exc()}') if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) return 'ok' - - async def _handle_message( - self, - message: platform_message.MessageChain, - target_id: str - ): + async def _handle_message(self, message: platform_message.MessageChain, target_id: str): """统一消息处理核心逻辑""" content_list = await self.message_converter.yiri2target(message) # print(content_list) - at_targets = [item["target"] for item in content_list if item["type"] == "at"] + at_targets = [item['target'] for item in content_list if item['type'] == 'at'] # print(at_targets) # 处理@逻辑 at_targets = at_targets or [] @@ -608,7 +533,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if at_targets: member_info = self.bot.get_chatroom_member_detail( target_id, - )["Data"]["member_data"]["chatroom_member_list"] + )['Data']['member_data']['chatroom_member_list'] # 处理消息组件 for msg in content_list: @@ -616,63 +541,51 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if msg['type'] == 'text' and at_targets: at_nick_name_list = [] for member in member_info: - if member["user_name"] in at_targets: + if member['user_name'] in at_targets: at_nick_name_list.append(f'@{member["nick_name"]}') msg['content'] = f'{" ".join(at_nick_name_list)} {msg["content"]}' # 统一消息派发 handler_map = { 'text': lambda msg: self.bot.send_text_message( - to_wxid=target_id, - message=msg['content'], - ats=at_targets + to_wxid=target_id, message=msg['content'], ats=at_targets ), 'image': lambda msg: self.bot.send_image_message( - to_wxid=target_id, - img_url=msg["image"], - ats = at_targets + to_wxid=target_id, img_url=msg['image'], ats=at_targets ), 'WeChatEmoji': lambda msg: self.bot.send_emoji_message( - to_wxid=target_id, - emoji_md5=msg['emoji_md5'], - emoji_size=msg['emoji_size'] + to_wxid=target_id, emoji_md5=msg['emoji_md5'], emoji_size=msg['emoji_size'] ), - 'voice': lambda msg: self.bot.send_voice_message( to_wxid=target_id, voice_data=msg['data'], - voice_duration=msg["duration"], - voice_forma=msg["forma"], + voice_duration=msg['duration'], + voice_forma=msg['forma'], ), 'WeChatAppMsg': lambda msg: self.bot.send_app_message( to_wxid=target_id, app_message=msg['app_msg'], type=0, ), - 'at': lambda msg: None + 'at': lambda msg: None, } if handler := handler_map.get(msg['type']): handler(msg) # self.ap.logger.warning(f"未处理的消息类型: {ret}") else: - self.ap.logger.warning(f"未处理的消息类型: {msg['type']}") + self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') continue - async def send_message( - self, - target_type: str, - target_id: str, - message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): """主动发送消息""" return await self._handle_message(message, target_id) async def reply_message( - self, - message_source: platform_events.MessageEvent, - message: platform_message.MessageChain, - quote_origin: bool = False + self, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, + quote_origin: bool = False, ): """回复消息""" if message_source.source_platform_object: @@ -683,58 +596,49 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): pass def register_listener( - self, - event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] + self, + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): self.listeners[event_type] = callback def unregister_listener( - self, - event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] + self, + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): pass async def run_async(self): - - if not self.config["admin_key"] and not self.config["token"]: - raise RuntimeError("无wechatpad管理密匙,请填入配置文件后重启") + if not self.config['admin_key'] and not self.config['token']: + raise RuntimeError('无wechatpad管理密匙,请填入配置文件后重启') else: - if self.config["token"]: - self.bot = WeChatPadClient( - self.config['wechatpad_url'], - self.config["token"] - ) + if self.config['token']: + self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) data = self.bot.get_login_status() self.ap.logger.info(data) - if data["Code"] == 300 and data["Text"] == "你已退出微信": + if data['Code'] == 300 and data['Text'] == '你已退出微信': response = requests.post( - f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}", - json={"Count": 1, "Days": 365} + f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', + json={'Count': 1, 'Days': 365}, ) if response.status_code != 200: - raise Exception(f"获取token失败: {response.text}") - self.config["token"] = response.json()["Data"][0] + raise Exception(f'获取token失败: {response.text}') + self.config['token'] = response.json()['Data'][0] - elif not self.config["token"]: + elif not self.config['token']: response = requests.post( - f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}", - json={"Count": 1, "Days": 365} + f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', + json={'Count': 1, 'Days': 365}, ) if response.status_code != 200: - raise Exception(f"获取token失败: {response.text}") - self.config["token"] = response.json()["Data"][0] + raise Exception(f'获取token失败: {response.text}') + self.config['token'] = response.json()['Data'][0] - self.bot = WeChatPadClient( - self.config['wechatpad_url'], - self.config["token"], - logger=self.logger - ) - self.ap.logger.info(self.config["token"]) + self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger) + self.ap.logger.info(self.config['token']) thread_1 = threading.Event() - def wechat_login_process(): # 不登录,这些先注释掉,避免登陆态尝试拉qrcode。 # login_data =self.bot.get_login_qr() @@ -742,67 +646,54 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # url = login_data['Data']["QrCodeUrl"] # self.ap.logger.info(login_data) - - profile =self.bot.get_profile() + profile = self.bot.get_profile() self.ap.logger.info(profile) - self.bot_account_id = profile["Data"]["userInfo"]["nickName"]["str"] - self.config["wxid"] = profile["Data"]["userInfo"]["userName"]["str"] + self.bot_account_id = profile['Data']['userInfo']['nickName']['str'] + self.config['wxid'] = profile['Data']['userInfo']['userName']['str'] thread_1.set() - # asyncio.create_task(wechat_login_process) threading.Thread(target=wechat_login_process).start() def connect_websocket_sync() -> None: - thread_1.wait() - uri = f"{self.config['wechatpad_ws']}/GetSyncMsg?key={self.config['token']}" - self.ap.logger.info(f"Connecting to WebSocket: {uri}") + uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}' + self.ap.logger.info(f'Connecting to WebSocket: {uri}') + def on_message(ws, message): try: data = json.loads(message) - self.ap.logger.debug(f"Received message: {data}") + self.ap.logger.debug(f'Received message: {data}') # 这里需要确保ws_message是同步的,或者使用asyncio.run调用异步方法 asyncio.run(self.ws_message(data)) except json.JSONDecodeError: - self.ap.logger.error(f"Non-JSON message: {message[:100]}...") + self.ap.logger.error(f'Non-JSON message: {message[:100]}...') def on_error(ws, error): - self.ap.logger.error(f"WebSocket error: {str(error)[:200]}") + self.ap.logger.error(f'WebSocket error: {str(error)[:200]}') def on_close(ws, close_status_code, close_msg): - self.ap.logger.info("WebSocket closed, reconnecting...") + self.ap.logger.info('WebSocket closed, reconnecting...') time.sleep(5) connect_websocket_sync() # 自动重连 def on_open(ws): - self.ap.logger.info("WebSocket connected successfully!") + self.ap.logger.info('WebSocket connected successfully!') ws = websocket.WebSocketApp( - uri, - on_message=on_message, - on_error=on_error, - on_close=on_close, - on_open=on_open - ) - ws.run_forever( - ping_interval=60, - ping_timeout=20 + uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open ) + ws.run_forever(ping_interval=60, ping_timeout=20) # 直接调用同步版本(会阻塞) # connect_websocket_sync() # 这行代码会在WebSocket连接断开后才会执行 # self.ap.logger.info("WebSocket client thread started") - thread = threading.Thread( - target=connect_websocket_sync, - name="WebSocketClientThread", - daemon=True - ) + thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True) thread.start() - self.ap.logger.info("WebSocket client thread started") + self.ap.logger.info('WebSocket client thread started') async def kill(self) -> bool: pass diff --git a/pkg/platform/sources/wecom.py b/pkg/platform/sources/wecom.py index f1cc677e..7be05a85 100644 --- a/pkg/platform/sources/wecom.py +++ b/pkg/platform/sources/wecom.py @@ -157,7 +157,7 @@ class WecomAdapter(adapter.MessagePlatformAdapter): token=config['token'], EncodingAESKey=config['EncodingAESKey'], contacts_secret=config['contacts_secret'], - logger=self.logger + logger=self.logger, ) async def reply_message( @@ -201,8 +201,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in wecom callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wecom callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/platform/sources/wecomcs.py b/pkg/platform/sources/wecomcs.py index aab8d394..da84ac6d 100644 --- a/pkg/platform/sources/wecomcs.py +++ b/pkg/platform/sources/wecomcs.py @@ -145,7 +145,7 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): secret=config['secret'], token=config['token'], EncodingAESKey=config['EncodingAESKey'], - logger=self.logger + logger=self.logger, ) async def reply_message( @@ -178,8 +178,8 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in wecomcs callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wecomcs callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py index a8c35883..35a52453 100644 --- a/pkg/rag/knowledge/services/database.py +++ b/pkg/rag/knowledge/services/database.py @@ -1,19 +1,20 @@ from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary from sqlalchemy.orm import declarative_base, sessionmaker, relationship from datetime import datetime -import numpy as np # 用于处理从LargeBinary转换回来的embedding Base = declarative_base() + class KnowledgeBase(Base): __tablename__ = 'kb' id = Column(Integer, primary_key=True, index=True) name = Column(String, index=True) description = Column(Text) created_at = Column(DateTime, default=datetime.utcnow) - embedding_model = Column(String, default="") # 默认嵌入模型 + embedding_model = Column(String, default='') # 默认嵌入模型 top_k = Column(Integer, default=5) # 默认返回的top_k数量 - files = relationship("File", back_populates="knowledge_base") + files = relationship('File', back_populates='knowledge_base') + class File(Base): __tablename__ = 'file' @@ -24,8 +25,9 @@ class File(Base): created_at = Column(DateTime, default=datetime.utcnow) file_type = Column(String) status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误 - knowledge_base = relationship("KnowledgeBase", back_populates="files") - chunks = relationship("Chunk", back_populates="file") + knowledge_base = relationship('KnowledgeBase', back_populates='files') + chunks = relationship('Chunk', back_populates='file') + class Chunk(Base): __tablename__ = 'chunks' @@ -33,26 +35,30 @@ class Chunk(Base): file_id = Column(Integer, ForeignKey('file.id')) text = Column(Text) - file = relationship("File", back_populates="chunks") - vector = relationship("Vector", uselist=False, back_populates="chunk") # One-to-one + file = relationship('File', back_populates='chunks') + vector = relationship('Vector', uselist=False, back_populates='chunk') # One-to-one + class Vector(Base): __tablename__ = 'vectors' id = Column(Integer, primary_key=True, index=True) chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) - embedding = Column(LargeBinary) # Store embeddings as binary + embedding = Column(LargeBinary) # Store embeddings as binary + + chunk = relationship('Chunk', back_populates='vector') - chunk = relationship("Chunk", back_populates="vector") # 数据库连接 -DATABASE_URL = "sqlite:///./knowledge_base.db" # 生产环境请更换为 PostgreSQL/MySQL -engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {}) +DATABASE_URL = 'sqlite:///./knowledge_base.db' # 生产环境请更换为 PostgreSQL/MySQL +engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False} if 'sqlite' in DATABASE_URL else {}) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + # 创建所有表 (可以在应用启动时执行一次) def create_db_and_tables(): Base.metadata.create_all(bind=engine) - print("Database tables created/checked.") + print('Database tables created/checked.') + # 定义嵌入维度(请根据你实际使用的模型调整) -EMBEDDING_DIM = 1024 \ No newline at end of file +EMBEDDING_DIM = 1024 diff --git a/pkg/rag/knowledge/services/embedding_models.py b/pkg/rag/knowledge/services/embedding_models.py index a6ce73ae..7301d640 100644 --- a/pkg/rag/knowledge/services/embedding_models.py +++ b/pkg/rag/knowledge/services/embedding_models.py @@ -1,14 +1,15 @@ # services/embedding_models.py import os -from typing import Dict, Any, List, Type, Optional +from typing import Dict, Any, List import logging -import aiohttp # Import aiohttp for asynchronous requests +import aiohttp # Import aiohttp for asynchronous requests import asyncio from sentence_transformers import SentenceTransformer logger = logging.getLogger(__name__) + # Base class for all embedding models class BaseEmbeddingModel: def __init__(self, model_name: str): @@ -27,9 +28,10 @@ class BaseEmbeddingModel: def embedding_dimension(self) -> int: """Returns the embedding dimension of the model.""" if self._embedding_dimension is None: - raise NotImplementedError("Embedding dimension not set for this model.") + raise NotImplementedError('Embedding dimension not set for this model.') return self._embedding_dimension - + + class EmbeddingModelFactory: @staticmethod def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel: @@ -39,26 +41,29 @@ class EmbeddingModelFactory: """ if model_name_key not in EMBEDDING_MODEL_CONFIGS: raise ValueError(f"Embedding model configuration '{model_name_key}' not found in EMBEDDING_MODEL_CONFIGS.") - + config = EMBEDDING_MODEL_CONFIGS[model_name_key] - - if config['type'] == "third_party_api": + + if config['type'] == 'third_party_api': required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension'] if not all(key in config for key in required_keys): - raise ValueError(f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}") - + raise ValueError( + f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}" + ) + # Retrieve model_name from config if it differs from model_name_key # Some APIs expect a specific 'model' value in the payload that might be different from the key - api_model_name = config.get('model_name', model_name_key) + api_model_name = config.get('model_name', model_name_key) return ThirdPartyAPIEmbeddingModel( - model_name=api_model_name, # Use the model_name from config or the key + model_name=api_model_name, # Use the model_name from config or the key api_endpoint=config['api_endpoint'], headers=config['headers'], payload_template=config['payload_template'], - embedding_dimension=config['embedding_dimension'] + embedding_dimension=config['embedding_dimension'], ) + class SentenceTransformerEmbeddingModel(BaseEmbeddingModel): def __init__(self, model_name: str): super().__init__(model_name) @@ -68,9 +73,11 @@ class SentenceTransformerEmbeddingModel(BaseEmbeddingModel): # if not run in a separate thread/process, but this keeps the API consistent. self.model = SentenceTransformer(model_name) self._embedding_dimension = self.model.get_sentence_embedding_dimension() - logger.info(f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}") + logger.info( + f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}" + ) except Exception as e: - logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}") + logger.error(f'Failed to load SentenceTransformer model {model_name}: {e}') raise async def embed_documents(self, texts: List[str]) -> List[List[float]]: @@ -84,14 +91,23 @@ class SentenceTransformerEmbeddingModel(BaseEmbeddingModel): class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): - def __init__(self, model_name: str, api_endpoint: str, headers: Dict[str, str], payload_template: Dict[str, Any], embedding_dimension: int): + def __init__( + self, + model_name: str, + api_endpoint: str, + headers: Dict[str, str], + payload_template: Dict[str, Any], + embedding_dimension: int, + ): super().__init__(model_name) self.api_endpoint = api_endpoint self.headers = headers self.payload_template = payload_template self._embedding_dimension = embedding_dimension - self.session = None # aiohttp client session will be initialized on first use or in a context manager - logger.info(f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}") + self.session = None # aiohttp client session will be initialized on first use or in a context manager + logger.info( + f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}" + ) async def _get_session(self): """Lazily create or return the aiohttp client session.""" @@ -104,7 +120,7 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): if self.session and not self.session.closed: await self.session.close() self.session = None - logger.info(f"Closed aiohttp session for model {self.model_name}") + logger.info(f'Closed aiohttp session for model {self.model_name}') async def embed_documents(self, texts: List[str]) -> List[List[float]]: """Asynchronously embeds a list of texts using the third-party API.""" @@ -118,10 +134,10 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): elif 'texts' in payload: payload['texts'] = [text] else: - raise ValueError("Payload template does not contain expected text input key.") + raise ValueError('Payload template does not contain expected text input key.') tasks.append(self._make_api_request(session, payload)) - + results = await asyncio.gather(*tasks, return_exceptions=True) for i, res in enumerate(results): @@ -131,93 +147,92 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): # - Append None or an empty list # - Re-raise the exception to stop processing # - Log and skip, then continue - embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure + embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure else: embeddings.append(res) - + return embeddings async def _make_api_request(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> List[float]: """Helper to make an asynchronous API request and extract embedding.""" try: async with session.post(self.api_endpoint, headers=self.headers, json=payload) as response: - response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx) + response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx) api_response = await response.json() - + # Adjust this based on your API's actual response structure - if "data" in api_response and len(api_response["data"]) > 0 and "embedding" in api_response["data"][0]: - embedding = api_response["data"][0]["embedding"] + if 'data' in api_response and len(api_response['data']) > 0 and 'embedding' in api_response['data'][0]: + embedding = api_response['data'][0]['embedding'] if len(embedding) != self.embedding_dimension: - logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.") + logger.warning( + f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.' + ) return embedding - elif "embeddings" in api_response and isinstance(api_response["embeddings"], list) and api_response["embeddings"]: - embedding = api_response["embeddings"][0] + elif ( + 'embeddings' in api_response + and isinstance(api_response['embeddings'], list) + and api_response['embeddings'] + ): + embedding = api_response['embeddings'][0] if len(embedding) != self.embedding_dimension: - logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.") + logger.warning( + f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.' + ) return embedding else: - raise ValueError(f"Unexpected API response structure: {api_response}") + raise ValueError(f'Unexpected API response structure: {api_response}') except aiohttp.ClientError as e: - raise ConnectionError(f"API request failed: {e}") from e + raise ConnectionError(f'API request failed: {e}') from e except ValueError as e: - raise ValueError(f"Error processing API response: {e}") from e - + raise ValueError(f'Error processing API response: {e}') from e async def embed_query(self, text: str) -> List[float]: """Asynchronously embeds a single query text.""" results = await self.embed_documents([text]) if results: return results[0] - return [] # Or raise an error if embedding a query must always succeed + return [] # Or raise an error if embedding a query must always succeed + # --- Embedding Model Configuration --- EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = { - "MiniLM": { # Example for a local Sentence Transformer model - "type": "sentence_transformer", - "model_name": "sentence-transformers/all-MiniLM-L6-v2" + 'MiniLM': { # Example for a local Sentence Transformer model + 'type': 'sentence_transformer', + 'model_name': 'sentence-transformers/all-MiniLM-L6-v2', }, - "bge-m3": { # Example for a third-party API model - "type": "third_party_api", - "model_name": "bge-m3", - "api_endpoint": "https://api.qhaigc.net/v1/embeddings", - "headers": { - "Content-Type": "application/json", - "Authorization": f"Bearer {os.getenv('rag_api_key')}" - }, - "payload_template": { - "model": "bge-m3", - "input": "" - }, - "embedding_dimension": 1024 + 'bge-m3': { # Example for a third-party API model + 'type': 'third_party_api', + 'model_name': 'bge-m3', + 'api_endpoint': 'https://api.qhaigc.net/v1/embeddings', + 'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("rag_api_key")}'}, + 'payload_template': {'model': 'bge-m3', 'input': ''}, + 'embedding_dimension': 1024, }, - "OpenAI-Ada-002": { - "type": "third_party_api", - "model_name": "text-embedding-ada-002", - "api_endpoint": "https://api.openai.com/v1/embeddings", - "headers": { - "Content-Type": "application/json", - "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" # Ensure OPENAI_API_KEY is set + 'OpenAI-Ada-002': { + 'type': 'third_party_api', + 'model_name': 'text-embedding-ada-002', + 'api_endpoint': 'https://api.openai.com/v1/embeddings', + 'headers': { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}', # Ensure OPENAI_API_KEY is set }, - "payload_template": { - "model": "text-embedding-ada-002", - "input": "" # Text will be injected here + 'payload_template': { + 'model': 'text-embedding-ada-002', + 'input': '', # Text will be injected here }, - "embedding_dimension": 1536 + 'embedding_dimension': 1536, }, - "OpenAI-Embedding-3-Small": { - "type": "third_party_api", - "model_name": "text-embedding-3-small", - "api_endpoint": "https://api.openai.com/v1/embeddings", - "headers": { - "Content-Type": "application/json", - "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" - }, - "payload_template": { - "model": "text-embedding-3-small", - "input": "", + 'OpenAI-Embedding-3-Small': { + 'type': 'third_party_api', + 'model_name': 'text-embedding-3-small', + 'api_endpoint': 'https://api.openai.com/v1/embeddings', + 'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}'}, + 'payload_template': { + 'model': 'text-embedding-3-small', + 'input': '', # "dimensions": 512 # Optional: uncomment if you want a specific output dimension }, - "embedding_dimension": 1536 # Default max dimension for text-embedding-3-small + 'embedding_dimension': 1536, # Default max dimension for text-embedding-3-small }, -} \ No newline at end of file +} diff --git a/pkg/rag/knowledge/services/parser.py b/pkg/rag/knowledge/services/parser.py index 5fa7d589..bea49721 100644 --- a/pkg/rag/knowledge/services/parser.py +++ b/pkg/rag/knowledge/services/parser.py @@ -1,22 +1,21 @@ - import PyPDF2 from docx import Document import pandas as pd -import csv import chardet -from typing import Union, List, Callable, Any +from typing import Union, Callable, Any import logging import markdown from bs4 import BeautifulSoup import ebooklib from ebooklib import epub import re -import asyncio # Import asyncio for async operations +import asyncio # Import asyncio for async operations import os # Configure logging logger = logging.getLogger(__name__) + class FileParser: """ A robust file parser class to extract text content from various document formats. @@ -24,8 +23,8 @@ class FileParser: All core file reading operations are designed to be run synchronously in a thread pool to avoid blocking the asyncio event loop. """ + def __init__(self): - self.logger = logging.getLogger(self.__class__.__name__) async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any: @@ -36,14 +35,14 @@ class FileParser: try: return await asyncio.to_thread(sync_func, *args, **kwargs) except Exception as e: - self.logger.error(f"Error running synchronous function {sync_func.__name__}: {e}") + self.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}') raise async def parse(self, file_path: str) -> Union[str, None]: """ Parses the file based on its extension and returns the extracted text content. This is the main asynchronous entry point for parsing. - + Args: file_path (str): The path to the file to be parsed. @@ -51,21 +50,21 @@ class FileParser: Union[str, None]: The extracted text content as a single string, or None if parsing fails. """ if not file_path or not os.path.exists(file_path): - self.logger.error(f"Invalid file path provided: {file_path}") + self.logger.error(f'Invalid file path provided: {file_path}') return None file_extension = file_path.split('.')[-1].lower() parser_method = getattr(self, f'_parse_{file_extension}', None) - + if parser_method is None: - self.logger.error(f"Unsupported file format: {file_extension} for file {file_path}") + self.logger.error(f'Unsupported file format: {file_extension} for file {file_path}') return None - + try: # Pass file_path to the specific parser methods return await parser_method(file_path) except Exception as e: - self.logger.error(f"Failed to parse {file_extension} file {file_path}: {e}") + self.logger.error(f'Failed to parse {file_extension} file {file_path}: {e}') return None # --- Helper for reading files with encoding detection --- @@ -74,15 +73,16 @@ class FileParser: Reads a file with automatic encoding detection, ensuring the synchronous file read operation runs in a separate thread. """ + def _read_sync(): with open(file_path, 'rb') as file: raw_data = file.read() detected = chardet.detect(raw_data) encoding = detected['encoding'] or 'utf-8' - + if mode == 'r': return raw_data.decode(encoding, errors='ignore') - return raw_data # For binary mode + return raw_data # For binary mode return await self._run_sync(_read_sync) @@ -90,12 +90,13 @@ class FileParser: async def _parse_txt(self, file_path: str) -> str: """Parses a TXT file and returns its content.""" - self.logger.info(f"Parsing TXT file: {file_path}") + self.logger.info(f'Parsing TXT file: {file_path}') return await self._read_file_content(file_path, mode='r') async def _parse_pdf(self, file_path: str) -> str: """Parses a PDF file and returns its text content.""" - self.logger.info(f"Parsing PDF file: {file_path}") + self.logger.info(f'Parsing PDF file: {file_path}') + def _parse_pdf_sync(): text_content = [] with open(file_path, 'rb') as file: @@ -105,57 +106,69 @@ class FileParser: if text: text_content.append(text) return '\n'.join(text_content) + return await self._run_sync(_parse_pdf_sync) async def _parse_docx(self, file_path: str) -> str: """Parses a DOCX file and returns its text content.""" - self.logger.info(f"Parsing DOCX file: {file_path}") + self.logger.info(f'Parsing DOCX file: {file_path}') + def _parse_docx_sync(): doc = Document(file_path) text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()] return '\n'.join(text_content) + return await self._run_sync(_parse_docx_sync) - + async def _parse_doc(self, file_path: str) -> str: """Handles .doc files, explicitly stating lack of direct support.""" - self.logger.warning(f"Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.") - raise NotImplementedError("Direct .doc parsing not supported. Please convert to .docx first.") - + self.logger.warning(f'Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.') + raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.') + async def _parse_xlsx(self, file_path: str) -> str: """Parses an XLSX file, returning text from all sheets.""" - self.logger.info(f"Parsing XLSX file: {file_path}") + self.logger.info(f'Parsing XLSX file: {file_path}') + def _parse_xlsx_sync(): excel_file = pd.ExcelFile(file_path) all_sheet_content = [] for sheet_name in excel_file.sheet_names: df = pd.read_excel(file_path, sheet_name=sheet_name) - sheet_text = f"--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n" + sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n' all_sheet_content.append(sheet_text) return '\n'.join(all_sheet_content) + return await self._run_sync(_parse_xlsx_sync) - + async def _parse_csv(self, file_path: str) -> str: """Parses a CSV file and returns its content as a string.""" - self.logger.info(f"Parsing CSV file: {file_path}") + self.logger.info(f'Parsing CSV file: {file_path}') + def _parse_csv_sync(): # pd.read_csv can often detect encoding, but explicit detection is safer - raw_data = self._read_file_content(file_path, mode='rb') # Note: this will need to be await outside this sync function + raw_data = self._read_file_content( + file_path, mode='rb' + ) # Note: this will need to be await outside this sync function + _ = raw_data # For simplicity, we'll let pandas handle encoding internally after a raw read. # A more robust solution might pass encoding directly to pd.read_csv after detection. detected = chardet.detect(open(file_path, 'rb').read()) encoding = detected['encoding'] or 'utf-8' df = pd.read_csv(file_path, encoding=encoding) return df.to_string(index=False) + return await self._run_sync(_parse_csv_sync) - + async def _parse_markdown(self, file_path: str) -> str: """Parses a Markdown file, converting it to structured plain text.""" - self.logger.info(f"Parsing Markdown file: {file_path}") + self.logger.info(f'Parsing Markdown file: {file_path}') + def _parse_markdown_sync(): - md_content = self._read_file_content(file_path, mode='r') # This is a synchronous call within a sync function + md_content = self._read_file_content( + file_path, mode='r' + ) # This is a synchronous call within a sync function html_content = markdown.markdown( - md_content, - extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code'] + md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code'] ) soup = BeautifulSoup(html_content, 'html.parser') text_parts = [] @@ -169,13 +182,13 @@ class FileParser: text_parts.append(text) elif element.name in ['ul', 'ol']: for li in element.find_all('li'): - text_parts.append(f"* {li.get_text().strip()}") + text_parts.append(f'* {li.get_text().strip()}') elif element.name == 'pre': code_block = element.get_text().strip() if code_block: - text_parts.append(f"```\n{code_block}\n```") + text_parts.append(f'```\n{code_block}\n```') elif element.name == 'table': - table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper if table_str: text_parts.append(table_str) elif element.name: @@ -184,15 +197,17 @@ class FileParser: text_parts.append(text) cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) return cleaned_text.strip() + return await self._run_sync(_parse_markdown_sync) async def _parse_html(self, file_path: str) -> str: """Parses an HTML file, extracting structured plain text.""" - self.logger.info(f"Parsing HTML file: {file_path}") + self.logger.info(f'Parsing HTML file: {file_path}') + def _parse_html_sync(): - html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function + html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function soup = BeautifulSoup(html_content, 'html.parser') - for script_or_style in soup(["script", "style"]): + for script_or_style in soup(['script', 'style']): script_or_style.decompose() text_parts = [] for element in soup.body.children if soup.body else soup.children: @@ -207,9 +222,9 @@ class FileParser: for li in element.find_all('li'): text = li.get_text().strip() if text: - text_parts.append(f"* {text}") + text_parts.append(f'* {text}') elif element.name == 'table': - table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper if table_str: text_parts.append(table_str) elif element.name: @@ -218,39 +233,42 @@ class FileParser: text_parts.append(text) cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) return cleaned_text.strip() + return await self._run_sync(_parse_html_sync) - + async def _parse_epub(self, file_path: str) -> str: """Parses an EPUB file, extracting metadata and content.""" - self.logger.info(f"Parsing EPUB file: {file_path}") + self.logger.info(f'Parsing EPUB file: {file_path}') + def _parse_epub_sync(): book = epub.read_epub(file_path) text_content = [] title_meta = book.get_metadata('DC', 'title') if title_meta: - text_content.append(f"Title: {title_meta[0][0]}") + text_content.append(f'Title: {title_meta[0][0]}') creator_meta = book.get_metadata('DC', 'creator') if creator_meta: - text_content.append(f"Author: {creator_meta[0][0]}") + text_content.append(f'Author: {creator_meta[0][0]}') date_meta = book.get_metadata('DC', 'date') if date_meta: - text_content.append(f"Publish Date: {date_meta[0][0]}") + text_content.append(f'Publish Date: {date_meta[0][0]}') toc = book.get_toc() if toc: - text_content.append("\n--- Table of Contents ---") - self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper - text_content.append("--- End of Table of Contents ---\n") + text_content.append('\n--- Table of Contents ---') + self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper + text_content.append('--- End of Table of Contents ---\n') for item in book.get_items(): if item.get_type() == ebooklib.ITEM_DOCUMENT: html_content = item.get_content().decode('utf-8', errors='ignore') soup = BeautifulSoup(html_content, 'html.parser') - for junk in soup(["script", "style", "nav", "header", "footer"]): + for junk in soup(['script', 'style', 'nav', 'header', 'footer']): junk.decompose() text = soup.get_text(separator='\n', strip=True) text = re.sub(r'\n\s*\n', '\n\n', text) if text: text_content.append(text) return re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_content)).strip() + return await self._run_sync(_parse_epub_sync) def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int): @@ -259,10 +277,10 @@ class FileParser: for item in toc_list: if isinstance(item, tuple): chapter, subchapters = item - text_content.append(f"{indent}- {chapter.title}") + text_content.append(f'{indent}- {chapter.title}') self._add_toc_items_sync(subchapters, text_content, level + 1) else: - text_content.append(f"{indent}- {item.title}") + text_content.append(f'{indent}- {item.title}') def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str: """Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous).""" @@ -272,17 +290,17 @@ class FileParser: cells = [td.get_text().strip() for td in tr.find_all('td')] if cells: rows.append(cells) - + if not headers and not rows: - return "" + return '' table_lines = [] if headers: table_lines.append(' | '.join(headers)) table_lines.append(' | '.join(['---'] * len(headers))) - + for row_cells in rows: padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells table_lines.append(' | '.join(padded_cells)) - - return '\n'.join(table_lines) \ No newline at end of file + + return '\n'.join(table_lines) diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py index 4da81eb1..f563f9b3 100644 --- a/pkg/rag/knowledge/services/retriever.py +++ b/pkg/rag/knowledge/services/retriever.py @@ -1,7 +1,6 @@ # services/retriever.py -import asyncio import logging -import numpy as np # Make sure numpy is imported +import numpy as np # Make sure numpy is imported from typing import List, Dict, Any from sqlalchemy.orm import Session from pkg.rag.knowledge.services.base_service import BaseService @@ -11,6 +10,7 @@ from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager logger = logging.getLogger(__name__) + class Retriever(BaseService): def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager): super().__init__() @@ -22,10 +22,14 @@ class Retriever(BaseService): self.embedding_model: BaseEmbeddingModel = self._load_embedding_model() def _load_embedding_model(self) -> BaseEmbeddingModel: - self.logger.info(f"Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...") + self.logger.info( + f'Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...' + ) try: model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key) - self.logger.info(f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}") + self.logger.info( + f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}" + ) return model except Exception as e: self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}") @@ -33,43 +37,42 @@ class Retriever(BaseService): async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]: if not self.embedding_model: - raise RuntimeError("Retriever embedding model not loaded. Please check Retriever initialization.") + raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.') self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}") query_embedding: List[float] = await self.embedding_model.embed_query(query) query_embedding_np = np.array([query_embedding], dtype=np.float32) - chroma_results = await self._run_sync( - self.chroma_manager.search_sync, - query_embedding_np, k - ) + chroma_results = await self._run_sync(self.chroma_manager.search_sync, query_embedding_np, k) # 'ids' is always returned by ChromaDB, even if not explicitly in 'include' - matched_chroma_ids = chroma_results.get("ids", [[]])[0] - distances = chroma_results.get("distances", [[]])[0] - chroma_metadatas = chroma_results.get("metadatas", [[]])[0] - chroma_documents = chroma_results.get("documents", [[]])[0] + matched_chroma_ids = chroma_results.get('ids', [[]])[0] + distances = chroma_results.get('distances', [[]])[0] + chroma_metadatas = chroma_results.get('metadatas', [[]])[0] + chroma_documents = chroma_results.get('documents', [[]])[0] if not matched_chroma_ids: - self.logger.info("No relevant chunks found in Chroma.") + self.logger.info('No relevant chunks found in Chroma.') return [] db_chunk_ids = [] for metadata in chroma_metadatas: - if "chunk_id" in metadata: - db_chunk_ids.append(metadata["chunk_id"]) + if 'chunk_id' in metadata: + db_chunk_ids.append(metadata['chunk_id']) else: self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.") if not db_chunk_ids: - self.logger.warning("No valid chunk_ids extracted from Chroma results metadata.") + self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.') return [] - self.logger.info(f"Fetching {len(db_chunk_ids)} chunk details from relational database...") + self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...') chunks_from_db = await self._run_sync( - lambda cids: self._db_get_chunks_sync(SessionLocal(), cids), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync - db_chunk_ids + lambda cids: self._db_get_chunks_sync( + SessionLocal(), cids + ), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync + db_chunk_ids, ) chunk_map = {chunk.id: chunk for chunk in chunks_from_db} @@ -80,27 +83,29 @@ class Retriever(BaseService): # Ensure original_chunk_id is int for DB lookup original_chunk_id = int(chroma_id.split('_')[-1]) except (ValueError, IndexError): - self.logger.warning(f"Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.") + self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.') continue chunk_text_from_chroma = chroma_documents[i] distance = float(distances[i]) - file_id_from_chroma = chroma_metadatas[i].get("file_id") + file_id_from_chroma = chroma_metadatas[i].get('file_id') chunk_from_db = chunk_map.get(original_chunk_id) - results_list.append({ - "chunk_id": original_chunk_id, - "text": chunk_from_db.text if chunk_from_db else chunk_text_from_chroma, - "distance": distance, - "file_id": file_id_from_chroma - }) + results_list.append( + { + 'chunk_id': original_chunk_id, + 'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma, + 'distance': distance, + 'file_id': file_id_from_chroma, + } + ) - self.logger.info(f"Retrieved {len(results_list)} chunks.") + self.logger.info(f'Retrieved {len(results_list)} chunks.') return results_list def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]: - self.logger.debug(f"Fetching {len(chunk_ids)} chunk details from database (sync).") + self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).') chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all() session.close() - return chunks \ No newline at end of file + return chunks