Skip to content

Commit

Permalink
Use assertAllClose to compare float32 arrays
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723676676
  • Loading branch information
ezhulenev authored and pax authors committed Feb 5, 2025
1 parent ccd78af commit ccd25a1
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions praxis/sample_decode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,10 @@ def decode_fn(model, input_ids, input_paddings):
# batch size is 1.
self.assertEqual(1, top_candidate_logprobs.shape[0])
self.assertEqual(1, top_candidate_ids.shape[0])
self.assertArraysEqual(
self.assertAllClose(
logprobs, top_candidate_logprobs[0, :, :, :num_per_token_logprobs]
)
self.assertArraysEqual(
self.assertAllClose(
ids, top_candidate_ids[0, :, :, :num_per_token_logprobs]
)

Expand Down

0 comments on commit ccd25a1

Please sign in to comment.