IndexTTS2 Release Preparation, Part 2 (#291)

* fix: Configure "uv" build system to use CUDA on supported platforms

- Linux builds of PyTorch always have CUDA acceleration built-in, but Windows only has it if we request a CUDA build.

- The built-in CUDA on Linux uses old libraries and can be slow.

- We now request PyTorch built for the most modern CUDA Toolkit on Linux + Windows, to solve both problems.

- Mac uses PyTorch without CUDA support, since it doesn't exist on that platform.

- Other dependencies have received new releases and are included in this fix too:

* click was downgraded because the author revoked 8.2.2 due to a bug.

* wetext received a new release now.

* fix: Use PyPI as the hashing reference in "uv" lockfile

- PyPI is the most trustworthy source for package hashes. We need to remove the custom mirror from the config, otherwise that mirror always becomes the default lockfile/package source, which leads to user trust issues and package impersonation risks.

- Regional mirrors should be added by users during installation instead, via the `uv sync --default-index` flag. Documented with example for Chinese mirror.

- When users add `--default-index`, "uv" will try to discover the exact same packages via the mirror to improve download speeds, but automatically uses PyPI if the mirror didn't have the files or if the mirror's file hashes were incorrect. Thus ensuring that users always have the correct package files.

* docs: Improve README for IndexTTS2 release!

- "Abstract" separated into paragraphs for easier readability.

- Clearer document structure and many grammatical improvements.

- More emojis, to make it easier to find sections when scrolling through the page!

- Added missing instructions:

* Needing `git-lfs` to clone the code.
* Needing CUDA Toolkit to install the dependencies.
* How to install the `hf` or `modelscope` CLI tools to download the models.

- Made our web demo the first section within "quickstart", to give users a quick, fun demo to start experimenting with.

- Fixed a bug in the "PYTHONPATH" recommendation. It must be enclosed in quotes `""`, otherwise the new path would break on systems that had spaces in their original path.

- Improved all Python code-example descriptions to make them much easier to understand.

- Clearly marked the IndexTTS1 legacy section as "legacy" to avoid confusion.

- Removed outdated Windows "conda/pip" instruction which is no longer relevant since we use "uv" now.

* refactor(webui): Remove unused imports

The old IndexTTS1 module and ModelScope were being loaded even though we don't need them. They also have a lot of dependencies, which slowed down loading and could even cause some conflicts.

* feat!: Remove obsolete build system (setup.py)

BREAKING CHANGE: The `setup.py` file has been removed.

Users should now use the new `pyproject.toml` based "uv" build system for installing and developing the project.

* feat: Add support for installing IndexTTS as a CLI tool

- We now support installing as a CLI tool via "uv".

- Uses the modern "hatchling" as the package / CLI build system.

- The `cli.py` code is currently outdated (doesn't support IndexTTS2). Marking as a TODO.

* chore: Add authors and classifiers metadata to pyproject.toml

* feat: Faster installs by making WebUI dependencies optional

* refactor!: Rename "sentences" to "segments" for clarity

- When we are splitting text into generation chunks, we are *not* creating "sentences". We are creating "segments". Because a *sentence* must always end with punctuation (".!?" etc). A *segment* can be a small fragment of a sentence, without any punctuation, so it's not accurate (and was very misleading) to use the word "sentences".

- All variables, function calls and strings have been carefully analyzed and renamed.

- This change will be part of user-facing code via a new feature, which is why the change was applied to the entire codebase.

- This change also helps future code contributors understand the code.

- All affected features are fully tested and work correctly after this refactoring.

- The `is_fp16` parameter has also been renamed to `use_fp16` since the previous name could confuse people ("is" implies an automatic check, "use" implies a user decision to enable/disable FP16).

- `cli.py`'s "--fp16" default value has been set to False, exactly like the web UI.

- `webui.py`'s "--is_fp16" flag has been changed to "--fp16" for easier usage and consistency with the CLI program, and the help-description has been improved.

* feat(webui): Set "max tokens per generation segment" via CLI flag

- The "Max tokens per generation segment" is a critical setting, as it directly impacts VRAM usage. Since the optimal value varies significantly based on a user's GPU, it is a frequent point of adjustment to prevent out-of-memory issues.

- This change allows the default value to be set via a CLI flag. Users can now conveniently start the web UI with the correct setting for their system, eliminating the need to manually reconfigure the value on every restart.

- The `webui.py -h` help text has also been enhanced to automatically display the default values for all CLI settings.

* refactor(i18n): Improve clarity of all web UI translation strings

* feat(webui): Use main text as emotion guidance when description is empty

If the user selects "text-to-emotion" control, but leaves the emotion description empty, we now automatically use the main text prompt instead.

This ensures that web users can enjoy every feature of IndexTTS2, including the ability to automatically guess the emotion from the main text prompt.

* feat: Add PyTorch GPU acceleration diagnostic tool

* chore: Use NVIDIA CUDA Toolkit v12.8

Downgrade from CUDA 12.9 to 12.8 to simplify user installation, since version 12.8 is very popular.

* docs: Simplify "uv run" command examples

The "uv run" command can take a `.py` file as direct argument and automatically understands that it should run via python.
This commit is contained in:
Johnny Arcitec
2025-09-09 06:51:45 +02:00
committed by GitHub
parent 3fe385af69
commit cdcc62ae22
17 changed files with 2727 additions and 2476 deletions

232
README.md
View File

@@ -37,55 +37,75 @@
</a> </a>
</div> </div>
### Abstract ### Abstract
Existing autoregressive large-scale text-to-speech (TTS) models have advantages in speech naturalness, but their token-by-token generation mechanism makes it difficult to precisely control the duration of synthesized speech. This becomes a significant limitation in applications requiring strict audio-visual synchronization, such as video dubbing. This paper introduces IndexTTS2, which proposes a novel, general, and autoregressive model-friendly method for speech duration control. The method supports two generation modes: one explicitly specifies the number of generated tokens to precisely control speech duration; the other freely generates speech in an autoregressive manner without specifying the number of tokens, while faithfully reproducing the prosodic features of the input prompt. Furthermore, IndexTTS2 achieves disentanglement between emotional expression and speaker identity, enabling independent control over timbre and emotion. In the zero-shot setting, the model can accurately reconstruct the target timbre (from the timbre prompt) while perfectly reproducing the specified emotional tone (from the style prompt). To enhance speech clarity in highly emotional expressions, we incorporate GPT latent representations and design a novel three-stage training paradigm to improve the stability of the generated speech. Additionally, to lower the barrier for emotional control, we designed a soft instruction mechanism based on text descriptions by fine-tuning Qwen3, effectively guiding the generation of speech with the desired emotional orientation. Finally, experimental results on multiple datasets show that IndexTTS2 outperforms state-of-the-art zero-shot TTS models in terms of word error rate, speaker similarity, and emotional fidelity. Audio samples are available at: <a href="https://index-tts.github.io/index-tts2.github.io/">IndexTTS2 demo page</a> Existing autoregressive large-scale text-to-speech (TTS) models have advantages in speech naturalness, but their token-by-token generation mechanism makes it difficult to precisely control the duration of synthesized speech. This becomes a significant limitation in applications requiring strict audio-visual synchronization, such as video dubbing.
This paper introduces IndexTTS2, which proposes a novel, general, and autoregressive model-friendly method for speech duration control.
The method supports two generation modes: one explicitly specifies the number of generated tokens to precisely control speech duration; the other freely generates speech in an autoregressive manner without specifying the number of tokens, while faithfully reproducing the prosodic features of the input prompt.
Furthermore, IndexTTS2 achieves disentanglement between emotional expression and speaker identity, enabling independent control over timbre and emotion. In the zero-shot setting, the model can accurately reconstruct the target timbre (from the timbre prompt) while perfectly reproducing the specified emotional tone (from the style prompt).
To enhance speech clarity in highly emotional expressions, we incorporate GPT latent representations and design a novel three-stage training paradigm to improve the stability of the generated speech. Additionally, to lower the barrier for emotional control, we designed a soft instruction mechanism based on text descriptions by fine-tuning Qwen3, effectively guiding the generation of speech with the desired emotional orientation.
Finally, experimental results on multiple datasets show that IndexTTS2 outperforms state-of-the-art zero-shot TTS models in terms of word error rate, speaker similarity, and emotional fidelity. Audio samples are available at: <a href="https://index-tts.github.io/index-tts2.github.io/">IndexTTS2 demo page</a>.
**Tips:** Please contact the authors for more detailed information. For commercial usage and cooperation, please contact <u>indexspeech@bilibili.com</u>.
**Tips:** Please contact authors for more detailed information. For commercial cooperation, please contact <u>indexspeech@bilibili.com</u>
### Feel IndexTTS2 ### Feel IndexTTS2
<div align="center"> <div align="center">
**IndexTTS2: The Future of Voice, Now Generating** **IndexTTS2: The Future of Voice, Now Generating**
[![IndexTTS2 Demo](assets/IndexTTS2-video-pic.png)](https://www.bilibili.com/video/BV136a9zqEk5) [![IndexTTS2 Demo](assets/IndexTTS2-video-pic.png)](https://www.bilibili.com/video/BV136a9zqEk5)
*Click the image to watch IndexTTS2 video* *Click the image to watch the IndexTTS2 introduction video.*
</div> </div>
### Contact ### Contact
QQ Group553460296(No.1) 663272642(No.4)\
QQ Group553460296(No.1) 663272642(No.4) \
Discordhttps://discord.gg/uT32E7KDmy \ Discordhttps://discord.gg/uT32E7KDmy \
Emalindexspeech@bilibili.com \ Emailindexspeech@bilibili.com \
You are welcome to join our community! 🌏 \
欢迎大家来交流讨论! 欢迎大家来交流讨论!
## 📣 Updates ## 📣 Updates
- `2025/09/08` 🔥🔥🔥 We release the **IndexTTS-2** - `2025/09/08` 🔥🔥🔥 We release **IndexTTS-2** to the world!
- The first autoregressive TTS model with precise synthesis duration control, supporting both controllable and uncontrollable modes. <i>This functionality is not yet enabled in this release.</i> - The first autoregressive TTS model with precise synthesis duration control, supporting both controllable and uncontrollable modes. <i>This functionality is not yet enabled in this release.</i>
- The model achieves highly expressive emotional speech synthesis, with emotion-controllable capabilities enabled through multiple input modalities. - The model achieves highly expressive emotional speech synthesis, with emotion-controllable capabilities enabled through multiple input modalities.
- `2025/05/14` 🔥🔥 We release the **IndexTTS-1.5**, Significantly improve the model's stability and its performance in the English language. - `2025/05/14` 🔥🔥 We release **IndexTTS-1.5**, significantly improving the model's stability and its performance in the English language.
- `2025/03/25` 🔥 We release **IndexTTS-1.0** model parameters and inference code. - `2025/03/25` 🔥 We release **IndexTTS-1.0** with model weights and inference code.
- `2025/02/12` 🔥 We submitted our paper on arXiv, and released our demos and test sets. - `2025/02/12` 🔥 We submitted our paper to arXiv, and released our demos and test sets.
## 🖥️ Method
The overview of IndexTTS2 is shown as follows. ## 🖥️ Neural Network Architecture
Architectural overview of IndexTTS2, our state-of-the art speech model:
<picture> <picture>
<img src="assets/IndexTTS2.png" width="800"/> <img src="assets/IndexTTS2.png" width="800"/>
</picture> </picture>
The key contributions of **indextts2** are summarized as follows: The key contributions of **IndexTTS2** are summarized as follows:
- We propose a duration adaptation scheme for autoregressive TTS models. IndexTTS2 is the first autoregressive zero-shot TTS model to combine precise duration control with natural duration generation, and the method is scalable for any autoregressive large-scale TTS model. - We propose a duration adaptation scheme for autoregressive TTS models. IndexTTS2 is the first autoregressive zero-shot TTS model to combine precise duration control with natural duration generation, and the method is scalable for any autoregressive large-scale TTS model.
- The emotional and speaker-related features are decoupled from the prompts, and a feature fusion strategy is designed to maintain semantic fluency and pronunciation clarity during emotionally rich expressions. Furthermore, a tool was developed for emotion control, utilising natural language descriptions for the benefit of users. - The emotional and speaker-related features are decoupled from the prompts, and a feature fusion strategy is designed to maintain semantic fluency and pronunciation clarity during emotionally rich expressions. Furthermore, a tool was developed for emotion control, utilizing natural language descriptions for the benefit of users.
- To address the lack of highly expressive speech data, we propose an effective training strategy, significantly enhancing the emotional expressiveness of zeroshot TTS to State-of-the-Art (SOTA) level. - To address the lack of highly expressive speech data, we propose an effective training strategy, significantly enhancing the emotional expressiveness of zeroshot TTS to State-of-the-Art (SOTA) level.
- We will publicly release the code and pre-trained weights to facilitate future research and practical applications. - We will publicly release the code and pre-trained weights to facilitate future research and practical applications.
## Model Download ## Model Download
| **HuggingFace** | **ModelScope** | | **HuggingFace** | **ModelScope** |
|----------------------------------------------------------|----------------------------------------------------------| |----------------------------------------------------------|----------------------------------------------------------|
| [😁 IndexTTS-2](https://huggingface.co/IndexTeam/IndexTTS-2) | [IndexTTS-2](https://modelscope.cn/models/IndexTeam/IndexTTS-2) | | [😁 IndexTTS-2](https://huggingface.co/IndexTeam/IndexTTS-2) | [IndexTTS-2](https://modelscope.cn/models/IndexTeam/IndexTTS-2) |
@@ -94,102 +114,189 @@ The key contributions of **indextts2** are summarized as follows:
## Usage Instructions ## Usage Instructions
### Environment Setup
1. Download this repository: ### ⚙️ Environment Setup
1. Ensure that you have both [git](https://git-scm.com/downloads)
and [git-lfs](https://git-lfs.com/) on your system.
The Git-LFS plugin must also be enabled on your current user account:
```bash
git lfs install
```
2. Download this repository:
```bash ```bash
git clone https://github.com/index-tts/index-tts.git && cd index-tts git clone https://github.com/index-tts/index-tts.git && cd index-tts
git lfs pull # fetch example files git lfs pull # download large repository files
``` ```
2. Install the [uv](https://docs.astral.sh/uv/getting-started/installation/) package 3. Install the [uv](https://docs.astral.sh/uv/getting-started/installation/) package
manager. It is *required* for a reliable, modern installation environment. manager. It is *required* for a reliable, modern installation environment.
3. Install dependencies: 4. Install required dependencies:
We use `uv` to manage the project's dependency environment. The following command
will install the correct versions of all dependencies into your `.venv` directory.
We use `uv` to manage the project's dependency environment.
```bash ```bash
uv sync uv sync --all-extras
``` ```
4. Download models: If the download is slow, please try a *local mirror*, for example China:
```bash
uv sync --all-extras --default-index "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"
```
**Tip:** You can remove the `--all-extras` flag if you don't want to install the WebUI support.
**Important:** If you see an error about CUDA during the installation, please ensure
that you have installed NVIDIA's [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit)
version 12.8 (or newer) on your system.
5. Download the required models:
Download via `huggingface-cli`: Download via `huggingface-cli`:
```bash ```bash
uv tool install "huggingface_hub[cli]"
hf download IndexTeam/IndexTTS-2 --local-dir=checkpoints hf download IndexTeam/IndexTTS-2 --local-dir=checkpoints
``` ```
Or download via `modelscope` Or download via `modelscope`:
```bash ```bash
uv tool install "modelscope"
modelscope download --model IndexTeam/IndexTTS-2 --local_dir checkpoints modelscope download --model IndexTeam/IndexTTS-2 --local_dir checkpoints
``` ```
>In addition to the above models, some small models will also be automatically downloaded when the project is run for the first time. If your network environment has slow access to HuggingFace, it is recommended to execute command below. <br> > In addition to the above models, some small models will also be automatically
除了以上模型外项目初次运行时还会自动下载一些小模型如果您的网络环境访问HuggingFace的速度较慢推荐执行 > downloaded when the project is run for the first time. If your network environment
>```bash > has slow access to HuggingFace, it is recommended to execute the following
>export HF_ENDPOINT="https://hf-mirror.com" > command before running the code:
>``` >
> 除了以上模型外项目初次运行时还会自动下载一些小模型如果您的网络环境访问HuggingFace的速度较慢推荐执行
>
> ```bash
> export HF_ENDPOINT="https://hf-mirror.com"
> ```
#### 🖥️ Checking PyTorch GPU Acceleration
### IndexTTS2 Quickstart If you need to diagnose your environment to see which GPUs are detected,
you can use our included utility to check your system:
Examples of running scripts with `uv`.
```bash ```bash
PYTHONPATH=$PYTHONPATH:. uv run python indextts/infer_v2.py uv run tools/gpu_check.py
``` ```
1. Synthesize speech with a single reference audio only:
### 🔥 IndexTTS2 Quickstart
#### 🌐 Web Demo
```bash
uv run webui.py
```
Open your browser and visit `http://127.0.0.1:7860` to see the demo.
#### 📝 Using IndexTTS2 in Python
To run scripts, you *must* use the `uv run <file.py>` command to ensure that
the code runs inside your current "uv" environment. It *may* also be necessary
to add the current directory to your `PYTHONPATH`, to help it find the IndexTTS
modules.
Example of running a script via `uv`:
```bash
PYTHONPATH="$PYTHONPATH:." uv run indextts/infer_v2.py
```
Here are several examples of how to use IndexTTS2 in your own scripts:
1. Synthesize new speech with a single reference audio file (voice cloning):
```python ```python
from indextts.infer_v2 import IndexTTS2 from indextts.infer_v2 import IndexTTS2
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False) tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, use_cuda_kernel=False)
text = "Translate for mewhat is a surprise!" text = "Translate for me, what is a surprise!"
tts.infer(spk_audio_prompt='examples/voice_01.wav', text=text, output_path="gen.wav", verbose=True) tts.infer(spk_audio_prompt='examples/voice_01.wav', text=text, output_path="gen.wav", verbose=True)
``` ```
2. Use additional emotional reference audio to condition speech synthesis: 2. Using a separate, emotional reference audio file to condition the speech synthesis:
```python ```python
from indextts.infer_v2 import IndexTTS2 from indextts.infer_v2 import IndexTTS2
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False) tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, use_cuda_kernel=False)
text = "酒楼丧尽天良,开始借机竞拍房间,哎,一群蠢货。" text = "酒楼丧尽天良,开始借机竞拍房间,哎,一群蠢货。"
tts.infer(spk_audio_prompt='examples/voice_07.wav', text=text, output_path="gen.wav", emo_audio_prompt="examples/emo_sad.wav", verbose=True) tts.infer(spk_audio_prompt='examples/voice_07.wav', text=text, output_path="gen.wav", emo_audio_prompt="examples/emo_sad.wav", verbose=True)
``` ```
3. When an emotional reference audio is specified, you can additionally set the `emo_alpha` parameter. Default value is `1.0`: 3. When an emotional reference audio file is specified, you can optionally set
the `emo_alpha` to adjust how much it affects the output.
Valid range is `0.0 - 1.0`, and the default value is `1.0` (100%):
```python ```python
from indextts.infer_v2 import IndexTTS2 from indextts.infer_v2 import IndexTTS2
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False) tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, use_cuda_kernel=False)
text = "酒楼丧尽天良,开始借机竞拍房间,哎,一群蠢货。" text = "酒楼丧尽天良,开始借机竞拍房间,哎,一群蠢货。"
tts.infer(spk_audio_prompt='examples/voice_07.wav', text=text, output_path="gen.wav", emo_audio_prompt="examples/emo_sad.wav", emo_alpha=0.9, verbose=True) tts.infer(spk_audio_prompt='examples/voice_07.wav', text=text, output_path="gen.wav", emo_audio_prompt="examples/emo_sad.wav", emo_alpha=0.9, verbose=True)
``` ```
4. It's also possible to omit the emotional reference audio and instead provide
an 8-float list specifying the intensity of each emotion, in the following order:
`[happy, angry, sad, afraid, disgusted, melancholic, surprised, calm]`.
You can additionally use the `use_random` parameter to introduce stochasticity
during inference; the default is `False`, and setting it to `True` enables
randomness:
4. It's also possible to omit the emotional reference audio and instead provide an 8-float list specifying the intensity of each base emotion (Happy | Angery | Sad | Fear | Hate | Low | Surprise | Neutral). You can additionally control the `use_random` parameter to decide whether to introduce stochasticity during inference; the default is `False`, and setting it to `True` increases randomness:
```python ```python
from indextts.infer_v2 import IndexTTS2 from indextts.infer_v2 import IndexTTS2
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False) tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, use_cuda_kernel=False)
text = "哇塞!这个爆率也太高了!欧皇附体了!" text = "哇塞!这个爆率也太高了!欧皇附体了!"
tts.infer(spk_audio_prompt='examples/voice_10.wav', text=text, output_path="gen.wav", emo_vector=[0, 0, 0, 0, 0, 0, 0.45, 0], use_random=False, verbose=True) tts.infer(spk_audio_prompt='examples/voice_10.wav', text=text, output_path="gen.wav", emo_vector=[0, 0, 0, 0, 0, 0, 0.45, 0], use_random=False, verbose=True)
``` ```
5. Use a text emotion description via `use_emo_text` to guide synthesis. Control randomness with `use_random` (default: False; True adds randomness): 5. Alternatively, you can enable `use_emo_text` to guide the emotions based on
your provided `text` script. Your text script will then automatically
be converted into emotion vectors.
You can introduce randomness with `use_random` (default: `False`;
`True` enables randomness):
```python ```python
from indextts.infer_v2 import IndexTTS2 from indextts.infer_v2 import IndexTTS2
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False) tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, use_cuda_kernel=False)
text = "快躲起来!是他要来了!他要来抓我们了!" text = "快躲起来!是他要来了!他要来抓我们了!"
tts.infer(spk_audio_prompt='examples/voice_12.wav', text=text, output_path="gen.wav", use_emo_text=True, use_random=False, verbose=True) tts.infer(spk_audio_prompt='examples/voice_12.wav', text=text, output_path="gen.wav", use_emo_text=True, use_random=False, verbose=True)
``` ```
6. Without `emo_text`, infer emotion from the synthesis script; with `emo_text`, infer from the provided text. 6. It's also possible to directly provide a specific text emotion description
via the `emo_text` parameter. Your emotion text will then automatically be
converted into emotion vectors. This gives you separate control of the text
script and the text emotion description:
```python ```python
from indextts.infer_v2 import IndexTTS2 from indextts.infer_v2 import IndexTTS2
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False) tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, use_cuda_kernel=False)
text = "快躲起来!是他要来了!他要来抓我们了!" text = "快躲起来!是他要来了!他要来抓我们了!"
emo_text = "你吓死我了!你是鬼吗?" emo_text = "你吓死我了!你是鬼吗?"
tts.infer(spk_audio_prompt='examples/voice_12.wav', text=text, output_path="gen.wav", use_emo_text=True, emo_text=emo_text, use_random=False, verbose=True) tts.infer(spk_audio_prompt='examples/voice_12.wav', text=text, output_path="gen.wav", use_emo_text=True, emo_text=emo_text, use_random=False, verbose=True)
``` ```
### IndexTTS1 User Guide
### Legacy: IndexTTS1 User Guide
You can also use our previous IndexTTS1 model by importing a different module:
```python ```python
from indextts.infer import IndexTTS from indextts.infer import IndexTTS
tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml") tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml")
@@ -197,36 +304,20 @@ voice = "examples/voice_07.wav"
text = "大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了比如说现在正在说话的其实是B站为我现场复刻的数字分身简直就是平行宇宙的另一个我了。如果大家也想体验更多深入的AIGC功能可以访问 bilibili studio相信我你们也会吃惊的。" text = "大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了比如说现在正在说话的其实是B站为我现场复刻的数字分身简直就是平行宇宙的另一个我了。如果大家也想体验更多深入的AIGC功能可以访问 bilibili studio相信我你们也会吃惊的。"
tts.infer(voice, text, 'gen.wav') tts.infer(voice, text, 'gen.wav')
``` ```
For more information, see [README_INDEXTTS_1_5](archive/README_INDEXTTS_1_5.md), or visit the specific version at <a href="https://github.com/index-tts/index-tts/tree/v1.5.0">index-tts:v1.5.0</a>
### Web Demo For more detailed information, see [README_INDEXTTS_1_5](archive/README_INDEXTTS_1_5.md),
```bash or visit the IndexTTS1 repository at <a href="https://github.com/index-tts/index-tts/tree/v1.5.0">index-tts:v1.5.0</a>.
PYTHONPATH=$PYTHONPATH:. uv run webui.py
```
Open your browser and visit `http://127.0.0.1:7860` to see the demo.
### Note for Windows Users
On Windows, you may encounter [an error](https://github.com/index-tts/index-tts/issues/61) when installing `pynini`:
`ERROR: Failed building wheel for pynini`
In this case, please install `pynini` via `conda`:
```bash
# after conda activate index-tts
conda install -c conda-forge pynini==2.1.5
pip install WeTextProcessing==1.0.3
pip install -e ".[webui]"
```
## 👉🏻 IndexTTS 👈🏻 ## Our Releases and Demos
### IndexTTS2: [[Paper]](https://arxiv.org/abs/2506.21619); [[Demo]](https://index-tts.github.io/index-tts2.github.io/) ### IndexTTS2: [[Paper]](https://arxiv.org/abs/2506.21619); [[Demo]](https://index-tts.github.io/index-tts2.github.io/)
### IndexTTS1: [[Paper]](https://arxiv.org/abs/2502.05512); [[Demo]](https://index-tts.github.io/); [[ModelScope]](https://modelscope.cn/studios/IndexTeam/IndexTTS-Demo); [[HuggingFace]](https://huggingface.co/spaces/IndexTeam/IndexTTS) ### IndexTTS1: [[Paper]](https://arxiv.org/abs/2502.05512); [[Demo]](https://index-tts.github.io/); [[ModelScope]](https://modelscope.cn/studios/IndexTeam/IndexTTS-Demo); [[HuggingFace]](https://huggingface.co/spaces/IndexTeam/IndexTTS)
## Acknowledge ## Acknowledgements
1. [tortoise-tts](https://github.com/neonbjb/tortoise-tts) 1. [tortoise-tts](https://github.com/neonbjb/tortoise-tts)
2. [XTTSv2](https://github.com/coqui-ai/TTS) 2. [XTTSv2](https://github.com/coqui-ai/TTS)
3. [BigVGAN](https://github.com/NVIDIA/BigVGAN) 3. [BigVGAN](https://github.com/NVIDIA/BigVGAN)
@@ -241,7 +332,8 @@ pip install -e ".[webui]"
🌟 If you find our work helpful, please leave us a star and cite our paper. 🌟 If you find our work helpful, please leave us a star and cite our paper.
IndexTTS2 IndexTTS2:
``` ```
@article{zhou2025indextts2, @article{zhou2025indextts2,
title={IndexTTS2: A Breakthrough in Emotionally Expressive and Duration-Controlled Auto-Regressive Zero-Shot Text-to-Speech}, title={IndexTTS2: A Breakthrough in Emotionally Expressive and Duration-Controlled Auto-Regressive Zero-Shot Text-to-Speech},
@@ -251,7 +343,9 @@ IndexTTS2
} }
``` ```
IndexTTS
IndexTTS:
``` ```
@article{deng2025indextts, @article{deng2025indextts,
title={IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System}, title={IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System},

View File

@@ -12,7 +12,7 @@ def main():
parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file") parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'") parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'") parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
parser.add_argument("--fp16", action="store_true", default=True, help="Use FP16 for inference if available") parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists") parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps)." ) parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps)." )
args = parser.parse_args() args = parser.parse_args()
@@ -54,8 +54,9 @@ def main():
args.fp16 = False # Disable FP16 on CPU args.fp16 = False # Disable FP16 on CPU
print("WARNING: Running on CPU may be slow.") print("WARNING: Running on CPU may be slow.")
# TODO: Add CLI support for IndexTTS2.
from indextts.infer import IndexTTS from indextts.infer import IndexTTS
tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, is_fp16=args.fp16, device=args.device) tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, use_fp16=args.fp16, device=args.device)
tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path) tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -26,38 +26,38 @@ from indextts.utils.front import TextNormalizer, TextTokenizer
class IndexTTS: class IndexTTS:
def __init__( def __init__(
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, device=None, self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=True, device=None,
use_cuda_kernel=None, use_cuda_kernel=None,
): ):
""" """
Args: Args:
cfg_path (str): path to the config file. cfg_path (str): path to the config file.
model_dir (str): path to the model directory. model_dir (str): path to the model directory.
is_fp16 (bool): whether to use fp16. use_fp16 (bool): whether to use fp16.
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS. device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device. use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
""" """
if device is not None: if device is not None:
self.device = device self.device = device
self.is_fp16 = False if device == "cpu" else is_fp16 self.use_fp16 = False if device == "cpu" else use_fp16
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda") self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
elif torch.cuda.is_available(): elif torch.cuda.is_available():
self.device = "cuda:0" self.device = "cuda:0"
self.is_fp16 = is_fp16 self.use_fp16 = use_fp16
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
elif hasattr(torch, "mps") and torch.backends.mps.is_available(): elif hasattr(torch, "mps") and torch.backends.mps.is_available():
self.device = "mps" self.device = "mps"
self.is_fp16 = False # Use float16 on MPS is overhead than float32 self.use_fp16 = False # Use float16 on MPS is overhead than float32
self.use_cuda_kernel = False self.use_cuda_kernel = False
else: else:
self.device = "cpu" self.device = "cpu"
self.is_fp16 = False self.use_fp16 = False
self.use_cuda_kernel = False self.use_cuda_kernel = False
print(">> Be patient, it may take a while to run in CPU mode.") print(">> Be patient, it may take a while to run in CPU mode.")
self.cfg = OmegaConf.load(cfg_path) self.cfg = OmegaConf.load(cfg_path)
self.model_dir = model_dir self.model_dir = model_dir
self.dtype = torch.float16 if self.is_fp16 else None self.dtype = torch.float16 if self.use_fp16 else None
self.stop_mel_token = self.cfg.gpt.stop_mel_token self.stop_mel_token = self.cfg.gpt.stop_mel_token
# Comment-off to load the VQ-VAE model for debugging tokenizer # Comment-off to load the VQ-VAE model for debugging tokenizer
@@ -68,7 +68,7 @@ class IndexTTS:
# self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint) # self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
# load_checkpoint(self.dvae, self.dvae_path) # load_checkpoint(self.dvae, self.dvae_path)
# self.dvae = self.dvae.to(self.device) # self.dvae = self.dvae.to(self.device)
# if self.is_fp16: # if self.use_fp16:
# self.dvae.eval().half() # self.dvae.eval().half()
# else: # else:
# self.dvae.eval() # self.dvae.eval()
@@ -77,12 +77,12 @@ class IndexTTS:
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint) self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
load_checkpoint(self.gpt, self.gpt_path) load_checkpoint(self.gpt, self.gpt_path)
self.gpt = self.gpt.to(self.device) self.gpt = self.gpt.to(self.device)
if self.is_fp16: if self.use_fp16:
self.gpt.eval().half() self.gpt.eval().half()
else: else:
self.gpt.eval() self.gpt.eval()
print(">> GPT weights restored from:", self.gpt_path) print(">> GPT weights restored from:", self.gpt_path)
if self.is_fp16: if self.use_fp16:
try: try:
import deepspeed import deepspeed
@@ -184,17 +184,17 @@ class IndexTTS:
code_lens = torch.tensor(code_lens, dtype=torch.long, device=device) code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
return codes, code_lens return codes, code_lens
def bucket_sentences(self, sentences, bucket_max_size=4) -> List[List[Dict]]: def bucket_segments(self, segments, bucket_max_size=4) -> List[List[Dict]]:
""" """
Sentence data bucketing. Segment data bucketing.
if ``bucket_max_size=1``, return all sentences in one bucket. if ``bucket_max_size=1``, return all segments in one bucket.
""" """
outputs: List[Dict] = [] outputs: List[Dict] = []
for idx, sent in enumerate(sentences): for idx, sent in enumerate(segments):
outputs.append({"idx": idx, "sent": sent, "len": len(sent)}) outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
if len(outputs) > bucket_max_size: if len(outputs) > bucket_max_size:
# split sentences into buckets by sentence length # split segments into buckets by segment length
buckets: List[List[Dict]] = [] buckets: List[List[Dict]] = []
factor = 1.5 factor = 1.5
last_bucket = None last_bucket = None
@@ -203,7 +203,7 @@ class IndexTTS:
for sent in sorted(outputs, key=lambda x: x["len"]): for sent in sorted(outputs, key=lambda x: x["len"]):
current_sent_len = sent["len"] current_sent_len = sent["len"]
if current_sent_len == 0: if current_sent_len == 0:
print(">> skip empty sentence") print(">> skip empty segment")
continue continue
if last_bucket is None \ if last_bucket is None \
or current_sent_len >= int(last_bucket_sent_len_median * factor) \ or current_sent_len >= int(last_bucket_sent_len_median * factor) \
@@ -213,7 +213,7 @@ class IndexTTS:
last_bucket = buckets[-1] last_bucket = buckets[-1]
last_bucket_sent_len_median = current_sent_len last_bucket_sent_len_median = current_sent_len
else: else:
# current bucket can hold more sentences # current bucket can hold more segments
last_bucket.append(sent) # sorted last_bucket.append(sent) # sorted
mid = len(last_bucket) // 2 mid = len(last_bucket) // 2
last_bucket_sent_len_median = last_bucket[mid]["len"] last_bucket_sent_len_median = last_bucket[mid]["len"]
@@ -276,14 +276,14 @@ class IndexTTS:
self.gr_progress(value, desc=desc) self.gr_progress(value, desc=desc)
# 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ First modified by sunnyboxs 2025-04-16 # 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ First modified by sunnyboxs 2025-04-16
def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=100, def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=100,
sentences_bucket_max_size=4, **generation_kwargs): segments_bucket_max_size=4, **generation_kwargs):
""" """
Args: Args:
``max_text_tokens_per_sentence``: 分句的最大token数默认``100``可以根据GPU硬件情况调整 ``max_text_tokens_per_segment``: 分句的最大token数默认``100``可以根据GPU硬件情况调整
- 越小batch 越多,推理速度越*快*,占用内存更多,可能影响质量 - 越小batch 越多,推理速度越*快*,占用内存更多,可能影响质量
- 越大batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理 - 越大batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
``sentences_bucket_max_size``: 分句分桶的最大容量,默认``4``可以根据GPU内存调整 ``segments_bucket_max_size``: 分句分桶的最大容量,默认``4``可以根据GPU内存调整
- 越大bucket数量越少batch越多推理速度越*快*,占用内存更多,可能影响质量 - 越大bucket数量越少batch越多推理速度越*快*,占用内存更多,可能影响质量
- 越小bucket数量越多batch越少推理速度越*慢*,占用内存和质量更接近于非快速推理 - 越小bucket数量越多batch越少推理速度越*慢*,占用内存和质量更接近于非快速推理
""" """
@@ -319,13 +319,13 @@ class IndexTTS:
# text_tokens # text_tokens
text_tokens_list = self.tokenizer.tokenize(text) text_tokens_list = self.tokenizer.tokenize(text)
sentences = self.tokenizer.split_sentences(text_tokens_list, segments = self.tokenizer.split_segments(text_tokens_list,
max_tokens_per_sentence=max_text_tokens_per_sentence) max_text_tokens_per_segment=max_text_tokens_per_segment)
if verbose: if verbose:
print(">> text token count:", len(text_tokens_list)) print(">> text token count:", len(text_tokens_list))
print(" splited sentences count:", len(sentences)) print(" segments count:", len(segments))
print(" max_text_tokens_per_sentence:", max_text_tokens_per_sentence) print(" max_text_tokens_per_segment:", max_text_tokens_per_segment)
print(*sentences, sep="\n") print(*segments, sep="\n")
do_sample = generation_kwargs.pop("do_sample", True) do_sample = generation_kwargs.pop("do_sample", True)
top_p = generation_kwargs.pop("top_p", 0.8) top_p = generation_kwargs.pop("top_p", 0.8)
top_k = generation_kwargs.pop("top_k", 30) top_k = generation_kwargs.pop("top_k", 30)
@@ -346,17 +346,17 @@ class IndexTTS:
# text processing # text processing
all_text_tokens: List[List[torch.Tensor]] = [] all_text_tokens: List[List[torch.Tensor]] = []
self._set_gr_progress(0.1, "text processing...") self._set_gr_progress(0.1, "text processing...")
bucket_max_size = sentences_bucket_max_size if self.device != "cpu" else 1 bucket_max_size = segments_bucket_max_size if self.device != "cpu" else 1
all_sentences = self.bucket_sentences(sentences, bucket_max_size=bucket_max_size) all_segments = self.bucket_segments(segments, bucket_max_size=bucket_max_size)
bucket_count = len(all_sentences) bucket_count = len(all_segments)
if verbose: if verbose:
print(">> sentences bucket_count:", bucket_count, print(">> segments bucket_count:", bucket_count,
"bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_sentences], "bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_segments],
"bucket_max_size:", bucket_max_size) "bucket_max_size:", bucket_max_size)
for sentences in all_sentences: for segments in all_segments:
temp_tokens: List[torch.Tensor] = [] temp_tokens: List[torch.Tensor] = []
all_text_tokens.append(temp_tokens) all_text_tokens.append(temp_tokens)
for item in sentences: for item in segments:
sent = item["sent"] sent = item["sent"]
text_tokens = self.tokenizer.convert_tokens_to_ids(sent) text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0) text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
@@ -365,11 +365,11 @@ class IndexTTS:
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
# debug tokenizer # debug tokenizer
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist()) text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
print("text_token_syms is same as sentence tokens", text_token_syms == sent) print("text_token_syms is same as segment tokens", text_token_syms == sent)
temp_tokens.append(text_tokens) temp_tokens.append(text_tokens)
# Sequential processing of bucketing data # Sequential processing of bucketing data
all_batch_num = sum(len(s) for s in all_sentences) all_batch_num = sum(len(s) for s in all_segments)
all_batch_codes = [] all_batch_codes = []
processed_num = 0 processed_num = 0
for item_tokens in all_text_tokens: for item_tokens in all_text_tokens:
@@ -407,13 +407,13 @@ class IndexTTS:
all_idxs = [] all_idxs = []
all_latents = [] all_latents = []
has_warned = False has_warned = False
for batch_codes, batch_tokens, batch_sentences in zip(all_batch_codes, all_text_tokens, all_sentences): for batch_codes, batch_tokens, batch_segments in zip(all_batch_codes, all_text_tokens, all_segments):
for i in range(batch_codes.shape[0]): for i in range(batch_codes.shape[0]):
codes = batch_codes[i] # [x] codes = batch_codes[i] # [x]
if not has_warned and codes[-1] != self.stop_mel_token: if not has_warned and codes[-1] != self.stop_mel_token:
warnings.warn( warnings.warn(
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). " f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.", f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
category=RuntimeWarning category=RuntimeWarning
) )
has_warned = True has_warned = True
@@ -427,7 +427,7 @@ class IndexTTS:
print(codes) print(codes)
print("code_lens:", code_lens) print("code_lens:", code_lens)
text_tokens = batch_tokens[i] text_tokens = batch_tokens[i]
all_idxs.append(batch_sentences[i]["idx"]) all_idxs.append(batch_segments[i]["idx"])
m_start_time = time.perf_counter() m_start_time = time.perf_counter()
with torch.no_grad(): with torch.no_grad():
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
@@ -440,7 +440,7 @@ class IndexTTS:
return_latent=True, clip_inputs=False) return_latent=True, clip_inputs=False)
gpt_forward_time += time.perf_counter() - m_start_time gpt_forward_time += time.perf_counter() - m_start_time
all_latents.append(latent) all_latents.append(latent)
del all_batch_codes, all_text_tokens, all_sentences del all_batch_codes, all_text_tokens, all_segments
# bigvgan chunk # bigvgan chunk
chunk_size = 2 chunk_size = 2
all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))] all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
@@ -503,7 +503,7 @@ class IndexTTS:
return (sampling_rate, wav_data) return (sampling_rate, wav_data)
# 原始推理模式 # 原始推理模式
def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=120, def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=120,
**generation_kwargs): **generation_kwargs):
print(">> start inference...") print(">> start inference...")
self._set_gr_progress(0, "start inference...") self._set_gr_progress(0, "start inference...")
@@ -533,12 +533,12 @@ class IndexTTS:
self._set_gr_progress(0.1, "text processing...") self._set_gr_progress(0.1, "text processing...")
auto_conditioning = cond_mel auto_conditioning = cond_mel
text_tokens_list = self.tokenizer.tokenize(text) text_tokens_list = self.tokenizer.tokenize(text)
sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence) segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
if verbose: if verbose:
print("text token count:", len(text_tokens_list)) print("text token count:", len(text_tokens_list))
print("sentences count:", len(sentences)) print("segments count:", len(segments))
print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence) print("max_text_tokens_per_segment:", max_text_tokens_per_segment)
print(*sentences, sep="\n") print(*segments, sep="\n")
do_sample = generation_kwargs.pop("do_sample", True) do_sample = generation_kwargs.pop("do_sample", True)
top_p = generation_kwargs.pop("top_p", 0.8) top_p = generation_kwargs.pop("top_p", 0.8)
top_k = generation_kwargs.pop("top_k", 30) top_k = generation_kwargs.pop("top_k", 30)
@@ -557,7 +557,7 @@ class IndexTTS:
bigvgan_time = 0 bigvgan_time = 0
progress = 0 progress = 0
has_warned = False has_warned = False
for sent in sentences: for sent in segments:
text_tokens = self.tokenizer.convert_tokens_to_ids(sent) text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0) text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. # text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
@@ -568,13 +568,13 @@ class IndexTTS:
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
# debug tokenizer # debug tokenizer
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist()) text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
print("text_token_syms is same as sentence tokens", text_token_syms == sent) print("text_token_syms is same as segment tokens", text_token_syms == sent)
# text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device) # text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device)
# print(text_len) # print(text_len)
progress += 1 progress += 1
self._set_gr_progress(0.2 + 0.4 * (progress - 1) / len(sentences), self._set_gr_progress(0.2 + 0.4 * (progress - 1) / len(segments),
f"gpt inference latent... {progress}/{len(sentences)}") f"gpt inference latent... {progress}/{len(segments)}")
m_start_time = time.perf_counter() m_start_time = time.perf_counter()
with torch.no_grad(): with torch.no_grad():
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
@@ -597,7 +597,7 @@ class IndexTTS:
warnings.warn( warnings.warn(
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). " f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
f"Input text tokens: {text_tokens.shape[1]}. " f"Input text tokens: {text_tokens.shape[1]}. "
f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.", f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
category=RuntimeWarning category=RuntimeWarning
) )
has_warned = True has_warned = True
@@ -615,8 +615,8 @@ class IndexTTS:
print(codes, type(codes)) print(codes, type(codes))
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}") print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
print(f"code len: {code_lens}") print(f"code len: {code_lens}")
self._set_gr_progress(0.2 + 0.4 * progress / len(sentences), self._set_gr_progress(0.2 + 0.4 * progress / len(segments),
f"gpt inference speech... {progress}/{len(sentences)}") f"gpt inference speech... {progress}/{len(segments)}")
m_start_time = time.perf_counter() m_start_time = time.perf_counter()
# latent, text_lens_out, code_lens_out = \ # latent, text_lens_out, code_lens_out = \
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):

