mirror of
https://github.com/OpenListTeam/OpenList.git
synced 2025-11-25 03:15:19 +08:00
fix(baidu_netdisk): Fix Baidu Netdisk resume uploads sticking to the same upload host (#1609)
Fix Baidu Netdisk resume uploads sticking to the same upload host
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
stdpath "path"
|
||||
@@ -35,11 +36,15 @@ type BaiduNetdisk struct {
|
||||
uploadThread int
|
||||
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
|
||||
|
||||
upClient *resty.Client // 上传文件使用的http客户端
|
||||
uploadUrlG singleflight.Group[string]
|
||||
uploadUrlMu sync.RWMutex
|
||||
uploadUrl string // 上传域名
|
||||
uploadUrlUpdateTime time.Time // 上传域名上次更新时间
|
||||
upClient *resty.Client // 上传文件使用的http客户端
|
||||
uploadUrlG singleflight.Group[string]
|
||||
uploadUrlMu sync.RWMutex
|
||||
uploadUrlCache map[string]uploadURLCacheEntry
|
||||
}
|
||||
|
||||
type uploadURLCacheEntry struct {
|
||||
url string
|
||||
updateTime time.Time
|
||||
}
|
||||
|
||||
var ErrUploadIDExpired = errors.New("uploadid expired")
|
||||
@@ -58,6 +63,7 @@ func (d *BaiduNetdisk) Init(ctx context.Context) error {
|
||||
SetRetryCount(UPLOAD_RETRY_COUNT).
|
||||
SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME).
|
||||
SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME)
|
||||
d.uploadUrlCache = make(map[string]uploadURLCacheEntry)
|
||||
d.uploadThread, _ = strconv.Atoi(d.UploadThread)
|
||||
if d.uploadThread < 1 {
|
||||
d.uploadThread, d.UploadThread = 1, "1"
|
||||
@@ -298,12 +304,22 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
|
||||
return fileToObj(precreateResp.File), nil
|
||||
}
|
||||
}
|
||||
ensureUploadURL := func() {
|
||||
if precreateResp.UploadURL != "" {
|
||||
return
|
||||
}
|
||||
precreateResp.UploadURL = d.getUploadUrl(path, precreateResp.Uploadid)
|
||||
}
|
||||
ensureUploadURL()
|
||||
|
||||
// step.2 上传分片
|
||||
uploadLoop:
|
||||
for attempt := 0; attempt < 2; attempt++ {
|
||||
// 获取上传域名
|
||||
uploadUrl := d.getUploadUrl(path, precreateResp.Uploadid)
|
||||
if precreateResp.UploadURL == "" {
|
||||
ensureUploadURL()
|
||||
}
|
||||
uploadUrl := precreateResp.UploadURL
|
||||
// 并发上传
|
||||
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
|
||||
retry.Attempts(1),
|
||||
@@ -363,6 +379,7 @@ uploadLoop:
|
||||
}
|
||||
if errors.Is(err, ErrUploadIDExpired) {
|
||||
log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch")
|
||||
d.clearUploadUrlCache(precreateResp.Uploadid)
|
||||
// 重新 precreate(所有分片都要重传)
|
||||
newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime)
|
||||
if err2 != nil {
|
||||
@@ -372,6 +389,8 @@ uploadLoop:
|
||||
return fileToObj(newPre.File), nil
|
||||
}
|
||||
precreateResp = newPre
|
||||
precreateResp.UploadURL = ""
|
||||
ensureUploadURL()
|
||||
// 覆盖掉旧的进度
|
||||
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
|
||||
continue uploadLoop
|
||||
@@ -390,6 +409,7 @@ uploadLoop:
|
||||
newFile.Mtime = mtime
|
||||
// 上传成功清理进度
|
||||
base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5)
|
||||
d.clearUploadUrlCache(precreateResp.Uploadid)
|
||||
return fileToObj(newFile), nil
|
||||
}
|
||||
|
||||
@@ -438,6 +458,9 @@ func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params
|
||||
return err
|
||||
}
|
||||
log.Debugln(res.RawResponse.Status + res.String())
|
||||
if res.StatusCode() != http.StatusOK {
|
||||
return errs.NewErr(errs.StreamIncomplete, "baidu upload failed, status=%d, body=%s", res.StatusCode(), res.String())
|
||||
}
|
||||
errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
|
||||
errNo := utils.Json.Get(res.Body(), "errno").ToInt()
|
||||
respStr := res.String()
|
||||
|
||||
@@ -193,6 +193,8 @@ type PrecreateResp struct {
|
||||
|
||||
// return_type=2
|
||||
File File `json:"info"`
|
||||
|
||||
UploadURL string `json:"-"` // 保存断点续传对应的上传域名
|
||||
}
|
||||
|
||||
type UploadServerResp struct {
|
||||
|
||||
@@ -394,29 +394,28 @@ func (d *BaiduNetdisk) quota(ctx context.Context) (model.DiskUsage, error) {
|
||||
return driver.DiskUsageFromUsedAndTotal(resp.Used, resp.Total), nil
|
||||
}
|
||||
|
||||
// getUploadUrl 从开放平台获取上传域名/地址,并发请求会被合并,结果会被缓存1h。
|
||||
// getUploadUrl 从开放平台获取上传域名/地址,并发请求会被合并,结果会在 uploadid 生命周期内复用。
|
||||
// 如果获取失败,则返回 Upload API设置项。
|
||||
func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
|
||||
if !d.UseDynamicUploadAPI {
|
||||
if !d.UseDynamicUploadAPI || uploadId == "" {
|
||||
return d.UploadAPI
|
||||
}
|
||||
getCachedUrlFunc := func() string {
|
||||
getCachedUrlFunc := func() (string, bool) {
|
||||
d.uploadUrlMu.RLock()
|
||||
defer d.uploadUrlMu.RUnlock()
|
||||
if d.uploadUrl != "" && time.Since(d.uploadUrlUpdateTime) < UPLOAD_URL_EXPIRE_TIME {
|
||||
uploadUrl := d.uploadUrl
|
||||
return uploadUrl
|
||||
if entry, ok := d.uploadUrlCache[uploadId]; ok {
|
||||
return entry.url, true
|
||||
}
|
||||
return ""
|
||||
return "", false
|
||||
}
|
||||
// 检查地址缓存
|
||||
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" {
|
||||
if uploadUrl, ok := getCachedUrlFunc(); ok {
|
||||
return uploadUrl
|
||||
}
|
||||
|
||||
uploadUrlGetFunc := func() (string, error) {
|
||||
// 双重检查缓存
|
||||
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" {
|
||||
if uploadUrl, ok := getCachedUrlFunc(); ok {
|
||||
return uploadUrl, nil
|
||||
}
|
||||
|
||||
@@ -426,13 +425,15 @@ func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
|
||||
}
|
||||
|
||||
d.uploadUrlMu.Lock()
|
||||
defer d.uploadUrlMu.Unlock()
|
||||
d.uploadUrl = uploadUrl
|
||||
d.uploadUrlUpdateTime = time.Now()
|
||||
d.uploadUrlCache[uploadId] = uploadURLCacheEntry{
|
||||
url: uploadUrl,
|
||||
updateTime: time.Now(),
|
||||
}
|
||||
d.uploadUrlMu.Unlock()
|
||||
return uploadUrl, nil
|
||||
}
|
||||
|
||||
uploadUrl, err, _ := d.uploadUrlG.Do("", uploadUrlGetFunc)
|
||||
uploadUrl, err, _ := d.uploadUrlG.Do(uploadId, uploadUrlGetFunc)
|
||||
if err != nil {
|
||||
fallback := d.UploadAPI
|
||||
log.Warnf("[baidu_netdisk] get upload URL failed (%v), will use fallback URL: %s", err, fallback)
|
||||
@@ -441,6 +442,17 @@ func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
|
||||
return uploadUrl
|
||||
}
|
||||
|
||||
func (d *BaiduNetdisk) clearUploadUrlCache(uploadId string) {
|
||||
if uploadId == "" {
|
||||
return
|
||||
}
|
||||
d.uploadUrlMu.Lock()
|
||||
if _, ok := d.uploadUrlCache[uploadId]; ok {
|
||||
delete(d.uploadUrlCache, uploadId)
|
||||
}
|
||||
d.uploadUrlMu.Unlock()
|
||||
}
|
||||
|
||||
// requestForUploadUrl 请求获取上传地址。
|
||||
// 实测此接口不需要认证,传method和upload_version就行,不过还是按文档规范调用。
|
||||
// https://pan.baidu.com/union/doc/Mlvw5hfnr
|
||||
|
||||
Reference in New Issue
Block a user