Files
index-tts/indextts/accel/accel_engine.py

659 lines
24 KiB
Python

import sys
from typing import List, Optional
import torch
from torch import nn
from .attention import (
ForwardContext,
get_forward_context,
reset_forward_context,
set_forward_context,
)
from .kv_manager import KVCacheManager, Seq
class Sampler(nn.Module):
def __init__(self):
super().__init__()
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
temperatures = temperatures.to(logits.device).clamp(min=1e-8)
greedy_mask = temperatures < 1e-5
temp_for_scaling = torch.where(greedy_mask, 1.0, temperatures)
scaled_logits = logits / temp_for_scaling.unsqueeze(-1)
probs = torch.softmax(scaled_logits, dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
sampled_tokens = probs.div_(q).argmax(dim=-1)
greedy_tokens = logits.argmax(dim=-1)
return torch.where(greedy_mask, greedy_tokens, sampled_tokens)
class AccelInferenceEngine:
def __init__(
self,
model,
lm_head,
num_layers: int,
num_heads: int,
head_dim: int,
block_size: int = 256,
num_blocks: int = 128,
use_cuda_graph: bool = True,
):
"""
Args:
model: The GPT transformer model (should have accel attention)
lm_head: Language model head for generating logits
num_layers: Number of transformer layers
num_heads: Number of attention heads
head_dim: Dimension per head
block_size: KV cache block size
num_blocks: Total number of KV cache blocks
use_cuda_graph: Whether to use CUDA Graph for decode optimization
"""
self.model = model
self.lm_head = lm_head
self.block_size = block_size
self.num_blocks = num_blocks
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
self.hidden_size = (
model.config.hidden_size
if hasattr(model, "config")
else head_dim * num_heads
)
self.kv_manager = KVCacheManager(
num_layers=num_layers,
num_heads=num_heads,
head_dim=head_dim,
block_size=block_size,
num_blocks=num_blocks,
dtype=torch.float16, # Force fp16 for FlashAttention
)
self.kv_manager.wire_kv_cache_to_model(model)
self.sampler = Sampler()
self.current_sequences = []
self.graphs = {}
self.graph_vars = None
self.graph_pool = None
self.graph_captured = False
def _prepare_prefill(self, requests: List[Seq]):
input_ids = []
positions = []
cu_seqlens_q = [0]
cu_seqlens_k = [0]
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
for req in requests:
seqlen = len(req)
input_ids.extend(req[req.num_cached_tokens :])
positions.extend(list(range(req.num_cached_tokens, seqlen)))
seqlen_q = seqlen - req.num_cached_tokens
seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
if req.block_table:
num_cached = req.num_cached_tokens
num_total = len(req)
for token_idx in range(num_cached, num_total):
block_idx = token_idx // self.block_size
block_offset = token_idx % self.block_size
block_id = req.block_table[block_idx]
slot_idx = block_id * self.block_size + block_offset
slot_mapping.append(slot_idx)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
non_blocking=True
)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(
non_blocking=True
)
cu_seqlens_q = torch.tensor(
cu_seqlens_q, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(
cu_seqlens_k, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
slot_mapping = torch.tensor(
slot_mapping, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
block_tables = None
if cu_seqlens_k[-1] > cu_seqlens_q[-1]:
max_len = max(len(req.block_table) for req in requests)
block_tables_list = []
for req in requests:
table = req.block_table + [-1] * (max_len - len(req.block_table))
block_tables_list.append(table)
block_tables = torch.tensor(
block_tables_list, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
set_forward_context(
True,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
slot_mapping,
None,
block_tables,
)
return input_ids, positions
def _prepare_decode(self, requests: List[Seq]):
if not requests:
raise RuntimeError("FATAL: No requests provided to _prepare_decode!")
input_ids = []
positions = []
slot_mapping = []
context_lens = []
for req in requests:
input_ids.append(req.last_token)
pos = len(req) - 1
if hasattr(self, "_tts_mode") and self._tts_mode:
pos = pos - (self._tts_prompt_len - 1)
positions.append(pos)
context_lens.append(len(req))
slot_mapping.append(
req.block_table[-1] * self.block_size + req.last_block_num_tokens - 1
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
non_blocking=True
)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(
non_blocking=True
)
slot_mapping = torch.tensor(
slot_mapping, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
context_lens = torch.tensor(
context_lens, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
max_len = max(len(req.block_table) for req in requests)
block_tables_list = []
for req in requests:
table = req.block_table + [-1] * (max_len - len(req.block_table))
block_tables_list.append(table)
block_tables = torch.tensor(
block_tables_list, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
assert block_tables.dim() == 2, (
f"block_tables must be 2D, got shape {block_tables.shape}"
)
assert block_tables.size(0) == len(requests), (
f"block_tables batch size mismatch: {block_tables.size(0)} vs {len(requests)}"
)
set_forward_context(
False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
)
return input_ids, positions
def _prepare_sample(self, requests: List[Seq], temperature: float):
temperatures = [temperature] * len(requests)
temperatures = torch.tensor(
temperatures, dtype=torch.float32, pin_memory=True
).cuda(non_blocking=True)
return temperatures
def _capture_cuda_graphs(self, tts_mel_embedding=None, tts_text_pos_embedding=None):
print("Capturing CUDA graphs for decode optimization...")
max_bs = 8 # Support up to batch size 8
max_num_blocks = (2048 + self.block_size - 1) // self.block_size
model_dtype = next(self.model.parameters()).dtype
input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda")
positions = torch.ones(max_bs, dtype=torch.int64, device="cuda")
slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
block_tables = torch.zeros(
max_bs, max_num_blocks, dtype=torch.int32, device="cuda"
)
outputs = torch.zeros(
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
)
inputs_embeds_buffer = torch.zeros(
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
)
self.graph_bs = [1, 2, 4, 8]
use_tts = tts_mel_embedding is not None and tts_text_pos_embedding is not None
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
slot_mapping[:bs] = torch.arange(bs, dtype=torch.int32, device="cuda")
context_lens[:bs] = bs + 1
block_tables[:bs, :] = 0
set_forward_context(
False,
slot_mapping=slot_mapping[:bs],
context_lens=context_lens[:bs],
block_tables=block_tables[:bs],
)
# warmup
if use_tts:
assert tts_mel_embedding is not None
assert tts_text_pos_embedding is not None
emb = tts_mel_embedding(input_ids[:bs])
pos_clamped = torch.clamp(positions[:bs], min=0)
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
inputs_embeds_buffer[:bs] = emb + pos_emb
out = self.model(
inputs_embeds=inputs_embeds_buffer[:bs].unsqueeze(1),
return_dict=True,
).last_hidden_state
else:
out = self.model(
input_ids=input_ids[:bs].unsqueeze(1), return_dict=True
).last_hidden_state
outputs[:bs] = out.squeeze(1) if out.dim() == 3 else out
with torch.cuda.graph(graph, self.graph_pool):
if use_tts:
assert tts_mel_embedding is not None
assert tts_text_pos_embedding is not None
emb = tts_mel_embedding(input_ids[:bs])
pos_clamped = torch.clamp(positions[:bs], min=0)
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
inputs_embeds_buffer[:bs] = emb + pos_emb
out = self.model(
inputs_embeds=inputs_embeds_buffer[:bs].unsqueeze(1),
return_dict=True,
).last_hidden_state
else:
out = self.model(
input_ids=input_ids[:bs].unsqueeze(1), return_dict=True
).last_hidden_state
outputs[:bs] = out.squeeze(1) if out.dim() == 3 else out
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[bs] = graph
torch.cuda.synchronize()
reset_forward_context()
self.graph_vars = {
"input_ids": input_ids,
"positions": positions,
"slot_mapping": slot_mapping,
"context_lens": context_lens,
"block_tables": block_tables,
"outputs": outputs,
"inputs_embeds": inputs_embeds_buffer,
}
print(f"CUDA graphs captured for batch sizes: {self.graph_bs}")
def _run_decode_with_graph(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
context: ForwardContext,
tts_mel_embedding: Optional[torch.nn.Module] = None,
tts_text_pos_embedding: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
bs = input_ids.size(0)
use_tts_embedding = hasattr(self, "_tts_mode") and self._tts_mode
if not self.use_cuda_graph or not self.graphs:
if use_tts_embedding:
assert tts_mel_embedding is not None
assert tts_text_pos_embedding is not None
inputs_embeds = tts_mel_embedding(input_ids)
pos_clamped = torch.clamp(positions, min=0)
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
inputs_embeds = inputs_embeds + pos_emb
out = self.model(
inputs_embeds=inputs_embeds.unsqueeze(1), return_dict=True
).last_hidden_state
else:
out = self.model(
input_ids=input_ids.unsqueeze(1), return_dict=True
).last_hidden_state
return out.squeeze(1) if out.dim() == 3 else out
graph_bs = next((x for x in self.graph_bs if x >= bs), None)
if graph_bs is None:
if use_tts_embedding:
assert tts_mel_embedding is not None
assert tts_text_pos_embedding is not None
inputs_embeds = tts_mel_embedding(input_ids)
pos_clamped = torch.clamp(positions, min=0)
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
inputs_embeds = inputs_embeds + pos_emb
out = self.model(
inputs_embeds=inputs_embeds.unsqueeze(1), return_dict=True
).last_hidden_state
else:
out = self.model(
input_ids=input_ids.unsqueeze(1), return_dict=True
).last_hidden_state
return out.squeeze(1) if out.dim() == 3 else out
graph = self.graphs[graph_bs]
graph_vars = self.graph_vars
if graph_vars is None:
raise RuntimeError("Graph variables not initialized")
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"].fill_(-1)
graph_vars["slot_mapping"][:bs] = context.slot_mapping
graph_vars["context_lens"].zero_()
graph_vars["context_lens"][:bs] = context.context_lens
graph_vars["block_tables"][:bs, :].fill_(-1)
graph_vars["block_tables"][:bs, : context.block_tables.size(1)] = (
context.block_tables
)
graph.replay()
return graph_vars["outputs"][:bs]
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
stop_tokens: Optional[List[int]] = None,
attention_mask: Optional[torch.Tensor] = None,
tts_embeddings: Optional[
torch.Tensor
] = None, # TTS: [pad][cond][text] embeddings (87 tokens, NO start_mel)
tts_mel_embedding: Optional[torch.nn.Module] = None, # TTS: mel_embedding layer
tts_text_pos_embedding: Optional[
torch.nn.Module
] = None, # TTS: text_pos_embedding layer
) -> torch.Tensor:
"""
Generate tokens.
Args:
input_ids: Input token IDs [batch_size, seq_len]
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_k: Top-k sampling
top_p: Nucleus sampling threshold
stop_tokens: List of token IDs that stop generation
Returns:
Generated token IDs [batch_size, total_len]
"""
batch_size = input_ids.size(0)
device = input_ids.device
self._tts_mode = tts_embeddings is not None
self._tts_prompt_len = input_ids.size(1) if self._tts_mode else 0
if self.use_cuda_graph and not self.graph_captured:
print(
f"[CAPTURE] use_cuda_graph={self.use_cuda_graph}, graph_captured={self.graph_captured}",
file=sys.stderr,
flush=True,
)
self._capture_cuda_graphs(
tts_mel_embedding=tts_mel_embedding,
tts_text_pos_embedding=tts_text_pos_embedding,
)
self.graph_captured = True
print(
f"[CAPTURE] Completed! graphs={list(self.graphs.keys())}",
file=sys.stderr,
flush=True,
)
if tts_embeddings is not None:
actual_seq_len = tts_embeddings.size(1) + 1 # embeddings + start_mel_token
else:
actual_seq_len = input_ids.size(1)
is_varlen_batch = (
tts_embeddings is not None
and attention_mask is not None
and batch_size > 1
and (attention_mask.sum(dim=1) != attention_mask.size(1)).any()
)
if is_varlen_batch:
seq_lens = [attention_mask[i].sum().item() for i in range(batch_size)]
else:
seq_lens = [actual_seq_len] * batch_size
sequences = []
for i in range(batch_size):
seq_len = seq_lens[i]
token_ids = [1] * seq_len
if tts_embeddings is not None and seq_len > 0:
token_ids[-1] = input_ids[i, -1].item() if input_ids.size(1) > 0 else 1
else:
token_ids = input_ids[i].tolist()
req = Seq(token_ids)
self.kv_manager.allocate(req)
sequences.append(req)
self.current_sequences = sequences
prefill_ids, prefill_pos = self._prepare_prefill(sequences)
if (
tts_embeddings is not None
and tts_mel_embedding is not None
and tts_text_pos_embedding is not None
):
start_token_id = input_ids[0, -1] if input_ids.size(1) > 0 else 8192
start_emb = tts_mel_embedding(
torch.tensor([[start_token_id]], device="cuda")
) # [1, 1, hidden_dim]
start_pos = torch.tensor(
[[tts_embeddings.size(1)]], device="cuda", dtype=torch.long
)
pos_emb = tts_text_pos_embedding.emb(start_pos)
start_emb = start_emb + pos_emb
start_emb = start_emb.repeat(batch_size, 1, 1)
if is_varlen_batch:
valid_embeddings = []
for i in range(batch_size):
emb_len = seq_lens[i] - 1
padding_len = tts_embeddings.size(1) - emb_len
valid_emb = tts_embeddings[i, padding_len:].unsqueeze(
0
) # [1, emb_len, hidden_dim]
valid_embeddings.append(
torch.cat([valid_emb, start_emb[i : i + 1]], dim=1)
)
full_embeddings = torch.cat(
valid_embeddings, dim=1
) # [1, total_tokens, hidden_dim]
else:
full_embeddings = torch.cat(
[tts_embeddings, start_emb], dim=1
) # [batch_size, seq_len, hidden_dim]
model_dtype = next(self.model.parameters()).dtype
if full_embeddings.dtype != model_dtype:
full_embeddings = full_embeddings.to(model_dtype)
hidden_states = self.model(
inputs_embeds=full_embeddings, return_dict=True
).last_hidden_state
else:
hidden_states = self.model(
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
).last_hidden_state
if is_varlen_batch:
context = get_forward_context()
cu_seqlens = context.cu_seqlens_q.cpu().tolist()
last_hidden = torch.stack(
[hidden_states[0, cu_seqlens[i + 1] - 1] for i in range(batch_size)]
)
else:
last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
reset_forward_context()
if self.lm_head is not None:
if last_hidden.dtype != next(self.lm_head.parameters()).dtype:
last_hidden = last_hidden.to(next(self.lm_head.parameters()).dtype)
logits = self.lm_head(last_hidden) # [batch_size, vocab_size]
else:
logits = self.model.compute_logits(last_hidden) # [batch_size, vocab_size]
temperatures = self._prepare_sample(sequences, temperature)
if temperature > 0:
first_token = self.sampler(logits, temperatures)
else:
first_token = torch.argmax(logits, dim=-1)
first_token_list = first_token.tolist()
generated_tokens = [[] for _ in range(batch_size)]
is_finished = [False] * batch_size
for i, token_id in enumerate(first_token_list):
if stop_tokens and token_id in stop_tokens:
is_finished[i] = True
else:
generated_tokens[i].append(token_id)
sequences[i].append_token(token_id)
self.kv_manager.append_to_seq(sequences[i])
if all(is_finished):
for req in sequences:
self.kv_manager.remove_seq(req)
self.current_sequences = []
output_ids = []
for i in range(batch_size):
full_sequence = input_ids[i].tolist() + generated_tokens[i]
output_ids.append(full_sequence)
output = torch.tensor(output_ids, dtype=torch.long, device=device)
return output
remaining_tokens = max_new_tokens - 1
for step in range(remaining_tokens):
decode_ids, decode_pos = self._prepare_decode(sequences)
context = get_forward_context()
hidden_states = self._run_decode_with_graph(
decode_ids,
decode_pos,
context,
tts_mel_embedding=tts_mel_embedding,
tts_text_pos_embedding=tts_text_pos_embedding,
)
# Get logits
if self.lm_head is not None:
logits = self.lm_head(hidden_states) # [batch_size, vocab_size]
else:
logits = self.model.compute_logits(
hidden_states
) # [batch_size, vocab_size]
reset_forward_context()
temperatures = self._prepare_sample(sequences, temperature)
if temperature > 0:
next_token = self.sampler(logits, temperatures)
else:
next_token = torch.argmax(logits, dim=-1)
next_token_list = next_token.tolist()
for i, token_id in enumerate(next_token_list):
if is_finished[i]:
continue
elif stop_tokens and token_id in stop_tokens:
is_finished[i] = True
else:
sequences[i].append_token(token_id)
self.kv_manager.append_to_seq(sequences[i])
generated_tokens[i].append(token_id)
if all(is_finished):
break
for req in sequences:
self.kv_manager.remove_seq(req)
self.current_sequences = []
pad_token = stop_tokens[0] if stop_tokens else 0
if is_varlen_batch:
max_prompt_len = attention_mask.size(1)
output_ids = []
for i in range(batch_size):
padding_len = max_prompt_len - seq_lens[i]
initial_tokens = sequences[i].token_ids[
: sequences[i].num_prompt_tokens
]
padded_prompt = [pad_token] * padding_len + initial_tokens
full_sequence = padded_prompt + generated_tokens[i]
output_ids.append(full_sequence)
else:
output_ids = [
sequences[i].token_ids[: sequences[i].num_prompt_tokens]
+ generated_tokens[i]
for i in range(batch_size)
]
max_length = max(len(seq) for seq in output_ids)
padded_output_ids = [
seq + [pad_token] * (max_length - len(seq)) for seq in output_ids
]
output = torch.tensor(padded_output_ids, dtype=torch.long, device=device)
assert output.size(0) == batch_size, (
f"Output batch size mismatch: {output.size(0)} != {batch_size}"
)
return output
class Sampler(nn.Module):
def __init__(self):
super().__init__()
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1)
sample_tokens = probs.div_(
torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)
).argmax(dim=-1)
return sample_tokens