Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix build, and build against TF 2.15.1 #1

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions reverb/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ reverb_cc_test(
deps = [
":tensor_compression",
"//reverb/cc/testing:tensor_testutil",
"//third_party/absl/status",
"//reverb/cc/platform:status_matchers",
] + reverb_tf_deps(),
)

Expand All @@ -88,8 +88,6 @@ reverb_cc_test(
"//reverb/cc/testing:proto_test_util",
"//reverb/cc/testing:tensor_testutil",
"//reverb/cc/testing:time_testutil",
"//third_party/absl/log:check",
"//third_party/grpc:grpc++",
] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(),
)

Expand Down Expand Up @@ -191,10 +189,6 @@ reverb_cc_library(
deps = [
"//reverb/cc/platform:logging",
"//reverb/cc/platform:snappy",
"//third_party/absl/status",
"//third_party/absl/status:statusor",
"//third_party/absl/strings",
"//third_party/tensorflow/tsl/platform:status",
] + reverb_tf_deps(),
)

Expand Down Expand Up @@ -241,8 +235,6 @@ reverb_cc_library(
"//reverb/cc/support:signature",
"//reverb/cc/support:tf_util",
"//reverb/cc/support:trajectory_util",
"//third_party/absl/log:check",
"//third_party/grpc:grpc++",
] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(),
)

Expand Down
1 change: 1 addition & 0 deletions reverb/cc/platform/default/build_rules.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def reverb_absl_deps():
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/random",
Expand Down
8 changes: 4 additions & 4 deletions reverb/cc/platform/default/repo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,11 @@ def googletest_deps():
def absl_deps():
http_archive(
name = "com_google_absl",
sha256 = "0320586856674d16b0b7a4d4afb22151bdc798490bb7f295eddd8f6a62b46fea", # SHARED_ABSL_SHA
strip_prefix = "abseil-cpp-fb3621f4f897824c0dbe0615fa94543df6192f30",
sha256 = "8eeec9382fc0338ef5c60053f3a4b0e0708361375fe51c9e65d0ce46ccfe55a7", # SHARED_ABSL_SHA
strip_prefix = "abseil-cpp-b971ac5250ea8de900eae9f95e06548d14cd95fe",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/fb3621f4f897824c0dbe0615fa94543df6192f30.tar.gz",
"https://github.com/abseil/abseil-cpp/archive/fb3621f4f897824c0dbe0615fa94543df6192f30.tar.gz",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/b971ac5250ea8de900eae9f95e06548d14cd95fe.tar.gz",
"https://github.com/abseil/abseil-cpp/archive/b971ac5250ea8de900eae9f95e06548d14cd95fe.tar.gz",
],
)