View File

@@ -37,38 +37,38 @@ import torch.nn.functional as F
class IndexTTS2: class IndexTTS2:
def __init__( def __init__(
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, device=None, self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None,
use_cuda_kernel=None, use_cuda_kernel=None,
): ):
""" """
Args: Args:
cfg_path (str): path to the config file. cfg_path (str): path to the config file.
model_dir (str): path to the model directory. model_dir (str): path to the model directory.
is_fp16 (bool): whether to use fp16. use_fp16 (bool): whether to use fp16.
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS. device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device. use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
""" """
if device is not None: if device is not None:
self.device = device self.device = device
self.is_fp16 = False if device == "cpu" else is_fp16 self.use_fp16 = False if device == "cpu" else use_fp16
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda") self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
elif torch.cuda.is_available(): elif torch.cuda.is_available():
self.device = "cuda:0" self.device = "cuda:0"
self.is_fp16 = is_fp16 self.use_fp16 = use_fp16
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
elif hasattr(torch, "mps") and torch.backends.mps.is_available(): elif hasattr(torch, "mps") and torch.backends.mps.is_available():
self.device = "mps" self.device = "mps"
self.is_fp16 = False # Use float16 on MPS is overhead than float32 self.use_fp16 = False # Use float16 on MPS is overhead than float32
self.use_cuda_kernel = False self.use_cuda_kernel = False
else: else:
self.device = "cpu" self.device = "cpu"
self.is_fp16 = False self.use_fp16 = False
self.use_cuda_kernel = False self.use_cuda_kernel = False
print(">> Be patient, it may take a while to run in CPU mode.") print(">> Be patient, it may take a while to run in CPU mode.")
self.cfg = OmegaConf.load(cfg_path) self.cfg = OmegaConf.load(cfg_path)
self.model_dir = model_dir self.model_dir = model_dir
self.dtype = torch.float16 if self.is_fp16 else None self.dtype = torch.float16 if self.use_fp16 else None
self.stop_mel_token = self.cfg.gpt.stop_mel_token self.stop_mel_token = self.cfg.gpt.stop_mel_token
self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path)) self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
@@ -77,7 +77,7 @@ class IndexTTS2:
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint) self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
load_checkpoint(self.gpt, self.gpt_path) load_checkpoint(self.gpt, self.gpt_path)
self.gpt = self.gpt.to(self.device) self.gpt = self.gpt.to(self.device)
if self.is_fp16: if self.use_fp16:
self.gpt.eval().half() self.gpt.eval().half()
else: else:
self.gpt.eval() self.gpt.eval()
@@ -90,7 +90,7 @@ class IndexTTS2:
use_deepspeed = False use_deepspeed = False
print(f">> DeepSpeed加载失败回退到标准推理: {e}") print(f">> DeepSpeed加载失败回退到标准推理: {e}")
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=self.is_fp16) self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=self.use_fp16)
if self.use_cuda_kernel: if self.use_cuda_kernel:
# preload the CUDA kernel for BigVGAN # preload the CUDA kernel for BigVGAN
@@ -262,7 +262,7 @@ class IndexTTS2:
def insert_interval_silence(self, wavs, sampling_rate=22050, interval_silence=200): def insert_interval_silence(self, wavs, sampling_rate=22050, interval_silence=200):
""" """
Insert silences between sentences. Insert silences between generated segments.
wavs: List[torch.tensor] wavs: List[torch.tensor]
""" """
@@ -292,7 +292,7 @@ class IndexTTS2:
emo_audio_prompt=None, emo_alpha=1.0, emo_audio_prompt=None, emo_alpha=1.0,
emo_vector=None, emo_vector=None,
use_emo_text=False, emo_text=None, use_random=False, interval_silence=200, use_emo_text=False, emo_text=None, use_random=False, interval_silence=200,
verbose=False, max_text_tokens_per_sentence=120, **generation_kwargs): verbose=False, max_text_tokens_per_segment=120, **generation_kwargs):
print(">> start inference...") print(">> start inference...")
self._set_gr_progress(0, "start inference...") self._set_gr_progress(0, "start inference...")
if verbose: if verbose:
@@ -394,12 +394,12 @@ class IndexTTS2:
self._set_gr_progress(0.1, "text processing...") self._set_gr_progress(0.1, "text processing...")
text_tokens_list = self.tokenizer.tokenize(text) text_tokens_list = self.tokenizer.tokenize(text)
sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence) segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
if verbose: if verbose:
print("text_tokens_list:", text_tokens_list) print("text_tokens_list:", text_tokens_list)
print("sentences count:", len(sentences)) print("segments count:", len(segments))
print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence) print("max_text_tokens_per_segment:", max_text_tokens_per_segment)
print(*sentences, sep="\n") print(*segments, sep="\n")
do_sample = generation_kwargs.pop("do_sample", True) do_sample = generation_kwargs.pop("do_sample", True)
top_p = generation_kwargs.pop("top_p", 0.8) top_p = generation_kwargs.pop("top_p", 0.8)
top_k = generation_kwargs.pop("top_k", 30) top_k = generation_kwargs.pop("top_k", 30)
@@ -418,7 +418,7 @@ class IndexTTS2:
bigvgan_time = 0 bigvgan_time = 0
progress = 0 progress = 0
has_warned = False has_warned = False
for sent in sentences: for sent in segments:
text_tokens = self.tokenizer.convert_tokens_to_ids(sent) text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0) text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
if verbose: if verbose:
@@ -426,7 +426,7 @@ class IndexTTS2:
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
# debug tokenizer # debug tokenizer
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist()) text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
print("text_token_syms is same as sentence tokens", text_token_syms == sent) print("text_token_syms is same as segment tokens", text_token_syms == sent)
m_start_time = time.perf_counter() m_start_time = time.perf_counter()
with torch.no_grad(): with torch.no_grad():
@@ -467,7 +467,7 @@ class IndexTTS2:
warnings.warn( warnings.warn(
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). " f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
f"Input text tokens: {text_tokens.shape[1]}. " f"Input text tokens: {text_tokens.shape[1]}. "
f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.", f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
category=RuntimeWarning category=RuntimeWarning
) )
has_warned = True has_warned = True

