Format project code

This commit is contained in:
yyhhyyyyyy
2024-07-24 14:59:06 +08:00
parent bbd4e94941
commit 905841965a
18 changed files with 410 additions and 214 deletions

View File

@@ -1,4 +1,5 @@
"""Application implementation - ASGI."""
import os
from fastapi import FastAPI, Request
@@ -24,7 +25,9 @@ def exception_handler(request: Request, e: HttpException):
def validation_exception_handler(request: Request, e: RequestValidationError):
return JSONResponse(
status_code=400,
content=utils.get_response(status=400, data=e.errors(), message='field required'),
content=utils.get_response(
status=400, data=e.errors(), message="field required"
),
)
@@ -61,7 +64,9 @@ app.add_middleware(
)
task_dir = utils.task_dir()
app.mount("/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name="")
app.mount(
"/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name=""
)
public_dir = utils.public_dir()
app.mount("/", StaticFiles(directory=public_dir, html=True), name="")

View File

@@ -10,7 +10,9 @@ from app.utils import utils
def __init_logger():
# _log_file = utils.storage_dir("logs/server.log")
_lvl = config.log_level
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
)
def format_record(record):
# 获取日志记录中的文件全路径
@@ -21,10 +23,13 @@ def __init_logger():
record["file"].path = f"./{relative_path}"
# 返回修改后的格式字符串
# 您可以根据需要调整这里的格式
_format = '<green>{time:%Y-%m-%d %H:%M:%S}</> | ' + \
'<level>{level}</> | ' + \
'"{file.path}:{line}":<blue> {function}</> ' + \
'- <level>{message}</>' + "\n"
_format = (
"<green>{time:%Y-%m-%d %H:%M:%S}</> | "
+ "<level>{level}</> | "
+ '"{file.path}:{line}":<blue> {function}</> '
+ "- <level>{message}</>"
+ "\n"
)
return _format
logger.remove()

View File

@@ -25,7 +25,7 @@ def load_config():
_config_ = toml.load(config_file)
except Exception as e:
logger.warning(f"load config failed: {str(e)}, try to load as utf-8-sig")
with open(config_file, mode="r", encoding='utf-8-sig') as fp:
with open(config_file, mode="r", encoding="utf-8-sig") as fp:
_cfg_content = fp.read()
_config_ = toml.loads(_cfg_content)
return _config_
@@ -52,8 +52,10 @@ log_level = _cfg.get("log_level", "DEBUG")
listen_host = _cfg.get("listen_host", "0.0.0.0")
listen_port = _cfg.get("listen_port", 8080)
project_name = _cfg.get("project_name", "MoneyPrinterTurbo")
project_description = _cfg.get("project_description",
"<a href='https://github.com/harry0703/MoneyPrinterTurbo'>https://github.com/harry0703/MoneyPrinterTurbo</a>")
project_description = _cfg.get(
"project_description",
"<a href='https://github.com/harry0703/MoneyPrinterTurbo'>https://github.com/harry0703/MoneyPrinterTurbo</a>",
)
project_version = _cfg.get("project_version", "1.1.9")
reload_debug = False

View File

@@ -7,14 +7,14 @@ from app.models.exception import HttpException
def get_task_id(request: Request):
task_id = request.headers.get('x-task-id')
task_id = request.headers.get("x-task-id")
if not task_id:
task_id = uuid4()
return str(task_id)
def get_api_key(request: Request):
api_key = request.headers.get('x-api-key')
api_key = request.headers.get("x-api-key")
return api_key
@@ -23,5 +23,9 @@ def verify_token(request: Request):
if token != config.app.get("api_key", ""):
request_id = get_task_id(request)
request_url = request.url
user_agent = request.headers.get('user-agent')
raise HttpException(task_id=request_id, status_code=401, message=f"invalid token: {request_url}, {user_agent}")
user_agent = request.headers.get("user-agent")
raise HttpException(
task_id=request_id,
status_code=401,
message=f"invalid token: {request_url}, {user_agent}",
)

View File

@@ -18,11 +18,15 @@ class TaskManager:
print(f"add task: {func.__name__}, current_tasks: {self.current_tasks}")
self.execute_task(func, *args, **kwargs)
else:
print(f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}")
print(
f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}"
)
self.enqueue({"func": func, "args": args, "kwargs": kwargs})
def execute_task(self, func: Callable, *args: Any, **kwargs: Any):
thread = threading.Thread(target=self.run_task, args=(func, *args), kwargs=kwargs)
thread = threading.Thread(
target=self.run_task, args=(func, *args), kwargs=kwargs
)
thread.start()
def run_task(self, func: Callable, *args: Any, **kwargs: Any):
@@ -35,11 +39,14 @@ class TaskManager:
def check_queue(self):
with self.lock:
if self.current_tasks < self.max_concurrent_tasks and not self.is_queue_empty():
if (
self.current_tasks < self.max_concurrent_tasks
and not self.is_queue_empty()
):
task_info = self.dequeue()
func = task_info['func']
args = task_info.get('args', ())
kwargs = task_info.get('kwargs', {})
func = task_info["func"]
args = task_info.get("args", ())
kwargs = task_info.get("kwargs", {})
self.execute_task(func, *args, **kwargs)
def task_done(self):

