Skip to content

Commit

Permalink
...In some multihead models, the inputs are shaped [batch, seq, dims]…
Browse files Browse the repository at this point in the history
…. In these cases, the code should treat the last two axes as the batch and sequence axes....

PiperOrigin-RevId: 706118602
  • Loading branch information
The praxis Authors committed Dec 14, 2024
1 parent ecff295 commit edc9a6f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 14 deletions.
37 changes: 23 additions & 14 deletions praxis/layers/embedding_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,21 +1184,16 @@ def __call__(
Args:
inputs: The input sequence on which to apply the Rotary position
embedding. Since rotary position embeddings are applied to query and
keys after projection, it is assumed of shape [B, S, N, H].
keys after projection, it is assumed of shape [B, S, N, H] or [B, S, H].
position: Optional position JTensor which denotes the position of each
token in the sequence. This only needs to be supplied when the sequence
is packed. It is of shape [B, S].
Returns:
a JTensor of shape [B, S, N, H] which includes the inputs together with
the rotary position embedding incorporated in it.
a JTensor of shape [B, S, N, H] or [B, S, H] which includes the inputs
together with the rotary position embedding incorporated in it.
"""
if len(inputs.shape) != 4:
raise ValueError(
'Input is assumed to be a rank 4 tensor of shape'
'[batch, sequence, heads, dims].'
)
if self.embedding_dims != inputs.shape[3]:
if self.embedding_dims != inputs.shape[-1]:
raise ValueError(
'The embedding dims of the rotary position embedding'
'must match the hidden dimension of the inputs.'
Expand All @@ -1212,8 +1207,14 @@ def __call__(
if position is None:
seq_length = inputs.shape[1]
position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :]
position = position[:, :, jnp.newaxis, jnp.newaxis]
timescale = timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :]
if len(inputs.shape) == 4:
position = position[:, :, jnp.newaxis, jnp.newaxis]
timescale = timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :]
elif len(inputs.shape) == 3:
position = position[:, :, jnp.newaxis]
timescale = timescale[jnp.newaxis, jnp.newaxis, :]
else:
raise ValueError('Inputs must be of rank 3 or 4.')
sinusoid_inp = position / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
Expand All @@ -1227,15 +1228,19 @@ def __call__(
return jnp.concatenate([first_part, second_part], axis=-1)

def extend_step(
self, inputs: JTensor, position: int | JTensor | None = None
self,
inputs: JTensor,
position: int | JTensor | None = None,
) -> JTensor:
"""Generates a JTensor of sinusoids with different frequencies for a step.
Args:
inputs: The input sequence on which to apply the Rotary position
embedding. Since rotary position embeddings are applied to query and
keys after projection, it is assumed of shape [B, N, H] or of shape [B,
P, N, H] where P may be a prefix length.
P, N, H] where P may be a prefix length if using multi-head attention.
If using multi-query attention, the shape is [B, H] or [B, P, H] where P
may be a prefix length.
position: The position which is being decoded, this should correspond to
the logical position of the last token in the prefix window (P) in the
entire sequence length S. It is a scalar or having shape [B].
Expand All @@ -1244,10 +1249,12 @@ def extend_step(
a JTensor of the same shape as input with the rotary position embedding
incorporated in it.
"""
assert len(inputs.shape) in [3, 4]
assert len(inputs.shape) in [2, 3, 4]
inputs_shape = inputs.shape
if len(inputs_shape) == 3:
inputs = inputs[:, jnp.newaxis, :, :]
elif len(inputs_shape) == 2:
inputs = inputs[:, jnp.newaxis, :]
seq_length = inputs.shape[1]
# Adjust the prefix's position with position.
# Note that position may be a tracer rather than an int, and so we must use
Expand All @@ -1263,6 +1270,8 @@ def extend_step(
output = self(inputs, position=prefix_position)
if len(inputs_shape) == 3:
output = jnp.squeeze(output, axis=1)
elif len(inputs_shape) == 2:
output = jnp.squeeze(output, axis=1)
return output


Expand Down
57 changes: 57 additions & 0 deletions praxis/layers/embedding_softmax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,7 @@ def test_rotary_position_embedding_layer_prefix(
max_timescale=max_timescale,
)
pos_layer = instantiate(p)
# test the case when the input is rank 4.
inputs = np.random.normal(1.5, 2.5, (2, 8, 4, embedding_dims))
prng_key = jax.random.PRNGKey(seed=123)
initial_vars = pos_layer.init(prng_key, inputs)
Expand All @@ -953,6 +954,42 @@ def test_rotary_position_embedding_layer_prefix(
)
self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)

# test the case when the input is rank 3.
inputs_for_mqa = np.random.normal(1.5, 2.5, (2, 8, embedding_dims))[
:, :, jnp.newaxis, :
]
prng_key = jax.random.PRNGKey(seed=123)
initial_vars = pos_layer.init(
prng_key, jnp.squeeze(inputs_for_mqa, axis=-2)
)
output = pos_layer.apply(initial_vars, jnp.squeeze(inputs_for_mqa, axis=-2))
# Test whether extend_step returns same output.
for i in range(inputs_for_mqa.shape[1]):
start = max(0, i + 1 - window_size)
end = i + 1
inputs_prefix = inputs_for_mqa[:, start:end, :, :]
pad_width = window_size - end + start
paddings = [(0, 0), (pad_width, 0), (0, 0), (0, 0)]
inputs_prefix = jnp.pad(inputs_prefix, paddings)
jax_extend_step_out = pos_layer.apply(
initial_vars,
inputs_prefix,
position=i,
method=pos_layer.extend_step,
)
jax_extend_step_out = jax.lax.dynamic_slice_in_dim(
jax_extend_step_out,
start_index=window_size - 1,
slice_size=1,
axis=1,
)
jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out)
jax_np_extend_step_out = jnp.squeeze(jax_np_extend_step_out, axis=-2)
jax_fprop_slice = jax.lax.dynamic_slice_in_dim(
output, start_index=i, slice_size=1, axis=1
)
self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)

@parameterized.parameters((1, 10), (1, 1e5), (10, 20), (10, 1e5))
def test_rotary_position_embedding_layer_no_prefix(
self, min_timescale, max_timescale
Expand All @@ -966,6 +1003,7 @@ def test_rotary_position_embedding_layer_no_prefix(
max_timescale=max_timescale,
)
pos_layer = instantiate(p)
# test the case when the input is rank 4.
inputs = np.random.normal(1.5, 2.5, (2, 8, 4, embedding_dims))
prng_key = jax.random.PRNGKey(seed=123)
initial_vars = pos_layer.init(prng_key, inputs=inputs)
Expand All @@ -982,6 +1020,25 @@ def test_rotary_position_embedding_layer_no_prefix(
jax_fprop_slice = output[:, i, :, :]
self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)

# test the case when the input is rank 3.
inputs_for_mqa = np.random.normal(1.5, 2.5, (2, 8, embedding_dims))[
:, :, jnp.newaxis, :
]
prng_key = jax.random.PRNGKey(seed=123)
initial_vars = pos_layer.init(prng_key, inputs=inputs_for_mqa)
output = pos_layer.apply(initial_vars, inputs=inputs_for_mqa)
# Test whether extend_step returns same output.
for i in range(inputs_for_mqa.shape[1]):
jax_extend_step_out = pos_layer.apply(
initial_vars,
inputs_for_mqa[:, i, :],
position=i,
method=pos_layer.extend_step,
)
jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out)
jax_fprop_slice = output[:, i, :]
self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)

@parameterized.parameters(
([0, 1, 0, 1],),
([0, 1, 2, 3],),
Expand Down

0 comments on commit edc9a6f

Please sign in to comment.