Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add missing roll function #192

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Source/MLX/Documentation.docc/Organization/shapes.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ and ``MLXArray/shape`` of the dimensions without changing the number of elements
- ``flattened(_:start:end:stream:)``
- ``reshaped(_:_:stream:)-5x3y0``
- ``squeezed(_:axes:stream:)``
- ``roll(_:shift:axis:stream:)``
- ``roll(_:shift:axes:stream:)``

### MLXArray Shape Methods (Change Size)

Expand Down
48 changes: 48 additions & 0 deletions Source/MLX/Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2201,6 +2201,54 @@ public func remainder<A: ScalarOrArray, B: ScalarOrArray>(
return MLXArray(result)
}

/// Roll array elements along a given axis.
///
/// Elements that are rolled beyond the end of the array are introduced at the beggining and vice-versa.
///
/// - Parameters:
/// - a: input array
/// - shift: The number of places by which elements
/// are shifted. If positive the array is rolled to the right, if
/// negative it is rolled to the left.
/// - axis: the axis along which to roll the elements
/// - stream: stream or device to evaluate on
///
/// ### See Also
/// - <doc:shapes>
public func roll(_ a: MLXArray, shift: Int, axis: Int, stream: StreamOrDevice = .default)
-> MLXArray
{
var result = mlx_array_new()
mlx_roll(&result, a.ctx, shift.int32, [axis.int32], 1, stream.ctx)
return MLXArray(result)
}

/// Roll array elements along a given axis.
///
/// Elements that are rolled beyond the end of the array are introduced at the beggining and vice-versa.
///
/// - Parameters:
/// - a: input array
/// - shift: The number of places by which elements
/// are shifted. If positive the array is rolled to the right, if
/// negative it is rolled to the left.
/// - axes: the axes along which to roll the elements, or all if omitted
/// - stream: stream or device to evaluate on
///
/// ### See Also
/// - <doc:shapes>
public func roll(_ a: MLXArray, shift: Int, axes: [Int]? = nil, stream: StreamOrDevice = .default)
-> MLXArray
{
var result = mlx_array_new()
if let axes {
mlx_roll(&result, a.ctx, shift.int32, axes.asInt32, axes.count, stream.ctx)
} else {
mlx_roll_all(&result, a.ctx, shift.int32, stream.ctx)
}
return MLXArray(result)
}

/// Save array to a binary file in `.npy`format.
///
/// - Parameters:
Expand Down