mirror of
https://github.com/OpenListTeam/OpenList.git
synced 2025-11-25 03:15:19 +08:00
refactor(stream): simplify code (#1590)
* refactor(stream): simplify Close method and update SeekableStream to use RangeReader interface * refactor(stream): improve RangeRead comments for clarity
This commit is contained in:
@@ -10,7 +10,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/OpenListTeam/OpenList/v4/internal/conf"
|
||||
"github.com/OpenListTeam/OpenList/v4/internal/errs"
|
||||
"github.com/OpenListTeam/OpenList/v4/internal/model"
|
||||
"github.com/OpenListTeam/OpenList/v4/pkg/buffer"
|
||||
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
|
||||
@@ -28,8 +27,8 @@ type FileStream struct {
|
||||
ForceStreamUpload bool
|
||||
Exist model.Obj //the file existed in the destination, we can reuse some info since we wil overwrite it
|
||||
utils.Closers
|
||||
peekBuff *buffer.Reader
|
||||
size int64
|
||||
peekBuff *buffer.Reader
|
||||
oriReader io.Reader // the original reader, used for caching
|
||||
}
|
||||
|
||||
@@ -55,23 +54,10 @@ func (f *FileStream) IsForceStreamUpload() bool {
|
||||
func (f *FileStream) Close() error {
|
||||
if f.peekBuff != nil {
|
||||
f.peekBuff.Reset()
|
||||
f.oriReader = nil
|
||||
f.peekBuff = nil
|
||||
}
|
||||
|
||||
var err1, err2 error
|
||||
err1 = f.Closers.Close()
|
||||
if errors.Is(err1, os.ErrClosed) {
|
||||
err1 = nil
|
||||
}
|
||||
if file, ok := f.Reader.(*os.File); ok {
|
||||
err2 = os.RemoveAll(file.Name())
|
||||
if err2 != nil {
|
||||
err2 = errs.NewErr(err2, "failed to remove tmpFile [%s]", file.Name())
|
||||
}
|
||||
}
|
||||
f.Reader = nil
|
||||
|
||||
return errors.Join(err1, err2)
|
||||
return f.Closers.Close()
|
||||
}
|
||||
|
||||
func (f *FileStream) GetExist() model.Obj {
|
||||
@@ -214,8 +200,9 @@ func (f *FileStream) GetFile() model.File {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RangeRead have to cache all data first since only Reader is provided.
|
||||
// It's not thread-safe!
|
||||
// 从流读取指定范围的一块数据,并且不消耗流。
|
||||
// 当读取的边界超过内部设置大小后会缓存整个流。
|
||||
// 流未缓存时线程不完全
|
||||
func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
|
||||
if httpRange.Length < 0 || httpRange.Start+httpRange.Length > f.GetSize() {
|
||||
httpRange.Length = f.GetSize() - httpRange.Start
|
||||
@@ -224,12 +211,7 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
|
||||
return io.NewSectionReader(f.GetFile(), httpRange.Start, httpRange.Length), nil
|
||||
}
|
||||
|
||||
size := httpRange.Start + httpRange.Length
|
||||
if f.peekBuff != nil && size <= int64(f.peekBuff.Size()) {
|
||||
return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil
|
||||
}
|
||||
|
||||
cache, err := f.cache(size)
|
||||
cache, err := f.cache(httpRange.Start + httpRange.Length)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -241,6 +223,7 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
|
||||
// 使用bytes.Buffer作为io.CopyBuffer的写入对象,CopyBuffer会调用Buffer.ReadFrom
|
||||
// 即使被写入的数据量与Buffer.Cap一致,Buffer也会扩大
|
||||
|
||||
// 确保指定大小的数据被缓存
|
||||
func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
|
||||
if maxCacheSize > int64(conf.MaxBufferLimit) {
|
||||
size := f.GetSize()
|
||||
@@ -253,10 +236,10 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.Add(utils.CloseFunc(func() error {
|
||||
return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name()))
|
||||
}))
|
||||
if f.peekBuff != nil {
|
||||
f.Add(utils.CloseFunc(func() error {
|
||||
return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name()))
|
||||
}))
|
||||
peekF, err := buffer.NewPeekFile(f.peekBuff, tmpF)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -264,8 +247,6 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
|
||||
f.Reader = peekF
|
||||
return peekF, nil
|
||||
}
|
||||
|
||||
f.Add(tmpF)
|
||||
f.Reader = tmpF
|
||||
return tmpF, nil
|
||||
}
|
||||
@@ -273,8 +254,12 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
|
||||
if f.peekBuff == nil {
|
||||
f.peekBuff = &buffer.Reader{}
|
||||
f.oriReader = f.Reader
|
||||
f.Reader = io.MultiReader(f.peekBuff, f.oriReader)
|
||||
}
|
||||
bufSize := maxCacheSize - f.peekBuff.Size()
|
||||
if bufSize <= 0 {
|
||||
return f.peekBuff, nil
|
||||
}
|
||||
var buf []byte
|
||||
if conf.MmapThreshold > 0 && bufSize >= int64(conf.MmapThreshold) {
|
||||
m, err := mmap.Alloc(int(bufSize))
|
||||
@@ -295,9 +280,6 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
|
||||
f.peekBuff.Append(buf)
|
||||
if f.peekBuff.Size() >= f.GetSize() {
|
||||
f.Reader = f.peekBuff
|
||||
f.oriReader = nil
|
||||
} else {
|
||||
f.Reader = io.MultiReader(f.peekBuff, f.oriReader)
|
||||
}
|
||||
return f.peekBuff, nil
|
||||
}
|
||||
@@ -305,19 +287,15 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
|
||||
var _ model.FileStreamer = (*SeekableStream)(nil)
|
||||
var _ model.FileStreamer = (*FileStream)(nil)
|
||||
|
||||
//var _ seekableStream = (*FileStream)(nil)
|
||||
|
||||
// for most internal stream, which is either RangeReadCloser or MFile
|
||||
// Any functionality implemented based on SeekableStream should implement a Close method,
|
||||
// whose only purpose is to close the SeekableStream object. If such functionality has
|
||||
// additional resources that need to be closed, they should be added to the Closer property of
|
||||
// the SeekableStream object and be closed together when the SeekableStream object is closed.
|
||||
type SeekableStream struct {
|
||||
*FileStream
|
||||
// should have one of belows to support rangeRead
|
||||
rangeReadCloser model.RangeReadCloserIF
|
||||
rangeReader model.RangeReaderIF
|
||||
}
|
||||
|
||||
// NewSeekableStream create a SeekableStream from FileStream and Link
|
||||
// if FileStream.Reader is not nil, use it directly
|
||||
// else create RangeReader from Link
|
||||
func NewSeekableStream(fs *FileStream, link *model.Link) (*SeekableStream, error) {
|
||||
if len(fs.Mimetype) == 0 {
|
||||
fs.Mimetype = utils.GetMimeType(fs.Obj.GetName())
|
||||
@@ -337,30 +315,31 @@ func NewSeekableStream(fs *FileStream, link *model.Link) (*SeekableStream, error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rrc := &model.RangeReadCloser{
|
||||
RangeReader: rr,
|
||||
}
|
||||
if _, ok := rr.(*model.FileRangeReader); ok {
|
||||
fs.Reader, err = rrc.RangeRead(fs.Ctx, http_range.Range{Length: -1})
|
||||
var rc io.ReadCloser
|
||||
rc, err = rr.RangeRead(fs.Ctx, http_range.Range{Length: -1})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fs.Reader = rc
|
||||
fs.Add(rc)
|
||||
}
|
||||
fs.size = size
|
||||
fs.Add(link)
|
||||
fs.Add(rrc)
|
||||
return &SeekableStream{FileStream: fs, rangeReadCloser: rrc}, nil
|
||||
return &SeekableStream{FileStream: fs, rangeReader: rr}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("illegal seekableStream")
|
||||
}
|
||||
|
||||
// RangeRead is not thread-safe, pls use it in single thread only.
|
||||
// 如果使用缓存或者rangeReader读取指定范围的数据,是线程安全的
|
||||
// 其他特性继承自FileStream.RangeRead
|
||||
func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
|
||||
if ss.GetFile() == nil && ss.rangeReadCloser != nil {
|
||||
rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, httpRange)
|
||||
if ss.GetFile() == nil && ss.rangeReader != nil {
|
||||
rc, err := ss.rangeReader.RangeRead(ss.Ctx, httpRange)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ss.Add(rc)
|
||||
return rc, nil
|
||||
}
|
||||
return ss.FileStream.RangeRead(httpRange)
|
||||
@@ -376,13 +355,14 @@ func (ss *SeekableStream) Read(p []byte) (n int, err error) {
|
||||
|
||||
func (ss *SeekableStream) generateReader() error {
|
||||
if ss.Reader == nil {
|
||||
if ss.rangeReadCloser == nil {
|
||||
if ss.rangeReader == nil {
|
||||
return fmt.Errorf("illegal seekableStream")
|
||||
}
|
||||
rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, http_range.Range{Length: -1})
|
||||
rc, err := ss.rangeReader.RangeRead(ss.Ctx, http_range.Range{Length: -1})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ss.Add(rc)
|
||||
ss.Reader = rc
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user