View File

@@ -8,7 +8,7 @@ from app.models.schema import VideoParams
from app.services import task as tm
FUNC_MAP = {
'start': tm.start,
"start": tm.start,
# 'start_test': tm.start_test
}
@@ -24,11 +24,15 @@ class RedisTaskManager(TaskManager):
def enqueue(self, task: Dict):
task_with_serializable_params = task.copy()
if 'params' in task['kwargs'] and isinstance(task['kwargs']['params'], VideoParams):
task_with_serializable_params['kwargs']['params'] = task['kwargs']['params'].dict()
if "params" in task["kwargs"] and isinstance(
task["kwargs"]["params"], VideoParams
):
task_with_serializable_params["kwargs"]["params"] = task["kwargs"][
"params"
].dict()
# 将函数对象转换为其名称
task_with_serializable_params['func'] = task['func'].__name__
task_with_serializable_params["func"] = task["func"].__name__
self.redis_client.rpush(self.queue, json.dumps(task_with_serializable_params))
def dequeue(self):
@@ -36,10 +40,14 @@ class RedisTaskManager(TaskManager):
if task_json:
task_info = json.loads(task_json)
# 将函数名称转换回函数对象
task_info['func'] = FUNC_MAP[task_info['func']]
task_info["func"] = FUNC_MAP[task_info["func"]]
if 'params' in task_info['kwargs'] and isinstance(task_info['kwargs']['params'], dict):
task_info['kwargs']['params'] = VideoParams(**task_info['kwargs']['params'])
if "params" in task_info["kwargs"] and isinstance(
task_info["kwargs"]["params"], dict
):
task_info["kwargs"]["params"] = VideoParams(
**task_info["kwargs"]["params"]
)
return task_info
return None

View File

@@ -4,6 +4,11 @@ from fastapi import Request
router = APIRouter()
@router.get("/ping", tags=["Health Check"], description="检查服务可用性", response_description="pong")
@router.get(
"/ping",
tags=["Health Check"],
description="检查服务可用性",
response_description="pong",
)
def ping(request: Request) -> str:
return "pong"

View File

@@ -3,8 +3,8 @@ from fastapi import APIRouter, Depends
def new_router(dependencies=None):
router = APIRouter()
router.tags = ['V1']
router.prefix = '/api/v1'
router.tags = ["V1"]
router.prefix = "/api/v1"
# 将认证依赖项应用于所有路由
if dependencies:
router.dependencies = dependencies

View File

@@ -1,6 +1,11 @@
from fastapi import Request
from app.controllers.v1.base import new_router
from app.models.schema import VideoScriptResponse, VideoScriptRequest, VideoTermsResponse, VideoTermsRequest
from app.models.schema import (
VideoScriptResponse,
VideoScriptRequest,
VideoTermsResponse,
VideoTermsRequest,
)
from app.services import llm
from app.utils import utils
@@ -9,23 +14,31 @@ from app.utils import utils
router = new_router()
@router.post("/scripts", response_model=VideoScriptResponse, summary="Create a script for the video")
@router.post(
"/scripts",
response_model=VideoScriptResponse,
summary="Create a script for the video",
)
def generate_video_script(request: Request, body: VideoScriptRequest):
video_script = llm.generate_script(video_subject=body.video_subject,
language=body.video_language,
paragraph_number=body.paragraph_number)
response = {
"video_script": video_script
}
video_script = llm.generate_script(
video_subject=body.video_subject,
language=body.video_language,
paragraph_number=body.paragraph_number,
)
response = {"video_script": video_script}
return utils.get_response(200, response)
@router.post("/terms", response_model=VideoTermsResponse, summary="Generate video terms based on the video script")
@router.post(
"/terms",
response_model=VideoTermsResponse,
summary="Generate video terms based on the video script",
)
def generate_video_terms(request: Request, body: VideoTermsRequest):
video_terms = llm.generate_terms(video_subject=body.video_subject,
video_script=body.video_script,
amount=body.amount)
response = {
"video_terms": video_terms
}
video_terms = llm.generate_terms(
video_subject=body.video_subject,
video_script=body.video_script,
amount=body.amount,
)
response = {"video_terms": video_terms}
return utils.get_response(200, response)

View File

