diff --git a/app/asgi.py b/app/asgi.py index d21a724..aec304c 100644 --- a/app/asgi.py +++ b/app/asgi.py @@ -1,4 +1,5 @@ """Application implementation - ASGI.""" + import os from fastapi import FastAPI, Request @@ -24,7 +25,9 @@ def exception_handler(request: Request, e: HttpException): def validation_exception_handler(request: Request, e: RequestValidationError): return JSONResponse( status_code=400, - content=utils.get_response(status=400, data=e.errors(), message='field required'), + content=utils.get_response( + status=400, data=e.errors(), message="field required" + ), ) @@ -61,7 +64,9 @@ app.add_middleware( ) task_dir = utils.task_dir() -app.mount("/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name="") +app.mount( + "/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name="" +) public_dir = utils.public_dir() app.mount("/", StaticFiles(directory=public_dir, html=True), name="") diff --git a/app/config/__init__.py b/app/config/__init__.py index 4bfbd1d..dd46812 100644 --- a/app/config/__init__.py +++ b/app/config/__init__.py @@ -10,7 +10,9 @@ from app.utils import utils def __init_logger(): # _log_file = utils.storage_dir("logs/server.log") _lvl = config.log_level - root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + root_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + ) def format_record(record): # 获取日志记录中的文件全路径 @@ -21,10 +23,13 @@ def __init_logger(): record["file"].path = f"./{relative_path}" # 返回修改后的格式字符串 # 您可以根据需要调整这里的格式 - _format = '{time:%Y-%m-%d %H:%M:%S} | ' + \ - '{level} | ' + \ - '"{file.path}:{line}": {function} ' + \ - '- {message}' + "\n" + _format = ( + "{time:%Y-%m-%d %H:%M:%S} | " + + "{level} | " + + '"{file.path}:{line}": {function} ' + + "- {message}" + + "\n" + ) return _format logger.remove() diff --git a/app/config/config.py b/app/config/config.py index 843ef75..9717c89 100644 --- a/app/config/config.py +++ b/app/config/config.py @@ -25,7 +25,7 @@ def load_config(): _config_ = toml.load(config_file) except Exception as e: logger.warning(f"load config failed: {str(e)}, try to load as utf-8-sig") - with open(config_file, mode="r", encoding='utf-8-sig') as fp: + with open(config_file, mode="r", encoding="utf-8-sig") as fp: _cfg_content = fp.read() _config_ = toml.loads(_cfg_content) return _config_ @@ -52,8 +52,10 @@ log_level = _cfg.get("log_level", "DEBUG") listen_host = _cfg.get("listen_host", "0.0.0.0") listen_port = _cfg.get("listen_port", 8080) project_name = _cfg.get("project_name", "MoneyPrinterTurbo") -project_description = _cfg.get("project_description", - "https://github.com/harry0703/MoneyPrinterTurbo") +project_description = _cfg.get( + "project_description", + "https://github.com/harry0703/MoneyPrinterTurbo", +) project_version = _cfg.get("project_version", "1.1.9") reload_debug = False diff --git a/app/controllers/base.py b/app/controllers/base.py index b2cdd9b..122e341 100644 --- a/app/controllers/base.py +++ b/app/controllers/base.py @@ -7,14 +7,14 @@ from app.models.exception import HttpException def get_task_id(request: Request): - task_id = request.headers.get('x-task-id') + task_id = request.headers.get("x-task-id") if not task_id: task_id = uuid4() return str(task_id) def get_api_key(request: Request): - api_key = request.headers.get('x-api-key') + api_key = request.headers.get("x-api-key") return api_key @@ -23,5 +23,9 @@ def verify_token(request: Request): if token != config.app.get("api_key", ""): request_id = get_task_id(request) request_url = request.url - user_agent = request.headers.get('user-agent') - raise HttpException(task_id=request_id, status_code=401, message=f"invalid token: {request_url}, {user_agent}") + user_agent = request.headers.get("user-agent") + raise HttpException( + task_id=request_id, + status_code=401, + message=f"invalid token: {request_url}, {user_agent}", + ) diff --git a/app/controllers/manager/base_manager.py b/app/controllers/manager/base_manager.py index 99cbf6f..462589e 100644 --- a/app/controllers/manager/base_manager.py +++ b/app/controllers/manager/base_manager.py @@ -18,11 +18,15 @@ class TaskManager: print(f"add task: {func.__name__}, current_tasks: {self.current_tasks}") self.execute_task(func, *args, **kwargs) else: - print(f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}") + print( + f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}" + ) self.enqueue({"func": func, "args": args, "kwargs": kwargs}) def execute_task(self, func: Callable, *args: Any, **kwargs: Any): - thread = threading.Thread(target=self.run_task, args=(func, *args), kwargs=kwargs) + thread = threading.Thread( + target=self.run_task, args=(func, *args), kwargs=kwargs + ) thread.start() def run_task(self, func: Callable, *args: Any, **kwargs: Any): @@ -35,11 +39,14 @@ class TaskManager: def check_queue(self): with self.lock: - if self.current_tasks < self.max_concurrent_tasks and not self.is_queue_empty(): + if ( + self.current_tasks < self.max_concurrent_tasks + and not self.is_queue_empty() + ): task_info = self.dequeue() - func = task_info['func'] - args = task_info.get('args', ()) - kwargs = task_info.get('kwargs', {}) + func = task_info["func"] + args = task_info.get("args", ()) + kwargs = task_info.get("kwargs", {}) self.execute_task(func, *args, **kwargs) def task_done(self): diff --git a/app/controllers/manager/redis_manager.py b/app/controllers/manager/redis_manager.py index a37c26c..cad1912 100644 --- a/app/controllers/manager/redis_manager.py +++ b/app/controllers/manager/redis_manager.py @@ -8,7 +8,7 @@ from app.models.schema import VideoParams from app.services import task as tm FUNC_MAP = { - 'start': tm.start, + "start": tm.start, # 'start_test': tm.start_test } @@ -24,11 +24,15 @@ class RedisTaskManager(TaskManager): def enqueue(self, task: Dict): task_with_serializable_params = task.copy() - if 'params' in task['kwargs'] and isinstance(task['kwargs']['params'], VideoParams): - task_with_serializable_params['kwargs']['params'] = task['kwargs']['params'].dict() + if "params" in task["kwargs"] and isinstance( + task["kwargs"]["params"], VideoParams + ): + task_with_serializable_params["kwargs"]["params"] = task["kwargs"][ + "params" + ].dict() # 将函数对象转换为其名称 - task_with_serializable_params['func'] = task['func'].__name__ + task_with_serializable_params["func"] = task["func"].__name__ self.redis_client.rpush(self.queue, json.dumps(task_with_serializable_params)) def dequeue(self): @@ -36,10 +40,14 @@ class RedisTaskManager(TaskManager): if task_json: task_info = json.loads(task_json) # 将函数名称转换回函数对象 - task_info['func'] = FUNC_MAP[task_info['func']] + task_info["func"] = FUNC_MAP[task_info["func"]] - if 'params' in task_info['kwargs'] and isinstance(task_info['kwargs']['params'], dict): - task_info['kwargs']['params'] = VideoParams(**task_info['kwargs']['params']) + if "params" in task_info["kwargs"] and isinstance( + task_info["kwargs"]["params"], dict + ): + task_info["kwargs"]["params"] = VideoParams( + **task_info["kwargs"]["params"] + ) return task_info return None diff --git a/app/controllers/ping.py b/app/controllers/ping.py index 980e81d..a3eeff0 100644 --- a/app/controllers/ping.py +++ b/app/controllers/ping.py @@ -4,6 +4,11 @@ from fastapi import Request router = APIRouter() -@router.get("/ping", tags=["Health Check"], description="检查服务可用性", response_description="pong") +@router.get( + "/ping", + tags=["Health Check"], + description="检查服务可用性", + response_description="pong", +) def ping(request: Request) -> str: return "pong" diff --git a/app/controllers/v1/base.py b/app/controllers/v1/base.py index 99e5729..51794df 100644 --- a/app/controllers/v1/base.py +++ b/app/controllers/v1/base.py @@ -3,8 +3,8 @@ from fastapi import APIRouter, Depends def new_router(dependencies=None): router = APIRouter() - router.tags = ['V1'] - router.prefix = '/api/v1' + router.tags = ["V1"] + router.prefix = "/api/v1" # 将认证依赖项应用于所有路由 if dependencies: router.dependencies = dependencies diff --git a/app/controllers/v1/llm.py b/app/controllers/v1/llm.py index e30df67..e841d68 100644 --- a/app/controllers/v1/llm.py +++ b/app/controllers/v1/llm.py @@ -1,6 +1,11 @@ from fastapi import Request from app.controllers.v1.base import new_router -from app.models.schema import VideoScriptResponse, VideoScriptRequest, VideoTermsResponse, VideoTermsRequest +from app.models.schema import ( + VideoScriptResponse, + VideoScriptRequest, + VideoTermsResponse, + VideoTermsRequest, +) from app.services import llm from app.utils import utils @@ -9,23 +14,31 @@ from app.utils import utils router = new_router() -@router.post("/scripts", response_model=VideoScriptResponse, summary="Create a script for the video") +@router.post( + "/scripts", + response_model=VideoScriptResponse, + summary="Create a script for the video", +) def generate_video_script(request: Request, body: VideoScriptRequest): - video_script = llm.generate_script(video_subject=body.video_subject, - language=body.video_language, - paragraph_number=body.paragraph_number) - response = { - "video_script": video_script - } + video_script = llm.generate_script( + video_subject=body.video_subject, + language=body.video_language, + paragraph_number=body.paragraph_number, + ) + response = {"video_script": video_script} return utils.get_response(200, response) -@router.post("/terms", response_model=VideoTermsResponse, summary="Generate video terms based on the video script") +@router.post( + "/terms", + response_model=VideoTermsResponse, + summary="Generate video terms based on the video script", +) def generate_video_terms(request: Request, body: VideoTermsRequest): - video_terms = llm.generate_terms(video_subject=body.video_subject, - video_script=body.video_script, - amount=body.amount) - response = { - "video_terms": video_terms - } + video_terms = llm.generate_terms( + video_subject=body.video_subject, + video_script=body.video_script, + amount=body.amount, + ) + response = {"video_terms": video_terms} return utils.get_response(200, response) diff --git a/app/models/const.py b/app/models/const.py index 2c62c95..e7540ef 100644 --- a/app/models/const.py +++ b/app/models/const.py @@ -1,11 +1,25 @@ PUNCTUATIONS = [ - "?", ",", ".", "、", ";", ":", "!", "…", - "?", ",", "。", "、", ";", ":", "!", "...", + "?", + ",", + ".", + "、", + ";", + ":", + "!", + "…", + "?", + ",", + "。", + "、", + ";", + ":", + "!", + "...", ] TASK_STATE_FAILED = -1 TASK_STATE_COMPLETE = 1 TASK_STATE_PROCESSING = 4 -FILE_TYPE_VIDEOS = ['mp4', 'mov', 'mkv', 'webm'] -FILE_TYPE_IMAGES = ['jpg', 'jpeg', 'png', 'bmp'] +FILE_TYPE_VIDEOS = ["mp4", "mov", "mkv", "webm"] +FILE_TYPE_IMAGES = ["jpg", "jpeg", "png", "bmp"] diff --git a/app/models/exception.py b/app/models/exception.py index 641cbf3..b186cae 100644 --- a/app/models/exception.py +++ b/app/models/exception.py @@ -5,16 +5,18 @@ from loguru import logger class HttpException(Exception): - def __init__(self, task_id: str, status_code: int, message: str = '', data: Any = None): + def __init__( + self, task_id: str, status_code: int, message: str = "", data: Any = None + ): self.message = message self.status_code = status_code self.data = data # 获取异常堆栈信息 tb_str = traceback.format_exc().strip() if not tb_str or tb_str == "NoneType: None": - msg = f'HttpException: {status_code}, {task_id}, {message}' + msg = f"HttpException: {status_code}, {task_id}, {message}" else: - msg = f'HttpException: {status_code}, {task_id}, {message}\n{tb_str}' + msg = f"HttpException: {status_code}, {task_id}, {message}\n{tb_str}" if status_code == 400: logger.warning(msg) diff --git a/app/services/llm.py b/app/services/llm.py index 40fe707..f6f1bc9 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -21,6 +21,7 @@ def _generate_response(prompt: str) -> str: if not model_name: model_name = "gpt-3.5-turbo-16k-0613" import g4f + content = g4f.ChatCompletion.create( model=model_name, messages=[{"role": "user", "content": prompt}], @@ -78,44 +79,56 @@ def _generate_response(prompt: str) -> str: base_url = config.app.get("ernie_base_url") model_name = "***" if not secret_key: - raise ValueError(f"{llm_provider}: secret_key is not set, please set it in the config.toml file.") + raise ValueError( + f"{llm_provider}: secret_key is not set, please set it in the config.toml file." + ) else: - raise ValueError("llm_provider is not set, please set it in the config.toml file.") + raise ValueError( + "llm_provider is not set, please set it in the config.toml file." + ) if not api_key: - raise ValueError(f"{llm_provider}: api_key is not set, please set it in the config.toml file.") + raise ValueError( + f"{llm_provider}: api_key is not set, please set it in the config.toml file." + ) if not model_name: - raise ValueError(f"{llm_provider}: model_name is not set, please set it in the config.toml file.") + raise ValueError( + f"{llm_provider}: model_name is not set, please set it in the config.toml file." + ) if not base_url: - raise ValueError(f"{llm_provider}: base_url is not set, please set it in the config.toml file.") + raise ValueError( + f"{llm_provider}: base_url is not set, please set it in the config.toml file." + ) if llm_provider == "qwen": import dashscope from dashscope.api_entities.dashscope_response import GenerationResponse + dashscope.api_key = api_key response = dashscope.Generation.call( - model=model_name, - messages=[{"role": "user", "content": prompt}] + model=model_name, messages=[{"role": "user", "content": prompt}] ) if response: if isinstance(response, GenerationResponse): status_code = response.status_code if status_code != 200: raise Exception( - f"[{llm_provider}] returned an error response: \"{response}\"") + f'[{llm_provider}] returned an error response: "{response}"' + ) content = response["output"]["text"] return content.replace("\n", "") else: raise Exception( - f"[{llm_provider}] returned an invalid response: \"{response}\"") + f'[{llm_provider}] returned an invalid response: "{response}"' + ) else: - raise Exception( - f"[{llm_provider}] returned an empty response") + raise Exception(f"[{llm_provider}] returned an empty response") if llm_provider == "gemini": import google.generativeai as genai - genai.configure(api_key=api_key, transport='rest') + + genai.configure(api_key=api_key, transport="rest") generation_config = { "temperature": 0.5, @@ -127,25 +140,27 @@ def _generate_response(prompt: str) -> str: safety_settings = [ { "category": "HARM_CATEGORY_HARASSMENT", - "threshold": "BLOCK_ONLY_HIGH" + "threshold": "BLOCK_ONLY_HIGH", }, { "category": "HARM_CATEGORY_HATE_SPEECH", - "threshold": "BLOCK_ONLY_HIGH" + "threshold": "BLOCK_ONLY_HIGH", }, { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "threshold": "BLOCK_ONLY_HIGH" + "threshold": "BLOCK_ONLY_HIGH", }, { "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "threshold": "BLOCK_ONLY_HIGH" + "threshold": "BLOCK_ONLY_HIGH", }, ] - model = genai.GenerativeModel(model_name=model_name, - generation_config=generation_config, - safety_settings=safety_settings) + model = genai.GenerativeModel( + model_name=model_name, + generation_config=generation_config, + safety_settings=safety_settings, + ) try: response = model.generate_content(prompt) @@ -158,15 +173,16 @@ def _generate_response(prompt: str) -> str: if llm_provider == "cloudflare": import requests + response = requests.post( f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}", headers={"Authorization": f"Bearer {api_key}"}, json={ "messages": [ {"role": "system", "content": "You are a friendly assistant"}, - {"role": "user", "content": prompt} + {"role": "user", "content": prompt}, ] - } + }, ) result = response.json() logger.info(result) @@ -174,30 +190,35 @@ def _generate_response(prompt: str) -> str: if llm_provider == "ernie": import requests - params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} - access_token = requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params).json().get( - "access_token") + + params = { + "grant_type": "client_credentials", + "client_id": api_key, + "client_secret": secret_key, + } + access_token = ( + requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params) + .json() + .get("access_token") + ) url = f"{base_url}?access_token={access_token}" - payload = json.dumps({ - "messages": [ - { - "role": "user", - "content": prompt - } - ], - "temperature": 0.5, - "top_p": 0.8, - "penalty_score": 1, - "disable_search": False, - "enable_citation": False, - "response_format": "text" - }) - headers = { - 'Content-Type': 'application/json' - } + payload = json.dumps( + { + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.5, + "top_p": 0.8, + "penalty_score": 1, + "disable_search": False, + "enable_citation": False, + "response_format": "text", + } + ) + headers = {"Content-Type": "application/json"} - response = requests.request("POST", url, headers=headers, data=payload).json() + response = requests.request( + "POST", url, headers=headers, data=payload + ).json() return response.get("result") if llm_provider == "azure": @@ -213,24 +234,27 @@ def _generate_response(prompt: str) -> str: ) response = client.chat.completions.create( - model=model_name, - messages=[{"role": "user", "content": prompt}] + model=model_name, messages=[{"role": "user", "content": prompt}] ) if response: if isinstance(response, ChatCompletion): content = response.choices[0].message.content else: raise Exception( - f"[{llm_provider}] returned an invalid response: \"{response}\", please check your network " - f"connection and try again.") + f'[{llm_provider}] returned an invalid response: "{response}", please check your network ' + f"connection and try again." + ) else: raise Exception( - f"[{llm_provider}] returned an empty response, please check your network connection and try again.") + f"[{llm_provider}] returned an empty response, please check your network connection and try again." + ) return content.replace("\n", "") -def generate_script(video_subject: str, language: str = "", paragraph_number: int = 1) -> str: +def generate_script( + video_subject: str, language: str = "", paragraph_number: int = 1 +) -> str: prompt = f""" # Role: Video Script Generator @@ -335,14 +359,16 @@ Please note that you must use English for generating video search terms; Chinese try: response = _generate_response(prompt) search_terms = json.loads(response) - if not isinstance(search_terms, list) or not all(isinstance(term, str) for term in search_terms): + if not isinstance(search_terms, list) or not all( + isinstance(term, str) for term in search_terms + ): logger.error("response is not a list of strings.") continue except Exception as e: logger.warning(f"failed to generate video terms: {str(e)}") if response: - match = re.search(r'\[.*]', response) + match = re.search(r"\[.*]", response) if match: try: search_terms = json.loads(match.group()) @@ -361,9 +387,13 @@ Please note that you must use English for generating video search terms; Chinese if __name__ == "__main__": video_subject = "生命的意义是什么" - script = generate_script(video_subject=video_subject, language="zh-CN", paragraph_number=1) + script = generate_script( + video_subject=video_subject, language="zh-CN", paragraph_number=1 + ) print("######################") print(script) - search_terms = generate_terms(video_subject=video_subject, video_script=script, amount=5) + search_terms = generate_terms( + video_subject=video_subject, video_script=script, amount=5 + ) print("######################") print(search_terms) diff --git a/app/services/material.py b/app/services/material.py index 6fc72a4..77bf6b9 100644 --- a/app/services/material.py +++ b/app/services/material.py @@ -19,7 +19,8 @@ def get_api_key(cfg_key: str): if not api_keys: raise ValueError( f"\n\n##### {cfg_key} is not set #####\n\nPlease set it in the config.toml file: {config.config_file}\n\n" - f"{utils.to_json(config.app)}") + f"{utils.to_json(config.app)}" + ) # if only one key is provided, return it if isinstance(api_keys, str): @@ -30,28 +31,29 @@ def get_api_key(cfg_key: str): return api_keys[requested_count % len(api_keys)] -def search_videos_pexels(search_term: str, - minimum_duration: int, - video_aspect: VideoAspect = VideoAspect.portrait, - ) -> List[MaterialInfo]: +def search_videos_pexels( + search_term: str, + minimum_duration: int, + video_aspect: VideoAspect = VideoAspect.portrait, +) -> List[MaterialInfo]: aspect = VideoAspect(video_aspect) video_orientation = aspect.name video_width, video_height = aspect.to_resolution() api_key = get_api_key("pexels_api_keys") - headers = { - "Authorization": api_key - } + headers = {"Authorization": api_key} # Build URL - params = { - "query": search_term, - "per_page": 20, - "orientation": video_orientation - } + params = {"query": search_term, "per_page": 20, "orientation": video_orientation} query_url = f"https://api.pexels.com/videos/search?{urlencode(params)}" logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}") try: - r = requests.get(query_url, headers=headers, proxies=config.proxy, verify=False, timeout=(30, 60)) + r = requests.get( + query_url, + headers=headers, + proxies=config.proxy, + verify=False, + timeout=(30, 60), + ) response = r.json() video_items = [] if "videos" not in response: @@ -83,10 +85,11 @@ def search_videos_pexels(search_term: str, return [] -def search_videos_pixabay(search_term: str, - minimum_duration: int, - video_aspect: VideoAspect = VideoAspect.portrait, - ) -> List[MaterialInfo]: +def search_videos_pixabay( + search_term: str, + minimum_duration: int, + video_aspect: VideoAspect = VideoAspect.portrait, +) -> List[MaterialInfo]: aspect = VideoAspect(video_aspect) video_width, video_height = aspect.to_resolution() @@ -97,13 +100,15 @@ def search_videos_pixabay(search_term: str, "q": search_term, "video_type": "all", # Accepted values: "all", "film", "animation" "per_page": 50, - "key": api_key + "key": api_key, } query_url = f"https://pixabay.com/api/videos/?{urlencode(params)}" logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}") try: - r = requests.get(query_url, proxies=config.proxy, verify=False, timeout=(30, 60)) + r = requests.get( + query_url, proxies=config.proxy, verify=False, timeout=(30, 60) + ) response = r.json() video_items = [] if "hits" not in response: @@ -155,7 +160,11 @@ def save_video(video_url: str, save_dir: str = "") -> str: # if video does not exist, download it with open(video_path, "wb") as f: - f.write(requests.get(video_url, proxies=config.proxy, verify=False, timeout=(60, 240)).content) + f.write( + requests.get( + video_url, proxies=config.proxy, verify=False, timeout=(60, 240) + ).content + ) if os.path.exists(video_path) and os.path.getsize(video_path) > 0: try: @@ -174,14 +183,15 @@ def save_video(video_url: str, save_dir: str = "") -> str: return "" -def download_videos(task_id: str, - search_terms: List[str], - source: str = "pexels", - video_aspect: VideoAspect = VideoAspect.portrait, - video_contact_mode: VideoConcatMode = VideoConcatMode.random, - audio_duration: float = 0.0, - max_clip_duration: int = 5, - ) -> List[str]: +def download_videos( + task_id: str, + search_terms: List[str], + source: str = "pexels", + video_aspect: VideoAspect = VideoAspect.portrait, + video_contact_mode: VideoConcatMode = VideoConcatMode.random, + audio_duration: float = 0.0, + max_clip_duration: int = 5, +) -> List[str]: valid_video_items = [] valid_video_urls = [] found_duration = 0.0 @@ -190,9 +200,11 @@ def download_videos(task_id: str, search_videos = search_videos_pixabay for search_term in search_terms: - video_items = search_videos(search_term=search_term, - minimum_duration=max_clip_duration, - video_aspect=video_aspect) + video_items = search_videos( + search_term=search_term, + minimum_duration=max_clip_duration, + video_aspect=video_aspect, + ) logger.info(f"found {len(video_items)} videos for '{search_term}'") for item in video_items: @@ -202,7 +214,8 @@ def download_videos(task_id: str, found_duration += item.duration logger.info( - f"found total videos: {len(valid_video_items)}, required duration: {audio_duration} seconds, found duration: {found_duration} seconds") + f"found total videos: {len(valid_video_items)}, required duration: {audio_duration} seconds, found duration: {found_duration} seconds" + ) video_paths = [] material_directory = config.app.get("material_directory", "").strip() @@ -218,14 +231,18 @@ def download_videos(task_id: str, for item in valid_video_items: try: logger.info(f"downloading video: {item.url}") - saved_video_path = save_video(video_url=item.url, save_dir=material_directory) + saved_video_path = save_video( + video_url=item.url, save_dir=material_directory + ) if saved_video_path: logger.info(f"video saved: {saved_video_path}") video_paths.append(saved_video_path) seconds = min(max_clip_duration, item.duration) total_duration += seconds if total_duration > audio_duration: - logger.info(f"total duration of downloaded videos: {total_duration} seconds, skip downloading more") + logger.info( + f"total duration of downloaded videos: {total_duration} seconds, skip downloading more" + ) break except Exception as e: logger.error(f"failed to download video: {utils.to_json(item)} => {str(e)}") @@ -234,4 +251,6 @@ def download_videos(task_id: str, if __name__ == "__main__": - download_videos("test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay") + download_videos( + "test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay" + ) diff --git a/app/services/state.py b/app/services/state.py index 5edd56d..51904fb 100644 --- a/app/services/state.py +++ b/app/services/state.py @@ -6,7 +6,6 @@ from app.models import const # Base class for state management class BaseState(ABC): - @abstractmethod def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs): pass @@ -18,11 +17,16 @@ class BaseState(ABC): # Memory state management class MemoryState(BaseState): - def __init__(self): self._tasks = {} - def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): + def update_task( + self, + task_id: str, + state: int = const.TASK_STATE_PROCESSING, + progress: int = 0, + **kwargs, + ): progress = int(progress) if progress > 100: progress = 100 @@ -43,12 +47,18 @@ class MemoryState(BaseState): # Redis state management class RedisState(BaseState): - - def __init__(self, host='localhost', port=6379, db=0, password=None): + def __init__(self, host="localhost", port=6379, db=0, password=None): import redis + self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password) - def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): + def update_task( + self, + task_id: str, + state: int = const.TASK_STATE_PROCESSING, + progress: int = 0, + **kwargs, + ): progress = int(progress) if progress > 100: progress = 100 @@ -67,7 +77,10 @@ class RedisState(BaseState): if not task_data: return None - task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()} + task = { + key.decode("utf-8"): self._convert_to_original_type(value) + for key, value in task_data.items() + } return task def delete_task(self, task_id: str): @@ -79,7 +92,7 @@ class RedisState(BaseState): Convert the value from byte string to its original data type. You can extend this method to handle other data types as needed. """ - value_str = value.decode('utf-8') + value_str = value.decode("utf-8") try: # try to convert byte string array to list @@ -100,4 +113,10 @@ _redis_port = config.app.get("redis_port", 6379) _redis_db = config.app.get("redis_db", 0) _redis_password = config.app.get("redis_password", None) -state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password) if _enable_redis else MemoryState() +state = ( + RedisState( + host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password + ) + if _enable_redis + else MemoryState() +) diff --git a/app/services/subtitle.py b/app/services/subtitle.py index 1bda56a..a939a54 100644 --- a/app/services/subtitle.py +++ b/app/services/subtitle.py @@ -23,18 +23,22 @@ def create(audio_file, subtitle_file: str = ""): if not os.path.isdir(model_path) or not os.path.isfile(model_bin_file): model_path = model_size - logger.info(f"loading model: {model_path}, device: {device}, compute_type: {compute_type}") + logger.info( + f"loading model: {model_path}, device: {device}, compute_type: {compute_type}" + ) try: - model = WhisperModel(model_size_or_path=model_path, - device=device, - compute_type=compute_type) + model = WhisperModel( + model_size_or_path=model_path, device=device, compute_type=compute_type + ) except Exception as e: - logger.error(f"failed to load model: {e} \n\n" - f"********************************************\n" - f"this may be caused by network issue. \n" - f"please download the model manually and put it in the 'models' folder. \n" - f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n" - f"********************************************\n\n") + logger.error( + f"failed to load model: {e} \n\n" + f"********************************************\n" + f"this may be caused by network issue. \n" + f"please download the model manually and put it in the 'models' folder. \n" + f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n" + f"********************************************\n\n" + ) return None logger.info(f"start, output file: {subtitle_file}") @@ -49,7 +53,9 @@ def create(audio_file, subtitle_file: str = ""): vad_parameters=dict(min_silence_duration_ms=500), ) - logger.info(f"detected language: '{info.language}', probability: {info.language_probability:.2f}") + logger.info( + f"detected language: '{info.language}', probability: {info.language_probability:.2f}" + ) start = timer() subtitles = [] @@ -62,11 +68,9 @@ def create(audio_file, subtitle_file: str = ""): msg = "[%.2fs -> %.2fs] %s" % (seg_start, seg_end, seg_text) logger.debug(msg) - subtitles.append({ - "msg": seg_text, - "start_time": seg_start, - "end_time": seg_end - }) + subtitles.append( + {"msg": seg_text, "start_time": seg_start, "end_time": seg_end} + ) for segment in segments: words_idx = 0 @@ -119,7 +123,11 @@ def create(audio_file, subtitle_file: str = ""): for subtitle in subtitles: text = subtitle.get("msg") if text: - lines.append(utils.text_to_srt(idx, text, subtitle.get("start_time"), subtitle.get("end_time"))) + lines.append( + utils.text_to_srt( + idx, text, subtitle.get("start_time"), subtitle.get("end_time") + ) + ) idx += 1 sub = "\n".join(lines) + "\n" @@ -136,12 +144,12 @@ def file_to_subtitles(filename): current_times = None current_text = "" index = 0 - with open(filename, 'r', encoding="utf-8") as f: + with open(filename, "r", encoding="utf-8") as f: for line in f: times = re.findall("([0-9]*:[0-9]*:[0-9]*,[0-9]*)", line) if times: current_times = line - elif line.strip() == '' and current_times: + elif line.strip() == "" and current_times: index += 1 times_texts.append((index, current_times.strip(), current_text.strip())) current_times, current_text = None, "" @@ -166,9 +174,10 @@ def levenshtein_distance(s1, s2): substitutions = previous_row[j] + (c1 != c2) current_row.append(min(insertions, deletions, substitutions)) previous_row = current_row - + return previous_row[-1] + def similarity(a, b): distance = levenshtein_distance(a.lower(), b.lower()) max_length = max(len(a), len(b)) @@ -194,26 +203,44 @@ def correct(subtitle_file, video_script): subtitle_index += 1 else: combined_subtitle = subtitle_line - start_time = subtitle_items[subtitle_index][1].split(' --> ')[0] - end_time = subtitle_items[subtitle_index][1].split(' --> ')[1] + start_time = subtitle_items[subtitle_index][1].split(" --> ")[0] + end_time = subtitle_items[subtitle_index][1].split(" --> ")[1] next_subtitle_index = subtitle_index + 1 while next_subtitle_index < len(subtitle_items): next_subtitle = subtitle_items[next_subtitle_index][2].strip() - if similarity(script_line, combined_subtitle + " " + next_subtitle) > similarity(script_line, combined_subtitle): + if similarity( + script_line, combined_subtitle + " " + next_subtitle + ) > similarity(script_line, combined_subtitle): combined_subtitle += " " + next_subtitle - end_time = subtitle_items[next_subtitle_index][1].split(' --> ')[1] + end_time = subtitle_items[next_subtitle_index][1].split(" --> ")[1] next_subtitle_index += 1 else: break if similarity(script_line, combined_subtitle) > 0.8: - logger.warning(f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}") - new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line)) + logger.warning( + f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}" + ) + new_subtitle_items.append( + ( + len(new_subtitle_items) + 1, + f"{start_time} --> {end_time}", + script_line, + ) + ) corrected = True else: - logger.warning(f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}") - new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line)) + logger.warning( + f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}" + ) + new_subtitle_items.append( + ( + len(new_subtitle_items) + 1, + f"{start_time} --> {end_time}", + script_line, + ) + ) corrected = True script_index += 1 @@ -223,10 +250,22 @@ def correct(subtitle_file, video_script): while script_index < len(script_lines): logger.warning(f"Extra script line: {script_lines[script_index]}") if subtitle_index < len(subtitle_items): - new_subtitle_items.append((len(new_subtitle_items) + 1, subtitle_items[subtitle_index][1], script_lines[script_index])) + new_subtitle_items.append( + ( + len(new_subtitle_items) + 1, + subtitle_items[subtitle_index][1], + script_lines[script_index], + ) + ) subtitle_index += 1 else: - new_subtitle_items.append((len(new_subtitle_items) + 1, "00:00:00,000 --> 00:00:00,000", script_lines[script_index])) + new_subtitle_items.append( + ( + len(new_subtitle_items) + 1, + "00:00:00,000 --> 00:00:00,000", + script_lines[script_index], + ) + ) script_index += 1 corrected = True diff --git a/app/services/voice.py b/app/services/voice.py index 2a31637..287e22d 100644 --- a/app/services/voice.py +++ b/app/services/voice.py @@ -988,7 +988,7 @@ Name: zh-CN-XiaoxiaoMultilingualNeural-V2 Gender: Female """.strip() voices = [] - name = '' + name = "" for line in voices_str.split("\n"): line = line.strip() if not line: @@ -1008,7 +1008,7 @@ Gender: Female voices.append(f"{name}-{gender}") else: voices.append(f"{name}-{gender}") - name = '' + name = "" voices.sort() return voices @@ -1028,7 +1028,9 @@ def is_azure_v2_voice(voice_name: str): return "" -def tts(text: str, voice_name: str, voice_rate: float, voice_file: str) -> [SubMaker, None]: +def tts( + text: str, voice_name: str, voice_rate: float, voice_file: str +) -> [SubMaker, None]: if is_azure_v2_voice(voice_name): return azure_tts_v2(text, voice_name, voice_file) return azure_tts_v1(text, voice_name, voice_rate, voice_file) @@ -1042,9 +1044,11 @@ def convert_rate_to_percent(rate: float) -> str: return f"+{percent}%" else: return f"{percent}%" - -def azure_tts_v1(text: str, voice_name: str, voice_rate: float, voice_file: str) -> [SubMaker, None]: + +def azure_tts_v1( + text: str, voice_name: str, voice_rate: float, voice_file: str +) -> [SubMaker, None]: voice_name = parse_voice_name(voice_name) text = text.strip() rate_str = convert_rate_to_percent(voice_rate) @@ -1060,7 +1064,9 @@ def azure_tts_v1(text: str, voice_name: str, voice_rate: float, voice_file: str) if chunk["type"] == "audio": file.write(chunk["data"]) elif chunk["type"] == "WordBoundary": - sub_maker.create_sub((chunk["offset"], chunk["duration"]), chunk["text"]) + sub_maker.create_sub( + (chunk["offset"], chunk["duration"]), chunk["text"] + ) return sub_maker sub_maker = asyncio.run(_do()) @@ -1085,8 +1091,12 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None def _format_duration_to_offset(duration) -> int: if isinstance(duration, str): time_obj = datetime.strptime(duration, "%H:%M:%S.%f") - milliseconds = (time_obj.hour * 3600000) + (time_obj.minute * 60000) + (time_obj.second * 1000) + ( - time_obj.microsecond // 1000) + milliseconds = ( + (time_obj.hour * 3600000) + + (time_obj.minute * 60000) + + (time_obj.second * 1000) + + (time_obj.microsecond // 1000) + ) return milliseconds * 10000 if isinstance(duration, int): @@ -1119,20 +1129,29 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None # Creates an instance of a speech config with specified subscription key and service region. speech_key = config.azure.get("speech_key", "") service_region = config.azure.get("speech_region", "") - audio_config = speechsdk.audio.AudioOutputConfig(filename=voice_file, use_default_speaker=True) - speech_config = speechsdk.SpeechConfig(subscription=speech_key, - region=service_region) + audio_config = speechsdk.audio.AudioOutputConfig( + filename=voice_file, use_default_speaker=True + ) + speech_config = speechsdk.SpeechConfig( + subscription=speech_key, region=service_region + ) speech_config.speech_synthesis_voice_name = voice_name # speech_config.set_property(property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestSentenceBoundary, # value='true') - speech_config.set_property(property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestWordBoundary, - value='true') + speech_config.set_property( + property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestWordBoundary, + value="true", + ) speech_config.set_speech_synthesis_output_format( - speechsdk.SpeechSynthesisOutputFormat.Audio48Khz192KBitRateMonoMp3) - speech_synthesizer = speechsdk.SpeechSynthesizer(audio_config=audio_config, - speech_config=speech_config) - speech_synthesizer.synthesis_word_boundary.connect(speech_synthesizer_word_boundary_cb) + speechsdk.SpeechSynthesisOutputFormat.Audio48Khz192KBitRateMonoMp3 + ) + speech_synthesizer = speechsdk.SpeechSynthesizer( + audio_config=audio_config, speech_config=speech_config + ) + speech_synthesizer.synthesis_word_boundary.connect( + speech_synthesizer_word_boundary_cb + ) result = speech_synthesizer.speak_text_async(text).get() if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: @@ -1140,9 +1159,13 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None return sub_maker elif result.reason == speechsdk.ResultReason.Canceled: cancellation_details = result.cancellation_details - logger.error(f"azure v2 speech synthesis canceled: {cancellation_details.reason}") + logger.error( + f"azure v2 speech synthesis canceled: {cancellation_details.reason}" + ) if cancellation_details.reason == speechsdk.CancellationReason.Error: - logger.error(f"azure v2 speech synthesis error: {cancellation_details.error_details}") + logger.error( + f"azure v2 speech synthesis error: {cancellation_details.error_details}" + ) logger.info(f"completed, output file: {voice_file}") except Exception as e: logger.error(f"failed, error: {str(e)}") @@ -1179,11 +1202,7 @@ def create_subtitle(sub_maker: submaker.SubMaker, text: str, subtitle_file: str) """ start_t = mktimestamp(start_time).replace(".", ",") end_t = mktimestamp(end_time).replace(".", ",") - return ( - f"{idx}\n" - f"{start_t} --> {end_t}\n" - f"{sub_text}\n" - ) + return f"{idx}\n" f"{start_t} --> {end_t}\n" f"{sub_text}\n" start_time = -1.0 sub_items = [] @@ -1240,12 +1259,16 @@ def create_subtitle(sub_maker: submaker.SubMaker, text: str, subtitle_file: str) try: sbs = subtitles.file_to_subtitles(subtitle_file, encoding="utf-8") duration = max([tb for ((ta, tb), txt) in sbs]) - logger.info(f"completed, subtitle file created: {subtitle_file}, duration: {duration}") + logger.info( + f"completed, subtitle file created: {subtitle_file}, duration: {duration}" + ) except Exception as e: logger.error(f"failed, error: {str(e)}") os.remove(subtitle_file) else: - logger.warning(f"failed, sub_items len: {len(sub_items)}, script_lines len: {len(script_lines)}") + logger.warning( + f"failed, sub_items len: {len(sub_items)}, script_lines len: {len(script_lines)}" + ) except Exception as e: logger.error(f"failed, error: {str(e)}") @@ -1269,7 +1292,6 @@ if __name__ == "__main__": voices = get_all_azure_voices() print(len(voices)) - async def _do(): temp_dir = utils.storage_dir("temp") @@ -1318,12 +1340,13 @@ if __name__ == "__main__": for voice_name in voice_names: voice_file = f"{temp_dir}/tts-{voice_name}.mp3" subtitle_file = f"{temp_dir}/tts.mp3.srt" - sub_maker = azure_tts_v2(text=text, voice_name=voice_name, voice_file=voice_file) + sub_maker = azure_tts_v2( + text=text, voice_name=voice_name, voice_file=voice_file + ) create_subtitle(sub_maker=sub_maker, text=text, subtitle_file=subtitle_file) audio_duration = get_audio_duration(sub_maker) print(f"voice: {voice_name}, audio duration: {audio_duration}s") - loop = asyncio.get_event_loop_policy().get_event_loop() try: loop.run_until_complete(_do()) diff --git a/app/utils/utils.py b/app/utils/utils.py index cca2154..c12342d 100644 --- a/app/utils/utils.py +++ b/app/utils/utils.py @@ -15,12 +15,12 @@ urllib3.disable_warnings() def get_response(status: int, data: Any = None, message: str = ""): obj = { - 'status': status, + "status": status, } if data: - obj['data'] = data + obj["data"] = data if message: - obj['message'] = message + obj["message"] = message return obj @@ -41,7 +41,7 @@ def to_json(obj): elif isinstance(o, (list, tuple)): return [serialize(item) for item in o] # 如果对象是自定义类型,尝试返回其__dict__属性 - elif hasattr(o, '__dict__'): + elif hasattr(o, "__dict__"): return serialize(o.__dict__) # 其他情况返回None(或者可以选择抛出异常) else: @@ -199,7 +199,8 @@ def split_string_by_punctuations(s): def md5(text): import hashlib - return hashlib.md5(text.encode('utf-8')).hexdigest() + + return hashlib.md5(text.encode("utf-8")).hexdigest() def get_system_locale(): diff --git a/changelog.py b/changelog.py index 2f1cd88..31a1337 100644 --- a/changelog.py +++ b/changelog.py @@ -12,6 +12,6 @@ build_and_render( parse_refs=False, sections=["build", "deps", "feat", "fix", "refactor"], versioning="pep440", - bump="1.1.2", # 指定bump版本 + bump="1.1.2", # 指定bump版本 in_place=True, )