View File

@@ -63,9 +63,9 @@ class BaseSpeakerTTS(OpenVoiceBaseClass):
return audio_segments return audio_segments
@staticmethod @staticmethod
def split_sentences_into_pieces(text, language_str): def split_segments_into_pieces(text, language_str):
texts = utils.split_sentence(text, language_str=language_str) texts = utils.split_segment(text, language_str=language_str)
print(" > Text splitted to sentences.") print(" > Text split into segments.")
print('\n'.join(texts)) print('\n'.join(texts))
print(" > ===========================") print(" > ===========================")
return texts return texts
@@ -74,7 +74,7 @@ class BaseSpeakerTTS(OpenVoiceBaseClass):
mark = self.language_marks.get(language.lower(), None) mark = self.language_marks.get(language.lower(), None)
assert mark is not None, f"language {language} is not supported" assert mark is not None, f"language {language} is not supported"
texts = self.split_sentences_into_pieces(text, mark) texts = self.split_segments_into_pieces(text, mark)
audio_list = [] audio_list = []
for t in texts: for t in texts:

View File

@@ -233,7 +233,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
with gr.Column(): with gr.Column():
input_text_gr = gr.Textbox( input_text_gr = gr.Textbox(
label="Text Prompt", label="Text Prompt",
info="One or two sentences at a time is better. Up to 200 text characters.", info="One or two sentences at a time produces the best results. Up to 200 text characters.",
value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.", value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
) )
style_gr = gr.Dropdown( style_gr = gr.Dropdown(

View File

@@ -75,23 +75,23 @@ def bits_to_string(bits_array):
return output_string return output_string
def split_sentence(text, min_len=10, language_str='[EN]'): def split_segment(text, min_len=10, language_str='[EN]'):
if language_str in ['EN']: if language_str in ['EN']:
sentences = split_sentences_latin(text, min_len=min_len) segments = split_segments_latin(text, min_len=min_len)
else: else:
sentences = split_sentences_zh(text, min_len=min_len) segments = split_segments_zh(text, min_len=min_len)
return sentences return segments
def split_sentences_latin(text, min_len=10): def split_segments_latin(text, min_len=10):
"""Split Long sentences into list of short ones """Split Long sentences into list of short segments.
Args: Args:
str: Input sentences. str: Input sentences.
Returns: Returns:
List[str]: list of output sentences. List[str]: list of output segments.
""" """
# deal with dirty sentences # deal with dirty text characters
text = re.sub('[。!?;]', '.', text) text = re.sub('[。!?;]', '.', text)
text = re.sub('[]', ',', text) text = re.sub('[]', ',', text)
text = re.sub('[“”]', '"', text) text = re.sub('[“”]', '"', text)
@@ -100,36 +100,36 @@ def split_sentences_latin(text, min_len=10):
text = re.sub('[\n\t ]+', ' ', text) text = re.sub('[\n\t ]+', ' ', text)
text = re.sub('([,.!?;])', r'\1 $#!', text) text = re.sub('([,.!?;])', r'\1 $#!', text)
# split # split
sentences = [s.strip() for s in text.split('$#!')] segments = [s.strip() for s in text.split('$#!')]
if len(sentences[-1]) == 0: del sentences[-1] if len(segments[-1]) == 0: del segments[-1]
new_sentences = [] new_segments = []
new_sent = [] new_sent = []
count_len = 0 count_len = 0
for ind, sent in enumerate(sentences): for ind, sent in enumerate(segments):
# print(sent) # print(sent)
new_sent.append(sent) new_sent.append(sent)
count_len += len(sent.split(" ")) count_len += len(sent.split(" "))
if count_len > min_len or ind == len(sentences) - 1: if count_len > min_len or ind == len(segments) - 1:
count_len = 0 count_len = 0
new_sentences.append(' '.join(new_sent)) new_segments.append(' '.join(new_sent))
new_sent = [] new_sent = []
return merge_short_sentences_latin(new_sentences) return merge_short_segments_latin(new_segments)
def merge_short_sentences_latin(sens): def merge_short_segments_latin(sens):
"""Avoid short sentences by merging them with the following sentence. """Avoid short segments by merging them with the following segment.
Args: Args:
List[str]: list of input sentences. List[str]: list of input segments.
Returns: Returns:
List[str]: list of output sentences. List[str]: list of output segments.
""" """
sens_out = [] sens_out = []
for s in sens: for s in sens:
# If the previous sentence is too short, merge them with # If the previous segment is too short, merge them with
# the current sentence. # the current segment.
if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2: if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
sens_out[-1] = sens_out[-1] + " " + s sens_out[-1] = sens_out[-1] + " " + s
else: else:
@@ -142,7 +142,7 @@ def merge_short_sentences_latin(sens):
pass pass
return sens_out return sens_out
def split_sentences_zh(text, min_len=10): def split_segments_zh(text, min_len=10):
text = re.sub('[。!?;]', '.', text) text = re.sub('[。!?;]', '.', text)
text = re.sub('[]', ',', text) text = re.sub('[]', ',', text)
# 将文本中的换行符、空格和制表符替换为空格 # 将文本中的换行符、空格和制表符替换为空格
@@ -150,37 +150,37 @@ def split_sentences_zh(text, min_len=10):
# 在标点符号后添加一个空格 # 在标点符号后添加一个空格
text = re.sub('([,.!?;])', r'\1 $#!', text) text = re.sub('([,.!?;])', r'\1 $#!', text)
# 分隔句子并去除前后空格 # 分隔句子并去除前后空格
# sentences = [s.strip() for s in re.split('(。|||)', text)] # segments = [s.strip() for s in re.split('(。|||)', text)]
sentences = [s.strip() for s in text.split('$#!')] segments = [s.strip() for s in text.split('$#!')]
if len(sentences[-1]) == 0: del sentences[-1] if len(segments[-1]) == 0: del segments[-1]
new_sentences = [] new_segments = []
new_sent = [] new_sent = []
count_len = 0 count_len = 0
for ind, sent in enumerate(sentences): for ind, sent in enumerate(segments):
new_sent.append(sent) new_sent.append(sent)
count_len += len(sent) count_len += len(sent)
if count_len > min_len or ind == len(sentences) - 1: if count_len > min_len or ind == len(segments) - 1:
count_len = 0 count_len = 0
new_sentences.append(' '.join(new_sent)) new_segments.append(' '.join(new_sent))
new_sent = [] new_sent = []
return merge_short_sentences_zh(new_sentences) return merge_short_segments_zh(new_segments)
def merge_short_sentences_zh(sens): def merge_short_segments_zh(sens):
# return sens # return sens
"""Avoid short sentences by merging them with the following sentence. """Avoid short segments by merging them with the following segment.
Args: Args:
List[str]: list of input sentences. List[str]: list of input segments.
Returns: Returns:
List[str]: list of output sentences. List[str]: list of output segments.
""" """
sens_out = [] sens_out = []
for s in sens: for s in sens:
# If the previous sentense is too short, merge them with # If the previous sentense is too short, merge them with
# the current sentence. # the current segment.
if len(sens_out) > 0 and len(sens_out[-1]) <= 2: if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
sens_out[-1] = sens_out[-1] + " " + s sens_out[-1] = sens_out[-1] + " " + s
else: else:

View File

@@ -342,8 +342,8 @@ class TextTokenizer:
return de_tokenized_by_CJK_char(decoded, do_lower_case=do_lower_case) return de_tokenized_by_CJK_char(decoded, do_lower_case=do_lower_case)
@staticmethod @staticmethod
def split_sentences_by_token( def split_segments_by_token(
tokenized_str: List[str], split_tokens: List[str], max_tokens_per_sentence: int tokenized_str: List[str], split_tokens: List[str], max_text_tokens_per_segment: int
) -> List[List[str]]: ) -> List[List[str]]:
""" """
将tokenize后的结果按特定token进一步分割 将tokenize后的结果按特定token进一步分割
@@ -351,67 +351,67 @@ class TextTokenizer:
# 处理特殊情况 # 处理特殊情况
if len(tokenized_str) == 0: if len(tokenized_str) == 0:
return [] return []
sentences: List[List[str]] = [] segments: List[List[str]] = []
current_sentence = [] current_segment = []
current_sentence_tokens_len = 0 current_segment_tokens_len = 0
for i in range(len(tokenized_str)): for i in range(len(tokenized_str)):
token = tokenized_str[i] token = tokenized_str[i]
current_sentence.append(token) current_segment.append(token)
current_sentence_tokens_len += 1 current_segment_tokens_len += 1
if current_sentence_tokens_len <= max_tokens_per_sentence: if current_segment_tokens_len <= max_text_tokens_per_segment:
if token in split_tokens and current_sentence_tokens_len > 2: if token in split_tokens and current_segment_tokens_len > 2:
if i < len(tokenized_str) - 1: if i < len(tokenized_str) - 1:
if tokenized_str[i + 1] in ["'", "'"]: if tokenized_str[i + 1] in ["'", "'"]:
# 后续token是',则不切分 # 后续token是',则不切分
current_sentence.append(tokenized_str[i + 1]) current_segment.append(tokenized_str[i + 1])
i += 1 i += 1
sentences.append(current_sentence) segments.append(current_segment)
current_sentence = [] current_segment = []
current_sentence_tokens_len = 0 current_segment_tokens_len = 0
continue continue
# 如果当前tokens的长度超过最大限制 # 如果当前tokens的长度超过最大限制
if not ("," in split_tokens or "▁," in split_tokens ) and ("," in current_sentence or "▁," in current_sentence): if not ("," in split_tokens or "▁," in split_tokens ) and ("," in current_segment or "▁," in current_segment):
# 如果当前tokens中有,,则按,分割 # 如果当前tokens中有,,则按,分割
sub_sentences = TextTokenizer.split_sentences_by_token( sub_segments = TextTokenizer.split_segments_by_token(
current_sentence, [",", "▁,"], max_tokens_per_sentence=max_tokens_per_sentence current_segment, [",", "▁,"], max_text_tokens_per_segment=max_text_tokens_per_segment
) )
elif "-" not in split_tokens and "-" in current_sentence: elif "-" not in split_tokens and "-" in current_segment:
# 没有,,则按-分割 # 没有,,则按-分割
sub_sentences = TextTokenizer.split_sentences_by_token( sub_segments = TextTokenizer.split_segments_by_token(
current_sentence, ["-"], max_tokens_per_sentence=max_tokens_per_sentence current_segment, ["-"], max_text_tokens_per_segment=max_text_tokens_per_segment
) )
else: else:
# 按照长度分割 # 按照长度分割
sub_sentences = [] sub_segments = []
for j in range(0, len(current_sentence), max_tokens_per_sentence): for j in range(0, len(current_segment), max_text_tokens_per_segment):
if j + max_tokens_per_sentence < len(current_sentence): if j + max_text_tokens_per_segment < len(current_segment):
sub_sentences.append(current_sentence[j : j + max_tokens_per_sentence]) sub_segments.append(current_segment[j : j + max_text_tokens_per_segment])
else: else:
sub_sentences.append(current_sentence[j:]) sub_segments.append(current_segment[j:])
warnings.warn( warnings.warn(
f"The tokens length of sentence exceeds limit: {max_tokens_per_sentence}, " f"The tokens length of segment exceeds limit: {max_text_tokens_per_segment}, "
f"Tokens in sentence: {current_sentence}." f"Tokens in segment: {current_segment}."
"Maybe unexpected behavior", "Maybe unexpected behavior",
RuntimeWarning, RuntimeWarning,
) )
sentences.extend(sub_sentences) segments.extend(sub_segments)
current_sentence = [] current_segment = []
current_sentence_tokens_len = 0 current_segment_tokens_len = 0
if current_sentence_tokens_len > 0: if current_segment_tokens_len > 0:
assert current_sentence_tokens_len <= max_tokens_per_sentence assert current_segment_tokens_len <= max_text_tokens_per_segment
sentences.append(current_sentence) segments.append(current_segment)
# 如果相邻的句子加起来长度小于最大限制,则合并 # 如果相邻的句子加起来长度小于最大限制,则合并
merged_sentences = [] merged_segments = []
for sentence in sentences: for segment in segments:
if len(sentence) == 0: if len(segment) == 0:
continue continue
if len(merged_sentences) == 0: if len(merged_segments) == 0:
merged_sentences.append(sentence) merged_segments.append(segment)
elif len(merged_sentences[-1]) + len(sentence) <= max_tokens_per_sentence: elif len(merged_segments[-1]) + len(segment) <= max_text_tokens_per_segment:
merged_sentences[-1] = merged_sentences[-1] + sentence merged_segments[-1] = merged_segments[-1] + segment
else: else:
merged_sentences.append(sentence) merged_segments.append(segment)
return merged_sentences return merged_segments
punctuation_marks_tokens = [ punctuation_marks_tokens = [
".", ".",
@@ -422,9 +422,9 @@ class TextTokenizer:
"▁?", "▁?",
"▁...", # ellipsis "▁...", # ellipsis
] ]
def split_sentences(self, tokenized: List[str], max_tokens_per_sentence=120) -> List[List[str]]: def split_segments(self, tokenized: List[str], max_text_tokens_per_segment=120) -> List[List[str]]:
return TextTokenizer.split_sentences_by_token( return TextTokenizer.split_segments_by_token(
tokenized, self.punctuation_marks_tokens, max_tokens_per_sentence=max_tokens_per_sentence tokenized, self.punctuation_marks_tokens, max_text_tokens_per_segment=max_text_tokens_per_segment
) )
@@ -516,19 +516,19 @@ if __name__ == "__main__":
# 测试 normalize后的字符能被分词器识别 # 测试 normalize后的字符能被分词器识别
print(f"`{ch}`", "->", tokenizer.sp_model.Encode(ch, out_type=str)) print(f"`{ch}`", "->", tokenizer.sp_model.Encode(ch, out_type=str))
print(f"` {ch}`", "->", tokenizer.sp_model.Encode(f" {ch}", out_type=str)) print(f"` {ch}`", "->", tokenizer.sp_model.Encode(f" {ch}", out_type=str))
max_tokens_per_sentence=120 max_text_tokens_per_segment=120
for i in range(len(cases)): for i in range(len(cases)):
print(f"原始文本: {cases[i]}") print(f"原始文本: {cases[i]}")
print(f"Normalized: {text_normalizer.normalize(cases[i])}") print(f"Normalized: {text_normalizer.normalize(cases[i])}")
tokens = tokenizer.tokenize(cases[i]) tokens = tokenizer.tokenize(cases[i])
print("Tokenzied: ", ", ".join([f"`{t}`" for t in tokens])) print("Tokenzied: ", ", ".join([f"`{t}`" for t in tokens]))
sentences = tokenizer.split_sentences(tokens, max_tokens_per_sentence=max_tokens_per_sentence) segments = tokenizer.split_segments(tokens, max_text_tokens_per_segment=max_text_tokens_per_segment)
print("Splitted sentences count:", len(sentences)) print("Segments count:", len(segments))
if len(sentences) > 1: if len(segments) > 1:
for j in range(len(sentences)): for j in range(len(segments)):
print(f" {j}, count:", len(sentences[j]), ", tokens:", "".join(sentences[j])) print(f" {j}, count:", len(segments[j]), ", tokens:", "".join(segments[j]))
if len(sentences[j]) > max_tokens_per_sentence: if len(segments[j]) > max_text_tokens_per_segment:
print(f"Warning: sentence {j} is too long, length: {len(sentences[j])}") print(f"Warning: segment {j} is too long, length: {len(segments[j])}")
#print(f"Token IDs (first 10): {codes[i][:10]}") #print(f"Token IDs (first 10): {codes[i][:10]}")
if tokenizer.unk_token in codes[i]: if tokenizer.unk_token in codes[i]:
print(f"Warning: `{cases[i]}` contains UNKNOWN token") print(f"Warning: `{cases[i]}` contains UNKNOWN token")

View File

@@ -1,50 +1,109 @@
[project] [project]
name = "index-tts" name = "indextts"
version = "2.0.0" version = "2.0.0"
description = "IndexTTS2: A Breakthrough in Emotionally Expressive and Duration-Controlled Auto-Regressive Zero-Shot Text-to-Speech" description = "IndexTTS2: A Breakthrough in Emotionally Expressive and Duration-Controlled Auto-Regressive Zero-Shot Text-to-Speech"
authors = [{ name = "Bilibili IndexTTS Team" }]
license = "Apache-2.0" license = "Apache-2.0"
license-files = ["LICEN[CS]E*", "INDEX_MODEL_LICENSE*"] license-files = ["LICEN[CS]E*", "INDEX_MODEL_LICENSE*"]
readme = "README.md" readme = "README.md"
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Natural Language :: English",
"Natural Language :: Chinese (Simplified)",
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
]
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
# IMPORTANT: Always run `uv lock` to resolve and update the lockfile after edits: # IMPORTANT: Always run `uv lock` or `uv lock --upgrade` to resolve dependencies
"accelerate==1.8.1", # and update the lockfile after editing anything below.
"cn2an==0.5.22", # WARNING: Ensure that you don't have a local `uv.toml` which overrides PyPI
"cython==3.0.7", # while generating the lockfile: https://github.com/astral-sh/uv/issues/15741
"deepspeed==0.17.1", "accelerate==1.8.1",
"descript-audiotools==0.7.2", "cn2an==0.5.22",
"ffmpeg-python==0.2.0", "cython==3.0.7",
"g2p-en==2.1.0", "deepspeed==0.17.1",
"gradio>=5.44.1", "descript-audiotools==0.7.2",
"jieba==0.42.1", "ffmpeg-python==0.2.0",
"json5==0.10.0", "g2p-en==2.1.0",
"keras==2.9.0", "jieba==0.42.1",
"librosa==0.10.2.post1", "json5==0.10.0",
"matplotlib==3.8.2", "keras==2.9.0",
"modelscope==1.27.0", "librosa==0.10.2.post1",
"munch==4.0.0", "matplotlib==3.8.2",
"numba==0.58.1", "modelscope==1.27.0",
"numpy==1.26.2", "munch==4.0.0",
"omegaconf>=2.3.0", "numba==0.58.1",
"opencv-python==4.9.0.80", "numpy==1.26.2",
"pandas==2.3.2", "omegaconf>=2.3.0",
"safetensors==0.5.2", "opencv-python==4.9.0.80",
"sentencepiece>=0.2.1", "pandas==2.3.2",
"tensorboard==2.9.1", "safetensors==0.5.2",
"textstat>=0.7.10", "sentencepiece>=0.2.1",
"tokenizers==0.21.0", "tensorboard==2.9.1",
"tqdm>=4.67.1", "textstat>=0.7.10",
"transformers==4.52.1", "tokenizers==0.21.0",
"torch==2.8.*",
"torchaudio==2.8.*",
"tqdm>=4.67.1",
"transformers==4.52.1",
# Use "wetext" on Windows/Mac, otherwise "WeTextProcessing" on Linux." # Use "wetext" on Windows/Mac, otherwise "WeTextProcessing" on Linux.
"wetext>=0.0.9; sys_platform != 'linux'", "wetext>=0.0.9; sys_platform != 'linux'",
"WeTextProcessing; sys_platform == 'linux'", "WeTextProcessing; sys_platform == 'linux'",
]
[project.optional-dependencies]
# To install the WebUI support, use `uv sync --extra webui` (or `--all-extras`).
webui = [
"gradio>=5.44.1",
] ]
[project.urls] [project.urls]
Homepage = "https://github.com/index-tts/index-tts" Homepage = "https://github.com/index-tts/index-tts"
Repository = "https://github.com/index-tts/index-tts.git" Repository = "https://github.com/index-tts/index-tts.git"
[project.scripts]
# Set the installed binary names and entry points.
indextts = "indextts.cli:main"
[build-system]
# How to build the project as a CLI tool or PyPI package.
# NOTE: Use `uv tool install -e .` to install the package as a CLI tool.
requires = ["hatchling >= 1.27.0"]
build-backend = "hatchling.build"
[tool.uv] [tool.uv]
extra-index-url = ["https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"] # Disable build isolation when building DeepSpeed from source.
no-build-isolation-package = ["deepspeed"] no-build-isolation-package = ["deepspeed"]
[tool.uv.sources]
# Install PyTorch with CUDA support on Linux/Windows (CUDA doesn't exist for Mac).
# NOTE: We must explicitly request them as `dependencies` above. These improved
# versions will not be selected if they're only third-party dependencies.
torch = [
{ index = "pytorch-cuda", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
torchaudio = [
{ index = "pytorch-cuda", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
torchvision = [
{ index = "pytorch-cuda", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
[[tool.uv.index]]
name = "pytorch-cuda"
# Use PyTorch built for NVIDIA Toolkit version 12.8.
# Available versions: https://pytorch.org/get-started/locally/
url = "https://download.pytorch.org/whl/cu128"
# Only use this index when explicitly requested by `tool.uv.sources`.
explicit = true

View File

@@ -1,80 +0,0 @@
import platform
import os
from setuptools import find_packages, setup
# add fused `anti_alias_activation` cuda extension if CUDA is available
anti_alias_activation_cuda_ext = None
if platform.system() != "Darwin":
try:
from torch.utils import cpp_extension
if cpp_extension.CUDA_HOME is not None:
anti_alias_activation_cuda_ext = cpp_extension.CUDAExtension(
name="indextts.BigVGAN.alias_free_activation.cuda.anti_alias_activation_cuda",
sources=[
"indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp",
"indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu",
],
include_dirs=["indextts/BigVGAN/alias_free_activation/cuda"],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": [
"-O3",
"--use_fast_math",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
],
},
)
else:
print("CUDA_HOME is not set. Skipping anti_alias_activation CUDA extension.")
except ImportError:
print("PyTorch is not installed. Skipping torch extension.")
setup(
name="indextts",
version="0.1.4",
author="Index SpeechTeam",
author_email="xuanwu@bilibili.com",
long_description=open("README.md", encoding="utf8").read(),
long_description_content_type="text/markdown",
description="An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System",
url="https://github.com/index-tts/index-tts",
packages=find_packages(),
include_package_data=True,
install_requires=[
"torch>=2.1.2",
"torchaudio",
"transformers==4.36.2",
"accelerate",
"tokenizers==0.15.0",
"einops==0.8.1",
"matplotlib==3.8.2",
"omegaconf",
"sentencepiece",
"librosa",
"numpy",
"wetext" if platform.system() == "Darwin" else "WeTextProcessing",
],
extras_require={
"webui": ["gradio"],
},
ext_modules=[anti_alias_activation_cuda_ext] if anti_alias_activation_cuda_ext else [],
cmdclass={"build_ext": cpp_extension.BuildExtension} if anti_alias_activation_cuda_ext else {},
entry_points={
"console_scripts": [
"indextts = indextts.cli:main",
]
},
license="Apache-2.0",
python_requires=">=3.10",
classifiers=[
"Programming Language :: Python :: 3.10",
"Operating System :: OS Independent",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)

View File

@@ -21,7 +21,7 @@ if __name__ == "__main__":
else: else:
model_dir = "checkpoints" model_dir = "checkpoints"
audio_prompt="tests/sample_prompt.wav" audio_prompt="tests/sample_prompt.wav"
tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, is_fp16=False, use_cuda_kernel=False) tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, use_fp16=False, use_cuda_kernel=False)
text = "晕 XUAN4 是 一 种 not very good GAN3 觉" text = "晕 XUAN4 是 一 种 not very good GAN3 觉"
text_tokens = tts.tokenizer.encode(text) text_tokens = tts.tokenizer.encode(text)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L] text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L]

View File

@@ -2,7 +2,7 @@ from indextts.infer import IndexTTS
if __name__ == "__main__": if __name__ == "__main__":
prompt_wav="tests/sample_prompt.wav" prompt_wav="tests/sample_prompt.wav"
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, use_cuda_kernel=False) tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=True, use_cuda_kernel=False)
# 单音频推理测试 # 单音频推理测试
text="晕 XUAN4 是 一 种 GAN3 觉" text="晕 XUAN4 是 一 种 GAN3 觉"
tts.infer(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True) tts.infer(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True)

