From 3936e736e6075538b9d6b6d622c04e06bbec920e Mon Sep 17 00:00:00 2001 From: KirCute <951206789@qq.com> Date: Fri, 19 Sep 2025 19:27:35 +0800 Subject: [PATCH] feat(drivers): add a driver that divides large files into multiple chunks (#1153) --- drivers/all.go | 1 + drivers/chunk/driver.go | 488 ++++++++++++++++++++++++++++++++++++++++ drivers/chunk/meta.go | 31 +++ drivers/chunk/obj.go | 8 + internal/net/request.go | 4 +- internal/net/serve.go | 12 +- internal/op/archive.go | 5 +- internal/op/fs.go | 8 +- internal/stream/util.go | 2 +- pkg/utils/hash.go | 5 + pkg/utils/io.go | 41 ++-- server/handles/down.go | 2 +- server/webdav/webdav.go | 2 +- 13 files changed, 575 insertions(+), 34 deletions(-) create mode 100644 drivers/chunk/driver.go create mode 100644 drivers/chunk/meta.go create mode 100644 drivers/chunk/obj.go diff --git a/drivers/all.go b/drivers/all.go index ce614735..197a936d 100644 --- a/drivers/all.go +++ b/drivers/all.go @@ -20,6 +20,7 @@ import ( _ "github.com/OpenListTeam/OpenList/v4/drivers/baidu_netdisk" _ "github.com/OpenListTeam/OpenList/v4/drivers/baidu_photo" _ "github.com/OpenListTeam/OpenList/v4/drivers/chaoxing" + _ "github.com/OpenListTeam/OpenList/v4/drivers/chunk" _ "github.com/OpenListTeam/OpenList/v4/drivers/cloudreve" _ "github.com/OpenListTeam/OpenList/v4/drivers/cloudreve_v4" _ "github.com/OpenListTeam/OpenList/v4/drivers/cnb_releases" diff --git a/drivers/chunk/driver.go b/drivers/chunk/driver.go new file mode 100644 index 00000000..76346974 --- /dev/null +++ b/drivers/chunk/driver.go @@ -0,0 +1,488 @@ +package chunk + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + stdpath "path" + "strconv" + "strings" + + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/fs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/internal/sign" + "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/OpenListTeam/OpenList/v4/server/common" +) + +type Chunk struct { + model.Storage + Addition +} + +func (d *Chunk) Config() driver.Config { + return config +} + +func (d *Chunk) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Chunk) Init(ctx context.Context) error { + if d.PartSize <= 0 { + return errors.New("part size must be positive") + } + d.RemotePath = utils.FixAndCleanPath(d.RemotePath) + return nil +} + +func (d *Chunk) Drop(ctx context.Context) error { + return nil +} + +func (d *Chunk) Get(ctx context.Context, path string) (model.Obj, error) { + if utils.PathEqual(path, "/") { + return &model.Object{ + Name: "Root", + IsFolder: true, + Path: "/", + }, nil + } + remoteStorage, remoteActualPath, err := op.GetStorageAndActualPath(d.RemotePath) + if err != nil { + return nil, err + } + remoteActualPath = stdpath.Join(remoteActualPath, path) + if remoteObj, err := op.Get(ctx, remoteStorage, remoteActualPath); err == nil { + return &model.Object{ + Path: path, + Name: remoteObj.GetName(), + Size: remoteObj.GetSize(), + Modified: remoteObj.ModTime(), + IsFolder: remoteObj.IsDir(), + HashInfo: remoteObj.GetHash(), + }, nil + } + + remoteActualDir, name := stdpath.Split(remoteActualPath) + chunkName := "[openlist_chunk]" + name + chunkObjs, err := op.List(ctx, remoteStorage, stdpath.Join(remoteActualDir, chunkName), model.ListArgs{}) + if err != nil { + return nil, err + } + var totalSize int64 = 0 + // 0号块必须存在 + chunkSizes := []int64{-1} + h := make(map[*utils.HashType]string) + var first model.Obj + for _, o := range chunkObjs { + if o.IsDir() { + continue + } + if after, ok := strings.CutPrefix(o.GetName(), "hash_"); ok { + hn, value, ok := strings.Cut(strings.TrimSuffix(after, d.CustomExt), "_") + if ok { + ht, ok := utils.GetHashByName(hn) + if ok { + h[ht] = value + } + } + continue + } + idx, err := strconv.Atoi(strings.TrimSuffix(o.GetName(), d.CustomExt)) + if err != nil { + continue + } + totalSize += o.GetSize() + if len(chunkSizes) > idx { + if idx == 0 { + first = o + } + chunkSizes[idx] = o.GetSize() + } else if len(chunkSizes) == idx { + chunkSizes = append(chunkSizes, o.GetSize()) + } else { + newChunkSizes := make([]int64, idx+1) + copy(newChunkSizes, chunkSizes) + chunkSizes = newChunkSizes + chunkSizes[idx] = o.GetSize() + } + } + // 检查0号块不等于-1 以支持空文件 + // 如果块数量大于1 最后一块不可能为0 + // 只检查中间块是否有0 + for i, l := 0, len(chunkSizes)-2; ; i++ { + if i == 0 { + if chunkSizes[i] == -1 { + return nil, fmt.Errorf("chunk part[%d] are missing", i) + } + } else if chunkSizes[i] == 0 { + return nil, fmt.Errorf("chunk part[%d] are missing", i) + } + if i >= l { + break + } + } + reqDir, _ := stdpath.Split(path) + objRes := chunkObject{ + Object: model.Object{ + Path: stdpath.Join(reqDir, chunkName), + Name: name, + Size: totalSize, + Modified: first.ModTime(), + Ctime: first.CreateTime(), + }, + chunkSizes: chunkSizes, + } + if len(h) > 0 { + objRes.HashInfo = utils.NewHashInfoByMap(h) + } + return &objRes, nil +} + +func (d *Chunk) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + remoteStorage, remoteActualPath, err := op.GetStorageAndActualPath(d.RemotePath) + if err != nil { + return nil, err + } + remoteActualDir := stdpath.Join(remoteActualPath, dir.GetPath()) + remoteObjs, err := op.List(ctx, remoteStorage, remoteActualDir, model.ListArgs{ + ReqPath: args.ReqPath, + Refresh: args.Refresh, + }) + if err != nil { + return nil, err + } + result := make([]model.Obj, 0, len(remoteObjs)) + for _, obj := range remoteObjs { + rawName := obj.GetName() + if obj.IsDir() { + if name, ok := strings.CutPrefix(rawName, "[openlist_chunk]"); ok { + chunkObjs, err := op.List(ctx, remoteStorage, stdpath.Join(remoteActualDir, rawName), model.ListArgs{ + ReqPath: stdpath.Join(args.ReqPath, rawName), + Refresh: args.Refresh, + }) + if err != nil { + return nil, err + } + totalSize := int64(0) + h := make(map[*utils.HashType]string) + first := obj + for _, o := range chunkObjs { + if o.IsDir() { + continue + } + if after, ok := strings.CutPrefix(strings.TrimSuffix(o.GetName(), d.CustomExt), "hash_"); ok { + hn, value, ok := strings.Cut(after, "_") + if ok { + ht, ok := utils.GetHashByName(hn) + if ok { + h[ht] = value + } + continue + } + } + idx, err := strconv.Atoi(strings.TrimSuffix(o.GetName(), d.CustomExt)) + if err != nil { + continue + } + if idx == 0 { + first = o + } + totalSize += o.GetSize() + } + objRes := model.Object{ + Name: name, + Size: totalSize, + Modified: first.ModTime(), + Ctime: first.CreateTime(), + } + if len(h) > 0 { + objRes.HashInfo = utils.NewHashInfoByMap(h) + } + if !d.Thumbnail { + result = append(result, &objRes) + } else { + thumbPath := stdpath.Join(args.ReqPath, ".thumbnails", name+".webp") + thumb := fmt.Sprintf("%s/d%s?sign=%s", + common.GetApiUrl(ctx), + utils.EncodePath(thumbPath, true), + sign.Sign(thumbPath)) + result = append(result, &model.ObjThumb{ + Object: objRes, + Thumbnail: model.Thumbnail{ + Thumbnail: thumb, + }, + }) + } + continue + } + } + + if !d.ShowHidden && strings.HasPrefix(rawName, ".") { + continue + } + thumb, ok := model.GetThumb(obj) + objRes := model.Object{ + Name: rawName, + Size: obj.GetSize(), + Modified: obj.ModTime(), + IsFolder: obj.IsDir(), + HashInfo: obj.GetHash(), + } + if !ok { + result = append(result, &objRes) + } else { + result = append(result, &model.ObjThumb{ + Object: objRes, + Thumbnail: model.Thumbnail{ + Thumbnail: thumb, + }, + }) + } + } + return result, nil +} + +func (d *Chunk) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + remoteStorage, remoteActualPath, err := op.GetStorageAndActualPath(d.RemotePath) + if err != nil { + return nil, err + } + chunkFile, ok := file.(*chunkObject) + remoteActualPath = stdpath.Join(remoteActualPath, file.GetPath()) + if !ok { + l, _, err := op.Link(ctx, remoteStorage, remoteActualPath, args) + if err != nil { + return nil, err + } + resultLink := *l + resultLink.SyncClosers = utils.NewSyncClosers(l) + return &resultLink, nil + } + fileSize := chunkFile.GetSize() + mergedRrf := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + start := httpRange.Start + length := httpRange.Length + if length < 0 || start+length > fileSize { + length = fileSize - start + } + if length == 0 { + return io.NopCloser(strings.NewReader("")), nil + } + rs := make([]io.Reader, 0) + cs := make(utils.Closers, 0) + var ( + rc io.ReadCloser + readFrom bool + ) + for idx, chunkSize := range chunkFile.chunkSizes { + if readFrom { + l, o, err := op.Link(ctx, remoteStorage, stdpath.Join(remoteActualPath, d.getPartName(idx)), args) + if err != nil { + _ = cs.Close() + return nil, err + } + cs = append(cs, l) + chunkSize2 := l.ContentLength + if chunkSize2 <= 0 { + chunkSize2 = o.GetSize() + } + if chunkSize2 != chunkSize { + _ = cs.Close() + return nil, fmt.Errorf("chunk part[%d] size not match", idx) + } + rrf, err := stream.GetRangeReaderFromLink(chunkSize2, l) + if err != nil { + _ = cs.Close() + return nil, err + } + newLength := length - chunkSize2 + if newLength >= 0 { + length = newLength + rc, err = rrf.RangeRead(ctx, http_range.Range{Length: -1}) + } else { + rc, err = rrf.RangeRead(ctx, http_range.Range{Length: length}) + } + if err != nil { + _ = cs.Close() + return nil, err + } + rs = append(rs, rc) + cs = append(cs, rc) + if newLength <= 0 { + return utils.ReadCloser{ + Reader: io.MultiReader(rs...), + Closer: &cs, + }, nil + } + } else if newStart := start - chunkSize; newStart >= 0 { + start = newStart + } else { + l, o, err := op.Link(ctx, remoteStorage, stdpath.Join(remoteActualPath, d.getPartName(idx)), args) + if err != nil { + _ = cs.Close() + return nil, err + } + cs = append(cs, l) + chunkSize2 := l.ContentLength + if chunkSize2 <= 0 { + chunkSize2 = o.GetSize() + } + if chunkSize2 != chunkSize { + _ = cs.Close() + return nil, fmt.Errorf("chunk part[%d] size not match", idx) + } + rrf, err := stream.GetRangeReaderFromLink(chunkSize2, l) + if err != nil { + _ = cs.Close() + return nil, err + } + rc, err = rrf.RangeRead(ctx, http_range.Range{Start: start, Length: -1}) + if err != nil { + _ = cs.Close() + return nil, err + } + length -= chunkSize2 - start + cs = append(cs, rc) + if length <= 0 { + return utils.ReadCloser{ + Reader: rc, + Closer: &cs, + }, nil + } + rs = append(rs, rc) + readFrom = true + } + } + return nil, fmt.Errorf("invalid range: start=%d,length=%d,fileSize=%d", httpRange.Start, httpRange.Length, fileSize) + } + return &model.Link{ + RangeReader: stream.RangeReaderFunc(mergedRrf), + }, nil +} + +func (d *Chunk) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + path := stdpath.Join(d.RemotePath, parentDir.GetPath(), dirName) + return fs.MakeDir(ctx, path) +} + +func (d *Chunk) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + src := stdpath.Join(d.RemotePath, srcObj.GetPath()) + dst := stdpath.Join(d.RemotePath, dstDir.GetPath()) + _, err := fs.Move(ctx, src, dst) + return err +} + +func (d *Chunk) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + if _, ok := srcObj.(*chunkObject); ok { + newName = "[openlist_chunk]" + newName + } + return fs.Rename(ctx, stdpath.Join(d.RemotePath, srcObj.GetPath()), newName) +} + +func (d *Chunk) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + dst := stdpath.Join(d.RemotePath, dstDir.GetPath()) + src := stdpath.Join(d.RemotePath, srcObj.GetPath()) + _, err := fs.Copy(ctx, src, dst) + return err +} + +func (d *Chunk) Remove(ctx context.Context, obj model.Obj) error { + return fs.Remove(ctx, stdpath.Join(d.RemotePath, obj.GetPath())) +} + +func (d *Chunk) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { + remoteStorage, remoteActualPath, err := op.GetStorageAndActualPath(d.RemotePath) + if err != nil { + return err + } + if d.Thumbnail && dstDir.GetName() == ".thumbnails" { + return op.Put(ctx, remoteStorage, stdpath.Join(remoteActualPath, dstDir.GetPath()), file, up) + } + upReader := &driver.ReaderUpdatingProgress{ + Reader: file, + UpdateProgress: up, + } + dst := stdpath.Join(remoteActualPath, dstDir.GetPath(), "[openlist_chunk]"+file.GetName()) + if d.StoreHash { + for ht, value := range file.GetHash().All() { + _ = op.Put(ctx, remoteStorage, dst, &stream.FileStream{ + Obj: &model.Object{ + Name: fmt.Sprintf("hash_%s_%s%s", ht.Name, value, d.CustomExt), + Size: 1, + Modified: file.ModTime(), + }, + Mimetype: "application/octet-stream", + Reader: bytes.NewReader([]byte{0}), // 兼容不支持空文件的驱动 + }, nil, true) + } + } + fullPartCount := int(file.GetSize() / d.PartSize) + tailSize := file.GetSize() % d.PartSize + if tailSize == 0 && fullPartCount > 0 { + fullPartCount-- + tailSize = d.PartSize + } + partIndex := 0 + for partIndex < fullPartCount { + err = op.Put(ctx, remoteStorage, dst, &stream.FileStream{ + Obj: &model.Object{ + Name: d.getPartName(partIndex), + Size: d.PartSize, + Modified: file.ModTime(), + }, + Mimetype: file.GetMimetype(), + Reader: io.LimitReader(upReader, d.PartSize), + }, nil, true) + if err != nil { + _ = op.Remove(ctx, remoteStorage, dst) + return err + } + partIndex++ + } + err = op.Put(ctx, remoteStorage, dst, &stream.FileStream{ + Obj: &model.Object{ + Name: d.getPartName(fullPartCount), + Size: tailSize, + Modified: file.ModTime(), + }, + Mimetype: file.GetMimetype(), + Reader: upReader, + }, nil) + if err != nil { + _ = op.Remove(ctx, remoteStorage, dst) + } + return err +} + +func (d *Chunk) getPartName(part int) string { + return fmt.Sprintf("%d%s", part, d.CustomExt) +} + +func (d *Chunk) GetDetails(ctx context.Context) (*model.StorageDetails, error) { + remoteStorage, err := fs.GetStorage(d.RemotePath, &fs.GetStoragesArgs{}) + if err != nil { + return nil, errs.NotImplement + } + wd, ok := remoteStorage.(driver.WithDetails) + if !ok { + return nil, errs.NotImplement + } + remoteDetails, err := wd.GetDetails(ctx) + if err != nil { + return nil, err + } + return &model.StorageDetails{ + DiskUsage: remoteDetails.DiskUsage, + }, nil +} + +var _ driver.Driver = (*Chunk)(nil) diff --git a/drivers/chunk/meta.go b/drivers/chunk/meta.go new file mode 100644 index 00000000..45429231 --- /dev/null +++ b/drivers/chunk/meta.go @@ -0,0 +1,31 @@ +package chunk + +import ( + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/op" +) + +type Addition struct { + RemotePath string `json:"remote_path" required:"true"` + PartSize int64 `json:"part_size" required:"true" type:"number" help:"bytes"` + CustomExt string `json:"custom_ext" type:"string"` + StoreHash bool `json:"store_hash" type:"bool" default:"true"` + + Thumbnail bool `json:"thumbnail" required:"true" default:"false" help:"enable thumbnail which pre-generated under .thumbnails folder"` + ShowHidden bool `json:"show_hidden" default:"true" required:"false" help:"show hidden directories and files"` +} + +var config = driver.Config{ + Name: "Chunk", + LocalSort: true, + OnlyProxy: true, + NoCache: true, + DefaultRoot: "/", + NoLinkURL: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Chunk{} + }) +} diff --git a/drivers/chunk/obj.go b/drivers/chunk/obj.go new file mode 100644 index 00000000..1885a925 --- /dev/null +++ b/drivers/chunk/obj.go @@ -0,0 +1,8 @@ +package chunk + +import "github.com/OpenListTeam/OpenList/v4/internal/model" + +type chunkObject struct { + model.Object + chunkSizes []int64 +} diff --git a/internal/net/request.go b/internal/net/request.go index 399e01f3..1306bc54 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -125,7 +125,7 @@ type ConcurrencyLimit struct { Limit int // 需要大于0 } -var ErrExceedMaxConcurrency = ErrorHttpStatusCode(http.StatusTooManyRequests) +var ErrExceedMaxConcurrency = HttpStatusCodeError(http.StatusTooManyRequests) func (l *ConcurrencyLimit) sub() error { l._m.Lock() @@ -403,7 +403,7 @@ var errInfiniteRetry = errors.New("infinite retry") func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) { resp, err := d.cfg.HttpClient(d.ctx, params) if err != nil { - statusCode, ok := errors.Unwrap(err).(ErrorHttpStatusCode) + statusCode, ok := errors.Unwrap(err).(HttpStatusCodeError) if !ok { return 0, err } diff --git a/internal/net/serve.go b/internal/net/serve.go index 1fd40b1c..6ffe4120 100644 --- a/internal/net/serve.go +++ b/internal/net/serve.go @@ -114,7 +114,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time reader, err := RangeReadCloser.RangeRead(ctx, http_range.Range{Length: -1}) if err != nil { code = http.StatusRequestedRangeNotSatisfiable - if statusCode, ok := errors.Unwrap(err).(ErrorHttpStatusCode); ok { + if statusCode, ok := errors.Unwrap(err).(HttpStatusCodeError); ok { code = int(statusCode) } http.Error(w, err.Error(), code) @@ -137,7 +137,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time sendContent, err = RangeReadCloser.RangeRead(ctx, ra) if err != nil { code = http.StatusRequestedRangeNotSatisfiable - if statusCode, ok := errors.Unwrap(err).(ErrorHttpStatusCode); ok { + if statusCode, ok := errors.Unwrap(err).(HttpStatusCodeError); ok { code = int(statusCode) } http.Error(w, err.Error(), code) @@ -199,7 +199,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time log.Warnf("Maybe size incorrect or reader not giving correct/full data, or connection closed before finish. written bytes: %d ,sendSize:%d, ", written, sendSize) } code = http.StatusInternalServerError - if statusCode, ok := errors.Unwrap(err).(ErrorHttpStatusCode); ok { + if statusCode, ok := errors.Unwrap(err).(HttpStatusCodeError); ok { code = int(statusCode) } w.WriteHeader(code) @@ -253,14 +253,14 @@ func RequestHttp(ctx context.Context, httpMethod string, headerOverride http.Hea _ = res.Body.Close() msg := string(all) log.Debugln(msg) - return nil, fmt.Errorf("http request [%s] failure,status: %w response:%s", URL, ErrorHttpStatusCode(res.StatusCode), msg) + return nil, fmt.Errorf("http request [%s] failure,status: %w response:%s", URL, HttpStatusCodeError(res.StatusCode), msg) } return res, nil } -type ErrorHttpStatusCode int +type HttpStatusCodeError int -func (e ErrorHttpStatusCode) Error() string { +func (e HttpStatusCodeError) Error() string { return fmt.Sprintf("%d|%s", e, http.StatusText(int(e))) } diff --git a/internal/op/archive.go b/internal/op/archive.go index 964e9397..4d85d206 100644 --- a/internal/op/archive.go +++ b/internal/op/archive.go @@ -405,11 +405,8 @@ func DriverExtract(ctx context.Context, storage driver.Driver, path string, args return nil }) link, err, _ := extractG.Do(key, fn) - if err == nil && !link.AcquireReference() { + for err == nil && !link.AcquireReference() { link, err, _ = extractG.Do(key, fn) - if err == nil { - link.AcquireReference() - } } if err == errLinkMFileCache { if linkM != nil { diff --git a/internal/op/fs.go b/internal/op/fs.go index c5a5b52d..2f3be94b 100644 --- a/internal/op/fs.go +++ b/internal/op/fs.go @@ -184,6 +184,9 @@ func Get(ctx context.Context, storage driver.Driver, path string) (model.Obj, er if err == nil { return model.WrapObjName(obj), nil } + if !errs.IsNotImplement(err) { + return nil, errors.WithMessage(err, "failed to get obj") + } } // is root folder @@ -327,11 +330,8 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li return nil }) link, err, _ := linkG.Do(key, fn) - if err == nil && !link.AcquireReference() { + for err == nil && !link.AcquireReference() { link, err, _ = linkG.Do(key, fn) - if err == nil { - link.AcquireReference() - } } if err == errLinkMFileCache { diff --git a/internal/stream/util.go b/internal/stream/util.go index 20cb4be0..4f51a46d 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -77,7 +77,7 @@ func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, response, err := net.RequestHttp(ctx, "GET", header, link.URL) if err != nil { - if _, ok := errors.Unwrap(err).(net.ErrorHttpStatusCode); ok { + if _, ok := errors.Unwrap(err).(net.HttpStatusCodeError); ok { return nil, err } return nil, fmt.Errorf("http request failure, err:%w", err) diff --git a/pkg/utils/hash.go b/pkg/utils/hash.go index 0b70e4e1..596e61e5 100644 --- a/pkg/utils/hash.go +++ b/pkg/utils/hash.go @@ -57,6 +57,11 @@ var ( Supported []*HashType ) +func GetHashByName(name string) (ht *HashType, ok bool) { + ht, ok = name2hash[name] + return +} + // RegisterHash adds a new Hash to the list and returns its Type func RegisterHash(name, alias string, width int, newFunc func() hash.Hash) *HashType { return RegisterHashWithParam(name, alias, width, func(a ...any) hash.Hash { return newFunc() }) diff --git a/pkg/utils/io.go b/pkg/utils/io.go index 172dc41c..7ce6a912 100644 --- a/pkg/utils/io.go +++ b/pkg/utils/io.go @@ -200,26 +200,37 @@ type SyncClosers struct { var _ SyncClosersIF = (*SyncClosers)(nil) func (c *SyncClosers) AcquireReference() bool { - ref := atomic.AddInt32(&c.ref, 1) - if ref > 0 { - // log.Debugf("SyncClosers.AcquireReference %p,ref=%d\n", c, ref) - return true + for { + ref := atomic.LoadInt32(&c.ref) + if ref < 0 { + return false + } + newRef := ref + 1 + if atomic.CompareAndSwapInt32(&c.ref, ref, newRef) { + log.Debugf("AcquireReference %p: %d", c, newRef) + return true + } } - atomic.StoreInt32(&c.ref, math.MinInt16) - return false } func (c *SyncClosers) Close() error { - ref := atomic.AddInt32(&c.ref, -1) - if ref < -1 { - atomic.StoreInt32(&c.ref, math.MinInt16) - return nil + for { + ref := atomic.LoadInt32(&c.ref) + if ref < 0 { + return nil + } + newRef := ref - 1 + if newRef <= 0 { + newRef = math.MinInt16 + } + if atomic.CompareAndSwapInt32(&c.ref, ref, newRef) { + log.Debugf("Close %p: %d", c, ref) + if newRef > 0 { + return nil + } + break + } } - // log.Debugf("SyncClosers.Close %p,ref=%d\n", c, ref+1) - if ref > 0 { - return nil - } - atomic.StoreInt32(&c.ref, math.MinInt16) var errs []error for _, closer := range c.closers { diff --git a/server/handles/down.go b/server/handles/down.go index 62008c06..84ebdc44 100644 --- a/server/handles/down.go +++ b/server/handles/down.go @@ -147,7 +147,7 @@ func proxy(c *gin.Context, link *model.Link, file model.Obj, proxyRange bool) { if Writer.IsWritten() { log.Errorf("%s %s local proxy error: %+v", c.Request.Method, c.Request.URL.Path, err) } else { - if statusCode, ok := errors.Unwrap(err).(net.ErrorHttpStatusCode); ok { + if statusCode, ok := errors.Unwrap(err).(net.HttpStatusCodeError); ok { common.ErrorPage(c, err, int(statusCode), true) } else { common.ErrorPage(c, err, 500, true) diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index 802947eb..0c4f0922 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -272,7 +272,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta } err = common.Proxy(w, r, link, fi) if err != nil { - if statusCode, ok := errors.Unwrap(err).(net.ErrorHttpStatusCode); ok { + if statusCode, ok := errors.Unwrap(err).(net.HttpStatusCodeError); ok { return int(statusCode), err } return http.StatusInternalServerError, fmt.Errorf("webdav proxy error: %+v", err)