mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-26 03:44:58 +08:00
Compare commits
71 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
28ce986a8c | ||
|
|
489b145606 | ||
|
|
5e92bffaa6 | ||
|
|
277d1b0e30 | ||
|
|
13f4ed8d2c | ||
|
|
91cb5ca36c | ||
|
|
c34d54a6cb | ||
|
|
2d1737da1f | ||
|
|
a1b8b9d47b | ||
|
|
8df14bf9d9 | ||
|
|
c98d265a1e | ||
|
|
4e6782a6b7 | ||
|
|
5541e9e6d0 | ||
|
|
878ab0ef6b | ||
|
|
b61bd36b14 | ||
|
|
bb672d8f46 | ||
|
|
ba1a26543b | ||
|
|
cb868ee7b2 | ||
|
|
5dd5cb12ad | ||
|
|
2dfa83ff22 | ||
|
|
27bb4e1253 | ||
|
|
45afdbdfbb | ||
|
|
4cbbe9e000 | ||
|
|
333ec346ef | ||
|
|
2f2db4d445 | ||
|
|
f731115805 | ||
|
|
67bc065ccd | ||
|
|
199164fc4b | ||
|
|
c9c26213df | ||
|
|
b7c57104c4 | ||
|
|
cbe297dc59 | ||
|
|
de76fed25a | ||
|
|
a10e61735d | ||
|
|
1ef0193028 | ||
|
|
1e85d02ae4 | ||
|
|
d78a329aa9 | ||
|
|
234b61e2f8 | ||
|
|
9f43097361 | ||
|
|
f395cac893 | ||
|
|
fe122281fd | ||
|
|
6d788cadbc | ||
|
|
a79a22a74d | ||
|
|
2ed3b68790 | ||
|
|
bd9331ce62 | ||
|
|
14c161b733 | ||
|
|
815cdf8b4a | ||
|
|
7d5503dab2 | ||
|
|
9ba1ad5bd3 | ||
|
|
367d04d0f0 | ||
|
|
75c3ddde19 | ||
|
|
ac03a2dceb | ||
|
|
cd25340826 | ||
|
|
ebd8e014c6 | ||
|
|
bef0d73e83 | ||
|
|
8d28ace252 | ||
|
|
39c062f73e | ||
|
|
0e5c9e19e1 | ||
|
|
c5b62b6ba3 | ||
|
|
bbf583ddb5 | ||
|
|
22ef1a399e | ||
|
|
0733f8878f | ||
|
|
f36a61dbb2 | ||
|
|
6d8936bd74 | ||
|
|
d2b93b3296 | ||
|
|
552fee9bac | ||
|
|
34fe8b324d | ||
|
|
c4671fbf1c | ||
|
|
4bcc06c955 | ||
|
|
348f6d9eaa | ||
|
|
157ffdc34c | ||
|
|
c81d5a1a49 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -42,4 +42,5 @@ botpy.log*
|
||||
test.py
|
||||
/web_ui
|
||||
.venv/
|
||||
uv.lock
|
||||
uv.lock
|
||||
/test
|
||||
@@ -13,7 +13,7 @@
|
||||
[](https://deepwiki.com/langbot-app/LangBot)
|
||||
[](https://github.com/langbot-app/LangBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
[](https://gitcode.com/langbot-app/LangBot)
|
||||
[](https://gitcode.com/RockChinQ/LangBot)
|
||||
|
||||
<a href="https://langbot.app">项目主页</a> |
|
||||
<a href="https://docs.langbot.app/zh/insight/guide.html">部署文档</a> |
|
||||
@@ -27,7 +27,8 @@
|
||||
|
||||
## ✨ 特性
|
||||
|
||||
- 💬 大模型对话、Agent:支持多种大模型,适配群聊和私聊;具有多轮对话、工具调用、多模态能力,并深度适配 [Dify](https://dify.ai)。目前支持 QQ、QQ频道、企业微信、个人微信、飞书、Discord、Telegram 等平台。
|
||||
- 💬 大模型对话、Agent:支持多种大模型,适配群聊和私聊;具有多轮对话、工具调用、多模态能力,自带 RAG(知识库)实现,并深度适配 [Dify](https://dify.ai)。
|
||||
- 🤖 多平台支持:目前支持 QQ、QQ频道、企业微信、个人微信、飞书、Discord、Telegram 等平台。
|
||||
- 🛠️ 高稳定性、功能完备:原生支持访问控制、限速、敏感词过滤等机制;配置简单,支持多种部署方式。支持多流水线配置,不同机器人用于不同应用场景。
|
||||
- 🧩 插件扩展、活跃社区:支持事件驱动、组件扩展等插件机制;适配 Anthropic [MCP 协议](https://modelcontextprotocol.io/);目前已有数百个插件。
|
||||
- 😻 Web 管理面板:支持通过浏览器管理 LangBot 实例,不再需要手动编写配置文件。
|
||||
|
||||
@@ -23,7 +23,8 @@
|
||||
|
||||
## ✨ Features
|
||||
|
||||
- 💬 Chat with LLM / Agent: Supports multiple LLMs, adapt to group chats and private chats; Supports multi-round conversations, tool calls, and multi-modal capabilities. Deeply integrates with [Dify](https://dify.ai). Currently supports QQ, QQ Channel, WeCom, personal WeChat, Lark, DingTalk, Discord, Telegram, etc.
|
||||
- 💬 Chat with LLM / Agent: Supports multiple LLMs, adapt to group chats and private chats; Supports multi-round conversations, tool calls, and multi-modal capabilities. Built-in RAG (knowledge base) implementation, and deeply integrates with [Dify](https://dify.ai).
|
||||
- 🤖 Multi-platform Support: Currently supports QQ, QQ Channel, WeCom, personal WeChat, Lark, DingTalk, Discord, Telegram, etc.
|
||||
- 🛠️ High Stability, Feature-rich: Native access control, rate limiting, sensitive word filtering, etc. mechanisms; Easy to use, supports multiple deployment methods. Supports multiple pipeline configurations, different bots can be used for different scenarios.
|
||||
- 🧩 Plugin Extension, Active Community: Support event-driven, component extension, etc. plugin mechanisms; Integrate Anthropic [MCP protocol](https://modelcontextprotocol.io/); Currently has hundreds of plugins.
|
||||
- 😻 [New] Web UI: Support management LangBot instance through the browser. No need to manually write configuration files.
|
||||
|
||||
@@ -23,7 +23,8 @@
|
||||
|
||||
## ✨ 機能
|
||||
|
||||
- 💬 LLM / エージェントとのチャット: 複数のLLMをサポートし、グループチャットとプライベートチャットに対応。マルチラウンドの会話、ツールの呼び出し、マルチモーダル機能をサポート。 [Dify](https://dify.ai) と深く統合。現在、QQ、QQ チャンネル、WeChat、個人 WeChat、Lark、DingTalk、Discord、Telegram など、複数のプラットフォームをサポートしています。
|
||||
- 💬 LLM / エージェントとのチャット: 複数のLLMをサポートし、グループチャットとプライベートチャットに対応。マルチラウンドの会話、ツールの呼び出し、マルチモーダル機能をサポート、RAG(知識ベース)を組み込み、[Dify](https://dify.ai) と深く統合。
|
||||
- 🤖 多プラットフォーム対応: 現在、QQ、QQ チャンネル、WeChat、個人 WeChat、Lark、DingTalk、Discord、Telegram など、複数のプラットフォームをサポートしています。
|
||||
- 🛠️ 高い安定性、豊富な機能: ネイティブのアクセス制御、レート制限、敏感な単語のフィルタリングなどのメカニズムをサポート。使いやすく、複数のデプロイ方法をサポート。複数のパイプライン設定をサポートし、異なるボットを異なる用途に使用できます。
|
||||
- 🧩 プラグイン拡張、活発なコミュニティ: イベント駆動、コンポーネント拡張などのプラグインメカニズムをサポート。適配 Anthropic [MCP プロトコル](https://modelcontextprotocol.io/);豊富なエコシステム、現在数百のプラグインが存在。
|
||||
- 😻 Web UI: ブラウザを通じてLangBotインスタンスを管理することをサポート。
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .client import WeChatPadClient
|
||||
from .client import WeChatPadClient as WeChatPadClient
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from libs.wechatpad_api.util.http_util import async_request, post_json
|
||||
from libs.wechatpad_api.util.http_util import post_json
|
||||
|
||||
|
||||
class ChatRoomApi:
|
||||
@@ -7,8 +7,6 @@ class ChatRoomApi:
|
||||
self.token = token
|
||||
|
||||
def get_chatroom_member_detail(self, chatroom_name):
|
||||
params = {
|
||||
"ChatRoomName": chatroom_name
|
||||
}
|
||||
params = {'ChatRoomName': chatroom_name}
|
||||
url = self.base_url + '/group/GetChatroomMemberDetail'
|
||||
return post_json(url, token=self.token, data=params)
|
||||
|
||||
@@ -1,32 +1,23 @@
|
||||
from libs.wechatpad_api.util.http_util import async_request, post_json
|
||||
from libs.wechatpad_api.util.http_util import post_json
|
||||
import httpx
|
||||
import base64
|
||||
|
||||
|
||||
class DownloadApi:
|
||||
def __init__(self, base_url, token):
|
||||
self.base_url = base_url
|
||||
self.token = token
|
||||
|
||||
def send_download(self, aeskey, file_type, file_url):
|
||||
json_data = {
|
||||
"AesKey": aeskey,
|
||||
"FileType": file_type,
|
||||
"FileURL": file_url
|
||||
}
|
||||
url = self.base_url + "/message/SendCdnDownload"
|
||||
json_data = {'AesKey': aeskey, 'FileType': file_type, 'FileURL': file_url}
|
||||
url = self.base_url + '/message/SendCdnDownload'
|
||||
return post_json(url, token=self.token, data=json_data)
|
||||
|
||||
def get_msg_voice(self,buf_id, length, new_msgid):
|
||||
json_data = {
|
||||
"Bufid": buf_id,
|
||||
"Length": length,
|
||||
"NewMsgId": new_msgid,
|
||||
"ToUserName": ""
|
||||
}
|
||||
url = self.base_url + "/message/GetMsgVoice"
|
||||
def get_msg_voice(self, buf_id, length, new_msgid):
|
||||
json_data = {'Bufid': buf_id, 'Length': length, 'NewMsgId': new_msgid, 'ToUserName': ''}
|
||||
url = self.base_url + '/message/GetMsgVoice'
|
||||
return post_json(url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
async def download_url_to_base64(self, download_url):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url)
|
||||
@@ -36,4 +27,4 @@ class DownloadApi:
|
||||
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
|
||||
return base64_str
|
||||
else:
|
||||
raise Exception('获取文件失败')
|
||||
raise Exception('获取文件失败')
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
from libs.wechatpad_api.util.http_util import post_json,async_request
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
|
||||
class FriendApi:
|
||||
"""联系人API类,处理所有与联系人相关的操作"""
|
||||
|
||||
def __init__(self, base_url: str, token: str):
|
||||
self.base_url = base_url
|
||||
self.token = token
|
||||
|
||||
|
||||
@@ -1,37 +1,34 @@
|
||||
from libs.wechatpad_api.util.http_util import async_request,post_json,get_json
|
||||
from libs.wechatpad_api.util.http_util import post_json, get_json
|
||||
|
||||
|
||||
class LoginApi:
|
||||
def __init__(self, base_url: str, token: str = None, admin_key: str = None):
|
||||
'''
|
||||
"""
|
||||
|
||||
Args:
|
||||
base_url: 原始路径
|
||||
token: token
|
||||
admin_key: 管理员key
|
||||
'''
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.token = token
|
||||
# self.admin_key = admin_key
|
||||
|
||||
def get_token(self, admin_key, day: int=365):
|
||||
def get_token(self, admin_key, day: int = 365):
|
||||
# 获取普通token
|
||||
url = f"{self.base_url}/admin/GenAuthKey1"
|
||||
json_data = {
|
||||
"Count": 1,
|
||||
"Days": day
|
||||
}
|
||||
url = f'{self.base_url}/admin/GenAuthKey1'
|
||||
json_data = {'Count': 1, 'Days': day}
|
||||
return post_json(base_url=url, token=admin_key, data=json_data)
|
||||
|
||||
def get_login_qr(self, Proxy: str = ""):
|
||||
'''
|
||||
def get_login_qr(self, Proxy: str = ''):
|
||||
"""
|
||||
|
||||
Args:
|
||||
Proxy:异地使用时代理
|
||||
|
||||
Returns:json数据
|
||||
|
||||
'''
|
||||
"""
|
||||
"""
|
||||
|
||||
{
|
||||
@@ -49,54 +46,37 @@ class LoginApi:
|
||||
}
|
||||
|
||||
"""
|
||||
#获取登录二维码
|
||||
url = f"{self.base_url}/login/GetLoginQrCodeNew"
|
||||
# 获取登录二维码
|
||||
url = f'{self.base_url}/login/GetLoginQrCodeNew'
|
||||
check = False
|
||||
if Proxy != "":
|
||||
if Proxy != '':
|
||||
check = True
|
||||
json_data = {
|
||||
"Check": check,
|
||||
"Proxy": Proxy
|
||||
}
|
||||
json_data = {'Check': check, 'Proxy': Proxy}
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
def get_login_status(self):
|
||||
# 获取登录状态
|
||||
url = f'{self.base_url}/login/GetLoginStatus'
|
||||
return get_json(base_url=url, token=self.token)
|
||||
|
||||
|
||||
|
||||
def logout(self):
|
||||
# 退出登录
|
||||
url = f'{self.base_url}/login/LogOut'
|
||||
return post_json(base_url=url, token=self.token)
|
||||
|
||||
|
||||
|
||||
|
||||
def wake_up_login(self, Proxy: str = ""):
|
||||
def wake_up_login(self, Proxy: str = ''):
|
||||
# 唤醒登录
|
||||
url = f'{self.base_url}/login/WakeUpLogin'
|
||||
check = False
|
||||
if Proxy != "":
|
||||
if Proxy != '':
|
||||
check = True
|
||||
json_data = {
|
||||
"Check": check,
|
||||
"Proxy": ""
|
||||
}
|
||||
json_data = {'Check': check, 'Proxy': ''}
|
||||
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
|
||||
def login(self,admin_key):
|
||||
def login(self, admin_key):
|
||||
login_status = self.get_login_status()
|
||||
if login_status["Code"] == 300 and login_status["Text"] == "你已退出微信":
|
||||
print("token已经失效,重新获取")
|
||||
if login_status['Code'] == 300 and login_status['Text'] == '你已退出微信':
|
||||
print('token已经失效,重新获取')
|
||||
token_data = self.get_token(admin_key)
|
||||
self.token = token_data["Data"][0]
|
||||
|
||||
|
||||
|
||||
self.token = token_data['Data'][0]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
|
||||
from libs.wechatpad_api.util.http_util import async_request, post_json
|
||||
from libs.wechatpad_api.util.http_util import post_json
|
||||
|
||||
|
||||
class MessageApi:
|
||||
@@ -7,8 +6,8 @@ class MessageApi:
|
||||
self.base_url = base_url
|
||||
self.token = token
|
||||
|
||||
def post_text(self, to_wxid, content, ats: list= []):
|
||||
'''
|
||||
def post_text(self, to_wxid, content, ats: list = []):
|
||||
"""
|
||||
|
||||
Args:
|
||||
app_id: 微信id
|
||||
@@ -18,106 +17,64 @@ class MessageApi:
|
||||
|
||||
Returns:
|
||||
|
||||
'''
|
||||
url = self.base_url + "/message/SendTextMessage"
|
||||
"""
|
||||
url = self.base_url + '/message/SendTextMessage'
|
||||
"""发送文字消息"""
|
||||
json_data = {
|
||||
"MsgItem": [
|
||||
{
|
||||
"AtWxIDList": ats,
|
||||
"ImageContent": "",
|
||||
"MsgType": 0,
|
||||
"TextContent": content,
|
||||
"ToUserName": to_wxid
|
||||
}
|
||||
]
|
||||
}
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
'MsgItem': [
|
||||
{'AtWxIDList': ats, 'ImageContent': '', 'MsgType': 0, 'TextContent': content, 'ToUserName': to_wxid}
|
||||
]
|
||||
}
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
|
||||
|
||||
def post_image(self, to_wxid, img_url, ats: list= []):
|
||||
def post_image(self, to_wxid, img_url, ats: list = []):
|
||||
"""发送图片消息"""
|
||||
# 这里好像可以尝试发送多个暂时未测试
|
||||
json_data = {
|
||||
"MsgItem": [
|
||||
{
|
||||
"AtWxIDList": ats,
|
||||
"ImageContent": img_url,
|
||||
"MsgType": 0,
|
||||
"TextContent": '',
|
||||
"ToUserName": to_wxid
|
||||
}
|
||||
'MsgItem': [
|
||||
{'AtWxIDList': ats, 'ImageContent': img_url, 'MsgType': 0, 'TextContent': '', 'ToUserName': to_wxid}
|
||||
]
|
||||
}
|
||||
url = self.base_url + "/message/SendImageMessage"
|
||||
url = self.base_url + '/message/SendImageMessage'
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration):
|
||||
"""发送语音消息"""
|
||||
json_data = {
|
||||
"ToUserName": to_wxid,
|
||||
"VoiceData": voice_data,
|
||||
"VoiceFormat": voice_forma,
|
||||
"VoiceSecond": voice_duration
|
||||
'ToUserName': to_wxid,
|
||||
'VoiceData': voice_data,
|
||||
'VoiceFormat': voice_forma,
|
||||
'VoiceSecond': voice_duration,
|
||||
}
|
||||
url = self.base_url + "/message/SendVoice"
|
||||
url = self.base_url + '/message/SendVoice'
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag):
|
||||
"""发送名片消息"""
|
||||
param = {
|
||||
"CardAlias": alias,
|
||||
"CardFlag": flag,
|
||||
"CardNickName": nick_name,
|
||||
"CardWxId": name_card_wxid,
|
||||
"ToUserName": to_wxid
|
||||
'CardAlias': alias,
|
||||
'CardFlag': flag,
|
||||
'CardNickName': nick_name,
|
||||
'CardWxId': name_card_wxid,
|
||||
'ToUserName': to_wxid,
|
||||
}
|
||||
url = f"{self.base_url}/message/ShareCardMessage"
|
||||
url = f'{self.base_url}/message/ShareCardMessage'
|
||||
return post_json(base_url=url, token=self.token, data=param)
|
||||
|
||||
def post_emoji(self, to_wxid, emoji_md5, emoji_size:int=0):
|
||||
def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0):
|
||||
"""发送emoji消息"""
|
||||
json_data = {
|
||||
"EmojiList": [
|
||||
{
|
||||
"EmojiMd5": emoji_md5,
|
||||
"EmojiSize": emoji_size,
|
||||
"ToUserName": to_wxid
|
||||
}
|
||||
]
|
||||
}
|
||||
url = f"{self.base_url}/message/SendEmojiMessage"
|
||||
json_data = {'EmojiList': [{'EmojiMd5': emoji_md5, 'EmojiSize': emoji_size, 'ToUserName': to_wxid}]}
|
||||
url = f'{self.base_url}/message/SendEmojiMessage'
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
def post_app_msg(self, to_wxid,xml_data, contenttype:int=0):
|
||||
def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0):
|
||||
"""发送appmsg消息"""
|
||||
json_data = {
|
||||
"AppList": [
|
||||
{
|
||||
"ContentType": contenttype,
|
||||
"ContentXML": xml_data,
|
||||
"ToUserName": to_wxid
|
||||
}
|
||||
]
|
||||
}
|
||||
url = f"{self.base_url}/message/SendAppMessage"
|
||||
json_data = {'AppList': [{'ContentType': contenttype, 'ContentXML': xml_data, 'ToUserName': to_wxid}]}
|
||||
url = f'{self.base_url}/message/SendAppMessage'
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
|
||||
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
|
||||
"""撤回消息"""
|
||||
param = {
|
||||
"ClientMsgId": msg_id,
|
||||
"CreateTime": create_time,
|
||||
"NewMsgId": new_msg_id,
|
||||
"ToUserName": to_wxid
|
||||
}
|
||||
url = f"{self.base_url}/message/RevokeMsg"
|
||||
return post_json(base_url=url, token=self.token, data=param)
|
||||
param = {'ClientMsgId': msg_id, 'CreateTime': create_time, 'NewMsgId': new_msg_id, 'ToUserName': to_wxid}
|
||||
url = f'{self.base_url}/message/RevokeMsg'
|
||||
return post_json(base_url=url, token=self.token, data=param)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import requests
|
||||
import aiohttp
|
||||
|
||||
|
||||
def post_json(base_url, token, data=None):
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
|
||||
url = base_url + f'?key={token}'
|
||||
|
||||
@@ -18,14 +17,12 @@ def post_json(base_url, token, data=None):
|
||||
else:
|
||||
raise RuntimeError(response.text)
|
||||
except Exception as e:
|
||||
print(f"http请求失败, url={url}, exception={e}")
|
||||
print(f'http请求失败, url={url}, exception={e}')
|
||||
raise RuntimeError(str(e))
|
||||
|
||||
def get_json(base_url, token):
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
def get_json(base_url, token):
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
|
||||
url = base_url + f'?key={token}'
|
||||
|
||||
@@ -39,21 +36,18 @@ def get_json(base_url, token):
|
||||
else:
|
||||
raise RuntimeError(response.text)
|
||||
except Exception as e:
|
||||
print(f"http请求失败, url={url}, exception={e}")
|
||||
print(f'http请求失败, url={url}, exception={e}')
|
||||
raise RuntimeError(str(e))
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
|
||||
async def async_request(
|
||||
base_url: str,
|
||||
token_key: str,
|
||||
method: str = 'POST',
|
||||
params: dict = None,
|
||||
# headers: dict = None,
|
||||
data: dict = None,
|
||||
json: dict = None
|
||||
base_url: str,
|
||||
token_key: str,
|
||||
method: str = 'POST',
|
||||
params: dict = None,
|
||||
# headers: dict = None,
|
||||
data: dict = None,
|
||||
json: dict = None,
|
||||
):
|
||||
"""
|
||||
通用异步请求函数
|
||||
@@ -67,18 +61,11 @@ async def async_request(
|
||||
:param json: JSON数据
|
||||
:return: 响应文本
|
||||
"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
url = f"{base_url}?key={token_key}"
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
url = f'{base_url}?key={token_key}'
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
data=data,
|
||||
json=json
|
||||
method=method, url=url, params=params, headers=headers, data=data, json=json
|
||||
) as response:
|
||||
response.raise_for_status() # 如果状态码不是200,抛出异常
|
||||
result = await response.json()
|
||||
@@ -89,4 +76,3 @@ async def async_request(
|
||||
# return await result
|
||||
# else:
|
||||
# raise RuntimeError("请求失败",response.text)
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ preregistered_groups: list[type[RouterGroup]] = []
|
||||
"""Pre-registered list of RouterGroup"""
|
||||
|
||||
|
||||
def group_class(name: str, path: str) -> None:
|
||||
"""Register a RouterGroup"""
|
||||
def group_class(name: str, path: str) -> typing.Callable[[typing.Type[RouterGroup]], typing.Type[RouterGroup]]:
|
||||
"""注册一个 RouterGroup"""
|
||||
|
||||
def decorator(cls: typing.Type[RouterGroup]) -> typing.Type[RouterGroup]:
|
||||
cls.name = name
|
||||
@@ -86,10 +86,11 @@ class RouterGroup(abc.ABC):
|
||||
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
except Exception: # auto 500
|
||||
|
||||
except Exception as e: # 自动 500
|
||||
traceback.print_exc()
|
||||
# return self.http_status(500, -2, str(e))
|
||||
return self.http_status(500, -2, 'internal server error')
|
||||
return self.http_status(500, -2, str(e))
|
||||
|
||||
new_f = handler_error
|
||||
new_f.__name__ = (self.name + rule).replace('/', '__')
|
||||
@@ -120,6 +121,6 @@ class RouterGroup(abc.ABC):
|
||||
}
|
||||
)
|
||||
|
||||
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
|
||||
"""Return a response with a specified status code"""
|
||||
return self.fail(code, msg), status
|
||||
def http_status(self, status: int, code: int, msg: str) -> typing.Tuple[quart.Response, int]:
|
||||
"""返回一个指定状态码的响应"""
|
||||
return (self.fail(code, msg), status)
|
||||
@@ -2,6 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import quart
|
||||
import mimetypes
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
import quart.datastructures
|
||||
|
||||
from .. import group
|
||||
|
||||
@@ -20,3 +24,23 @@ class FilesRouterGroup(group.RouterGroup):
|
||||
mime_type = 'image/jpeg'
|
||||
|
||||
return quart.Response(image_bytes, mimetype=mime_type)
|
||||
|
||||
@self.route('/documents', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> quart.Response:
|
||||
request = quart.request
|
||||
# get file bytes from 'file'
|
||||
file = (await request.files)['file']
|
||||
assert isinstance(file, quart.datastructures.FileStorage)
|
||||
|
||||
file_bytes = await asyncio.to_thread(file.stream.read)
|
||||
extension = file.filename.split('.')[-1]
|
||||
file_name = file.filename.split('.')[0]
|
||||
|
||||
file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension
|
||||
# save file to storage
|
||||
await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes)
|
||||
return self.success(
|
||||
data={
|
||||
'file_id': file_key,
|
||||
}
|
||||
)
|
||||
|
||||
90
pkg/api/http/controller/groups/knowledge/base.py
Normal file
90
pkg/api/http/controller/groups/knowledge/base.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import quart
|
||||
from ... import group
|
||||
|
||||
|
||||
@group.group_class('knowledge_base', '/api/v1/knowledge/bases')
|
||||
class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['POST', 'GET'])
|
||||
async def handle_knowledge_bases() -> quart.Response:
|
||||
if quart.request.method == 'GET':
|
||||
knowledge_bases = await self.ap.knowledge_service.get_knowledge_bases()
|
||||
return self.success(data={'bases': knowledge_bases})
|
||||
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
knowledge_base_uuid = await self.ap.knowledge_service.create_knowledge_base(json_data)
|
||||
return self.success(data={'uuid': knowledge_base_uuid})
|
||||
|
||||
return self.http_status(405, -1, 'Method not allowed')
|
||||
|
||||
@self.route(
|
||||
'/<knowledge_base_uuid>',
|
||||
methods=['GET', 'DELETE', 'PUT'],
|
||||
)
|
||||
async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> quart.Response:
|
||||
if quart.request.method == 'GET':
|
||||
knowledge_base = await self.ap.knowledge_service.get_knowledge_base(knowledge_base_uuid)
|
||||
|
||||
if knowledge_base is None:
|
||||
return self.http_status(404, -1, 'knowledge base not found')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'base': knowledge_base,
|
||||
}
|
||||
)
|
||||
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
await self.ap.knowledge_service.update_knowledge_base(knowledge_base_uuid, json_data)
|
||||
return self.success({})
|
||||
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.knowledge_service.delete_knowledge_base(knowledge_base_uuid)
|
||||
return self.success({})
|
||||
|
||||
@self.route(
|
||||
'/<knowledge_base_uuid>/files',
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
async def get_knowledge_base_files(knowledge_base_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
files = await self.ap.knowledge_service.get_files_by_knowledge_base(knowledge_base_uuid)
|
||||
return self.success(
|
||||
data={
|
||||
'files': files,
|
||||
}
|
||||
)
|
||||
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
file_id = json_data.get('file_id')
|
||||
if not file_id:
|
||||
return self.http_status(400, -1, 'File ID is required')
|
||||
|
||||
# 调用服务层方法将文件与知识库关联
|
||||
task_id = await self.ap.knowledge_service.store_file(knowledge_base_uuid, file_id)
|
||||
return self.success(
|
||||
{
|
||||
'task_id': task_id,
|
||||
}
|
||||
)
|
||||
|
||||
@self.route(
|
||||
'/<knowledge_base_uuid>/files/<file_id>',
|
||||
methods=['DELETE'],
|
||||
)
|
||||
async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str:
|
||||
await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id)
|
||||
return self.success({})
|
||||
|
||||
@self.route(
|
||||
'/<knowledge_base_uuid>/retrieve',
|
||||
methods=['POST'],
|
||||
)
|
||||
async def retrieve_knowledge_base(knowledge_base_uuid: str) -> str:
|
||||
json_data = await quart.request.json
|
||||
query = json_data.get('query')
|
||||
results = await self.ap.knowledge_service.retrieve_knowledge_base(knowledge_base_uuid, query)
|
||||
return self.success(data={'results': results})
|
||||
@@ -9,18 +9,18 @@ class LLMModelsRouterGroup(group.RouterGroup):
|
||||
@self.route('', methods=['GET', 'POST'])
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
return self.success(data={'models': await self.ap.model_service.get_llm_models()})
|
||||
return self.success(data={'models': await self.ap.llm_model_service.get_llm_models()})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
|
||||
model_uuid = await self.ap.model_service.create_llm_model(json_data)
|
||||
model_uuid = await self.ap.llm_model_service.create_llm_model(json_data)
|
||||
|
||||
return self.success(data={'uuid': model_uuid})
|
||||
|
||||
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'])
|
||||
async def _(model_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
model = await self.ap.model_service.get_llm_model(model_uuid)
|
||||
model = await self.ap.llm_model_service.get_llm_model(model_uuid)
|
||||
|
||||
if model is None:
|
||||
return self.http_status(404, -1, 'model not found')
|
||||
@@ -29,11 +29,11 @@ class LLMModelsRouterGroup(group.RouterGroup):
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
|
||||
await self.ap.model_service.update_llm_model(model_uuid, json_data)
|
||||
await self.ap.llm_model_service.update_llm_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.model_service.delete_llm_model(model_uuid)
|
||||
await self.ap.llm_model_service.delete_llm_model(model_uuid)
|
||||
|
||||
return self.success()
|
||||
|
||||
@@ -41,6 +41,49 @@ class LLMModelsRouterGroup(group.RouterGroup):
|
||||
async def _(model_uuid: str) -> str:
|
||||
json_data = await quart.request.json
|
||||
|
||||
await self.ap.model_service.test_llm_model(model_uuid, json_data)
|
||||
await self.ap.llm_model_service.test_llm_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
|
||||
|
||||
@group.group_class('models/embedding', '/api/v1/provider/models/embedding')
|
||||
class EmbeddingModelsRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET', 'POST'])
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
return self.success(data={'models': await self.ap.embedding_models_service.get_embedding_models()})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
|
||||
model_uuid = await self.ap.embedding_models_service.create_embedding_model(json_data)
|
||||
|
||||
return self.success(data={'uuid': model_uuid})
|
||||
|
||||
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'])
|
||||
async def _(model_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
model = await self.ap.embedding_models_service.get_embedding_model(model_uuid)
|
||||
|
||||
if model is None:
|
||||
return self.http_status(404, -1, 'model not found')
|
||||
|
||||
return self.success(data={'model': model})
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
|
||||
await self.ap.embedding_models_service.update_embedding_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.embedding_models_service.delete_embedding_model(model_uuid)
|
||||
|
||||
return self.success()
|
||||
|
||||
@self.route('/<model_uuid>/test', methods=['POST'])
|
||||
async def _(model_uuid: str) -> str:
|
||||
json_data = await quart.request.json
|
||||
|
||||
await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
|
||||
@@ -8,7 +8,8 @@ class RequestersRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET'])
|
||||
async def _() -> quart.Response:
|
||||
return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info()})
|
||||
model_type = quart.request.args.get('type', '')
|
||||
return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info(model_type)})
|
||||
|
||||
@self.route('/<requester_name>', methods=['GET'])
|
||||
async def _(requester_name: str) -> quart.Response:
|
||||
|
||||
@@ -14,11 +14,13 @@ from . import group
|
||||
from .groups import provider as groups_provider
|
||||
from .groups import platform as groups_platform
|
||||
from .groups import pipelines as groups_pipelines
|
||||
from .groups import knowledge as groups_knowledge
|
||||
|
||||
importutil.import_modules_in_pkg(groups)
|
||||
importutil.import_modules_in_pkg(groups_provider)
|
||||
importutil.import_modules_in_pkg(groups_platform)
|
||||
importutil.import_modules_in_pkg(groups_pipelines)
|
||||
importutil.import_modules_in_pkg(groups_knowledge)
|
||||
|
||||
|
||||
class HTTPController:
|
||||
|
||||
118
pkg/api/http/service/knowledge.py
Normal file
118
pkg/api/http/service/knowledge.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import rag as persistence_rag
|
||||
|
||||
|
||||
class KnowledgeService:
|
||||
"""知识库服务"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_knowledge_bases(self) -> list[dict]:
|
||||
"""获取所有知识库"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
||||
knowledge_bases = result.all()
|
||||
return [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
|
||||
for knowledge_base in knowledge_bases
|
||||
]
|
||||
|
||||
async def get_knowledge_base(self, kb_uuid: str) -> dict | None:
|
||||
"""获取知识库"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
knowledge_base = result.first()
|
||||
if knowledge_base is None:
|
||||
return None
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
|
||||
|
||||
async def create_knowledge_base(self, kb_data: dict) -> str:
|
||||
"""创建知识库"""
|
||||
kb_data['uuid'] = str(uuid.uuid4())
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.KnowledgeBase).values(kb_data))
|
||||
|
||||
kb = await self.get_knowledge_base(kb_data['uuid'])
|
||||
|
||||
await self.ap.rag_mgr.load_knowledge_base(kb)
|
||||
|
||||
return kb_data['uuid']
|
||||
|
||||
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
|
||||
"""更新知识库"""
|
||||
if 'uuid' in kb_data:
|
||||
del kb_data['uuid']
|
||||
|
||||
if 'embedding_model_uuid' in kb_data:
|
||||
del kb_data['embedding_model_uuid']
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.KnowledgeBase)
|
||||
.values(kb_data)
|
||||
.where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
await self.ap.rag_mgr.remove_knowledge_base_from_runtime(kb_uuid)
|
||||
|
||||
kb = await self.get_knowledge_base(kb_uuid)
|
||||
|
||||
await self.ap.rag_mgr.load_knowledge_base(kb)
|
||||
|
||||
async def store_file(self, kb_uuid: str, file_id: str) -> int:
|
||||
"""存储文件"""
|
||||
# await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(kb_id=kb_uuid, file_id=file_id))
|
||||
# await self.ap.rag_mgr.store_file(file_id)
|
||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
if runtime_kb is None:
|
||||
raise Exception('Knowledge base not found')
|
||||
return await runtime_kb.store_file(file_id)
|
||||
|
||||
async def retrieve_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]:
|
||||
"""检索知识库"""
|
||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
if runtime_kb is None:
|
||||
raise Exception('Knowledge base not found')
|
||||
return [result.model_dump() for result in await runtime_kb.retrieve(query)]
|
||||
|
||||
async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]:
|
||||
"""获取知识库文件"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid)
|
||||
)
|
||||
files = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_rag.File, file) for file in files]
|
||||
|
||||
async def delete_file(self, kb_uuid: str, file_id: str) -> None:
|
||||
"""删除文件"""
|
||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
if runtime_kb is None:
|
||||
raise Exception('Knowledge base not found')
|
||||
await runtime_kb.delete_file(file_id)
|
||||
|
||||
async def delete_knowledge_base(self, kb_uuid: str) -> None:
|
||||
"""删除知识库"""
|
||||
await self.ap.rag_mgr.delete_knowledge_base(kb_uuid)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
|
||||
# delete files
|
||||
files = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid)
|
||||
)
|
||||
for file in files:
|
||||
# delete chunks
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file.uuid)
|
||||
)
|
||||
# delete file
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file.uuid)
|
||||
)
|
||||
@@ -10,7 +10,7 @@ from ....provider.modelmgr import requester as model_requester
|
||||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
class ModelsService:
|
||||
class LLMModelsService:
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
@@ -103,3 +103,89 @@ class ModelsService:
|
||||
funcs=[],
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingModelsService:
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_embedding_models(self) -> list[dict]:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
|
||||
|
||||
models = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) for model in models]
|
||||
|
||||
async def create_embedding_model(self, model_data: dict) -> str:
|
||||
model_data['uuid'] = str(uuid.uuid4())
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data)
|
||||
)
|
||||
|
||||
embedding_model = await self.get_embedding_model(model_data['uuid'])
|
||||
|
||||
await self.ap.model_mgr.load_embedding_model(embedding_model)
|
||||
|
||||
return model_data['uuid']
|
||||
|
||||
async def get_embedding_model(self, model_uuid: str) -> dict | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.EmbeddingModel).where(
|
||||
persistence_model.EmbeddingModel.uuid == model_uuid
|
||||
)
|
||||
)
|
||||
|
||||
model = result.first()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
|
||||
|
||||
async def update_embedding_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
if 'uuid' in model_data:
|
||||
del model_data['uuid']
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.EmbeddingModel)
|
||||
.where(persistence_model.EmbeddingModel.uuid == model_uuid)
|
||||
.values(**model_data)
|
||||
)
|
||||
|
||||
await self.ap.model_mgr.remove_embedding_model(model_uuid)
|
||||
|
||||
embedding_model = await self.get_embedding_model(model_uuid)
|
||||
|
||||
await self.ap.model_mgr.load_embedding_model(embedding_model)
|
||||
|
||||
async def delete_embedding_model(self, model_uuid: str) -> None:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_model.EmbeddingModel).where(
|
||||
persistence_model.EmbeddingModel.uuid == model_uuid
|
||||
)
|
||||
)
|
||||
|
||||
await self.ap.model_mgr.remove_embedding_model(model_uuid)
|
||||
|
||||
async def test_embedding_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
runtime_embedding_model: model_requester.RuntimeEmbeddingModel | None = None
|
||||
|
||||
if model_uuid != '_':
|
||||
for model in self.ap.model_mgr.embedding_models:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
runtime_embedding_model = model
|
||||
break
|
||||
|
||||
if runtime_embedding_model is None:
|
||||
raise Exception('model not found')
|
||||
|
||||
else:
|
||||
runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data)
|
||||
|
||||
await runtime_embedding_model.requester.invoke_embedding(
|
||||
model=runtime_embedding_model,
|
||||
input_text=['Hello, world!'],
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
@@ -22,11 +22,14 @@ from ..api.http.service import user as user_service
|
||||
from ..api.http.service import model as model_service
|
||||
from ..api.http.service import pipeline as pipeline_service
|
||||
from ..api.http.service import bot as bot_service
|
||||
from ..api.http.service import knowledge as knowledge_service
|
||||
from ..discover import engine as discover_engine
|
||||
from ..storage import mgr as storagemgr
|
||||
from ..utils import logcache
|
||||
from . import taskmgr
|
||||
from . import entities as core_entities
|
||||
from ..rag.knowledge import kbmgr as rag_mgr
|
||||
from ..vector import mgr as vectordb_mgr
|
||||
|
||||
|
||||
class Application:
|
||||
@@ -47,6 +50,8 @@ class Application:
|
||||
|
||||
model_mgr: llm_model_mgr.ModelManager = None
|
||||
|
||||
rag_mgr: rag_mgr.RAGManager = None
|
||||
|
||||
# TODO move to pipeline
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
|
||||
@@ -93,6 +98,8 @@ class Application:
|
||||
|
||||
persistence_mgr: persistencemgr.PersistenceManager = None
|
||||
|
||||
vector_db_mgr: vectordb_mgr.VectorDBManager = None
|
||||
|
||||
http_ctrl: http_controller.HTTPController = None
|
||||
|
||||
log_cache: logcache.LogCache = None
|
||||
@@ -103,12 +110,16 @@ class Application:
|
||||
|
||||
user_service: user_service.UserService = None
|
||||
|
||||
model_service: model_service.ModelsService = None
|
||||
llm_model_service: model_service.LLMModelsService = None
|
||||
|
||||
embedding_models_service: model_service.EmbeddingModelsService = None
|
||||
|
||||
pipeline_service: pipeline_service.PipelineService = None
|
||||
|
||||
bot_service: bot_service.BotService = None
|
||||
|
||||
knowledge_service: knowledge_service.KnowledgeService = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -143,6 +154,7 @@ class Application:
|
||||
name='http-api-controller',
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||
)
|
||||
|
||||
self.task_mgr.create_task(
|
||||
never_ending(),
|
||||
name='never-ending-task',
|
||||
|
||||
@@ -19,7 +19,7 @@ class LifecycleControlScope(enum.Enum):
|
||||
APPLICATION = 'application'
|
||||
PLATFORM = 'platform'
|
||||
PLUGIN = 'plugin'
|
||||
PROVIDER = 'provider'
|
||||
PROVIDER = 'provider'
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
|
||||
@@ -9,6 +9,7 @@ from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...rag.knowledge import kbmgr as rag_mgr
|
||||
from ...platform import botmgr as im_mgr
|
||||
from ...persistence import mgr as persistencemgr
|
||||
from ...api.http.controller import main as http_controller
|
||||
@@ -16,9 +17,11 @@ from ...api.http.service import user as user_service
|
||||
from ...api.http.service import model as model_service
|
||||
from ...api.http.service import pipeline as pipeline_service
|
||||
from ...api.http.service import bot as bot_service
|
||||
from ...api.http.service import knowledge as knowledge_service
|
||||
from ...discover import engine as discover_engine
|
||||
from ...storage import mgr as storagemgr
|
||||
from ...utils import logcache
|
||||
from ...vector import mgr as vectordb_mgr
|
||||
from .. import taskmgr
|
||||
|
||||
|
||||
@@ -88,6 +91,15 @@ class BuildAppStage(stage.BootingStage):
|
||||
await pipeline_mgr.initialize()
|
||||
ap.pipeline_mgr = pipeline_mgr
|
||||
|
||||
rag_mgr_inst = rag_mgr.RAGManager(ap)
|
||||
await rag_mgr_inst.initialize()
|
||||
ap.rag_mgr = rag_mgr_inst
|
||||
|
||||
# 初始化向量数据库管理器
|
||||
vectordb_mgr_inst = vectordb_mgr.VectorDBManager(ap)
|
||||
await vectordb_mgr_inst.initialize()
|
||||
ap.vector_db_mgr = vectordb_mgr_inst
|
||||
|
||||
http_ctrl = http_controller.HTTPController(ap)
|
||||
await http_ctrl.initialize()
|
||||
ap.http_ctrl = http_ctrl
|
||||
@@ -95,8 +107,11 @@ class BuildAppStage(stage.BootingStage):
|
||||
user_service_inst = user_service.UserService(ap)
|
||||
ap.user_service = user_service_inst
|
||||
|
||||
model_service_inst = model_service.ModelsService(ap)
|
||||
ap.model_service = model_service_inst
|
||||
llm_model_service_inst = model_service.LLMModelsService(ap)
|
||||
ap.llm_model_service = llm_model_service_inst
|
||||
|
||||
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
||||
ap.embedding_models_service = embedding_models_service_inst
|
||||
|
||||
pipeline_service_inst = pipeline_service.PipelineService(ap)
|
||||
ap.pipeline_service = pipeline_service_inst
|
||||
@@ -104,5 +119,8 @@ class BuildAppStage(stage.BootingStage):
|
||||
bot_service_inst = bot_service.BotService(ap)
|
||||
ap.bot_service = bot_service_inst
|
||||
|
||||
knowledge_service_inst = knowledge_service.KnowledgeService(ap)
|
||||
ap.knowledge_service = knowledge_service_inst
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
|
||||
@@ -23,3 +23,24 @@ class LLMModel(Base):
|
||||
server_default=sqlalchemy.func.now(),
|
||||
onupdate=sqlalchemy.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingModel(Base):
|
||||
"""Embedding 模型"""
|
||||
|
||||
__tablename__ = 'embedding_models'
|
||||
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
|
||||
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
server_default=sqlalchemy.func.now(),
|
||||
onupdate=sqlalchemy.func.now(),
|
||||
)
|
||||
|
||||
@@ -20,7 +20,6 @@ class LegacyPipeline(Base):
|
||||
)
|
||||
for_version = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
is_default = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
|
||||
|
||||
stages = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
|
||||
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
|
||||
|
||||
@@ -43,3 +42,4 @@ class PipelineRunRecord(Base):
|
||||
started_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False)
|
||||
finished_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False)
|
||||
result = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
|
||||
knowledge_base_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
|
||||
50
pkg/entity/persistence/rag.py
Normal file
50
pkg/entity/persistence/rag.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import sqlalchemy
|
||||
from .base import Base
|
||||
|
||||
# Base = declarative_base()
|
||||
# DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db')
|
||||
# print("Using database URL:", DATABASE_URL)
|
||||
|
||||
|
||||
# engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False})
|
||||
|
||||
# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# def create_db_and_tables():
|
||||
# """Creates all database tables defined in the Base."""
|
||||
# Base.metadata.create_all(bind=engine)
|
||||
# print('Database tables created or already exist.')
|
||||
|
||||
|
||||
class KnowledgeBase(Base):
|
||||
__tablename__ = 'knowledge_bases'
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String, index=True)
|
||||
description = sqlalchemy.Column(sqlalchemy.Text)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
|
||||
embedding_model_uuid = sqlalchemy.Column(sqlalchemy.String, default='')
|
||||
top_k = sqlalchemy.Column(sqlalchemy.Integer, default=5)
|
||||
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = 'knowledge_base_files'
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
kb_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
file_name = sqlalchemy.Column(sqlalchemy.String)
|
||||
extension = sqlalchemy.Column(sqlalchemy.String)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
|
||||
status = sqlalchemy.Column(sqlalchemy.String, default='pending') # pending, processing, completed, failed
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
__tablename__ = 'knowledge_base_chunks'
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
file_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
text = sqlalchemy.Column(sqlalchemy.Text)
|
||||
|
||||
|
||||
# class Vector(Base):
|
||||
# __tablename__ = 'knowledge_base_vectors'
|
||||
# uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
# chunk_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
# embedding = sqlalchemy.Column(sqlalchemy.LargeBinary)
|
||||
13
pkg/entity/persistence/vector.py
Normal file
13
pkg/entity/persistence/vector.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from sqlalchemy import Column, Integer, ForeignKey, LargeBinary
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class Vector(Base):
|
||||
__tablename__ = 'vectors'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
|
||||
embedding = Column(LargeBinary) # Store embeddings as binary
|
||||
|
||||
chunk = relationship('Chunk', back_populates='vector')
|
||||
0
pkg/entity/rag/__init__.py
Normal file
0
pkg/entity/rag/__init__.py
Normal file
13
pkg/entity/rag/retriever.py
Normal file
13
pkg/entity/rag/retriever.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pydantic
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RetrieveResultEntry(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
metadata: dict[str, Any]
|
||||
|
||||
distance: float
|
||||
@@ -79,7 +79,7 @@ class PersistenceManager:
|
||||
'stages': pipeline_service.default_stage_order,
|
||||
'is_default': True,
|
||||
'name': 'ChatPipeline',
|
||||
'description': 'Default pipeline provided, your new bots will be automatically bound to this pipeline | 默认提供的流水线,您配置的机器人将自动绑定到此流水线',
|
||||
'description': 'Default pipeline, new bots will be bound to this pipeline | 默认提供的流水线,您配置的机器人将自动绑定到此流水线',
|
||||
'config': pipeline_config,
|
||||
}
|
||||
|
||||
|
||||
38
pkg/persistence/migrations/dbm004_rag_kb_uuid.py
Normal file
38
pkg/persistence/migrations/dbm004_rag_kb_uuid.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
|
||||
|
||||
@migration.migration_class(4)
|
||||
class DBMigrateRAGKBUUID(migration.DBMigration):
|
||||
"""RAG知识库UUID"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""升级"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
|
||||
if 'knowledge-base' not in config['ai']['local-agent']:
|
||||
config['ai']['local-agent']['knowledge-base'] = ''
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""降级"""
|
||||
pass
|
||||
@@ -144,23 +144,27 @@ class RuntimePipeline:
|
||||
result = await result
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {result}')
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query.query_id} res {result.result_type}'
|
||||
)
|
||||
await self._check_output(query, result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}')
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} gen')
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query.query_id} gen')
|
||||
|
||||
async for sub_result in result:
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {sub_result}')
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query.query_id} res {sub_result.result_type}'
|
||||
)
|
||||
await self._check_output(query, sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}')
|
||||
break
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = sub_result.new_query
|
||||
@@ -192,7 +196,7 @@ class RuntimePipeline:
|
||||
if event_ctx.is_prevented_default():
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f'Processing query {query}')
|
||||
self.ap.logger.debug(f'Processing query {query.query_id}')
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
@@ -200,7 +204,7 @@ class RuntimePipeline:
|
||||
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
|
||||
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
|
||||
finally:
|
||||
self.ap.logger.debug(f'Query {query} processed')
|
||||
self.ap.logger.debug(f'Query {query.query_id} processed')
|
||||
|
||||
|
||||
class PipelineManager:
|
||||
|
||||
@@ -80,14 +80,15 @@ class PreProcessor(stage.PipelineStage):
|
||||
if me.type == 'image_url':
|
||||
msg.content.remove(me)
|
||||
|
||||
content_list = []
|
||||
content_list: list[llm_entities.ContentElement] = []
|
||||
|
||||
plain_text = ''
|
||||
qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message')
|
||||
|
||||
# tidy the content_list
|
||||
# combine all text content into one, and put it in the first position
|
||||
for me in query.message_chain:
|
||||
if isinstance(me, platform_message.Plain):
|
||||
content_list.append(llm_entities.ContentElement.from_text(me.text))
|
||||
plain_text += me.text
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
|
||||
@@ -106,6 +107,8 @@ class PreProcessor(stage.PipelineStage):
|
||||
if msg.base64 is not None:
|
||||
content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64))
|
||||
|
||||
content_list.insert(0, llm_entities.ContentElement.from_text(plain_text))
|
||||
|
||||
query.variables['user_message_text'] = plain_text
|
||||
|
||||
query.user_message = llm_entities.Message(role='user', content=content_list)
|
||||
|
||||
@@ -119,7 +119,7 @@ class EventLogger:
|
||||
async def _truncate_logs(self):
|
||||
if len(self.logs) > MAX_LOG_COUNT:
|
||||
for i in range(DELETE_COUNT_PER_TIME):
|
||||
for image_key in self.logs[i].images:
|
||||
for image_key in self.logs[i].images: # type: ignore
|
||||
await self.ap.storage_mgr.storage_provider.delete(image_key)
|
||||
self.logs = self.logs[DELETE_COUNT_PER_TIME:]
|
||||
|
||||
|
||||
@@ -654,10 +654,10 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
||||
# 确保路径没有空字节
|
||||
clean_path = ele.path.replace('\x00', '')
|
||||
clean_path = os.path.abspath(clean_path)
|
||||
|
||||
|
||||
if not os.path.exists(clean_path):
|
||||
continue # 跳过不存在的文件
|
||||
|
||||
|
||||
try:
|
||||
with open(clean_path, 'rb') as f:
|
||||
image_bytes = f.read()
|
||||
@@ -677,12 +677,13 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
||||
filename = f'{uuid.uuid4()}.webp'
|
||||
# 默认保持PNG
|
||||
except Exception as e:
|
||||
print(f"Error reading image file {clean_path}: {e}")
|
||||
print(f'Error reading image file {clean_path}: {e}')
|
||||
continue # 跳过读取失败的文件
|
||||
|
||||
if image_bytes:
|
||||
# 使用BytesIO创建文件对象,避免路径问题
|
||||
import io
|
||||
|
||||
image_files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename))
|
||||
elif isinstance(ele, platform_message.Plain):
|
||||
text_string += ele.text
|
||||
@@ -1003,25 +1004,25 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
msg_to_send, image_files = await self.message_converter.yiri2target(message)
|
||||
|
||||
|
||||
try:
|
||||
# 获取频道对象
|
||||
channel = self.bot.get_channel(int(target_id))
|
||||
if channel is None:
|
||||
# 如果本地缓存中没有,尝试从API获取
|
||||
channel = await self.bot.fetch_channel(int(target_id))
|
||||
|
||||
|
||||
args = {
|
||||
'content': msg_to_send,
|
||||
}
|
||||
|
||||
|
||||
if len(image_files) > 0:
|
||||
args['files'] = image_files
|
||||
|
||||
|
||||
await channel.send(**args)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await self.logger.error(f"Discord send_message failed: {e}")
|
||||
await self.logger.error(f'Discord send_message failed: {e}')
|
||||
raise e
|
||||
|
||||
async def reply_message(
|
||||
|
||||
@@ -378,15 +378,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
if 'im.message.receive_v1' == type:
|
||||
try:
|
||||
event = await self.event_converter.target2yiri(p2v1, self.api_client)
|
||||
except Exception as e:
|
||||
await self.logger.error(f"Error in lark callback: {traceback.format_exc()}")
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
|
||||
|
||||
if event.__class__ in self.listeners:
|
||||
await self.listeners[event.__class__](event, self)
|
||||
|
||||
return {'code': 200, 'message': 'ok'}
|
||||
except Exception as e:
|
||||
await self.logger.error(f"Error in lark callback: {traceback.format_exc()}")
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
|
||||
return {'code': 500, 'message': 'error'}
|
||||
|
||||
async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1):
|
||||
|
||||
@@ -72,8 +72,9 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
content=content_list,
|
||||
)
|
||||
nakuru_forward_node_list.append(nakuru_forward_node)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
nakuru_msg_list.append(nakuru_forward_node_list)
|
||||
@@ -276,7 +277,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
# 注册监听器
|
||||
self.bot.receiver(source_cls.__name__)(listener_wrapper)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in nakuru register_listener: {traceback.format_exc()}")
|
||||
self.logger.error(f'Error in nakuru register_listener: {traceback.format_exc()}')
|
||||
raise e
|
||||
|
||||
def unregister_listener(
|
||||
|
||||
@@ -125,8 +125,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
self.bot_account_id = event.receiver_id
|
||||
try:
|
||||
return await callback(await self.event_converter.target2yiri(event), self)
|
||||
except Exception as e:
|
||||
await self.logger.error(f"Error in officialaccount callback: {traceback.format_exc()}")
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in officialaccount callback: {traceback.format_exc()}')
|
||||
|
||||
if event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message('text')(on_message)
|
||||
|
||||
@@ -154,10 +154,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
raise ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
self.bot = QQOfficialClient(
|
||||
app_id=config['appid'],
|
||||
secret=config['secret'],
|
||||
token=config['token'],
|
||||
logger=self.logger
|
||||
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
@@ -224,8 +221,8 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
self.bot_account_id = 'justbot'
|
||||
try:
|
||||
return await callback(await self.event_converter.target2yiri(event), self)
|
||||
except Exception as e:
|
||||
await self.logger.error(f"Error in qqofficial callback: {traceback.format_exc()}")
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in qqofficial callback: {traceback.format_exc()}')
|
||||
|
||||
if event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message('DIRECT_MESSAGE_CREATE')(on_message)
|
||||
|
||||
@@ -104,7 +104,9 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
|
||||
if missing_keys:
|
||||
raise ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
self.bot = SlackClient(bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger)
|
||||
self.bot = SlackClient(
|
||||
bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
@@ -139,8 +141,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
|
||||
self.bot_account_id = 'SlackBot'
|
||||
try:
|
||||
return await callback(await self.event_converter.target2yiri(event, self.bot), self)
|
||||
except Exception as e:
|
||||
await self.logger.error(f"Error in slack callback: {traceback.format_exc()}")
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in slack callback: {traceback.format_exc()}')
|
||||
|
||||
if event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message('im')(on_message)
|
||||
|
||||
@@ -160,8 +160,8 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
try:
|
||||
lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id)
|
||||
await self.listeners[type(lb_event)](lb_event, self)
|
||||
except Exception as e:
|
||||
await self.logger.error(f"Error in telegram callback: {traceback.format_exc()}")
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}')
|
||||
|
||||
self.application = ApplicationBuilder().token(self.config['token']).build()
|
||||
self.bot = self.application.bot
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import requests
|
||||
import websockets
|
||||
import websocket
|
||||
import json
|
||||
import time
|
||||
@@ -10,32 +9,25 @@ from libs.wechatpad_api.client import WeChatPadClient
|
||||
import typing
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
import re
|
||||
import base64
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
import copy
|
||||
import datetime
|
||||
import threading
|
||||
|
||||
import quart
|
||||
import aiohttp
|
||||
|
||||
from .. import adapter
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ...utils import image
|
||||
from ..logger import EventLogger
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Optional, List, Tuple
|
||||
from typing import Optional, Tuple
|
||||
from functools import partial
|
||||
import logging
|
||||
|
||||
|
||||
class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
|
||||
def __init__(self, config: dict, logger: logging.Logger):
|
||||
@@ -44,19 +36,14 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
self.logger = logger
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain
|
||||
) -> list[dict]:
|
||||
async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]:
|
||||
content_list = []
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
|
||||
|
||||
|
||||
for component in message_chain:
|
||||
if isinstance(component, platform_message.At):
|
||||
content_list.append({"type": "at", "target": component.target})
|
||||
content_list.append({'type': 'at', 'target': component.target})
|
||||
elif isinstance(component, platform_message.Plain):
|
||||
content_list.append({"type": "text", "content": component.text})
|
||||
content_list.append({'type': 'text', 'content': component.text})
|
||||
elif isinstance(component, platform_message.Image):
|
||||
if component.url:
|
||||
async with httpx.AsyncClient() as client:
|
||||
@@ -68,15 +55,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
else:
|
||||
raise Exception('获取文件失败')
|
||||
# pass
|
||||
content_list.append({"type": "image", "image": base64_str})
|
||||
content_list.append({'type': 'image', 'image': base64_str})
|
||||
elif component.base64:
|
||||
content_list.append({"type": "image", "image": component.base64})
|
||||
content_list.append({'type': 'image', 'image': component.base64})
|
||||
|
||||
elif isinstance(component, platform_message.WeChatEmoji):
|
||||
content_list.append(
|
||||
{'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size})
|
||||
{'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size}
|
||||
)
|
||||
elif isinstance(component, platform_message.Voice):
|
||||
content_list.append({"type": "voice", "data": component.url, "duration": component.length, "forma": 0})
|
||||
content_list.append({'type': 'voice', 'data': component.url, 'duration': component.length, 'forma': 0})
|
||||
elif isinstance(component, platform_message.WeChatAppMsg):
|
||||
content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg})
|
||||
elif isinstance(component, platform_message.Forward):
|
||||
@@ -86,7 +74,6 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
|
||||
return content_list
|
||||
|
||||
|
||||
async def target2yiri(
|
||||
self,
|
||||
message: dict,
|
||||
@@ -97,15 +84,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
message_list = []
|
||||
bot_wxid = self.config['wxid']
|
||||
ats_bot = False # 是否被@
|
||||
content = message["content"]["str"]
|
||||
content = message['content']['str']
|
||||
content_no_preifx = content # 群消息则去掉前缀
|
||||
is_group_message = self._is_group_message(message)
|
||||
if is_group_message:
|
||||
ats_bot = self._ats_bot(message, bot_account_id)
|
||||
|
||||
self.logger.info(f"ats_bot: {ats_bot}; bot_account_id: {bot_account_id}; bot_wxid: {bot_wxid}")
|
||||
if "@所有人" in content:
|
||||
message_list.append(platform_message.AtAll())
|
||||
elif ats_bot:
|
||||
if ats_bot:
|
||||
message_list.append(platform_message.At(target=bot_account_id))
|
||||
|
||||
# 解析@信息并生成At组件
|
||||
@@ -116,7 +104,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
|
||||
content_no_preifx, _ = self._extract_content_and_sender(content)
|
||||
|
||||
msg_type = message["msg_type"]
|
||||
msg_type = message['msg_type']
|
||||
|
||||
# 映射消息类型到处理器方法
|
||||
handler_map = {
|
||||
@@ -138,11 +126,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
|
||||
return platform_message.MessageChain(message_list)
|
||||
|
||||
async def _handler_text(
|
||||
self,
|
||||
message: Optional[dict],
|
||||
content_no_preifx: str
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
|
||||
"""处理文本消息 (msg_type=1)"""
|
||||
if message and self._is_group_message(message):
|
||||
pattern = r'@\S{1,20}'
|
||||
@@ -150,16 +134,12 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
|
||||
return platform_message.MessageChain([platform_message.Plain(content_no_preifx)])
|
||||
|
||||
async def _handler_image(
|
||||
self,
|
||||
message: Optional[dict],
|
||||
content_no_preifx: str
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
|
||||
"""处理图像消息 (msg_type=3)"""
|
||||
try:
|
||||
image_xml = content_no_preifx
|
||||
if not image_xml:
|
||||
return platform_message.MessageChain([platform_message.Unknown("[图片内容为空]")])
|
||||
return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')])
|
||||
root = ET.fromstring(image_xml)
|
||||
|
||||
# 提取img标签的属性
|
||||
@@ -169,28 +149,22 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
cdnthumburl = img_tag.get('cdnthumburl')
|
||||
# cdnmidimgurl = img_tag.get('cdnmidimgurl')
|
||||
|
||||
|
||||
image_data = self.bot.cdn_download(aeskey=aeskey, file_type=1, file_url=cdnthumburl)
|
||||
if image_data["Data"]['FileData'] == '':
|
||||
if image_data['Data']['FileData'] == '':
|
||||
image_data = self.bot.cdn_download(aeskey=aeskey, file_type=2, file_url=cdnthumburl)
|
||||
base64_str = image_data["Data"]['FileData']
|
||||
base64_str = image_data['Data']['FileData']
|
||||
# self.logger.info(f"data:image/png;base64,{base64_str}")
|
||||
|
||||
|
||||
elements = [
|
||||
platform_message.Image(base64=f"data:image/png;base64,{base64_str}"),
|
||||
platform_message.Image(base64=f'data:image/png;base64,{base64_str}'),
|
||||
# platform_message.WeChatForwardImage(xml_data=image_xml) # 微信消息转发
|
||||
]
|
||||
return platform_message.MessageChain(elements)
|
||||
except Exception as e:
|
||||
self.logger.error(f"处理图片失败: {str(e)}")
|
||||
return platform_message.MessageChain([platform_message.Unknown("[图片处理失败]")])
|
||||
self.logger.error(f'处理图片失败: {str(e)}')
|
||||
return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')])
|
||||
|
||||
async def _handler_voice(
|
||||
self,
|
||||
message: Optional[dict],
|
||||
content_no_preifx: str
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
|
||||
"""处理语音消息 (msg_type=34)"""
|
||||
message_List = []
|
||||
try:
|
||||
@@ -206,39 +180,33 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
bufid = voicemsg.get('bufid')
|
||||
length = voicemsg.get('voicelength')
|
||||
voice_data = self.bot.get_msg_voice(buf_id=str(bufid), length=int(length), msgid=str(new_msg_id))
|
||||
audio_base64 = voice_data["Data"]['Base64']
|
||||
audio_base64 = voice_data['Data']['Base64']
|
||||
|
||||
# 验证语音数据有效性
|
||||
if not audio_base64:
|
||||
message_List.append(platform_message.Unknown(text="[语音内容为空]"))
|
||||
message_List.append(platform_message.Unknown(text='[语音内容为空]'))
|
||||
return platform_message.MessageChain(message_List)
|
||||
|
||||
# 转换为平台支持的语音格式(如 Silk 格式)
|
||||
voice_element = platform_message.Voice(
|
||||
base64=f"data:audio/silk;base64,{audio_base64}"
|
||||
)
|
||||
voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}')
|
||||
message_List.append(voice_element)
|
||||
|
||||
except KeyError as e:
|
||||
self.logger.error(f"语音数据字段缺失: {str(e)}")
|
||||
message_List.append(platform_message.Unknown(text="[语音数据解析失败]"))
|
||||
self.logger.error(f'语音数据字段缺失: {str(e)}')
|
||||
message_List.append(platform_message.Unknown(text='[语音数据解析失败]'))
|
||||
except Exception as e:
|
||||
self.logger.error(f"处理语音消息异常: {str(e)}")
|
||||
message_List.append(platform_message.Unknown(text="[语音处理失败]"))
|
||||
self.logger.error(f'处理语音消息异常: {str(e)}')
|
||||
message_List.append(platform_message.Unknown(text='[语音处理失败]'))
|
||||
|
||||
return platform_message.MessageChain(message_List)
|
||||
|
||||
async def _handler_compound(
|
||||
self,
|
||||
message: Optional[dict],
|
||||
content_no_preifx: str
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
|
||||
"""处理复合消息 (msg_type=49),根据子类型分派"""
|
||||
try:
|
||||
xml_data = ET.fromstring(content_no_preifx)
|
||||
appmsg_data = xml_data.find('.//appmsg')
|
||||
if appmsg_data:
|
||||
data_type = appmsg_data.findtext('.//type', "")
|
||||
data_type = appmsg_data.findtext('.//type', '')
|
||||
# 二次分派处理器
|
||||
sub_handler_map = {
|
||||
'57': self._handler_compound_quote,
|
||||
@@ -247,9 +215,9 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
'74': self._handler_compound_file,
|
||||
'33': self._handler_compound_mini_program,
|
||||
'36': self._handler_compound_mini_program,
|
||||
'2000': partial(self._handler_compound_unsupported, text="[转账消息]"),
|
||||
'2001': partial(self._handler_compound_unsupported, text="[红包消息]"),
|
||||
'51': partial(self._handler_compound_unsupported, text="[视频号消息]"),
|
||||
'2000': partial(self._handler_compound_unsupported, text='[转账消息]'),
|
||||
'2001': partial(self._handler_compound_unsupported, text='[红包消息]'),
|
||||
'51': partial(self._handler_compound_unsupported, text='[视频号消息]'),
|
||||
}
|
||||
|
||||
handler = sub_handler_map.get(data_type, self._handler_compound_unsupported)
|
||||
@@ -260,56 +228,54 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
else:
|
||||
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
|
||||
except Exception as e:
|
||||
self.logger.error(f"解析复合消息失败: {str(e)}")
|
||||
self.logger.error(f'解析复合消息失败: {str(e)}')
|
||||
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
|
||||
|
||||
async def _handler_compound_quote(
|
||||
self,
|
||||
message: Optional[dict],
|
||||
xml_data: ET.Element
|
||||
self, message: Optional[dict], xml_data: ET.Element
|
||||
) -> platform_message.MessageChain:
|
||||
"""处理引用消息 (data_type=57)"""
|
||||
message_list = []
|
||||
# self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode'))
|
||||
# self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode'))
|
||||
appmsg_data = xml_data.find('.//appmsg')
|
||||
quote_data = "" # 引用原文
|
||||
quote_data = '' # 引用原文
|
||||
quote_id = None # 引用消息的原发送者
|
||||
tousername = None # 接收方: 所属微信的wxid
|
||||
user_data = "" # 用户消息
|
||||
user_data = '' # 用户消息
|
||||
sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member
|
||||
|
||||
# 引用消息转发
|
||||
if appmsg_data:
|
||||
user_data = appmsg_data.findtext('.//title') or ""
|
||||
user_data = appmsg_data.findtext('.//title') or ''
|
||||
quote_data = appmsg_data.find('.//refermsg').findtext('.//content')
|
||||
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr')
|
||||
message_list.append(
|
||||
platform_message.WeChatAppMsg(
|
||||
app_msg=ET.tostring(appmsg_data, encoding='unicode'))
|
||||
)
|
||||
message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg_data, encoding='unicode')))
|
||||
if message:
|
||||
tousername = message['to_user_name']["str"]
|
||||
|
||||
tousername = message['to_user_name']['str']
|
||||
|
||||
_ = quote_id
|
||||
_ = tousername
|
||||
|
||||
if quote_data:
|
||||
quote_data_message_list = platform_message.MessageChain()
|
||||
# 文本消息
|
||||
try:
|
||||
if "<msg>" not in quote_data:
|
||||
if '<msg>' not in quote_data:
|
||||
quote_data_message_list.append(platform_message.Plain(quote_data))
|
||||
else:
|
||||
# 引用消息展开
|
||||
quote_data_xml = ET.fromstring(quote_data)
|
||||
if quote_data_xml.find("img"):
|
||||
if quote_data_xml.find('img'):
|
||||
quote_data_message_list.extend(await self._handler_image(None, quote_data))
|
||||
elif quote_data_xml.find("voicemsg"):
|
||||
elif quote_data_xml.find('voicemsg'):
|
||||
quote_data_message_list.extend(await self._handler_voice(None, quote_data))
|
||||
elif quote_data_xml.find("videomsg"):
|
||||
elif quote_data_xml.find('videomsg'):
|
||||
quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理
|
||||
else:
|
||||
# appmsg
|
||||
quote_data_message_list.extend(await self._handler_compound(None, quote_data))
|
||||
except Exception as e:
|
||||
self.logger.error(f"处理引用消息异常 expcetion:{e}")
|
||||
self.logger.error(f'处理引用消息异常 expcetion:{e}')
|
||||
quote_data_message_list.append(platform_message.Plain(quote_data))
|
||||
message_list.append(
|
||||
platform_message.Quote(
|
||||
@@ -324,15 +290,11 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
|
||||
return platform_message.MessageChain(message_list)
|
||||
|
||||
async def _handler_compound_file(
|
||||
self,
|
||||
message: dict,
|
||||
xml_data: ET.Element
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
|
||||
"""处理文件消息 (data_type=6)"""
|
||||
file_data = xml_data.find('.//appmsg')
|
||||
|
||||
if file_data.findtext('.//type', "") == "74":
|
||||
if file_data.findtext('.//type', '') == '74':
|
||||
return None
|
||||
|
||||
else:
|
||||
@@ -355,22 +317,21 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
|
||||
file_data = self.bot.cdn_download(aeskey=aeskey, file_type=5, file_url=cdnthumburl)
|
||||
|
||||
file_base64 = file_data["Data"]['FileData']
|
||||
file_base64 = file_data['Data']['FileData']
|
||||
# print(file_data)
|
||||
file_size = file_data["Data"]['TotalSize']
|
||||
file_size = file_data['Data']['TotalSize']
|
||||
|
||||
# print(file_base64)
|
||||
return platform_message.MessageChain([
|
||||
platform_message.WeChatFile(file_id=file_id, file_name=file_name, file_size=file_size,
|
||||
file_base64=file_base64),
|
||||
platform_message.WeChatForwardFile(xml_data=xml_data_str)
|
||||
])
|
||||
return platform_message.MessageChain(
|
||||
[
|
||||
platform_message.WeChatFile(
|
||||
file_id=file_id, file_name=file_name, file_size=file_size, file_base64=file_base64
|
||||
),
|
||||
platform_message.WeChatForwardFile(xml_data=xml_data_str),
|
||||
]
|
||||
)
|
||||
|
||||
async def _handler_compound_link(
|
||||
self,
|
||||
message: dict,
|
||||
xml_data: ET.Element
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
|
||||
"""处理链接消息(如公众号文章、外部网页)"""
|
||||
message_list = []
|
||||
try:
|
||||
@@ -383,56 +344,38 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
link_title=appmsg.findtext('title', ''),
|
||||
link_desc=appmsg.findtext('des', ''),
|
||||
link_url=appmsg.findtext('url', ''),
|
||||
link_thumb_url=appmsg.findtext("thumburl", '') # 这个字段拿不到
|
||||
link_thumb_url=appmsg.findtext('thumburl', ''), # 这个字段拿不到
|
||||
)
|
||||
)
|
||||
# 还没有发链接的接口, 暂时还需要自己构造appmsg, 先用WeChatAppMsg。
|
||||
message_list.append(
|
||||
platform_message.WeChatAppMsg(
|
||||
app_msg=ET.tostring(appmsg, encoding='unicode')
|
||||
)
|
||||
)
|
||||
message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg, encoding='unicode')))
|
||||
except Exception as e:
|
||||
self.logger.error(f"解析链接消息失败: {str(e)}")
|
||||
self.logger.error(f'解析链接消息失败: {str(e)}')
|
||||
return platform_message.MessageChain(message_list)
|
||||
|
||||
async def _handler_compound_mini_program(
|
||||
self,
|
||||
message: dict,
|
||||
xml_data: ET.Element
|
||||
self, message: dict, xml_data: ET.Element
|
||||
) -> platform_message.MessageChain:
|
||||
"""处理小程序消息(如小程序卡片、服务通知)"""
|
||||
xml_data_str = ET.tostring(xml_data, encoding='unicode')
|
||||
return platform_message.MessageChain([
|
||||
platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)
|
||||
])
|
||||
return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)])
|
||||
|
||||
async def _handler_default(
|
||||
self,
|
||||
message: Optional[dict],
|
||||
content_no_preifx: str
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
|
||||
"""处理未知消息类型"""
|
||||
if message:
|
||||
msg_type = message["msg_type"]
|
||||
msg_type = message['msg_type']
|
||||
else:
|
||||
msg_type = ""
|
||||
return platform_message.MessageChain([
|
||||
platform_message.Unknown(text=f"[未知消息类型 msg_type:{msg_type}]")
|
||||
])
|
||||
msg_type = ''
|
||||
return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')])
|
||||
|
||||
def _handler_compound_unsupported(
|
||||
self,
|
||||
message: dict,
|
||||
xml_data: str,
|
||||
text: Optional[str] = None
|
||||
self, message: dict, xml_data: str, text: Optional[str] = None
|
||||
) -> platform_message.MessageChain:
|
||||
"""处理未支持复合消息类型(msg_type=49)子类型"""
|
||||
if not text:
|
||||
text = f"[xml_data={xml_data}]"
|
||||
text = f'[xml_data={xml_data}]'
|
||||
content_list = []
|
||||
content_list.append(
|
||||
platform_message.Unknown(text=f"[处理未支持复合消息类型[msg_type=49]|{text}"))
|
||||
content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}'))
|
||||
|
||||
return platform_message.MessageChain(content_list)
|
||||
|
||||
@@ -441,7 +384,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
ats_bot = False
|
||||
try:
|
||||
to_user_name = message['to_user_name']['str'] # 接收方: 所属微信的wxid
|
||||
raw_content = message["content"]["str"] # 原始消息内容
|
||||
raw_content = message['content']['str'] # 原始消息内容
|
||||
content_no_prefix, _ = self._extract_content_and_sender(raw_content)
|
||||
# 直接艾特机器人(这个有bug,当被引用的消息里面有@bot,会套娃
|
||||
# ats_bot = ats_bot or (f"@{bot_account_id}" in content_no_prefix)
|
||||
@@ -452,7 +395,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
msg_source = message.get('msg_source', '') or ''
|
||||
if len(msg_source) > 0:
|
||||
msg_source_data = ET.fromstring(msg_source)
|
||||
at_user_list = msg_source_data.findtext("atuserlist") or ""
|
||||
at_user_list = msg_source_data.findtext('atuserlist') or ''
|
||||
ats_bot = ats_bot or (to_user_name in at_user_list)
|
||||
# 引用bot
|
||||
if message.get('msg_type', 0) == 49:
|
||||
@@ -463,7 +406,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者
|
||||
ats_bot = ats_bot or (quote_id == tousername)
|
||||
except Exception as e:
|
||||
self.logger.error(f"_ats_bot got except: {e}")
|
||||
self.logger.error(f'_ats_bot got except: {e}')
|
||||
finally:
|
||||
return ats_bot
|
||||
|
||||
@@ -489,21 +432,21 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
try:
|
||||
# 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉
|
||||
# add: 有些用户的wxid不是上述格式。换成user_name:
|
||||
regex = re.compile(r"^[a-zA-Z0-9_\-]{5,20}:")
|
||||
line_split = raw_content.split("\n")
|
||||
regex = re.compile(r'^[a-zA-Z0-9_\-]{5,20}:')
|
||||
line_split = raw_content.split('\n')
|
||||
if len(line_split) > 0 and regex.match(line_split[0]):
|
||||
raw_content = "\n".join(line_split[1:])
|
||||
sender_id = line_split[0].strip(":")
|
||||
raw_content = '\n'.join(line_split[1:])
|
||||
sender_id = line_split[0].strip(':')
|
||||
return raw_content, sender_id
|
||||
except Exception as e:
|
||||
self.logger.error(f"_extract_content_and_sender got except: {e}")
|
||||
self.logger.error(f'_extract_content_and_sender got except: {e}')
|
||||
finally:
|
||||
return raw_content, None
|
||||
|
||||
# 是否是群消息
|
||||
def _is_group_message(self, message: dict) -> bool:
|
||||
from_user_name = message['from_user_name']['str']
|
||||
return from_user_name.endswith("@chatroom")
|
||||
return from_user_name.endswith('@chatroom')
|
||||
|
||||
|
||||
class WeChatPadEventConverter(adapter.EventConverter):
|
||||
@@ -514,9 +457,7 @@ class WeChatPadEventConverter(adapter.EventConverter):
|
||||
self.logger = logger
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
event: platform_events.MessageEvent
|
||||
) -> dict:
|
||||
async def yiri2target(event: platform_events.MessageEvent) -> dict:
|
||||
pass
|
||||
|
||||
async def target2yiri(
|
||||
@@ -526,10 +467,12 @@ class WeChatPadEventConverter(adapter.EventConverter):
|
||||
) -> platform_events.MessageEvent:
|
||||
|
||||
# 排除公众号以及微信团队消息
|
||||
if event['from_user_name']['str'].startswith('gh_') \
|
||||
or event['from_user_name']['str']=='weixin'\
|
||||
or event['from_user_name']['str'] == "newsapp"\
|
||||
or event['from_user_name']['str'] == self.config["wxid"]:
|
||||
if (
|
||||
event['from_user_name']['str'].startswith('gh_')
|
||||
or event['from_user_name']['str'] == 'weixin'
|
||||
or event['from_user_name']['str'] == 'newsapp'
|
||||
or event['from_user_name']['str'] == self.config['wxid']
|
||||
):
|
||||
return None
|
||||
message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id)
|
||||
|
||||
@@ -538,7 +481,7 @@ class WeChatPadEventConverter(adapter.EventConverter):
|
||||
|
||||
if '@chatroom' in event['from_user_name']['str']:
|
||||
# 找出开头的 wxid_ 字符串,以:结尾
|
||||
sender_wxid = event['content']['str'].split(":")[0]
|
||||
sender_wxid = event['content']['str'].split(':')[0]
|
||||
|
||||
return platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
@@ -550,13 +493,13 @@ class WeChatPadEventConverter(adapter.EventConverter):
|
||||
name=event['from_user_name']['str'],
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title="",
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event["create_time"],
|
||||
time=event['create_time'],
|
||||
source_platform_object=event,
|
||||
)
|
||||
else:
|
||||
@@ -567,13 +510,13 @@ class WeChatPadEventConverter(adapter.EventConverter):
|
||||
remark='',
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event["create_time"],
|
||||
time=event['create_time'],
|
||||
source_platform_object=event,
|
||||
)
|
||||
|
||||
|
||||
class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
name: str = "WeChatPad" # 定义适配器名称
|
||||
name: str = 'WeChatPad' # 定义适配器名称
|
||||
|
||||
bot: WeChatPadClient
|
||||
quart_app: quart.Quart
|
||||
@@ -606,27 +549,21 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
# self.ap.logger.debug(f"Gewechat callback event: {data}")
|
||||
# print(data)
|
||||
|
||||
|
||||
try:
|
||||
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
|
||||
except Exception as e:
|
||||
await self.logger.error(f"Error in wechatpad callback: {traceback.format_exc()}")
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in wechatpad callback: {traceback.format_exc()}')
|
||||
|
||||
if event.__class__ in self.listeners:
|
||||
await self.listeners[event.__class__](event, self)
|
||||
|
||||
return 'ok'
|
||||
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
message: platform_message.MessageChain,
|
||||
target_id: str
|
||||
):
|
||||
async def _handle_message(self, message: platform_message.MessageChain, target_id: str):
|
||||
"""统一消息处理核心逻辑"""
|
||||
content_list = await self.message_converter.yiri2target(message)
|
||||
# print(content_list)
|
||||
at_targets = [item["target"] for item in content_list if item["type"] == "at"]
|
||||
at_targets = [item['target'] for item in content_list if item['type'] == 'at']
|
||||
# print(at_targets)
|
||||
# 处理@逻辑
|
||||
at_targets = at_targets or []
|
||||
@@ -634,7 +571,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
if at_targets:
|
||||
member_info = self.bot.get_chatroom_member_detail(
|
||||
target_id,
|
||||
)["Data"]["member_data"]["chatroom_member_list"]
|
||||
)['Data']['member_data']['chatroom_member_list']
|
||||
|
||||
# 处理消息组件
|
||||
for msg in content_list:
|
||||
@@ -642,63 +579,51 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
if msg['type'] == 'text' and at_targets:
|
||||
at_nick_name_list = []
|
||||
for member in member_info:
|
||||
if member["user_name"] in at_targets:
|
||||
if member['user_name'] in at_targets:
|
||||
at_nick_name_list.append(f'@{member["nick_name"]}')
|
||||
msg['content'] = f'{" ".join(at_nick_name_list)} {msg["content"]}'
|
||||
|
||||
# 统一消息派发
|
||||
handler_map = {
|
||||
'text': lambda msg: self.bot.send_text_message(
|
||||
to_wxid=target_id,
|
||||
message=msg['content'],
|
||||
ats=at_targets
|
||||
to_wxid=target_id, message=msg['content'], ats=at_targets
|
||||
),
|
||||
'image': lambda msg: self.bot.send_image_message(
|
||||
to_wxid=target_id,
|
||||
img_url=msg["image"],
|
||||
ats = at_targets
|
||||
to_wxid=target_id, img_url=msg['image'], ats=at_targets
|
||||
),
|
||||
'WeChatEmoji': lambda msg: self.bot.send_emoji_message(
|
||||
to_wxid=target_id,
|
||||
emoji_md5=msg['emoji_md5'],
|
||||
emoji_size=msg['emoji_size']
|
||||
to_wxid=target_id, emoji_md5=msg['emoji_md5'], emoji_size=msg['emoji_size']
|
||||
),
|
||||
|
||||
'voice': lambda msg: self.bot.send_voice_message(
|
||||
to_wxid=target_id,
|
||||
voice_data=msg['data'],
|
||||
voice_duration=msg["duration"],
|
||||
voice_forma=msg["forma"],
|
||||
voice_duration=msg['duration'],
|
||||
voice_forma=msg['forma'],
|
||||
),
|
||||
'WeChatAppMsg': lambda msg: self.bot.send_app_message(
|
||||
to_wxid=target_id,
|
||||
app_message=msg['app_msg'],
|
||||
type=0,
|
||||
),
|
||||
'at': lambda msg: None
|
||||
'at': lambda msg: None,
|
||||
}
|
||||
|
||||
if handler := handler_map.get(msg['type']):
|
||||
handler(msg)
|
||||
# self.ap.logger.warning(f"未处理的消息类型: {ret}")
|
||||
else:
|
||||
self.ap.logger.warning(f"未处理的消息类型: {msg['type']}")
|
||||
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
|
||||
continue
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: platform_message.MessageChain
|
||||
):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
"""主动发送消息"""
|
||||
return await self._handle_message(message, target_id)
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
"""回复消息"""
|
||||
if message_source.source_platform_object:
|
||||
@@ -709,58 +634,49 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
pass
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None]
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None]
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
pass
|
||||
|
||||
async def run_async(self):
|
||||
|
||||
if not self.config["admin_key"] and not self.config["token"]:
|
||||
raise RuntimeError("无wechatpad管理密匙,请填入配置文件后重启")
|
||||
if not self.config['admin_key'] and not self.config['token']:
|
||||
raise RuntimeError('无wechatpad管理密匙,请填入配置文件后重启')
|
||||
else:
|
||||
if self.config["token"]:
|
||||
self.bot = WeChatPadClient(
|
||||
self.config['wechatpad_url'],
|
||||
self.config["token"]
|
||||
)
|
||||
if self.config['token']:
|
||||
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
|
||||
data = self.bot.get_login_status()
|
||||
self.ap.logger.info(data)
|
||||
if data["Code"] == 300 and data["Text"] == "你已退出微信":
|
||||
if data['Code'] == 300 and data['Text'] == '你已退出微信':
|
||||
response = requests.post(
|
||||
f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}",
|
||||
json={"Count": 1, "Days": 365}
|
||||
f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
|
||||
json={'Count': 1, 'Days': 365},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"获取token失败: {response.text}")
|
||||
self.config["token"] = response.json()["Data"][0]
|
||||
raise Exception(f'获取token失败: {response.text}')
|
||||
self.config['token'] = response.json()['Data'][0]
|
||||
|
||||
elif not self.config["token"]:
|
||||
elif not self.config['token']:
|
||||
response = requests.post(
|
||||
f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}",
|
||||
json={"Count": 1, "Days": 365}
|
||||
f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
|
||||
json={'Count': 1, 'Days': 365},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"获取token失败: {response.text}")
|
||||
self.config["token"] = response.json()["Data"][0]
|
||||
raise Exception(f'获取token失败: {response.text}')
|
||||
self.config['token'] = response.json()['Data'][0]
|
||||
|
||||
self.bot = WeChatPadClient(
|
||||
self.config['wechatpad_url'],
|
||||
self.config["token"],
|
||||
logger=self.logger
|
||||
)
|
||||
self.ap.logger.info(self.config["token"])
|
||||
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger)
|
||||
self.ap.logger.info(self.config['token'])
|
||||
thread_1 = threading.Event()
|
||||
|
||||
|
||||
def wechat_login_process():
|
||||
# 不登录,这些先注释掉,避免登陆态尝试拉qrcode。
|
||||
# login_data =self.bot.get_login_qr()
|
||||
@@ -768,67 +684,54 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
||||
# url = login_data['Data']["QrCodeUrl"]
|
||||
# self.ap.logger.info(login_data)
|
||||
|
||||
|
||||
profile =self.bot.get_profile()
|
||||
profile = self.bot.get_profile()
|
||||
self.ap.logger.info(profile)
|
||||
|
||||
self.bot_account_id = profile["Data"]["userInfo"]["nickName"]["str"]
|
||||
self.config["wxid"] = profile["Data"]["userInfo"]["userName"]["str"]
|
||||
self.bot_account_id = profile['Data']['userInfo']['nickName']['str']
|
||||
self.config['wxid'] = profile['Data']['userInfo']['userName']['str']
|
||||
thread_1.set()
|
||||
|
||||
|
||||
# asyncio.create_task(wechat_login_process)
|
||||
threading.Thread(target=wechat_login_process).start()
|
||||
|
||||
def connect_websocket_sync() -> None:
|
||||
|
||||
thread_1.wait()
|
||||
uri = f"{self.config['wechatpad_ws']}/GetSyncMsg?key={self.config['token']}"
|
||||
self.ap.logger.info(f"Connecting to WebSocket: {uri}")
|
||||
uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}'
|
||||
self.ap.logger.info(f'Connecting to WebSocket: {uri}')
|
||||
|
||||
def on_message(ws, message):
|
||||
try:
|
||||
data = json.loads(message)
|
||||
self.ap.logger.debug(f"Received message: {data}")
|
||||
self.ap.logger.debug(f'Received message: {data}')
|
||||
# 这里需要确保ws_message是同步的,或者使用asyncio.run调用异步方法
|
||||
asyncio.run(self.ws_message(data))
|
||||
except json.JSONDecodeError:
|
||||
self.ap.logger.error(f"Non-JSON message: {message[:100]}...")
|
||||
self.ap.logger.error(f'Non-JSON message: {message[:100]}...')
|
||||
|
||||
def on_error(ws, error):
|
||||
self.ap.logger.error(f"WebSocket error: {str(error)[:200]}")
|
||||
self.ap.logger.error(f'WebSocket error: {str(error)[:200]}')
|
||||
|
||||
def on_close(ws, close_status_code, close_msg):
|
||||
self.ap.logger.info("WebSocket closed, reconnecting...")
|
||||
self.ap.logger.info('WebSocket closed, reconnecting...')
|
||||
time.sleep(5)
|
||||
connect_websocket_sync() # 自动重连
|
||||
|
||||
def on_open(ws):
|
||||
self.ap.logger.info("WebSocket connected successfully!")
|
||||
self.ap.logger.info('WebSocket connected successfully!')
|
||||
|
||||
ws = websocket.WebSocketApp(
|
||||
uri,
|
||||
on_message=on_message,
|
||||
on_error=on_error,
|
||||
on_close=on_close,
|
||||
on_open=on_open
|
||||
)
|
||||
ws.run_forever(
|
||||
ping_interval=60,
|
||||
ping_timeout=20
|
||||
uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open
|
||||
)
|
||||
ws.run_forever(ping_interval=60, ping_timeout=20)
|
||||
|
||||
# 直接调用同步版本(会阻塞)
|
||||
# connect_websocket_sync()
|
||||
|
||||
# 这行代码会在WebSocket连接断开后才会执行
|
||||
# self.ap.logger.info("WebSocket client thread started")
|
||||
thread = threading.Thread(
|
||||
target=connect_websocket_sync,
|
||||
name="WebSocketClientThread",
|
||||
daemon=True
|
||||
)
|
||||
thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True)
|
||||
thread.start()
|
||||
self.ap.logger.info("WebSocket client thread started")
|
||||
self.ap.logger.info('WebSocket client thread started')
|
||||
|
||||
async def kill(self) -> bool:
|
||||
pass
|
||||
|
||||
@@ -157,7 +157,7 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
|
||||
token=config['token'],
|
||||
EncodingAESKey=config['EncodingAESKey'],
|
||||
contacts_secret=config['contacts_secret'],
|
||||
logger=self.logger
|
||||
logger=self.logger,
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
@@ -201,8 +201,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
|
||||
self.bot_account_id = event.receiver_id
|
||||
try:
|
||||
return await callback(await self.event_converter.target2yiri(event), self)
|
||||
except Exception as e:
|
||||
await self.logger.error(f"Error in wecom callback: {traceback.format_exc()}")
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in wecom callback: {traceback.format_exc()}')
|
||||
|
||||
if event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message('text')(on_message)
|
||||
|
||||
@@ -145,7 +145,7 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
|
||||
secret=config['secret'],
|
||||
token=config['token'],
|
||||
EncodingAESKey=config['EncodingAESKey'],
|
||||
logger=self.logger
|
||||
logger=self.logger,
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
@@ -178,8 +178,8 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
|
||||
self.bot_account_id = event.receiver_id
|
||||
try:
|
||||
return await callback(await self.event_converter.target2yiri(event), self)
|
||||
except Exception as e:
|
||||
await self.logger.error(f"Error in wecomcs callback: {traceback.format_exc()}")
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in wecomcs callback: {traceback.format_exc()}')
|
||||
|
||||
if event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message('text')(on_message)
|
||||
|
||||
@@ -17,7 +17,7 @@ class LLMModelInfo(pydantic.BaseModel):
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
|
||||
requester: requester.LLMAPIRequester
|
||||
requester: requester.ProviderAPIRequester
|
||||
|
||||
tool_call_supported: typing.Optional[bool] = False
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class ModelManager:
|
||||
|
||||
model_list: list[entities.LLMModelInfo] # deprecated
|
||||
|
||||
requesters: dict[str, requester.LLMAPIRequester] # deprecated
|
||||
requesters: dict[str, requester.ProviderAPIRequester] # deprecated
|
||||
|
||||
token_mgrs: dict[str, token.TokenManager] # deprecated
|
||||
|
||||
@@ -28,9 +28,11 @@ class ModelManager:
|
||||
|
||||
llm_models: list[requester.RuntimeLLMModel]
|
||||
|
||||
embedding_models: list[requester.RuntimeEmbeddingModel]
|
||||
|
||||
requester_components: list[engine.Component]
|
||||
|
||||
requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache
|
||||
requester_dict: dict[str, type[requester.ProviderAPIRequester]] # cache
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
@@ -38,6 +40,7 @@ class ModelManager:
|
||||
self.requesters = {}
|
||||
self.token_mgrs = {}
|
||||
self.llm_models = []
|
||||
self.embedding_models = []
|
||||
self.requester_components = []
|
||||
self.requester_dict = {}
|
||||
|
||||
@@ -45,7 +48,7 @@ class ModelManager:
|
||||
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
|
||||
|
||||
# forge requester class dict
|
||||
requester_dict: dict[str, type[requester.LLMAPIRequester]] = {}
|
||||
requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {}
|
||||
for component in self.requester_components:
|
||||
requester_dict[component.metadata.name] = component.get_python_component_class()
|
||||
|
||||
@@ -58,13 +61,11 @@ class ModelManager:
|
||||
self.ap.logger.info('Loading models from db...')
|
||||
|
||||
self.llm_models = []
|
||||
self.embedding_models = []
|
||||
|
||||
# llm models
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
|
||||
|
||||
llm_models = result.all()
|
||||
|
||||
# load models
|
||||
for llm_model in llm_models:
|
||||
try:
|
||||
await self.load_llm_model(llm_model)
|
||||
@@ -73,11 +74,17 @@ class ModelManager:
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}')
|
||||
|
||||
# embedding models
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
|
||||
embedding_models = result.all()
|
||||
for embedding_model in embedding_models:
|
||||
await self.load_embedding_model(embedding_model)
|
||||
|
||||
async def init_runtime_llm_model(
|
||||
self,
|
||||
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
|
||||
):
|
||||
"""初始化运行时模型"""
|
||||
"""初始化运行时 LLM 模型"""
|
||||
if isinstance(model_info, sqlalchemy.Row):
|
||||
model_info = persistence_model.LLMModel(**model_info._mapping)
|
||||
elif isinstance(model_info, dict):
|
||||
@@ -101,14 +108,47 @@ class ModelManager:
|
||||
|
||||
return runtime_llm_model
|
||||
|
||||
async def init_runtime_embedding_model(
|
||||
self,
|
||||
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict,
|
||||
):
|
||||
"""初始化运行时 Embedding 模型"""
|
||||
if isinstance(model_info, sqlalchemy.Row):
|
||||
model_info = persistence_model.EmbeddingModel(**model_info._mapping)
|
||||
elif isinstance(model_info, dict):
|
||||
model_info = persistence_model.EmbeddingModel(**model_info)
|
||||
|
||||
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
|
||||
|
||||
await requester_inst.initialize()
|
||||
|
||||
runtime_embedding_model = requester.RuntimeEmbeddingModel(
|
||||
model_entity=model_info,
|
||||
token_mgr=token.TokenManager(
|
||||
name=model_info.uuid,
|
||||
tokens=model_info.api_keys,
|
||||
),
|
||||
requester=requester_inst,
|
||||
)
|
||||
|
||||
return runtime_embedding_model
|
||||
|
||||
async def load_llm_model(
|
||||
self,
|
||||
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
|
||||
):
|
||||
"""加载模型"""
|
||||
"""加载 LLM 模型"""
|
||||
runtime_llm_model = await self.init_runtime_llm_model(model_info)
|
||||
self.llm_models.append(runtime_llm_model)
|
||||
|
||||
async def load_embedding_model(
|
||||
self,
|
||||
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict,
|
||||
):
|
||||
"""加载 Embedding 模型"""
|
||||
runtime_embedding_model = await self.init_runtime_embedding_model(model_info)
|
||||
self.embedding_models.append(runtime_embedding_model)
|
||||
|
||||
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated
|
||||
"""通过名称获取模型"""
|
||||
for model in self.model_list:
|
||||
@@ -116,23 +156,44 @@ class ModelManager:
|
||||
return model
|
||||
raise ValueError(f'无法确定模型 {name} 的信息')
|
||||
|
||||
async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo:
|
||||
"""通过uuid获取模型"""
|
||||
async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel:
|
||||
"""通过uuid获取 LLM 模型"""
|
||||
for model in self.llm_models:
|
||||
if model.model_entity.uuid == uuid:
|
||||
return model
|
||||
raise ValueError(f'model {uuid} not found')
|
||||
raise ValueError(f'LLM model {uuid} not found')
|
||||
|
||||
async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel:
|
||||
"""通过uuid获取 Embedding 模型"""
|
||||
for model in self.embedding_models:
|
||||
if model.model_entity.uuid == uuid:
|
||||
return model
|
||||
raise ValueError(f'Embedding model {uuid} not found')
|
||||
|
||||
async def remove_llm_model(self, model_uuid: str):
|
||||
"""移除模型"""
|
||||
"""移除 LLM 模型"""
|
||||
for model in self.llm_models:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
self.llm_models.remove(model)
|
||||
return
|
||||
|
||||
def get_available_requesters_info(self) -> list[dict]:
|
||||
async def remove_embedding_model(self, model_uuid: str):
|
||||
"""移除 Embedding 模型"""
|
||||
for model in self.embedding_models:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
self.embedding_models.remove(model)
|
||||
return
|
||||
|
||||
def get_available_requesters_info(self, model_type: str) -> list[dict]:
|
||||
"""获取所有可用的请求器"""
|
||||
return [component.to_plain_dict() for component in self.requester_components]
|
||||
if model_type != '':
|
||||
return [
|
||||
component.to_plain_dict()
|
||||
for component in self.requester_components
|
||||
if model_type in component.spec['support_type']
|
||||
]
|
||||
else:
|
||||
return [component.to_plain_dict() for component in self.requester_components]
|
||||
|
||||
def get_available_requester_info_by_name(self, name: str) -> dict | None:
|
||||
"""通过名称获取请求器信息"""
|
||||
|
||||
@@ -20,22 +20,45 @@ class RuntimeLLMModel:
|
||||
token_mgr: token.TokenManager
|
||||
"""api key管理器"""
|
||||
|
||||
requester: LLMAPIRequester
|
||||
requester: ProviderAPIRequester
|
||||
"""请求器实例"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_entity: persistence_model.LLMModel,
|
||||
token_mgr: token.TokenManager,
|
||||
requester: LLMAPIRequester,
|
||||
requester: ProviderAPIRequester,
|
||||
):
|
||||
self.model_entity = model_entity
|
||||
self.token_mgr = token_mgr
|
||||
self.requester = requester
|
||||
|
||||
|
||||
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""LLM API请求器"""
|
||||
class RuntimeEmbeddingModel:
|
||||
"""运行时 Embedding 模型"""
|
||||
|
||||
model_entity: persistence_model.EmbeddingModel
|
||||
"""模型数据"""
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
"""api key管理器"""
|
||||
|
||||
requester: ProviderAPIRequester
|
||||
"""请求器实例"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_entity: persistence_model.EmbeddingModel,
|
||||
token_mgr: token.TokenManager,
|
||||
requester: ProviderAPIRequester,
|
||||
):
|
||||
self.model_entity = model_entity
|
||||
self.token_mgr = token_mgr
|
||||
self.requester = requester
|
||||
|
||||
|
||||
class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""Provider API请求器"""
|
||||
|
||||
name: str = None
|
||||
|
||||
@@ -74,3 +97,22 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
llm_entities.Message: 返回消息对象
|
||||
"""
|
||||
pass
|
||||
|
||||
async def invoke_embedding(
|
||||
self,
|
||||
model: RuntimeEmbeddingModel,
|
||||
input_text: list[str],
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> list[list[float]]:
|
||||
"""调用 Embedding API
|
||||
|
||||
Args:
|
||||
query (core_entities.Query): 请求上下文
|
||||
model (RuntimeEmbeddingModel): 使用的模型信息
|
||||
input_text (list[str]): 输入文本
|
||||
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: 返回的 embedding 向量
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./302aichatcmpl.py
|
||||
|
||||
@@ -15,7 +15,7 @@ from ...tools import entities as tools_entities
|
||||
from ....utils import image
|
||||
|
||||
|
||||
class AnthropicMessages(requester.LLMAPIRequester):
|
||||
class AnthropicMessages(requester.ProviderAPIRequester):
|
||||
"""Anthropic Messages API 请求器"""
|
||||
|
||||
client: anthropic.AsyncAnthropic
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./anthropicmsgs.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./bailianchatcmpl.py
|
||||
|
||||
@@ -13,7 +13,7 @@ from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
|
||||
class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
"""OpenAI ChatCompletion API 请求器"""
|
||||
|
||||
client: openai.AsyncClient
|
||||
@@ -141,3 +141,39 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||
except openai.APIError as e:
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
|
||||
async def invoke_embedding(
|
||||
self,
|
||||
model: requester.RuntimeEmbeddingModel,
|
||||
input_text: list[str],
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> list[list[float]]:
|
||||
"""调用 Embedding API"""
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
|
||||
args = {
|
||||
'model': model.model_entity.name,
|
||||
'input': input_text,
|
||||
}
|
||||
|
||||
if model.model_entity.extra_args:
|
||||
args.update(model.model_entity.extra_args)
|
||||
|
||||
args.update(extra_args)
|
||||
|
||||
try:
|
||||
resp = await self.client.embeddings.create(**args)
|
||||
|
||||
return [d.embedding for d in resp.data]
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
||||
except openai.AuthenticationError as e:
|
||||
raise errors.RequesterError(f'无效的 api-key: {e.message}')
|
||||
except openai.NotFoundError as e:
|
||||
raise errors.RequesterError(f'请求路径错误: {e.message}')
|
||||
except openai.RateLimitError as e:
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||
except openai.APIError as e:
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
|
||||
@@ -22,6 +22,9 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
- text-embedding
|
||||
execution:
|
||||
python:
|
||||
path: ./chatcmpl.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./compsharechatcmpl.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./deepseekchatcmpl.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./geminichatcmpl.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./giteeaichatcmpl.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./lmstudiochatcmpl.py
|
||||
|
||||
@@ -14,7 +14,7 @@ from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
|
||||
class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
||||
class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
"""ModelScope ChatCompletion API 请求器"""
|
||||
|
||||
client: openai.AsyncClient
|
||||
|
||||
@@ -29,6 +29,8 @@ spec:
|
||||
type: int
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./modelscopechatcmpl.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./moonshotchatcmpl.py
|
||||
|
||||
@@ -17,7 +17,7 @@ from ....core import entities as core_entities
|
||||
REQUESTER_NAME: str = 'ollama-chat'
|
||||
|
||||
|
||||
class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
class OllamaChatCompletions(requester.ProviderAPIRequester):
|
||||
"""Ollama平台 ChatCompletion API请求器"""
|
||||
|
||||
client: ollama.AsyncClient
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./ollamachat.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./openrouterchatcmpl.py
|
||||
|
||||
@@ -29,6 +29,8 @@ spec:
|
||||
type: int
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./ppiochatcmpl.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./siliconflowchatcmpl.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./volcarkchatcmpl.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./xaichatcmpl.py
|
||||
|
||||
@@ -22,6 +22,8 @@ spec:
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
execution:
|
||||
python:
|
||||
path: ./zhipuaichatcmpl.py
|
||||
|
||||
@@ -1,13 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import copy
|
||||
import typing
|
||||
|
||||
from .. import runner
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
|
||||
|
||||
rag_combined_prompt_template = """
|
||||
The following are relevant context entries retrieved from the knowledge base.
|
||||
Please use them to answer the user's message.
|
||||
Respond in the same language as the user's input.
|
||||
|
||||
<context>
|
||||
{rag_context}
|
||||
</context>
|
||||
|
||||
<user_message>
|
||||
{user_message}
|
||||
</user_message>
|
||||
"""
|
||||
|
||||
|
||||
@runner.runner_class('local-agent')
|
||||
class LocalAgentRunner(runner.RequestRunner):
|
||||
"""本地Agent请求运行器"""
|
||||
@@ -16,7 +31,54 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
"""运行请求"""
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
kb_uuid = query.pipeline_config['ai']['local-agent']['knowledge-base']
|
||||
|
||||
if kb_uuid == '__none__':
|
||||
kb_uuid = None
|
||||
|
||||
user_message = copy.deepcopy(query.user_message)
|
||||
|
||||
user_message_text = ''
|
||||
|
||||
if isinstance(user_message.content, str):
|
||||
user_message_text = user_message.content
|
||||
elif isinstance(user_message.content, list):
|
||||
for ce in user_message.content:
|
||||
if ce.type == 'text':
|
||||
user_message_text += ce.text
|
||||
break
|
||||
|
||||
if kb_uuid and user_message_text:
|
||||
# only support text for now
|
||||
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
|
||||
if not kb:
|
||||
self.ap.logger.warning(f'Knowledge base {kb_uuid} not found')
|
||||
raise ValueError(f'Knowledge base {kb_uuid} not found')
|
||||
|
||||
result = await kb.retrieve(user_message_text)
|
||||
|
||||
final_user_message_text = ''
|
||||
|
||||
if result:
|
||||
rag_context = '\n\n'.join(
|
||||
f'[{i + 1}] {entry.metadata.get("text", "")}' for i, entry in enumerate(result)
|
||||
)
|
||||
final_user_message_text = rag_combined_prompt_template.format(
|
||||
rag_context=rag_context, user_message=user_message_text
|
||||
)
|
||||
|
||||
else:
|
||||
final_user_message_text = user_message_text
|
||||
|
||||
self.ap.logger.debug(f'Final user message text: {final_user_message_text}')
|
||||
|
||||
for ce in user_message.content:
|
||||
if ce.type == 'text':
|
||||
ce.text = final_user_message_text
|
||||
break
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_llm_model.requester.invoke_llm(
|
||||
|
||||
212
pkg/rag/knowledge/kbmgr.py
Normal file
212
pkg/rag/knowledge/kbmgr.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from __future__ import annotations
|
||||
import traceback
|
||||
import uuid
|
||||
from .services import parser, chunker
|
||||
from pkg.core import app
|
||||
from pkg.rag.knowledge.services.embedder import Embedder
|
||||
from pkg.rag.knowledge.services.retriever import Retriever
|
||||
import sqlalchemy
|
||||
from ...entity.persistence import rag as persistence_rag
|
||||
from pkg.core import taskmgr
|
||||
from ...entity.rag import retriever as retriever_entities
|
||||
|
||||
|
||||
class RuntimeKnowledgeBase:
|
||||
ap: app.Application
|
||||
|
||||
knowledge_base_entity: persistence_rag.KnowledgeBase
|
||||
|
||||
parser: parser.FileParser
|
||||
|
||||
chunker: chunker.Chunker
|
||||
|
||||
embedder: Embedder
|
||||
|
||||
retriever: Retriever
|
||||
|
||||
def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase):
|
||||
self.ap = ap
|
||||
self.knowledge_base_entity = knowledge_base_entity
|
||||
self.parser = parser.FileParser(ap=self.ap)
|
||||
self.chunker = chunker.Chunker(ap=self.ap)
|
||||
self.embedder = Embedder(ap=self.ap)
|
||||
self.retriever = Retriever(ap=self.ap)
|
||||
# 传递kb_id给retriever
|
||||
self.retriever.kb_id = knowledge_base_entity.uuid
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def _store_file_task(self, file: persistence_rag.File, task_context: taskmgr.TaskContext):
|
||||
try:
|
||||
# set file status to processing
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.File)
|
||||
.where(persistence_rag.File.uuid == file.uuid)
|
||||
.values(status='processing')
|
||||
)
|
||||
|
||||
task_context.set_current_action('Parsing file')
|
||||
# parse file
|
||||
text = await self.parser.parse(file.file_name, file.extension)
|
||||
if not text:
|
||||
raise Exception(f'No text extracted from file {file.file_name}')
|
||||
|
||||
task_context.set_current_action('Chunking file')
|
||||
# chunk file
|
||||
chunks_texts = await self.chunker.chunk(text)
|
||||
if not chunks_texts:
|
||||
raise Exception(f'No chunks extracted from file {file.file_name}')
|
||||
|
||||
task_context.set_current_action('Embedding chunks')
|
||||
|
||||
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
|
||||
self.knowledge_base_entity.embedding_model_uuid
|
||||
)
|
||||
# embed chunks
|
||||
await self.embedder.embed_and_store(
|
||||
kb_id=self.knowledge_base_entity.uuid,
|
||||
file_id=file.uuid,
|
||||
chunks=chunks_texts,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
# set file status to completed
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.File)
|
||||
.where(persistence_rag.File.uuid == file.uuid)
|
||||
.values(status='completed')
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error storing file {file.uuid}: {e}')
|
||||
traceback.print_exc()
|
||||
# set file status to failed
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.File)
|
||||
.where(persistence_rag.File.uuid == file.uuid)
|
||||
.values(status='failed')
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
async def store_file(self, file_id: str) -> str:
|
||||
# pre checking
|
||||
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
|
||||
raise Exception(f'File {file_id} not found')
|
||||
|
||||
file_uuid = str(uuid.uuid4())
|
||||
kb_id = self.knowledge_base_entity.uuid
|
||||
file_name = file_id
|
||||
extension = file_name.split('.')[-1]
|
||||
|
||||
file_obj_data = {
|
||||
'uuid': file_uuid,
|
||||
'kb_id': kb_id,
|
||||
'file_name': file_name,
|
||||
'extension': extension,
|
||||
'status': 'pending',
|
||||
}
|
||||
|
||||
file_obj = persistence_rag.File(**file_obj_data)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(file_obj_data))
|
||||
|
||||
# run background task asynchronously
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._store_file_task(file_obj, task_context=ctx),
|
||||
kind='knowledge-operation',
|
||||
name=f'knowledge-store-file-{file_id}',
|
||||
label=f'Store file {file_id}',
|
||||
context=ctx,
|
||||
)
|
||||
return wrapper.id
|
||||
|
||||
async def retrieve(self, query: str) -> list[retriever_entities.RetrieveResultEntry]:
|
||||
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
|
||||
self.knowledge_base_entity.embedding_model_uuid
|
||||
)
|
||||
return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model)
|
||||
|
||||
async def delete_file(self, file_id: str):
|
||||
# delete vector
|
||||
await self.ap.vector_db_mgr.vector_db.delete_by_file_id(self.knowledge_base_entity.uuid, file_id)
|
||||
|
||||
# delete chunk
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file_id)
|
||||
)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id)
|
||||
)
|
||||
|
||||
async def dispose(self):
|
||||
await self.ap.vector_db_mgr.vector_db.delete_collection(self.knowledge_base_entity.uuid)
|
||||
|
||||
|
||||
class RAGManager:
|
||||
ap: app.Application
|
||||
|
||||
knowledge_bases: list[RuntimeKnowledgeBase]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.knowledge_bases = []
|
||||
|
||||
async def initialize(self):
|
||||
await self.load_knowledge_bases_from_db()
|
||||
|
||||
async def load_knowledge_bases_from_db(self):
|
||||
self.ap.logger.info('Loading knowledge bases from db...')
|
||||
|
||||
self.knowledge_bases = []
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
||||
|
||||
knowledge_bases = result.all()
|
||||
|
||||
for knowledge_base in knowledge_bases:
|
||||
try:
|
||||
await self.load_knowledge_base(knowledge_base)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(
|
||||
f'Error loading knowledge base {knowledge_base.uuid}: {e}\n{traceback.format_exc()}'
|
||||
)
|
||||
|
||||
async def load_knowledge_base(
|
||||
self,
|
||||
knowledge_base_entity: persistence_rag.KnowledgeBase | sqlalchemy.Row | dict,
|
||||
) -> RuntimeKnowledgeBase:
|
||||
if isinstance(knowledge_base_entity, sqlalchemy.Row):
|
||||
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity._mapping)
|
||||
elif isinstance(knowledge_base_entity, dict):
|
||||
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity)
|
||||
|
||||
runtime_knowledge_base = RuntimeKnowledgeBase(ap=self.ap, knowledge_base_entity=knowledge_base_entity)
|
||||
|
||||
await runtime_knowledge_base.initialize()
|
||||
|
||||
self.knowledge_bases.append(runtime_knowledge_base)
|
||||
|
||||
return runtime_knowledge_base
|
||||
|
||||
async def get_knowledge_base_by_uuid(self, kb_uuid: str) -> RuntimeKnowledgeBase | None:
|
||||
for kb in self.knowledge_bases:
|
||||
if kb.knowledge_base_entity.uuid == kb_uuid:
|
||||
return kb
|
||||
return None
|
||||
|
||||
async def remove_knowledge_base_from_runtime(self, kb_uuid: str):
|
||||
for kb in self.knowledge_bases:
|
||||
if kb.knowledge_base_entity.uuid == kb_uuid:
|
||||
self.knowledge_bases.remove(kb)
|
||||
return
|
||||
|
||||
async def delete_knowledge_base(self, kb_uuid: str):
|
||||
for kb in self.knowledge_bases:
|
||||
if kb.knowledge_base_entity.uuid == kb_uuid:
|
||||
await kb.dispose()
|
||||
self.knowledge_bases.remove(kb)
|
||||
return
|
||||
0
pkg/rag/knowledge/services/__init__.py
Normal file
0
pkg/rag/knowledge/services/__init__.py
Normal file
15
pkg/rag/knowledge/services/base_service.py
Normal file
15
pkg/rag/knowledge/services/base_service.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# 封装异步操作
|
||||
import asyncio
|
||||
|
||||
|
||||
class BaseService:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def _run_sync(self, func, *args, **kwargs):
|
||||
"""
|
||||
在单独的线程中运行同步函数。
|
||||
如果第一个参数是 session,则在 to_thread 中获取新的 session。
|
||||
"""
|
||||
|
||||
return await asyncio.to_thread(func, *args, **kwargs)
|
||||
63
pkg/rag/knowledge/services/chunker.py
Normal file
63
pkg/rag/knowledge/services/chunker.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
from pkg.rag.knowledge.services import base_service
|
||||
from pkg.core import app
|
||||
|
||||
|
||||
class Chunker(base_service.BaseService):
|
||||
"""
|
||||
A class for splitting long texts into smaller, overlapping chunks.
|
||||
"""
|
||||
|
||||
def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50):
|
||||
self.ap = ap
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
if self.chunk_overlap >= self.chunk_size:
|
||||
self.ap.logger.warning(
|
||||
'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.'
|
||||
)
|
||||
|
||||
def _split_text_sync(self, text: str) -> List[str]:
|
||||
"""
|
||||
Synchronously splits a long text into chunks with specified overlap.
|
||||
This is a CPU-bound operation, intended to be run in a separate thread.
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
# words = text.split()
|
||||
# chunks = []
|
||||
# current_chunk = []
|
||||
|
||||
# for word in words:
|
||||
# current_chunk.append(word)
|
||||
# if len(current_chunk) > self.chunk_size:
|
||||
# chunks.append(" ".join(current_chunk[:self.chunk_size]))
|
||||
# current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
|
||||
|
||||
# if current_chunk:
|
||||
# chunks.append(" ".join(current_chunk))
|
||||
|
||||
# A more robust chunking strategy (e.g., using recursive character text splitter)
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
length_function=len,
|
||||
is_separator_regex=False,
|
||||
)
|
||||
return text_splitter.split_text(text)
|
||||
|
||||
async def chunk(self, text: str) -> List[str]:
|
||||
"""
|
||||
Asynchronously chunks a given text into smaller pieces.
|
||||
"""
|
||||
self.ap.logger.info(f'Chunking text (length: {len(text)})...')
|
||||
# Run the synchronous splitting logic in a separate thread
|
||||
chunks = await self._run_sync(self._split_text_sync, text)
|
||||
self.ap.logger.info(f'Text chunked into {len(chunks)} pieces.')
|
||||
self.ap.logger.debug(f'Chunks: {json.dumps(chunks, indent=4, ensure_ascii=False)}')
|
||||
return chunks
|
||||
47
pkg/rag/knowledge/services/embedder.py
Normal file
47
pkg/rag/knowledge/services/embedder.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from typing import List
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from ....entity.persistence import rag as persistence_rag
|
||||
from ....core import app
|
||||
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
class Embedder(BaseService):
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
|
||||
async def embed_and_store(
|
||||
self, kb_id: str, file_id: str, chunks: List[str], embedding_model: RuntimeEmbeddingModel
|
||||
) -> list[persistence_rag.Chunk]:
|
||||
# save chunk to db
|
||||
chunk_entities: list[persistence_rag.Chunk] = []
|
||||
chunk_ids: list[str] = []
|
||||
|
||||
for chunk_text in chunks:
|
||||
chunk_uuid = str(uuid.uuid4())
|
||||
chunk_ids.append(chunk_uuid)
|
||||
chunk_entity = persistence_rag.Chunk(uuid=chunk_uuid, file_id=file_id, text=chunk_text)
|
||||
chunk_entities.append(chunk_entity)
|
||||
|
||||
chunk_dicts = [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_rag.Chunk, chunk) for chunk in chunk_entities
|
||||
]
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts))
|
||||
|
||||
# get embeddings
|
||||
embeddings_list: list[list[float]] = await embedding_model.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=chunks,
|
||||
extra_args={}, # TODO: add extra args
|
||||
)
|
||||
|
||||
# save embeddings to vdb
|
||||
await self.ap.vector_db_mgr.vector_db.add_embeddings(kb_id, chunk_ids, embeddings_list, chunk_dicts)
|
||||
|
||||
self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.')
|
||||
|
||||
return chunk_entities
|
||||
291
pkg/rag/knowledge/services/parser.py
Normal file
291
pkg/rag/knowledge/services/parser.py
Normal file
@@ -0,0 +1,291 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import PyPDF2
|
||||
import io
|
||||
from docx import Document
|
||||
import chardet
|
||||
from typing import Union, Callable, Any
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
import re
|
||||
import asyncio # Import asyncio for async operations
|
||||
from pkg.core import app
|
||||
|
||||
|
||||
class FileParser:
|
||||
"""
|
||||
A robust file parser class to extract text content from various document formats.
|
||||
It supports TXT, PDF, DOCX, XLSX, CSV, Markdown, HTML, and EPUB files.
|
||||
All core file reading operations are designed to be run synchronously in a thread pool
|
||||
to avoid blocking the asyncio event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Runs a synchronous function in a separate thread to prevent blocking the event loop.
|
||||
This is a general utility method for wrapping blocking I/O operations.
|
||||
"""
|
||||
try:
|
||||
return await asyncio.to_thread(sync_func, *args, **kwargs)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}')
|
||||
raise
|
||||
|
||||
async def parse(self, file_name: str, extension: str) -> Union[str, None]:
|
||||
"""
|
||||
Parses the file based on its extension and returns the extracted text content.
|
||||
This is the main asynchronous entry point for parsing.
|
||||
|
||||
Args:
|
||||
file_name (str): The name of the file to be parsed, get from ap.storage_mgr
|
||||
|
||||
Returns:
|
||||
Union[str, None]: The extracted text content as a single string, or None if parsing fails.
|
||||
"""
|
||||
|
||||
file_extension = extension.lower()
|
||||
parser_method = getattr(self, f'_parse_{file_extension}', None)
|
||||
|
||||
if parser_method is None:
|
||||
self.ap.logger.error(f'Unsupported file format: {file_extension} for file {file_name}')
|
||||
return None
|
||||
|
||||
try:
|
||||
# Pass file_path to the specific parser methods
|
||||
return await parser_method(file_name)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to parse {file_extension} file {file_name}: {e}')
|
||||
return None
|
||||
|
||||
# --- Helper for reading files with encoding detection ---
|
||||
async def _read_file_content(self, file_name: str) -> Union[str, bytes]:
|
||||
"""
|
||||
Reads a file with automatic encoding detection, ensuring the synchronous
|
||||
file read operation runs in a separate thread.
|
||||
"""
|
||||
|
||||
# def _read_sync():
|
||||
# with open(file_path, 'rb') as file:
|
||||
# raw_data = file.read()
|
||||
# detected = chardet.detect(raw_data)
|
||||
# encoding = detected['encoding'] or 'utf-8'
|
||||
|
||||
# if mode == 'r':
|
||||
# return raw_data.decode(encoding, errors='ignore')
|
||||
# return raw_data # For binary mode
|
||||
|
||||
# return await self._run_sync(_read_sync)
|
||||
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
detected = chardet.detect(file_bytes)
|
||||
encoding = detected['encoding'] or 'utf-8'
|
||||
|
||||
return file_bytes.decode(encoding, errors='ignore')
|
||||
|
||||
# --- Specific Parser Methods ---
|
||||
|
||||
async def _parse_txt(self, file_name: str) -> str:
|
||||
"""Parses a TXT file and returns its content."""
|
||||
self.ap.logger.info(f'Parsing TXT file: {file_name}')
|
||||
return await self._read_file_content(file_name)
|
||||
|
||||
async def _parse_pdf(self, file_name: str) -> str:
|
||||
"""Parses a PDF file and returns its text content."""
|
||||
self.ap.logger.info(f'Parsing PDF file: {file_name}')
|
||||
|
||||
# def _parse_pdf_sync():
|
||||
# text_content = []
|
||||
# with open(file_name, 'rb') as file:
|
||||
# pdf_reader = PyPDF2.PdfReader(file)
|
||||
# for page in pdf_reader.pages:
|
||||
# text = page.extract_text()
|
||||
# if text:
|
||||
# text_content.append(text)
|
||||
# return '\n'.join(text_content)
|
||||
|
||||
# return await self._run_sync(_parse_pdf_sync)
|
||||
|
||||
pdf_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_pdf_sync():
|
||||
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
|
||||
text_content = []
|
||||
for page in pdf_reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_content.append(text)
|
||||
return '\n'.join(text_content)
|
||||
|
||||
return await self._run_sync(_parse_pdf_sync)
|
||||
|
||||
async def _parse_docx(self, file_name: str) -> str:
|
||||
"""Parses a DOCX file and returns its text content."""
|
||||
self.ap.logger.info(f'Parsing DOCX file: {file_name}')
|
||||
|
||||
docx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_docx_sync():
|
||||
doc = Document(io.BytesIO(docx_bytes))
|
||||
text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()]
|
||||
return '\n'.join(text_content)
|
||||
|
||||
return await self._run_sync(_parse_docx_sync)
|
||||
|
||||
async def _parse_doc(self, file_name: str) -> str:
|
||||
"""Handles .doc files, explicitly stating lack of direct support."""
|
||||
self.ap.logger.warning(f'Direct .doc parsing is not supported for {file_name}. Please convert to .docx first.')
|
||||
raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.')
|
||||
|
||||
# async def _parse_xlsx(self, file_name: str) -> str:
|
||||
# """Parses an XLSX file, returning text from all sheets."""
|
||||
# self.ap.logger.info(f'Parsing XLSX file: {file_name}')
|
||||
|
||||
# xlsx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
# def _parse_xlsx_sync():
|
||||
# excel_file = pd.ExcelFile(io.BytesIO(xlsx_bytes))
|
||||
# all_sheet_content = []
|
||||
# for sheet_name in excel_file.sheet_names:
|
||||
# df = pd.read_excel(io.BytesIO(xlsx_bytes), sheet_name=sheet_name)
|
||||
# sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n'
|
||||
# all_sheet_content.append(sheet_text)
|
||||
# return '\n'.join(all_sheet_content)
|
||||
|
||||
# return await self._run_sync(_parse_xlsx_sync)
|
||||
|
||||
# async def _parse_csv(self, file_name: str) -> str:
|
||||
# """Parses a CSV file and returns its content as a string."""
|
||||
# self.ap.logger.info(f'Parsing CSV file: {file_name}')
|
||||
|
||||
# csv_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
# def _parse_csv_sync():
|
||||
# # pd.read_csv can often detect encoding, but explicit detection is safer
|
||||
# # raw_data = self._read_file_content(
|
||||
# # file_name, mode='rb'
|
||||
# # ) # Note: this will need to be await outside this sync function
|
||||
# # _ = raw_data
|
||||
# # For simplicity, we'll let pandas handle encoding internally after a raw read.
|
||||
# # A more robust solution might pass encoding directly to pd.read_csv after detection.
|
||||
# detected = chardet.detect(io.BytesIO(csv_bytes))
|
||||
# encoding = detected['encoding'] or 'utf-8'
|
||||
# df = pd.read_csv(io.BytesIO(csv_bytes), encoding=encoding)
|
||||
# return df.to_string(index=False)
|
||||
|
||||
# return await self._run_sync(_parse_csv_sync)
|
||||
|
||||
async def _parse_md(self, file_name: str) -> str:
|
||||
"""Parses a Markdown file, converting it to structured plain text."""
|
||||
self.ap.logger.info(f'Parsing Markdown file: {file_name}')
|
||||
|
||||
md_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_markdown_sync():
|
||||
md_content = io.BytesIO(md_bytes).read().decode('utf-8', errors='ignore')
|
||||
html_content = markdown.markdown(
|
||||
md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
|
||||
)
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
text_parts = []
|
||||
for element in soup.children:
|
||||
if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
||||
level = int(element.name[1])
|
||||
text_parts.append('#' * level + ' ' + element.get_text().strip())
|
||||
elif element.name == 'p':
|
||||
text = element.get_text().strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif element.name in ['ul', 'ol']:
|
||||
for li in element.find_all('li'):
|
||||
text_parts.append(f'* {li.get_text().strip()}')
|
||||
elif element.name == 'pre':
|
||||
code_block = element.get_text().strip()
|
||||
if code_block:
|
||||
text_parts.append(f'```\n{code_block}\n```')
|
||||
elif element.name == 'table':
|
||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
||||
if table_str:
|
||||
text_parts.append(table_str)
|
||||
elif element.name:
|
||||
text = element.get_text(separator=' ', strip=True)
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
|
||||
return cleaned_text.strip()
|
||||
|
||||
return await self._run_sync(_parse_markdown_sync)
|
||||
|
||||
async def _parse_html(self, file_name: str) -> str:
|
||||
"""Parses an HTML file, extracting structured plain text."""
|
||||
self.ap.logger.info(f'Parsing HTML file: {file_name}')
|
||||
|
||||
html_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_html_sync():
|
||||
html_content = io.BytesIO(html_bytes).read().decode('utf-8', errors='ignore')
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
for script_or_style in soup(['script', 'style']):
|
||||
script_or_style.decompose()
|
||||
text_parts = []
|
||||
for element in soup.body.children if soup.body else soup.children:
|
||||
if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
||||
level = int(element.name[1])
|
||||
text_parts.append('#' * level + ' ' + element.get_text().strip())
|
||||
elif element.name == 'p':
|
||||
text = element.get_text().strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif element.name in ['ul', 'ol']:
|
||||
for li in element.find_all('li'):
|
||||
text = li.get_text().strip()
|
||||
if text:
|
||||
text_parts.append(f'* {text}')
|
||||
elif element.name == 'table':
|
||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
||||
if table_str:
|
||||
text_parts.append(table_str)
|
||||
elif element.name:
|
||||
text = element.get_text(separator=' ', strip=True)
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
|
||||
return cleaned_text.strip()
|
||||
|
||||
return await self._run_sync(_parse_html_sync)
|
||||
|
||||
def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int):
|
||||
"""Recursively adds TOC items to text_content (synchronous helper)."""
|
||||
indent = ' ' * level
|
||||
for item in toc_list:
|
||||
if isinstance(item, tuple):
|
||||
chapter, subchapters = item
|
||||
text_content.append(f'{indent}- {chapter.title}')
|
||||
self._add_toc_items_sync(subchapters, text_content, level + 1)
|
||||
else:
|
||||
text_content.append(f'{indent}- {item.title}')
|
||||
|
||||
def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str:
|
||||
"""Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous)."""
|
||||
headers = [th.get_text().strip() for th in table_element.find_all('th')]
|
||||
rows = []
|
||||
for tr in table_element.find_all('tr'):
|
||||
cells = [td.get_text().strip() for td in tr.find_all('td')]
|
||||
if cells:
|
||||
rows.append(cells)
|
||||
|
||||
if not headers and not rows:
|
||||
return ''
|
||||
|
||||
table_lines = []
|
||||
if headers:
|
||||
table_lines.append(' | '.join(headers))
|
||||
table_lines.append(' | '.join(['---'] * len(headers)))
|
||||
|
||||
for row_cells in rows:
|
||||
padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells
|
||||
table_lines.append(' | '.join(padded_cells))
|
||||
|
||||
return '\n'.join(table_lines)
|
||||
48
pkg/rag/knowledge/services/retriever.py
Normal file
48
pkg/rag/knowledge/services/retriever.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import base_service
|
||||
from ....core import app
|
||||
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
|
||||
from ....entity.rag import retriever as retriever_entities
|
||||
|
||||
|
||||
class Retriever(base_service.BaseService):
|
||||
def __init__(self, ap: app.Application):
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
|
||||
async def retrieve(
|
||||
self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5
|
||||
) -> list[retriever_entities.RetrieveResultEntry]:
|
||||
self.ap.logger.info(
|
||||
f"Retrieving for query: '{query[:10]}' with k={k} using {embedding_model.model_entity.uuid}"
|
||||
)
|
||||
|
||||
query_embedding: list[float] = await embedding_model.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=[query],
|
||||
extra_args={}, # TODO: add extra args
|
||||
)
|
||||
|
||||
chroma_results = await self.ap.vector_db_mgr.vector_db.search(kb_id, query_embedding[0], k)
|
||||
|
||||
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
|
||||
matched_chroma_ids = chroma_results.get('ids', [[]])[0]
|
||||
distances = chroma_results.get('distances', [[]])[0]
|
||||
chroma_metadatas = chroma_results.get('metadatas', [[]])[0]
|
||||
|
||||
if not matched_chroma_ids:
|
||||
self.ap.logger.info('No relevant chunks found in Chroma.')
|
||||
return []
|
||||
|
||||
result: list[retriever_entities.RetrieveResultEntry] = []
|
||||
|
||||
for i, id in enumerate(matched_chroma_ids):
|
||||
entry = retriever_entities.RetrieveResultEntry(
|
||||
id=id,
|
||||
metadata=chroma_metadatas[i],
|
||||
distance=distances[i],
|
||||
)
|
||||
result.append(entry)
|
||||
|
||||
return result
|
||||
@@ -1,7 +1,7 @@
|
||||
semantic_version = 'v4.0.9'
|
||||
semantic_version = 'v4.1.0'
|
||||
|
||||
required_database_version = 3
|
||||
"""标记本版本所需要的数据库结构版本,用于判断数据库迁移"""
|
||||
required_database_version = 4
|
||||
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
||||
|
||||
debug_mode = False
|
||||
|
||||
|
||||
0
pkg/vector/__init__.py
Normal file
0
pkg/vector/__init__.py
Normal file
18
pkg/vector/mgr.py
Normal file
18
pkg/vector/mgr.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..core import app
|
||||
from .vdb import VectorDatabase
|
||||
from .vdbs.chroma import ChromaVectorDatabase
|
||||
|
||||
|
||||
class VectorDBManager:
|
||||
ap: app.Application
|
||||
vector_db: VectorDatabase = None
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
# 初始化 Chroma 向量数据库(可扩展为多种实现)
|
||||
if self.vector_db is None:
|
||||
self.vector_db = ChromaVectorDatabase(self.ap)
|
||||
37
pkg/vector/vdb.py
Normal file
37
pkg/vector/vdb.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
from typing import Any, Dict
|
||||
import numpy as np
|
||||
|
||||
|
||||
class VectorDatabase(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
documents: list[str],
|
||||
) -> None:
|
||||
"""向指定 collection 添加向量数据。"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
|
||||
"""在指定 collection 中检索最相似的向量。"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
"""根据 file_id 删除指定 collection 中的向量。"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""获取或创建 collection。"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_collection(self, collection: str):
|
||||
pass
|
||||
0
pkg/vector/vdbs/__init__.py
Normal file
0
pkg/vector/vdbs/__init__.py
Normal file
61
pkg/vector/vdbs/chroma.py
Normal file
61
pkg/vector/vdbs/chroma.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from chromadb import PersistentClient
|
||||
from pkg.vector.vdb import VectorDatabase
|
||||
from pkg.core import app
|
||||
import chromadb
|
||||
import chromadb.errors
|
||||
|
||||
|
||||
class ChromaVectorDatabase(VectorDatabase):
|
||||
def __init__(self, ap: app.Application, base_path: str = './data/chroma'):
|
||||
self.ap = ap
|
||||
self.client = PersistentClient(path=base_path)
|
||||
self._collections = {}
|
||||
|
||||
async def get_or_create_collection(self, collection: str) -> chromadb.Collection:
|
||||
if collection not in self._collections:
|
||||
self._collections[collection] = await asyncio.to_thread(
|
||||
self.client.get_or_create_collection, name=collection
|
||||
)
|
||||
self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.")
|
||||
return self._collections[collection]
|
||||
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
) -> None:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
await asyncio.to_thread(col.add, embeddings=embeddings_list, ids=ids, metadatas=metadatas)
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.")
|
||||
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
results = await asyncio.to_thread(
|
||||
col.query,
|
||||
query_embeddings=query_embedding,
|
||||
n_results=k,
|
||||
include=['metadatas', 'distances', 'documents'],
|
||||
)
|
||||
self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
|
||||
return results
|
||||
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
await asyncio.to_thread(col.delete, where={'file_id': file_id})
|
||||
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with file_id: {file_id}")
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
if collection in self._collections:
|
||||
del self._collections[collection]
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(self.client.delete_collection, name=collection)
|
||||
except chromadb.errors.NotFoundError:
|
||||
self.ap.logger.warning(f"Chroma collection '{collection}' not found.")
|
||||
return
|
||||
self.ap.logger.info(f"Chroma collection '{collection}' deleted.")
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "langbot"
|
||||
version = "4.0.9"
|
||||
version = "4.1.0"
|
||||
description = "高稳定、支持扩展、多模态 - 大模型原生即时通信机器人平台"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10.1"
|
||||
@@ -50,6 +50,16 @@ dependencies = [
|
||||
"ruff>=0.11.9",
|
||||
"pre-commit>=4.2.0",
|
||||
"uv>=0.7.11",
|
||||
"PyPDF2>=3.0.1",
|
||||
"python-docx>=1.1.0",
|
||||
"pandas>=2.2.2",
|
||||
"chardet>=5.2.0",
|
||||
"markdown>=3.6",
|
||||
"beautifulsoup4>=4.12.3",
|
||||
"ebooklib>=0.18",
|
||||
"html2text>=2024.2.26",
|
||||
"langchain>=0.2.0",
|
||||
"chromadb>=0.4.24",
|
||||
]
|
||||
keywords = [
|
||||
"bot",
|
||||
|
||||
@@ -44,7 +44,8 @@
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}
|
||||
]
|
||||
],
|
||||
"knowledge-base": ""
|
||||
},
|
||||
"dify-service-api": {
|
||||
"base-url": "https://api.dify.ai/v1",
|
||||
|
||||
@@ -68,6 +68,16 @@ stages:
|
||||
zh_Hans: 除非您了解消息结构,否则请只使用 system 单提示词
|
||||
type: prompt-editor
|
||||
required: true
|
||||
- name: knowledge-base
|
||||
label:
|
||||
en_US: Knowledge Base
|
||||
zh_Hans: 知识库
|
||||
description:
|
||||
en_US: Configure the knowledge base to use for the agent, if not selected, the agent will directly use the LLM to reply
|
||||
zh_Hans: 配置用于提升回复质量的知识库,若不选择,则直接使用大模型回复
|
||||
type: knowledge-base-selector
|
||||
required: false
|
||||
default: ''
|
||||
- name: dify-service-api
|
||||
label:
|
||||
en_US: Dify Service API
|
||||
@@ -298,3 +308,4 @@ stages:
|
||||
type: string
|
||||
required: false
|
||||
default: 'response'
|
||||
|
||||
|
||||
490
web/package-lock.json
generated
490
web/package-lock.json
generated
@@ -12,23 +12,27 @@
|
||||
"@dnd-kit/sortable": "^10.0.0",
|
||||
"@hookform/resolvers": "^5.0.1",
|
||||
"@radix-ui/react-checkbox": "^1.3.1",
|
||||
"@radix-ui/react-dialog": "^1.1.13",
|
||||
"@radix-ui/react-dialog": "^1.1.14",
|
||||
"@radix-ui/react-hover-card": "^1.1.13",
|
||||
"@radix-ui/react-label": "^2.1.6",
|
||||
"@radix-ui/react-popover": "^1.1.14",
|
||||
"@radix-ui/react-scroll-area": "^1.2.9",
|
||||
"@radix-ui/react-select": "^2.2.4",
|
||||
"@radix-ui/react-slot": "^1.2.2",
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"@radix-ui/react-switch": "^1.2.4",
|
||||
"@radix-ui/react-tabs": "^1.1.11",
|
||||
"@radix-ui/react-toggle": "^1.1.8",
|
||||
"@radix-ui/react-toggle-group": "^1.1.9",
|
||||
"@radix-ui/react-tooltip": "^1.2.7",
|
||||
"@tailwindcss/postcss": "^4.1.5",
|
||||
"@tanstack/react-table": "^8.21.3",
|
||||
"axios": "^1.8.4",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"i18next": "^25.1.2",
|
||||
"i18next-browser-languagedetector": "^8.1.0",
|
||||
"input-otp": "^1.4.2",
|
||||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.507.0",
|
||||
"next": "15.2.4",
|
||||
@@ -1037,6 +1041,24 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-collection/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.2.tgz",
|
||||
"integrity": "sha512-y7TBO4xN4Y94FvcWIOIh18fM4R1A8S4q1jhoz4PNzOoHsFcN8pogcFmZrTYAm4F9VRUrWP/Mw7xSKybIeRI+CQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-compose-refs": {
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.2.tgz",
|
||||
@@ -1068,22 +1090,22 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-dialog": {
|
||||
"version": "1.1.13",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.13.tgz",
|
||||
"integrity": "sha512-ARFmqUyhIVS3+riWzwGTe7JLjqwqgnODBUZdqpWar/z1WFs9z76fuOs/2BOWCR+YboRn4/WN9aoaGVwqNRr8VA==",
|
||||
"version": "1.1.14",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.14.tgz",
|
||||
"integrity": "sha512-+CpweKjqpzTmwRwcYECQcNYbI8V9VSQt0SNFKeEBLgfucbsLssU6Ppq7wUdNXEGb573bMjFhVjKVll8rmV6zMw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/primitive": "1.1.2",
|
||||
"@radix-ui/react-compose-refs": "1.1.2",
|
||||
"@radix-ui/react-context": "1.1.2",
|
||||
"@radix-ui/react-dismissable-layer": "1.1.9",
|
||||
"@radix-ui/react-dismissable-layer": "1.1.10",
|
||||
"@radix-ui/react-focus-guards": "1.1.2",
|
||||
"@radix-ui/react-focus-scope": "1.1.6",
|
||||
"@radix-ui/react-focus-scope": "1.1.7",
|
||||
"@radix-ui/react-id": "1.1.1",
|
||||
"@radix-ui/react-portal": "1.1.8",
|
||||
"@radix-ui/react-portal": "1.1.9",
|
||||
"@radix-ui/react-presence": "1.1.4",
|
||||
"@radix-ui/react-primitive": "2.1.2",
|
||||
"@radix-ui/react-slot": "1.2.2",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-slot": "1.2.3",
|
||||
"@radix-ui/react-use-controllable-state": "1.2.2",
|
||||
"aria-hidden": "^1.2.4",
|
||||
"react-remove-scroll": "^2.6.3"
|
||||
@@ -1103,6 +1125,105 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-dismissable-layer": {
|
||||
"version": "1.1.10",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.10.tgz",
|
||||
"integrity": "sha512-IM1zzRV4W3HtVgftdQiiOmA0AdJlCtMLe00FXaHwgt3rAnNsIyDqshvkIW3hj/iu5hu8ERP7KIYki6NkqDxAwQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/primitive": "1.1.2",
|
||||
"@radix-ui/react-compose-refs": "1.1.2",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-use-callback-ref": "1.1.1",
|
||||
"@radix-ui/react-use-escape-keydown": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-focus-scope": {
|
||||
"version": "1.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.7.tgz",
|
||||
"integrity": "sha512-t2ODlkXBQyn7jkl6TNaw/MtVEVvIGelJDCG41Okq/KwUsJBwQ4XVZsHAVUkK4mBv3ewiAS3PGuUWuY2BoK4ZUw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.2",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-use-callback-ref": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-portal": {
|
||||
"version": "1.1.9",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz",
|
||||
"integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-use-layout-effect": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-primitive": {
|
||||
"version": "2.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz",
|
||||
"integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-slot": "1.2.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-direction": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-direction/-/react-direction-1.1.1.tgz",
|
||||
@@ -1448,24 +1569,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-popover/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz",
|
||||
"integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-popper": {
|
||||
"version": "1.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.6.tgz",
|
||||
@@ -1569,6 +1672,24 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-primitive/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.2.tgz",
|
||||
"integrity": "sha512-y7TBO4xN4Y94FvcWIOIh18fM4R1A8S4q1jhoz4PNzOoHsFcN8pogcFmZrTYAm4F9VRUrWP/Mw7xSKybIeRI+CQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-roving-focus": {
|
||||
"version": "1.1.9",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-roving-focus/-/react-roving-focus-1.1.9.tgz",
|
||||
@@ -1654,24 +1775,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-scroll-area/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz",
|
||||
"integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-select": {
|
||||
"version": "2.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.2.4.tgz",
|
||||
@@ -1715,7 +1818,7 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slot": {
|
||||
"node_modules/@radix-ui/react-select/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.2.tgz",
|
||||
"integrity": "sha512-y7TBO4xN4Y94FvcWIOIh18fM4R1A8S4q1jhoz4PNzOoHsFcN8pogcFmZrTYAm4F9VRUrWP/Mw7xSKybIeRI+CQ==",
|
||||
@@ -1733,6 +1836,70 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-separator": {
|
||||
"version": "1.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.1.7.tgz",
|
||||
"integrity": "sha512-0HEb8R9E8A+jZjvmFCy/J4xhbXy3TV+9XSnGJ3KvTtjlIUy/YQ/p6UYZvi7YbeoeXdyU9+Y3scizK6hkY37baA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-primitive": "2.1.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-separator/node_modules/@radix-ui/react-primitive": {
|
||||
"version": "2.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz",
|
||||
"integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-slot": "1.2.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz",
|
||||
"integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-switch": {
|
||||
"version": "1.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-switch/-/react-switch-1.2.4.tgz",
|
||||
@@ -1846,6 +2013,192 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip": {
|
||||
"version": "1.2.7",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.7.tgz",
|
||||
"integrity": "sha512-Ap+fNYwKTYJ9pzqW+Xe2HtMRbQ/EeWkj2qykZ6SuEV4iS/o1bZI5ssJbk4D2r8XuDuOBVz/tIx2JObtuqU+5Zw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/primitive": "1.1.2",
|
||||
"@radix-ui/react-compose-refs": "1.1.2",
|
||||
"@radix-ui/react-context": "1.1.2",
|
||||
"@radix-ui/react-dismissable-layer": "1.1.10",
|
||||
"@radix-ui/react-id": "1.1.1",
|
||||
"@radix-ui/react-popper": "1.2.7",
|
||||
"@radix-ui/react-portal": "1.1.9",
|
||||
"@radix-ui/react-presence": "1.1.4",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-slot": "1.2.3",
|
||||
"@radix-ui/react-use-controllable-state": "1.2.2",
|
||||
"@radix-ui/react-visually-hidden": "1.2.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-arrow": {
|
||||
"version": "1.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.1.7.tgz",
|
||||
"integrity": "sha512-F+M1tLhO+mlQaOWspE8Wstg+z6PwxwRd8oQ8IXceWz92kfAmalTRf0EjrouQeo7QssEPfCn05B4Ihs1K9WQ/7w==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-primitive": "2.1.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-dismissable-layer": {
|
||||
"version": "1.1.10",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.10.tgz",
|
||||
"integrity": "sha512-IM1zzRV4W3HtVgftdQiiOmA0AdJlCtMLe00FXaHwgt3rAnNsIyDqshvkIW3hj/iu5hu8ERP7KIYki6NkqDxAwQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/primitive": "1.1.2",
|
||||
"@radix-ui/react-compose-refs": "1.1.2",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-use-callback-ref": "1.1.1",
|
||||
"@radix-ui/react-use-escape-keydown": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-popper": {
|
||||
"version": "1.2.7",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.7.tgz",
|
||||
"integrity": "sha512-IUFAccz1JyKcf/RjB552PlWwxjeCJB8/4KxT7EhBHOJM+mN7LdW+B3kacJXILm32xawcMMjb2i0cIZpo+f9kiQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@floating-ui/react-dom": "^2.0.0",
|
||||
"@radix-ui/react-arrow": "1.1.7",
|
||||
"@radix-ui/react-compose-refs": "1.1.2",
|
||||
"@radix-ui/react-context": "1.1.2",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-use-callback-ref": "1.1.1",
|
||||
"@radix-ui/react-use-layout-effect": "1.1.1",
|
||||
"@radix-ui/react-use-rect": "1.1.1",
|
||||
"@radix-ui/react-use-size": "1.1.1",
|
||||
"@radix-ui/rect": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-portal": {
|
||||
"version": "1.1.9",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz",
|
||||
"integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-use-layout-effect": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-primitive": {
|
||||
"version": "2.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz",
|
||||
"integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-slot": "1.2.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-visually-hidden": {
|
||||
"version": "1.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.3.tgz",
|
||||
"integrity": "sha512-pzJq12tEaaIhqjbzpCuv/OypJY/BPavOofm+dbab+MHLajy277+1lLm6JFcGgF5eskJ6mquGirhXY2GD/8u8Ug==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-primitive": "2.1.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-use-callback-ref": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.1.1.tgz",
|
||||
@@ -2295,6 +2648,39 @@
|
||||
"tailwindcss": "4.1.5"
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/react-table": {
|
||||
"version": "8.21.3",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.21.3.tgz",
|
||||
"integrity": "sha512-5nNMTSETP4ykGegmVkhjcS8tTLW6Vl4axfEGQN3v0zdHYbK4UfoqfPChclTrJ4EoK9QynqAu9oUf8VEmrpZ5Ww==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@tanstack/table-core": "8.21.3"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/tannerlinsley"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=16.8",
|
||||
"react-dom": ">=16.8"
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/table-core": {
|
||||
"version": "8.21.3",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.21.3.tgz",
|
||||
"integrity": "sha512-ldZXEhOBb8Is7xLs01fR3YEc3DERiz5silj8tnGkFZytt1abEvl/GhUmCE0PMLaMPTa3Jk4HbKmRlHmu+gCftg==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/tannerlinsley"
|
||||
}
|
||||
},
|
||||
"node_modules/@tybys/wasm-util": {
|
||||
"version": "0.9.0",
|
||||
"resolved": "https://registry.npmjs.org/@tybys/wasm-util/-/wasm-util-0.9.0.tgz",
|
||||
@@ -4763,6 +5149,16 @@
|
||||
"node": ">=0.8.19"
|
||||
}
|
||||
},
|
||||
"node_modules/input-otp": {
|
||||
"version": "1.4.2",
|
||||
"resolved": "https://registry.npmjs.org/input-otp/-/input-otp-1.4.2.tgz",
|
||||
"integrity": "sha512-l3jWwYNvrEa6NTCt7BECfCm48GvwuZzkoeG3gBL2w4CHeOXW3eKFmf9UNYkNfYc3mxMrthMnxjIE07MT0zLBQA==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0.0 || ^19.0.0-rc"
|
||||
}
|
||||
},
|
||||
"node_modules/internal-slot": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.1.0.tgz",
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"scripts": {
|
||||
"dev": "next dev --turbopack",
|
||||
"dev:local": "NEXT_PUBLIC_API_BASE_URL=http://localhost:5300 next dev --turbopack",
|
||||
"dev:local:win": "set NEXT_PUBLIC_API_BASE_URL=http://localhost:5300 && next dev --turbopack",
|
||||
"dev:local:win": "set NEXT_PUBLIC_API_BASE_URL=http://localhost:5300&&next dev --turbopack",
|
||||
"build": "next build",
|
||||
"start": "next start",
|
||||
"lint": "next lint",
|
||||
@@ -23,6 +23,7 @@
|
||||
"@hookform/resolvers": "^5.0.1",
|
||||
"@radix-ui/react-checkbox": "^1.3.1",
|
||||
"@radix-ui/react-dialog": "^1.1.14",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.15",
|
||||
"@radix-ui/react-hover-card": "^1.1.13",
|
||||
"@radix-ui/react-label": "^2.1.6",
|
||||
"@radix-ui/react-popover": "^1.1.14",
|
||||
@@ -36,6 +37,7 @@
|
||||
"@radix-ui/react-toggle-group": "^1.1.9",
|
||||
"@radix-ui/react-tooltip": "^1.2.7",
|
||||
"@tailwindcss/postcss": "^4.1.5",
|
||||
"@tanstack/react-table": "^8.21.3",
|
||||
"axios": "^1.8.4",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
|
||||
@@ -127,10 +127,8 @@ export default function BotDetailDialog({
|
||||
<BotForm
|
||||
initBotId={undefined}
|
||||
onFormSubmit={handleFormSubmit}
|
||||
onFormCancel={handleFormCancel}
|
||||
onBotDeleted={handleBotDeleted}
|
||||
onNewBotCreated={handleNewBotCreated}
|
||||
hideButtons={true}
|
||||
/>
|
||||
</div>
|
||||
<DialogFooter className="px-6 py-4 border-t shrink-0">
|
||||
@@ -199,10 +197,8 @@ export default function BotDetailDialog({
|
||||
<BotForm
|
||||
initBotId={botId}
|
||||
onFormSubmit={handleFormSubmit}
|
||||
onFormCancel={handleFormCancel}
|
||||
onBotDeleted={handleBotDeleted}
|
||||
onNewBotCreated={handleNewBotCreated}
|
||||
hideButtons={true}
|
||||
/>
|
||||
)}
|
||||
{activeMenu === 'logs' && botId && (
|
||||
|
||||
@@ -64,17 +64,13 @@ const getFormSchema = (t: (key: string) => string) =>
|
||||
export default function BotForm({
|
||||
initBotId,
|
||||
onFormSubmit,
|
||||
onFormCancel,
|
||||
onBotDeleted,
|
||||
onNewBotCreated,
|
||||
hideButtons = false,
|
||||
}: {
|
||||
initBotId?: string;
|
||||
onFormSubmit: (value: z.infer<ReturnType<typeof getFormSchema>>) => void;
|
||||
onFormCancel: () => void;
|
||||
onBotDeleted: () => void;
|
||||
onNewBotCreated: (botId: string) => void;
|
||||
hideButtons?: boolean;
|
||||
}) {
|
||||
const { t } = useTranslation();
|
||||
const formSchema = getFormSchema(t);
|
||||
@@ -214,6 +210,7 @@ export default function BotForm({
|
||||
});
|
||||
setAdapterNameToDynamicConfigMap(adapterNameToDynamicConfigMap);
|
||||
}
|
||||
|
||||
async function getBotConfig(
|
||||
botId: string,
|
||||
): Promise<z.infer<typeof formSchema>> {
|
||||
@@ -527,45 +524,6 @@ export default function BotForm({
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{!hideButtons && (
|
||||
<div className="sticky bottom-0 left-0 right-0 bg-background border-t p-4 mt-4">
|
||||
<div className="flex justify-end gap-2">
|
||||
{!initBotId && (
|
||||
<Button
|
||||
type="submit"
|
||||
onClick={form.handleSubmit(onDynamicFormSubmit)}
|
||||
>
|
||||
{t('common.submit')}
|
||||
</Button>
|
||||
)}
|
||||
{initBotId && (
|
||||
<>
|
||||
<Button
|
||||
type="button"
|
||||
variant="destructive"
|
||||
onClick={() => setShowDeleteConfirmModal(true)}
|
||||
>
|
||||
{t('common.delete')}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={form.handleSubmit(onDynamicFormSubmit)}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={() => onFormCancel()}
|
||||
>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</form>
|
||||
</Form>
|
||||
</div>
|
||||
|
||||
@@ -92,7 +92,7 @@ export default function BotConfigPage() {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={styles.configPageContainer}>
|
||||
<div>
|
||||
<BotDetailDialog
|
||||
open={detailDialogOpen}
|
||||
onOpenChange={setDetailDialogOpen}
|
||||
|
||||
@@ -50,6 +50,9 @@ export default function DynamicFormComponent({
|
||||
case 'llm-model-selector':
|
||||
fieldSchema = z.string();
|
||||
break;
|
||||
case 'knowledge-base-selector':
|
||||
fieldSchema = z.string();
|
||||
break;
|
||||
case 'prompt-editor':
|
||||
fieldSchema = z.array(
|
||||
z.object({
|
||||
|
||||
@@ -17,6 +17,7 @@ import { Button } from '@/components/ui/button';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||
import { LLMModel } from '@/app/infra/entities/api';
|
||||
import { KnowledgeBase } from '@/app/infra/entities/api';
|
||||
import { toast } from 'sonner';
|
||||
import {
|
||||
HoverCard,
|
||||
@@ -35,6 +36,7 @@ export default function DynamicFormItemComponent({
|
||||
field: ControllerRenderProps<any, any>;
|
||||
}) {
|
||||
const [llmModels, setLlmModels] = useState<LLMModel[]>([]);
|
||||
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBase[]>([]);
|
||||
const { t } = useTranslation();
|
||||
|
||||
useEffect(() => {
|
||||
@@ -50,6 +52,19 @@ export default function DynamicFormItemComponent({
|
||||
}
|
||||
}, [config.type]);
|
||||
|
||||
useEffect(() => {
|
||||
if (config.type === DynamicFormItemType.KNOWLEDGE_BASE_SELECTOR) {
|
||||
httpClient
|
||||
.getKnowledgeBases()
|
||||
.then((resp) => {
|
||||
setKnowledgeBases(resp.bases);
|
||||
})
|
||||
.catch((err) => {
|
||||
toast.error('获取知识库列表失败:' + err.message);
|
||||
});
|
||||
}
|
||||
}, [config.type]);
|
||||
|
||||
switch (config.type) {
|
||||
case DynamicFormItemType.INT:
|
||||
case DynamicFormItemType.FLOAT:
|
||||
@@ -249,6 +264,25 @@ export default function DynamicFormItemComponent({
|
||||
</Select>
|
||||
);
|
||||
|
||||
case DynamicFormItemType.KNOWLEDGE_BASE_SELECTOR:
|
||||
return (
|
||||
<Select value={field.value} onValueChange={field.onChange}>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder={t('knowledge.selectKnowledgeBase')} />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
<SelectItem value="__none__">{t('knowledge.empty')}</SelectItem>
|
||||
{knowledgeBases.map((base) => (
|
||||
<SelectItem key={base.uuid} value={base.uuid ?? ''}>
|
||||
{base.name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
);
|
||||
|
||||
case DynamicFormItemType.PROMPT_EDITOR:
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
|
||||
@@ -47,6 +47,7 @@ export const sidebarConfigList = [
|
||||
zh_Hans: 'https://docs.langbot.app/zh/deploy/models/readme.html',
|
||||
},
|
||||
}),
|
||||
|
||||
new SidebarChildVO({
|
||||
id: 'pipelines',
|
||||
name: t('pipelines.title'),
|
||||
@@ -67,6 +68,25 @@ export const sidebarConfigList = [
|
||||
zh_Hans: 'https://docs.langbot.app/zh/deploy/pipelines/readme.html',
|
||||
},
|
||||
}),
|
||||
new SidebarChildVO({
|
||||
id: 'knowledge',
|
||||
name: t('knowledge.title'),
|
||||
icon: (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path d="M3 18.5V5C3 3.34315 4.34315 2 6 2H20C20.5523 2 21 2.44772 21 3V21C21 21.5523 20.5523 22 20 22H6.5C4.567 22 3 20.433 3 18.5ZM19 20V17H6.5C5.67157 17 5 17.6716 5 18.5C5 19.3284 5.67157 20 6.5 20H19ZM10 4H6C5.44772 4 5 4.44772 5 5V15.3368C5.45463 15.1208 5.9632 15 6.5 15H19V4H17V12L13.5 10L10 12V4Z"></path>
|
||||
</svg>
|
||||
),
|
||||
route: '/home/knowledge',
|
||||
description: t('knowledge.description'),
|
||||
helpLink: {
|
||||
en_US: 'https://docs.langbot.app/en/deploy/knowledge/readme.html',
|
||||
zh_Hans: 'https://docs.langbot.app/zh/deploy/knowledge/readme.html',
|
||||
},
|
||||
}),
|
||||
new SidebarChildVO({
|
||||
id: 'plugins',
|
||||
name: t('plugins.title'),
|
||||
|
||||
236
web/src/app/home/knowledge/KBDetailDialog.tsx
Normal file
236
web/src/app/home/knowledge/KBDetailDialog.tsx
Normal file
@@ -0,0 +1,236 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogFooter,
|
||||
} from '@/components/ui/dialog';
|
||||
import {
|
||||
Sidebar,
|
||||
SidebarContent,
|
||||
SidebarGroup,
|
||||
SidebarGroupContent,
|
||||
SidebarMenu,
|
||||
SidebarMenuButton,
|
||||
SidebarMenuItem,
|
||||
SidebarProvider,
|
||||
} from '@/components/ui/sidebar';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||
// import { KnowledgeBase } from '@/app/infra/entities/api';
|
||||
import KBForm from '@/app/home/knowledge/components/kb-form/KBForm';
|
||||
import KBDoc from '@/app/home/knowledge/components/kb-docs/KBDoc';
|
||||
|
||||
interface KBDetailDialogProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
kbId?: string;
|
||||
onFormCancel: () => void;
|
||||
onKbDeleted: () => void;
|
||||
onNewKbCreated: (kbId: string) => void;
|
||||
onKbUpdated: (kbId: string) => void;
|
||||
}
|
||||
|
||||
export default function KBDetailDialog({
|
||||
open,
|
||||
onOpenChange,
|
||||
kbId: propKbId,
|
||||
onFormCancel,
|
||||
onKbDeleted,
|
||||
onNewKbCreated,
|
||||
onKbUpdated,
|
||||
}: KBDetailDialogProps) {
|
||||
const { t } = useTranslation();
|
||||
const [kbId, setKbId] = useState<string | undefined>(propKbId);
|
||||
const [activeMenu, setActiveMenu] = useState('metadata');
|
||||
const [showDeleteConfirm, setShowDeleteConfirm] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
setKbId(propKbId);
|
||||
setActiveMenu('metadata');
|
||||
}, [propKbId, open]);
|
||||
|
||||
const menu = [
|
||||
{
|
||||
key: 'metadata',
|
||||
label: t('knowledge.metadata'),
|
||||
icon: (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path d="M5 7C5 6.17157 5.67157 5.5 6.5 5.5C7.32843 5.5 8 6.17157 8 7C8 7.82843 7.32843 8.5 6.5 8.5C5.67157 8.5 5 7.82843 5 7ZM6.5 3.5C4.567 3.5 3 5.067 3 7C3 8.933 4.567 10.5 6.5 10.5C8.433 10.5 10 8.933 10 7C10 5.067 8.433 3.5 6.5 3.5ZM12 8H20V6H12V8ZM16 17C16 16.1716 16.6716 15.5 17.5 15.5C18.3284 15.5 19 16.1716 19 17C19 17.8284 18.3284 18.5 17.5 18.5C16.6716 18.5 16 17.8284 16 17ZM17.5 13.5C15.567 13.5 14 15.067 14 17C14 18.933 15.567 20.5 17.5 20.5C19.433 20.5 21 18.933 21 17C21 15.067 19.433 13.5 17.5 13.5ZM4 16V18H12V16H4Z"></path>
|
||||
</svg>
|
||||
),
|
||||
},
|
||||
{
|
||||
key: 'documents',
|
||||
label: t('knowledge.documents'),
|
||||
icon: (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path d="M21 8V20.9932C21 21.5501 20.5552 22 20.0066 22H3.9934C3.44495 22 3 21.556 3 21.0082V2.9918C3 2.45531 3.4487 2 4.00221 2H14.9968L21 8ZM19 9H14V4H5V20H19V9ZM8 7H11V9H8V7ZM8 11H16V13H8V11ZM8 15H16V17H8V15Z"></path>
|
||||
</svg>
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
const confirmDelete = () => {
|
||||
httpClient.deleteKnowledgeBase(kbId ?? '').then(() => {
|
||||
onKbDeleted();
|
||||
});
|
||||
setShowDeleteConfirm(false);
|
||||
};
|
||||
|
||||
if (!kbId) {
|
||||
// new kb
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="overflow-hidden p-0 !max-w-[40vw] max-h-[70vh] flex">
|
||||
<main className="flex flex-1 flex-col h-[70vh]">
|
||||
<DialogHeader className="px-6 pt-6 pb-4 shrink-0">
|
||||
<DialogTitle>{t('knowledge.createKnowledgeBase')}</DialogTitle>
|
||||
</DialogHeader>
|
||||
<div className="flex-1 overflow-y-auto px-6 pb-6">
|
||||
{activeMenu === 'metadata' && (
|
||||
<KBForm
|
||||
initKbId={undefined}
|
||||
onNewKbCreated={onNewKbCreated}
|
||||
onKbUpdated={onKbUpdated}
|
||||
/>
|
||||
)}
|
||||
{activeMenu === 'documents' && <div>documents</div>}
|
||||
</div>
|
||||
{activeMenu === 'metadata' && (
|
||||
<DialogFooter className="px-6 py-4 border-t shrink-0">
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button type="submit" form="kb-form">
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={onFormCancel}
|
||||
>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
</div>
|
||||
</DialogFooter>
|
||||
)}
|
||||
</main>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="overflow-hidden p-0 !max-w-[50rem] max-h-[75vh] flex">
|
||||
<SidebarProvider className="items-start w-full flex">
|
||||
<Sidebar
|
||||
collapsible="none"
|
||||
className="hidden md:flex h-[80vh] w-40 min-w-[120px] border-r bg-white"
|
||||
>
|
||||
<SidebarContent>
|
||||
<SidebarGroup>
|
||||
<SidebarGroupContent>
|
||||
<SidebarMenu>
|
||||
{menu.map((item) => (
|
||||
<SidebarMenuItem key={item.key}>
|
||||
<SidebarMenuButton
|
||||
asChild
|
||||
isActive={activeMenu === item.key}
|
||||
onClick={() => setActiveMenu(item.key)}
|
||||
>
|
||||
<a href="#">
|
||||
{item.icon}
|
||||
<span>{item.label}</span>
|
||||
</a>
|
||||
</SidebarMenuButton>
|
||||
</SidebarMenuItem>
|
||||
))}
|
||||
</SidebarMenu>
|
||||
</SidebarGroupContent>
|
||||
</SidebarGroup>
|
||||
</SidebarContent>
|
||||
</Sidebar>
|
||||
<main className="flex flex-1 flex-col h-[75vh]">
|
||||
<DialogHeader className="px-6 pt-6 pb-4 shrink-0">
|
||||
<DialogTitle>
|
||||
{activeMenu === 'metadata'
|
||||
? t('knowledge.editKnowledgeBase')
|
||||
: t('knowledge.editDocument')}
|
||||
</DialogTitle>
|
||||
</DialogHeader>
|
||||
<div className="flex-1 overflow-y-auto px-6 pb-6">
|
||||
{activeMenu === 'metadata' && (
|
||||
<KBForm
|
||||
initKbId={kbId}
|
||||
onNewKbCreated={onNewKbCreated}
|
||||
onKbUpdated={onKbUpdated}
|
||||
/>
|
||||
)}
|
||||
{activeMenu === 'documents' && <KBDoc kbId={kbId} />}
|
||||
</div>
|
||||
{activeMenu === 'metadata' && (
|
||||
<DialogFooter className="px-6 py-4 border-t shrink-0">
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
type="button"
|
||||
variant="destructive"
|
||||
onClick={() => setShowDeleteConfirm(true)}
|
||||
>
|
||||
{t('common.delete')}
|
||||
</Button>
|
||||
<Button type="submit" form="kb-form">
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={onFormCancel}
|
||||
>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
</div>
|
||||
</DialogFooter>
|
||||
)}
|
||||
</main>
|
||||
</SidebarProvider>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
{/* 删除确认对话框 */}
|
||||
<Dialog open={showDeleteConfirm} onOpenChange={setShowDeleteConfirm}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('common.confirmDelete')}</DialogTitle>
|
||||
</DialogHeader>
|
||||
<div className="py-4">
|
||||
{t('knowledge.deleteKnowledgeBaseConfirmation')}
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setShowDeleteConfirm(false)}
|
||||
>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button variant="destructive" onClick={confirmDelete}>
|
||||
{t('common.confirmDelete')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</>
|
||||
);
|
||||
}
|
||||
107
web/src/app/home/knowledge/components/kb-card/KBCard.module.css
Normal file
107
web/src/app/home/knowledge/components/kb-card/KBCard.module.css
Normal file
@@ -0,0 +1,107 @@
|
||||
.cardContainer {
|
||||
width: 100%;
|
||||
height: 10rem;
|
||||
background-color: #fff;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0px 2px 2px 0 rgba(0, 0, 0, 0.2);
|
||||
padding: 1.2rem;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
justify-content: space-between;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.cardContainer:hover {
|
||||
box-shadow: 0px 2px 8px 0 rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.basicInfoContainer {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: space-between;
|
||||
gap: 0.4rem;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.basicInfoNameContainer {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.2rem;
|
||||
}
|
||||
|
||||
.basicInfoNameText {
|
||||
font-size: 1.4rem;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.basicInfoDescriptionText {
|
||||
font-size: 0.9rem;
|
||||
font-weight: 400;
|
||||
display: -webkit-box;
|
||||
-webkit-line-clamp: 3;
|
||||
-webkit-box-orient: vertical;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
color: #b1b1b1;
|
||||
}
|
||||
|
||||
.basicInfoLastUpdatedTimeContainer {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.basicInfoUpdateTimeIcon {
|
||||
width: 1.2rem;
|
||||
height: 1.2rem;
|
||||
}
|
||||
|
||||
.basicInfoUpdateTimeText {
|
||||
font-size: 1rem;
|
||||
font-weight: 400;
|
||||
}
|
||||
|
||||
.operationContainer {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-end;
|
||||
justify-content: space-between;
|
||||
gap: 0.5rem;
|
||||
width: 8rem;
|
||||
}
|
||||
|
||||
.operationDefaultBadge {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.operationDefaultBadgeIcon {
|
||||
width: 1.2rem;
|
||||
height: 1.2rem;
|
||||
color: #ffcd27;
|
||||
}
|
||||
|
||||
.operationDefaultBadgeText {
|
||||
font-size: 1rem;
|
||||
font-weight: 400;
|
||||
color: #ffcd27;
|
||||
}
|
||||
|
||||
.bigText {
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
font-size: 1.4rem;
|
||||
font-weight: bold;
|
||||
max-width: 100%;
|
||||
}
|
||||
|
||||
.debugButtonIcon {
|
||||
width: 1.2rem;
|
||||
height: 1.2rem;
|
||||
}
|
||||
36
web/src/app/home/knowledge/components/kb-card/KBCard.tsx
Normal file
36
web/src/app/home/knowledge/components/kb-card/KBCard.tsx
Normal file
@@ -0,0 +1,36 @@
|
||||
import { KnowledgeBaseVO } from '@/app/home/knowledge/components/kb-card/KBCardVO';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import styles from './KBCard.module.css';
|
||||
|
||||
export default function KBCard({ kbCardVO }: { kbCardVO: KnowledgeBaseVO }) {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<div className={`${styles.cardContainer}`}>
|
||||
<div className={`${styles.basicInfoContainer}`}>
|
||||
<div className={`${styles.basicInfoNameContainer}`}>
|
||||
<div className={`${styles.basicInfoNameText} ${styles.bigText}`}>
|
||||
{kbCardVO.name}
|
||||
</div>
|
||||
<div className={`${styles.basicInfoDescriptionText}`}>
|
||||
{kbCardVO.description}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className={`${styles.basicInfoLastUpdatedTimeContainer}`}>
|
||||
<svg
|
||||
className={`${styles.basicInfoUpdateTimeIcon}`}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path d="M12 22C6.47715 22 2 17.5228 2 12C2 6.47715 6.47715 2 12 2C17.5228 2 22 6.47715 22 12C22 17.5228 17.5228 22 12 22ZM12 20C16.4183 20 20 16.4183 20 12C20 7.58172 16.4183 4 12 4C7.58172 4 4 7.58172 4 12C4 16.4183 7.58172 20 12 20ZM13 12H17V14H11V7H13V12Z"></path>
|
||||
</svg>
|
||||
<div className={`${styles.basicInfoUpdateTimeText}`}>
|
||||
{t('knowledge.updateTime')}
|
||||
{kbCardVO.lastUpdatedTimeAgo}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
23
web/src/app/home/knowledge/components/kb-card/KBCardVO.ts
Normal file
23
web/src/app/home/knowledge/components/kb-card/KBCardVO.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
export interface IKnowledgeBaseVO {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
embeddingModelUUID: string;
|
||||
lastUpdatedTimeAgo: string;
|
||||
}
|
||||
|
||||
export class KnowledgeBaseVO implements IKnowledgeBaseVO {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
embeddingModelUUID: string;
|
||||
lastUpdatedTimeAgo: string;
|
||||
|
||||
constructor(props: IKnowledgeBaseVO) {
|
||||
this.id = props.id;
|
||||
this.name = props.name;
|
||||
this.description = props.description;
|
||||
this.embeddingModelUUID = props.embeddingModelUUID;
|
||||
this.lastUpdatedTimeAgo = props.lastUpdatedTimeAgo;
|
||||
}
|
||||
}
|
||||
145
web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx
Normal file
145
web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx
Normal file
@@ -0,0 +1,145 @@
|
||||
import React, { useCallback, useState } from 'react';
|
||||
import { Card, CardContent } from '@/components/ui/card';
|
||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||
import { toast } from 'sonner';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
interface FileUploadZoneProps {
|
||||
kbId: string;
|
||||
onUploadSuccess: () => void;
|
||||
onUploadError: (error: string) => void;
|
||||
}
|
||||
|
||||
export default function FileUploadZone({
|
||||
kbId,
|
||||
onUploadSuccess,
|
||||
onUploadError,
|
||||
}: FileUploadZoneProps) {
|
||||
const { t } = useTranslation();
|
||||
const [isDragOver, setIsDragOver] = useState(false);
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
|
||||
const handleUpload = useCallback(
|
||||
async (file: File) => {
|
||||
if (isUploading) return;
|
||||
|
||||
setIsUploading(true);
|
||||
const toastId = toast.loading(t('knowledge.documentsTab.uploadingFile'));
|
||||
|
||||
try {
|
||||
// Step 1: Upload file to server
|
||||
const uploadResult = await httpClient.uploadDocumentFile(file);
|
||||
|
||||
// Step 2: Associate file with knowledge base
|
||||
await httpClient.uploadKnowledgeBaseFile(kbId, uploadResult.file_id);
|
||||
|
||||
toast.success(t('knowledge.documentsTab.uploadSuccess'), {
|
||||
id: toastId,
|
||||
});
|
||||
onUploadSuccess();
|
||||
} catch (error) {
|
||||
console.error('File upload failed:', error);
|
||||
const errorMessage = t('knowledge.documentsTab.uploadError');
|
||||
toast.error(errorMessage, { id: toastId });
|
||||
onUploadError(errorMessage);
|
||||
} finally {
|
||||
setIsUploading(false);
|
||||
}
|
||||
},
|
||||
[kbId, isUploading, onUploadSuccess, onUploadError],
|
||||
);
|
||||
|
||||
const handleDragOver = useCallback((e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
setIsDragOver(true);
|
||||
}, []);
|
||||
|
||||
const handleDragLeave = useCallback((e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
setIsDragOver(false);
|
||||
}, []);
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
setIsDragOver(false);
|
||||
|
||||
const files = Array.from(e.dataTransfer.files);
|
||||
if (files.length > 0) {
|
||||
handleUpload(files[0]);
|
||||
}
|
||||
},
|
||||
[handleUpload],
|
||||
);
|
||||
|
||||
const handleFileSelect = useCallback(
|
||||
(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const files = e.target.files;
|
||||
if (files && files.length > 0) {
|
||||
handleUpload(files[0]);
|
||||
}
|
||||
},
|
||||
[handleUpload],
|
||||
);
|
||||
|
||||
return (
|
||||
<Card className="mb-4">
|
||||
<CardContent className="p-4">
|
||||
<div
|
||||
className={`
|
||||
relative border-2 border-dashed rounded-lg p-4 text-center transition-colors
|
||||
${
|
||||
isDragOver
|
||||
? 'border-blue-500 bg-blue-50'
|
||||
: 'border-gray-300 hover:border-gray-400'
|
||||
}
|
||||
${isUploading ? 'opacity-50 pointer-events-none' : ''}
|
||||
`}
|
||||
onDragOver={handleDragOver}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<input
|
||||
type="file"
|
||||
id="file-upload"
|
||||
className="hidden"
|
||||
onChange={handleFileSelect}
|
||||
accept=".pdf,.doc,.docx,.txt,.md,.html"
|
||||
disabled={isUploading}
|
||||
/>
|
||||
|
||||
<label htmlFor="file-upload" className="cursor-pointer block">
|
||||
<div className="space-y-2">
|
||||
<div className="mx-auto w-10 h-10 bg-gray-100 rounded-full flex items-center justify-center">
|
||||
<svg
|
||||
className="w-5 h-5 text-gray-400"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M7 16a4 4 0 01-.88-7.903A5 5 0 1115.9 6L16 6a5 5 0 011 9.9M15 13l-3-3m0 0l-3 3m3-3v12"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<p className="text-base font-medium text-gray-900">
|
||||
{isUploading
|
||||
? t('knowledge.documentsTab.uploading')
|
||||
: t('knowledge.documentsTab.dragAndDrop')}
|
||||
</p>
|
||||
<p className="text-xs text-gray-500 mt-1">
|
||||
{t('knowledge.documentsTab.supportedFormats')}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
72
web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx
Normal file
72
web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx
Normal file
@@ -0,0 +1,72 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||
import { KnowledgeBaseFile } from '@/app/infra/entities/api';
|
||||
import { columns, DocumentFile } from './documents/columns';
|
||||
import { DataTable } from './documents/data-table';
|
||||
import FileUploadZone from './FileUploadZone';
|
||||
import { toast } from 'sonner';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function KBDoc({ kbId }: { kbId: string }) {
|
||||
const [documentsList, setDocumentsList] = useState<DocumentFile[]>([]);
|
||||
const { t } = useTranslation();
|
||||
|
||||
useEffect(() => {
|
||||
getDocumentsList();
|
||||
|
||||
const intervalId = setInterval(() => {
|
||||
getDocumentsList();
|
||||
}, 5000);
|
||||
|
||||
return () => {
|
||||
clearInterval(intervalId);
|
||||
};
|
||||
}, [kbId]);
|
||||
|
||||
async function getDocumentsList() {
|
||||
const resp = await httpClient.getKnowledgeBaseFiles(kbId);
|
||||
setDocumentsList(
|
||||
resp.files.map((file: KnowledgeBaseFile) => {
|
||||
return {
|
||||
uuid: file.uuid,
|
||||
name: file.file_name,
|
||||
status: file.status,
|
||||
};
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
const handleUploadSuccess = () => {
|
||||
// Refresh document list after successful upload
|
||||
getDocumentsList();
|
||||
};
|
||||
|
||||
const handleUploadError = (error: string) => {
|
||||
// Error messages are already handled by toast in FileUploadZone component
|
||||
console.error('Upload failed:', error);
|
||||
};
|
||||
|
||||
const handleDelete = (id: string) => {
|
||||
httpClient
|
||||
.deleteKnowledgeBaseFile(kbId, id)
|
||||
.then(() => {
|
||||
getDocumentsList();
|
||||
toast.success(t('knowledge.documentsTab.fileDeleteSuccess'));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Delete failed:', error);
|
||||
toast.error(t('knowledge.documentsTab.fileDeleteFailed'));
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="container mx-auto py-2">
|
||||
<FileUploadZone
|
||||
kbId={kbId}
|
||||
onUploadSuccess={handleUploadSuccess}
|
||||
onUploadError={handleUploadError}
|
||||
/>
|
||||
<DataTable columns={columns(handleDelete, t)} data={documentsList} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user