feat: 添加对 Gitee AI 的支持

This commit is contained in:
Junyan Qin
2024-11-21 23:28:19 +08:00
parent 753066ccb9
commit 875adfcbaa
13 changed files with 112 additions and 23 deletions

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("gitee-ai-config", 15)
class GiteeAIConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'gitee-ai' not in self.ap.provider_cfg.data['keys']
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['gitee-ai-chat-completions'] = {
"base-url": "https://ai.gitee.com/v1",
"args": {},
"timeout": 120
}
self.ap.provider_cfg.data['keys']['gitee-ai'] = [
"XXXXX"
]
await self.ap.provider_cfg.dump_config()

View File

@@ -7,6 +7,7 @@ from .. import migration
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config, m013_http_api_config, m014_force_delay_config
from ..migrations import m015_gitee_ai_config
@stage.stage_class("MigrationStage")
@@ -28,3 +29,4 @@ class MigrationStage(stage.BootingStage):
if await migration_instance.need_migrate():
await migration_instance.run()
print(f'已执行迁移 {migration_instance.name}')

View File

@@ -4,7 +4,7 @@ import typing
import pydantic
from . import api
from . import requester
from . import token
@@ -17,7 +17,7 @@ class LLMModelInfo(pydantic.BaseModel):
token_mgr: token.TokenManager
requester: api.LLMAPIRequester
requester: requester.LLMAPIRequester
tool_call_supported: typing.Optional[bool] = False

View File

@@ -2,11 +2,11 @@ from __future__ import annotations
import aiohttp
from . import entities
from . import entities, requester
from ...core import app
from . import token, api
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat
from . import token
from .requesters import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat, giteeaichatcmpl
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
@@ -18,7 +18,7 @@ class ModelManager:
model_list: list[entities.LLMModelInfo]
requesters: dict[str, api.LLMAPIRequester]
requesters: dict[str, requester.LLMAPIRequester]
token_mgrs: dict[str, token.TokenManager]
@@ -42,7 +42,7 @@ class ModelManager:
for k, v in self.ap.provider_cfg.data['keys'].items():
self.token_mgrs[k] = token.TokenManager(k, v)
for api_cls in api.preregistered_requesters:
for api_cls in requester.preregistered_requesters:
api_inst = api_cls(self.ap)
await api_inst.initialize()
self.requesters[api_inst.name] = api_inst
@@ -94,7 +94,7 @@ class ModelManager:
model_name = model.get('model_name', default_model_info.model_name)
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
req = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
vision_supported = model.get('vision_supported', default_model_info.vision_supported)
@@ -102,7 +102,7 @@ class ModelManager:
name=model['name'],
model_name=model_name,
token_mgr=token_mgr,
requester=requester,
requester=req,
tool_call_supported=tool_call_supported,
vision_supported=vision_supported
)

View File

@@ -5,17 +5,17 @@ import traceback
import anthropic
from .. import api, entities, errors
from .. import entities, errors, requester
from .. import api, entities, errors
from .. import entities, errors
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....utils import image
@api.requester_class("anthropic-messages")
class AnthropicMessages(api.LLMAPIRequester):
@requester.requester_class("anthropic-messages")
class AnthropicMessages(requester.LLMAPIRequester):
"""Anthropic Messages API 请求器"""
client: anthropic.AsyncAnthropic

View File

@@ -12,15 +12,15 @@ import httpx
import aiohttp
import async_lru
from .. import api, entities, errors
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....utils import image
@api.requester_class("openai-chat-completions")
class OpenAIChatCompletions(api.LLMAPIRequester):
@requester.requester_class("openai-chat-completions")
class OpenAIChatCompletions(requester.LLMAPIRequester):
"""OpenAI ChatCompletion API 请求器"""
client: openai.AsyncClient

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from ....core import app
from . import chatcmpl
from .. import api, entities, errors
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("deepseek-chat-completions")
@requester.requester_class("deepseek-chat-completions")
class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Deepseek ChatCompletion API 请求器"""

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
import json
import asyncio
import aiohttp
import typing
from . import chatcmpl
from .. import entities, errors, requester
from ....core import app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from .. import entities as modelmgr_entities
@requester.requester_class("gitee-ai-chat-completions")
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Gitee AI ChatCompletions API 请求器"""
def __init__(self, ap: app.Application):
self.ap = ap
self.requester_cfg = ap.provider_cfg.data['requester']['gitee-ai-chat-completions'].copy()
async def _closure(
self,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
# gitee 不支持多模态把content都转换成纯文字
for m in req_messages:
if 'content' in m and isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
args["messages"] = req_messages
resp = await self._req(args)
message = await self._make_msg(resp)
return message

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from ....core import app
from . import chatcmpl
from .. import api, entities, errors
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("moonshot-chat-completions")
@requester.requester_class("moonshot-chat-completions")
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Moonshot ChatCompletion API 请求器"""

View File

@@ -8,7 +8,7 @@ from typing import Union, Mapping, Any, AsyncIterator
import async_lru
import ollama
from .. import api, entities, errors
from .. import entities, errors, requester
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....core import app
@@ -17,8 +17,8 @@ from ....utils import image
REQUESTER_NAME: str = "ollama-chat"
@api.requester_class(REQUESTER_NAME)
class OllamaChatCompletions(api.LLMAPIRequester):
@requester.requester_class(REQUESTER_NAME)
class OllamaChatCompletions(requester.LLMAPIRequester):
"""Ollama平台 ChatCompletion API请求器"""
client: ollama.AsyncClient
request_cfg: dict

View File

@@ -13,6 +13,9 @@
],
"deepseek": [
"sk-1234567890"
],
"gitee-ai": [
"XXXXX"
]
},
"requester": {
@@ -42,6 +45,11 @@
"base-url": "http://127.0.0.1:11434",
"args": {},
"timeout": 600
},
"gitee-ai-chat-completions": {
"base-url": "https://ai.gitee.com/v1",
"args": {},
"timeout": 120
}
},
"model": "gpt-3.5-turbo",