diff --git a/drivers/115_open/upload.go b/drivers/115_open/upload.go index b718772b..9bd1f920 100644 --- a/drivers/115_open/upload.go +++ b/drivers/115_open/upload.go @@ -9,6 +9,7 @@ import ( sdk "github.com/OpenListTeam/115-sdk-go" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" + streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/aliyun/aliyun-oss-go-sdk/oss" "github.com/avast/retry-go" @@ -69,9 +70,6 @@ func (d *Open115) singleUpload(ctx context.Context, tempF model.File, tokenResp // } func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, up driver.UpdateProgress, tokenResp *sdk.UploadGetTokenResp, initResp *sdk.UploadInitResp) error { - fileSize := stream.GetSize() - chunkSize := calPartSize(fileSize) - ossClient, err := oss.New(tokenResp.Endpoint, tokenResp.AccessKeyId, tokenResp.AccessKeySecret, oss.SecurityToken(tokenResp.SecurityToken)) if err != nil { return err @@ -86,9 +84,15 @@ func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, return err } + fileSize := stream.GetSize() + chunkSize := calPartSize(fileSize) partNum := (stream.GetSize() + chunkSize - 1) / chunkSize parts := make([]oss.UploadPart, partNum) offset := int64(0) + ss, err := streamPkg.NewStreamSectionReader(stream, int(chunkSize)) + if err != nil { + return err + } for i := int64(1); i <= partNum; i++ { if utils.IsCanceled(ctx) { return ctx.Err() @@ -98,10 +102,13 @@ func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, if i == partNum { partSize = fileSize - (i-1)*chunkSize } - rd := utils.NewMultiReadable(io.LimitReader(stream, partSize)) + rd, err := ss.GetSectionReader(offset, partSize) + if err != nil { + return err + } + rateLimitedRd := driver.NewLimitedUploadStream(ctx, rd) err = retry.Do(func() error { - _ = rd.Reset() - rateLimitedRd := driver.NewLimitedUploadStream(ctx, rd) + rd.Seek(0, io.SeekStart) part, err := bucket.UploadPart(imur, rateLimitedRd, partSize, int(i)) if err != nil { return err @@ -112,6 +119,7 @@ func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, retry.Attempts(3), retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second)) + ss.RecycleSectionReader(rd) if err != nil { return err } diff --git a/drivers/123/meta.go b/drivers/123/meta.go index 505c55c0..613ba3ad 100644 --- a/drivers/123/meta.go +++ b/drivers/123/meta.go @@ -11,7 +11,8 @@ type Addition struct { driver.RootID //OrderBy string `json:"order_by" type:"select" options:"file_id,file_name,size,update_at" default:"file_name"` //OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` - AccessToken string + AccessToken string + UploadThread int `json:"UploadThread" type:"number" default:"3" help:"the threads of upload"` } var config = driver.Config{ @@ -22,6 +23,11 @@ var config = driver.Config{ func init() { op.RegisterDriver(func() driver.Driver { - return &Pan123{} + // 新增默认选项 要在RegisterDriver初始化设置 才会对正在使用的用户生效 + return &Pan123{ + Addition: Addition{ + UploadThread: 3, + }, + } }) } diff --git a/drivers/123/upload.go b/drivers/123/upload.go index 1dc79e2f..e44ce2ee 100644 --- a/drivers/123/upload.go +++ b/drivers/123/upload.go @@ -6,11 +6,16 @@ import ( "io" "net/http" "strconv" + "time" "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" + "github.com/OpenListTeam/OpenList/v4/pkg/singleflight" "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/avast/retry-go" "github.com/go-resty/resty/v2" ) @@ -69,18 +74,15 @@ func (d *Pan123) completeS3(ctx context.Context, upReq *UploadResp, file model.F } func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.FileStreamer, up driver.UpdateProgress) error { - tmpF, err := file.CacheFullInTempFile() - if err != nil { - return err - } // fetch s3 pre signed urls size := file.GetSize() - chunkSize := min(size, 16*utils.MB) - chunkCount := int(size / chunkSize) + chunkSize := int64(16 * utils.MB) + chunkCount := 1 + if size > chunkSize { + chunkCount = int((size + chunkSize - 1) / chunkSize) + } lastChunkSize := size % chunkSize - if lastChunkSize > 0 { - chunkCount++ - } else { + if lastChunkSize == 0 { lastChunkSize = chunkSize } // only 1 batch is allowed @@ -90,73 +92,103 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi batchSize = 10 getS3UploadUrl = d.getS3PreSignedUrls } + ss, err := stream.NewStreamSectionReader(file, int(chunkSize)) + if err != nil { + return err + } + + thread := min(int(chunkCount), d.UploadThread) + threadG, uploadCtx := errgroup.NewOrderedGroupWithContext(ctx, thread, + retry.Attempts(3), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) for i := 1; i <= chunkCount; i += batchSize { - if utils.IsCanceled(ctx) { - return ctx.Err() + if utils.IsCanceled(uploadCtx) { + break } start := i end := min(i+batchSize, chunkCount+1) - s3PreSignedUrls, err := getS3UploadUrl(ctx, upReq, start, end) + s3PreSignedUrls, err := getS3UploadUrl(uploadCtx, upReq, start, end) if err != nil { return err } // upload each chunk - for j := start; j < end; j++ { - if utils.IsCanceled(ctx) { - return ctx.Err() + for cur := start; cur < end; cur++ { + if utils.IsCanceled(uploadCtx) { + break } + offset := int64(cur-1) * chunkSize curSize := chunkSize - if j == chunkCount { + if cur == chunkCount { curSize = lastChunkSize } - err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.NewSectionReader(tmpF, chunkSize*int64(j-1), curSize), curSize, false, getS3UploadUrl) - if err != nil { - return err - } - up(float64(j) * 100 / float64(chunkCount)) + var reader *stream.SectionReader + var rateLimitedRd io.Reader + threadG.GoWithLifecycle(errgroup.Lifecycle{ + Before: func(ctx context.Context) error { + if reader == nil { + var err error + reader, err = ss.GetSectionReader(offset, curSize) + if err != nil { + return err + } + rateLimitedRd = driver.NewLimitedUploadStream(ctx, reader) + } + return nil + }, + Do: func(ctx context.Context) error { + reader.Seek(0, io.SeekStart) + uploadUrl := s3PreSignedUrls.Data.PreSignedUrls[strconv.Itoa(cur)] + if uploadUrl == "" { + return fmt.Errorf("upload url is empty, s3PreSignedUrls: %+v", s3PreSignedUrls) + } + reader.Seek(0, io.SeekStart) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadUrl, rateLimitedRd) + if err != nil { + return err + } + req.ContentLength = curSize + //req.Header.Set("Content-Length", strconv.FormatInt(curSize, 10)) + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode == http.StatusForbidden { + singleflight.AnyGroup.Do(fmt.Sprintf("Pan123.newUpload_%p", threadG), func() (any, error) { + newS3PreSignedUrls, err := getS3UploadUrl(ctx, upReq, cur, end) + if err != nil { + return nil, err + } + s3PreSignedUrls.Data.PreSignedUrls = newS3PreSignedUrls.Data.PreSignedUrls + return nil, nil + }) + if err != nil { + return err + } + return fmt.Errorf("upload s3 chunk %d failed, status code: %d", cur, res.StatusCode) + } + if res.StatusCode != http.StatusOK { + body, err := io.ReadAll(res.Body) + if err != nil { + return err + } + return fmt.Errorf("upload s3 chunk %d failed, status code: %d, body: %s", cur, res.StatusCode, body) + } + progress := 10.0 + 85.0*float64(threadG.Success())/float64(chunkCount) + up(progress) + return nil + }, + After: func(err error) { + ss.RecycleSectionReader(reader) + }, + }) } } + if err := threadG.Wait(); err != nil { + return err + } + defer up(100) // complete s3 upload return d.completeS3(ctx, upReq, file, chunkCount > 1) } - -func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSignedUrls *S3PreSignedURLs, cur, end int, reader *io.SectionReader, curSize int64, retry bool, getS3UploadUrl func(ctx context.Context, upReq *UploadResp, start int, end int) (*S3PreSignedURLs, error)) error { - uploadUrl := s3PreSignedUrls.Data.PreSignedUrls[strconv.Itoa(cur)] - if uploadUrl == "" { - return fmt.Errorf("upload url is empty, s3PreSignedUrls: %+v", s3PreSignedUrls) - } - req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, reader)) - if err != nil { - return err - } - req = req.WithContext(ctx) - req.ContentLength = curSize - //req.Header.Set("Content-Length", strconv.FormatInt(curSize, 10)) - res, err := base.HttpClient.Do(req) - if err != nil { - return err - } - defer res.Body.Close() - if res.StatusCode == http.StatusForbidden { - if retry { - return fmt.Errorf("upload s3 chunk %d failed, status code: %d", cur, res.StatusCode) - } - // refresh s3 pre signed urls - newS3PreSignedUrls, err := getS3UploadUrl(ctx, upReq, cur, end) - if err != nil { - return err - } - s3PreSignedUrls.Data.PreSignedUrls = newS3PreSignedUrls.Data.PreSignedUrls - // retry - reader.Seek(0, io.SeekStart) - return d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, cur, end, reader, curSize, true, getS3UploadUrl) - } - if res.StatusCode != http.StatusOK { - body, err := io.ReadAll(res.Body) - if err != nil { - return err - } - return fmt.Errorf("upload s3 chunk %d failed, status code: %d, body: %s", cur, res.StatusCode, body) - } - return nil -} diff --git a/drivers/123_open/upload.go b/drivers/123_open/upload.go index cc769509..3f2ec70c 100644 --- a/drivers/123_open/upload.go +++ b/drivers/123_open/upload.go @@ -2,6 +2,7 @@ package _123_open import ( "context" + "io" "net/http" "strings" "time" @@ -9,8 +10,8 @@ import ( "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" - "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/avast/retry-go" "github.com/go-resty/resty/v2" @@ -79,49 +80,64 @@ func (d *Open123) Upload(ctx context.Context, file model.FileStreamer, createRes size := file.GetSize() chunkSize := createResp.Data.SliceSize uploadNums := (size + chunkSize - 1) / chunkSize - threadG, uploadCtx := errgroup.NewGroupWithContext(ctx, d.UploadThread, + thread := min(int(uploadNums), d.UploadThread) + threadG, uploadCtx := errgroup.NewOrderedGroupWithContext(ctx, thread, retry.Attempts(3), retry.Delay(time.Second), retry.DelayType(retry.BackOffDelay)) + ss, err := stream.NewStreamSectionReader(file, int(chunkSize)) + if err != nil { + return err + } for partIndex := int64(0); partIndex < uploadNums; partIndex++ { if utils.IsCanceled(uploadCtx) { - return ctx.Err() + break } partIndex := partIndex partNumber := partIndex + 1 // 分片号从1开始 offset := partIndex * chunkSize size := min(chunkSize, size-offset) - limitedReader, err := file.RangeRead(http_range.Range{ - Start: offset, - Length: size}) - if err != nil { - return err - } - limitedReader = driver.NewLimitedUploadStream(ctx, limitedReader) + var reader *stream.SectionReader + var rateLimitedRd io.Reader + threadG.GoWithLifecycle(errgroup.Lifecycle{ + Before: func(ctx context.Context) error { + if reader == nil { + var err error + reader, err = ss.GetSectionReader(offset, size) + if err != nil { + return err + } + rateLimitedRd = driver.NewLimitedUploadStream(ctx, reader) + } + return nil + }, + Do: func(ctx context.Context) error { + reader.Seek(0, io.SeekStart) + uploadPartUrl, err := d.url(createResp.Data.PreuploadID, partNumber) + if err != nil { + return err + } - threadG.Go(func(ctx context.Context) error { - uploadPartUrl, err := d.url(createResp.Data.PreuploadID, partNumber) - if err != nil { - return err - } + req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadPartUrl, rateLimitedRd) + if err != nil { + return err + } + req.ContentLength = size - req, err := http.NewRequestWithContext(ctx, "PUT", uploadPartUrl, limitedReader) - if err != nil { - return err - } - req = req.WithContext(ctx) - req.ContentLength = size + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + _ = res.Body.Close() - res, err := base.HttpClient.Do(req) - if err != nil { - return err - } - _ = res.Body.Close() - - progress := 10.0 + 85.0*float64(threadG.Success())/float64(uploadNums) - up(progress) - return nil + progress := 10.0 + 85.0*float64(threadG.Success())/float64(uploadNums) + up(progress) + return nil + }, + After: func(err error) { + ss.RecycleSectionReader(reader) + }, }) } diff --git a/drivers/139/driver.go b/drivers/139/driver.go index b3e5e2a6..8033eefc 100644 --- a/drivers/139/driver.go +++ b/drivers/139/driver.go @@ -531,12 +531,10 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr } size := stream.GetSize() - var partSize = d.getPartSize(size) - part := size / partSize - if size%partSize > 0 { - part++ - } else if part == 0 { - part = 1 + partSize := d.getPartSize(size) + part := int64(1) + if size > partSize { + part = (size + partSize - 1) / partSize } partInfos := make([]PartInfo, 0, part) for i := int64(0); i < part; i++ { @@ -638,11 +636,10 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr // Update Progress r := io.TeeReader(limitReader, p) - req, err := http.NewRequest("PUT", uploadPartInfo.UploadUrl, r) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadPartInfo.UploadUrl, r) if err != nil { return err } - req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Length", fmt.Sprint(partSize)) req.Header.Set("Origin", "https://yun.139.com") @@ -788,12 +785,10 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr size := stream.GetSize() // Progress p := driver.NewProgress(size, up) - var partSize = d.getPartSize(size) - part := size / partSize - if size%partSize > 0 { - part++ - } else if part == 0 { - part = 1 + partSize := d.getPartSize(size) + part := int64(1) + if size > partSize { + part = (size + partSize - 1) / partSize } rateLimited := driver.NewLimitedUploadStream(ctx, stream) for i := int64(0); i < part; i++ { @@ -807,12 +802,10 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr limitReader := io.LimitReader(rateLimited, byteSize) // Update Progress r := io.TeeReader(limitReader, p) - req, err := http.NewRequest("POST", resp.Data.UploadResult.RedirectionURL, r) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, resp.Data.UploadResult.RedirectionURL, r) if err != nil { return err } - - req = req.WithContext(ctx) req.Header.Set("Content-Type", "text/plain;name="+unicode(stream.GetName())) req.Header.Set("contentSize", strconv.FormatInt(size, 10)) req.Header.Set("range", fmt.Sprintf("bytes=%d-%d", start, start+byteSize-1)) diff --git a/drivers/189/util.go b/drivers/189/util.go index 8b48fcad..d10c7e8b 100644 --- a/drivers/189/util.go +++ b/drivers/189/util.go @@ -365,11 +365,10 @@ func (d *Cloud189) newUpload(ctx context.Context, dstDir model.Obj, file model.F log.Debugf("uploadData: %+v", uploadData) requestURL := uploadData.RequestURL uploadHeaders := strings.Split(decodeURIComponent(uploadData.RequestHeader), "&") - req, err := http.NewRequest(http.MethodPut, requestURL, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, requestURL, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) if err != nil { return err } - req = req.WithContext(ctx) for _, v := range uploadHeaders { i := strings.Index(v, "=") req.Header.Set(v[0:i], v[i+1:]) diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index fc7cb98a..00fbe297 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "encoding/xml" "fmt" + "hash" "io" "net/http" "net/http/cookiejar" @@ -472,7 +473,7 @@ func (y *Cloud189PC) refreshSession() (err error) { // 无法上传大小为0的文件 func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { size := file.GetSize() - sliceSize := partSize(size) + sliceSize := min(size, partSize(size)) params := Params{ "parentFolderId": dstDir.GetID(), @@ -499,65 +500,99 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo if err != nil { return nil, err } + ss, err := stream.NewStreamSectionReader(file, int(sliceSize)) + if err != nil { + return nil, err + } - threadG, upCtx := errgroup.NewGroupWithContext(ctx, y.uploadThread, + threadG, upCtx := errgroup.NewOrderedGroupWithContext(ctx, y.uploadThread, retry.Attempts(3), retry.Delay(time.Second), retry.DelayType(retry.BackOffDelay)) - count := int(size / sliceSize) + count := 1 + if size > sliceSize { + count = int((size + sliceSize - 1) / sliceSize) + } lastPartSize := size % sliceSize - if lastPartSize > 0 { - count++ - } else { + if lastPartSize == 0 { lastPartSize = sliceSize } - fileMd5 := utils.MD5.NewFunc() - silceMd5 := utils.MD5.NewFunc() + silceMd5Hexs := make([]string, 0, count) - teeReader := io.TeeReader(file, io.MultiWriter(fileMd5, silceMd5)) - byteSize := sliceSize + silceMd5 := utils.MD5.NewFunc() + var writers io.Writer = silceMd5 + + fileMd5Hex := file.GetHash().GetHash(utils.MD5) + var fileMd5 hash.Hash + if len(fileMd5Hex) != utils.MD5.Width { + fileMd5 = utils.MD5.NewFunc() + writers = io.MultiWriter(silceMd5, fileMd5) + } for i := 1; i <= count; i++ { if utils.IsCanceled(upCtx) { break } + offset := int64((i)-1) * sliceSize + size := sliceSize if i == count { - byteSize = lastPartSize - } - byteData := make([]byte, byteSize) - // 读取块 - silceMd5.Reset() - if _, err := io.ReadFull(teeReader, byteData); err != io.EOF && err != nil { - return nil, err + size = lastPartSize } + partInfo := "" + var reader *stream.SectionReader + var rateLimitedRd io.Reader + threadG.GoWithLifecycle(errgroup.Lifecycle{ + Before: func(ctx context.Context) error { + if reader == nil { + var err error + reader, err = ss.GetSectionReader(offset, size) + if err != nil { + return err + } + silceMd5.Reset() + w, _ := utils.CopyWithBuffer(writers, reader) + if w != size { + return fmt.Errorf("can't read data, expected=%d, got=%d", size, w) + } + // 计算块md5并进行hex和base64编码 + md5Bytes := silceMd5.Sum(nil) + silceMd5Hexs = append(silceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Bytes))) + partInfo = fmt.Sprintf("%d-%s", i, base64.StdEncoding.EncodeToString(md5Bytes)) - // 计算块md5并进行hex和base64编码 - md5Bytes := silceMd5.Sum(nil) - silceMd5Hexs = append(silceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Bytes))) - partInfo := fmt.Sprintf("%d-%s", i, base64.StdEncoding.EncodeToString(md5Bytes)) + rateLimitedRd = driver.NewLimitedUploadStream(ctx, reader) + } + return nil + }, + Do: func(ctx context.Context) error { + reader.Seek(0, io.SeekStart) + uploadUrls, err := y.GetMultiUploadUrls(ctx, isFamily, initMultiUpload.Data.UploadFileID, partInfo) + if err != nil { + return err + } - threadG.Go(func(ctx context.Context) error { - uploadUrls, err := y.GetMultiUploadUrls(ctx, isFamily, initMultiUpload.Data.UploadFileID, partInfo) - if err != nil { - return err - } - - // step.4 上传切片 - uploadUrl := uploadUrls[0] - _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, - driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)), isFamily) - if err != nil { - return err - } - up(float64(threadG.Success()) * 100 / float64(count)) - return nil - }) + // step.4 上传切片 + uploadUrl := uploadUrls[0] + _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, + driver.NewLimitedUploadStream(ctx, rateLimitedRd), isFamily) + if err != nil { + return err + } + up(float64(threadG.Success()) * 100 / float64(count)) + return nil + }, + After: func(err error) { + ss.RecycleSectionReader(reader) + }, + }, + ) } if err = threadG.Wait(); err != nil { return nil, err } - fileMd5Hex := strings.ToUpper(hex.EncodeToString(fileMd5.Sum(nil))) + if fileMd5 != nil { + fileMd5Hex = strings.ToUpper(hex.EncodeToString(fileMd5.Sum(nil))) + } sliceMd5Hex := fileMd5Hex if file.GetSize() > sliceSize { sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(silceMd5Hexs, "\n"))) @@ -620,11 +655,12 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode cache = tmpF } sliceSize := partSize(size) - count := int(size / sliceSize) + count := 1 + if size > sliceSize { + count = int((size + sliceSize - 1) / sliceSize) + } lastSliceSize := size % sliceSize - if lastSliceSize > 0 { - count++ - } else { + if lastSliceSize == 0 { lastSliceSize = sliceSize } @@ -738,7 +774,8 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode } // step.4 上传切片 - _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, io.NewSectionReader(cache, offset, byteSize), isFamily) + rateLimitedRd := driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize)) + _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, rateLimitedRd, isFamily) if err != nil { return err } diff --git a/drivers/aliyundrive/driver.go b/drivers/aliyundrive/driver.go index ae2abb3a..92df0319 100644 --- a/drivers/aliyundrive/driver.go +++ b/drivers/aliyundrive/driver.go @@ -297,11 +297,10 @@ func (d *AliDrive) Put(ctx context.Context, dstDir model.Obj, streamer model.Fil if d.InternalUpload { url = partInfo.InternalUploadUrl } - req, err := http.NewRequest("PUT", url, io.LimitReader(rateLimited, DEFAULT)) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, io.LimitReader(rateLimited, DEFAULT)) if err != nil { return err } - req = req.WithContext(ctx) res, err := base.HttpClient.Do(req) if err != nil { return err diff --git a/drivers/aliyundrive_open/upload.go b/drivers/aliyundrive_open/upload.go index 51a63017..98852706 100644 --- a/drivers/aliyundrive_open/upload.go +++ b/drivers/aliyundrive_open/upload.go @@ -69,7 +69,7 @@ func (d *AliyundriveOpen) uploadPart(ctx context.Context, r io.Reader, partInfo if d.InternalUpload { uploadUrl = strings.ReplaceAll(uploadUrl, "https://cn-beijing-data.aliyundrive.net/", "http://ccp-bj29-bj-1592982087.oss-cn-beijing-internal.aliyuncs.com/") } - req, err := http.NewRequestWithContext(ctx, "PUT", uploadUrl, r) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadUrl, r) if err != nil { return err } @@ -225,6 +225,10 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m preTime := time.Now() var offset, length int64 = 0, partSize //var length + ss, err := streamPkg.NewStreamSectionReader(stream, int(partSize)) + if err != nil { + return nil, err + } for i := 0; i < len(createResp.PartInfoList); i++ { if utils.IsCanceled(ctx) { return nil, ctx.Err() @@ -240,22 +244,19 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m if remain := stream.GetSize() - offset; length > remain { length = remain } - rd := utils.NewMultiReadable(io.LimitReader(stream, partSize)) - if rapidUpload { - srd, err := stream.RangeRead(http_range.Range{Start: offset, Length: length}) - if err != nil { - return nil, err - } - rd = utils.NewMultiReadable(srd) + rd, err := ss.GetSectionReader(offset, length) + if err != nil { + return nil, err } + rateLimitedRd := driver.NewLimitedUploadStream(ctx, rd) err = retry.Do(func() error { - _ = rd.Reset() - rateLimitedRd := driver.NewLimitedUploadStream(ctx, rd) + rd.Seek(0, io.SeekStart) return d.uploadPart(ctx, rateLimitedRd, createResp.PartInfoList[i]) }, retry.Attempts(3), retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second)) + ss.RecycleSectionReader(rd) if err != nil { return nil, err } diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index 9bbaa3ae..0fa94e88 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -203,11 +203,12 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F streamSize := stream.GetSize() sliceSize := d.getSliceSize(streamSize) - count := int(streamSize / sliceSize) + count := 1 + if streamSize > sliceSize { + count = int((streamSize + sliceSize - 1) / sliceSize) + } lastBlockSize := streamSize % sliceSize - if lastBlockSize > 0 { - count++ - } else { + if lastBlockSize == 0 { lastBlockSize = sliceSize } diff --git a/drivers/baidu_photo/driver.go b/drivers/baidu_photo/driver.go index bbd6e74e..00e36ee6 100644 --- a/drivers/baidu_photo/driver.go +++ b/drivers/baidu_photo/driver.go @@ -262,11 +262,12 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil // 计算需要的数据 streamSize := stream.GetSize() - count := int(streamSize / DEFAULT) + count := 1 + if streamSize > DEFAULT { + count = int((streamSize + DEFAULT - 1) / DEFAULT) + } lastBlockSize := streamSize % DEFAULT - if lastBlockSize > 0 { - count++ - } else { + if lastBlockSize == 0 { lastBlockSize = DEFAULT } diff --git a/drivers/chaoxing/driver.go b/drivers/chaoxing/driver.go index 79835995..cb12b29f 100644 --- a/drivers/chaoxing/driver.go +++ b/drivers/chaoxing/driver.go @@ -255,7 +255,7 @@ func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, file model.FileStr }, UpdateProgress: up, }) - req, err := http.NewRequestWithContext(ctx, "POST", "https://pan-yz.chaoxing.com/upload", r) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://pan-yz.chaoxing.com/upload", r) if err != nil { return err } diff --git a/drivers/chaoxing/util.go b/drivers/chaoxing/util.go index 03caa1ee..715c248a 100644 --- a/drivers/chaoxing/util.go +++ b/drivers/chaoxing/util.go @@ -167,7 +167,7 @@ func (d *ChaoXing) Login() (string, error) { return "", err } // Create the request - req, err := http.NewRequest("POST", "https://passport2.chaoxing.com/fanyalogin", body) + req, err := http.NewRequest(http.MethodPost, "https://passport2.chaoxing.com/fanyalogin", body) if err != nil { return "", err } diff --git a/drivers/cloudreve/util.go b/drivers/cloudreve/util.go index 06b51319..88ff67cc 100644 --- a/drivers/cloudreve/util.go +++ b/drivers/cloudreve/util.go @@ -18,6 +18,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/setting" + streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/cookie" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/avast/retry-go" @@ -241,23 +242,26 @@ func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u U var finish int64 = 0 var chunk int = 0 DEFAULT := int64(u.ChunkSize) + ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT)) + if err != nil { + return err + } for finish < stream.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() } left := stream.GetSize() - finish byteSize := min(left, DEFAULT) - err := retry.Do( + utils.Log.Debugf("[Cloudreve-Remote] upload range: %d-%d/%d", finish, finish+byteSize-1, stream.GetSize()) + rd, err := ss.GetSectionReader(finish, byteSize) + if err != nil { + return err + } + err = retry.Do( func() error { - utils.Log.Debugf("[Cloudreve-Remote] upload range: %d-%d/%d", finish, finish+byteSize-1, stream.GetSize()) - byteData := make([]byte, byteSize) - n, err := io.ReadFull(stream, byteData) - utils.Log.Debug(err, n) - if err != nil { - return err - } + rd.Seek(0, io.SeekStart) req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl+"?chunk="+strconv.Itoa(chunk), - driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) + driver.NewLimitedUploadStream(ctx, rd)) if err != nil { return err } @@ -290,6 +294,7 @@ func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u U retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) + ss.RecycleSectionReader(rd) if err != nil { return err } @@ -304,23 +309,25 @@ func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u uploadUrl := u.UploadURLs[0] var finish int64 = 0 DEFAULT := int64(u.ChunkSize) + ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT)) + if err != nil { + return err + } for finish < stream.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() } left := stream.GetSize() - finish byteSize := min(left, DEFAULT) - err := retry.Do( + utils.Log.Debugf("[Cloudreve-OneDrive] upload range: %d-%d/%d", finish, finish+byteSize-1, stream.GetSize()) + rd, err := ss.GetSectionReader(finish, byteSize) + if err != nil { + return err + } + err = retry.Do( func() error { - utils.Log.Debugf("[Cloudreve-OneDrive] upload range: %d-%d/%d", finish, finish+byteSize-1, stream.GetSize()) - byteData := make([]byte, byteSize) - n, err := io.ReadFull(stream, byteData) - utils.Log.Debug(err, n) - if err != nil { - return err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadUrl, - driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) + rd.Seek(0, io.SeekStart) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadUrl, driver.NewLimitedUploadStream(ctx, rd)) if err != nil { return err } @@ -346,6 +353,7 @@ func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) + ss.RecycleSectionReader(rd) if err != nil { return err } @@ -363,23 +371,26 @@ func (d *Cloudreve) upS3(ctx context.Context, stream model.FileStreamer, u Uploa var chunk int = 0 var etags []string DEFAULT := int64(u.ChunkSize) + ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT)) + if err != nil { + return err + } for finish < stream.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() } left := stream.GetSize() - finish byteSize := min(left, DEFAULT) - err := retry.Do( + utils.Log.Debugf("[Cloudreve-S3] upload range: %d-%d/%d", finish, finish+byteSize-1, stream.GetSize()) + rd, err := ss.GetSectionReader(finish, byteSize) + if err != nil { + return err + } + err = retry.Do( func() error { - utils.Log.Debugf("[Cloudreve-S3] upload range: %d-%d/%d", finish, finish+byteSize-1, stream.GetSize()) - byteData := make([]byte, byteSize) - n, err := io.ReadFull(stream, byteData) - utils.Log.Debug(err, n) - if err != nil { - return err - } + rd.Seek(0, io.SeekStart) req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.UploadURLs[chunk], - driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + driver.NewLimitedUploadStream(ctx, rd)) if err != nil { return err } @@ -404,6 +415,7 @@ func (d *Cloudreve) upS3(ctx context.Context, stream model.FileStreamer, u Uploa retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) + ss.RecycleSectionReader(rd) if err != nil { return err } diff --git a/drivers/cloudreve_v4/util.go b/drivers/cloudreve_v4/util.go index 5e7559e3..215b4c53 100644 --- a/drivers/cloudreve_v4/util.go +++ b/drivers/cloudreve_v4/util.go @@ -19,6 +19,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/internal/setting" + "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/avast/retry-go" "github.com/go-resty/resty/v2" @@ -256,23 +257,26 @@ func (d *CloudreveV4) upRemote(ctx context.Context, file model.FileStreamer, u F var finish int64 = 0 var chunk int = 0 DEFAULT := int64(u.ChunkSize) + ss, err := stream.NewStreamSectionReader(file, int(DEFAULT)) + if err != nil { + return err + } for finish < file.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() } left := file.GetSize() - finish byteSize := min(left, DEFAULT) - err := retry.Do( + utils.Log.Debugf("[CloudreveV4-Remote] upload range: %d-%d/%d", finish, finish+byteSize-1, file.GetSize()) + rd, err := ss.GetSectionReader(finish, byteSize) + if err != nil { + return err + } + err = retry.Do( func() error { - utils.Log.Debugf("[CloudreveV4-Remote] upload range: %d-%d/%d", finish, finish+byteSize-1, file.GetSize()) - byteData := make([]byte, byteSize) - n, err := io.ReadFull(file, byteData) - utils.Log.Debug(err, n) - if err != nil { - return err - } + rd.Seek(0, io.SeekStart) req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl+"?chunk="+strconv.Itoa(chunk), - driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) + driver.NewLimitedUploadStream(ctx, rd)) if err != nil { return err } @@ -305,6 +309,7 @@ func (d *CloudreveV4) upRemote(ctx context.Context, file model.FileStreamer, u F retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) + ss.RecycleSectionReader(rd) if err != nil { return err } @@ -319,23 +324,25 @@ func (d *CloudreveV4) upOneDrive(ctx context.Context, file model.FileStreamer, u uploadUrl := u.UploadUrls[0] var finish int64 = 0 DEFAULT := int64(u.ChunkSize) + ss, err := stream.NewStreamSectionReader(file, int(DEFAULT)) + if err != nil { + return err + } for finish < file.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() } left := file.GetSize() - finish byteSize := min(left, DEFAULT) - err := retry.Do( + utils.Log.Debugf("[CloudreveV4-OneDrive] upload range: %d-%d/%d", finish, finish+byteSize-1, file.GetSize()) + rd, err := ss.GetSectionReader(finish, byteSize) + if err != nil { + return err + } + err = retry.Do( func() error { - utils.Log.Debugf("[CloudreveV4-OneDrive] upload range: %d-%d/%d", finish, finish+byteSize-1, file.GetSize()) - byteData := make([]byte, byteSize) - n, err := io.ReadFull(file, byteData) - utils.Log.Debug(err, n) - if err != nil { - return err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadUrl, - driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) + rd.Seek(0, io.SeekStart) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadUrl, driver.NewLimitedUploadStream(ctx, rd)) if err != nil { return err } @@ -362,6 +369,7 @@ func (d *CloudreveV4) upOneDrive(ctx context.Context, file model.FileStreamer, u retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) + ss.RecycleSectionReader(rd) if err != nil { return err } @@ -379,23 +387,26 @@ func (d *CloudreveV4) upS3(ctx context.Context, file model.FileStreamer, u FileU var chunk int = 0 var etags []string DEFAULT := int64(u.ChunkSize) + ss, err := stream.NewStreamSectionReader(file, int(DEFAULT)) + if err != nil { + return err + } for finish < file.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() } left := file.GetSize() - finish byteSize := min(left, DEFAULT) - err := retry.Do( + utils.Log.Debugf("[CloudreveV4-S3] upload range: %d-%d/%d", finish, finish+byteSize-1, file.GetSize()) + rd, err := ss.GetSectionReader(finish, byteSize) + if err != nil { + return err + } + err = retry.Do( func() error { - utils.Log.Debugf("[CloudreveV4-S3] upload range: %d-%d/%d", finish, finish+byteSize-1, file.GetSize()) - byteData := make([]byte, byteSize) - n, err := io.ReadFull(file, byteData) - utils.Log.Debug(err, n) - if err != nil { - return err - } + rd.Seek(0, io.SeekStart) req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.UploadUrls[chunk], - driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + driver.NewLimitedUploadStream(ctx, rd)) if err != nil { return err } @@ -421,6 +432,7 @@ func (d *CloudreveV4) upS3(ctx context.Context, file model.FileStreamer, u FileU retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) + ss.RecycleSectionReader(rd) if err != nil { return err } diff --git a/drivers/doubao/driver.go b/drivers/doubao/driver.go index 1819c686..d2ba04ea 100644 --- a/drivers/doubao/driver.go +++ b/drivers/doubao/driver.go @@ -236,7 +236,7 @@ func (d *Doubao) Put(ctx context.Context, dstDir model.Obj, file model.FileStrea // 根据文件大小选择上传方式 if file.GetSize() <= 1*utils.MB { // 小于1MB,使用普通模式上传 - return d.Upload(&uploadConfig, dstDir, file, up, dataType) + return d.Upload(ctx, &uploadConfig, dstDir, file, up, dataType) } // 大文件使用分片上传 return d.UploadByMultipart(ctx, &uploadConfig, file.GetSize(), dstDir, file, up, dataType) diff --git a/drivers/doubao/util.go b/drivers/doubao/util.go index bc633baf..39d55134 100644 --- a/drivers/doubao/util.go +++ b/drivers/doubao/util.go @@ -24,6 +24,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/avast/retry-go" @@ -447,41 +448,66 @@ func (d *Doubao) uploadNode(uploadConfig *UploadConfig, dir model.Obj, file mode } // Upload 普通上传实现 -func (d *Doubao) Upload(config *UploadConfig, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, dataType string) (model.Obj, error) { - data, err := io.ReadAll(file) +func (d *Doubao) Upload(ctx context.Context, config *UploadConfig, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, dataType string) (model.Obj, error) { + ss, err := stream.NewStreamSectionReader(file, int(file.GetSize())) + if err != nil { + return nil, err + } + reader, err := ss.GetSectionReader(0, file.GetSize()) if err != nil { return nil, err } // 计算CRC32 crc32Hash := crc32.NewIEEE() - crc32Hash.Write(data) + w, _ := utils.CopyWithBuffer(crc32Hash, reader) + if w != file.GetSize() { + return nil, fmt.Errorf("can't read data, expected=%d, got=%d", file.GetSize(), w) + } crc32Value := hex.EncodeToString(crc32Hash.Sum(nil)) // 构建请求路径 uploadNode := config.InnerUploadAddress.UploadNodes[0] storeInfo := uploadNode.StoreInfos[0] uploadUrl := fmt.Sprintf("https://%s/upload/v1/%s", uploadNode.UploadHost, storeInfo.StoreURI) - - uploadResp := UploadResp{} - - if _, err = d.uploadRequest(uploadUrl, http.MethodPost, storeInfo, func(req *resty.Request) { - req.SetHeaders(map[string]string{ - "Content-Type": "application/octet-stream", - "Content-Crc32": crc32Value, - "Content-Length": fmt.Sprintf("%d", len(data)), - "Content-Disposition": fmt.Sprintf("attachment; filename=%s", url.QueryEscape(storeInfo.StoreURI)), - }) - - req.SetBody(data) - }, &uploadResp); err != nil { + rateLimitedRd := driver.NewLimitedUploadStream(ctx, reader) + err = d._retryOperation("Upload", func() error { + reader.Seek(0, io.SeekStart) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl, rateLimitedRd) + if err != nil { + return err + } + req.Header = map[string][]string{ + "Referer": {BaseURL + "/"}, + "Origin": {BaseURL}, + "User-Agent": {UserAgent}, + "X-Storage-U": {d.UserId}, + "Authorization": {storeInfo.Auth}, + "Content-Type": {"application/octet-stream"}, + "Content-Crc32": {crc32Value}, + "Content-Length": {fmt.Sprintf("%d", file.GetSize())}, + "Content-Disposition": {fmt.Sprintf("attachment; filename=%s", url.QueryEscape(storeInfo.StoreURI))}, + } + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + bytes, _ := io.ReadAll(res.Body) + resp := UploadResp{} + utils.Json.Unmarshal(bytes, &resp) + if resp.Code != 2000 { + return fmt.Errorf("upload part failed: %s", resp.Message) + } else if resp.Data.Crc32 != crc32Value { + return fmt.Errorf("upload part failed: crc32 mismatch, expected %s, got %s", crc32Value, resp.Data.Crc32) + } + return nil + }) + ss.RecycleSectionReader(reader) + if err != nil { return nil, err } - if uploadResp.Code != 2000 { - return nil, fmt.Errorf("upload failed: %s", uploadResp.Message) - } - uploadNodeResp, err := d.uploadNode(config, dstDir, file, dataType) if err != nil { return nil, err @@ -519,65 +545,104 @@ func (d *Doubao) UploadByMultipart(ctx context.Context, config *UploadConfig, fi totalParts := (fileSize + chunkSize - 1) / chunkSize // 创建分片信息组 parts := make([]UploadPart, totalParts) - // 缓存文件 - tempFile, err := file.CacheFullInTempFile() + + // 用 stream.NewStreamSectionReader 替代缓存临时文件 + ss, err := stream.NewStreamSectionReader(file, int(chunkSize)) if err != nil { - return nil, fmt.Errorf("failed to cache file: %w", err) + return nil, err } up(10.0) // 更新进度 // 设置并行上传 - threadG, uploadCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, - retry.Attempts(1), + thread := min(int(totalParts), d.uploadThread) + threadG, uploadCtx := errgroup.NewOrderedGroupWithContext(ctx, thread, + retry.Attempts(MaxRetryAttempts), retry.Delay(time.Second), - retry.DelayType(retry.BackOffDelay)) + retry.DelayType(retry.BackOffDelay), + retry.MaxJitter(200*time.Millisecond), + ) var partsMutex sync.Mutex // 并行上传所有分片 - for partIndex := int64(0); partIndex < totalParts; partIndex++ { + hash := crc32.NewIEEE() + for partIndex := range totalParts { if utils.IsCanceled(uploadCtx) { break } - partIndex := partIndex partNumber := partIndex + 1 // 分片编号从1开始 - threadG.Go(func(ctx context.Context) error { - // 计算此分片的大小和偏移 - offset := partIndex * chunkSize - size := chunkSize - if partIndex == totalParts-1 { - size = fileSize - offset - } - - limitedReader := driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, size)) - // 读取数据到内存 - data, err := io.ReadAll(limitedReader) - if err != nil { - return fmt.Errorf("failed to read part %d: %w", partNumber, err) - } - // 计算CRC32 - crc32Value := calculateCRC32(data) - // 使用_retryOperation上传分片 - var uploadPart UploadPart - if err = d._retryOperation(fmt.Sprintf("Upload part %d", partNumber), func() error { - var err error - uploadPart, err = d.uploadPart(config, uploadUrl, uploadID, partNumber, data, crc32Value) - return err - }); err != nil { - return fmt.Errorf("part %d upload failed: %w", partNumber, err) - } - // 记录成功上传的分片 - partsMutex.Lock() - parts[partIndex] = UploadPart{ - PartNumber: strconv.FormatInt(partNumber, 10), - Etag: uploadPart.Etag, - Crc32: crc32Value, - } - partsMutex.Unlock() - // 更新进度 - progress := 10.0 + 90.0*float64(threadG.Success()+1)/float64(totalParts) - up(math.Min(progress, 95.0)) - - return nil + // 计算此分片的大小和偏移 + offset := partIndex * chunkSize + size := chunkSize + if partIndex == totalParts-1 { + size = fileSize - offset + } + var reader *stream.SectionReader + var rateLimitedRd io.Reader + crc32Value := "" + threadG.GoWithLifecycle(errgroup.Lifecycle{ + Before: func(ctx context.Context) error { + if reader == nil { + var err error + reader, err = ss.GetSectionReader(offset, size) + if err != nil { + return err + } + hash.Reset() + w, _ := utils.CopyWithBuffer(hash, reader) + if w != size { + return fmt.Errorf("can't read data, expected=%d, got=%d", size, w) + } + crc32Value = hex.EncodeToString(hash.Sum(nil)) + rateLimitedRd = driver.NewLimitedUploadStream(ctx, reader) + } + return nil + }, + Do: func(ctx context.Context) error { + reader.Seek(0, io.SeekStart) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s?uploadid=%s&part_number=%d&phase=transfer", uploadUrl, uploadID, partNumber), rateLimitedRd) + if err != nil { + return err + } + req.Header = map[string][]string{ + "Referer": {BaseURL + "/"}, + "Origin": {BaseURL}, + "User-Agent": {UserAgent}, + "X-Storage-U": {d.UserId}, + "Authorization": {storeInfo.Auth}, + "Content-Type": {"application/octet-stream"}, + "Content-Crc32": {crc32Value}, + "Content-Length": {fmt.Sprintf("%d", size)}, + "Content-Disposition": {fmt.Sprintf("attachment; filename=%s", url.QueryEscape(storeInfo.StoreURI))}, + } + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + bytes, _ := io.ReadAll(res.Body) + uploadResp := UploadResp{} + utils.Json.Unmarshal(bytes, &uploadResp) + if uploadResp.Code != 2000 { + return fmt.Errorf("upload part failed: %s", uploadResp.Message) + } else if uploadResp.Data.Crc32 != crc32Value { + return fmt.Errorf("upload part failed: crc32 mismatch, expected %s, got %s", crc32Value, uploadResp.Data.Crc32) + } + // 记录成功上传的分片 + partsMutex.Lock() + parts[partIndex] = UploadPart{ + PartNumber: strconv.FormatInt(partNumber, 10), + Etag: uploadResp.Data.Etag, + Crc32: crc32Value, + } + partsMutex.Unlock() + // 更新进度 + progress := 10.0 + 90.0*float64(threadG.Success()+1)/float64(totalParts) + up(math.Min(progress, 95.0)) + return nil + }, + After: func(err error) { + ss.RecycleSectionReader(reader) + }, }) } @@ -680,42 +745,6 @@ func (d *Doubao) initMultipartUpload(config *UploadConfig, uploadUrl string, sto return uploadResp.Data.UploadId, nil } -// 分片上传实现 -func (d *Doubao) uploadPart(config *UploadConfig, uploadUrl, uploadID string, partNumber int64, data []byte, crc32Value string) (resp UploadPart, err error) { - uploadResp := UploadResp{} - storeInfo := config.InnerUploadAddress.UploadNodes[0].StoreInfos[0] - - _, err = d.uploadRequest(uploadUrl, http.MethodPost, storeInfo, func(req *resty.Request) { - req.SetHeaders(map[string]string{ - "Content-Type": "application/octet-stream", - "Content-Crc32": crc32Value, - "Content-Length": fmt.Sprintf("%d", len(data)), - "Content-Disposition": fmt.Sprintf("attachment; filename=%s", url.QueryEscape(storeInfo.StoreURI)), - }) - - req.SetQueryParams(map[string]string{ - "uploadid": uploadID, - "part_number": strconv.FormatInt(partNumber, 10), - "phase": "transfer", - }) - - req.SetBody(data) - req.SetContentLength(true) - }, &uploadResp) - - if err != nil { - return resp, err - } - - if uploadResp.Code != 2000 { - return resp, fmt.Errorf("upload part failed: %s", uploadResp.Message) - } else if uploadResp.Data.Crc32 != crc32Value { - return resp, fmt.Errorf("upload part failed: crc32 mismatch, expected %s, got %s", crc32Value, uploadResp.Data.Crc32) - } - - return uploadResp.Data, nil -} - // 完成分片上传 func (d *Doubao) completeMultipartUpload(config *UploadConfig, uploadUrl, uploadID string, parts []UploadPart) error { uploadResp := UploadResp{} @@ -784,13 +813,6 @@ func (d *Doubao) commitMultipartUpload(uploadConfig *UploadConfig) error { return nil } -// 计算CRC32 -func calculateCRC32(data []byte) string { - hash := crc32.NewIEEE() - hash.Write(data) - return hex.EncodeToString(hash.Sum(nil)) -} - // _retryOperation 操作重试 func (d *Doubao) _retryOperation(operation string, fn func() error) error { return retry.Do( diff --git a/drivers/dropbox/driver.go b/drivers/dropbox/driver.go index 4ae3ddf8..000da086 100644 --- a/drivers/dropbox/driver.go +++ b/drivers/dropbox/driver.go @@ -192,12 +192,11 @@ func (d *Dropbox) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt url := d.contentBase + "/2/files/upload_session/append_v2" reader := driver.NewLimitedUploadStream(ctx, io.LimitReader(stream, PartSize)) - req, err := http.NewRequest(http.MethodPost, url, reader) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, reader) if err != nil { log.Errorf("failed to update file when append to upload session, err: %+v", err) return err } - req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Authorization", "Bearer "+d.AccessToken) diff --git a/drivers/dropbox/util.go b/drivers/dropbox/util.go index bb71118d..73cb4c8c 100644 --- a/drivers/dropbox/util.go +++ b/drivers/dropbox/util.go @@ -169,11 +169,10 @@ func (d *Dropbox) getFiles(ctx context.Context, path string) ([]File, error) { func (d *Dropbox) finishUploadSession(ctx context.Context, toPath string, offset int64, sessionId string) error { url := d.contentBase + "/2/files/upload_session/finish" - req, err := http.NewRequest(http.MethodPost, url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { return err } - req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Authorization", "Bearer "+d.AccessToken) @@ -214,11 +213,10 @@ func (d *Dropbox) finishUploadSession(ctx context.Context, toPath string, offset func (d *Dropbox) startUploadSession(ctx context.Context) (string, error) { url := d.contentBase + "/2/files/upload_session/start" - req, err := http.NewRequest(http.MethodPost, url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { return "", err } - req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Authorization", "Bearer "+d.AccessToken) req.Header.Set("Dropbox-API-Arg", "{\"close\":false}") diff --git a/drivers/febbox/oauth2.go b/drivers/febbox/oauth2.go index 6345d1a7..e9029168 100644 --- a/drivers/febbox/oauth2.go +++ b/drivers/febbox/oauth2.go @@ -31,13 +31,13 @@ func (c *customTokenSource) Token() (*oauth2.Token, error) { v.Set("client_id", c.config.ClientID) v.Set("client_secret", c.config.ClientSecret) - req, err := http.NewRequest("POST", c.config.TokenURL, strings.NewReader(v.Encode())) + req, err := http.NewRequestWithContext(c.ctx, http.MethodPost, c.config.TokenURL, strings.NewReader(v.Encode())) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := http.DefaultClient.Do(req.WithContext(c.ctx)) + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } diff --git a/drivers/ftp/util.go b/drivers/ftp/util.go index 9e050b4b..c81803d6 100644 --- a/drivers/ftp/util.go +++ b/drivers/ftp/util.go @@ -15,8 +15,8 @@ import ( // do others that not defined in Driver interface func (d *FTP) login() error { - err, _, _ := singleflight.ErrorGroup.Do(fmt.Sprintf("FTP.login:%p", d), func() (error, error) { - return d._login(), nil + _, err, _ := singleflight.AnyGroup.Do(fmt.Sprintf("FTP.login:%p", d), func() (any, error) { + return nil, d._login() }) return err } diff --git a/drivers/google_drive/util.go b/drivers/google_drive/util.go index 97e04f4d..ff219136 100644 --- a/drivers/google_drive/util.go +++ b/drivers/google_drive/util.go @@ -5,17 +5,20 @@ import ( "crypto/x509" "encoding/pem" "fmt" - "github.com/OpenListTeam/OpenList/v4/internal/op" + "io" "net/http" "os" "regexp" "strconv" "time" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/avast/retry-go" + "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" - "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/go-resty/resty/v2" "github.com/golang-jwt/jwt/v4" @@ -251,28 +254,58 @@ func (d *GoogleDrive) getFiles(id string) ([]File, error) { return res, nil } -func (d *GoogleDrive) chunkUpload(ctx context.Context, stream model.FileStreamer, url string) error { +func (d *GoogleDrive) chunkUpload(ctx context.Context, file model.FileStreamer, url string) error { var defaultChunkSize = d.ChunkSize * 1024 * 1024 var offset int64 = 0 - for offset < stream.GetSize() { + ss, err := stream.NewStreamSectionReader(file, int(defaultChunkSize)) + if err != nil { + return err + } + url += "?includeItemsFromAllDrives=true&supportsAllDrives=true" + for offset < file.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() } - chunkSize := stream.GetSize() - offset - if chunkSize > defaultChunkSize { - chunkSize = defaultChunkSize - } - reader, err := stream.RangeRead(http_range.Range{Start: offset, Length: chunkSize}) + chunkSize := min(file.GetSize()-offset, defaultChunkSize) + reader, err := ss.GetSectionReader(offset, chunkSize) if err != nil { return err } - reader = driver.NewLimitedUploadStream(ctx, reader) - _, err = d.request(url, http.MethodPut, func(req *resty.Request) { - req.SetHeaders(map[string]string{ - "Content-Length": strconv.FormatInt(chunkSize, 10), - "Content-Range": fmt.Sprintf("bytes %d-%d/%d", offset, offset+chunkSize-1, stream.GetSize()), - }).SetBody(reader).SetContext(ctx) - }, nil) + limitedReader := driver.NewLimitedUploadStream(ctx, reader) + err = retry.Do(func() error { + reader.Seek(0, io.SeekStart) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, limitedReader) + if err != nil { + return err + } + req.Header = map[string][]string{ + "Authorization": {"Bearer " + d.AccessToken}, + "Content-Length": {strconv.FormatInt(chunkSize, 10)}, + "Content-Range": {fmt.Sprintf("bytes %d-%d/%d", offset, offset+chunkSize-1, file.GetSize())}, + } + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + bytes, _ := io.ReadAll(res.Body) + var e Error + utils.Json.Unmarshal(bytes, &e) + if e.Error.Code != 0 { + if e.Error.Code == 401 { + err = d.refreshToken() + if err != nil { + return err + } + } + return fmt.Errorf("%s: %v", e.Error.Message, e.Error.Errors) + } + return nil + }, + retry.Attempts(3), + retry.DelayType(retry.BackOffDelay), + retry.Delay(time.Second)) + ss.RecycleSectionReader(reader) if err != nil { return err } diff --git a/drivers/misskey/util.go b/drivers/misskey/util.go index 65764f6f..5e7a0d8d 100644 --- a/drivers/misskey/util.go +++ b/drivers/misskey/util.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "net/http" "time" "github.com/go-resty/resty/v2" @@ -72,7 +73,7 @@ func (d *Misskey) getFiles(dir model.Obj) ([]model.Obj, error) { } else { body = map[string]string{} } - err := d.request("/files", "POST", setBody(body), &files) + err := d.request("/files", http.MethodPost, setBody(body), &files) if err != nil { return []model.Obj{}, err } @@ -89,7 +90,7 @@ func (d *Misskey) getFolders(dir model.Obj) ([]model.Obj, error) { } else { body = map[string]string{} } - err := d.request("/folders", "POST", setBody(body), &folders) + err := d.request("/folders", http.MethodPost, setBody(body), &folders) if err != nil { return []model.Obj{}, err } @@ -106,7 +107,7 @@ func (d *Misskey) list(dir model.Obj) ([]model.Obj, error) { func (d *Misskey) link(file model.Obj) (*model.Link, error) { var mFile MFile - err := d.request("/files/show", "POST", setBody(map[string]string{"fileId": file.GetID()}), &mFile) + err := d.request("/files/show", http.MethodPost, setBody(map[string]string{"fileId": file.GetID()}), &mFile) if err != nil { return nil, err } @@ -117,7 +118,7 @@ func (d *Misskey) link(file model.Obj) (*model.Link, error) { func (d *Misskey) makeDir(parentDir model.Obj, dirName string) (model.Obj, error) { var folder MFolder - err := d.request("/folders/create", "POST", setBody(map[string]interface{}{"parentId": handleFolderId(parentDir), "name": dirName}), &folder) + err := d.request("/folders/create", http.MethodPost, setBody(map[string]interface{}{"parentId": handleFolderId(parentDir), "name": dirName}), &folder) if err != nil { return nil, err } @@ -127,11 +128,11 @@ func (d *Misskey) makeDir(parentDir model.Obj, dirName string) (model.Obj, error func (d *Misskey) move(srcObj, dstDir model.Obj) (model.Obj, error) { if srcObj.IsDir() { var folder MFolder - err := d.request("/folders/update", "POST", setBody(map[string]interface{}{"folderId": srcObj.GetID(), "parentId": handleFolderId(dstDir)}), &folder) + err := d.request("/folders/update", http.MethodPost, setBody(map[string]interface{}{"folderId": srcObj.GetID(), "parentId": handleFolderId(dstDir)}), &folder) return mFolder2Object(folder), err } else { var file MFile - err := d.request("/files/update", "POST", setBody(map[string]interface{}{"fileId": srcObj.GetID(), "folderId": handleFolderId(dstDir)}), &file) + err := d.request("/files/update", http.MethodPost, setBody(map[string]interface{}{"fileId": srcObj.GetID(), "folderId": handleFolderId(dstDir)}), &file) return mFile2Object(file), err } } @@ -139,11 +140,11 @@ func (d *Misskey) move(srcObj, dstDir model.Obj) (model.Obj, error) { func (d *Misskey) rename(srcObj model.Obj, newName string) (model.Obj, error) { if srcObj.IsDir() { var folder MFolder - err := d.request("/folders/update", "POST", setBody(map[string]string{"folderId": srcObj.GetID(), "name": newName}), &folder) + err := d.request("/folders/update", http.MethodPost, setBody(map[string]string{"folderId": srcObj.GetID(), "name": newName}), &folder) return mFolder2Object(folder), err } else { var file MFile - err := d.request("/files/update", "POST", setBody(map[string]string{"fileId": srcObj.GetID(), "name": newName}), &file) + err := d.request("/files/update", http.MethodPost, setBody(map[string]string{"fileId": srcObj.GetID(), "name": newName}), &file) return mFile2Object(file), err } } @@ -171,7 +172,7 @@ func (d *Misskey) copy(srcObj, dstDir model.Obj) (model.Obj, error) { if err != nil { return nil, err } - err = d.request("/files/upload-from-url", "POST", setBody(map[string]interface{}{"url": url.URL, "folderId": handleFolderId(dstDir)}), &file) + err = d.request("/files/upload-from-url", http.MethodPost, setBody(map[string]interface{}{"url": url.URL, "folderId": handleFolderId(dstDir)}), &file) if err != nil { return nil, err } @@ -181,10 +182,10 @@ func (d *Misskey) copy(srcObj, dstDir model.Obj) (model.Obj, error) { func (d *Misskey) remove(obj model.Obj) error { if obj.IsDir() { - err := d.request("/folders/delete", "POST", setBody(map[string]string{"folderId": obj.GetID()}), nil) + err := d.request("/folders/delete", http.MethodPost, setBody(map[string]string{"folderId": obj.GetID()}), nil) return err } else { - err := d.request("/files/delete", "POST", setBody(map[string]string{"fileId": obj.GetID()}), nil) + err := d.request("/files/delete", http.MethodPost, setBody(map[string]string{"fileId": obj.GetID()}), nil) return err } } diff --git a/drivers/onedrive/util.go b/drivers/onedrive/util.go index be86d50b..672d3c51 100644 --- a/drivers/onedrive/util.go +++ b/drivers/onedrive/util.go @@ -1,7 +1,6 @@ package onedrive import ( - "bytes" "context" "errors" "fmt" @@ -15,6 +14,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" + streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/avast/retry-go" "github.com/go-resty/resty/v2" @@ -241,23 +241,25 @@ func (d *Onedrive) upBig(ctx context.Context, dstDir model.Obj, stream model.Fil uploadUrl := jsoniter.Get(res, "uploadUrl").ToString() var finish int64 = 0 DEFAULT := d.ChunkSize * 1024 * 1024 + ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT)) + if err != nil { + return err + } for finish < stream.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() } left := stream.GetSize() - finish byteSize := min(left, DEFAULT) + utils.Log.Debugf("[Onedrive] upload range: %d-%d/%d", finish, finish+byteSize-1, stream.GetSize()) + rd, err := ss.GetSectionReader(finish, byteSize) + if err != nil { + return err + } err = retry.Do( func() error { - utils.Log.Debugf("[Onedrive] upload range: %d-%d/%d", finish, finish+byteSize-1, stream.GetSize()) - byteData := make([]byte, byteSize) - n, err := io.ReadFull(stream, byteData) - utils.Log.Debug(err, n) - if err != nil { - return err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadUrl, - driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) + rd.Seek(0, io.SeekStart) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadUrl, driver.NewLimitedUploadStream(ctx, rd)) if err != nil { return err } @@ -283,6 +285,7 @@ func (d *Onedrive) upBig(ctx context.Context, dstDir model.Obj, stream model.Fil retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) + ss.RecycleSectionReader(rd) if err != nil { return err } diff --git a/drivers/onedrive_app/util.go b/drivers/onedrive_app/util.go index 3be82b5a..2aca3688 100644 --- a/drivers/onedrive_app/util.go +++ b/drivers/onedrive_app/util.go @@ -1,7 +1,6 @@ package onedrive_app import ( - "bytes" "context" "errors" "fmt" @@ -15,6 +14,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" + streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/avast/retry-go" "github.com/go-resty/resty/v2" @@ -155,23 +155,25 @@ func (d *OnedriveAPP) upBig(ctx context.Context, dstDir model.Obj, stream model. uploadUrl := jsoniter.Get(res, "uploadUrl").ToString() var finish int64 = 0 DEFAULT := d.ChunkSize * 1024 * 1024 + ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT)) + if err != nil { + return err + } for finish < stream.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() } left := stream.GetSize() - finish byteSize := min(left, DEFAULT) + utils.Log.Debugf("[OnedriveAPP] upload range: %d-%d/%d", finish, finish+byteSize-1, stream.GetSize()) + rd, err := ss.GetSectionReader(finish, byteSize) + if err != nil { + return err + } err = retry.Do( func() error { - utils.Log.Debugf("[OnedriveAPP] upload range: %d-%d/%d", finish, finish+byteSize-1, stream.GetSize()) - byteData := make([]byte, byteSize) - n, err := io.ReadFull(stream, byteData) - utils.Log.Debug(err, n) - if err != nil { - return err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadUrl, - driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) + rd.Seek(0, io.SeekStart) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadUrl, driver.NewLimitedUploadStream(ctx, rd)) if err != nil { return err } @@ -197,6 +199,7 @@ func (d *OnedriveAPP) upBig(ctx context.Context, dstDir model.Obj, stream model. retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) + ss.RecycleSectionReader(rd) if err != nil { return err } diff --git a/drivers/onedrive_sharelink/driver.go b/drivers/onedrive_sharelink/driver.go index 08ed07e0..42d0f190 100644 --- a/drivers/onedrive_sharelink/driver.go +++ b/drivers/onedrive_sharelink/driver.go @@ -38,14 +38,14 @@ func (d *OnedriveSharelink) Init(ctx context.Context) error { d.cron = cron.NewCron(time.Hour * 1) d.cron.Do(func() { var err error - d.Headers, err = d.getHeaders() + d.Headers, err = d.getHeaders(ctx) if err != nil { log.Errorf("%+v", err) } }) // Get initial headers - d.Headers, err = d.getHeaders() + d.Headers, err = d.getHeaders(ctx) if err != nil { return err } @@ -59,7 +59,7 @@ func (d *OnedriveSharelink) Drop(ctx context.Context) error { func (d *OnedriveSharelink) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { path := dir.GetPath() - files, err := d.getFiles(path) + files, err := d.getFiles(ctx, path) if err != nil { return nil, err } @@ -82,7 +82,7 @@ func (d *OnedriveSharelink) Link(ctx context.Context, file model.Obj, args model if d.HeaderTime < time.Now().Unix()-1800 { var err error log.Debug("headers are older than 30 minutes, get new headers") - header, err = d.getHeaders() + header, err = d.getHeaders(ctx) if err != nil { return nil, err } diff --git a/drivers/onedrive_sharelink/util.go b/drivers/onedrive_sharelink/util.go index 9f3480fb..d4cd4229 100644 --- a/drivers/onedrive_sharelink/util.go +++ b/drivers/onedrive_sharelink/util.go @@ -1,6 +1,7 @@ package onedrive_sharelink import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -131,7 +132,7 @@ func getAttrValue(n *html.Node, key string) string { } // getHeaders constructs and returns the necessary HTTP headers for accessing the OneDrive share link -func (d *OnedriveSharelink) getHeaders() (http.Header, error) { +func (d *OnedriveSharelink) getHeaders(ctx context.Context) (http.Header, error) { header := http.Header{} header.Set("User-Agent", base.UserAgent) header.Set("accept-language", "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6") @@ -142,7 +143,7 @@ func (d *OnedriveSharelink) getHeaders() (http.Header, error) { if d.ShareLinkPassword == "" { // Create a no-redirect client clientNoDirect := NewNoRedirectCLient() - req, err := http.NewRequest("GET", d.ShareLinkURL, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.ShareLinkURL, nil) if err != nil { return nil, err } @@ -180,9 +181,9 @@ func (d *OnedriveSharelink) getHeaders() (http.Header, error) { } // getFiles retrieves the files from the OneDrive share link at the specified path -func (d *OnedriveSharelink) getFiles(path string) ([]Item, error) { +func (d *OnedriveSharelink) getFiles(ctx context.Context, path string) ([]Item, error) { clientNoDirect := NewNoRedirectCLient() - req, err := http.NewRequest("GET", d.ShareLinkURL, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.ShareLinkURL, nil) if err != nil { return nil, err } @@ -221,11 +222,11 @@ func (d *OnedriveSharelink) getFiles(path string) ([]Item, error) { // Get redirectUrl answer, err := clientNoDirect.Do(req) if err != nil { - d.Headers, err = d.getHeaders() + d.Headers, err = d.getHeaders(ctx) if err != nil { return nil, err } - return d.getFiles(path) + return d.getFiles(ctx, path) } defer answer.Body.Close() re := regexp.MustCompile(`templateUrl":"(.*?)"`) @@ -290,7 +291,7 @@ func (d *OnedriveSharelink) getFiles(path string) ([]Item, error) { client := &http.Client{} postUrl := strings.Join(redirectSplitURL[:len(redirectSplitURL)-3], "/") + "/_api/v2.1/graphql" - req, err = http.NewRequest("POST", postUrl, strings.NewReader(graphqlVar)) + req, err = http.NewRequest(http.MethodPost, postUrl, strings.NewReader(graphqlVar)) if err != nil { return nil, err } @@ -298,11 +299,11 @@ func (d *OnedriveSharelink) getFiles(path string) ([]Item, error) { resp, err := client.Do(req) if err != nil { - d.Headers, err = d.getHeaders() + d.Headers, err = d.getHeaders(ctx) if err != nil { return nil, err } - return d.getFiles(path) + return d.getFiles(ctx, path) } defer resp.Body.Close() var graphqlReq GraphQLRequest @@ -323,31 +324,31 @@ func (d *OnedriveSharelink) getFiles(path string) ([]Item, error) { graphqlReqNEW := GraphQLNEWRequest{} postUrl = strings.Join(redirectSplitURL[:len(redirectSplitURL)-3], "/") + "/_api/web/GetListUsingPath(DecodedUrl=@a1)/RenderListDataAsStream" + nextHref - req, _ = http.NewRequest("POST", postUrl, strings.NewReader(renderListDataAsStreamVar)) + req, _ = http.NewRequest(http.MethodPost, postUrl, strings.NewReader(renderListDataAsStreamVar)) req.Header = tempHeader resp, err := client.Do(req) if err != nil { - d.Headers, err = d.getHeaders() + d.Headers, err = d.getHeaders(ctx) if err != nil { return nil, err } - return d.getFiles(path) + return d.getFiles(ctx, path) } defer resp.Body.Close() json.NewDecoder(resp.Body).Decode(&graphqlReqNEW) for graphqlReqNEW.ListData.NextHref != "" { graphqlReqNEW = GraphQLNEWRequest{} postUrl = strings.Join(redirectSplitURL[:len(redirectSplitURL)-3], "/") + "/_api/web/GetListUsingPath(DecodedUrl=@a1)/RenderListDataAsStream" + nextHref - req, _ = http.NewRequest("POST", postUrl, strings.NewReader(renderListDataAsStreamVar)) + req, _ = http.NewRequest(http.MethodPost, postUrl, strings.NewReader(renderListDataAsStreamVar)) req.Header = tempHeader resp, err := client.Do(req) if err != nil { - d.Headers, err = d.getHeaders() + d.Headers, err = d.getHeaders(ctx) if err != nil { return nil, err } - return d.getFiles(path) + return d.getFiles(ctx, path) } defer resp.Body.Close() json.NewDecoder(resp.Body).Decode(&graphqlReqNEW) diff --git a/drivers/quark_uc_tv/driver.go b/drivers/quark_uc_tv/driver.go index f611fc98..038f768f 100644 --- a/drivers/quark_uc_tv/driver.go +++ b/drivers/quark_uc_tv/driver.go @@ -3,6 +3,7 @@ package quark_uc_tv import ( "context" "fmt" + "net/http" "strconv" "time" @@ -96,7 +97,7 @@ func (d *QuarkUCTV) List(ctx context.Context, dir model.Obj, args model.ListArgs pageSize := int64(100) for { var filesData FilesData - _, err := d.request(ctx, "/file", "GET", func(req *resty.Request) { + _, err := d.request(ctx, "/file", http.MethodGet, func(req *resty.Request) { req.SetQueryParams(map[string]string{ "method": "list", "parent_fid": dir.GetID(), diff --git a/drivers/quark_uc_tv/util.go b/drivers/quark_uc_tv/util.go index 92e28ba2..f513ef11 100644 --- a/drivers/quark_uc_tv/util.go +++ b/drivers/quark_uc_tv/util.go @@ -95,7 +95,7 @@ func (d *QuarkUCTV) getLoginCode(ctx context.Context) (string, error) { QrData string `json:"qr_data"` QueryToken string `json:"query_token"` } - _, err := d.request(ctx, pathname, "GET", func(req *resty.Request) { + _, err := d.request(ctx, pathname, http.MethodGet, func(req *resty.Request) { req.SetQueryParams(map[string]string{ "auth_type": "code", "client_id": d.conf.clientID, @@ -123,7 +123,7 @@ func (d *QuarkUCTV) getCode(ctx context.Context) (string, error) { CommonRsp Code string `json:"code"` } - _, err := d.request(ctx, pathname, "GET", func(req *resty.Request) { + _, err := d.request(ctx, pathname, http.MethodGet, func(req *resty.Request) { req.SetQueryParams(map[string]string{ "client_id": d.conf.clientID, "scope": "netdisk", @@ -138,7 +138,7 @@ func (d *QuarkUCTV) getCode(ctx context.Context) (string, error) { func (d *QuarkUCTV) getRefreshTokenByTV(ctx context.Context, code string, isRefresh bool) error { pathname := "/token" - _, _, reqID := d.generateReqSign("POST", pathname, d.conf.signKey) + _, _, reqID := d.generateReqSign(http.MethodPost, pathname, d.conf.signKey) u := d.conf.codeApi + pathname var resp RefreshTokenAuthResp body := map[string]string{ diff --git a/drivers/s3/doge.go b/drivers/s3/doge.go index 12a584ca..625c2f27 100644 --- a/drivers/s3/doge.go +++ b/drivers/s3/doge.go @@ -38,7 +38,7 @@ func getCredentials(AccessKey, SecretKey string) (rst Credentials, err error) { sign := hex.EncodeToString(hmacObj.Sum(nil)) Authorization := "TOKEN " + AccessKey + ":" + sign - req, err := http.NewRequest("POST", "https://api.dogecloud.com"+apiPath, strings.NewReader(string(reqBody))) + req, err := http.NewRequest(http.MethodPost, "https://api.dogecloud.com"+apiPath, strings.NewReader(string(reqBody))) if err != nil { return rst, err } diff --git a/drivers/sftp/util.go b/drivers/sftp/util.go index 5c47c532..293df8fa 100644 --- a/drivers/sftp/util.go +++ b/drivers/sftp/util.go @@ -13,8 +13,8 @@ import ( // do others that not defined in Driver interface func (d *SFTP) initClient() error { - err, _, _ := singleflight.ErrorGroup.Do(fmt.Sprintf("SFTP.initClient:%p", d), func() (error, error) { - return d._initClient(), nil + _, err, _ := singleflight.AnyGroup.Do(fmt.Sprintf("SFTP.initClient:%p", d), func() (any, error) { + return nil, d._initClient() }) return err } diff --git a/drivers/smb/util.go b/drivers/smb/util.go index 166e2ae3..3e40f813 100644 --- a/drivers/smb/util.go +++ b/drivers/smb/util.go @@ -28,8 +28,8 @@ func (d *SMB) getLastConnTime() time.Time { } func (d *SMB) initFS() error { - err, _, _ := singleflight.ErrorGroup.Do(fmt.Sprintf("SMB.initFS:%p", d), func() (error, error) { - return d._initFS(), nil + _, err, _ := singleflight.AnyGroup.Do(fmt.Sprintf("SMB.initFS:%p", d), func() (any, error) { + return nil, d._initFS() }) return err } diff --git a/drivers/webdav/odrvcookie/fetch.go b/drivers/webdav/odrvcookie/fetch.go index a6e71a56..b4eca077 100644 --- a/drivers/webdav/odrvcookie/fetch.go +++ b/drivers/webdav/odrvcookie/fetch.go @@ -181,7 +181,7 @@ func (ca *CookieAuth) getSPToken() (*SuccessResponse, error) { // Execute the first request which gives us an auth token for the sharepoint service // With this token we can authenticate on the login page and save the returned cookies - req, err := http.NewRequest("POST", loginUrl, buf) + req, err := http.NewRequest(http.MethodPost, loginUrl, buf) if err != nil { return nil, err } diff --git a/internal/bootstrap/config.go b/internal/bootstrap/config.go index db8dcf28..3980deb8 100644 --- a/internal/bootstrap/config.go +++ b/internal/bootstrap/config.go @@ -12,6 +12,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/net" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/caarlos0/env/v9" + "github.com/shirou/gopsutil/v4/mem" log "github.com/sirupsen/logrus" ) @@ -79,6 +80,18 @@ func InitConfig() { if conf.Conf.MaxConcurrency > 0 { net.DefaultConcurrencyLimit = &net.ConcurrencyLimit{Limit: conf.Conf.MaxConcurrency} } + if conf.Conf.MaxBufferLimit < 0 { + m, _ := mem.VirtualMemory() + if m != nil { + conf.MaxBufferLimit = max(int(float64(m.Total)*0.05), 4*utils.MB) + conf.MaxBufferLimit -= conf.MaxBufferLimit % utils.MB + } else { + conf.MaxBufferLimit = 16 * utils.MB + } + } else { + conf.MaxBufferLimit = conf.Conf.MaxBufferLimit * utils.MB + } + log.Infof("max buffer limit: %d", conf.MaxBufferLimit) if !conf.Conf.Force { confFromEnv() } diff --git a/internal/conf/config.go b/internal/conf/config.go index 9c1cd52a..72a8ee72 100644 --- a/internal/conf/config.go +++ b/internal/conf/config.go @@ -119,6 +119,7 @@ type Config struct { DistDir string `json:"dist_dir"` Log LogConfig `json:"log" envPrefix:"LOG_"` DelayedStart int `json:"delayed_start" env:"DELAYED_START"` + MaxBufferLimit int `json:"max_buffer_limitMB" env:"MAX_BUFFER_LIMIT_MB"` MaxConnections int `json:"max_connections" env:"MAX_CONNECTIONS"` MaxConcurrency int `json:"max_concurrency" env:"MAX_CONCURRENCY"` TlsInsecureSkipVerify bool `json:"tls_insecure_skip_verify" env:"TLS_INSECURE_SKIP_VERIFY"` @@ -174,6 +175,7 @@ func DefaultConfig(dataDir string) *Config { }, }, }, + MaxBufferLimit: -1, MaxConnections: 0, MaxConcurrency: 64, TlsInsecureSkipVerify: true, diff --git a/internal/conf/var.go b/internal/conf/var.go index 8af45ca3..50a7f33d 100644 --- a/internal/conf/var.go +++ b/internal/conf/var.go @@ -25,6 +25,7 @@ var PrivacyReg []*regexp.Regexp var ( // StoragesLoaded loaded success if empty StoragesLoaded = false + MaxBufferLimit int ) var ( RawIndexHtml string diff --git a/internal/net/request.go b/internal/net/request.go index 9b634ee5..5e98f490 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/pkg/utils" @@ -22,7 +23,7 @@ import ( // DefaultDownloadPartSize is the default range of bytes to get at a time when // using Download(). -const DefaultDownloadPartSize = utils.MB * 10 +const DefaultDownloadPartSize = utils.MB * 8 // DefaultDownloadConcurrency is the default number of goroutines to spin up // when using Download(). @@ -84,6 +85,9 @@ func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readClo if impl.cfg.PartSize == 0 { impl.cfg.PartSize = DefaultDownloadPartSize } + if conf.MaxBufferLimit > 0 && impl.cfg.PartSize > conf.MaxBufferLimit { + impl.cfg.PartSize = conf.MaxBufferLimit + } if impl.cfg.HttpClient == nil { impl.cfg.HttpClient = DefaultHttpRequestFunc } @@ -159,17 +163,13 @@ func (d *downloader) download() (io.ReadCloser, error) { return nil, err } - maxPart := int(d.params.Range.Length / int64(d.cfg.PartSize)) - if d.params.Range.Length%int64(d.cfg.PartSize) > 0 { - maxPart++ + maxPart := 1 + if d.params.Range.Length > int64(d.cfg.PartSize) { + maxPart = int((d.params.Range.Length + int64(d.cfg.PartSize) - 1) / int64(d.cfg.PartSize)) } if maxPart < d.cfg.Concurrency { d.cfg.Concurrency = maxPart } - if d.params.Range.Length == 0 { - d.cfg.Concurrency = 1 - } - log.Debugf("cfgConcurrency:%d", d.cfg.Concurrency) if maxPart == 1 { diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 7b34d18c..387bf036 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -9,6 +9,7 @@ import ( "math" "os" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" @@ -104,11 +105,8 @@ func (f *FileStream) GetFile() model.File { return nil } -const InMemoryBufMaxSize = 10 // Megabytes -const InMemoryBufMaxSizeBytes = InMemoryBufMaxSize * 1024 * 1024 - // RangeRead have to cache all data first since only Reader is provided. -// also support a peeking RangeRead at very start, but won't buffer more than 10MB data in memory +// also support a peeking RangeRead at very start, but won't buffer more than conf.MaxBufferLimit data in memory func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { if httpRange.Length < 0 || httpRange.Start+httpRange.Length > f.GetSize() { httpRange.Length = f.GetSize() - httpRange.Start @@ -122,7 +120,7 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { if f.peekBuff != nil && size <= int64(f.peekBuff.Len()) { return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil } - if size <= InMemoryBufMaxSizeBytes { + if size <= int64(conf.MaxBufferLimit) { bufSize := min(size, f.GetSize()) // 使用bytes.Buffer作为io.CopyBuffer的写入对象,CopyBuffer会调用Buffer.ReadFrom // 即使被写入的数据量与Buffer.Cap一致,Buffer也会扩大 diff --git a/internal/stream/util.go b/internal/stream/util.go index aee5c603..77b23802 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -1,12 +1,14 @@ package stream import ( + "bytes" "context" "encoding/hex" "errors" "fmt" "io" "net/http" + "sync" "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -187,3 +189,68 @@ func CacheFullInTempFileAndHash(stream model.FileStreamer, up model.UpdateProgre } return tmpF, hex.EncodeToString(h.Sum(nil)), err } + +type StreamSectionReader struct { + file model.FileStreamer + off int64 + bufPool *sync.Pool +} + +func NewStreamSectionReader(file model.FileStreamer, maxBufferSize int) (*StreamSectionReader, error) { + ss := &StreamSectionReader{file: file} + if file.GetFile() == nil { + maxBufferSize = min(maxBufferSize, int(file.GetSize())) + if maxBufferSize > conf.MaxBufferLimit { + _, err := file.CacheFullInTempFile() + if err != nil { + return nil, err + } + } else { + ss.bufPool = &sync.Pool{ + New: func() any { + return make([]byte, maxBufferSize) + }, + } + } + } + return ss, nil +} + +// 线程不安全 +func (ss *StreamSectionReader) GetSectionReader(off, length int64) (*SectionReader, error) { + var cache io.ReaderAt = ss.file.GetFile() + var buf []byte + if cache == nil { + if off != ss.off { + return nil, fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.off) + } + tempBuf := ss.bufPool.Get().([]byte) + buf = tempBuf[:length] + n, err := io.ReadFull(ss.file, buf) + if err != nil { + return nil, err + } + if int64(n) != length { + return nil, fmt.Errorf("can't read data, expected=%d, got=%d", length, n) + } + ss.off += int64(n) + off = 0 + cache = bytes.NewReader(buf) + } + return &SectionReader{io.NewSectionReader(cache, off, length), buf}, nil +} + +func (ss *StreamSectionReader) RecycleSectionReader(sr *SectionReader) { + if sr != nil { + if sr.buf != nil { + ss.bufPool.Put(sr.buf[0:cap(sr.buf)]) + sr.buf = nil + } + sr.ReadSeeker = nil + } +} + +type SectionReader struct { + io.ReadSeeker + buf []byte +} diff --git a/pkg/errgroup/errgroup.go b/pkg/errgroup/errgroup.go index 858df044..daf1b315 100644 --- a/pkg/errgroup/errgroup.go +++ b/pkg/errgroup/errgroup.go @@ -19,6 +19,8 @@ type Group struct { wg sync.WaitGroup sem chan token + + startChan chan token } func NewGroupWithContext(ctx context.Context, limit int, retryOpts ...retry.Option) (*Group, context.Context) { @@ -26,6 +28,13 @@ func NewGroupWithContext(ctx context.Context, limit int, retryOpts ...retry.Opti return (&Group{cancel: cancel, ctx: ctx, opts: append(retryOpts, retry.Context(ctx))}).SetLimit(limit), ctx } +// OrderedGroup +func NewOrderedGroupWithContext(ctx context.Context, limit int, retryOpts ...retry.Option) (*Group, context.Context) { + group, ctx := NewGroupWithContext(ctx, limit, retryOpts...) + group.startChan = make(chan token, 1) + return group, ctx +} + func (g *Group) done() { if g.sem != nil { <-g.sem @@ -39,18 +48,62 @@ func (g *Group) Wait() error { return context.Cause(g.ctx) } -func (g *Group) Go(f func(ctx context.Context) error) { +func (g *Group) Go(do func(ctx context.Context) error) { + g.GoWithLifecycle(Lifecycle{Do: do}) +} + +type Lifecycle struct { + // Before在OrderedGroup是线程安全的 + Before func(ctx context.Context) error + // 如果Before返回err就不调用Do + Do func(ctx context.Context) error + // 最后调用After + After func(err error) +} + +func (g *Group) GoWithLifecycle(lifecycle Lifecycle) { + if g.startChan != nil { + select { + case <-g.ctx.Done(): + return + case g.startChan <- token{}: + } + } + if g.sem != nil { - g.sem <- token{} + select { + case <-g.ctx.Done(): + return + case g.sem <- token{}: + } } g.wg.Add(1) go func() { defer g.done() - if err := retry.Do(func() error { return f(g.ctx) }, g.opts...); err != nil { - g.cancel(err) + var err error + if lifecycle.Before != nil { + err = lifecycle.Before(g.ctx) + } + if err == nil { + if g.startChan != nil { + <-g.startChan + } + err = retry.Do(func() error { return lifecycle.Do(g.ctx) }, g.opts...) + } + if lifecycle.After != nil { + lifecycle.After(err) + } + if err != nil { + select { + case <-g.ctx.Done(): + return + default: + g.cancel(err) + } } }() + } func (g *Group) TryGo(f func(ctx context.Context) error) bool { diff --git a/pkg/gowebdav/client.go b/pkg/gowebdav/client.go index e23a5a25..7251e084 100644 --- a/pkg/gowebdav/client.go +++ b/pkg/gowebdav/client.go @@ -383,7 +383,7 @@ func (c *Client) Link(path string) (string, http.Header, error) { // ReadStream reads the stream for a given path func (c *Client) ReadStream(path string, callback func(rq *http.Request)) (io.ReadCloser, http.Header, error) { - rs, err := c.req("GET", path, nil, callback) + rs, err := c.req(http.MethodGet, path, nil, callback) if err != nil { return nil, nil, newPathErrorErr("ReadStream", path, err) } @@ -405,7 +405,7 @@ func (c *Client) ReadStream(path string, callback func(rq *http.Request)) (io.Re // this function will emulate the behavior by skipping `offset` bytes and limiting the result // to `length`. func (c *Client) ReadStreamRange(path string, offset, length int64) (io.ReadCloser, error) { - rs, err := c.req("GET", path, nil, func(r *http.Request) { + rs, err := c.req(http.MethodGet, path, nil, func(r *http.Request) { r.Header.Add("Range", fmt.Sprintf("bytes=%v-%v", offset, offset+length-1)) }) if err != nil { diff --git a/pkg/qbittorrent/client.go b/pkg/qbittorrent/client.go index f9047ea8..102b445d 100644 --- a/pkg/qbittorrent/client.go +++ b/pkg/qbittorrent/client.go @@ -114,7 +114,7 @@ func (c *client) post(path string, data url.Values) (*http.Response, error) { u := c.url.JoinPath(path) u.User = nil // remove userinfo for requests - req, err := http.NewRequest("POST", u.String(), bytes.NewReader([]byte(data.Encode()))) + req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader([]byte(data.Encode()))) if err != nil { return nil, err } @@ -162,7 +162,7 @@ func (c *client) AddFromLink(link string, savePath string, id string) error { u := c.url.JoinPath("/api/v2/torrents/add") u.User = nil // remove userinfo for requests - req, err := http.NewRequest("POST", u.String(), buf) + req, err := http.NewRequest(http.MethodPost, u.String(), buf) if err != nil { return err } diff --git a/pkg/singleflight/var.go b/pkg/singleflight/var.go index 41c97a2e..a92288d1 100644 --- a/pkg/singleflight/var.go +++ b/pkg/singleflight/var.go @@ -1,3 +1,3 @@ package singleflight -var ErrorGroup Group[error] +var AnyGroup Group[any]