You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Would it be possible to have MLX equivalents for jax.lax.scan and jax.lax.while_loop?
Matching algorithms like Hungarian matching for object detection (DETR) need boolean arrays to be evaluated to correctly compute optimal matchings. As far as my understanding goes, this would make it improbable to write an MLX specific implementation which can be compiled to optimise the end-to-end training process.
What I'm not sure about is how much performance benefit can be obtained if the loss computation (that depends on matching) itself can be compiled. From what I see on the documentation, compiling can yield significant performance improvements.
The text was updated successfully, but these errors were encountered:
Would it be possible to have MLX equivalents for jax.lax.scan and jax.lax.while_loop?
Matching algorithms like Hungarian matching for object detection (DETR) need boolean arrays to be evaluated to correctly compute optimal matchings. As far as my understanding goes, this would make it improbable to write an MLX specific implementation which can be compiled to optimise the end-to-end training process.
What I'm not sure about is how much performance benefit can be obtained if the loss computation (that depends on matching) itself can be compiled. From what I see on the documentation, compiling can yield significant performance improvements.
The text was updated successfully, but these errors were encountered: