Files
LangBot/pkg/provider/runners/localagent.py

157 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
from ssl import ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
import typing
from .. import runner
from ...core import entities as core_entities
from .. import entities as llm_entities
@runner.runner_class('local-agent')
class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器"""
class ToolCallTracker:
"""工具调用追踪器"""
def __init__(self):
self.active_calls: dict[str,dict] = {}
self.completed_calls: list[llm_entities.ToolCall] = []
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message | llm_entities.MessageChunk, None]:
"""运行请求"""
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
try:
is_stream = query.adapter.is_stream
except AttributeError:
is_stream = False
# while True:
# pass
if not is_stream:
# 非流式输出,直接请求
# print(123)
msg = await query.use_llm_model.requester.invoke_llm(
query,
query.use_llm_model,
req_messages,
query.use_funcs,
extra_args=query.use_llm_model.model_entity.extra_args,
)
yield msg
final_msg = msg
print(final_msg)
else:
# 流式输出,需要处理工具调用
tool_calls_map: dict[str, llm_entities.ToolCall] = {}
async for msg in query.use_llm_model.requester.invoke_llm_stream(
query,
query.use_llm_model,
req_messages,
query.use_funcs,
stream=is_stream,
extra_args=query.use_llm_model.model_entity.extra_args,
):
yield msg
# if msg.tool_calls:
# for tool_call in msg.tool_calls:
# if tool_call.id not in tool_calls_map:
# tool_calls_map[tool_call.id] = llm_entities.ToolCall(
# id=tool_call.id,
# type=tool_call.type,
# function=llm_entities.FunctionCall(
# name=tool_call.function.name if tool_call.function else '',
# arguments=''
# ),
# )
# if tool_call.function and tool_call.function.arguments:
# # 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖
# tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
final_msg = llm_entities.Message(
role=msg.role,
content=msg.all_content,
tool_calls=list(tool_calls_map.values()),
)
pending_tool_calls = final_msg.tool_calls
req_messages.append(final_msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(query, func.name, parameters)
msg = llm_entities.Message(
role='tool',
content=json.dumps(func_ret, ensure_ascii=False),
tool_call_id=tool_call.id,
)
yield msg
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id)
yield err_msg
req_messages.append(err_msg)
if is_stream:
tool_calls_map = {}
async for msg in await query.use_llm_model.requester.invoke_llm_stream(
query,
query.use_llm_model,
req_messages,
query.use_funcs,
stream=is_stream,
extra_args=query.use_llm_model.model_entity.extra_args,
):
yield msg
if msg.tool_calls:
for tool_call in msg.tool_calls:
if tool_call.id not in tool_calls_map:
tool_calls_map[tool_call.id] = llm_entities.ToolCall(
id=tool_call.id,
type=tool_call.type,
function=llm_entities.FunctionCall(
name=tool_call.function.name if tool_call.function else '',
arguments=''
),
)
if tool_call.function and tool_call.function.arguments:
# 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
final_msg = llm_entities.Message(
role=msg.role,
content=msg.all_content,
tool_calls=list(tool_calls_map.values()),
)
else:
print("非流式")
# 处理完所有调用,再次请求
msg = await query.use_llm_model.requester.invoke_llm(
query,
query.use_llm_model,
req_messages,
query.use_funcs,
extra_args=query.use_llm_model.model_entity.extra_args,
)
yield msg
final_msg = msg
pending_tool_calls = final_msg.tool_calls
req_messages.append(final_msg)