@@ -1,11 +1,25 @@
PUNCTUATIONS = [
"?", ",", ".", "", ";", ":", "!", "",
"", "", "", "", "", "", "", "...",
"?",
",",
".",
"",
";",
":",
"!",
"",
"",
"",
"",
"",
"",
"",
"",
"...",
]
TASK_STATE_FAILED = -1
TASK_STATE_COMPLETE = 1
TASK_STATE_PROCESSING = 4
FILE_TYPE_VIDEOS = ['mp4', 'mov', 'mkv', 'webm']
FILE_TYPE_IMAGES = ['jpg', 'jpeg', 'png', 'bmp']
FILE_TYPE_VIDEOS = ["mp4", "mov", "mkv", "webm"]
FILE_TYPE_IMAGES = ["jpg", "jpeg", "png", "bmp"]

View File

@@ -5,16 +5,18 @@ from loguru import logger
class HttpException(Exception):
def __init__(self, task_id: str, status_code: int, message: str = '', data: Any = None):
def __init__(
self, task_id: str, status_code: int, message: str = "", data: Any = None
):
self.message = message
self.status_code = status_code
self.data = data
# 获取异常堆栈信息
tb_str = traceback.format_exc().strip()
if not tb_str or tb_str == "NoneType: None":
msg = f'HttpException: {status_code}, {task_id}, {message}'
msg = f"HttpException: {status_code}, {task_id}, {message}"
else:
msg = f'HttpException: {status_code}, {task_id}, {message}\n{tb_str}'
msg = f"HttpException: {status_code}, {task_id}, {message}\n{tb_str}"
if status_code == 400:
logger.warning(msg)

View File

