Files
index-tts/indextts/accel/gpt2_accel.py
2025-10-24 08:15:00 +00:00

182 lines
6.2 KiB
Python

import torch
import torch.nn as nn
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import Conv1D, GPT2Block, GPT2Model
from .attention import Attention
class GPT2AccelAttention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(
torch.ones((max_positions, max_positions), dtype=torch.bool)
).view(1, 1, max_positions, max_positions),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale_attn_weights = config.scale_attn_weights
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
scale = (self.head_dim**-0.5) if self.scale_attn_weights else 1.0
self.accel_attn = Attention(
self.num_heads, self.head_dim, scale, self.num_heads
)
def forward(
self,
hidden_states: torch.Tensor,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
past_key_value=None,
**kwargs,
):
if encoder_hidden_states is not None:
raise NotImplementedError("Cross attention not supported in accel mode")
qkv = self.c_attn(hidden_states)
query, key, value = qkv.split(self.split_size, dim=2)
# [B, T, H*D] -> [B, H, T, D]
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# flatten to [B*T, H, D]
bsz, num_heads, seq_len, head_dim = query.shape
q_flat = query.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
k_flat = key.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
v_flat = value.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
# ensure fp16
if q_flat.device.type == "cuda" and q_flat.dtype != torch.float16:
orig_dtype = q_flat.dtype
q_flat = q_flat.to(torch.float16)
k_flat = k_flat.to(torch.float16)
v_flat = v_flat.to(torch.float16)
else:
orig_dtype = q_flat.dtype
o_flat = self.accel_attn(q_flat, k_flat, v_flat) # [B*T, H, D]
if o_flat.dtype != orig_dtype:
o_flat = o_flat.to(orig_dtype)
# Reshape back: [B*T, H, D] -> [B, H, T, D]
attn_output = o_flat.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, None)
if output_attentions:
outputs += (None,)
return outputs
def _split_heads(self, tensor, num_heads, head_dim):
new_shape = tensor.size()[:-1] + (num_heads, head_dim)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, head_dim):
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * head_dim,)
return tensor.view(new_shape)
class GPT2AccelBlock(GPT2Block):
def __init__(self, config, layer_idx=None):
super().__init__(config, layer_idx)
self.attn = GPT2AccelAttention(config, layer_idx)
class GPT2AccelModel(GPT2Model):
def __init__(self, config):
super().__init__(config)
self.h = nn.ModuleList(
[
GPT2AccelBlock(config, layer_idx=i)
for i in range(config.num_hidden_layers)
]
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if inputs_embeds is not None:
hidden_states = inputs_embeds
for block in self.h:
hidden_states = block(hidden_states)[0]
hidden_states = self.ln_f(hidden_states)
if return_dict:
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
)
return (hidden_states,)
else:
return super().forward(
input_ids=input_ids,
past_key_values=None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=False,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)