diff --git a/internal/bootstrap/storage.go b/internal/bootstrap/storage.go index d111c7d3..1389095b 100644 --- a/internal/bootstrap/storage.go +++ b/internal/bootstrap/storage.go @@ -25,6 +25,6 @@ func LoadStorages() { storages[i].MountPath, storages[i].Driver, storages[i].Order) } } - conf.StoragesLoaded = true + conf.SendStoragesLoadedSignal() }(storages) } diff --git a/internal/conf/var.go b/internal/conf/var.go index de23b5c6..9a02eca2 100644 --- a/internal/conf/var.go +++ b/internal/conf/var.go @@ -3,6 +3,7 @@ package conf import ( "net/url" "regexp" + "sync" ) var ( @@ -23,8 +24,6 @@ var FilenameCharMap = make(map[string]string) var PrivacyReg []*regexp.Regexp var ( - // StoragesLoaded loaded success if empty - StoragesLoaded = false // 单个Buffer最大限制 MaxBufferLimit = 16 * 1024 * 1024 // 超过该阈值的Buffer将使用 mmap 分配,可主动释放内存 @@ -35,3 +34,39 @@ var ( ManageHtml string IndexHtml string ) + +var ( + // StoragesLoaded loaded success if empty + StoragesLoaded = false + storagesLoadMu sync.RWMutex + storagesLoadSignal chan struct{} = make(chan struct{}) +) + +func StoragesLoadSignal() <-chan struct{} { + storagesLoadMu.RLock() + ch := storagesLoadSignal + storagesLoadMu.RUnlock() + return ch +} +func SendStoragesLoadedSignal() { + storagesLoadMu.Lock() + select { + case <-storagesLoadSignal: + // already closed + default: + StoragesLoaded = true + close(storagesLoadSignal) + } + storagesLoadMu.Unlock() +} +func ResetStoragesLoadSignal() { + storagesLoadMu.Lock() + select { + case <-storagesLoadSignal: + StoragesLoaded = false + storagesLoadSignal = make(chan struct{}) + default: + // not closed -> nothing to do + } + storagesLoadMu.Unlock() +} diff --git a/internal/db/tasks.go b/internal/db/tasks.go index dcb9dfea..36054898 100644 --- a/internal/db/tasks.go +++ b/internal/db/tasks.go @@ -1,6 +1,7 @@ package db import ( + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/pkg/errors" ) @@ -30,6 +31,7 @@ func GetTaskDataFunc(type_s string, enabled bool) func() ([]byte, error) { return nil } return func() ([]byte, error) { + <-conf.StoragesLoadSignal() return []byte(task.PersistData), nil } } diff --git a/internal/fs/archive.go b/internal/fs/archive.go index e1e4c448..b2885d2b 100644 --- a/internal/fs/archive.go +++ b/internal/fs/archive.go @@ -41,6 +41,18 @@ func (t *ArchiveDownloadTask) Run() error { if err := t.ReinitCtx(); err != nil { return err } + if t.SrcStorage == nil { + if srcStorage, _, err := op.GetStorageAndActualPath(t.SrcStorageMp); err == nil { + t.SrcStorage = srcStorage + } else { + return err + } + if dstStorage, _, err := op.GetStorageAndActualPath(t.DstStorageMp); err == nil { + t.DstStorage = dstStorage + } else { + return err + } + } t.ClearEndTime() t.SetStartTime(time.Now()) defer func() { t.SetEndTime(time.Now()) }() diff --git a/internal/fs/copy_move.go b/internal/fs/copy_move.go index dbdd2835..d8ecf98c 100644 --- a/internal/fs/copy_move.go +++ b/internal/fs/copy_move.go @@ -48,6 +48,19 @@ func (t *FileTransferTask) Run() error { if err := t.ReinitCtx(); err != nil { return err } + if t.SrcStorage == nil { + if srcStorage, _, err := op.GetStorageAndActualPath(t.SrcStorageMp); err == nil { + t.SrcStorage = srcStorage + } else { + return err + } + if dstStorage, _, err := op.GetStorageAndActualPath(t.DstStorageMp); err == nil { + t.DstStorage = dstStorage + } else { + return err + } + } + t.ClearEndTime() t.SetStartTime(time.Now()) defer func() { t.SetEndTime(time.Now()) }() diff --git a/internal/offline_download/tool/transfer.go b/internal/offline_download/tool/transfer.go index 1c1284a0..7daf0b17 100644 --- a/internal/offline_download/tool/transfer.go +++ b/internal/offline_download/tool/transfer.go @@ -34,6 +34,20 @@ func (t *TransferTask) Run() error { if err := t.ReinitCtx(); err != nil { return err } + if t.SrcStorage == nil && t.SrcStorageMp != "" { + if srcStorage, _, err := op.GetStorageAndActualPath(t.SrcStorageMp); err == nil { + t.SrcStorage = srcStorage + } else { + return err + } + if t.DstStorage == nil { + if dstStorage, _, err := op.GetStorageAndActualPath(t.DstStorageMp); err == nil { + t.DstStorage = dstStorage + } else { + return err + } + } + } t.ClearEndTime() t.SetStartTime(time.Now()) defer func() { t.SetEndTime(time.Now()) }() @@ -64,9 +78,8 @@ func (t *TransferTask) Run() error { return op.Put(t.Ctx(), t.DstStorage, t.DstActualPath, s, t.SetProgress) } return transferStdPath(t) - } else { - return transferObjPath(t) } + return transferObjPath(t) } func (t *TransferTask) GetName() string { diff --git a/internal/task/base.go b/internal/task/base.go index 8976ed90..5bfa03a8 100644 --- a/internal/task/base.go +++ b/internal/task/base.go @@ -14,7 +14,7 @@ type TaskExtension struct { Creator *model.User startTime *time.Time endTime *time.Time - totalBytes int64 + TotalBytes int64 ApiUrl string } @@ -58,11 +58,11 @@ func (t *TaskExtension) ClearEndTime() { } func (t *TaskExtension) SetTotalBytes(totalBytes int64) { - t.totalBytes = totalBytes + t.TotalBytes = totalBytes } func (t *TaskExtension) GetTotalBytes() int64 { - return t.totalBytes + return t.TotalBytes } func (t *TaskExtension) ReinitCtx() error { diff --git a/server/handles/storage.go b/server/handles/storage.go index d0bc062f..f648fc36 100644 --- a/server/handles/storage.go +++ b/server/handles/storage.go @@ -175,7 +175,7 @@ func LoadAllStorages(c *gin.Context) { common.ErrorResp(c, err, 500, true) return } - conf.StoragesLoaded = false + conf.ResetStoragesLoadSignal() go func(storages []model.Storage) { for _, storage := range storages { storageDriver, err := op.GetStorageByMountPath(storage.MountPath) @@ -195,7 +195,7 @@ func LoadAllStorages(c *gin.Context) { log.Infof("success load storage: [%s], driver: [%s]", storage.MountPath, storage.Driver) } - conf.StoragesLoaded = true + conf.SendStoragesLoadedSignal() }(storages) common.SuccessResp(c) } diff --git a/server/middlewares/check.go b/server/middlewares/check.go index a1011de3..c7203a49 100644 --- a/server/middlewares/check.go +++ b/server/middlewares/check.go @@ -22,9 +22,12 @@ func StoragesLoaded(c *gin.Context) { return } } - common.ErrorStrResp(c, "Loading storage, please wait", 500) - c.Abort() - return + select { + case <-conf.StoragesLoadSignal(): + case <-c.Request.Context().Done(): + c.Abort() + return + } } common.GinWithValue(c, conf.ApiUrlKey, common.GetApiUrlFromRequest(c.Request),