feat: task scheduler (#216)

* feat: update progress in task message

* feat: task scheduler for queued tasks
This commit is contained in:
Xinrea
2025-11-02 11:20:27 +08:00
committed by GitHub
parent 8bea9336ae
commit 83c6979973
13 changed files with 683 additions and 170 deletions

View File

@@ -43,7 +43,7 @@ pub async fn handle_ffmpeg_process(
if content.starts_with("out_time_ms") {
let time_str = content.strip_prefix("out_time_ms=").unwrap_or_default();
if let Some(reporter) = reporter {
reporter.update(time_str);
reporter.update(time_str).await;
}
}
}

View File

@@ -104,7 +104,8 @@ pub async fn transcode(
}
reporter
.unwrap()
.update(format!("压制中:{}", p.time).as_str());
.update(format!("压制中:{}", p.time).as_str())
.await;
}
FfmpegEvent::LogEOF => break,
FfmpegEvent::Error(e) => {
@@ -161,7 +162,8 @@ pub async fn trim_video(
}
reporter
.unwrap()
.update(format!("切片中:{}", p.time).as_str());
.update(format!("切片中:{}", p.time).as_str())
.await;
}
FfmpegEvent::LogEOF => break,
FfmpegEvent::Error(e) => {
@@ -429,7 +431,9 @@ pub async fn encode_video_subtitle(
}
FfmpegEvent::Progress(p) => {
log::info!("Encode video subtitle progress: {}", p.time);
reporter.update(format!("压制中:{}", p.time).as_str());
reporter
.update(format!("压制中:{}", p.time).as_str())
.await;
}
FfmpegEvent::LogEOF => break,
FfmpegEvent::Log(_level, _content) => {}
@@ -528,7 +532,8 @@ pub async fn encode_video_danmu(
}
reporter
.unwrap()
.update(format!("压制中:{}", p.time).as_str());
.update(format!("压制中:{}", p.time).as_str())
.await;
}
FfmpegEvent::Log(_level, _content) => {}
FfmpegEvent::LogEOF => break,
@@ -813,7 +818,7 @@ pub async fn clip_from_video_file(
match event {
FfmpegEvent::Progress(p) => {
if let Some(reporter) = reporter {
reporter.update(&format!("切片进度: {}", p.time));
reporter.update(&format!("切片进度: {}", p.time)).await;
}
}
FfmpegEvent::LogEOF => break,
@@ -980,7 +985,9 @@ pub async fn execute_ffmpeg_conversion(
while let Ok(event) = parser.parse_next_event().await {
match event {
FfmpegEvent::Progress(p) => {
reporter.update(&format!("正在转换视频格式... {} ({})", p.time, mode_name));
reporter
.update(&format!("正在转换视频格式... {} ({})", p.time, mode_name))
.await;
}
FfmpegEvent::LogEOF => break,
FfmpegEvent::Log(level, content) => {
@@ -1008,7 +1015,9 @@ pub async fn execute_ffmpeg_conversion(
return Err(format!("视频格式转换失败 ({mode_name}): {error_msg}"));
}
reporter.update(&format!("视频格式转换完成 100% ({mode_name})"));
reporter
.update(&format!("视频格式转换完成 100% ({mode_name})"))
.await;
Ok(())
}
@@ -1018,7 +1027,7 @@ pub async fn try_stream_copy_conversion(
dest: &Path,
reporter: &ProgressReporter,
) -> Result<(), String> {
reporter.update("正在转换视频格式... 0% (无损模式)");
reporter.update("正在转换视频格式... 0% (无损模式)").await;
// 构建ffmpeg命令 - 流复制模式
let mut cmd = tokio::process::Command::new(ffmpeg_path());
@@ -1051,7 +1060,7 @@ pub async fn try_high_quality_conversion(
dest: &Path,
reporter: &ProgressReporter,
) -> Result<(), String> {
reporter.update("正在转换视频格式... 0% (高质量模式)");
reporter.update("正在转换视频格式... 0% (高质量模式)").await;
// 构建ffmpeg命令 - 高质量重编码
let mut cmd = tokio::process::Command::new(ffmpeg_path());
@@ -1094,7 +1103,7 @@ pub async fn convert_video_format(
match try_stream_copy_conversion(source, dest, reporter).await {
Ok(()) => Ok(()),
Err(stream_copy_error) => {
reporter.update("流复制失败,使用高质量重编码模式...");
reporter.update("流复制失败,使用高质量重编码模式...").await;
log::warn!("Stream copy failed: {stream_copy_error}, falling back to re-encoding");
try_high_quality_conversion(source, dest, reporter).await
}

View File

@@ -10,6 +10,8 @@ use crate::progress::progress_reporter::ProgressReporterTrait;
use crate::recorder_manager::RecorderList;
use crate::state::State;
use crate::state_type;
use crate::task::Task;
use crate::task::TaskPriority;
use crate::webhook::events;
use recorder::account::Account;
use recorder::danmu::DanmuEntry;
@@ -457,7 +459,7 @@ pub async fn generate_whole_clip(
let emitter = EventEmitter::new(state.app_handle.clone());
#[cfg(feature = "headless")]
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
let reporter = ProgressReporter::new(&emitter, &task.id).await?;
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &task.id).await?;
log::info!("Create task: {} {}", task.id, task.task_type);
// create a tokio task to run in background
@@ -467,27 +469,42 @@ pub async fn generate_whole_clip(
let state_clone = state.clone();
let task_id = task.id.clone();
tokio::spawn(async move {
match state_clone
.recorder_manager
.generate_whole_clip(Some(&reporter), encode_danmu, platform, &room_id, parent_id)
.await
{
Ok(()) => {
reporter.finish(true, "切片生成完成").await;
let _ = state_clone
.db
.update_task(&task_id, "success", "切片生成完成", None)
.await;
}
Err(e) => {
reporter.finish(false, &format!("切片生成失败: {e}")).await;
let _ = state_clone
.db
.update_task(&task_id, "failed", &format!("切片生成失败: {e}"), None)
.await;
}
}
});
state
.task_manager
.add_task(Task::new(
task_id.clone(),
TaskPriority::Normal,
async move {
match state_clone
.recorder_manager
.generate_whole_clip(
Some(&reporter),
encode_danmu,
platform,
&room_id,
parent_id,
)
.await
{
Ok(()) => {
reporter.finish(true, "切片生成完成").await;
let _ = state_clone
.db
.update_task(&task_id, "success", "切片生成完成", None)
.await;
Ok(())
}
Err(e) => {
reporter.finish(false, &format!("切片生成失败: {e}")).await;
let _ = state_clone
.db
.update_task(&task_id, "failed", &format!("切片生成失败: {e}"), None)
.await;
Err(format!("切片生成失败: {e}"))
}
}
},
))
.await?;
Ok(task)
}

View File

@@ -2,9 +2,7 @@ use crate::database::task::TaskRow;
use crate::database::video::VideoRow;
use crate::ffmpeg;
use crate::handlers::utils::get_disk_info_inner;
use crate::progress::progress_reporter::{
cancel_progress, EventEmitter, ProgressReporter, ProgressReporterTrait,
};
use crate::progress::progress_reporter::{EventEmitter, ProgressReporter, ProgressReporterTrait};
use crate::recorder_manager::ClipRangeParams;
use crate::subtitle_generator::item_to_srt;
use crate::webhook::events;
@@ -125,7 +123,9 @@ async fn copy_file_with_progress(
let report_threshold = 1; // 每1%报告一次
if percent != last_reported_percent && (percent % report_threshold == 0 || percent == 100) {
reporter.update(&format!("正在复制视频文件... {percent}%"));
reporter
.update(&format!("正在复制视频文件... {percent}%"))
.await;
last_reported_percent = percent;
}
}
@@ -155,11 +155,13 @@ async fn copy_and_convert_with_progress(
if is_network_source {
// 网络文件:先复制到本地临时位置,再转换
reporter.update("检测到网络文件,使用先复制后转换策略...");
reporter
.update("检测到网络文件,使用先复制后转换策略...")
.await;
copy_then_convert_strategy(source, dest, reporter).await
} else {
// 本地文件:直接转换(更高效)
reporter.update("检测到本地文件,使用直接转换策略...");
reporter.update("检测到本地文件,使用直接转换策略...").await;
ffmpeg::convert_video_format(source, dest, reporter).await
}
}
@@ -185,11 +187,13 @@ async fn copy_then_convert_strategy(
}
// 第一步:将网络文件复制到本地临时位置(使用优化的缓冲区)
reporter.update("第1步从网络复制文件到本地临时位置...");
reporter
.update("第1步从网络复制文件到本地临时位置...")
.await;
copy_file_with_network_optimization(source, &temp_path, reporter).await?;
// 第二步:从本地临时文件转换到目标位置
reporter.update("第2步从临时文件转换到目标格式...");
reporter.update("第2步从临时文件转换到目标格式...").await;
let convert_result = ffmpeg::convert_video_format(&temp_path, dest, reporter).await;
// 清理临时文件
@@ -251,12 +255,14 @@ async fn copy_file_with_network_optimization(
// 网络文件更频繁地报告进度
if percent != last_reported_percent {
reporter.update(&format!(
"正在从网络复制文件... {}% ({:.1}MB/{:.1}MB)",
percent,
copied as f64 / (1024.0 * 1024.0),
total_size as f64 / (1024.0 * 1024.0)
));
reporter
.update(&format!(
"正在从网络复制文件... {}% ({:.1}MB/{:.1}MB)",
percent,
copied as f64 / (1024.0 * 1024.0),
total_size as f64 / (1024.0 * 1024.0)
))
.await;
last_reported_percent = percent;
}
}
@@ -270,9 +276,11 @@ async fn copy_file_with_network_optimization(
// 等待一小段时间后重试
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
reporter.update(&format!(
"网络连接中断,正在重试... ({consecutive_errors}/{MAX_RETRIES})"
));
reporter
.update(&format!(
"网络连接中断,正在重试... ({consecutive_errors}/{MAX_RETRIES})"
))
.await;
}
}
}
@@ -280,7 +288,7 @@ async fn copy_file_with_network_optimization(
dest_file
.flush()
.map_err(|e| format!("刷新临时文件缓冲区失败: {e}"))?;
reporter.update("网络文件复制完成");
reporter.update("网络文件复制完成").await;
Ok(())
}
@@ -313,7 +321,7 @@ pub async fn clip_range(
let emitter = EventEmitter::new(state.app_handle.clone());
#[cfg(feature = "headless")]
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
let mut params_without_cover = params.clone();
params_without_cover.cover = String::new();
let task = TaskRow {
@@ -507,7 +515,7 @@ pub async fn upload_procedure(
let emitter = EventEmitter::new(state.app_handle.clone());
#[cfg(feature = "headless")]
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
let task = TaskRow {
id: event_id.clone(),
task_type: "upload_procedure".to_string(),
@@ -562,7 +570,7 @@ async fn upload_procedure_inner(
let path = Path::new(&file);
let client = reqwest::Client::new();
let cover_url = bilibili::api::upload_cover(&client, &account.to_account(), &cover).await;
reporter.update("投稿预处理中");
reporter.update("投稿预处理中").await;
match bilibili::api::prepare_video(&client, &account.to_account(), path).await {
Ok(video) => {
@@ -614,8 +622,13 @@ async fn upload_procedure_inner(
}
#[cfg_attr(feature = "gui", tauri::command)]
pub async fn cancel(_state: state_type!(), event_id: String) -> Result<(), String> {
cancel_progress(&event_id).await;
pub async fn cancel(state: state_type!(), event_id: String) -> Result<(), String> {
log::info!("Cancel task: {event_id}");
state.task_manager.cancel_task(&event_id).await?;
state
.db
.update_task(&event_id, "cancelled", "任务取消", None)
.await?;
Ok(())
}
@@ -743,7 +756,7 @@ pub async fn generate_video_subtitle(
let emitter = EventEmitter::new(state.app_handle.clone());
#[cfg(feature = "headless")]
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
let task = TaskRow {
id: event_id.clone(),
task_type: "generate_video_subtitle".to_string(),
@@ -861,7 +874,7 @@ pub async fn encode_video_subtitle(
let emitter = EventEmitter::new(state.app_handle.clone());
#[cfg(feature = "headless")]
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
let task = TaskRow {
id: event_id.clone(),
task_type: "encode_video_subtitle".to_string(),
@@ -956,14 +969,14 @@ pub async fn import_external_video(
#[cfg(feature = "headless")]
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
let source_path = Path::new(&file_path);
if !source_path.exists() {
return Err("文件不存在".to_string());
}
reporter.update("正在提取视频元数据...");
reporter.update("正在提取视频元数据...").await;
let metadata = ffmpeg::extract_video_metadata(source_path).await?;
let output_str = state.config.read().await.output.clone();
let output_dir = Path::new(&output_str);
@@ -989,7 +1002,7 @@ pub async fn import_external_video(
let final_target_full_path = if need_conversion {
let mp4_target_full_path = target_full_path.with_extension("mp4");
reporter.update("准备转换视频格式 (FLV → MP4)...");
reporter.update("准备转换视频格式 (FLV → MP4)...").await;
copy_and_convert_with_progress(source_path, &mp4_target_full_path, true, &reporter).await?;
@@ -1008,7 +1021,7 @@ pub async fn import_external_video(
};
// 步骤3: 生成缩略图
reporter.update("正在生成视频缩略图...");
reporter.update("正在生成视频缩略图...").await;
// 生成缩略图,使用智能时间点选择
let thumbnail_timestamp = get_optimal_thumbnail_timestamp(metadata.duration);
@@ -1022,7 +1035,7 @@ pub async fn import_external_video(
};
// 步骤4: 保存到数据库
reporter.update("正在保存视频信息...");
reporter.update("正在保存视频信息...").await;
let Ok(size) = i64::try_from(
final_target_full_path
@@ -1091,7 +1104,7 @@ pub async fn clip_video(
let emitter = EventEmitter::new(state.app_handle.clone());
#[cfg(feature = "headless")]
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
let reporter = ProgressReporter::new(&emitter, &event_id).await?;
let reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
// 创建任务记录
let task = TaskRow {
@@ -1185,7 +1198,7 @@ async fn clip_video_inner(
let output_full_path = output_dir.join(&output_filename);
// 执行切片
reporter.update("开始切片处理");
reporter.update("开始切片处理").await;
ffmpeg::clip_from_video_file(
Some(reporter),
&input_path,
@@ -1314,7 +1327,7 @@ pub async fn batch_import_external_videos(
let emitter = EventEmitter::new(state.app_handle.clone());
#[cfg(feature = "headless")]
let emitter = EventEmitter::new(state.progress_manager.get_event_sender());
let batch_reporter = ProgressReporter::new(&emitter, &event_id).await?;
let batch_reporter = ProgressReporter::new(state.db.clone(), &emitter, &event_id).await?;
let total_files = file_paths.len();
@@ -1327,9 +1340,11 @@ pub async fn batch_import_external_videos(
.to_string();
// 更新批量进度,只显示进度信息
batch_reporter.update(&format!(
"正在导入第{current_index}个,共{total_files}个文件"
));
batch_reporter
.update(&format!(
"正在导入第{current_index}个,共{total_files}个文件"
))
.await;
// 为每个文件创建独立的事件ID
let file_event_id = format!("{event_id}_file_{index}");

View File

@@ -14,6 +14,7 @@ mod progress;
mod recorder_manager;
mod state;
mod subtitle_generator;
mod task;
#[cfg(feature = "gui")]
mod tray;
mod webhook;
@@ -421,6 +422,7 @@ impl MigrationSource<'static> for MigrationList {
async fn setup_server_state(args: Args) -> Result<State, Box<dyn std::error::Error>> {
use std::path::PathBuf;
use crate::task::TaskManager;
use progress::progress_manager::ProgressManager;
use progress::progress_reporter::EventEmitter;
@@ -467,10 +469,14 @@ async fn setup_server_state(args: Args) -> Result<State, Box<dyn std::error::Err
let emitter = EventEmitter::new(progress_manager.get_event_sender());
let webhook_poster =
webhook::poster::create_webhook_poster(&config.read().await.webhook_url, None).unwrap();
let mut task_manager = TaskManager::new();
task_manager.start();
let task_manager = Arc::new(task_manager);
let recorder_manager = Arc::new(RecorderManager::new(
emitter,
db.clone(),
config.clone(),
task_manager.clone(),
webhook_poster.clone(),
));
@@ -485,6 +491,7 @@ async fn setup_server_state(args: Args) -> Result<State, Box<dyn std::error::Err
config,
webhook_poster,
recorder_manager,
task_manager,
progress_manager,
readonly: args.readonly,
})
@@ -495,6 +502,8 @@ async fn setup_app_state(app: &tauri::App) -> Result<State, Box<dyn std::error::
use platform_dirs::AppDirs;
use progress::progress_reporter::EventEmitter;
use crate::task::TaskManager;
let log_dir = app.path().app_log_dir()?;
setup_logging(&log_dir).await?;
@@ -527,12 +536,17 @@ async fn setup_app_state(app: &tauri::App) -> Result<State, Box<dyn std::error::
db_clone.finish_pending_tasks().await?;
let webhook_poster =
webhook::poster::create_webhook_poster(&config.read().await.webhook_url, None).unwrap();
let mut task_manager = TaskManager::new();
task_manager.start();
let task_manager = Arc::new(task_manager);
let recorder_manager = Arc::new(RecorderManager::new(
app.app_handle().clone(),
emitter,
db.clone(),
config.clone(),
task_manager.clone(),
webhook_poster.clone(),
));
@@ -551,6 +565,7 @@ async fn setup_app_state(app: &tauri::App) -> Result<State, Box<dyn std::error::
db,
config,
recorder_manager,
task_manager,
app_handle: app.handle().clone(),
webhook_poster,
})

View File

@@ -1,11 +1,10 @@
use async_trait::async_trait;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::sync::LazyLock;
use tokio::sync::RwLock;
use recorder::events::RecorderEvent;
use crate::database::Database;
#[cfg(feature = "gui")]
use {
recorder::danmu::DanmuEntry,
@@ -16,22 +15,16 @@ use {
#[cfg(feature = "headless")]
use tokio::sync::broadcast;
type CancelFlagMap = std::collections::HashMap<String, Arc<AtomicBool>>;
static CANCEL_FLAG_MAP: LazyLock<Arc<RwLock<CancelFlagMap>>> =
LazyLock::new(|| Arc::new(RwLock::new(CancelFlagMap::new())));
#[derive(Clone)]
pub struct ProgressReporter {
emitter: EventEmitter,
pub event_id: String,
#[allow(unused)]
pub cancel: Arc<AtomicBool>,
db: Arc<Database>,
}
#[async_trait]
pub trait ProgressReporterTrait: Send + Sync + Clone {
fn update(&self, content: &str);
async fn update(&self, content: &str);
async fn finish(&self, success: bool, message: &str);
}
@@ -119,39 +112,30 @@ impl EventEmitter {
}
}
impl ProgressReporter {
pub async fn new(emitter: &EventEmitter, event_id: &str) -> Result<Self, String> {
// if already exists, return
if CANCEL_FLAG_MAP.read().await.get(event_id).is_some() {
log::error!("Task already exists: {event_id}");
emitter.emit(&RecorderEvent::ProgressFinished {
id: event_id.to_string(),
success: false,
message: "任务已经存在".to_string(),
});
return Err("任务已经存在".to_string());
}
let cancel = Arc::new(AtomicBool::new(false));
CANCEL_FLAG_MAP
.write()
.await
.insert(event_id.to_string(), cancel.clone());
pub async fn new(
db: Arc<Database>,
emitter: &EventEmitter,
event_id: &str,
) -> Result<Self, String> {
Ok(Self {
db,
emitter: emitter.clone(),
event_id: event_id.to_string(),
cancel,
})
}
}
#[async_trait]
impl ProgressReporterTrait for ProgressReporter {
fn update(&self, content: &str) {
async fn update(&self, content: &str) {
self.emitter.emit(&RecorderEvent::ProgressUpdate {
id: self.event_id.clone(),
content: content.to_string(),
});
let _ = self
.db
.update_task(&self.event_id, "processing", content, None)
.await;
}
async fn finish(&self, success: bool, message: &str) {
@@ -160,13 +144,5 @@ impl ProgressReporterTrait for ProgressReporter {
success,
message: message.to_string(),
});
CANCEL_FLAG_MAP.write().await.remove(&self.event_id);
}
}
pub async fn cancel_progress(event_id: &str) {
let mut cancel_flag_map = CANCEL_FLAG_MAP.write().await;
if let Some(cancel_flag) = cancel_flag_map.get_mut(event_id) {
cancel_flag.store(true, std::sync::atomic::Ordering::Relaxed);
}
}

View File

@@ -7,6 +7,7 @@ use crate::database::{Database, DatabaseError};
use crate::ffmpeg::{encode_video_danmu, transcode, Range};
use crate::progress::progress_reporter::{EventEmitter, ProgressReporter, ProgressReporterTrait};
use crate::subtitle_generator::item_to_srt;
use crate::task::{Task, TaskManager, TaskPriority};
use crate::webhook::events::{self, Payload};
use crate::webhook::poster::WebhookPoster;
use chrono::DateTime;
@@ -123,6 +124,7 @@ pub struct RecorderManager {
emitter: EventEmitter,
db: Arc<Database>,
config: Arc<RwLock<Config>>,
task_manager: Arc<TaskManager>,
recorders: Arc<RwLock<HashMap<String, RecorderType>>>,
to_remove: Arc<RwLock<HashSet<String>>>,
event_tx: broadcast::Sender<RecorderEvent>,
@@ -176,6 +178,7 @@ impl RecorderManager {
emitter: EventEmitter,
db: Arc<Database>,
config: Arc<RwLock<Config>>,
task_manager: Arc<TaskManager>,
webhook_poster: WebhookPoster,
) -> RecorderManager {
let (event_tx, _) = broadcast::channel(100);
@@ -185,6 +188,7 @@ impl RecorderManager {
emitter,
db,
config,
task_manager,
recorders: Arc::new(RwLock::new(HashMap::new())),
to_remove: Arc::new(RwLock::new(HashSet::new())),
event_tx,
@@ -360,7 +364,8 @@ impl RecorderManager {
return;
};
let Ok(reporter) = ProgressReporter::new(&self.emitter, &task.id).await else {
let Ok(reporter) = ProgressReporter::new(self.db.clone(), &self.emitter, &task.id).await
else {
log::error!("Failed to create reporter");
let _ = self
.db
@@ -371,43 +376,56 @@ impl RecorderManager {
log::info!("Create task: {} {}", task.id, task.task_type);
if let Err(e) = self
.generate_whole_clip(
Some(&reporter),
self.config.read().await.auto_generate.encode_danmu,
platform.as_str().to_string(),
room_id,
live_record.parent_id,
)
.await
{
log::error!("Failed to generate whole clip: {e}");
let _ = reporter
.finish(false, &format!("Failed to generate whole clip: {e}"))
.await;
let _ = self
.db
.update_task(
&task.id,
"failed",
&format!("Failed to generate whole clip: {e}"),
None,
)
.await;
return;
}
let _ = reporter
.finish(true, "Whole clip generated successfully")
.await;
let self_clone = self.clone();
let task_id = task.id.clone();
let room_id = room_id.to_string();
let _ = self
.db
.update_task(
&task.id,
"success",
"Whole clip generated successfully",
None,
)
.task_manager
.add_task(Task::new(
task_id.clone(),
TaskPriority::Normal,
async move {
if let Err(e) = self_clone
.generate_whole_clip(
Some(&reporter),
self_clone.config.read().await.auto_generate.encode_danmu,
platform.as_str().to_string(),
&room_id,
live_record.parent_id,
)
.await
{
log::error!("Failed to generate whole clip: {e}");
let _ = reporter
.finish(false, &format!("Failed to generate whole clip: {e}"))
.await;
let _ = self_clone
.db
.update_task(
&task_id,
"failed",
&format!("Failed to generate whole clip: {e}"),
None,
)
.await;
return Err(format!("Failed to generate whole clip: {e}"));
}
let _ = reporter
.finish(true, "Whole clip generated successfully")
.await;
let _ = self_clone
.db
.update_task(
&task_id,
"success",
"Whole clip generated successfully",
None,
)
.await;
Ok(())
},
))
.await;
}

View File

@@ -4,6 +4,7 @@ use tokio::sync::RwLock;
use crate::config::Config;
use crate::database::Database;
use crate::recorder_manager::RecorderManager;
use crate::task::TaskManager;
use crate::webhook::poster::WebhookPoster;
#[cfg(feature = "headless")]
@@ -15,6 +16,7 @@ pub struct State {
pub config: Arc<RwLock<Config>>,
pub webhook_poster: WebhookPoster,
pub recorder_manager: Arc<RecorderManager>,
pub task_manager: Arc<TaskManager>,
#[cfg(not(feature = "headless"))]
pub app_handle: tauri::AppHandle,
#[cfg(feature = "headless")]

View File

@@ -106,7 +106,7 @@ impl SubtitleGeneratorType {
pub trait SubtitleGenerator {
async fn generate_subtitle(
&self,
reporter: Option<&impl ProgressReporterTrait>,
reporter: Option<&(impl ProgressReporterTrait + 'static)>,
audio_path: &Path,
language_hint: &str,
) -> Result<GenerateResult, String>;

View File

@@ -36,7 +36,7 @@ pub async fn new(model: &Path, prompt: &str) -> Result<WhisperCPP, String> {
impl SubtitleGenerator for WhisperCPP {
async fn generate_subtitle(
&self,
reporter: Option<&impl ProgressReporterTrait>,
reporter: Option<&(impl ProgressReporterTrait + 'static)>,
audio_path: &Path,
language_hint: &str,
) -> Result<GenerateResult, String> {
@@ -71,7 +71,7 @@ impl SubtitleGenerator for WhisperCPP {
let mut inter_samples = vec![Default::default(); samples.len()];
if let Some(reporter) = reporter {
reporter.update("处理音频中");
reporter.update("处理音频中").await;
}
if let Err(e) = whisper_rs::convert_integer_to_float_audio(&samples, &mut inter_samples) {
return Err(e.to_string());
@@ -85,7 +85,7 @@ impl SubtitleGenerator for WhisperCPP {
let samples = samples.unwrap();
if let Some(reporter) = reporter {
reporter.update("生成字幕中");
reporter.update("生成字幕中").await;
}
if let Err(e) = state.full(params, &samples[..]) {
log::error!("failed to run model: {e}");
@@ -143,14 +143,14 @@ mod tests {
struct MockReporter {}
impl MockReporter {
#[allow(dead_code)]
fn update(&self, _message: &str) {
async fn update(&self, _message: &str) {
// mock implementation
}
}
#[async_trait]
impl ProgressReporterTrait for MockReporter {
fn update(&self, message: &str) {
async fn update(&self, message: &str) {
println!("Mock update: {message}");
}

View File

@@ -54,7 +54,7 @@ pub async fn new(
impl SubtitleGenerator for WhisperOnline {
async fn generate_subtitle(
&self,
reporter: Option<&impl ProgressReporterTrait>,
reporter: Option<&(impl ProgressReporterTrait + 'static)>,
audio_path: &Path,
language_hint: &str,
) -> Result<GenerateResult, String> {
@@ -63,7 +63,7 @@ impl SubtitleGenerator for WhisperOnline {
// Read audio file
if let Some(reporter) = reporter {
reporter.update("读取音频文件中");
reporter.update("读取音频文件中").await;
}
let audio_data = fs::read(audio_path)
.await
@@ -115,7 +115,7 @@ impl SubtitleGenerator for WhisperOnline {
}
if let Some(reporter) = reporter {
reporter.update("上传音频中");
reporter.update("上传音频中").await;
}
let response = req_builder
.timeout(std::time::Duration::from_secs(3 * 60)) // 3 minutes timeout
@@ -195,7 +195,7 @@ mod tests {
#[async_trait]
impl ProgressReporterTrait for MockReporter {
fn update(&self, message: &str) {
async fn update(&self, message: &str) {
println!("Mock update: {message}");
}

437
src-tauri/src/task/mod.rs Normal file
View File

@@ -0,0 +1,437 @@
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use tokio::task::JoinHandle;
use tokio::time::{interval, Duration};
/// Task execution function type
pub type TaskFn = Pin<Box<dyn Future<Output = Result<(), String>> + Send + 'static>>;
/// Task status enumeration
#[derive(Debug, Clone, PartialEq)]
pub enum TaskStatus {
Pending,
Running,
Completed,
Failed(String),
Cancelled,
}
/// Task priority levels
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[allow(dead_code)]
pub enum TaskPriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
/// Represents a single task
#[allow(dead_code)]
pub struct Task {
pub task_id: String,
pub priority: TaskPriority,
pub status: TaskStatus,
pub created_at: std::time::Instant,
task_fn: Option<TaskFn>,
cancel_handle: Option<tokio::sync::oneshot::Sender<()>>,
}
impl Task {
pub fn new<F>(task_id: String, priority: TaskPriority, task_fn: F) -> Self
where
F: Future<Output = Result<(), String>> + Send + 'static,
{
Task {
task_id,
priority,
status: TaskStatus::Pending,
created_at: std::time::Instant::now(),
task_fn: Some(Box::pin(task_fn)),
cancel_handle: None,
}
}
#[allow(dead_code)]
pub fn new_with_default_priority<F>(task_id: String, task_fn: F) -> Self
where
F: Future<Output = Result<(), String>> + Send + 'static,
{
Self::new(task_id, TaskPriority::Normal, task_fn)
}
}
/// Task manager configuration
#[derive(Debug, Clone)]
pub struct TaskManagerConfig {
/// Maximum number of tasks that can run concurrently
pub max_concurrent_tasks: usize,
/// Interval for checking the queue (in milliseconds)
pub check_interval_ms: u64,
/// Maximum queue size (0 = unlimited)
pub max_queue_size: usize,
}
impl Default for TaskManagerConfig {
fn default() -> Self {
TaskManagerConfig {
max_concurrent_tasks: 1,
check_interval_ms: 1000,
max_queue_size: 20,
}
}
}
/// Main task manager structure
pub struct TaskManager {
config: TaskManagerConfig,
queue: Arc<Mutex<VecDeque<Task>>>,
running_tasks: Arc<RwLock<HashMap<String, JoinHandle<()>>>>,
task_statuses: Arc<RwLock<HashMap<String, TaskStatus>>>,
scheduler_handle: Option<JoinHandle<()>>,
shutdown_tx: Option<tokio::sync::broadcast::Sender<()>>,
}
impl TaskManager {
/// Create a new task manager with default configuration
pub fn new() -> Self {
Self::with_config(TaskManagerConfig::default())
}
/// Create a new task manager with custom configuration
pub fn with_config(config: TaskManagerConfig) -> Self {
TaskManager {
config,
queue: Arc::new(Mutex::new(VecDeque::new())),
running_tasks: Arc::new(RwLock::new(HashMap::new())),
task_statuses: Arc::new(RwLock::new(HashMap::new())),
scheduler_handle: None,
shutdown_tx: None,
}
}
/// Start the task manager's scheduler
pub fn start(&mut self) {
let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
self.shutdown_tx = Some(shutdown_tx.clone());
let queue = Arc::clone(&self.queue);
let running_tasks = Arc::clone(&self.running_tasks);
let task_statuses = Arc::clone(&self.task_statuses);
let config = self.config.clone();
let mut shutdown_rx = shutdown_tx.subscribe();
let handle = tokio::spawn(async move {
let mut check_interval = interval(Duration::from_millis(config.check_interval_ms));
loop {
tokio::select! {
_ = check_interval.tick() => {
Self::process_queue(
&queue,
&running_tasks,
&task_statuses,
config.max_concurrent_tasks,
)
.await;
}
_ = shutdown_rx.recv() => {
break;
}
}
}
});
self.scheduler_handle = Some(handle);
}
/// Stop the task manager's scheduler
#[allow(dead_code)]
pub async fn stop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(handle) = self.scheduler_handle.take() {
let _ = handle.await;
}
// Cancel all running tasks
let running_tasks = self.running_tasks.read().await;
for handle in running_tasks.values() {
handle.abort();
}
}
/// Add a task to the queue
pub async fn add_task(&self, task: Task) -> Result<(), String> {
let mut queue = self.queue.lock().await;
if self.config.max_queue_size > 0 && queue.len() >= self.config.max_queue_size {
return Err("Queue is full".to_string());
}
let task_id = task.task_id.clone();
// Insert task based on priority
let insert_pos = queue
.iter()
.position(|t| t.priority < task.priority)
.unwrap_or(queue.len());
queue.insert(insert_pos, task);
// Update task status
let mut statuses = self.task_statuses.write().await;
statuses.insert(task_id, TaskStatus::Pending);
Ok(())
}
/// Cancel a task by ID
pub async fn cancel_task(&self, task_id: &str) -> Result<(), String> {
// Check if task is running
let mut running_tasks = self.running_tasks.write().await;
if let Some(handle) = running_tasks.remove(task_id) {
handle.abort();
drop(running_tasks); // Release lock before await
let mut statuses = self.task_statuses.write().await;
statuses.insert(task_id.to_string(), TaskStatus::Cancelled);
return Ok(());
}
drop(running_tasks); // Release lock before await
// Check if task is in queue
let pos = {
let queue = self.queue.lock().await;
queue.iter().position(|t| t.task_id == task_id)
};
if let Some(pos) = pos {
{
let mut queue = self.queue.lock().await;
queue.remove(pos);
} // Release queue lock before await
let mut statuses = self.task_statuses.write().await;
statuses.insert(task_id.to_string(), TaskStatus::Cancelled);
return Ok(());
}
Err("Task not found".to_string())
}
/// Get the status of a task
#[allow(dead_code)]
pub async fn get_task_status(&self, task_id: &str) -> Option<TaskStatus> {
let statuses = self.task_statuses.read().await;
statuses.get(task_id).cloned()
}
/// Get all task statuses
#[allow(dead_code)]
pub async fn get_all_task_statuses(&self) -> HashMap<String, TaskStatus> {
let statuses = self.task_statuses.read().await;
statuses.clone()
}
/// Get the number of tasks in queue
#[allow(dead_code)]
pub async fn queue_size(&self) -> usize {
let queue = self.queue.lock().await;
queue.len()
}
/// Get the number of running tasks
#[allow(dead_code)]
pub async fn running_count(&self) -> usize {
let running = self.running_tasks.read().await;
running.len()
}
/// Clear all completed and failed tasks from status tracking
#[allow(dead_code)]
pub async fn clear_finished_tasks(&self) {
let mut statuses = self.task_statuses.write().await;
statuses.retain(|_, status| {
!matches!(
status,
TaskStatus::Completed | TaskStatus::Failed(_) | TaskStatus::Cancelled
)
});
}
/// Process the queue and start tasks if slots are available
async fn process_queue(
queue: &Arc<Mutex<VecDeque<Task>>>,
running_tasks: &Arc<RwLock<HashMap<String, JoinHandle<()>>>>,
task_statuses: &Arc<RwLock<HashMap<String, TaskStatus>>>,
max_concurrent: usize,
) {
// Check if we have available slots
let running_count = {
let running = running_tasks.read().await;
running.len()
};
if running_count >= max_concurrent {
return;
}
// Get the next task from queue
let mut task_opt = {
let mut queue = queue.lock().await;
queue.pop_front()
};
if let Some(mut task) = task_opt.take() {
let task_id = task.task_id.clone();
// Update status to running
{
let mut statuses = task_statuses.write().await;
statuses.insert(task_id.clone(), TaskStatus::Running);
}
// Take the task function
if let Some(task_fn) = task.task_fn.take() {
let task_id_clone = task_id.clone();
let running_tasks_clone = Arc::clone(running_tasks);
let task_statuses_clone = Arc::clone(task_statuses);
// Spawn the task
let handle = tokio::spawn(async move {
let result = task_fn.await;
// Update status based on result
let mut statuses = task_statuses_clone.write().await;
match result {
Ok(_) => {
statuses.insert(task_id_clone.clone(), TaskStatus::Completed);
}
Err(e) => {
statuses.insert(task_id_clone.clone(), TaskStatus::Failed(e));
}
}
// Remove from running tasks
let mut running = running_tasks_clone.write().await;
running.remove(&task_id_clone);
});
// Add to running tasks
let mut running = running_tasks.write().await;
running.insert(task_id, handle);
}
}
// Clean up completed tasks from running list
Self::cleanup_finished_tasks(running_tasks).await;
}
/// Clean up finished task handles
async fn cleanup_finished_tasks(running_tasks: &Arc<RwLock<HashMap<String, JoinHandle<()>>>>) {
let mut running = running_tasks.write().await;
let mut finished_ids = Vec::new();
for (task_id, handle) in running.iter() {
if handle.is_finished() {
finished_ids.push(task_id.clone());
}
}
for task_id in finished_ids {
running.remove(&task_id);
}
}
}
impl Default for TaskManager {
fn default() -> Self {
Self::new()
}
}
impl Drop for TaskManager {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::sleep;
#[tokio::test]
async fn test_task_manager_basic() {
let mut manager = TaskManager::new();
manager.start();
let task = Task::new("test-1".to_string(), TaskPriority::Normal, async {
sleep(Duration::from_millis(100)).await;
Ok(())
});
manager.add_task(task).await.unwrap();
sleep(Duration::from_millis(200)).await;
let status = manager.get_task_status("test-1").await;
assert_eq!(status, Some(TaskStatus::Completed));
manager.stop().await;
}
#[tokio::test]
async fn test_task_cancellation() {
let mut manager = TaskManager::new();
manager.start();
let task = Task::new("test-cancel".to_string(), TaskPriority::Normal, async {
sleep(Duration::from_secs(10)).await;
Ok(())
});
manager.add_task(task).await.unwrap();
sleep(Duration::from_millis(100)).await;
manager.cancel_task("test-cancel").await.unwrap();
let status = manager.get_task_status("test-cancel").await;
assert_eq!(status, Some(TaskStatus::Cancelled));
manager.stop().await;
}
#[tokio::test]
async fn test_task_priority() {
let mut manager = TaskManager::with_config(TaskManagerConfig {
max_concurrent_tasks: 1,
check_interval_ms: 100,
max_queue_size: 10,
});
manager.start();
// Add low priority task
let task1 = Task::new("low".to_string(), TaskPriority::Low, async {
sleep(Duration::from_millis(100)).await;
Ok(())
});
// Add high priority task
let task2 = Task::new("high".to_string(), TaskPriority::High, async {
sleep(Duration::from_millis(100)).await;
Ok(())
});
manager.add_task(task1).await.unwrap();
manager.add_task(task2).await.unwrap();
sleep(Duration::from_millis(500)).await;
manager.stop().await;
}
}