mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-26 03:44:58 +08:00
feat: adapt more events
This commit is contained in:
@@ -66,7 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
if not message.strip():
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.PRE in filter.enable_stages:
|
||||
@@ -85,7 +85,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
||||
message = result.replacement
|
||||
|
||||
query.message_chain = platform_message.MessageChain(platform_message.Plain(text=message))
|
||||
query.message_chain = platform_message.MessageChain([platform_message.Plain(text=message)])
|
||||
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
|
||||
@@ -35,9 +35,9 @@ class Controller:
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
self.ap.logger.debug(f'Checking query {query} session {session}')
|
||||
|
||||
if not session.semaphore.locked():
|
||||
if not session._semaphore.locked():
|
||||
selected_query = query
|
||||
await session.semaphore.acquire()
|
||||
await session._semaphore.acquire()
|
||||
|
||||
break
|
||||
|
||||
@@ -62,7 +62,7 @@ class Controller:
|
||||
await pipeline.run(selected_query)
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
(await self.ap.sess_mgr.get_session(selected_query))._semaphore.release()
|
||||
# 通知其他协程,有新的请求可以处理了
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class RuntimePipeline:
|
||||
# 处理str类型
|
||||
|
||||
if isinstance(result.user_notice, str):
|
||||
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice))
|
||||
result.user_notice = platform_message.MessageChain([platform_message.Plain(text=result.user_notice)])
|
||||
elif isinstance(result.user_notice, list):
|
||||
result.user_notice = platform_message.MessageChain(*result.user_notice)
|
||||
|
||||
|
||||
@@ -4,9 +4,10 @@ import datetime
|
||||
|
||||
from .. import stage, entities
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
from ...plugin import events
|
||||
import langbot_plugin.api.entities.events as events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.context as event_context
|
||||
|
||||
|
||||
@stage.stage_class('PreProcessor')
|
||||
@@ -108,7 +109,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.user_message = provider_message.Message(role='user', content=content_list)
|
||||
# =========== 触发事件 PromptPreProcessing
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event_ctx = event_context.EventContext(
|
||||
event=events.PromptPreProcessing(
|
||||
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
default_prompt=query.prompt.messages,
|
||||
@@ -117,6 +118,12 @@ class PreProcessor(stage.PipelineStage):
|
||||
)
|
||||
)
|
||||
|
||||
event_ctx_result = await self.ap.plugin_connector.handler.emit_event(
|
||||
event_ctx.model_dump(serialize_as_any=True)
|
||||
)
|
||||
|
||||
event_ctx = event_context.EventContext.parse_from_dict(event_ctx_result['event_context'])
|
||||
|
||||
query.prompt.messages = event_ctx.event.default_prompt
|
||||
query.messages = event_ctx.event.prompt
|
||||
|
||||
|
||||
@@ -7,13 +7,14 @@ import traceback
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....provider import runner as runner_module
|
||||
from ....plugin import events
|
||||
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from ....utils import importutil
|
||||
from ....provider import runners
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.context as event_context
|
||||
|
||||
|
||||
importutil.import_modules_in_pkg(runners)
|
||||
@@ -35,7 +36,7 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
else events.GroupNormalMessageReceived
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event_ctx = event_context.EventContext(
|
||||
event=event_class(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
@@ -45,6 +46,12 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
)
|
||||
)
|
||||
|
||||
event_ctx_result = await self.ap.plugin_connector.handler.emit_event(
|
||||
event_ctx.model_dump(serialize_as_any=True)
|
||||
)
|
||||
|
||||
event_ctx = event_context.EventContext.parse_from_dict(event_ctx_result['event_context'])
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||
|
||||
@@ -4,9 +4,11 @@ import typing
|
||||
|
||||
from .. import entities
|
||||
from .. import stage
|
||||
from ...plugin import events
|
||||
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.context as event_context
|
||||
import langbot_plugin.api.entities.events as events
|
||||
|
||||
|
||||
@stage.stage_class('ResponseWrapper')
|
||||
@@ -57,7 +59,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
reply_text = str(result.get_content_platform_message_chain())
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event_ctx = event_context.EventContext(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
@@ -72,6 +74,13 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
serialized_event_ctx = event_ctx.model_dump(serialize_as_any=True)
|
||||
|
||||
event_ctx_result = await self.ap.plugin_connector.handler.emit_event(serialized_event_ctx)
|
||||
|
||||
event_ctx = event_context.EventContext.parse_from_dict(event_ctx_result['event_context'])
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
|
||||
@@ -96,7 +96,7 @@ class Message(pydantic.BaseModel):
|
||||
if self.content is None:
|
||||
return None
|
||||
elif isinstance(self.content, str):
|
||||
return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)])
|
||||
return platform_message.MessageChain([platform_message.Plain(text=(prefix_text + self.content))])
|
||||
elif isinstance(self.content, list):
|
||||
mc = []
|
||||
for ce in self.content:
|
||||
|
||||
@@ -33,8 +33,8 @@ class SessionManager:
|
||||
session = provider_session.Session(
|
||||
launcher_type=query.launcher_type,
|
||||
launcher_id=query.launcher_id,
|
||||
semaphore=asyncio.Semaphore(session_concurrency),
|
||||
)
|
||||
session._semaphore = asyncio.Semaphore(session_concurrency)
|
||||
self.session_list.append(session)
|
||||
return session
|
||||
|
||||
|
||||
Reference in New Issue
Block a user