feat: bind pipeline with runtime manager

This commit is contained in:
Junyan Qin
2025-03-28 15:55:03 +08:00
parent 5379e4cf27
commit 7cd03b0243
7 changed files with 119 additions and 10 deletions

View File

@@ -56,7 +56,10 @@ class PipelineService:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data)
)
# TODO: 更新到pipeline manager
pipeline = await self.get_pipeline(pipeline_data['uuid'])
await self.ap.pipeline_mgr.load_pipeline(pipeline)
return pipeline_data['uuid']
@@ -67,10 +70,15 @@ class PipelineService:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data)
)
# TODO: 更新到pipeline manager
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
pipeline = await self.get_pipeline(pipeline_uuid)
await self.ap.pipeline_mgr.load_pipeline(pipeline)
async def delete_pipeline(self, pipeline_uuid: str) -> None:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
)
# TODO: 更新到pipeline manager
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)

View File

@@ -20,7 +20,7 @@ from ..audit.center import v2 as center_mgr
from ..command import cmdmgr
from ..plugin import manager as plugin_mgr
from ..pipeline import pool
from ..pipeline import controller, stagemgr
from ..pipeline import controller, stagemgr, pipelinemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
from ..persistence import mgr as persistencemgr
from ..api.http.controller import main as http_controller
@@ -102,6 +102,8 @@ class Application:
stage_mgr: stagemgr.StageManager = None
pipeline_mgr: pipelinemgr.PipelineManager = None
ver_mgr: version_mgr.VersionManager = None
ann_mgr: announce_mgr.AnnouncementManager = None

View File

@@ -6,7 +6,7 @@ from .. import stage, app
from ...utils import version, proxy, announce, platform
from ...audit.center import v2 as center_v2
from ...audit import identifier
from ...pipeline import pool, controller, stagemgr
from ...pipeline import pool, controller, stagemgr, pipelinemgr
from ...plugin import manager as plugin_mgr
from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
@@ -119,6 +119,10 @@ class BuildAppStage(stage.BootingStage):
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
pipeline_mgr = pipelinemgr.PipelineManager(ap)
await pipeline_mgr.initialize()
ap.pipeline_mgr = pipeline_mgr
http_ctrl = http_controller.HTTPController(ap)
await http_ctrl.initialize()
ap.http_ctrl = http_ctrl

View File

@@ -0,0 +1,93 @@
from __future__ import annotations
import typing
import sqlalchemy
from ..core import app, entities
from ..entity.persistence import pipeline as persistence_pipeline
from . import stagemgr, stage
class RuntimePipeline:
"""运行时流水线"""
ap: app.Application
pipeline_entity: persistence_pipeline.LegacyPipeline
"""流水线实体"""
stage_containers: list[stagemgr.StageInstContainer]
"""阶段实例容器"""
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[stagemgr.StageInstContainer]):
self.ap = ap
self.pipeline_entity = pipeline_entity
self.stage_containers = stage_containers
async def run(self):
pass
class PipelineManager:
"""流水线管理器"""
# ====== 4.0 ======
ap: app.Application
pipelines: list[RuntimePipeline]
stage_dict: dict[str, type[stage.PipelineStage]]
def __init__(self, ap: app.Application):
self.ap = ap
self.pipelines = []
async def initialize(self):
self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()}
await self.load_pipelines_from_db()
async def load_pipelines_from_db(self):
self.ap.logger.info('Loading pipelines from db...')
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
)
pipelines = result.all()
# load pipelines
for pipeline in pipelines:
await self.load_pipeline(pipeline)
async def load_pipeline(self, pipeline_entity: persistence_pipeline.LegacyPipeline | sqlalchemy.Row[persistence_pipeline.LegacyPipeline] | dict):
if isinstance(pipeline_entity, sqlalchemy.Row):
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping)
elif isinstance(pipeline_entity, dict):
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
# initialize stage containers according to pipeline_entity.stages
stage_containers = []
for stage_name in pipeline_entity.stages:
stage_containers.append(stagemgr.StageInstContainer(
stage_name=stage_name,
stage_class=self.stage_dict[stage_name]
))
runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers)
self.pipelines.append(runtime_pipeline)
async def get_pipeline_by_uuid(self, uuid: str) -> RuntimePipeline | None:
for pipeline in self.pipelines:
if pipeline.pipeline_entity.uuid == uuid:
return pipeline
return None
async def remove_pipeline(self, uuid: str):
for pipeline in self.pipelines:
if pipeline.pipeline_entity.uuid == uuid:
self.pipelines.remove(pipeline)
return

View File

@@ -7,13 +7,13 @@ from ..core import app, entities as core_entities
from . import entities
_stage_classes: dict[str, PipelineStage] = {}
preregistered_stages: dict[str, PipelineStage] = {}
def stage_class(name: str):
def decorator(cls):
_stage_classes[name] = cls
preregistered_stages[name] = cls
return cls
return decorator

View File

@@ -58,7 +58,7 @@ class StageManager:
"""初始化
"""
for name, cls in stage._stage_classes.items():
for name, cls in stage.preregistered_stages.items():
self.stage_containers.append(StageInstContainer(
inst_name=name,
inst=cls(self.ap)

View File

@@ -72,10 +72,12 @@ class ModelManager:
self.requester_dict = requester_dict
await self.load_model_from_db()
await self.load_models_from_db()
async def load_model_from_db(self):
async def load_models_from_db(self):
"""从数据库加载模型"""
self.ap.logger.info('Loading models from db...')
self.llm_models = []
# llm models