mirror of
https://github.com/Tencent/WeKnora.git
synced 2025-11-25 03:15:00 +08:00
78 lines
1.8 KiB
Go
78 lines
1.8 KiB
Go
package rerank
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/Tencent/WeKnora/internal/types"
|
|
)
|
|
|
|
// Reranker defines the interface for document reranking
|
|
type Reranker interface {
|
|
// Rerank reranks documents based on relevance to the query
|
|
Rerank(ctx context.Context, query string, documents []string) ([]RankResult, error)
|
|
|
|
// GetModelName returns the model name
|
|
GetModelName() string
|
|
|
|
// GetModelID returns the model ID
|
|
GetModelID() string
|
|
}
|
|
|
|
type RankResult struct {
|
|
Index int `json:"index"`
|
|
Document DocumentInfo `json:"document"`
|
|
RelevanceScore float64 `json:"relevance_score"`
|
|
}
|
|
//Handles the RelevanceScore field by checking if RelevanceScore exists first, otherwise falls back to Score field
|
|
func (r *RankResult) UnmarshalJSON(data []byte) error {
|
|
|
|
var temp struct {
|
|
Index int `json:"index"`
|
|
Document DocumentInfo `json:"document"`
|
|
RelevanceScore *float64 `json:"relevance_score"`
|
|
Score *float64 `json:"score"`
|
|
}
|
|
|
|
|
|
if err := json.Unmarshal(data, &temp); err != nil {
|
|
return fmt.Errorf("failed to unmarshal rank result: %w", err)
|
|
}
|
|
|
|
r.Index = temp.Index
|
|
r.Document = temp.Document
|
|
|
|
if temp.RelevanceScore != nil {
|
|
r.RelevanceScore = *temp.RelevanceScore
|
|
} else if temp.Score != nil {
|
|
r.RelevanceScore = *temp.Score
|
|
}
|
|
|
|
|
|
return nil
|
|
}
|
|
|
|
type DocumentInfo struct {
|
|
Text string `json:"text"`
|
|
}
|
|
|
|
type RerankerConfig struct {
|
|
APIKey string
|
|
BaseURL string
|
|
ModelName string
|
|
Source types.ModelSource
|
|
ModelID string
|
|
}
|
|
|
|
// NewReranker creates a reranker
|
|
func NewReranker(config *RerankerConfig) (Reranker, error) {
|
|
switch strings.ToLower(string(config.Source)) {
|
|
case string(types.ModelSourceRemote):
|
|
return NewOpenAIReranker(config)
|
|
default:
|
|
return nil, fmt.Errorf("unsupported rerank model source: %s", config.Source)
|
|
}
|
|
}
|