47
tools/gpu_check.py Normal file
View File

@@ -0,0 +1,47 @@
import torch
def show_cuda_gpu_list() -> None:
"""
Displays a list of all detected GPUs that support the CUDA Torch APIs.
"""
num_gpus = torch.cuda.device_count()
print(f"Number of GPUs found: {num_gpus}")
for i in range(num_gpus):
gpu_name = torch.cuda.get_device_name(i)
print(f'GPU {i}: "{gpu_name}"')
def check_torch_gpus() -> None:
"""
Checks for the availability of various PyTorch GPU acceleration platforms
and prints information about the discovered GPUs.
"""
# Check for AMD ROCm/HIP first, since it modifies the CUDA APIs.
# NOTE: The unofficial ROCm/HIP backend exposes the AMD features through
# the CUDA Torch API calls.
if hasattr(torch.backends, "hip") and torch.backends.hip.is_available():
print("PyTorch: AMD ROCm/HIP is available!")
show_cuda_gpu_list()
# Check for NVIDIA CUDA.
elif torch.cuda.is_available():
print("PyTorch: NVIDIA CUDA is available!")
show_cuda_gpu_list()
# Check for Apple Metal Performance Shaders (MPS).
elif torch.backends.mps.is_available():
print("PyTorch: Apple MPS is available!")
# PyTorch with MPS doesn't have a direct equivalent of `device_count()`
# or `get_device_name()` for now, so we just confirm its presence.
print("Using Apple Silicon GPU.")
else:
print("PyTorch: No GPU acceleration detected. Running in CPU mode.")
if __name__ == "__main__":
check_torch_gpus()

