mirror of
https://github.com/Tencent/WeKnora.git
synced 2025-11-25 03:15:00 +08:00
feat: support aliyun qwen3 model
This commit is contained in:
3
go.mod
3
go.mod
@@ -21,6 +21,7 @@ require (
|
||||
github.com/sashabaranov/go-openai v1.40.5
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/viper v1.20.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tencentyun/cos-go-sdk-v5 v0.7.65
|
||||
github.com/yanyiwu/gojieba v1.4.5
|
||||
go.opentelemetry.io/otel v1.37.0
|
||||
@@ -45,6 +46,7 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/clbanning/mxj v1.8.4 // indirect
|
||||
github.com/cloudwego/base64x v0.1.5 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/elastic/elastic-transport-go/v8 v8.7.0 // indirect
|
||||
@@ -81,6 +83,7 @@ require (
|
||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||
github.com/rs/xid v1.6.0 // indirect
|
||||
|
||||
3
go.sum
3
go.sum
@@ -192,8 +192,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.563/go.mod h1:7sCQWVkxcsR38nffDW057DRGk8mUjK1Ing/EFOK8s8Y=
|
||||
|
||||
@@ -1361,7 +1361,8 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
|
||||
"content": "test",
|
||||
},
|
||||
},
|
||||
"max_tokens": 1,
|
||||
"max_tokens": 1,
|
||||
"enable_thinking": false, // for dashscope.aliyuncs qwen3-32b
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(testRequest)
|
||||
@@ -1386,6 +1387,11 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err == nil {
|
||||
logger.Infof(ctx, "Response body: %s", string(body))
|
||||
}
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
// 连接成功,模型可用
|
||||
@@ -1401,58 +1407,6 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// checkModelExistence 检查指定模型是否在模型列表中存在
|
||||
func (h *InitializationHandler) checkModelExistence(ctx context.Context,
|
||||
resp *http.Response, modelName string) (bool, string) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return true, "连接正常,但无法验证模型列表"
|
||||
}
|
||||
|
||||
var modelsResp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
} `json:"data"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
// 尝试解析模型列表响应
|
||||
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
||||
// 如果无法解析,可能是非标准API,只要连接成功就认为可用
|
||||
return true, "连接正常"
|
||||
}
|
||||
|
||||
// 检查模型是否在列表中
|
||||
for _, model := range modelsResp.Data {
|
||||
if model.ID == modelName {
|
||||
return true, "连接正常,模型存在"
|
||||
}
|
||||
}
|
||||
|
||||
// 模型不在列表中,返回可用的模型建议
|
||||
if len(modelsResp.Data) > 0 {
|
||||
availableModels := make([]string, 0, min(3, len(modelsResp.Data)))
|
||||
for i, model := range modelsResp.Data {
|
||||
if i >= 3 {
|
||||
break
|
||||
}
|
||||
availableModels = append(availableModels, model.ID)
|
||||
}
|
||||
return false, fmt.Sprintf("模型 '%s' 不存在,可用模型: %s", modelName, strings.Join(availableModels, ", "))
|
||||
}
|
||||
|
||||
return false, fmt.Sprintf("模型 '%s' 不存在", modelName)
|
||||
}
|
||||
|
||||
// min returns the minimum of two integers
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// checkRerankModelConnection 检查Rerank模型连接和功能的内部方法
|
||||
func (h *InitializationHandler) checkRerankModelConnection(ctx context.Context,
|
||||
modelName, baseURL, apiKey string) (bool, string) {
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
@@ -13,6 +17,14 @@ type RemoteAPIChat struct {
|
||||
modelName string
|
||||
client *openai.Client
|
||||
modelID string
|
||||
baseURL string
|
||||
apiKey string
|
||||
}
|
||||
|
||||
// QwenChatCompletionRequest 用于 qwen 模型的自定义请求结构体
|
||||
type QwenChatCompletionRequest struct {
|
||||
openai.ChatCompletionRequest
|
||||
EnableThinking *bool `json:"enable_thinking,omitempty"` // qwen 模型专用字段
|
||||
}
|
||||
|
||||
// NewRemoteAPIChat 调用远程API 聊天实例
|
||||
@@ -26,6 +38,8 @@ func NewRemoteAPIChat(chatConfig *ChatConfig) (*RemoteAPIChat, error) {
|
||||
modelName: chatConfig.ModelName,
|
||||
client: openai.NewClientWithConfig(config),
|
||||
modelID: chatConfig.ModelID,
|
||||
baseURL: chatConfig.BaseURL,
|
||||
apiKey: apiKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -41,6 +55,27 @@ func (c *RemoteAPIChat) convertMessages(messages []Message) []openai.ChatComplet
|
||||
return openaiMessages
|
||||
}
|
||||
|
||||
// isQwenModel 检查是否为 qwen 模型
|
||||
func (c *RemoteAPIChat) isAliyunQwen3Model() bool {
|
||||
return strings.HasPrefix(c.modelName, "qwen3-") && c.baseURL == "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
|
||||
// buildQwenChatCompletionRequest 构建 qwen 模型的聊天请求参数
|
||||
func (c *RemoteAPIChat) buildQwenChatCompletionRequest(messages []Message,
|
||||
opts *ChatOptions, isStream bool,
|
||||
) QwenChatCompletionRequest {
|
||||
req := QwenChatCompletionRequest{
|
||||
ChatCompletionRequest: c.buildChatCompletionRequest(messages, opts, isStream),
|
||||
}
|
||||
|
||||
// 对于 qwen 模型,在非流式调用中强制设置 enable_thinking: false
|
||||
if !isStream {
|
||||
enableThinking := false
|
||||
req.EnableThinking = &enableThinking
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildChatCompletionRequest 构建聊天请求参数
|
||||
func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
|
||||
opts *ChatOptions, isStream bool,
|
||||
@@ -71,11 +106,6 @@ func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
|
||||
if opts.PresencePenalty > 0 {
|
||||
req.PresencePenalty = float32(opts.PresencePenalty)
|
||||
}
|
||||
if opts.Thinking != nil {
|
||||
req.ChatTemplateKwargs = map[string]any{
|
||||
"enable_thinking": *opts.Thinking,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return req
|
||||
@@ -83,6 +113,11 @@ func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
|
||||
|
||||
// Chat 进行非流式聊天
|
||||
func (c *RemoteAPIChat) Chat(ctx context.Context, messages []Message, opts *ChatOptions) (*types.ChatResponse, error) {
|
||||
// 如果是 qwen 模型,使用自定义请求
|
||||
if c.isAliyunQwen3Model() {
|
||||
return c.chatWithQwen(ctx, messages, opts)
|
||||
}
|
||||
|
||||
// 构建请求参数
|
||||
req := c.buildChatCompletionRequest(messages, opts, false)
|
||||
|
||||
@@ -111,6 +146,68 @@ func (c *RemoteAPIChat) Chat(ctx context.Context, messages []Message, opts *Chat
|
||||
}, nil
|
||||
}
|
||||
|
||||
// chatWithQwen 使用自定义请求处理 qwen 模型
|
||||
func (c *RemoteAPIChat) chatWithQwen(ctx context.Context, messages []Message, opts *ChatOptions) (*types.ChatResponse, error) {
|
||||
// 构建 qwen 请求参数
|
||||
req := c.buildQwenChatCompletionRequest(messages, opts, false)
|
||||
|
||||
// 序列化请求
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
// 构建 URL
|
||||
endpoint := c.baseURL + "/chat/completions"
|
||||
|
||||
// 创建 HTTP 请求
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var chatResp openai.ChatCompletionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return nil, fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(chatResp.Choices) == 0 {
|
||||
return nil, fmt.Errorf("no response from API")
|
||||
}
|
||||
|
||||
// 转换响应格式
|
||||
return &types.ChatResponse{
|
||||
Content: chatResp.Choices[0].Message.Content,
|
||||
Usage: struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}{
|
||||
PromptTokens: chatResp.Usage.PromptTokens,
|
||||
CompletionTokens: chatResp.Usage.CompletionTokens,
|
||||
TotalTokens: chatResp.Usage.TotalTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ChatStream 进行流式聊天
|
||||
func (c *RemoteAPIChat) ChatStream(ctx context.Context,
|
||||
messages []Message, opts *ChatOptions,
|
||||
|
||||
127
internal/models/chat/remote_api_test.go
Normal file
127
internal/models/chat/remote_api_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRemoteAPIChat 综合测试 Remote API Chat 的所有功能
|
||||
func TestRemoteAPIChat(t *testing.T) {
|
||||
// 获取环境变量
|
||||
deepseekAPIKey := os.Getenv("DEEPSEEK_API_KEY")
|
||||
aliyunAPIKey := os.Getenv("ALIYUN_API_KEY")
|
||||
|
||||
// 定义测试配置
|
||||
testConfigs := []struct {
|
||||
name string
|
||||
apiKey string
|
||||
config *ChatConfig
|
||||
skipMsg string
|
||||
}{
|
||||
{
|
||||
name: "DeepSeek API",
|
||||
apiKey: deepseekAPIKey,
|
||||
config: &ChatConfig{
|
||||
Source: types.ModelSourceRemote,
|
||||
BaseURL: "https://api.deepseek.com/v1",
|
||||
ModelName: "deepseek-chat",
|
||||
APIKey: deepseekAPIKey,
|
||||
ModelID: "deepseek-chat",
|
||||
},
|
||||
skipMsg: "DEEPSEEK_API_KEY environment variable not set",
|
||||
},
|
||||
{
|
||||
name: "Aliyun DeepSeek",
|
||||
apiKey: aliyunAPIKey,
|
||||
config: &ChatConfig{
|
||||
Source: types.ModelSourceRemote,
|
||||
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
ModelName: "deepseek-v3.1",
|
||||
APIKey: aliyunAPIKey,
|
||||
ModelID: "deepseek-v3.1",
|
||||
},
|
||||
skipMsg: "ALIYUN_API_KEY environment variable not set",
|
||||
},
|
||||
{
|
||||
name: "Aliyun Qwen3-32b",
|
||||
apiKey: aliyunAPIKey,
|
||||
config: &ChatConfig{
|
||||
Source: types.ModelSourceRemote,
|
||||
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
ModelName: "qwen3-32b",
|
||||
APIKey: aliyunAPIKey,
|
||||
ModelID: "qwen3-32b",
|
||||
},
|
||||
skipMsg: "ALIYUN_API_KEY environment variable not set",
|
||||
},
|
||||
{
|
||||
name: "Aliyun Qwen-max",
|
||||
apiKey: aliyunAPIKey,
|
||||
config: &ChatConfig{
|
||||
Source: types.ModelSourceRemote,
|
||||
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
ModelName: "qwen-max",
|
||||
APIKey: aliyunAPIKey,
|
||||
ModelID: "qwen-max",
|
||||
},
|
||||
skipMsg: "ALIYUN_API_KEY environment variable not set",
|
||||
},
|
||||
}
|
||||
|
||||
// 测试消息
|
||||
testMessages := []Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "test",
|
||||
},
|
||||
}
|
||||
|
||||
// 测试选项
|
||||
testOptions := &ChatOptions{
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 100,
|
||||
}
|
||||
|
||||
// 创建上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 遍历所有配置进行测试
|
||||
for _, tc := range testConfigs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 检查 API Key
|
||||
if tc.apiKey == "" {
|
||||
t.Skip(tc.skipMsg)
|
||||
}
|
||||
|
||||
// 创建聊天实例
|
||||
chat, err := NewRemoteAPIChat(tc.config)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.config.ModelName, chat.GetModelName())
|
||||
assert.Equal(t, tc.config.ModelID, chat.GetModelID())
|
||||
|
||||
// 测试基本聊天功能
|
||||
t.Run("Basic Chat", func(t *testing.T) {
|
||||
response, err := chat.Chat(ctx, testMessages, testOptions)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, response.Content)
|
||||
assert.Greater(t, response.Usage.TotalTokens, 0)
|
||||
assert.Greater(t, response.Usage.PromptTokens, 0)
|
||||
assert.Greater(t, response.Usage.CompletionTokens, 0)
|
||||
|
||||
t.Logf("%s Response: %s", tc.name, response.Content)
|
||||
t.Logf("Usage: Prompt=%d, Completion=%d, Total=%d",
|
||||
response.Usage.PromptTokens,
|
||||
response.Usage.CompletionTokens,
|
||||
response.Usage.TotalTokens)
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user