mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 19:37:36 +08:00
feat: 为 ollama 支持视觉和函数调用 (#950)
This commit is contained in:
@@ -4,6 +4,8 @@ import asyncio
|
||||
import os
|
||||
import typing
|
||||
from typing import Union, Mapping, Any, AsyncIterator
|
||||
import uuid
|
||||
import json
|
||||
|
||||
import async_lru
|
||||
import ollama
|
||||
@@ -60,21 +62,49 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
image_urls.append(image_url)
|
||||
msg["content"] = "\n".join(text_content)
|
||||
msg["images"] = [url.split(',')[1] for url in image_urls]
|
||||
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
|
||||
for tool_call in msg['tool_calls']:
|
||||
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
|
||||
args["messages"] = messages
|
||||
|
||||
resp: Mapping[str, Any] | AsyncIterator[Mapping[str, Any]] = await self._req(args)
|
||||
args["tools"] = []
|
||||
if user_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(user_funcs)
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
||||
resp = await self._req(args)
|
||||
message: llm_entities.Message = await self._make_msg(resp)
|
||||
return message
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completions: Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]) -> llm_entities.Message:
|
||||
message: Any = chat_completions.pop('message', None)
|
||||
chat_completions: ollama.ChatResponse) -> llm_entities.Message:
|
||||
message: ollama.Message = chat_completions.message
|
||||
if message is None:
|
||||
raise ValueError("chat_completions must contain a 'message' field")
|
||||
|
||||
message.update(chat_completions)
|
||||
ret_msg: llm_entities.Message = llm_entities.Message(**message)
|
||||
ret_msg: llm_entities.Message = None
|
||||
|
||||
if message.content is not None:
|
||||
ret_msg = llm_entities.Message(
|
||||
role="assistant",
|
||||
content=message.content
|
||||
)
|
||||
if message.tool_calls is not None and len(message.tool_calls) > 0:
|
||||
tool_calls: list[llm_entities.ToolCall] = []
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
tool_calls.append(llm_entities.ToolCall(
|
||||
id=uuid.uuid4().hex,
|
||||
type="function",
|
||||
function=llm_entities.FunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=json.dumps(tool_call.function.arguments)
|
||||
)
|
||||
))
|
||||
ret_msg.tool_calls = tool_calls
|
||||
|
||||
return ret_msg
|
||||
|
||||
async def call(
|
||||
@@ -92,7 +122,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
msg_dict["content"] = "\n".join(part["text"] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
try:
|
||||
return await self._closure(req_messages, model)
|
||||
return await self._closure(req_messages, model, funcs)
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user