diff --git a/reverb/cc/trajectory_writer.cc b/reverb/cc/trajectory_writer.cc index b5645ad..afc2478 100644 --- a/reverb/cc/trajectory_writer.cc +++ b/reverb/cc/trajectory_writer.cc @@ -14,6 +14,7 @@ #include "reverb/cc/trajectory_writer.h" +#include #include #include #include @@ -430,6 +431,9 @@ absl::Status TrajectoryWriter::CreateItem( [](const TrajectoryColumn& col) { return col.empty(); })) { return absl::InvalidArgumentError("trajectory must not be empty."); } + if (std::isnan(priority)) { + return absl::InvalidArgumentError("`priority` must not be nan."); + } { absl::MutexLock lock(&mu_); diff --git a/reverb/cc/trajectory_writer_test.cc b/reverb/cc/trajectory_writer_test.cc index 498d05c..1541520 100644 --- a/reverb/cc/trajectory_writer_test.cc +++ b/reverb/cc/trajectory_writer_test.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "reverb/cc/trajectory_writer.h" + #include +#include #include #include #include @@ -1358,6 +1360,24 @@ TEST(TrajectoryWriter, CreateItemValidatesSqueezedColumns) { "exactly one row when squeeze is set but got 2.")); } +TEST(TrajectoryWriter, CreateItemValidatesPriorityIsNotNan) { + AsyncInterface success_stream; + auto stub = std::make_shared(); + EXPECT_CALL(*stub, async()).WillRepeatedly(Return(&success_stream)); + + TrajectoryWriter writer( + stub, MakeOptions(/*max_chunk_length=*/1, /*num_keep_alive_refs=*/1)); + + StepRef step; + REVERB_ASSERT_OK(writer.Append(Step({MakeTensor(kIntSpec)}), &step)); + + auto status = + writer.CreateItem("table", std::nan("1"), MakeTrajectory({step})); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(std::string(status.message()), + ::testing::HasSubstr("`priority` must not be nan.")); +} + class TrajectoryWriterSignatureValidationTest : public ::testing::Test { protected: void SetUp() override { diff --git a/reverb/cc/writer.cc b/reverb/cc/writer.cc index b61cfa4..308aacf 100644 --- a/reverb/cc/writer.cc +++ b/reverb/cc/writer.cc @@ -15,6 +15,7 @@ #include "reverb/cc/writer.h" #include +#include #include #include #include @@ -169,6 +170,9 @@ absl::Status Writer::CreateItem(const std::string& table, int num_timesteps, return absl::InvalidArgumentError( "`num_timesteps` must be <= `max_timesteps`"); } + if (std::isnan(priority)) { + return absl::InvalidArgumentError("`priority` must not be nan."); + } const internal::DtypesAndShapes* dtypes_and_shapes = nullptr; REVERB_RETURN_IF_ERROR(GetFlatSignature(table, &dtypes_and_shapes)); diff --git a/reverb/cc/writer_test.cc b/reverb/cc/writer_test.cc index 212d616..e6997b4 100644 --- a/reverb/cc/writer_test.cc +++ b/reverb/cc/writer_test.cc @@ -15,6 +15,7 @@ #include "reverb/cc/writer.h" #include +#include #include #include #include @@ -854,6 +855,25 @@ TEST(WriterTest, WriteTimeStepsInconsistentShapeError) { "dtype float and shape compatible with [5]")); } +TEST(WriterTest, WriteNanPriorityError) { + std::vector requests; + auto stub = MakeGoodStub(&requests); + Client client(stub); + std::unique_ptr writer; + REVERB_EXPECT_OK(client.NewWriter(2, 6, /*delta_encoded=*/false, &writer)); + + REVERB_ASSERT_OK(writer->Append( + MakeTimestep(/*num_tensors=*/1, /*shape=*/tensorflow::TensorShape({1})))); + REVERB_ASSERT_OK(writer->Append( + MakeTimestep(/*num_tensors=*/1, /*shape=*/tensorflow::TensorShape({1})))); + + auto status = writer->CreateItem("dist", 2, std::nan("1")); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(std::string(status.message()), + ::testing::HasSubstr("`priority` must not be nan.")); +} + TEST(WriterTest, WriteTimeStepsInconsistentShapeErrorAgainstBoundedSpec) { std::vector requests; tensorflow::StructuredValue signature = MakeBoundedTensorSpecSignature(