Expand Down
5 changes: 1 addition & 4 deletions reverb/cc/sampler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@
#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "third_party/grpc/include/grpcpp/client_context.h"
#include "third_party/grpc/include/grpcpp/impl/call_op_set.h"
#include "third_party/grpc/include/grpcpp/support/sync_stream.h"
#include "reverb/cc/chunk_store.h"
#include "reverb/cc/platform/logging.h"
#include "reverb/cc/platform/status_matchers.h"
Expand Down Expand Up @@ -703,7 +700,7 @@ TEST(GrpcSamplerTest, GetNextTimestepReturnsErrorIfNotDecomposible) {
auto* entry = response.mutable_entries(0);

// Add a column of length 10 to the existing one of length 5.
ASSERT_OK(CompressTensorAsProto(
REVERB_ASSERT_OK(CompressTensorAsProto(
MakeTensor(10), entry->add_data()->mutable_data()->add_tensors()));
auto* slice = entry->mutable_info()
->mutable_item()
Expand Down
4 changes: 2 additions & 2 deletions reverb/cc/support/trajectory_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ TEST(UnpackChunkColumn, SelectsCorrectColumn) {
tensorflow::Tensor second_col_tensor(static_cast<int32_t>(9000));

ChunkData data;
ASSERT_OK(CompressTensorAsProto(first_col_tensor,
REVERB_ASSERT_OK(CompressTensorAsProto(first_col_tensor,
data.mutable_data()->add_tensors()));
ASSERT_OK(CompressTensorAsProto(second_col_tensor,
REVERB_ASSERT_OK(CompressTensorAsProto(second_col_tensor,
data.mutable_data()->add_tensors()));
data.set_data_tensors_len(2);

Expand Down
25 changes: 15 additions & 10 deletions reverb/cc/tensor_compression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/status/status.h"
#include "reverb/cc/platform/status_matchers.h"
#include "reverb/cc/testing/tensor_testutil.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand All @@ -28,13 +29,17 @@
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/framework/variant_encode_decode.h"

namespace deepmind {
namespace reverb {
namespace {

using ::testing::HasSubstr;
using ::testing::status::StatusIs;

MATCHER_P2(StatusIs, code, message, "") {
return arg.code() == code && absl::StrContains(arg.message(), message);
}

template <typename T>
void EncodeMatchesDecodeT() {
Expand Down Expand Up @@ -75,9 +80,9 @@ TEST(TensorCompressionTest, StringTensor) {
tensor.flat<tensorflow::tstring>()(1) = "world";

tensorflow::TensorProto proto;
ASSERT_OK(CompressTensorAsProto(tensor, &proto));
REVERB_ASSERT_OK(CompressTensorAsProto(tensor, &proto));

ASSERT_OK_AND_ASSIGN(tensorflow::Tensor result,
TF_ASSERT_OK_AND_ASSIGN(tensorflow::Tensor result,
DecompressTensorFromProto(proto));
test::ExpectTensorEqual<tensorflow::tstring>(tensor, result);
}
Expand All @@ -88,9 +93,9 @@ TEST(TensorCompressionTest, NonStringTensor) {
tensor.flat<int>().setRandom();

tensorflow::TensorProto proto;
ASSERT_OK(CompressTensorAsProto(tensor, &proto));
REVERB_ASSERT_OK(CompressTensorAsProto(tensor, &proto));

ASSERT_OK_AND_ASSIGN(tensorflow::Tensor result,
TF_ASSERT_OK_AND_ASSIGN(tensorflow::Tensor result,
DecompressTensorFromProto(proto));
test::ExpectTensorEqual<int>(tensor, result);
}
Expand All @@ -101,9 +106,9 @@ TEST(TensorCompressionTest, NonStringTensorWithDeltaEncoding) {
tensor.flat<int>().setRandom();

tensorflow::TensorProto proto;
ASSERT_OK(CompressTensorAsProto(DeltaEncode(tensor, true), &proto));
REVERB_ASSERT_OK(CompressTensorAsProto(DeltaEncode(tensor, true), &proto));

ASSERT_OK_AND_ASSIGN(tensorflow::Tensor result,
TF_ASSERT_OK_AND_ASSIGN(tensorflow::Tensor result,
DecompressTensorFromProto(proto));
test::ExpectTensorEqual<int>(tensor, DeltaEncode(result, false));
}
Expand All @@ -120,16 +125,16 @@ TEST(TensorCompressionTest, CompressingVariantNotSupported) {
tensorflow::TensorProto proto;
EXPECT_THAT(CompressTensorAsProto(DeltaEncode(tensor, true), &proto),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("variant is not supported")));
"variant is not supported"));
}

TEST(TensorCompressionTest, DecompressingVariantNotSupported) {
tensorflow::TensorProto proto;
proto.set_dtype(tensorflow::DT_VARIANT);

EXPECT_THAT(DecompressTensorFromProto(proto),
EXPECT_THAT(DecompressTensorFromProto(proto).status(),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("variant is not supported")));
"variant is not supported"));
}

} // namespace
Expand Down
1 change: 0 additions & 1 deletion reverb/cc/testing/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ reverb_cc_library(
"//reverb/cc:schema_cc_proto",
"//reverb/cc:tensor_compression",
"//reverb/cc/platform:logging",
"//third_party/absl/log:check",
] + reverb_tf_deps(),
)

Expand Down
3 changes: 1 addition & 2 deletions reverb/cc/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/optional.h"
#include "third_party/grpc/include/grpcpp/client_context.h"
#include "third_party/grpc/include/grpcpp/impl/call_op_set.h"
#include "grpcpp/impl/codegen/client_context.h"
#include "reverb/cc/platform/hash_set.h"
#include "reverb/cc/platform/logging.h"
#include "reverb/cc/platform/status_macros.h"
Expand Down