Skip to content

Commit

Permalink
[pallas:triton] Temporarily reverted to the lowering using Triton IR
Browse files Browse the repository at this point in the history
The new lowering caused a performance regression internally.

PiperOrigin-RevId: 723934141
  • Loading branch information
superbobry authored and Google-ML-Automation committed Feb 6, 2025
1 parent 5d647cc commit efbb0af
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/_src/pallas/triton/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def pallas_call_lowering(
buf = io.BytesIO()
module_op.write_bytecode(buf)

if jaxlib_version < (0, 5, 1):
# TODO(b/394629193): Remove True once the bug is fixed.
if True and jaxlib_version < (0, 5, 1):
# AOT Triton compilation is only available on jaxlib 0.5.1+.
out_types = [
ir.RankedTensorType.get(bm.array_shape_dtype.shape,
Expand Down

0 comments on commit efbb0af

Please sign in to comment.