更新WebUI,添加模型目录检查和必要文件验证

- 新增示例
- 新增模型版本提示
- 新增生成参数设置
- 新增分句预览
This commit is contained in:
yrom
2025-05-18 16:48:07 +08:00
parent 60a2238eac
commit 76e7645a8d
3 changed files with 158 additions and 15 deletions

View File

@@ -149,6 +149,7 @@ wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/bpe.model -P che
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/dvae.pth -P checkpoints
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/gpt.pth -P checkpoints
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/unigram_12000.vocab -P checkpoints
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/config.yaml -P checkpoints
```
4. Run test script:
@@ -180,6 +181,9 @@ indextts --help
```bash
pip install -e ".[webui]"
python webui.py
# use another model version:
python webui.py --model_dir IndexTTS-1.5
```
Open your browser and visit `http://127.0.0.1:7860` to see the demo.

8
tests/cases.jsonl Normal file
View File

@@ -0,0 +1,8 @@
{"prompt_audio":"sample_prompt.wav","text":"IndexTTS 正式发布1.0版本了效果666","infer_mode":0}
{"prompt_audio":"sample_prompt.wav","text":"大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了","infer_mode":0}
{"prompt_audio":"sample_prompt.wav","text":"晕XUAN4是一种GAN3觉","infer_mode":0}
{"prompt_audio":"sample_prompt.wav","text":"最zhong4要的是不要chong2蹈覆辙","infer_mode":0}
{"prompt_audio":"sample_prompt.wav","text":"Matt Hougan, chief investment officer at Bitwise, predicts Bitcoin (BTC) will reach $200,000 by the end of 2025 due to a supply shock from heightened institutional demand. In an interview with Cointelegraph at Consensus 2025 in Toronto, the executive said that Bitwise's Bitcoin price prediction model is driven exclusively by supply and demand metrics. \"I think eventually that will exhaust sellers at the $100,000 level where we have been stuck, and I think the next stopping point above that is $200,000,\" the executive said.","infer_mode":1}
{"prompt_audio":"sample_prompt.wav","text":"《盗梦空间》英语Inception是由美国华纳兄弟影片公司出品的电影由克里斯托弗·诺兰Christopher Edward Nolan执导并编剧莱昂纳多·迪卡普里奥Leonardo Wilhelm DiCaprio、玛丽昂·歌迪亚、约瑟夫·高登-莱维特、艾利奥特·佩吉、汤姆·哈迪等联袂主演2010年7月16日在美国上映2010年9月1日在中国内地上映2020年8月28日在中国内地重映。豆瓣评分9.4IMDB 8.8。影片剧情游走于梦境与现实之间,被定义为“发生在意识结构内的当代动作科幻片”,讲述了由 Leonardo 扮演的造梦师,带领特工团队进入他人梦境,从他人的潜意识中盗取机密,并重塑他人梦境的故事。","infer_mode":1}
{"prompt_audio":"sample_prompt.wav","text":"清晨拉开窗帘阳光洒在窗台的Bloomixy花艺礼盒上——薰衣草香薰蜡烛唤醒嗅觉永生花束折射出晨露般光泽。设计师将“自然绽放美学”融入每个细节手工陶瓷花瓶可作首饰收纳香薰精油含依兰依兰舒缓配方。限量款附赠《365天插花灵感手册》让每个平凡日子都有花开仪式感。宴会厅灯光暗下的刹那Glimmeria星月系列耳坠开始发光——瑞士冷珐琅工艺让蓝宝石如银河流动钛合金骨架仅3.2g无负重感。设计师秘密内置微型重力感应器随步伐产生0.01mm振幅,打造“行走的星光”。七夕限定礼盒含星座定制铭牌,让爱意如星辰永恒闪耀。","infer_mode":1}
{"prompt_audio":"sample_prompt.wav","text":"当地时间15日随着特朗普与阿联酋敲定2000亿美元协议特朗普的中东之行正式收官。特朗普已宣布获得沙特6000亿美元和卡塔尔2430亿美元投资承诺。商业协议成为特朗普重返白宫后首次外访的核心成果。香港英文媒体《南华早报》South China Morning Post特朗普访问期间提出了以经济合作为驱动的中东及南亚和平计划。分析人士称该战略包含多项旨在遏制中国在这些地区影响力的措施。英国伦敦国王学院安全研究教授安德烈亚斯·克里格(Dr. Andreas Krieg)对半岛电视台表示特朗普访问海湾地区的主要目标有三个其一以军工投资和能源合作的形式获得海湾国家的切实承诺其二加强与“让美国再次伟大”运动结盟的外交伙伴关系维持美国外交影响力其三将海湾国家重新定位为美国在从加沙到伊朗等地区危机前线的调解人这样就可以不用增强军事部署。安德烈亚斯·克里格直言“海湾国家不会为美国牺牲与中国的关系他们的战略自主性远超特朗普想象。”美国有线电视新闻网CNN报道称特朗普到访的三个能源富国每个国家都对美国有着长长的诉求清单。尽管这些国家豪掷重金但美国并未实现所有诉求。","infer_mode":1}

