diff --git a/indextts/infer_v2.py b/indextts/infer_v2.py index 1cab10a..f5c0b4a 100644 --- a/indextts/infer_v2.py +++ b/indextts/infer_v2.py @@ -38,7 +38,7 @@ import torch.nn.functional as F class IndexTTS2: def __init__( self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None, - use_cuda_kernel=None, + use_cuda_kernel=None,use_deepspeed=False ): """ Args: @@ -47,6 +47,7 @@ class IndexTTS2: use_fp16 (bool): whether to use fp16. device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS. use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device. + use_deepspeed (bool): whether to use deepspeed or not. """ if device is not None: self.device = device @@ -87,12 +88,12 @@ class IndexTTS2: self.gpt.eval() print(">> GPT weights restored from:", self.gpt_path) - use_deepspeed = True try: import deepspeed except (ImportError, OSError, CalledProcessError) as e: + if use_deepspeed: + print(f">> DeepSpeed加载失败,回退到标准推理: {e}") use_deepspeed = False - print(f">> DeepSpeed加载失败,回退到标准推理: {e}") self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=self.use_fp16) diff --git a/webui.py b/webui.py index 705b126..e623eb5 100644 --- a/webui.py +++ b/webui.py @@ -25,6 +25,7 @@ parser.add_argument("--port", type=int, default=7860, help="Port to run the web parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on") parser.add_argument("--model_dir", type=str, default="./checkpoints", help="Model checkpoints directory") parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available") +parser.add_argument("--use_deepspeed", action="store_true", default=False, help="Use Deepspeed to accelerate if available") parser.add_argument("--cuda_kernel", action="store_true", default=False, help="Use cuda kernel for inference if available") parser.add_argument("--gui_seg_tokens", type=int, default=120, help="GUI: Max tokens per generation segment") cmd_args = parser.parse_args() @@ -51,13 +52,12 @@ from tools.i18n.i18n import I18nAuto i18n = I18nAuto(language="Auto") MODE = 'local' -tts = IndexTTS2( - model_dir=cmd_args.model_dir, - cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"), - use_fp16=cmd_args.fp16, - use_cuda_kernel=cmd_args.cuda_kernel, -) - +tts = IndexTTS2(model_dir=cmd_args.model_dir, + cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"), + use_fp16=cmd_args.fp16, + use_deepspeed=cmd_args.use_deepspeed, + use_cuda_kernel=cmd_args.cuda_kernel, + ) # 支持的语言列表 LANGUAGES = { "中文": "zh_CN",