feat: support custom subtitle positioning

This commit is contained in:
yyhhyyyyyy
2024-07-24 14:25:20 +08:00
parent e64041c93d
commit e8b20c697d
9 changed files with 382 additions and 234 deletions

View File

@@ -119,6 +119,7 @@ class VideoParams(BaseModel):
subtitle_enabled: Optional[bool] = True
subtitle_position: Optional[str] = "bottom" # top, bottom, center
custom_position: float = 70.0
font_name: Optional[str] = "STHeitiMedium.ttc"
text_fore_color: Optional[str] = "#FFFFFF"
text_background_color: Optional[str] = "transparent"

View File

@@ -6,9 +6,10 @@ Resources:
1. https://fastapi.tiangolo.com/tutorial/bigger-applications
"""
from fastapi import APIRouter
from app.controllers.v1 import video, llm
from app.controllers.v1 import llm, video
root_api_router = APIRouter()
# v1

View File

@@ -1,13 +1,14 @@
import glob
import random
from typing import List
from PIL import ImageFont, Image
from loguru import logger
from moviepy.editor import *
from moviepy.video.tools.subtitles import SubtitlesClip
from PIL import ImageFont
from app.models import const
from app.models.schema import VideoAspect, VideoParams, VideoConcatMode, MaterialInfo
from app.models.schema import MaterialInfo, VideoAspect, VideoConcatMode, VideoParams
from app.utils import utils
@@ -27,14 +28,15 @@ def get_bgm_file(bgm_type: str = "random", bgm_file: str = ""):
return ""
def combine_videos(combined_video_path: str,
video_paths: List[str],
audio_file: str,
video_aspect: VideoAspect = VideoAspect.portrait,
video_concat_mode: VideoConcatMode = VideoConcatMode.random,
max_clip_duration: int = 5,
threads: int = 2,
) -> str:
def combine_videos(
combined_video_path: str,
video_paths: List[str],
audio_file: str,
video_aspect: VideoAspect = VideoAspect.portrait,
video_concat_mode: VideoConcatMode = VideoConcatMode.random,
max_clip_duration: int = 5,
threads: int = 2,
) -> str:
audio_clip = AudioFileClip(audio_file)
audio_duration = audio_clip.duration
logger.info(f"max duration of audio: {audio_duration} seconds")
@@ -102,13 +104,19 @@ def combine_videos(combined_video_path: str,
new_height = int(clip_h * scale_factor)
clip_resized = clip.resize(newsize=(new_width, new_height))
background = ColorClip(size=(video_width, video_height), color=(0, 0, 0))
clip = CompositeVideoClip([
background.set_duration(clip.duration),
clip_resized.set_position("center")
])
background = ColorClip(
size=(video_width, video_height), color=(0, 0, 0)
)
clip = CompositeVideoClip(
[
background.set_duration(clip.duration),
clip_resized.set_position("center"),
]
)
logger.info(f"resizing video to {video_width} x {video_height}, clip size: {clip_w} x {clip_h}")
logger.info(
f"resizing video to {video_width} x {video_height}, clip size: {clip_w} x {clip_h}"
)
if clip.duration > max_clip_duration:
clip = clip.subclip(0, max_clip_duration)
@@ -118,21 +126,22 @@ def combine_videos(combined_video_path: str,
video_clip = concatenate_videoclips(clips)
video_clip = video_clip.set_fps(30)
logger.info(f"writing")
logger.info("writing")
# https://github.com/harry0703/MoneyPrinterTurbo/issues/111#issuecomment-2032354030
video_clip.write_videofile(filename=combined_video_path,
threads=threads,
logger=None,
temp_audiofile_path=output_dir,
audio_codec="aac",
fps=30,
)
video_clip.write_videofile(
filename=combined_video_path,
threads=threads,
logger=None,
temp_audiofile_path=output_dir,
audio_codec="aac",
fps=30,
)
video_clip.close()
logger.success(f"completed")
logger.success("completed")
return combined_video_path
def wrap_text(text, max_width, font='Arial', fontsize=60):
def wrap_text(text, max_width, font="Arial", fontsize=60):
# 创建字体对象
font = ImageFont.truetype(font, fontsize)
@@ -151,7 +160,7 @@ def wrap_text(text, max_width, font='Arial', fontsize=60):
_wrapped_lines_ = []
words = text.split(" ")
_txt_ = ''
_txt_ = ""
for word in words:
_before = _txt_
_txt_ += f"{word} "
@@ -167,14 +176,14 @@ def wrap_text(text, max_width, font='Arial', fontsize=60):
_wrapped_lines_.append(_txt_)
if processed:
_wrapped_lines_ = [line.strip() for line in _wrapped_lines_]
result = '\n'.join(_wrapped_lines_).strip()
result = "\n".join(_wrapped_lines_).strip()
height = len(_wrapped_lines_) * height
# logger.warning(f"wrapped text: {result}")
return result, height
_wrapped_lines_ = []
chars = list(text)
_txt_ = ''
_txt_ = ""
for word in chars:
_txt_ += word
_width, _height = get_text_size(_txt_)
@@ -182,20 +191,21 @@ def wrap_text(text, max_width, font='Arial', fontsize=60):
continue
else:
_wrapped_lines_.append(_txt_)
_txt_ = ''
_txt_ = ""
_wrapped_lines_.append(_txt_)
result = '\n'.join(_wrapped_lines_).strip()
result = "\n".join(_wrapped_lines_).strip()
height = len(_wrapped_lines_) * height
# logger.warning(f"wrapped text: {result}")
return result, height
def generate_video(video_path: str,
audio_path: str,
subtitle_path: str,
output_file: str,
params: VideoParams,
):
def generate_video(
video_path: str,
audio_path: str,
subtitle_path: str,
output_file: str,
params: VideoParams,
):
aspect = VideoAspect(params.video_aspect)
video_width, video_height = aspect.to_resolution()
@@ -215,7 +225,7 @@ def generate_video(video_path: str,
if not params.font_name:
params.font_name = "STHeitiMedium.ttc"
font_path = os.path.join(utils.font_dir(), params.font_name)
if os.name == 'nt':
if os.name == "nt":
font_path = font_path.replace("\\", "/")
logger.info(f"using font: {font_path}")
@@ -223,11 +233,9 @@ def generate_video(video_path: str,
def create_text_clip(subtitle_item):
phrase = subtitle_item[1]
max_width = video_width * 0.9
wrapped_txt, txt_height = wrap_text(phrase,
max_width=max_width,
font=font_path,
fontsize=params.font_size
)
wrapped_txt, txt_height = wrap_text(
phrase, max_width=max_width, font=font_path, fontsize=params.font_size
)
_clip = TextClip(
wrapped_txt,
font=font_path,
@@ -243,18 +251,26 @@ def generate_video(video_path: str,
_clip = _clip.set_end(subtitle_item[0][1])
_clip = _clip.set_duration(duration)
if params.subtitle_position == "bottom":
_clip = _clip.set_position(('center', video_height * 0.95 - _clip.h))
_clip = _clip.set_position(("center", video_height * 0.95 - _clip.h))
elif params.subtitle_position == "top":
_clip = _clip.set_position(('center', video_height * 0.1))
else:
_clip = _clip.set_position(('center', 'center'))
_clip = _clip.set_position(("center", video_height * 0.05))
elif params.subtitle_position == "custom":
# 确保字幕完全在屏幕内
margin = 10 # 额外的边距,单位为像素
max_y = video_height - _clip.h - margin
min_y = margin
custom_y = (video_height - _clip.h) * (params.custom_position / 100)
custom_y = max(min_y, min(custom_y, max_y)) # 限制 y 值在有效范围内
_clip = _clip.set_position(("center", custom_y))
else: # center
_clip = _clip.set_position(("center", "center"))
return _clip
video_clip = VideoFileClip(video_path)
audio_clip = AudioFileClip(audio_path).volumex(params.voice_volume)
if subtitle_path and os.path.exists(subtitle_path):
sub = SubtitlesClip(subtitles=subtitle_path, encoding='utf-8')
sub = SubtitlesClip(subtitles=subtitle_path, encoding="utf-8")
text_clips = []
for item in sub.subtitles:
clip = create_text_clip(subtitle_item=item)
@@ -264,24 +280,25 @@ def generate_video(video_path: str,
bgm_file = get_bgm_file(bgm_type=params.bgm_type, bgm_file=params.bgm_file)
if bgm_file:
try:
bgm_clip = (AudioFileClip(bgm_file)
.volumex(params.bgm_volume)
.audio_fadeout(3))
bgm_clip = (
AudioFileClip(bgm_file).volumex(params.bgm_volume).audio_fadeout(3)
)
bgm_clip = afx.audio_loop(bgm_clip, duration=video_clip.duration)
audio_clip = CompositeAudioClip([audio_clip, bgm_clip])
except Exception as e:
logger.error(f"failed to add bgm: {str(e)}")
video_clip = video_clip.set_audio(audio_clip)
video_clip.write_videofile(output_file,
audio_codec="aac",
temp_audiofile_path=output_dir,
threads=params.n_threads or 2,
logger=None,
fps=30,
)
video_clip.write_videofile(
output_file,
audio_codec="aac",
temp_audiofile_path=output_dir,
threads=params.n_threads or 2,
logger=None,
fps=30,
)
video_clip.close()
logger.success(f"completed")
logger.success("completed")
def preprocess_video(materials: List[MaterialInfo], clip_duration=4):
@@ -292,7 +309,7 @@ def preprocess_video(materials: List[MaterialInfo], clip_duration=4):
ext = utils.parse_extension(material.url)
try:
clip = VideoFileClip(material.url)
except Exception as e:
except Exception:
clip = ImageClip(material.url)
width = clip.size[0]
@@ -304,12 +321,18 @@ def preprocess_video(materials: List[MaterialInfo], clip_duration=4):
if ext in const.FILE_TYPE_IMAGES:
logger.info(f"processing image: {material.url}")
# 创建一个图片剪辑并设置持续时间为3秒钟
clip = ImageClip(material.url).set_duration(clip_duration).set_position("center")
clip = (
ImageClip(material.url)
.set_duration(clip_duration)
.set_position("center")
)
# 使用resize方法来添加缩放效果。这里使用了lambda函数来使得缩放效果随时间变化。
# 假设我们想要从原始大小逐渐放大到120%的大小。
# t代表当前时间clip.duration为视频总时长这里是3秒。
# 注意1 表示100%的大小所以1.2表示120%的大小
zoom_clip = clip.resize(lambda t: 1 + (clip_duration * 0.03) * (t / clip.duration))
zoom_clip = clip.resize(
lambda t: 1 + (clip_duration * 0.03) * (t / clip.duration)
)
# 如果需要,可以创建一个包含缩放剪辑的复合视频剪辑
# (这在您想要在视频中添加其他元素时非常有用)

16
main.py
View File

@@ -1,8 +1,16 @@
import uvicorn
from loguru import logger
from app.config import config
if __name__ == '__main__':
logger.info("start server, docs: http://127.0.0.1:" + str(config.listen_port) + "/docs")
uvicorn.run(app="app.asgi:app", host=config.listen_host, port=config.listen_port, reload=config.reload_debug,
log_level="warning")
if __name__ == "__main__":
logger.info(
"start server, docs: http://127.0.0.1:" + str(config.listen_port) + "/docs"
)
uvicorn.run(
app="app.asgi:app",
host=config.listen_host,
port=config.listen_port,
reload=config.reload_debug,
log_level="warning",
)

View File

@@ -1,6 +1,5 @@
import sys
import os
import time
import sys
# Add the root directory of the project to the system path to allow importing modules from the project
root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@@ -10,31 +9,33 @@ if root_dir not in sys.path:
print(sys.path)
print("")
import streamlit as st
import os
from uuid import uuid4
import platform
import streamlit.components.v1 as components
from uuid import uuid4
import streamlit as st
from loguru import logger
st.set_page_config(page_title="MoneyPrinterTurbo",
page_icon="🤖",
layout="wide",
initial_sidebar_state="auto",
menu_items={
'Report a bug': "https://github.com/harry0703/MoneyPrinterTurbo/issues",
'About': "# MoneyPrinterTurbo\nSimply provide a topic or keyword for a video, and it will "
"automatically generate the video copy, video materials, video subtitles, "
"and video background music before synthesizing a high-definition short "
"video.\n\nhttps://github.com/harry0703/MoneyPrinterTurbo"
})
st.set_page_config(
page_title="MoneyPrinterTurbo",
page_icon="🤖",
layout="wide",
initial_sidebar_state="auto",
menu_items={
"Report a bug": "https://github.com/harry0703/MoneyPrinterTurbo/issues",
"About": "# MoneyPrinterTurbo\nSimply provide a topic or keyword for a video, and it will "
"automatically generate the video copy, video materials, video subtitles, "
"and video background music before synthesizing a high-definition short "
"video.\n\nhttps://github.com/harry0703/MoneyPrinterTurbo",
},
)
from app.models.schema import VideoParams, VideoAspect, VideoConcatMode, MaterialInfo
from app.services import task as tm, llm, voice
from app.utils import utils
from app.config import config
from app.models.const import FILE_TYPE_VIDEOS, FILE_TYPE_IMAGES
from app.models.const import FILE_TYPE_IMAGES, FILE_TYPE_VIDEOS
from app.models.schema import MaterialInfo, VideoAspect, VideoConcatMode, VideoParams
from app.services import llm, voice
from app.services import task as tm
from app.utils import utils
hide_streamlit_style = """
<style>#root > div:nth-child(1) > div > div > div > div > section > div {padding-top: 0rem;}</style>
@@ -42,7 +43,16 @@ hide_streamlit_style = """
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
st.title(f"MoneyPrinterTurbo v{config.project_version}")
support_locales = ["zh-CN", "zh-HK", "zh-TW", "de-DE", "en-US", "fr-FR", "vi-VN", "th-TH"]
support_locales = [
"zh-CN",
"zh-HK",
"zh-TW",
"de-DE",
"en-US",
"fr-FR",
"vi-VN",
"th-TH",
]
font_dir = os.path.join(root_dir, "resource", "fonts")
song_dir = os.path.join(root_dir, "resource", "songs")
@@ -51,14 +61,14 @@ config_file = os.path.join(root_dir, "webui", ".streamlit", "webui.toml")
system_locale = utils.get_system_locale()
# print(f"******** system locale: {system_locale} ********")
if 'video_subject' not in st.session_state:
st.session_state['video_subject'] = ''
if 'video_script' not in st.session_state:
st.session_state['video_script'] = ''
if 'video_terms' not in st.session_state:
st.session_state['video_terms'] = ''
if 'ui_language' not in st.session_state:
st.session_state['ui_language'] = config.ui.get("language", system_locale)
if "video_subject" not in st.session_state:
st.session_state["video_subject"] = ""
if "video_script" not in st.session_state:
st.session_state["video_script"] = ""
if "video_terms" not in st.session_state:
st.session_state["video_terms"] = ""
if "ui_language" not in st.session_state:
st.session_state["ui_language"] = config.ui.get("language", system_locale)
def get_all_fonts():
@@ -85,25 +95,25 @@ def open_task_folder(task_id):
sys = platform.system()
path = os.path.join(root_dir, "storage", "tasks", task_id)
if os.path.exists(path):
if sys == 'Windows':
if sys == "Windows":
os.system(f"start {path}")
if sys == 'Darwin':
if sys == "Darwin":
os.system(f"open {path}")
except Exception as e:
logger.error(e)
def scroll_to_bottom():
js = f"""
js = """
<script>
console.log("scroll_to_bottom");
function scroll(dummy_var_to_force_repeat_execution){{
function scroll(dummy_var_to_force_repeat_execution){
var sections = parent.document.querySelectorAll('section.main');
console.log(sections);
for(let index = 0; index<sections.length; index++) {{
for(let index = 0; index<sections.length; index++) {
sections[index].scrollTop = sections[index].scrollHeight;
}}
}}
}
}
scroll(1);
</script>
"""
@@ -123,12 +133,15 @@ def init_log():
record["file"].path = f"./{relative_path}"
# 返回修改后的格式字符串
# 您可以根据需要调整这里的格式
record['message'] = record['message'].replace(root_dir, ".")
record["message"] = record["message"].replace(root_dir, ".")
_format = '<green>{time:%Y-%m-%d %H:%M:%S}</> | ' + \
'<level>{level}</> | ' + \
'"{file.path}:{line}":<blue> {function}</> ' + \
'- <level>{message}</>' + "\n"
_format = (
"<green>{time:%Y-%m-%d %H:%M:%S}</> | "
+ "<level>{level}</> | "
+ '"{file.path}:{line}":<blue> {function}</> '
+ "- <level>{message}</>"
+ "\n"
)
return _format
logger.add(
@@ -145,7 +158,7 @@ locales = utils.load_locales(i18n_dir)
def tr(key):
loc = locales.get(st.session_state['ui_language'], {})
loc = locales.get(st.session_state["ui_language"], {})
return loc.get("Translation", {}).get(key, key)
@@ -164,19 +177,22 @@ if not config.app.get("hide_config", False):
selected_index = 0
for i, code in enumerate(locales.keys()):
display_languages.append(f"{code} - {locales[code].get('Language')}")
if code == st.session_state['ui_language']:
if code == st.session_state["ui_language"]:
selected_index = i
selected_language = st.selectbox(tr("Language"), options=display_languages,
index=selected_index)
selected_language = st.selectbox(
tr("Language"), options=display_languages, index=selected_index
)
if selected_language:
code = selected_language.split(" - ")[0].strip()
st.session_state['ui_language'] = code
config.ui['language'] = code
st.session_state["ui_language"] = code
config.ui["language"] = code
# 是否禁用日志显示
hide_log = st.checkbox(tr("Hide Log"), value=config.app.get("hide_log", False))
config.ui['hide_log'] = hide_log
hide_log = st.checkbox(
tr("Hide Log"), value=config.app.get("hide_log", False)
)
config.ui["hide_log"] = hide_log
with middle_config_panel:
# openai
@@ -187,8 +203,19 @@ if not config.app.get("hide_config", False):
# qwen (通义千问)
# gemini
# ollama
llm_providers = ['OpenAI', 'Moonshot', 'Azure', 'Qwen', 'DeepSeek', 'Gemini', 'Ollama', 'G4f', 'OneAPI',
"Cloudflare", "ERNIE"]
llm_providers = [
"OpenAI",
"Moonshot",
"Azure",
"Qwen",
"DeepSeek",
"Gemini",
"Ollama",
"G4f",
"OneAPI",
"Cloudflare",
"ERNIE",
]
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
for i, provider in enumerate(llm_providers):
@@ -196,19 +223,25 @@ if not config.app.get("hide_config", False):
saved_llm_provider_index = i
break
llm_provider = st.selectbox(tr("LLM Provider"), options=llm_providers, index=saved_llm_provider_index)
llm_provider = st.selectbox(
tr("LLM Provider"),
options=llm_providers,
index=saved_llm_provider_index,
)
llm_helper = st.container()
llm_provider = llm_provider.lower()
config.app["llm_provider"] = llm_provider
llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
llm_secret_key = config.app.get(f"{llm_provider}_secret_key", "") # only for baidu ernie
llm_secret_key = config.app.get(
f"{llm_provider}_secret_key", ""
) # only for baidu ernie
llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
tips = ""
if llm_provider == 'ollama':
if llm_provider == "ollama":
if not llm_model_name:
llm_model_name = "qwen:7b"
if not llm_base_url:
@@ -224,7 +257,7 @@ if not config.app.get("hide_config", False):
- **Model Name**: 使用 `ollama list` 查看,比如 `qwen:7b`
"""
if llm_provider == 'openai':
if llm_provider == "openai":
if not llm_model_name:
llm_model_name = "gpt-3.5-turbo"
with llm_helper:
@@ -236,7 +269,7 @@ if not config.app.get("hide_config", False):
- **Model Name**: 填写**有权限**的模型,[点击查看模型列表](https://platform.openai.com/settings/organization/limits)
"""
if llm_provider == 'moonshot':
if llm_provider == "moonshot":
if not llm_model_name:
llm_model_name = "moonshot-v1-8k"
with llm_helper:
@@ -246,9 +279,11 @@ if not config.app.get("hide_config", False):
- **Base Url**: 固定为 https://api.moonshot.cn/v1
- **Model Name**: 比如 moonshot-v1-8k[点击查看模型列表](https://platform.moonshot.cn/docs/intro#%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8)
"""
if llm_provider == 'oneapi':
if llm_provider == "oneapi":
if not llm_model_name:
llm_model_name = "claude-3-5-sonnet-20240620" # 默认模型,可以根据需要调整
llm_model_name = (
"claude-3-5-sonnet-20240620" # 默认模型,可以根据需要调整
)
with llm_helper:
tips = """
##### OneAPI 配置说明
@@ -256,8 +291,8 @@ if not config.app.get("hide_config", False):
- **Base Url**: 填写 OneAPI 的基础 URL
- **Model Name**: 填写您要使用的模型名称,例如 claude-3-5-sonnet-20240620
"""
if llm_provider == 'qwen':
if llm_provider == "qwen":
if not llm_model_name:
llm_model_name = "qwen-max"
with llm_helper:
@@ -268,7 +303,7 @@ if not config.app.get("hide_config", False):
- **Model Name**: 比如 qwen-max[点击查看模型列表](https://help.aliyun.com/zh/dashscope/developer-reference/model-introduction#3ef6d0bcf91wy)
"""
if llm_provider == 'g4f':
if llm_provider == "g4f":
if not llm_model_name:
llm_model_name = "gpt-3.5-turbo"
with llm_helper:
@@ -279,7 +314,7 @@ if not config.app.get("hide_config", False):
- **Base Url**: 留空
- **Model Name**: 比如 gpt-3.5-turbo[点击查看模型列表](https://github.com/xtekky/gpt4free/blob/main/g4f/models.py#L308)
"""
if llm_provider == 'azure':
if llm_provider == "azure":
with llm_helper:
tips = """
##### Azure 配置说明
@@ -289,7 +324,7 @@ if not config.app.get("hide_config", False):
- **Model Name**: 填写你实际的部署名
"""
if llm_provider == 'gemini':
if llm_provider == "gemini":
if not llm_model_name:
llm_model_name = "gemini-1.0-pro"
@@ -302,7 +337,7 @@ if not config.app.get("hide_config", False):
- **Model Name**: 比如 gemini-1.0-pro
"""
if llm_provider == 'deepseek':
if llm_provider == "deepseek":
if not llm_model_name:
llm_model_name = "deepseek-chat"
if not llm_base_url:
@@ -315,7 +350,7 @@ if not config.app.get("hide_config", False):
- **Model Name**: 固定为 deepseek-chat
"""
if llm_provider == 'ernie':
if llm_provider == "ernie":
with llm_helper:
tips = """
##### 百度文心一言 配置说明
@@ -324,16 +359,23 @@ if not config.app.get("hide_config", False):
- **Base Url**: 填写 **请求地址** [点击查看文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11#%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E)
"""
if tips and config.ui['language'] == 'zh':
if tips and config.ui["language"] == "zh":
st.warning(
"中国用户建议使用 **DeepSeek** 或 **Moonshot** 作为大模型提供商\n- 国内可直接访问不需要VPN \n- 注册就送额度,基本够用")
"中国用户建议使用 **DeepSeek** 或 **Moonshot** 作为大模型提供商\n- 国内可直接访问不需要VPN \n- 注册就送额度,基本够用"
)
st.info(tips)
st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
st_llm_api_key = st.text_input(
tr("API Key"), value=llm_api_key, type="password"
)
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
st_llm_model_name = ""
if llm_provider != 'ernie':
st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name, key=f"{llm_provider}_model_name_input")
if llm_provider != "ernie":
st_llm_model_name = st.text_input(
tr("Model Name"),
value=llm_model_name,
key=f"{llm_provider}_model_name_input",
)
if st_llm_model_name:
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
else:
@@ -345,16 +387,21 @@ if not config.app.get("hide_config", False):
config.app[f"{llm_provider}_base_url"] = st_llm_base_url
if st_llm_model_name:
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
if llm_provider == 'ernie':
st_llm_secret_key = st.text_input(tr("Secret Key"), value=llm_secret_key, type="password")
if llm_provider == "ernie":
st_llm_secret_key = st.text_input(
tr("Secret Key"), value=llm_secret_key, type="password"
)
config.app[f"{llm_provider}_secret_key"] = st_llm_secret_key
if llm_provider == 'cloudflare':
st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
if llm_provider == "cloudflare":
st_llm_account_id = st.text_input(
tr("Account ID"), value=llm_account_id
)
if st_llm_account_id:
config.app[f"{llm_provider}_account_id"] = st_llm_account_id
with right_config_panel:
def get_keys_from_config(cfg_key):
api_keys = config.app.get(cfg_key, [])
if isinstance(api_keys, str):
@@ -362,19 +409,21 @@ if not config.app.get("hide_config", False):
api_key = ", ".join(api_keys)
return api_key
def save_keys_to_config(cfg_key, value):
value = value.replace(" ", "")
if value:
config.app[cfg_key] = value.split(",")
pexels_api_key = get_keys_from_config("pexels_api_keys")
pexels_api_key = st.text_input(tr("Pexels API Key"), value=pexels_api_key, type="password")
pexels_api_key = st.text_input(
tr("Pexels API Key"), value=pexels_api_key, type="password"
)
save_keys_to_config("pexels_api_keys", pexels_api_key)
pixabay_api_key = get_keys_from_config("pixabay_api_keys")
pixabay_api_key = st.text_input(tr("Pixabay API Key"), value=pixabay_api_key, type="password")
pixabay_api_key = st.text_input(
tr("Pixabay API Key"), value=pixabay_api_key, type="password"
)
save_keys_to_config("pixabay_api_keys", pixabay_api_key)
panel = st.columns(3)
@@ -388,8 +437,9 @@ uploaded_files = []
with left_panel:
with st.container(border=True):
st.write(tr("Video Script Settings"))
params.video_subject = st.text_input(tr("Video Subject"),
value=st.session_state['video_subject']).strip()
params.video_subject = st.text_input(
tr("Video Subject"), value=st.session_state["video_subject"]
).strip()
video_languages = [
(tr("Auto Detect"), ""),
@@ -397,24 +447,27 @@ with left_panel:
for code in support_locales:
video_languages.append((code, code))
selected_index = st.selectbox(tr("Script Language"),
index=0,
options=range(len(video_languages)), # 使用索引作为内部选项值
format_func=lambda x: video_languages[x][0] # 显示给用户的是标签
)
selected_index = st.selectbox(
tr("Script Language"),
index=0,
options=range(len(video_languages)), # 使用索引作为内部选项值
format_func=lambda x: video_languages[x][0], # 显示给用户的是标签
)
params.video_language = video_languages[selected_index][1]
if st.button(tr("Generate Video Script and Keywords"), key="auto_generate_script"):
if st.button(
tr("Generate Video Script and Keywords"), key="auto_generate_script"
):
with st.spinner(tr("Generating Video Script and Keywords")):
script = llm.generate_script(video_subject=params.video_subject, language=params.video_language)
script = llm.generate_script(
video_subject=params.video_subject, language=params.video_language
)
terms = llm.generate_terms(params.video_subject, script)
st.session_state['video_script'] = script
st.session_state['video_terms'] = ", ".join(terms)
st.session_state["video_script"] = script
st.session_state["video_terms"] = ", ".join(terms)
params.video_script = st.text_area(
tr("Video Script"),
value=st.session_state['video_script'],
height=280
tr("Video Script"), value=st.session_state["video_script"], height=280
)
if st.button(tr("Generate Video Keywords"), key="auto_generate_terms"):
if not params.video_script:
@@ -423,12 +476,11 @@ with left_panel:
with st.spinner(tr("Generating Video Keywords")):
terms = llm.generate_terms(params.video_subject, params.video_script)
st.session_state['video_terms'] = ", ".join(terms)
st.session_state["video_terms"] = ", ".join(terms)
params.video_terms = st.text_area(
tr("Video Keywords"),
value=st.session_state['video_terms'],
height=50)
tr("Video Keywords"), value=st.session_state["video_terms"], height=50
)
with middle_panel:
with st.container(border=True):
@@ -447,73 +499,93 @@ with middle_panel:
]
saved_video_source_name = config.app.get("video_source", "pexels")
saved_video_source_index = [v[1] for v in video_sources].index(saved_video_source_name)
saved_video_source_index = [v[1] for v in video_sources].index(
saved_video_source_name
)
selected_index = st.selectbox(tr("Video Source"),
options=range(len(video_sources)),
format_func=lambda x: video_sources[x][0],
index=saved_video_source_index
)
selected_index = st.selectbox(
tr("Video Source"),
options=range(len(video_sources)),
format_func=lambda x: video_sources[x][0],
index=saved_video_source_index,
)
params.video_source = video_sources[selected_index][1]
config.app["video_source"] = params.video_source
if params.video_source == 'local':
if params.video_source == "local":
_supported_types = FILE_TYPE_VIDEOS + FILE_TYPE_IMAGES
uploaded_files = st.file_uploader("Upload Local Files",
type=["mp4", "mov", "avi", "flv", "mkv", "jpg", "jpeg", "png"],
accept_multiple_files=True)
uploaded_files = st.file_uploader(
"Upload Local Files",
type=["mp4", "mov", "avi", "flv", "mkv", "jpg", "jpeg", "png"],
accept_multiple_files=True,
)
selected_index = st.selectbox(tr("Video Concat Mode"),
index=1,
options=range(len(video_concat_modes)), # 使用索引作为内部选项值
format_func=lambda x: video_concat_modes[x][0] # 显示给用户的是标签
)
params.video_concat_mode = VideoConcatMode(video_concat_modes[selected_index][1])
selected_index = st.selectbox(
tr("Video Concat Mode"),
index=1,
options=range(len(video_concat_modes)), # 使用索引作为内部选项值
format_func=lambda x: video_concat_modes[x][0], # 显示给用户的是标签
)
params.video_concat_mode = VideoConcatMode(
video_concat_modes[selected_index][1]
)
video_aspect_ratios = [
(tr("Portrait"), VideoAspect.portrait.value),
(tr("Landscape"), VideoAspect.landscape.value),
]
selected_index = st.selectbox(tr("Video Ratio"),
options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值
format_func=lambda x: video_aspect_ratios[x][0] # 显示给用户的是标签
)
selected_index = st.selectbox(
tr("Video Ratio"),
options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值
format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签
)
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
params.video_clip_duration = st.selectbox(tr("Clip Duration"), options=[2, 3, 4, 5, 6], index=1)
params.video_count = st.selectbox(tr("Number of Videos Generated Simultaneously"), options=[1, 2, 3, 4, 5],
index=0)
params.video_clip_duration = st.selectbox(
tr("Clip Duration"), options=[2, 3, 4, 5, 6], index=1
)
params.video_count = st.selectbox(
tr("Number of Videos Generated Simultaneously"),
options=[1, 2, 3, 4, 5],
index=0,
)
with st.container(border=True):
st.write(tr("Audio Settings"))
# tts_providers = ['edge', 'azure']
# tts_provider = st.selectbox(tr("TTS Provider"), tts_providers)
voices = voice.get_all_azure_voices(
filter_locals=support_locales)
voices = voice.get_all_azure_voices(filter_locals=support_locales)
friendly_names = {
v: v.
replace("Female", tr("Female")).
replace("Male", tr("Male")).
replace("Neural", "") for
v in voices}
v: v.replace("Female", tr("Female"))
.replace("Male", tr("Male"))
.replace("Neural", "")
for v in voices
}
saved_voice_name = config.ui.get("voice_name", "")
saved_voice_name_index = 0
if saved_voice_name in friendly_names:
saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
else:
for i, v in enumerate(voices):
if v.lower().startswith(st.session_state['ui_language'].lower()) and "V2" not in v:
if (
v.lower().startswith(st.session_state["ui_language"].lower())
and "V2" not in v
):
saved_voice_name_index = i
break
selected_friendly_name = st.selectbox(tr("Speech Synthesis"),
options=list(friendly_names.values()),
index=saved_voice_name_index)
selected_friendly_name = st.selectbox(
tr("Speech Synthesis"),
options=list(friendly_names.values()),
index=saved_voice_name_index,
)
voice_name = list(friendly_names.keys())[list(friendly_names.values()).index(selected_friendly_name)]
voice_name = list(friendly_names.keys())[
list(friendly_names.values()).index(selected_friendly_name)
]
params.voice_name = voice_name
config.ui['voice_name'] = voice_name
config.ui["voice_name"] = voice_name
if st.button(tr("Play Voice")):
play_content = params.video_subject
@@ -524,11 +596,21 @@ with middle_panel:
with st.spinner(tr("Synthesizing Voice")):
temp_dir = utils.storage_dir("temp", create=True)
audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
sub_maker = voice.tts(text=play_content, voice_name=voice_name, voice_rate=params.voice_rate, voice_file=audio_file)
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=params.voice_rate,
voice_file=audio_file,
)
# if the voice file generation failed, try again with a default content.
if not sub_maker:
play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
sub_maker = voice.tts(text=play_content, voice_name=voice_name, voice_rate=params.voice_rate, voice_file=audio_file)
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=params.voice_rate,
voice_file=audio_file,
)
if sub_maker and os.path.exists(audio_file):
st.audio(audio_file, format="audio/mp3")
@@ -536,29 +618,40 @@ with middle_panel:
os.remove(audio_file)
if voice.is_azure_v2_voice(voice_name):
saved_azure_speech_region = config.azure.get(f"speech_region", "")
saved_azure_speech_key = config.azure.get(f"speech_key", "")
azure_speech_region = st.text_input(tr("Speech Region"), value=saved_azure_speech_region)
azure_speech_key = st.text_input(tr("Speech Key"), value=saved_azure_speech_key, type="password")
saved_azure_speech_region = config.azure.get("speech_region", "")
saved_azure_speech_key = config.azure.get("speech_key", "")
azure_speech_region = st.text_input(
tr("Speech Region"), value=saved_azure_speech_region
)
azure_speech_key = st.text_input(
tr("Speech Key"), value=saved_azure_speech_key, type="password"
)
config.azure["speech_region"] = azure_speech_region
config.azure["speech_key"] = azure_speech_key
params.voice_volume = st.selectbox(tr("Speech Volume"),
options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0], index=2)
params.voice_rate = st.selectbox(tr("Speech Rate"),
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0], index=2)
params.voice_volume = st.selectbox(
tr("Speech Volume"),
options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
index=2,
)
params.voice_rate = st.selectbox(
tr("Speech Rate"),
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
index=2,
)
bgm_options = [
(tr("No Background Music"), ""),
(tr("Random Background Music"), "random"),
(tr("Custom Background Music"), "custom"),
]
selected_index = st.selectbox(tr("Background Music"),
index=1,
options=range(len(bgm_options)), # 使用索引作为内部选项值
format_func=lambda x: bgm_options[x][0] # 显示给用户的是标签
)
selected_index = st.selectbox(
tr("Background Music"),
index=1,
options=range(len(bgm_options)), # 使用索引作为内部选项值
format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签
)
# 获取选择的背景音乐类型
params.bgm_type = bgm_options[selected_index][1]
@@ -568,8 +661,11 @@ with middle_panel:
if custom_bgm_file and os.path.exists(custom_bgm_file):
params.bgm_file = custom_bgm_file
# st.write(f":red[已选择自定义背景音乐]**{custom_bgm_file}**")
params.bgm_volume = st.selectbox(tr("Background Music Volume"),
options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], index=2)
params.bgm_volume = st.selectbox(
tr("Background Music Volume"),
options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
index=2,
)
with right_panel:
with st.container(border=True):
@@ -580,31 +676,48 @@ with right_panel:
saved_font_name_index = 0
if saved_font_name in font_names:
saved_font_name_index = font_names.index(saved_font_name)
params.font_name = st.selectbox(tr("Font"), font_names, index=saved_font_name_index)
config.ui['font_name'] = params.font_name
params.font_name = st.selectbox(
tr("Font"), font_names, index=saved_font_name_index
)
config.ui["font_name"] = params.font_name
subtitle_positions = [
(tr("Top"), "top"),
(tr("Center"), "center"),
(tr("Bottom"), "bottom"),
(tr("Custom"), "custom"),
]
selected_index = st.selectbox(tr("Position"),
index=2,
options=range(len(subtitle_positions)), # 使用索引作为内部选项值
format_func=lambda x: subtitle_positions[x][0] # 显示给用户的是标签
)
selected_index = st.selectbox(
tr("Position"),
index=2,
options=range(len(subtitle_positions)),
format_func=lambda x: subtitle_positions[x][0],
)
params.subtitle_position = subtitle_positions[selected_index][1]
if params.subtitle_position == "custom":
custom_position = st.text_input(
tr("Custom Position (% from top)"), value="50"
)
try:
params.custom_position = float(custom_position)
if params.custom_position < 0 or params.custom_position > 100:
st.error(tr("Please enter a value between 0 and 100"))
except ValueError:
st.error(tr("Please enter a valid number"))
font_cols = st.columns([0.3, 0.7])
with font_cols[0]:
saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
params.text_fore_color = st.color_picker(tr("Font Color"), saved_text_fore_color)
config.ui['text_fore_color'] = params.text_fore_color
params.text_fore_color = st.color_picker(
tr("Font Color"), saved_text_fore_color
)
config.ui["text_fore_color"] = params.text_fore_color
with font_cols[1]:
saved_font_size = config.ui.get("font_size", 60)
params.font_size = st.slider(tr("Font Size"), 30, 100, saved_font_size)
config.ui['font_size'] = params.font_size
config.ui["font_size"] = params.font_size
stroke_cols = st.columns([0.3, 0.7])
with stroke_cols[0]:
@@ -621,7 +734,7 @@ if start_button:
scroll_to_bottom()
st.stop()
if llm_provider != 'g4f' and not config.app.get(f"{llm_provider}_api_key", ""):
if llm_provider != "g4f" and not config.app.get(f"{llm_provider}_api_key", ""):
st.error(tr("Please Enter the LLM API Key"))
scroll_to_bottom()
st.stop()
@@ -657,15 +770,13 @@ if start_button:
log_container = st.empty()
log_records = []
def log_received(msg):
if config.ui['hide_log']:
if config.ui["hide_log"]:
return
with log_container:
log_records.append(msg)
st.code("\n".join(log_records))
logger.add(log_received)
st.toast(tr("Generating Video"))
@@ -687,7 +798,7 @@ if start_button:
player_cols = st.columns(len(video_files) * 2 + 1)
for i, url in enumerate(video_files):
player_cols[i * 2 + 1].video(url)
except Exception as e:
except Exception:
pass
open_task_folder(task_id)

View File

@@ -42,6 +42,7 @@
"Top": "Oben",
"Center": "Mittig",
"Bottom": "Unten (empfohlen)",
"Custom": "Benutzerdefinierte Position (70, was 70% von oben bedeutet)",
"Font Size": "Schriftgröße für Untertitel",
"Font Color": "Schriftfarbe",
"Stroke Color": "Kontur",

View File

@@ -42,6 +42,7 @@
"Top": "Top",
"Center": "Center",
"Bottom": "Bottom (Recommended)",
"Custom": "Custom position (70, indicating 70% down from the top)",
"Font Size": "Subtitle Font Size",
"Font Color": "Subtitle Font Color",
"Stroke Color": "Subtitle Outline Color",

View File

@@ -42,6 +42,7 @@
"Top": "Trên",
"Center": "Giữa",
"Bottom": "Dưới (Được Khuyến Nghị)",
"Custom": "Vị trí tùy chỉnh (70, chỉ ra là cách đầu trang 70%)",
"Font Size": "Cỡ Chữ Phụ Đề",
"Font Color": "Màu Chữ Phụ Đề",
"Stroke Color": "Màu Viền Phụ Đề",

View File

@@ -42,6 +42,7 @@
"Top": "顶部",
"Center": "中间",
"Bottom": "底部(推荐)",
"Custom": "自定义位置70表示离顶部70%的位置)",
"Font Size": "字幕大小",
"Font Color": "字幕颜色",
"Stroke Color": "描边颜色",