Skip to content

Commit

Permalink
Merge pull request #27 from xingchensong/Mddct-sdpa
Browse files Browse the repository at this point in the history
support flash att
  • Loading branch information
xingchensong authored Dec 27, 2024
2 parents 0c87c76 + 1d90e04 commit f7ee8f6
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 47 deletions.
75 changes: 50 additions & 25 deletions s3tokenizer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@


@dataclass
class ModelDimensions:
class ModelConfig:
n_mels: int = 128
n_audio_ctx: int = 1500
n_audio_state: int = 1280
n_audio_head: int = 20
n_audio_layer: int = 6
n_codebook_size: int = 4096

use_sdpa: bool = False


class LayerNorm(nn.LayerNorm):

Expand Down Expand Up @@ -75,14 +77,16 @@ def sinusoids(length, channels, max_timescale=10000):

class MultiHeadAttention(nn.Module):

def __init__(self, n_state: int, n_head: int):
def __init__(self, n_state: int, n_head: int, use_sdpa: bool = False):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)

self.use_sdpa = use_sdpa

def forward(
self,
x: Tensor,
Expand All @@ -100,27 +104,44 @@ def qkv_attention(self,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None):
_, T, D = q.shape
_, _, D = q.shape
scale = (D // self.n_head)**-0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
k = k.view(*k.shape[:2], self.n_head, -1)
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

qk = q @ k # (B, n_head, T, T)
if mask is not None:
qk = qk + mask
qk = qk.float()

w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
if not self.use_sdpa:
k = k.permute(0, 2, 3, 1) * scale
qk = q @ k # (B, n_head, T, T)
if mask is not None:
qk = qk + mask
qk = qk.float()
w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1,
3).flatten(start_dim=2), qk.detach()
else:
k = k.permute(0, 2, 1, 3) * scale
assert mask is not None
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.,
scale=1.,
)
output = (output.transpose(1,
2).contiguous().view(q.size(0), -1, D)
) # (batch, time1, d_model)
return output, None


class ResidualAttentionBlock(nn.Module):

def __init__(self, n_state: int, n_head: int):
def __init__(self, n_state: int, n_head: int, use_sdpa: bool):
super().__init__()

self.attn = MultiHeadAttention(n_state, n_head)
self.attn = MultiHeadAttention(n_state, n_head, use_sdpa=use_sdpa)
self.attn_ln = LayerNorm(n_state)

n_mlp = n_state * 4
Expand Down Expand Up @@ -148,6 +169,7 @@ def __init__(
n_head: int,
n_layer: int,
stride: int,
use_sdpa: bool,
):
super().__init__()
self.stride = stride
Expand All @@ -163,8 +185,10 @@ def __init__(
padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([
ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa)
for _ in range(n_layer)
])

