Merge pull request #458 from yyhhyyyyyy/refactor-task-add-subtitle-api

Refactor task.py and add subtitle API
This commit is contained in:
Harry
2024-07-24 11:47:27 +08:00
committed by GitHub
3 changed files with 383 additions and 194 deletions

View File

@@ -1,11 +1,12 @@
import os
import glob
import os
import pathlib
import shutil
from typing import Union
from fastapi import Request, Depends, Path, BackgroundTasks, UploadFile
from fastapi.responses import FileResponse, StreamingResponse
from fastapi import BackgroundTasks, Depends, Path, Request, UploadFile
from fastapi.params import File
from fastapi.responses import FileResponse, StreamingResponse
from loguru import logger
from app.config import config
@@ -14,10 +15,19 @@ from app.controllers.manager.memory_manager import InMemoryTaskManager
from app.controllers.manager.redis_manager import RedisTaskManager
from app.controllers.v1.base import new_router
from app.models.exception import HttpException
from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest, \
BgmUploadResponse, BgmRetrieveResponse, TaskDeletionResponse
from app.services import task as tm
from app.models.schema import (
AudioRequest,
BgmRetrieveResponse,
BgmUploadResponse,
SubtitleRequest,
TaskDeletionResponse,
TaskQueryRequest,
TaskQueryResponse,
TaskResponse,
TaskVideoRequest,
)
from app.services import state as sm
from app.services import task as tm
from app.utils import utils
# 认证依赖项
@@ -34,48 +44,65 @@ _max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5)
redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}"
# 根据配置选择合适的任务管理器
if _enable_redis:
task_manager = RedisTaskManager(max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url)
task_manager = RedisTaskManager(
max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url
)
else:
task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks)
# @router.post("/videos-test", response_model=TaskResponse, summary="Generate a short video")
# async def create_video_test(request: Request, body: TaskVideoRequest):
# task_id = utils.get_uuid()
# request_id = base.get_task_id(request)
# try:
# task = {
# "task_id": task_id,
# "request_id": request_id,
# "params": body.dict(),
# }
# task_manager.add_task(tm.start_test, task_id=task_id, params=body)
# return utils.get_response(200, task)
# except ValueError as e:
# raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
@router.post("/videos", response_model=TaskResponse, summary="Generate a short video")
def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest):
def create_video(
background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest
):
return create_task(request, body, stop_at="video")
@router.post("/subtitle", response_model=TaskResponse, summary="Generate subtitle only")
def create_subtitle(
background_tasks: BackgroundTasks, request: Request, body: SubtitleRequest
):
return create_task(request, body, stop_at="subtitle")
@router.post("/audio", response_model=TaskResponse, summary="Generate audio only")
def create_audio(
background_tasks: BackgroundTasks, request: Request, body: AudioRequest
):
return create_task(request, body, stop_at="audio")
def create_task(
request: Request,
body: Union[TaskVideoRequest, SubtitleRequest, AudioRequest],
stop_at: str,
):
task_id = utils.get_uuid()
request_id = base.get_task_id(request)
try:
task = {
"task_id": task_id,
"request_id": request_id,
"params": body.dict(),
"params": body.model_dump(),
}
sm.state.update_task(task_id)
# background_tasks.add_task(tm.start, task_id=task_id, params=body)
task_manager.add_task(tm.start, task_id=task_id, params=body)
logger.success(f"video created: {utils.to_json(task)}")
task_manager.add_task(tm.start, task_id=task_id, params=body, stop_at=stop_at)
logger.success(f"Task created: {utils.to_json(task)}")
return utils.get_response(200, task)
except ValueError as e:
raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
raise HttpException(
task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}"
)
@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status")
def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
query: TaskQueryRequest = Depends()):
@router.get(
"/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status"
)
def get_task(
request: Request,
task_id: str = Path(..., description="Task ID"),
query: TaskQueryRequest = Depends(),
):
endpoint = config.app.get("endpoint", "")
if not endpoint:
endpoint = str(request.base_url)
@@ -108,10 +135,16 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
task["combined_videos"] = urls
return utils.get_response(200, task)
raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found")
raise HttpException(
task_id=task_id, status_code=404, message=f"{request_id}: task not found"
)
@router.delete("/tasks/{task_id}", response_model=TaskDeletionResponse, summary="Delete a generated short video task")
@router.delete(
"/tasks/{task_id}",
response_model=TaskDeletionResponse,
summary="Delete a generated short video task",
)
def delete_video(request: Request, task_id: str = Path(..., description="Task ID")):
request_id = base.get_task_id(request)
task = sm.state.get_task(task_id)
@@ -125,32 +158,40 @@ def delete_video(request: Request, task_id: str = Path(..., description="Task ID
logger.success(f"video deleted: {utils.to_json(task)}")
return utils.get_response(200)
raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found")
raise HttpException(
task_id=task_id, status_code=404, message=f"{request_id}: task not found"
)
@router.get("/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files")
@router.get(
"/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files"
)
def get_bgm_list(request: Request):
suffix = "*.mp3"
song_dir = utils.song_dir()
files = glob.glob(os.path.join(song_dir, suffix))
bgm_list = []
for file in files:
bgm_list.append({
"name": os.path.basename(file),
"size": os.path.getsize(file),
"file": file,
})
response = {
"files": bgm_list
}
bgm_list.append(
{
"name": os.path.basename(file),
"size": os.path.getsize(file),
"file": file,
}
)
response = {"files": bgm_list}
return utils.get_response(200, response)
@router.post("/musics", response_model=BgmUploadResponse, summary="Upload the BGM file to the songs directory")
@router.post(
"/musics",
response_model=BgmUploadResponse,
summary="Upload the BGM file to the songs directory",
)
def upload_bgm_file(request: Request, file: UploadFile = File(...)):
request_id = base.get_task_id(request)
# check file ext
if file.filename.endswith('mp3'):
if file.filename.endswith("mp3"):
song_dir = utils.song_dir()
save_path = os.path.join(song_dir, file.filename)
# save file
@@ -158,26 +199,26 @@ def upload_bgm_file(request: Request, file: UploadFile = File(...)):
# If the file already exists, it will be overwritten
file.file.seek(0)
buffer.write(file.file.read())
response = {
"file": save_path
}
response = {"file": save_path}
return utils.get_response(200, response)
raise HttpException('', status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded")
raise HttpException(
"", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded"
)
@router.get("/stream/{file_path:path}")
async def stream_video(request: Request, file_path: str):
tasks_dir = utils.task_dir()
video_path = os.path.join(tasks_dir, file_path)
range_header = request.headers.get('Range')
range_header = request.headers.get("Range")
video_size = os.path.getsize(video_path)
start, end = 0, video_size - 1
length = video_size
if range_header:
range_ = range_header.split('bytes=')[1]
start, end = [int(part) if part else None for part in range_.split('-')]
range_ = range_header.split("bytes=")[1]
start, end = [int(part) if part else None for part in range_.split("-")]
if start is None:
start = video_size - end
end = video_size - 1
@@ -186,7 +227,7 @@ async def stream_video(request: Request, file_path: str):
length = end - start + 1
def file_iterator(file_path, offset=0, bytes_to_read=None):
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
f.seek(offset, os.SEEK_SET)
remaining = bytes_to_read or video_size
while remaining > 0:
@@ -197,10 +238,12 @@ async def stream_video(request: Request, file_path: str):
remaining -= len(data)
yield data
response = StreamingResponse(file_iterator(video_path, start, length), media_type='video/mp4')
response.headers['Content-Range'] = f'bytes {start}-{end}/{video_size}'
response.headers['Accept-Ranges'] = 'bytes'
response.headers['Content-Length'] = str(length)
response = StreamingResponse(
file_iterator(video_path, start, length), media_type="video/mp4"
)
response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}"
response.headers["Accept-Ranges"] = "bytes"
response.headers["Content-Length"] = str(length)
response.status_code = 206 # Partial Content
return response
@@ -219,8 +262,10 @@ async def download_video(_: Request, file_path: str):
file_path = pathlib.Path(video_path)
filename = file_path.stem
extension = file_path.suffix
headers = {
"Content-Disposition": f"attachment; filename={filename}{extension}"
}
return FileResponse(path=video_path, headers=headers, filename=f"{filename}{extension}",
media_type=f'video/{extension[1:]}')
headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"}
return FileResponse(
path=video_path,
headers=headers,
filename=f"{filename}{extension}",
media_type=f"video/{extension[1:]}",
)

View File

@@ -1,12 +1,16 @@
import warnings
from enum import Enum
from typing import Any, Optional, List
from typing import Any, List, Optional
import pydantic
from pydantic import BaseModel
import warnings
# 忽略 Pydantic 的特定警告
warnings.filterwarnings("ignore", category=UserWarning, message="Field name.*shadows an attribute in parent.*")
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="Field name.*shadows an attribute in parent.*",
)
class VideoConcatMode(str, Enum):
@@ -61,7 +65,6 @@ class MaterialInfo:
# # "male-zh-TW-YunJheNeural",
#
# # en-US
#
# "female-en-US-AnaNeural",
# "female-en-US-AriaNeural",
# "female-en-US-AvaNeural",
@@ -93,6 +96,7 @@ class VideoParams(BaseModel):
"stroke_width": 1.5
}
"""
video_subject: str
video_script: str = "" # 用于生成视频的脚本
video_terms: Optional[str | list] = None # 用于生成视频的关键词
@@ -126,6 +130,38 @@ class VideoParams(BaseModel):
paragraph_number: Optional[int] = 1
class SubtitleRequest(BaseModel):
video_script: str
video_language: Optional[str] = ""
voice_name: Optional[str] = "zh-CN-XiaoxiaoNeural-Female"
voice_volume: Optional[float] = 1.0
voice_rate: Optional[float] = 1.2
bgm_type: Optional[str] = "random"
bgm_file: Optional[str] = ""
bgm_volume: Optional[float] = 0.2
subtitle_position: Optional[str] = "bottom"
font_name: Optional[str] = "STHeitiMedium.ttc"
text_fore_color: Optional[str] = "#FFFFFF"
text_background_color: Optional[str] = "transparent"
font_size: int = 60
stroke_color: Optional[str] = "#000000"
stroke_width: float = 1.5
video_source: Optional[str] = "local"
subtitle_enabled: Optional[str] = "true"
class AudioRequest(BaseModel):
video_script: str
video_language: Optional[str] = ""
voice_name: Optional[str] = "zh-CN-XiaoxiaoNeural-Female"
voice_volume: Optional[float] = 1.0
voice_rate: Optional[float] = 1.2
bgm_type: Optional[str] = "random"
bgm_file: Optional[str] = ""
bgm_volume: Optional[float] = 0.2
video_source: Optional[str] = "local"
class VideoScriptParams:
"""
{
@@ -134,6 +170,7 @@ class VideoScriptParams:
"paragraph_number": 1
}
"""
video_subject: Optional[str] = "春天的花海"
video_language: Optional[str] = ""
paragraph_number: Optional[int] = 1
@@ -147,14 +184,17 @@ class VideoTermsParams:
"amount": 5
}
"""
video_subject: Optional[str] = "春天的花海"
video_script: Optional[str] = "春天的花海,如诗如画般展现在眼前。万物复苏的季节里,大地披上了一袭绚丽多彩的盛装。金黄的迎春、粉嫩的樱花、洁白的梨花、艳丽的郁金香……"
video_script: Optional[str] = (
"春天的花海,如诗如画般展现在眼前。万物复苏的季节里,大地披上了一袭绚丽多彩的盛装。金黄的迎春、粉嫩的樱花、洁白的梨花、艳丽的郁金香……"
)
amount: Optional[int] = 5
class BaseResponse(BaseModel):
status: int = 200
message: Optional[str] = 'success'
message: Optional[str] = "success"
data: Any = None
@@ -189,9 +229,7 @@ class TaskResponse(BaseResponse):
"example": {
"status": 200,
"message": "success",
"data": {
"task_id": "6c85c8cc-a77a-42b9-bc30-947815aa0558"
}
"data": {"task_id": "6c85c8cc-a77a-42b9-bc30-947815aa0558"},
},
}
@@ -210,8 +248,8 @@ class TaskQueryResponse(BaseResponse):
],
"combined_videos": [
"http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/combined-1.mp4"
]
}
],
},
},
}
@@ -230,8 +268,8 @@ class TaskDeletionResponse(BaseResponse):
],
"combined_videos": [
"http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/combined-1.mp4"
]
}
],
},
},
}
@@ -244,7 +282,7 @@ class VideoScriptResponse(BaseResponse):
"message": "success",
"data": {
"video_script": "春天的花海,是大自然的一幅美丽画卷。在这个季节里,大地复苏,万物生长,花朵争相绽放,形成了一片五彩斑斓的花海..."
}
},
},
}
@@ -255,9 +293,7 @@ class VideoTermsResponse(BaseResponse):
"example": {
"status": 200,
"message": "success",
"data": {
"video_terms": ["sky", "tree"]
}
"data": {"video_terms": ["sky", "tree"]},
},
}
@@ -273,10 +309,10 @@ class BgmRetrieveResponse(BaseResponse):
{
"name": "output013.mp3",
"size": 1891269,
"file": "/MoneyPrinterTurbo/resource/songs/output013.mp3"
"file": "/MoneyPrinterTurbo/resource/songs/output013.mp3",
}
]
}
},
},
}
@@ -287,8 +323,6 @@ class BgmUploadResponse(BaseResponse):
"example": {
"status": 200,
"message": "success",
"data": {
"file": "/MoneyPrinterTurbo/resource/songs/example.mp3"
}
"data": {"file": "/MoneyPrinterTurbo/resource/songs/example.mp3"},
},
}

