fix(baidu_netdisk): Fix Baidu Netdisk resume uploads sticking to the same upload host (#1609)

Fix Baidu Netdisk resume uploads sticking to the same upload host
This commit is contained in:
jenfonro
2025-11-09 20:43:02 +08:00
committed by GitHub
parent f2e0fe8589
commit 7d78944d14
3 changed files with 56 additions and 19 deletions

View File

@@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http"
"net/url" "net/url"
"os" "os"
stdpath "path" stdpath "path"
@@ -38,8 +39,12 @@ type BaiduNetdisk struct {
upClient *resty.Client // 上传文件使用的http客户端 upClient *resty.Client // 上传文件使用的http客户端
uploadUrlG singleflight.Group[string] uploadUrlG singleflight.Group[string]
uploadUrlMu sync.RWMutex uploadUrlMu sync.RWMutex
uploadUrl string // 上传域名 uploadUrlCache map[string]uploadURLCacheEntry
uploadUrlUpdateTime time.Time // 上传域名上次更新时间 }
type uploadURLCacheEntry struct {
url string
updateTime time.Time
} }
var ErrUploadIDExpired = errors.New("uploadid expired") var ErrUploadIDExpired = errors.New("uploadid expired")
@@ -58,6 +63,7 @@ func (d *BaiduNetdisk) Init(ctx context.Context) error {
SetRetryCount(UPLOAD_RETRY_COUNT). SetRetryCount(UPLOAD_RETRY_COUNT).
SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME). SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME).
SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME) SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME)
d.uploadUrlCache = make(map[string]uploadURLCacheEntry)
d.uploadThread, _ = strconv.Atoi(d.UploadThread) d.uploadThread, _ = strconv.Atoi(d.UploadThread)
if d.uploadThread < 1 { if d.uploadThread < 1 {
d.uploadThread, d.UploadThread = 1, "1" d.uploadThread, d.UploadThread = 1, "1"
@@ -298,12 +304,22 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
return fileToObj(precreateResp.File), nil return fileToObj(precreateResp.File), nil
} }
} }
ensureUploadURL := func() {
if precreateResp.UploadURL != "" {
return
}
precreateResp.UploadURL = d.getUploadUrl(path, precreateResp.Uploadid)
}
ensureUploadURL()
// step.2 上传分片 // step.2 上传分片
uploadLoop: uploadLoop:
for attempt := 0; attempt < 2; attempt++ { for attempt := 0; attempt < 2; attempt++ {
// 获取上传域名 // 获取上传域名
uploadUrl := d.getUploadUrl(path, precreateResp.Uploadid) if precreateResp.UploadURL == "" {
ensureUploadURL()
}
uploadUrl := precreateResp.UploadURL
// 并发上传 // 并发上传
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
retry.Attempts(1), retry.Attempts(1),
@@ -363,6 +379,7 @@ uploadLoop:
} }
if errors.Is(err, ErrUploadIDExpired) { if errors.Is(err, ErrUploadIDExpired) {
log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch") log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch")
d.clearUploadUrlCache(precreateResp.Uploadid)
// 重新 precreate所有分片都要重传 // 重新 precreate所有分片都要重传
newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime) newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime)
if err2 != nil { if err2 != nil {
@@ -372,6 +389,8 @@ uploadLoop:
return fileToObj(newPre.File), nil return fileToObj(newPre.File), nil
} }
precreateResp = newPre precreateResp = newPre
precreateResp.UploadURL = ""
ensureUploadURL()
// 覆盖掉旧的进度 // 覆盖掉旧的进度
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5) base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
continue uploadLoop continue uploadLoop
@@ -390,6 +409,7 @@ uploadLoop:
newFile.Mtime = mtime newFile.Mtime = mtime
// 上传成功清理进度 // 上传成功清理进度
base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5) base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5)
d.clearUploadUrlCache(precreateResp.Uploadid)
return fileToObj(newFile), nil return fileToObj(newFile), nil
} }
@@ -438,6 +458,9 @@ func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params
return err return err
} }
log.Debugln(res.RawResponse.Status + res.String()) log.Debugln(res.RawResponse.Status + res.String())
if res.StatusCode() != http.StatusOK {
return errs.NewErr(errs.StreamIncomplete, "baidu upload failed, status=%d, body=%s", res.StatusCode(), res.String())
}
errCode := utils.Json.Get(res.Body(), "error_code").ToInt() errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
errNo := utils.Json.Get(res.Body(), "errno").ToInt() errNo := utils.Json.Get(res.Body(), "errno").ToInt()
respStr := res.String() respStr := res.String()

