mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-25 11:29:32 +08:00
更新WebUI,添加模型目录检查和必要文件验证
- 新增示例 - 新增模型版本提示 - 新增生成参数设置 - 新增分句预览
This commit is contained in:
@@ -149,6 +149,7 @@ wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/bpe.model -P che
|
||||
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/dvae.pth -P checkpoints
|
||||
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/gpt.pth -P checkpoints
|
||||
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/unigram_12000.vocab -P checkpoints
|
||||
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/config.yaml -P checkpoints
|
||||
```
|
||||
|
||||
4. Run test script:
|
||||
@@ -180,6 +181,9 @@ indextts --help
|
||||
```bash
|
||||
pip install -e ".[webui]"
|
||||
python webui.py
|
||||
|
||||
# use another model version:
|
||||
python webui.py --model_dir IndexTTS-1.5
|
||||
```
|
||||
Open your browser and visit `http://127.0.0.1:7860` to see the demo.
|
||||
|
||||
|
||||
8
tests/cases.jsonl
Normal file
8
tests/cases.jsonl
Normal file
@@ -0,0 +1,8 @@
|
||||
{"prompt_audio":"sample_prompt.wav","text":"IndexTTS 正式发布1.0版本了,效果666","infer_mode":0}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!","infer_mode":0}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"晕XUAN4是一种GAN3觉","infer_mode":0}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"最zhong4要的是:不要chong2蹈覆辙","infer_mode":0}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"Matt Hougan, chief investment officer at Bitwise, predicts Bitcoin (BTC) will reach $200,000 by the end of 2025 due to a supply shock from heightened institutional demand. In an interview with Cointelegraph at Consensus 2025 in Toronto, the executive said that Bitwise's Bitcoin price prediction model is driven exclusively by supply and demand metrics. \"I think eventually that will exhaust sellers at the $100,000 level where we have been stuck, and I think the next stopping point above that is $200,000,\" the executive said.","infer_mode":1}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"《盗梦空间》(英语:Inception)是由美国华纳兄弟影片公司出品的电影,由克里斯托弗·诺兰(Christopher Edward Nolan)执导并编剧,莱昂纳多·迪卡普里奥(Leonardo Wilhelm DiCaprio)、玛丽昂·歌迪亚、约瑟夫·高登-莱维特、艾利奥特·佩吉、汤姆·哈迪等联袂主演,2010年7月16日在美国上映,2010年9月1日在中国内地上映,2020年8月28日在中国内地重映。豆瓣评分:9.4,IMDB 8.8。影片剧情游走于梦境与现实之间,被定义为“发生在意识结构内的当代动作科幻片”,讲述了由 Leonardo 扮演的造梦师,带领特工团队进入他人梦境,从他人的潜意识中盗取机密,并重塑他人梦境的故事。","infer_mode":1}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"清晨拉开窗帘,阳光洒在窗台的Bloomixy花艺礼盒上——薰衣草香薰蜡烛唤醒嗅觉,永生花束折射出晨露般光泽。设计师将“自然绽放美学”融入每个细节:手工陶瓷花瓶可作首饰收纳,香薰精油含依兰依兰舒缓配方。限量款附赠《365天插花灵感手册》,让每个平凡日子都有花开仪式感。宴会厅灯光暗下的刹那,Glimmeria星月系列耳坠开始发光——瑞士冷珐琅工艺让蓝宝石如银河流动,钛合金骨架仅3.2g无负重感。设计师秘密:内置微型重力感应器,随步伐产生0.01mm振幅,打造“行走的星光”。七夕限定礼盒含星座定制铭牌,让爱意如星辰永恒闪耀。","infer_mode":1}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"当地时间15日,随着特朗普与阿联酋敲定2000亿美元协议,特朗普的中东之行正式收官。特朗普已宣布获得沙特6000亿美元和卡塔尔2430亿美元投资承诺。商业协议成为特朗普重返白宫后首次外访的核心成果。香港英文媒体《南华早报》(South China Morning Post)称,特朗普访问期间提出了以经济合作为驱动的中东及南亚和平计划。分析人士称,该战略包含多项旨在遏制中国在这些地区影响力的措施。英国伦敦国王学院安全研究教授安德烈亚斯·克里格(Dr. Andreas Krieg)对半岛电视台表示,特朗普访问海湾地区的主要目标有三个:其一,以军工投资和能源合作的形式获得海湾国家的切实承诺;其二,加强与“让美国再次伟大”运动结盟的外交伙伴关系,维持美国外交影响力;其三:将海湾国家重新定位为美国在从加沙到伊朗等地区危机前线的调解人,这样就可以不用增强军事部署。安德烈亚斯·克里格直言:“海湾国家不会为美国牺牲与中国的关系,他们的战略自主性远超特朗普想象。”美国有线电视新闻网(CNN)报道称,特朗普到访的三个能源富国,每个国家都对美国有着长长的诉求清单。尽管这些国家豪掷重金,但美国并未实现所有诉求。","infer_mode":1}
|
||||
161
webui.py
161
webui.py
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -12,70 +12,201 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(current_dir)
|
||||
sys.path.append(os.path.join(current_dir, "indextts"))
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="IndexTTS WebUI")
|
||||
parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
|
||||
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to run the web UI on")
|
||||
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
|
||||
cmd_args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(cmd_args.model_dir):
|
||||
print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
|
||||
sys.exit(1)
|
||||
|
||||
for file in [
|
||||
"bigvgan_generator.pth",
|
||||
"bpe.model",
|
||||
"gpt.pth",
|
||||
"config.yaml",
|
||||
]:
|
||||
file_path = os.path.join(cmd_args.model_dir, file)
|
||||
if not os.path.exists(file_path):
|
||||
print(f"Required file {file_path} does not exist. Please download it.")
|
||||
sys.exit(1)
|
||||
|
||||
import gradio as gr
|
||||
from indextts.utils.webui_utils import next_page, prev_page
|
||||
|
||||
from indextts.infer import IndexTTS
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto(language="zh_CN")
|
||||
MODE = 'local'
|
||||
tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml")
|
||||
tts = IndexTTS(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),)
|
||||
|
||||
|
||||
os.makedirs("outputs/tasks",exist_ok=True)
|
||||
os.makedirs("prompts",exist_ok=True)
|
||||
|
||||
with open("tests/cases.jsonl", "r", encoding="utf-8") as f:
|
||||
example_cases = []
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
example = json.loads(line)
|
||||
example_cases.append([os.path.join("tests", example.get("prompt_audio", "sample_prompt.wav")),
|
||||
example.get("text"), ["普通推理", "批次推理"][example.get("infer_mode", 0)]])
|
||||
|
||||
def gen_single(prompt, text, infer_mode, progress=gr.Progress()):
|
||||
def gen_single(prompt, text, infer_mode, max_text_tokens_per_sentence=120, sentences_bucket_max_size=4,
|
||||
*args, progress=gr.Progress()):
|
||||
output_path = None
|
||||
if not output_path:
|
||||
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
|
||||
# set gradio progress
|
||||
tts.gr_progress = progress
|
||||
do_sample, top_p, top_k, temperature, \
|
||||
length_penalty, num_beams, repetition_penalty, max_mel_tokens = args
|
||||
kwargs = {
|
||||
"do_sample": bool(do_sample),
|
||||
"top_p": float(top_p),
|
||||
"top_k": int(top_k) if int(top_k) > 0 else None,
|
||||
"temperature": float(temperature),
|
||||
"length_penalty": float(length_penalty),
|
||||
"num_beams": num_beams,
|
||||
"repetition_penalty": float(repetition_penalty),
|
||||
"max_mel_tokens": int(max_mel_tokens),
|
||||
# "typical_sampling": bool(typical_sampling),
|
||||
# "typical_mass": float(typical_mass),
|
||||
}
|
||||
if infer_mode == "普通推理":
|
||||
output = tts.infer(prompt, text, output_path) # 普通推理
|
||||
output = tts.infer(prompt, text, output_path, verbose=cmd_args.verbose,
|
||||
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
|
||||
**kwargs)
|
||||
else:
|
||||
output = tts.infer_fast(prompt, text, output_path) # 批次推理
|
||||
# 批次推理
|
||||
output = tts.infer_fast(prompt, text, output_path, verbose=cmd_args.verbose,
|
||||
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
|
||||
sentences_bucket_max_size=(sentences_bucket_max_size),
|
||||
**kwargs)
|
||||
return gr.update(value=output,visible=True)
|
||||
|
||||
def update_prompt_audio():
|
||||
update_button = gr.update(interactive=True)
|
||||
return update_button
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Blocks(title="IndexTTS Demo") as demo:
|
||||
mutex = threading.Lock()
|
||||
gr.HTML('''
|
||||
<h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
|
||||
<h2><center>(一款工业级可控且高效的零样本文本转语音系统)</h2>
|
||||
|
||||
<p align="center">
|
||||
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
|
||||
</p>
|
||||
''')
|
||||
with gr.Tab("音频生成"):
|
||||
with gr.Row():
|
||||
os.makedirs("prompts",exist_ok=True)
|
||||
prompt_audio = gr.Audio(label="请上传参考音频",key="prompt_audio",
|
||||
prompt_audio = gr.Audio(label="参考音频",key="prompt_audio",
|
||||
sources=["upload","microphone"],type="filepath")
|
||||
prompt_list = os.listdir("prompts")
|
||||
default = ''
|
||||
if prompt_list:
|
||||
default = prompt_list[0]
|
||||
with gr.Column():
|
||||
input_text_single = gr.TextArea(label="请输入目标文本",key="input_text_single")
|
||||
infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="选择推理模式(批次推理:更适合长句,性能翻倍)",value="普通推理")
|
||||
gen_button = gr.Button("生成语音",key="gen_button",interactive=True)
|
||||
input_text_single = gr.TextArea(label="文本",key="input_text_single", placeholder="请输入目标文本", info="当前模型版本{}".format(tts.model_version or "1.0"))
|
||||
infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="推理模式",info="批次推理:更适合长句,性能翻倍",value="普通推理")
|
||||
gen_button = gr.Button("生成语音", key="gen_button",interactive=True)
|
||||
output_audio = gr.Audio(label="生成结果", visible=True,key="output_audio")
|
||||
with gr.Accordion("高级生成参数设置", open=False):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown("**GPT2 采样设置** _参数会影响音频多样性和生成速度详见[Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)_")
|
||||
with gr.Row():
|
||||
do_sample = gr.Checkbox(label="do_sample", value=True, info="是否进行采样")
|
||||
temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
|
||||
with gr.Row():
|
||||
top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
|
||||
top_k = gr.Slider(label="top_k", minimum=0, maximum=100, value=30, step=1)
|
||||
num_beams = gr.Slider(label="num_beams", value=3, minimum=1, maximum=10, step=1)
|
||||
with gr.Row():
|
||||
repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
|
||||
length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
|
||||
max_mel_tokens = gr.Slider(label="max_mel_tokens", value=600, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info="生成Token最大数量,过小导致音频被截断", key="max_mel_tokens")
|
||||
# with gr.Row():
|
||||
# typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用")
|
||||
# typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1)
|
||||
with gr.Column(scale=2):
|
||||
gr.Markdown("**分句设置** _参数会影响音频质量和生成速度_")
|
||||
with gr.Row():
|
||||
max_text_tokens_per_sentence = gr.Slider(
|
||||
label="分句最大Token数", value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_sentence",
|
||||
info="建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高",
|
||||
)
|
||||
sentences_bucket_max_size = gr.Slider(
|
||||
label="分句分桶的最大容量(批次推理生效)", value=4, minimum=1, maximum=16, step=1, key="sentences_bucket_max_size",
|
||||
info="建议2-8之间,值越大,一批次推理包含的分句数越多,过大可能导致内存溢出",
|
||||
)
|
||||
with gr.Accordion("预览分句结果", open=True) as sentences_settings:
|
||||
sentences_preview = gr.Dataframe(
|
||||
headers=["序号", "分句内容", "Token数"],
|
||||
key="sentences_preview",
|
||||
wrap=True,
|
||||
)
|
||||
advanced_params = [
|
||||
do_sample, top_p, top_k, temperature,
|
||||
length_penalty, num_beams, repetition_penalty, max_mel_tokens,
|
||||
# typical_sampling, typical_mass,
|
||||
]
|
||||
|
||||
if len(example_cases) > 0:
|
||||
gr.Examples(
|
||||
examples=example_cases,
|
||||
inputs=[prompt_audio, input_text_single, infer_mode],
|
||||
)
|
||||
|
||||
def on_input_text_change(text, max_tokens_per_sentence):
|
||||
if text and len(text) > 0:
|
||||
text_tokens_list = tts.tokenizer.tokenize(text)
|
||||
|
||||
sentences = tts.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=int(max_tokens_per_sentence))
|
||||
data = []
|
||||
for i, s in enumerate(sentences):
|
||||
sentence_str = ''.join(s)
|
||||
tokens_count = len(s)
|
||||
data.append([i, sentence_str, tokens_count])
|
||||
|
||||
return {
|
||||
sentences_preview: gr.update(value=data, visible=True, type="array"),
|
||||
}
|
||||
else:
|
||||
df = pd.DataFrame([], columns=["序号", "分句内容", "Token数"])
|
||||
return {
|
||||
sentences_preview: gr.update(value=df)
|
||||
}
|
||||
|
||||
input_text_single.change(
|
||||
on_input_text_change,
|
||||
inputs=[input_text_single, max_text_tokens_per_sentence],
|
||||
outputs=[sentences_preview]
|
||||
)
|
||||
max_text_tokens_per_sentence.change(
|
||||
on_input_text_change,
|
||||
inputs=[input_text_single, max_text_tokens_per_sentence],
|
||||
outputs=[sentences_preview]
|
||||
)
|
||||
prompt_audio.upload(update_prompt_audio,
|
||||
inputs=[],
|
||||
outputs=[gen_button])
|
||||
|
||||
gen_button.click(gen_single,
|
||||
inputs=[prompt_audio, input_text_single, infer_mode],
|
||||
inputs=[prompt_audio, input_text_single, infer_mode,
|
||||
max_text_tokens_per_sentence, sentences_bucket_max_size,
|
||||
*advanced_params,
|
||||
],
|
||||
outputs=[output_audio])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.queue(20)
|
||||
demo.launch(server_name="127.0.0.1")
|
||||
demo.launch(server_name=cmd_args.host, server_port=cmd_args.port)
|
||||
|
||||
Reference in New Issue
Block a user