diff --git a/reverb/cc/BUILD b/reverb/cc/BUILD index ef8b5b3..648fafc 100644 --- a/reverb/cc/BUILD +++ b/reverb/cc/BUILD @@ -142,7 +142,6 @@ reverb_cc_library( visibility = ["//reverb:__subpackages__"], deps = [ ":chunk_store", - ":priority_table_item", ":schema_cc_proto", "//reverb/cc/checkpointing:checkpoint_cc_proto", "//reverb/cc/distributions:interface", @@ -230,15 +229,6 @@ reverb_cc_library( ] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(), ) -reverb_cc_library( - name = "priority_table_item", - hdrs = ["priority_table_item.h"], - deps = [ - ":chunk_store", - ":schema_cc_proto", - ], -) - reverb_cc_proto_library( name = "schema_cc_proto", srcs = ["schema.proto"], diff --git a/reverb/cc/priority_table_item.h b/reverb/cc/priority_table_item.h deleted file mode 100644 index d1f93b3..0000000 --- a/reverb/cc/priority_table_item.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2019 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef REVERB_CC_PRIORITY_TABLE_ITEM_H_ -#define REVERB_CC_PRIORITY_TABLE_ITEM_H_ - -#include -#include - -#include "reverb/cc/chunk_store.h" -#include "reverb/cc/schema.pb.h" - -namespace deepmind { -namespace reverb { - -// Used for representing items of the priority distribution. See -// PrioritizedItem in schema.proto for documentation. -struct PriorityTableItem { - PrioritizedItem item; - std::vector> chunks; -}; - -} // namespace reverb -} // namespace deepmind - -#endif // REVERB_CC_PRIORITY_TABLE_ITEM_H_ diff --git a/reverb/cc/table.h b/reverb/cc/table.h index d431c11..39e984a 100644 --- a/reverb/cc/table.h +++ b/reverb/cc/table.h @@ -32,7 +32,6 @@ #include "reverb/cc/checkpointing/checkpoint.pb.h" #include "reverb/cc/chunk_store.h" #include "reverb/cc/distributions/interface.h" -#include "reverb/cc/priority_table_item.h" #include "reverb/cc/rate_limiter.h" #include "reverb/cc/schema.pb.h" #include "reverb/cc/table_extensions/interface.h" @@ -42,6 +41,13 @@ namespace deepmind { namespace reverb { +// Used for representing items of the priority distribution. See +// PrioritizedItem in schema.proto for documentation. +struct TableItem { + PrioritizedItem item; + std::vector> chunks; +}; + // A Table is a structure for storing `PriorityItem` objects. The Table uses two // instances of KeyDistributionInterface, one for sampling (sampler) and another // for removing (remover). PriorityItems are registered with both the sampler @@ -69,7 +75,7 @@ namespace reverb { class Table { public: using Key = KeyDistributionInterface::Key; - using Item = PriorityTableItem; + using Item = TableItem; // Used as the return of Sample(). Note that this returns the probability of // an item instead as opposed to the raw priority value. diff --git a/reverb/cc/table_extensions/BUILD b/reverb/cc/table_extensions/BUILD index b88c2b8..007f4c0 100644 --- a/reverb/cc/table_extensions/BUILD +++ b/reverb/cc/table_extensions/BUILD @@ -13,9 +13,8 @@ reverb_cc_library( name = "interface", hdrs = ["interface.h"], deps = [ - "//reverb/cc:priority_table_item", "//reverb/cc:schema_cc_proto", - ] + reverb_absl_deps(), + ] + reverb_absl_deps() + reverb_tf_deps(), ) reverb_cc_library( @@ -25,7 +24,6 @@ reverb_cc_library( deps = [ ":interface", "//reverb/cc:table", - "//reverb/cc:priority_table_item", "//reverb/cc:schema_cc_proto", "//reverb/cc/platform:logging", ] + reverb_absl_deps() + reverb_tf_deps(), diff --git a/reverb/cc/table_extensions/base.cc b/reverb/cc/table_extensions/base.cc index 16e7ff5..7f4752f 100644 --- a/reverb/cc/table_extensions/base.cc +++ b/reverb/cc/table_extensions/base.cc @@ -15,7 +15,6 @@ #include "reverb/cc/table_extensions/base.h" #include "reverb/cc/platform/logging.h" -#include "reverb/cc/priority_table_item.h" #include "reverb/cc/table.h" #include "tensorflow/core/platform/errors.h" @@ -42,37 +41,33 @@ void TableExtensionBase::UnregisterTable(absl::Mutex* mu, Table* table) { table_ = nullptr; } -void TableExtensionBase::OnDelete(absl::Mutex* mu, - const PriorityTableItem& item) { +void TableExtensionBase::OnDelete(absl::Mutex* mu, const TableItem& item) { ApplyOnDelete(item); } -void TableExtensionBase::OnInsert(absl::Mutex* mu, - const PriorityTableItem& item) { +void TableExtensionBase::OnInsert(absl::Mutex* mu, const TableItem& item) { ApplyOnInsert(item); } void TableExtensionBase::OnReset(absl::Mutex* mu) { ApplyOnReset(); } -void TableExtensionBase::OnUpdate(absl::Mutex* mu, - const PriorityTableItem& item) { +void TableExtensionBase::OnUpdate(absl::Mutex* mu, const TableItem& item) { ApplyOnUpdate(item); } -void TableExtensionBase::OnSample(absl::Mutex* mu, - const PriorityTableItem& item) { +void TableExtensionBase::OnSample(absl::Mutex* mu, const TableItem& item) { ApplyOnSample(item); } -void TableExtensionBase::ApplyOnDelete(const PriorityTableItem& item) {} +void TableExtensionBase::ApplyOnDelete(const TableItem& item) {} -void TableExtensionBase::ApplyOnInsert(const PriorityTableItem& item) {} +void TableExtensionBase::ApplyOnInsert(const TableItem& item) {} void TableExtensionBase::ApplyOnReset() {} -void TableExtensionBase::ApplyOnUpdate(const PriorityTableItem& item) {} +void TableExtensionBase::ApplyOnUpdate(const TableItem& item) {} -void TableExtensionBase::ApplyOnSample(const PriorityTableItem& item) {} +void TableExtensionBase::ApplyOnSample(const TableItem& item) {} } // namespace reverb } // namespace deepmind diff --git a/reverb/cc/table_extensions/base.h b/reverb/cc/table_extensions/base.h index 5fa4fe6..c5d4281 100644 --- a/reverb/cc/table_extensions/base.h +++ b/reverb/cc/table_extensions/base.h @@ -17,7 +17,6 @@ #include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" -#include "reverb/cc/priority_table_item.h" #include "reverb/cc/table.h" #include "reverb/cc/table_extensions/interface.h" @@ -36,11 +35,11 @@ class TableExtensionBase : public TableExtensionInterface { virtual ~TableExtensionBase() = default; // Children should override these (noop by default). - virtual void ApplyOnDelete(const PriorityTableItem& item); - virtual void ApplyOnInsert(const PriorityTableItem& item); + virtual void ApplyOnDelete(const TableItem& item); + virtual void ApplyOnInsert(const TableItem& item); virtual void ApplyOnReset(); - virtual void ApplyOnUpdate(const PriorityTableItem& item); - virtual void ApplyOnSample(const PriorityTableItem& item); + virtual void ApplyOnUpdate(const TableItem& item); + virtual void ApplyOnSample(const TableItem& item); protected: friend class Table; @@ -52,22 +51,22 @@ class TableExtensionBase : public TableExtensionInterface { ABSL_LOCKS_EXCLUDED(mu) override; // Delegates call to ApplyOnDelete. - void OnDelete(absl::Mutex* mu, const PriorityTableItem& item) override + void OnDelete(absl::Mutex* mu, const TableItem& item) override ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); // Delegates call to ApplyOnInsert. - void OnInsert(absl::Mutex* mu, const PriorityTableItem& item) override + void OnInsert(absl::Mutex* mu, const TableItem& item) override ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); // Delegates call to ApplyOnReset. void OnReset(absl::Mutex* mu) override ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); // Delegates call to ApplyOnUpdate. - void OnUpdate(absl::Mutex* mu, const PriorityTableItem& item) override + void OnUpdate(absl::Mutex* mu, const TableItem& item) override ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); // Delegates call to ApplyOnSample. - void OnSample(absl::Mutex* mu, const PriorityTableItem& item) override + void OnSample(absl::Mutex* mu, const TableItem& item) override ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); protected: diff --git a/reverb/cc/table_extensions/interface.h b/reverb/cc/table_extensions/interface.h index 7172aa8..717bafa 100644 --- a/reverb/cc/table_extensions/interface.h +++ b/reverb/cc/table_extensions/interface.h @@ -20,13 +20,14 @@ #include #include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" -#include "reverb/cc/priority_table_item.h" #include "reverb/cc/schema.pb.h" +#include "tensorflow/core/platform/status.h" namespace deepmind { namespace reverb { class Table; +class TableItem; // A `TableExtension` is passed to a single `Table` and executed // as part of the atomic operations of the parent table. All "hooks" are @@ -40,21 +41,21 @@ class TableExtensionInterface { friend class Table; // Executed just after item is inserted into parent `Table`. - virtual void OnInsert(absl::Mutex* mu, const PriorityTableItem& item) + virtual void OnInsert(absl::Mutex* mu, const TableItem& item) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu) = 0; // Executed just before item is removed from parent `Table`. - virtual void OnDelete(absl::Mutex* mu, const PriorityTableItem& item) + virtual void OnDelete(absl::Mutex* mu, const TableItem& item) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu) = 0; // Executed just after the priority of an item has been updated in parent // `Table`. - virtual void OnUpdate(absl::Mutex* mu, const PriorityTableItem& item) + virtual void OnUpdate(absl::Mutex* mu, const TableItem& item) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu) = 0; // Executed just before a sample is returned. The sample count of the item // includes the active sample and thus always is >= 1. - virtual void OnSample(absl::Mutex* mu, const PriorityTableItem& item) + virtual void OnSample(absl::Mutex* mu, const TableItem& item) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu) = 0; // Executed just before all items are deleted. diff --git a/reverb/cc/table_test.cc b/reverb/cc/table_test.cc index 279c36c..47f0f7e 100644 --- a/reverb/cc/table_test.cc +++ b/reverb/cc/table_test.cc @@ -51,9 +51,9 @@ using ::testing::SizeIs; MATCHER_P(HasItemKey, key, "") { return arg.item.key() == key; } -PriorityTableItem MakeItem(uint64_t key, double priority, - const std::vector& sequences) { - PriorityTableItem item; +TableItem MakeItem(uint64_t key, double priority, + const std::vector& sequences) { + TableItem item; std::vector data(sequences.size()); for (int i = 0; i < sequences.size(); i++) { @@ -66,7 +66,7 @@ PriorityTableItem MakeItem(uint64_t key, double priority, return item; } -PriorityTableItem MakeItem(uint64_t key, double priority) { +TableItem MakeItem(uint64_t key, double priority) { return MakeItem(key, priority, {testing::MakeSequenceRange(key * 100, 0, 1)}); } @@ -567,7 +567,7 @@ TEST(TableTest, GetExistingItem) { TF_EXPECT_OK(table->InsertOrAssign(MakeItem(2, 1))); TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 1))); - PriorityTableItem item; + TableItem item; EXPECT_TRUE(table->Get(2, &item)); EXPECT_THAT(item, HasItemKey(2)); } @@ -578,7 +578,7 @@ TEST(TableTest, GetMissingItem) { TF_EXPECT_OK(table->InsertOrAssign(MakeItem(1, 1))); TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 1))); - PriorityTableItem item; + TableItem item; EXPECT_FALSE(table->Get(2, &item)); }