Merge pull request #963 from RockChinQ/feat/dl-image-by-adapters

fix: 下载 QQ 图片时的400问题
This commit is contained in:
Junyan Qin
2024-12-24 11:28:31 +08:00
committed by GitHub
10 changed files with 76 additions and 61 deletions

View File

@@ -61,9 +61,9 @@ class PreProcessor(stage.PipelineStage):
)
elif isinstance(me, platform_message.Image):
if self.ap.provider_cfg.data['enable-vision'] and (self.ap.provider_cfg.data['runner'] != 'local-agent' or query.use_model.vision_supported):
if me.url is not None:
if me.base64 is not None:
content_list.append(
llm_entities.ContentElement.from_image_url(str(me.url))
llm_entities.ContentElement.from_image_base64(me.base64)
)
query.user_message = llm_entities.Message(

View File

@@ -6,6 +6,7 @@ import time
import datetime
import aiocqhttp
import aiohttp
from .. import adapter
from ...pipeline.longtext.strategies import forward
@@ -13,12 +14,12 @@ from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ...utils import image
class AiocqhttpMessageConverter(adapter.MessageConverter):
@staticmethod
def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
async def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
msg_list = aiocqhttp.Message()
msg_id = 0
@@ -59,7 +60,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif type(msg) is forward.Forward:
for node in msg.node_list:
msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0])
msg_list.extend(await AiocqhttpMessageConverter.yiri2target(node.message_chain)[0])
else:
msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))
@@ -67,7 +68,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
return msg_list, msg_id, msg_time
@staticmethod
def target2yiri(message: str, message_id: int = -1):
async def target2yiri(message: str, message_id: int = -1):
message = aiocqhttp.Message(message)
yiri_msg_list = []
@@ -89,7 +90,8 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif msg.type == "text":
yiri_msg_list.append(platform_message.Plain(text=msg.data["text"]))
elif msg.type == "image":
yiri_msg_list.append(platform_message.Image(url=msg.data["url"]))
image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url'])
yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}"))
chain = platform_message.MessageChain(yiri_msg_list)
@@ -99,9 +101,9 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
class AiocqhttpEventConverter(adapter.EventConverter):
@staticmethod
def yiri2target(event: platform_events.Event, bot_account_id: int):
async def yiri2target(event: platform_events.Event, bot_account_id: int):
msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain)
msg, msg_id, msg_time = await AiocqhttpMessageConverter.yiri2target(event.message_chain)
if type(event) is platform_events.GroupMessage:
role = "member"
@@ -164,8 +166,8 @@ class AiocqhttpEventConverter(adapter.EventConverter):
return aiocqhttp.Event.from_payload(payload)
@staticmethod
def target2yiri(event: aiocqhttp.Event):
yiri_chain = AiocqhttpMessageConverter.target2yiri(
async def target2yiri(event: aiocqhttp.Event):
yiri_chain = await AiocqhttpMessageConverter.target2yiri(
event.message, event.message_id
)
@@ -242,7 +244,7 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
aiocq_msg = await AiocqhttpMessageConverter.yiri2target(message)[0]
if target_type == "group":
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
@@ -255,8 +257,8 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
message: platform_message.MessageChain,
quote_origin: bool = False,
):
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
aiocq_event = await AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
if quote_origin:
aiocq_msg = aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg
@@ -276,7 +278,7 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id
try:
return await callback(self.event_converter.target2yiri(event), self)
return await callback(await self.event_converter.target2yiri(event), self)
except:
traceback.print_exc()

View File

@@ -38,6 +38,8 @@ class ContentElement(pydantic.BaseModel):
image_url: typing.Optional[ImageURLContentObject] = None
image_base64: typing.Optional[str] = None
def __str__(self):
if self.type == 'text':
return self.text
@@ -53,6 +55,10 @@ class ContentElement(pydantic.BaseModel):
@classmethod
def from_image_url(cls, image_url: str):
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
@classmethod
def from_image_base64(cls, image_base64: str):
return cls(type='image_base64', image_base64=image_base64)
class Message(pydantic.BaseModel):

View File