View File

@@ -1,10 +1,10 @@
{ {
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "This software is open-sourced under the MIT License. The author has no control over the software, and users of the software, as well as those who distribute the audio generated by the software, assume full responsibility.", "本软件以Apache-2.0协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "This software is open-sourced under the Apache-2.0 License. The author has no control over the software, and users of the software, as well as those who distribute the audio generated by the software, assume full responsibility.",
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "If you do not agree to these terms, you are not permitted to use or reference any code or files within the software package. For further details, please refer to the LICENSE file in the root directory.", "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "If you do not agree to these terms, you are not permitted to use or reference any code or files within the software package. For further details, please refer to the LICENSE files in the root directory.",
"时长必须为正数": "Duration must be a positive number", "时长必须为正数": "Duration must be a positive number",
"请输入有效的浮点数": "Please enter a valid floating-point number", "请输入有效的浮点数": "Please enter a valid floating-point number",
"使用情感参考音频": "Use emotion reference audio", "使用情感参考音频": "Use emotion reference audio",
"使用情感向量控制": "Use emotion vector", "使用情感向量控制": "Use emotion vectors",
"使用情感描述文本控制": "Use text description to control emotion", "使用情感描述文本控制": "Use text description to control emotion",
"上传情感参考音频": "Upload emotion reference audio", "上传情感参考音频": "Upload emotion reference audio",
"情感权重": "Emotion control weight", "情感权重": "Emotion control weight",
@@ -17,32 +17,32 @@
"惊喜": "Surprised", "惊喜": "Surprised",
"平静": "Calm", "平静": "Calm",
"情感描述文本": "Emotion description", "情感描述文本": "Emotion description",
"请输入情描述文本": "Please input emotion description", "请输入情描述(或留空以自动使用目标文本作为情绪描述)": "Please input an emotion description (or leave blank to automatically use the main text prompt)",
"高级生成参数设置": "Advanced generation parameter settings", "高级生成参数设置": "Advanced generation parameter settings",
"情感向量之和不能超过1.5,请调整后重试。": "The sum of the emotion vectors cannot exceed 1.5. Please adjust and try again.", "情感向量之和不能超过1.5,请调整后重试。": "The sum of the emotion vectors cannot exceed 1.5. Please adjust and try again.",
"音色参考音频": "Voice reference", "音色参考音频": "Voice Reference",
"音频生成": "Speech Synthesis", "音频生成": "Speech Synthesis",
"文本": "Text", "文本": "Text",
"生成语音": "Synthesize", "生成语音": "Synthesize",
"生成结果": "Synthesis Result", "生成结果": "Synthesis Result",
"功能设置": "Settings", "功能设置": "Settings",
"分句设置": "Sentence segmentation settings", "分句设置": "Text segmentation settings",
"参数会影响音频质量和生成速度": "Parameters below affect audio quality and generation speed", "参数会影响音频质量和生成速度": "These parameters affect the audio quality and generation speed.",
"分句最大Token数": "Max tokens per generation segment", "分句最大Token数": "Max tokens per generation segment",
"建议80~200之间值越大分句越长值越小分句越碎过小过大都可能导致音频质量不高": "Recommended range: 80 - 200. Larger values require more VRAM but improves the flow of the speech, while lower values require less VRAM but means more fragmented sentences. Values that are too small or too large may lead to less coherent speech.", "建议80~200之间值越大分句越长值越小分句越碎过小过大都可能导致音频质量不高": "Recommended range: 80 - 200. Larger values require more VRAM but improves the flow of the speech, while lower values require less VRAM but means more fragmented sentences. Values that are too small or too large may lead to less coherent speech.",
"预览分句结果": "Preview sentence segmentation result", "预览分句结果": "Preview of the audio generation segments",
"序号": "Index", "序号": "Index",
"分句内容": "Content", "分句内容": "Content",
"Token数": "Token Count", "Token数": "Token Count",
"情感控制方式": "Emotion control method", "情感控制方式": "Emotion control method",
"GPT2 采样设置": "GPT-2 Sampling Configuration", "GPT2 采样设置": "GPT-2 Sampling Configuration",
"参数会影响音频多样性和生成速度详见": "Influence both the diversity of the generated audio and the generation speed. For further details, refer to", "参数会影响音频多样性和生成速度详见": "Influences both the diversity of the generated audio and the generation speed. For further details, refer to",
"是否进行采样": "Enable GPT-2 sampling", "是否进行采样": "Enable GPT-2 sampling",
"生成Token最大数量过小导致音频被截断": "Maximum number of tokens to generate. If text exceeds this, the audio will be cut off.", "生成Token最大数量过小导致音频被截断": "Maximum number of tokens to generate. If text exceeds this, the audio will be cut off.",
"请上传情感参考音频": "Please upload emotion reference audio", "请上传情感参考音频": "Please upload the emotion reference audio",
"当前模型版本": "Current model version ", "当前模型版本": "Current model version: ",
"请输入目标文本": "Please input text to synthesize", "请输入目标文本": "Please input the text to synthesize",
"例如:高兴,愤怒,悲伤等": "e.g., happy, angry, sad, etc.", "例如:高兴,愤怒,悲伤等": "e.g., happy, angry, sad, etc.",
"与音色参考音频相同": "Same as the voice reference", "与音色参考音频相同": "Same as the voice reference",
"情感随机采样": "Random emotion sampling" "情感随机采样": "Randomize emotion sampling"
} }

