mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-25 11:29:32 +08:00
Merge branch 'main' into perf/gumbel_softmax_sampler
This commit is contained in:
@@ -59,7 +59,6 @@ class AccelInferenceEngine:
|
|||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
|
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
|
||||||
model_dtype = next(model.parameters()).dtype
|
|
||||||
self.hidden_size = (
|
self.hidden_size = (
|
||||||
model.config.hidden_size
|
model.config.hidden_size
|
||||||
if hasattr(model, "config")
|
if hasattr(model, "config")
|
||||||
@@ -102,14 +101,15 @@ class AccelInferenceEngine:
|
|||||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||||
|
|
||||||
if req.block_table:
|
if req.block_table:
|
||||||
for i in range(req.num_cached_blocks, req.num_blocks):
|
num_cached = req.num_cached_tokens
|
||||||
block_id = req.block_table[i]
|
num_total = len(req)
|
||||||
start = block_id * self.block_size
|
|
||||||
if i != req.num_blocks - 1:
|
for token_idx in range(num_cached, num_total):
|
||||||
end = start + self.block_size
|
block_idx = token_idx // self.block_size
|
||||||
else:
|
block_offset = token_idx % self.block_size
|
||||||
end = start + req.last_block_num_tokens
|
block_id = req.block_table[block_idx]
|
||||||
slot_mapping.extend(list(range(start, end)))
|
slot_idx = block_id * self.block_size + block_offset
|
||||||
|
slot_mapping.append(slot_idx)
|
||||||
|
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
|
||||||
non_blocking=True
|
non_blocking=True
|
||||||
@@ -218,13 +218,12 @@ class AccelInferenceEngine:
|
|||||||
).cuda(non_blocking=True)
|
).cuda(non_blocking=True)
|
||||||
return temperatures
|
return temperatures
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def _capture_cuda_graphs(self, tts_mel_embedding=None, tts_text_pos_embedding=None):
|
def _capture_cuda_graphs(self, tts_mel_embedding=None, tts_text_pos_embedding=None):
|
||||||
print("Capturing CUDA graphs for decode optimization...")
|
print("Capturing CUDA graphs for decode optimization...")
|
||||||
max_bs = 8 # Support up to batch size 8
|
max_bs = 8 # Support up to batch size 8
|
||||||
max_num_blocks = (2048 + self.block_size - 1) // self.block_size
|
max_num_blocks = (2048 + self.block_size - 1) // self.block_size
|
||||||
model_dtype = next(self.model.parameters()).dtype
|
model_dtype = next(self.model.parameters()).dtype
|
||||||
input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda") * 8192
|
input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda")
|
||||||
positions = torch.ones(max_bs, dtype=torch.int64, device="cuda")
|
positions = torch.ones(max_bs, dtype=torch.int64, device="cuda")
|
||||||
slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
|
slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
|
||||||
context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
|
context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
|
||||||
@@ -238,7 +237,7 @@ class AccelInferenceEngine:
|
|||||||
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
|
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.graph_bs = [1]
|
self.graph_bs = [1, 2, 4, 8]
|
||||||
|
|
||||||
use_tts = tts_mel_embedding is not None and tts_text_pos_embedding is not None
|
use_tts = tts_mel_embedding is not None and tts_text_pos_embedding is not None
|
||||||
|
|
||||||
@@ -247,7 +246,7 @@ class AccelInferenceEngine:
|
|||||||
|
|
||||||
slot_mapping[:bs] = torch.arange(bs, dtype=torch.int32, device="cuda")
|
slot_mapping[:bs] = torch.arange(bs, dtype=torch.int32, device="cuda")
|
||||||
context_lens[:bs] = bs + 1
|
context_lens[:bs] = bs + 1
|
||||||
block_tables[:bs, 0] = 0
|
block_tables[:bs, :] = 0
|
||||||
|
|
||||||
set_forward_context(
|
set_forward_context(
|
||||||
False,
|
False,
|
||||||
@@ -310,7 +309,6 @@ class AccelInferenceEngine:
|
|||||||
}
|
}
|
||||||
print(f"CUDA graphs captured for batch sizes: {self.graph_bs}")
|
print(f"CUDA graphs captured for batch sizes: {self.graph_bs}")
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def _run_decode_with_graph(
|
def _run_decode_with_graph(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@@ -363,19 +361,13 @@ class AccelInferenceEngine:
|
|||||||
if graph_vars is None:
|
if graph_vars is None:
|
||||||
raise RuntimeError("Graph variables not initialized")
|
raise RuntimeError("Graph variables not initialized")
|
||||||
|
|
||||||
set_forward_context(
|
|
||||||
False,
|
|
||||||
slot_mapping=graph_vars["slot_mapping"][:graph_bs],
|
|
||||||
context_lens=graph_vars["context_lens"][:graph_bs],
|
|
||||||
block_tables=graph_vars["block_tables"][:graph_bs],
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_vars["input_ids"][:bs] = input_ids
|
graph_vars["input_ids"][:bs] = input_ids
|
||||||
graph_vars["positions"][:bs] = positions
|
graph_vars["positions"][:bs] = positions
|
||||||
graph_vars["slot_mapping"].fill_(-1)
|
graph_vars["slot_mapping"].fill_(-1)
|
||||||
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
||||||
graph_vars["context_lens"].zero_()
|
graph_vars["context_lens"].zero_()
|
||||||
graph_vars["context_lens"][:bs] = context.context_lens
|
graph_vars["context_lens"][:bs] = context.context_lens
|
||||||
|
graph_vars["block_tables"][:bs, :].fill_(-1)
|
||||||
graph_vars["block_tables"][:bs, : context.block_tables.size(1)] = (
|
graph_vars["block_tables"][:bs, : context.block_tables.size(1)] = (
|
||||||
context.block_tables
|
context.block_tables
|
||||||
)
|
)
|
||||||
@@ -383,7 +375,6 @@ class AccelInferenceEngine:
|
|||||||
|
|
||||||
return graph_vars["outputs"][:bs]
|
return graph_vars["outputs"][:bs]
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@@ -440,14 +431,26 @@ class AccelInferenceEngine:
|
|||||||
|
|
||||||
if tts_embeddings is not None:
|
if tts_embeddings is not None:
|
||||||
actual_seq_len = tts_embeddings.size(1) + 1 # embeddings + start_mel_token
|
actual_seq_len = tts_embeddings.size(1) + 1 # embeddings + start_mel_token
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
actual_seq_len = input_ids.size(1)
|
actual_seq_len = input_ids.size(1)
|
||||||
|
|
||||||
|
is_varlen_batch = (
|
||||||
|
tts_embeddings is not None
|
||||||
|
and attention_mask is not None
|
||||||
|
and batch_size > 1
|
||||||
|
and (attention_mask.sum(dim=1) != attention_mask.size(1)).any()
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_varlen_batch:
|
||||||
|
seq_lens = [attention_mask[i].sum().item() for i in range(batch_size)]
|
||||||
|
else:
|
||||||
|
seq_lens = [actual_seq_len] * batch_size
|
||||||
|
|
||||||
sequences = []
|
sequences = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
token_ids = [1] * actual_seq_len
|
seq_len = seq_lens[i]
|
||||||
if tts_embeddings is not None and actual_seq_len > 0:
|
token_ids = [1] * seq_len
|
||||||
|
if tts_embeddings is not None and seq_len > 0:
|
||||||
token_ids[-1] = input_ids[i, -1].item() if input_ids.size(1) > 0 else 1
|
token_ids[-1] = input_ids[i, -1].item() if input_ids.size(1) > 0 else 1
|
||||||
else:
|
else:
|
||||||
token_ids = input_ids[i].tolist()
|
token_ids = input_ids[i].tolist()
|
||||||
@@ -457,18 +460,8 @@ class AccelInferenceEngine:
|
|||||||
|
|
||||||
self.current_sequences = sequences
|
self.current_sequences = sequences
|
||||||
|
|
||||||
# Prefill phase
|
|
||||||
prefill_ids, prefill_pos = self._prepare_prefill(sequences)
|
prefill_ids, prefill_pos = self._prepare_prefill(sequences)
|
||||||
|
|
||||||
if prefill_ids.dim() == 1:
|
|
||||||
prefill_ids = prefill_ids.unsqueeze(
|
|
||||||
0
|
|
||||||
) # [total_tokens] -> [1, total_tokens]
|
|
||||||
if prefill_pos.dim() == 1:
|
|
||||||
prefill_pos = prefill_pos.unsqueeze(
|
|
||||||
0
|
|
||||||
) # [total_tokens] -> [1, total_tokens]
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
tts_embeddings is not None
|
tts_embeddings is not None
|
||||||
and tts_mel_embedding is not None
|
and tts_mel_embedding is not None
|
||||||
@@ -480,11 +473,31 @@ class AccelInferenceEngine:
|
|||||||
torch.tensor([[start_token_id]], device="cuda")
|
torch.tensor([[start_token_id]], device="cuda")
|
||||||
) # [1, 1, hidden_dim]
|
) # [1, 1, hidden_dim]
|
||||||
|
|
||||||
start_emb = start_emb + tts_text_pos_embedding(start_emb)
|
start_pos = torch.tensor(
|
||||||
|
[[tts_embeddings.size(1)]], device="cuda", dtype=torch.long
|
||||||
|
)
|
||||||
|
pos_emb = tts_text_pos_embedding.emb(start_pos)
|
||||||
|
start_emb = start_emb + pos_emb
|
||||||
|
start_emb = start_emb.repeat(batch_size, 1, 1)
|
||||||
|
|
||||||
full_embeddings = torch.cat(
|
if is_varlen_batch:
|
||||||
[tts_embeddings, start_emb], dim=1
|
valid_embeddings = []
|
||||||
) # [1, 88, hidden_dim]
|
for i in range(batch_size):
|
||||||
|
emb_len = seq_lens[i] - 1
|
||||||
|
padding_len = tts_embeddings.size(1) - emb_len
|
||||||
|
valid_emb = tts_embeddings[i, padding_len:].unsqueeze(
|
||||||
|
0
|
||||||
|
) # [1, emb_len, hidden_dim]
|
||||||
|
valid_embeddings.append(
|
||||||
|
torch.cat([valid_emb, start_emb[i : i + 1]], dim=1)
|
||||||
|
)
|
||||||
|
full_embeddings = torch.cat(
|
||||||
|
valid_embeddings, dim=1
|
||||||
|
) # [1, total_tokens, hidden_dim]
|
||||||
|
else:
|
||||||
|
full_embeddings = torch.cat(
|
||||||
|
[tts_embeddings, start_emb], dim=1
|
||||||
|
) # [batch_size, seq_len, hidden_dim]
|
||||||
|
|
||||||
model_dtype = next(self.model.parameters()).dtype
|
model_dtype = next(self.model.parameters()).dtype
|
||||||
if full_embeddings.dtype != model_dtype:
|
if full_embeddings.dtype != model_dtype:
|
||||||
@@ -499,9 +512,16 @@ class AccelInferenceEngine:
|
|||||||
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
|
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
|
||||||
).last_hidden_state
|
).last_hidden_state
|
||||||
|
|
||||||
reset_forward_context()
|
if is_varlen_batch:
|
||||||
|
context = get_forward_context()
|
||||||
|
cu_seqlens = context.cu_seqlens_q.cpu().tolist()
|
||||||
|
last_hidden = torch.stack(
|
||||||
|
[hidden_states[0, cu_seqlens[i + 1] - 1] for i in range(batch_size)]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
|
||||||
|
|
||||||
last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
|
reset_forward_context()
|
||||||
|
|
||||||
if self.lm_head is not None:
|
if self.lm_head is not None:
|
||||||
if last_hidden.dtype != next(self.lm_head.parameters()).dtype:
|
if last_hidden.dtype != next(self.lm_head.parameters()).dtype:
|
||||||
@@ -519,15 +539,17 @@ class AccelInferenceEngine:
|
|||||||
first_token_list = first_token.tolist()
|
first_token_list = first_token.tolist()
|
||||||
|
|
||||||
generated_tokens = [[] for _ in range(batch_size)]
|
generated_tokens = [[] for _ in range(batch_size)]
|
||||||
hit_stop_on_first = False
|
is_finished = [False] * batch_size
|
||||||
|
|
||||||
for i, token_id in enumerate(first_token_list):
|
for i, token_id in enumerate(first_token_list):
|
||||||
if stop_tokens and token_id in stop_tokens:
|
if stop_tokens and token_id in stop_tokens:
|
||||||
hit_stop_on_first = True
|
is_finished[i] = True
|
||||||
else:
|
else:
|
||||||
generated_tokens[i].append(token_id)
|
generated_tokens[i].append(token_id)
|
||||||
|
sequences[i].append_token(token_id)
|
||||||
|
self.kv_manager.append_to_seq(sequences[i])
|
||||||
|
|
||||||
if hit_stop_on_first:
|
if all(is_finished):
|
||||||
for req in sequences:
|
for req in sequences:
|
||||||
self.kv_manager.remove_seq(req)
|
self.kv_manager.remove_seq(req)
|
||||||
self.current_sequences = []
|
self.current_sequences = []
|
||||||
@@ -540,22 +562,11 @@ class AccelInferenceEngine:
|
|||||||
output = torch.tensor(output_ids, dtype=torch.long, device=device)
|
output = torch.tensor(output_ids, dtype=torch.long, device=device)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
if not hit_stop_on_first:
|
|
||||||
for i, req in enumerate(sequences):
|
|
||||||
req.append_token(first_token_list[i])
|
|
||||||
self.kv_manager.append_to_seq(req)
|
|
||||||
|
|
||||||
remaining_tokens = max_new_tokens - 1
|
remaining_tokens = max_new_tokens - 1
|
||||||
|
|
||||||
for step in range(remaining_tokens):
|
for step in range(remaining_tokens):
|
||||||
decode_ids, decode_pos = self._prepare_decode(sequences)
|
decode_ids, decode_pos = self._prepare_decode(sequences)
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
if batch_size > 8:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"FATAL: batch_size={batch_size} exceeds CUDA Graph limit (8)!"
|
|
||||||
)
|
|
||||||
|
|
||||||
context = get_forward_context()
|
context = get_forward_context()
|
||||||
hidden_states = self._run_decode_with_graph(
|
hidden_states = self._run_decode_with_graph(
|
||||||
decode_ids,
|
decode_ids,
|
||||||
@@ -582,32 +593,67 @@ class AccelInferenceEngine:
|
|||||||
next_token = torch.argmax(logits, dim=-1)
|
next_token = torch.argmax(logits, dim=-1)
|
||||||
next_token_list = next_token.tolist()
|
next_token_list = next_token.tolist()
|
||||||
|
|
||||||
should_stop = False
|
|
||||||
for i, token_id in enumerate(next_token_list):
|
for i, token_id in enumerate(next_token_list):
|
||||||
if stop_tokens and token_id in stop_tokens:
|
if is_finished[i]:
|
||||||
should_stop = True
|
continue
|
||||||
|
elif stop_tokens and token_id in stop_tokens:
|
||||||
|
is_finished[i] = True
|
||||||
else:
|
else:
|
||||||
sequences[i].append_token(token_id)
|
sequences[i].append_token(token_id)
|
||||||
self.kv_manager.append_to_seq(sequences[i])
|
self.kv_manager.append_to_seq(sequences[i])
|
||||||
generated_tokens[i].append(token_id)
|
generated_tokens[i].append(token_id)
|
||||||
|
|
||||||
if should_stop:
|
if all(is_finished):
|
||||||
break
|
break
|
||||||
|
|
||||||
for req in sequences:
|
for req in sequences:
|
||||||
self.kv_manager.remove_seq(req)
|
self.kv_manager.remove_seq(req)
|
||||||
self.current_sequences = []
|
self.current_sequences = []
|
||||||
|
|
||||||
output_ids = []
|
pad_token = stop_tokens[0] if stop_tokens else 0
|
||||||
for i in range(batch_size):
|
|
||||||
initial_tokens = sequences[i].token_ids[: sequences[i].num_prompt_tokens]
|
|
||||||
full_sequence = initial_tokens + generated_tokens[i]
|
|
||||||
output_ids.append(full_sequence)
|
|
||||||
|
|
||||||
output = torch.tensor(output_ids, dtype=torch.long, device=device)
|
if is_varlen_batch:
|
||||||
|
max_prompt_len = attention_mask.size(1)
|
||||||
|
output_ids = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
padding_len = max_prompt_len - seq_lens[i]
|
||||||
|
initial_tokens = sequences[i].token_ids[
|
||||||
|
: sequences[i].num_prompt_tokens
|
||||||
|
]
|
||||||
|
padded_prompt = [pad_token] * padding_len + initial_tokens
|
||||||
|
full_sequence = padded_prompt + generated_tokens[i]
|
||||||
|
output_ids.append(full_sequence)
|
||||||
|
else:
|
||||||
|
output_ids = [
|
||||||
|
sequences[i].token_ids[: sequences[i].num_prompt_tokens]
|
||||||
|
+ generated_tokens[i]
|
||||||
|
for i in range(batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
max_length = max(len(seq) for seq in output_ids)
|
||||||
|
padded_output_ids = [
|
||||||
|
seq + [pad_token] * (max_length - len(seq)) for seq in output_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
output = torch.tensor(padded_output_ids, dtype=torch.long, device=device)
|
||||||
|
|
||||||
assert output.size(0) == batch_size, (
|
assert output.size(0) == batch_size, (
|
||||||
f"Output batch size mismatch: {output.size(0)} != {batch_size}"
|
f"Output batch size mismatch: {output.size(0)} != {batch_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
||||||
|
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
sample_tokens = probs.div_(
|
||||||
|
torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)
|
||||||
|
).argmax(dim=-1)
|
||||||
|
return sample_tokens
|
||||||
Reference in New Issue
Block a user