mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-25 19:37:47 +08:00
add deepspeed cmd option (#307)
This commit is contained in:
@@ -38,7 +38,7 @@ import torch.nn.functional as F
|
||||
class IndexTTS2:
|
||||
def __init__(
|
||||
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None,
|
||||
use_cuda_kernel=None,
|
||||
use_cuda_kernel=None,use_deepspeed=False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -47,6 +47,7 @@ class IndexTTS2:
|
||||
use_fp16 (bool): whether to use fp16.
|
||||
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
|
||||
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
|
||||
use_deepspeed (bool): whether to use deepspeed or not.
|
||||
"""
|
||||
if device is not None:
|
||||
self.device = device
|
||||
@@ -87,12 +88,12 @@ class IndexTTS2:
|
||||
self.gpt.eval()
|
||||
print(">> GPT weights restored from:", self.gpt_path)
|
||||
|
||||
use_deepspeed = True
|
||||
try:
|
||||
import deepspeed
|
||||
except (ImportError, OSError, CalledProcessError) as e:
|
||||
if use_deepspeed:
|
||||
print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
|
||||
use_deepspeed = False
|
||||
print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
|
||||
|
||||
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=self.use_fp16)
|
||||
|
||||
|
||||
14
webui.py
14
webui.py
@@ -25,6 +25,7 @@ parser.add_argument("--port", type=int, default=7860, help="Port to run the web
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on")
|
||||
parser.add_argument("--model_dir", type=str, default="./checkpoints", help="Model checkpoints directory")
|
||||
parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
|
||||
parser.add_argument("--use_deepspeed", action="store_true", default=False, help="Use Deepspeed to accelerate if available")
|
||||
parser.add_argument("--cuda_kernel", action="store_true", default=False, help="Use cuda kernel for inference if available")
|
||||
parser.add_argument("--gui_seg_tokens", type=int, default=120, help="GUI: Max tokens per generation segment")
|
||||
cmd_args = parser.parse_args()
|
||||
@@ -51,13 +52,12 @@ from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto(language="Auto")
|
||||
MODE = 'local'
|
||||
tts = IndexTTS2(
|
||||
model_dir=cmd_args.model_dir,
|
||||
cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),
|
||||
use_fp16=cmd_args.fp16,
|
||||
use_cuda_kernel=cmd_args.cuda_kernel,
|
||||
)
|
||||
|
||||
tts = IndexTTS2(model_dir=cmd_args.model_dir,
|
||||
cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),
|
||||
use_fp16=cmd_args.fp16,
|
||||
use_deepspeed=cmd_args.use_deepspeed,
|
||||
use_cuda_kernel=cmd_args.cuda_kernel,
|
||||
)
|
||||
# 支持的语言列表
|
||||
LANGUAGES = {
|
||||
"中文": "zh_CN",
|
||||
|
||||
Reference in New Issue
Block a user