feat: Get Embedding demension from backend

This commit is contained in:
wizardchen
2025-08-15 10:34:24 +08:00
committed by lyingbug
parent 09d038eeb7
commit dacdad4dac
4 changed files with 207 additions and 21 deletions

View File

@@ -16,6 +16,7 @@ import (
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/embedding"
"github.com/Tencent/WeKnora/internal/models/utils/ollama"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
@@ -1273,6 +1274,66 @@ func (h *InitializationHandler) CheckRemoteModel(c *gin.Context) {
})
}
// TestEmbeddingModel 测试 Embedding 接口(本地或远程)是否可用
func (h *InitializationHandler) TestEmbeddingModel(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Testing embedding model connectivity and functionality")
var req struct {
Source string `json:"source" binding:"required"`
ModelName string `json:"modelName" binding:"required"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"`
Dimension int `json:"dimension"`
}
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse embedding test request", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// 构造 embedder 配置
cfg := embedding.Config{
Source: types.ModelSource(strings.ToLower(req.Source)),
BaseURL: req.BaseURL,
ModelName: req.ModelName,
APIKey: req.APIKey,
TruncatePromptTokens: 256,
Dimensions: req.Dimension,
ModelID: "",
}
emb, err := embedding.NewEmbedder(cfg)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{"model": req.ModelName})
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{`available`: false, `message`: fmt.Sprintf("创建Embedder失败: %v", err), `dimension`: 0},
})
return
}
// 执行一次最小化 embedding 调用
sample := "hello"
vec, err := emb.Embed(ctx, sample)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{"model": req.ModelName})
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{`available`: false, `message`: fmt.Sprintf("调用Embedding失败: %v", err), `dimension`: 0},
})
return
}
logger.Infof(ctx, "Embedding test succeeded, dim=%d", len(vec))
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{`available`: true, `message`: fmt.Sprintf("测试成功,向量维度=%d", len(vec)), `dimension`: len(vec)},
})
}
// checkRemoteModelConnection 检查远程模型连接的内部方法
func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
model *types.Model,
@@ -1318,7 +1379,7 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
// 检查响应状态
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
// 连接成功,现在检查模型是否存在
return h.checkModelExistence(ctx, resp, model.Name)
return true, "连接正常,请自行确保模型存在"
} else if resp.StatusCode == 401 {
return false, "认证失败请检查API Key"
} else if resp.StatusCode == 403 {

View File

@@ -78,6 +78,7 @@ func NewRouter(params RouterParams) *gin.Engine {
// 远程API相关接口不需要认证
r.POST("/api/v1/initialization/remote/check", params.InitializationHandler.CheckRemoteModel)
r.POST("/api/v1/initialization/embedding/test", params.InitializationHandler.TestEmbeddingModel)
r.POST("/api/v1/initialization/rerank/check", params.InitializationHandler.CheckRerankModel)
r.POST("/api/v1/initialization/multimodal/test", params.InitializationHandler.TestMultimodalFunction)