mirror of
https://github.com/OpenListTeam/OpenList.git
synced 2025-11-25 03:15:19 +08:00
fix(baidu_netdisk): support resuming uploads when an error occurs (#1279)
support resuming uploads when an error occurs
This commit is contained in:
@@ -5,11 +5,13 @@ import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
stdpath "path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/OpenListTeam/OpenList/v4/drivers/base"
|
||||
@@ -31,6 +33,8 @@ type BaiduNetdisk struct {
|
||||
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
|
||||
}
|
||||
|
||||
var ErrUploadIDExpired = errors.New("uploadid expired")
|
||||
|
||||
func (d *BaiduNetdisk) Config() driver.Config {
|
||||
return config
|
||||
}
|
||||
@@ -214,7 +218,6 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
|
||||
|
||||
// cal md5 for first 256k data
|
||||
const SliceSize int64 = 256 * utils.KB
|
||||
// cal md5
|
||||
blockList := make([]string, 0, count)
|
||||
byteSize := sliceSize
|
||||
fileMd5H := md5.New()
|
||||
@@ -244,7 +247,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
|
||||
}
|
||||
if tmpF != nil {
|
||||
if written != streamSize {
|
||||
return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize)
|
||||
return nil, errs.NewErr(err, "CreateTempFile failed, size mismatch: %d != %d ", written, streamSize)
|
||||
}
|
||||
_, err = tmpF.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
@@ -258,31 +261,14 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
|
||||
mtime := stream.ModTime().Unix()
|
||||
ctime := stream.CreateTime().Unix()
|
||||
|
||||
// step.1 预上传
|
||||
// 尝试获取之前的进度
|
||||
// step.1 尝试读取已保存进度
|
||||
precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5)
|
||||
if !ok {
|
||||
params := map[string]string{
|
||||
"method": "precreate",
|
||||
}
|
||||
form := map[string]string{
|
||||
"path": path,
|
||||
"size": strconv.FormatInt(streamSize, 10),
|
||||
"isdir": "0",
|
||||
"autoinit": "1",
|
||||
"rtype": "3",
|
||||
"block_list": blockListStr,
|
||||
"content-md5": contentMd5,
|
||||
"slice-md5": sliceMd5,
|
||||
}
|
||||
joinTime(form, ctime, mtime)
|
||||
|
||||
log.Debugf("[baidu_netdisk] precreate data: %s", form)
|
||||
_, err = d.postForm("/xpan/file", params, form, &precreateResp)
|
||||
// 没有进度,走预上传
|
||||
precreateResp, err = d.precreate(ctx, path, streamSize, blockListStr, contentMd5, sliceMd5, ctime, mtime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("%+v", precreateResp)
|
||||
if precreateResp.ReturnType == 2 {
|
||||
// rapid upload, since got md5 match from baidu server
|
||||
// 修复时间,具体原因见 Put 方法注释的 **注意**
|
||||
@@ -291,45 +277,81 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
|
||||
return fileToObj(precreateResp.File), nil
|
||||
}
|
||||
}
|
||||
|
||||
// step.2 上传分片
|
||||
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
|
||||
retry.Attempts(1),
|
||||
retry.Delay(time.Second),
|
||||
retry.DelayType(retry.BackOffDelay))
|
||||
uploadLoop:
|
||||
for attempt := 0; attempt < 2; attempt++ {
|
||||
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
|
||||
retry.Attempts(1),
|
||||
retry.Delay(time.Second),
|
||||
retry.DelayType(retry.BackOffDelay))
|
||||
|
||||
for i, partseq := range precreateResp.BlockList {
|
||||
if utils.IsCanceled(upCtx) {
|
||||
break
|
||||
cacheReaderAt, okReaderAt := cache.(io.ReaderAt)
|
||||
if !okReaderAt {
|
||||
return nil, fmt.Errorf("cache object must implement io.ReaderAt interface for upload operations")
|
||||
}
|
||||
|
||||
i, partseq, offset, byteSize := i, partseq, int64(partseq)*sliceSize, sliceSize
|
||||
if partseq+1 == count {
|
||||
byteSize = lastBlockSize
|
||||
totalParts := len(precreateResp.BlockList)
|
||||
doneParts := 0 // 新增计数器
|
||||
|
||||
for i, partseq := range precreateResp.BlockList {
|
||||
if utils.IsCanceled(upCtx) || partseq < 0 {
|
||||
continue
|
||||
}
|
||||
i, partseq := i, partseq
|
||||
offset, size := int64(partseq)*sliceSize, sliceSize
|
||||
if partseq+1 == count {
|
||||
size = lastBlockSize
|
||||
}
|
||||
threadG.Go(func(ctx context.Context) error {
|
||||
params := map[string]string{
|
||||
"method": "upload",
|
||||
"access_token": d.AccessToken,
|
||||
"type": "tmpfile",
|
||||
"path": path,
|
||||
"uploadid": precreateResp.Uploadid,
|
||||
"partseq": strconv.Itoa(partseq),
|
||||
}
|
||||
section := io.NewSectionReader(cacheReaderAt, offset, size)
|
||||
err := d.uploadSlice(ctx, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
precreateResp.BlockList[i] = -1
|
||||
doneParts++ // 直接递增计数器
|
||||
if totalParts > 0 {
|
||||
up(float64(doneParts) * 100.0 / float64(totalParts))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
threadG.Go(func(ctx context.Context) error {
|
||||
params := map[string]string{
|
||||
"method": "upload",
|
||||
"access_token": d.AccessToken,
|
||||
"type": "tmpfile",
|
||||
"path": path,
|
||||
"uploadid": precreateResp.Uploadid,
|
||||
"partseq": strconv.Itoa(partseq),
|
||||
}
|
||||
err := d.uploadSlice(ctx, params, stream.GetName(),
|
||||
driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
up(float64(threadG.Success()) * 100 / float64(len(precreateResp.BlockList)))
|
||||
precreateResp.BlockList[i] = -1
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if err = threadG.Wait(); err != nil {
|
||||
// 如果属于用户主动取消,则保存上传进度
|
||||
|
||||
err = threadG.Wait()
|
||||
if err == nil {
|
||||
break uploadLoop
|
||||
}
|
||||
|
||||
// 保存进度(所有错误都会保存)
|
||||
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
|
||||
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
|
||||
|
||||
if errors.Is(err, context.Canceled) {
|
||||
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
|
||||
return nil, err
|
||||
}
|
||||
if errors.Is(err, ErrUploadIDExpired) {
|
||||
log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch")
|
||||
// 重新 precreate(所有分片都要重传)
|
||||
newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime)
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
if newPre.ReturnType == 2 {
|
||||
return fileToObj(newPre.File), nil
|
||||
}
|
||||
precreateResp = newPre
|
||||
// 覆盖掉旧的进度
|
||||
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
|
||||
continue uploadLoop
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -343,9 +365,46 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
|
||||
// 修复时间,具体原因见 Put 方法注释的 **注意**
|
||||
newFile.Ctime = ctime
|
||||
newFile.Mtime = mtime
|
||||
// 上传成功清理进度
|
||||
base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5)
|
||||
return fileToObj(newFile), nil
|
||||
}
|
||||
|
||||
// precreate 执行预上传操作,支持首次上传和 uploadid 过期重试
|
||||
func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize int64, blockListStr, contentMd5, sliceMd5 string, ctime, mtime int64) (*PrecreateResp, error) {
|
||||
params := map[string]string{"method": "precreate"}
|
||||
form := map[string]string{
|
||||
"path": path,
|
||||
"size": strconv.FormatInt(streamSize, 10),
|
||||
"isdir": "0",
|
||||
"autoinit": "1",
|
||||
"rtype": "3",
|
||||
"block_list": blockListStr,
|
||||
}
|
||||
|
||||
// 只有在首次上传时才包含 content-md5 和 slice-md5
|
||||
if contentMd5 != "" && sliceMd5 != "" {
|
||||
form["content-md5"] = contentMd5
|
||||
form["slice-md5"] = sliceMd5
|
||||
}
|
||||
|
||||
joinTime(form, ctime, mtime)
|
||||
|
||||
var precreateResp PrecreateResp
|
||||
_, err := d.postForm("/xpan/file", params, form, &precreateResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 修复时间,具体原因见 Put 方法注释的 **注意**
|
||||
if precreateResp.ReturnType == 2 {
|
||||
precreateResp.File.Ctime = ctime
|
||||
precreateResp.File.Mtime = mtime
|
||||
}
|
||||
|
||||
return &precreateResp, nil
|
||||
}
|
||||
|
||||
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string, fileName string, file io.Reader) error {
|
||||
res, err := base.RestyClient.R().
|
||||
SetContext(ctx).
|
||||
@@ -358,8 +417,16 @@ func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string
|
||||
log.Debugln(res.RawResponse.Status + res.String())
|
||||
errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
|
||||
errNo := utils.Json.Get(res.Body(), "errno").ToInt()
|
||||
respStr := res.String()
|
||||
lower := strings.ToLower(respStr)
|
||||
// 合并 uploadid 过期检测逻辑
|
||||
if strings.Contains(lower, "uploadid") &&
|
||||
(strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) {
|
||||
return ErrUploadIDExpired
|
||||
}
|
||||
|
||||
if errCode != 0 || errNo != 0 {
|
||||
return errs.NewErr(errs.StreamIncomplete, "error in uploading to baidu, will retry. response=%s", res.String())
|
||||
return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", res.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user