mirror of
https://github.com/harry0703/MoneyPrinterTurbo.git
synced 2025-11-25 03:15:04 +08:00
Merge pull request #458 from yyhhyyyyyy/refactor-task-add-subtitle-api
Refactor task.py and add subtitle API
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import os
|
||||
import glob
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
from typing import Union
|
||||
|
||||
from fastapi import Request, Depends, Path, BackgroundTasks, UploadFile
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from fastapi import BackgroundTasks, Depends, Path, Request, UploadFile
|
||||
from fastapi.params import File
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
@@ -14,10 +15,19 @@ from app.controllers.manager.memory_manager import InMemoryTaskManager
|
||||
from app.controllers.manager.redis_manager import RedisTaskManager
|
||||
from app.controllers.v1.base import new_router
|
||||
from app.models.exception import HttpException
|
||||
from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest, \
|
||||
BgmUploadResponse, BgmRetrieveResponse, TaskDeletionResponse
|
||||
from app.services import task as tm
|
||||
from app.models.schema import (
|
||||
AudioRequest,
|
||||
BgmRetrieveResponse,
|
||||
BgmUploadResponse,
|
||||
SubtitleRequest,
|
||||
TaskDeletionResponse,
|
||||
TaskQueryRequest,
|
||||
TaskQueryResponse,
|
||||
TaskResponse,
|
||||
TaskVideoRequest,
|
||||
)
|
||||
from app.services import state as sm
|
||||
from app.services import task as tm
|
||||
from app.utils import utils
|
||||
|
||||
# 认证依赖项
|
||||
@@ -34,48 +44,65 @@ _max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5)
|
||||
redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}"
|
||||
# 根据配置选择合适的任务管理器
|
||||
if _enable_redis:
|
||||
task_manager = RedisTaskManager(max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url)
|
||||
task_manager = RedisTaskManager(
|
||||
max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url
|
||||
)
|
||||
else:
|
||||
task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks)
|
||||
|
||||
# @router.post("/videos-test", response_model=TaskResponse, summary="Generate a short video")
|
||||
# async def create_video_test(request: Request, body: TaskVideoRequest):
|
||||
# task_id = utils.get_uuid()
|
||||
# request_id = base.get_task_id(request)
|
||||
# try:
|
||||
# task = {
|
||||
# "task_id": task_id,
|
||||
# "request_id": request_id,
|
||||
# "params": body.dict(),
|
||||
# }
|
||||
# task_manager.add_task(tm.start_test, task_id=task_id, params=body)
|
||||
# return utils.get_response(200, task)
|
||||
# except ValueError as e:
|
||||
# raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/videos", response_model=TaskResponse, summary="Generate a short video")
|
||||
def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest):
|
||||
def create_video(
|
||||
background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest
|
||||
):
|
||||
return create_task(request, body, stop_at="video")
|
||||
|
||||
|
||||
@router.post("/subtitle", response_model=TaskResponse, summary="Generate subtitle only")
|
||||
def create_subtitle(
|
||||
background_tasks: BackgroundTasks, request: Request, body: SubtitleRequest
|
||||
):
|
||||
return create_task(request, body, stop_at="subtitle")
|
||||
|
||||
|
||||
@router.post("/audio", response_model=TaskResponse, summary="Generate audio only")
|
||||
def create_audio(
|
||||
background_tasks: BackgroundTasks, request: Request, body: AudioRequest
|
||||
):
|
||||
return create_task(request, body, stop_at="audio")
|
||||
|
||||
|
||||
def create_task(
|
||||
request: Request,
|
||||
body: Union[TaskVideoRequest, SubtitleRequest, AudioRequest],
|
||||
stop_at: str,
|
||||
):
|
||||
task_id = utils.get_uuid()
|
||||
request_id = base.get_task_id(request)
|
||||
try:
|
||||
task = {
|
||||
"task_id": task_id,
|
||||
"request_id": request_id,
|
||||
"params": body.dict(),
|
||||
"params": body.model_dump(),
|
||||
}
|
||||
sm.state.update_task(task_id)
|
||||
# background_tasks.add_task(tm.start, task_id=task_id, params=body)
|
||||
task_manager.add_task(tm.start, task_id=task_id, params=body)
|
||||
logger.success(f"video created: {utils.to_json(task)}")
|
||||
task_manager.add_task(tm.start, task_id=task_id, params=body, stop_at=stop_at)
|
||||
logger.success(f"Task created: {utils.to_json(task)}")
|
||||
return utils.get_response(200, task)
|
||||
except ValueError as e:
|
||||
raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
|
||||
raise HttpException(
|
||||
task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status")
|
||||
def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
|
||||
query: TaskQueryRequest = Depends()):
|
||||
@router.get(
|
||||
"/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status"
|
||||
)
|
||||
def get_task(
|
||||
request: Request,
|
||||
task_id: str = Path(..., description="Task ID"),
|
||||
query: TaskQueryRequest = Depends(),
|
||||
):
|
||||
endpoint = config.app.get("endpoint", "")
|
||||
if not endpoint:
|
||||
endpoint = str(request.base_url)
|
||||
@@ -108,10 +135,16 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
|
||||
task["combined_videos"] = urls
|
||||
return utils.get_response(200, task)
|
||||
|
||||
raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found")
|
||||
raise HttpException(
|
||||
task_id=task_id, status_code=404, message=f"{request_id}: task not found"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/tasks/{task_id}", response_model=TaskDeletionResponse, summary="Delete a generated short video task")
|
||||
@router.delete(
|
||||
"/tasks/{task_id}",
|
||||
response_model=TaskDeletionResponse,
|
||||
summary="Delete a generated short video task",
|
||||
)
|
||||
def delete_video(request: Request, task_id: str = Path(..., description="Task ID")):
|
||||
request_id = base.get_task_id(request)
|
||||
task = sm.state.get_task(task_id)
|
||||
@@ -125,32 +158,40 @@ def delete_video(request: Request, task_id: str = Path(..., description="Task ID
|
||||
logger.success(f"video deleted: {utils.to_json(task)}")
|
||||
return utils.get_response(200)
|
||||
|
||||
raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found")
|
||||
raise HttpException(
|
||||
task_id=task_id, status_code=404, message=f"{request_id}: task not found"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files")
|
||||
@router.get(
|
||||
"/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files"
|
||||
)
|
||||
def get_bgm_list(request: Request):
|
||||
suffix = "*.mp3"
|
||||
song_dir = utils.song_dir()
|
||||
files = glob.glob(os.path.join(song_dir, suffix))
|
||||
bgm_list = []
|
||||
for file in files:
|
||||
bgm_list.append({
|
||||
"name": os.path.basename(file),
|
||||
"size": os.path.getsize(file),
|
||||
"file": file,
|
||||
})
|
||||
response = {
|
||||
"files": bgm_list
|
||||
}
|
||||
bgm_list.append(
|
||||
{
|
||||
"name": os.path.basename(file),
|
||||
"size": os.path.getsize(file),
|
||||
"file": file,
|
||||
}
|
||||
)
|
||||
response = {"files": bgm_list}
|
||||
return utils.get_response(200, response)
|
||||
|
||||
|
||||
@router.post("/musics", response_model=BgmUploadResponse, summary="Upload the BGM file to the songs directory")
|
||||
@router.post(
|
||||
"/musics",
|
||||
response_model=BgmUploadResponse,
|
||||
summary="Upload the BGM file to the songs directory",
|
||||
)
|
||||
def upload_bgm_file(request: Request, file: UploadFile = File(...)):
|
||||
request_id = base.get_task_id(request)
|
||||
# check file ext
|
||||
if file.filename.endswith('mp3'):
|
||||
if file.filename.endswith("mp3"):
|
||||
song_dir = utils.song_dir()
|
||||
save_path = os.path.join(song_dir, file.filename)
|
||||
# save file
|
||||
@@ -158,26 +199,26 @@ def upload_bgm_file(request: Request, file: UploadFile = File(...)):
|
||||
# If the file already exists, it will be overwritten
|
||||
file.file.seek(0)
|
||||
buffer.write(file.file.read())
|
||||
response = {
|
||||
"file": save_path
|
||||
}
|
||||
response = {"file": save_path}
|
||||
return utils.get_response(200, response)
|
||||
|
||||
raise HttpException('', status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded")
|
||||
raise HttpException(
|
||||
"", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stream/{file_path:path}")
|
||||
async def stream_video(request: Request, file_path: str):
|
||||
tasks_dir = utils.task_dir()
|
||||
video_path = os.path.join(tasks_dir, file_path)
|
||||
range_header = request.headers.get('Range')
|
||||
range_header = request.headers.get("Range")
|
||||
video_size = os.path.getsize(video_path)
|
||||
start, end = 0, video_size - 1
|
||||
|
||||
length = video_size
|
||||
if range_header:
|
||||
range_ = range_header.split('bytes=')[1]
|
||||
start, end = [int(part) if part else None for part in range_.split('-')]
|
||||
range_ = range_header.split("bytes=")[1]
|
||||
start, end = [int(part) if part else None for part in range_.split("-")]
|
||||
if start is None:
|
||||
start = video_size - end
|
||||
end = video_size - 1
|
||||
@@ -186,7 +227,7 @@ async def stream_video(request: Request, file_path: str):
|
||||
length = end - start + 1
|
||||
|
||||
def file_iterator(file_path, offset=0, bytes_to_read=None):
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
f.seek(offset, os.SEEK_SET)
|
||||
remaining = bytes_to_read or video_size
|
||||
while remaining > 0:
|
||||
@@ -197,10 +238,12 @@ async def stream_video(request: Request, file_path: str):
|
||||
remaining -= len(data)
|
||||
yield data
|
||||
|
||||
response = StreamingResponse(file_iterator(video_path, start, length), media_type='video/mp4')
|
||||
response.headers['Content-Range'] = f'bytes {start}-{end}/{video_size}'
|
||||
response.headers['Accept-Ranges'] = 'bytes'
|
||||
response.headers['Content-Length'] = str(length)
|
||||
response = StreamingResponse(
|
||||
file_iterator(video_path, start, length), media_type="video/mp4"
|
||||
)
|
||||
response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}"
|
||||
response.headers["Accept-Ranges"] = "bytes"
|
||||
response.headers["Content-Length"] = str(length)
|
||||
response.status_code = 206 # Partial Content
|
||||
|
||||
return response
|
||||
@@ -219,8 +262,10 @@ async def download_video(_: Request, file_path: str):
|
||||
file_path = pathlib.Path(video_path)
|
||||
filename = file_path.stem
|
||||
extension = file_path.suffix
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename={filename}{extension}"
|
||||
}
|
||||
return FileResponse(path=video_path, headers=headers, filename=f"{filename}{extension}",
|
||||
media_type=f'video/{extension[1:]}')
|
||||
headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"}
|
||||
return FileResponse(
|
||||
path=video_path,
|
||||
headers=headers,
|
||||
filename=f"{filename}{extension}",
|
||||
media_type=f"video/{extension[1:]}",
|
||||
)
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
import warnings
|
||||
|
||||
# 忽略 Pydantic 的特定警告
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message="Field name.*shadows an attribute in parent.*")
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=UserWarning,
|
||||
message="Field name.*shadows an attribute in parent.*",
|
||||
)
|
||||
|
||||
|
||||
class VideoConcatMode(str, Enum):
|
||||
@@ -61,7 +65,6 @@ class MaterialInfo:
|
||||
# # "male-zh-TW-YunJheNeural",
|
||||
#
|
||||
# # en-US
|
||||
#
|
||||
# "female-en-US-AnaNeural",
|
||||
# "female-en-US-AriaNeural",
|
||||
# "female-en-US-AvaNeural",
|
||||
@@ -93,6 +96,7 @@ class VideoParams(BaseModel):
|
||||
"stroke_width": 1.5
|
||||
}
|
||||
"""
|
||||
|
||||
video_subject: str
|
||||
video_script: str = "" # 用于生成视频的脚本
|
||||
video_terms: Optional[str | list] = None # 用于生成视频的关键词
|
||||
@@ -126,6 +130,38 @@ class VideoParams(BaseModel):
|
||||
paragraph_number: Optional[int] = 1
|
||||
|
||||
|
||||
class SubtitleRequest(BaseModel):
|
||||
video_script: str
|
||||
video_language: Optional[str] = ""
|
||||
voice_name: Optional[str] = "zh-CN-XiaoxiaoNeural-Female"
|
||||
voice_volume: Optional[float] = 1.0
|
||||
voice_rate: Optional[float] = 1.2
|
||||
bgm_type: Optional[str] = "random"
|
||||
bgm_file: Optional[str] = ""
|
||||
bgm_volume: Optional[float] = 0.2
|
||||
subtitle_position: Optional[str] = "bottom"
|
||||
font_name: Optional[str] = "STHeitiMedium.ttc"
|
||||
text_fore_color: Optional[str] = "#FFFFFF"
|
||||
text_background_color: Optional[str] = "transparent"
|
||||
font_size: int = 60
|
||||
stroke_color: Optional[str] = "#000000"
|
||||
stroke_width: float = 1.5
|
||||
video_source: Optional[str] = "local"
|
||||
subtitle_enabled: Optional[str] = "true"
|
||||
|
||||
|
||||
class AudioRequest(BaseModel):
|
||||
video_script: str
|
||||
video_language: Optional[str] = ""
|
||||
voice_name: Optional[str] = "zh-CN-XiaoxiaoNeural-Female"
|
||||
voice_volume: Optional[float] = 1.0
|
||||
voice_rate: Optional[float] = 1.2
|
||||
bgm_type: Optional[str] = "random"
|
||||
bgm_file: Optional[str] = ""
|
||||
bgm_volume: Optional[float] = 0.2
|
||||
video_source: Optional[str] = "local"
|
||||
|
||||
|
||||
class VideoScriptParams:
|
||||
"""
|
||||
{
|
||||
@@ -134,6 +170,7 @@ class VideoScriptParams:
|
||||
"paragraph_number": 1
|
||||
}
|
||||
"""
|
||||
|
||||
video_subject: Optional[str] = "春天的花海"
|
||||
video_language: Optional[str] = ""
|
||||
paragraph_number: Optional[int] = 1
|
||||
@@ -147,14 +184,17 @@ class VideoTermsParams:
|
||||
"amount": 5
|
||||
}
|
||||
"""
|
||||
|
||||
video_subject: Optional[str] = "春天的花海"
|
||||
video_script: Optional[str] = "春天的花海,如诗如画般展现在眼前。万物复苏的季节里,大地披上了一袭绚丽多彩的盛装。金黄的迎春、粉嫩的樱花、洁白的梨花、艳丽的郁金香……"
|
||||
video_script: Optional[str] = (
|
||||
"春天的花海,如诗如画般展现在眼前。万物复苏的季节里,大地披上了一袭绚丽多彩的盛装。金黄的迎春、粉嫩的樱花、洁白的梨花、艳丽的郁金香……"
|
||||
)
|
||||
amount: Optional[int] = 5
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
status: int = 200
|
||||
message: Optional[str] = 'success'
|
||||
message: Optional[str] = "success"
|
||||
data: Any = None
|
||||
|
||||
|
||||
@@ -189,9 +229,7 @@ class TaskResponse(BaseResponse):
|
||||
"example": {
|
||||
"status": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"task_id": "6c85c8cc-a77a-42b9-bc30-947815aa0558"
|
||||
}
|
||||
"data": {"task_id": "6c85c8cc-a77a-42b9-bc30-947815aa0558"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -210,8 +248,8 @@ class TaskQueryResponse(BaseResponse):
|
||||
],
|
||||
"combined_videos": [
|
||||
"http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/combined-1.mp4"
|
||||
]
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -230,8 +268,8 @@ class TaskDeletionResponse(BaseResponse):
|
||||
],
|
||||
"combined_videos": [
|
||||
"http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/combined-1.mp4"
|
||||
]
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -244,7 +282,7 @@ class VideoScriptResponse(BaseResponse):
|
||||
"message": "success",
|
||||
"data": {
|
||||
"video_script": "春天的花海,是大自然的一幅美丽画卷。在这个季节里,大地复苏,万物生长,花朵争相绽放,形成了一片五彩斑斓的花海..."
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -255,9 +293,7 @@ class VideoTermsResponse(BaseResponse):
|
||||
"example": {
|
||||
"status": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"video_terms": ["sky", "tree"]
|
||||
}
|
||||
"data": {"video_terms": ["sky", "tree"]},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -273,10 +309,10 @@ class BgmRetrieveResponse(BaseResponse):
|
||||
{
|
||||
"name": "output013.mp3",
|
||||
"size": 1891269,
|
||||
"file": "/MoneyPrinterTurbo/resource/songs/output013.mp3"
|
||||
"file": "/MoneyPrinterTurbo/resource/songs/output013.mp3",
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -287,8 +323,6 @@ class BgmUploadResponse(BaseResponse):
|
||||
"example": {
|
||||
"status": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"file": "/MoneyPrinterTurbo/resource/songs/example.mp3"
|
||||
}
|
||||
"data": {"file": "/MoneyPrinterTurbo/resource/songs/example.mp3"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -7,58 +7,42 @@ from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.models import const
|
||||
from app.models.schema import VideoParams, VideoConcatMode
|
||||
from app.services import llm, material, voice, video, subtitle
|
||||
from app.models.schema import VideoConcatMode, VideoParams
|
||||
from app.services import llm, material, subtitle, video, voice
|
||||
from app.services import state as sm
|
||||
from app.utils import utils
|
||||
|
||||
|
||||
def start(task_id, params: VideoParams):
|
||||
"""
|
||||
{
|
||||
"video_subject": "",
|
||||
"video_aspect": "横屏 16:9(西瓜视频)",
|
||||
"voice_name": "女生-晓晓",
|
||||
"enable_bgm": false,
|
||||
"font_name": "STHeitiMedium 黑体-中",
|
||||
"text_color": "#FFFFFF",
|
||||
"font_size": 60,
|
||||
"stroke_color": "#000000",
|
||||
"stroke_width": 1.5
|
||||
}
|
||||
"""
|
||||
logger.info(f"start task: {task_id}")
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
|
||||
|
||||
video_subject = params.video_subject
|
||||
voice_name = voice.parse_voice_name(params.voice_name)
|
||||
voice_rate = params.voice_rate
|
||||
paragraph_number = params.paragraph_number
|
||||
n_threads = params.n_threads
|
||||
max_clip_duration = params.video_clip_duration
|
||||
|
||||
def generate_script(task_id, params):
|
||||
logger.info("\n\n## generating video script")
|
||||
video_script = params.video_script.strip()
|
||||
if not video_script:
|
||||
video_script = llm.generate_script(video_subject=video_subject, language=params.video_language,
|
||||
paragraph_number=paragraph_number)
|
||||
video_script = llm.generate_script(
|
||||
video_subject=params.video_subject,
|
||||
language=params.video_language,
|
||||
paragraph_number=params.paragraph_number,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"video script: \n{video_script}")
|
||||
|
||||
if not video_script:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
logger.error("failed to generate video script.")
|
||||
return
|
||||
return None
|
||||
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
|
||||
return video_script
|
||||
|
||||
|
||||
def generate_terms(task_id, params, video_script):
|
||||
logger.info("\n\n## generating video terms")
|
||||
video_terms = params.video_terms
|
||||
if not video_terms:
|
||||
video_terms = llm.generate_terms(video_subject=video_subject, video_script=video_script, amount=5)
|
||||
video_terms = llm.generate_terms(
|
||||
video_subject=params.video_subject, video_script=video_script, amount=5
|
||||
)
|
||||
else:
|
||||
if isinstance(video_terms, str):
|
||||
video_terms = [term.strip() for term in re.split(r'[,,]', video_terms)]
|
||||
video_terms = [term.strip() for term in re.split(r"[,,]", video_terms)]
|
||||
elif isinstance(video_terms, list):
|
||||
video_terms = [term.strip() for term in video_terms]
|
||||
else:
|
||||
@@ -69,9 +53,13 @@ def start(task_id, params: VideoParams):
|
||||
if not video_terms:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
logger.error("failed to generate video terms.")
|
||||
return
|
||||
return None
|
||||
|
||||
script_file = path.join(utils.task_dir(task_id), f"script.json")
|
||||
return video_terms
|
||||
|
||||
|
||||
def save_script_data(task_id, video_script, video_terms, params):
|
||||
script_file = path.join(utils.task_dir(task_id), "script.json")
|
||||
script_data = {
|
||||
"script": video_script,
|
||||
"search_terms": video_terms,
|
||||
@@ -81,11 +69,16 @@ def start(task_id, params: VideoParams):
|
||||
with open(script_file, "w", encoding="utf-8") as f:
|
||||
f.write(utils.to_json(script_data))
|
||||
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
|
||||
|
||||
def generate_audio(task_id, params, video_script):
|
||||
logger.info("\n\n## generating audio")
|
||||
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
|
||||
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_rate=voice_rate, voice_file=audio_file)
|
||||
audio_file = path.join(utils.task_dir(task_id), "audio.mp3")
|
||||
sub_maker = voice.tts(
|
||||
text=video_script,
|
||||
voice_name=voice.parse_voice_name(params.voice_name),
|
||||
voice_rate=params.voice_rate,
|
||||
voice_file=audio_file,
|
||||
)
|
||||
if sub_maker is None:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
logger.error(
|
||||
@@ -94,86 +87,100 @@ def start(task_id, params: VideoParams):
|
||||
2. check if the network is available. If you are in China, it is recommended to use a VPN and enable the global traffic mode.
|
||||
""".strip()
|
||||
)
|
||||
return
|
||||
return None, None
|
||||
|
||||
audio_duration = voice.get_audio_duration(sub_maker)
|
||||
audio_duration = math.ceil(audio_duration)
|
||||
audio_duration = math.ceil(voice.get_audio_duration(sub_maker))
|
||||
return audio_file, audio_duration
|
||||
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
|
||||
|
||||
subtitle_path = ""
|
||||
if params.subtitle_enabled:
|
||||
subtitle_path = path.join(utils.task_dir(task_id), f"subtitle.srt")
|
||||
subtitle_provider = config.app.get("subtitle_provider", "").strip().lower()
|
||||
logger.info(f"\n\n## generating subtitle, provider: {subtitle_provider}")
|
||||
subtitle_fallback = False
|
||||
if subtitle_provider == "edge":
|
||||
voice.create_subtitle(text=video_script, sub_maker=sub_maker, subtitle_file=subtitle_path)
|
||||
if not os.path.exists(subtitle_path):
|
||||
subtitle_fallback = True
|
||||
logger.warning("subtitle file not found, fallback to whisper")
|
||||
def generate_subtitle(task_id, params, video_script, sub_maker, audio_file):
|
||||
if not params.subtitle_enabled:
|
||||
return ""
|
||||
|
||||
if subtitle_provider == "whisper" or subtitle_fallback:
|
||||
subtitle.create(audio_file=audio_file, subtitle_file=subtitle_path)
|
||||
logger.info("\n\n## correcting subtitle")
|
||||
subtitle.correct(subtitle_file=subtitle_path, video_script=video_script)
|
||||
subtitle_path = path.join(utils.task_dir(task_id), "subtitle.srt")
|
||||
subtitle_provider = config.app.get("subtitle_provider", "").strip().lower()
|
||||
logger.info(f"\n\n## generating subtitle, provider: {subtitle_provider}")
|
||||
|
||||
subtitle_lines = subtitle.file_to_subtitles(subtitle_path)
|
||||
if not subtitle_lines:
|
||||
logger.warning(f"subtitle file is invalid: {subtitle_path}")
|
||||
subtitle_path = ""
|
||||
subtitle_fallback = False
|
||||
if subtitle_provider == "edge":
|
||||
voice.create_subtitle(
|
||||
text=video_script, sub_maker=sub_maker, subtitle_file=subtitle_path
|
||||
)
|
||||
if not os.path.exists(subtitle_path):
|
||||
subtitle_fallback = True
|
||||
logger.warning("subtitle file not found, fallback to whisper")
|
||||
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
|
||||
if subtitle_provider == "whisper" or subtitle_fallback:
|
||||
subtitle.create(audio_file=audio_file, subtitle_file=subtitle_path)
|
||||
logger.info("\n\n## correcting subtitle")
|
||||
subtitle.correct(subtitle_file=subtitle_path, video_script=video_script)
|
||||
|
||||
downloaded_videos = []
|
||||
subtitle_lines = subtitle.file_to_subtitles(subtitle_path)
|
||||
if not subtitle_lines:
|
||||
logger.warning(f"subtitle file is invalid: {subtitle_path}")
|
||||
return ""
|
||||
|
||||
return subtitle_path
|
||||
|
||||
|
||||
def get_video_materials(task_id, params, video_terms, audio_duration):
|
||||
if params.video_source == "local":
|
||||
logger.info("\n\n## preprocess local materials")
|
||||
materials = video.preprocess_video(materials=params.video_materials, clip_duration=max_clip_duration)
|
||||
print(materials)
|
||||
|
||||
materials = video.preprocess_video(
|
||||
materials=params.video_materials, clip_duration=params.video_clip_duration
|
||||
)
|
||||
if not materials:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
logger.error("no valid materials found, please check the materials and try again.")
|
||||
return
|
||||
for material_info in materials:
|
||||
print(material_info)
|
||||
downloaded_videos.append(material_info.url)
|
||||
logger.error(
|
||||
"no valid materials found, please check the materials and try again."
|
||||
)
|
||||
return None
|
||||
return [material_info.url for material_info in materials]
|
||||
else:
|
||||
logger.info(f"\n\n## downloading videos from {params.video_source}")
|
||||
downloaded_videos = material.download_videos(task_id=task_id,
|
||||
search_terms=video_terms,
|
||||
source=params.video_source,
|
||||
video_aspect=params.video_aspect,
|
||||
video_contact_mode=params.video_concat_mode,
|
||||
audio_duration=audio_duration * params.video_count,
|
||||
max_clip_duration=max_clip_duration,
|
||||
)
|
||||
if not downloaded_videos:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
logger.error(
|
||||
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
|
||||
return
|
||||
downloaded_videos = material.download_videos(
|
||||
task_id=task_id,
|
||||
search_terms=video_terms,
|
||||
source=params.video_source,
|
||||
video_aspect=params.video_aspect,
|
||||
video_contact_mode=params.video_concat_mode,
|
||||
audio_duration=audio_duration * params.video_count,
|
||||
max_clip_duration=params.video_clip_duration,
|
||||
)
|
||||
if not downloaded_videos:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
logger.error(
|
||||
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN."
|
||||
)
|
||||
return None
|
||||
return downloaded_videos
|
||||
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
|
||||
|
||||
def generate_final_videos(
|
||||
task_id, params, downloaded_videos, audio_file, subtitle_path
|
||||
):
|
||||
final_video_paths = []
|
||||
combined_video_paths = []
|
||||
video_concat_mode = params.video_concat_mode
|
||||
if params.video_count > 1:
|
||||
video_concat_mode = VideoConcatMode.random
|
||||
video_concat_mode = (
|
||||
params.video_concat_mode if params.video_count > 1 else VideoConcatMode.random
|
||||
)
|
||||
|
||||
_progress = 50
|
||||
for i in range(params.video_count):
|
||||
index = i + 1
|
||||
combined_video_path = path.join(utils.task_dir(task_id), f"combined-{index}.mp4")
|
||||
combined_video_path = path.join(
|
||||
utils.task_dir(task_id), f"combined-{index}.mp4"
|
||||
)
|
||||
logger.info(f"\n\n## combining video: {index} => {combined_video_path}")
|
||||
video.combine_videos(combined_video_path=combined_video_path,
|
||||
video_paths=downloaded_videos,
|
||||
audio_file=audio_file,
|
||||
video_aspect=params.video_aspect,
|
||||
video_concat_mode=video_concat_mode,
|
||||
max_clip_duration=max_clip_duration,
|
||||
threads=n_threads)
|
||||
video.combine_videos(
|
||||
combined_video_path=combined_video_path,
|
||||
video_paths=downloaded_videos,
|
||||
audio_file=audio_file,
|
||||
video_aspect=params.video_aspect,
|
||||
video_concat_mode=video_concat_mode,
|
||||
max_clip_duration=params.video_clip_duration,
|
||||
threads=params.n_threads,
|
||||
)
|
||||
|
||||
_progress += 50 / params.video_count / 2
|
||||
sm.state.update_task(task_id, progress=_progress)
|
||||
@@ -181,13 +188,13 @@ def start(task_id, params: VideoParams):
|
||||
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
|
||||
|
||||
logger.info(f"\n\n## generating video: {index} => {final_video_path}")
|
||||
# Put everything together
|
||||
video.generate_video(video_path=combined_video_path,
|
||||
audio_path=audio_file,
|
||||
subtitle_path=subtitle_path,
|
||||
output_file=final_video_path,
|
||||
params=params,
|
||||
)
|
||||
video.generate_video(
|
||||
video_path=combined_video_path,
|
||||
audio_path=audio_file,
|
||||
subtitle_path=subtitle_path,
|
||||
output_file=final_video_path,
|
||||
params=params,
|
||||
)
|
||||
|
||||
_progress += 50 / params.video_count / 2
|
||||
sm.state.update_task(task_id, progress=_progress)
|
||||
@@ -195,16 +202,119 @@ def start(task_id, params: VideoParams):
|
||||
final_video_paths.append(final_video_path)
|
||||
combined_video_paths.append(combined_video_path)
|
||||
|
||||
logger.success(f"task {task_id} finished, generated {len(final_video_paths)} videos.")
|
||||
return final_video_paths, combined_video_paths
|
||||
|
||||
|
||||
def start(task_id, params: VideoParams, stop_at: str = "video"):
|
||||
logger.info(f"start task: {task_id}, stop_at: {stop_at}")
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
|
||||
|
||||
# 1. Generate script
|
||||
video_script = generate_script(task_id, params)
|
||||
if not video_script:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
return
|
||||
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
|
||||
|
||||
if stop_at == "script":
|
||||
sm.state.update_task(
|
||||
task_id, state=const.TASK_STATE_COMPLETE, progress=100, script=video_script
|
||||
)
|
||||
return {"script": video_script}
|
||||
|
||||
# 2. Generate terms
|
||||
video_terms = ""
|
||||
if params.video_source != "local":
|
||||
video_terms = generate_terms(task_id, params, video_script)
|
||||
if not video_terms:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
return
|
||||
|
||||
save_script_data(task_id, video_script, video_terms, params)
|
||||
|
||||
if stop_at == "terms":
|
||||
sm.state.update_task(
|
||||
task_id, state=const.TASK_STATE_COMPLETE, progress=100, terms=video_terms
|
||||
)
|
||||
return {"script": video_script, "terms": video_terms}
|
||||
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
|
||||
|
||||
# 3. Generate audio
|
||||
audio_file, audio_duration = generate_audio(task_id, params, video_script)
|
||||
if not audio_file:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
return
|
||||
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
|
||||
|
||||
if stop_at == "audio":
|
||||
sm.state.update_task(
|
||||
task_id,
|
||||
state=const.TASK_STATE_COMPLETE,
|
||||
progress=100,
|
||||
audio_file=audio_file,
|
||||
)
|
||||
return {"audio_file": audio_file, "audio_duration": audio_duration}
|
||||
|
||||
# 4. Generate subtitle
|
||||
subtitle_path = generate_subtitle(task_id, params, video_script, None, audio_file)
|
||||
|
||||
if stop_at == "subtitle":
|
||||
sm.state.update_task(
|
||||
task_id,
|
||||
state=const.TASK_STATE_COMPLETE,
|
||||
progress=100,
|
||||
subtitle_path=subtitle_path,
|
||||
)
|
||||
return {"subtitle_path": subtitle_path}
|
||||
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
|
||||
|
||||
# 5. Get video materials
|
||||
downloaded_videos = get_video_materials(
|
||||
task_id, params, video_terms, audio_duration
|
||||
)
|
||||
if not downloaded_videos:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
return
|
||||
|
||||
if stop_at == "materials":
|
||||
sm.state.update_task(
|
||||
task_id,
|
||||
state=const.TASK_STATE_COMPLETE,
|
||||
progress=100,
|
||||
materials=downloaded_videos,
|
||||
)
|
||||
return {"materials": downloaded_videos}
|
||||
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
|
||||
|
||||
# 6. Generate final videos
|
||||
final_video_paths, combined_video_paths = generate_final_videos(
|
||||
task_id, params, downloaded_videos, audio_file, subtitle_path
|
||||
)
|
||||
|
||||
if not final_video_paths:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
return
|
||||
|
||||
logger.success(
|
||||
f"task {task_id} finished, generated {len(final_video_paths)} videos."
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"videos": final_video_paths,
|
||||
"combined_videos": combined_video_paths
|
||||
"combined_videos": combined_video_paths,
|
||||
"script": video_script,
|
||||
"terms": video_terms,
|
||||
"audio_file": audio_file,
|
||||
"audio_duration": audio_duration,
|
||||
"subtitle_path": subtitle_path,
|
||||
"materials": downloaded_videos,
|
||||
}
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
|
||||
sm.state.update_task(
|
||||
task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs
|
||||
)
|
||||
return kwargs
|
||||
|
||||
# def start_test(task_id, params: VideoParams):
|
||||
# print(f"start task {task_id} \n")
|
||||
# time.sleep(5)
|
||||
# print(f"task {task_id} finished \n")
|
||||
|
||||
Reference in New Issue
Block a user