diff --git a/tests/BUILD b/tests/BUILD index e1c055ed5bf5..07ba4b52fddb 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1518,7 +1518,7 @@ jax_multiplatform_test( ], shard_count = { "cpu": 4, - "gpu": 4, + "gpu": 6, "tpu": 4, }, tags = [