mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-25 03:15:01 +08:00
Merge branch 'main' into perf/gumbel_softmax_sampler
This commit is contained in:
@@ -59,7 +59,6 @@ class AccelInferenceEngine:
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
|
||||
model_dtype = next(model.parameters()).dtype
|
||||
self.hidden_size = (
|
||||
model.config.hidden_size
|
||||
if hasattr(model, "config")
|
||||
@@ -102,14 +101,15 @@ class AccelInferenceEngine:
|
||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||
|
||||
if req.block_table:
|
||||
for i in range(req.num_cached_blocks, req.num_blocks):
|
||||
block_id = req.block_table[i]
|
||||
start = block_id * self.block_size
|
||||
if i != req.num_blocks - 1:
|
||||
end = start + self.block_size
|
||||
else:
|
||||
end = start + req.last_block_num_tokens
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
num_cached = req.num_cached_tokens
|
||||
num_total = len(req)
|
||||
|
||||
for token_idx in range(num_cached, num_total):
|
||||
block_idx = token_idx // self.block_size
|
||||
block_offset = token_idx % self.block_size
|
||||
block_id = req.block_table[block_idx]
|
||||
slot_idx = block_id * self.block_size + block_offset
|
||||
slot_mapping.append(slot_idx)
|
||||
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
|
||||
non_blocking=True
|
||||
@@ -218,13 +218,12 @@ class AccelInferenceEngine:
|
||||
).cuda(non_blocking=True)
|
||||
return temperatures
|
||||
|
||||
@torch.inference_mode()
|
||||
def _capture_cuda_graphs(self, tts_mel_embedding=None, tts_text_pos_embedding=None):
|
||||
print("Capturing CUDA graphs for decode optimization...")
|
||||
max_bs = 8 # Support up to batch size 8
|
||||
max_num_blocks = (2048 + self.block_size - 1) // self.block_size
|
||||
model_dtype = next(self.model.parameters()).dtype
|
||||
input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda") * 8192
|
||||
input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda")
|
||||
positions = torch.ones(max_bs, dtype=torch.int64, device="cuda")
|
||||
slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
|
||||
context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
|
||||
@@ -238,7 +237,7 @@ class AccelInferenceEngine:
|
||||
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
|
||||
)
|
||||
|
||||
self.graph_bs = [1]
|
||||
self.graph_bs = [1, 2, 4, 8]
|
||||
|
||||
use_tts = tts_mel_embedding is not None and tts_text_pos_embedding is not None
|
||||
|
||||
@@ -247,7 +246,7 @@ class AccelInferenceEngine:
|
||||
|
||||
slot_mapping[:bs] = torch.arange(bs, dtype=torch.int32, device="cuda")
|
||||
context_lens[:bs] = bs + 1
|
||||
block_tables[:bs, 0] = 0
|
||||
block_tables[:bs, :] = 0
|
||||
|
||||
set_forward_context(
|
||||
False,
|
||||
@@ -310,7 +309,6 @@ class AccelInferenceEngine:
|
||||
}
|
||||
print(f"CUDA graphs captured for batch sizes: {self.graph_bs}")
|
||||
|
||||
@torch.inference_mode()
|
||||
def _run_decode_with_graph(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -363,19 +361,13 @@ class AccelInferenceEngine:
|
||||
if graph_vars is None:
|
||||
raise RuntimeError("Graph variables not initialized")
|
||||
|
||||
set_forward_context(
|
||||
False,
|
||||
slot_mapping=graph_vars["slot_mapping"][:graph_bs],
|
||||
context_lens=graph_vars["context_lens"][:graph_bs],
|
||||
block_tables=graph_vars["block_tables"][:graph_bs],
|
||||
)
|
||||
|
||||
graph_vars["input_ids"][:bs] = input_ids
|
||||
graph_vars["positions"][:bs] = positions
|
||||
graph_vars["slot_mapping"].fill_(-1)
|
||||
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
||||
graph_vars["context_lens"].zero_()
|
||||
graph_vars["context_lens"][:bs] = context.context_lens
|
||||
graph_vars["block_tables"][:bs, :].fill_(-1)
|
||||
graph_vars["block_tables"][:bs, : context.block_tables.size(1)] = (
|
||||
context.block_tables
|
||||
)
|
||||
@@ -383,7 +375,6 @@ class AccelInferenceEngine:
|
||||
|
||||
return graph_vars["outputs"][:bs]
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -440,14 +431,26 @@ class AccelInferenceEngine:
|
||||
|
||||
if tts_embeddings is not None:
|
||||
actual_seq_len = tts_embeddings.size(1) + 1 # embeddings + start_mel_token
|
||||
pass
|
||||
else:
|
||||
actual_seq_len = input_ids.size(1)
|
||||
|
||||
is_varlen_batch = (
|
||||
tts_embeddings is not None
|
||||
and attention_mask is not None
|
||||
and batch_size > 1
|
||||
and (attention_mask.sum(dim=1) != attention_mask.size(1)).any()
|
||||
)
|
||||
|
||||
if is_varlen_batch:
|
||||
seq_lens = [attention_mask[i].sum().item() for i in range(batch_size)]
|
||||
else:
|
||||
seq_lens = [actual_seq_len] * batch_size
|
||||
|
||||
sequences = []
|
||||
for i in range(batch_size):
|
||||
token_ids = [1] * actual_seq_len
|
||||
if tts_embeddings is not None and actual_seq_len > 0:
|
||||
seq_len = seq_lens[i]
|
||||
token_ids = [1] * seq_len
|
||||
if tts_embeddings is not None and seq_len > 0:
|
||||
token_ids[-1] = input_ids[i, -1].item() if input_ids.size(1) > 0 else 1
|
||||
else:
|
||||
token_ids = input_ids[i].tolist()
|
||||
@@ -457,18 +460,8 @@ class AccelInferenceEngine:
|
||||
|
||||
self.current_sequences = sequences
|
||||
|
||||
# Prefill phase
|
||||
prefill_ids, prefill_pos = self._prepare_prefill(sequences)
|
||||
|
||||
if prefill_ids.dim() == 1:
|
||||
prefill_ids = prefill_ids.unsqueeze(
|
||||
0
|
||||
) # [total_tokens] -> [1, total_tokens]
|
||||
if prefill_pos.dim() == 1:
|
||||
prefill_pos = prefill_pos.unsqueeze(
|
||||
0
|
||||
) # [total_tokens] -> [1, total_tokens]
|
||||
|
||||
if (
|
||||
tts_embeddings is not None
|
||||
and tts_mel_embedding is not None
|
||||
@@ -480,11 +473,31 @@ class AccelInferenceEngine:
|
||||
torch.tensor([[start_token_id]], device="cuda")
|
||||
) # [1, 1, hidden_dim]
|
||||
|
||||
start_emb = start_emb + tts_text_pos_embedding(start_emb)
|
||||
start_pos = torch.tensor(
|
||||
[[tts_embeddings.size(1)]], device="cuda", dtype=torch.long
|
||||
)
|
||||
pos_emb = tts_text_pos_embedding.emb(start_pos)
|
||||
start_emb = start_emb + pos_emb
|
||||
start_emb = start_emb.repeat(batch_size, 1, 1)
|
||||
|
||||
full_embeddings = torch.cat(
|
||||
[tts_embeddings, start_emb], dim=1
|
||||
) # [1, 88, hidden_dim]
|
||||
if is_varlen_batch:
|
||||
valid_embeddings = []
|
||||
for i in range(batch_size):
|
||||
emb_len = seq_lens[i] - 1
|
||||
padding_len = tts_embeddings.size(1) - emb_len
|
||||
valid_emb = tts_embeddings[i, padding_len:].unsqueeze(
|
||||
0
|
||||
) # [1, emb_len, hidden_dim]
|
||||
valid_embeddings.append(
|
||||
torch.cat([valid_emb, start_emb[i : i + 1]], dim=1)
|
||||
)
|
||||
full_embeddings = torch.cat(
|
||||
valid_embeddings, dim=1
|
||||
) # [1, total_tokens, hidden_dim]
|
||||
else:
|
||||
full_embeddings = torch.cat(
|
||||
[tts_embeddings, start_emb], dim=1
|
||||
) # [batch_size, seq_len, hidden_dim]
|
||||
|
||||
model_dtype = next(self.model.parameters()).dtype
|
||||
if full_embeddings.dtype != model_dtype:
|
||||
@@ -499,9 +512,16 @@ class AccelInferenceEngine:
|
||||
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
|
||||
).last_hidden_state
|
||||
|
||||
reset_forward_context()
|
||||
if is_varlen_batch:
|
||||
context = get_forward_context()
|
||||
cu_seqlens = context.cu_seqlens_q.cpu().tolist()
|
||||
last_hidden = torch.stack(
|
||||
[hidden_states[0, cu_seqlens[i + 1] - 1] for i in range(batch_size)]
|
||||
)
|
||||
else:
|
||||
last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
|
||||
|
||||
last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
|
||||
reset_forward_context()
|
||||
|
||||
if self.lm_head is not None:
|
||||
if last_hidden.dtype != next(self.lm_head.parameters()).dtype:
|
||||
@@ -519,15 +539,17 @@ class AccelInferenceEngine:
|
||||
first_token_list = first_token.tolist()
|
||||
|
||||
generated_tokens = [[] for _ in range(batch_size)]
|
||||
hit_stop_on_first = False
|
||||
is_finished = [False] * batch_size
|
||||
|
||||
for i, token_id in enumerate(first_token_list):
|
||||
if stop_tokens and token_id in stop_tokens:
|
||||
hit_stop_on_first = True
|
||||
is_finished[i] = True
|
||||
else:
|
||||
generated_tokens[i].append(token_id)
|
||||
sequences[i].append_token(token_id)
|
||||
self.kv_manager.append_to_seq(sequences[i])
|
||||
|
||||
if hit_stop_on_first:
|
||||
if all(is_finished):
|
||||
for req in sequences:
|
||||
self.kv_manager.remove_seq(req)
|
||||
self.current_sequences = []
|
||||
@@ -540,22 +562,11 @@ class AccelInferenceEngine:
|
||||
output = torch.tensor(output_ids, dtype=torch.long, device=device)
|
||||
return output
|
||||
|
||||
if not hit_stop_on_first:
|
||||
for i, req in enumerate(sequences):
|
||||
req.append_token(first_token_list[i])
|
||||
self.kv_manager.append_to_seq(req)
|
||||
|
||||
remaining_tokens = max_new_tokens - 1
|
||||
|
||||
for step in range(remaining_tokens):
|
||||
decode_ids, decode_pos = self._prepare_decode(sequences)
|
||||
|
||||
# Forward pass
|
||||
if batch_size > 8:
|
||||
raise RuntimeError(
|
||||
f"FATAL: batch_size={batch_size} exceeds CUDA Graph limit (8)!"
|
||||
)
|
||||
|
||||
context = get_forward_context()
|
||||
hidden_states = self._run_decode_with_graph(
|
||||
decode_ids,
|
||||
@@ -582,32 +593,67 @@ class AccelInferenceEngine:
|
||||
next_token = torch.argmax(logits, dim=-1)
|
||||
next_token_list = next_token.tolist()
|
||||
|
||||
should_stop = False
|
||||
for i, token_id in enumerate(next_token_list):
|
||||
if stop_tokens and token_id in stop_tokens:
|
||||
should_stop = True
|
||||
if is_finished[i]:
|
||||
continue
|
||||
elif stop_tokens and token_id in stop_tokens:
|
||||
is_finished[i] = True
|
||||
else:
|
||||
sequences[i].append_token(token_id)
|
||||
self.kv_manager.append_to_seq(sequences[i])
|
||||
generated_tokens[i].append(token_id)
|
||||
|
||||
if should_stop:
|
||||
if all(is_finished):
|
||||
break
|
||||
|
||||
for req in sequences:
|
||||
self.kv_manager.remove_seq(req)
|
||||
self.current_sequences = []
|
||||
|
||||
output_ids = []
|
||||
for i in range(batch_size):
|
||||
initial_tokens = sequences[i].token_ids[: sequences[i].num_prompt_tokens]
|
||||
full_sequence = initial_tokens + generated_tokens[i]
|
||||
output_ids.append(full_sequence)
|
||||
pad_token = stop_tokens[0] if stop_tokens else 0
|
||||
|
||||
output = torch.tensor(output_ids, dtype=torch.long, device=device)
|
||||
if is_varlen_batch:
|
||||
max_prompt_len = attention_mask.size(1)
|
||||
output_ids = []
|
||||
|
||||
for i in range(batch_size):
|
||||
padding_len = max_prompt_len - seq_lens[i]
|
||||
initial_tokens = sequences[i].token_ids[
|
||||
: sequences[i].num_prompt_tokens
|
||||
]
|
||||
padded_prompt = [pad_token] * padding_len + initial_tokens
|
||||
full_sequence = padded_prompt + generated_tokens[i]
|
||||
output_ids.append(full_sequence)
|
||||
else:
|
||||
output_ids = [
|
||||
sequences[i].token_ids[: sequences[i].num_prompt_tokens]
|
||||
+ generated_tokens[i]
|
||||
for i in range(batch_size)
|
||||
]
|
||||
|
||||
max_length = max(len(seq) for seq in output_ids)
|
||||
padded_output_ids = [
|
||||
seq + [pad_token] * (max_length - len(seq)) for seq in output_ids
|
||||
]
|
||||
|
||||
output = torch.tensor(padded_output_ids, dtype=torch.long, device=device)
|
||||
|
||||
assert output.size(0) == batch_size, (
|
||||
f"Output batch size mismatch: {output.size(0)} != {batch_size}"
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torch.compile
|
||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
||||
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
sample_tokens = probs.div_(
|
||||
torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)
|
||||
).argmax(dim=-1)
|
||||
return sample_tokens
|
||||
Reference in New Issue
Block a user