chore: stash

This commit is contained in:
Junyan Qin
2025-10-11 19:10:56 +08:00
parent a3552893aa
commit 0f39a31648
4 changed files with 447 additions and 12 deletions

View File

@@ -59,7 +59,7 @@ class ModelManager:
try:
await self.load_llm_model(llm_model)
except provider_errors.RequesterNotFoundError as e:
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping model {llm_model.uuid}')
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping llm model {llm_model.uuid}')
except Exception as e:
self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}')
@@ -67,7 +67,14 @@ class ModelManager:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
embedding_models = result.all()
for embedding_model in embedding_models:
await self.load_embedding_model(embedding_model)
try:
await self.load_embedding_model(embedding_model)
except provider_errors.RequesterNotFoundError as e:
self.ap.logger.warning(
f'Requester {e.requester_name} not found, skipping embedding model {embedding_model.uuid}'
)
except Exception as e:
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
async def init_runtime_llm_model(
self,
@@ -107,6 +114,9 @@ class ModelManager:
elif isinstance(model_info, dict):
model_info = persistence_model.EmbeddingModel(**model_info)
if model_info.requester not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(model_info.requester)
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
await requester_inst.initialize()

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import typing
from contextlib import AsyncExitStack
import traceback
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@@ -9,7 +10,9 @@ from mcp.client.sse import sse_client
from .. import loader
from ....core import app
from ....entity.persistence import mcp as persistence_mcp
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import sqlalchemy
class RuntimeMCPSession:
@@ -27,11 +30,13 @@ class RuntimeMCPSession:
functions: list[resource_tool.LLMTool] = []
def __init__(self, server_name: str, server_config: dict, ap: app.Application):
enable: bool
def __init__(self, server_name: str, server_config: dict, enable: bool, ap: app.Application):
self.server_name = server_name
self.server_config = server_config
self.ap = ap
self.enable = enable
self.session = None
self.exit_stack = AsyncExitStack()
@@ -68,6 +73,12 @@ class RuntimeMCPSession:
await self.session.initialize()
async def initialize(self):
pass
async def start(self):
if not self.enable:
return
self.ap.logger.debug(f'初始化 MCP 会话: {self.server_name} {self.server_config}')
if self.server_config['mode'] == 'stdio':
@@ -123,13 +134,45 @@ class MCPLoader(loader.ToolLoader):
self._last_listed_functions = []
async def initialize(self):
for server_config in self.ap.instance_config.data.get('mcp', {}).get('servers', []):
if not server_config['enable']:
continue
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
await session.initialize()
# self.ap.event_loop.create_task(session.initialize())
self.sessions[server_config['name']] = session
await self.load_mcp_servers_from_db()
async def load_mcp_servers_from_db(self):
self.ap.logger.info('Loading MCP servers from db...')
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
servers = result.all()
for server in servers:
try:
await self.load_mcp_server(server)
except Exception as e:
self.ap.logger.error(f'Failed to load MCP server {server.name}: {e}\n{traceback.format_exc()}')
async def init_runtime_mcp_session(
self,
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
):
if isinstance(server_entity, sqlalchemy.Row):
server_entity = persistence_mcp.MCPServer(**server_entity._mapping)
elif isinstance(server_entity, dict):
server_entity = persistence_mcp.MCPServer(**server_entity)
mixed_config = {
'name': server_entity.name,
'mode': server_entity.mode,
'enable': server_entity.enable,
**server_entity.extra_args,
}
session = RuntimeMCPSession(server_entity.name, mixed_config, server_entity.enable, self.ap)
await session.initialize()
return session
async def load_mcp_server(
self,
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
):
session = await self.init_runtime_mcp_session(server_entity)
self.sessions[server_entity.name] = session
async def get_tools(self) -> list[resource_tool.LLMTool]:
all_functions = []
@@ -150,7 +193,14 @@ class MCPLoader(loader.ToolLoader):
if function.name == name:
return await function.func(**parameters)
raise ValueError(f'未找到工具: {name}')
raise ValueError(f'Tool not found: {name}')
async def remove_mcp_server(self, server_name: str):
if server_name not in self.sessions:
raise ValueError(f'MCP server {server_name} not found')
session = self.sessions.pop(server_name)
await session.shutdown()
async def shutdown(self):
"""关闭工具"""