diff --git a/pkg/persistence/migrations/dbm002_combine_quote_msg_config.py b/pkg/persistence/migrations/dbm002_combine_quote_msg_config.py new file mode 100644 index 00000000..2bb665ea --- /dev/null +++ b/pkg/persistence/migrations/dbm002_combine_quote_msg_config.py @@ -0,0 +1,36 @@ +from .. import migration + +import sqlalchemy + +from ...entity.persistence import pipeline as persistence_pipeline + + +@migration.migration_class(2) +class DBMigrateCombineQuoteMsgConfig(migration.DBMigration): + """引用消息合并配置""" + + async def upgrade(self): + """升级""" + # read all pipelines + pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline)) + + for pipeline in pipelines: + serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) + + config = serialized_pipeline['config'] + + if 'misc' not in config['trigger']: + config['trigger']['misc'] = {} + + if 'combine-quote-message' not in config['trigger']['misc']: + config['trigger']['misc']['combine-quote-message'] = False + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_pipeline.LegacyPipeline) + .where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid']) + .values({'config': config, 'for_version': self.ap.ver_mgr.get_current_version()}) + ) + + async def downgrade(self): + """降级""" + pass diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 6dd909be..c71fad78 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -34,6 +34,7 @@ class PreProcessor(stage.PipelineStage): session = await self.ap.sess_mgr.get_session(query) + # 非 local-agent 时,llm_model 为 None llm_model = ( await self.ap.model_mgr.get_model_by_uuid(query.pipeline_config['ai']['local-agent']['model']) @@ -81,6 +82,7 @@ class PreProcessor(stage.PipelineStage): content_list = [] plain_text = '' + qoute_msg = query.pipeline_config["trigger"].get("misc",'').get("combine-quote-message") for me in query.message_chain: if isinstance(me, platform_message.Plain): @@ -92,6 +94,18 @@ class PreProcessor(stage.PipelineStage): ): if me.base64 is not None: content_list.append(llm_entities.ContentElement.from_image_base64(me.base64)) + elif isinstance(me, platform_message.Quote) and qoute_msg: + for msg in me.origin: + if isinstance(msg, platform_message.Plain): + content_list.append(llm_entities.ContentElement.from_text(msg.text)) + elif isinstance(msg, platform_message.Image): + if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__( + 'vision' + ): + if msg.base64 is not None: + content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64)) + + query.variables['user_message_text'] = plain_text diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index bee97f57..b942b8e9 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -15,6 +15,7 @@ from ...utils import image class AiocqhttpMessageConverter(adapter.MessageConverter): + @staticmethod async def yiri2target( message_chain: platform_message.MessageChain, @@ -66,14 +67,40 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): return msg_list, msg_id, msg_time @staticmethod - async def target2yiri(message: str, message_id: int = -1): + async def target2yiri(message: str, message_id: int = -1,bot=None): message = aiocqhttp.Message(message) + async def process_message_data(msg_data, reply_list): + if msg_data["type"] == "image": + image_base64, image_format = await image.qq_image_url_to_base64(msg_data["data"]['url']) + reply_list.append( + platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}')) + + elif msg_data["type"] == "text": + reply_list.append(platform_message.Plain(text=msg_data["data"]["text"])) + + elif msg_data["type"] == "forward": # 这里来应该传入转发消息组,暂时传入qoute + for forward_msg_datas in msg_data["data"]["content"]: + for forward_msg_data in forward_msg_datas["message"]: + await process_message_data(forward_msg_data, reply_list) + + elif msg_data["type"] == "at": + if msg_data["data"]['qq'] == 'all': + reply_list.append(platform_message.AtAll()) + else: + reply_list.append( + platform_message.At( + target=msg_data["data"]['qq'], + ) + ) + + yiri_msg_list = [] yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) for msg in message: + reply_list = [] if msg.type == 'at': if msg.data['qq'] == 'all': yiri_msg_list.append(platform_message.AtAll()) @@ -88,20 +115,46 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): elif msg.type == 'image': image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url']) yiri_msg_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}')) + elif msg.type == 'forward': + # 暂时不太合理 + # msg_datas = await bot.get_msg(message_id=message_id) + # print(msg_datas) + # for msg_data in msg_datas["message"]: + # await process_message_data(msg_data, yiri_msg_list) + pass + + + elif msg.type == 'reply': # 此处处理引用消息传入Qoute + msg_datas = await bot.get_msg(message_id=msg.data["id"]) + + for msg_data in msg_datas["message"]: + await process_message_data(msg_data, reply_list) + + reply_msg = platform_message.Quote(message_id=msg.data["id"],sender_id=msg_datas["user_id"],origin=reply_list) + yiri_msg_list.append(reply_msg) + + + + + + chain = platform_message.MessageChain(yiri_msg_list) return chain + + class AiocqhttpEventConverter(adapter.EventConverter): @staticmethod async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int): return event.source_platform_object @staticmethod - async def target2yiri(event: aiocqhttp.Event): - yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id) + async def target2yiri(event: aiocqhttp.Event,bot=None): + yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id,bot) + if event.message_type == 'group': permission = 'MEMBER' @@ -205,7 +258,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): async def on_message(event: aiocqhttp.Event): self.bot_account_id = event.self_id try: - return await callback(await self.event_converter.target2yiri(event), self) + return await callback(await self.event_converter.target2yiri(event,self.bot), self) except Exception: traceback.print_exc() diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py index 89eb4ef3..ccf2f8c5 100644 --- a/pkg/utils/constants.py +++ b/pkg/utils/constants.py @@ -1,6 +1,6 @@ semantic_version = 'v4.0.3.1' -required_database_version = 1 +required_database_version = 2 """标记本版本所需要的数据库结构版本,用于判断数据库迁移""" debug_mode = False diff --git a/templates/metadata/pipeline/trigger.yaml b/templates/metadata/pipeline/trigger.yaml index cb60f448..57b7a6a1 100644 --- a/templates/metadata/pipeline/trigger.yaml +++ b/templates/metadata/pipeline/trigger.yaml @@ -117,3 +117,18 @@ stages: type: array[string] required: true default: [] + - name: misc + label: + en_US: Misc + zh_Hans: 杂项 + config: + - name: combine-quote-message + label: + en_US: Combine Quote Message + zh_Hans: 合并引用消息 + description: + en_US: If enabled, the bot will combine the quote message with the user's message + zh_Hans: 如果启用,将合并引用消息与用户发送的消息 + type: boolean + required: true + default: true