Skip to content

Commit

Permalink
Reverb: Rename PriorityTableExtension* to TableExtension* in the C++ …
Browse files Browse the repository at this point in the history
…code.

PiperOrigin-RevId: 312472894
Change-Id: Ie6a2ff488e4045eccdb7df8e997f2b6e564f5bd9
  • Loading branch information
acassirer authored and copybara-github committed May 20, 2020
1 parent adbd117 commit d9ceb4e
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 47 deletions.
2 changes: 1 addition & 1 deletion reverb/cc/reverb_service_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ std::unique_ptr<ReverbServiceImpl> MakeService(
absl::make_unique<RateLimiter>(kSamplesPerInsert, kMinSizeToSample,
kMinDiff, kMaxDiff),
/*extensions=*/
std::vector<std::shared_ptr<PriorityTableExtensionInterface>>{},
std::vector<std::shared_ptr<TableExtensionInterface>>{},
/*signature=*/absl::make_optional(MakeSignature())));
return absl::make_unique<ReverbServiceImpl>(std::move(tables),
std::move(checkpointer));
Expand Down
13 changes: 6 additions & 7 deletions reverb/cc/table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ namespace deepmind {
namespace reverb {
namespace {

using Extensions =
std::vector<std::shared_ptr<PriorityTableExtensionInterface>>;
using Extensions = std::vector<std::shared_ptr<TableExtensionInterface>>;

inline bool IsAdjacent(const SequenceRange& a, const SequenceRange& b) {
return a.episode_id() == b.episode_id() && a.end() + 1 == b.start();
Expand Down Expand Up @@ -262,7 +261,7 @@ void Table::DeleteItem(Table::Key key) {

tensorflow::Status Table::UpdateItem(
Key key, double priority,
std::initializer_list<PriorityTableExtensionInterface*> exclude) {
std::initializer_list<TableExtensionInterface*> exclude) {
auto it = data_.find(key);
if (it == data_.end()) {
return tensorflow::Status::OK();
Expand Down Expand Up @@ -363,15 +362,15 @@ const absl::flat_hash_map<Table::Key, Table::Item>* Table::RawLookup() {
}

void Table::UnsafeAddExtension(
std::shared_ptr<PriorityTableExtensionInterface> extension) {
std::shared_ptr<TableExtensionInterface> extension) {
TF_CHECK_OK(extension->RegisterTable(&mu_, this));
absl::WriterMutexLock lock(&mu_);
REVERB_CHECK(data_.empty());
extensions_.push_back(std::move(extension));
}

const std::vector<std::shared_ptr<PriorityTableExtensionInterface>>&
Table::extensions() const {
const std::vector<std::shared_ptr<TableExtensionInterface>>& Table::extensions()
const {
return extensions_;
}

Expand Down Expand Up @@ -403,7 +402,7 @@ int64_t Table::num_episodes() const {

tensorflow::Status Table::UnsafeUpdateItem(
Key key, double priority,
std::initializer_list<PriorityTableExtensionInterface*> exclude) {
std::initializer_list<TableExtensionInterface*> exclude) {
mu_.AssertHeld();
return UpdateItem(key, priority, std::move(exclude));
}
Expand Down
16 changes: 7 additions & 9 deletions reverb/cc/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ class Table {
Table(std::string name, std::shared_ptr<KeyDistributionInterface> sampler,
std::shared_ptr<KeyDistributionInterface> remover, int64_t max_size,
int32_t max_times_sampled, std::shared_ptr<RateLimiter> rate_limiter,
std::vector<std::shared_ptr<PriorityTableExtensionInterface>>
extensions = {},
std::vector<std::shared_ptr<TableExtensionInterface>> extensions = {},
absl::optional<tensorflow::StructuredValue> signature = absl::nullopt);

~Table();
Expand Down Expand Up @@ -170,12 +169,11 @@ class Table {
//
// Note! This method is not thread safe and caller is responsible for making
// sure that this method, nor any other method, is called concurrently.
void UnsafeAddExtension(
std::shared_ptr<PriorityTableExtensionInterface> extension);
void UnsafeAddExtension(std::shared_ptr<TableExtensionInterface> extension);

// Registered table extensions.
const std::vector<std::shared_ptr<PriorityTableExtensionInterface>>&
extensions() const;
const std::vector<std::shared_ptr<TableExtensionInterface>>& extensions()
const;

// Lookup a single item. Returns true if found, else false.
bool Get(Key key, Item* item) ABSL_LOCKS_EXCLUDED(mu_);
Expand Down Expand Up @@ -217,15 +215,15 @@ class Table {
// Asserts that `mu_` is held at runtime and calls UpdateItem.
tensorflow::Status UnsafeUpdateItem(
Key key, double priority,
std::initializer_list<PriorityTableExtensionInterface*> exclude)
std::initializer_list<TableExtensionInterface*> exclude)
ABSL_ASSERT_EXCLUSIVE_LOCK(mu_);

private:
// Updates item priority in `data_`, `samper_`, `remover_` and calls
// `OnUpdate` on all extensions not part of `exclude`.
tensorflow::Status UpdateItem(
Key key, double priority,
std::initializer_list<PriorityTableExtensionInterface*> exclude = {})
std::initializer_list<TableExtensionInterface*> exclude = {})
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);

// Deletes the item associated with the key from `data_`, `sampler_` and
Expand Down Expand Up @@ -263,7 +261,7 @@ class Table {

// Extensions implement hooks that are executed while holding `mu_` as part
// of insert, delete, update or reset operations.
std::vector<std::shared_ptr<PriorityTableExtensionInterface>> extensions_
std::vector<std::shared_ptr<TableExtensionInterface>> extensions_
ABSL_GUARDED_BY(mu_);

// Synchronizes access to `sampler_`, `remover_`, 'rate_limiter_`,
Expand Down
35 changes: 17 additions & 18 deletions reverb/cc/table_extensions/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
namespace deepmind {
namespace reverb {

tensorflow::Status PriorityTableExtensionBase::RegisterTable(absl::Mutex* mu,
Table* table) {
tensorflow::Status TableExtensionBase::RegisterTable(absl::Mutex* mu,
Table* table) {
absl::WriterMutexLock lock(&table_mu_);
if (table_) {
return tensorflow::errors::FailedPrecondition(
Expand All @@ -35,45 +35,44 @@ tensorflow::Status PriorityTableExtensionBase::RegisterTable(absl::Mutex* mu,
return tensorflow::Status::OK();
}

void PriorityTableExtensionBase::UnregisterTable(absl::Mutex* mu,
Table* table) {
void TableExtensionBase::UnregisterTable(absl::Mutex* mu, Table* table) {
absl::WriterMutexLock lock(&table_mu_);
REVERB_CHECK_EQ(table, table_)
<< "The wrong Table attempted to unregister this extension.";
table_ = nullptr;
}

void PriorityTableExtensionBase::OnDelete(absl::Mutex* mu,
const PriorityTableItem& item) {
void TableExtensionBase::OnDelete(absl::Mutex* mu,
const PriorityTableItem& item) {
ApplyOnDelete(item);
}

void PriorityTableExtensionBase::OnInsert(absl::Mutex* mu,
const PriorityTableItem& item) {
void TableExtensionBase::OnInsert(absl::Mutex* mu,
const PriorityTableItem& item) {
ApplyOnInsert(item);
}

void PriorityTableExtensionBase::OnReset(absl::Mutex* mu) { ApplyOnReset(); }
void TableExtensionBase::OnReset(absl::Mutex* mu) { ApplyOnReset(); }

void PriorityTableExtensionBase::OnUpdate(absl::Mutex* mu,
const PriorityTableItem& item) {
void TableExtensionBase::OnUpdate(absl::Mutex* mu,
const PriorityTableItem& item) {
ApplyOnUpdate(item);
}

void PriorityTableExtensionBase::OnSample(absl::Mutex* mu,
const PriorityTableItem& item) {
void TableExtensionBase::OnSample(absl::Mutex* mu,
const PriorityTableItem& item) {
ApplyOnSample(item);
}

void PriorityTableExtensionBase::ApplyOnDelete(const PriorityTableItem& item) {}
void TableExtensionBase::ApplyOnDelete(const PriorityTableItem& item) {}

void PriorityTableExtensionBase::ApplyOnInsert(const PriorityTableItem& item) {}
void TableExtensionBase::ApplyOnInsert(const PriorityTableItem& item) {}

void PriorityTableExtensionBase::ApplyOnReset() {}
void TableExtensionBase::ApplyOnReset() {}

void PriorityTableExtensionBase::ApplyOnUpdate(const PriorityTableItem& item) {}
void TableExtensionBase::ApplyOnUpdate(const PriorityTableItem& item) {}

void PriorityTableExtensionBase::ApplyOnSample(const PriorityTableItem& item) {}
void TableExtensionBase::ApplyOnSample(const PriorityTableItem& item) {}

} // namespace reverb
} // namespace deepmind
6 changes: 3 additions & 3 deletions reverb/cc/table_extensions/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@
namespace deepmind {
namespace reverb {

// Base implementation for PriorityTableExtensionInterface.
// Base implementation for TableExtensionInterface.
//
// This class implements table registration and all mutex protected On*-methods
// by delegating it to a "simpler" ApplyOn method. Children are thus able to
// implement any subset of the ApplyOn (and avoid the overly verbose API)
// without losing the safety provided by the static analysis of the mutexes.
//
class PriorityTableExtensionBase : public PriorityTableExtensionInterface {
class TableExtensionBase : public TableExtensionInterface {
public:
virtual ~PriorityTableExtensionBase() = default;
virtual ~TableExtensionBase() = default;

// Children should override these (noop by default).
virtual void ApplyOnDelete(const PriorityTableItem& item);
Expand Down
6 changes: 3 additions & 3 deletions reverb/cc/table_extensions/interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ namespace reverb {

class Table;

// A `PriorityTableExtension` is passed to a single `Table` and executed
// A `TableExtension` is passed to a single `Table` and executed
// as part of the atomic operations of the parent table. All "hooks" are
// executed while parent is holding its mutex and thus latency is very
// important.
class PriorityTableExtensionInterface {
class TableExtensionInterface {
public:
virtual ~PriorityTableExtensionInterface() = default;
virtual ~TableExtensionInterface() = default;

protected:
friend class Table;
Expand Down
8 changes: 3 additions & 5 deletions reverb/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,8 @@ PYBIND11_MODULE(libpybind, m) {
std::shared_ptr<HeapDistribution>>(m, "HeapDistribution")
.def(py::init<bool>(), py::arg("min_heap"));

py::class_<PriorityTableExtensionInterface,
std::shared_ptr<PriorityTableExtensionInterface>>
unused_priority_table_extension_interface(
m, "PriorityTableExtensionInterface");
py::class_<TableExtensionInterface, std::shared_ptr<TableExtensionInterface>>
unused_priority_table_extension_interface(m, "TableExtensionInterface");

py::class_<RateLimiter, std::shared_ptr<RateLimiter>>(m, "RateLimiter")
.def(py::init<double, int, double, double>(),
Expand All @@ -505,7 +503,7 @@ PYBIND11_MODULE(libpybind, m) {
int max_size, int max_times_sampled,
const std::shared_ptr<RateLimiter> &rate_limiter,
const std::vector<std::shared_ptr<
PriorityTableExtensionInterface>> &extensions,
TableExtensionInterface>> &extensions,
const absl::optional<std::string> &serialized_signature =
absl::nullopt) -> Table * {
absl::optional<tensorflow::StructuredValue> signature =
Expand Down
2 changes: 1 addition & 1 deletion reverb/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PriorityTableExtensionBase(metaclass=abc.ABCMeta):

@abc.abstractmethod
def build_internal_extensions(
self, table_name: str) -> List[pybind.PriorityTableExtensionInterface]:
self, table_name: str) -> List[pybind.TableExtensionInterface]:
"""Constructs the c++ PriorityTableExtensions."""


Expand Down

0 comments on commit d9ceb4e

Please sign in to comment.