mirror of
https://github.com/Tencent/WeKnora.git
synced 2025-11-25 11:29:31 +08:00
feat: support aliyun qwen3 model
This commit is contained in:
3
go.mod
3
go.mod
@@ -21,6 +21,7 @@ require (
|
|||||||
github.com/sashabaranov/go-openai v1.40.5
|
github.com/sashabaranov/go-openai v1.40.5
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/spf13/viper v1.20.1
|
github.com/spf13/viper v1.20.1
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/tencentyun/cos-go-sdk-v5 v0.7.65
|
github.com/tencentyun/cos-go-sdk-v5 v0.7.65
|
||||||
github.com/yanyiwu/gojieba v1.4.5
|
github.com/yanyiwu/gojieba v1.4.5
|
||||||
go.opentelemetry.io/otel v1.37.0
|
go.opentelemetry.io/otel v1.37.0
|
||||||
@@ -45,6 +46,7 @@ require (
|
|||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/clbanning/mxj v1.8.4 // indirect
|
github.com/clbanning/mxj v1.8.4 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.5 // indirect
|
github.com/cloudwego/base64x v0.1.5 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/elastic/elastic-transport-go/v8 v8.7.0 // indirect
|
github.com/elastic/elastic-transport-go/v8 v8.7.0 // indirect
|
||||||
@@ -81,6 +83,7 @@ require (
|
|||||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||||
github.com/rs/xid v1.6.0 // indirect
|
github.com/rs/xid v1.6.0 // indirect
|
||||||
|
|||||||
3
go.sum
3
go.sum
@@ -192,8 +192,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
|||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.563/go.mod h1:7sCQWVkxcsR38nffDW057DRGk8mUjK1Ing/EFOK8s8Y=
|
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.563/go.mod h1:7sCQWVkxcsR38nffDW057DRGk8mUjK1Ing/EFOK8s8Y=
|
||||||
|
|||||||
@@ -1361,7 +1361,8 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
|
|||||||
"content": "test",
|
"content": "test",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"max_tokens": 1,
|
"max_tokens": 1,
|
||||||
|
"enable_thinking": false, // for dashscope.aliyuncs qwen3-32b
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonData, err := json.Marshal(testRequest)
|
jsonData, err := json.Marshal(testRequest)
|
||||||
@@ -1386,6 +1387,11 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err == nil {
|
||||||
|
logger.Infof(ctx, "Response body: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
// 检查响应状态
|
// 检查响应状态
|
||||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
// 连接成功,模型可用
|
// 连接成功,模型可用
|
||||||
@@ -1401,58 +1407,6 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkModelExistence 检查指定模型是否在模型列表中存在
|
|
||||||
func (h *InitializationHandler) checkModelExistence(ctx context.Context,
|
|
||||||
resp *http.Response, modelName string) (bool, string) {
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return true, "连接正常,但无法验证模型列表"
|
|
||||||
}
|
|
||||||
|
|
||||||
var modelsResp struct {
|
|
||||||
Data []struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
} `json:"data"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// 尝试解析模型列表响应
|
|
||||||
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
|
||||||
// 如果无法解析,可能是非标准API,只要连接成功就认为可用
|
|
||||||
return true, "连接正常"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查模型是否在列表中
|
|
||||||
for _, model := range modelsResp.Data {
|
|
||||||
if model.ID == modelName {
|
|
||||||
return true, "连接正常,模型存在"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 模型不在列表中,返回可用的模型建议
|
|
||||||
if len(modelsResp.Data) > 0 {
|
|
||||||
availableModels := make([]string, 0, min(3, len(modelsResp.Data)))
|
|
||||||
for i, model := range modelsResp.Data {
|
|
||||||
if i >= 3 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
availableModels = append(availableModels, model.ID)
|
|
||||||
}
|
|
||||||
return false, fmt.Sprintf("模型 '%s' 不存在,可用模型: %s", modelName, strings.Join(availableModels, ", "))
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, fmt.Sprintf("模型 '%s' 不存在", modelName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// min returns the minimum of two integers
|
|
||||||
func min(a, b int) int {
|
|
||||||
if a < b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkRerankModelConnection 检查Rerank模型连接和功能的内部方法
|
// checkRerankModelConnection 检查Rerank模型连接和功能的内部方法
|
||||||
func (h *InitializationHandler) checkRerankModelConnection(ctx context.Context,
|
func (h *InitializationHandler) checkRerankModelConnection(ctx context.Context,
|
||||||
modelName, baseURL, apiKey string) (bool, string) {
|
modelName, baseURL, apiKey string) (bool, string) {
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
package chat
|
package chat
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Tencent/WeKnora/internal/types"
|
"github.com/Tencent/WeKnora/internal/types"
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
@@ -13,6 +17,14 @@ type RemoteAPIChat struct {
|
|||||||
modelName string
|
modelName string
|
||||||
client *openai.Client
|
client *openai.Client
|
||||||
modelID string
|
modelID string
|
||||||
|
baseURL string
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QwenChatCompletionRequest 用于 qwen 模型的自定义请求结构体
|
||||||
|
type QwenChatCompletionRequest struct {
|
||||||
|
openai.ChatCompletionRequest
|
||||||
|
EnableThinking *bool `json:"enable_thinking,omitempty"` // qwen 模型专用字段
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRemoteAPIChat 调用远程API 聊天实例
|
// NewRemoteAPIChat 调用远程API 聊天实例
|
||||||
@@ -26,6 +38,8 @@ func NewRemoteAPIChat(chatConfig *ChatConfig) (*RemoteAPIChat, error) {
|
|||||||
modelName: chatConfig.ModelName,
|
modelName: chatConfig.ModelName,
|
||||||
client: openai.NewClientWithConfig(config),
|
client: openai.NewClientWithConfig(config),
|
||||||
modelID: chatConfig.ModelID,
|
modelID: chatConfig.ModelID,
|
||||||
|
baseURL: chatConfig.BaseURL,
|
||||||
|
apiKey: apiKey,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,6 +55,27 @@ func (c *RemoteAPIChat) convertMessages(messages []Message) []openai.ChatComplet
|
|||||||
return openaiMessages
|
return openaiMessages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isQwenModel 检查是否为 qwen 模型
|
||||||
|
func (c *RemoteAPIChat) isAliyunQwen3Model() bool {
|
||||||
|
return strings.HasPrefix(c.modelName, "qwen3-") && c.baseURL == "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildQwenChatCompletionRequest 构建 qwen 模型的聊天请求参数
|
||||||
|
func (c *RemoteAPIChat) buildQwenChatCompletionRequest(messages []Message,
|
||||||
|
opts *ChatOptions, isStream bool,
|
||||||
|
) QwenChatCompletionRequest {
|
||||||
|
req := QwenChatCompletionRequest{
|
||||||
|
ChatCompletionRequest: c.buildChatCompletionRequest(messages, opts, isStream),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于 qwen 模型,在非流式调用中强制设置 enable_thinking: false
|
||||||
|
if !isStream {
|
||||||
|
enableThinking := false
|
||||||
|
req.EnableThinking = &enableThinking
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
// buildChatCompletionRequest 构建聊天请求参数
|
// buildChatCompletionRequest 构建聊天请求参数
|
||||||
func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
|
func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
|
||||||
opts *ChatOptions, isStream bool,
|
opts *ChatOptions, isStream bool,
|
||||||
@@ -71,11 +106,6 @@ func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
|
|||||||
if opts.PresencePenalty > 0 {
|
if opts.PresencePenalty > 0 {
|
||||||
req.PresencePenalty = float32(opts.PresencePenalty)
|
req.PresencePenalty = float32(opts.PresencePenalty)
|
||||||
}
|
}
|
||||||
if opts.Thinking != nil {
|
|
||||||
req.ChatTemplateKwargs = map[string]any{
|
|
||||||
"enable_thinking": *opts.Thinking,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return req
|
return req
|
||||||
@@ -83,6 +113,11 @@ func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
|
|||||||
|
|
||||||
// Chat 进行非流式聊天
|
// Chat 进行非流式聊天
|
||||||
func (c *RemoteAPIChat) Chat(ctx context.Context, messages []Message, opts *ChatOptions) (*types.ChatResponse, error) {
|
func (c *RemoteAPIChat) Chat(ctx context.Context, messages []Message, opts *ChatOptions) (*types.ChatResponse, error) {
|
||||||
|
// 如果是 qwen 模型,使用自定义请求
|
||||||
|
if c.isAliyunQwen3Model() {
|
||||||
|
return c.chatWithQwen(ctx, messages, opts)
|
||||||
|
}
|
||||||
|
|
||||||
// 构建请求参数
|
// 构建请求参数
|
||||||
req := c.buildChatCompletionRequest(messages, opts, false)
|
req := c.buildChatCompletionRequest(messages, opts, false)
|
||||||
|
|
||||||
@@ -111,6 +146,68 @@ func (c *RemoteAPIChat) Chat(ctx context.Context, messages []Message, opts *Chat
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// chatWithQwen 使用自定义请求处理 qwen 模型
|
||||||
|
func (c *RemoteAPIChat) chatWithQwen(ctx context.Context, messages []Message, opts *ChatOptions) (*types.ChatResponse, error) {
|
||||||
|
// 构建 qwen 请求参数
|
||||||
|
req := c.buildQwenChatCompletionRequest(messages, opts, false)
|
||||||
|
|
||||||
|
// 序列化请求
|
||||||
|
jsonData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 URL
|
||||||
|
endpoint := c.baseURL + "/chat/completions"
|
||||||
|
|
||||||
|
// 创建 HTTP 请求
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置请求头
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 检查响应状态
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("API request failed with status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析响应
|
||||||
|
var chatResp openai.ChatCompletionResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(chatResp.Choices) == 0 {
|
||||||
|
return nil, fmt.Errorf("no response from API")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换响应格式
|
||||||
|
return &types.ChatResponse{
|
||||||
|
Content: chatResp.Choices[0].Message.Content,
|
||||||
|
Usage: struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}{
|
||||||
|
PromptTokens: chatResp.Usage.PromptTokens,
|
||||||
|
CompletionTokens: chatResp.Usage.CompletionTokens,
|
||||||
|
TotalTokens: chatResp.Usage.TotalTokens,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ChatStream 进行流式聊天
|
// ChatStream 进行流式聊天
|
||||||
func (c *RemoteAPIChat) ChatStream(ctx context.Context,
|
func (c *RemoteAPIChat) ChatStream(ctx context.Context,
|
||||||
messages []Message, opts *ChatOptions,
|
messages []Message, opts *ChatOptions,
|
||||||
|
|||||||
127
internal/models/chat/remote_api_test.go
Normal file
127
internal/models/chat/remote_api_test.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Tencent/WeKnora/internal/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestRemoteAPIChat 综合测试 Remote API Chat 的所有功能
|
||||||
|
func TestRemoteAPIChat(t *testing.T) {
|
||||||
|
// 获取环境变量
|
||||||
|
deepseekAPIKey := os.Getenv("DEEPSEEK_API_KEY")
|
||||||
|
aliyunAPIKey := os.Getenv("ALIYUN_API_KEY")
|
||||||
|
|
||||||
|
// 定义测试配置
|
||||||
|
testConfigs := []struct {
|
||||||
|
name string
|
||||||
|
apiKey string
|
||||||
|
config *ChatConfig
|
||||||
|
skipMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "DeepSeek API",
|
||||||
|
apiKey: deepseekAPIKey,
|
||||||
|
config: &ChatConfig{
|
||||||
|
Source: types.ModelSourceRemote,
|
||||||
|
BaseURL: "https://api.deepseek.com/v1",
|
||||||
|
ModelName: "deepseek-chat",
|
||||||
|
APIKey: deepseekAPIKey,
|
||||||
|
ModelID: "deepseek-chat",
|
||||||
|
},
|
||||||
|
skipMsg: "DEEPSEEK_API_KEY environment variable not set",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Aliyun DeepSeek",
|
||||||
|
apiKey: aliyunAPIKey,
|
||||||
|
config: &ChatConfig{
|
||||||
|
Source: types.ModelSourceRemote,
|
||||||
|
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
ModelName: "deepseek-v3.1",
|
||||||
|
APIKey: aliyunAPIKey,
|
||||||
|
ModelID: "deepseek-v3.1",
|
||||||
|
},
|
||||||
|
skipMsg: "ALIYUN_API_KEY environment variable not set",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Aliyun Qwen3-32b",
|
||||||
|
apiKey: aliyunAPIKey,
|
||||||
|
config: &ChatConfig{
|
||||||
|
Source: types.ModelSourceRemote,
|
||||||
|
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
ModelName: "qwen3-32b",
|
||||||
|
APIKey: aliyunAPIKey,
|
||||||
|
ModelID: "qwen3-32b",
|
||||||
|
},
|
||||||
|
skipMsg: "ALIYUN_API_KEY environment variable not set",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Aliyun Qwen-max",
|
||||||
|
apiKey: aliyunAPIKey,
|
||||||
|
config: &ChatConfig{
|
||||||
|
Source: types.ModelSourceRemote,
|
||||||
|
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
ModelName: "qwen-max",
|
||||||
|
APIKey: aliyunAPIKey,
|
||||||
|
ModelID: "qwen-max",
|
||||||
|
},
|
||||||
|
skipMsg: "ALIYUN_API_KEY environment variable not set",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试消息
|
||||||
|
testMessages := []Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试选项
|
||||||
|
testOptions := &ChatOptions{
|
||||||
|
Temperature: 0.7,
|
||||||
|
MaxTokens: 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建上下文
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 遍历所有配置进行测试
|
||||||
|
for _, tc := range testConfigs {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// 检查 API Key
|
||||||
|
if tc.apiKey == "" {
|
||||||
|
t.Skip(tc.skipMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建聊天实例
|
||||||
|
chat, err := NewRemoteAPIChat(tc.config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.config.ModelName, chat.GetModelName())
|
||||||
|
assert.Equal(t, tc.config.ModelID, chat.GetModelID())
|
||||||
|
|
||||||
|
// 测试基本聊天功能
|
||||||
|
t.Run("Basic Chat", func(t *testing.T) {
|
||||||
|
response, err := chat.Chat(ctx, testMessages, testOptions)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, response.Content)
|
||||||
|
assert.Greater(t, response.Usage.TotalTokens, 0)
|
||||||
|
assert.Greater(t, response.Usage.PromptTokens, 0)
|
||||||
|
assert.Greater(t, response.Usage.CompletionTokens, 0)
|
||||||
|
|
||||||
|
t.Logf("%s Response: %s", tc.name, response.Content)
|
||||||
|
t.Logf("Usage: Prompt=%d, Completion=%d, Total=%d",
|
||||||
|
response.Usage.PromptTokens,
|
||||||
|
response.Usage.CompletionTokens,
|
||||||
|
response.Usage.TotalTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user