mirror of
https://github.com/OpenListTeam/OpenList.git
synced 2025-11-25 03:15:19 +08:00
feat(stream): enhance GetRangeReaderFromLink rate limiting (#1528)
* feat(stream): enhance GetRangeReaderFromLink rate limiting * refactor(stream): update GetRangeReaderFromMFile to return *model.FileRangeReader * refactor(stream): simplify context error handling in RateLimitReader, RateLimitWriter, and RateLimitFile * refactor(net): replace custom LimitedReadCloser with readers.NewLimitedReadCloser * fix(model): update Link.ContentLength JSON tag for correct serialization * docs(model): add clarification to FileRangeReader usage comment
This commit is contained in:
@@ -113,9 +113,7 @@ func (d *FTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m
|
||||
}
|
||||
|
||||
return &model.Link{
|
||||
RangeReader: &model.FileRangeReader{
|
||||
RangeReaderIF: stream.RateLimitRangeReaderFunc(resultRangeReader),
|
||||
},
|
||||
RangeReader: stream.RateLimitRangeReaderFunc(resultRangeReader),
|
||||
SyncClosers: utils.NewSyncClosers(utils.CloseFunc(conn.Quit)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -190,9 +190,7 @@ func (d *ProtonDrive) Link(ctx context.Context, file model.Obj, args model.LinkA
|
||||
|
||||
expiration := time.Minute
|
||||
return &model.Link{
|
||||
RangeReader: &model.FileRangeReader{
|
||||
RangeReaderIF: stream.RateLimitRangeReaderFunc(rangeReaderFunc),
|
||||
},
|
||||
RangeReader: stream.RateLimitRangeReaderFunc(rangeReaderFunc),
|
||||
ContentLength: size,
|
||||
Expiration: &expiration,
|
||||
}, nil
|
||||
|
||||
@@ -34,7 +34,7 @@ type Link struct {
|
||||
//for accelerating request, use multi-thread downloading
|
||||
Concurrency int `json:"concurrency"`
|
||||
PartSize int `json:"part_size"`
|
||||
ContentLength int64 `json:"-"` // 转码视频、缩略图
|
||||
ContentLength int64 `json:"content_length"` // 转码视频、缩略图
|
||||
|
||||
utils.SyncClosers `json:"-"`
|
||||
// 如果SyncClosers中的资源被关闭后Link将不可用,则此值应为 true
|
||||
|
||||
@@ -27,6 +27,9 @@ func (f *FileCloser) Close() error {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// FileRangeReader 是对 RangeReaderIF 的轻量包装,表明由 RangeReaderIF.RangeRead
|
||||
// 返回的 io.ReadCloser 同时实现了 model.File(即支持 Read/ReadAt/Seek)。
|
||||
// 只有满足这些才需要使用 FileRangeReader,否则直接使用 RangeReaderIF 即可。
|
||||
type FileRangeReader struct {
|
||||
RangeReaderIF
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
@@ -13,6 +11,7 @@ import (
|
||||
|
||||
"github.com/OpenListTeam/OpenList/v4/internal/conf"
|
||||
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
|
||||
"github.com/rclone/rclone/lib/readers"
|
||||
|
||||
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
|
||||
"github.com/go-resty/resty/v2"
|
||||
@@ -308,39 +307,9 @@ func rangesMIMESize(ranges []http_range.Range, contentType string, contentSize i
|
||||
return encSize, nil
|
||||
}
|
||||
|
||||
// LimitedReadCloser wraps a io.ReadCloser and limits the number of bytes that can be read from it.
|
||||
type LimitedReadCloser struct {
|
||||
rc io.ReadCloser
|
||||
remaining int
|
||||
}
|
||||
|
||||
func (l *LimitedReadCloser) Read(buf []byte) (int, error) {
|
||||
if l.remaining <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if len(buf) > l.remaining {
|
||||
buf = buf[0:l.remaining]
|
||||
}
|
||||
|
||||
n, err := l.rc.Read(buf)
|
||||
l.remaining -= n
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (l *LimitedReadCloser) Close() error {
|
||||
return l.rc.Close()
|
||||
}
|
||||
|
||||
// GetRangedHttpReader some http server doesn't support "Range" header,
|
||||
// so this function read readCloser with whole data, skip offset, then return ReaderCloser.
|
||||
func GetRangedHttpReader(readCloser io.ReadCloser, offset, length int64) (io.ReadCloser, error) {
|
||||
var length_int int
|
||||
if length > math.MaxInt {
|
||||
return nil, fmt.Errorf("doesnot support length bigger than int32 max ")
|
||||
}
|
||||
length_int = int(length)
|
||||
|
||||
if offset > 100*1024*1024 {
|
||||
log.Warnf("offset is more than 100MB, if loading data from internet, high-latency and wasting of bandwidth is expected")
|
||||
@@ -351,7 +320,7 @@ func GetRangedHttpReader(readCloser io.ReadCloser, offset, length int64) (io.Rea
|
||||
}
|
||||
|
||||
// return an io.ReadCloser that is limited to `length` bytes.
|
||||
return &LimitedReadCloser{readCloser, length_int}, nil
|
||||
return readers.NewLimitedReadCloser(readCloser, length), nil
|
||||
}
|
||||
|
||||
// SetProxyIfConfigured sets proxy for HTTP Transport if configured
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/OpenListTeam/OpenList/v4/internal/model"
|
||||
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
|
||||
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
@@ -42,17 +41,14 @@ type RateLimitReader struct {
|
||||
}
|
||||
|
||||
func (r *RateLimitReader) Read(p []byte) (n int, err error) {
|
||||
if r.Ctx != nil && utils.IsCanceled(r.Ctx) {
|
||||
return 0, r.Ctx.Err()
|
||||
if err = r.Ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err = r.Reader.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if r.Limiter != nil {
|
||||
if r.Ctx == nil {
|
||||
r.Ctx = context.Background()
|
||||
}
|
||||
err = r.Limiter.WaitN(r.Ctx, n)
|
||||
}
|
||||
return
|
||||
@@ -72,17 +68,14 @@ type RateLimitWriter struct {
|
||||
}
|
||||
|
||||
func (w *RateLimitWriter) Write(p []byte) (n int, err error) {
|
||||
if w.Ctx != nil && utils.IsCanceled(w.Ctx) {
|
||||
return 0, w.Ctx.Err()
|
||||
if err = w.Ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err = w.Writer.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if w.Limiter != nil {
|
||||
if w.Ctx == nil {
|
||||
w.Ctx = context.Background()
|
||||
}
|
||||
err = w.Limiter.WaitN(w.Ctx, n)
|
||||
}
|
||||
return
|
||||
@@ -102,34 +95,28 @@ type RateLimitFile struct {
|
||||
}
|
||||
|
||||
func (r *RateLimitFile) Read(p []byte) (n int, err error) {
|
||||
if r.Ctx != nil && utils.IsCanceled(r.Ctx) {
|
||||
return 0, r.Ctx.Err()
|
||||
if err = r.Ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err = r.File.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if r.Limiter != nil {
|
||||
if r.Ctx == nil {
|
||||
r.Ctx = context.Background()
|
||||
}
|
||||
err = r.Limiter.WaitN(r.Ctx, n)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *RateLimitFile) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
if r.Ctx != nil && utils.IsCanceled(r.Ctx) {
|
||||
return 0, r.Ctx.Err()
|
||||
if err = r.Ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err = r.File.ReadAt(p, off)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if r.Limiter != nil {
|
||||
if r.Ctx == nil {
|
||||
r.Ctx = context.Background()
|
||||
}
|
||||
err = r.Limiter.WaitN(r.Ctx, n)
|
||||
}
|
||||
return
|
||||
@@ -145,16 +132,16 @@ func (r *RateLimitFile) Close() error {
|
||||
type RateLimitRangeReaderFunc RangeReaderFunc
|
||||
|
||||
func (f RateLimitRangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
||||
if ServerDownloadLimit == nil {
|
||||
return f(ctx, httpRange)
|
||||
}
|
||||
rc, err := f(ctx, httpRange)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ServerDownloadLimit != nil {
|
||||
rc = &RateLimitReader{
|
||||
Ctx: ctx,
|
||||
Reader: rc,
|
||||
Limiter: ServerDownloadLimit,
|
||||
}
|
||||
}
|
||||
return rc, nil
|
||||
return &RateLimitReader{
|
||||
Ctx: ctx,
|
||||
Reader: rc,
|
||||
Limiter: ServerDownloadLimit,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -28,44 +28,61 @@ func (f RangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Ran
|
||||
}
|
||||
|
||||
func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, error) {
|
||||
if link.Concurrency > 0 || link.PartSize > 0 {
|
||||
if link.RangeReader != nil {
|
||||
if link.Concurrency < 1 && link.PartSize < 1 {
|
||||
return link.RangeReader, nil
|
||||
}
|
||||
down := net.NewDownloader(func(d *net.Downloader) {
|
||||
d.Concurrency = link.Concurrency
|
||||
d.PartSize = link.PartSize
|
||||
d.HttpClient = net.GetRangeReaderHttpRequestFunc(link.RangeReader)
|
||||
})
|
||||
var rangeReader RangeReaderFunc = func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
||||
var req *net.HttpRequestParams
|
||||
if link.RangeReader != nil {
|
||||
req = &net.HttpRequestParams{
|
||||
Range: httpRange,
|
||||
Size: size,
|
||||
}
|
||||
} else {
|
||||
requestHeader, _ := ctx.Value(conf.RequestHeaderKey).(http.Header)
|
||||
header := net.ProcessHeader(requestHeader, link.Header)
|
||||
req = &net.HttpRequestParams{
|
||||
Range: httpRange,
|
||||
Size: size,
|
||||
URL: link.URL,
|
||||
HeaderRef: header,
|
||||
}
|
||||
}
|
||||
return down.Download(ctx, req)
|
||||
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
||||
return down.Download(ctx, &net.HttpRequestParams{
|
||||
Range: httpRange,
|
||||
Size: size,
|
||||
})
|
||||
}
|
||||
if link.RangeReader != nil {
|
||||
down.HttpClient = net.GetRangeReaderHttpRequestFunc(link.RangeReader)
|
||||
return rangeReader, nil
|
||||
}
|
||||
return RateLimitRangeReaderFunc(rangeReader), nil
|
||||
}
|
||||
|
||||
if link.RangeReader != nil {
|
||||
return link.RangeReader, nil
|
||||
// RangeReader只能在驱动限速
|
||||
return RangeReaderFunc(rangeReader), nil
|
||||
}
|
||||
|
||||
if len(link.URL) == 0 {
|
||||
return nil, errors.New("invalid link: must have at least one of URL or RangeReader")
|
||||
}
|
||||
|
||||
if link.Concurrency > 0 || link.PartSize > 0 {
|
||||
down := net.NewDownloader(func(d *net.Downloader) {
|
||||
d.Concurrency = link.Concurrency
|
||||
d.PartSize = link.PartSize
|
||||
d.HttpClient = func(ctx context.Context, params *net.HttpRequestParams) (*http.Response, error) {
|
||||
if ServerDownloadLimit == nil {
|
||||
return net.DefaultHttpRequestFunc(ctx, params)
|
||||
}
|
||||
resp, err := net.DefaultHttpRequestFunc(ctx, params)
|
||||
if err == nil && resp.Body != nil {
|
||||
resp.Body = &RateLimitReader{
|
||||
Ctx: ctx,
|
||||
Reader: resp.Body,
|
||||
Limiter: ServerDownloadLimit,
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
})
|
||||
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
||||
requestHeader, _ := ctx.Value(conf.RequestHeaderKey).(http.Header)
|
||||
header := net.ProcessHeader(requestHeader, link.Header)
|
||||
return down.Download(ctx, &net.HttpRequestParams{
|
||||
Range: httpRange,
|
||||
Size: size,
|
||||
URL: link.URL,
|
||||
HeaderRef: header,
|
||||
})
|
||||
}
|
||||
return RangeReaderFunc(rangeReader), nil
|
||||
}
|
||||
|
||||
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
||||
if httpRange.Length < 0 || httpRange.Start+httpRange.Length > size {
|
||||
httpRange.Length = size - httpRange.Start
|
||||
@@ -81,7 +98,15 @@ func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF,
|
||||
}
|
||||
return nil, fmt.Errorf("http request failure, err:%w", err)
|
||||
}
|
||||
if httpRange.Start == 0 && (httpRange.Length == -1 || httpRange.Length == size) || response.StatusCode == http.StatusPartialContent ||
|
||||
if ServerDownloadLimit != nil {
|
||||
response.Body = &RateLimitReader{
|
||||
Ctx: ctx,
|
||||
Reader: response.Body,
|
||||
Limiter: ServerDownloadLimit,
|
||||
}
|
||||
}
|
||||
if httpRange.Start == 0 && httpRange.Length == size ||
|
||||
response.StatusCode == http.StatusPartialContent ||
|
||||
checkContentRange(&response.Header, httpRange.Start) {
|
||||
return response.Body, nil
|
||||
} else if response.StatusCode == http.StatusOK {
|
||||
@@ -94,11 +119,10 @@ func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF,
|
||||
}
|
||||
return response.Body, nil
|
||||
}
|
||||
return RateLimitRangeReaderFunc(rangeReader), nil
|
||||
return RangeReaderFunc(rangeReader), nil
|
||||
}
|
||||
|
||||
// RangeReaderIF.RangeRead返回的io.ReadCloser保留file的签名。
|
||||
func GetRangeReaderFromMFile(size int64, file model.File) model.RangeReaderIF {
|
||||
func GetRangeReaderFromMFile(size int64, file model.File) *model.FileRangeReader {
|
||||
return &model.FileRangeReader{
|
||||
RangeReaderIF: RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
||||
length := httpRange.Length
|
||||
|
||||
Reference in New Issue
Block a user