mirror of
https://github.com/ctwj/urldb.git
synced 2025-11-25 11:29:37 +08:00
153 lines
4.6 KiB
Go
153 lines
4.6 KiB
Go
package repo
|
||
|
||
import (
|
||
"time"
|
||
|
||
"github.com/ctwj/urldb/db/entity"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// TaskRepository 任务仓库接口
|
||
type TaskRepository interface {
|
||
GetByID(id uint) (*entity.Task, error)
|
||
Create(task *entity.Task) error
|
||
Delete(id uint) error
|
||
GetList(page, pageSize int, taskType, status string) ([]*entity.Task, int64, error)
|
||
UpdateStatus(id uint, status string) error
|
||
UpdateProgress(id uint, progress float64, progressData string) error
|
||
UpdateStatusAndMessage(id uint, status, message string) error
|
||
UpdateTaskStats(id uint, processed, success, failed int) error
|
||
UpdateStartedAt(id uint) error
|
||
UpdateCompletedAt(id uint) error
|
||
}
|
||
|
||
// TaskRepositoryImpl 任务仓库实现
|
||
type TaskRepositoryImpl struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewTaskRepository 创建任务仓库
|
||
func NewTaskRepository(db *gorm.DB) TaskRepository {
|
||
return &TaskRepositoryImpl{
|
||
db: db,
|
||
}
|
||
}
|
||
|
||
// GetByID 根据ID获取任务
|
||
func (r *TaskRepositoryImpl) GetByID(id uint) (*entity.Task, error) {
|
||
var task entity.Task
|
||
err := r.db.First(&task, id).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &task, nil
|
||
}
|
||
|
||
// Create 创建任务
|
||
func (r *TaskRepositoryImpl) Create(task *entity.Task) error {
|
||
return r.db.Create(task).Error
|
||
}
|
||
|
||
// Delete 删除任务
|
||
func (r *TaskRepositoryImpl) Delete(id uint) error {
|
||
return r.db.Delete(&entity.Task{}, id).Error
|
||
}
|
||
|
||
// GetList 获取任务列表
|
||
func (r *TaskRepositoryImpl) GetList(page, pageSize int, taskType, status string) ([]*entity.Task, int64, error) {
|
||
var tasks []*entity.Task
|
||
var total int64
|
||
|
||
query := r.db.Model(&entity.Task{})
|
||
|
||
// 添加过滤条件
|
||
if taskType != "" {
|
||
query = query.Where("type = ?", taskType)
|
||
}
|
||
if status != "" {
|
||
query = query.Where("status = ?", status)
|
||
}
|
||
|
||
// 获取总数
|
||
err := query.Count(&total).Error
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 分页查询
|
||
offset := (page - 1) * pageSize
|
||
err = query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&tasks).Error
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return tasks, total, nil
|
||
}
|
||
|
||
// UpdateStatus 更新任务状态
|
||
func (r *TaskRepositoryImpl) UpdateStatus(id uint, status string) error {
|
||
return r.db.Model(&entity.Task{}).Where("id = ?", id).Update("status", status).Error
|
||
}
|
||
|
||
// UpdateProgress 更新任务进度
|
||
func (r *TaskRepositoryImpl) UpdateProgress(id uint, progress float64, progressData string) error {
|
||
// 检查progress和progress_data字段是否存在
|
||
var count int64
|
||
err := r.db.Raw("SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'tasks' AND column_name = 'progress'").Count(&count).Error
|
||
if err != nil || count == 0 {
|
||
// 如果检查失败或字段不存在,只更新processed_items等现有字段
|
||
return r.db.Model(&entity.Task{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||
"processed_items": progress, // 使用progress作为processed_items的近似值
|
||
}).Error
|
||
}
|
||
|
||
// 字段存在,正常更新
|
||
return r.db.Model(&entity.Task{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||
"progress": progress,
|
||
"progress_data": progressData,
|
||
}).Error
|
||
}
|
||
|
||
// UpdateStatusAndMessage 更新任务状态和消息
|
||
func (r *TaskRepositoryImpl) UpdateStatusAndMessage(id uint, status, message string) error {
|
||
// 检查message字段是否存在
|
||
var count int64
|
||
err := r.db.Raw("SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'tasks' AND column_name = 'message'").Count(&count).Error
|
||
if err != nil {
|
||
// 如果检查失败,只更新状态
|
||
return r.db.Model(&entity.Task{}).Where("id = ?", id).Update("status", status).Error
|
||
}
|
||
|
||
if count > 0 {
|
||
// message字段存在,更新状态和消息
|
||
return r.db.Model(&entity.Task{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||
"status": status,
|
||
"message": message,
|
||
}).Error
|
||
} else {
|
||
// message字段不存在,只更新状态
|
||
return r.db.Model(&entity.Task{}).Where("id = ?", id).Update("status", status).Error
|
||
}
|
||
}
|
||
|
||
// UpdateTaskStats 更新任务统计信息
|
||
func (r *TaskRepositoryImpl) UpdateTaskStats(id uint, processed, success, failed int) error {
|
||
return r.db.Model(&entity.Task{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||
"processed_items": processed,
|
||
"success_items": success,
|
||
"failed_items": failed,
|
||
}).Error
|
||
}
|
||
|
||
// UpdateStartedAt 更新任务开始时间
|
||
func (r *TaskRepositoryImpl) UpdateStartedAt(id uint) error {
|
||
now := time.Now()
|
||
return r.db.Model(&entity.Task{}).Where("id = ?", id).Update("started_at", now).Error
|
||
}
|
||
|
||
// UpdateCompletedAt 更新任务完成时间
|
||
func (r *TaskRepositoryImpl) UpdateCompletedAt(id uint) error {
|
||
now := time.Now()
|
||
return r.db.Model(&entity.Task{}).Where("id = ?", id).Update("completed_at", now).Error
|
||
}
|