feat(stream): enhance GetRangeReaderFromLink rate limiting (#1528)

* feat(stream): enhance GetRangeReaderFromLink rate limiting

* refactor(stream): update GetRangeReaderFromMFile to return *model.FileRangeReader

* refactor(stream): simplify context error handling in RateLimitReader, RateLimitWriter, and RateLimitFile

* refactor(net): replace custom LimitedReadCloser with readers.NewLimitedReadCloser

* fix(model): update Link.ContentLength JSON tag for correct serialization

* docs(model): add clarification to FileRangeReader usage comment
This commit is contained in:
j2rong4cn
2025-11-04 23:56:09 +08:00
committed by GitHub
parent 2844797684
commit 6de15b6310
7 changed files with 80 additions and 101 deletions

View File

@@ -113,9 +113,7 @@ func (d *FTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m
}
return &model.Link{
RangeReader: &model.FileRangeReader{
RangeReaderIF: stream.RateLimitRangeReaderFunc(resultRangeReader),
},
RangeReader: stream.RateLimitRangeReaderFunc(resultRangeReader),
SyncClosers: utils.NewSyncClosers(utils.CloseFunc(conn.Quit)),
}, nil
}

View File

@@ -190,9 +190,7 @@ func (d *ProtonDrive) Link(ctx context.Context, file model.Obj, args model.LinkA
expiration := time.Minute
return &model.Link{
RangeReader: &model.FileRangeReader{
RangeReaderIF: stream.RateLimitRangeReaderFunc(rangeReaderFunc),
},
RangeReader: stream.RateLimitRangeReaderFunc(rangeReaderFunc),
ContentLength: size,
Expiration: &expiration,
}, nil

View File

@@ -34,7 +34,7 @@ type Link struct {
//for accelerating request, use multi-thread downloading
Concurrency int `json:"concurrency"`
PartSize int `json:"part_size"`
ContentLength int64 `json:"-"` // 转码视频、缩略图
ContentLength int64 `json:"content_length"` // 转码视频、缩略图
utils.SyncClosers `json:"-"`
// 如果SyncClosers中的资源被关闭后Link将不可用则此值应为 true

View File

@@ -27,6 +27,9 @@ func (f *FileCloser) Close() error {
return errors.Join(errs...)
}
// FileRangeReader 是对 RangeReaderIF 的轻量包装,表明由 RangeReaderIF.RangeRead
// 返回的 io.ReadCloser 同时实现了 model.File即支持 Read/ReadAt/Seek
// 只有满足这些才需要使用 FileRangeReader否则直接使用 RangeReaderIF 即可。
type FileRangeReader struct {
RangeReaderIF
}

View File

@@ -1,9 +1,7 @@
package net
import (
"fmt"
"io"
"math"
"mime/multipart"
"net/http"
"net/textproto"
@@ -13,6 +11,7 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/conf"
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/rclone/rclone/lib/readers"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/go-resty/resty/v2"
@@ -308,39 +307,9 @@ func rangesMIMESize(ranges []http_range.Range, contentType string, contentSize i
return encSize, nil
}
// LimitedReadCloser wraps a io.ReadCloser and limits the number of bytes that can be read from it.
type LimitedReadCloser struct {
rc io.ReadCloser
remaining int
}
func (l *LimitedReadCloser) Read(buf []byte) (int, error) {
if l.remaining <= 0 {
return 0, io.EOF
}
if len(buf) > l.remaining {
buf = buf[0:l.remaining]
}
n, err := l.rc.Read(buf)
l.remaining -= n
return n, err
}
func (l *LimitedReadCloser) Close() error {
return l.rc.Close()
}
// GetRangedHttpReader some http server doesn't support "Range" header,
// so this function read readCloser with whole data, skip offset, then return ReaderCloser.
func GetRangedHttpReader(readCloser io.ReadCloser, offset, length int64) (io.ReadCloser, error) {
var length_int int
if length > math.MaxInt {
return nil, fmt.Errorf("doesnot support length bigger than int32 max ")
}
length_int = int(length)
if offset > 100*1024*1024 {
log.Warnf("offset is more than 100MB, if loading data from internet, high-latency and wasting of bandwidth is expected")
@@ -351,7 +320,7 @@ func GetRangedHttpReader(readCloser io.ReadCloser, offset, length int64) (io.Rea
}
// return an io.ReadCloser that is limited to `length` bytes.
return &LimitedReadCloser{readCloser, length_int}, nil
return readers.NewLimitedReadCloser(readCloser, length), nil
}
// SetProxyIfConfigured sets proxy for HTTP Transport if configured

View File

@@ -7,7 +7,6 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
"golang.org/x/time/rate"
)
@@ -42,17 +41,14 @@ type RateLimitReader struct {
}
func (r *RateLimitReader) Read(p []byte) (n int, err error) {
if r.Ctx != nil && utils.IsCanceled(r.Ctx) {
return 0, r.Ctx.Err()
if err = r.Ctx.Err(); err != nil {
return 0, err
}
n, err = r.Reader.Read(p)
if err != nil {
return
}
if r.Limiter != nil {
if r.Ctx == nil {
r.Ctx = context.Background()
}
err = r.Limiter.WaitN(r.Ctx, n)
}
return
@@ -72,17 +68,14 @@ type RateLimitWriter struct {
}
func (w *RateLimitWriter) Write(p []byte) (n int, err error) {
if w.Ctx != nil && utils.IsCanceled(w.Ctx) {
return 0, w.Ctx.Err()
if err = w.Ctx.Err(); err != nil {
return 0, err
}
n, err = w.Writer.Write(p)
if err != nil {
return
}
if w.Limiter != nil {
if w.Ctx == nil {
w.Ctx = context.Background()
}
err = w.Limiter.WaitN(w.Ctx, n)
}
return
@@ -102,34 +95,28 @@ type RateLimitFile struct {
}
func (r *RateLimitFile) Read(p []byte) (n int, err error) {
if r.Ctx != nil && utils.IsCanceled(r.Ctx) {
return 0, r.Ctx.Err()
if err = r.Ctx.Err(); err != nil {
return 0, err
}
n, err = r.File.Read(p)
if err != nil {
return
}
if r.Limiter != nil {
if r.Ctx == nil {
r.Ctx = context.Background()
}
err = r.Limiter.WaitN(r.Ctx, n)
}
return
}
func (r *RateLimitFile) ReadAt(p []byte, off int64) (n int, err error) {
if r.Ctx != nil && utils.IsCanceled(r.Ctx) {
return 0, r.Ctx.Err()
if err = r.Ctx.Err(); err != nil {
return 0, err
}
n, err = r.File.ReadAt(p, off)
if err != nil {
return
}
if r.Limiter != nil {
if r.Ctx == nil {
r.Ctx = context.Background()
}
err = r.Limiter.WaitN(r.Ctx, n)
}
return
@@ -145,16 +132,16 @@ func (r *RateLimitFile) Close() error {
type RateLimitRangeReaderFunc RangeReaderFunc
func (f RateLimitRangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
if ServerDownloadLimit == nil {
return f(ctx, httpRange)
}
rc, err := f(ctx, httpRange)
if err != nil {
return nil, err
}
if ServerDownloadLimit != nil {
rc = &RateLimitReader{
Ctx: ctx,
Reader: rc,
Limiter: ServerDownloadLimit,
}
}
return rc, nil
return &RateLimitReader{
Ctx: ctx,
Reader: rc,
Limiter: ServerDownloadLimit,
}, nil
}

View File

@@ -28,44 +28,61 @@ func (f RangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Ran
}
func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, error) {
if link.Concurrency > 0 || link.PartSize > 0 {
if link.RangeReader != nil {
if link.Concurrency < 1 && link.PartSize < 1 {
return link.RangeReader, nil
}
down := net.NewDownloader(func(d *net.Downloader) {
d.Concurrency = link.Concurrency
d.PartSize = link.PartSize
d.HttpClient = net.GetRangeReaderHttpRequestFunc(link.RangeReader)
})
var rangeReader RangeReaderFunc = func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
var req *net.HttpRequestParams
if link.RangeReader != nil {
req = &net.HttpRequestParams{
Range: httpRange,
Size: size,
}
} else {
requestHeader, _ := ctx.Value(conf.RequestHeaderKey).(http.Header)
header := net.ProcessHeader(requestHeader, link.Header)
req = &net.HttpRequestParams{
Range: httpRange,
Size: size,
URL: link.URL,
HeaderRef: header,
}
}
return down.Download(ctx, req)
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
return down.Download(ctx, &net.HttpRequestParams{
Range: httpRange,
Size: size,
})
}
if link.RangeReader != nil {
down.HttpClient = net.GetRangeReaderHttpRequestFunc(link.RangeReader)
return rangeReader, nil
}
return RateLimitRangeReaderFunc(rangeReader), nil
}
if link.RangeReader != nil {
return link.RangeReader, nil
// RangeReader只能在驱动限速
return RangeReaderFunc(rangeReader), nil
}
if len(link.URL) == 0 {
return nil, errors.New("invalid link: must have at least one of URL or RangeReader")
}
if link.Concurrency > 0 || link.PartSize > 0 {
down := net.NewDownloader(func(d *net.Downloader) {
d.Concurrency = link.Concurrency
d.PartSize = link.PartSize
d.HttpClient = func(ctx context.Context, params *net.HttpRequestParams) (*http.Response, error) {
if ServerDownloadLimit == nil {
return net.DefaultHttpRequestFunc(ctx, params)
}
resp, err := net.DefaultHttpRequestFunc(ctx, params)
if err == nil && resp.Body != nil {
resp.Body = &RateLimitReader{
Ctx: ctx,
Reader: resp.Body,
Limiter: ServerDownloadLimit,
}
}
return resp, err
}
})
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
requestHeader, _ := ctx.Value(conf.RequestHeaderKey).(http.Header)
header := net.ProcessHeader(requestHeader, link.Header)
return down.Download(ctx, &net.HttpRequestParams{
Range: httpRange,
Size: size,
URL: link.URL,
HeaderRef: header,
})
}
return RangeReaderFunc(rangeReader), nil
}
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
if httpRange.Length < 0 || httpRange.Start+httpRange.Length > size {
httpRange.Length = size - httpRange.Start
@@ -81,7 +98,15 @@ func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF,
}
return nil, fmt.Errorf("http request failure, err:%w", err)
}
if httpRange.Start == 0 && (httpRange.Length == -1 || httpRange.Length == size) || response.StatusCode == http.StatusPartialContent ||
if ServerDownloadLimit != nil {
response.Body = &RateLimitReader{
Ctx: ctx,
Reader: response.Body,
Limiter: ServerDownloadLimit,
}
}
if httpRange.Start == 0 && httpRange.Length == size ||
response.StatusCode == http.StatusPartialContent ||
checkContentRange(&response.Header, httpRange.Start) {
return response.Body, nil
} else if response.StatusCode == http.StatusOK {
@@ -94,11 +119,10 @@ func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF,
}
return response.Body, nil
}
return RateLimitRangeReaderFunc(rangeReader), nil
return RangeReaderFunc(rangeReader), nil
}
// RangeReaderIF.RangeRead返回的io.ReadCloser保留file的签名。
func GetRangeReaderFromMFile(size int64, file model.File) model.RangeReaderIF {
func GetRangeReaderFromMFile(size int64, file model.File) *model.FileRangeReader {
return &model.FileRangeReader{
RangeReaderIF: RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
length := httpRange.Length