mirror of
https://github.com/Xinrea/bili-shadowreplay.git
synced 2025-11-25 04:22:24 +08:00
feat: task scheduler (#216)
* feat: update progress in task message * feat: task scheduler for queued tasks
This commit is contained in:
@@ -43,7 +43,7 @@ pub async fn handle_ffmpeg_process(
|
||||
if content.starts_with("out_time_ms") {
|
||||
let time_str = content.strip_prefix("out_time_ms=").unwrap_or_default();
|
||||
if let Some(reporter) = reporter {
|
||||
reporter.update(time_str);
|
||||
reporter.update(time_str).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,7 +104,8 @@ pub async fn transcode(
|
||||
}
|
||||
reporter
|
||||
.unwrap()
|
||||
.update(format!("压制中:{}", p.time).as_str());
|
||||
.update(format!("压制中:{}", p.time).as_str())
|
||||
.await;
|
||||
}
|
||||
FfmpegEvent::LogEOF => break,
|
||||
FfmpegEvent::Error(e) => {
|
||||
@@ -161,7 +162,8 @@ pub async fn trim_video(
|
||||
}
|
||||
reporter
|
||||
.unwrap()
|
||||
.update(format!("切片中:{}", p.time).as_str());
|
||||
.update(format!("切片中:{}", p.time).as_str())
|
||||
.await;
|
||||
}
|
||||
FfmpegEvent::LogEOF => break,
|
||||
FfmpegEvent::Error(e) => {
|
||||
@@ -429,7 +431,9 @@ pub async fn encode_video_subtitle(
|
||||
}
|
||||
FfmpegEvent::Progress(p) => {
|
||||
log::info!("Encode video subtitle progress: {}", p.time);
|
||||
reporter.update(format!("压制中:{}", p.time).as_str());
|
||||
reporter
|
||||
.update(format!("压制中:{}", p.time).as_str())
|
||||
.await;
|
||||
}
|
||||
FfmpegEvent::LogEOF => break,
|
||||
FfmpegEvent::Log(_level, _content) => {}
|
||||
@@ -528,7 +532,8 @@ pub async fn encode_video_danmu(
|
||||
}
|
||||
reporter
|
||||
.unwrap()
|
||||
.update(format!("压制中:{}", p.time).as_str());
|
||||
.update(format!("压制中:{}", p.time).as_str())
|
||||
.await;
|
||||
}
|
||||
FfmpegEvent::Log(_level, _content) => {}
|
||||
FfmpegEvent::LogEOF => break,
|
||||
@@ -813,7 +818,7 @@ pub async fn clip_from_video_file(
|
||||
match event {
|
||||
FfmpegEvent::Progress(p) => {
|
||||
if let Some(reporter) = reporter {
|
||||
reporter.update(&format!("切片进度: {}", p.time));
|
||||
reporter.update(&format!("切片进度: {}", p.time)).await;
|
||||
}
|
||||
}
|
||||
FfmpegEvent::LogEOF => break,
|
||||
@@ -980,7 +985,9 @@ pub async fn execute_ffmpeg_conversion(
|
||||
while let Ok(event) = parser.parse_next_event().await {
|
||||
match event {
|
||||
FfmpegEvent::Progress(p) => {
|
||||
reporter.update(&format!("正在转换视频格式... {} ({})", p.time, mode_name));
|
||||
reporter
|
||||
.update(&format!("正在转换视频格式... {} ({})", p.time, mode_name))
|
||||
.await;
|
||||
}
|
||||
FfmpegEvent::LogEOF => break,
|
||||
FfmpegEvent::Log(level, content) => {
|
||||
@@ -1008,7 +1015,9 @@ pub async fn execute_ffmpeg_conversion(
|
||||
return Err(format!("视频格式转换失败 ({mode_name}): {error_msg}"));
|
||||
}
|
||||
|
||||
reporter.update(&format!("视频格式转换完成 100% ({mode_name})"));
|
||||
reporter
|
||||
.update(&format!("视频格式转换完成 100% ({mode_name})"))
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1018,7 +1027,7 @@ pub async fn try_stream_copy_conversion(
|
||||
dest: &Path,
|
||||
reporter: &ProgressReporter,
|
||||
) -> Result<(), String> {
|
||||
reporter.update("正在转换视频格式... 0% (无损模式)");
|
||||
reporter.update("正在转换视频格式... 0% (无损模式)").await;
|
||||
|
||||
// 构建ffmpeg命令 - 流复制模式
|
||||
let mut cmd = tokio::process::Command::new(ffmpeg_path());
|
||||
@@ -1051,7 +1060,7 @@ pub async fn try_high_quality_conversion(
|
||||
dest: &Path,
|
||||
reporter: &ProgressReporter,
|
||||
) -> Result<(), String> {
|
||||
reporter.update("正在转换视频格式... 0% (高质量模式)");
|
||||
reporter.update("正在转换视频格式... 0% (高质量模式)").await;
|
||||
|
||||
// 构建ffmpeg命令 - 高质量重编码
|
||||
let mut cmd = tokio::process::Command::new(ffmpeg_path());
|
||||
@@ -1094,7 +1103,7 @@ pub async fn convert_video_format(
|
||||
match try_stream_copy_conversion(source, dest, reporter).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(stream_copy_error) => {
|
||||
reporter.update("流复制失败,使用高质量重编码模式...");
|
||||
reporter.update("流复制失败,使用高质量重编码模式...").await;
|
||||
log::warn!("Stream copy failed: {stream_copy_error}, falling back to re-encoding");
|
||||
try_high_quality_conversion(source, dest, reporter).await
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ use crate::progress::progress_reporter::ProgressReporterTrait;
|
||||
use crate::recorder_manager::RecorderList;
|
||||
use crate::state::State;
|
||||
use crate::state_type;
|
||||
use crate::task::Task;
|
||||
use crate::task::TaskPriority;
|
||||
use crate::webhook::events;
|
||||
use recorder::account::Account;
|
||||
use recorder::danmu::DanmuEntry;
|
||||
@@ -457,7 +459,7 @@ pub async fn generate_whole_clip(
|
||||
let emitter = EventEmitter::new(state.app_handle.clone());
|
||||
#[cfg(feature = "headless")]
|
||||
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
|
||||
let reporter = ProgressReporter::new(&emitter, &task.id).await?;
|
||||
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &task.id).await?;
|
||||
|
||||
log::info!("Create task: {} {}", task.id, task.task_type);
|
||||
// create a tokio task to run in background
|
||||
@@ -467,27 +469,42 @@ pub async fn generate_whole_clip(
|
||||
let state_clone = state.clone();
|
||||
|
||||
let task_id = task.id.clone();
|
||||
tokio::spawn(async move {
|
||||
match state_clone
|
||||
.recorder_manager
|
||||
.generate_whole_clip(Some(&reporter), encode_danmu, platform, &room_id, parent_id)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
reporter.finish(true, "切片生成完成").await;
|
||||
let _ = state_clone
|
||||
.db
|
||||
.update_task(&task_id, "success", "切片生成完成", None)
|
||||
.await;
|
||||
}
|
||||
Err(e) => {
|
||||
reporter.finish(false, &format!("切片生成失败: {e}")).await;
|
||||
let _ = state_clone
|
||||
.db
|
||||
.update_task(&task_id, "failed", &format!("切片生成失败: {e}"), None)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
});
|
||||
state
|
||||
.task_manager
|
||||
.add_task(Task::new(
|
||||
task_id.clone(),
|
||||
TaskPriority::Normal,
|
||||
async move {
|
||||
match state_clone
|
||||
.recorder_manager
|
||||
.generate_whole_clip(
|
||||
Some(&reporter),
|
||||
encode_danmu,
|
||||
platform,
|
||||
&room_id,
|
||||
parent_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
reporter.finish(true, "切片生成完成").await;
|
||||
let _ = state_clone
|
||||
.db
|
||||
.update_task(&task_id, "success", "切片生成完成", None)
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
reporter.finish(false, &format!("切片生成失败: {e}")).await;
|
||||
let _ = state_clone
|
||||
.db
|
||||
.update_task(&task_id, "failed", &format!("切片生成失败: {e}"), None)
|
||||
.await;
|
||||
Err(format!("切片生成失败: {e}"))
|
||||
}
|
||||
}
|
||||
},
|
||||
))
|
||||
.await?;
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
@@ -2,9 +2,7 @@ use crate::database::task::TaskRow;
|
||||
use crate::database::video::VideoRow;
|
||||
use crate::ffmpeg;
|
||||
use crate::handlers::utils::get_disk_info_inner;
|
||||
use crate::progress::progress_reporter::{
|
||||
cancel_progress, EventEmitter, ProgressReporter, ProgressReporterTrait,
|
||||
};
|
||||
use crate::progress::progress_reporter::{EventEmitter, ProgressReporter, ProgressReporterTrait};
|
||||
use crate::recorder_manager::ClipRangeParams;
|
||||
use crate::subtitle_generator::item_to_srt;
|
||||
use crate::webhook::events;
|
||||
@@ -125,7 +123,9 @@ async fn copy_file_with_progress(
|
||||
let report_threshold = 1; // 每1%报告一次
|
||||
|
||||
if percent != last_reported_percent && (percent % report_threshold == 0 || percent == 100) {
|
||||
reporter.update(&format!("正在复制视频文件... {percent}%"));
|
||||
reporter
|
||||
.update(&format!("正在复制视频文件... {percent}%"))
|
||||
.await;
|
||||
last_reported_percent = percent;
|
||||
}
|
||||
}
|
||||
@@ -155,11 +155,13 @@ async fn copy_and_convert_with_progress(
|
||||
|
||||
if is_network_source {
|
||||
// 网络文件:先复制到本地临时位置,再转换
|
||||
reporter.update("检测到网络文件,使用先复制后转换策略...");
|
||||
reporter
|
||||
.update("检测到网络文件,使用先复制后转换策略...")
|
||||
.await;
|
||||
copy_then_convert_strategy(source, dest, reporter).await
|
||||
} else {
|
||||
// 本地文件:直接转换(更高效)
|
||||
reporter.update("检测到本地文件,使用直接转换策略...");
|
||||
reporter.update("检测到本地文件,使用直接转换策略...").await;
|
||||
ffmpeg::convert_video_format(source, dest, reporter).await
|
||||
}
|
||||
}
|
||||
@@ -185,11 +187,13 @@ async fn copy_then_convert_strategy(
|
||||
}
|
||||
|
||||
// 第一步:将网络文件复制到本地临时位置(使用优化的缓冲区)
|
||||
reporter.update("第1步:从网络复制文件到本地临时位置...");
|
||||
reporter
|
||||
.update("第1步:从网络复制文件到本地临时位置...")
|
||||
.await;
|
||||
copy_file_with_network_optimization(source, &temp_path, reporter).await?;
|
||||
|
||||
// 第二步:从本地临时文件转换到目标位置
|
||||
reporter.update("第2步:从临时文件转换到目标格式...");
|
||||
reporter.update("第2步:从临时文件转换到目标格式...").await;
|
||||
let convert_result = ffmpeg::convert_video_format(&temp_path, dest, reporter).await;
|
||||
|
||||
// 清理临时文件
|
||||
@@ -251,12 +255,14 @@ async fn copy_file_with_network_optimization(
|
||||
|
||||
// 网络文件更频繁地报告进度
|
||||
if percent != last_reported_percent {
|
||||
reporter.update(&format!(
|
||||
"正在从网络复制文件... {}% ({:.1}MB/{:.1}MB)",
|
||||
percent,
|
||||
copied as f64 / (1024.0 * 1024.0),
|
||||
total_size as f64 / (1024.0 * 1024.0)
|
||||
));
|
||||
reporter
|
||||
.update(&format!(
|
||||
"正在从网络复制文件... {}% ({:.1}MB/{:.1}MB)",
|
||||
percent,
|
||||
copied as f64 / (1024.0 * 1024.0),
|
||||
total_size as f64 / (1024.0 * 1024.0)
|
||||
))
|
||||
.await;
|
||||
last_reported_percent = percent;
|
||||
}
|
||||
}
|
||||
@@ -270,9 +276,11 @@ async fn copy_file_with_network_optimization(
|
||||
|
||||
// 等待一小段时间后重试
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
|
||||
reporter.update(&format!(
|
||||
"网络连接中断,正在重试... ({consecutive_errors}/{MAX_RETRIES})"
|
||||
));
|
||||
reporter
|
||||
.update(&format!(
|
||||
"网络连接中断,正在重试... ({consecutive_errors}/{MAX_RETRIES})"
|
||||
))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -280,7 +288,7 @@ async fn copy_file_with_network_optimization(
|
||||
dest_file
|
||||
.flush()
|
||||
.map_err(|e| format!("刷新临时文件缓冲区失败: {e}"))?;
|
||||
reporter.update("网络文件复制完成");
|
||||
reporter.update("网络文件复制完成").await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -313,7 +321,7 @@ pub async fn clip_range(
|
||||
let emitter = EventEmitter::new(state.app_handle.clone());
|
||||
#[cfg(feature = "headless")]
|
||||
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
|
||||
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
|
||||
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
|
||||
let mut params_without_cover = params.clone();
|
||||
params_without_cover.cover = String::new();
|
||||
let task = TaskRow {
|
||||
@@ -507,7 +515,7 @@ pub async fn upload_procedure(
|
||||
let emitter = EventEmitter::new(state.app_handle.clone());
|
||||
#[cfg(feature = "headless")]
|
||||
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
|
||||
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
|
||||
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
|
||||
let task = TaskRow {
|
||||
id: event_id.clone(),
|
||||
task_type: "upload_procedure".to_string(),
|
||||
@@ -562,7 +570,7 @@ async fn upload_procedure_inner(
|
||||
let path = Path::new(&file);
|
||||
let client = reqwest::Client::new();
|
||||
let cover_url = bilibili::api::upload_cover(&client, &account.to_account(), &cover).await;
|
||||
reporter.update("投稿预处理中");
|
||||
reporter.update("投稿预处理中").await;
|
||||
|
||||
match bilibili::api::prepare_video(&client, &account.to_account(), path).await {
|
||||
Ok(video) => {
|
||||
@@ -614,8 +622,13 @@ async fn upload_procedure_inner(
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "gui", tauri::command)]
|
||||
pub async fn cancel(_state: state_type!(), event_id: String) -> Result<(), String> {
|
||||
cancel_progress(&event_id).await;
|
||||
pub async fn cancel(state: state_type!(), event_id: String) -> Result<(), String> {
|
||||
log::info!("Cancel task: {event_id}");
|
||||
state.task_manager.cancel_task(&event_id).await?;
|
||||
state
|
||||
.db
|
||||
.update_task(&event_id, "cancelled", "任务取消", None)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -743,7 +756,7 @@ pub async fn generate_video_subtitle(
|
||||
let emitter = EventEmitter::new(state.app_handle.clone());
|
||||
#[cfg(feature = "headless")]
|
||||
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
|
||||
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
|
||||
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
|
||||
let task = TaskRow {
|
||||
id: event_id.clone(),
|
||||
task_type: "generate_video_subtitle".to_string(),
|
||||
@@ -861,7 +874,7 @@ pub async fn encode_video_subtitle(
|
||||
let emitter = EventEmitter::new(state.app_handle.clone());
|
||||
#[cfg(feature = "headless")]
|
||||
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
|
||||
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
|
||||
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
|
||||
let task = TaskRow {
|
||||
id: event_id.clone(),
|
||||
task_type: "encode_video_subtitle".to_string(),
|
||||
@@ -956,14 +969,14 @@ pub async fn import_external_video(
|
||||
#[cfg(feature = "headless")]
|
||||
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
|
||||
|
||||
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
|
||||
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
|
||||
|
||||
let source_path = Path::new(&file_path);
|
||||
if !source_path.exists() {
|
||||
return Err("文件不存在".to_string());
|
||||
}
|
||||
|
||||
reporter.update("正在提取视频元数据...");
|
||||
reporter.update("正在提取视频元数据...").await;
|
||||
let metadata = ffmpeg::extract_video_metadata(source_path).await?;
|
||||
let output_str = state.config.read().await.output.clone();
|
||||
let output_dir = Path::new(&output_str);
|
||||
@@ -989,7 +1002,7 @@ pub async fn import_external_video(
|
||||
let final_target_full_path = if need_conversion {
|
||||
let mp4_target_full_path = target_full_path.with_extension("mp4");
|
||||
|
||||
reporter.update("准备转换视频格式 (FLV → MP4)...");
|
||||
reporter.update("准备转换视频格式 (FLV → MP4)...").await;
|
||||
|
||||
copy_and_convert_with_progress(source_path, &mp4_target_full_path, true, &reporter).await?;
|
||||
|
||||
@@ -1008,7 +1021,7 @@ pub async fn import_external_video(
|
||||
};
|
||||
|
||||
// 步骤3: 生成缩略图
|
||||
reporter.update("正在生成视频缩略图...");
|
||||
reporter.update("正在生成视频缩略图...").await;
|
||||
|
||||
// 生成缩略图,使用智能时间点选择
|
||||
let thumbnail_timestamp = get_optimal_thumbnail_timestamp(metadata.duration);
|
||||
@@ -1022,7 +1035,7 @@ pub async fn import_external_video(
|
||||
};
|
||||
|
||||
// 步骤4: 保存到数据库
|
||||
reporter.update("正在保存视频信息...");
|
||||
reporter.update("正在保存视频信息...").await;
|
||||
|
||||
let Ok(size) = i64::try_from(
|
||||
final_target_full_path
|
||||
@@ -1091,7 +1104,7 @@ pub async fn clip_video(
|
||||
let emitter = EventEmitter::new(state.app_handle.clone());
|
||||
#[cfg(feature = "headless")]
|
||||
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
|
||||
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
|
||||
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
|
||||
|
||||
// 创建任务记录
|
||||
let task = TaskRow {
|
||||
@@ -1185,7 +1198,7 @@ async fn clip_video_inner(
|
||||
let output_full_path = output_dir.join(&output_filename);
|
||||
|
||||
// 执行切片
|
||||
reporter.update("开始切片处理");
|
||||
reporter.update("开始切片处理").await;
|
||||
ffmpeg::clip_from_video_file(
|
||||
Some(reporter),
|
||||
&input_path,
|
||||
@@ -1314,7 +1327,7 @@ pub async fn batch_import_external_videos(
|
||||
let emitter = EventEmitter::new(state.app_handle.clone());
|
||||
#[cfg(feature = "headless")]
|
||||
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
|
||||
let batch_reporter = ProgressReporter::new(&emitter, &event_id).await?;
|
||||
let batch_reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
|
||||
|
||||
let total_files = file_paths.len();
|
||||
|
||||
@@ -1327,9 +1340,11 @@ pub async fn batch_import_external_videos(
|
||||
.to_string();
|
||||
|
||||
// 更新批量进度,只显示进度信息
|
||||
batch_reporter.update(&format!(
|
||||
"正在导入第{current_index}个,共{total_files}个文件"
|
||||
));
|
||||
batch_reporter
|
||||
.update(&format!(
|
||||
"正在导入第{current_index}个,共{total_files}个文件"
|
||||
))
|
||||
.await;
|
||||
|
||||
// 为每个文件创建独立的事件ID
|
||||
let file_event_id = format!("{event_id}_file_{index}");
|
||||
|
||||
@@ -14,6 +14,7 @@ mod progress;
|
||||
mod recorder_manager;
|
||||
mod state;
|
||||
mod subtitle_generator;
|
||||
mod task;
|
||||
#[cfg(feature = "gui")]
|
||||
mod tray;
|
||||
mod webhook;
|
||||
@@ -421,6 +422,7 @@ impl MigrationSource<'static> for MigrationList {
|
||||
async fn setup_server_state(args: Args) -> Result<State, Box<dyn std::error::Error>> {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::task::TaskManager;
|
||||
use progress::progress_manager::ProgressManager;
|
||||
use progress::progress_reporter::EventEmitter;
|
||||
|
||||
@@ -467,10 +469,14 @@ async fn setup_server_state(args: Args) -> Result<State, Box<dyn std::error::Err
|
||||
let emitter = EventEmitter::new(progress_manager.get_event_sender());
|
||||
let webhook_poster =
|
||||
webhook::poster::create_webhook_poster(&config.read().await.webhook_url, None).unwrap();
|
||||
let mut task_manager = TaskManager::new();
|
||||
task_manager.start();
|
||||
let task_manager = Arc::new(task_manager);
|
||||
let recorder_manager = Arc::new(RecorderManager::new(
|
||||
emitter,
|
||||
db.clone(),
|
||||
config.clone(),
|
||||
task_manager.clone(),
|
||||
webhook_poster.clone(),
|
||||
));
|
||||
|
||||
@@ -485,6 +491,7 @@ async fn setup_server_state(args: Args) -> Result<State, Box<dyn std::error::Err
|
||||
config,
|
||||
webhook_poster,
|
||||
recorder_manager,
|
||||
task_manager,
|
||||
progress_manager,
|
||||
readonly: args.readonly,
|
||||
})
|
||||
@@ -495,6 +502,8 @@ async fn setup_app_state(app: &tauri::App) -> Result<State, Box<dyn std::error::
|
||||
use platform_dirs::AppDirs;
|
||||
use progress::progress_reporter::EventEmitter;
|
||||
|
||||
use crate::task::TaskManager;
|
||||
|
||||
let log_dir = app.path().app_log_dir()?;
|
||||
setup_logging(&log_dir).await?;
|
||||
|
||||
@@ -527,12 +536,17 @@ async fn setup_app_state(app: &tauri::App) -> Result<State, Box<dyn std::error::
|
||||
db_clone.finish_pending_tasks().await?;
|
||||
let webhook_poster =
|
||||
webhook::poster::create_webhook_poster(&config.read().await.webhook_url, None).unwrap();
|
||||
let mut task_manager = TaskManager::new();
|
||||
task_manager.start();
|
||||
|
||||
let task_manager = Arc::new(task_manager);
|
||||
|
||||
let recorder_manager = Arc::new(RecorderManager::new(
|
||||
app.app_handle().clone(),
|
||||
emitter,
|
||||
db.clone(),
|
||||
config.clone(),
|
||||
task_manager.clone(),
|
||||
webhook_poster.clone(),
|
||||
));
|
||||
|
||||
@@ -551,6 +565,7 @@ async fn setup_app_state(app: &tauri::App) -> Result<State, Box<dyn std::error::
|
||||
db,
|
||||
config,
|
||||
recorder_manager,
|
||||
task_manager,
|
||||
app_handle: app.handle().clone(),
|
||||
webhook_poster,
|
||||
})
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use async_trait::async_trait;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use std::sync::LazyLock;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use recorder::events::RecorderEvent;
|
||||
|
||||
use crate::database::Database;
|
||||
|
||||
#[cfg(feature = "gui")]
|
||||
use {
|
||||
recorder::danmu::DanmuEntry,
|
||||
@@ -16,22 +15,16 @@ use {
|
||||
#[cfg(feature = "headless")]
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
type CancelFlagMap = std::collections::HashMap<String, Arc<AtomicBool>>;
|
||||
|
||||
static CANCEL_FLAG_MAP: LazyLock<Arc<RwLock<CancelFlagMap>>> =
|
||||
LazyLock::new(|| Arc::new(RwLock::new(CancelFlagMap::new())));
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProgressReporter {
|
||||
emitter: EventEmitter,
|
||||
pub event_id: String,
|
||||
#[allow(unused)]
|
||||
pub cancel: Arc<AtomicBool>,
|
||||
db: Arc<Database>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ProgressReporterTrait: Send + Sync + Clone {
|
||||
fn update(&self, content: &str);
|
||||
async fn update(&self, content: &str);
|
||||
async fn finish(&self, success: bool, message: &str);
|
||||
}
|
||||
|
||||
@@ -119,39 +112,30 @@ impl EventEmitter {
|
||||
}
|
||||
}
|
||||
impl ProgressReporter {
|
||||
pub async fn new(emitter: &EventEmitter, event_id: &str) -> Result<Self, String> {
|
||||
// if already exists, return
|
||||
if CANCEL_FLAG_MAP.read().await.get(event_id).is_some() {
|
||||
log::error!("Task already exists: {event_id}");
|
||||
emitter.emit(&RecorderEvent::ProgressFinished {
|
||||
id: event_id.to_string(),
|
||||
success: false,
|
||||
message: "任务已经存在".to_string(),
|
||||
});
|
||||
return Err("任务已经存在".to_string());
|
||||
}
|
||||
|
||||
let cancel = Arc::new(AtomicBool::new(false));
|
||||
CANCEL_FLAG_MAP
|
||||
.write()
|
||||
.await
|
||||
.insert(event_id.to_string(), cancel.clone());
|
||||
|
||||
pub async fn new(
|
||||
db: Arc<Database>,
|
||||
emitter: &EventEmitter,
|
||||
event_id: &str,
|
||||
) -> Result<Self, String> {
|
||||
Ok(Self {
|
||||
db,
|
||||
emitter: emitter.clone(),
|
||||
event_id: event_id.to_string(),
|
||||
cancel,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProgressReporterTrait for ProgressReporter {
|
||||
fn update(&self, content: &str) {
|
||||
async fn update(&self, content: &str) {
|
||||
self.emitter.emit(&RecorderEvent::ProgressUpdate {
|
||||
id: self.event_id.clone(),
|
||||
content: content.to_string(),
|
||||
});
|
||||
let _ = self
|
||||
.db
|
||||
.update_task(&self.event_id, "processing", content, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn finish(&self, success: bool, message: &str) {
|
||||
@@ -160,13 +144,5 @@ impl ProgressReporterTrait for ProgressReporter {
|
||||
success,
|
||||
message: message.to_string(),
|
||||
});
|
||||
CANCEL_FLAG_MAP.write().await.remove(&self.event_id);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cancel_progress(event_id: &str) {
|
||||
let mut cancel_flag_map = CANCEL_FLAG_MAP.write().await;
|
||||
if let Some(cancel_flag) = cancel_flag_map.get_mut(event_id) {
|
||||
cancel_flag.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::database::{Database, DatabaseError};
|
||||
use crate::ffmpeg::{encode_video_danmu, transcode, Range};
|
||||
use crate::progress::progress_reporter::{EventEmitter, ProgressReporter, ProgressReporterTrait};
|
||||
use crate::subtitle_generator::item_to_srt;
|
||||
use crate::task::{Task, TaskManager, TaskPriority};
|
||||
use crate::webhook::events::{self, Payload};
|
||||
use crate::webhook::poster::WebhookPoster;
|
||||
use chrono::DateTime;
|
||||
@@ -123,6 +124,7 @@ pub struct RecorderManager {
|
||||
emitter: EventEmitter,
|
||||
db: Arc<Database>,
|
||||
config: Arc<RwLock<Config>>,
|
||||
task_manager: Arc<TaskManager>,
|
||||
recorders: Arc<RwLock<HashMap<String, RecorderType>>>,
|
||||
to_remove: Arc<RwLock<HashSet<String>>>,
|
||||
event_tx: broadcast::Sender<RecorderEvent>,
|
||||
@@ -176,6 +178,7 @@ impl RecorderManager {
|
||||
emitter: EventEmitter,
|
||||
db: Arc<Database>,
|
||||
config: Arc<RwLock<Config>>,
|
||||
task_manager: Arc<TaskManager>,
|
||||
webhook_poster: WebhookPoster,
|
||||
) -> RecorderManager {
|
||||
let (event_tx, _) = broadcast::channel(100);
|
||||
@@ -185,6 +188,7 @@ impl RecorderManager {
|
||||
emitter,
|
||||
db,
|
||||
config,
|
||||
task_manager,
|
||||
recorders: Arc::new(RwLock::new(HashMap::new())),
|
||||
to_remove: Arc::new(RwLock::new(HashSet::new())),
|
||||
event_tx,
|
||||
@@ -360,7 +364,8 @@ impl RecorderManager {
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(reporter) = ProgressReporter::new(&self.emitter, &task.id).await else {
|
||||
let Ok(reporter) = ProgressReporter::new(self.db.clone(), &self.emitter, &task.id).await
|
||||
else {
|
||||
log::error!("Failed to create reporter");
|
||||
let _ = self
|
||||
.db
|
||||
@@ -371,43 +376,56 @@ impl RecorderManager {
|
||||
|
||||
log::info!("Create task: {} {}", task.id, task.task_type);
|
||||
|
||||
if let Err(e) = self
|
||||
.generate_whole_clip(
|
||||
Some(&reporter),
|
||||
self.config.read().await.auto_generate.encode_danmu,
|
||||
platform.as_str().to_string(),
|
||||
room_id,
|
||||
live_record.parent_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
log::error!("Failed to generate whole clip: {e}");
|
||||
let _ = reporter
|
||||
.finish(false, &format!("Failed to generate whole clip: {e}"))
|
||||
.await;
|
||||
let _ = self
|
||||
.db
|
||||
.update_task(
|
||||
&task.id,
|
||||
"failed",
|
||||
&format!("Failed to generate whole clip: {e}"),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = reporter
|
||||
.finish(true, "Whole clip generated successfully")
|
||||
.await;
|
||||
let self_clone = self.clone();
|
||||
let task_id = task.id.clone();
|
||||
let room_id = room_id.to_string();
|
||||
let _ = self
|
||||
.db
|
||||
.update_task(
|
||||
&task.id,
|
||||
"success",
|
||||
"Whole clip generated successfully",
|
||||
None,
|
||||
)
|
||||
.task_manager
|
||||
.add_task(Task::new(
|
||||
task_id.clone(),
|
||||
TaskPriority::Normal,
|
||||
async move {
|
||||
if let Err(e) = self_clone
|
||||
.generate_whole_clip(
|
||||
Some(&reporter),
|
||||
self_clone.config.read().await.auto_generate.encode_danmu,
|
||||
platform.as_str().to_string(),
|
||||
&room_id,
|
||||
live_record.parent_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
log::error!("Failed to generate whole clip: {e}");
|
||||
let _ = reporter
|
||||
.finish(false, &format!("Failed to generate whole clip: {e}"))
|
||||
.await;
|
||||
let _ = self_clone
|
||||
.db
|
||||
.update_task(
|
||||
&task_id,
|
||||
"failed",
|
||||
&format!("Failed to generate whole clip: {e}"),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
return Err(format!("Failed to generate whole clip: {e}"));
|
||||
}
|
||||
|
||||
let _ = reporter
|
||||
.finish(true, "Whole clip generated successfully")
|
||||
.await;
|
||||
let _ = self_clone
|
||||
.db
|
||||
.update_task(
|
||||
&task_id,
|
||||
"success",
|
||||
"Whole clip generated successfully",
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
Ok(())
|
||||
},
|
||||
))
|
||||
.await;
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ use tokio::sync::RwLock;
|
||||
use crate::config::Config;
|
||||
use crate::database::Database;
|
||||
use crate::recorder_manager::RecorderManager;
|
||||
use crate::task::TaskManager;
|
||||
use crate::webhook::poster::WebhookPoster;
|
||||
|
||||
#[cfg(feature = "headless")]
|
||||
@@ -15,6 +16,7 @@ pub struct State {
|
||||
pub config: Arc<RwLock<Config>>,
|
||||
pub webhook_poster: WebhookPoster,
|
||||
pub recorder_manager: Arc<RecorderManager>,
|
||||
pub task_manager: Arc<TaskManager>,
|
||||
#[cfg(not(feature = "headless"))]
|
||||
pub app_handle: tauri::AppHandle,
|
||||
#[cfg(feature = "headless")]
|
||||
|
||||
@@ -106,7 +106,7 @@ impl SubtitleGeneratorType {
|
||||
pub trait SubtitleGenerator {
|
||||
async fn generate_subtitle(
|
||||
&self,
|
||||
reporter: Option<&impl ProgressReporterTrait>,
|
||||
reporter: Option<&(impl ProgressReporterTrait + 'static)>,
|
||||
audio_path: &Path,
|
||||
language_hint: &str,
|
||||
) -> Result<GenerateResult, String>;
|
||||
|
||||
@@ -36,7 +36,7 @@ pub async fn new(model: &Path, prompt: &str) -> Result<WhisperCPP, String> {
|
||||
impl SubtitleGenerator for WhisperCPP {
|
||||
async fn generate_subtitle(
|
||||
&self,
|
||||
reporter: Option<&impl ProgressReporterTrait>,
|
||||
reporter: Option<&(impl ProgressReporterTrait + 'static)>,
|
||||
audio_path: &Path,
|
||||
language_hint: &str,
|
||||
) -> Result<GenerateResult, String> {
|
||||
@@ -71,7 +71,7 @@ impl SubtitleGenerator for WhisperCPP {
|
||||
let mut inter_samples = vec![Default::default(); samples.len()];
|
||||
|
||||
if let Some(reporter) = reporter {
|
||||
reporter.update("处理音频中");
|
||||
reporter.update("处理音频中").await;
|
||||
}
|
||||
if let Err(e) = whisper_rs::convert_integer_to_float_audio(&samples, &mut inter_samples) {
|
||||
return Err(e.to_string());
|
||||
@@ -85,7 +85,7 @@ impl SubtitleGenerator for WhisperCPP {
|
||||
let samples = samples.unwrap();
|
||||
|
||||
if let Some(reporter) = reporter {
|
||||
reporter.update("生成字幕中");
|
||||
reporter.update("生成字幕中").await;
|
||||
}
|
||||
if let Err(e) = state.full(params, &samples[..]) {
|
||||
log::error!("failed to run model: {e}");
|
||||
@@ -143,14 +143,14 @@ mod tests {
|
||||
struct MockReporter {}
|
||||
impl MockReporter {
|
||||
#[allow(dead_code)]
|
||||
fn update(&self, _message: &str) {
|
||||
async fn update(&self, _message: &str) {
|
||||
// mock implementation
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProgressReporterTrait for MockReporter {
|
||||
fn update(&self, message: &str) {
|
||||
async fn update(&self, message: &str) {
|
||||
println!("Mock update: {message}");
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ pub async fn new(
|
||||
impl SubtitleGenerator for WhisperOnline {
|
||||
async fn generate_subtitle(
|
||||
&self,
|
||||
reporter: Option<&impl ProgressReporterTrait>,
|
||||
reporter: Option<&(impl ProgressReporterTrait + 'static)>,
|
||||
audio_path: &Path,
|
||||
language_hint: &str,
|
||||
) -> Result<GenerateResult, String> {
|
||||
@@ -63,7 +63,7 @@ impl SubtitleGenerator for WhisperOnline {
|
||||
|
||||
// Read audio file
|
||||
if let Some(reporter) = reporter {
|
||||
reporter.update("读取音频文件中");
|
||||
reporter.update("读取音频文件中").await;
|
||||
}
|
||||
let audio_data = fs::read(audio_path)
|
||||
.await
|
||||
@@ -115,7 +115,7 @@ impl SubtitleGenerator for WhisperOnline {
|
||||
}
|
||||
|
||||
if let Some(reporter) = reporter {
|
||||
reporter.update("上传音频中");
|
||||
reporter.update("上传音频中").await;
|
||||
}
|
||||
let response = req_builder
|
||||
.timeout(std::time::Duration::from_secs(3 * 60)) // 3 minutes timeout
|
||||
@@ -195,7 +195,7 @@ mod tests {
|
||||
|
||||
#[async_trait]
|
||||
impl ProgressReporterTrait for MockReporter {
|
||||
fn update(&self, message: &str) {
|
||||
async fn update(&self, message: &str) {
|
||||
println!("Mock update: {message}");
|
||||
}
|
||||
|
||||
|
||||
437
src-tauri/src/task/mod.rs
Normal file
437
src-tauri/src/task/mod.rs
Normal file
@@ -0,0 +1,437 @@
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::{interval, Duration};
|
||||
|
||||
/// Task execution function type
|
||||
pub type TaskFn = Pin<Box<dyn Future<Output = Result<(), String>> + Send + 'static>>;
|
||||
|
||||
/// Task status enumeration
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum TaskStatus {
|
||||
Pending,
|
||||
Running,
|
||||
Completed,
|
||||
Failed(String),
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// Task priority levels
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[allow(dead_code)]
|
||||
pub enum TaskPriority {
|
||||
Low = 0,
|
||||
Normal = 1,
|
||||
High = 2,
|
||||
Critical = 3,
|
||||
}
|
||||
|
||||
/// Represents a single task
|
||||
#[allow(dead_code)]
|
||||
pub struct Task {
|
||||
pub task_id: String,
|
||||
pub priority: TaskPriority,
|
||||
pub status: TaskStatus,
|
||||
pub created_at: std::time::Instant,
|
||||
task_fn: Option<TaskFn>,
|
||||
cancel_handle: Option<tokio::sync::oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl Task {
|
||||
pub fn new<F>(task_id: String, priority: TaskPriority, task_fn: F) -> Self
|
||||
where
|
||||
F: Future<Output = Result<(), String>> + Send + 'static,
|
||||
{
|
||||
Task {
|
||||
task_id,
|
||||
priority,
|
||||
status: TaskStatus::Pending,
|
||||
created_at: std::time::Instant::now(),
|
||||
task_fn: Some(Box::pin(task_fn)),
|
||||
cancel_handle: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn new_with_default_priority<F>(task_id: String, task_fn: F) -> Self
|
||||
where
|
||||
F: Future<Output = Result<(), String>> + Send + 'static,
|
||||
{
|
||||
Self::new(task_id, TaskPriority::Normal, task_fn)
|
||||
}
|
||||
}
|
||||
|
||||
/// Task manager configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskManagerConfig {
|
||||
/// Maximum number of tasks that can run concurrently
|
||||
pub max_concurrent_tasks: usize,
|
||||
/// Interval for checking the queue (in milliseconds)
|
||||
pub check_interval_ms: u64,
|
||||
/// Maximum queue size (0 = unlimited)
|
||||
pub max_queue_size: usize,
|
||||
}
|
||||
|
||||
impl Default for TaskManagerConfig {
|
||||
fn default() -> Self {
|
||||
TaskManagerConfig {
|
||||
max_concurrent_tasks: 1,
|
||||
check_interval_ms: 1000,
|
||||
max_queue_size: 20,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Main task manager structure
|
||||
pub struct TaskManager {
|
||||
config: TaskManagerConfig,
|
||||
queue: Arc<Mutex<VecDeque<Task>>>,
|
||||
running_tasks: Arc<RwLock<HashMap<String, JoinHandle<()>>>>,
|
||||
task_statuses: Arc<RwLock<HashMap<String, TaskStatus>>>,
|
||||
scheduler_handle: Option<JoinHandle<()>>,
|
||||
shutdown_tx: Option<tokio::sync::broadcast::Sender<()>>,
|
||||
}
|
||||
|
||||
impl TaskManager {
|
||||
/// Create a new task manager with default configuration
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(TaskManagerConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new task manager with custom configuration
|
||||
pub fn with_config(config: TaskManagerConfig) -> Self {
|
||||
TaskManager {
|
||||
config,
|
||||
queue: Arc::new(Mutex::new(VecDeque::new())),
|
||||
running_tasks: Arc::new(RwLock::new(HashMap::new())),
|
||||
task_statuses: Arc::new(RwLock::new(HashMap::new())),
|
||||
scheduler_handle: None,
|
||||
shutdown_tx: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the task manager's scheduler
|
||||
pub fn start(&mut self) {
|
||||
let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
|
||||
self.shutdown_tx = Some(shutdown_tx.clone());
|
||||
|
||||
let queue = Arc::clone(&self.queue);
|
||||
let running_tasks = Arc::clone(&self.running_tasks);
|
||||
let task_statuses = Arc::clone(&self.task_statuses);
|
||||
let config = self.config.clone();
|
||||
let mut shutdown_rx = shutdown_tx.subscribe();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let mut check_interval = interval(Duration::from_millis(config.check_interval_ms));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = check_interval.tick() => {
|
||||
Self::process_queue(
|
||||
&queue,
|
||||
&running_tasks,
|
||||
&task_statuses,
|
||||
config.max_concurrent_tasks,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
self.scheduler_handle = Some(handle);
|
||||
}
|
||||
|
||||
/// Stop the task manager's scheduler
|
||||
#[allow(dead_code)]
|
||||
pub async fn stop(&mut self) {
|
||||
if let Some(tx) = self.shutdown_tx.take() {
|
||||
let _ = tx.send(());
|
||||
}
|
||||
|
||||
if let Some(handle) = self.scheduler_handle.take() {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
// Cancel all running tasks
|
||||
let running_tasks = self.running_tasks.read().await;
|
||||
for handle in running_tasks.values() {
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a task to the queue
|
||||
pub async fn add_task(&self, task: Task) -> Result<(), String> {
|
||||
let mut queue = self.queue.lock().await;
|
||||
|
||||
if self.config.max_queue_size > 0 && queue.len() >= self.config.max_queue_size {
|
||||
return Err("Queue is full".to_string());
|
||||
}
|
||||
|
||||
let task_id = task.task_id.clone();
|
||||
|
||||
// Insert task based on priority
|
||||
let insert_pos = queue
|
||||
.iter()
|
||||
.position(|t| t.priority < task.priority)
|
||||
.unwrap_or(queue.len());
|
||||
|
||||
queue.insert(insert_pos, task);
|
||||
|
||||
// Update task status
|
||||
let mut statuses = self.task_statuses.write().await;
|
||||
statuses.insert(task_id, TaskStatus::Pending);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Cancel a task by ID
|
||||
pub async fn cancel_task(&self, task_id: &str) -> Result<(), String> {
|
||||
// Check if task is running
|
||||
let mut running_tasks = self.running_tasks.write().await;
|
||||
if let Some(handle) = running_tasks.remove(task_id) {
|
||||
handle.abort();
|
||||
drop(running_tasks); // Release lock before await
|
||||
let mut statuses = self.task_statuses.write().await;
|
||||
statuses.insert(task_id.to_string(), TaskStatus::Cancelled);
|
||||
return Ok(());
|
||||
}
|
||||
drop(running_tasks); // Release lock before await
|
||||
|
||||
// Check if task is in queue
|
||||
let pos = {
|
||||
let queue = self.queue.lock().await;
|
||||
queue.iter().position(|t| t.task_id == task_id)
|
||||
};
|
||||
|
||||
if let Some(pos) = pos {
|
||||
{
|
||||
let mut queue = self.queue.lock().await;
|
||||
queue.remove(pos);
|
||||
} // Release queue lock before await
|
||||
let mut statuses = self.task_statuses.write().await;
|
||||
statuses.insert(task_id.to_string(), TaskStatus::Cancelled);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err("Task not found".to_string())
|
||||
}
|
||||
|
||||
/// Get the status of a task
|
||||
#[allow(dead_code)]
|
||||
pub async fn get_task_status(&self, task_id: &str) -> Option<TaskStatus> {
|
||||
let statuses = self.task_statuses.read().await;
|
||||
statuses.get(task_id).cloned()
|
||||
}
|
||||
|
||||
/// Get all task statuses
|
||||
#[allow(dead_code)]
|
||||
pub async fn get_all_task_statuses(&self) -> HashMap<String, TaskStatus> {
|
||||
let statuses = self.task_statuses.read().await;
|
||||
statuses.clone()
|
||||
}
|
||||
|
||||
/// Get the number of tasks in queue
|
||||
#[allow(dead_code)]
|
||||
pub async fn queue_size(&self) -> usize {
|
||||
let queue = self.queue.lock().await;
|
||||
queue.len()
|
||||
}
|
||||
|
||||
/// Get the number of running tasks
|
||||
#[allow(dead_code)]
|
||||
pub async fn running_count(&self) -> usize {
|
||||
let running = self.running_tasks.read().await;
|
||||
running.len()
|
||||
}
|
||||
|
||||
/// Clear all completed and failed tasks from status tracking
|
||||
#[allow(dead_code)]
|
||||
pub async fn clear_finished_tasks(&self) {
|
||||
let mut statuses = self.task_statuses.write().await;
|
||||
statuses.retain(|_, status| {
|
||||
!matches!(
|
||||
status,
|
||||
TaskStatus::Completed | TaskStatus::Failed(_) | TaskStatus::Cancelled
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
/// Process the queue and start tasks if slots are available
|
||||
async fn process_queue(
|
||||
queue: &Arc<Mutex<VecDeque<Task>>>,
|
||||
running_tasks: &Arc<RwLock<HashMap<String, JoinHandle<()>>>>,
|
||||
task_statuses: &Arc<RwLock<HashMap<String, TaskStatus>>>,
|
||||
max_concurrent: usize,
|
||||
) {
|
||||
// Check if we have available slots
|
||||
let running_count = {
|
||||
let running = running_tasks.read().await;
|
||||
running.len()
|
||||
};
|
||||
|
||||
if running_count >= max_concurrent {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the next task from queue
|
||||
let mut task_opt = {
|
||||
let mut queue = queue.lock().await;
|
||||
queue.pop_front()
|
||||
};
|
||||
|
||||
if let Some(mut task) = task_opt.take() {
|
||||
let task_id = task.task_id.clone();
|
||||
|
||||
// Update status to running
|
||||
{
|
||||
let mut statuses = task_statuses.write().await;
|
||||
statuses.insert(task_id.clone(), TaskStatus::Running);
|
||||
}
|
||||
|
||||
// Take the task function
|
||||
if let Some(task_fn) = task.task_fn.take() {
|
||||
let task_id_clone = task_id.clone();
|
||||
let running_tasks_clone = Arc::clone(running_tasks);
|
||||
let task_statuses_clone = Arc::clone(task_statuses);
|
||||
|
||||
// Spawn the task
|
||||
let handle = tokio::spawn(async move {
|
||||
let result = task_fn.await;
|
||||
|
||||
// Update status based on result
|
||||
let mut statuses = task_statuses_clone.write().await;
|
||||
match result {
|
||||
Ok(_) => {
|
||||
statuses.insert(task_id_clone.clone(), TaskStatus::Completed);
|
||||
}
|
||||
Err(e) => {
|
||||
statuses.insert(task_id_clone.clone(), TaskStatus::Failed(e));
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from running tasks
|
||||
let mut running = running_tasks_clone.write().await;
|
||||
running.remove(&task_id_clone);
|
||||
});
|
||||
|
||||
// Add to running tasks
|
||||
let mut running = running_tasks.write().await;
|
||||
running.insert(task_id, handle);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up completed tasks from running list
|
||||
Self::cleanup_finished_tasks(running_tasks).await;
|
||||
}
|
||||
|
||||
/// Clean up finished task handles
|
||||
async fn cleanup_finished_tasks(running_tasks: &Arc<RwLock<HashMap<String, JoinHandle<()>>>>) {
|
||||
let mut running = running_tasks.write().await;
|
||||
let mut finished_ids = Vec::new();
|
||||
|
||||
for (task_id, handle) in running.iter() {
|
||||
if handle.is_finished() {
|
||||
finished_ids.push(task_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for task_id in finished_ids {
|
||||
running.remove(&task_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TaskManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TaskManager {
|
||||
fn drop(&mut self) {
|
||||
if let Some(tx) = self.shutdown_tx.take() {
|
||||
let _ = tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_task_manager_basic() {
|
||||
let mut manager = TaskManager::new();
|
||||
manager.start();
|
||||
|
||||
let task = Task::new("test-1".to_string(), TaskPriority::Normal, async {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
manager.add_task(task).await.unwrap();
|
||||
sleep(Duration::from_millis(200)).await;
|
||||
|
||||
let status = manager.get_task_status("test-1").await;
|
||||
assert_eq!(status, Some(TaskStatus::Completed));
|
||||
|
||||
manager.stop().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_task_cancellation() {
|
||||
let mut manager = TaskManager::new();
|
||||
manager.start();
|
||||
|
||||
let task = Task::new("test-cancel".to_string(), TaskPriority::Normal, async {
|
||||
sleep(Duration::from_secs(10)).await;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
manager.add_task(task).await.unwrap();
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
|
||||
manager.cancel_task("test-cancel").await.unwrap();
|
||||
let status = manager.get_task_status("test-cancel").await;
|
||||
assert_eq!(status, Some(TaskStatus::Cancelled));
|
||||
|
||||
manager.stop().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_task_priority() {
|
||||
let mut manager = TaskManager::with_config(TaskManagerConfig {
|
||||
max_concurrent_tasks: 1,
|
||||
check_interval_ms: 100,
|
||||
max_queue_size: 10,
|
||||
});
|
||||
manager.start();
|
||||
|
||||
// Add low priority task
|
||||
let task1 = Task::new("low".to_string(), TaskPriority::Low, async {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
// Add high priority task
|
||||
let task2 = Task::new("high".to_string(), TaskPriority::High, async {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
manager.add_task(task1).await.unwrap();
|
||||
manager.add_task(task2).await.unwrap();
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
manager.stop().await;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user