From 19c5e242bf700429286c5e46650df8ece92143da Mon Sep 17 00:00:00 2001 From: Liwx1014 Date: Thu, 28 Aug 2025 10:54:01 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=BC=E5=AE=B9=E8=A7=A3=E6=9E=90=E9=87=8D?= =?UTF-8?q?=E6=8E=92=E5=BA=8F=E6=9C=8D=E5=8A=A1=E8=BF=94=E5=9B=9Escore?= =?UTF-8?q?=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/models/rerank/reranker.go | 30 ++++++++- rerank_server_demo.py | 102 +++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 rerank_server_demo.py diff --git a/internal/models/rerank/reranker.go b/internal/models/rerank/reranker.go index f62c66b..4a54211 100644 --- a/internal/models/rerank/reranker.go +++ b/internal/models/rerank/reranker.go @@ -2,9 +2,10 @@ package rerank import ( "context" + "encoding/json" "fmt" "strings" - + "github.com/Tencent/WeKnora/internal/types" ) @@ -25,6 +26,33 @@ type RankResult struct { Document DocumentInfo `json:"document"` RelevanceScore float64 `json:"relevance_score"` } +//Handles the RelevanceScore field by checking if RelevanceScore exists first, otherwise falls back to Score field +func (r *RankResult) UnmarshalJSON(data []byte) error { + + var temp struct { + Index int `json:"index"` + Document DocumentInfo `json:"document"` + RelevanceScore *float64 `json:"relevance_score"` + Score *float64 `json:"score"` + } + + + if err := json.Unmarshal(data, &temp); err != nil { + return fmt.Errorf("failed to unmarshal rank result: %w", err) + } + + r.Index = temp.Index + r.Document = temp.Document + + if temp.RelevanceScore != nil { + r.RelevanceScore = *temp.RelevanceScore + } else if temp.Score != nil { + r.RelevanceScore = *temp.Score + } + + + return nil +} type DocumentInfo struct { Text string `json:"text"` diff --git a/rerank_server_demo.py b/rerank_server_demo.py new file mode 100644 index 0000000..866acf9 --- /dev/null +++ b/rerank_server_demo.py @@ -0,0 +1,102 @@ +import torch +import uvicorn +from fastapi import FastAPI +from pydantic import BaseModel, Field +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from typing import List + +# --- 1. 定义API的请求和响应数据结构 --- + +# 请求体结构保持不变 +class RerankRequest(BaseModel): + query: str + documents: List[str] + +# --- 修改开始:定义测试用的响应结构,字段名为 "score" --- + +# DocumentInfo 结构保持不变 +class DocumentInfo(BaseModel): + text: str + +# 将原来的 GoRankResult 修改为 TestRankResult +# 核心改动:将 "relevance_score" 字段重命名为 "score" +class TestRankResult(BaseModel): + index: int + document: DocumentInfo + score: float # <--- 【关键修改点】字段名已从 relevance_score 改为 score + +# 最终响应体结构,其 "results" 列表包含的是 TestRankResult +class TestFinalResponse(BaseModel): + results: List[TestRankResult] + +# --- 修改结束 --- + + +# --- 2. 加载模型 (在服务启动时执行一次) --- +print("正在加载模型,请稍候...") +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"使用的设备: {device}") +try: + # 请确保这里的路径是正确的 + model_path = '/data1/home/lwx/work/Download/rerank_model_weight' + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = AutoModelForSequenceClassification.from_pretrained(model_path) + model.to(device) + model.eval() + print("模型加载成功!") +except Exception as e: + print(f"模型加载失败: {e}") + # 在测试环境中,如果模型加载失败,可以考虑退出以避免运行一个无效的服务 + exit() + +# --- 3. 创建FastAPI应用 --- +app = FastAPI( + title="Reranker API (Test Version)", + description="一个返回 'score' 字段以测试Go客户端兼容性的API服务", + version="1.0.1" +) + +# --- 4. 定义API端点 --- +# --- 修改开始:将 response_model 指向新的测试用响应结构 --- +@app.post("/rerank", response_model=TestFinalResponse) # <--- 【关键修改点】response_model 改为 TestFinalResponse +def rerank_endpoint(request: RerankRequest): + # --- 修改结束 --- + + pairs = [[request.query, doc] for doc in request.documents] + + with torch.no_grad(): + inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=1024).to(device) + scores = model(**inputs, return_dict=True).logits.view(-1, ).float() + + # --- 修改开始:按照测试用的结构来构建结果 --- + results = [] + for i, (text, score_val) in enumerate(zip(request.documents, scores)): + + # 1. 创建嵌套的 document 对象 + doc_info = DocumentInfo(text=text) + + # 2. 创建 TestRankResult 对象 + # 注意字段名:index, document, score + test_result = TestRankResult( + index=i, + document=doc_info, + score=score_val.item() # <--- 【关键修改点】赋值给 "score" 字段 + ) + results.append(test_result) + + # 3. 排序 (key 也要相应修改为 score) + sorted_results = sorted(results, key=lambda x: x.score, reverse=True) + # --- 修改结束 --- + + # 返回一个字典,FastAPI 会根据 response_model (TestFinalResponse) 来验证和序列化它 + # 最终生成的 JSON 会是 {"results": [{"index": ..., "document": ..., "score": ...}]} + return {"results": sorted_results} + +@app.get("/") +def read_root(): + return {"status": "Reranker API (Test Version) is running"} + +# --- 5. 启动服务 --- +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) + \ No newline at end of file