@@ -48,6 +48,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def call(
self,
query: core_entities.Query,
model: modelmgr_entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import typing
import traceback
import base64
import anthropic
import httpx
@@ -39,6 +40,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
@@ -70,24 +72,20 @@ class AnthropicMessages(requester.LLMAPIRequester):
if isinstance(m.content, str) and m.content.strip() != "":
req_messages.append(m.dict(exclude_none=True))
elif isinstance(m.content, list):
# m.content = [
# c for c in m.content if c.type == "text"
# ]
# if len(m.content) > 0:
# req_messages.append(m.dict(exclude_none=True))
msg_dict = m.dict(exclude_none=True)
for i, ce in enumerate(m.content):
if ce.type == "image_url":
base64_image, image_format = await image.qq_image_url_to_base64(ce.image_url.url)
if ce.type == "image_base64":
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
alter_image_ele = {
"type": "image",
"source": {
"type": "base64",
"media_type": f"image/{image_format}",
"data": base64_image
"data": image_b64
}
}
msg_dict["content"][i] = alter_image_ele

View File

@@ -65,6 +65,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
async def _closure(
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
@@ -87,8 +88,12 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
for me in msg["content"]:
if me["type"] == "image_url":
me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url'])
if me["type"] == "image_base64":
me["image_url"] = {
"url": me["image_base64"]
}
me["type"] = "image_url"
del me["image_base64"]
args["messages"] = messages
@@ -102,6 +107,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
@@ -118,7 +124,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
req_messages.append(msg_dict)
try:
return await self._closure(req_messages, model, funcs)
return await self._closure(query, req_messages, model, funcs)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e:
@@ -134,11 +140,3 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}')
@async_lru.alru_cache(maxsize=128)
async def get_base64_str(
self,
original_url: str,
) -> str:
base64_image, image_format = await image.qq_image_url_to_base64(original_url)
return f"data:image/{image_format};base64,{base64_image}"

View File

@@ -6,6 +6,7 @@ import typing
from typing import Union, Mapping, Any, AsyncIterator
import uuid
import json
import base64
import async_lru
import ollama
@@ -13,7 +14,7 @@ import ollama
from .. import entities, errors, requester
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....core import app
from ....core import app, entities as core_entities
from ....utils import image
REQUESTER_NAME: str = "ollama-chat"
@@ -43,7 +44,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
**args
)
async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelInfo,
async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo,
user_funcs: list[tools_entities.LLMFunction] = None) -> (
llm_entities.Message):
args: Any = self.request_cfg['args'].copy()
@@ -57,9 +58,9 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
for me in msg["content"]:
if me["type"] == "text":
text_content.append(me["text"])
elif me["type"] == "image_url":
image_url = await self.get_base64_str(me["image_url"]['url'])
image_urls.append(image_url)
elif me["type"] == "image_base64":
image_urls.append(me["image_base64"])
msg["content"] = "\n".join(text_content)
msg["images"] = [url.split(',')[1] for url in image_urls]
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
@@ -109,6 +110,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
@@ -122,14 +124,6 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict)
try:
return await self._closure(req_messages, model, funcs)
return await self._closure(query, req_messages, model, funcs)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
@async_lru.alru_cache(maxsize=128)
async def get_base64_str(
self,
original_url: str,
) -> str:
base64_image, image_format = await image.qq_image_url_to_base64(original_url)
return f"data:image/{image_format};base64,{base64_image}"

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import typing
import json
import uuid
import base64
from .. import runner
from ...core import entities as core_entities
@@ -52,10 +53,9 @@ class DifyServiceAPIRunner(runner.RequestRunner):
for ce in query.user_message.content:
if ce.type == "text":
plain_text += ce.text
elif ce.type == "image_url":
file_bytes, image_format = await image.get_qq_image_bytes(
ce.image_url.url
)
elif ce.type == "image_base64":
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
file_bytes = base64.b64decode(image_b64)
file = ("img.png", file_bytes, f"image/{image_format}")
file_upload_resp = await self.dify_client.upload_file(
file,

View File

@@ -23,7 +23,7 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)
yield msg
@@ -61,7 +61,7 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)
yield msg

View File

@@ -1,9 +1,11 @@
import base64
import typing
import io
from urllib.parse import urlparse, parse_qs
import ssl
import aiohttp
import PIL.Image
def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]:
@@ -13,9 +15,10 @@ def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]:
return f"http://{parsed.netloc}{parsed.path}", query
async def get_qq_image_bytes(image_url: str) -> tuple[bytes, str]:
"""获取QQ图片的bytes"""
image_url, query = get_qq_image_downloadable_url(image_url)
async def get_qq_image_bytes(image_url: str, query: dict={}) -> tuple[bytes, str]:
"""[弃用]获取QQ图片的bytes"""
image_url, query_in_url = get_qq_image_downloadable_url(image_url)
query = {**query, **query_in_url}
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
@@ -24,8 +27,11 @@ async def get_qq_image_bytes(image_url: str) -> tuple[bytes, str]:
resp.raise_for_status()
file_bytes = await resp.read()
content_type = resp.headers.get('Content-Type')
if not content_type or not content_type.startswith('image/'):
if not content_type:
image_format = 'jpeg'
elif not content_type.startswith('image/'):
pil_img = PIL.Image.open(io.BytesIO(file_bytes))
image_format = pil_img.format.lower()
else:
image_format = content_type.split('/')[-1]
return file_bytes, image_format
@@ -34,7 +40,7 @@ async def get_qq_image_bytes(image_url: str) -> tuple[bytes, str]:
async def qq_image_url_to_base64(
image_url: str
) -> typing.Tuple[str, str]:
"""将QQ图片URL转为base64并返回图片格式
"""[弃用]将QQ图片URL转为base64并返回图片格式
Args:
image_url (str): QQ图片URL
@@ -47,8 +53,18 @@ async def qq_image_url_to_base64(
# Flatten the query dictionary
query = {k: v[0] for k, v in query.items()}
file_bytes, image_format = await get_qq_image_bytes(image_url)
file_bytes, image_format = await get_qq_image_bytes(image_url, query)
base64_str = base64.b64encode(file_bytes).decode()
return base64_str, image_format
async def extract_b64_and_format(image_base64_data: str) -> typing.Tuple[str, str]:
"""提取base64编码和图片格式

提取出base64编码和图片格式
"""
base64_str = image_base64_data.split(',')[-1]
image_format = image_base64_data.split(':')[-1].split(';')[0].split('/')[-1]
return base64_str, image_format