mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 03:15:06 +08:00
feat: binding bots with runtime (#1238)
This commit is contained in:
committed by
GitHub
parent
5be17c55d2
commit
5379e4cf27
@@ -44,11 +44,16 @@ class BotService:
|
||||
|
||||
async def create_bot(self, bot_data: dict) -> str:
|
||||
"""创建机器人"""
|
||||
# TODO: 检查配置信息格式
|
||||
bot_data['uuid'] = str(uuid.uuid4())
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_bot.Bot).values(bot_data)
|
||||
)
|
||||
# TODO: 加载机器人到机器人管理器
|
||||
|
||||
bot = await self.get_bot(bot_data['uuid'])
|
||||
|
||||
await self.ap.platform_mgr.load_bot(bot)
|
||||
|
||||
return bot_data['uuid']
|
||||
|
||||
async def update_bot(self, bot_uuid: str, bot_data: dict) -> None:
|
||||
@@ -58,13 +63,21 @@ class BotService:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
)
|
||||
# TODO: 加载机器人到机器人管理器
|
||||
await self.ap.platform_mgr.remove_bot(bot_uuid)
|
||||
|
||||
# select from db
|
||||
bot = await self.get_bot(bot_uuid)
|
||||
|
||||
runtime_bot = await self.ap.platform_mgr.load_bot(bot)
|
||||
|
||||
if runtime_bot.enable:
|
||||
await runtime_bot.run()
|
||||
|
||||
async def delete_bot(self, bot_uuid: str) -> None:
|
||||
"""删除机器人"""
|
||||
await self.ap.platform_mgr.remove_bot(bot_uuid)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
)
|
||||
# TODO: 从机器人管理器中删除机器人
|
||||
|
||||
|
||||
|
||||
@@ -35,7 +35,10 @@ class ModelsService:
|
||||
**model_data
|
||||
)
|
||||
)
|
||||
await self.ap.model_mgr.load_llm_model(model_data)
|
||||
|
||||
llm_model = await self.get_llm_model(model_data['uuid'])
|
||||
|
||||
await self.ap.model_mgr.load_llm_model(llm_model)
|
||||
|
||||
return model_data['uuid']
|
||||
|
||||
@@ -60,7 +63,10 @@ class ModelsService:
|
||||
)
|
||||
|
||||
await self.ap.model_mgr.remove_llm_model(model_uuid)
|
||||
await self.ap.model_mgr.load_llm_model(model_data)
|
||||
|
||||
llm_model = await self.get_llm_model(model_uuid)
|
||||
|
||||
await self.ap.model_mgr.load_llm_model(llm_model)
|
||||
|
||||
async def delete_llm_model(self, model_uuid: str) -> None:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
|
||||
@@ -6,13 +6,14 @@ import sys
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
import sqlalchemy
|
||||
|
||||
from .sources import qqofficial
|
||||
|
||||
# FriendMessage, Image, MessageChain, Plain
|
||||
from ..platform import adapter as msadapter
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from ..core import app, entities as core_entities, taskmgr
|
||||
from ..plugin import events
|
||||
from .types import message as platform_message
|
||||
from .types import events as platform_events
|
||||
@@ -20,11 +21,64 @@ from .types import entities as platform_entities
|
||||
|
||||
from ..discover import engine
|
||||
|
||||
from ..entity.persistence import bot as persistence_bot
|
||||
|
||||
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题
|
||||
from . import types as mirai
|
||||
sys.modules['mirai'] = mirai
|
||||
|
||||
|
||||
class RuntimeBot:
|
||||
"""运行时机器人"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
bot_entity: persistence_bot.Bot
|
||||
|
||||
enable: bool
|
||||
|
||||
adapter: msadapter.MessagePlatformAdapter
|
||||
|
||||
task_wrapper: taskmgr.TaskWrapper
|
||||
|
||||
task_context: taskmgr.TaskContext
|
||||
|
||||
def __init__(self, ap: app.Application, bot_entity: persistence_bot.Bot, adapter: msadapter.MessagePlatformAdapter):
|
||||
self.ap = ap
|
||||
self.bot_entity = bot_entity
|
||||
self.enable = bot_entity.enable
|
||||
self.adapter = adapter
|
||||
self.task_context = taskmgr.TaskContext()
|
||||
|
||||
async def run(self):
|
||||
|
||||
async def exception_wrapper():
|
||||
try:
|
||||
self.task_context.set_current_action('Running...')
|
||||
await self.adapter.run_async()
|
||||
self.task_context.set_current_action('Exited.')
|
||||
except Exception as e:
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
self.task_context.set_current_action('Exited.')
|
||||
return
|
||||
self.task_context.set_current_action('Exited with error.')
|
||||
self.task_context.log(f'平台适配器运行出错: {e}')
|
||||
self.task_context.log(f"Traceback: {traceback.format_exc()}")
|
||||
self.ap.logger.error(f'平台适配器运行出错: {e}')
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
self.task_wrapper = self.ap.task_mgr.create_task(
|
||||
exception_wrapper(),
|
||||
kind="platform-adapter",
|
||||
name=f"platform-adapter-{self.adapter.__class__.__name__}",
|
||||
context=self.task_context,
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM]
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
await self.adapter.kill()
|
||||
|
||||
|
||||
# 控制QQ消息输入输出的类
|
||||
class PlatformManager:
|
||||
|
||||
@@ -33,22 +87,55 @@ class PlatformManager:
|
||||
|
||||
message_platform_adapter_components: list[engine.Component] = []
|
||||
|
||||
# modern
|
||||
# ====== 4.0 ======
|
||||
ap: app.Application = None
|
||||
|
||||
bots: list[RuntimeBot]
|
||||
|
||||
adapter_components: list[engine.Component]
|
||||
|
||||
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]]
|
||||
|
||||
def __init__(self, ap: app.Application = None):
|
||||
|
||||
self.ap = ap
|
||||
self.adapters = []
|
||||
self.bots = []
|
||||
self.adapter_components = []
|
||||
self.adapter_dict = {}
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
|
||||
self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
|
||||
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {}
|
||||
for component in self.adapter_components:
|
||||
adapter_dict[component.metadata.name] = component.get_python_component_class()
|
||||
self.adapter_dict = adapter_dict
|
||||
|
||||
self.message_platform_adapter_components = components
|
||||
await self.load_bots_from_db()
|
||||
|
||||
# from .sources import nakuru, aiocqhttp, qqbotpy, qqofficial, wecom, lark, discord, gewechat, officialaccount, telegram, dingtalk
|
||||
async def load_bots_from_db(self):
|
||||
self.ap.logger.info('Loading bots from db...')
|
||||
|
||||
self.bots = []
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_bot.Bot)
|
||||
)
|
||||
|
||||
bots = result.all()
|
||||
|
||||
for bot in bots:
|
||||
# load all bots here, enable or disable will be handled in runtime
|
||||
await self.load_bot(bot)
|
||||
|
||||
async def load_bot(self, bot_entity: persistence_bot.Bot | sqlalchemy.Row[persistence_bot.Bot] | dict) -> RuntimeBot:
|
||||
"""加载机器人"""
|
||||
if isinstance(bot_entity, sqlalchemy.Row):
|
||||
bot_entity = persistence_bot.Bot(**bot_entity._mapping)
|
||||
elif isinstance(bot_entity, dict):
|
||||
bot_entity = persistence_bot.Bot(**bot_entity)
|
||||
|
||||
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
@@ -70,45 +157,44 @@ class PlatformManager:
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
index = 0
|
||||
|
||||
for adap_cfg in self.ap.platform_cfg.data['platform-adapters']:
|
||||
if adap_cfg['enable']:
|
||||
self.ap.logger.info(f'初始化平台适配器 {index}: {adap_cfg["adapter"]}')
|
||||
index += 1
|
||||
cfg_copy = adap_cfg.copy()
|
||||
del cfg_copy['enable']
|
||||
adapter_name = cfg_copy['adapter']
|
||||
del cfg_copy['adapter']
|
||||
adapter_inst = self.adapter_dict[bot_entity.adapter](
|
||||
bot_entity.adapter_config,
|
||||
self.ap
|
||||
)
|
||||
|
||||
found = False
|
||||
adapter_inst.register_listener(
|
||||
platform_events.FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
adapter_inst.register_listener(
|
||||
platform_events.GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
|
||||
for adapter in self.message_platform_adapter_components:
|
||||
if adapter.metadata.name == adapter_name:
|
||||
found = True
|
||||
adapter_cls = adapter.get_python_component_class()
|
||||
|
||||
adapter_inst = adapter_cls(
|
||||
cfg_copy,
|
||||
self.ap
|
||||
)
|
||||
self.adapters.append(adapter_inst)
|
||||
runtime_bot = RuntimeBot(
|
||||
ap=self.ap,
|
||||
bot_entity=bot_entity,
|
||||
adapter=adapter_inst
|
||||
)
|
||||
|
||||
adapter_inst.register_listener(
|
||||
platform_events.FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
adapter_inst.register_listener(
|
||||
platform_events.GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
|
||||
if not found:
|
||||
raise Exception('platform.json 中启用了未知的平台适配器: ' + adapter_name)
|
||||
|
||||
if len(self.adapters) == 0:
|
||||
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
|
||||
self.bots.append(runtime_bot)
|
||||
|
||||
return runtime_bot
|
||||
|
||||
async def get_bot_by_uuid(self, bot_uuid: str) -> RuntimeBot | None:
|
||||
for bot in self.bots:
|
||||
if bot.bot_entity.uuid == bot_uuid:
|
||||
return bot
|
||||
return None
|
||||
|
||||
async def remove_bot(self, bot_uuid: str):
|
||||
for bot in self.bots:
|
||||
if bot.bot_entity.uuid == bot_uuid:
|
||||
if bot.enable:
|
||||
await bot.shutdown()
|
||||
self.bots.remove(bot)
|
||||
return
|
||||
|
||||
def get_available_adapters_info(self) -> list[dict]:
|
||||
return [
|
||||
@@ -168,35 +254,14 @@ class PlatformManager:
|
||||
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] else False
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
tasks = []
|
||||
for adapter in self.adapters:
|
||||
async def exception_wrapper(adapter: msadapter.MessagePlatformAdapter):
|
||||
try:
|
||||
await adapter.run_async()
|
||||
except Exception as e:
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
return
|
||||
self.ap.logger.error('平台适配器运行出错: ' + str(e))
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
async def run(self):
|
||||
# This method will only be called when the application launching
|
||||
for bot in self.bots:
|
||||
if bot.enable:
|
||||
await bot.run()
|
||||
|
||||
tasks.append(exception_wrapper(adapter))
|
||||
|
||||
|
||||
for task in tasks:
|
||||
self.ap.task_mgr.create_task(
|
||||
task,
|
||||
kind="platform-adapter",
|
||||
name=f"platform-adapter-{adapter.__class__.__name__}",
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error('平台适配器运行出错: ' + str(e))
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def shutdown(self):
|
||||
for adapter in self.adapters:
|
||||
await adapter.kill()
|
||||
for bot in self.bots:
|
||||
if bot.enable:
|
||||
await bot.shutdown()
|
||||
self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM)
|
||||
Reference in New Issue
Block a user