mirror of
https://github.com/Tencent/WeKnora.git
synced 2025-11-25 03:15:00 +08:00
feat: Get Embedding demension from backend
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/models/embedding"
|
||||
"github.com/Tencent/WeKnora/internal/models/utils/ollama"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
@@ -1273,6 +1274,66 @@ func (h *InitializationHandler) CheckRemoteModel(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestEmbeddingModel 测试 Embedding 接口(本地或远程)是否可用
|
||||
func (h *InitializationHandler) TestEmbeddingModel(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Testing embedding model connectivity and functionality")
|
||||
|
||||
var req struct {
|
||||
Source string `json:"source" binding:"required"`
|
||||
ModelName string `json:"modelName" binding:"required"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKey string `json:"apiKey"`
|
||||
Dimension int `json:"dimension"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error(ctx, "Failed to parse embedding test request", err)
|
||||
c.Error(errors.NewBadRequestError(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// 构造 embedder 配置
|
||||
cfg := embedding.Config{
|
||||
Source: types.ModelSource(strings.ToLower(req.Source)),
|
||||
BaseURL: req.BaseURL,
|
||||
ModelName: req.ModelName,
|
||||
APIKey: req.APIKey,
|
||||
TruncatePromptTokens: 256,
|
||||
Dimensions: req.Dimension,
|
||||
ModelID: "",
|
||||
}
|
||||
|
||||
emb, err := embedding.NewEmbedder(cfg)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{"model": req.ModelName})
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{`available`: false, `message`: fmt.Sprintf("创建Embedder失败: %v", err), `dimension`: 0},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 执行一次最小化 embedding 调用
|
||||
sample := "hello"
|
||||
vec, err := emb.Embed(ctx, sample)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{"model": req.ModelName})
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{`available`: false, `message`: fmt.Sprintf("调用Embedding失败: %v", err), `dimension`: 0},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Embedding test succeeded, dim=%d", len(vec))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{`available`: true, `message`: fmt.Sprintf("测试成功,向量维度=%d", len(vec)), `dimension`: len(vec)},
|
||||
})
|
||||
}
|
||||
|
||||
// checkRemoteModelConnection 检查远程模型连接的内部方法
|
||||
func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
|
||||
model *types.Model,
|
||||
@@ -1318,7 +1379,7 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
|
||||
// 检查响应状态
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
// 连接成功,现在检查模型是否存在
|
||||
return h.checkModelExistence(ctx, resp, model.Name)
|
||||
return true, "连接正常,请自行确保模型存在"
|
||||
} else if resp.StatusCode == 401 {
|
||||
return false, "认证失败,请检查API Key"
|
||||
} else if resp.StatusCode == 403 {
|
||||
|
||||
@@ -78,6 +78,7 @@ func NewRouter(params RouterParams) *gin.Engine {
|
||||
|
||||
// 远程API相关接口(不需要认证)
|
||||
r.POST("/api/v1/initialization/remote/check", params.InitializationHandler.CheckRemoteModel)
|
||||
r.POST("/api/v1/initialization/embedding/test", params.InitializationHandler.TestEmbeddingModel)
|
||||
r.POST("/api/v1/initialization/rerank/check", params.InitializationHandler.CheckRerankModel)
|
||||
r.POST("/api/v1/initialization/multimodal/test", params.InitializationHandler.TestMultimodalFunction)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user