mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
chore: stash
This commit is contained in:
@@ -59,7 +59,7 @@ class ModelManager:
|
||||
try:
|
||||
await self.load_llm_model(llm_model)
|
||||
except provider_errors.RequesterNotFoundError as e:
|
||||
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping model {llm_model.uuid}')
|
||||
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping llm model {llm_model.uuid}')
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}')
|
||||
|
||||
@@ -67,7 +67,14 @@ class ModelManager:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
|
||||
embedding_models = result.all()
|
||||
for embedding_model in embedding_models:
|
||||
await self.load_embedding_model(embedding_model)
|
||||
try:
|
||||
await self.load_embedding_model(embedding_model)
|
||||
except provider_errors.RequesterNotFoundError as e:
|
||||
self.ap.logger.warning(
|
||||
f'Requester {e.requester_name} not found, skipping embedding model {embedding_model.uuid}'
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
|
||||
|
||||
async def init_runtime_llm_model(
|
||||
self,
|
||||
@@ -107,6 +114,9 @@ class ModelManager:
|
||||
elif isinstance(model_info, dict):
|
||||
model_info = persistence_model.EmbeddingModel(**model_info)
|
||||
|
||||
if model_info.requester not in self.requester_dict:
|
||||
raise provider_errors.RequesterNotFoundError(model_info.requester)
|
||||
|
||||
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
|
||||
|
||||
await requester_inst.initialize()
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from contextlib import AsyncExitStack
|
||||
import traceback
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
@@ -9,7 +10,9 @@ from mcp.client.sse import sse_client
|
||||
|
||||
from .. import loader
|
||||
from ....core import app
|
||||
from ....entity.persistence import mcp as persistence_mcp
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
class RuntimeMCPSession:
|
||||
@@ -27,11 +30,13 @@ class RuntimeMCPSession:
|
||||
|
||||
functions: list[resource_tool.LLMTool] = []
|
||||
|
||||
def __init__(self, server_name: str, server_config: dict, ap: app.Application):
|
||||
enable: bool
|
||||
|
||||
def __init__(self, server_name: str, server_config: dict, enable: bool, ap: app.Application):
|
||||
self.server_name = server_name
|
||||
self.server_config = server_config
|
||||
self.ap = ap
|
||||
|
||||
self.enable = enable
|
||||
self.session = None
|
||||
|
||||
self.exit_stack = AsyncExitStack()
|
||||
@@ -68,6 +73,12 @@ class RuntimeMCPSession:
|
||||
await self.session.initialize()
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def start(self):
|
||||
if not self.enable:
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f'初始化 MCP 会话: {self.server_name} {self.server_config}')
|
||||
|
||||
if self.server_config['mode'] == 'stdio':
|
||||
@@ -123,13 +134,45 @@ class MCPLoader(loader.ToolLoader):
|
||||
self._last_listed_functions = []
|
||||
|
||||
async def initialize(self):
|
||||
for server_config in self.ap.instance_config.data.get('mcp', {}).get('servers', []):
|
||||
if not server_config['enable']:
|
||||
continue
|
||||
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
|
||||
await session.initialize()
|
||||
# self.ap.event_loop.create_task(session.initialize())
|
||||
self.sessions[server_config['name']] = session
|
||||
await self.load_mcp_servers_from_db()
|
||||
|
||||
async def load_mcp_servers_from_db(self):
|
||||
self.ap.logger.info('Loading MCP servers from db...')
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
|
||||
servers = result.all()
|
||||
for server in servers:
|
||||
try:
|
||||
await self.load_mcp_server(server)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to load MCP server {server.name}: {e}\n{traceback.format_exc()}')
|
||||
|
||||
async def init_runtime_mcp_session(
|
||||
self,
|
||||
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
|
||||
):
|
||||
if isinstance(server_entity, sqlalchemy.Row):
|
||||
server_entity = persistence_mcp.MCPServer(**server_entity._mapping)
|
||||
elif isinstance(server_entity, dict):
|
||||
server_entity = persistence_mcp.MCPServer(**server_entity)
|
||||
|
||||
mixed_config = {
|
||||
'name': server_entity.name,
|
||||
'mode': server_entity.mode,
|
||||
'enable': server_entity.enable,
|
||||
**server_entity.extra_args,
|
||||
}
|
||||
|
||||
session = RuntimeMCPSession(server_entity.name, mixed_config, server_entity.enable, self.ap)
|
||||
await session.initialize()
|
||||
|
||||
return session
|
||||
|
||||
async def load_mcp_server(
|
||||
self,
|
||||
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
|
||||
):
|
||||
session = await self.init_runtime_mcp_session(server_entity)
|
||||
self.sessions[server_entity.name] = session
|
||||
|
||||
async def get_tools(self) -> list[resource_tool.LLMTool]:
|
||||
all_functions = []
|
||||
@@ -150,7 +193,14 @@ class MCPLoader(loader.ToolLoader):
|
||||
if function.name == name:
|
||||
return await function.func(**parameters)
|
||||
|
||||
raise ValueError(f'未找到工具: {name}')
|
||||
raise ValueError(f'Tool not found: {name}')
|
||||
|
||||
async def remove_mcp_server(self, server_name: str):
|
||||
if server_name not in self.sessions:
|
||||
raise ValueError(f'MCP server {server_name} not found')
|
||||
|
||||
session = self.sessions.pop(server_name)
|
||||
await session.shutdown()
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭工具"""
|
||||
|
||||
Reference in New Issue
Block a user