mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-26 03:44:54 +08:00
659 lines
24 KiB
Python
659 lines
24 KiB
Python
import sys
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from .attention import (
|
|
ForwardContext,
|
|
get_forward_context,
|
|
reset_forward_context,
|
|
set_forward_context,
|
|
)
|
|
from .kv_manager import KVCacheManager, Seq
|
|
|
|
|
|
class Sampler(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@torch.compile
|
|
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
|
temperatures = temperatures.to(logits.device).clamp(min=1e-8)
|
|
greedy_mask = temperatures < 1e-5
|
|
temp_for_scaling = torch.where(greedy_mask, 1.0, temperatures)
|
|
scaled_logits = logits / temp_for_scaling.unsqueeze(-1)
|
|
probs = torch.softmax(scaled_logits, dim=-1, dtype=torch.float32)
|
|
q = torch.empty_like(probs)
|
|
q.exponential_()
|
|
sampled_tokens = probs.div_(q).argmax(dim=-1)
|
|
greedy_tokens = logits.argmax(dim=-1)
|
|
return torch.where(greedy_mask, greedy_tokens, sampled_tokens)
|
|
|
|
|
|
class AccelInferenceEngine:
|
|
def __init__(
|
|
self,
|
|
model,
|
|
lm_head,
|
|
num_layers: int,
|
|
num_heads: int,
|
|
head_dim: int,
|
|
block_size: int = 256,
|
|
num_blocks: int = 128,
|
|
use_cuda_graph: bool = True,
|
|
):
|
|
"""
|
|
Args:
|
|
model: The GPT transformer model (should have accel attention)
|
|
lm_head: Language model head for generating logits
|
|
num_layers: Number of transformer layers
|
|
num_heads: Number of attention heads
|
|
head_dim: Dimension per head
|
|
block_size: KV cache block size
|
|
num_blocks: Total number of KV cache blocks
|
|
use_cuda_graph: Whether to use CUDA Graph for decode optimization
|
|
"""
|
|
self.model = model
|
|
self.lm_head = lm_head
|
|
self.block_size = block_size
|
|
self.num_blocks = num_blocks
|
|
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
|
|
self.hidden_size = (
|
|
model.config.hidden_size
|
|
if hasattr(model, "config")
|
|
else head_dim * num_heads
|
|
)
|
|
self.kv_manager = KVCacheManager(
|
|
num_layers=num_layers,
|
|
num_heads=num_heads,
|
|
head_dim=head_dim,
|
|
block_size=block_size,
|
|
num_blocks=num_blocks,
|
|
dtype=torch.float16, # Force fp16 for FlashAttention
|
|
)
|
|
self.kv_manager.wire_kv_cache_to_model(model)
|
|
self.sampler = Sampler()
|
|
self.current_sequences = []
|
|
self.graphs = {}
|
|
self.graph_vars = None
|
|
self.graph_pool = None
|
|
self.graph_captured = False
|
|
|
|
def _prepare_prefill(self, requests: List[Seq]):
|
|
input_ids = []
|
|
positions = []
|
|
cu_seqlens_q = [0]
|
|
cu_seqlens_k = [0]
|
|
max_seqlen_q = 0
|
|
max_seqlen_k = 0
|
|
slot_mapping = []
|
|
|
|
for req in requests:
|
|
seqlen = len(req)
|
|
input_ids.extend(req[req.num_cached_tokens :])
|
|
positions.extend(list(range(req.num_cached_tokens, seqlen)))
|
|
seqlen_q = seqlen - req.num_cached_tokens
|
|
seqlen_k = seqlen
|
|
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
|
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
|
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
|
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
|
|
|
if req.block_table:
|
|
num_cached = req.num_cached_tokens
|
|
num_total = len(req)
|
|
|
|
for token_idx in range(num_cached, num_total):
|
|
block_idx = token_idx // self.block_size
|
|
block_offset = token_idx % self.block_size
|
|
block_id = req.block_table[block_idx]
|
|
slot_idx = block_id * self.block_size + block_offset
|
|
slot_mapping.append(slot_idx)
|
|
|
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
|
|
non_blocking=True
|
|
)
|
|
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(
|
|
non_blocking=True
|
|
)
|
|
cu_seqlens_q = torch.tensor(
|
|
cu_seqlens_q, dtype=torch.int32, pin_memory=True
|
|
).cuda(non_blocking=True)
|
|
cu_seqlens_k = torch.tensor(
|
|
cu_seqlens_k, dtype=torch.int32, pin_memory=True
|
|
).cuda(non_blocking=True)
|
|
slot_mapping = torch.tensor(
|
|
slot_mapping, dtype=torch.int32, pin_memory=True
|
|
).cuda(non_blocking=True)
|
|
|
|
block_tables = None
|
|
if cu_seqlens_k[-1] > cu_seqlens_q[-1]:
|
|
max_len = max(len(req.block_table) for req in requests)
|
|
block_tables_list = []
|
|
for req in requests:
|
|
table = req.block_table + [-1] * (max_len - len(req.block_table))
|
|
block_tables_list.append(table)
|
|
block_tables = torch.tensor(
|
|
block_tables_list, dtype=torch.int32, pin_memory=True
|
|
).cuda(non_blocking=True)
|
|
|
|
set_forward_context(
|
|
True,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
slot_mapping,
|
|
None,
|
|
block_tables,
|
|
)
|
|
|
|
return input_ids, positions
|
|
|
|
def _prepare_decode(self, requests: List[Seq]):
|
|
if not requests:
|
|
raise RuntimeError("FATAL: No requests provided to _prepare_decode!")
|
|
|
|
input_ids = []
|
|
positions = []
|
|
slot_mapping = []
|
|
context_lens = []
|
|
|
|
for req in requests:
|
|
input_ids.append(req.last_token)
|
|
|
|
pos = len(req) - 1
|
|
if hasattr(self, "_tts_mode") and self._tts_mode:
|
|
pos = pos - (self._tts_prompt_len - 1)
|
|
positions.append(pos)
|
|
|
|
context_lens.append(len(req))
|
|
slot_mapping.append(
|
|
req.block_table[-1] * self.block_size + req.last_block_num_tokens - 1
|
|
)
|
|
|
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
|
|
non_blocking=True
|
|
)
|
|
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(
|
|
non_blocking=True
|
|
)
|
|
slot_mapping = torch.tensor(
|
|
slot_mapping, dtype=torch.int32, pin_memory=True
|
|
).cuda(non_blocking=True)
|
|
context_lens = torch.tensor(
|
|
context_lens, dtype=torch.int32, pin_memory=True
|
|
).cuda(non_blocking=True)
|
|
|
|
max_len = max(len(req.block_table) for req in requests)
|
|
block_tables_list = []
|
|
for req in requests:
|
|
table = req.block_table + [-1] * (max_len - len(req.block_table))
|
|
block_tables_list.append(table)
|
|
block_tables = torch.tensor(
|
|
block_tables_list, dtype=torch.int32, pin_memory=True
|
|
).cuda(non_blocking=True)
|
|
|
|
assert block_tables.dim() == 2, (
|
|
f"block_tables must be 2D, got shape {block_tables.shape}"
|
|
)
|
|
assert block_tables.size(0) == len(requests), (
|
|
f"block_tables batch size mismatch: {block_tables.size(0)} vs {len(requests)}"
|
|
)
|
|
|
|
set_forward_context(
|
|
False,
|
|
slot_mapping=slot_mapping,
|
|
context_lens=context_lens,
|
|
block_tables=block_tables,
|
|
)
|
|
|
|
return input_ids, positions
|
|
|
|
def _prepare_sample(self, requests: List[Seq], temperature: float):
|
|
temperatures = [temperature] * len(requests)
|
|
temperatures = torch.tensor(
|
|
temperatures, dtype=torch.float32, pin_memory=True
|
|
).cuda(non_blocking=True)
|
|
return temperatures
|
|
|
|
def _capture_cuda_graphs(self, tts_mel_embedding=None, tts_text_pos_embedding=None):
|
|
print("Capturing CUDA graphs for decode optimization...")
|
|
max_bs = 8 # Support up to batch size 8
|
|
max_num_blocks = (2048 + self.block_size - 1) // self.block_size
|
|
model_dtype = next(self.model.parameters()).dtype
|
|
input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda")
|
|
positions = torch.ones(max_bs, dtype=torch.int64, device="cuda")
|
|
slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
|
|
context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
|
|
block_tables = torch.zeros(
|
|
max_bs, max_num_blocks, dtype=torch.int32, device="cuda"
|
|
)
|
|
outputs = torch.zeros(
|
|
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
|
|
)
|
|
inputs_embeds_buffer = torch.zeros(
|
|
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
|
|
)
|
|
|
|
self.graph_bs = [1, 2, 4, 8]
|
|
|
|
use_tts = tts_mel_embedding is not None and tts_text_pos_embedding is not None
|
|
|
|
for bs in reversed(self.graph_bs):
|
|
graph = torch.cuda.CUDAGraph()
|
|
|
|
slot_mapping[:bs] = torch.arange(bs, dtype=torch.int32, device="cuda")
|
|
context_lens[:bs] = bs + 1
|
|
block_tables[:bs, :] = 0
|
|
|
|
set_forward_context(
|
|
False,
|
|
slot_mapping=slot_mapping[:bs],
|
|
context_lens=context_lens[:bs],
|
|
block_tables=block_tables[:bs],
|
|
)
|
|
|
|
# warmup
|
|
if use_tts:
|
|
assert tts_mel_embedding is not None
|
|
assert tts_text_pos_embedding is not None
|
|
emb = tts_mel_embedding(input_ids[:bs])
|
|
pos_clamped = torch.clamp(positions[:bs], min=0)
|
|
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
|
|
inputs_embeds_buffer[:bs] = emb + pos_emb
|
|
out = self.model(
|
|
inputs_embeds=inputs_embeds_buffer[:bs].unsqueeze(1),
|
|
return_dict=True,
|
|
).last_hidden_state
|
|
else:
|
|
out = self.model(
|
|
input_ids=input_ids[:bs].unsqueeze(1), return_dict=True
|
|
).last_hidden_state
|
|
outputs[:bs] = out.squeeze(1) if out.dim() == 3 else out
|
|
|
|
with torch.cuda.graph(graph, self.graph_pool):
|
|
if use_tts:
|
|
assert tts_mel_embedding is not None
|
|
assert tts_text_pos_embedding is not None
|
|
emb = tts_mel_embedding(input_ids[:bs])
|
|
pos_clamped = torch.clamp(positions[:bs], min=0)
|
|
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
|
|
inputs_embeds_buffer[:bs] = emb + pos_emb
|
|
out = self.model(
|
|
inputs_embeds=inputs_embeds_buffer[:bs].unsqueeze(1),
|
|
return_dict=True,
|
|
).last_hidden_state
|
|
else:
|
|
out = self.model(
|
|
input_ids=input_ids[:bs].unsqueeze(1), return_dict=True
|
|
).last_hidden_state
|
|
outputs[:bs] = out.squeeze(1) if out.dim() == 3 else out
|
|
|
|
if self.graph_pool is None:
|
|
self.graph_pool = graph.pool()
|
|
|
|
self.graphs[bs] = graph
|
|
torch.cuda.synchronize()
|
|
reset_forward_context()
|
|
|
|
self.graph_vars = {
|
|
"input_ids": input_ids,
|
|
"positions": positions,
|
|
"slot_mapping": slot_mapping,
|
|
"context_lens": context_lens,
|
|
"block_tables": block_tables,
|
|
"outputs": outputs,
|
|
"inputs_embeds": inputs_embeds_buffer,
|
|
}
|
|
print(f"CUDA graphs captured for batch sizes: {self.graph_bs}")
|
|
|
|
def _run_decode_with_graph(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
context: ForwardContext,
|
|
tts_mel_embedding: Optional[torch.nn.Module] = None,
|
|
tts_text_pos_embedding: Optional[torch.nn.Module] = None,
|
|
) -> torch.Tensor:
|
|
bs = input_ids.size(0)
|
|
use_tts_embedding = hasattr(self, "_tts_mode") and self._tts_mode
|
|
|
|
if not self.use_cuda_graph or not self.graphs:
|
|
if use_tts_embedding:
|
|
assert tts_mel_embedding is not None
|
|
assert tts_text_pos_embedding is not None
|
|
inputs_embeds = tts_mel_embedding(input_ids)
|
|
pos_clamped = torch.clamp(positions, min=0)
|
|
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
|
|
inputs_embeds = inputs_embeds + pos_emb
|
|
out = self.model(
|
|
inputs_embeds=inputs_embeds.unsqueeze(1), return_dict=True
|
|
).last_hidden_state
|
|
else:
|
|
out = self.model(
|
|
input_ids=input_ids.unsqueeze(1), return_dict=True
|
|
).last_hidden_state
|
|
return out.squeeze(1) if out.dim() == 3 else out
|
|
|
|
graph_bs = next((x for x in self.graph_bs if x >= bs), None)
|
|
if graph_bs is None:
|
|
if use_tts_embedding:
|
|
assert tts_mel_embedding is not None
|
|
assert tts_text_pos_embedding is not None
|
|
inputs_embeds = tts_mel_embedding(input_ids)
|
|
pos_clamped = torch.clamp(positions, min=0)
|
|
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
|
|
inputs_embeds = inputs_embeds + pos_emb
|
|
out = self.model(
|
|
inputs_embeds=inputs_embeds.unsqueeze(1), return_dict=True
|
|
).last_hidden_state
|
|
else:
|
|
out = self.model(
|
|
input_ids=input_ids.unsqueeze(1), return_dict=True
|
|
).last_hidden_state
|
|
return out.squeeze(1) if out.dim() == 3 else out
|
|
|
|
graph = self.graphs[graph_bs]
|
|
graph_vars = self.graph_vars
|
|
|
|
if graph_vars is None:
|
|
raise RuntimeError("Graph variables not initialized")
|
|
|
|
graph_vars["input_ids"][:bs] = input_ids
|
|
graph_vars["positions"][:bs] = positions
|
|
graph_vars["slot_mapping"].fill_(-1)
|
|
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
|
graph_vars["context_lens"].zero_()
|
|
graph_vars["context_lens"][:bs] = context.context_lens
|
|
graph_vars["block_tables"][:bs, :].fill_(-1)
|
|
graph_vars["block_tables"][:bs, : context.block_tables.size(1)] = (
|
|
context.block_tables
|
|
)
|
|
graph.replay()
|
|
|
|
return graph_vars["outputs"][:bs]
|
|
|
|
def generate(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
max_new_tokens: int = 100,
|
|
temperature: float = 1.0,
|
|
top_k: int = 50,
|
|
top_p: float = 1.0,
|
|
stop_tokens: Optional[List[int]] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
tts_embeddings: Optional[
|
|
torch.Tensor
|
|
] = None, # TTS: [pad][cond][text] embeddings (87 tokens, NO start_mel)
|
|
tts_mel_embedding: Optional[torch.nn.Module] = None, # TTS: mel_embedding layer
|
|
tts_text_pos_embedding: Optional[
|
|
torch.nn.Module
|
|
] = None, # TTS: text_pos_embedding layer
|
|
) -> torch.Tensor:
|
|
"""
|
|
Generate tokens.
|
|
|
|
Args:
|
|
input_ids: Input token IDs [batch_size, seq_len]
|
|
max_new_tokens: Maximum number of tokens to generate
|
|
temperature: Sampling temperature
|
|
top_k: Top-k sampling
|
|
top_p: Nucleus sampling threshold
|
|
stop_tokens: List of token IDs that stop generation
|
|
|
|
Returns:
|
|
Generated token IDs [batch_size, total_len]
|
|
"""
|
|
batch_size = input_ids.size(0)
|
|
device = input_ids.device
|
|
|
|
self._tts_mode = tts_embeddings is not None
|
|
self._tts_prompt_len = input_ids.size(1) if self._tts_mode else 0
|
|
|
|
if self.use_cuda_graph and not self.graph_captured:
|
|
print(
|
|
f"[CAPTURE] use_cuda_graph={self.use_cuda_graph}, graph_captured={self.graph_captured}",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
self._capture_cuda_graphs(
|
|
tts_mel_embedding=tts_mel_embedding,
|
|
tts_text_pos_embedding=tts_text_pos_embedding,
|
|
)
|
|
self.graph_captured = True
|
|
print(
|
|
f"[CAPTURE] Completed! graphs={list(self.graphs.keys())}",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
if tts_embeddings is not None:
|
|
actual_seq_len = tts_embeddings.size(1) + 1 # embeddings + start_mel_token
|
|
else:
|
|
actual_seq_len = input_ids.size(1)
|
|
|
|
is_varlen_batch = (
|
|
tts_embeddings is not None
|
|
and attention_mask is not None
|
|
and batch_size > 1
|
|
and (attention_mask.sum(dim=1) != attention_mask.size(1)).any()
|
|
)
|
|
|
|
if is_varlen_batch:
|
|
seq_lens = [attention_mask[i].sum().item() for i in range(batch_size)]
|
|
else:
|
|
seq_lens = [actual_seq_len] * batch_size
|
|
|
|
sequences = []
|
|
for i in range(batch_size):
|
|
seq_len = seq_lens[i]
|
|
token_ids = [1] * seq_len
|
|
if tts_embeddings is not None and seq_len > 0:
|
|
token_ids[-1] = input_ids[i, -1].item() if input_ids.size(1) > 0 else 1
|
|
else:
|
|
token_ids = input_ids[i].tolist()
|
|
req = Seq(token_ids)
|
|
self.kv_manager.allocate(req)
|
|
sequences.append(req)
|
|
|
|
self.current_sequences = sequences
|
|
|
|
prefill_ids, prefill_pos = self._prepare_prefill(sequences)
|
|
|
|
if (
|
|
tts_embeddings is not None
|
|
and tts_mel_embedding is not None
|
|
and tts_text_pos_embedding is not None
|
|
):
|
|
start_token_id = input_ids[0, -1] if input_ids.size(1) > 0 else 8192
|
|
|
|
start_emb = tts_mel_embedding(
|
|
torch.tensor([[start_token_id]], device="cuda")
|
|
) # [1, 1, hidden_dim]
|
|
|
|
start_pos = torch.tensor(
|
|
[[tts_embeddings.size(1)]], device="cuda", dtype=torch.long
|
|
)
|
|
pos_emb = tts_text_pos_embedding.emb(start_pos)
|
|
start_emb = start_emb + pos_emb
|
|
start_emb = start_emb.repeat(batch_size, 1, 1)
|
|
|
|
if is_varlen_batch:
|
|
valid_embeddings = []
|
|
for i in range(batch_size):
|
|
emb_len = seq_lens[i] - 1
|
|
padding_len = tts_embeddings.size(1) - emb_len
|
|
valid_emb = tts_embeddings[i, padding_len:].unsqueeze(
|
|
0
|
|
) # [1, emb_len, hidden_dim]
|
|
valid_embeddings.append(
|
|
torch.cat([valid_emb, start_emb[i : i + 1]], dim=1)
|
|
)
|
|
full_embeddings = torch.cat(
|
|
valid_embeddings, dim=1
|
|
) # [1, total_tokens, hidden_dim]
|
|
else:
|
|
full_embeddings = torch.cat(
|
|
[tts_embeddings, start_emb], dim=1
|
|
) # [batch_size, seq_len, hidden_dim]
|
|
|
|
model_dtype = next(self.model.parameters()).dtype
|
|
if full_embeddings.dtype != model_dtype:
|
|
full_embeddings = full_embeddings.to(model_dtype)
|
|
|
|
hidden_states = self.model(
|
|
inputs_embeds=full_embeddings, return_dict=True
|
|
).last_hidden_state
|
|
|
|
else:
|
|
hidden_states = self.model(
|
|
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
|
|
).last_hidden_state
|
|
|
|
if is_varlen_batch:
|
|
context = get_forward_context()
|
|
cu_seqlens = context.cu_seqlens_q.cpu().tolist()
|
|
last_hidden = torch.stack(
|
|
[hidden_states[0, cu_seqlens[i + 1] - 1] for i in range(batch_size)]
|
|
)
|
|
else:
|
|
last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
|
|
|
|
reset_forward_context()
|
|
|
|
if self.lm_head is not None:
|
|
if last_hidden.dtype != next(self.lm_head.parameters()).dtype:
|
|
last_hidden = last_hidden.to(next(self.lm_head.parameters()).dtype)
|
|
logits = self.lm_head(last_hidden) # [batch_size, vocab_size]
|
|
else:
|
|
logits = self.model.compute_logits(last_hidden) # [batch_size, vocab_size]
|
|
|
|
temperatures = self._prepare_sample(sequences, temperature)
|
|
if temperature > 0:
|
|
first_token = self.sampler(logits, temperatures)
|
|
else:
|
|
first_token = torch.argmax(logits, dim=-1)
|
|
|
|
first_token_list = first_token.tolist()
|
|
|
|
generated_tokens = [[] for _ in range(batch_size)]
|
|
is_finished = [False] * batch_size
|
|
|
|
for i, token_id in enumerate(first_token_list):
|
|
if stop_tokens and token_id in stop_tokens:
|
|
is_finished[i] = True
|
|
else:
|
|
generated_tokens[i].append(token_id)
|
|
sequences[i].append_token(token_id)
|
|
self.kv_manager.append_to_seq(sequences[i])
|
|
|
|
if all(is_finished):
|
|
for req in sequences:
|
|
self.kv_manager.remove_seq(req)
|
|
self.current_sequences = []
|
|
|
|
output_ids = []
|
|
for i in range(batch_size):
|
|
full_sequence = input_ids[i].tolist() + generated_tokens[i]
|
|
output_ids.append(full_sequence)
|
|
|
|
output = torch.tensor(output_ids, dtype=torch.long, device=device)
|
|
return output
|
|
|
|
remaining_tokens = max_new_tokens - 1
|
|
|
|
for step in range(remaining_tokens):
|
|
decode_ids, decode_pos = self._prepare_decode(sequences)
|
|
|
|
context = get_forward_context()
|
|
hidden_states = self._run_decode_with_graph(
|
|
decode_ids,
|
|
decode_pos,
|
|
context,
|
|
tts_mel_embedding=tts_mel_embedding,
|
|
tts_text_pos_embedding=tts_text_pos_embedding,
|
|
)
|
|
|
|
# Get logits
|
|
if self.lm_head is not None:
|
|
logits = self.lm_head(hidden_states) # [batch_size, vocab_size]
|
|
else:
|
|
logits = self.model.compute_logits(
|
|
hidden_states
|
|
) # [batch_size, vocab_size]
|
|
|
|
reset_forward_context()
|
|
|
|
temperatures = self._prepare_sample(sequences, temperature)
|
|
if temperature > 0:
|
|
next_token = self.sampler(logits, temperatures)
|
|
else:
|
|
next_token = torch.argmax(logits, dim=-1)
|
|
next_token_list = next_token.tolist()
|
|
|
|
for i, token_id in enumerate(next_token_list):
|
|
if is_finished[i]:
|
|
continue
|
|
elif stop_tokens and token_id in stop_tokens:
|
|
is_finished[i] = True
|
|
else:
|
|
sequences[i].append_token(token_id)
|
|
self.kv_manager.append_to_seq(sequences[i])
|
|
generated_tokens[i].append(token_id)
|
|
|
|
if all(is_finished):
|
|
break
|
|
|
|
for req in sequences:
|
|
self.kv_manager.remove_seq(req)
|
|
self.current_sequences = []
|
|
|
|
pad_token = stop_tokens[0] if stop_tokens else 0
|
|
|
|
if is_varlen_batch:
|
|
max_prompt_len = attention_mask.size(1)
|
|
output_ids = []
|
|
|
|
for i in range(batch_size):
|
|
padding_len = max_prompt_len - seq_lens[i]
|
|
initial_tokens = sequences[i].token_ids[
|
|
: sequences[i].num_prompt_tokens
|
|
]
|
|
padded_prompt = [pad_token] * padding_len + initial_tokens
|
|
full_sequence = padded_prompt + generated_tokens[i]
|
|
output_ids.append(full_sequence)
|
|
else:
|
|
output_ids = [
|
|
sequences[i].token_ids[: sequences[i].num_prompt_tokens]
|
|
+ generated_tokens[i]
|
|
for i in range(batch_size)
|
|
]
|
|
|
|
max_length = max(len(seq) for seq in output_ids)
|
|
padded_output_ids = [
|
|
seq + [pad_token] * (max_length - len(seq)) for seq in output_ids
|
|
]
|
|
|
|
output = torch.tensor(padded_output_ids, dtype=torch.long, device=device)
|
|
|
|
assert output.size(0) == batch_size, (
|
|
f"Output batch size mismatch: {output.size(0)} != {batch_size}"
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
class Sampler(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@torch.compile
|
|
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
|
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
|
probs = torch.softmax(logits, dim=-1)
|
|
sample_tokens = probs.div_(
|
|
torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)
|
|
).argmax(dim=-1)
|
|
return sample_tokens |