mirror of
https://github.com/OpenListTeam/OpenList.git
synced 2025-11-25 03:15:19 +08:00
489 lines
14 KiB
Go
489 lines
14 KiB
Go
package baidu_netdisk
|
||
|
||
import (
|
||
"context"
|
||
"crypto/md5"
|
||
"encoding/hex"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
stdpath "path"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/OpenListTeam/OpenList/v4/drivers/base"
|
||
"github.com/OpenListTeam/OpenList/v4/internal/conf"
|
||
"github.com/OpenListTeam/OpenList/v4/internal/driver"
|
||
"github.com/OpenListTeam/OpenList/v4/internal/errs"
|
||
"github.com/OpenListTeam/OpenList/v4/internal/model"
|
||
"github.com/OpenListTeam/OpenList/v4/pkg/errgroup"
|
||
"github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
|
||
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
|
||
"github.com/avast/retry-go"
|
||
"github.com/go-resty/resty/v2"
|
||
log "github.com/sirupsen/logrus"
|
||
)
|
||
|
||
type BaiduNetdisk struct {
|
||
model.Storage
|
||
Addition
|
||
|
||
uploadThread int
|
||
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
|
||
|
||
upClient *resty.Client // 上传文件使用的http客户端
|
||
uploadUrlG singleflight.Group[string]
|
||
uploadUrlMu sync.RWMutex
|
||
uploadUrlCache map[string]uploadURLCacheEntry
|
||
}
|
||
|
||
type uploadURLCacheEntry struct {
|
||
url string
|
||
updateTime time.Time
|
||
}
|
||
|
||
var ErrUploadIDExpired = errors.New("uploadid expired")
|
||
|
||
func (d *BaiduNetdisk) Config() driver.Config {
|
||
return config
|
||
}
|
||
|
||
func (d *BaiduNetdisk) GetAddition() driver.Additional {
|
||
return &d.Addition
|
||
}
|
||
|
||
func (d *BaiduNetdisk) Init(ctx context.Context) error {
|
||
d.upClient = base.NewRestyClient().
|
||
SetTimeout(UPLOAD_TIMEOUT).
|
||
SetRetryCount(UPLOAD_RETRY_COUNT).
|
||
SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME).
|
||
SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME)
|
||
d.uploadUrlCache = make(map[string]uploadURLCacheEntry)
|
||
d.uploadThread, _ = strconv.Atoi(d.UploadThread)
|
||
if d.uploadThread < 1 {
|
||
d.uploadThread, d.UploadThread = 1, "1"
|
||
} else if d.uploadThread > 32 {
|
||
d.uploadThread, d.UploadThread = 32, "32"
|
||
}
|
||
|
||
if _, err := url.Parse(d.UploadAPI); d.UploadAPI == "" || err != nil {
|
||
d.UploadAPI = UPLOAD_FALLBACK_API
|
||
}
|
||
|
||
res, err := d.get("/xpan/nas", map[string]string{
|
||
"method": "uinfo",
|
||
}, nil)
|
||
log.Debugf("[baidu_netdisk] get uinfo: %s", string(res))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
d.vipType = utils.Json.Get(res, "vip_type").ToInt()
|
||
return nil
|
||
}
|
||
|
||
func (d *BaiduNetdisk) Drop(ctx context.Context) error {
|
||
return nil
|
||
}
|
||
|
||
func (d *BaiduNetdisk) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) {
|
||
files, err := d.getFiles(dir.GetPath())
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return utils.SliceConvert(files, func(src File) (model.Obj, error) {
|
||
return fileToObj(src), nil
|
||
})
|
||
}
|
||
|
||
func (d *BaiduNetdisk) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
|
||
if d.DownloadAPI == "crack" {
|
||
return d.linkCrack(file, args)
|
||
} else if d.DownloadAPI == "crack_video" {
|
||
return d.linkCrackVideo(file, args)
|
||
}
|
||
return d.linkOfficial(file, args)
|
||
}
|
||
|
||
func (d *BaiduNetdisk) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) {
|
||
var newDir File
|
||
_, err := d.create(stdpath.Join(parentDir.GetPath(), dirName), 0, 1, "", "", &newDir, 0, 0)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return fileToObj(newDir), nil
|
||
}
|
||
|
||
func (d *BaiduNetdisk) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) {
|
||
data := []base.Json{
|
||
{
|
||
"path": srcObj.GetPath(),
|
||
"dest": dstDir.GetPath(),
|
||
"newname": srcObj.GetName(),
|
||
},
|
||
}
|
||
_, err := d.manage("move", data)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if srcObj, ok := srcObj.(*model.ObjThumb); ok {
|
||
srcObj.SetPath(stdpath.Join(dstDir.GetPath(), srcObj.GetName()))
|
||
srcObj.Modified = time.Now()
|
||
return srcObj, nil
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (d *BaiduNetdisk) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) {
|
||
data := []base.Json{
|
||
{
|
||
"path": srcObj.GetPath(),
|
||
"newname": newName,
|
||
},
|
||
}
|
||
_, err := d.manage("rename", data)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if srcObj, ok := srcObj.(*model.ObjThumb); ok {
|
||
srcObj.SetPath(stdpath.Join(stdpath.Dir(srcObj.GetPath()), newName))
|
||
srcObj.Name = newName
|
||
srcObj.Modified = time.Now()
|
||
return srcObj, nil
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (d *BaiduNetdisk) Copy(ctx context.Context, srcObj, dstDir model.Obj) error {
|
||
data := []base.Json{
|
||
{
|
||
"path": srcObj.GetPath(),
|
||
"dest": dstDir.GetPath(),
|
||
"newname": srcObj.GetName(),
|
||
},
|
||
}
|
||
_, err := d.manage("copy", data)
|
||
return err
|
||
}
|
||
|
||
func (d *BaiduNetdisk) Remove(ctx context.Context, obj model.Obj) error {
|
||
data := []string{obj.GetPath()}
|
||
_, err := d.manage("delete", data)
|
||
return err
|
||
}
|
||
|
||
func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream model.FileStreamer) (model.Obj, error) {
|
||
contentMd5 := stream.GetHash().GetHash(utils.MD5)
|
||
if len(contentMd5) < utils.MD5.Width {
|
||
return nil, errors.New("invalid hash")
|
||
}
|
||
|
||
streamSize := stream.GetSize()
|
||
path := stdpath.Join(dstDir.GetPath(), stream.GetName())
|
||
mtime := stream.ModTime().Unix()
|
||
ctime := stream.CreateTime().Unix()
|
||
blockList, _ := utils.Json.MarshalToString([]string{contentMd5})
|
||
|
||
var newFile File
|
||
_, err := d.create(path, streamSize, 0, "", blockList, &newFile, mtime, ctime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 修复时间,具体原因见 Put 方法注释的 **注意**
|
||
newFile.Ctime = stream.CreateTime().Unix()
|
||
newFile.Mtime = stream.ModTime().Unix()
|
||
return fileToObj(newFile), nil
|
||
}
|
||
|
||
// Put
|
||
//
|
||
// **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。
|
||
// 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致
|
||
func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
|
||
// 百度网盘不允许上传空文件
|
||
if stream.GetSize() < 1 {
|
||
return nil, ErrBaiduEmptyFilesNotAllowed
|
||
}
|
||
|
||
// rapid upload
|
||
if newObj, err := d.PutRapid(ctx, dstDir, stream); err == nil {
|
||
return newObj, nil
|
||
}
|
||
|
||
var (
|
||
cache = stream.GetFile()
|
||
tmpF *os.File
|
||
err error
|
||
)
|
||
if _, ok := cache.(io.ReaderAt); !ok {
|
||
tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
_ = tmpF.Close()
|
||
_ = os.Remove(tmpF.Name())
|
||
}()
|
||
cache = tmpF
|
||
}
|
||
|
||
streamSize := stream.GetSize()
|
||
sliceSize := d.getSliceSize(streamSize)
|
||
count := 1
|
||
if streamSize > sliceSize {
|
||
count = int((streamSize + sliceSize - 1) / sliceSize)
|
||
}
|
||
lastBlockSize := streamSize % sliceSize
|
||
if lastBlockSize == 0 {
|
||
lastBlockSize = sliceSize
|
||
}
|
||
|
||
// cal md5 for first 256k data
|
||
const SliceSize int64 = 256 * utils.KB
|
||
blockList := make([]string, 0, count)
|
||
byteSize := sliceSize
|
||
fileMd5H := md5.New()
|
||
sliceMd5H := md5.New()
|
||
sliceMd5H2 := md5.New()
|
||
slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize)
|
||
writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write}
|
||
if tmpF != nil {
|
||
writers = append(writers, tmpF)
|
||
}
|
||
written := int64(0)
|
||
|
||
for i := 1; i <= count; i++ {
|
||
if utils.IsCanceled(ctx) {
|
||
return nil, ctx.Err()
|
||
}
|
||
if i == count {
|
||
byteSize = lastBlockSize
|
||
}
|
||
n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize)
|
||
written += n
|
||
if err != nil && err != io.EOF {
|
||
return nil, err
|
||
}
|
||
blockList = append(blockList, hex.EncodeToString(sliceMd5H.Sum(nil)))
|
||
sliceMd5H.Reset()
|
||
}
|
||
if tmpF != nil {
|
||
if written != streamSize {
|
||
return nil, errs.NewErr(err, "CreateTempFile failed, size mismatch: %d != %d ", written, streamSize)
|
||
}
|
||
_, err = tmpF.Seek(0, io.SeekStart)
|
||
if err != nil {
|
||
return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ")
|
||
}
|
||
}
|
||
contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil))
|
||
sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil))
|
||
blockListStr, _ := utils.Json.MarshalToString(blockList)
|
||
path := stdpath.Join(dstDir.GetPath(), stream.GetName())
|
||
mtime := stream.ModTime().Unix()
|
||
ctime := stream.CreateTime().Unix()
|
||
|
||
// step.1 尝试读取已保存进度
|
||
precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5)
|
||
if !ok {
|
||
// 没有进度,走预上传
|
||
precreateResp, err = d.precreate(ctx, path, streamSize, blockListStr, contentMd5, sliceMd5, ctime, mtime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if precreateResp.ReturnType == 2 {
|
||
// rapid upload, since got md5 match from baidu server
|
||
// 修复时间,具体原因见 Put 方法注释的 **注意**
|
||
precreateResp.File.Ctime = ctime
|
||
precreateResp.File.Mtime = mtime
|
||
return fileToObj(precreateResp.File), nil
|
||
}
|
||
}
|
||
ensureUploadURL := func() {
|
||
if precreateResp.UploadURL != "" {
|
||
return
|
||
}
|
||
precreateResp.UploadURL = d.getUploadUrl(path, precreateResp.Uploadid)
|
||
}
|
||
ensureUploadURL()
|
||
|
||
// step.2 上传分片
|
||
uploadLoop:
|
||
for attempt := 0; attempt < 2; attempt++ {
|
||
// 获取上传域名
|
||
if precreateResp.UploadURL == "" {
|
||
ensureUploadURL()
|
||
}
|
||
uploadUrl := precreateResp.UploadURL
|
||
// 并发上传
|
||
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
|
||
retry.Attempts(1),
|
||
retry.Delay(time.Second),
|
||
retry.DelayType(retry.BackOffDelay))
|
||
|
||
cacheReaderAt, okReaderAt := cache.(io.ReaderAt)
|
||
if !okReaderAt {
|
||
return nil, fmt.Errorf("cache object must implement io.ReaderAt interface for upload operations")
|
||
}
|
||
|
||
totalParts := len(precreateResp.BlockList)
|
||
|
||
for i, partseq := range precreateResp.BlockList {
|
||
if utils.IsCanceled(upCtx) || partseq < 0 {
|
||
continue
|
||
}
|
||
i, partseq := i, partseq
|
||
offset, size := int64(partseq)*sliceSize, sliceSize
|
||
if partseq+1 == count {
|
||
size = lastBlockSize
|
||
}
|
||
threadG.Go(func(ctx context.Context) error {
|
||
params := map[string]string{
|
||
"method": "upload",
|
||
"access_token": d.AccessToken,
|
||
"type": "tmpfile",
|
||
"path": path,
|
||
"uploadid": precreateResp.Uploadid,
|
||
"partseq": strconv.Itoa(partseq),
|
||
}
|
||
section := io.NewSectionReader(cacheReaderAt, offset, size)
|
||
err := d.uploadSlice(ctx, uploadUrl, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
precreateResp.BlockList[i] = -1
|
||
// 当前goroutine还没退出,+1才是真正成功的数量
|
||
success := threadG.Success() + 1
|
||
progress := float64(success) * 100 / float64(totalParts)
|
||
up(progress)
|
||
return nil
|
||
})
|
||
}
|
||
|
||
err = threadG.Wait()
|
||
if err == nil {
|
||
break uploadLoop
|
||
}
|
||
|
||
// 保存进度(所有错误都会保存)
|
||
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
|
||
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
|
||
|
||
if errors.Is(err, context.Canceled) {
|
||
return nil, err
|
||
}
|
||
if errors.Is(err, ErrUploadIDExpired) {
|
||
log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch")
|
||
d.clearUploadUrlCache(precreateResp.Uploadid)
|
||
// 重新 precreate(所有分片都要重传)
|
||
newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime)
|
||
if err2 != nil {
|
||
return nil, err2
|
||
}
|
||
if newPre.ReturnType == 2 {
|
||
return fileToObj(newPre.File), nil
|
||
}
|
||
precreateResp = newPre
|
||
precreateResp.UploadURL = ""
|
||
ensureUploadURL()
|
||
// 覆盖掉旧的进度
|
||
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
|
||
continue uploadLoop
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
// step.3 创建文件
|
||
var newFile File
|
||
_, err = d.create(path, streamSize, 0, precreateResp.Uploadid, blockListStr, &newFile, mtime, ctime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 修复时间,具体原因见 Put 方法注释的 **注意**
|
||
newFile.Ctime = ctime
|
||
newFile.Mtime = mtime
|
||
// 上传成功清理进度
|
||
base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5)
|
||
d.clearUploadUrlCache(precreateResp.Uploadid)
|
||
return fileToObj(newFile), nil
|
||
}
|
||
|
||
// precreate 执行预上传操作,支持首次上传和 uploadid 过期重试
|
||
func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize int64, blockListStr, contentMd5, sliceMd5 string, ctime, mtime int64) (*PrecreateResp, error) {
|
||
params := map[string]string{"method": "precreate"}
|
||
form := map[string]string{
|
||
"path": path,
|
||
"size": strconv.FormatInt(streamSize, 10),
|
||
"isdir": "0",
|
||
"autoinit": "1",
|
||
"rtype": "3",
|
||
"block_list": blockListStr,
|
||
}
|
||
|
||
// 只有在首次上传时才包含 content-md5 和 slice-md5
|
||
if contentMd5 != "" && sliceMd5 != "" {
|
||
form["content-md5"] = contentMd5
|
||
form["slice-md5"] = sliceMd5
|
||
}
|
||
|
||
joinTime(form, ctime, mtime)
|
||
|
||
var precreateResp PrecreateResp
|
||
_, err := d.postForm("/xpan/file", params, form, &precreateResp)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 修复时间,具体原因见 Put 方法注释的 **注意**
|
||
if precreateResp.ReturnType == 2 {
|
||
precreateResp.File.Ctime = ctime
|
||
precreateResp.File.Mtime = mtime
|
||
}
|
||
|
||
return &precreateResp, nil
|
||
}
|
||
|
||
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params map[string]string, fileName string, file io.Reader) error {
|
||
res, err := d.upClient.R().
|
||
SetContext(ctx).
|
||
SetQueryParams(params).
|
||
SetFileReader("file", fileName, file).
|
||
Post(uploadUrl + "/rest/2.0/pcs/superfile2")
|
||
if err != nil {
|
||
return err
|
||
}
|
||
log.Debugln(res.RawResponse.Status + res.String())
|
||
if res.StatusCode() != http.StatusOK {
|
||
return errs.NewErr(errs.StreamIncomplete, "baidu upload failed, status=%d, body=%s", res.StatusCode(), res.String())
|
||
}
|
||
errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
|
||
errNo := utils.Json.Get(res.Body(), "errno").ToInt()
|
||
respStr := res.String()
|
||
lower := strings.ToLower(respStr)
|
||
// 合并 uploadid 过期检测逻辑
|
||
if strings.Contains(lower, "uploadid") &&
|
||
(strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) {
|
||
return ErrUploadIDExpired
|
||
}
|
||
|
||
if errCode != 0 || errNo != 0 {
|
||
return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", res.String())
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (d *BaiduNetdisk) GetDetails(ctx context.Context) (*model.StorageDetails, error) {
|
||
du, err := d.quota(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &model.StorageDetails{DiskUsage: du}, nil
|
||
}
|
||
|
||
var _ driver.Driver = (*BaiduNetdisk)(nil)
|