@@ -21,6 +21,7 @@ def _generate_response(prompt: str) -> str:
if not model_name:
model_name = "gpt-3.5-turbo-16k-0613"
import g4f
content = g4f.ChatCompletion.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
@@ -78,44 +79,56 @@ def _generate_response(prompt: str) -> str:
base_url = config.app.get("ernie_base_url")
model_name = "***"
if not secret_key:
raise ValueError(f"{llm_provider}: secret_key is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: secret_key is not set, please set it in the config.toml file."
)
else:
raise ValueError("llm_provider is not set, please set it in the config.toml file.")
raise ValueError(
"llm_provider is not set, please set it in the config.toml file."
)
if not api_key:
raise ValueError(f"{llm_provider}: api_key is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: api_key is not set, please set it in the config.toml file."
)
if not model_name:
raise ValueError(f"{llm_provider}: model_name is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: model_name is not set, please set it in the config.toml file."
)
if not base_url:
raise ValueError(f"{llm_provider}: base_url is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: base_url is not set, please set it in the config.toml file."
)
if llm_provider == "qwen":
import dashscope
from dashscope.api_entities.dashscope_response import GenerationResponse
dashscope.api_key = api_key
response = dashscope.Generation.call(
model=model_name,
messages=[{"role": "user", "content": prompt}]
model=model_name, messages=[{"role": "user", "content": prompt}]
)
if response:
if isinstance(response, GenerationResponse):
status_code = response.status_code
if status_code != 200:
raise Exception(
f"[{llm_provider}] returned an error response: \"{response}\"")
f'[{llm_provider}] returned an error response: "{response}"'
)
content = response["output"]["text"]
return content.replace("\n", "")
else:
raise Exception(
f"[{llm_provider}] returned an invalid response: \"{response}\"")
f'[{llm_provider}] returned an invalid response: "{response}"'
)
else:
raise Exception(
f"[{llm_provider}] returned an empty response")
raise Exception(f"[{llm_provider}] returned an empty response")
if llm_provider == "gemini":
import google.generativeai as genai
genai.configure(api_key=api_key, transport='rest')
genai.configure(api_key=api_key, transport="rest")
generation_config = {
"temperature": 0.5,
@@ -127,25 +140,27 @@ def _generate_response(prompt: str) -> str:
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
]
model = genai.GenerativeModel(model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings)
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
)
try:
response = model.generate_content(prompt)
@@ -158,15 +173,16 @@ def _generate_response(prompt: str) -> str:
if llm_provider == "cloudflare":
import requests
response = requests.post(
f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}",
headers={"Authorization": f"Bearer {api_key}"},
json={
"messages": [
{"role": "system", "content": "You are a friendly assistant"},
{"role": "user", "content": prompt}
{"role": "user", "content": prompt},
]
}
},
)
result = response.json()
logger.info(result)
@@ -174,30 +190,35 @@ def _generate_response(prompt: str) -> str:
if llm_provider == "ernie":
import requests
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
access_token = requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params).json().get(
"access_token")
params = {
"grant_type": "client_credentials",
"client_id": api_key,
"client_secret": secret_key,
}
access_token = (
requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params)
.json()
.get("access_token")
)
url = f"{base_url}?access_token={access_token}"
payload = json.dumps({
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": 0.5,
"top_p": 0.8,
"penalty_score": 1,
"disable_search": False,
"enable_citation": False,
"response_format": "text"
})
headers = {
'Content-Type': 'application/json'
}
payload = json.dumps(
{
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
"top_p": 0.8,
"penalty_score": 1,
"disable_search": False,
"enable_citation": False,
"response_format": "text",
}
)
headers = {"Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload).json()
response = requests.request(
"POST", url, headers=headers, data=payload
).json()
return response.get("result")
if llm_provider == "azure":
@@ -213,24 +234,27 @@ def _generate_response(prompt: str) -> str:
)
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}]
model=model_name, messages=[{"role": "user", "content": prompt}]
)
if response:
if isinstance(response, ChatCompletion):
content = response.choices[0].message.content
else:
raise Exception(
f"[{llm_provider}] returned an invalid response: \"{response}\", please check your network "
f"connection and try again.")
f'[{llm_provider}] returned an invalid response: "{response}", please check your network '
f"connection and try again."
)
else:
raise Exception(
f"[{llm_provider}] returned an empty response, please check your network connection and try again.")
f"[{llm_provider}] returned an empty response, please check your network connection and try again."
)
return content.replace("\n", "")
def generate_script(video_subject: str, language: str = "", paragraph_number: int = 1) -> str:
def generate_script(
video_subject: str, language: str = "", paragraph_number: int = 1
) -> str:
prompt = f"""
# Role: Video Script Generator
@@ -335,14 +359,16 @@ Please note that you must use English for generating video search terms; Chinese
try:
response = _generate_response(prompt)
search_terms = json.loads(response)
if not isinstance(search_terms, list) or not all(isinstance(term, str) for term in search_terms):
if not isinstance(search_terms, list) or not all(
isinstance(term, str) for term in search_terms
):
logger.error("response is not a list of strings.")
continue
except Exception as e:
logger.warning(f"failed to generate video terms: {str(e)}")
if response:
match = re.search(r'\[.*]', response)
match = re.search(r"\[.*]", response)
if match:
try:
search_terms = json.loads(match.group())
@@ -361,9 +387,13 @@ Please note that you must use English for generating video search terms; Chinese
if __name__ == "__main__":
video_subject = "生命的意义是什么"
script = generate_script(video_subject=video_subject, language="zh-CN", paragraph_number=1)
script = generate_script(
video_subject=video_subject, language="zh-CN", paragraph_number=1
)
print("######################")
print(script)
search_terms = generate_terms(video_subject=video_subject, video_script=script, amount=5)
search_terms = generate_terms(
video_subject=video_subject, video_script=script, amount=5
)
print("######################")
print(search_terms)

View File

@@ -19,7 +19,8 @@ def get_api_key(cfg_key: str):
if not api_keys:
raise ValueError(
f"\n\n##### {cfg_key} is not set #####\n\nPlease set it in the config.toml file: {config.config_file}\n\n"
f"{utils.to_json(config.app)}")
f"{utils.to_json(config.app)}"
)
# if only one key is provided, return it
if isinstance(api_keys, str):
@@ -30,28 +31,29 @@ def get_api_key(cfg_key: str):
return api_keys[requested_count % len(api_keys)]
def search_videos_pexels(search_term: str,
minimum_duration: int,
video_aspect: VideoAspect = VideoAspect.portrait,
) -> List[MaterialInfo]:
def search_videos_pexels(
search_term: str,
minimum_duration: int,
video_aspect: VideoAspect = VideoAspect.portrait,
) -> List[MaterialInfo]:
aspect = VideoAspect(video_aspect)
video_orientation = aspect.name
video_width, video_height = aspect.to_resolution()
api_key = get_api_key("pexels_api_keys")
headers = {
"Authorization": api_key
}
headers = {"Authorization": api_key}
# Build URL
params = {
"query": search_term,
"per_page": 20,
"orientation": video_orientation
}
params = {"query": search_term, "per_page": 20, "orientation": video_orientation}
query_url = f"https://api.pexels.com/videos/search?{urlencode(params)}"
logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}")
try:
r = requests.get(query_url, headers=headers, proxies=config.proxy, verify=False, timeout=(30, 60))
r = requests.get(
query_url,
headers=headers,
proxies=config.proxy,
verify=False,
timeout=(30, 60),
)
response = r.json()
video_items = []
if "videos" not in response:
@@ -83,10 +85,11 @@ def search_videos_pexels(search_term: str,
return []
def search_videos_pixabay(search_term: str,
minimum_duration: int,
video_aspect: VideoAspect = VideoAspect.portrait,
) -> List[MaterialInfo]:
def search_videos_pixabay(
search_term: str,
minimum_duration: int,
video_aspect: VideoAspect = VideoAspect.portrait,
) -> List[MaterialInfo]:
aspect = VideoAspect(video_aspect)
video_width, video_height = aspect.to_resolution()
@@ -97,13 +100,15 @@ def search_videos_pixabay(search_term: str,
"q": search_term,
"video_type": "all", # Accepted values: "all", "film", "animation"
"per_page": 50,
"key": api_key
"key": api_key,
}
query_url = f"https://pixabay.com/api/videos/?{urlencode(params)}"
logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}")
try:
r = requests.get(query_url, proxies=config.proxy, verify=False, timeout=(30, 60))
r = requests.get(
query_url, proxies=config.proxy, verify=False, timeout=(30, 60)
)
response = r.json()
video_items = []
if "hits" not in response:
@@ -155,7 +160,11 @@ def save_video(video_url: str, save_dir: str = "") -> str:
# if video does not exist, download it
with open(video_path, "wb") as f:
f.write(requests.get(video_url, proxies=config.proxy, verify=False, timeout=(60, 240)).content)
f.write(
requests.get(
video_url, proxies=config.proxy, verify=False, timeout=(60, 240)
).content
)
if os.path.exists(video_path) and os.path.getsize(video_path) > 0:
try:
@@ -174,14 +183,15 @@ def save_video(video_url: str, save_dir: str = "") -> str:
return ""
def download_videos(task_id: str,
search_terms: List[str],
source: str = "pexels",
video_aspect: VideoAspect = VideoAspect.portrait,
video_contact_mode: VideoConcatMode = VideoConcatMode.random,
audio_duration: float = 0.0,
max_clip_duration: int = 5,
) -> List[str]:
def download_videos(
task_id: str,
search_terms: List[str],
source: str = "pexels",
video_aspect: VideoAspect = VideoAspect.portrait,
video_contact_mode: VideoConcatMode = VideoConcatMode.random,
audio_duration: float = 0.0,
max_clip_duration: int = 5,
) -> List[str]:
valid_video_items = []
valid_video_urls = []
found_duration = 0.0
@@ -190,9 +200,11 @@ def download_videos(task_id: str,
search_videos = search_videos_pixabay
for search_term in search_terms:
video_items = search_videos(search_term=search_term,
minimum_duration=max_clip_duration,
video_aspect=video_aspect)
video_items = search_videos(
search_term=search_term,
minimum_duration=max_clip_duration,
video_aspect=video_aspect,
)
logger.info(f"found {len(video_items)} videos for '{search_term}'")
for item in video_items:
@@ -202,7 +214,8 @@ def download_videos(task_id: str,
found_duration += item.duration
logger.info(
f"found total videos: {len(valid_video_items)}, required duration: {audio_duration} seconds, found duration: {found_duration} seconds")
f"found total videos: {len(valid_video_items)}, required duration: {audio_duration} seconds, found duration: {found_duration} seconds"
)
video_paths = []
material_directory = config.app.get("material_directory", "").strip()
@@ -218,14 +231,18 @@ def download_videos(task_id: str,
for item in valid_video_items:
try:
logger.info(f"downloading video: {item.url}")
saved_video_path = save_video(video_url=item.url, save_dir=material_directory)
saved_video_path = save_video(
video_url=item.url, save_dir=material_directory
)
if saved_video_path:
logger.info(f"video saved: {saved_video_path}")
video_paths.append(saved_video_path)
seconds = min(max_clip_duration, item.duration)
total_duration += seconds
if total_duration > audio_duration:
logger.info(f"total duration of downloaded videos: {total_duration} seconds, skip downloading more")
logger.info(
f"total duration of downloaded videos: {total_duration} seconds, skip downloading more"
)
break
except Exception as e:
logger.error(f"failed to download video: {utils.to_json(item)} => {str(e)}")
@@ -234,4 +251,6 @@ def download_videos(task_id: str,
if __name__ == "__main__":
download_videos("test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay")
download_videos(
"test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay"
)

View File

@@ -6,7 +6,6 @@ from app.models import const
# Base class for state management
class BaseState(ABC):
@abstractmethod
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
pass
@@ -18,11 +17,16 @@ class BaseState(ABC):
# Memory state management
class MemoryState(BaseState):
def __init__(self):
self._tasks = {}
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress)
if progress > 100:
progress = 100
@@ -43,12 +47,18 @@ class MemoryState(BaseState):
# Redis state management
class RedisState(BaseState):
def __init__(self, host='localhost', port=6379, db=0, password=None):
def __init__(self, host="localhost", port=6379, db=0, password=None):
import redis
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress)
if progress > 100:
progress = 100
@@ -67,7 +77,10 @@ class RedisState(BaseState):
if not task_data:
return None
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
task = {
key.decode("utf-8"): self._convert_to_original_type(value)
for key, value in task_data.items()
}
return task
def delete_task(self, task_id: str):
@@ -79,7 +92,7 @@ class RedisState(BaseState):
Convert the value from byte string to its original data type.
You can extend this method to handle other data types as needed.
"""
value_str = value.decode('utf-8')
value_str = value.decode("utf-8")
try:
# try to convert byte string array to list
@@ -100,4 +113,10 @@ _redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0)
_redis_password = config.app.get("redis_password", None)
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password) if _enable_redis else MemoryState()
state = (
RedisState(
host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password
)
if _enable_redis
else MemoryState()
)

