feat: support aliyun qwen3 model

This commit is contained in:
wizardchen
2025-09-08 14:52:07 +08:00
committed by lyingbug
parent 4489a4da7f
commit 65b2d9eb84
5 changed files with 241 additions and 59 deletions

3
go.mod
View File

@@ -21,6 +21,7 @@ require (
github.com/sashabaranov/go-openai v1.40.5
github.com/sirupsen/logrus v1.9.3
github.com/spf13/viper v1.20.1
github.com/stretchr/testify v1.11.1
github.com/tencentyun/cos-go-sdk-v5 v0.7.65
github.com/yanyiwu/gojieba v1.4.5
go.opentelemetry.io/otel v1.37.0
@@ -45,6 +46,7 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/clbanning/mxj v1.8.4 // indirect
github.com/cloudwego/base64x v0.1.5 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/elastic/elastic-transport-go/v8 v8.7.0 // indirect
@@ -81,6 +83,7 @@ require (
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/rs/xid v1.6.0 // indirect

3
go.sum
View File

@@ -192,8 +192,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.563/go.mod h1:7sCQWVkxcsR38nffDW057DRGk8mUjK1Ing/EFOK8s8Y=

View File

@@ -1361,7 +1361,8 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
"content": "test",
},
},
"max_tokens": 1,
"max_tokens": 1,
"enable_thinking": false, // for dashscope.aliyuncs qwen3-32b
}
jsonData, err := json.Marshal(testRequest)
@@ -1386,6 +1387,11 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err == nil {
logger.Infof(ctx, "Response body: %s", string(body))
}
// 检查响应状态
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
// 连接成功,模型可用
@@ -1401,58 +1407,6 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
}
}
// checkModelExistence 检查指定模型是否在模型列表中存在
func (h *InitializationHandler) checkModelExistence(ctx context.Context,
resp *http.Response, modelName string) (bool, string) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return true, "连接正常,但无法验证模型列表"
}
var modelsResp struct {
Data []struct {
ID string `json:"id"`
Object string `json:"object"`
} `json:"data"`
Object string `json:"object"`
}
// 尝试解析模型列表响应
if err := json.Unmarshal(body, &modelsResp); err != nil {
// 如果无法解析可能是非标准API只要连接成功就认为可用
return true, "连接正常"
}
// 检查模型是否在列表中
for _, model := range modelsResp.Data {
if model.ID == modelName {
return true, "连接正常,模型存在"
}
}
// 模型不在列表中,返回可用的模型建议
if len(modelsResp.Data) > 0 {
availableModels := make([]string, 0, min(3, len(modelsResp.Data)))
for i, model := range modelsResp.Data {
if i >= 3 {
break
}
availableModels = append(availableModels, model.ID)
}
return false, fmt.Sprintf("模型 '%s' 不存在,可用模型: %s", modelName, strings.Join(availableModels, ", "))
}
return false, fmt.Sprintf("模型 '%s' 不存在", modelName)
}
// min returns the minimum of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}
// checkRerankModelConnection 检查Rerank模型连接和功能的内部方法
func (h *InitializationHandler) checkRerankModelConnection(ctx context.Context,
modelName, baseURL, apiKey string) (bool, string) {

View File

@@ -1,8 +1,12 @@
package chat
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/Tencent/WeKnora/internal/types"
"github.com/sashabaranov/go-openai"
@@ -13,6 +17,14 @@ type RemoteAPIChat struct {
modelName string
client *openai.Client
modelID string
baseURL string
apiKey string
}
// QwenChatCompletionRequest 用于 qwen 模型的自定义请求结构体
type QwenChatCompletionRequest struct {
openai.ChatCompletionRequest
EnableThinking *bool `json:"enable_thinking,omitempty"` // qwen 模型专用字段
}
// NewRemoteAPIChat 调用远程API 聊天实例
@@ -26,6 +38,8 @@ func NewRemoteAPIChat(chatConfig *ChatConfig) (*RemoteAPIChat, error) {
modelName: chatConfig.ModelName,
client: openai.NewClientWithConfig(config),
modelID: chatConfig.ModelID,
baseURL: chatConfig.BaseURL,
apiKey: apiKey,
}, nil
}
@@ -41,6 +55,27 @@ func (c *RemoteAPIChat) convertMessages(messages []Message) []openai.ChatComplet
return openaiMessages
}
// isQwenModel 检查是否为 qwen 模型
func (c *RemoteAPIChat) isAliyunQwen3Model() bool {
return strings.HasPrefix(c.modelName, "qwen3-") && c.baseURL == "https://dashscope.aliyuncs.com/compatible-mode/v1"
}
// buildQwenChatCompletionRequest 构建 qwen 模型的聊天请求参数
func (c *RemoteAPIChat) buildQwenChatCompletionRequest(messages []Message,
opts *ChatOptions, isStream bool,
) QwenChatCompletionRequest {
req := QwenChatCompletionRequest{
ChatCompletionRequest: c.buildChatCompletionRequest(messages, opts, isStream),
}
// 对于 qwen 模型,在非流式调用中强制设置 enable_thinking: false
if !isStream {
enableThinking := false
req.EnableThinking = &enableThinking
}
return req
}
// buildChatCompletionRequest 构建聊天请求参数
func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
opts *ChatOptions, isStream bool,
@@ -71,11 +106,6 @@ func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
if opts.PresencePenalty > 0 {
req.PresencePenalty = float32(opts.PresencePenalty)
}
if opts.Thinking != nil {
req.ChatTemplateKwargs = map[string]any{
"enable_thinking": *opts.Thinking,
}
}
}
return req
@@ -83,6 +113,11 @@ func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
// Chat 进行非流式聊天
func (c *RemoteAPIChat) Chat(ctx context.Context, messages []Message, opts *ChatOptions) (*types.ChatResponse, error) {
// 如果是 qwen 模型,使用自定义请求
if c.isAliyunQwen3Model() {
return c.chatWithQwen(ctx, messages, opts)
}
// 构建请求参数
req := c.buildChatCompletionRequest(messages, opts, false)
@@ -111,6 +146,68 @@ func (c *RemoteAPIChat) Chat(ctx context.Context, messages []Message, opts *Chat
}, nil
}
// chatWithQwen 使用自定义请求处理 qwen 模型
func (c *RemoteAPIChat) chatWithQwen(ctx context.Context, messages []Message, opts *ChatOptions) (*types.ChatResponse, error) {
// 构建 qwen 请求参数
req := c.buildQwenChatCompletionRequest(messages, opts, false)
// 序列化请求
jsonData, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
// 构建 URL
endpoint := c.baseURL + "/chat/completions"
// 创建 HTTP 请求
httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
// 发送请求
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status: %d", resp.StatusCode)
}
// 解析响应
var chatResp openai.ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return nil, fmt.Errorf("decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return nil, fmt.Errorf("no response from API")
}
// 转换响应格式
return &types.ChatResponse{
Content: chatResp.Choices[0].Message.Content,
Usage: struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}{
PromptTokens: chatResp.Usage.PromptTokens,
CompletionTokens: chatResp.Usage.CompletionTokens,
TotalTokens: chatResp.Usage.TotalTokens,
},
}, nil
}
// ChatStream 进行流式聊天
func (c *RemoteAPIChat) ChatStream(ctx context.Context,
messages []Message, opts *ChatOptions,

View File

@@ -0,0 +1,127 @@
package chat
import (
"context"
"os"
"testing"
"time"
"github.com/Tencent/WeKnora/internal/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRemoteAPIChat 综合测试 Remote API Chat 的所有功能
func TestRemoteAPIChat(t *testing.T) {
// 获取环境变量
deepseekAPIKey := os.Getenv("DEEPSEEK_API_KEY")
aliyunAPIKey := os.Getenv("ALIYUN_API_KEY")
// 定义测试配置
testConfigs := []struct {
name string
apiKey string
config *ChatConfig
skipMsg string
}{
{
name: "DeepSeek API",
apiKey: deepseekAPIKey,
config: &ChatConfig{
Source: types.ModelSourceRemote,
BaseURL: "https://api.deepseek.com/v1",
ModelName: "deepseek-chat",
APIKey: deepseekAPIKey,
ModelID: "deepseek-chat",
},
skipMsg: "DEEPSEEK_API_KEY environment variable not set",
},
{
name: "Aliyun DeepSeek",
apiKey: aliyunAPIKey,
config: &ChatConfig{
Source: types.ModelSourceRemote,
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
ModelName: "deepseek-v3.1",
APIKey: aliyunAPIKey,
ModelID: "deepseek-v3.1",
},
skipMsg: "ALIYUN_API_KEY environment variable not set",
},
{
name: "Aliyun Qwen3-32b",
apiKey: aliyunAPIKey,
config: &ChatConfig{
Source: types.ModelSourceRemote,
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
ModelName: "qwen3-32b",
APIKey: aliyunAPIKey,
ModelID: "qwen3-32b",
},
skipMsg: "ALIYUN_API_KEY environment variable not set",
},
{
name: "Aliyun Qwen-max",
apiKey: aliyunAPIKey,
config: &ChatConfig{
Source: types.ModelSourceRemote,
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
ModelName: "qwen-max",
APIKey: aliyunAPIKey,
ModelID: "qwen-max",
},
skipMsg: "ALIYUN_API_KEY environment variable not set",
},
}
// 测试消息
testMessages := []Message{
{
Role: "user",
Content: "test",
},
}
// 测试选项
testOptions := &ChatOptions{
Temperature: 0.7,
MaxTokens: 100,
}
// 创建上下文
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 遍历所有配置进行测试
for _, tc := range testConfigs {
t.Run(tc.name, func(t *testing.T) {
// 检查 API Key
if tc.apiKey == "" {
t.Skip(tc.skipMsg)
}
// 创建聊天实例
chat, err := NewRemoteAPIChat(tc.config)
require.NoError(t, err)
assert.Equal(t, tc.config.ModelName, chat.GetModelName())
assert.Equal(t, tc.config.ModelID, chat.GetModelID())
// 测试基本聊天功能
t.Run("Basic Chat", func(t *testing.T) {
response, err := chat.Chat(ctx, testMessages, testOptions)
require.NoError(t, err)
assert.NotEmpty(t, response.Content)
assert.Greater(t, response.Usage.TotalTokens, 0)
assert.Greater(t, response.Usage.PromptTokens, 0)
assert.Greater(t, response.Usage.CompletionTokens, 0)
t.Logf("%s Response: %s", tc.name, response.Content)
t.Logf("Usage: Prompt=%d, Completion=%d, Total=%d",
response.Usage.PromptTokens,
response.Usage.CompletionTokens,
response.Usage.TotalTokens)
})
})
}
}