From 016ed90efa8c5438fdf191be9774b72504eef3b4 Mon Sep 17 00:00:00 2001 From: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com> Date: Sat, 16 Aug 2025 17:19:52 +0800 Subject: [PATCH] feat(stream): fast buffer freeing for large cache (#1053) Signed-off-by: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/bootstrap/config.go | 26 +++++-- internal/conf/config.go | 2 + internal/conf/var.go | 3 + internal/net/request.go | 136 +++++++++++++++++++++------------ internal/stream/stream.go | 25 ++++-- internal/stream/stream_test.go | 4 +- internal/stream/util.go | 56 ++++++++++---- pkg/pool/pool.go | 37 +++++++++ 8 files changed, 209 insertions(+), 80 deletions(-) create mode 100644 pkg/pool/pool.go diff --git a/internal/bootstrap/config.go b/internal/bootstrap/config.go index bff899e1..2209c64f 100644 --- a/internal/bootstrap/config.go +++ b/internal/bootstrap/config.go @@ -77,6 +77,10 @@ func InitConfig() { log.Fatalf("update config struct error: %+v", err) } } + if !conf.Conf.Force { + confFromEnv() + } + if conf.Conf.MaxConcurrency > 0 { net.DefaultConcurrencyLimit = &net.ConcurrencyLimit{Limit: conf.Conf.MaxConcurrency} } @@ -92,25 +96,31 @@ func InitConfig() { conf.MaxBufferLimit = conf.Conf.MaxBufferLimit * utils.MB } log.Infof("max buffer limit: %dMB", conf.MaxBufferLimit/utils.MB) - if !conf.Conf.Force { - confFromEnv() + if conf.Conf.MmapThreshold > 0 { + conf.MmapThreshold = conf.Conf.MmapThreshold * utils.MB + } else { + conf.MmapThreshold = 0 } + log.Infof("mmap threshold: %dMB", conf.Conf.MmapThreshold) + if len(conf.Conf.Log.Filter.Filters) == 0 { conf.Conf.Log.Filter.Enable = false } // convert abs path convertAbsPath := func(path *string) { - if !filepath.IsAbs(*path) { + if *path != "" && !filepath.IsAbs(*path) { *path = filepath.Join(pwd, *path) } } + convertAbsPath(&conf.Conf.Database.DBFile) + convertAbsPath(&conf.Conf.Scheme.CertFile) + convertAbsPath(&conf.Conf.Scheme.KeyFile) + convertAbsPath(&conf.Conf.Scheme.UnixFile) + convertAbsPath(&conf.Conf.Log.Name) convertAbsPath(&conf.Conf.TempDir) convertAbsPath(&conf.Conf.BleveDir) - convertAbsPath(&conf.Conf.Log.Name) - convertAbsPath(&conf.Conf.Database.DBFile) - if conf.Conf.DistDir != "" { - convertAbsPath(&conf.Conf.DistDir) - } + convertAbsPath(&conf.Conf.DistDir) + err := os.MkdirAll(conf.Conf.TempDir, 0o777) if err != nil { log.Fatalf("create temp dir error: %+v", err) diff --git a/internal/conf/config.go b/internal/conf/config.go index 72a8ee72..af198e91 100644 --- a/internal/conf/config.go +++ b/internal/conf/config.go @@ -120,6 +120,7 @@ type Config struct { Log LogConfig `json:"log" envPrefix:"LOG_"` DelayedStart int `json:"delayed_start" env:"DELAYED_START"` MaxBufferLimit int `json:"max_buffer_limitMB" env:"MAX_BUFFER_LIMIT_MB"` + MmapThreshold int `json:"mmap_thresholdMB" env:"MMAP_THRESHOLD_MB"` MaxConnections int `json:"max_connections" env:"MAX_CONNECTIONS"` MaxConcurrency int `json:"max_concurrency" env:"MAX_CONCURRENCY"` TlsInsecureSkipVerify bool `json:"tls_insecure_skip_verify" env:"TLS_INSECURE_SKIP_VERIFY"` @@ -176,6 +177,7 @@ func DefaultConfig(dataDir string) *Config { }, }, MaxBufferLimit: -1, + MmapThreshold: 4, MaxConnections: 0, MaxConcurrency: 64, TlsInsecureSkipVerify: true, diff --git a/internal/conf/var.go b/internal/conf/var.go index 83fb87e9..de23b5c6 100644 --- a/internal/conf/var.go +++ b/internal/conf/var.go @@ -25,7 +25,10 @@ var PrivacyReg []*regexp.Regexp var ( // StoragesLoaded loaded success if empty StoragesLoaded = false + // 单个Buffer最大限制 MaxBufferLimit = 16 * 1024 * 1024 + // 超过该阈值的Buffer将使用 mmap 分配,可主动释放内存 + MmapThreshold = 4 * 1024 * 1024 ) var ( RawIndexHtml string diff --git a/internal/net/request.go b/internal/net/request.go index 5e98f490..399e01f3 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -1,7 +1,6 @@ package net import ( - "bytes" "context" "errors" "fmt" @@ -15,6 +14,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/rclone/rclone/lib/mmap" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/aws/aws-sdk-go/aws/awsutil" @@ -255,7 +255,10 @@ func (d *downloader) sendChunkTask(newConcurrency bool) error { finalSize += firstSize - minSize } } - buf.Reset(int(finalSize)) + err := buf.Reset(int(finalSize)) + if err != nil { + return err + } ch := chunk{ start: d.pos, size: finalSize, @@ -645,11 +648,13 @@ func (mr MultiReadCloser) Close() error { } type Buf struct { - buffer *bytes.Buffer - size int //expected size - ctx context.Context - off int - rw sync.Mutex + size int //expected size + ctx context.Context + offR int + offW int + rw sync.Mutex + buf []byte + mmap bool readSignal chan struct{} readPending bool @@ -658,76 +663,100 @@ type Buf struct { // NewBuf is a buffer that can have 1 read & 1 write at the same time. // when read is faster write, immediately feed data to read after written func NewBuf(ctx context.Context, maxSize int) *Buf { - return &Buf{ - ctx: ctx, - buffer: bytes.NewBuffer(make([]byte, 0, maxSize)), - size: maxSize, - + br := &Buf{ + ctx: ctx, + size: maxSize, readSignal: make(chan struct{}, 1), } -} -func (br *Buf) Reset(size int) { - br.rw.Lock() - defer br.rw.Unlock() - if br.buffer == nil { - return + if conf.MmapThreshold > 0 && maxSize >= conf.MmapThreshold { + m, err := mmap.Alloc(maxSize) + if err == nil { + br.buf = m + br.mmap = true + return br + } } - br.buffer.Reset() - br.size = size - br.off = 0 + br.buf = make([]byte, maxSize) + return br } -func (br *Buf) Read(p []byte) (n int, err error) { +func (br *Buf) Reset(size int) error { + br.rw.Lock() + defer br.rw.Unlock() + if br.buf == nil { + return io.ErrClosedPipe + } + if size > cap(br.buf) { + return fmt.Errorf("reset size %d exceeds max size %d", size, cap(br.buf)) + } + br.size = size + br.offR = 0 + br.offW = 0 + return nil +} + +func (br *Buf) Read(p []byte) (int, error) { if err := br.ctx.Err(); err != nil { return 0, err } if len(p) == 0 { return 0, nil } - if br.off >= br.size { + if br.offR >= br.size { return 0, io.EOF } for { br.rw.Lock() - if br.buffer != nil { - n, err = br.buffer.Read(p) - } else { - err = io.ErrClosedPipe - } - if err != nil && err != io.EOF { + if br.buf == nil { br.rw.Unlock() - return + return 0, io.ErrClosedPipe } - if n > 0 { - br.off += n + + if br.offW < br.offR { br.rw.Unlock() - return n, nil + return 0, io.ErrUnexpectedEOF } - br.readPending = true - br.rw.Unlock() - // n==0, err==io.EOF - select { - case <-br.ctx.Done(): - return 0, br.ctx.Err() - case _, ok := <-br.readSignal: - if !ok { - return 0, io.ErrClosedPipe + if br.offW == br.offR { + br.readPending = true + br.rw.Unlock() + select { + case <-br.ctx.Done(): + return 0, br.ctx.Err() + case _, ok := <-br.readSignal: + if !ok { + return 0, io.ErrClosedPipe + } + continue } - continue } + + n := copy(p, br.buf[br.offR:br.offW]) + br.offR += n + br.rw.Unlock() + if n < len(p) && br.offR >= br.size { + return n, io.EOF + } + return n, nil } } -func (br *Buf) Write(p []byte) (n int, err error) { +func (br *Buf) Write(p []byte) (int, error) { if err := br.ctx.Err(); err != nil { return 0, err } + if len(p) == 0 { + return 0, nil + } br.rw.Lock() defer br.rw.Unlock() - if br.buffer == nil { + if br.buf == nil { return 0, io.ErrClosedPipe } - n, err = br.buffer.Write(p) + if br.offW >= br.size { + return 0, io.ErrShortWrite + } + n := copy(br.buf[br.offW:], p[:min(br.size-br.offW, len(p))]) + br.offW += n if br.readPending { br.readPending = false select { @@ -735,12 +764,21 @@ func (br *Buf) Write(p []byte) (n int, err error) { default: } } - return + if n < len(p) { + return n, io.ErrShortWrite + } + return n, nil } -func (br *Buf) Close() { +func (br *Buf) Close() error { br.rw.Lock() defer br.rw.Unlock() - br.buffer = nil + var err error + if br.mmap { + err = mmap.Free(br.buf) + br.mmap = false + } + br.buf = nil close(br.readSignal) + return err } diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 344a7759..94772761 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -15,6 +15,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/pkg/buffer" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/rclone/rclone/lib/mmap" "go4.org/readerutil" ) @@ -60,8 +61,12 @@ func (f *FileStream) IsForceStreamUpload() bool { } func (f *FileStream) Close() error { - var err1, err2 error + if f.peekBuff != nil { + f.peekBuff.Reset() + f.peekBuff = nil + } + var err1, err2 error err1 = f.Closers.Close() if errors.Is(err1, os.ErrClosed) { err1 = nil @@ -74,10 +79,6 @@ func (f *FileStream) Close() error { f.tmpFile = nil } } - if f.peekBuff != nil { - f.peekBuff.Reset() - f.peekBuff = nil - } return errors.Join(err1, err2) } @@ -194,7 +195,19 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) { f.oriReader = f.Reader } bufSize := maxCacheSize - int64(f.peekBuff.Len()) - buf := make([]byte, bufSize) + var buf []byte + if conf.MmapThreshold > 0 && bufSize >= int64(conf.MmapThreshold) { + m, err := mmap.Alloc(int(bufSize)) + if err == nil { + f.Add(utils.CloseFunc(func() error { + return mmap.Free(m) + })) + buf = m + } + } + if buf == nil { + buf = make([]byte, bufSize) + } n, err := io.ReadFull(f.oriReader, buf) if bufSize != int64(n) { return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err) diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go index 0c7412ff..52c2abee 100644 --- a/internal/stream/stream_test.go +++ b/internal/stream/stream_test.go @@ -7,11 +7,13 @@ import ( "io" "testing" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" ) func TestFileStream_RangeRead(t *testing.T) { + conf.MaxBufferLimit = 16 * 1024 * 1024 type args struct { httpRange http_range.Range } @@ -71,7 +73,7 @@ func TestFileStream_RangeRead(t *testing.T) { } }) } - t.Run("after check", func(t *testing.T) { + t.Run("after", func(t *testing.T) { if f.GetFile() == nil { t.Error("not cached") } diff --git a/internal/stream/util.go b/internal/stream/util.go index 2df1963a..20cb4be0 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -8,13 +8,14 @@ import ( "fmt" "io" "net/http" - "sync" "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/net" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" + "github.com/OpenListTeam/OpenList/v4/pkg/pool" "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/rclone/rclone/lib/mmap" log "github.com/sirupsen/logrus" ) @@ -153,26 +154,49 @@ func CacheFullAndHash(stream model.FileStreamer, up *model.UpdateProgress, hashT type StreamSectionReader struct { file model.FileStreamer off int64 - bufPool *sync.Pool + bufPool *pool.Pool[[]byte] } func NewStreamSectionReader(file model.FileStreamer, maxBufferSize int, up *model.UpdateProgress) (*StreamSectionReader, error) { ss := &StreamSectionReader{file: file} - if file.GetFile() == nil { - maxBufferSize = min(maxBufferSize, int(file.GetSize())) - if maxBufferSize > conf.MaxBufferLimit { - _, err := file.CacheFullAndWriter(up, nil) - if err != nil { - return nil, err - } - } else { - ss.bufPool = &sync.Pool{ - New: func() any { - return make([]byte, maxBufferSize) - }, - } + if file.GetFile() != nil { + return ss, nil + } + + maxBufferSize = min(maxBufferSize, int(file.GetSize())) + if maxBufferSize > conf.MaxBufferLimit { + _, err := file.CacheFullAndWriter(up, nil) + if err != nil { + return nil, err + } + return ss, nil + } + if conf.MmapThreshold > 0 && maxBufferSize >= conf.MmapThreshold { + ss.bufPool = &pool.Pool[[]byte]{ + New: func() []byte { + buf, err := mmap.Alloc(maxBufferSize) + if err == nil { + file.Add(utils.CloseFunc(func() error { + return mmap.Free(buf) + })) + } else { + buf = make([]byte, maxBufferSize) + } + return buf + }, + } + } else { + ss.bufPool = &pool.Pool[[]byte]{ + New: func() []byte { + return make([]byte, maxBufferSize) + }, } } + + file.Add(utils.CloseFunc(func() error { + ss.bufPool.Reset() + return nil + })) return ss, nil } @@ -184,7 +208,7 @@ func (ss *StreamSectionReader) GetSectionReader(off, length int64) (*SectionRead if off != ss.off { return nil, fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.off) } - tempBuf := ss.bufPool.Get().([]byte) + tempBuf := ss.bufPool.Get() buf = tempBuf[:length] n, err := io.ReadFull(ss.file, buf) if int64(n) != length { diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go new file mode 100644 index 00000000..ce92cd1f --- /dev/null +++ b/pkg/pool/pool.go @@ -0,0 +1,37 @@ +package pool + +import "sync" + +type Pool[T any] struct { + New func() T + MaxCap int + + cache []T + mu sync.Mutex +} + +func (p *Pool[T]) Get() T { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.cache) == 0 { + return p.New() + } + item := p.cache[len(p.cache)-1] + p.cache = p.cache[:len(p.cache)-1] + return item +} + +func (p *Pool[T]) Put(item T) { + p.mu.Lock() + defer p.mu.Unlock() + if p.MaxCap == 0 || len(p.cache) < int(p.MaxCap) { + p.cache = append(p.cache, item) + } +} + +func (p *Pool[T]) Reset() { + p.mu.Lock() + defer p.mu.Unlock() + clear(p.cache) + p.cache = nil +}