View File

@@ -23,18 +23,22 @@ def create(audio_file, subtitle_file: str = ""):
if not os.path.isdir(model_path) or not os.path.isfile(model_bin_file):
model_path = model_size
logger.info(f"loading model: {model_path}, device: {device}, compute_type: {compute_type}")
logger.info(
f"loading model: {model_path}, device: {device}, compute_type: {compute_type}"
)
try:
model = WhisperModel(model_size_or_path=model_path,
device=device,
compute_type=compute_type)
model = WhisperModel(
model_size_or_path=model_path, device=device, compute_type=compute_type
)
except Exception as e:
logger.error(f"failed to load model: {e} \n\n"
f"********************************************\n"
f"this may be caused by network issue. \n"
f"please download the model manually and put it in the 'models' folder. \n"
f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n"
f"********************************************\n\n")
logger.error(
f"failed to load model: {e} \n\n"
f"********************************************\n"
f"this may be caused by network issue. \n"
f"please download the model manually and put it in the 'models' folder. \n"
f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n"
f"********************************************\n\n"
)
return None
logger.info(f"start, output file: {subtitle_file}")
@@ -49,7 +53,9 @@ def create(audio_file, subtitle_file: str = ""):
vad_parameters=dict(min_silence_duration_ms=500),
)
logger.info(f"detected language: '{info.language}', probability: {info.language_probability:.2f}")
logger.info(
f"detected language: '{info.language}', probability: {info.language_probability:.2f}"
)
start = timer()
subtitles = []
@@ -62,11 +68,9 @@ def create(audio_file, subtitle_file: str = ""):
msg = "[%.2fs -> %.2fs] %s" % (seg_start, seg_end, seg_text)
logger.debug(msg)
subtitles.append({
"msg": seg_text,
"start_time": seg_start,
"end_time": seg_end
})
subtitles.append(
{"msg": seg_text, "start_time": seg_start, "end_time": seg_end}
)
for segment in segments:
words_idx = 0
@@ -119,7 +123,11 @@ def create(audio_file, subtitle_file: str = ""):
for subtitle in subtitles:
text = subtitle.get("msg")
if text:
lines.append(utils.text_to_srt(idx, text, subtitle.get("start_time"), subtitle.get("end_time")))
lines.append(
utils.text_to_srt(
idx, text, subtitle.get("start_time"), subtitle.get("end_time")
)
)
idx += 1
sub = "\n".join(lines) + "\n"
@@ -136,12 +144,12 @@ def file_to_subtitles(filename):
current_times = None
current_text = ""
index = 0
with open(filename, 'r', encoding="utf-8") as f:
with open(filename, "r", encoding="utf-8") as f:
for line in f:
times = re.findall("([0-9]*:[0-9]*:[0-9]*,[0-9]*)", line)
if times:
current_times = line
elif line.strip() == '' and current_times:
elif line.strip() == "" and current_times:
index += 1
times_texts.append((index, current_times.strip(), current_text.strip()))
current_times, current_text = None, ""
@@ -166,9 +174,10 @@ def levenshtein_distance(s1, s2):
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def similarity(a, b):
distance = levenshtein_distance(a.lower(), b.lower())
max_length = max(len(a), len(b))
@@ -194,26 +203,44 @@ def correct(subtitle_file, video_script):
subtitle_index += 1
else:
combined_subtitle = subtitle_line
start_time = subtitle_items[subtitle_index][1].split(' --> ')[0]
end_time = subtitle_items[subtitle_index][1].split(' --> ')[1]
start_time = subtitle_items[subtitle_index][1].split(" --> ")[0]
end_time = subtitle_items[subtitle_index][1].split(" --> ")[1]
next_subtitle_index = subtitle_index + 1
while next_subtitle_index < len(subtitle_items):
next_subtitle = subtitle_items[next_subtitle_index][2].strip()
if similarity(script_line, combined_subtitle + " " + next_subtitle) > similarity(script_line, combined_subtitle):
if similarity(
script_line, combined_subtitle + " " + next_subtitle
) > similarity(script_line, combined_subtitle):
combined_subtitle += " " + next_subtitle
end_time = subtitle_items[next_subtitle_index][1].split(' --> ')[1]
end_time = subtitle_items[next_subtitle_index][1].split(" --> ")[1]
next_subtitle_index += 1
else:
break
if similarity(script_line, combined_subtitle) > 0.8:
logger.warning(f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}")
new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line))
logger.warning(
f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}"
)
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
f"{start_time} --> {end_time}",
script_line,
)
)
corrected = True
else:
logger.warning(f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}")
new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line))
logger.warning(
f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}"
)
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
f"{start_time} --> {end_time}",
script_line,
)
)
corrected = True
script_index += 1
@@ -223,10 +250,22 @@ def correct(subtitle_file, video_script):
while script_index < len(script_lines):
logger.warning(f"Extra script line: {script_lines[script_index]}")
if subtitle_index < len(subtitle_items):
new_subtitle_items.append((len(new_subtitle_items) + 1, subtitle_items[subtitle_index][1], script_lines[script_index]))
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
subtitle_items[subtitle_index][1],
script_lines[script_index],
)
)
subtitle_index += 1
else:
new_subtitle_items.append((len(new_subtitle_items) + 1, "00:00:00,000 --> 00:00:00,000", script_lines[script_index]))
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
"00:00:00,000 --> 00:00:00,000",
script_lines[script_index],
)
)
script_index += 1
corrected = True

