Skip to content

Commit

Permalink
Add validation in the writers that priorities are never nan.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 525400218
Change-Id: I27308c5cae43f71798a0e41afec4f88cb4dfd514
  • Loading branch information
acassirer authored and copybara-github committed Apr 19, 2023
1 parent 379da9a commit 6316be4
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 0 deletions.
4 changes: 4 additions & 0 deletions reverb/cc/trajectory_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "reverb/cc/trajectory_writer.h"

#include <cmath>
#include <limits>
#include <memory>
#include <vector>
Expand Down Expand Up @@ -430,6 +431,9 @@ absl::Status TrajectoryWriter::CreateItem(
[](const TrajectoryColumn& col) { return col.empty(); })) {
return absl::InvalidArgumentError("trajectory must not be empty.");
}
if (std::isnan(priority)) {
return absl::InvalidArgumentError("`priority` must not be nan.");
}

{
absl::MutexLock lock(&mu_);
Expand Down
20 changes: 20 additions & 0 deletions reverb/cc/trajectory_writer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
// limitations under the License.

#include "reverb/cc/trajectory_writer.h"

#include <grpcpp/support/status.h>

#include <cmath>
#include <limits>
#include <memory>
#include <string>
Expand Down Expand Up @@ -1358,6 +1360,24 @@ TEST(TrajectoryWriter, CreateItemValidatesSqueezedColumns) {
"exactly one row when squeeze is set but got 2."));
}

TEST(TrajectoryWriter, CreateItemValidatesPriorityIsNotNan) {
AsyncInterface success_stream;
auto stub = std::make_shared<MockReverbServiceAsyncStub>();
EXPECT_CALL(*stub, async()).WillRepeatedly(Return(&success_stream));

TrajectoryWriter writer(
stub, MakeOptions(/*max_chunk_length=*/1, /*num_keep_alive_refs=*/1));

StepRef step;
REVERB_ASSERT_OK(writer.Append(Step({MakeTensor(kIntSpec)}), &step));

auto status =
writer.CreateItem("table", std::nan("1"), MakeTrajectory({step}));
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(std::string(status.message()),
::testing::HasSubstr("`priority` must not be nan."));
}

class TrajectoryWriterSignatureValidationTest : public ::testing::Test {
protected:
void SetUp() override {
Expand Down
4 changes: 4 additions & 0 deletions reverb/cc/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "reverb/cc/writer.h"

#include <algorithm>
#include <cmath>
#include <iterator>
#include <memory>
#include <string>
Expand Down Expand Up @@ -169,6 +170,9 @@ absl::Status Writer::CreateItem(const std::string& table, int num_timesteps,
return absl::InvalidArgumentError(
"`num_timesteps` must be <= `max_timesteps`");
}
if (std::isnan(priority)) {
return absl::InvalidArgumentError("`priority` must not be nan.");
}

const internal::DtypesAndShapes* dtypes_and_shapes = nullptr;
REVERB_RETURN_IF_ERROR(GetFlatSignature(table, &dtypes_and_shapes));
Expand Down
20 changes: 20 additions & 0 deletions reverb/cc/writer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "reverb/cc/writer.h"

#include <algorithm>
#include <cmath>
#include <memory>
#include <queue>
#include <string>
Expand Down Expand Up @@ -854,6 +855,25 @@ TEST(WriterTest, WriteTimeStepsInconsistentShapeError) {
"dtype float and shape compatible with [5]"));
}

TEST(WriterTest, WriteNanPriorityError) {
std::vector<InsertStreamRequest> requests;
auto stub = MakeGoodStub(&requests);
Client client(stub);
std::unique_ptr<Writer> writer;
REVERB_EXPECT_OK(client.NewWriter(2, 6, /*delta_encoded=*/false, &writer));

REVERB_ASSERT_OK(writer->Append(
MakeTimestep(/*num_tensors=*/1, /*shape=*/tensorflow::TensorShape({1}))));
REVERB_ASSERT_OK(writer->Append(
MakeTimestep(/*num_tensors=*/1, /*shape=*/tensorflow::TensorShape({1}))));

auto status = writer->CreateItem("dist", 2, std::nan("1"));

EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(std::string(status.message()),
::testing::HasSubstr("`priority` must not be nan."));
}

TEST(WriterTest, WriteTimeStepsInconsistentShapeErrorAgainstBoundedSpec) {
std::vector<InsertStreamRequest> requests;
tensorflow::StructuredValue signature = MakeBoundedTensorSpecSignature(
Expand Down

0 comments on commit 6316be4

Please sign in to comment.