From 57fceabcf4c172290866942a73f544c626702548 Mon Sep 17 00:00:00 2001 From: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com> Date: Mon, 11 Aug 2025 23:41:22 +0800 Subject: [PATCH] perf(stream): improve file stream range reading and caching mechanism (#1001) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * perf(stream): improve file stream range reading and caching mechanism * 。 * add bytes_test.go * fix(stream): handle EOF and buffer reading more gracefully * 注释 * refactor: update CacheFullAndWriter to accept pointer for UpdateProgress * update tests * Update drivers/google_drive/util.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com> * 更优雅的克隆Link * 修复stream已缓存但无法重复读取 * 将Bytes类型重命名为Reader * 修复栈溢出 * update tests --------- Signed-off-by: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- drivers/115/driver.go | 4 +- drivers/115/util.go | 2 +- drivers/115_open/driver.go | 4 +- drivers/115_open/upload.go | 11 +- drivers/123/driver.go | 4 +- drivers/123/upload.go | 12 +- drivers/123_open/driver.go | 4 +- drivers/123_open/upload.go | 12 +- drivers/139/driver.go | 4 +- drivers/189_tv/utils.go | 17 +- drivers/189pc/utils.go | 9 +- drivers/alias/driver.go | 72 +++++--- drivers/alias/util.go | 23 +-- drivers/aliyundrive_open/upload.go | 15 +- drivers/cloudreve/util.go | 33 ++-- drivers/cloudreve_v4/util.go | 33 ++-- drivers/crypt/driver.go | 1 - drivers/doubao/util.go | 17 +- drivers/google_drive/driver.go | 2 +- drivers/google_drive/util.go | 10 +- drivers/ilanzou/driver.go | 4 +- drivers/mediatrack/driver.go | 2 +- drivers/mopan/driver.go | 2 +- drivers/netease_music/util.go | 2 +- drivers/onedrive/util.go | 9 +- drivers/onedrive_app/util.go | 9 +- drivers/pikpak/driver.go | 13 +- drivers/pikpak/util.go | 11 +- drivers/quark_open/driver.go | 5 +- drivers/quark_uc/driver.go | 5 +- drivers/strm/driver.go | 5 +- drivers/terabox/driver.go | 2 +- drivers/thunder/driver.go | 4 +- drivers/thunder_browser/driver.go | 4 +- drivers/thunderx/driver.go | 4 +- drivers/weiyun/driver.go | 2 +- internal/bootstrap/config.go | 2 +- internal/conf/var.go | 2 +- internal/fs/archive.go | 18 +- internal/fs/put.go | 2 +- internal/model/obj.go | 12 +- internal/stream/stream.go | 280 +++++++++++++++++------------ internal/stream/stream_test.go | 86 +++++++++ internal/stream/util.go | 55 +----- pkg/buffer/bytes.go | 92 ++++++++++ pkg/buffer/bytes_test.go | 95 ++++++++++ pkg/errgroup/errgroup.go | 5 +- pkg/utils/io.go | 16 +- 48 files changed, 657 insertions(+), 380 deletions(-) create mode 100644 internal/stream/stream_test.go create mode 100644 pkg/buffer/bytes.go create mode 100644 pkg/buffer/bytes_test.go diff --git a/drivers/115/driver.go b/drivers/115/driver.go index 49570984..62ff1de8 100644 --- a/drivers/115/driver.go +++ b/drivers/115/driver.go @@ -186,9 +186,7 @@ func (d *Pan115) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr preHash = strings.ToUpper(preHash) fullHash := stream.GetHash().GetHash(utils.SHA1) if len(fullHash) != utils.SHA1.Width { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, fullHash, err = streamPkg.CacheFullInTempFileAndHash(stream, cacheFileProgress, utils.SHA1) + _, fullHash, err = streamPkg.CacheFullAndHash(stream, &up, utils.SHA1) if err != nil { return nil, err } diff --git a/drivers/115/util.go b/drivers/115/util.go index f947365c..b000436b 100644 --- a/drivers/115/util.go +++ b/drivers/115/util.go @@ -321,7 +321,7 @@ func (d *Pan115) UploadByMultipart(ctx context.Context, params *driver115.Upload err error ) - tmpF, err := s.CacheFullInTempFile() + tmpF, err := s.CacheFullAndWriter(&up, nil) if err != nil { return nil, err } diff --git a/drivers/115_open/driver.go b/drivers/115_open/driver.go index 0045ab34..1ded971e 100644 --- a/drivers/115_open/driver.go +++ b/drivers/115_open/driver.go @@ -239,9 +239,7 @@ func (d *Open115) Put(ctx context.Context, dstDir model.Obj, file model.FileStre } sha1 := file.GetHash().GetHash(utils.SHA1) if len(sha1) != utils.SHA1.Width { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, sha1, err = stream.CacheFullInTempFileAndHash(file, cacheFileProgress, utils.SHA1) + _, sha1, err = stream.CacheFullAndHash(file, &up, utils.SHA1) if err != nil { return err } diff --git a/drivers/115_open/upload.go b/drivers/115_open/upload.go index 9bd1f920..3c847e05 100644 --- a/drivers/115_open/upload.go +++ b/drivers/115_open/upload.go @@ -86,13 +86,14 @@ func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, fileSize := stream.GetSize() chunkSize := calPartSize(fileSize) - partNum := (stream.GetSize() + chunkSize - 1) / chunkSize - parts := make([]oss.UploadPart, partNum) - offset := int64(0) - ss, err := streamPkg.NewStreamSectionReader(stream, int(chunkSize)) + ss, err := streamPkg.NewStreamSectionReader(stream, int(chunkSize), &up) if err != nil { return err } + + partNum := (stream.GetSize() + chunkSize - 1) / chunkSize + parts := make([]oss.UploadPart, partNum) + offset := int64(0) for i := int64(1); i <= partNum; i++ { if utils.IsCanceled(ctx) { return ctx.Err() @@ -119,7 +120,7 @@ func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, retry.Attempts(3), retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second)) - ss.RecycleSectionReader(rd) + ss.FreeSectionReader(rd) if err != nil { return err } diff --git a/drivers/123/driver.go b/drivers/123/driver.go index 0c7078a7..6e172f67 100644 --- a/drivers/123/driver.go +++ b/drivers/123/driver.go @@ -182,9 +182,7 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, file model.FileStrea etag := file.GetHash().GetHash(utils.MD5) var err error if len(etag) < utils.MD5.Width { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, etag, err = stream.CacheFullInTempFileAndHash(file, cacheFileProgress, utils.MD5) + _, etag, err = stream.CacheFullAndHash(file, &up, utils.MD5) if err != nil { return err } diff --git a/drivers/123/upload.go b/drivers/123/upload.go index e44ce2ee..eb05e094 100644 --- a/drivers/123/upload.go +++ b/drivers/123/upload.go @@ -81,6 +81,12 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi if size > chunkSize { chunkCount = int((size + chunkSize - 1) / chunkSize) } + + ss, err := stream.NewStreamSectionReader(file, int(chunkSize), &up) + if err != nil { + return err + } + lastChunkSize := size % chunkSize if lastChunkSize == 0 { lastChunkSize = chunkSize @@ -92,10 +98,6 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi batchSize = 10 getS3UploadUrl = d.getS3PreSignedUrls } - ss, err := stream.NewStreamSectionReader(file, int(chunkSize)) - if err != nil { - return err - } thread := min(int(chunkCount), d.UploadThread) threadG, uploadCtx := errgroup.NewOrderedGroupWithContext(ctx, thread, @@ -180,7 +182,7 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi return nil }, After: func(err error) { - ss.RecycleSectionReader(reader) + ss.FreeSectionReader(reader) }, }) } diff --git a/drivers/123_open/driver.go b/drivers/123_open/driver.go index 547ee810..a71a1aa2 100644 --- a/drivers/123_open/driver.go +++ b/drivers/123_open/driver.go @@ -132,9 +132,7 @@ func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStre // etag 文件md5 etag := file.GetHash().GetHash(utils.MD5) if len(etag) < utils.MD5.Width { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, etag, err = stream.CacheFullInTempFileAndHash(file, cacheFileProgress, utils.MD5) + _, etag, err = stream.CacheFullAndHash(file, &up, utils.MD5) if err != nil { return nil, err } diff --git a/drivers/123_open/upload.go b/drivers/123_open/upload.go index a9966d7b..8cc42012 100644 --- a/drivers/123_open/upload.go +++ b/drivers/123_open/upload.go @@ -46,6 +46,12 @@ func (d *Open123) Upload(ctx context.Context, file model.FileStreamer, createRes uploadDomain := createResp.Data.Servers[0] size := file.GetSize() chunkSize := createResp.Data.SliceSize + + ss, err := stream.NewStreamSectionReader(file, int(chunkSize), &up) + if err != nil { + return err + } + uploadNums := (size + chunkSize - 1) / chunkSize thread := min(int(uploadNums), d.UploadThread) threadG, uploadCtx := errgroup.NewOrderedGroupWithContext(ctx, thread, @@ -53,10 +59,6 @@ func (d *Open123) Upload(ctx context.Context, file model.FileStreamer, createRes retry.Delay(time.Second), retry.DelayType(retry.BackOffDelay)) - ss, err := stream.NewStreamSectionReader(file, int(chunkSize)) - if err != nil { - return err - } for partIndex := range uploadNums { if utils.IsCanceled(uploadCtx) { break @@ -157,7 +159,7 @@ func (d *Open123) Upload(ctx context.Context, file model.FileStreamer, createRes return nil }, After: func(err error) { - ss.RecycleSectionReader(reader) + ss.FreeSectionReader(reader) }, }) } diff --git a/drivers/139/driver.go b/drivers/139/driver.go index 8033eefc..7f3b5c60 100644 --- a/drivers/139/driver.go +++ b/drivers/139/driver.go @@ -522,9 +522,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr var err error fullHash := stream.GetHash().GetHash(utils.SHA256) if len(fullHash) != utils.SHA256.Width { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, fullHash, err = streamPkg.CacheFullInTempFileAndHash(stream, cacheFileProgress, utils.SHA256) + _, fullHash, err = streamPkg.CacheFullAndHash(stream, &up, utils.SHA256) if err != nil { return err } diff --git a/drivers/189_tv/utils.go b/drivers/189_tv/utils.go index 2e11c829..fd4d74df 100644 --- a/drivers/189_tv/utils.go +++ b/drivers/189_tv/utils.go @@ -5,17 +5,19 @@ import ( "encoding/base64" "encoding/xml" "fmt" - "github.com/skip2/go-qrcode" "io" "net/http" "strconv" "strings" "time" + "github.com/skip2/go-qrcode" + "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/go-resty/resty/v2" @@ -311,11 +313,14 @@ func (y *Cloud189TV) RapidUpload(ctx context.Context, dstDir model.Obj, stream m // 旧版本上传,家庭云不支持覆盖 func (y *Cloud189TV) OldUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { - tempFile, err := file.CacheFullInTempFile() - if err != nil { - return nil, err + fileMd5 := file.GetHash().GetHash(utils.MD5) + var tempFile = file.GetFile() + var err error + if len(fileMd5) != utils.MD5.Width { + tempFile, fileMd5, err = stream.CacheFullAndHash(file, &up, utils.MD5) + } else if tempFile == nil { + tempFile, err = file.CacheFullAndWriter(&up, nil) } - fileMd5, err := utils.HashFile(utils.MD5, tempFile) if err != nil { return nil, err } @@ -345,7 +350,7 @@ func (y *Cloud189TV) OldUpload(ctx context.Context, dstDir model.Obj, file model header["Edrive-UploadFileId"] = fmt.Sprint(status.UploadFileId) } - _, err := y.put(ctx, status.FileUploadUrl, header, true, io.NopCloser(tempFile), isFamily) + _, err := y.put(ctx, status.FileUploadUrl, header, true, tempFile, isFamily) if err, ok := err.(*RespErr); ok && err.Code != "InputStreamReadError" { return nil, err } diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index e38f636a..c791e755 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -500,7 +500,8 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo if err != nil { return nil, err } - ss, err := stream.NewStreamSectionReader(file, int(sliceSize)) + + ss, err := stream.NewStreamSectionReader(file, int(sliceSize), &up) if err != nil { return nil, err } @@ -581,7 +582,7 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo return nil }, After: func(err error) { - ss.RecycleSectionReader(reader) + ss.FreeSectionReader(reader) }, }, ) @@ -857,9 +858,7 @@ func (y *Cloud189PC) GetMultiUploadUrls(ctx context.Context, isFamily bool, uplo // 旧版本上传,家庭云不支持覆盖 func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - tempFile, fileMd5, err := stream.CacheFullInTempFileAndHash(file, cacheFileProgress, utils.MD5) + tempFile, fileMd5, err := stream.CacheFullAndHash(file, &up, utils.MD5) if err != nil { return nil, err } diff --git a/drivers/alias/driver.go b/drivers/alias/driver.go index 284cdc40..6954f2b5 100644 --- a/drivers/alias/driver.go +++ b/drivers/alias/driver.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "net/url" stdpath "path" "strings" @@ -12,6 +13,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/fs" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/internal/sign" "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" @@ -160,25 +162,18 @@ func (d *Alias) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( sign.Sign(reqPath)), }, nil } + + resultLink := *link + resultLink.SyncClosers = utils.NewSyncClosers(link) if args.Redirect { - return link, nil + return &resultLink, nil } - resultLink := &model.Link{ - URL: link.URL, - Header: link.Header, - RangeReader: link.RangeReader, - MFile: link.MFile, - Concurrency: link.Concurrency, - PartSize: link.PartSize, - ContentLength: link.ContentLength, - SyncClosers: utils.NewSyncClosers(link), - } if resultLink.ContentLength == 0 { resultLink.ContentLength = fi.GetSize() } if resultLink.MFile != nil { - return resultLink, nil + return &resultLink, nil } if d.DownloadConcurrency > 0 { resultLink.Concurrency = d.DownloadConcurrency @@ -186,7 +181,7 @@ func (d *Alias) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( if d.DownloadPartSize > 0 { resultLink.PartSize = d.DownloadPartSize * utils.KB } - return resultLink, nil + return &resultLink, nil } return nil, errs.ObjectNotFound } @@ -313,24 +308,29 @@ func (d *Alias) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, reqPath, err := d.getReqPath(ctx, dstDir, true) if err == nil { if len(reqPath) == 1 { - return fs.PutDirectly(ctx, *reqPath[0], &stream.FileStream{ - Obj: s, - Mimetype: s.GetMimetype(), - WebPutAsTask: s.NeedStore(), - Reader: s, - }) - } else { - file, err := s.CacheFullInTempFile() + storage, reqActualPath, err := op.GetStorageAndActualPath(*reqPath[0]) if err != nil { return err } - for _, path := range reqPath { + return op.Put(ctx, storage, reqActualPath, &stream.FileStream{ + Obj: s, + Mimetype: s.GetMimetype(), + Reader: s, + }, up) + } else { + file, err := s.CacheFullAndWriter(nil, nil) + if err != nil { + return err + } + count := float64(len(reqPath) + 1) + up(100 / count) + for i, path := range reqPath { err = errors.Join(err, fs.PutDirectly(ctx, *path, &stream.FileStream{ - Obj: s, - Mimetype: s.GetMimetype(), - WebPutAsTask: s.NeedStore(), - Reader: file, + Obj: s, + Mimetype: s.GetMimetype(), + Reader: file, })) + up(float64(i+2) / float64(count) * 100) _, e := file.Seek(0, io.SeekStart) if e != nil { return errors.Join(err, e) @@ -402,10 +402,24 @@ func (d *Alias) Extract(ctx context.Context, obj model.Obj, args model.ArchiveIn return nil, errs.ObjectNotFound } for _, dst := range dsts { - link, err := d.extract(ctx, dst, sub, args) - if err == nil { - return link, nil + reqPath := stdpath.Join(dst, sub) + link, err := d.extract(ctx, reqPath, args) + if err != nil { + continue } + if link == nil { + return &model.Link{ + URL: fmt.Sprintf("%s/ap%s?inner=%s&pass=%s&sign=%s", + common.GetApiUrl(ctx), + utils.EncodePath(reqPath, true), + utils.EncodePath(args.InnerPath, true), + url.QueryEscape(args.Password), + sign.SignArchive(reqPath)), + }, nil + } + resultLink := *link + resultLink.SyncClosers = utils.NewSyncClosers(link) + return &resultLink, nil } return nil, errs.NotImplement } diff --git a/drivers/alias/util.go b/drivers/alias/util.go index a31ec1c5..11c299e9 100644 --- a/drivers/alias/util.go +++ b/drivers/alias/util.go @@ -2,8 +2,6 @@ package alias import ( "context" - "fmt" - "net/url" stdpath "path" "strings" @@ -12,8 +10,6 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/fs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" - "github.com/OpenListTeam/OpenList/v4/internal/sign" - "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/server/common" ) @@ -140,8 +136,7 @@ func (d *Alias) listArchive(ctx context.Context, dst, sub string, args model.Arc return nil, errs.NotImplement } -func (d *Alias) extract(ctx context.Context, dst, sub string, args model.ArchiveInnerArgs) (*model.Link, error) { - reqPath := stdpath.Join(dst, sub) +func (d *Alias) extract(ctx context.Context, reqPath string, args model.ArchiveInnerArgs) (*model.Link, error) { storage, reqActualPath, err := op.GetStorageAndActualPath(reqPath) if err != nil { return nil, err @@ -149,20 +144,12 @@ func (d *Alias) extract(ctx context.Context, dst, sub string, args model.Archive if _, ok := storage.(driver.ArchiveReader); !ok { return nil, errs.NotImplement } - if args.Redirect && common.ShouldProxy(storage, stdpath.Base(sub)) { - _, err = fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true}) - if err != nil { + if args.Redirect && common.ShouldProxy(storage, stdpath.Base(reqPath)) { + _, err := fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true}) + if err == nil { return nil, err } - link := &model.Link{ - URL: fmt.Sprintf("%s/ap%s?inner=%s&pass=%s&sign=%s", - common.GetApiUrl(ctx), - utils.EncodePath(reqPath, true), - utils.EncodePath(args.InnerPath, true), - url.QueryEscape(args.Password), - sign.SignArchive(reqPath)), - } - return link, nil + return nil, nil } link, _, err := op.DriverExtract(ctx, storage, reqActualPath, args) return link, err diff --git a/drivers/aliyundrive_open/upload.go b/drivers/aliyundrive_open/upload.go index 93ad3182..9114997c 100644 --- a/drivers/aliyundrive_open/upload.go +++ b/drivers/aliyundrive_open/upload.go @@ -191,9 +191,7 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m hash := stream.GetHash().GetHash(utils.SHA1) if len(hash) != utils.SHA1.Width { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, hash, err = streamPkg.CacheFullInTempFileAndHash(stream, cacheFileProgress, utils.SHA1) + _, hash, err = streamPkg.CacheFullAndHash(stream, &up, utils.SHA1) if err != nil { return nil, err } @@ -218,14 +216,13 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m if !createResp.RapidUpload { // 2. normal upload log.Debugf("[aliyundive_open] normal upload") - - preTime := time.Now() - var offset, length int64 = 0, partSize - //var length - ss, err := streamPkg.NewStreamSectionReader(stream, int(partSize)) + ss, err := streamPkg.NewStreamSectionReader(stream, int(partSize), &up) if err != nil { return nil, err } + + preTime := time.Now() + var offset, length int64 = 0, partSize for i := 0; i < len(createResp.PartInfoList); i++ { if utils.IsCanceled(ctx) { return nil, ctx.Err() @@ -253,7 +250,7 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m retry.Attempts(3), retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second)) - ss.RecycleSectionReader(rd) + ss.FreeSectionReader(rd) if err != nil { return nil, err } diff --git a/drivers/cloudreve/util.go b/drivers/cloudreve/util.go index 88ff67cc..c9894b3a 100644 --- a/drivers/cloudreve/util.go +++ b/drivers/cloudreve/util.go @@ -237,15 +237,16 @@ func (d *Cloudreve) upLocal(ctx context.Context, stream model.FileStreamer, u Up } func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u UploadInfo, up driver.UpdateProgress) error { + DEFAULT := int64(u.ChunkSize) + ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT), &up) + if err != nil { + return err + } + uploadUrl := u.UploadURLs[0] credential := u.Credential var finish int64 = 0 var chunk int = 0 - DEFAULT := int64(u.ChunkSize) - ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT)) - if err != nil { - return err - } for finish < stream.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() @@ -294,7 +295,7 @@ func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u U retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) - ss.RecycleSectionReader(rd) + ss.FreeSectionReader(rd) if err != nil { return err } @@ -306,13 +307,14 @@ func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u U } func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u UploadInfo, up driver.UpdateProgress) error { - uploadUrl := u.UploadURLs[0] - var finish int64 = 0 DEFAULT := int64(u.ChunkSize) - ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT)) + ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT), &up) if err != nil { return err } + + uploadUrl := u.UploadURLs[0] + var finish int64 = 0 for finish < stream.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() @@ -353,7 +355,7 @@ func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) - ss.RecycleSectionReader(rd) + ss.FreeSectionReader(rd) if err != nil { return err } @@ -367,14 +369,15 @@ func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u } func (d *Cloudreve) upS3(ctx context.Context, stream model.FileStreamer, u UploadInfo, up driver.UpdateProgress) error { - var finish int64 = 0 - var chunk int = 0 - var etags []string DEFAULT := int64(u.ChunkSize) - ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT)) + ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT), &up) if err != nil { return err } + + var finish int64 = 0 + var chunk int = 0 + var etags []string for finish < stream.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() @@ -415,7 +418,7 @@ func (d *Cloudreve) upS3(ctx context.Context, stream model.FileStreamer, u Uploa retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) - ss.RecycleSectionReader(rd) + ss.FreeSectionReader(rd) if err != nil { return err } diff --git a/drivers/cloudreve_v4/util.go b/drivers/cloudreve_v4/util.go index 215b4c53..fc03ee7a 100644 --- a/drivers/cloudreve_v4/util.go +++ b/drivers/cloudreve_v4/util.go @@ -252,15 +252,16 @@ func (d *CloudreveV4) upLocal(ctx context.Context, file model.FileStreamer, u Fi } func (d *CloudreveV4) upRemote(ctx context.Context, file model.FileStreamer, u FileUploadResp, up driver.UpdateProgress) error { + DEFAULT := int64(u.ChunkSize) + ss, err := stream.NewStreamSectionReader(file, int(DEFAULT), &up) + if err != nil { + return err + } + uploadUrl := u.UploadUrls[0] credential := u.Credential var finish int64 = 0 var chunk int = 0 - DEFAULT := int64(u.ChunkSize) - ss, err := stream.NewStreamSectionReader(file, int(DEFAULT)) - if err != nil { - return err - } for finish < file.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() @@ -309,7 +310,7 @@ func (d *CloudreveV4) upRemote(ctx context.Context, file model.FileStreamer, u F retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) - ss.RecycleSectionReader(rd) + ss.FreeSectionReader(rd) if err != nil { return err } @@ -321,13 +322,14 @@ func (d *CloudreveV4) upRemote(ctx context.Context, file model.FileStreamer, u F } func (d *CloudreveV4) upOneDrive(ctx context.Context, file model.FileStreamer, u FileUploadResp, up driver.UpdateProgress) error { - uploadUrl := u.UploadUrls[0] - var finish int64 = 0 DEFAULT := int64(u.ChunkSize) - ss, err := stream.NewStreamSectionReader(file, int(DEFAULT)) + ss, err := stream.NewStreamSectionReader(file, int(DEFAULT), &up) if err != nil { return err } + + uploadUrl := u.UploadUrls[0] + var finish int64 = 0 for finish < file.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() @@ -369,7 +371,7 @@ func (d *CloudreveV4) upOneDrive(ctx context.Context, file model.FileStreamer, u retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) - ss.RecycleSectionReader(rd) + ss.FreeSectionReader(rd) if err != nil { return err } @@ -383,14 +385,15 @@ func (d *CloudreveV4) upOneDrive(ctx context.Context, file model.FileStreamer, u } func (d *CloudreveV4) upS3(ctx context.Context, file model.FileStreamer, u FileUploadResp, up driver.UpdateProgress) error { - var finish int64 = 0 - var chunk int = 0 - var etags []string DEFAULT := int64(u.ChunkSize) - ss, err := stream.NewStreamSectionReader(file, int(DEFAULT)) + ss, err := stream.NewStreamSectionReader(file, int(DEFAULT), &up) if err != nil { return err } + + var finish int64 = 0 + var chunk int = 0 + var etags []string for finish < file.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() @@ -432,7 +435,7 @@ func (d *CloudreveV4) upS3(ctx context.Context, file model.FileStreamer, u FileU retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) - ss.RecycleSectionReader(rd) + ss.FreeSectionReader(rd) if err != nil { return err } diff --git a/drivers/crypt/driver.go b/drivers/crypt/driver.go index 4cd64348..704c70cb 100644 --- a/drivers/crypt/driver.go +++ b/drivers/crypt/driver.go @@ -401,7 +401,6 @@ func (d *Crypt) Put(ctx context.Context, dstDir model.Obj, streamer model.FileSt }, Reader: wrappedIn, Mimetype: "application/octet-stream", - WebPutAsTask: streamer.NeedStore(), ForceStreamUpload: true, Exist: streamer.GetExist(), } diff --git a/drivers/doubao/util.go b/drivers/doubao/util.go index 7dd1da2c..325c16c5 100644 --- a/drivers/doubao/util.go +++ b/drivers/doubao/util.go @@ -449,10 +449,11 @@ func (d *Doubao) uploadNode(uploadConfig *UploadConfig, dir model.Obj, file mode // Upload 普通上传实现 func (d *Doubao) Upload(ctx context.Context, config *UploadConfig, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, dataType string) (model.Obj, error) { - ss, err := stream.NewStreamSectionReader(file, int(file.GetSize())) + ss, err := stream.NewStreamSectionReader(file, int(file.GetSize()), &up) if err != nil { return nil, err } + reader, err := ss.GetSectionReader(0, file.GetSize()) if err != nil { return nil, err @@ -503,7 +504,7 @@ func (d *Doubao) Upload(ctx context.Context, config *UploadConfig, dstDir model. } return nil }) - ss.RecycleSectionReader(reader) + ss.FreeSectionReader(reader) if err != nil { return nil, err } @@ -542,15 +543,15 @@ func (d *Doubao) UploadByMultipart(ctx context.Context, config *UploadConfig, fi if config.InnerUploadAddress.AdvanceOption.SliceSize > 0 { chunkSize = int64(config.InnerUploadAddress.AdvanceOption.SliceSize) } + ss, err := stream.NewStreamSectionReader(file, int(chunkSize), &up) + if err != nil { + return nil, err + } + totalParts := (fileSize + chunkSize - 1) / chunkSize // 创建分片信息组 parts := make([]UploadPart, totalParts) - // 用 stream.NewStreamSectionReader 替代缓存临时文件 - ss, err := stream.NewStreamSectionReader(file, int(chunkSize)) - if err != nil { - return nil, err - } up(10.0) // 更新进度 // 设置并行上传 thread := min(int(totalParts), d.uploadThread) @@ -641,7 +642,7 @@ func (d *Doubao) UploadByMultipart(ctx context.Context, config *UploadConfig, fi return nil }, After: func(err error) { - ss.RecycleSectionReader(reader) + ss.FreeSectionReader(reader) }, }) } diff --git a/drivers/google_drive/driver.go b/drivers/google_drive/driver.go index 203deada..c4dd01af 100644 --- a/drivers/google_drive/driver.go +++ b/drivers/google_drive/driver.go @@ -162,7 +162,7 @@ func (d *GoogleDrive) Put(ctx context.Context, dstDir model.Obj, stream model.Fi SetBody(driver.NewLimitedUploadStream(ctx, stream)) }, nil) } else { - err = d.chunkUpload(ctx, stream, putUrl) + err = d.chunkUpload(ctx, stream, putUrl, up) } return err } diff --git a/drivers/google_drive/util.go b/drivers/google_drive/util.go index ff219136..ff4bb7b9 100644 --- a/drivers/google_drive/util.go +++ b/drivers/google_drive/util.go @@ -254,13 +254,14 @@ func (d *GoogleDrive) getFiles(id string) ([]File, error) { return res, nil } -func (d *GoogleDrive) chunkUpload(ctx context.Context, file model.FileStreamer, url string) error { +func (d *GoogleDrive) chunkUpload(ctx context.Context, file model.FileStreamer, url string, up driver.UpdateProgress) error { var defaultChunkSize = d.ChunkSize * 1024 * 1024 - var offset int64 = 0 - ss, err := stream.NewStreamSectionReader(file, int(defaultChunkSize)) + ss, err := stream.NewStreamSectionReader(file, int(defaultChunkSize), &up) if err != nil { return err } + + var offset int64 = 0 url += "?includeItemsFromAllDrives=true&supportsAllDrives=true" for offset < file.GetSize() { if utils.IsCanceled(ctx) { @@ -300,12 +301,13 @@ func (d *GoogleDrive) chunkUpload(ctx context.Context, file model.FileStreamer, } return fmt.Errorf("%s: %v", e.Error.Message, e.Error.Errors) } + up(float64(offset+chunkSize) / float64(file.GetSize()) * 100) return nil }, retry.Attempts(3), retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second)) - ss.RecycleSectionReader(reader) + ss.FreeSectionReader(reader) if err != nil { return err } diff --git a/drivers/ilanzou/driver.go b/drivers/ilanzou/driver.go index 59b24b53..0e4f9be2 100644 --- a/drivers/ilanzou/driver.go +++ b/drivers/ilanzou/driver.go @@ -276,9 +276,7 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreame etag := s.GetHash().GetHash(utils.MD5) var err error if len(etag) != utils.MD5.Width { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, etag, err = stream.CacheFullInTempFileAndHash(s, cacheFileProgress, utils.MD5) + _, etag, err = stream.CacheFullAndHash(s, &up, utils.MD5) if err != nil { return nil, err } diff --git a/drivers/mediatrack/driver.go b/drivers/mediatrack/driver.go index 15d84f31..85c8fb6d 100644 --- a/drivers/mediatrack/driver.go +++ b/drivers/mediatrack/driver.go @@ -180,7 +180,7 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, file model.FileS if err != nil { return err } - tempFile, err := file.CacheFullInTempFile() + tempFile, err := file.CacheFullAndWriter(&up, nil) if err != nil { return err } diff --git a/drivers/mopan/driver.go b/drivers/mopan/driver.go index e6d1bf8f..c611cf02 100644 --- a/drivers/mopan/driver.go +++ b/drivers/mopan/driver.go @@ -263,7 +263,7 @@ func (d *MoPan) Remove(ctx context.Context, obj model.Obj) error { } func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { - file, err := stream.CacheFullInTempFile() + file, err := stream.CacheFullAndWriter(&up, nil) if err != nil { return nil, err } diff --git a/drivers/netease_music/util.go b/drivers/netease_music/util.go index daa91930..d090b5f8 100644 --- a/drivers/netease_music/util.go +++ b/drivers/netease_music/util.go @@ -223,7 +223,7 @@ func (d *NeteaseMusic) removeSongObj(file model.Obj) error { } func (d *NeteaseMusic) putSongStream(ctx context.Context, stream model.FileStreamer, up driver.UpdateProgress) error { - tmp, err := stream.CacheFullInTempFile() + tmp, err := stream.CacheFullAndWriter(&up, nil) if err != nil { return err } diff --git a/drivers/onedrive/util.go b/drivers/onedrive/util.go index 672d3c51..3e853cd1 100644 --- a/drivers/onedrive/util.go +++ b/drivers/onedrive/util.go @@ -238,13 +238,14 @@ func (d *Onedrive) upBig(ctx context.Context, dstDir model.Obj, stream model.Fil if err != nil { return err } - uploadUrl := jsoniter.Get(res, "uploadUrl").ToString() - var finish int64 = 0 DEFAULT := d.ChunkSize * 1024 * 1024 - ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT)) + ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT), &up) if err != nil { return err } + + uploadUrl := jsoniter.Get(res, "uploadUrl").ToString() + var finish int64 = 0 for finish < stream.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() @@ -285,7 +286,7 @@ func (d *Onedrive) upBig(ctx context.Context, dstDir model.Obj, stream model.Fil retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) - ss.RecycleSectionReader(rd) + ss.FreeSectionReader(rd) if err != nil { return err } diff --git a/drivers/onedrive_app/util.go b/drivers/onedrive_app/util.go index 2aca3688..783760ff 100644 --- a/drivers/onedrive_app/util.go +++ b/drivers/onedrive_app/util.go @@ -152,13 +152,14 @@ func (d *OnedriveAPP) upBig(ctx context.Context, dstDir model.Obj, stream model. if err != nil { return err } - uploadUrl := jsoniter.Get(res, "uploadUrl").ToString() - var finish int64 = 0 DEFAULT := d.ChunkSize * 1024 * 1024 - ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT)) + ss, err := streamPkg.NewStreamSectionReader(stream, int(DEFAULT), &up) if err != nil { return err } + + uploadUrl := jsoniter.Get(res, "uploadUrl").ToString() + var finish int64 = 0 for finish < stream.GetSize() { if utils.IsCanceled(ctx) { return ctx.Err() @@ -199,7 +200,7 @@ func (d *OnedriveAPP) upBig(ctx context.Context, dstDir model.Obj, stream model. retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second), ) - ss.RecycleSectionReader(rd) + ss.FreeSectionReader(rd) if err != nil { return err } diff --git a/drivers/pikpak/driver.go b/drivers/pikpak/driver.go index 74399451..c728123c 100644 --- a/drivers/pikpak/driver.go +++ b/drivers/pikpak/driver.go @@ -12,6 +12,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" + streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" hash_extend "github.com/OpenListTeam/OpenList/v4/pkg/utils/hash" "github.com/go-resty/resty/v2" @@ -212,15 +213,11 @@ func (d *PikPak) Remove(ctx context.Context, obj model.Obj) error { } func (d *PikPak) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - hi := stream.GetHash() - sha1Str := hi.GetHash(hash_extend.GCID) - if len(sha1Str) < hash_extend.GCID.Width { - tFile, err := stream.CacheFullInTempFile() - if err != nil { - return err - } + sha1Str := stream.GetHash().GetHash(hash_extend.GCID) - sha1Str, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize()) + if len(sha1Str) < hash_extend.GCID.Width { + var err error + _, sha1Str, err = streamPkg.CacheFullAndHash(stream, &up, hash_extend.GCID, stream.GetSize()) if err != nil { return err } diff --git a/drivers/pikpak/util.go b/drivers/pikpak/util.go index 8b499c1b..9b7207fa 100644 --- a/drivers/pikpak/util.go +++ b/drivers/pikpak/util.go @@ -438,20 +438,19 @@ func (d *PikPak) UploadByOSS(ctx context.Context, params *S3Params, s model.File } func (d *PikPak) UploadByMultipart(ctx context.Context, params *S3Params, fileSize int64, s model.FileStreamer, up driver.UpdateProgress) error { + tmpF, err := s.CacheFullAndWriter(&up, nil) + if err != nil { + return err + } + var ( chunks []oss.FileChunk parts []oss.UploadPart imur oss.InitiateMultipartUploadResult ossClient *oss.Client bucket *oss.Bucket - err error ) - tmpF, err := s.CacheFullInTempFile() - if err != nil { - return err - } - if ossClient, err = oss.New(params.Endpoint, params.AccessKeyID, params.AccessKeySecret); err != nil { return err } diff --git a/drivers/quark_open/driver.go b/drivers/quark_open/driver.go index cc757580..7b6b3133 100644 --- a/drivers/quark_open/driver.go +++ b/drivers/quark_open/driver.go @@ -14,7 +14,6 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" - streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/go-resty/resty/v2" ) @@ -158,9 +157,7 @@ func (d *QuarkOpen) Put(ctx context.Context, dstDir model.Obj, stream model.File } if len(writers) > 0 { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, err := streamPkg.CacheFullInTempFileAndWriter(stream, cacheFileProgress, io.MultiWriter(writers...)) + _, err := stream.CacheFullAndWriter(&up, io.MultiWriter(writers...)) if err != nil { return err } diff --git a/drivers/quark_uc/driver.go b/drivers/quark_uc/driver.go index c9944345..f5b30d8b 100644 --- a/drivers/quark_uc/driver.go +++ b/drivers/quark_uc/driver.go @@ -13,7 +13,6 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" - streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" @@ -144,9 +143,7 @@ func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.File } if len(writers) > 0 { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, err := streamPkg.CacheFullInTempFileAndWriter(stream, cacheFileProgress, io.MultiWriter(writers...)) + _, err := stream.CacheFullAndWriter(&up, io.MultiWriter(writers...)) if err != nil { return err } diff --git a/drivers/strm/driver.go b/drivers/strm/driver.go index 69660358..80af52d4 100644 --- a/drivers/strm/driver.go +++ b/drivers/strm/driver.go @@ -173,8 +173,9 @@ func (d *Strm) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (* }, nil } - // 没有修改link的字段,可直接返回 - return link, nil + resultLink := *link + resultLink.SyncClosers = utils.NewSyncClosers(link) + return &resultLink, nil } var _ driver.Driver = (*Strm)(nil) diff --git a/drivers/terabox/driver.go b/drivers/terabox/driver.go index bf28d26e..a5b94577 100644 --- a/drivers/terabox/driver.go +++ b/drivers/terabox/driver.go @@ -179,7 +179,7 @@ func (d *Terabox) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt } // upload chunks - tempFile, err := stream.CacheFullInTempFile() + tempFile, err := stream.CacheFullAndWriter(&up, nil) if err != nil { return err } diff --git a/drivers/thunder/driver.go b/drivers/thunder/driver.go index 83cb4b3f..7f537ac6 100644 --- a/drivers/thunder/driver.go +++ b/drivers/thunder/driver.go @@ -371,9 +371,7 @@ func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, file model.Fi gcid := file.GetHash().GetHash(hash_extend.GCID) var err error if len(gcid) < hash_extend.GCID.Width { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, gcid, err = stream.CacheFullInTempFileAndHash(file, cacheFileProgress, hash_extend.GCID, file.GetSize()) + _, gcid, err = stream.CacheFullAndHash(file, &up, hash_extend.GCID, file.GetSize()) if err != nil { return err } diff --git a/drivers/thunder_browser/driver.go b/drivers/thunder_browser/driver.go index 89d68ea1..bf1843a4 100644 --- a/drivers/thunder_browser/driver.go +++ b/drivers/thunder_browser/driver.go @@ -491,9 +491,7 @@ func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream gcid := stream.GetHash().GetHash(hash_extend.GCID) var err error if len(gcid) < hash_extend.GCID.Width { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, gcid, err = streamPkg.CacheFullInTempFileAndHash(stream, cacheFileProgress, hash_extend.GCID, stream.GetSize()) + _, gcid, err = streamPkg.CacheFullAndHash(stream, &up, hash_extend.GCID, stream.GetSize()) if err != nil { return err } diff --git a/drivers/thunderx/driver.go b/drivers/thunderx/driver.go index 0ee65378..86ff22bd 100644 --- a/drivers/thunderx/driver.go +++ b/drivers/thunderx/driver.go @@ -372,9 +372,7 @@ func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, file model.F gcid := file.GetHash().GetHash(hash_extend.GCID) var err error if len(gcid) < hash_extend.GCID.Width { - cacheFileProgress := model.UpdateProgressWithRange(up, 0, 50) - up = model.UpdateProgressWithRange(up, 50, 100) - _, gcid, err = stream.CacheFullInTempFileAndHash(file, cacheFileProgress, hash_extend.GCID, file.GetSize()) + _, gcid, err = stream.CacheFullAndHash(file, &up, hash_extend.GCID, file.GetSize()) if err != nil { return err } diff --git a/drivers/weiyun/driver.go b/drivers/weiyun/driver.go index fe9b70a4..ef203cfa 100644 --- a/drivers/weiyun/driver.go +++ b/drivers/weiyun/driver.go @@ -317,7 +317,7 @@ func (d *WeiYun) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr if folder, ok = dstDir.(*Folder); !ok { return nil, errs.NotSupport } - file, err := stream.CacheFullInTempFile() + file, err := stream.CacheFullAndWriter(&up, nil) if err != nil { return nil, err } diff --git a/internal/bootstrap/config.go b/internal/bootstrap/config.go index 3980deb8..bff899e1 100644 --- a/internal/bootstrap/config.go +++ b/internal/bootstrap/config.go @@ -91,7 +91,7 @@ func InitConfig() { } else { conf.MaxBufferLimit = conf.Conf.MaxBufferLimit * utils.MB } - log.Infof("max buffer limit: %d", conf.MaxBufferLimit) + log.Infof("max buffer limit: %dMB", conf.MaxBufferLimit/utils.MB) if !conf.Conf.Force { confFromEnv() } diff --git a/internal/conf/var.go b/internal/conf/var.go index 50a7f33d..83fb87e9 100644 --- a/internal/conf/var.go +++ b/internal/conf/var.go @@ -25,7 +25,7 @@ var PrivacyReg []*regexp.Regexp var ( // StoragesLoaded loaded success if empty StoragesLoaded = false - MaxBufferLimit int + MaxBufferLimit = 16 * 1024 * 1024 ) var ( RawIndexHtml string diff --git a/internal/fs/archive.go b/internal/fs/archive.go index 40cc3981..e1e4c448 100644 --- a/internal/fs/archive.go +++ b/internal/fs/archive.go @@ -70,25 +70,25 @@ func (t *ArchiveDownloadTask) RunWithoutPushUploadTask() (*ArchiveContentUploadT }() var decompressUp model.UpdateProgress if t.CacheFull { - var total, cur int64 = 0, 0 + total := int64(0) for _, s := range ss { total += s.GetSize() } t.SetTotalBytes(total) t.Status = "getting src object" - for _, s := range ss { - if s.GetFile() == nil { - _, err = stream.CacheFullInTempFileAndWriter(s, func(p float64) { - t.SetProgress((float64(cur) + float64(s.GetSize())*p/100.0) / float64(total)) - }, nil) + part := 100 / float64(len(ss)+1) + for i, s := range ss { + if s.GetFile() != nil { + continue } - cur += s.GetSize() + _, err = s.CacheFullAndWriter(nil, nil) if err != nil { return nil, err + } else { + t.SetProgress(float64(i+1) * part) } } - t.SetProgress(100.0) - decompressUp = func(_ float64) {} + decompressUp = model.UpdateProgressWithRange(t.SetProgress, 100-part, 100) } else { decompressUp = t.SetProgress } diff --git a/internal/fs/put.go b/internal/fs/put.go index 887c8d63..881330b0 100644 --- a/internal/fs/put.go +++ b/internal/fs/put.go @@ -69,7 +69,7 @@ func putAsTask(ctx context.Context, dstDirPath string, file model.FileStreamer) return nil, errors.WithStack(errs.UploadNotSupported) } if file.NeedStore() { - _, err := file.CacheFullInTempFile() + _, err := file.CacheFullAndWriter(nil, nil) if err != nil { return nil, errors.Wrapf(err, "failed to create temp file") } diff --git a/internal/model/obj.go b/internal/model/obj.go index 33ae6e3d..836904fc 100644 --- a/internal/model/obj.go +++ b/internal/model/obj.go @@ -2,7 +2,6 @@ package model import ( "io" - "os" "sort" "strings" "time" @@ -40,16 +39,17 @@ type FileStreamer interface { utils.ClosersIF Obj GetMimetype() string - //SetReader(io.Reader) NeedStore() bool IsForceStreamUpload() bool GetExist() Obj SetExist(Obj) - //for a non-seekable Stream, RangeRead supports peeking some data, and CacheFullInTempFile still works + // for a non-seekable Stream, RangeRead supports peeking some data, and CacheFullAndWriter still works RangeRead(http_range.Range) (io.Reader, error) - //for a non-seekable Stream, if Read is called, this function won't work - CacheFullInTempFile() (File, error) - SetTmpFile(r *os.File) + // for a non-seekable Stream, if Read is called, this function won't work. + // caches the full Stream and writes it to writer (if provided, even if the stream is already cached). + CacheFullAndWriter(up *UpdateProgress, writer io.Writer) (File, error) + SetTmpFile(file File) + // if the Stream is not a File and is not cached, returns nil. GetFile() File } diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 932975a4..344a7759 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -1,7 +1,6 @@ package stream import ( - "bytes" "context" "errors" "fmt" @@ -13,6 +12,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/buffer" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "go4.org/readerutil" @@ -27,13 +27,19 @@ type FileStream struct { ForceStreamUpload bool Exist model.Obj //the file existed in the destination, we can reuse some info since we wil overwrite it utils.Closers - tmpFile *os.File //if present, tmpFile has full content, it will be deleted at last - peekBuff *bytes.Reader + + tmpFile model.File //if present, tmpFile has full content, it will be deleted at last + peekBuff *buffer.Reader + size int64 + oriReader io.Reader // the original reader, used for caching } func (f *FileStream) GetSize() int64 { - if f.tmpFile != nil { - info, err := f.tmpFile.Stat() + if f.size > 0 { + return f.size + } + if file, ok := f.tmpFile.(*os.File); ok { + info, err := file.Stat() if err == nil { return info.Size() } @@ -60,14 +66,18 @@ func (f *FileStream) Close() error { if errors.Is(err1, os.ErrClosed) { err1 = nil } - if f.tmpFile != nil { - err2 = os.RemoveAll(f.tmpFile.Name()) + if file, ok := f.tmpFile.(*os.File); ok { + err2 = os.RemoveAll(file.Name()) if err2 != nil { - err2 = errs.NewErr(err2, "failed to remove tmpFile [%s]", f.tmpFile.Name()) + err2 = errs.NewErr(err2, "failed to remove tmpFile [%s]", file.Name()) } else { f.tmpFile = nil } } + if f.peekBuff != nil { + f.peekBuff.Reset() + f.peekBuff = nil + } return errors.Join(err1, err2) } @@ -79,20 +89,55 @@ func (f *FileStream) SetExist(obj model.Obj) { f.Exist = obj } -// CacheFullInTempFile save all data into tmpFile. Not recommended since it wears disk, -// and can't start upload until the file is written. It's not thread-safe! -func (f *FileStream) CacheFullInTempFile() (model.File, error) { - if file := f.GetFile(); file != nil { - return file, nil +// CacheFullAndWriter save all data into tmpFile or memory. +// It's not thread-safe! +func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writer) (model.File, error) { + if cache := f.GetFile(); cache != nil { + if writer == nil { + return cache, nil + } + _, err := cache.Seek(0, io.SeekStart) + if err == nil { + reader := f.Reader + if up != nil { + cacheProgress := model.UpdateProgressWithRange(*up, 0, 50) + *up = model.UpdateProgressWithRange(*up, 50, 100) + reader = &ReaderUpdatingProgress{ + Reader: &SimpleReaderWithSize{ + Reader: reader, + Size: f.GetSize(), + }, + UpdateProgress: cacheProgress, + } + } + _, err = utils.CopyWithBuffer(writer, reader) + if err == nil { + _, err = cache.Seek(0, io.SeekStart) + } + } + if err != nil { + return nil, err + } + return cache, nil } - tmpF, err := utils.CreateTempFile(f.Reader, f.GetSize()) - if err != nil { - return nil, err + + reader := f.Reader + if up != nil { + cacheProgress := model.UpdateProgressWithRange(*up, 0, 50) + *up = model.UpdateProgressWithRange(*up, 50, 100) + reader = &ReaderUpdatingProgress{ + Reader: &SimpleReaderWithSize{ + Reader: reader, + Size: f.GetSize(), + }, + UpdateProgress: cacheProgress, + } } - f.Add(tmpF) - f.tmpFile = tmpF - f.Reader = tmpF - return tmpF, nil + if writer != nil { + reader = io.TeeReader(reader, writer) + } + f.Reader = reader + return f.cache(f.GetSize()) } func (f *FileStream) GetFile() model.File { @@ -106,40 +151,68 @@ func (f *FileStream) GetFile() model.File { } // RangeRead have to cache all data first since only Reader is provided. -// also support a peeking RangeRead at very start, but won't buffer more than conf.MaxBufferLimit data in memory +// It's not thread-safe! func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { if httpRange.Length < 0 || httpRange.Start+httpRange.Length > f.GetSize() { httpRange.Length = f.GetSize() - httpRange.Start } - var cache io.ReaderAt = f.GetFile() - if cache != nil { - return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil + if f.GetFile() != nil { + return io.NewSectionReader(f.GetFile(), httpRange.Start, httpRange.Length), nil } size := httpRange.Start + httpRange.Length if f.peekBuff != nil && size <= int64(f.peekBuff.Len()) { return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil } - if size <= int64(conf.MaxBufferLimit) { - bufSize := min(size, f.GetSize()) - // 使用bytes.Buffer作为io.CopyBuffer的写入对象,CopyBuffer会调用Buffer.ReadFrom - // 即使被写入的数据量与Buffer.Cap一致,Buffer也会扩大 - buf := make([]byte, bufSize) - n, err := io.ReadFull(f.Reader, buf) - if err != nil { - return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err) - } - f.peekBuff = bytes.NewReader(buf) - f.Reader = io.MultiReader(f.peekBuff, f.Reader) - cache = f.peekBuff - } else { - var err error - cache, err = f.CacheFullInTempFile() + + cache, err := f.cache(size) + if err != nil { + return nil, err + } + + return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil +} + +// *旧笔记 +// 使用bytes.Buffer作为io.CopyBuffer的写入对象,CopyBuffer会调用Buffer.ReadFrom +// 即使被写入的数据量与Buffer.Cap一致,Buffer也会扩大 + +func (f *FileStream) cache(maxCacheSize int64) (model.File, error) { + if maxCacheSize > int64(conf.MaxBufferLimit) { + tmpF, err := utils.CreateTempFile(f.Reader, f.GetSize()) if err != nil { return nil, err } + f.Add(tmpF) + f.tmpFile = tmpF + f.Reader = tmpF + return tmpF, nil } - return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil + + if f.peekBuff == nil { + f.peekBuff = &buffer.Reader{} + f.oriReader = f.Reader + } + bufSize := maxCacheSize - int64(f.peekBuff.Len()) + buf := make([]byte, bufSize) + n, err := io.ReadFull(f.oriReader, buf) + if bufSize != int64(n) { + return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err) + } + f.peekBuff.Append(buf) + if int64(f.peekBuff.Len()) >= f.GetSize() { + f.Reader = f.peekBuff + f.oriReader = nil + } else { + f.Reader = io.MultiReader(f.peekBuff, f.oriReader) + } + return f.peekBuff, nil +} + +func (f *FileStream) SetTmpFile(file model.File) { + f.AddIfCloser(file) + f.tmpFile = file + f.Reader = file } var _ model.FileStreamer = (*SeekableStream)(nil) @@ -156,7 +229,6 @@ type SeekableStream struct { *FileStream // should have one of belows to support rangeRead rangeReadCloser model.RangeReadCloserIF - size int64 } func NewSeekableStream(fs *FileStream, link *model.Link) (*SeekableStream, error) { @@ -178,38 +250,26 @@ func NewSeekableStream(fs *FileStream, link *model.Link) (*SeekableStream, error if err != nil { return nil, err } - if _, ok := rr.(*model.FileRangeReader); ok { - fs.Reader, err = rr.RangeRead(fs.Ctx, http_range.Range{Length: -1}) - if err != nil { - return nil, err - } - fs.Add(link) - return &SeekableStream{FileStream: fs, size: size}, nil - } rrc := &model.RangeReadCloser{ RangeReader: rr, } + if _, ok := rr.(*model.FileRangeReader); ok { + fs.Reader, err = rrc.RangeRead(fs.Ctx, http_range.Range{Length: -1}) + if err != nil { + return nil, err + } + } + fs.size = size fs.Add(link) fs.Add(rrc) - return &SeekableStream{FileStream: fs, rangeReadCloser: rrc, size: size}, nil + return &SeekableStream{FileStream: fs, rangeReadCloser: rrc}, nil } return nil, fmt.Errorf("illegal seekableStream") } -func (ss *SeekableStream) GetSize() int64 { - if ss.size > 0 { - return ss.size - } - return ss.FileStream.GetSize() -} - -//func (ss *SeekableStream) Peek(length int) { -// -//} - // RangeRead is not thread-safe, pls use it in single thread only. func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { - if ss.tmpFile == nil && ss.rangeReadCloser != nil { + if ss.GetFile() == nil && ss.rangeReadCloser != nil { rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, httpRange) if err != nil { return nil, err @@ -219,47 +279,37 @@ func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, erro return ss.FileStream.RangeRead(httpRange) } -//func (f *FileStream) GetReader() io.Reader { -// return f.Reader -//} - // only provide Reader as full stream when it's demanded. in rapid-upload, we can skip this to save memory func (ss *SeekableStream) Read(p []byte) (n int, err error) { + if err := ss.generateReader(); err != nil { + return 0, err + } + return ss.FileStream.Read(p) +} + +func (ss *SeekableStream) generateReader() error { if ss.Reader == nil { if ss.rangeReadCloser == nil { - return 0, fmt.Errorf("illegal seekableStream") + return fmt.Errorf("illegal seekableStream") } rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, http_range.Range{Length: -1}) if err != nil { - return 0, err + return err } ss.Reader = rc } - return ss.Reader.Read(p) + return nil } -func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) { - if file := ss.GetFile(); file != nil { - return file, nil - } - tmpF, err := utils.CreateTempFile(ss, ss.GetSize()) - if err != nil { +func (ss *SeekableStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writer) (model.File, error) { + if err := ss.generateReader(); err != nil { return nil, err } - ss.Add(tmpF) - ss.tmpFile = tmpF - ss.Reader = tmpF - return tmpF, nil -} - -func (f *FileStream) SetTmpFile(r *os.File) { - f.Add(r) - f.tmpFile = r - f.Reader = r + return ss.FileStream.CacheFullAndWriter(up, writer) } type ReaderWithSize interface { - io.ReadCloser + io.Reader GetSize() int64 } @@ -293,7 +343,10 @@ func (r *ReaderUpdatingProgress) Read(p []byte) (n int, err error) { } func (r *ReaderUpdatingProgress) Close() error { - return r.Reader.Close() + if c, ok := r.Reader.(io.Closer); ok { + return c.Close() + } + return nil } type RangeReadReadAtSeeker struct { @@ -311,19 +364,20 @@ type headCache struct { func (c *headCache) head(p []byte) (int, error) { n := 0 for _, buf := range c.bufs { - if len(buf)+n >= len(p) { - n += copy(p[n:], buf[:len(p)-n]) + n += copy(p[n:], buf) + if n == len(p) { return n, nil - } else { - n += copy(p[n:], buf) } } - w, err := io.ReadAtLeast(c.reader, p[n:], 1) - if w > 0 { - buf := make([]byte, w) - copy(buf, p[n:n+w]) + nn, err := io.ReadFull(c.reader, p[n:]) + if nn > 0 { + buf := make([]byte, nn) + copy(buf, p[n:]) c.bufs = append(c.bufs, buf) - n += w + n += nn + if err == io.ErrUnexpectedEOF { + err = io.EOF + } } return n, err } @@ -422,6 +476,9 @@ func (r *RangeReadReadAtSeeker) getReaderAtOffset(off int64) (io.Reader, error) } func (r *RangeReadReadAtSeeker) ReadAt(p []byte, off int64) (n int, err error) { + if off < 0 || off >= r.ss.GetSize() { + return 0, io.EOF + } if off == 0 && r.headCache != nil { return r.headCache.head(p) } @@ -430,12 +487,15 @@ func (r *RangeReadReadAtSeeker) ReadAt(p []byte, off int64) (n int, err error) { if err != nil { return 0, err } - n, err = io.ReadAtLeast(rr, p, 1) - off += int64(n) - if err == nil { - r.readerMap.Store(int64(off), rr) - } else { - rr = nil + n, err = io.ReadFull(rr, p) + if n > 0 { + off += int64(n) + switch err { + case nil: + r.readerMap.Store(int64(off), rr) + case io.ErrUnexpectedEOF: + err = io.EOF + } } return n, err } @@ -444,20 +504,14 @@ func (r *RangeReadReadAtSeeker) Seek(offset int64, whence int) (int64, error) { switch whence { case io.SeekStart: case io.SeekCurrent: - if offset == 0 { - return r.masterOff, nil - } offset += r.masterOff case io.SeekEnd: offset += r.ss.GetSize() default: - return 0, errs.NotSupport + return 0, errors.New("Seek: invalid whence") } - if offset < 0 { - return r.masterOff, errors.New("invalid seek: negative position") - } - if offset > r.ss.GetSize() { - offset = r.ss.GetSize() + if offset < 0 || offset > r.ss.GetSize() { + return 0, errors.New("Seek: invalid offset") } r.masterOff = offset return offset, nil @@ -465,6 +519,8 @@ func (r *RangeReadReadAtSeeker) Seek(offset int64, whence int) (int64, error) { func (r *RangeReadReadAtSeeker) Read(p []byte) (n int, err error) { n, err = r.ReadAt(p, r.masterOff) - r.masterOff += int64(n) + if n > 0 { + r.masterOff += int64(n) + } return n, err } diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go new file mode 100644 index 00000000..0c7412ff --- /dev/null +++ b/internal/stream/stream_test.go @@ -0,0 +1,86 @@ +package stream + +import ( + "bytes" + "errors" + "fmt" + "io" + "testing" + + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" +) + +func TestFileStream_RangeRead(t *testing.T) { + type args struct { + httpRange http_range.Range + } + buf := []byte("github.com/OpenListTeam/OpenList") + f := &FileStream{ + Obj: &model.Object{ + Size: int64(len(buf)), + }, + Reader: io.NopCloser(bytes.NewReader(buf)), + } + tests := []struct { + name string + f *FileStream + args args + want func(f *FileStream, got io.Reader, err error) error + }{ + { + name: "range 11-12", + f: f, + args: args{ + httpRange: http_range.Range{Start: 11, Length: 12}, + }, + want: func(f *FileStream, got io.Reader, err error) error { + if f.GetFile() != nil { + return errors.New("cached") + } + b, _ := io.ReadAll(got) + if !bytes.Equal(buf[11:11+12], b) { + return fmt.Errorf("=%s ,want =%s", b, buf[11:11+12]) + } + return nil + }, + }, + { + name: "range 11-21", + f: f, + args: args{ + httpRange: http_range.Range{Start: 11, Length: 21}, + }, + want: func(f *FileStream, got io.Reader, err error) error { + if f.GetFile() == nil { + return errors.New("not cached") + } + b, _ := io.ReadAll(got) + if !bytes.Equal(buf[11:11+21], b) { + return fmt.Errorf("=%s ,want =%s", b, buf[11:11+21]) + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.f.RangeRead(tt.args.httpRange) + if err := tt.want(tt.f, got, err); err != nil { + t.Errorf("FileStream.RangeRead() %v", err) + } + }) + } + t.Run("after check", func(t *testing.T) { + if f.GetFile() == nil { + t.Error("not cached") + } + buf2 := make([]byte, len(buf)) + if _, err := io.ReadFull(f, buf2); err != nil { + t.Errorf("FileStream.Read() error = %v", err) + } + if !bytes.Equal(buf, buf2) { + t.Errorf("FileStream.Read() = %s, want %s", buf2, buf) + } + }) +} diff --git a/internal/stream/util.go b/internal/stream/util.go index d2de46ac..2df1963a 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -141,56 +141,13 @@ func (r *ReaderWithCtx) Close() error { return nil } -func CacheFullInTempFileAndWriter(stream model.FileStreamer, up model.UpdateProgress, w io.Writer) (model.File, error) { - if cache := stream.GetFile(); cache != nil { - if w != nil { - _, err := cache.Seek(0, io.SeekStart) - if err == nil { - var reader io.Reader = stream - if up != nil { - reader = &ReaderUpdatingProgress{ - Reader: stream, - UpdateProgress: up, - } - } - _, err = utils.CopyWithBuffer(w, reader) - if err == nil { - _, err = cache.Seek(0, io.SeekStart) - } - } - return cache, err - } - if up != nil { - up(100) - } - return cache, nil - } - - var reader io.Reader = stream - if up != nil { - reader = &ReaderUpdatingProgress{ - Reader: stream, - UpdateProgress: up, - } - } - - if w != nil { - reader = io.TeeReader(reader, w) - } - tmpF, err := utils.CreateTempFile(reader, stream.GetSize()) - if err == nil { - stream.SetTmpFile(tmpF) - } - return tmpF, err -} - -func CacheFullInTempFileAndHash(stream model.FileStreamer, up model.UpdateProgress, hashType *utils.HashType, hashParams ...any) (model.File, string, error) { +func CacheFullAndHash(stream model.FileStreamer, up *model.UpdateProgress, hashType *utils.HashType, hashParams ...any) (model.File, string, error) { h := hashType.NewFunc(hashParams...) - tmpF, err := CacheFullInTempFileAndWriter(stream, up, h) + tmpF, err := stream.CacheFullAndWriter(up, h) if err != nil { return nil, "", err } - return tmpF, hex.EncodeToString(h.Sum(nil)), err + return tmpF, hex.EncodeToString(h.Sum(nil)), nil } type StreamSectionReader struct { @@ -199,12 +156,12 @@ type StreamSectionReader struct { bufPool *sync.Pool } -func NewStreamSectionReader(file model.FileStreamer, maxBufferSize int) (*StreamSectionReader, error) { +func NewStreamSectionReader(file model.FileStreamer, maxBufferSize int, up *model.UpdateProgress) (*StreamSectionReader, error) { ss := &StreamSectionReader{file: file} if file.GetFile() == nil { maxBufferSize = min(maxBufferSize, int(file.GetSize())) if maxBufferSize > conf.MaxBufferLimit { - _, err := file.CacheFullInTempFile() + _, err := file.CacheFullAndWriter(up, nil) if err != nil { return nil, err } @@ -240,7 +197,7 @@ func (ss *StreamSectionReader) GetSectionReader(off, length int64) (*SectionRead return &SectionReader{io.NewSectionReader(cache, off, length), buf}, nil } -func (ss *StreamSectionReader) RecycleSectionReader(sr *SectionReader) { +func (ss *StreamSectionReader) FreeSectionReader(sr *SectionReader) { if sr != nil { if sr.buf != nil { ss.bufPool.Put(sr.buf[0:cap(sr.buf)]) diff --git a/pkg/buffer/bytes.go b/pkg/buffer/bytes.go new file mode 100644 index 00000000..3ee10747 --- /dev/null +++ b/pkg/buffer/bytes.go @@ -0,0 +1,92 @@ +package buffer + +import ( + "errors" + "io" +) + +// 用于存储不复用的[]byte +type Reader struct { + bufs [][]byte + length int + offset int +} + +func (r *Reader) Len() int { + return r.length +} + +func (r *Reader) Append(buf []byte) { + r.length += len(buf) + r.bufs = append(r.bufs, buf) +} + +func (r *Reader) Read(p []byte) (int, error) { + n, err := r.ReadAt(p, int64(r.offset)) + if n > 0 { + r.offset += n + } + return n, err +} + +func (r *Reader) ReadAt(p []byte, off int64) (int, error) { + if off < 0 || off >= int64(r.length) { + return 0, io.EOF + } + + n, length := 0, int64(0) + readFrom := false + for _, buf := range r.bufs { + newLength := length + int64(len(buf)) + if readFrom { + w := copy(p[n:], buf) + n += w + } else if off < newLength { + readFrom = true + w := copy(p[n:], buf[int(off-length):]) + n += w + } + if n == len(p) { + return n, nil + } + length = newLength + } + + return n, io.EOF +} + +func (r *Reader) Seek(offset int64, whence int) (int64, error) { + var abs int + switch whence { + case io.SeekStart: + abs = int(offset) + case io.SeekCurrent: + abs = r.offset + int(offset) + case io.SeekEnd: + abs = r.length + int(offset) + default: + return 0, errors.New("Seek: invalid whence") + } + + if abs < 0 || abs > r.length { + return 0, errors.New("Seek: invalid offset") + } + + r.offset = abs + return int64(abs), nil +} + +func (r *Reader) Reset() { + clear(r.bufs) + r.bufs = nil + r.length = 0 + r.offset = 0 +} + +func NewReader(buf ...[]byte) *Reader { + b := &Reader{} + for _, b1 := range buf { + b.Append(b1) + } + return b +} diff --git a/pkg/buffer/bytes_test.go b/pkg/buffer/bytes_test.go new file mode 100644 index 00000000..b66af229 --- /dev/null +++ b/pkg/buffer/bytes_test.go @@ -0,0 +1,95 @@ +package buffer + +import ( + "errors" + "io" + "testing" +) + +func TestReader_ReadAt(t *testing.T) { + type args struct { + p []byte + off int64 + } + bs := &Reader{} + bs.Append([]byte("github.com")) + bs.Append([]byte("/")) + bs.Append([]byte("OpenList")) + bs.Append([]byte("Team/")) + bs.Append([]byte("OpenList")) + tests := []struct { + name string + b *Reader + args args + want func(a args, n int, err error) error + }{ + { + name: "readAt len 10 offset 0", + b: bs, + args: args{ + p: make([]byte, 10), + off: 0, + }, + want: func(a args, n int, err error) error { + if n != len(a.p) { + return errors.New("read length not match") + } + if string(a.p) != "github.com" { + return errors.New("read content not match") + } + if err != nil { + return err + } + return nil + }, + }, + { + name: "readAt len 12 offset 11", + b: bs, + args: args{ + p: make([]byte, 12), + off: 11, + }, + want: func(a args, n int, err error) error { + if n != len(a.p) { + return errors.New("read length not match") + } + if string(a.p) != "OpenListTeam" { + return errors.New("read content not match") + } + if err != nil { + return err + } + return nil + }, + }, + { + name: "readAt len 50 offset 24", + b: bs, + args: args{ + p: make([]byte, 50), + off: 24, + }, + want: func(a args, n int, err error) error { + if n != bs.Len()-int(a.off) { + return errors.New("read length not match") + } + if string(a.p[:n]) != "OpenList" { + return errors.New("read content not match") + } + if err != io.EOF { + return errors.New("expect eof") + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.b.ReadAt(tt.args.p, tt.args.off) + if err := tt.want(tt.args, got, err); err != nil { + t.Errorf("Bytes.ReadAt() error = %v", err) + } + }) + } +} diff --git a/pkg/errgroup/errgroup.go b/pkg/errgroup/errgroup.go index daf1b315..d3c4feaf 100644 --- a/pkg/errgroup/errgroup.go +++ b/pkg/errgroup/errgroup.go @@ -53,11 +53,12 @@ func (g *Group) Go(do func(ctx context.Context) error) { } type Lifecycle struct { - // Before在OrderedGroup是线程安全的 + // Before在OrderedGroup是线程安全的。 + // 只会被调用一次 Before func(ctx context.Context) error // 如果Before返回err就不调用Do Do func(ctx context.Context) error - // 最后调用After + // 最后调用一次After After func(err error) } diff --git a/pkg/utils/io.go b/pkg/utils/io.go index dd0e3fac..172dc41c 100644 --- a/pkg/utils/io.go +++ b/pkg/utils/io.go @@ -194,32 +194,32 @@ type SyncClosersIF interface { type SyncClosers struct { closers []io.Closer - ref atomic.Int32 + ref int32 } var _ SyncClosersIF = (*SyncClosers)(nil) func (c *SyncClosers) AcquireReference() bool { - ref := c.ref.Add(1) + ref := atomic.AddInt32(&c.ref, 1) if ref > 0 { // log.Debugf("SyncClosers.AcquireReference %p,ref=%d\n", c, ref) return true } - c.ref.Store(math.MinInt16) + atomic.StoreInt32(&c.ref, math.MinInt16) return false } func (c *SyncClosers) Close() error { - ref := c.ref.Add(-1) + ref := atomic.AddInt32(&c.ref, -1) if ref < -1 { - c.ref.Store(math.MinInt16) + atomic.StoreInt32(&c.ref, math.MinInt16) return nil } // log.Debugf("SyncClosers.Close %p,ref=%d\n", c, ref+1) if ref > 0 { return nil } - c.ref.Store(math.MinInt16) + atomic.StoreInt32(&c.ref, math.MinInt16) var errs []error for _, closer := range c.closers { @@ -234,7 +234,7 @@ func (c *SyncClosers) Close() error { func (c *SyncClosers) Add(closer io.Closer) { if closer != nil { - if c.ref.Load() < 0 { + if atomic.LoadInt32(&c.ref) < 0 { panic("Not reusable") } c.closers = append(c.closers, closer) @@ -243,7 +243,7 @@ func (c *SyncClosers) Add(closer io.Closer) { func (c *SyncClosers) AddIfCloser(a any) { if closer, ok := a.(io.Closer); ok { - if c.ref.Load() < 0 { + if atomic.LoadInt32(&c.ref) < 0 { panic("Not reusable") } c.closers = append(c.closers, closer)