View File

@@ -7,58 +7,42 @@ from loguru import logger
from app.config import config
from app.models import const
from app.models.schema import VideoParams, VideoConcatMode
from app.services import llm, material, voice, video, subtitle
from app.models.schema import VideoConcatMode, VideoParams
from app.services import llm, material, subtitle, video, voice
from app.services import state as sm
from app.utils import utils
def start(task_id, params: VideoParams):
"""
{
"video_subject": "",
"video_aspect": "横屏 16:9西瓜视频",
"voice_name": "女生-晓晓",
"enable_bgm": false,
"font_name": "STHeitiMedium 黑体-中",
"text_color": "#FFFFFF",
"font_size": 60,
"stroke_color": "#000000",
"stroke_width": 1.5
}
"""
logger.info(f"start task: {task_id}")
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
video_subject = params.video_subject
voice_name = voice.parse_voice_name(params.voice_name)
voice_rate = params.voice_rate
paragraph_number = params.paragraph_number
n_threads = params.n_threads
max_clip_duration = params.video_clip_duration
def generate_script(task_id, params):
logger.info("\n\n## generating video script")
video_script = params.video_script.strip()
if not video_script:
video_script = llm.generate_script(video_subject=video_subject, language=params.video_language,
paragraph_number=paragraph_number)
video_script = llm.generate_script(
video_subject=params.video_subject,
language=params.video_language,
paragraph_number=params.paragraph_number,
)
else:
logger.debug(f"video script: \n{video_script}")
if not video_script:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error("failed to generate video script.")
return
return None
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
return video_script
def generate_terms(task_id, params, video_script):
logger.info("\n\n## generating video terms")
video_terms = params.video_terms
if not video_terms:
video_terms = llm.generate_terms(video_subject=video_subject, video_script=video_script, amount=5)
video_terms = llm.generate_terms(
video_subject=params.video_subject, video_script=video_script, amount=5
)
else:
if isinstance(video_terms, str):
video_terms = [term.strip() for term in re.split(r'[,]', video_terms)]
video_terms = [term.strip() for term in re.split(r"[,]", video_terms)]
elif isinstance(video_terms, list):
video_terms = [term.strip() for term in video_terms]
else:
@@ -69,9 +53,13 @@ def start(task_id, params: VideoParams):
if not video_terms:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error("failed to generate video terms.")
return
return None
script_file = path.join(utils.task_dir(task_id), f"script.json")
return video_terms
def save_script_data(task_id, video_script, video_terms, params):
script_file = path.join(utils.task_dir(task_id), "script.json")
script_data = {
"script": video_script,
"search_terms": video_terms,
@@ -81,11 +69,16 @@ def start(task_id, params: VideoParams):
with open(script_file, "w", encoding="utf-8") as f:
f.write(utils.to_json(script_data))
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
def generate_audio(task_id, params, video_script):
logger.info("\n\n## generating audio")
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_rate=voice_rate, voice_file=audio_file)
audio_file = path.join(utils.task_dir(task_id), "audio.mp3")
sub_maker = voice.tts(
text=video_script,
voice_name=voice.parse_voice_name(params.voice_name),
voice_rate=params.voice_rate,
voice_file=audio_file,
)
if sub_maker is None:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
@@ -94,86 +87,100 @@ def start(task_id, params: VideoParams):
2. check if the network is available. If you are in China, it is recommended to use a VPN and enable the global traffic mode.
""".strip()
)
return
return None, None
audio_duration = voice.get_audio_duration(sub_maker)
audio_duration = math.ceil(audio_duration)
audio_duration = math.ceil(voice.get_audio_duration(sub_maker))
return audio_file, audio_duration
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
subtitle_path = ""
if params.subtitle_enabled:
subtitle_path = path.join(utils.task_dir(task_id), f"subtitle.srt")
subtitle_provider = config.app.get("subtitle_provider", "").strip().lower()
logger.info(f"\n\n## generating subtitle, provider: {subtitle_provider}")
subtitle_fallback = False
if subtitle_provider == "edge":
voice.create_subtitle(text=video_script, sub_maker=sub_maker, subtitle_file=subtitle_path)
if not os.path.exists(subtitle_path):
subtitle_fallback = True
logger.warning("subtitle file not found, fallback to whisper")
def generate_subtitle(task_id, params, video_script, sub_maker, audio_file):
if not params.subtitle_enabled:
return ""
if subtitle_provider == "whisper" or subtitle_fallback:
subtitle.create(audio_file=audio_file, subtitle_file=subtitle_path)
logger.info("\n\n## correcting subtitle")
subtitle.correct(subtitle_file=subtitle_path, video_script=video_script)
subtitle_path = path.join(utils.task_dir(task_id), "subtitle.srt")
subtitle_provider = config.app.get("subtitle_provider", "").strip().lower()
logger.info(f"\n\n## generating subtitle, provider: {subtitle_provider}")
subtitle_lines = subtitle.file_to_subtitles(subtitle_path)
if not subtitle_lines:
logger.warning(f"subtitle file is invalid: {subtitle_path}")
subtitle_path = ""
subtitle_fallback = False
if subtitle_provider == "edge":
voice.create_subtitle(
text=video_script, sub_maker=sub_maker, subtitle_file=subtitle_path
)
if not os.path.exists(subtitle_path):
subtitle_fallback = True
logger.warning("subtitle file not found, fallback to whisper")
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
if subtitle_provider == "whisper" or subtitle_fallback:
subtitle.create(audio_file=audio_file, subtitle_file=subtitle_path)
logger.info("\n\n## correcting subtitle")
subtitle.correct(subtitle_file=subtitle_path, video_script=video_script)
downloaded_videos = []
subtitle_lines = subtitle.file_to_subtitles(subtitle_path)
if not subtitle_lines:
logger.warning(f"subtitle file is invalid: {subtitle_path}")
return ""
return subtitle_path
def get_video_materials(task_id, params, video_terms, audio_duration):
if params.video_source == "local":
logger.info("\n\n## preprocess local materials")
materials = video.preprocess_video(materials=params.video_materials, clip_duration=max_clip_duration)
print(materials)
materials = video.preprocess_video(
materials=params.video_materials, clip_duration=params.video_clip_duration
)
if not materials:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error("no valid materials found, please check the materials and try again.")
return
for material_info in materials:
print(material_info)
downloaded_videos.append(material_info.url)
logger.error(
"no valid materials found, please check the materials and try again."
)
return None
return [material_info.url for material_info in materials]
else:
logger.info(f"\n\n## downloading videos from {params.video_source}")
downloaded_videos = material.download_videos(task_id=task_id,
search_terms=video_terms,
source=params.video_source,
video_aspect=params.video_aspect,
video_contact_mode=params.video_concat_mode,
audio_duration=audio_duration * params.video_count,
max_clip_duration=max_clip_duration,
)
if not downloaded_videos:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
return
downloaded_videos = material.download_videos(
task_id=task_id,
search_terms=video_terms,
source=params.video_source,
video_aspect=params.video_aspect,
video_contact_mode=params.video_concat_mode,
audio_duration=audio_duration * params.video_count,
max_clip_duration=params.video_clip_duration,
)
if not downloaded_videos:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN."
)
return None
return downloaded_videos
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
def generate_final_videos(
task_id, params, downloaded_videos, audio_file, subtitle_path
):
final_video_paths = []
combined_video_paths = []
video_concat_mode = params.video_concat_mode
if params.video_count > 1:
video_concat_mode = VideoConcatMode.random
video_concat_mode = (
params.video_concat_mode if params.video_count > 1 else VideoConcatMode.random
)
_progress = 50
for i in range(params.video_count):
index = i + 1
combined_video_path = path.join(utils.task_dir(task_id), f"combined-{index}.mp4")
combined_video_path = path.join(
utils.task_dir(task_id), f"combined-{index}.mp4"
)
logger.info(f"\n\n## combining video: {index} => {combined_video_path}")
video.combine_videos(combined_video_path=combined_video_path,
video_paths=downloaded_videos,
audio_file=audio_file,
video_aspect=params.video_aspect,
video_concat_mode=video_concat_mode,
max_clip_duration=max_clip_duration,
threads=n_threads)
video.combine_videos(
combined_video_path=combined_video_path,
video_paths=downloaded_videos,
audio_file=audio_file,
video_aspect=params.video_aspect,
video_concat_mode=video_concat_mode,
max_clip_duration=params.video_clip_duration,
threads=params.n_threads,
)
_progress += 50 / params.video_count / 2
sm.state.update_task(task_id, progress=_progress)
@@ -181,13 +188,13 @@ def start(task_id, params: VideoParams):
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
logger.info(f"\n\n## generating video: {index} => {final_video_path}")
# Put everything together
video.generate_video(video_path=combined_video_path,
audio_path=audio_file,
subtitle_path=subtitle_path,
output_file=final_video_path,
params=params,
)
video.generate_video(
video_path=combined_video_path,
audio_path=audio_file,
subtitle_path=subtitle_path,
output_file=final_video_path,
params=params,
)
_progress += 50 / params.video_count / 2
sm.state.update_task(task_id, progress=_progress)
@@ -195,16 +202,119 @@ def start(task_id, params: VideoParams):
final_video_paths.append(final_video_path)
combined_video_paths.append(combined_video_path)
logger.success(f"task {task_id} finished, generated {len(final_video_paths)} videos.")
return final_video_paths, combined_video_paths
def start(task_id, params: VideoParams, stop_at: str = "video"):
logger.info(f"start task: {task_id}, stop_at: {stop_at}")
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
# 1. Generate script
video_script = generate_script(task_id, params)
if not video_script:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
if stop_at == "script":
sm.state.update_task(
task_id, state=const.TASK_STATE_COMPLETE, progress=100, script=video_script
)
return {"script": video_script}
# 2. Generate terms
video_terms = ""
if params.video_source != "local":
video_terms = generate_terms(task_id, params, video_script)
if not video_terms:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return
save_script_data(task_id, video_script, video_terms, params)
if stop_at == "terms":
sm.state.update_task(
task_id, state=const.TASK_STATE_COMPLETE, progress=100, terms=video_terms
)
return {"script": video_script, "terms": video_terms}
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
# 3. Generate audio
audio_file, audio_duration = generate_audio(task_id, params, video_script)
if not audio_file:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
if stop_at == "audio":
sm.state.update_task(
task_id,
state=const.TASK_STATE_COMPLETE,
progress=100,
audio_file=audio_file,
)
return {"audio_file": audio_file, "audio_duration": audio_duration}
# 4. Generate subtitle
subtitle_path = generate_subtitle(task_id, params, video_script, None, audio_file)
if stop_at == "subtitle":
sm.state.update_task(
task_id,
state=const.TASK_STATE_COMPLETE,
progress=100,
subtitle_path=subtitle_path,
)
return {"subtitle_path": subtitle_path}
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
# 5. Get video materials
downloaded_videos = get_video_materials(
task_id, params, video_terms, audio_duration
)
if not downloaded_videos:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return
if stop_at == "materials":
sm.state.update_task(
task_id,
state=const.TASK_STATE_COMPLETE,
progress=100,
materials=downloaded_videos,
)
return {"materials": downloaded_videos}
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
# 6. Generate final videos
final_video_paths, combined_video_paths = generate_final_videos(
task_id, params, downloaded_videos, audio_file, subtitle_path
)
if not final_video_paths:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return
logger.success(
f"task {task_id} finished, generated {len(final_video_paths)} videos."
)
kwargs = {
"videos": final_video_paths,
"combined_videos": combined_video_paths
"combined_videos": combined_video_paths,
"script": video_script,
"terms": video_terms,
"audio_file": audio_file,
"audio_duration": audio_duration,
"subtitle_path": subtitle_path,
"materials": downloaded_videos,
}
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
sm.state.update_task(
task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs
)
return kwargs
# def start_test(task_id, params: VideoParams):
# print(f"start task {task_id} \n")
# time.sleep(5)
# print(f"task {task_id} finished \n")