feat: adapt more events

This commit is contained in:
Junyan Qin
2025-07-02 11:04:03 +08:00
parent c246470b37
commit ee3da8aa17
8 changed files with 37 additions and 14 deletions

View File

@@ -66,7 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
if not message.strip():
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages:
@@ -85,7 +85,7 @@ class ContentFilterStage(stage.PipelineStage):
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = platform_message.MessageChain(platform_message.Plain(text=message))
query.message_chain = platform_message.MessageChain([platform_message.Plain(text=message)])
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -35,9 +35,9 @@ class Controller:
session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(f'Checking query {query} session {session}')
if not session.semaphore.locked():
if not session._semaphore.locked():
selected_query = query
await session.semaphore.acquire()
await session._semaphore.acquire()
break
@@ -62,7 +62,7 @@ class Controller:
await pipeline.run(selected_query)
async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
(await self.ap.sess_mgr.get_session(selected_query))._semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()

View File

@@ -90,7 +90,7 @@ class RuntimePipeline:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice))
result.user_notice = platform_message.MessageChain([platform_message.Plain(text=result.user_notice)])
elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(*result.user_notice)

View File

@@ -4,9 +4,10 @@ import datetime
from .. import stage, entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ...plugin import events
import langbot_plugin.api.entities.events as events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.context as event_context
@stage.stage_class('PreProcessor')
@@ -108,7 +109,7 @@ class PreProcessor(stage.PipelineStage):
query.user_message = provider_message.Message(role='user', content=content_list)
# =========== 触发事件 PromptPreProcessing
event_ctx = await self.ap.plugin_mgr.emit_event(
event_ctx = event_context.EventContext(
event=events.PromptPreProcessing(
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
default_prompt=query.prompt.messages,
@@ -117,6 +118,12 @@ class PreProcessor(stage.PipelineStage):
)
)
event_ctx_result = await self.ap.plugin_connector.handler.emit_event(
event_ctx.model_dump(serialize_as_any=True)
)
event_ctx = event_context.EventContext.parse_from_dict(event_ctx_result['event_context'])
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt

View File

@@ -7,13 +7,14 @@ import traceback
from .. import handler
from ... import entities
from ....provider import runner as runner_module
from ....plugin import events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.events as events
from ....utils import importutil
from ....provider import runners
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.context as event_context
importutil.import_modules_in_pkg(runners)
@@ -35,7 +36,7 @@ class ChatMessageHandler(handler.MessageHandler):
else events.GroupNormalMessageReceived
)
event_ctx = await self.ap.plugin_mgr.emit_event(
event_ctx = event_context.EventContext(
event=event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
@@ -45,6 +46,12 @@ class ChatMessageHandler(handler.MessageHandler):
)
)
event_ctx_result = await self.ap.plugin_connector.handler.emit_event(
event_ctx.model_dump(serialize_as_any=True)
)
event_ctx = event_context.EventContext.parse_from_dict(event_ctx_result['event_context'])
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = platform_message.MessageChain(event_ctx.event.reply)

View File

@@ -4,9 +4,11 @@ import typing
from .. import entities
from .. import stage
from ...plugin import events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.context as event_context
import langbot_plugin.api.entities.events as events
@stage.stage_class('ResponseWrapper')
@@ -57,7 +59,7 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = str(result.get_content_platform_message_chain())
# ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event(
event_ctx = event_context.EventContext(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
@@ -72,6 +74,13 @@ class ResponseWrapper(stage.PipelineStage):
query=query,
)
)
serialized_event_ctx = event_ctx.model_dump(serialize_as_any=True)
event_ctx_result = await self.ap.plugin_connector.handler.emit_event(serialized_event_ctx)
event_ctx = event_context.EventContext.parse_from_dict(event_ctx_result['event_context'])
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,

View File

@@ -96,7 +96,7 @@ class Message(pydantic.BaseModel):
if self.content is None:
return None
elif isinstance(self.content, str):
return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)])
return platform_message.MessageChain([platform_message.Plain(text=(prefix_text + self.content))])
elif isinstance(self.content, list):
mc = []
for ce in self.content:

View File

@@ -33,8 +33,8 @@ class SessionManager:
session = provider_session.Session(
launcher_type=query.launcher_type,
launcher_id=query.launcher_id,
semaphore=asyncio.Semaphore(session_concurrency),
)
session._semaphore = asyncio.Semaphore(session_concurrency)
self.session_list.append(session)
return session