Merge branch 'main' into perf/gumbel_softmax_sampler

This commit is contained in:
Vanka0051
2025-11-07 15:32:22 +08:00
committed by GitHub

View File

@@ -59,7 +59,6 @@ class AccelInferenceEngine:
self.block_size = block_size
self.num_blocks = num_blocks
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
model_dtype = next(model.parameters()).dtype
self.hidden_size = (
model.config.hidden_size
if hasattr(model, "config")
@@ -102,14 +101,15 @@ class AccelInferenceEngine:
max_seqlen_k = max(seqlen_k, max_seqlen_k)
if req.block_table:
for i in range(req.num_cached_blocks, req.num_blocks):
block_id = req.block_table[i]
start = block_id * self.block_size
if i != req.num_blocks - 1:
end = start + self.block_size
else:
end = start + req.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
num_cached = req.num_cached_tokens
num_total = len(req)
for token_idx in range(num_cached, num_total):
block_idx = token_idx // self.block_size
block_offset = token_idx % self.block_size
block_id = req.block_table[block_idx]
slot_idx = block_id * self.block_size + block_offset
slot_mapping.append(slot_idx)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
non_blocking=True
@@ -218,13 +218,12 @@ class AccelInferenceEngine:
).cuda(non_blocking=True)
return temperatures
@torch.inference_mode()
def _capture_cuda_graphs(self, tts_mel_embedding=None, tts_text_pos_embedding=None):
print("Capturing CUDA graphs for decode optimization...")
max_bs = 8 # Support up to batch size 8
max_num_blocks = (2048 + self.block_size - 1) // self.block_size
model_dtype = next(self.model.parameters()).dtype
input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda") * 8192
input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda")
positions = torch.ones(max_bs, dtype=torch.int64, device="cuda")
slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
@@ -238,7 +237,7 @@ class AccelInferenceEngine:
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
)
self.graph_bs = [1]
self.graph_bs = [1, 2, 4, 8]
use_tts = tts_mel_embedding is not None and tts_text_pos_embedding is not None
@@ -247,7 +246,7 @@ class AccelInferenceEngine:
slot_mapping[:bs] = torch.arange(bs, dtype=torch.int32, device="cuda")
context_lens[:bs] = bs + 1
block_tables[:bs, 0] = 0
block_tables[:bs, :] = 0
set_forward_context(
False,
@@ -310,7 +309,6 @@ class AccelInferenceEngine:
}
print(f"CUDA graphs captured for batch sizes: {self.graph_bs}")
@torch.inference_mode()
def _run_decode_with_graph(
self,
input_ids: torch.Tensor,
@@ -363,19 +361,13 @@ class AccelInferenceEngine:
if graph_vars is None:
raise RuntimeError("Graph variables not initialized")
set_forward_context(
False,
slot_mapping=graph_vars["slot_mapping"][:graph_bs],
context_lens=graph_vars["context_lens"][:graph_bs],
block_tables=graph_vars["block_tables"][:graph_bs],
)
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"].fill_(-1)
graph_vars["slot_mapping"][:bs] = context.slot_mapping
graph_vars["context_lens"].zero_()
graph_vars["context_lens"][:bs] = context.context_lens
graph_vars["block_tables"][:bs, :].fill_(-1)
graph_vars["block_tables"][:bs, : context.block_tables.size(1)] = (
context.block_tables
)
@@ -383,7 +375,6 @@ class AccelInferenceEngine:
return graph_vars["outputs"][:bs]
@torch.inference_mode()
def generate(
self,
input_ids: torch.Tensor,
@@ -440,14 +431,26 @@ class AccelInferenceEngine:
if tts_embeddings is not None:
actual_seq_len = tts_embeddings.size(1) + 1 # embeddings + start_mel_token
pass
else:
actual_seq_len = input_ids.size(1)
is_varlen_batch = (
tts_embeddings is not None
and attention_mask is not None
and batch_size > 1
and (attention_mask.sum(dim=1) != attention_mask.size(1)).any()
)
if is_varlen_batch:
seq_lens = [attention_mask[i].sum().item() for i in range(batch_size)]
else:
seq_lens = [actual_seq_len] * batch_size
sequences = []
for i in range(batch_size):
token_ids = [1] * actual_seq_len
if tts_embeddings is not None and actual_seq_len > 0:
seq_len = seq_lens[i]
token_ids = [1] * seq_len
if tts_embeddings is not None and seq_len > 0:
token_ids[-1] = input_ids[i, -1].item() if input_ids.size(1) > 0 else 1
else:
token_ids = input_ids[i].tolist()
@@ -457,18 +460,8 @@ class AccelInferenceEngine:
self.current_sequences = sequences
# Prefill phase
prefill_ids, prefill_pos = self._prepare_prefill(sequences)
if prefill_ids.dim() == 1:
prefill_ids = prefill_ids.unsqueeze(
0
) # [total_tokens] -> [1, total_tokens]
if prefill_pos.dim() == 1:
prefill_pos = prefill_pos.unsqueeze(
0
) # [total_tokens] -> [1, total_tokens]
if (
tts_embeddings is not None
and tts_mel_embedding is not None
@@ -480,11 +473,31 @@ class AccelInferenceEngine:
torch.tensor([[start_token_id]], device="cuda")
) # [1, 1, hidden_dim]
start_emb = start_emb + tts_text_pos_embedding(start_emb)
start_pos = torch.tensor(
[[tts_embeddings.size(1)]], device="cuda", dtype=torch.long
)
pos_emb = tts_text_pos_embedding.emb(start_pos)
start_emb = start_emb + pos_emb
start_emb = start_emb.repeat(batch_size, 1, 1)
full_embeddings = torch.cat(
[tts_embeddings, start_emb], dim=1
) # [1, 88, hidden_dim]
if is_varlen_batch:
valid_embeddings = []
for i in range(batch_size):
emb_len = seq_lens[i] - 1
padding_len = tts_embeddings.size(1) - emb_len
valid_emb = tts_embeddings[i, padding_len:].unsqueeze(
0
) # [1, emb_len, hidden_dim]
valid_embeddings.append(
torch.cat([valid_emb, start_emb[i : i + 1]], dim=1)
)
full_embeddings = torch.cat(
valid_embeddings, dim=1
) # [1, total_tokens, hidden_dim]
else:
full_embeddings = torch.cat(
[tts_embeddings, start_emb], dim=1
) # [batch_size, seq_len, hidden_dim]
model_dtype = next(self.model.parameters()).dtype
if full_embeddings.dtype != model_dtype:
@@ -499,9 +512,16 @@ class AccelInferenceEngine:
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
).last_hidden_state
reset_forward_context()
if is_varlen_batch:
context = get_forward_context()
cu_seqlens = context.cu_seqlens_q.cpu().tolist()
last_hidden = torch.stack(
[hidden_states[0, cu_seqlens[i + 1] - 1] for i in range(batch_size)]
)
else:
last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
reset_forward_context()
if self.lm_head is not None:
if last_hidden.dtype != next(self.lm_head.parameters()).dtype:
@@ -519,15 +539,17 @@ class AccelInferenceEngine:
first_token_list = first_token.tolist()
generated_tokens = [[] for _ in range(batch_size)]
hit_stop_on_first = False
is_finished = [False] * batch_size
for i, token_id in enumerate(first_token_list):
if stop_tokens and token_id in stop_tokens:
hit_stop_on_first = True
is_finished[i] = True
else:
generated_tokens[i].append(token_id)
sequences[i].append_token(token_id)
self.kv_manager.append_to_seq(sequences[i])
if hit_stop_on_first:
if all(is_finished):
for req in sequences:
self.kv_manager.remove_seq(req)
self.current_sequences = []
@@ -540,22 +562,11 @@ class AccelInferenceEngine:
output = torch.tensor(output_ids, dtype=torch.long, device=device)
return output
if not hit_stop_on_first:
for i, req in enumerate(sequences):
req.append_token(first_token_list[i])
self.kv_manager.append_to_seq(req)
remaining_tokens = max_new_tokens - 1
for step in range(remaining_tokens):
decode_ids, decode_pos = self._prepare_decode(sequences)
# Forward pass
if batch_size > 8:
raise RuntimeError(
f"FATAL: batch_size={batch_size} exceeds CUDA Graph limit (8)!"
)
context = get_forward_context()
hidden_states = self._run_decode_with_graph(
decode_ids,
@@ -582,32 +593,67 @@ class AccelInferenceEngine:
next_token = torch.argmax(logits, dim=-1)
next_token_list = next_token.tolist()
should_stop = False
for i, token_id in enumerate(next_token_list):
if stop_tokens and token_id in stop_tokens:
should_stop = True
if is_finished[i]:
continue
elif stop_tokens and token_id in stop_tokens:
is_finished[i] = True
else:
sequences[i].append_token(token_id)
self.kv_manager.append_to_seq(sequences[i])
generated_tokens[i].append(token_id)
if should_stop:
if all(is_finished):
break
for req in sequences:
self.kv_manager.remove_seq(req)
self.current_sequences = []
output_ids = []
for i in range(batch_size):
initial_tokens = sequences[i].token_ids[: sequences[i].num_prompt_tokens]
full_sequence = initial_tokens + generated_tokens[i]
output_ids.append(full_sequence)
pad_token = stop_tokens[0] if stop_tokens else 0
output = torch.tensor(output_ids, dtype=torch.long, device=device)
if is_varlen_batch:
max_prompt_len = attention_mask.size(1)
output_ids = []
for i in range(batch_size):
padding_len = max_prompt_len - seq_lens[i]
initial_tokens = sequences[i].token_ids[
: sequences[i].num_prompt_tokens
]
padded_prompt = [pad_token] * padding_len + initial_tokens
full_sequence = padded_prompt + generated_tokens[i]
output_ids.append(full_sequence)
else:
output_ids = [
sequences[i].token_ids[: sequences[i].num_prompt_tokens]
+ generated_tokens[i]
for i in range(batch_size)
]
max_length = max(len(seq) for seq in output_ids)
padded_output_ids = [
seq + [pad_token] * (max_length - len(seq)) for seq in output_ids
]
output = torch.tensor(padded_output_ids, dtype=torch.long, device=device)
assert output.size(0) == batch_size, (
f"Output batch size mismatch: {output.size(0)} != {batch_size}"
)
return output
class Sampler(nn.Module):
def __init__(self):
super().__init__()
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1)
sample_tokens = probs.div_(
torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)
).argmax(dim=-1)
return sample_tokens