perf(stream): optimize CacheFullAndWriter for better memory management (#1584)

* perf(stream): optimize CacheFullAndWriter for better memory management

* fix(stream): ensure proper seek handling in CacheFullAndWriter for improved data integrity
This commit is contained in:
j2rong4cn
2025-11-05 12:16:09 +08:00
committed by GitHub
parent b9f058fcc9
commit 174eae802a
3 changed files with 108 additions and 68 deletions

View File

@@ -48,7 +48,6 @@ type FileStreamer interface {
// for a non-seekable Stream, if Read is called, this function won't work.
// caches the full Stream and writes it to writer (if provided, even if the stream is already cached).
CacheFullAndWriter(up *UpdateProgress, writer io.Writer) (File, error)
SetTmpFile(file File)
// if the Stream is not a File and is not cached, returns nil.
GetFile() File
}

View File

@@ -28,8 +28,6 @@ type FileStream struct {
ForceStreamUpload bool
Exist model.Obj //the file existed in the destination, we can reuse some info since we wil overwrite it
utils.Closers
tmpFile model.File //if present, tmpFile has full content, it will be deleted at last
peekBuff *buffer.Reader
size int64
oriReader io.Reader // the original reader, used for caching
@@ -39,12 +37,6 @@ func (f *FileStream) GetSize() int64 {
if f.size > 0 {
return f.size
}
if file, ok := f.tmpFile.(*os.File); ok {
info, err := file.Stat()
if err == nil {
return info.Size()
}
}
return f.Obj.GetSize()
}
@@ -71,14 +63,13 @@ func (f *FileStream) Close() error {
if errors.Is(err1, os.ErrClosed) {
err1 = nil
}
if file, ok := f.tmpFile.(*os.File); ok {
if file, ok := f.Reader.(*os.File); ok {
err2 = os.RemoveAll(file.Name())
if err2 != nil {
err2 = errs.NewErr(err2, "failed to remove tmpFile [%s]", file.Name())
} else {
f.tmpFile = nil
}
}
f.Reader = nil
return errors.Join(err1, err2)
}
@@ -94,27 +85,28 @@ func (f *FileStream) SetExist(obj model.Obj) {
// It's not thread-safe!
func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writer) (model.File, error) {
if cache := f.GetFile(); cache != nil {
_, err := cache.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
if writer == nil {
return cache, nil
}
_, err := cache.Seek(0, io.SeekStart)
reader := f.Reader
if up != nil {
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
*up = model.UpdateProgressWithRange(*up, 50, 100)
reader = &ReaderUpdatingProgress{
Reader: &SimpleReaderWithSize{
Reader: reader,
Size: f.GetSize(),
},
UpdateProgress: cacheProgress,
}
}
_, err = utils.CopyWithBuffer(writer, reader)
if err == nil {
reader := f.Reader
if up != nil {
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
*up = model.UpdateProgressWithRange(*up, 50, 100)
reader = &ReaderUpdatingProgress{
Reader: &SimpleReaderWithSize{
Reader: reader,
Size: f.GetSize(),
},
UpdateProgress: cacheProgress,
}
}
_, err = utils.CopyWithBuffer(writer, reader)
if err == nil {
_, err = cache.Seek(0, io.SeekStart)
}
_, err = cache.Seek(0, io.SeekStart)
}
if err != nil {
return nil, err
@@ -123,21 +115,20 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ
}
reader := f.Reader
if up != nil {
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
*up = model.UpdateProgressWithRange(*up, 50, 100)
reader = &ReaderUpdatingProgress{
Reader: &SimpleReaderWithSize{
Reader: reader,
Size: f.GetSize(),
},
UpdateProgress: cacheProgress,
if f.peekBuff != nil {
f.peekBuff.Seek(0, io.SeekStart)
if writer != nil {
_, err := utils.CopyWithBuffer(writer, f.peekBuff)
if err != nil {
return nil, err
}
f.peekBuff.Seek(0, io.SeekStart)
}
reader = f.oriReader
}
if writer != nil {
reader = io.TeeReader(reader, writer)
}
if f.GetSize() < 0 {
if f.peekBuff == nil {
f.peekBuff = &buffer.Reader{}
@@ -174,7 +165,6 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ
}
}
}
tmpF, err := utils.CreateTempFile(reader, 0)
if err != nil {
return nil, err
@@ -191,14 +181,33 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ
return peekF, nil
}
f.Reader = reader
if up != nil {
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
*up = model.UpdateProgressWithRange(*up, 50, 100)
size := f.GetSize()
if f.peekBuff != nil {
peekSize := f.peekBuff.Size()
cacheProgress(float64(peekSize) / float64(size) * 100)
size -= peekSize
}
reader = &ReaderUpdatingProgress{
Reader: &SimpleReaderWithSize{
Reader: reader,
Size: size,
},
UpdateProgress: cacheProgress,
}
}
if f.peekBuff != nil {
f.oriReader = reader
} else {
f.Reader = reader
}
return f.cache(f.GetSize())
}
func (f *FileStream) GetFile() model.File {
if f.tmpFile != nil {
return f.tmpFile
}
if file, ok := f.Reader.(model.File); ok {
return file
}
@@ -234,12 +243,29 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
if maxCacheSize > int64(conf.MaxBufferLimit) {
tmpF, err := utils.CreateTempFile(f.Reader, f.GetSize())
size := f.GetSize()
reader := f.Reader
if f.peekBuff != nil {
size -= f.peekBuff.Size()
reader = f.oriReader
}
tmpF, err := utils.CreateTempFile(reader, size)
if err != nil {
return nil, err
}
if f.peekBuff != nil {
f.Add(utils.CloseFunc(func() error {
return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name()))
}))
peekF, err := buffer.NewPeekFile(f.peekBuff, tmpF)
if err != nil {
return nil, err
}
f.Reader = peekF
return peekF, nil
}
f.Add(tmpF)
f.tmpFile = tmpF
f.Reader = tmpF
return tmpF, nil
}
@@ -248,7 +274,7 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
f.peekBuff = &buffer.Reader{}
f.oriReader = f.Reader
}
bufSize := maxCacheSize - int64(f.peekBuff.Size())
bufSize := maxCacheSize - f.peekBuff.Size()
var buf []byte
if conf.MmapThreshold > 0 && bufSize >= int64(conf.MmapThreshold) {
m, err := mmap.Alloc(int(bufSize))
@@ -267,7 +293,7 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err)
}
f.peekBuff.Append(buf)
if int64(f.peekBuff.Size()) >= f.GetSize() {
if f.peekBuff.Size() >= f.GetSize() {
f.Reader = f.peekBuff
f.oriReader = nil
} else {
@@ -276,12 +302,6 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
return f.peekBuff, nil
}
func (f *FileStream) SetTmpFile(file model.File) {
f.AddIfCloser(file)
f.tmpFile = file
f.Reader = file
}
var _ model.FileStreamer = (*SeekableStream)(nil)
var _ model.FileStreamer = (*FileStream)(nil)

View File

@@ -7,13 +7,12 @@ import (
"io"
"testing"
"github.com/OpenListTeam/OpenList/v4/internal/conf"
"github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
)
func TestFileStream_RangeRead(t *testing.T) {
conf.MaxBufferLimit = 16 * 1024 * 1024
type args struct {
httpRange http_range.Range
}
@@ -73,16 +72,38 @@ func TestFileStream_RangeRead(t *testing.T) {
}
})
}
t.Run("after", func(t *testing.T) {
if f.GetFile() == nil {
t.Error("not cached")
}
buf2 := make([]byte, len(buf))
if _, err := io.ReadFull(f, buf2); err != nil {
t.Errorf("FileStream.Read() error = %v", err)
}
if !bytes.Equal(buf, buf2) {
t.Errorf("FileStream.Read() = %s, want %s", buf2, buf)
}
})
if f.GetFile() == nil {
t.Error("not cached")
}
buf2 := make([]byte, len(buf))
if _, err := io.ReadFull(f, buf2); err != nil {
t.Errorf("FileStream.Read() error = %v", err)
}
if !bytes.Equal(buf, buf2) {
t.Errorf("FileStream.Read() = %s, want %s", buf2, buf)
}
}
func TestFileStream_With_PreHash(t *testing.T) {
buf := []byte("github.com/OpenListTeam/OpenList")
f := &FileStream{
Obj: &model.Object{
Size: int64(len(buf)),
},
Reader: io.NopCloser(bytes.NewReader(buf)),
}
const hashSize int64 = 20
reader, _ := f.RangeRead(http_range.Range{Start: 0, Length: hashSize})
preHash, _ := utils.HashReader(utils.SHA1, reader)
if preHash == "" {
t.Error("preHash is empty")
}
tmpF, fullHash, _ := CacheFullAndHash(f, nil, utils.SHA1)
fmt.Println(fullHash)
fileFullHash, _ := utils.HashFile(utils.SHA1, tmpF)
fmt.Println(fileFullHash)
if fullHash != fileFullHash {
t.Errorf("fullHash and fileFullHash should match: fullHash=%s fileFullHash=%s", fullHash, fileFullHash)
}
}