Merge pull request #347 from index-tts/cut_audio

feat: 裁剪过长的输入音频至15s,减少爆内存和显存
This commit is contained in:
nanaoto
2025-09-12 16:48:14 +08:00
committed by GitHub

View File

@@ -292,6 +292,20 @@ class IndexTTS2:
if self.gr_progress is not None:
self.gr_progress(value, desc=desc)
def _load_and_cut_audio(self,audio_path,max_audio_length_seconds,verbose=False,sr=None):
if not sr:
audio, sr = librosa.load(audio_path)
else:
audio, _ = librosa.load(audio_path,sr=sr)
audio = torch.tensor(audio).unsqueeze(0)
max_audio_samples = int(max_audio_length_seconds * sr)
if audio.shape[1] > max_audio_samples:
if verbose:
print(f"Audio too long ({audio.shape[1]} samples), truncating to {max_audio_samples} samples")
audio = audio[:, :max_audio_samples]
return audio, sr
# 原始推理模式
def infer(self, spk_audio_prompt, text, output_path,
emo_audio_prompt=None, emo_alpha=1.0,
@@ -340,8 +354,7 @@ class IndexTTS2:
# 如果参考音频改变了,才需要重新生成, 提升速度
if self.cache_spk_cond is None or self.cache_spk_audio_prompt != spk_audio_prompt:
audio, sr = librosa.load(spk_audio_prompt)
audio = torch.tensor(audio).unsqueeze(0)
audio,sr = self._load_and_cut_audio(spk_audio_prompt,15,verbose)
audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
@@ -392,7 +405,7 @@ class IndexTTS2:
emovec_mat = emovec_mat.unsqueeze(0)
if self.cache_emo_cond is None or self.cache_emo_audio_prompt != emo_audio_prompt:
emo_audio, _ = librosa.load(emo_audio_prompt, sr=16000)
emo_audio, _ = self._load_and_cut_audio(emo_audio_prompt,15,verbose,sr=16000)
emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt")
emo_input_features = emo_inputs["input_features"]
emo_attention_mask = emo_inputs["attention_mask"]