mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-26 03:44:54 +08:00
182 lines
6.2 KiB
Python
182 lines
6.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
from transformers.models.gpt2.modeling_gpt2 import Conv1D, GPT2Block, GPT2Model
|
|
|
|
from .attention import Attention
|
|
|
|
|
|
class GPT2AccelAttention(nn.Module):
|
|
def __init__(self, config, layer_idx=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
|
|
max_positions = config.max_position_embeddings
|
|
self.register_buffer(
|
|
"bias",
|
|
torch.tril(
|
|
torch.ones((max_positions, max_positions), dtype=torch.bool)
|
|
).view(1, 1, max_positions, max_positions),
|
|
persistent=False,
|
|
)
|
|
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
|
|
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
self.split_size = self.embed_dim
|
|
|
|
if self.head_dim * self.num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
f" {self.num_heads})."
|
|
)
|
|
|
|
self.scale_attn_weights = config.scale_attn_weights
|
|
|
|
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
|
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
|
|
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
|
|
|
scale = (self.head_dim**-0.5) if self.scale_attn_weights else 1.0
|
|
self.accel_attn = Attention(
|
|
self.num_heads, self.head_dim, scale, self.num_heads
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
layer_past=None,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
use_cache=False,
|
|
output_attentions=False,
|
|
past_key_value=None,
|
|
**kwargs,
|
|
):
|
|
if encoder_hidden_states is not None:
|
|
raise NotImplementedError("Cross attention not supported in accel mode")
|
|
|
|
qkv = self.c_attn(hidden_states)
|
|
query, key, value = qkv.split(self.split_size, dim=2)
|
|
|
|
# [B, T, H*D] -> [B, H, T, D]
|
|
query = self._split_heads(query, self.num_heads, self.head_dim)
|
|
key = self._split_heads(key, self.num_heads, self.head_dim)
|
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
|
|
|
# flatten to [B*T, H, D]
|
|
bsz, num_heads, seq_len, head_dim = query.shape
|
|
q_flat = query.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
|
|
k_flat = key.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
|
|
v_flat = value.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
|
|
|
|
# ensure fp16
|
|
if q_flat.device.type == "cuda" and q_flat.dtype != torch.float16:
|
|
orig_dtype = q_flat.dtype
|
|
q_flat = q_flat.to(torch.float16)
|
|
k_flat = k_flat.to(torch.float16)
|
|
v_flat = v_flat.to(torch.float16)
|
|
else:
|
|
orig_dtype = q_flat.dtype
|
|
|
|
o_flat = self.accel_attn(q_flat, k_flat, v_flat) # [B*T, H, D]
|
|
|
|
if o_flat.dtype != orig_dtype:
|
|
o_flat = o_flat.to(orig_dtype)
|
|
|
|
# Reshape back: [B*T, H, D] -> [B, H, T, D]
|
|
attn_output = o_flat.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
|
|
|
|
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
|
|
|
attn_output = self.c_proj(attn_output)
|
|
attn_output = self.resid_dropout(attn_output)
|
|
|
|
outputs = (attn_output, None)
|
|
if output_attentions:
|
|
outputs += (None,)
|
|
|
|
return outputs
|
|
|
|
def _split_heads(self, tensor, num_heads, head_dim):
|
|
new_shape = tensor.size()[:-1] + (num_heads, head_dim)
|
|
tensor = tensor.view(new_shape)
|
|
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
|
|
|
def _merge_heads(self, tensor, num_heads, head_dim):
|
|
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
|
new_shape = tensor.size()[:-2] + (num_heads * head_dim,)
|
|
return tensor.view(new_shape)
|
|
|
|
|
|
class GPT2AccelBlock(GPT2Block):
|
|
def __init__(self, config, layer_idx=None):
|
|
super().__init__(config, layer_idx)
|
|
self.attn = GPT2AccelAttention(config, layer_idx)
|
|
|
|
|
|
class GPT2AccelModel(GPT2Model):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.h = nn.ModuleList(
|
|
[
|
|
GPT2AccelBlock(config, layer_idx=i)
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
past_key_values=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
use_cache=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
|
|
for block in self.h:
|
|
hidden_states = block(hidden_states)[0]
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
if return_dict:
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=None,
|
|
hidden_states=None,
|
|
attentions=None,
|
|
)
|
|
return (hidden_states,)
|
|
else:
|
|
return super().forward(
|
|
input_ids=input_ids,
|
|
past_key_values=None,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=None,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
use_cache=False,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|