View File

@@ -1,5 +1,5 @@
{ {
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.", "本软件以Apache-2.0协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "本软件以Apache-2.0协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.",
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.", "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.",
"时长必须为正数": "时长必须为正数", "时长必须为正数": "时长必须为正数",
"请输入有效的浮点数": "请输入有效的浮点数", "请输入有效的浮点数": "请输入有效的浮点数",
@@ -17,7 +17,7 @@
"惊喜": "惊喜", "惊喜": "惊喜",
"平静": "平静", "平静": "平静",
"情感描述文本": "情感描述文本", "情感描述文本": "情感描述文本",
"请输入情描述文本": "请输入情描述文本", "请输入情描述(或留空以自动使用目标文本作为情绪描述)": "请输入情描述(或留空以自动使用目标文本作为情绪描述)",
"高级生成参数设置": "高级生成参数设置", "高级生成参数设置": "高级生成参数设置",
"情感向量之和不能超过1.5,请调整后重试。": "情感向量之和不能超过1.5,请调整后重试。", "情感向量之和不能超过1.5,请调整后重试。": "情感向量之和不能超过1.5,请调整后重试。",
"音色参考音频": "音色参考音频", "音色参考音频": "音色参考音频",

4296
uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,4 @@
import json import json
import logging
import os import os
import sys import sys
import threading import threading
@@ -17,12 +16,16 @@ sys.path.append(current_dir)
sys.path.append(os.path.join(current_dir, "indextts")) sys.path.append(os.path.join(current_dir, "indextts"))
import argparse import argparse
parser = argparse.ArgumentParser(description="IndexTTS WebUI") parser = argparse.ArgumentParser(
description="IndexTTS WebUI",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode") parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on") parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on")
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory") parser.add_argument("--model_dir", type=str, default="./checkpoints", help="Model checkpoints directory")
parser.add_argument("--is_fp16", action="store_true", default=False, help="Fp16 infer") parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
parser.add_argument("--gui_seg_tokens", type=int, default=120, help="GUI: Max tokens per generation segment")
cmd_args = parser.parse_args() cmd_args = parser.parse_args()
if not os.path.exists(cmd_args.model_dir): if not os.path.exists(cmd_args.model_dir):
@@ -42,14 +45,12 @@ for file in [
sys.exit(1) sys.exit(1)
import gradio as gr import gradio as gr
from indextts import infer
from indextts.infer_v2 import IndexTTS2 from indextts.infer_v2 import IndexTTS2
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
from modelscope.hub import api
i18n = I18nAuto(language="Auto") i18n = I18nAuto(language="Auto")
MODE = 'local' MODE = 'local'
tts = IndexTTS2(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),is_fp16=cmd_args.is_fp16) tts = IndexTTS2(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),use_fp16=cmd_args.fp16)
# 支持的语言列表 # 支持的语言列表
LANGUAGES = { LANGUAGES = {
@@ -96,7 +97,7 @@ def gen_single(emo_control_method,prompt, text,
emo_ref_path, emo_weight, emo_ref_path, emo_weight,
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8, vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
emo_text,emo_random, emo_text,emo_random,
max_text_tokens_per_sentence=120, max_text_tokens_per_segment=120,
*args, progress=gr.Progress()): *args, progress=gr.Progress()):
output_path = None output_path = None
if not output_path: if not output_path:
@@ -133,6 +134,10 @@ def gen_single(emo_control_method,prompt, text,
else: else:
vec = None vec = None
if emo_text == "":
# erase empty emotion descriptions; `infer()` will then automatically use the main prompt
emo_text = None
print(f"Emo control mode:{emo_control_method},vec:{vec}") print(f"Emo control mode:{emo_control_method},vec:{vec}")
output = tts.infer(spk_audio_prompt=prompt, text=text, output = tts.infer(spk_audio_prompt=prompt, text=text,
output_path=output_path, output_path=output_path,
@@ -140,7 +145,7 @@ def gen_single(emo_control_method,prompt, text,
emo_vector=vec, emo_vector=vec,
use_emo_text=(emo_control_method==3), emo_text=emo_text,use_random=emo_random, use_emo_text=(emo_control_method==3), emo_text=emo_text,use_random=emo_random,
verbose=cmd_args.verbose, verbose=cmd_args.verbose,
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence), max_text_tokens_per_segment=int(max_text_tokens_per_segment),
**kwargs) **kwargs)
return gr.update(value=output,visible=True) return gr.update(value=output,visible=True)
@@ -204,7 +209,7 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
with gr.Group(visible=False) as emo_text_group: with gr.Group(visible=False) as emo_text_group:
with gr.Row(): with gr.Row():
emo_text = gr.Textbox(label=i18n("情感描述文本"), placeholder=i18n("请输入情描述文本"), value="", info=i18n("例如:高兴,愤怒,悲伤等")) emo_text = gr.Textbox(label=i18n("情感描述文本"), placeholder=i18n("请输入情描述(或留空以自动使用目标文本作为情绪描述)"), value="", info=i18n("例如:高兴,愤怒,悲伤等"))
with gr.Accordion(i18n("高级生成参数设置"), open=False): with gr.Accordion(i18n("高级生成参数设置"), open=False):
with gr.Row(): with gr.Row():
@@ -227,14 +232,15 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
with gr.Column(scale=2): with gr.Column(scale=2):
gr.Markdown(f'**{i18n("分句设置")}** _{i18n("参数会影响音频质量和生成速度")}_') gr.Markdown(f'**{i18n("分句设置")}** _{i18n("参数会影响音频质量和生成速度")}_')
with gr.Row(): with gr.Row():
max_text_tokens_per_sentence = gr.Slider( initial_value = max(20, min(tts.cfg.gpt.max_text_tokens, cmd_args.gui_seg_tokens))
label=i18n("分句最大Token数"), value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_sentence", max_text_tokens_per_segment = gr.Slider(
label=i18n("分句最大Token数"), value=initial_value, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_segment",
info=i18n("建议80~200之间值越大分句越长值越小分句越碎过小过大都可能导致音频质量不高"), info=i18n("建议80~200之间值越大分句越长值越小分句越碎过小过大都可能导致音频质量不高"),
) )
with gr.Accordion(i18n("预览分句结果"), open=True) as sentences_settings: with gr.Accordion(i18n("预览分句结果"), open=True) as segments_settings:
sentences_preview = gr.Dataframe( segments_preview = gr.Dataframe(
headers=[i18n("序号"), i18n("分句内容"), i18n("Token数")], headers=[i18n("序号"), i18n("分句内容"), i18n("Token数")],
key="sentences_preview", key="segments_preview",
wrap=True, wrap=True,
) )
advanced_params = [ advanced_params = [
@@ -256,23 +262,23 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
vec1,vec2,vec3,vec4,vec5,vec6,vec7,vec8] vec1,vec2,vec3,vec4,vec5,vec6,vec7,vec8]
) )
def on_input_text_change(text, max_tokens_per_sentence): def on_input_text_change(text, max_text_tokens_per_segment):
if text and len(text) > 0: if text and len(text) > 0:
text_tokens_list = tts.tokenizer.tokenize(text) text_tokens_list = tts.tokenizer.tokenize(text)
sentences = tts.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=int(max_tokens_per_sentence)) segments = tts.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment=int(max_text_tokens_per_segment))
data = [] data = []
for i, s in enumerate(sentences): for i, s in enumerate(segments):
sentence_str = ''.join(s) segment_str = ''.join(s)
tokens_count = len(s) tokens_count = len(s)
data.append([i, sentence_str, tokens_count]) data.append([i, segment_str, tokens_count])
return { return {
sentences_preview: gr.update(value=data, visible=True, type="array"), segments_preview: gr.update(value=data, visible=True, type="array"),
} }
else: else:
df = pd.DataFrame([], columns=[i18n("序号"), i18n("分句内容"), i18n("Token数")]) df = pd.DataFrame([], columns=[i18n("序号"), i18n("分句内容"), i18n("Token数")])
return { return {
sentences_preview: gr.update(value=df), segments_preview: gr.update(value=df),
} }
def on_method_select(emo_control_method): def on_method_select(emo_control_method):
if emo_control_method == 1: if emo_control_method == 1:
@@ -310,13 +316,13 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
input_text_single.change( input_text_single.change(
on_input_text_change, on_input_text_change,
inputs=[input_text_single, max_text_tokens_per_sentence], inputs=[input_text_single, max_text_tokens_per_segment],
outputs=[sentences_preview] outputs=[segments_preview]
) )
max_text_tokens_per_sentence.change( max_text_tokens_per_segment.change(
on_input_text_change, on_input_text_change,
inputs=[input_text_single, max_text_tokens_per_sentence], inputs=[input_text_single, max_text_tokens_per_segment],
outputs=[sentences_preview] outputs=[segments_preview]
) )
prompt_audio.upload(update_prompt_audio, prompt_audio.upload(update_prompt_audio,
inputs=[], inputs=[],
@@ -326,7 +332,7 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
inputs=[emo_control_method,prompt_audio, input_text_single, emo_upload, emo_weight, inputs=[emo_control_method,prompt_audio, input_text_single, emo_upload, emo_weight,
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8, vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
emo_text,emo_random, emo_text,emo_random,
max_text_tokens_per_sentence, max_text_tokens_per_segment,
*advanced_params, *advanced_params,
], ],
outputs=[output_audio]) outputs=[output_audio])