def forward(self, x: Tensor, x_len: Tensor) -> Tuple[Tensor, Tensor]:
"""
Expand Down Expand Up @@ -277,22 +301,23 @@ def decode(self, embed_ind: Tensor) -> Tensor:
class S3Tokenizer(nn.Module):
"""S3 tokenizer implementation (inference-only).
Args:
dims (ModelDimensions): Dimension
config (ModelConfig): Config
"""

def __init__(self, name: str, dims: ModelDimensions = ModelDimensions()):
def __init__(self, name: str, config: ModelConfig = ModelConfig()):
super().__init__()
self.dims = dims
self.config = config
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
self.config.n_mels,
self.config.n_audio_ctx,
self.config.n_audio_state,
self.config.n_audio_head,
self.config.n_audio_layer,
2 if name == "speech_tokenizer_v1_25hz" else 1,
self.config.use_sdpa,
)
self.quantizer = VectorQuantization(self.dims.n_audio_state,
self.dims.n_codebook_size)
self.quantizer = VectorQuantization(self.config.n_audio_state,
self.config.n_codebook_size)

def forward(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
return self.quantize(mel, mel_len)
Expand Down
73 changes: 51 additions & 22 deletions s3tokenizer/model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@


@dataclass
class ModelDimensions:
class ModelConfig:
n_mels: int = 128
n_audio_ctx: int = 1500
n_audio_state: int = 1280
n_audio_head: int = 20
n_audio_layer: int = 6
n_codebook_size: int = 3**8

use_sdpa: bool = False


def precompute_freqs_cis(dim: int,
end: int,
Expand Down Expand Up @@ -154,6 +156,7 @@ def __init__(
n_state: int,
n_head: int,
kernel_size: int = 31,
use_sdpa: bool = False,
):
super().__init__(n_state, n_head)

Expand All @@ -169,6 +172,8 @@ def __init__(
self.pad_fn = torch.nn.ConstantPad1d(
(self.left_padding, self.right_padding), 0.0)

self.use_sdpa = use_sdpa

def forward_fsmn(self,
inputs: torch.Tensor,
mask: Optional[torch.Tensor] = None):
Expand Down Expand Up @@ -202,16 +207,32 @@ def qkv_attention(self,
fsm_memory = self.forward_fsmn(v, mask_pad)

q = q.permute(0, 2, 1, 3) * scale
k = k.permute(0, 2, 3, 1) * scale
v = v.permute(0, 2, 1, 3)

qk = q @ k # (B, n_head, T, T)
if mask is not None:
qk = qk + mask
qk = qk.float()
w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1,
3).flatten(start_dim=2), qk.detach(), fsm_memory
if not self.use_sdpa:
k = k.permute(0, 2, 3, 1) * scale
qk = q @ k # (B, n_head, T, T)
if mask is not None:
qk = qk + mask
qk = qk.float()
w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(
0, 2, 1, 3).flatten(start_dim=2), qk.detach(), fsm_memory
else:
k = k.permute(0, 2, 1, 3) * scale
assert mask is not None
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.,
scale=1.,
)
output = (output.transpose(1,
2).contiguous().view(q.size(0), -1, D)
) # (batch, time1, d_model)
return output, None, fsm_memory

def forward(self,
x: torch.Tensor,
Expand All @@ -235,10 +256,14 @@ def __init__(
n_state: int,
n_head: int,
kernel_size: int = 31,
use_sdpa: bool = False,
):
super().__init__()

self.attn = FSMNMultiHeadAttention(n_state, n_head, kernel_size)
self.attn = FSMNMultiHeadAttention(n_state,
n_head,
kernel_size,
use_sdpa=use_sdpa)
self.attn_ln = LayerNorm(n_state, eps=1e-6)

n_mlp = n_state * 4
Expand Down Expand Up @@ -271,6 +296,7 @@ def __init__(
n_head: int,
n_layer: int,
stride: int,
use_sdpa: bool,
):
super().__init__()
self.stride = stride
Expand All @@ -286,8 +312,10 @@ def __init__(
stride=2,
padding=1)
self.freqs_cis = precompute_freqs_cis(64, 1024 * 2)
self.blocks = torch.nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])
self.blocks = torch.nn.ModuleList([
ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa)
for _ in range(n_layer)
])

def forward(self, x: torch.Tensor,
x_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -326,26 +354,27 @@ def forward(self, x: torch.Tensor,
class S3TokenizerV2(torch.nn.Module):
"""S3 tokenizer v2 implementation (inference-only).
Args:
dims (ModelDimensions): Dimension
config (ModelConfig): Config
"""

def __init__(self, name: str, dims: ModelDimensions = ModelDimensions()):
def __init__(self, name: str, config: ModelConfig = ModelConfig()):
super().__init__()
if 'v1' not in name:
assert 'v2' in name
# TODO(Mddct): make it configureable
dims.n_codebook_size = 3**8
self.dims = dims
config.n_codebook_size = 3**8
self.config = config
self.encoder = AudioEncoderV2(
self.dims.n_mels,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
self.config.n_mels,
self.config.n_audio_state,
self.config.n_audio_head,
self.config.n_audio_layer,
2,
self.config.use_sdpa,
)
self.quantizer = FSQVectorQuantization(
self.dims.n_audio_state,
self.dims.n_codebook_size,
self.config.n_audio_state,
self.config.n_codebook_size,
)

def forward(self, mel: torch.Tensor,
Expand Down

0 comments on commit f7ee8f6

Please sign in to comment.