161
webui.py
View File

@@ -1,5 +1,5 @@
import json
import os
import shutil
import sys
import threading
import time
@@ -12,70 +12,201 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
sys.path.append(os.path.join(current_dir, "indextts"))
import argparse
parser = argparse.ArgumentParser(description="IndexTTS WebUI")
parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to run the web UI on")
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
cmd_args = parser.parse_args()
if not os.path.exists(cmd_args.model_dir):
print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
sys.exit(1)
for file in [
"bigvgan_generator.pth",
"bpe.model",
"gpt.pth",
"config.yaml",
]:
file_path = os.path.join(cmd_args.model_dir, file)
if not os.path.exists(file_path):
print(f"Required file {file_path} does not exist. Please download it.")
sys.exit(1)
import gradio as gr
from indextts.utils.webui_utils import next_page, prev_page
from indextts.infer import IndexTTS
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto(language="zh_CN")
MODE = 'local'
tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml")
tts = IndexTTS(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),)
os.makedirs("outputs/tasks",exist_ok=True)
os.makedirs("prompts",exist_ok=True)
with open("tests/cases.jsonl", "r", encoding="utf-8") as f:
example_cases = []
for line in f:
line = line.strip()
if not line:
continue
example = json.loads(line)
example_cases.append([os.path.join("tests", example.get("prompt_audio", "sample_prompt.wav")),
example.get("text"), ["普通推理", "批次推理"][example.get("infer_mode", 0)]])
def gen_single(prompt, text, infer_mode, progress=gr.Progress()):
def gen_single(prompt, text, infer_mode, max_text_tokens_per_sentence=120, sentences_bucket_max_size=4,
*args, progress=gr.Progress()):
output_path = None
if not output_path:
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
# set gradio progress
tts.gr_progress = progress
do_sample, top_p, top_k, temperature, \
length_penalty, num_beams, repetition_penalty, max_mel_tokens = args
kwargs = {
"do_sample": bool(do_sample),
"top_p": float(top_p),
"top_k": int(top_k) if int(top_k) > 0 else None,
"temperature": float(temperature),
"length_penalty": float(length_penalty),
"num_beams": num_beams,
"repetition_penalty": float(repetition_penalty),
"max_mel_tokens": int(max_mel_tokens),
# "typical_sampling": bool(typical_sampling),
# "typical_mass": float(typical_mass),
}
if infer_mode == "普通推理":
output = tts.infer(prompt, text, output_path) # 普通推理
output = tts.infer(prompt, text, output_path, verbose=cmd_args.verbose,
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
**kwargs)
else:
output = tts.infer_fast(prompt, text, output_path) # 批次推理
# 批次推理
output = tts.infer_fast(prompt, text, output_path, verbose=cmd_args.verbose,
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
sentences_bucket_max_size=(sentences_bucket_max_size),
**kwargs)
return gr.update(value=output,visible=True)
def update_prompt_audio():
update_button = gr.update(interactive=True)
return update_button
with gr.Blocks() as demo:
with gr.Blocks(title="IndexTTS Demo") as demo:
mutex = threading.Lock()
gr.HTML('''
<h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
<h2><center>(一款工业级可控且高效的零样本文本转语音系统)</h2>
<p align="center">
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
</p>
''')
with gr.Tab("音频生成"):
with gr.Row():
os.makedirs("prompts",exist_ok=True)
prompt_audio = gr.Audio(label="请上传参考音频",key="prompt_audio",
prompt_audio = gr.Audio(label="参考音频",key="prompt_audio",
sources=["upload","microphone"],type="filepath")
prompt_list = os.listdir("prompts")
default = ''
if prompt_list:
default = prompt_list[0]
with gr.Column():
input_text_single = gr.TextArea(label="请输入目标文本",key="input_text_single")
infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="选择推理模式批次推理:更适合长句,性能翻倍",value="普通推理")
gen_button = gr.Button("生成语音",key="gen_button",interactive=True)
input_text_single = gr.TextArea(label="文本",key="input_text_single", placeholder="请输入目标文本", info="当前模型版本{}".format(tts.model_version or "1.0"))
infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="推理模式",info="批次推理:更适合长句,性能翻倍",value="普通推理")
gen_button = gr.Button("生成语音", key="gen_button",interactive=True)
output_audio = gr.Audio(label="生成结果", visible=True,key="output_audio")
with gr.Accordion("高级生成参数设置", open=False):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**GPT2 采样设置** _参数会影响音频多样性和生成速度详见[Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)_")
with gr.Row():
do_sample = gr.Checkbox(label="do_sample", value=True, info="是否进行采样")
temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
with gr.Row():
top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
top_k = gr.Slider(label="top_k", minimum=0, maximum=100, value=30, step=1)
num_beams = gr.Slider(label="num_beams", value=3, minimum=1, maximum=10, step=1)
with gr.Row():
repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
max_mel_tokens = gr.Slider(label="max_mel_tokens", value=600, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info="生成Token最大数量过小导致音频被截断", key="max_mel_tokens")
# with gr.Row():
# typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用")
# typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1)
with gr.Column(scale=2):
gr.Markdown("**分句设置** _参数会影响音频质量和生成速度_")
with gr.Row():
max_text_tokens_per_sentence = gr.Slider(
label="分句最大Token数", value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_sentence",
info="建议80~200之间值越大分句越长值越小分句越碎过小过大都可能导致音频质量不高",
)
sentences_bucket_max_size = gr.Slider(
label="分句分桶的最大容量(批次推理生效)", value=4, minimum=1, maximum=16, step=1, key="sentences_bucket_max_size",
info="建议2-8之间值越大一批次推理包含的分句数越多过大可能导致内存溢出",
)
with gr.Accordion("预览分句结果", open=True) as sentences_settings:
sentences_preview = gr.Dataframe(
headers=["序号", "分句内容", "Token数"],
key="sentences_preview",
wrap=True,
)
advanced_params = [
do_sample, top_p, top_k, temperature,
length_penalty, num_beams, repetition_penalty, max_mel_tokens,
# typical_sampling, typical_mass,
]
if len(example_cases) > 0:
gr.Examples(
examples=example_cases,
inputs=[prompt_audio, input_text_single, infer_mode],
)
def on_input_text_change(text, max_tokens_per_sentence):
if text and len(text) > 0:
text_tokens_list = tts.tokenizer.tokenize(text)
sentences = tts.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=int(max_tokens_per_sentence))
data = []
for i, s in enumerate(sentences):
sentence_str = ''.join(s)
tokens_count = len(s)
data.append([i, sentence_str, tokens_count])
return {
sentences_preview: gr.update(value=data, visible=True, type="array"),
}
else:
df = pd.DataFrame([], columns=["序号", "分句内容", "Token数"])
return {
sentences_preview: gr.update(value=df)
}
input_text_single.change(
on_input_text_change,
inputs=[input_text_single, max_text_tokens_per_sentence],
outputs=[sentences_preview]
)
max_text_tokens_per_sentence.change(
on_input_text_change,
inputs=[input_text_single, max_text_tokens_per_sentence],
outputs=[sentences_preview]
)
prompt_audio.upload(update_prompt_audio,
inputs=[],
outputs=[gen_button])
gen_button.click(gen_single,
inputs=[prompt_audio, input_text_single, infer_mode],
inputs=[prompt_audio, input_text_single, infer_mode,
max_text_tokens_per_sentence, sentences_bucket_max_size,
*advanced_params,
],
outputs=[output_audio])
if __name__ == "__main__":
demo.queue(20)
demo.launch(server_name="127.0.0.1")
demo.launch(server_name=cmd_args.host, server_port=cmd_args.port)