View File

@@ -193,6 +193,8 @@ type PrecreateResp struct {
// return_type=2 // return_type=2
File File `json:"info"` File File `json:"info"`
UploadURL string `json:"-"` // 保存断点续传对应的上传域名
} }
type UploadServerResp struct { type UploadServerResp struct {

View File

@@ -394,29 +394,28 @@ func (d *BaiduNetdisk) quota(ctx context.Context) (model.DiskUsage, error) {
return driver.DiskUsageFromUsedAndTotal(resp.Used, resp.Total), nil return driver.DiskUsageFromUsedAndTotal(resp.Used, resp.Total), nil
} }
// getUploadUrl 从开放平台获取上传域名/地址,并发请求会被合并,结果会被缓存1h // getUploadUrl 从开放平台获取上传域名/地址,并发请求会被合并,结果会在 uploadid 生命周期内复用
// 如果获取失败,则返回 Upload API设置项。 // 如果获取失败,则返回 Upload API设置项。
func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string { func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
if !d.UseDynamicUploadAPI { if !d.UseDynamicUploadAPI || uploadId == "" {
return d.UploadAPI return d.UploadAPI
} }
getCachedUrlFunc := func() string { getCachedUrlFunc := func() (string, bool) {
d.uploadUrlMu.RLock() d.uploadUrlMu.RLock()
defer d.uploadUrlMu.RUnlock() defer d.uploadUrlMu.RUnlock()
if d.uploadUrl != "" && time.Since(d.uploadUrlUpdateTime) < UPLOAD_URL_EXPIRE_TIME { if entry, ok := d.uploadUrlCache[uploadId]; ok {
uploadUrl := d.uploadUrl return entry.url, true
return uploadUrl
} }
return "" return "", false
} }
// 检查地址缓存 // 检查地址缓存
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" { if uploadUrl, ok := getCachedUrlFunc(); ok {
return uploadUrl return uploadUrl
} }
uploadUrlGetFunc := func() (string, error) { uploadUrlGetFunc := func() (string, error) {
// 双重检查缓存 // 双重检查缓存
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" { if uploadUrl, ok := getCachedUrlFunc(); ok {
return uploadUrl, nil return uploadUrl, nil
} }
@@ -426,13 +425,15 @@ func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
} }
d.uploadUrlMu.Lock() d.uploadUrlMu.Lock()
defer d.uploadUrlMu.Unlock() d.uploadUrlCache[uploadId] = uploadURLCacheEntry{
d.uploadUrl = uploadUrl url: uploadUrl,
d.uploadUrlUpdateTime = time.Now() updateTime: time.Now(),
}
d.uploadUrlMu.Unlock()
return uploadUrl, nil return uploadUrl, nil
} }
uploadUrl, err, _ := d.uploadUrlG.Do("", uploadUrlGetFunc) uploadUrl, err, _ := d.uploadUrlG.Do(uploadId, uploadUrlGetFunc)
if err != nil { if err != nil {
fallback := d.UploadAPI fallback := d.UploadAPI
log.Warnf("[baidu_netdisk] get upload URL failed (%v), will use fallback URL: %s", err, fallback) log.Warnf("[baidu_netdisk] get upload URL failed (%v), will use fallback URL: %s", err, fallback)
@@ -441,6 +442,17 @@ func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
return uploadUrl return uploadUrl
} }
func (d *BaiduNetdisk) clearUploadUrlCache(uploadId string) {
if uploadId == "" {
return
}
d.uploadUrlMu.Lock()
if _, ok := d.uploadUrlCache[uploadId]; ok {
delete(d.uploadUrlCache, uploadId)
}
d.uploadUrlMu.Unlock()
}
// requestForUploadUrl 请求获取上传地址。 // requestForUploadUrl 请求获取上传地址。
// 实测此接口不需要认证传method和upload_version就行不过还是按文档规范调用。 // 实测此接口不需要认证传method和upload_version就行不过还是按文档规范调用。
// https://pan.baidu.com/union/doc/Mlvw5hfnr // https://pan.baidu.com/union/doc/Mlvw5hfnr