View File

@@ -988,7 +988,7 @@ Name: zh-CN-XiaoxiaoMultilingualNeural-V2
Gender: Female
""".strip()
voices = []
name = ''
name = ""
for line in voices_str.split("\n"):
line = line.strip()
if not line:
@@ -1008,7 +1008,7 @@ Gender: Female
voices.append(f"{name}-{gender}")
else:
voices.append(f"{name}-{gender}")
name = ''
name = ""
voices.sort()
return voices
@@ -1028,7 +1028,9 @@ def is_azure_v2_voice(voice_name: str):
return ""
def tts(text: str, voice_name: str, voice_rate: float, voice_file: str) -> [SubMaker, None]:
def tts(
text: str, voice_name: str, voice_rate: float, voice_file: str
) -> [SubMaker, None]:
if is_azure_v2_voice(voice_name):
return azure_tts_v2(text, voice_name, voice_file)
return azure_tts_v1(text, voice_name, voice_rate, voice_file)
@@ -1042,9 +1044,11 @@ def convert_rate_to_percent(rate: float) -> str:
return f"+{percent}%"
else:
return f"{percent}%"
def azure_tts_v1(text: str, voice_name: str, voice_rate: float, voice_file: str) -> [SubMaker, None]:
def azure_tts_v1(
text: str, voice_name: str, voice_rate: float, voice_file: str
) -> [SubMaker, None]:
voice_name = parse_voice_name(voice_name)
text = text.strip()
rate_str = convert_rate_to_percent(voice_rate)
@@ -1060,7 +1064,9 @@ def azure_tts_v1(text: str, voice_name: str, voice_rate: float, voice_file: str)
if chunk["type"] == "audio":
file.write(chunk["data"])
elif chunk["type"] == "WordBoundary":
sub_maker.create_sub((chunk["offset"], chunk["duration"]), chunk["text"])
sub_maker.create_sub(
(chunk["offset"], chunk["duration"]), chunk["text"]
)
return sub_maker
sub_maker = asyncio.run(_do())
@@ -1085,8 +1091,12 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
def _format_duration_to_offset(duration) -> int:
if isinstance(duration, str):
time_obj = datetime.strptime(duration, "%H:%M:%S.%f")
milliseconds = (time_obj.hour * 3600000) + (time_obj.minute * 60000) + (time_obj.second * 1000) + (
time_obj.microsecond // 1000)
milliseconds = (
(time_obj.hour * 3600000)
+ (time_obj.minute * 60000)
+ (time_obj.second * 1000)
+ (time_obj.microsecond // 1000)
)
return milliseconds * 10000
if isinstance(duration, int):
@@ -1119,20 +1129,29 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
# Creates an instance of a speech config with specified subscription key and service region.
speech_key = config.azure.get("speech_key", "")
service_region = config.azure.get("speech_region", "")
audio_config = speechsdk.audio.AudioOutputConfig(filename=voice_file, use_default_speaker=True)
speech_config = speechsdk.SpeechConfig(subscription=speech_key,
region=service_region)
audio_config = speechsdk.audio.AudioOutputConfig(
filename=voice_file, use_default_speaker=True
)
speech_config = speechsdk.SpeechConfig(
subscription=speech_key, region=service_region
)
speech_config.speech_synthesis_voice_name = voice_name
# speech_config.set_property(property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestSentenceBoundary,
# value='true')
speech_config.set_property(property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestWordBoundary,
value='true')
speech_config.set_property(
property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestWordBoundary,
value="true",
)
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Audio48Khz192KBitRateMonoMp3)
speech_synthesizer = speechsdk.SpeechSynthesizer(audio_config=audio_config,
speech_config=speech_config)
speech_synthesizer.synthesis_word_boundary.connect(speech_synthesizer_word_boundary_cb)
speechsdk.SpeechSynthesisOutputFormat.Audio48Khz192KBitRateMonoMp3
)
speech_synthesizer = speechsdk.SpeechSynthesizer(
audio_config=audio_config, speech_config=speech_config
)
speech_synthesizer.synthesis_word_boundary.connect(
speech_synthesizer_word_boundary_cb
)
result = speech_synthesizer.speak_text_async(text).get()
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
@@ -1140,9 +1159,13 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
return sub_maker
elif result.reason == speechsdk.ResultReason.Canceled:
cancellation_details = result.cancellation_details
logger.error(f"azure v2 speech synthesis canceled: {cancellation_details.reason}")
logger.error(
f"azure v2 speech synthesis canceled: {cancellation_details.reason}"
)
if cancellation_details.reason == speechsdk.CancellationReason.Error:
logger.error(f"azure v2 speech synthesis error: {cancellation_details.error_details}")
logger.error(
f"azure v2 speech synthesis error: {cancellation_details.error_details}"
)
logger.info(f"completed, output file: {voice_file}")
except Exception as e:
logger.error(f"failed, error: {str(e)}")
@@ -1179,11 +1202,7 @@ def create_subtitle(sub_maker: submaker.SubMaker, text: str, subtitle_file: str)
"""
start_t = mktimestamp(start_time).replace(".", ",")
end_t = mktimestamp(end_time).replace(".", ",")
return (
f"{idx}\n"
f"{start_t} --> {end_t}\n"
f"{sub_text}\n"
)
return f"{idx}\n" f"{start_t} --> {end_t}\n" f"{sub_text}\n"
start_time = -1.0
sub_items = []
@@ -1240,12 +1259,16 @@ def create_subtitle(sub_maker: submaker.SubMaker, text: str, subtitle_file: str)
try:
sbs = subtitles.file_to_subtitles(subtitle_file, encoding="utf-8")
duration = max([tb for ((ta, tb), txt) in sbs])
logger.info(f"completed, subtitle file created: {subtitle_file}, duration: {duration}")
logger.info(
f"completed, subtitle file created: {subtitle_file}, duration: {duration}"
)
except Exception as e:
logger.error(f"failed, error: {str(e)}")
os.remove(subtitle_file)
else:
logger.warning(f"failed, sub_items len: {len(sub_items)}, script_lines len: {len(script_lines)}")
logger.warning(
f"failed, sub_items len: {len(sub_items)}, script_lines len: {len(script_lines)}"
)
except Exception as e:
logger.error(f"failed, error: {str(e)}")
@@ -1269,7 +1292,6 @@ if __name__ == "__main__":
voices = get_all_azure_voices()
print(len(voices))
async def _do():
temp_dir = utils.storage_dir("temp")
@@ -1318,12 +1340,13 @@ if __name__ == "__main__":
for voice_name in voice_names:
voice_file = f"{temp_dir}/tts-{voice_name}.mp3"
subtitle_file = f"{temp_dir}/tts.mp3.srt"
sub_maker = azure_tts_v2(text=text, voice_name=voice_name, voice_file=voice_file)
sub_maker = azure_tts_v2(
text=text, voice_name=voice_name, voice_file=voice_file
)
create_subtitle(sub_maker=sub_maker, text=text, subtitle_file=subtitle_file)
audio_duration = get_audio_duration(sub_maker)
print(f"voice: {voice_name}, audio duration: {audio_duration}s")
loop = asyncio.get_event_loop_policy().get_event_loop()
try:
loop.run_until_complete(_do())

View File

@@ -15,12 +15,12 @@ urllib3.disable_warnings()
def get_response(status: int, data: Any = None, message: str = ""):
obj = {
'status': status,
"status": status,
}
if data:
obj['data'] = data
obj["data"] = data
if message:
obj['message'] = message
obj["message"] = message
return obj
@@ -41,7 +41,7 @@ def to_json(obj):
elif isinstance(o, (list, tuple)):
return [serialize(item) for item in o]
# 如果对象是自定义类型尝试返回其__dict__属性
elif hasattr(o, '__dict__'):
elif hasattr(o, "__dict__"):
return serialize(o.__dict__)
# 其他情况返回None或者可以选择抛出异常
else:
@@ -199,7 +199,8 @@ def split_string_by_punctuations(s):
def md5(text):
import hashlib
return hashlib.md5(text.encode('utf-8')).hexdigest()
return hashlib.md5(text.encode("utf-8")).hexdigest()
def get_system_locale():

View File

@@ -12,6 +12,6 @@ build_and_render(
parse_refs=False,
sections=["build", "deps", "feat", "fix", "refactor"],
versioning="pep440",
bump="1.1.2", # 指定bump版本
bump="1.1.2", # 指定bump版本
in_place=True,
)