From 103abc942e06e09ed4dc1f87a73c7b15fb2d486c Mon Sep 17 00:00:00 2001 From: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com> Date: Mon, 30 Jun 2025 15:48:05 +0800 Subject: [PATCH] refactor: pass `api_url` through context (#457) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: pass `api_url` through context * 移除 LinkArgs.HttpReq * pref(alias): 减少不必要下载代理 * 修复bug * net: 支持1并发 分片下载 --- drivers/alias/util.go | 47 +++---- drivers/crypt/driver.go | 2 +- drivers/local/driver.go | 2 +- drivers/netease_music/driver.go | 2 +- drivers/netease_music/types.go | 4 +- internal/authn/authn.go | 6 +- internal/conf/const.go | 1 + internal/fs/archive.go | 8 +- internal/fs/copy.go | 10 +- internal/fs/link.go | 5 +- internal/fs/move.go | 152 +++++++++++---------- internal/model/args.go | 1 - internal/net/request.go | 2 +- internal/offline_download/tool/download.go | 4 +- internal/offline_download/tool/transfer.go | 4 +- internal/stream/util.go | 2 +- internal/task/base.go | 45 +++--- internal/task/manager.go | 4 +- server/common/base.go | 11 +- server/common/common.go | 9 -- server/common/proxy.go | 12 +- server/handles/archive.go | 28 ++-- server/handles/down.go | 8 +- server/handles/fsmanage.go | 16 +-- server/handles/fsread.go | 3 +- server/handles/ssologin.go | 16 +-- server/handles/webauthn.go | 8 +- server/middlewares/check.go | 7 +- server/webdav.go | 6 +- server/webdav/webdav.go | 6 +- 30 files changed, 209 insertions(+), 222 deletions(-) diff --git a/drivers/alias/util.go b/drivers/alias/util.go index 2c609ddb..631feaab 100644 --- a/drivers/alias/util.go +++ b/drivers/alias/util.go @@ -103,7 +103,12 @@ func (d *Alias) link(ctx context.Context, dst, sub string, args model.LinkArgs) if err != nil { return nil, err } - if _, ok := storage.(*Alias); !ok && !args.Redirect { + useRawLink := len(common.GetApiUrl(ctx)) == 0 // ftp、s3 + if !useRawLink { + _, ok := storage.(*Alias) + useRawLink = !ok && !args.Redirect + } + if useRawLink { link, _, err := op.Link(ctx, storage, reqActualPath, args) return link, err } @@ -114,13 +119,10 @@ func (d *Alias) link(ctx context.Context, dst, sub string, args model.LinkArgs) if common.ShouldProxy(storage, stdpath.Base(sub)) { link := &model.Link{ URL: fmt.Sprintf("%s/p%s?sign=%s", - common.GetApiUrl(args.HttpReq), + common.GetApiUrl(ctx), utils.EncodePath(reqPath, true), sign.Sign(reqPath)), } - if args.HttpReq != nil && d.ProxyRange { - link.RangeReadCloser = common.NoProxyRange - } return link, nil } link, _, err := op.Link(ctx, storage, reqActualPath, args) @@ -201,31 +203,24 @@ func (d *Alias) extract(ctx context.Context, dst, sub string, args model.Archive if err != nil { return nil, err } - if _, ok := storage.(driver.ArchiveReader); ok { - if _, ok := storage.(*Alias); !ok && !args.Redirect { - link, _, err := op.DriverExtract(ctx, storage, reqActualPath, args) - return link, err - } + if _, ok := storage.(driver.ArchiveReader); !ok { + return nil, errs.NotImplement + } + if args.Redirect && common.ShouldProxy(storage, stdpath.Base(sub)) { _, err = fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true}) if err != nil { return nil, err } - if common.ShouldProxy(storage, stdpath.Base(sub)) { - link := &model.Link{ - URL: fmt.Sprintf("%s/ap%s?inner=%s&pass=%s&sign=%s", - common.GetApiUrl(args.HttpReq), - utils.EncodePath(reqPath, true), - utils.EncodePath(args.InnerPath, true), - url.QueryEscape(args.Password), - sign.SignArchive(reqPath)), - } - if args.HttpReq != nil && d.ProxyRange { - link.RangeReadCloser = common.NoProxyRange - } - return link, nil + link := &model.Link{ + URL: fmt.Sprintf("%s/ap%s?inner=%s&pass=%s&sign=%s", + common.GetApiUrl(ctx), + utils.EncodePath(reqPath, true), + utils.EncodePath(args.InnerPath, true), + url.QueryEscape(args.Password), + sign.SignArchive(reqPath)), } - link, _, err := op.DriverExtract(ctx, storage, reqActualPath, args) - return link, err + return link, nil } - return nil, errs.NotImplement + link, _, err := op.DriverExtract(ctx, storage, reqActualPath, args) + return link, err } diff --git a/drivers/crypt/driver.go b/drivers/crypt/driver.go index 250e8ea8..cc76e3d8 100644 --- a/drivers/crypt/driver.go +++ b/drivers/crypt/driver.go @@ -163,7 +163,7 @@ func (d *Crypt) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([ if d.Thumbnail && thumb == "" { thumbPath := stdpath.Join(args.ReqPath, ".thumbnails", name+".webp") thumb = fmt.Sprintf("%s/d%s?sign=%s", - common.GetApiUrl(common.GetHttpReq(ctx)), + common.GetApiUrl(ctx), utils.EncodePath(thumbPath, true), sign.Sign(thumbPath)) } diff --git a/drivers/local/driver.go b/drivers/local/driver.go index 956f3cb3..b48a4e66 100644 --- a/drivers/local/driver.go +++ b/drivers/local/driver.go @@ -139,7 +139,7 @@ func (d *Local) FileInfoToObj(ctx context.Context, f fs.FileInfo, reqPath string if d.Thumbnail { typeName := utils.GetFileType(f.Name()) if typeName == conf.IMAGE || typeName == conf.VIDEO { - thumb = common.GetApiUrl(common.GetHttpReq(ctx)) + stdpath.Join("/d", reqPath, f.Name()) + thumb = common.GetApiUrl(ctx) + stdpath.Join("/d", reqPath, f.Name()) thumb = utils.EncodePath(thumb, true) thumb += "?type=thumb&sign=" + sign.Sign(stdpath.Join(reqPath, f.Name())) } diff --git a/drivers/netease_music/driver.go b/drivers/netease_music/driver.go index 54169d6f..01f5cbaa 100644 --- a/drivers/netease_music/driver.go +++ b/drivers/netease_music/driver.go @@ -76,7 +76,7 @@ func (d *NeteaseMusic) Link(ctx context.Context, file model.Obj, args model.Link if args.Type == "parsed" { return lrc.getLyricLink(), nil } else { - return lrc.getProxyLink(args), nil + return lrc.getProxyLink(ctx), nil } } diff --git a/drivers/netease_music/types.go b/drivers/netease_music/types.go index b9b5307c..bbfb2717 100644 --- a/drivers/netease_music/types.go +++ b/drivers/netease_music/types.go @@ -48,8 +48,8 @@ type LyricObj struct { lyric string } -func (lrc *LyricObj) getProxyLink(args model.LinkArgs) *model.Link { - rawURL := common.GetApiUrl(args.HttpReq) + "/p" + lrc.Path +func (lrc *LyricObj) getProxyLink(ctx context.Context) *model.Link { + rawURL := common.GetApiUrl(ctx) + "/p" + lrc.Path rawURL = utils.EncodePath(rawURL, true) + "?type=parsed&sign=" + sign.Sign(lrc.Path) return &model.Link{URL: rawURL} } diff --git a/internal/authn/authn.go b/internal/authn/authn.go index d3bc4e62..7c317ca8 100644 --- a/internal/authn/authn.go +++ b/internal/authn/authn.go @@ -2,17 +2,17 @@ package authn import ( "fmt" - "net/http" "net/url" "github.com/OpenListTeam/OpenList/internal/conf" "github.com/OpenListTeam/OpenList/internal/setting" "github.com/OpenListTeam/OpenList/server/common" + "github.com/gin-gonic/gin" "github.com/go-webauthn/webauthn/webauthn" ) -func NewAuthnInstance(r *http.Request) (*webauthn.WebAuthn, error) { - siteUrl, err := url.Parse(common.GetApiUrl(r)) +func NewAuthnInstance(c *gin.Context) (*webauthn.WebAuthn, error) { + siteUrl, err := url.Parse(common.GetApiUrl(c)) if err != nil { return nil, err } diff --git a/internal/conf/const.go b/internal/conf/const.go index 0e51455b..e36ba16e 100644 --- a/internal/conf/const.go +++ b/internal/conf/const.go @@ -148,4 +148,5 @@ const ( // ContextKey is the type of context keys. const ( NoTaskKey = "no_task" + ApiUrlKey = "api_url" ) diff --git a/internal/fs/archive.go b/internal/fs/archive.go index 24651ee7..18b1f1c0 100644 --- a/internal/fs/archive.go +++ b/internal/fs/archive.go @@ -49,7 +49,9 @@ func (t *ArchiveDownloadTask) GetStatus() string { } func (t *ArchiveDownloadTask) Run() error { - t.ReinitCtx() + if err := t.ReinitCtx(); err != nil { + return err + } t.ClearEndTime() t.SetStartTime(time.Now()) defer func() { t.SetEndTime(time.Now()) }() @@ -152,7 +154,9 @@ func (t *ArchiveContentUploadTask) GetStatus() string { } func (t *ArchiveContentUploadTask) Run() error { - t.ReinitCtx() + if err := t.ReinitCtx(); err != nil { + return err + } t.ClearEndTime() t.SetStartTime(time.Now()) defer func() { t.SetEndTime(time.Now()) }() diff --git a/internal/fs/copy.go b/internal/fs/copy.go index 3cc3c186..b1e07797 100644 --- a/internal/fs/copy.go +++ b/internal/fs/copy.go @@ -7,15 +7,15 @@ import ( stdpath "path" "time" - "github.com/OpenListTeam/OpenList/internal/errs" - "github.com/OpenListTeam/OpenList/internal/conf" "github.com/OpenListTeam/OpenList/internal/driver" + "github.com/OpenListTeam/OpenList/internal/errs" "github.com/OpenListTeam/OpenList/internal/model" "github.com/OpenListTeam/OpenList/internal/op" "github.com/OpenListTeam/OpenList/internal/stream" "github.com/OpenListTeam/OpenList/internal/task" "github.com/OpenListTeam/OpenList/pkg/utils" + "github.com/OpenListTeam/OpenList/server/common" "github.com/pkg/errors" "github.com/xhofe/tache" ) @@ -40,7 +40,9 @@ func (t *CopyTask) GetStatus() string { } func (t *CopyTask) Run() error { - t.ReinitCtx() + if err := t.ReinitCtx(); err != nil { + return err + } t.ClearEndTime() t.SetStartTime(time.Now()) defer func() { t.SetEndTime(time.Now()) }() @@ -107,6 +109,7 @@ func _copy(ctx context.Context, srcObjPath, dstDirPath string, lazyCache ...bool t := &CopyTask{ TaskExtension: task.TaskExtension{ Creator: taskCreator, + ApiUrl: common.GetApiUrl(ctx), }, srcStorage: srcStorage, dstStorage: dstStorage, @@ -140,6 +143,7 @@ func copyBetween2Storages(t *CopyTask, srcStorage, dstStorage driver.Driver, src CopyTaskManager.Add(&CopyTask{ TaskExtension: task.TaskExtension{ Creator: t.GetCreator(), + ApiUrl: t.ApiUrl, }, srcStorage: srcStorage, dstStorage: dstStorage, diff --git a/internal/fs/link.go b/internal/fs/link.go index ea9de7ff..e93f5507 100644 --- a/internal/fs/link.go +++ b/internal/fs/link.go @@ -7,7 +7,6 @@ import ( "github.com/OpenListTeam/OpenList/internal/model" "github.com/OpenListTeam/OpenList/internal/op" "github.com/OpenListTeam/OpenList/server/common" - "github.com/gin-gonic/gin" "github.com/pkg/errors" ) @@ -21,9 +20,7 @@ func link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, m return nil, nil, errors.WithMessage(err, "failed link") } if l.URL != "" && !strings.HasPrefix(l.URL, "http://") && !strings.HasPrefix(l.URL, "https://") { - if c, ok := ctx.(*gin.Context); ok { - l.URL = common.GetApiUrl(c.Request) + l.URL - } + l.URL = common.GetApiUrl(ctx) + l.URL } return l, obj, nil } diff --git a/internal/fs/move.go b/internal/fs/move.go index 15e04f5c..bd06ab0a 100644 --- a/internal/fs/move.go +++ b/internal/fs/move.go @@ -15,26 +15,27 @@ import ( "github.com/OpenListTeam/OpenList/internal/stream" "github.com/OpenListTeam/OpenList/internal/task" "github.com/OpenListTeam/OpenList/pkg/utils" + "github.com/OpenListTeam/OpenList/server/common" "github.com/pkg/errors" "github.com/xhofe/tache" ) type MoveTask struct { task.TaskExtension - Status string `json:"-"` - SrcObjPath string `json:"src_path"` - DstDirPath string `json:"dst_path"` - srcStorage driver.Driver `json:"-"` - dstStorage driver.Driver `json:"-"` - SrcStorageMp string `json:"src_storage_mp"` - DstStorageMp string `json:"dst_storage_mp"` - IsRootTask bool `json:"is_root_task"` - RootTaskID string `json:"root_task_id"` - TotalFiles int `json:"total_files"` - CompletedFiles int `json:"completed_files"` - Phase string `json:"phase"` // "copying", "verifying", "deleting", "completed" - ValidateExistence bool `json:"validate_existence"` - mu sync.RWMutex `json:"-"` + Status string `json:"-"` + SrcObjPath string `json:"src_path"` + DstDirPath string `json:"dst_path"` + srcStorage driver.Driver `json:"-"` + dstStorage driver.Driver `json:"-"` + SrcStorageMp string `json:"src_storage_mp"` + DstStorageMp string `json:"dst_storage_mp"` + IsRootTask bool `json:"is_root_task"` + RootTaskID string `json:"root_task_id"` + TotalFiles int `json:"total_files"` + CompletedFiles int `json:"completed_files"` + Phase string `json:"phase"` // "copying", "verifying", "deleting", "completed" + ValidateExistence bool `json:"validate_existence"` + mu sync.RWMutex `json:"-"` } type MoveProgress struct { @@ -62,11 +63,11 @@ func (t *MoveTask) GetStatus() string { func (t *MoveTask) GetProgress() float64 { t.mu.RLock() defer t.mu.RUnlock() - + if t.TotalFiles == 0 { return 0 } - + switch t.Phase { case "copying": return float64(t.CompletedFiles*60) / float64(t.TotalFiles) @@ -84,9 +85,9 @@ func (t *MoveTask) GetProgress() float64 { func (t *MoveTask) GetMoveProgress() *MoveProgress { t.mu.RLock() defer t.mu.RUnlock() - + progress := int(t.GetProgress()) - + return &MoveProgress{ TaskID: t.GetID(), Phase: t.Phase, @@ -106,16 +107,18 @@ func (t *MoveTask) updateProgress() { } func (t *MoveTask) Run() error { - t.ReinitCtx() + if err := t.ReinitCtx(); err != nil { + return err + } t.ClearEndTime() t.SetStartTime(time.Now()) - defer func() { + defer func() { t.SetEndTime(time.Now()) if t.IsRootTask { moveProgressMap.Delete(t.GetID()) } }() - + var err error if t.srcStorage == nil { t.srcStorage, err = op.GetStorageByMountPath(t.SrcStorageMp) @@ -131,13 +134,13 @@ func (t *MoveTask) Run() error { t.mu.Lock() t.Status = "validating source and destination" t.mu.Unlock() - + // Check if source exists srcObj, err := op.Get(t.Ctx(), t.srcStorage, t.SrcObjPath) if err != nil { return errors.WithMessagef(err, "source file [%s] not found", stdpath.Base(t.SrcObjPath)) } - + // Check if destination already exists (if validation is required) if t.ValidateExistence { dstFilePath := stdpath.Join(t.DstDirPath, srcObj.GetName()) @@ -155,7 +158,7 @@ func (t *MoveTask) Run() error { t.mu.Unlock() return t.runRootMoveTask() } - + // Use safe move logic for files return t.safeMoveOperation(srcObj) } @@ -167,7 +170,7 @@ func (t *MoveTask) runRootMoveTask() error { if err != nil { return errors.WithMessagef(err, "failed get src [%s] object", t.SrcObjPath) } - + if !srcObj.IsDir() { // Source is not a directory, use regular move logic t.mu.Lock() @@ -175,32 +178,32 @@ func (t *MoveTask) runRootMoveTask() error { t.mu.Unlock() return t.safeMoveOperation(srcObj) } - + // Phase 1: Count total files and create directory structure t.mu.Lock() t.Phase = "preparing" t.Status = "counting files and preparing directory structure" t.mu.Unlock() t.updateProgress() - + totalFiles, err := t.countFilesAndCreateDirs(t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath) if err != nil { return errors.WithMessage(err, "failed to prepare directory structure") } - + t.mu.Lock() t.TotalFiles = totalFiles t.Phase = "copying" t.Status = "copying files" t.mu.Unlock() t.updateProgress() - + // Phase 2: Copy all files err = t.copyAllFiles(t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath) if err != nil { return errors.WithMessage(err, "failed to copy files") } - + // Phase 3: Verify directory structure t.mu.Lock() t.Phase = "verifying" @@ -208,12 +211,12 @@ func (t *MoveTask) runRootMoveTask() error { t.CompletedFiles = 0 t.mu.Unlock() t.updateProgress() - + err = t.verifyDirectoryStructure(t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath) if err != nil { return errors.WithMessage(err, "verification failed") } - + // Phase 4: Delete source files and directories t.mu.Lock() t.Phase = "deleting" @@ -221,18 +224,18 @@ func (t *MoveTask) runRootMoveTask() error { t.CompletedFiles = 0 t.mu.Unlock() t.updateProgress() - + err = t.deleteSourceRecursively(t.srcStorage, t.SrcObjPath) if err != nil { return errors.WithMessage(err, "failed to delete source files") } - + t.mu.Lock() t.Phase = "completed" t.Status = "completed" t.mu.Unlock() t.updateProgress() - + return nil } @@ -257,11 +260,11 @@ func (t *MoveTask) countFilesAndCreateDirs(srcStorage, dstStorage driver.Driver, if err != nil { return 0, errors.WithMessagef(err, "failed get src [%s] object", srcPath) } - + if !srcObj.IsDir() { return 1, nil } - + // Create destination directory dstObjPath := stdpath.Join(dstPath, srcObj.GetName()) err = op.MakeDir(t.Ctx(), dstStorage, dstObjPath) @@ -271,13 +274,13 @@ func (t *MoveTask) countFilesAndCreateDirs(srcStorage, dstStorage driver.Driver, } return 0, errors.WithMessagef(err, "failed to create destination directory [%s] in storage [%s]", dstObjPath, dstStorage.GetStorage().MountPath) } - + // List and count files recursively objs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{}) if err != nil { return 0, errors.WithMessagef(err, "failed list src [%s] objs", srcPath) } - + totalFiles := 0 for _, obj := range objs { if utils.IsCanceled(t.Ctx()) { @@ -290,7 +293,7 @@ func (t *MoveTask) countFilesAndCreateDirs(srcStorage, dstStorage driver.Driver, } totalFiles += subCount } - + return totalFiles, nil } @@ -300,27 +303,27 @@ func (t *MoveTask) copyAllFiles(srcStorage, dstStorage driver.Driver, srcPath, d if err != nil { return errors.WithMessagef(err, "failed get src [%s] object", srcPath) } - + if !srcObj.IsDir() { // Copy single file err := t.copyFile(srcStorage, dstStorage, srcPath, dstPath) if err != nil { return err } - + t.mu.Lock() t.CompletedFiles++ t.mu.Unlock() t.updateProgress() return nil } - + // Copy directory contents objs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{}) if err != nil { return errors.WithMessagef(err, "failed list src [%s] objs", srcPath) } - + dstObjPath := stdpath.Join(dstPath, srcObj.GetName()) for _, obj := range objs { if utils.IsCanceled(t.Ctx()) { @@ -332,7 +335,7 @@ func (t *MoveTask) copyAllFiles(srcStorage, dstStorage driver.Driver, srcPath, d return err } } - + return nil } @@ -342,24 +345,24 @@ func (t *MoveTask) copyFile(srcStorage, dstStorage driver.Driver, srcFilePath, d if err != nil { return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath) } - + link, _, err := op.Link(t.Ctx(), srcStorage, srcFilePath, model.LinkArgs{ Header: http.Header{}, }) if err != nil { return errors.WithMessagef(err, "failed get [%s] link", srcFilePath) } - + fs := stream.FileStream{ Obj: srcFile, Ctx: t.Ctx(), } - + ss, err := stream.NewSeekableStream(fs, link) if err != nil { return errors.WithMessagef(err, "failed get [%s] stream", srcFilePath) } - + return op.Put(t.Ctx(), dstStorage, dstDirPath, ss, nil, true) } @@ -369,7 +372,7 @@ func (t *MoveTask) verifyDirectoryStructure(srcStorage, dstStorage driver.Driver if err != nil { return errors.WithMessagef(err, "failed get src [%s] object", srcPath) } - + if !srcObj.IsDir() { // Verify single file dstFilePath := stdpath.Join(dstPath, srcObj.GetName()) @@ -377,27 +380,27 @@ func (t *MoveTask) verifyDirectoryStructure(srcStorage, dstStorage driver.Driver if err != nil { return errors.WithMessagef(err, "verification failed: destination file [%s] not found", dstFilePath) } - + t.mu.Lock() t.CompletedFiles++ t.mu.Unlock() t.updateProgress() return nil } - + // Verify directory dstObjPath := stdpath.Join(dstPath, srcObj.GetName()) _, err = op.Get(t.Ctx(), dstStorage, dstObjPath) if err != nil { return errors.WithMessagef(err, "verification failed: destination directory [%s] not found", dstObjPath) } - + // Verify directory contents srcObjs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{}) if err != nil { return errors.WithMessagef(err, "failed list src [%s] objs for verification", srcPath) } - + for _, obj := range srcObjs { if utils.IsCanceled(t.Ctx()) { return nil @@ -408,7 +411,7 @@ func (t *MoveTask) verifyDirectoryStructure(srcStorage, dstStorage driver.Driver return err } } - + return nil } @@ -418,27 +421,27 @@ func (t *MoveTask) deleteSourceRecursively(srcStorage driver.Driver, srcPath str if err != nil { return errors.WithMessagef(err, "failed get src [%s] object for deletion", srcPath) } - + if !srcObj.IsDir() { // Delete single file err := op.Remove(t.Ctx(), srcStorage, srcPath) if err != nil { return errors.WithMessagef(err, "failed to delete src [%s] file", srcPath) } - + t.mu.Lock() t.CompletedFiles++ t.mu.Unlock() t.updateProgress() return nil } - + // Delete directory contents first objs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{}) if err != nil { return errors.WithMessagef(err, "failed list src [%s] objs for deletion", srcPath) } - + for _, obj := range objs { if utils.IsCanceled(t.Ctx()) { return nil @@ -449,13 +452,13 @@ func (t *MoveTask) deleteSourceRecursively(srcStorage driver.Driver, srcPath str return err } } - + // Delete the directory itself err = op.Remove(t.Ctx(), srcStorage, srcPath) if err != nil { return errors.WithMessagef(err, "failed to delete src [%s] directory", srcPath) } - + return nil } @@ -465,14 +468,14 @@ func moveBetween2Storages(t *MoveTask, srcStorage, dstStorage driver.Driver, src if err != nil { return errors.WithMessagef(err, "failed get src [%s] file", srcObjPath) } - + if srcObj.IsDir() { t.Status = "src object is dir, listing objs" objs, err := op.List(t.Ctx(), srcStorage, srcObjPath, model.ListArgs{}) if err != nil { return errors.WithMessagef(err, "failed list src [%s] objs", srcObjPath) } - + dstObjPath := stdpath.Join(dstDirPath, srcObj.GetName()) t.Status = "creating destination directory" err = op.MakeDir(t.Ctx(), dstStorage, dstObjPath) @@ -483,7 +486,7 @@ func moveBetween2Storages(t *MoveTask, srcStorage, dstStorage driver.Driver, src } return errors.WithMessagef(err, "failed to create destination directory [%s] in storage [%s]", dstObjPath, dstStorage.GetStorage().MountPath) } - + for _, obj := range objs { if utils.IsCanceled(t.Ctx()) { return nil @@ -492,6 +495,7 @@ func moveBetween2Storages(t *MoveTask, srcStorage, dstStorage driver.Driver, src MoveTaskManager.Add(&MoveTask{ TaskExtension: task.TaskExtension{ Creator: t.GetCreator(), + ApiUrl: t.ApiUrl, }, srcStorage: srcStorage, dstStorage: dstStorage, @@ -515,13 +519,13 @@ func moveBetween2Storages(t *MoveTask, srcStorage, dstStorage driver.Driver, src } } - func moveFileBetween2Storages(tsk *MoveTask, srcStorage, dstStorage driver.Driver, srcFilePath, dstDirPath string) error { tsk.Status = "copying file to destination" - + copyTask := &CopyTask{ TaskExtension: task.TaskExtension{ Creator: tsk.GetCreator(), + ApiUrl: tsk.ApiUrl, }, srcStorage: srcStorage, dstStorage: dstStorage, @@ -530,10 +534,8 @@ func moveFileBetween2Storages(tsk *MoveTask, srcStorage, dstStorage driver.Drive SrcStorageMp: srcStorage.GetStorage().MountPath, DstStorageMp: dstStorage.GetStorage().MountPath, } - copyTask.SetCtx(tsk.Ctx()) - err := copyBetween2Storages(copyTask, srcStorage, dstStorage, srcFilePath, dstDirPath) if err != nil { @@ -543,21 +545,20 @@ func moveFileBetween2Storages(tsk *MoveTask, srcStorage, dstStorage driver.Drive } return errors.WithMessagef(err, "failed to copy [%s] to destination storage [%s]", srcFilePath, dstStorage.GetStorage().MountPath) } - + tsk.SetProgress(50) - + tsk.Status = "deleting source file" err = op.Remove(tsk.Ctx(), srcStorage, srcFilePath) if err != nil { return errors.WithMessagef(err, "failed to delete src [%s] file from storage [%s] after successful copy", srcFilePath, srcStorage.GetStorage().MountPath) } - + tsk.SetProgress(100) tsk.Status = "completed" return nil } - // safeMoveOperation ensures copy-then-delete sequence for safe move operations func (t *MoveTask) safeMoveOperation(srcObj model.Obj) error { if srcObj.IsDir() { @@ -592,12 +593,13 @@ func _moveWithValidation(ctx context.Context, srcObjPath, dstDirPath string, val } taskCreator, _ := ctx.Value("user").(*model.User) - + // Create task immediately without any synchronous checks to avoid blocking frontend // All validation and type checking will be done asynchronously in the Run method t := &MoveTask{ TaskExtension: task.TaskExtension{ Creator: taskCreator, + ApiUrl: common.GetApiUrl(ctx), }, srcStorage: srcStorage, dstStorage: dstStorage, @@ -608,7 +610,7 @@ func _moveWithValidation(ctx context.Context, srcObjPath, dstDirPath string, val ValidateExistence: validateExistence, Phase: "initializing", } - + MoveTaskManager.Add(t) return t, nil -} \ No newline at end of file +} diff --git a/internal/model/args.go b/internal/model/args.go index 343b8a21..6fb3b8a6 100644 --- a/internal/model/args.go +++ b/internal/model/args.go @@ -20,7 +20,6 @@ type LinkArgs struct { IP string Header http.Header Type string - HttpReq *http.Request Redirect bool } diff --git a/internal/net/request.go b/internal/net/request.go index 59778593..788bf172 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -171,7 +171,7 @@ func (d *downloader) download() (io.ReadCloser, error) { log.Debugf("cfgConcurrency:%d", d.cfg.Concurrency) - if d.cfg.Concurrency == 1 { + if maxPart == 1 { if d.cfg.ConcurrencyLimit != nil { go func() { <-d.ctx.Done() diff --git a/internal/offline_download/tool/download.go b/internal/offline_download/tool/download.go index ff418a2a..bee0c5ca 100644 --- a/internal/offline_download/tool/download.go +++ b/internal/offline_download/tool/download.go @@ -28,7 +28,9 @@ type DownloadTask struct { } func (t *DownloadTask) Run() error { - t.ReinitCtx() + if err := t.ReinitCtx(); err != nil { + return err + } t.ClearEndTime() t.SetStartTime(time.Now()) defer func() { t.SetEndTime(time.Now()) }() diff --git a/internal/offline_download/tool/transfer.go b/internal/offline_download/tool/transfer.go index 6cc0cbb7..3f389953 100644 --- a/internal/offline_download/tool/transfer.go +++ b/internal/offline_download/tool/transfer.go @@ -33,7 +33,9 @@ type TransferTask struct { } func (t *TransferTask) Run() error { - t.ReinitCtx() + if err := t.ReinitCtx(); err != nil { + return err + } t.ClearEndTime() t.SetStartTime(time.Now()) defer func() { t.SetEndTime(time.Now()) }() diff --git a/internal/stream/util.go b/internal/stream/util.go index 1360b688..ee7d291f 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -19,7 +19,7 @@ func GetRangeReadCloserFromLink(size int64, link *model.Link) (model.RangeReadCl return nil, fmt.Errorf("can't create RangeReadCloser since URL is empty in link") } rangeReaderFunc := func(ctx context.Context, r http_range.Range) (io.ReadCloser, error) { - if link.Concurrency != 0 || link.PartSize != 0 { + if link.Concurrency > 0 || link.PartSize > 0 { header := net.ProcessHeader(nil, link.Header) down := net.NewDownloader(func(d *net.Downloader) { d.Concurrency = link.Concurrency diff --git a/internal/task/base.go b/internal/task/base.go index 42bb2635..5e09be2b 100644 --- a/internal/task/base.go +++ b/internal/task/base.go @@ -2,7 +2,6 @@ package task import ( "context" - "sync" "time" "github.com/OpenListTeam/OpenList/internal/conf" @@ -12,12 +11,21 @@ import ( type TaskExtension struct { tache.Base - ctx context.Context - ctxInitMutex sync.Mutex - Creator *model.User - startTime *time.Time - endTime *time.Time - totalBytes int64 + Creator *model.User + startTime *time.Time + endTime *time.Time + totalBytes int64 + ApiUrl string +} + +func (t *TaskExtension) SetCtx(ctx context.Context) { + if t.Creator != nil { + ctx = context.WithValue(ctx, "user", t.Creator) + } + if len(t.ApiUrl) > 0 { + ctx = context.WithValue(ctx, conf.ApiUrlKey, t.ApiUrl) + } + t.Base.SetCtx(ctx) } func (t *TaskExtension) SetCreator(creator *model.User) { @@ -57,29 +65,18 @@ func (t *TaskExtension) GetTotalBytes() int64 { return t.totalBytes } -func (t *TaskExtension) Ctx() context.Context { - if t.ctx == nil { - t.ctxInitMutex.Lock() - if t.ctx == nil { - t.ctx = context.WithValue(t.Base.Ctx(), "user", t.Creator) - } - t.ctxInitMutex.Unlock() - } - return t.ctx -} - -func (t *TaskExtension) ReinitCtx() { - if !conf.Conf.Tasks.AllowRetryCanceled { - return - } +func (t *TaskExtension) ReinitCtx() error { select { - case <-t.Base.Ctx().Done(): + case <-t.Ctx().Done(): + if !conf.Conf.Tasks.AllowRetryCanceled { + return t.Ctx().Err() + } ctx, cancel := context.WithCancel(context.Background()) t.SetCtx(ctx) t.SetCancelFunc(cancel) - t.ctx = nil default: } + return nil } type TaskExtensionInfo interface { diff --git a/internal/task/manager.go b/internal/task/manager.go index 3caa685a..09867b98 100644 --- a/internal/task/manager.go +++ b/internal/task/manager.go @@ -1,6 +1,8 @@ package task -import "github.com/xhofe/tache" +import ( + "github.com/xhofe/tache" +) type Manager[T tache.Task] interface { Add(task T) diff --git a/server/common/base.go b/server/common/base.go index 9317bf6d..3d8b22a5 100644 --- a/server/common/base.go +++ b/server/common/base.go @@ -1,6 +1,7 @@ package common import ( + "context" "fmt" "net/http" stdpath "path" @@ -9,7 +10,7 @@ import ( "github.com/OpenListTeam/OpenList/internal/conf" ) -func GetApiUrl(r *http.Request) string { +func GetApiUrlFormRequest(r *http.Request) string { api := conf.Conf.SiteURL if strings.HasPrefix(api, "http") { return strings.TrimSuffix(api, "/") @@ -28,3 +29,11 @@ func GetApiUrl(r *http.Request) string { api = strings.TrimSuffix(api, "/") return api } + +func GetApiUrl(ctx context.Context) string { + val := ctx.Value(conf.ApiUrlKey) + if api, ok := val.(string); ok { + return api + } + return "" +} diff --git a/server/common/common.go b/server/common/common.go index 098ad3c4..f047cc93 100644 --- a/server/common/common.go +++ b/server/common/common.go @@ -1,8 +1,6 @@ package common import ( - "context" - "net/http" "strings" "github.com/OpenListTeam/OpenList/cmd/flags" @@ -90,10 +88,3 @@ func Pluralize(count int, singular, plural string) string { } return plural } - -func GetHttpReq(ctx context.Context) *http.Request { - if c, ok := ctx.(*gin.Context); ok { - return c.Request - } - return nil -} diff --git a/server/common/proxy.go b/server/common/proxy.go index e72df4ed..cf6cc7d4 100644 --- a/server/common/proxy.go +++ b/server/common/proxy.go @@ -15,7 +15,6 @@ import ( "github.com/OpenListTeam/OpenList/internal/stream" "github.com/OpenListTeam/OpenList/pkg/http_range" "github.com/OpenListTeam/OpenList/pkg/utils" - log "github.com/sirupsen/logrus" ) func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error { @@ -42,7 +41,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. RangeReadCloserIF: link.RangeReadCloser, Limiter: stream.ServerDownloadLimit, }) - } else if link.Concurrency != 0 || link.PartSize != 0 { + } else if link.Concurrency > 0 || link.PartSize > 0 { attachHeader(w, file) size := file.GetSize() rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { @@ -110,21 +109,16 @@ func GetEtag(file model.Obj) string { return fmt.Sprintf(`"%x-%x"`, file.ModTime().Unix(), file.GetSize()) } -var NoProxyRange = &model.RangeReadCloser{} - -func ProxyRange(link *model.Link, size int64) { +func ProxyRange(ctx context.Context, link *model.Link, size int64) { if link.MFile != nil { return } - if link.RangeReadCloser == nil { + if link.RangeReadCloser == nil && !strings.HasPrefix(link.URL, GetApiUrl(ctx)+"/") { var rrc, err = stream.GetRangeReadCloserFromLink(size, link) if err != nil { - log.Warnf("ProxyRange error: %s", err) return } link.RangeReadCloser = rrc - } else if link.RangeReadCloser == NoProxyRange { - link.RangeReadCloser = nil } } diff --git a/server/handles/archive.go b/server/handles/archive.go index 3876791a..316ab9dd 100644 --- a/server/handles/archive.go +++ b/server/handles/archive.go @@ -101,9 +101,8 @@ func FsArchiveMeta(c *gin.Context) { } archiveArgs := model.ArchiveArgs{ LinkArgs: model.LinkArgs{ - Header: c.Request.Header, - Type: c.Query("type"), - HttpReq: c.Request, + Header: c.Request.Header, + Type: c.Query("type"), }, Password: req.ArchivePass, } @@ -132,7 +131,7 @@ func FsArchiveMeta(c *gin.Context) { IsEncrypted: ret.IsEncrypted(), Content: toContentResp(ret.GetTree()), Sort: ret.Sort, - RawURL: fmt.Sprintf("%s%s%s", common.GetApiUrl(c.Request), api, utils.EncodePath(reqPath, true)), + RawURL: fmt.Sprintf("%s%s%s", common.GetApiUrl(c), api, utils.EncodePath(reqPath, true)), Sign: s, }) } @@ -181,9 +180,8 @@ func FsArchiveList(c *gin.Context) { ArchiveInnerArgs: model.ArchiveInnerArgs{ ArchiveArgs: model.ArchiveArgs{ LinkArgs: model.LinkArgs{ - Header: c.Request.Header, - Type: c.Query("type"), - HttpReq: c.Request, + Header: c.Request.Header, + Type: c.Query("type"), }, Password: req.ArchivePass, }, @@ -266,9 +264,8 @@ func FsArchiveDecompress(c *gin.Context) { ArchiveInnerArgs: model.ArchiveInnerArgs{ ArchiveArgs: model.ArchiveArgs{ LinkArgs: model.LinkArgs{ - Header: c.Request.Header, - Type: c.Query("type"), - HttpReq: c.Request, + Header: c.Request.Header, + Type: c.Query("type"), }, Password: req.ArchivePass, }, @@ -314,7 +311,6 @@ func ArchiveDown(c *gin.Context) { IP: c.ClientIP(), Header: c.Request.Header, Type: c.Query("type"), - HttpReq: c.Request, Redirect: true, }, Password: password, @@ -344,9 +340,8 @@ func ArchiveProxy(c *gin.Context) { link, file, err := fs.ArchiveDriverExtract(c, archiveRawPath, model.ArchiveInnerArgs{ ArchiveArgs: model.ArchiveArgs{ LinkArgs: model.LinkArgs{ - Header: c.Request.Header, - Type: c.Query("type"), - HttpReq: c.Request, + Header: c.Request.Header, + Type: c.Query("type"), }, Password: password, }, @@ -370,9 +365,8 @@ func ArchiveInternalExtract(c *gin.Context) { rc, size, err := fs.ArchiveInternalExtract(c, archiveRawPath, model.ArchiveInnerArgs{ ArchiveArgs: model.ArchiveArgs{ LinkArgs: model.LinkArgs{ - Header: c.Request.Header, - Type: c.Query("type"), - HttpReq: c.Request, + Header: c.Request.Header, + Type: c.Query("type"), }, Password: password, }, diff --git a/server/handles/down.go b/server/handles/down.go index f1851fe4..027c14e3 100644 --- a/server/handles/down.go +++ b/server/handles/down.go @@ -38,7 +38,6 @@ func Down(c *gin.Context) { IP: c.ClientIP(), Header: c.Request.Header, Type: c.Query("type"), - HttpReq: c.Request, Redirect: true, }) if err != nil { @@ -71,9 +70,8 @@ func Proxy(c *gin.Context) { } } link, file, err := fs.Link(c, rawPath, model.LinkArgs{ - Header: c.Request.Header, - Type: c.Query("type"), - HttpReq: c.Request, + Header: c.Request.Header, + Type: c.Query("type"), }) if err != nil { common.ErrorResp(c, err, 500) @@ -126,7 +124,7 @@ func localProxy(c *gin.Context, link *model.Link, file model.Obj, proxyRange boo } } if proxyRange { - common.ProxyRange(link, file.GetSize()) + common.ProxyRange(c, link, file.GetSize()) } Writer := &common.WrittenResponseWriter{ResponseWriter: c.Writer} diff --git a/server/handles/fsmanage.go b/server/handles/fsmanage.go index ac1cb9e2..ed0a27ba 100644 --- a/server/handles/fsmanage.go +++ b/server/handles/fsmanage.go @@ -97,7 +97,7 @@ func FsMove(c *gin.Context) { } } } - + // Create all tasks immediately without any synchronous validation // All validation will be done asynchronously in the background var addedTasks []task.TaskExtensionInfo @@ -111,12 +111,12 @@ func FsMove(c *gin.Context) { return } } - + // Return immediately with task information if len(addedTasks) > 0 { common.SuccessResp(c, gin.H{ "message": fmt.Sprintf("Successfully created %d move task(s)", len(addedTasks)), - "tasks": getTaskInfos(addedTasks), + "tasks": getTaskInfos(addedTasks), }) } else { common.SuccessResp(c, gin.H{ @@ -159,7 +159,7 @@ func FsCopy(c *gin.Context) { } } } - + // Create all tasks immediately without any synchronous validation // All validation will be done asynchronously in the background var addedTasks []task.TaskExtensionInfo @@ -173,12 +173,12 @@ func FsCopy(c *gin.Context) { return } } - + // Return immediately with task information if len(addedTasks) > 0 { common.SuccessResp(c, gin.H{ "message": fmt.Sprintf("Successfully created %d copy task(s)", len(addedTasks)), - "tasks": getTaskInfos(addedTasks), + "tasks": getTaskInfos(addedTasks), }) } else { common.SuccessResp(c, gin.H{ @@ -379,13 +379,13 @@ func Link(c *gin.Context) { if storage.Config().OnlyLocal { common.SuccessResp(c, model.Link{ URL: fmt.Sprintf("%s/p%s?d&sign=%s", - common.GetApiUrl(c.Request), + common.GetApiUrl(c), utils.EncodePath(rawPath, true), sign.Sign(rawPath)), }) return } - link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header, HttpReq: c.Request}) + link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header}) if err != nil { common.ErrorResp(c, err, 500) return diff --git a/server/handles/fsread.go b/server/handles/fsread.go index 5ce06511..42143845 100644 --- a/server/handles/fsread.go +++ b/server/handles/fsread.go @@ -296,7 +296,7 @@ func FsGet(c *gin.Context) { sign.Sign(reqPath)) } else { rawURL = fmt.Sprintf("%s/p%s%s", - common.GetApiUrl(c.Request), + common.GetApiUrl(c), utils.EncodePath(reqPath, true), query) } @@ -309,7 +309,6 @@ func FsGet(c *gin.Context) { link, _, err := fs.Link(c, reqPath, model.LinkArgs{ IP: c.ClientIP(), Header: c.Request.Header, - HttpReq: c.Request, Redirect: true, }) if err != nil { diff --git a/server/handles/ssologin.go b/server/handles/ssologin.go index 428f1291..f585719b 100644 --- a/server/handles/ssologin.go +++ b/server/handles/ssologin.go @@ -48,9 +48,9 @@ func verifyState(clientID, ip, state string) bool { func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string { if useCompatibility { - return common.GetApiUrl(c.Request) + "/api/auth/" + method + return common.GetApiUrl(c) + "/api/auth/" + method } else { - return common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method + return common.GetApiUrl(c) + "/api/auth/sso_callback" + "?method=" + method } } @@ -236,7 +236,7 @@ func OIDCLoginCallback(c *gin.Context) { } if method == "get_sso_id" { if useCompatibility { - c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID) + c.Redirect(302, common.GetApiUrl(c)+"/@manage?sso_id="+userID) return } html := fmt.Sprintf(` @@ -263,7 +263,7 @@ func OIDCLoginCallback(c *gin.Context) { common.ErrorResp(c, err, 400) } if useCompatibility { - c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token) + c.Redirect(302, common.GetApiUrl(c)+"/@login?token="+token) return } html := fmt.Sprintf(` @@ -364,9 +364,9 @@ func SSOLoginCallback(c *gin.Context) { } else { var redirect_uri string if usecompatibility { - redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument + redirect_uri = common.GetApiUrl(c) + "/api/auth/" + argument } else { - redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument + redirect_uri = common.GetApiUrl(c) + "/api/auth/sso_callback" + "?method=" + argument } resp, err = ssoClient.R().SetHeader("Accept", "application/json"). SetFormData(map[string]string{ @@ -401,7 +401,7 @@ func SSOLoginCallback(c *gin.Context) { } if argument == "get_sso_id" { if usecompatibility { - c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID) + c.Redirect(302, common.GetApiUrl(c)+"/@manage?sso_id="+userID) return } html := fmt.Sprintf(` @@ -429,7 +429,7 @@ func SSOLoginCallback(c *gin.Context) { common.ErrorResp(c, err, 400) } if usecompatibility { - c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token) + c.Redirect(302, common.GetApiUrl(c)+"/@login?token="+token) return } html := fmt.Sprintf(` diff --git a/server/handles/webauthn.go b/server/handles/webauthn.go index 0afa2803..1d9374bb 100644 --- a/server/handles/webauthn.go +++ b/server/handles/webauthn.go @@ -24,7 +24,7 @@ func BeginAuthnLogin(c *gin.Context) { common.ErrorStrResp(c, "WebAuthn is not enabled", 403) return } - authnInstance, err := authn.NewAuthnInstance(c.Request) + authnInstance, err := authn.NewAuthnInstance(c) if err != nil { common.ErrorResp(c, err, 400) return @@ -65,7 +65,7 @@ func FinishAuthnLogin(c *gin.Context) { common.ErrorStrResp(c, "WebAuthn is not enabled", 403) return } - authnInstance, err := authn.NewAuthnInstance(c.Request) + authnInstance, err := authn.NewAuthnInstance(c) if err != nil { common.ErrorResp(c, err, 400) return @@ -127,7 +127,7 @@ func BeginAuthnRegistration(c *gin.Context) { } user := c.MustGet("user").(*model.User) - authnInstance, err := authn.NewAuthnInstance(c.Request) + authnInstance, err := authn.NewAuthnInstance(c) if err != nil { common.ErrorResp(c, err, 400) } @@ -158,7 +158,7 @@ func FinishAuthnRegistration(c *gin.Context) { user := c.MustGet("user").(*model.User) sessionDataString := c.GetHeader("Session") - authnInstance, err := authn.NewAuthnInstance(c.Request) + authnInstance, err := authn.NewAuthnInstance(c) if err != nil { common.ErrorResp(c, err, 400) return diff --git a/server/middlewares/check.go b/server/middlewares/check.go index fad75eff..ee396dfb 100644 --- a/server/middlewares/check.go +++ b/server/middlewares/check.go @@ -10,9 +10,7 @@ import ( ) func StoragesLoaded(c *gin.Context) { - if conf.StoragesLoaded { - c.Next() - } else { + if !conf.StoragesLoaded { if utils.SliceContains([]string{"", "/", "/favicon.ico"}, c.Request.URL.Path) { c.Next() return @@ -26,5 +24,8 @@ func StoragesLoaded(c *gin.Context) { } common.ErrorStrResp(c, "Loading storage, please wait", 500) c.Abort() + return } + c.Set(conf.ApiUrlKey, common.GetApiUrlFormRequest(c.Request)) + c.Next() } diff --git a/server/webdav.go b/server/webdav.go index 5dd3e86e..8e373edf 100644 --- a/server/webdav.go +++ b/server/webdav.go @@ -1,7 +1,6 @@ package server import ( - "context" "crypto/subtle" "net/http" "path" @@ -11,7 +10,6 @@ import ( "github.com/OpenListTeam/OpenList/server/middlewares" "github.com/OpenListTeam/OpenList/internal/conf" - "github.com/OpenListTeam/OpenList/internal/model" "github.com/OpenListTeam/OpenList/internal/op" "github.com/OpenListTeam/OpenList/internal/setting" "github.com/OpenListTeam/OpenList/server/webdav" @@ -45,9 +43,7 @@ func WebDav(dav *gin.RouterGroup) { } func ServeWebDAV(c *gin.Context) { - user := c.MustGet("user").(*model.User) - ctx := context.WithValue(c.Request.Context(), "user", user) - handler.ServeHTTP(c.Writer, c.Request.WithContext(ctx)) + handler.ServeHTTP(c.Writer, c.Request.WithContext(c)) } func WebDAVAuth(c *gin.Context) { diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index 81af6c43..922d0428 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -241,12 +241,12 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta storage, _ := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) downProxyUrl := storage.GetStorage().DownProxyUrl if storage.GetStorage().WebdavNative() || (storage.GetStorage().WebdavProxy() && downProxyUrl == "") { - link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{Header: r.Header, HttpReq: r}) + link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{Header: r.Header}) if err != nil { return http.StatusInternalServerError, err } if storage.GetStorage().ProxyRange { - common.ProxyRange(link, fi.GetSize()) + common.ProxyRange(ctx, link, fi.GetSize()) } err = common.Proxy(w, r, link, fi) if err != nil { @@ -260,7 +260,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate") http.Redirect(w, r, u, http.StatusFound) } else { - link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, HttpReq: r, Redirect: true}) + link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, Redirect: true}) if err != nil { return http.StatusInternalServerError, err }