mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-26 03:44:58 +08:00
feat: bind pipeline with runtime manager
This commit is contained in:
@@ -56,7 +56,10 @@ class PipelineService:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data)
|
||||
)
|
||||
# TODO: 更新到pipeline manager
|
||||
|
||||
pipeline = await self.get_pipeline(pipeline_data['uuid'])
|
||||
|
||||
await self.ap.pipeline_mgr.load_pipeline(pipeline)
|
||||
|
||||
return pipeline_data['uuid']
|
||||
|
||||
@@ -67,10 +70,15 @@ class PipelineService:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data)
|
||||
)
|
||||
# TODO: 更新到pipeline manager
|
||||
|
||||
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
|
||||
|
||||
pipeline = await self.get_pipeline(pipeline_uuid)
|
||||
|
||||
await self.ap.pipeline_mgr.load_pipeline(pipeline)
|
||||
|
||||
async def delete_pipeline(self, pipeline_uuid: str) -> None:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
|
||||
)
|
||||
# TODO: 更新到pipeline manager
|
||||
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
|
||||
|
||||
@@ -20,7 +20,7 @@ from ..audit.center import v2 as center_mgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import manager as plugin_mgr
|
||||
from ..pipeline import pool
|
||||
from ..pipeline import controller, stagemgr
|
||||
from ..pipeline import controller, stagemgr, pipelinemgr
|
||||
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
|
||||
from ..persistence import mgr as persistencemgr
|
||||
from ..api.http.controller import main as http_controller
|
||||
@@ -102,6 +102,8 @@ class Application:
|
||||
|
||||
stage_mgr: stagemgr.StageManager = None
|
||||
|
||||
pipeline_mgr: pipelinemgr.PipelineManager = None
|
||||
|
||||
ver_mgr: version_mgr.VersionManager = None
|
||||
|
||||
ann_mgr: announce_mgr.AnnouncementManager = None
|
||||
|
||||
@@ -6,7 +6,7 @@ from .. import stage, app
|
||||
from ...utils import version, proxy, announce, platform
|
||||
from ...audit.center import v2 as center_v2
|
||||
from ...audit import identifier
|
||||
from ...pipeline import pool, controller, stagemgr
|
||||
from ...pipeline import pool, controller, stagemgr, pipelinemgr
|
||||
from ...plugin import manager as plugin_mgr
|
||||
from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
@@ -119,6 +119,10 @@ class BuildAppStage(stage.BootingStage):
|
||||
await stage_mgr.initialize()
|
||||
ap.stage_mgr = stage_mgr
|
||||
|
||||
pipeline_mgr = pipelinemgr.PipelineManager(ap)
|
||||
await pipeline_mgr.initialize()
|
||||
ap.pipeline_mgr = pipeline_mgr
|
||||
|
||||
http_ctrl = http_controller.HTTPController(ap)
|
||||
await http_ctrl.initialize()
|
||||
ap.http_ctrl = http_ctrl
|
||||
|
||||
93
pkg/pipeline/pipelinemgr.py
Normal file
93
pkg/pipeline/pipelinemgr.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ..core import app, entities
|
||||
from ..entity.persistence import pipeline as persistence_pipeline
|
||||
from . import stagemgr, stage
|
||||
|
||||
|
||||
class RuntimePipeline:
|
||||
"""运行时流水线"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
pipeline_entity: persistence_pipeline.LegacyPipeline
|
||||
"""流水线实体"""
|
||||
|
||||
stage_containers: list[stagemgr.StageInstContainer]
|
||||
"""阶段实例容器"""
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[stagemgr.StageInstContainer]):
|
||||
self.ap = ap
|
||||
self.pipeline_entity = pipeline_entity
|
||||
self.stage_containers = stage_containers
|
||||
|
||||
async def run(self):
|
||||
pass
|
||||
|
||||
|
||||
class PipelineManager:
|
||||
"""流水线管理器"""
|
||||
|
||||
# ====== 4.0 ======
|
||||
|
||||
ap: app.Application
|
||||
|
||||
pipelines: list[RuntimePipeline]
|
||||
|
||||
stage_dict: dict[str, type[stage.PipelineStage]]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.pipelines = []
|
||||
|
||||
async def initialize(self):
|
||||
self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()}
|
||||
|
||||
await self.load_pipelines_from_db()
|
||||
|
||||
async def load_pipelines_from_db(self):
|
||||
self.ap.logger.info('Loading pipelines from db...')
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
|
||||
)
|
||||
|
||||
pipelines = result.all()
|
||||
|
||||
# load pipelines
|
||||
for pipeline in pipelines:
|
||||
await self.load_pipeline(pipeline)
|
||||
|
||||
async def load_pipeline(self, pipeline_entity: persistence_pipeline.LegacyPipeline | sqlalchemy.Row[persistence_pipeline.LegacyPipeline] | dict):
|
||||
|
||||
if isinstance(pipeline_entity, sqlalchemy.Row):
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping)
|
||||
elif isinstance(pipeline_entity, dict):
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
|
||||
|
||||
# initialize stage containers according to pipeline_entity.stages
|
||||
stage_containers = []
|
||||
for stage_name in pipeline_entity.stages:
|
||||
stage_containers.append(stagemgr.StageInstContainer(
|
||||
stage_name=stage_name,
|
||||
stage_class=self.stage_dict[stage_name]
|
||||
))
|
||||
|
||||
runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers)
|
||||
self.pipelines.append(runtime_pipeline)
|
||||
|
||||
async def get_pipeline_by_uuid(self, uuid: str) -> RuntimePipeline | None:
|
||||
for pipeline in self.pipelines:
|
||||
if pipeline.pipeline_entity.uuid == uuid:
|
||||
return pipeline
|
||||
return None
|
||||
|
||||
async def remove_pipeline(self, uuid: str):
|
||||
for pipeline in self.pipelines:
|
||||
if pipeline.pipeline_entity.uuid == uuid:
|
||||
self.pipelines.remove(pipeline)
|
||||
return
|
||||
@@ -7,13 +7,13 @@ from ..core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
_stage_classes: dict[str, PipelineStage] = {}
|
||||
preregistered_stages: dict[str, PipelineStage] = {}
|
||||
|
||||
|
||||
def stage_class(name: str):
|
||||
|
||||
def decorator(cls):
|
||||
_stage_classes[name] = cls
|
||||
preregistered_stages[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -58,7 +58,7 @@ class StageManager:
|
||||
"""初始化
|
||||
"""
|
||||
|
||||
for name, cls in stage._stage_classes.items():
|
||||
for name, cls in stage.preregistered_stages.items():
|
||||
self.stage_containers.append(StageInstContainer(
|
||||
inst_name=name,
|
||||
inst=cls(self.ap)
|
||||
|
||||
@@ -72,10 +72,12 @@ class ModelManager:
|
||||
|
||||
self.requester_dict = requester_dict
|
||||
|
||||
await self.load_model_from_db()
|
||||
await self.load_models_from_db()
|
||||
|
||||
async def load_model_from_db(self):
|
||||
async def load_models_from_db(self):
|
||||
"""从数据库加载模型"""
|
||||
self.ap.logger.info('Loading models from db...')
|
||||
|
||||
self.llm_models = []
|
||||
|
||||
# llm models
|
||||
|
||||
Reference in New Issue
Block a user