mirror of
https://github.com/ctwj/urldb.git
synced 2025-11-25 03:15:04 +08:00
279 lines
7.6 KiB
Go
279 lines
7.6 KiB
Go
|
|
package task
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/ctwj/urldb/db/entity"
|
||
|
|
"github.com/ctwj/urldb/db/repo"
|
||
|
|
"github.com/ctwj/urldb/utils"
|
||
|
|
)
|
||
|
|
|
||
|
|
// TaskProcessor 任务处理器接口
|
||
|
|
type TaskProcessor interface {
|
||
|
|
Process(ctx context.Context, taskID uint, item *entity.TaskItem) error
|
||
|
|
GetTaskType() string
|
||
|
|
}
|
||
|
|
|
||
|
|
// TaskManager 任务管理器
|
||
|
|
type TaskManager struct {
|
||
|
|
processors map[string]TaskProcessor
|
||
|
|
repoMgr *repo.RepositoryManager
|
||
|
|
mu sync.RWMutex
|
||
|
|
running map[uint]context.CancelFunc // 正在运行的任务
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewTaskManager 创建任务管理器
|
||
|
|
func NewTaskManager(repoMgr *repo.RepositoryManager) *TaskManager {
|
||
|
|
return &TaskManager{
|
||
|
|
processors: make(map[string]TaskProcessor),
|
||
|
|
repoMgr: repoMgr,
|
||
|
|
running: make(map[uint]context.CancelFunc),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// RegisterProcessor 注册任务处理器
|
||
|
|
func (tm *TaskManager) RegisterProcessor(processor TaskProcessor) {
|
||
|
|
tm.mu.Lock()
|
||
|
|
defer tm.mu.Unlock()
|
||
|
|
tm.processors[processor.GetTaskType()] = processor
|
||
|
|
utils.Info("注册任务处理器: %s", processor.GetTaskType())
|
||
|
|
}
|
||
|
|
|
||
|
|
// getRegisteredProcessors 获取已注册的处理器列表(用于调试)
|
||
|
|
func (tm *TaskManager) getRegisteredProcessors() []string {
|
||
|
|
var types []string
|
||
|
|
for taskType := range tm.processors {
|
||
|
|
types = append(types, taskType)
|
||
|
|
}
|
||
|
|
return types
|
||
|
|
}
|
||
|
|
|
||
|
|
// StartTask 启动任务
|
||
|
|
func (tm *TaskManager) StartTask(taskID uint) error {
|
||
|
|
tm.mu.Lock()
|
||
|
|
defer tm.mu.Unlock()
|
||
|
|
|
||
|
|
utils.Info("StartTask: 尝试启动任务 %d", taskID)
|
||
|
|
|
||
|
|
// 检查任务是否已在运行
|
||
|
|
if _, exists := tm.running[taskID]; exists {
|
||
|
|
utils.Info("任务 %d 已在运行中", taskID)
|
||
|
|
return fmt.Errorf("任务 %d 已在运行中", taskID)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 获取任务信息
|
||
|
|
task, err := tm.repoMgr.TaskRepository.GetByID(taskID)
|
||
|
|
if err != nil {
|
||
|
|
utils.Error("获取任务失败: %v", err)
|
||
|
|
return fmt.Errorf("获取任务失败: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
utils.Info("StartTask: 获取到任务 %d, 类型: %s, 状态: %s", task.ID, task.Type, task.Status)
|
||
|
|
|
||
|
|
// 获取处理器
|
||
|
|
processor, exists := tm.processors[string(task.Type)]
|
||
|
|
if !exists {
|
||
|
|
utils.Error("未找到任务类型 %s 的处理器, 已注册的处理器: %v", task.Type, tm.getRegisteredProcessors())
|
||
|
|
return fmt.Errorf("未找到任务类型 %s 的处理器", task.Type)
|
||
|
|
}
|
||
|
|
|
||
|
|
utils.Info("StartTask: 找到处理器 %s", task.Type)
|
||
|
|
|
||
|
|
// 创建上下文
|
||
|
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
|
tm.running[taskID] = cancel
|
||
|
|
|
||
|
|
utils.Info("StartTask: 启动后台任务协程")
|
||
|
|
// 启动后台任务
|
||
|
|
go tm.processTask(ctx, task, processor)
|
||
|
|
|
||
|
|
utils.Info("StartTask: 任务 %d 启动成功", taskID)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// StopTask 停止任务
|
||
|
|
func (tm *TaskManager) StopTask(taskID uint) error {
|
||
|
|
tm.mu.Lock()
|
||
|
|
defer tm.mu.Unlock()
|
||
|
|
|
||
|
|
cancel, exists := tm.running[taskID]
|
||
|
|
if !exists {
|
||
|
|
return fmt.Errorf("任务 %d 未在运行", taskID)
|
||
|
|
}
|
||
|
|
|
||
|
|
cancel()
|
||
|
|
delete(tm.running, taskID)
|
||
|
|
|
||
|
|
// 更新任务状态为暂停
|
||
|
|
err := tm.repoMgr.TaskRepository.UpdateStatus(taskID, "paused")
|
||
|
|
if err != nil {
|
||
|
|
utils.Error("更新任务状态失败: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// processTask 处理任务
|
||
|
|
func (tm *TaskManager) processTask(ctx context.Context, task *entity.Task, processor TaskProcessor) {
|
||
|
|
defer func() {
|
||
|
|
tm.mu.Lock()
|
||
|
|
delete(tm.running, task.ID)
|
||
|
|
tm.mu.Unlock()
|
||
|
|
utils.Info("processTask: 任务 %d 处理完成,清理资源", task.ID)
|
||
|
|
}()
|
||
|
|
|
||
|
|
utils.Info("processTask: 开始处理任务: %d, 类型: %s", task.ID, task.Type)
|
||
|
|
|
||
|
|
// 更新任务状态为运行中
|
||
|
|
err := tm.repoMgr.TaskRepository.UpdateStatus(task.ID, "running")
|
||
|
|
if err != nil {
|
||
|
|
utils.Error("更新任务状态失败: %v", err)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// 获取待处理的任务项
|
||
|
|
items, err := tm.repoMgr.TaskItemRepository.GetByTaskIDAndStatus(task.ID, "pending")
|
||
|
|
if err != nil {
|
||
|
|
utils.Error("获取任务项失败: %v", err)
|
||
|
|
tm.markTaskFailed(task.ID, fmt.Sprintf("获取任务项失败: %v", err))
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
totalItems := len(items)
|
||
|
|
processedItems := 0
|
||
|
|
successItems := 0
|
||
|
|
failedItems := 0
|
||
|
|
|
||
|
|
utils.Info("任务 %d 共有 %d 个待处理项", task.ID, totalItems)
|
||
|
|
|
||
|
|
for _, item := range items {
|
||
|
|
select {
|
||
|
|
case <-ctx.Done():
|
||
|
|
utils.Info("任务 %d 被取消", task.ID)
|
||
|
|
return
|
||
|
|
default:
|
||
|
|
// 处理单个任务项
|
||
|
|
err := tm.processTaskItem(ctx, task.ID, item, processor)
|
||
|
|
processedItems++
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
failedItems++
|
||
|
|
utils.Error("处理任务项 %d 失败: %v", item.ID, err)
|
||
|
|
} else {
|
||
|
|
successItems++
|
||
|
|
}
|
||
|
|
|
||
|
|
// 更新任务进度
|
||
|
|
progress := float64(processedItems) / float64(totalItems) * 100
|
||
|
|
tm.updateTaskProgress(task.ID, progress, processedItems, successItems, failedItems)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 任务完成
|
||
|
|
status := "completed"
|
||
|
|
message := fmt.Sprintf("任务完成,共处理 %d 项,成功 %d 项,失败 %d 项", processedItems, successItems, failedItems)
|
||
|
|
|
||
|
|
if failedItems > 0 && successItems == 0 {
|
||
|
|
status = "failed"
|
||
|
|
message = fmt.Sprintf("任务失败,共处理 %d 项,全部失败", processedItems)
|
||
|
|
} else if failedItems > 0 {
|
||
|
|
status = "partial_success"
|
||
|
|
message = fmt.Sprintf("任务部分成功,共处理 %d 项,成功 %d 项,失败 %d 项", processedItems, successItems, failedItems)
|
||
|
|
}
|
||
|
|
|
||
|
|
err = tm.repoMgr.TaskRepository.UpdateStatusAndMessage(task.ID, status, message)
|
||
|
|
if err != nil {
|
||
|
|
utils.Error("更新任务状态失败: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
utils.Info("任务 %d 处理完成: %s", task.ID, message)
|
||
|
|
}
|
||
|
|
|
||
|
|
// processTaskItem 处理单个任务项
|
||
|
|
func (tm *TaskManager) processTaskItem(ctx context.Context, taskID uint, item *entity.TaskItem, processor TaskProcessor) error {
|
||
|
|
// 更新任务项状态为处理中
|
||
|
|
err := tm.repoMgr.TaskItemRepository.UpdateStatus(item.ID, "processing")
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("更新任务项状态失败: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 处理任务项
|
||
|
|
err = processor.Process(ctx, taskID, item)
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
// 处理失败
|
||
|
|
outputData := map[string]interface{}{
|
||
|
|
"error": err.Error(),
|
||
|
|
"time": time.Now(),
|
||
|
|
}
|
||
|
|
outputJSON, _ := json.Marshal(outputData)
|
||
|
|
|
||
|
|
updateErr := tm.repoMgr.TaskItemRepository.UpdateStatusAndOutput(item.ID, "failed", string(outputJSON))
|
||
|
|
if updateErr != nil {
|
||
|
|
utils.Error("更新失败任务项状态失败: %v", updateErr)
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// 处理成功
|
||
|
|
outputData := map[string]interface{}{
|
||
|
|
"success": true,
|
||
|
|
"time": time.Now(),
|
||
|
|
}
|
||
|
|
outputJSON, _ := json.Marshal(outputData)
|
||
|
|
|
||
|
|
err = tm.repoMgr.TaskItemRepository.UpdateStatusAndOutput(item.ID, "completed", string(outputJSON))
|
||
|
|
if err != nil {
|
||
|
|
utils.Error("更新成功任务项状态失败: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// updateTaskProgress 更新任务进度
|
||
|
|
func (tm *TaskManager) updateTaskProgress(taskID uint, progress float64, processed, success, failed int) {
|
||
|
|
progressData := map[string]interface{}{
|
||
|
|
"progress": progress,
|
||
|
|
"processed": processed,
|
||
|
|
"success": success,
|
||
|
|
"failed": failed,
|
||
|
|
"time": time.Now(),
|
||
|
|
}
|
||
|
|
|
||
|
|
progressJSON, _ := json.Marshal(progressData)
|
||
|
|
|
||
|
|
err := tm.repoMgr.TaskRepository.UpdateProgress(taskID, progress, string(progressJSON))
|
||
|
|
if err != nil {
|
||
|
|
utils.Error("更新任务进度失败: %v", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// markTaskFailed 标记任务失败
|
||
|
|
func (tm *TaskManager) markTaskFailed(taskID uint, message string) {
|
||
|
|
err := tm.repoMgr.TaskRepository.UpdateStatusAndMessage(taskID, "failed", message)
|
||
|
|
if err != nil {
|
||
|
|
utils.Error("标记任务失败状态失败: %v", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetTaskStatus 获取任务状态
|
||
|
|
func (tm *TaskManager) GetTaskStatus(taskID uint) (string, error) {
|
||
|
|
task, err := tm.repoMgr.TaskRepository.GetByID(taskID)
|
||
|
|
if err != nil {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
return string(task.Status), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// IsTaskRunning 检查任务是否在运行
|
||
|
|
func (tm *TaskManager) IsTaskRunning(taskID uint) bool {
|
||
|
|
tm.mu.RLock()
|
||
|
|
defer tm.mu.RUnlock()
|
||
|
|
_, exists := tm.running[taskID]
|
||
|
|
return exists
|
||
|
|
}
|