refactor(stream): simplify code (#1590)

* refactor(stream): simplify Close method and update SeekableStream to use RangeReader interface

* refactor(stream):  improve RangeRead comments for clarity
This commit is contained in:
j2rong4cn
2025-11-06 20:06:48 +08:00
committed by GitHub
parent affc499913
commit a1f1f98f94

View File

@@ -10,7 +10,6 @@ import (
"sync"
"github.com/OpenListTeam/OpenList/v4/internal/conf"
"github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/pkg/buffer"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
@@ -28,8 +27,8 @@ type FileStream struct {
ForceStreamUpload bool
Exist model.Obj //the file existed in the destination, we can reuse some info since we wil overwrite it
utils.Closers
peekBuff *buffer.Reader
size int64
peekBuff *buffer.Reader
oriReader io.Reader // the original reader, used for caching
}
@@ -55,23 +54,10 @@ func (f *FileStream) IsForceStreamUpload() bool {
func (f *FileStream) Close() error {
if f.peekBuff != nil {
f.peekBuff.Reset()
f.oriReader = nil
f.peekBuff = nil
}
var err1, err2 error
err1 = f.Closers.Close()
if errors.Is(err1, os.ErrClosed) {
err1 = nil
}
if file, ok := f.Reader.(*os.File); ok {
err2 = os.RemoveAll(file.Name())
if err2 != nil {
err2 = errs.NewErr(err2, "failed to remove tmpFile [%s]", file.Name())
}
}
f.Reader = nil
return errors.Join(err1, err2)
return f.Closers.Close()
}
func (f *FileStream) GetExist() model.Obj {
@@ -214,8 +200,9 @@ func (f *FileStream) GetFile() model.File {
return nil
}
// RangeRead have to cache all data first since only Reader is provided.
// It's not thread-safe!
// 从流读取指定范围的一块数据,并且不消耗流。
// 当读取的边界超过内部设置大小后会缓存整个流。
// 流未缓存时线程不完全
func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
if httpRange.Length < 0 || httpRange.Start+httpRange.Length > f.GetSize() {
httpRange.Length = f.GetSize() - httpRange.Start
@@ -224,12 +211,7 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
return io.NewSectionReader(f.GetFile(), httpRange.Start, httpRange.Length), nil
}
size := httpRange.Start + httpRange.Length
if f.peekBuff != nil && size <= int64(f.peekBuff.Size()) {
return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil
}
cache, err := f.cache(size)
cache, err := f.cache(httpRange.Start + httpRange.Length)
if err != nil {
return nil, err
}
@@ -241,6 +223,7 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
// 使用bytes.Buffer作为io.CopyBuffer的写入对象CopyBuffer会调用Buffer.ReadFrom
// 即使被写入的数据量与Buffer.Cap一致Buffer也会扩大
// 确保指定大小的数据被缓存
func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
if maxCacheSize > int64(conf.MaxBufferLimit) {
size := f.GetSize()
@@ -253,10 +236,10 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
if err != nil {
return nil, err
}
f.Add(utils.CloseFunc(func() error {
return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name()))
}))
if f.peekBuff != nil {
f.Add(utils.CloseFunc(func() error {
return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name()))
}))
peekF, err := buffer.NewPeekFile(f.peekBuff, tmpF)
if err != nil {
return nil, err
@@ -264,8 +247,6 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
f.Reader = peekF
return peekF, nil
}
f.Add(tmpF)
f.Reader = tmpF
return tmpF, nil
}
@@ -273,8 +254,12 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
if f.peekBuff == nil {
f.peekBuff = &buffer.Reader{}
f.oriReader = f.Reader
f.Reader = io.MultiReader(f.peekBuff, f.oriReader)
}
bufSize := maxCacheSize - f.peekBuff.Size()
if bufSize <= 0 {
return f.peekBuff, nil
}
var buf []byte
if conf.MmapThreshold > 0 && bufSize >= int64(conf.MmapThreshold) {
m, err := mmap.Alloc(int(bufSize))
@@ -295,9 +280,6 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
f.peekBuff.Append(buf)
if f.peekBuff.Size() >= f.GetSize() {
f.Reader = f.peekBuff
f.oriReader = nil
} else {
f.Reader = io.MultiReader(f.peekBuff, f.oriReader)
}
return f.peekBuff, nil
}
@@ -305,19 +287,15 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
var _ model.FileStreamer = (*SeekableStream)(nil)
var _ model.FileStreamer = (*FileStream)(nil)
//var _ seekableStream = (*FileStream)(nil)
// for most internal stream, which is either RangeReadCloser or MFile
// Any functionality implemented based on SeekableStream should implement a Close method,
// whose only purpose is to close the SeekableStream object. If such functionality has
// additional resources that need to be closed, they should be added to the Closer property of
// the SeekableStream object and be closed together when the SeekableStream object is closed.
type SeekableStream struct {
*FileStream
// should have one of belows to support rangeRead
rangeReadCloser model.RangeReadCloserIF
rangeReader model.RangeReaderIF
}
// NewSeekableStream create a SeekableStream from FileStream and Link
// if FileStream.Reader is not nil, use it directly
// else create RangeReader from Link
func NewSeekableStream(fs *FileStream, link *model.Link) (*SeekableStream, error) {
if len(fs.Mimetype) == 0 {
fs.Mimetype = utils.GetMimeType(fs.Obj.GetName())
@@ -337,30 +315,31 @@ func NewSeekableStream(fs *FileStream, link *model.Link) (*SeekableStream, error
if err != nil {
return nil, err
}
rrc := &model.RangeReadCloser{
RangeReader: rr,
}
if _, ok := rr.(*model.FileRangeReader); ok {
fs.Reader, err = rrc.RangeRead(fs.Ctx, http_range.Range{Length: -1})
var rc io.ReadCloser
rc, err = rr.RangeRead(fs.Ctx, http_range.Range{Length: -1})
if err != nil {
return nil, err
}
fs.Reader = rc
fs.Add(rc)
}
fs.size = size
fs.Add(link)
fs.Add(rrc)
return &SeekableStream{FileStream: fs, rangeReadCloser: rrc}, nil
return &SeekableStream{FileStream: fs, rangeReader: rr}, nil
}
return nil, fmt.Errorf("illegal seekableStream")
}
// RangeRead is not thread-safe, pls use it in single thread only.
// 如果使用缓存或者rangeReader读取指定范围的数据是线程安全的
// 其他特性继承自FileStream.RangeRead
func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
if ss.GetFile() == nil && ss.rangeReadCloser != nil {
rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, httpRange)
if ss.GetFile() == nil && ss.rangeReader != nil {
rc, err := ss.rangeReader.RangeRead(ss.Ctx, httpRange)
if err != nil {
return nil, err
}
ss.Add(rc)
return rc, nil
}
return ss.FileStream.RangeRead(httpRange)
@@ -376,13 +355,14 @@ func (ss *SeekableStream) Read(p []byte) (n int, err error) {
func (ss *SeekableStream) generateReader() error {
if ss.Reader == nil {
if ss.rangeReadCloser == nil {
if ss.rangeReader == nil {
return fmt.Errorf("illegal seekableStream")
}
rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, http_range.Range{Length: -1})
rc, err := ss.rangeReader.RangeRead(ss.Ctx, http_range.Range{Length: -1})
if err != nil {
return err
}
ss.Add(rc)
ss.Reader = rc
}
return nil