mirror of
https://github.com/sun-guannan/CapCutAPI.git
synced 2025-11-24 19:13:01 +08:00
save draft local
This commit is contained in:
20
oss.py
20
oss.py
@@ -4,28 +4,28 @@ import os
|
||||
from settings.local import OSS_CONFIG, MP4_OSS_CONFIG
|
||||
|
||||
def upload_to_oss(path):
|
||||
# 创建OSS客户端
|
||||
# Create OSS client
|
||||
auth = oss2.Auth(OSS_CONFIG['access_key_id'], OSS_CONFIG['access_key_secret'])
|
||||
bucket = oss2.Bucket(auth, OSS_CONFIG['endpoint'], OSS_CONFIG['bucket_name'])
|
||||
|
||||
# 上传文件
|
||||
# Upload file
|
||||
object_name = os.path.basename(path)
|
||||
bucket.put_object_from_file(object_name, path)
|
||||
|
||||
# 生成签名URL(24小时有效)
|
||||
# Generate signed URL (valid for 24 hours)
|
||||
url = bucket.sign_url('GET', object_name, 24 * 60 * 60)
|
||||
|
||||
# 清理临时文件
|
||||
# Clean up temporary file
|
||||
os.remove(path)
|
||||
|
||||
return url
|
||||
|
||||
def upload_mp4_to_oss(path):
|
||||
"""专门用于上传MP4文件的方法,使用自定义域名和v4签名"""
|
||||
# 直接使用配置文件中的凭证
|
||||
"""Special method for uploading MP4 files, using custom domain and v4 signature"""
|
||||
# Directly use credentials from the configuration file
|
||||
auth = oss2.AuthV4(MP4_OSS_CONFIG['access_key_id'], MP4_OSS_CONFIG['access_key_secret'])
|
||||
|
||||
# 创建OSS客户端,使用自定义域名
|
||||
# Create OSS client with custom domain
|
||||
bucket = oss2.Bucket(
|
||||
auth,
|
||||
MP4_OSS_CONFIG['endpoint'],
|
||||
@@ -34,14 +34,14 @@ def upload_mp4_to_oss(path):
|
||||
is_cname=True
|
||||
)
|
||||
|
||||
# 上传文件
|
||||
# Upload file
|
||||
object_name = os.path.basename(path)
|
||||
bucket.put_object_from_file(object_name, path)
|
||||
|
||||
# 生成预签名URL(24小时有效),设置slash_safe为True避免路径转义
|
||||
# Generate pre-signed URL (valid for 24 hours), set slash_safe to True to avoid path escaping
|
||||
url = bucket.sign_url('GET', object_name, 24 * 60 * 60, slash_safe=True)
|
||||
|
||||
# 清理临时文件
|
||||
# Clean up temporary file
|
||||
os.remove(path)
|
||||
|
||||
return url
|
||||
@@ -19,109 +19,100 @@ from collections import OrderedDict
|
||||
import time
|
||||
import requests # Import requests for making HTTP calls
|
||||
import logging
|
||||
# 导入配置
|
||||
# Import configuration
|
||||
from settings import IS_CAPCUT_ENV, IS_UPLOAD_DRAFT
|
||||
|
||||
# --- 获取你的 Logger 实例 ---
|
||||
# 这里的名称必须和你在 app.py 中配置的 logger 名称一致
|
||||
# --- Get your Logger instance ---
|
||||
# The name here must match the logger name you configured in app.py
|
||||
logger = logging.getLogger('flask_video_generator')
|
||||
|
||||
# 定义任务状态枚举类型
|
||||
# Define task status enumeration type
|
||||
TaskStatus = Literal["initialized", "processing", "completed", "failed", "not_found"]
|
||||
|
||||
def build_asset_path(draft_folder: str, draft_id: str, asset_type: str, material_name: str) -> str:
|
||||
"""
|
||||
构建资源文件路径
|
||||
:param draft_folder: 草稿文件夹路径
|
||||
:param draft_id: 草稿ID
|
||||
:param asset_type: 资源类型(audio, image, video)
|
||||
:param material_name: 素材名称
|
||||
:return: 构建好的路径
|
||||
Build asset file path
|
||||
:param draft_folder: Draft folder path
|
||||
:param draft_id: Draft ID
|
||||
:param asset_type: Asset type (audio, image, video)
|
||||
:param material_name: Material name
|
||||
:return: Built path
|
||||
"""
|
||||
if is_windows_path(draft_folder):
|
||||
# Windows路径处理
|
||||
if os.name == 'nt': # 'nt' for Windows
|
||||
draft_real_path = os.path.join(draft_folder, draft_id, "assets", asset_type, material_name)
|
||||
else:
|
||||
windows_drive, windows_path = re.match(r'([a-zA-Z]:)(.*)', draft_folder).groups()
|
||||
parts = [p for p in windows_path.split('\\') if p]
|
||||
draft_real_path = os.path.join(windows_drive, *parts, draft_id, "assets", asset_type, material_name)
|
||||
# 规范化路径(确保分隔符一致)
|
||||
draft_real_path = draft_real_path.replace('/', '\\')
|
||||
else:
|
||||
# macOS/Linux路径处理
|
||||
draft_real_path = os.path.join(draft_folder, draft_id, "assets", asset_type, material_name)
|
||||
return draft_real_path
|
||||
|
||||
def save_draft_background(draft_id, draft_folder, task_id):
|
||||
"""后台保存草稿到OSS"""
|
||||
"""Background save draft to OSS"""
|
||||
try:
|
||||
# 从全局缓存中获取草稿信息
|
||||
# Get draft information from global cache
|
||||
if draft_id not in DRAFT_CACHE:
|
||||
task_status = {
|
||||
"status": "failed",
|
||||
"message": f"草稿 {draft_id} 不存在于缓存中",
|
||||
"message": f"Draft {draft_id} does not exist in cache",
|
||||
"progress": 0,
|
||||
"completed_files": 0,
|
||||
"total_files": 0,
|
||||
"draft_url": ""
|
||||
}
|
||||
update_tasks_cache(task_id, task_status) # 使用新的缓存管理函数
|
||||
logger.error(f"草稿 {draft_id} 不存在于缓存中,任务 {task_id} 失败。")
|
||||
update_tasks_cache(task_id, task_status) # Use new cache management function
|
||||
logger.error(f"Draft {draft_id} does not exist in cache, task {task_id} failed.")
|
||||
return
|
||||
|
||||
script = DRAFT_CACHE[draft_id]
|
||||
logger.info(f"成功从缓存获取草稿 {draft_id}。")
|
||||
logger.info(f"Successfully retrieved draft {draft_id} from cache.")
|
||||
|
||||
# 更新任务状态为处理中
|
||||
# Update task status to processing
|
||||
task_status = {
|
||||
"status": "processing",
|
||||
"message": "正在准备草稿文件",
|
||||
"message": "Preparing draft files",
|
||||
"progress": 0,
|
||||
"completed_files": 0,
|
||||
"total_files": 0,
|
||||
"draft_url": ""
|
||||
}
|
||||
update_tasks_cache(task_id, task_status) # 使用新的缓存管理函数
|
||||
logger.info(f"任务 {task_id} 状态更新为 'processing':正在准备草稿文件。")
|
||||
update_tasks_cache(task_id, task_status) # Use new cache management function
|
||||
logger.info(f"Task {task_id} status updated to 'processing': Preparing draft files.")
|
||||
|
||||
# 删除可能已存在的draft_id文件夹
|
||||
# Delete possibly existing draft_id folder
|
||||
if os.path.exists(draft_id):
|
||||
logger.warning(f"删除已存在的草稿文件夹 (当前工作目录): {draft_id}")
|
||||
logger.warning(f"Deleting existing draft folder (current working directory): {draft_id}")
|
||||
shutil.rmtree(draft_id)
|
||||
|
||||
logger.info(f"开始保存草稿: {draft_id}")
|
||||
# 保存草稿
|
||||
logger.info(f"Starting to save draft: {draft_id}")
|
||||
# Save draft
|
||||
draft_folder_for_duplicate = draft.Draft_folder("./")
|
||||
# 根据配置选择不同的模板目录
|
||||
# Choose different template directory based on configuration
|
||||
template_dir = "template" if IS_CAPCUT_ENV else "template_jianying"
|
||||
draft_folder_for_duplicate.duplicate_as_template(template_dir, draft_id)
|
||||
|
||||
# 更新任务状态
|
||||
update_task_field(task_id, "message", "正在更新媒体文件元数据")
|
||||
# Update task status
|
||||
update_task_field(task_id, "message", "Updating media file metadata")
|
||||
update_task_field(task_id, "progress", 5)
|
||||
logger.info(f"任务 {task_id} 进度5%:正在更新媒体文件元数据。")
|
||||
logger.info(f"Task {task_id} progress 5%: Updating media file metadata.")
|
||||
|
||||
# 调用公共方法更新媒体文件元数据
|
||||
update_media_metadata(script, task_id)
|
||||
|
||||
# 收集下载任务
|
||||
download_tasks = []
|
||||
|
||||
# 收集音频下载任务
|
||||
audios = script.materials.audios
|
||||
if audios:
|
||||
for audio in audios:
|
||||
remote_url = audio.remote_url
|
||||
material_name = audio.material_name
|
||||
# 使用辅助函数构建路径
|
||||
if draft_folder:
|
||||
audio.replace_path = build_asset_path(draft_folder, draft_id, "audio", material_name)
|
||||
if not remote_url:
|
||||
logger.warning(f"音频文件 {material_name} 没有 remote_url,跳过下载。")
|
||||
logger.warning(f"Audio file {material_name} has no remote_url, skipping download.")
|
||||
continue
|
||||
|
||||
# 添加音频下载任务
|
||||
# Add audio download task
|
||||
download_tasks.append({
|
||||
'type': 'audio',
|
||||
'func': download_file,
|
||||
@@ -129,7 +120,7 @@ def save_draft_background(draft_id, draft_folder, task_id):
|
||||
'material': audio
|
||||
})
|
||||
|
||||
# 收集视频和图片下载任务
|
||||
# Collect video and image download tasks
|
||||
videos = script.materials.videos
|
||||
if videos:
|
||||
for video in videos:
|
||||
@@ -137,14 +128,14 @@ def save_draft_background(draft_id, draft_folder, task_id):
|
||||
material_name = video.material_name
|
||||
|
||||
if video.material_type == 'photo':
|
||||
# 使用辅助函数构建路径
|
||||
# Use helper function to build path
|
||||
if draft_folder:
|
||||
video.replace_path = build_asset_path(draft_folder, draft_id, "image", material_name)
|
||||
if not remote_url:
|
||||
logger.warning(f"图片文件 {material_name} 没有 remote_url,跳过下载。")
|
||||
logger.warning(f"Image file {material_name} has no remote_url, skipping download.")
|
||||
continue
|
||||
|
||||
# 添加图片下载任务
|
||||
# Add image download task
|
||||
download_tasks.append({
|
||||
'type': 'image',
|
||||
'func': download_file,
|
||||
@@ -153,14 +144,14 @@ def save_draft_background(draft_id, draft_folder, task_id):
|
||||
})
|
||||
|
||||
elif video.material_type == 'video':
|
||||
# 使用辅助函数构建路径
|
||||
# Use helper function to build path
|
||||
if draft_folder:
|
||||
video.replace_path = build_asset_path(draft_folder, draft_id, "video", material_name)
|
||||
if not remote_url:
|
||||
logger.warning(f"视频文件 {material_name} 没有 remote_url,跳过下载。")
|
||||
logger.warning(f"Video file {material_name} has no remote_url, skipping download.")
|
||||
continue
|
||||
|
||||
# 添加视频下载任务
|
||||
# Add video download task
|
||||
download_tasks.append({
|
||||
'type': 'video',
|
||||
'func': download_file,
|
||||
@@ -168,120 +159,120 @@ def save_draft_background(draft_id, draft_folder, task_id):
|
||||
'material': video
|
||||
})
|
||||
|
||||
update_task_field(task_id, "message", f"共收集到{len(download_tasks)}个下载任务")
|
||||
update_task_field(task_id, "message", f"Collected {len(download_tasks)} download tasks in total")
|
||||
update_task_field(task_id, "progress", 10)
|
||||
logger.info(f"任务 {task_id} 进度10%:共收集到 {len(download_tasks)} 个下载任务。")
|
||||
logger.info(f"Task {task_id} progress 10%: Collected {len(download_tasks)} download tasks in total.")
|
||||
|
||||
# 并发执行所有下载任务
|
||||
# Execute all download tasks concurrently
|
||||
downloaded_paths = []
|
||||
completed_files = 0
|
||||
if download_tasks:
|
||||
logger.info(f"开始并发下载 {len(download_tasks)} 个文件...")
|
||||
logger.info(f"Starting concurrent download of {len(download_tasks)} files...")
|
||||
|
||||
# 使用线程池并发下载,最大并发数为16
|
||||
# Use thread pool for concurrent downloads, maximum concurrency of 16
|
||||
with ThreadPoolExecutor(max_workers=16) as executor:
|
||||
# 提交所有下载任务
|
||||
# Submit all download tasks
|
||||
future_to_task = {
|
||||
executor.submit(task['func'], *task['args']): task
|
||||
for task in download_tasks
|
||||
}
|
||||
|
||||
# 等待所有任务完成
|
||||
# Wait for all tasks to complete
|
||||
for future in as_completed(future_to_task):
|
||||
task = future_to_task[future]
|
||||
try:
|
||||
local_path = future.result()
|
||||
downloaded_paths.append(local_path)
|
||||
|
||||
# 更新任务状态 - 只更新已完成文件数
|
||||
# Update task status - only update completed files count
|
||||
completed_files += 1
|
||||
update_task_field(task_id, "completed_files", completed_files)
|
||||
task_status = get_task_status(task_id)
|
||||
completed = task_status["completed_files"]
|
||||
total = len(download_tasks)
|
||||
update_task_field(task_id, "total_files", total)
|
||||
# 下载部分占总进度的60%
|
||||
# Download part accounts for 60% of the total progress
|
||||
download_progress = 10 + int((completed / total) * 60)
|
||||
update_task_field(task_id, "progress", download_progress)
|
||||
update_task_field(task_id, "message", f"已下载 {completed}/{total} 个文件")
|
||||
update_task_field(task_id, "message", f"Downloaded {completed}/{total} files")
|
||||
|
||||
logger.info(f"任务 {task_id}:成功下载 {task['type']} 文件,进度 {download_progress}。")
|
||||
logger.info(f"Task {task_id}: Successfully downloaded {task['type']} file, progress {download_progress}.")
|
||||
except Exception as e:
|
||||
logger.error(f"任务 {task_id}:下载 {task['type']} 文件, 失败: {str(e)}", exc_info=True)
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
logger.error(f"Task {task_id}: Download {task['type']} file failed: {str(e)}", exc_info=True)
|
||||
# Continue processing other files, don't interrupt the entire process
|
||||
|
||||
logger.info(f"任务 {task_id}:并发下载完成,共下载 {len(downloaded_paths)} 个文件。")
|
||||
logger.info(f"Task {task_id}: Concurrent download completed, downloaded {len(downloaded_paths)} files in total.")
|
||||
|
||||
# 更新任务状态 - 开始保存草稿信息
|
||||
# Update task status - Start saving draft information
|
||||
update_task_field(task_id, "progress", 70)
|
||||
update_task_field(task_id, "message", "正在保存草稿信息")
|
||||
logger.info(f"任务 {task_id} 进度70%:正在保存草稿信息。")
|
||||
update_task_field(task_id, "message", "Saving draft information")
|
||||
logger.info(f"Task {task_id} progress 70%: Saving draft information.")
|
||||
|
||||
script.dump(f"{draft_id}/draft_info.json")
|
||||
logger.info(f"草稿信息已保存到 {draft_id}/draft_info.json。")
|
||||
logger.info(f"Draft information has been saved to {draft_id}/draft_info.json.")
|
||||
|
||||
draft_url = ""
|
||||
# 仅在 IS_UPLOAD_DRAFT 为 True 时上传草稿信息
|
||||
# Only upload draft information when IS_UPLOAD_DRAFT is True
|
||||
if IS_UPLOAD_DRAFT:
|
||||
# 更新任务状态 - 开始压缩草稿
|
||||
# Update task status - Start compressing draft
|
||||
update_task_field(task_id, "progress", 80)
|
||||
update_task_field(task_id, "message", "正在压缩草稿文件")
|
||||
logger.info(f"任务 {task_id} 进度80%:正在压缩草稿文件。")
|
||||
update_task_field(task_id, "message", "Compressing draft files")
|
||||
logger.info(f"Task {task_id} progress 80%: Compressing draft files.")
|
||||
|
||||
# 压缩整个草稿目录
|
||||
# Compress the entire draft directory
|
||||
zip_path = zip_draft(draft_id)
|
||||
logger.info(f"草稿目录 {draft_id} 已压缩为 {zip_path}。")
|
||||
logger.info(f"Draft directory {draft_id} has been compressed to {zip_path}.")
|
||||
|
||||
# 更新任务状态 - 开始上传到OSS
|
||||
# Update task status - Start uploading to OSS
|
||||
update_task_field(task_id, "progress", 90)
|
||||
update_task_field(task_id, "message", "正在上传到云存储")
|
||||
logger.info(f"任务 {task_id} 进度90%:正在上传到云存储。")
|
||||
update_task_field(task_id, "message", "Uploading to cloud storage")
|
||||
logger.info(f"Task {task_id} progress 90%: Uploading to cloud storage.")
|
||||
|
||||
# 上传到OSS
|
||||
# Upload to OSS
|
||||
draft_url = upload_to_oss(zip_path)
|
||||
logger.info(f"草稿压缩包已上传到 OSS,URL: {draft_url}")
|
||||
logger.info(f"Draft archive has been uploaded to OSS, URL: {draft_url}")
|
||||
update_task_field(task_id, "draft_url", draft_url)
|
||||
|
||||
# 清理临时文件
|
||||
# Clean up temporary files
|
||||
if os.path.exists(draft_id):
|
||||
shutil.rmtree(draft_id)
|
||||
logger.info(f"已清理临时草稿文件夹: {draft_id}")
|
||||
logger.info(f"Cleaned up temporary draft folder: {draft_id}")
|
||||
|
||||
|
||||
# 更新任务状态 - 完成
|
||||
# Update task status - Completed
|
||||
update_task_field(task_id, "status", "completed")
|
||||
update_task_field(task_id, "progress", 100)
|
||||
update_task_field(task_id, "message", "草稿制作完成")
|
||||
logger.info(f"任务 {task_id} 已完成,草稿URL: {draft_url}")
|
||||
update_task_field(task_id, "message", "Draft creation completed")
|
||||
logger.info(f"Task {task_id} completed, draft URL: {draft_url}")
|
||||
return draft_url
|
||||
|
||||
except Exception as e:
|
||||
# 更新任务状态 - 失败
|
||||
# Update task status - Failed
|
||||
update_task_fields(task_id,
|
||||
status="failed",
|
||||
message=f"保存草稿失败: {str(e)}")
|
||||
logger.error(f"保存草稿 {draft_id} 任务 {task_id} 失败: {str(e)}", exc_info=True)
|
||||
message=f"Failed to save draft: {str(e)}")
|
||||
logger.error(f"Saving draft {draft_id} task {task_id} failed: {str(e)}", exc_info=True)
|
||||
return ""
|
||||
|
||||
def query_task_status(task_id: str):
|
||||
return get_task_status(task_id)
|
||||
|
||||
def save_draft_impl(draft_id: str, draft_folder: str = None) -> Dict[str, str]:
|
||||
"""启动保存草稿的后台任务"""
|
||||
logger.info(f"接收到保存草稿请求:draft_id={draft_id}, draft_folder={draft_folder}")
|
||||
"""Start a background task to save the draft"""
|
||||
logger.info(f"Received save draft request: draft_id={draft_id}, draft_folder={draft_folder}")
|
||||
try:
|
||||
# 生成唯一的任务ID
|
||||
# Generate a unique task ID
|
||||
task_id = draft_id
|
||||
create_task(task_id)
|
||||
logger.info(f"任务 {task_id} 已创建。")
|
||||
logger.info(f"Task {task_id} has been created.")
|
||||
|
||||
# 改为同步执行
|
||||
# Changed to synchronous execution
|
||||
return {
|
||||
"success": True,
|
||||
"draft_url": save_draft_background(draft_id, draft_folder, task_id)
|
||||
}
|
||||
|
||||
# # 启动后台线程执行任务
|
||||
# # Start a background thread to execute the task
|
||||
# thread = threading.Thread(
|
||||
# target=save_draft_background,
|
||||
# args=(draft_id, draft_folder, task_id)
|
||||
@@ -289,7 +280,7 @@ def save_draft_impl(draft_id: str, draft_folder: str = None) -> Dict[str, str]:
|
||||
# thread.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动保存草稿任务 {draft_id} 失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"Failed to start save draft task {draft_id}: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
@@ -297,22 +288,22 @@ def save_draft_impl(draft_id: str, draft_folder: str = None) -> Dict[str, str]:
|
||||
|
||||
def update_media_metadata(script, task_id=None):
|
||||
"""
|
||||
更新脚本中所有媒体文件的元数据(时长、宽高等)
|
||||
Update metadata for all media files in the script (duration, width/height, etc.)
|
||||
|
||||
:param script: 草稿脚本对象
|
||||
:param task_id: 可选的任务ID,用于更新任务状态
|
||||
:param script: Draft script object
|
||||
:param task_id: Optional task ID for updating task status
|
||||
:return: None
|
||||
"""
|
||||
# 处理音频文件元数据
|
||||
# Process audio file metadata
|
||||
audios = script.materials.audios
|
||||
if not audios:
|
||||
logger.info("草稿中没有找到音频文件。")
|
||||
logger.info("No audio files found in the draft.")
|
||||
else:
|
||||
for audio in audios:
|
||||
remote_url = audio.remote_url
|
||||
material_name = audio.material_name
|
||||
if not remote_url:
|
||||
logger.warning(f"警告:音频文件 {material_name} 没有 remote_url,已跳过。")
|
||||
logger.warning(f"Warning: Audio file {material_name} has no remote_url, skipped.")
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -326,93 +317,93 @@ def update_media_metadata(script, task_id=None):
|
||||
]
|
||||
video_result = subprocess.check_output(video_command, stderr=subprocess.STDOUT)
|
||||
video_result_str = video_result.decode('utf-8')
|
||||
# 查找JSON开始位置(第一个'{')
|
||||
# Find JSON start position (first '{')
|
||||
video_json_start = video_result_str.find('{')
|
||||
if video_json_start != -1:
|
||||
video_json_str = video_result_str[video_json_start:]
|
||||
video_info = json.loads(video_json_str)
|
||||
if 'streams' in video_info and len(video_info['streams']) > 0:
|
||||
logger.warning(f"警告:音频文件 {material_name} 包含视频轨道,已跳过其元数据更新。")
|
||||
logger.warning(f"Warning: Audio file {material_name} contains video tracks, skipped its metadata update.")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"检查音频 {material_name} 是否包含视频流时发生错误: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error occurred while checking if audio {material_name} contains video streams: {str(e)}", exc_info=True)
|
||||
|
||||
# 获取音频时长并设置
|
||||
# Get audio duration and set it
|
||||
try:
|
||||
duration_result = get_video_duration(remote_url)
|
||||
if duration_result["success"]:
|
||||
if task_id:
|
||||
update_task_field(task_id, "message", f"正在处理音频元数据: {material_name}")
|
||||
# 将秒转换为微秒
|
||||
update_task_field(task_id, "message", f"Processing audio metadata: {material_name}")
|
||||
# Convert seconds to microseconds
|
||||
audio.duration = int(duration_result["output"] * 1000000)
|
||||
logger.info(f"成功获取音频 {material_name} 时长: {duration_result['output']:.2f} 秒 ({audio.duration} 微秒)。")
|
||||
logger.info(f"Successfully obtained audio {material_name} duration: {duration_result['output']:.2f} seconds ({audio.duration} microseconds).")
|
||||
|
||||
# 更新使用该音频素材的所有片段的timerange
|
||||
# Update timerange for all segments using this audio material
|
||||
for track_name, track in script.tracks.items():
|
||||
if track.track_type == draft.Track_type.audio:
|
||||
for segment in track.segments:
|
||||
if isinstance(segment, draft.Audio_segment) and segment.material_id == audio.material_id:
|
||||
# 获取当前设置
|
||||
# Get current settings
|
||||
current_target = segment.target_timerange
|
||||
current_source = segment.source_timerange
|
||||
speed = segment.speed.speed
|
||||
|
||||
# 如果source_timerange的结束时间超过了新的音频时长,则调整它
|
||||
# If the end time of source_timerange exceeds the new audio duration, adjust it
|
||||
if current_source.end > audio.duration or current_source.end <= 0:
|
||||
# 调整source_timerange以适应新的音频时长
|
||||
# Adjust source_timerange to fit the new audio duration
|
||||
new_source_duration = audio.duration - current_source.start
|
||||
if new_source_duration <= 0:
|
||||
logger.warning(f"警告:音频片段 {segment.segment_id} 的起始时间 {current_source.start} 超出了音频时长 {audio.duration},将跳过此片段。")
|
||||
logger.warning(f"Warning: Audio segment {segment.segment_id} start time {current_source.start} exceeds audio duration {audio.duration}, will skip this segment.")
|
||||
continue
|
||||
|
||||
# 更新source_timerange
|
||||
# Update source_timerange
|
||||
segment.source_timerange = draft.Timerange(current_source.start, new_source_duration)
|
||||
|
||||
# 根据新的source_timerange和speed更新target_timerange
|
||||
# Update target_timerange based on new source_timerange and speed
|
||||
new_target_duration = int(new_source_duration / speed)
|
||||
segment.target_timerange = draft.Timerange(current_target.start, new_target_duration)
|
||||
|
||||
logger.info(f"已调整音频片段 {segment.segment_id} 的timerange以适应新的音频时长。")
|
||||
logger.info(f"Adjusted audio segment {segment.segment_id} timerange to fit the new audio duration.")
|
||||
else:
|
||||
logger.warning(f"警告:无法获取音频 {material_name} 的时长: {duration_result['error']}。")
|
||||
logger.warning(f"Warning: Unable to get audio {material_name} duration: {duration_result['error']}.")
|
||||
except Exception as e:
|
||||
logger.error(f"获取音频 {material_name} 时长时发生错误: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error occurred while getting audio {material_name} duration: {str(e)}", exc_info=True)
|
||||
|
||||
# 处理视频和图片文件元数据
|
||||
# Process video and image file metadata
|
||||
videos = script.materials.videos
|
||||
if not videos:
|
||||
logger.info("草稿中没有找到视频或图片文件。")
|
||||
logger.info("No video or image files found in the draft.")
|
||||
else:
|
||||
for video in videos:
|
||||
remote_url = video.remote_url
|
||||
material_name = video.material_name
|
||||
if not remote_url:
|
||||
logger.warning(f"警告:媒体文件 {material_name} 没有 remote_url,已跳过。")
|
||||
logger.warning(f"Warning: Media file {material_name} has no remote_url, skipped.")
|
||||
continue
|
||||
|
||||
if video.material_type == 'photo':
|
||||
# 使用imageio获取图片宽高并设置
|
||||
# Use imageio to get image width/height and set it
|
||||
try:
|
||||
if task_id:
|
||||
update_task_field(task_id, "message", f"正在处理图片元数据: {material_name}")
|
||||
update_task_field(task_id, "message", f"Processing image metadata: {material_name}")
|
||||
img = imageio.imread(remote_url)
|
||||
video.height, video.width = img.shape[:2]
|
||||
logger.info(f"成功设置图片 {material_name} 宽高: {video.width}x{video.height}。")
|
||||
logger.info(f"Successfully set image {material_name} dimensions: {video.width}x{video.height}.")
|
||||
except Exception as e:
|
||||
logger.error(f"设置图片 {material_name} 宽高失败: {str(e)},使用默认值 1920x1080。", exc_info=True)
|
||||
logger.error(f"Failed to set image {material_name} dimensions: {str(e)}, using default values 1920x1080.", exc_info=True)
|
||||
video.width = 1920
|
||||
video.height = 1080
|
||||
|
||||
elif video.material_type == 'video':
|
||||
# 获取视频时长和宽高信息
|
||||
# Get video duration and width/height information
|
||||
try:
|
||||
if task_id:
|
||||
update_task_field(task_id, "message", f"正在处理视频元数据: {material_name}")
|
||||
# 使用ffprobe获取视频信息
|
||||
update_task_field(task_id, "message", f"Processing video metadata: {material_name}")
|
||||
# Use ffprobe to get video information
|
||||
command = [
|
||||
'ffprobe',
|
||||
'-v', 'error',
|
||||
'-select_streams', 'v:0', # 选择第一个视频流
|
||||
'-select_streams', 'v:0', # Select the first video stream
|
||||
'-show_entries', 'stream=width,height,duration',
|
||||
'-show_entries', 'format=duration',
|
||||
'-of', 'json',
|
||||
@@ -420,7 +411,7 @@ def update_media_metadata(script, task_id=None):
|
||||
]
|
||||
result = subprocess.check_output(command, stderr=subprocess.STDOUT)
|
||||
result_str = result.decode('utf-8')
|
||||
# 查找JSON开始位置(第一个'{')
|
||||
# Find JSON start position (first '{')
|
||||
json_start = result_str.find('{')
|
||||
if json_start != -1:
|
||||
json_str = result_str[json_start:]
|
||||
@@ -428,142 +419,141 @@ def update_media_metadata(script, task_id=None):
|
||||
|
||||
if 'streams' in info and len(info['streams']) > 0:
|
||||
stream = info['streams'][0]
|
||||
# 设置宽高
|
||||
# Set width and height
|
||||
video.width = int(stream.get('width', 0))
|
||||
video.height = int(stream.get('height', 0))
|
||||
logger.info(f"成功设置视频 {material_name} 宽高: {video.width}x{video.height}。")
|
||||
logger.info(f"Successfully set video {material_name} dimensions: {video.width}x{video.height}.")
|
||||
|
||||
# 设置时长
|
||||
# 优先使用流的duration,如果没有则使用格式的duration
|
||||
# Set duration
|
||||
# Prefer stream duration, if not available use format duration
|
||||
duration = stream.get('duration') or info['format'].get('duration', '0')
|
||||
video.duration = int(float(duration) * 1000000) # 转换为微秒
|
||||
logger.info(f"成功获取视频 {material_name} 时长: {float(duration):.2f} 秒 ({video.duration} 微秒)。")
|
||||
video.duration = int(float(duration) * 1000000) # Convert to microseconds
|
||||
logger.info(f"Successfully obtained video {material_name} duration: {float(duration):.2f} seconds ({video.duration} microseconds).")
|
||||
|
||||
# 更新使用该视频素材的所有片段的timerange
|
||||
# Update timerange for all segments using this video material
|
||||
for track_name, track in script.tracks.items():
|
||||
if track.track_type == draft.Track_type.video:
|
||||
for segment in track.segments:
|
||||
if isinstance(segment, draft.Video_segment) and segment.material_id == video.material_id:
|
||||
# 获取当前设置
|
||||
# Get current settings
|
||||
current_target = segment.target_timerange
|
||||
current_source = segment.source_timerange
|
||||
speed = segment.speed.speed
|
||||
|
||||
# 如果source_timerange的结束时间超过了新的音频时长,则调整它
|
||||
# If the end time of source_timerange exceeds the new video duration, adjust it
|
||||
if current_source.end > video.duration or current_source.end <= 0:
|
||||
# 调整source_timerange以适应新的视频时长
|
||||
# Adjust source_timerange to fit the new video duration
|
||||
new_source_duration = video.duration - current_source.start
|
||||
if new_source_duration <= 0:
|
||||
logger.warning(f"警告:视频片段 {segment.segment_id} 的起始时间 {current_source.start} 超出了视频时长 {video.duration},将跳过此片段。")
|
||||
logger.warning(f"Warning: Video segment {segment.segment_id} start time {current_source.start} exceeds video duration {video.duration}, will skip this segment.")
|
||||
continue
|
||||
|
||||
# 更新source_timerange
|
||||
# Update source_timerange
|
||||
segment.source_timerange = draft.Timerange(current_source.start, new_source_duration)
|
||||
|
||||
# 根据新的source_timerange和speed更新target_timerange
|
||||
# Update target_timerange based on new source_timerange and speed
|
||||
new_target_duration = int(new_source_duration / speed)
|
||||
segment.target_timerange = draft.Timerange(current_target.start, new_target_duration)
|
||||
|
||||
logger.info(f"已调整视频片段 {segment.segment_id} 的timerange以适应新的视频时长。")
|
||||
logger.info(f"Adjusted video segment {segment.segment_id} timerange to fit the new video duration.")
|
||||
else:
|
||||
logger.warning(f"警告:无法获取视频 {material_name} 的流信息。")
|
||||
# 设置默认值
|
||||
logger.warning(f"Warning: Unable to get video {material_name} stream information.")
|
||||
# Set default values
|
||||
video.width = 1920
|
||||
video.height = 1080
|
||||
else:
|
||||
logger.warning(f"警告:无法在ffprobe输出中找到JSON数据。")
|
||||
# 设置默认值
|
||||
logger.warning(f"Warning: Could not find JSON data in ffprobe output.")
|
||||
# Set default values
|
||||
video.width = 1920
|
||||
video.height = 1080
|
||||
except Exception as e:
|
||||
logger.error(f"获取视频 {material_name} 信息时发生错误: {str(e)},使用默认值 1920x1080。", exc_info=True)
|
||||
# 设置默认值
|
||||
logger.error(f"Error occurred while getting video {material_name} information: {str(e)}, using default values 1920x1080.", exc_info=True)
|
||||
# Set default values
|
||||
video.width = 1920
|
||||
video.height = 1080
|
||||
|
||||
# 尝试单独获取时长
|
||||
# Try to get duration separately
|
||||
try:
|
||||
duration_result = get_video_duration(remote_url)
|
||||
if duration_result["success"]:
|
||||
# 将秒转换为微秒
|
||||
# Convert seconds to microseconds
|
||||
video.duration = int(duration_result["output"] * 1000000)
|
||||
logger.info(f"成功获取视频 {material_name} 时长: {duration_result['output']:.2f} 秒 ({video.duration} 微秒)。")
|
||||
logger.info(f"Successfully obtained video {material_name} duration: {duration_result['output']:.2f} seconds ({video.duration} microseconds).")
|
||||
else:
|
||||
logger.warning(f"警告:无法获取视频 {material_name} 的时长: {duration_result['error']}。")
|
||||
logger.warning(f"Warning: Unable to get video {material_name} duration: {duration_result['error']}.")
|
||||
except Exception as e2:
|
||||
logger.error(f"获取视频 {material_name} 时长时发生错误: {str(e2)}。", exc_info=True)
|
||||
logger.error(f"Error occurred while getting video {material_name} duration: {str(e2)}.", exc_info=True)
|
||||
|
||||
# 在更新完所有片段的timerange后,检查每个轨道中的片段是否有时间范围冲突,并删除冲突的后一个片段
|
||||
logger.info("检查轨道片段时间范围冲突...")
|
||||
# After updating all segments' timerange, check if there are time range conflicts in each track, and delete the later segment in case of conflict
|
||||
logger.info("Checking track segment time range conflicts...")
|
||||
for track_name, track in script.tracks.items():
|
||||
# 使用集合记录需要删除的片段索引
|
||||
# Use a set to record segment indices that need to be deleted
|
||||
to_remove = set()
|
||||
|
||||
# 检查所有片段之间的冲突
|
||||
# Check for conflicts between all segments
|
||||
for i in range(len(track.segments)):
|
||||
# 如果当前片段已经被标记为删除,则跳过
|
||||
# Skip if current segment is already marked for deletion
|
||||
if i in to_remove:
|
||||
continue
|
||||
|
||||
for j in range(len(track.segments)):
|
||||
# 跳过自身比较和已标记为删除的片段
|
||||
# Skip self-comparison and segments already marked for deletion
|
||||
if i == j or j in to_remove:
|
||||
continue
|
||||
|
||||
# 检查是否有冲突
|
||||
# Check if there is a conflict
|
||||
if track.segments[i].overlaps(track.segments[j]):
|
||||
# 总是保留索引较小的片段(先添加的片段)
|
||||
# Always keep the segment with the smaller index (added first)
|
||||
later_index = max(i, j)
|
||||
logger.warning(f"轨道 {track_name} 中的片段 {track.segments[min(i, j)].segment_id} 和 {track.segments[later_index].segment_id} 时间范围冲突,删除后一个片段")
|
||||
logger.warning(f"Time range conflict between segments {track.segments[min(i, j)].segment_id} and {track.segments[later_index].segment_id} in track {track_name}, deleting the later segment")
|
||||
to_remove.add(later_index)
|
||||
|
||||
# 从后向前删除标记的片段,避免索引变化问题
|
||||
# Delete marked segments from back to front to avoid index change issues
|
||||
for index in sorted(to_remove, reverse=True):
|
||||
track.segments.pop(index)
|
||||
|
||||
# 在更新完所有片段的timerange后,重新计算脚本的总时长
|
||||
# After updating all segments' timerange, recalculate the total duration of the script
|
||||
max_duration = 0
|
||||
for track_name, track in script.tracks.items():
|
||||
for segment in track.segments:
|
||||
max_duration = max(max_duration, segment.end)
|
||||
script.duration = max_duration
|
||||
logger.info(f"更新脚本总时长为: {script.duration} 微秒。")
|
||||
logger.info(f"Updated script total duration to: {script.duration} microseconds.")
|
||||
|
||||
# 处理所有轨道中待添加的关键帧
|
||||
logger.info("处理待添加的关键帧...")
|
||||
# Process all pending keyframes in tracks
|
||||
logger.info("Processing pending keyframes...")
|
||||
for track_name, track in script.tracks.items():
|
||||
if hasattr(track, 'pending_keyframes') and track.pending_keyframes:
|
||||
logger.info(f"处理轨道 {track_name} 中的 {len(track.pending_keyframes)} 个待添加关键帧...")
|
||||
logger.info(f"Processing {len(track.pending_keyframes)} pending keyframes in track {track_name}...")
|
||||
track.process_pending_keyframes()
|
||||
logger.info(f"轨道 {track_name} 中的待添加关键帧已处理完成。")
|
||||
logger.info(f"Pending keyframes in track {track_name} have been processed.")
|
||||
|
||||
def query_script_impl(draft_id: str, force_update: bool = True):
|
||||
"""
|
||||
查询草稿脚本对象,可选择是否强制刷新媒体元数据
|
||||
Query draft script object, with option to force refresh media metadata
|
||||
|
||||
:param draft_id: 草稿ID
|
||||
:param force_update: 是否强制刷新媒体元数据,默认为True
|
||||
:return: 脚本对象
|
||||
:param draft_id: Draft ID
|
||||
:param force_update: Whether to force refresh media metadata, default is True
|
||||
:return: Script object
|
||||
"""
|
||||
# 从全局缓存中获取草稿信息
|
||||
# Get draft information from global cache
|
||||
if draft_id not in DRAFT_CACHE:
|
||||
logger.warning(f"草稿 {draft_id} 不存在于缓存中。")
|
||||
logger.warning(f"Draft {draft_id} does not exist in cache.")
|
||||
return None
|
||||
|
||||
script = DRAFT_CACHE[draft_id]
|
||||
logger.info(f"从缓存中获取草稿 {draft_id}。")
|
||||
logger.info(f"Retrieved draft {draft_id} from cache.")
|
||||
|
||||
# 如果force_update为True,则强制刷新媒体元数据
|
||||
# If force_update is True, force refresh media metadata
|
||||
if force_update:
|
||||
logger.info(f"强制刷新草稿 {draft_id} 的媒体元数据。")
|
||||
logger.info(f"Force refreshing media metadata for draft {draft_id}.")
|
||||
update_media_metadata(script)
|
||||
|
||||
# 返回脚本对象
|
||||
# Return script object
|
||||
return script
|
||||
|
||||
def download_script(draft_id: str, draft_folder: str = None, script_data: Dict = None) -> Dict[str, str]:
|
||||
"""
|
||||
Downloads the draft script and its associated media assets.
|
||||
"""Downloads the draft script and its associated media assets.
|
||||
|
||||
This function fetches the script object from a remote API,
|
||||
then iterates through its materials (audios, videos, images)
|
||||
@@ -579,15 +569,15 @@ def download_script(draft_id: str, draft_folder: str = None, script_data: Dict =
|
||||
If failed, it returns an error message.
|
||||
"""
|
||||
|
||||
logger.info(f"开始下载草稿: {draft_id} 到文件夹: {draft_folder}")
|
||||
# 把模版复制到目标目录下
|
||||
logger.info(f"Starting to download draft: {draft_id} to folder: {draft_folder}")
|
||||
# Copy template to target directory
|
||||
template_path = os.path.join("./", 'template') if IS_CAPCUT_ENV else os.path.join("./", 'template_jianying')
|
||||
new_draft_path = os.path.join(draft_folder, draft_id)
|
||||
if os.path.exists(new_draft_path):
|
||||
logger.warning(f"删除已存在的草稿目标文件夹: {new_draft_path}")
|
||||
logger.warning(f"Deleting existing draft target folder: {new_draft_path}")
|
||||
shutil.rmtree(new_draft_path)
|
||||
|
||||
# 复制草稿文件夹
|
||||
# Copy draft folder
|
||||
shutil.copytree(template_path, new_draft_path)
|
||||
|
||||
try:
|
||||
@@ -597,33 +587,33 @@ def download_script(draft_id: str, draft_folder: str = None, script_data: Dict =
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = {"draft_id": draft_id}
|
||||
|
||||
logger.info(f"尝试从 {query_url} 获取草稿 ID: {draft_id} 的脚本。")
|
||||
logger.info(f"Attempting to get script for draft ID: {draft_id} from {query_url}.")
|
||||
response = requests.post(query_url, headers=headers, json=payload)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
||||
|
||||
script_data = json.loads(response.json().get('output'))
|
||||
logger.info(f"成功获取草稿 {draft_id} 的脚本数据。")
|
||||
logger.info(f"Successfully retrieved script data for draft {draft_id}.")
|
||||
else:
|
||||
logger.info(f"使用传入的 script_data,跳过远程获取。")
|
||||
logger.info(f"Using provided script_data, skipping remote retrieval.")
|
||||
|
||||
# 收集下载任务
|
||||
# Collect download tasks
|
||||
download_tasks = []
|
||||
|
||||
# 收集音频下载任务
|
||||
# Collect audio download tasks
|
||||
audios = script_data.get('materials',{}).get('audios',[])
|
||||
if audios:
|
||||
for audio in audios:
|
||||
remote_url = audio['remote_url']
|
||||
material_name = audio['name']
|
||||
# 使用辅助函数构建路径
|
||||
# Use helper function to build path
|
||||
if draft_folder:
|
||||
audio['path']=build_asset_path(draft_folder, draft_id, "audio", material_name)
|
||||
logger.debug(f"音频 {material_name} 的本地路径: {audio['path']}")
|
||||
logger.debug(f"Local path for audio {material_name}: {audio['path']}")
|
||||
if not remote_url:
|
||||
logger.warning(f"音频文件 {material_name} 没有 remote_url,跳过下载。")
|
||||
logger.warning(f"Audio file {material_name} has no remote_url, skipping download.")
|
||||
continue
|
||||
|
||||
# 添加音频下载任务
|
||||
# Add audio download task
|
||||
download_tasks.append({
|
||||
'type': 'audio',
|
||||
'func': download_file,
|
||||
@@ -631,7 +621,7 @@ def download_script(draft_id: str, draft_folder: str = None, script_data: Dict =
|
||||
'material': audio
|
||||
})
|
||||
|
||||
# 收集视频和图片下载任务
|
||||
# Collect video and image download tasks
|
||||
videos = script_data['materials']['videos']
|
||||
if videos:
|
||||
for video in videos:
|
||||
@@ -639,14 +629,14 @@ def download_script(draft_id: str, draft_folder: str = None, script_data: Dict =
|
||||
material_name = video['material_name']
|
||||
|
||||
if video['type'] == 'photo':
|
||||
# 使用辅助函数构建路径
|
||||
# Use helper function to build path
|
||||
if draft_folder:
|
||||
video['path'] = build_asset_path(draft_folder, draft_id, "image", material_name)
|
||||
if not remote_url:
|
||||
logger.warning(f"图片文件 {material_name} 没有 remote_url,跳过下载。")
|
||||
logger.warning(f"Image file {material_name} has no remote_url, skipping download.")
|
||||
continue
|
||||
|
||||
# 添加图片下载任务
|
||||
# Add image download task
|
||||
download_tasks.append({
|
||||
'type': 'image',
|
||||
'func': download_file,
|
||||
@@ -655,14 +645,14 @@ def download_script(draft_id: str, draft_folder: str = None, script_data: Dict =
|
||||
})
|
||||
|
||||
elif video['type'] == 'video':
|
||||
# 使用辅助函数构建路径
|
||||
# Use helper function to build path
|
||||
if draft_folder:
|
||||
video['path'] = build_asset_path(draft_folder, draft_id, "video", material_name)
|
||||
if not remote_url:
|
||||
logger.warning(f"视频文件 {material_name} 没有 remote_url,跳过下载。")
|
||||
logger.warning(f"Video file {material_name} has no remote_url, skipping download.")
|
||||
continue
|
||||
|
||||
# 添加视频下载任务
|
||||
# Add video download task
|
||||
download_tasks.append({
|
||||
'type': 'video',
|
||||
'func': download_file,
|
||||
@@ -670,50 +660,50 @@ def download_script(draft_id: str, draft_folder: str = None, script_data: Dict =
|
||||
'material': video
|
||||
})
|
||||
|
||||
# 并发执行所有下载任务
|
||||
# Execute all download tasks concurrently
|
||||
downloaded_paths = []
|
||||
completed_files = 0
|
||||
if download_tasks:
|
||||
logger.info(f"开始并发下载 {len(download_tasks)} 个文件...")
|
||||
logger.info(f"Starting concurrent download of {len(download_tasks)} files...")
|
||||
|
||||
# 使用线程池并发下载,最大并发数为16
|
||||
# Use thread pool for concurrent downloads, maximum concurrency of 16
|
||||
with ThreadPoolExecutor(max_workers=16) as executor:
|
||||
# 提交所有下载任务
|
||||
# Submit all download tasks
|
||||
future_to_task = {
|
||||
executor.submit(task['func'], *task['args']): task
|
||||
for task in download_tasks
|
||||
}
|
||||
|
||||
# 等待所有任务完成
|
||||
# Wait for all tasks to complete
|
||||
for future in as_completed(future_to_task):
|
||||
task = future_to_task[future]
|
||||
try:
|
||||
local_path = future.result()
|
||||
downloaded_paths.append(local_path)
|
||||
|
||||
# 更新任务状态 - 只更新已完成文件数
|
||||
# Update task status - only update completed files count
|
||||
completed_files += 1
|
||||
logger.info(f"已下载 {completed_files}/{len(download_tasks)} 个文件。")
|
||||
logger.info(f"Downloaded {completed_files}/{len(download_tasks)} files.")
|
||||
except Exception as e:
|
||||
logger.error(f"下载 {task['type']} 文件 {task['args'][0]} 失败: {str(e)}", exc_info=True)
|
||||
logger.error("下载失败。")
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
logger.error(f"Failed to download {task['type']} file {task['args'][0]}: {str(e)}", exc_info=True)
|
||||
logger.error("Download failed.")
|
||||
# Continue processing other files, don't interrupt the entire process
|
||||
|
||||
logger.info(f"并发下载完成,共下载 {len(downloaded_paths)} 个文件。")
|
||||
logger.info(f"Concurrent download completed, downloaded {len(downloaded_paths)} files in total.")
|
||||
|
||||
"""将草稿文件内容写入文件"""
|
||||
"""Write draft file content to file"""
|
||||
with open(f"{draft_folder}/{draft_id}/draft_info.json", "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(script_data))
|
||||
logger.info(f"草稿已保存。")
|
||||
logger.info(f"Draft has been saved.")
|
||||
|
||||
# No draft_url for download, but return success
|
||||
return {"success": True, "message": f"Draft {draft_id} and its assets downloaded successfully"}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"API 请求失败: {e}", exc_info=True)
|
||||
logger.error(f"API request failed: {e}", exc_info=True)
|
||||
return {"success": False, "error": f"Failed to fetch script from API: {str(e)}"}
|
||||
except Exception as e:
|
||||
logger.error(f"下载过程中发生意外错误: {e}", exc_info=True)
|
||||
logger.error(f"Unexpected error during download: {e}", exc_info=True)
|
||||
return {"success": False, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,77 +2,77 @@ from collections import OrderedDict
|
||||
import threading
|
||||
from typing import Dict, Any
|
||||
|
||||
# 使用OrderedDict实现LRU缓存,限制最大数量为1000
|
||||
DRAFT_TASKS: Dict[str, dict] = OrderedDict() # 使用 Dict 进行类型提示
|
||||
# Using OrderedDict to implement LRU cache, limiting the maximum number to 1000
|
||||
DRAFT_TASKS: Dict[str, dict] = OrderedDict() # Using Dict for type hinting
|
||||
MAX_TASKS_CACHE_SIZE = 1000
|
||||
|
||||
|
||||
def update_tasks_cache(task_id: str, task_status: dict) -> None:
|
||||
"""更新任务状态LRU缓存
|
||||
"""Update task status LRU cache
|
||||
|
||||
:param task_id: 任务ID
|
||||
:param task_status: 任务状态信息字典
|
||||
:param task_id: Task ID
|
||||
:param task_status: Task status information dictionary
|
||||
"""
|
||||
|
||||
if task_id in DRAFT_TASKS:
|
||||
# 如果键存在,删除旧的项
|
||||
# If the key exists, delete the old item
|
||||
DRAFT_TASKS.pop(task_id)
|
||||
elif len(DRAFT_TASKS) >= MAX_TASKS_CACHE_SIZE:
|
||||
# 如果缓存已满,删除最久未使用的项(第一个项)
|
||||
# If the cache is full, delete the least recently used item (the first item)
|
||||
DRAFT_TASKS.popitem(last=False)
|
||||
# 添加新项到末尾(最近使用)
|
||||
# Add new item to the end (most recently used)
|
||||
DRAFT_TASKS[task_id] = task_status
|
||||
|
||||
def update_task_field(task_id: str, field: str, value: Any) -> None:
|
||||
"""更新任务状态中的单个字段
|
||||
"""Update a single field in the task status
|
||||
|
||||
:param task_id: 任务ID
|
||||
:param field: 要更新的字段名
|
||||
:param value: 字段的新值
|
||||
:param task_id: Task ID
|
||||
:param field: Field name to update
|
||||
:param value: New value for the field
|
||||
"""
|
||||
if task_id in DRAFT_TASKS:
|
||||
# 复制当前状态,修改指定字段,然后更新缓存
|
||||
# Copy the current status, modify the specified field, then update the cache
|
||||
task_status = DRAFT_TASKS[task_id].copy()
|
||||
task_status[field] = value
|
||||
# 删除旧项并添加更新后的项
|
||||
# Delete the old item and add the updated item
|
||||
DRAFT_TASKS.pop(task_id)
|
||||
DRAFT_TASKS[task_id] = task_status
|
||||
else:
|
||||
# 如果任务不存在,创建一个默认状态并设置指定字段
|
||||
# If the task doesn't exist, create a default status and set the specified field
|
||||
task_status = {
|
||||
"status": "initialized",
|
||||
"message": "任务已初始化",
|
||||
"message": "Task initialized",
|
||||
"progress": 0,
|
||||
"completed_files": 0,
|
||||
"total_files": 0,
|
||||
"draft_url": ""
|
||||
}
|
||||
task_status[field] = value
|
||||
# 如果缓存已满,删除最久未使用的项
|
||||
# If the cache is full, delete the least recently used item
|
||||
if len(DRAFT_TASKS) >= MAX_TASKS_CACHE_SIZE:
|
||||
DRAFT_TASKS.popitem(last=False)
|
||||
# 添加新项
|
||||
# Add new item
|
||||
DRAFT_TASKS[task_id] = task_status
|
||||
|
||||
def update_task_fields(task_id: str, **fields) -> None:
|
||||
"""更新任务状态中的多个字段
|
||||
"""Update multiple fields in the task status
|
||||
|
||||
:param task_id: 任务ID
|
||||
:param fields: 要更新的字段及其值,以关键字参数形式提供
|
||||
:param task_id: Task ID
|
||||
:param fields: Fields to update and their values, provided as keyword arguments
|
||||
"""
|
||||
if task_id in DRAFT_TASKS:
|
||||
# 复制当前状态,修改指定字段,然后更新缓存
|
||||
# Copy the current status, modify the specified fields, then update the cache
|
||||
task_status = DRAFT_TASKS[task_id].copy()
|
||||
for field, value in fields.items():
|
||||
task_status[field] = value
|
||||
# 删除旧项并添加更新后的项
|
||||
# Delete the old item and add the updated item
|
||||
DRAFT_TASKS.pop(task_id)
|
||||
DRAFT_TASKS[task_id] = task_status
|
||||
else:
|
||||
# 如果任务不存在,创建一个默认状态并设置指定字段
|
||||
# If the task doesn't exist, create a default status and set the specified fields
|
||||
task_status = {
|
||||
"status": "initialized",
|
||||
"message": "任务已初始化",
|
||||
"message": "Task initialized",
|
||||
"progress": 0,
|
||||
"completed_files": 0,
|
||||
"total_files": 0,
|
||||
@@ -80,60 +80,60 @@ def update_task_fields(task_id: str, **fields) -> None:
|
||||
}
|
||||
for field, value in fields.items():
|
||||
task_status[field] = value
|
||||
# 如果缓存已满,删除最久未使用的项
|
||||
# If the cache is full, delete the least recently used item
|
||||
if len(DRAFT_TASKS) >= MAX_TASKS_CACHE_SIZE:
|
||||
DRAFT_TASKS.popitem(last=False)
|
||||
# 添加新项
|
||||
# Add new item
|
||||
DRAFT_TASKS[task_id] = task_status
|
||||
|
||||
def increment_task_field(task_id: str, field: str, increment: int = 1) -> None:
|
||||
"""增加任务状态中的数值字段
|
||||
"""Increment a numeric field in the task status
|
||||
|
||||
:param task_id: 任务ID
|
||||
:param field: 要增加的字段名
|
||||
:param increment: 增加的值,默认为1
|
||||
:param task_id: Task ID
|
||||
:param field: Field name to increment
|
||||
:param increment: Value to increment by, default is 1
|
||||
"""
|
||||
if task_id in DRAFT_TASKS:
|
||||
# 复制当前状态,增加指定字段,然后更新缓存
|
||||
# Copy the current status, increment the specified field, then update the cache
|
||||
task_status = DRAFT_TASKS[task_id].copy()
|
||||
if field in task_status and isinstance(task_status[field], (int, float)):
|
||||
task_status[field] += increment
|
||||
else:
|
||||
task_status[field] = increment
|
||||
# 删除旧项并添加更新后的项
|
||||
# Delete the old item and add the updated item
|
||||
DRAFT_TASKS.pop(task_id)
|
||||
DRAFT_TASKS[task_id] = task_status
|
||||
|
||||
def get_task_status(task_id: str) -> dict:
|
||||
"""获取任务状态
|
||||
"""Get task status
|
||||
|
||||
:param task_id: 任务ID
|
||||
:return: 任务状态信息字典
|
||||
:param task_id: Task ID
|
||||
:return: Task status information dictionary
|
||||
"""
|
||||
task_status = DRAFT_TASKS.get(task_id, {
|
||||
"status": "not_found",
|
||||
"message": "任务不存在",
|
||||
"message": "Task does not exist",
|
||||
"progress": 0,
|
||||
"completed_files": 0,
|
||||
"total_files": 0,
|
||||
"draft_url": ""
|
||||
})
|
||||
|
||||
# 如果找到了任务,更新其在LRU缓存中的位置
|
||||
# If the task is found, update its position in the LRU cache
|
||||
if task_id in DRAFT_TASKS:
|
||||
# 先删除,再添加到末尾,实现LRU更新
|
||||
# First delete, then add to the end, implementing LRU update
|
||||
update_tasks_cache(task_id, task_status)
|
||||
|
||||
return task_status
|
||||
|
||||
def create_task(task_id: str) -> None:
|
||||
"""创建新任务并初始化状态
|
||||
"""Create a new task and initialize its status
|
||||
|
||||
:param task_id: 任务ID
|
||||
:param task_id: Task ID
|
||||
"""
|
||||
task_status = {
|
||||
"status": "initialized",
|
||||
"message": "任务已初始化",
|
||||
"message": "Task initialized",
|
||||
"progress": 0,
|
||||
"completed_files": 0,
|
||||
"total_files": 0,
|
||||
|
||||
Reference in New Issue
Block a user