Files
urldb/task/task_processor.go

279 lines
7.6 KiB
Go
Raw Normal View History

2025-08-09 23:47:30 +08:00
package task
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/ctwj/urldb/db/entity"
"github.com/ctwj/urldb/db/repo"
"github.com/ctwj/urldb/utils"
)
// TaskProcessor 任务处理器接口
type TaskProcessor interface {
Process(ctx context.Context, taskID uint, item *entity.TaskItem) error
GetTaskType() string
}
// TaskManager 任务管理器
type TaskManager struct {
processors map[string]TaskProcessor
repoMgr *repo.RepositoryManager
mu sync.RWMutex
running map[uint]context.CancelFunc // 正在运行的任务
}
// NewTaskManager 创建任务管理器
func NewTaskManager(repoMgr *repo.RepositoryManager) *TaskManager {
return &TaskManager{
processors: make(map[string]TaskProcessor),
repoMgr: repoMgr,
running: make(map[uint]context.CancelFunc),
}
}
// RegisterProcessor 注册任务处理器
func (tm *TaskManager) RegisterProcessor(processor TaskProcessor) {
tm.mu.Lock()
defer tm.mu.Unlock()
tm.processors[processor.GetTaskType()] = processor
utils.Info("注册任务处理器: %s", processor.GetTaskType())
}
// getRegisteredProcessors 获取已注册的处理器列表(用于调试)
func (tm *TaskManager) getRegisteredProcessors() []string {
var types []string
for taskType := range tm.processors {
types = append(types, taskType)
}
return types
}
// StartTask 启动任务
func (tm *TaskManager) StartTask(taskID uint) error {
tm.mu.Lock()
defer tm.mu.Unlock()
utils.Info("StartTask: 尝试启动任务 %d", taskID)
// 检查任务是否已在运行
if _, exists := tm.running[taskID]; exists {
utils.Info("任务 %d 已在运行中", taskID)
return fmt.Errorf("任务 %d 已在运行中", taskID)
}
// 获取任务信息
task, err := tm.repoMgr.TaskRepository.GetByID(taskID)
if err != nil {
utils.Error("获取任务失败: %v", err)
return fmt.Errorf("获取任务失败: %v", err)
}
utils.Info("StartTask: 获取到任务 %d, 类型: %s, 状态: %s", task.ID, task.Type, task.Status)
// 获取处理器
processor, exists := tm.processors[string(task.Type)]
if !exists {
utils.Error("未找到任务类型 %s 的处理器, 已注册的处理器: %v", task.Type, tm.getRegisteredProcessors())
return fmt.Errorf("未找到任务类型 %s 的处理器", task.Type)
}
utils.Info("StartTask: 找到处理器 %s", task.Type)
// 创建上下文
ctx, cancel := context.WithCancel(context.Background())
tm.running[taskID] = cancel
utils.Info("StartTask: 启动后台任务协程")
// 启动后台任务
go tm.processTask(ctx, task, processor)
utils.Info("StartTask: 任务 %d 启动成功", taskID)
return nil
}
// StopTask 停止任务
func (tm *TaskManager) StopTask(taskID uint) error {
tm.mu.Lock()
defer tm.mu.Unlock()
cancel, exists := tm.running[taskID]
if !exists {
return fmt.Errorf("任务 %d 未在运行", taskID)
}
cancel()
delete(tm.running, taskID)
// 更新任务状态为暂停
err := tm.repoMgr.TaskRepository.UpdateStatus(taskID, "paused")
if err != nil {
utils.Error("更新任务状态失败: %v", err)
}
return nil
}
// processTask 处理任务
func (tm *TaskManager) processTask(ctx context.Context, task *entity.Task, processor TaskProcessor) {
defer func() {
tm.mu.Lock()
delete(tm.running, task.ID)
tm.mu.Unlock()
utils.Info("processTask: 任务 %d 处理完成,清理资源", task.ID)
}()
utils.Info("processTask: 开始处理任务: %d, 类型: %s", task.ID, task.Type)
// 更新任务状态为运行中
err := tm.repoMgr.TaskRepository.UpdateStatus(task.ID, "running")
if err != nil {
utils.Error("更新任务状态失败: %v", err)
return
}
// 获取待处理的任务项
items, err := tm.repoMgr.TaskItemRepository.GetByTaskIDAndStatus(task.ID, "pending")
if err != nil {
utils.Error("获取任务项失败: %v", err)
tm.markTaskFailed(task.ID, fmt.Sprintf("获取任务项失败: %v", err))
return
}
totalItems := len(items)
processedItems := 0
successItems := 0
failedItems := 0
utils.Info("任务 %d 共有 %d 个待处理项", task.ID, totalItems)
for _, item := range items {
select {
case <-ctx.Done():
utils.Info("任务 %d 被取消", task.ID)
return
default:
// 处理单个任务项
err := tm.processTaskItem(ctx, task.ID, item, processor)
processedItems++
if err != nil {
failedItems++
utils.Error("处理任务项 %d 失败: %v", item.ID, err)
} else {
successItems++
}
// 更新任务进度
progress := float64(processedItems) / float64(totalItems) * 100
tm.updateTaskProgress(task.ID, progress, processedItems, successItems, failedItems)
}
}
// 任务完成
status := "completed"
message := fmt.Sprintf("任务完成,共处理 %d 项,成功 %d 项,失败 %d 项", processedItems, successItems, failedItems)
if failedItems > 0 && successItems == 0 {
status = "failed"
message = fmt.Sprintf("任务失败,共处理 %d 项,全部失败", processedItems)
} else if failedItems > 0 {
status = "partial_success"
message = fmt.Sprintf("任务部分成功,共处理 %d 项,成功 %d 项,失败 %d 项", processedItems, successItems, failedItems)
}
err = tm.repoMgr.TaskRepository.UpdateStatusAndMessage(task.ID, status, message)
if err != nil {
utils.Error("更新任务状态失败: %v", err)
}
utils.Info("任务 %d 处理完成: %s", task.ID, message)
}
// processTaskItem 处理单个任务项
func (tm *TaskManager) processTaskItem(ctx context.Context, taskID uint, item *entity.TaskItem, processor TaskProcessor) error {
// 更新任务项状态为处理中
err := tm.repoMgr.TaskItemRepository.UpdateStatus(item.ID, "processing")
if err != nil {
return fmt.Errorf("更新任务项状态失败: %v", err)
}
// 处理任务项
err = processor.Process(ctx, taskID, item)
if err != nil {
// 处理失败
outputData := map[string]interface{}{
"error": err.Error(),
"time": time.Now(),
}
outputJSON, _ := json.Marshal(outputData)
updateErr := tm.repoMgr.TaskItemRepository.UpdateStatusAndOutput(item.ID, "failed", string(outputJSON))
if updateErr != nil {
utils.Error("更新失败任务项状态失败: %v", updateErr)
}
return err
}
// 处理成功
outputData := map[string]interface{}{
"success": true,
"time": time.Now(),
}
outputJSON, _ := json.Marshal(outputData)
err = tm.repoMgr.TaskItemRepository.UpdateStatusAndOutput(item.ID, "completed", string(outputJSON))
if err != nil {
utils.Error("更新成功任务项状态失败: %v", err)
}
return nil
}
// updateTaskProgress 更新任务进度
func (tm *TaskManager) updateTaskProgress(taskID uint, progress float64, processed, success, failed int) {
progressData := map[string]interface{}{
"progress": progress,
"processed": processed,
"success": success,
"failed": failed,
"time": time.Now(),
}
progressJSON, _ := json.Marshal(progressData)
err := tm.repoMgr.TaskRepository.UpdateProgress(taskID, progress, string(progressJSON))
if err != nil {
utils.Error("更新任务进度失败: %v", err)
}
}
// markTaskFailed 标记任务失败
func (tm *TaskManager) markTaskFailed(taskID uint, message string) {
err := tm.repoMgr.TaskRepository.UpdateStatusAndMessage(taskID, "failed", message)
if err != nil {
utils.Error("标记任务失败状态失败: %v", err)
}
}
// GetTaskStatus 获取任务状态
func (tm *TaskManager) GetTaskStatus(taskID uint) (string, error) {
task, err := tm.repoMgr.TaskRepository.GetByID(taskID)
if err != nil {
return "", err
}
return string(task.Status), nil
}
// IsTaskRunning 检查任务是否在运行
func (tm *TaskManager) IsTaskRunning(taskID uint) bool {
tm.mu.RLock()
defer tm.mu.RUnlock()
_, exists := tm.running[taskID]
return exists
}