diff --git a/.bazelrc b/.bazelrc
new file mode 100644
index 0000000..587a0f9
--- /dev/null
+++ b/.bazelrc
@@ -0,0 +1,23 @@
+# gRPC using libcares in opensource has some issues.
+build --define=grpc_no_ares=true
+
+# Suppress all warning messages.
+build:short_logs --output_filter=DONT_MATCH_ANYTHING
+
+# Force python3
+build --action_env=PYTHON_BIN_PATH=python3
+build --repo_env=PYTHON_BIN_PATH=python3
+build --python_path=python3
+
+build:manylinux2010 --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain
+
+build -c opt
+build --cxxopt="-std=c++14"
+build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
+build --auto_output_filter=subpackages
+build --copt="-Wall" --copt="-Wno-sign-compare"
+build --linkopt="-lrt -lm"
+
+# TF isn't built in dbg mode, so our dbg builds will segfault due to inconsistency
+# of defines when using tf's headers. In particular in refcount.h.
+build --cxxopt="-DNDEBUG"
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..db177d4
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,28 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to this project. There are
+just a few small guidelines you need to follow.
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution;
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code reviews
+
+All submissions, including submissions by project members, require review. We
+use GitHub pull requests for this purpose. Consult
+[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
+
+## Community Guidelines
+
+This project follows
+[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..8018c87
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2018, DeepMind Technologies Limited.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..679e11b
--- /dev/null
+++ b/README.md
@@ -0,0 +1,5 @@
+# Reverb
+
+Reverb is a service for data transport (and storage) that is used for machine
+learning research. One particularly common use is as a prioritized experience
+replay system in reinforcement learning algorithms.
diff --git a/WORKSPACE b/WORKSPACE
new file mode 100644
index 0000000..ceb6840
--- /dev/null
+++ b/WORKSPACE
@@ -0,0 +1,86 @@
+workspace(name = "reverb")
+
+# To change to a version of protoc compatible with tensorflow:
+# 1. Convert the required header version to a version string, e.g.:
+# 3011004 => "3.11.4"
+# 2. Calculate the sha256 of the binary:
+# PROTOC_VERSION="3.11.4"
+# curl -L "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-x86_64.zip" | sha256
+# 3. Update the two variables below.
+#
+# Alternatively, run bazel with the environment var "REVERB_PROTOC_VERSION"
+# set to override PROTOC_VERSION.
+#
+# *WARNING* If using the REVERB_PROTOC_VERSION environment variable, sha256
+# checking is disabled. Use at your own risk.
+PROTOC_VERSION = "3.9.0"
+PROTOC_SHA256 = "15e395b648a1a6dda8fd66868824a396e9d3e89bc2c8648e3b9ab9801bea5d55"
+
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+
+http_archive(
+ name = "com_google_googletest",
+ sha256 = "ff7a82736e158c077e76188232eac77913a15dac0b22508c390ab3f88e6d6d86",
+ strip_prefix = "googletest-b6cd405286ed8635ece71c72f118e659f4ade3fb",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/googletest/archive/b6cd405286ed8635ece71c72f118e659f4ade3fb.zip",
+ "https://github.com/google/googletest/archive/b6cd405286ed8635ece71c72f118e659f4ade3fb.zip",
+ ],
+)
+
+http_archive(
+ name = "com_google_absl",
+ sha256 = "f368a8476f4e2e0eccf8a7318b98dafbe30b2600f4e3cf52636e5eb145aba06a", # SHARED_ABSL_SHA
+ strip_prefix = "abseil-cpp-df3ea785d8c30a9503321a3d35ee7d35808f190d",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz",
+ ],
+)
+
+## Begin GRPC related deps
+http_archive(
+ name = "com_github_grpc_grpc",
+ patch_cmds = [
+ """sed -i.bak 's/"python",/"python3",/g' third_party/py/python_configure.bzl""",
+ """sed -i.bak 's/PYTHONHASHSEED=0/PYTHONHASHSEED=0 python3/g' bazel/cython_library.bzl""",
+ ],
+ sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f",
+ strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz",
+ "https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz",
+ ],
+)
+
+load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
+
+grpc_deps()
+
+load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
+
+bazel_version_repository(
+ name = "bazel_version",
+)
+
+load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies")
+
+apple_rules_dependencies()
+
+load("@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies")
+
+apple_support_dependencies()
+## End GRPC related deps
+
+load(
+ "//reverb/cc/platform/default:repo.bzl",
+ "cc_tf_configure",
+ "reverb_protoc_deps",
+ "reverb_python_deps",
+)
+
+cc_tf_configure()
+
+reverb_python_deps()
+
+reverb_protoc_deps(version = PROTOC_VERSION, sha256 = PROTOC_SHA256)
diff --git a/docker/dev.dockerfile b/docker/dev.dockerfile
new file mode 100644
index 0000000..7a90ef9
--- /dev/null
+++ b/docker/dev.dockerfile
@@ -0,0 +1,97 @@
+# Run the following commands in order:
+#
+# REVERB_DIR="/tmp/reverb" # (change to the cloned reverb directory, e.g. "$HOME/reverb")
+# docker build --tag tensorflow:reverb - < "$REVERB_DIR/docker/dev.dockerfile"
+# docker run --rm -it -v ${REVERB_DIR}:/tmp/reverb \
+# -v ${HOME}/.gitconfig:/home/${USER}/.gitconfig:ro \
+# --name reverb tensorflow:reverb bash
+#
+# Test that everything worked:
+#
+# bazel test -c opt --test_output=streamed //reverb:tf_client_test
+
+ARG cpu_base_image="ubuntu:18.04"
+ARG base_image=$cpu_base_image
+FROM $base_image
+
+LABEL maintainer="Reverb Team "
+
+# Re-declare args because the args declared before FROM can't be used in any
+# instruction after a FROM.
+ARG cpu_base_image="ubuntu:18.04"
+ARG base_image=$cpu_base_image
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends software-properties-common
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ aria2 \
+ build-essential \
+ curl \
+ git \
+ less \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng-dev \
+ libzmq3-dev \
+ lsof \
+ pkg-config \
+ python3-distutils \
+ python3-dev \
+ python3.6-dev \
+ rename \
+ rsync \
+ sox \
+ unzip \
+ vim \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN curl -O https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && rm get-pip.py
+
+ARG bazel_version=2.2.0
+# This is to install bazel, for development purposes.
+ENV BAZEL_VERSION ${bazel_version}
+RUN mkdir /bazel && \
+ cd /bazel && \
+ curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
+ chmod +x bazel-*.sh && \
+ ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ cd / && \
+ rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
+
+# Add support for Bazel autocomplete, see
+# https://docs.bazel.build/versions/master/completion.html for instructions.
+RUN cp /usr/local/lib/bazel/bin/bazel-complete.bash /etc/bash_completion.d
+# TODO(b/154932078): This line should go into bashrc.
+# NOTE(ebrevdo): RUN source doesn't work. Disabling the command below for now.
+# RUN source /etc/bash_autcompletion.d/bazel-complete.bash
+
+
+ARG pip_dependencies=' \
+ contextlib2 \
+ dm-tree \
+ google-api-python-client \
+ h5py \
+ numpy \
+ oauth2client \
+ pandas \
+ portpicker'
+
+RUN pip3 --no-cache-dir install $pip_dependencies
+
+# The latest tensorflow requires CUDA 10 compatible nvidia drivers (410.xx).
+# If you are unable to update your drivers, an alternative is to compile
+# tensorflow from source instead of installing from pip.
+# Ensure we install the correct version by uninstalling first.
+RUN pip3 uninstall -y tensorflow tensorflow-gpu tf-nightly tf-nightly-gpu
+
+RUN pip3 --no-cache-dir install tf-nightly --upgrade
+
+# bazel assumes the python executable is "python".
+RUN ln -s /usr/bin/python3 /usr/bin/python
+
+WORKDIR "/tmp/reverb"
+
+CMD ["/bin/bash"]
diff --git a/docker/release.dockerfile b/docker/release.dockerfile
new file mode 100644
index 0000000..e50a88c
--- /dev/null
+++ b/docker/release.dockerfile
@@ -0,0 +1,74 @@
+# Run the following commands in order:
+#
+# REVERB_DIR="/tmp/reverb" # (change to the cloned reverb directory, e.g. "$HOME/reverb")
+# docker build --tag tensorflow:reverb_release - < "$REVERB_DIR/docker/release.dockerfile"
+# docker run --rm -it -v ${REVERB_DIR}:/tmp/reverb \
+# -v ${HOME}/.gitconfig:/home/${USER}/.gitconfig:ro \
+# --name reverb_release tensorflow:reverb_release bash
+#
+# Test that everything worked:
+#
+# bazel test -c opt --copt=-mavx --config=manylinux2010 --test_output=errors //reverb/...
+
+ARG cpu_base_image="tensorflow/tensorflow:2.1.0-custom-op-ubuntu16"
+ARG base_image=$cpu_base_image
+FROM $base_image
+
+LABEL maintainer="Reverb Team "
+
+# Re-declare args because the args declared before FROM can't be used in any
+# instruction after a FROM.
+ARG cpu_base_image="tensorflow/tensorflow:2.1.0-custom-op-ubuntu16"
+ARG base_image=$cpu_base_image
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends software-properties-common
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ aria2 \
+ build-essential \
+ curl \
+ git \
+ less \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng-dev \
+ libzmq3-dev \
+ lsof \
+ pkg-config \
+ python3-dev \
+ python3.6-dev \
+ python3.7-dev \
+ python3.8-dev \
+ rename \
+ rsync \
+ sox \
+ unzip \
+ vim \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN curl -O https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && rm get-pip.py
+
+ARG pip_dependencies=' \
+ absl-py \
+ contextlib2 \
+ dm-tree \
+ google-api-python-client \
+ h5py \
+ numpy \
+ oauth2client \
+ pandas \
+ portpicker'
+
+# TODO(b/154930404): Update to 2.2.0 once it's out. May need to
+# cut a branch to make changes that allow us to build against 2.2.0 instead
+# of tf-nightly due to API changes.
+RUN pip3 uninstall -y tensorflow tensorflow-gpu tf-nightly tf-nightly-gpu
+RUN pip3 --no-cache-dir install tf-nightly --upgrade
+
+RUN pip3 --no-cache-dir install $pip_dependencies
+
+WORKDIR "/tmp/reverb"
+
+CMD ["/bin/bash"]
diff --git a/requirements-no-deps.txt b/requirements-no-deps.txt
new file mode 100644
index 0000000..c24d6dc
--- /dev/null
+++ b/requirements-no-deps.txt
@@ -0,0 +1,5 @@
+# These packages are required by TensorFlow, but they need to be installed
+# without their dependencies.
+# To install: python -m pip install --no-deps -r requirements.txt.
+keras_applications
+keras_preprocessing
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..fc3b214
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,7 @@
+dm-tree>=0.1.1
+numpy
+wheel
+setuptools
+mock
+future>=0.17.1
+portpicker
diff --git a/reverb/BUILD b/reverb/BUILD
new file mode 100644
index 0000000..33953d5
--- /dev/null
+++ b/reverb/BUILD
@@ -0,0 +1,216 @@
+# Description: Reverb is an efficient and easy to use prioritized replay system designed for ML research.
+
+load(
+ "//reverb/cc/platform:build_rules.bzl",
+ "reverb_absl_deps",
+ "reverb_py_test",
+ "reverb_pybind_deps",
+ "reverb_pybind_extension",
+ "reverb_pytype_library",
+ "reverb_pytype_strict_library",
+)
+
+package(default_visibility = [":__subpackages__"])
+
+licenses(["notice"])
+
+exports_files(["LICENSE"])
+
+reverb_pytype_strict_library(
+ name = "reverb",
+ srcs = ["__init__.py"],
+ srcs_version = "PY3",
+ deps = [
+ ":client",
+ ":distributions",
+ ":errors",
+ ":rate_limiters",
+ ":replay_sample",
+ ":server",
+ ":tf_client",
+ ],
+)
+
+reverb_pytype_strict_library(
+ name = "distributions",
+ srcs = ["distributions.py"],
+ srcs_version = "PY3",
+ deps = [":pybind"],
+)
+
+reverb_pytype_strict_library(
+ name = "rate_limiters",
+ srcs = ["rate_limiters.py"],
+ srcs_version = "PY3",
+ deps = [":pybind"],
+)
+
+reverb_pytype_library(
+ name = "client",
+ srcs = ["client.py"],
+ srcs_version = "PY3",
+ strict_deps = True,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":errors",
+ ":pybind",
+ ":replay_sample",
+ ":reverb_types",
+ "//reverb/cc:schema_py_pb2",
+ ],
+)
+
+reverb_pytype_library(
+ name = "errors",
+ srcs = ["errors.py"],
+ srcs_version = "PY3",
+ strict_deps = True,
+ visibility = ["//visibility:public"],
+ deps = [],
+)
+
+reverb_pytype_library(
+ name = "server",
+ srcs = ["server.py"],
+ srcs_version = "PY3",
+ strict_deps = True,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":checkpointer",
+ ":client",
+ ":distributions",
+ ":pybind",
+ ":rate_limiters",
+ ":reverb_types",
+ ],
+)
+
+reverb_pytype_library(
+ name = "replay_sample",
+ srcs = ["replay_sample.py"],
+ srcs_version = "PY3",
+ strict_deps = True,
+)
+
+reverb_pybind_extension(
+ name = "pybind",
+ srcs = ["pybind.cc"],
+ module_name = "libpybind",
+ srcs_version = "PY3ONLY",
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ "//reverb/cc:priority_table",
+ "//reverb/cc:replay_client",
+ "//reverb/cc:replay_writer",
+ "//reverb/cc:replay_sampler",
+ "//reverb/cc:reverb_server",
+ "//reverb/cc/checkpointing:interface",
+ "//reverb/cc/distributions:fifo",
+ "//reverb/cc/distributions:heap",
+ "//reverb/cc/distributions:interface",
+ "//reverb/cc/distributions:lifo",
+ "//reverb/cc/distributions:prioritized",
+ "//reverb/cc/distributions:uniform",
+ "//reverb/cc/platform:checkpointing",
+ "//reverb/cc/table_extensions:interface",
+ ] + reverb_pybind_deps() + reverb_absl_deps(),
+)
+
+reverb_pytype_library(
+ name = "tf_client",
+ srcs = ["tf_client.py"],
+ srcs_version = "PY3",
+ strict_deps = True,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":replay_sample",
+ "//reverb/cc/ops:gen_client_ops",
+ "//reverb/cc/ops:gen_dataset_op",
+ ],
+)
+
+reverb_pytype_library(
+ name = "checkpointer",
+ srcs = ["checkpointer.py"],
+ srcs_version = "PY3",
+ strict_deps = True,
+ visibility = ["//visibility:public"],
+ deps = [":pybind"],
+)
+
+reverb_pytype_library(
+ name = "reverb_types",
+ srcs = ["reverb_types.py"],
+ srcs_version = "PY3",
+ strict_deps = True,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":pybind",
+ "//reverb/cc:schema_py_pb2",
+ ],
+)
+
+reverb_py_test(
+ name = "checkpointer_test",
+ srcs = ["checkpointer_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":checkpointer",
+ ":pybind",
+ ],
+)
+
+reverb_py_test(
+ name = "client_test",
+ srcs = ["client_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":client",
+ ":distributions",
+ ":errors",
+ ":rate_limiters",
+ ":server",
+ ],
+)
+
+reverb_py_test(
+ name = "server_test",
+ srcs = ["server_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":client",
+ ":distributions",
+ ":rate_limiters",
+ ":server",
+ ],
+)
+
+reverb_py_test(
+ name = "tf_client_test",
+ timeout = "short",
+ srcs = ["tf_client_test.py"],
+ python_version = "PY3",
+ shard_count = 6,
+ deps = [
+ ":client",
+ ":distributions",
+ ":rate_limiters",
+ ":replay_sample",
+ ":server",
+ ":tf_client",
+ ],
+)
+
+reverb_py_test(
+ name = "rate_limiters_test",
+ srcs = ["rate_limiters_test.py"],
+ python_version = "PY3",
+ deps = [":rate_limiters"],
+)
+
+reverb_py_test(
+ name = "pybind_test",
+ srcs = ["pybind_test.py"],
+ python_version = "PY3",
+ deps = [":reverb"],
+)
diff --git a/reverb/__init__.py b/reverb/__init__.py
new file mode 100644
index 0000000..a90da56
--- /dev/null
+++ b/reverb/__init__.py
@@ -0,0 +1,33 @@
+# Copyright 2019 DeepMind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Reverb."""
+
+from reverb import distributions
+from reverb import rate_limiters
+
+from reverb.client import Client
+from reverb.client import Writer
+
+from reverb.errors import ReverbError
+from reverb.errors import TimeoutError
+
+from reverb.replay_sample import ReplaySample
+from reverb.replay_sample import SampleInfo
+
+from reverb.server import PriorityTable
+from reverb.server import Server
+
+from reverb.tf_client import ReplayDataset
+from reverb.tf_client import TFClient
diff --git a/reverb/cc/BUILD b/reverb/cc/BUILD
new file mode 100644
index 0000000..7c8d884
--- /dev/null
+++ b/reverb/cc/BUILD
@@ -0,0 +1,290 @@
+# Description: Reverb is an efficient and easy to use prioritized replay system designed for ML research.
+
+load(
+ "//reverb/cc/platform:build_rules.bzl",
+ "reverb_absl_deps",
+ "reverb_cc_grpc_library",
+ "reverb_cc_library",
+ "reverb_cc_proto_library",
+ "reverb_cc_test",
+ "reverb_grpc_deps",
+ "reverb_py_proto_library",
+ "reverb_tf_deps",
+)
+
+package(default_visibility = ["//reverb:__subpackages__"])
+
+licenses(["notice"])
+
+exports_files(["LICENSE"])
+
+reverb_cc_test(
+ name = "chunk_store_test",
+ srcs = ["chunk_store_test.cc"],
+ deps = [
+ ":chunk_store",
+ ":schema_cc_proto",
+ "//reverb/cc/platform:thread",
+ "//reverb/cc/testing:proto_test_util",
+ ] + reverb_tf_deps(),
+)
+
+reverb_cc_test(
+ name = "rate_limiter_test",
+ srcs = ["rate_limiter_test.cc"],
+ deps = [
+ ":priority_table",
+ "//reverb/cc/distributions:uniform",
+ "//reverb/cc/platform:thread",
+ "//reverb/cc/testing:proto_test_util",
+ ] + reverb_absl_deps(),
+)
+
+reverb_cc_test(
+ name = "priority_table_test",
+ srcs = ["priority_table_test.cc"],
+ deps = [
+ ":chunk_store",
+ ":priority_table",
+ ":schema_cc_proto",
+ "//reverb/cc/checkpointing:checkpoint_cc_proto",
+ "//reverb/cc/distributions:fifo",
+ "//reverb/cc/distributions:uniform",
+ "//reverb/cc/platform:thread",
+ "//reverb/cc/testing:proto_test_util",
+ ] + reverb_tf_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_test(
+ name = "tensor_compression_test",
+ srcs = ["tensor_compression_test.cc"],
+ deps = [
+ ":tensor_compression",
+ "//reverb/cc/testing:tensor_testutil",
+ ] + reverb_tf_deps(),
+)
+
+reverb_cc_test(
+ name = "replay_sampler_test",
+ srcs = ["replay_sampler_test.cc"],
+ deps = [
+ ":replay_sampler",
+ ":replay_service_cc_grpc_proto",
+ ":replay_service_cc_proto",
+ ":tensor_compression",
+ "//reverb/cc/platform:logging",
+ "//reverb/cc/testing:tensor_testutil",
+ "//reverb/cc/testing:time_testutil",
+ ] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_test(
+ name = "replay_writer_test",
+ srcs = ["replay_writer_test.cc"],
+ deps = [
+ ":replay_client",
+ ":replay_service_cc_grpc_proto",
+ ":replay_writer",
+ "//reverb/cc/support:grpc_util",
+ "//reverb/cc/support:uint128",
+ "//reverb/cc/testing:proto_test_util",
+ ] + reverb_tf_deps() + reverb_grpc_deps(),
+)
+
+reverb_cc_test(
+ name = "replay_client_test",
+ srcs = ["replay_client_test.cc"],
+ deps = [
+ ":replay_client",
+ ":replay_service_cc_grpc_proto",
+ ":replay_service_cc_proto",
+ "//reverb/cc/support:uint128",
+ "//reverb/cc/testing:proto_test_util",
+ ] + reverb_tf_deps() + reverb_grpc_deps(),
+)
+
+reverb_cc_test(
+ name = "reverb_server_test",
+ srcs = ["reverb_server_test.cc"],
+ deps = [
+ ":reverb_server",
+ "//reverb/cc/platform:net",
+ ] + reverb_tf_deps() + reverb_grpc_deps(),
+)
+
+reverb_cc_test(
+ name = "replay_service_impl_test",
+ srcs = ["replay_service_impl_test.cc"],
+ deps = [
+ ":replay_service_cc_proto",
+ ":replay_service_impl",
+ ":schema_cc_proto",
+ "//reverb/cc/distributions:fifo",
+ "//reverb/cc/distributions:uniform",
+ "//reverb/cc/platform:checkpointing",
+ "//reverb/cc/platform:thread",
+ "//reverb/cc/testing:proto_test_util",
+ ] + reverb_tf_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "chunk_store",
+ srcs = ["chunk_store.cc"],
+ hdrs = ["chunk_store.h"],
+ deps = [
+ ":schema_cc_proto",
+ "//reverb/cc/platform:thread",
+ "//reverb/cc/support:queue",
+ ] + reverb_tf_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "priority_table",
+ srcs = [
+ "priority_table.cc",
+ "rate_limiter.cc",
+ ],
+ hdrs = [
+ "priority_table.h",
+ "rate_limiter.h",
+ ],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ ":chunk_store",
+ ":priority_table_item",
+ ":schema_cc_proto",
+ "//reverb/cc/checkpointing:checkpoint_cc_proto",
+ "//reverb/cc/distributions:interface",
+ "//reverb/cc/platform:logging",
+ "//reverb/cc/table_extensions:interface",
+ ] + reverb_tf_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "tensor_compression",
+ srcs = ["tensor_compression.cc"],
+ hdrs = ["tensor_compression.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ "//reverb/cc/platform:logging",
+ "//reverb/cc/platform:snappy",
+ ] + reverb_tf_deps(),
+)
+
+reverb_cc_library(
+ name = "replay_sampler",
+ srcs = ["replay_sampler.cc"],
+ hdrs = ["replay_sampler.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ ":replay_service_cc_grpc_proto",
+ ":replay_service_cc_proto",
+ ":tensor_compression",
+ "//reverb/cc/platform:logging",
+ "//reverb/cc/platform:thread",
+ "//reverb/cc/support:grpc_util",
+ "//reverb/cc/support:queue",
+ ] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "replay_writer",
+ srcs = ["replay_writer.cc"],
+ hdrs = ["replay_writer.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ ":replay_service_cc_grpc_proto",
+ ":replay_service_cc_proto",
+ ":schema_cc_proto",
+ ":tensor_compression",
+ "//reverb/cc/platform:logging",
+ "//reverb/cc/support:grpc_util",
+ "//reverb/cc/support:signature",
+ ] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "replay_client",
+ srcs = ["replay_client.cc"],
+ hdrs = ["replay_client.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ ":replay_sampler",
+ ":replay_service_cc_grpc_proto",
+ ":replay_service_cc_proto",
+ ":replay_writer",
+ ":schema_cc_proto",
+ "//reverb/cc/platform:grpc_utils",
+ "//reverb/cc/platform:logging",
+ "//reverb/cc/support:grpc_util",
+ "//reverb/cc/support:signature",
+ "//reverb/cc/support:uint128",
+ ] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "replay_service_impl",
+ srcs = ["replay_service_impl.cc"],
+ hdrs = ["replay_service_impl.h"],
+ deps = [
+ ":chunk_store",
+ ":priority_table",
+ ":replay_service_cc_grpc_proto",
+ ":replay_service_cc_proto",
+ ":schema_cc_proto",
+ "//reverb/cc/checkpointing:interface",
+ "//reverb/cc/platform:logging",
+ "//reverb/cc/support:grpc_util",
+ "//reverb/cc/support:uint128",
+ ] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "priority_table_item",
+ hdrs = ["priority_table_item.h"],
+ deps = [
+ ":chunk_store",
+ ":schema_cc_proto",
+ ],
+)
+
+reverb_cc_library(
+ name = "reverb_server",
+ srcs = ["reverb_server.cc"],
+ hdrs = ["reverb_server.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ ":priority_table",
+ ":replay_client",
+ ":replay_service_impl",
+ "//reverb/cc/checkpointing:interface",
+ "//reverb/cc/platform:grpc_utils",
+ "//reverb/cc/platform:logging",
+ ] + reverb_grpc_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_proto_library(
+ name = "schema_cc_proto",
+ srcs = ["schema.proto"],
+)
+
+reverb_py_proto_library(
+ name = "schema_py_pb2",
+ srcs = ["schema.proto"],
+ deps = [":schema_cc_proto"],
+)
+
+reverb_cc_proto_library(
+ name = "replay_service_cc_proto",
+ srcs = ["replay_service.proto"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [":schema_cc_proto"],
+)
+
+reverb_cc_grpc_library(
+ name = "replay_service_cc_grpc_proto",
+ srcs = ["replay_service.proto"],
+ generate_mocks = True,
+ visibility = ["//reverb:__subpackages__"],
+ deps = [":replay_service_cc_proto"],
+)
diff --git a/reverb/cc/checkpointing/BUILD b/reverb/cc/checkpointing/BUILD
new file mode 100644
index 0000000..80a5013
--- /dev/null
+++ b/reverb/cc/checkpointing/BUILD
@@ -0,0 +1,23 @@
+load(
+ "//reverb/cc/platform:build_rules.bzl",
+ "reverb_cc_library",
+ "reverb_cc_proto_library",
+)
+
+package(default_visibility = ["//reverb:__subpackages__"])
+
+licenses(["notice"])
+
+reverb_cc_proto_library(
+ name = "checkpoint_cc_proto",
+ srcs = ["checkpoint.proto"],
+ deps = [
+ "//reverb/cc:schema_cc_proto",
+ ],
+)
+
+reverb_cc_library(
+ name = "interface",
+ hdrs = ["interface.h"],
+ deps = ["//reverb/cc:priority_table"],
+)
diff --git a/reverb/cc/checkpointing/checkpoint.proto b/reverb/cc/checkpointing/checkpoint.proto
new file mode 100644
index 0000000..34dbcbc
--- /dev/null
+++ b/reverb/cc/checkpointing/checkpoint.proto
@@ -0,0 +1,61 @@
+syntax = "proto3";
+
+package deepmind.reverb;
+
+import "reverb/cc/schema.proto";
+
+// Configs for reconstructing a distribution to its initial state.
+
+message PriorityTableCheckpoint {
+ // Name of the priority table.
+ string table_name = 1;
+
+ // Maximum number of items in the priority table.
+ // If an insert would result in this value getting exceeded, `remover` is used
+ // to select an item to remove before proceeding with the insert.
+ int64 max_size = 6;
+
+ // The maximum number of times an item can be sampled before being removed.
+ int32 max_times_sampled = 7;
+
+ // Items in the priority table ordered by `inserted_at` (asc).
+ // When loading a checkpoint the items should be added in the same order so
+ // position based distributions (e.g fifo) are reconstructed correctly.
+ repeated PrioritizedItem items = 2;
+
+ // Checkpoint of the associated rate limiter.
+ RateLimiterCheckpoint rate_limiter = 3;
+
+ // Options for constructing new samplers and removers of the correct type.
+ // Note that this does not include the state that they currently hold as it
+ // will be reproduced using the order of `items.
+ KeyDistributionOptions sampler = 4;
+ KeyDistributionOptions remover = 5;
+}
+
+message RateLimiterCheckpoint {
+ reserved 1; // Deprecated field `name`.
+
+ // The average number of times each item should be sampled during its
+ // lifetime.
+ double samples_per_insert = 2;
+
+ // The minimum and maximum values the cursor is allowed to reach. The cursor
+ // value is calculated as `insert_count * samples_per_insert -
+ // sample_count`. If the value would go beyond these limits then the call is
+ // blocked until it can proceed without violating the constraints.
+ double min_diff = 3;
+ double max_diff = 4;
+
+ // The minimum number of inserts required before any sample operation.
+ int64 min_size_to_sample = 5;
+
+ // The total number of samples that occurred before the checkpoint.
+ int64 sample_count = 6;
+
+ // The total number of inserts that occurred before the checkpoint.
+ int64 insert_count = 7;
+
+ // The total number of deletes that occured before the checkpoint.
+ int64 delete_count = 8;
+}
diff --git a/reverb/cc/checkpointing/interface.h b/reverb/cc/checkpointing/interface.h
new file mode 100644
index 0000000..c94f6e3
--- /dev/null
+++ b/reverb/cc/checkpointing/interface.h
@@ -0,0 +1,55 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef REVERB_CC_CHECKPOINTING_INTERFACE_H_
+#define REVERB_CC_CHECKPOINTING_INTERFACE_H_
+
+#include "reverb/cc/priority_table.h"
+
+namespace deepmind {
+namespace reverb {
+
+// A checkpointer is able to encode the configuration, data and state as a
+// proto . This proto is stored in a permanent storage system where it can
+// retrieved at a later point and restore a copy of the checkpointed tables.
+class CheckpointerInterface {
+ public:
+ virtual ~CheckpointerInterface() = default;
+
+ // Save a new checkpoint for every table in `tables` to permanent storage. If
+ // successful, `path` will contain an ABSOLUTE path that could be used to
+ // restore the checkpoint.
+ virtual tensorflow::Status Save(std::vector tables,
+ int keep_latest, std::string* path) = 0;
+
+ // Attempts to load a checkpoint from the active workspace.
+ //
+ // Tables loaded from checkpoint must already exist in `tables`. When
+ // constructing the newly loaded table the extensions are passed from the old
+ // table and the item is replaced with the newly loaded table.
+ virtual tensorflow::Status Load(
+ absl::string_view relative_path, ChunkStore* chunk_store,
+ std::vector>* tables) = 0;
+
+ // Finds the most recent checkpoint within the active workspace. See `Load`
+ // for more details.
+ virtual tensorflow::Status LoadLatest(
+ ChunkStore* chunk_store,
+ std::vector>* tables) = 0;
+};
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // REVERB_CC_CHECKPOINTING_INTERFACE_H_
diff --git a/reverb/cc/chunk_store.cc b/reverb/cc/chunk_store.cc
new file mode 100644
index 0000000..d86bcf0
--- /dev/null
+++ b/reverb/cc/chunk_store.cc
@@ -0,0 +1,101 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/chunk_store.h"
+
+#include
+#include
+#include
+
+#include
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_cat.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "reverb/cc/platform/thread.h"
+#include "reverb/cc/schema.pb.h"
+#include "reverb/cc/support/queue.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+ChunkStore::ChunkStore(int cleanup_batch_size)
+ : delete_keys_(std::make_shared>(10000000)),
+ cleaner_(internal::StartThread(
+ "ChunkStore-Cleaner", [this, cleanup_batch_size] {
+ while (CleanupInternal(cleanup_batch_size)) {
+ }
+ })) {}
+
+ChunkStore::~ChunkStore() {
+ // Closing the queue makes all calls to `CleanupInternal` to return false
+ // which will break the loop in `cleaner_` making it joinable.
+ delete_keys_->Close();
+ cleaner_ = nullptr; // Joins thread.
+}
+
+std::shared_ptr ChunkStore::Insert(ChunkData item) {
+ absl::WriterMutexLock lock(&mu_);
+ std::weak_ptr& wp = data_[item.chunk_key()];
+ std::shared_ptr sp = wp.lock();
+ if (sp == nullptr) {
+ wp = (sp = std::shared_ptr(new Chunk(std::move(item)),
+ [q = delete_keys_](Chunk* chunk) {
+ q->Push(chunk->data().chunk_key());
+ delete chunk;
+ }));
+ }
+ return sp;
+}
+
+tensorflow::Status ChunkStore::Get(
+ absl::Span keys,
+ std::vector>* chunks) {
+ absl::ReaderMutexLock lock(&mu_);
+ chunks->clear();
+ chunks->reserve(keys.size());
+ for (int i = 0; i < keys.size(); i++) {
+ chunks->push_back(GetItem(keys[i]));
+ if (!chunks->at(i)) {
+ return tensorflow::errors::NotFound(
+ absl::StrCat("Chunk ", keys[i], " cannot be found."));
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+std::shared_ptr ChunkStore::GetItem(Key key) {
+ auto it = data_.find(key);
+ return it == data_.end() ? nullptr : it->second.lock();
+}
+
+bool ChunkStore::CleanupInternal(int num_chunks) {
+ std::vector popped_keys(num_chunks);
+ for (int i = 0; i < num_chunks; i++) {
+ if (!delete_keys_->Pop(&popped_keys[i])) return false;
+ }
+
+ absl::WriterMutexLock data_lock(&mu_);
+ for (const Key& key : popped_keys) {
+ data_.erase(key);
+ }
+
+ return true;
+}
+
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/chunk_store.h b/reverb/cc/chunk_store.h
new file mode 100644
index 0000000..4b37aca
--- /dev/null
+++ b/reverb/cc/chunk_store.h
@@ -0,0 +1,127 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef REVERB_CC_CHUNK_STORE_H_
+#define REVERB_CC_CHUNK_STORE_H_
+
+#include
+#include
+#include
+
+#include
+#include "absl/base/thread_annotations.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
+#include "reverb/cc/platform/thread.h"
+#include "reverb/cc/schema.pb.h"
+#include "reverb/cc/support/queue.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+// Maintains a bijection from chunk keys to Chunks. For inserting, the caller
+// passes ChunkData which contains a chunk key and the actual data. We use the
+// key for the mapping and wrap the ChunkData with in a thin class which
+// provides a read-only accessor to the ChunkData.
+//
+// For ChunkData that populates sequence_range, an interval tree is maintained
+// to allow lookups of intersecting ranges.
+//
+// Each Chunk is reference counted individually. When its reference count drops
+// to zero, the Chunk is destroyed and subsequent calls to Get() will no longer
+// return that Chunk. Please note that this container only holds a weak pointer
+// to a Chunk, and thus does not count towards the reference count. For this
+// reason, Insert() returns a shared pointer, as otherwise the Chunk would be
+// destroyed right away.
+//
+// All public methods are thread safe.
+class ChunkStore {
+ public:
+ using Key = uint64_t;
+
+ class Chunk {
+ public:
+ explicit Chunk(ChunkData data) : data_(std::move(data)) {}
+
+ // Returns the proto data of the chunk.
+ const ChunkData& data() const { return data_; }
+
+ // (Potentially cached) size of `data`.
+ size_t DataByteSizeLong() const {
+ if (data_byte_size_ == 0) {
+ data_byte_size_ = data_.ByteSizeLong();
+ }
+ return data_byte_size_;
+ }
+
+ private:
+ ChunkData data_;
+ mutable size_t data_byte_size_ = 0;
+ };
+
+ // Starts `cleaner_`. `cleanp_batch_size` is the number of keys the cleaner
+ // should wait for before acquiring the lock and erasing them from `data_`.
+ explicit ChunkStore(int cleanup_batch_size = 1000);
+
+ // Stops `cleaner_` closes `delete_keys_`.
+ ~ChunkStore();
+
+ // Attempts to insert a Chunk into the map using the key inside `item`. If no
+ // entry existed for the key, a new Chunk is created, inserted and returned.
+ // Otherwise, the existing chunk is returned.
+ std::shared_ptr Insert(ChunkData item) ABSL_LOCKS_EXCLUDED(mu_);
+
+ // Gets the Chunk for each given key. Returns an error if one of the items
+ // does not exist or if `Close` has been called. On success, the returned
+ // items are in the same order as given in `keys`.
+ tensorflow::Status Get(absl::Span keys,
+ std::vector> *chunks)
+ ABSL_LOCKS_EXCLUDED(mu_);
+
+ // Blocks until `num_chunks` expired entries have been cleaned up from
+ // `data_`. This method is called automatically by a background thread to
+ // limit memory size, but does not have any effect on the semantics of Get()
+ // or Insert() calls.
+ //
+ // Returns false if `delete_keys_` closed before `num_chunks` could be popped.
+ bool CleanupInternal(int num_chunks) ABSL_LOCKS_EXCLUDED(mu_);
+
+ private:
+ // Gets an item. Returns nullptr if the item does not exist.
+ std::shared_ptr GetItem(Key key) ABSL_SHARED_LOCKS_REQUIRED(mu_);
+
+ // Holds the actual mapping of key to Chunk. We only hold a weak pointer to
+ // the Chunk, which means that destruction and reference counting of the
+ // chunks happens independently of this map.
+ absl::flat_hash_map> data_ ABSL_GUARDED_BY(mu_);
+
+ // Mutex protecting access to `data_`.
+ mutable absl::Mutex mu_;
+
+ // Queue of keys of deleted items that will be cleaned up by `cleaner_`. Note
+ // the queue have to be allocated on the heap in order to avoid dereferncing
+ // errors caused by a stack allocated ChunkStore getting destroyed before all
+ // Chunk have been destroyed.
+ std::shared_ptr> delete_keys_;
+
+ // Consumes `delete_keys_` to remove dead pointers in `data_`.
+ std::unique_ptr cleaner_;
+};
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // REVERB_CC_CHUNK_STORE_H_
diff --git a/reverb/cc/chunk_store_test.cc b/reverb/cc/chunk_store_test.cc
new file mode 100644
index 0000000..f412458
--- /dev/null
+++ b/reverb/cc/chunk_store_test.cc
@@ -0,0 +1,128 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/chunk_store.h"
+
+#include
+#include
+#include
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "reverb/cc/platform/thread.h"
+#include "reverb/cc/schema.pb.h"
+#include "reverb/cc/testing/proto_test_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+using ChunkVector = ::std::vector<::std::shared_ptr>;
+
+TEST(ChunkStoreTest, GetAfterInsertSucceeds) {
+ ChunkStore store;
+ std::shared_ptr inserted =
+ store.Insert(testing::MakeChunkData(2));
+ ChunkVector chunks;
+ TF_ASSERT_OK(store.Get({2}, &chunks));
+ EXPECT_EQ(inserted, chunks[0]);
+}
+
+TEST(ChunkStoreTest, GetFailsWhenKeyDoesNotExist) {
+ ChunkStore store;
+ ChunkVector chunks;
+ EXPECT_EQ(store.Get({2}, &chunks).code(), tensorflow::error::NOT_FOUND);
+}
+
+TEST(ChunkStoreTest, GetFailsAfterChunkIsDestroyed) {
+ ChunkStore store;
+ std::shared_ptr inserted =
+ store.Insert(testing::MakeChunkData(1));
+ inserted = nullptr;
+ ChunkVector chunks;
+ EXPECT_EQ(store.Get({2}, &chunks).code(), tensorflow::error::NOT_FOUND);
+}
+
+TEST(ChunkStoreTest, InsertingTwiceReturnsExistingChunk) {
+ ChunkStore store;
+ ChunkData data = testing::MakeChunkData(2);
+ data.add_data();
+ std::shared_ptr first =
+ store.Insert(testing::MakeChunkData(2));
+ EXPECT_NE(first, nullptr);
+ std::shared_ptr second =
+ store.Insert(testing::MakeChunkData(2));
+ EXPECT_EQ(first, second);
+}
+
+TEST(ChunkStoreTest, InsertingTwiceSucceedsWhenChunkIsDestroyed) {
+ ChunkStore store;
+ std::shared_ptr first =
+ store.Insert(testing::MakeChunkData(1));
+ EXPECT_NE(first, nullptr);
+ first = nullptr;
+ std::shared_ptr second =
+ store.Insert(testing::MakeChunkData(1));
+ EXPECT_NE(second, nullptr);
+}
+
+TEST(ChunkStoreTest, CleanupDoesNotDeleteRequiredChunks) {
+ ChunkStore store(/*cleanup_batch_size=*/1);
+
+ // Keep this one around.
+ std::shared_ptr first =
+ store.Insert(testing::MakeChunkData(1));
+
+ // Let this one expire.
+ {
+ std::shared_ptr second =
+ store.Insert(testing::MakeChunkData(2));
+ }
+
+ // The first one should still be available.
+ ChunkVector chunks;
+ TF_EXPECT_OK(store.Get({1}, &chunks));
+
+ // The second one should be gone eventually.
+ tensorflow::Status status;
+ while (status.code() != tensorflow::error::NOT_FOUND) {
+ status = store.Get({2}, &chunks);
+ }
+}
+
+TEST(ChunkStoreTest, ConcurrentCalls) {
+ ChunkStore store;
+ std::vector> bundle;
+ std::atomic count(0);
+ for (ChunkStore::Key i = 0; i < 1000; i++) {
+ bundle.push_back(internal::StartThread("", [i, &store, &count] {
+ std::shared_ptr first =
+ store.Insert(testing::MakeChunkData(i));
+ ChunkVector chunks;
+ TF_ASSERT_OK(store.Get({i}, &chunks));
+ first = nullptr;
+ while (store.Get({i}, &chunks).code() != tensorflow::error::NOT_FOUND) {
+ }
+ count++;
+ }));
+ }
+ bundle.clear(); // Joins all threads.
+ EXPECT_EQ(count, 1000);
+}
+
+} // namespace
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/distributions/BUILD b/reverb/cc/distributions/BUILD
new file mode 100644
index 0000000..6907126
--- /dev/null
+++ b/reverb/cc/distributions/BUILD
@@ -0,0 +1,130 @@
+load(
+ "//reverb/cc/platform:build_rules.bzl",
+ "reverb_absl_deps",
+ "reverb_cc_library",
+ "reverb_cc_test",
+ "reverb_tf_deps",
+)
+
+package(default_visibility = ["//reverb:__subpackages__"])
+
+licenses(["notice"])
+
+reverb_cc_library(
+ name = "interface",
+ hdrs = ["interface.h"],
+ deps = [
+ "//reverb/cc/checkpointing:checkpoint_cc_proto",
+ ] + reverb_tf_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "uniform",
+ srcs = ["uniform.cc"],
+ hdrs = ["uniform.h"],
+ deps = [
+ ":interface",
+ "//reverb/cc:schema_cc_proto",
+ "//reverb/cc/checkpointing:checkpoint_cc_proto",
+ "//reverb/cc/platform:logging",
+ ] + reverb_tf_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "fifo",
+ srcs = ["fifo.cc"],
+ hdrs = ["fifo.h"],
+ deps = [
+ ":interface",
+ "//reverb/cc:schema_cc_proto",
+ "//reverb/cc/checkpointing:checkpoint_cc_proto",
+ "//reverb/cc/platform:logging",
+ ] + reverb_tf_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "lifo",
+ srcs = ["lifo.cc"],
+ hdrs = ["lifo.h"],
+ deps = [
+ ":interface",
+ "//reverb/cc:schema_cc_proto",
+ "//reverb/cc/checkpointing:checkpoint_cc_proto",
+ "//reverb/cc/platform:logging",
+ ] + reverb_tf_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "prioritized",
+ srcs = ["prioritized.cc"],
+ hdrs = ["prioritized.h"],
+ deps = [
+ ":interface",
+ "//reverb/cc:schema_cc_proto",
+ "//reverb/cc/checkpointing:checkpoint_cc_proto",
+ "//reverb/cc/platform:logging",
+ ] + reverb_tf_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "heap",
+ srcs = ["heap.cc"],
+ hdrs = ["heap.h"],
+ deps = [
+ ":interface",
+ "//reverb/cc:schema_cc_proto",
+ "//reverb/cc/checkpointing:checkpoint_cc_proto",
+ "//reverb/cc/support:intrusive_heap",
+ ] + reverb_tf_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_test(
+ name = "uniform_test",
+ srcs = ["uniform_test.cc"],
+ deps = [
+ ":uniform",
+ "//reverb/cc:schema_cc_proto",
+ "//reverb/cc/testing:proto_test_util",
+ ],
+)
+
+reverb_cc_test(
+ name = "fifo_test",
+ srcs = ["fifo_test.cc"],
+ deps = [
+ ":fifo",
+ "//reverb/cc:schema_cc_proto",
+ "//reverb/cc/testing:proto_test_util",
+ ],
+)
+
+reverb_cc_test(
+ name = "lifo_test",
+ srcs = ["lifo_test.cc"],
+ deps = [
+ ":lifo",
+ "//reverb/cc:schema_cc_proto",
+ "//reverb/cc/testing:proto_test_util",
+ ],
+)
+
+reverb_cc_test(
+ name = "prioritized_test",
+ srcs = ["prioritized_test.cc"],
+ deps = [
+ ":prioritized",
+ "//reverb/cc:schema_cc_proto",
+ "//reverb/cc/testing:proto_test_util",
+ ] + reverb_absl_deps(),
+)
+
+reverb_cc_test(
+ name = "heap_test",
+ srcs = ["heap_test.cc"],
+ deps = [
+ ":heap",
+ ":interface",
+ "//reverb/cc:schema_cc_proto",
+ "//reverb/cc/testing:proto_test_util",
+ ],
+)
diff --git a/reverb/cc/distributions/fifo.cc b/reverb/cc/distributions/fifo.cc
new file mode 100644
index 0000000..60e6023
--- /dev/null
+++ b/reverb/cc/distributions/fifo.cc
@@ -0,0 +1,72 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/distributions/fifo.h"
+
+#include "reverb/cc/checkpointing/checkpoint.pb.h"
+#include "reverb/cc/platform/logging.h"
+#include "reverb/cc/schema.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+tensorflow::Status FifoDistribution::Delete(KeyDistributionInterface::Key key) {
+ auto it = key_to_iterator_.find(key);
+ if (it == key_to_iterator_.end())
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " not found in distribution."));
+ keys_.erase(it->second);
+ key_to_iterator_.erase(it);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status FifoDistribution::Insert(KeyDistributionInterface::Key key,
+ double priority) {
+ if (key_to_iterator_.find(key) != key_to_iterator_.end()) {
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " already exists in distribution."));
+ }
+ key_to_iterator_.emplace(key, keys_.emplace(keys_.end(), key));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status FifoDistribution::Update(KeyDistributionInterface::Key key,
+ double priority) {
+ if (key_to_iterator_.find(key) == key_to_iterator_.end()) {
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " not found in distribution."));
+ }
+ return tensorflow::Status::OK();
+}
+
+KeyDistributionInterface::KeyWithProbability FifoDistribution::Sample() {
+ REVERB_CHECK(!keys_.empty());
+ return {keys_.front(), 1.};
+}
+
+void FifoDistribution::Clear() {
+ keys_.clear();
+ key_to_iterator_.clear();
+}
+
+KeyDistributionOptions FifoDistribution::options() const {
+ KeyDistributionOptions options;
+ options.set_fifo(true);
+ return options;
+}
+
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/distributions/fifo.h b/reverb/cc/distributions/fifo.h
new file mode 100644
index 0000000..1cdd02d
--- /dev/null
+++ b/reverb/cc/distributions/fifo.h
@@ -0,0 +1,56 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_FIFO_H_
+#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_FIFO_H_
+
+#include
+
+#include "absl/container/flat_hash_map.h"
+#include "reverb/cc/checkpointing/checkpoint.pb.h"
+#include "reverb/cc/distributions/interface.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+// Fifo sampling. We ignore all priority values in the calls. Sample() always
+// returns the key that was inserted first until this key is deleted. All
+// operations take O(1) time. See KeyDistributionInterface for documentation
+// about the methods.
+class FifoDistribution : public KeyDistributionInterface {
+ public:
+ tensorflow::Status Delete(Key key) override;
+
+ // The priority is ignored.
+ tensorflow::Status Insert(Key key, double priority) override;
+
+ // This is a no-op but will return an error if the key does not exist.
+ tensorflow::Status Update(Key key, double priority) override;
+
+ KeyWithProbability Sample() override;
+
+ void Clear() override;
+
+ KeyDistributionOptions options() const override;
+
+ private:
+ std::list keys_;
+ absl::flat_hash_map::iterator> key_to_iterator_;
+};
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_FIFO_H_
diff --git a/reverb/cc/distributions/fifo_test.cc b/reverb/cc/distributions/fifo_test.cc
new file mode 100644
index 0000000..7080096
--- /dev/null
+++ b/reverb/cc/distributions/fifo_test.cc
@@ -0,0 +1,88 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/distributions/fifo.h"
+
+#include
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "reverb/cc/schema.pb.h"
+#include "reverb/cc/testing/proto_test_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+TEST(FifoTest, ReturnValueSantiyChecks) {
+ FifoDistribution fifo;
+
+ // Non existent keys cannot be deleted or updated.
+ EXPECT_EQ(fifo.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT);
+ EXPECT_EQ(fifo.Update(123, 4).code(), tensorflow::error::INVALID_ARGUMENT);
+
+ // Keys cannot be inserted twice.
+ TF_EXPECT_OK(fifo.Insert(123, 4));
+ EXPECT_THAT(fifo.Insert(123, 4).code(), tensorflow::error::INVALID_ARGUMENT);
+
+ // Existing keys can be updated and sampled.
+ TF_EXPECT_OK(fifo.Update(123, 5));
+ EXPECT_EQ(fifo.Sample().key, 123);
+
+ // Existing keys cannot be deleted twice.
+ TF_EXPECT_OK(fifo.Delete(123));
+ EXPECT_THAT(fifo.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT);
+}
+
+TEST(FifoTest, MatchesFifoOrdering) {
+ int64_t kItems = 100;
+
+ FifoDistribution fifo;
+ // Insert items.
+ for (int i = 0; i < kItems; i++) {
+ TF_EXPECT_OK(fifo.Insert(i, 0));
+ }
+ // Delete every 10th item.
+ for (int i = 0; i < kItems; i++) {
+ if (i % 10 == 0) TF_EXPECT_OK(fifo.Delete(i));
+ }
+
+ for (int i = 0; i < kItems; i++) {
+ if (i % 10 == 0) continue;
+ KeyDistributionInterface::KeyWithProbability sample = fifo.Sample();
+ EXPECT_EQ(sample.key, i);
+ EXPECT_EQ(sample.probability, 1);
+ TF_EXPECT_OK(fifo.Delete(sample.key));
+ }
+}
+
+TEST(FifoTest, Options) {
+ FifoDistribution fifo;
+ EXPECT_THAT(fifo.options(), testing::EqualsProto("fifo: true"));
+}
+
+TEST(FifoDeathTest, ClearThenSample) {
+ FifoDistribution fifo;
+ for (int i = 0; i < 100; i++) {
+ TF_EXPECT_OK(fifo.Insert(i, i));
+ }
+ fifo.Sample();
+ fifo.Clear();
+ EXPECT_DEATH(fifo.Sample(), "");
+}
+
+} // namespace
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/distributions/heap.cc b/reverb/cc/distributions/heap.cc
new file mode 100644
index 0000000..3b6b985
--- /dev/null
+++ b/reverb/cc/distributions/heap.cc
@@ -0,0 +1,81 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/distributions/heap.h"
+
+#include "reverb/cc/checkpointing/checkpoint.pb.h"
+#include "reverb/cc/distributions/interface.h"
+#include "reverb/cc/schema.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+HeapDistribution::HeapDistribution(bool min_heap)
+ : sign_(min_heap ? 1 : -1), update_count_(0) {}
+
+tensorflow::Status HeapDistribution::Delete(KeyDistributionInterface::Key key) {
+ auto it = nodes_.find(key);
+ if (it == nodes_.end()) {
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " not found in distribution."));
+ }
+ heap_.Remove(it->second.get());
+ nodes_.erase(it);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status HeapDistribution::Insert(KeyDistributionInterface::Key key,
+ double priority) {
+ if (nodes_.contains(key)) {
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " already exists in distribution."));
+ }
+ nodes_[key] =
+ absl::make_unique(key, priority * sign_, update_count_++);
+ heap_.Push(nodes_[key].get());
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status HeapDistribution::Update(KeyDistributionInterface::Key key,
+ double priority) {
+ if (!nodes_.contains(key)) {
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " not found in distribution."));
+ }
+ nodes_[key]->priority = priority * sign_;
+ nodes_[key]->update_number = update_count_++;
+ heap_.Adjust(nodes_[key].get());
+ return tensorflow::Status::OK();
+}
+
+KeyDistributionInterface::KeyWithProbability HeapDistribution::Sample() {
+ REVERB_CHECK(!nodes_.empty());
+ return {heap_.top()->key, 1.};
+}
+
+void HeapDistribution::Clear() {
+ nodes_.clear();
+ heap_.Clear();
+}
+
+KeyDistributionOptions HeapDistribution::options() const {
+ KeyDistributionOptions options;
+ options.mutable_heap()->set_min_heap(sign_ == 1);
+ return options;
+}
+
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/distributions/heap.h b/reverb/cc/distributions/heap.h
new file mode 100644
index 0000000..be0c609
--- /dev/null
+++ b/reverb/cc/distributions/heap.h
@@ -0,0 +1,90 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_HEAP_H_
+#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_HEAP_H_
+
+#include
+#include "absl/container/flat_hash_map.h"
+#include "reverb/cc/distributions/interface.h"
+#include "reverb/cc/support/intrusive_heap.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+// HeapDistribution always samples the item with the lowest or highest priority
+// (controlled by `min_heap`). If multiple items share the same priority then
+// the least recently inserted or updated key is sampled.
+class HeapDistribution : public KeyDistributionInterface {
+ public:
+ explicit HeapDistribution(bool min_heap = true);
+
+ // O(log n) time.
+ tensorflow::Status Delete(Key key) override;
+
+ // O(log n) time.
+ tensorflow::Status Insert(Key key, double priority) override;
+
+ // O(log n) time.
+ tensorflow::Status Update(Key key, double priority) override;
+
+ // O(1) time.
+ KeyWithProbability Sample() override;
+
+ // O(n) time.
+ void Clear() override;
+
+ KeyDistributionOptions options() const override;
+
+ private:
+ struct HeapNode {
+ Key key;
+ double priority;
+ IntrusiveHeapLink heap;
+ uint64_t update_number;
+
+ HeapNode(Key key, double priority, uint64_t update_number)
+ : key(key), priority(priority), update_number(update_number) {}
+ };
+
+ struct HeapNodeCompare {
+ bool operator()(const HeapNode* a, const HeapNode* b) const {
+ // Lexicographic ordering by (priority, update_number).
+ return (a->priority < b->priority) ||
+ ((a->priority == b->priority) &&
+ (a->update_number < b->update_number));
+ }
+ };
+
+ // 1 if `min_heap` = true, else -1. Priorities are multiplied by this number
+ // to control whether the min or max priority item should be sampled.
+ const double sign_;
+
+ // Heap where the top item is the one with the lowest/highest priority in the
+ // distribution.
+ IntrusiveHeap heap_;
+
+ // `IntrusiveHeap` does not manage the memory of its nodes so they are stored
+ // in `nodes_`. The content of nodes_ and heap_ are always kept in sync.
+ absl::flat_hash_map> nodes_;
+
+ // Keep track of the number of inserts/updates for most-recent tie-breaking.
+ uint64_t update_count_;
+};
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_HEAP_H_
diff --git a/reverb/cc/distributions/heap_test.cc b/reverb/cc/distributions/heap_test.cc
new file mode 100644
index 0000000..dac8b6a
--- /dev/null
+++ b/reverb/cc/distributions/heap_test.cc
@@ -0,0 +1,186 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/distributions/heap.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "reverb/cc/distributions/interface.h"
+#include "reverb/cc/schema.pb.h"
+#include "reverb/cc/testing/proto_test_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+
+TEST(HeapDistributionTest, ReturnValueSantiyChecks) {
+ HeapDistribution heap;
+
+ // Non existent keys cannot be deleted or updated.
+ EXPECT_EQ(heap.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT);
+ EXPECT_EQ(heap.Update(123, 4).code(), tensorflow::error::INVALID_ARGUMENT);
+
+ // Keys cannot be inserted twice.
+ TF_EXPECT_OK(heap.Insert(123, 4));
+ EXPECT_EQ(heap.Insert(123, 4).code(), tensorflow::error::INVALID_ARGUMENT);
+
+ // Existing keys can be updated and sampled.
+ TF_EXPECT_OK(heap.Update(123, 5));
+ EXPECT_EQ(heap.Sample().key, 123);
+
+ // Existing keys cannot be deleted twice.
+ TF_EXPECT_OK(heap.Delete(123));
+ EXPECT_EQ(heap.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT);
+}
+
+TEST(HeapDistributionTest, SampleMinPriorityFirstByDefault) {
+ HeapDistribution heap;
+
+ TF_EXPECT_OK(heap.Insert(123, 2));
+ TF_EXPECT_OK(heap.Insert(124, 1));
+ TF_EXPECT_OK(heap.Insert(125, 3));
+
+ EXPECT_EQ(heap.Sample().key, 124);
+
+ // Remove the top item.
+ TF_EXPECT_OK(heap.Delete(124));
+
+ // Second lowest priority is now the lowest.
+ EXPECT_EQ(heap.Sample().key, 123);
+}
+
+TEST(HeapDistributionTest, BreakTiesByInsertionOrder) {
+ HeapDistribution heap;
+
+ // We insert keys with priorities such that the keys are removed in the
+ // order [0, 1, 2, 3, 4, 5].
+ // Three-way tie and two-way tie are checked.
+ TF_EXPECT_OK(heap.Insert(5, 300));
+ TF_EXPECT_OK(heap.Insert(0, 1));
+ TF_EXPECT_OK(heap.Insert(3, 20));
+ TF_EXPECT_OK(heap.Insert(1, 1));
+ TF_EXPECT_OK(heap.Insert(4, 20));
+ TF_EXPECT_OK(heap.Insert(2, 1));
+
+ for (auto i = 0; i < 6; i++) {
+ EXPECT_EQ(heap.Sample().key, i);
+ TF_EXPECT_OK(heap.Delete(i));
+ }
+}
+
+TEST(HeapDistributionTest, BreakTiesByUpdateOrder) {
+ HeapDistribution heap;
+
+ TF_EXPECT_OK(heap.Insert(2, 1));
+ TF_EXPECT_OK(heap.Insert(0, 1));
+ TF_EXPECT_OK(heap.Insert(1, 1));
+
+ // Removing keys at this point would result in the order [2, 0, 1]
+ // by LRU because the priorites are equal.
+ // This update does not change the priority, but does increase the update
+ // recency, resulting in the new order [0, 1, 2] which we verify.
+ TF_EXPECT_OK(heap.Update(2, 1));
+ for (auto i = 0; i < 3; i++) {
+ EXPECT_EQ(heap.Sample().key, i);
+ TF_EXPECT_OK(heap.Delete(i));
+ }
+}
+
+TEST(HeapDistributionTest, SampleMaxPriorityWhenMinHeapFalse) {
+ HeapDistribution heap(false);
+
+ TF_EXPECT_OK(heap.Insert(123, 2));
+ TF_EXPECT_OK(heap.Insert(124, 1));
+ TF_EXPECT_OK(heap.Insert(125, 3));
+
+ EXPECT_EQ(heap.Sample().key, 125);
+
+ // Remove the top item.
+ TF_EXPECT_OK(heap.Delete(125));
+
+ // Second lowest priority is now the highest.
+ EXPECT_EQ(heap.Sample().key, 123);
+}
+
+TEST(HeapDistributionTest, UpdateChangesOrder) {
+ HeapDistribution heap;
+
+ TF_EXPECT_OK(heap.Insert(123, 2));
+ TF_EXPECT_OK(heap.Insert(124, 1));
+ TF_EXPECT_OK(heap.Insert(125, 3));
+
+ EXPECT_EQ(heap.Sample().key, 124);
+
+ // Update the current top item.
+ TF_EXPECT_OK(heap.Update(124, 5));
+ EXPECT_EQ(heap.Sample().key, 123);
+
+ // Update another item and check that it is moved to the top.
+ TF_EXPECT_OK(heap.Update(125, 0.5));
+ EXPECT_EQ(heap.Sample().key, 125);
+}
+
+TEST(HeapDistributionTest, Clear) {
+ HeapDistribution heap;
+
+ TF_EXPECT_OK(heap.Insert(123, 2));
+ TF_EXPECT_OK(heap.Insert(124, 1));
+
+ EXPECT_EQ(heap.Sample().key, 124);
+
+ // Clear distibution and insert an item that should otherwise be at the end.
+ heap.Clear();
+ TF_EXPECT_OK(heap.Insert(125, 10));
+ EXPECT_EQ(heap.Sample().key, 125);
+}
+
+TEST(HeapDistributionTest, ProbabilityIsAlwaysOne) {
+ HeapDistribution heap;
+
+ for (int i = 100; i < 150; i++) {
+ TF_EXPECT_OK(heap.Insert(i, i));
+ }
+
+ for (int i = 0; i < 50; i++) {
+ auto sample = heap.Sample();
+ EXPECT_EQ(sample.probability, 1);
+ TF_EXPECT_OK(heap.Delete(sample.key));
+ }
+}
+
+TEST(HeapDistributionTest, Options) {
+ HeapDistribution min_heap;
+ HeapDistribution max_heap(false);
+ EXPECT_THAT(min_heap.options(),
+ testing::EqualsProto("heap: { min_heap: true }"));
+ EXPECT_THAT(max_heap.options(),
+ testing::EqualsProto("heap: { min_heap: false }"));
+}
+
+TEST(HeapDistributionDeathTest, SampleFromEmptyDistribution) {
+ HeapDistribution heap;
+ EXPECT_DEATH(heap.Sample(), "");
+
+ TF_EXPECT_OK(heap.Insert(123, 2));
+ heap.Sample();
+
+ TF_EXPECT_OK(heap.Delete(123));
+ EXPECT_DEATH(heap.Sample(), "");
+}
+
+} // namespace
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/distributions/interface.h b/reverb/cc/distributions/interface.h
new file mode 100644
index 0000000..53f61c4
--- /dev/null
+++ b/reverb/cc/distributions/interface.h
@@ -0,0 +1,68 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_INTERFACE_H_
+#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_INTERFACE_H_
+
+#include
+#include "reverb/cc/checkpointing/checkpoint.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+// Allows sampling from a population of keys with a specified priority per key.
+//
+// Member methods will not be called concurrently, so implementations do not
+// need to be thread-safe. More to the point, a number of subclasses use bit
+// generators that are not thread-safe, so methods like `Sample` are not
+// thread-safe.
+class KeyDistributionInterface {
+ public:
+ using Key = uint64_t;
+
+ struct KeyWithProbability {
+ Key key;
+ double probability;
+ };
+
+ virtual ~KeyDistributionInterface() = default;
+
+ // Deletes a key and the associated priority. Returns an error if the key does
+ // not exist.
+ virtual tensorflow::Status Delete(Key key) = 0;
+
+ // Inserts a key and associated priority. Returns an error without any change
+ // if the key already exists.
+ virtual tensorflow::Status Insert(Key key, double priority) = 0;
+
+ // Updates a key and associated priority. Returns an error if the key does
+ // not exist.
+ virtual tensorflow::Status Update(Key key, double priority) = 0;
+
+ // Samples a key. Must contain keys when this is called.
+ virtual KeyWithProbability Sample() = 0;
+
+ // Clear the distribution of all data.
+ virtual void Clear() = 0;
+
+ // Options for dynamically constructing the distribution. Required when
+ // reconstructing class from checkpoint. Also used to query table metadata.
+ virtual KeyDistributionOptions options() const = 0;
+};
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_INTERFACE_H_
diff --git a/reverb/cc/distributions/lifo.cc b/reverb/cc/distributions/lifo.cc
new file mode 100644
index 0000000..fe88785
--- /dev/null
+++ b/reverb/cc/distributions/lifo.cc
@@ -0,0 +1,72 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/distributions/lifo.h"
+
+#include "reverb/cc/checkpointing/checkpoint.pb.h"
+#include "reverb/cc/platform/logging.h"
+#include "reverb/cc/schema.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+tensorflow::Status LifoDistribution::Delete(KeyDistributionInterface::Key key) {
+ auto it = key_to_iterator_.find(key);
+ if (it == key_to_iterator_.end())
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " not found in distribution."));
+ keys_.erase(it->second);
+ key_to_iterator_.erase(it);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status LifoDistribution::Insert(KeyDistributionInterface::Key key,
+ double priority) {
+ if (key_to_iterator_.find(key) != key_to_iterator_.end()) {
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " already exists in distribution."));
+ }
+ key_to_iterator_.emplace(key, keys_.emplace(keys_.begin(), key));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status LifoDistribution::Update(KeyDistributionInterface::Key key,
+ double priority) {
+ if (key_to_iterator_.find(key) == key_to_iterator_.end()) {
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " not found in distribution."));
+ }
+ return tensorflow::Status::OK();
+}
+
+KeyDistributionInterface::KeyWithProbability LifoDistribution::Sample() {
+ REVERB_CHECK(!keys_.empty());
+ return {keys_.front(), 1.};
+}
+
+void LifoDistribution::Clear() {
+ keys_.clear();
+ key_to_iterator_.clear();
+}
+
+KeyDistributionOptions LifoDistribution::options() const {
+ KeyDistributionOptions options;
+ options.set_lifo(true);
+ return options;
+}
+
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/distributions/lifo.h b/reverb/cc/distributions/lifo.h
new file mode 100644
index 0000000..3850471
--- /dev/null
+++ b/reverb/cc/distributions/lifo.h
@@ -0,0 +1,56 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_LIFO_H_
+#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_LIFO_H_
+
+#include
+
+#include "absl/container/flat_hash_map.h"
+#include "reverb/cc/checkpointing/checkpoint.pb.h"
+#include "reverb/cc/distributions/interface.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+// Lifo sampling. We ignore all priority values in the calls. Sample() always
+// returns the key that was inserted last until this key is deleted. All
+// operations take O(1) time. See KeyDistributionInterface for documentation
+// about the methods.
+class LifoDistribution : public KeyDistributionInterface {
+ public:
+ tensorflow::Status Delete(Key key) override;
+
+ // The priority is ignored.
+ tensorflow::Status Insert(Key key, double priority) override;
+
+ // This is a no-op but will return an error if the key does not exist.
+ tensorflow::Status Update(Key key, double priority) override;
+
+ KeyWithProbability Sample() override;
+
+ void Clear() override;
+
+ KeyDistributionOptions options() const override;
+
+ private:
+ std::list keys_;
+ absl::flat_hash_map::iterator> key_to_iterator_;
+};
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_LIFO_H_
diff --git a/reverb/cc/distributions/lifo_test.cc b/reverb/cc/distributions/lifo_test.cc
new file mode 100644
index 0000000..e8d2b10
--- /dev/null
+++ b/reverb/cc/distributions/lifo_test.cc
@@ -0,0 +1,88 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/distributions/lifo.h"
+
+#include
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "reverb/cc/schema.pb.h"
+#include "reverb/cc/testing/proto_test_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+TEST(LifoTest, ReturnValueSantiyChecks) {
+ LifoDistribution lifo;
+
+ // Non existent keys cannot be deleted or updated.
+ EXPECT_EQ(lifo.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT);
+ EXPECT_EQ(lifo.Update(123, 4).code(), tensorflow::error::INVALID_ARGUMENT);
+
+ // Keys cannot be inserted twice.
+ TF_EXPECT_OK(lifo.Insert(123, 4));
+ EXPECT_THAT(lifo.Insert(123, 4).code(), tensorflow::error::INVALID_ARGUMENT);
+
+ // Existing keys can be updated and sampled.
+ TF_EXPECT_OK(lifo.Update(123, 5));
+ EXPECT_EQ(lifo.Sample().key, 123);
+
+ // Existing keys cannot be deleted twice.
+ TF_EXPECT_OK(lifo.Delete(123));
+ EXPECT_THAT(lifo.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT);
+}
+
+TEST(LifoTest, MatchesLifoOrdering) {
+ int64_t kItems = 100;
+
+ LifoDistribution lifo;
+ // Insert items.
+ for (int i = 0; i < kItems; i++) {
+ TF_EXPECT_OK(lifo.Insert(i, 0));
+ }
+ // Delete every 10th item.
+ for (int i = 0; i < kItems; i++) {
+ if (i % 10 == 0) TF_EXPECT_OK(lifo.Delete(i));
+ }
+
+ for (int i = kItems - 1; i >= 0; i--) {
+ if (i % 10 == 0) continue;
+ KeyDistributionInterface::KeyWithProbability sample = lifo.Sample();
+ EXPECT_EQ(sample.key, i);
+ EXPECT_EQ(sample.probability, 1);
+ TF_EXPECT_OK(lifo.Delete(sample.key));
+ }
+}
+
+TEST(LifoTest, Options) {
+ LifoDistribution lifo;
+ EXPECT_THAT(lifo.options(), testing::EqualsProto("lifo: true"));
+}
+
+TEST(LifoDeathTest, ClearThenSample) {
+ LifoDistribution lifo;
+ for (int i = 0; i < 100; i++) {
+ TF_EXPECT_OK(lifo.Insert(i, i));
+ }
+ lifo.Sample();
+ lifo.Clear();
+ EXPECT_DEATH(lifo.Sample(), "");
+}
+
+} // namespace
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/distributions/prioritized.cc b/reverb/cc/distributions/prioritized.cc
new file mode 100644
index 0000000..4a0f0f3
--- /dev/null
+++ b/reverb/cc/distributions/prioritized.cc
@@ -0,0 +1,181 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/distributions/prioritized.h"
+
+#include
+#include
+
+#include "absl/random/distributions.h"
+#include "reverb/cc/platform/logging.h"
+#include "reverb/cc/schema.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+// A priority of zero should correspond to zero probability, even if the
+// priority exponent is zero. So this modified version of std::pow is used to
+// turn priorities into weights. Expects base and exponent to be non-negative.
+inline double power(double base, double exponent) {
+ return base == 0. ? 0. : std::pow(base, exponent);
+}
+
+tensorflow::Status CheckValidPriority(double priority) {
+ if (std::isnan(priority))
+ return tensorflow::errors::InvalidArgument("Priority must not be NaN.");
+ if (priority < 0)
+ return tensorflow::errors::InvalidArgument(
+ "Priority must not be negative.");
+ return tensorflow::Status::OK();
+}
+
+} // namespace
+
+PrioritizedDistribution::PrioritizedDistribution(double priority_exponent)
+ : priority_exponent_(priority_exponent), capacity_(std::pow(2, 17)) {
+ REVERB_CHECK_GE(priority_exponent_, 0);
+ sum_tree_.resize(capacity_);
+}
+
+tensorflow::Status PrioritizedDistribution::Delete(Key key) {
+ const size_t last_index = key_to_index_.size() - 1;
+ const auto it = key_to_index_.find(key);
+ if (it == key_to_index_.end())
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " not found in distribution."));
+ const size_t index = it->second;
+
+ if (index != last_index) {
+ // Replace the element that we want to remove with the last element.
+ SetNode(index, NodeValue(last_index));
+ const Key last_key = sum_tree_[last_index].key;
+ sum_tree_[index].key = last_key;
+ key_to_index_[last_key] = index;
+ }
+
+ SetNode(last_index, 0);
+ key_to_index_.erase(it); // Note that this must occur after SetNode.
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status PrioritizedDistribution::Insert(Key key, double priority) {
+ TF_RETURN_IF_ERROR(CheckValidPriority(priority));
+ const size_t index = key_to_index_.size();
+ if (index == capacity_) {
+ capacity_ *= 2;
+ sum_tree_.resize(capacity_);
+ }
+ if (!key_to_index_.try_emplace(key, index).second) {
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " already exists in distribution."));
+ }
+ sum_tree_[index].key = key;
+ REVERB_CHECK_EQ(sum_tree_[index].sum, 0);
+ SetNode(index, power(priority, priority_exponent_));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status PrioritizedDistribution::Update(Key key, double priority) {
+ TF_RETURN_IF_ERROR(CheckValidPriority(priority));
+ const auto it = key_to_index_.find(key);
+ if (it == key_to_index_.end()) {
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " not found in distribution."));
+ }
+ SetNode(it->second, power(priority, priority_exponent_));
+ return tensorflow::Status::OK();
+}
+
+KeyDistributionInterface::KeyWithProbability PrioritizedDistribution::Sample() {
+ const size_t size = key_to_index_.size();
+ REVERB_CHECK_NE(size, 0);
+
+ // This should never be called concurrently from multiple threads.
+ const double target = absl::Uniform(bit_gen_, 0, 1);
+ const double total_weight = sum_tree_[0].sum;
+
+ // All keys have zero priority so treat as if uniformly sampling.
+ if (total_weight == 0) {
+ const size_t pos = static_cast(target * size);
+ return {sum_tree_[pos].key, 1. / size};
+ }
+
+ // We begin traversing the `sum_tree_` from the root to the children in order
+ // to find the `index` corresponding to the sampled `target_weight`.
+ size_t index = 0;
+ double target_weight = target * total_weight;
+ while (true) {
+ // Go to the left sub tree if it contains our sampled `target_weight`.
+ const size_t left_index = 2 * index + 1;
+ const double left_sum = NodeSum(left_index);
+ if (target_weight < left_sum) {
+ index = left_index;
+ continue;
+ }
+ target_weight -= left_sum;
+ // Go to the right sub tree if it contains our sampled `target_weight`.
+ const size_t right_index = 2 * index + 2;
+ const double right_sum = NodeSum(right_index);
+ if (target_weight < right_sum) {
+ index = right_index;
+ continue;
+ }
+ target_weight -= right_sum;
+ // Otherwise it is the current index.
+ break;
+ }
+ REVERB_CHECK_LT(index, size);
+ const double picked_weight = NodeValue(index);
+ REVERB_CHECK_LT(target_weight, picked_weight);
+ return {sum_tree_[index].key, picked_weight / total_weight};
+}
+
+void PrioritizedDistribution::Clear() {
+ for (size_t i = 0; i < key_to_index_.size(); ++i) {
+ sum_tree_[i].sum = 0;
+ }
+ key_to_index_.clear();
+}
+
+KeyDistributionOptions PrioritizedDistribution::options() const {
+ KeyDistributionOptions options;
+ options.mutable_prioritized()->set_priority_exponent(priority_exponent_);
+ return options;
+}
+
+double PrioritizedDistribution::NodeValue(size_t index) const {
+ const size_t left_index = 2 * index + 1;
+ const size_t right_index = 2 * index + 2;
+ return sum_tree_[index].sum - NodeSum(left_index) - NodeSum(right_index);
+}
+
+double PrioritizedDistribution::NodeSum(size_t index) const {
+ return index < key_to_index_.size() ? sum_tree_[index].sum : 0;
+}
+
+void PrioritizedDistribution::SetNode(size_t index, double value) {
+ double difference = value - NodeValue(index);
+ sum_tree_[index].sum += difference;
+ while (index != 0) {
+ index = (index - 1) / 2;
+ sum_tree_[index].sum += difference;
+ }
+}
+
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/distributions/prioritized.h b/reverb/cc/distributions/prioritized.h
new file mode 100644
index 0000000..c9106cb
--- /dev/null
+++ b/reverb/cc/distributions/prioritized.h
@@ -0,0 +1,108 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_PRIORITIZED_H_
+#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_PRIORITIZED_H_
+
+#include
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/random/random.h"
+#include "reverb/cc/checkpointing/checkpoint.pb.h"
+#include "reverb/cc/distributions/interface.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+// This an implementation of a categorical distribution that allows incremental
+// changes to the keys to be made efficiently. The probability of sampling a key
+// is proportional to its priority raised to configurable exponent.
+//
+// Since the priorities and probabilities are stored as doubles, numerical
+// rounding errors may be introduced especially when the relative size of
+// probabilities for keys is large. Ideally when using this class priorities are
+// roughly the same scale and the priority exponent is not large, e.g. less than
+// 2.
+//
+// This was forked from:
+// ## proportional_picker.h
+//
+class PrioritizedDistribution : public KeyDistributionInterface {
+ public:
+ explicit PrioritizedDistribution(double priority_exponent);
+
+ // O(log n) time.
+ tensorflow::Status Delete(Key key) override;
+
+ // The priority must be non-negative. O(log n) time.
+ tensorflow::Status Insert(Key key, double priority) override;
+
+ // The priority must be non-negative. O(log n) time.
+ tensorflow::Status Update(Key key, double priority) override;
+
+ // O(log n) time.
+ KeyWithProbability Sample() override;
+
+ // O(n) time.
+ void Clear() override;
+
+ KeyDistributionOptions options() const override;
+
+ private:
+ struct Node {
+ Key key;
+ // Sum of the exponentiated priority of this node and all its descendants.
+ // This includes the entire sub tree with inner and leaf nodes.
+ // `NodeValue()` can be used to get the exponentiated priority of a node
+ // without its children.
+ double sum = 0;
+ };
+
+ // Gets the individual value of a node in `sum_tree_` without the summed up
+ // value of all its descendants.
+ double NodeValue(size_t index) const;
+
+ // Sum of the exponentiated priority of this node and all its descendants.
+ // If the index is out of bounds, then 0 is returned.
+ double NodeSum(size_t index) const;
+
+ // Sets the individual value of a node in the `sum_tree_`. This does not
+ // include the value of the descendants.
+ void SetNode(size_t index, double value);
+
+ // Controls the degree of prioritization. Priorities are raised to this
+ // exponent before adding them to the `SumTree` as weights. A non-negative
+ // number where a value of zero corresponds each key having the same
+ // probability (except for keys with zero priority).
+ const double priority_exponent_;
+
+ // Capacity of the summary tree. Starts at ~130000 and grows exponentially.
+ size_t capacity_;
+
+ // A tree stored as a flat vector were each node is the sum of its children
+ // plus its own exponentiated priority.
+ std::vector sum_tree_;
+
+ // Maps a key to the index where this key can be found in `sum_tree_`.
+ absl::flat_hash_map key_to_index_;
+
+ // Used for sampling, not thread-safe.
+ absl::BitGen bit_gen_;
+};
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_PRIORITIZED_H_
diff --git a/reverb/cc/distributions/prioritized_test.cc b/reverb/cc/distributions/prioritized_test.cc
new file mode 100644
index 0000000..7f631de
--- /dev/null
+++ b/reverb/cc/distributions/prioritized_test.cc
@@ -0,0 +1,151 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/distributions/prioritized.h"
+
+#include
+#include
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/random/distributions.h"
+#include "absl/random/random.h"
+#include "reverb/cc/schema.pb.h"
+#include "reverb/cc/testing/proto_test_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+const double kInitialPriorityExponent = 1;
+
+TEST(PrioritizedTest, ReturnValueSantiyChecks) {
+ PrioritizedDistribution prioritized(kInitialPriorityExponent);
+
+ // Non existent keys cannot be deleted or updated.
+ EXPECT_EQ(prioritized.Delete(123).code(),
+ tensorflow::error::INVALID_ARGUMENT);
+ EXPECT_EQ(prioritized.Update(123, 4).code(),
+ tensorflow::error::INVALID_ARGUMENT);
+
+ // Keys cannot be inserted twice.
+ TF_EXPECT_OK(prioritized.Insert(123, 4));
+ EXPECT_EQ(prioritized.Insert(123, 4).code(),
+ tensorflow::error::INVALID_ARGUMENT);
+
+ // Existing keys can be updated and sampled.
+ TF_EXPECT_OK(prioritized.Update(123, 5));
+ EXPECT_EQ(prioritized.Sample().key, 123);
+
+ // Negative priorities are not allowed.
+ EXPECT_EQ(prioritized.Update(123, -1).code(),
+ tensorflow::error::INVALID_ARGUMENT);
+ EXPECT_EQ(prioritized.Insert(456, -1).code(),
+ tensorflow::error::INVALID_ARGUMENT);
+
+ // NAN priorites are not allowed
+ EXPECT_EQ(prioritized.Update(123, NAN).code(),
+ tensorflow::error::INVALID_ARGUMENT);
+ EXPECT_EQ(prioritized.Insert(456, NAN).code(),
+ tensorflow::error::INVALID_ARGUMENT);
+
+ // Existing keys cannot be deleted twice.
+ TF_EXPECT_OK(prioritized.Delete(123));
+ EXPECT_EQ(prioritized.Delete(123).code(),
+ tensorflow::error::INVALID_ARGUMENT);
+}
+
+TEST(PrioritizedTest, AllZeroPrioritiesResultsInUniformSampling) {
+ int64_t kItems = 100;
+ int64_t kSamples = 1000000;
+ double expected_probability = 1. / static_cast(kItems);
+
+ PrioritizedDistribution prioritized(kInitialPriorityExponent);
+ for (int i = 0; i < kItems; i++) {
+ TF_EXPECT_OK(prioritized.Insert(i, 0));
+ }
+ std::vector counts(kItems);
+ for (int i = 0; i < kSamples; i++) {
+ KeyDistributionInterface::KeyWithProbability sample = prioritized.Sample();
+ EXPECT_EQ(sample.probability, expected_probability);
+ counts[sample.key]++;
+ }
+ for (int64_t count : counts) {
+ EXPECT_NEAR(static_cast(count) / static_cast(kSamples),
+ expected_probability, 0.05);
+ }
+}
+
+TEST(PrioritizedTest, SampledDistributionMatchesProbabilities) {
+ const int kStart = 10;
+ const int kEnd = 100;
+ const int kSamples = 1000000;
+
+ PrioritizedDistribution prioritized(kInitialPriorityExponent);
+ double sum = 0;
+ absl::BitGen bit_gen_;
+ for (int i = 0; i < kEnd; i++) {
+ if (absl::Uniform(bit_gen_, 0, 1) < 0.5) {
+ TF_EXPECT_OK(prioritized.Insert(i, i));
+ } else {
+ TF_EXPECT_OK(prioritized.Insert(i, 123));
+ TF_EXPECT_OK(prioritized.Update(i, i));
+ }
+ sum += i;
+ }
+ // Remove the first few items.
+ for (int i = 0; i < kStart; i++) {
+ TF_EXPECT_OK(prioritized.Delete(i));
+ sum -= i;
+ }
+ // Update the priorities.
+ std::vector counts(kEnd);
+ absl::flat_hash_map probabilities;
+ for (int i = 0; i < kSamples; i++) {
+ KeyDistributionInterface::KeyWithProbability sample = prioritized.Sample();
+ probabilities[sample.key] = sample.probability;
+ counts[sample.key]++;
+ EXPECT_NEAR(sample.probability, sample.key / sum, 0.001);
+ }
+ for (int k = 0; k < kStart; k++) EXPECT_EQ(counts[k], 0);
+ for (int k = kStart; k < kEnd; k++) {
+ EXPECT_NEAR(static_cast(counts[k]) / static_cast(kSamples),
+ probabilities[k], 0.05);
+ }
+}
+
+TEST(PrioritizedTest, SetsPriorityExponentInOptions) {
+ PrioritizedDistribution prioritized_a(0.1);
+ PrioritizedDistribution prioritized_b(0.5);
+ EXPECT_THAT(prioritized_a.options(),
+ testing::EqualsProto("prioritized: { priority_exponent: 0.1 } "));
+ EXPECT_THAT(prioritized_b.options(),
+ testing::EqualsProto("prioritized: { priority_exponent: 0.5 } "));
+}
+
+TEST(PrioritizedDeathTest, ClearThenSample) {
+ PrioritizedDistribution prioritized(kInitialPriorityExponent);
+ for (int i = 0; i < 100; i++) {
+ TF_EXPECT_OK(prioritized.Insert(i, i));
+ }
+ prioritized.Sample();
+ prioritized.Clear();
+ EXPECT_DEATH(prioritized.Sample(), "");
+}
+
+} // namespace
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/distributions/uniform.cc b/reverb/cc/distributions/uniform.cc
new file mode 100644
index 0000000..0e6ed43
--- /dev/null
+++ b/reverb/cc/distributions/uniform.cc
@@ -0,0 +1,83 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/distributions/uniform.h"
+
+#include "absl/strings/str_cat.h"
+#include "reverb/cc/checkpointing/checkpoint.pb.h"
+#include "reverb/cc/platform/logging.h"
+#include "reverb/cc/schema.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+tensorflow::Status UniformDistribution::Delete(Key key) {
+ const auto it = key_to_index_.find(key);
+ if (it == key_to_index_.end())
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " not found in distribution."));
+ const size_t index = it->second;
+ key_to_index_.erase(it);
+
+ const size_t last_index = keys_.size() - 1;
+ const Key last_key = keys_.back();
+ if (index != last_index) {
+ keys_[index] = last_key;
+ key_to_index_[last_key] = index;
+ }
+
+ keys_.pop_back();
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status UniformDistribution::Insert(Key key, double priority) {
+ const size_t index = keys_.size();
+ if (!key_to_index_.emplace(key, index).second)
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " already exists in distribution."));
+ keys_.push_back(key);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status UniformDistribution::Update(Key key, double priority) {
+ if (key_to_index_.find(key) == key_to_index_.end())
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Key ", key, " not found in distribution."));
+ return tensorflow::Status::OK();
+}
+
+KeyDistributionInterface::KeyWithProbability UniformDistribution::Sample() {
+ REVERB_CHECK(!keys_.empty());
+
+ // This code is not thread-safe, because bit_gen_ is not protected by a mutex
+ // and is not itself thread-safe.
+ const size_t index = absl::Uniform(bit_gen_, 0, keys_.size());
+ return {keys_[index], 1.0 / static_cast(keys_.size())};
+}
+
+void UniformDistribution::Clear() {
+ keys_.clear();
+ key_to_index_.clear();
+}
+
+KeyDistributionOptions UniformDistribution::options() const {
+ KeyDistributionOptions options;
+ options.set_uniform(true);
+ return options;
+}
+
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/distributions/uniform.h b/reverb/cc/distributions/uniform.h
new file mode 100644
index 0000000..623892b
--- /dev/null
+++ b/reverb/cc/distributions/uniform.h
@@ -0,0 +1,60 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_UNIFORM_H_
+#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_UNIFORM_H_
+
+#include
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/random/random.h"
+#include "reverb/cc/checkpointing/checkpoint.pb.h"
+#include "reverb/cc/distributions/interface.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+// Uniform sampling. We ignore all priority values in the calls. All operations
+// take O(1) time. See KeyDistributionInterface for documentation about the
+// methods.
+class UniformDistribution : public KeyDistributionInterface {
+ public:
+ tensorflow::Status Delete(Key key) override;
+
+ tensorflow::Status Insert(Key key, double priority) override;
+
+ tensorflow::Status Update(Key key, double priority) override;
+
+ KeyWithProbability Sample() override;
+
+ void Clear() override;
+
+ KeyDistributionOptions options() const override;
+
+ private:
+ // All keys.
+ std::vector keys_;
+
+ // Maps a key to the index where this key can be found in `keys_.
+ absl::flat_hash_map key_to_index_;
+
+ // Used for sampling, not thread-safe.
+ absl::BitGen bit_gen_;
+};
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_UNIFORM_H_
diff --git a/reverb/cc/distributions/uniform_test.cc b/reverb/cc/distributions/uniform_test.cc
new file mode 100644
index 0000000..8000dbe
--- /dev/null
+++ b/reverb/cc/distributions/uniform_test.cc
@@ -0,0 +1,87 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/distributions/uniform.h"
+
+#include
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "reverb/cc/schema.pb.h"
+#include "reverb/cc/testing/proto_test_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+TEST(UniformTest, ReturnValueSantiyChecks) {
+ UniformDistribution uniform;
+
+ // Non existent keys cannot be deleted or updated.
+ EXPECT_EQ(uniform.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT);
+ EXPECT_EQ(uniform.Update(123, 4).code(), tensorflow::error::INVALID_ARGUMENT);
+
+ // Keys cannot be inserted twice.
+ TF_EXPECT_OK(uniform.Insert(123, 4));
+ EXPECT_EQ(uniform.Insert(123, 4).code(), tensorflow::error::INVALID_ARGUMENT);
+
+ // Existing keys can be updated and sampled.
+ TF_EXPECT_OK(uniform.Update(123, 5));
+ EXPECT_EQ(uniform.Sample().key, 123);
+
+ // Existing keys cannot be deleted twice.
+ TF_EXPECT_OK(uniform.Delete(123));
+ EXPECT_EQ(uniform.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT);
+}
+
+TEST(UniformTest, MatchesUniformDistribution) {
+ const int64_t kItems = 100;
+ const int64_t kSamples = 1000000;
+ double expected_probability = 1. / static_cast(kItems);
+
+ UniformDistribution uniform;
+ for (int i = 0; i < kItems; i++) {
+ TF_EXPECT_OK(uniform.Insert(i, 0));
+ }
+ std::vector counts(kItems);
+ for (int i = 0; i < kSamples; i++) {
+ KeyDistributionInterface::KeyWithProbability sample = uniform.Sample();
+ EXPECT_EQ(sample.probability, expected_probability);
+ counts[sample.key]++;
+ }
+ for (int64_t count : counts) {
+ EXPECT_NEAR(static_cast(count) / static_cast(kSamples),
+ expected_probability, 0.05);
+ }
+}
+
+TEST(UniformTest, Options) {
+ UniformDistribution uniform;
+ EXPECT_THAT(uniform.options(), testing::EqualsProto("uniform: true"));
+}
+
+TEST(UniformDeathTest, ClearThenSample) {
+ UniformDistribution uniform;
+ for (int i = 0; i < 100; i++) {
+ TF_EXPECT_OK(uniform.Insert(i, i));
+ }
+ uniform.Sample();
+ uniform.Clear();
+ EXPECT_DEATH(uniform.Sample(), "");
+}
+
+} // namespace
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/ops/BUILD b/reverb/cc/ops/BUILD
new file mode 100644
index 0000000..8ed6418
--- /dev/null
+++ b/reverb/cc/ops/BUILD
@@ -0,0 +1,41 @@
+load(
+ "//reverb/cc/platform:build_rules.bzl",
+ "reverb_absl_deps",
+ "reverb_gen_op_wrapper_py",
+ "reverb_kernel_library",
+ "reverb_tf_ops_visibility",
+)
+
+package(default_visibility = reverb_tf_ops_visibility())
+
+licenses(["notice"])
+
+reverb_kernel_library(
+ name = "client",
+ srcs = ["client.cc"],
+ deps = [
+ "//reverb/cc:replay_client",
+ ] + reverb_absl_deps(),
+)
+
+reverb_kernel_library(
+ name = "dataset",
+ srcs = ["dataset.cc"],
+ deps = [
+ "//reverb/cc:replay_client",
+ "//reverb/cc:replay_sampler",
+ "//reverb/cc/platform:logging",
+ ],
+)
+
+reverb_gen_op_wrapper_py(
+ name = "gen_client_ops",
+ out = "gen_client_ops.py",
+ kernel_lib = ":client",
+)
+
+reverb_gen_op_wrapper_py(
+ name = "gen_dataset_op",
+ out = "gen_dataset_op.py",
+ kernel_lib = ":dataset",
+)
diff --git a/reverb/cc/ops/client.cc b/reverb/cc/ops/client.cc
new file mode 100644
index 0000000..4105069
--- /dev/null
+++ b/reverb/cc/ops/client.cc
@@ -0,0 +1,284 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+
+#include
+#include "absl/strings/str_cat.h"
+#include "reverb/cc/replay_client.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+using ::tensorflow::tstring;
+using ::tensorflow::errors::InvalidArgument;
+
+REGISTER_OP("ReverbClient")
+ .Output("handle: resource")
+ .Attr("server_address: string")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(R"doc(
+Constructs a `ClientResource` that constructs a `ReplayClient` connected to
+`server_address`. The resource allows ops to share the stub across calls.
+)doc");
+
+REGISTER_OP("ReverbClientSample")
+ .Attr("Toutput_list: list(type) >= 0")
+ .Input("handle: resource")
+ .Input("table: string")
+ .Output("key: uint64")
+ .Output("probability: double")
+ .Output("table_size: int64")
+ .Output("outputs: Toutput_list")
+ .Doc(R"doc(
+Blocking call to sample a single item from table `table` using shared resource.
+A `SampleStream`-stream is opened between the client and the server and when
+the one sample has been received, the stream is closed.
+
+Prefer to use `ReverbDataset` when requesting more than one sample to avoid
+opening and closing the stream with each call.
+)doc");
+
+REGISTER_OP("ReverbClientUpdatePriorities")
+ .Input("handle: resource")
+ .Input("table: string")
+ .Input("keys: uint64")
+ .Input("priorities: double")
+ .Doc(R"doc(
+Blocking call to update the priorities of a collection of items. Keys that could
+not be found in table `table` on server are ignored and does not impact the rest
+of the request.
+)doc");
+
+REGISTER_OP("ReverbClientInsert")
+ .Attr("T: list(type) >= 0")
+ .Input("handle: resource")
+ .Input("data: T")
+ .Input("tables: string")
+ .Input("priorities: double")
+ .Doc(R"doc(
+Blocking call to insert a single trajectory into one or more tables. The data
+is treated as an episode constituting of a single timestep. Note that this mean
+that when the item is sampled, it will be returned as a sequence of length 1,
+containing `data`.
+)doc");
+
+class ClientResource : public tensorflow::ResourceBase {
+ public:
+ explicit ClientResource(const std::string& server_address)
+ : tensorflow::ResourceBase(),
+ client_(server_address),
+ server_address_(server_address) {}
+
+ std::string DebugString() const override {
+ return tensorflow::strings::StrCat("Client with server address: ",
+ server_address_);
+ }
+
+ ReplayClient* client() { return &client_; }
+
+ private:
+ ReplayClient client_;
+ std::string server_address_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ClientResource);
+};
+
+class ClientHandleOp : public tensorflow::ResourceOpKernel {
+ public:
+ explicit ClientHandleOp(tensorflow::OpKernelConstruction* context)
+ : tensorflow::ResourceOpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("server_address", &server_address_));
+ }
+
+ private:
+ tensorflow::Status CreateResource(ClientResource** ret) override {
+ *ret = new ClientResource(server_address_);
+ return tensorflow::Status::OK();
+ }
+
+ std::string server_address_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ClientHandleOp);
+};
+
+// TODO(b/154929314): Change this to an async op.
+class SampleOp : public tensorflow::OpKernel {
+ public:
+ explicit SampleOp(tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(tensorflow::OpKernelContext* context) override {
+ ClientResource* resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &resource));
+
+ const tensorflow::Tensor* table_tensor;
+ OP_REQUIRES_OK(context, context->input("table", &table_tensor));
+ std::string table = table_tensor->scalar()();
+
+ std::vector sample;
+ std::unique_ptr sampler;
+
+ ReplaySampler::Options options;
+ options.max_samples = 1;
+ options.max_in_flight_samples_per_worker = 1;
+
+ OP_REQUIRES_OK(context,
+ resource->client()->NewSampler(table, options, &sampler));
+ OP_REQUIRES_OK(context, sampler->GetNextTimestep(&sample, nullptr));
+ OP_REQUIRES(context, sample.size() == context->num_outputs(),
+ InvalidArgument(
+ "Number of tensors in the replay sample did not match the "
+ "expected count."));
+
+ for (int i = 0; i < sample.size(); i++) {
+ tensorflow::Tensor* tensor;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(i, sample[i].shape(), &tensor));
+ *tensor = std::move(sample[i]);
+ }
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SampleOp);
+};
+
+class UpdatePrioritiesOp : public tensorflow::OpKernel {
+ public:
+ explicit UpdatePrioritiesOp(tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(tensorflow::OpKernelContext* context) override {
+ ClientResource* resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &resource));
+
+ const tensorflow::Tensor* table;
+ OP_REQUIRES_OK(context, context->input("table", &table));
+ const tensorflow::Tensor* keys;
+ OP_REQUIRES_OK(context, context->input("keys", &keys));
+ const tensorflow::Tensor* priorities;
+ OP_REQUIRES_OK(context, context->input("priorities", &priorities));
+
+ OP_REQUIRES(
+ context, keys->dims() == 1,
+ InvalidArgument("Tensors `keys` and `priorities` must be of rank 1."));
+ OP_REQUIRES(context, keys->shape() == priorities->shape(),
+ InvalidArgument(
+ "Tensors `keys` and `priorities` do not match in shape."));
+
+ std::string table_str = table->scalar()();
+ std::vector updates;
+ for (int i = 0; i < keys->dim_size(0); i++) {
+ KeyWithPriority update;
+ update.set_key(keys->flat()(i));
+ update.set_priority(priorities->flat()(i));
+ updates.push_back(std::move(update));
+ }
+
+ // The call will only fail if the Reverb-server is brought down during an
+ // active call (e.g preempted). When this happens the request is retried and
+ // since MutatePriorities sets `wait_for_ready` the request will no be sent
+ // before the server is brought up again. It is therefore no problem to have
+ // this retry in this tight loop.
+ tensorflow::Status status;
+ do {
+ status = resource->client()->MutatePriorities(table_str, updates, {});
+ } while (tensorflow::errors::IsUnavailable(status) ||
+ tensorflow::errors::IsDeadlineExceeded(status));
+ OP_REQUIRES_OK(context, status);
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(UpdatePrioritiesOp);
+};
+
+class InsertOp : public tensorflow::OpKernel {
+ public:
+ explicit InsertOp(tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(tensorflow::OpKernelContext* context) override {
+ ClientResource* resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &resource));
+
+ const tensorflow::Tensor* tables;
+ OP_REQUIRES_OK(context, context->input("tables", &tables));
+ const tensorflow::Tensor* priorities;
+ OP_REQUIRES_OK(context, context->input("priorities", &priorities));
+
+ OP_REQUIRES(context, tables->dims() == 1 && priorities->dims() == 1,
+ InvalidArgument(
+ "Tensors `tables` and `priorities` must be of rank 1."));
+ OP_REQUIRES(
+ context, tables->shape() == priorities->shape(),
+ InvalidArgument(
+ "Tensors `tables` and `priorities` do not match in shape."));
+
+ tensorflow::OpInputList data;
+ OP_REQUIRES_OK(context, context->input_list("data", &data));
+
+ // TODO(b/154929210): This can probably be avoided.
+ std::vector tensors;
+ for (const auto& i : data) {
+ tensors.push_back(i);
+ }
+
+ std::unique_ptr writer;
+ OP_REQUIRES_OK(context,
+ resource->client()->NewWriter(1, 1, false, &writer));
+ OP_REQUIRES_OK(context, writer->AppendTimestep(std::move(tensors)));
+
+ auto tables_t = tables->flat();
+ auto priorities_t = priorities->flat();
+ for (int i = 0; i < tables->dim_size(0); i++) {
+ OP_REQUIRES_OK(context,
+ writer->AddPriority(tables_t(i), 1, priorities_t(i)));
+ }
+
+ OP_REQUIRES_OK(context, writer->Close());
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(InsertOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("ReverbClient").Device(tensorflow::DEVICE_CPU),
+ ClientHandleOp);
+
+REGISTER_KERNEL_BUILDER(
+ Name("ReverbClientInsert").Device(tensorflow::DEVICE_CPU), InsertOp);
+
+REGISTER_KERNEL_BUILDER(
+ Name("ReverbClientSample").Device(tensorflow::DEVICE_CPU), SampleOp);
+
+REGISTER_KERNEL_BUILDER(
+ Name("ReverbClientUpdatePriorities").Device(tensorflow::DEVICE_CPU),
+ UpdatePrioritiesOp);
+
+} // namespace
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/ops/dataset.cc b/reverb/cc/ops/dataset.cc
new file mode 100644
index 0000000..a8f9d38
--- /dev/null
+++ b/reverb/cc/ops/dataset.cc
@@ -0,0 +1,371 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tensorflow/core/framework/dataset.h"
+
+#include "reverb/cc/platform/logging.h"
+#include "reverb/cc/replay_client.h"
+#include "reverb/cc/replay_sampler.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+using ::tensorflow::errors::Cancelled;
+using ::tensorflow::errors::FailedPrecondition;
+using ::tensorflow::errors::InvalidArgument;
+using ::tensorflow::errors::Unimplemented;
+
+REGISTER_OP("ReverbDataset")
+ .Attr("server_address: string")
+ .Attr("table: string")
+ .Attr("sequence_length: int = -1")
+ .Attr("emit_timesteps: bool = true")
+ .Attr("max_in_flight_samples_per_worker: int = 100")
+ .Attr("num_workers_per_iterator: int = -1")
+ .Attr("max_samples_per_stream: int = -1")
+ .Attr("dtypes: list(type) >= 1")
+ .Attr("shapes: list(shape) >= 1")
+ .Output("dataset: variant")
+ .SetIsStateful()
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(R"doc(
+Establishes and manages a connection to gRPC ReplayService at `server_address`
+to stream samples from table `table`.
+
+The connection is managed using a single instance of `ReplayClient` (see
+../replay_client.h) owned by the Dataset. From the shared `ReplayClient`, each
+iterator maintains their own `ReplaySampler` (see ../replay_sampler.h), allowing
+for multiple parallel streams using a single connection.
+
+`dtypes` and `shapes` must match the type and shape of a single "timestep"
+within sampled sequences. That is, (key, priority, table_size, ...data passed to
+`ReplayWriter::AppendTimestep` at insertion time). This is the type and shape of
+tensors returned by `GetNextTimestep`.
+
+sequence_length: (Defaults to -1, i.e unknown) The number of timesteps in
+the samples. If set then the length of the received samples are checked against
+this value.
+
+`emit_timesteps` (defaults to true) determines whether individual timesteps or
+complete sequences should be returned from the iterators. When set to false
+(i.e return sequences), `shapes` must have dim[0] equal to `sequence_length`.
+Emitting complete samples is more efficient as it avoids the memcopies involved
+in splitting up a sequence and then batching it up again.
+
+`max_in_flight_samples_per_worker` (defaults to 100) is the maximum number of
+ sampled item allowed to exist in flight (per iterator). See
+`ReplaySampler::Options::max_in_flight_samples_per_worker` for more details.
+
+`num_workers_per_iterator` (defaults to -1, i.e auto selected) is the number of
+worker threads to start per iterator. When the selected table uses a FIFO
+sampler (i.e a queue) then exactly 1 worker must be used to avoid races causing
+invalid ordering of items. For all other samplers, this value should be roughly
+equal to the number of threads available on the CPU.
+
+`max_samples_per_stream` (defaults to -1, i.e auto selected) is the maximum
+number of samples to fetch from a stream before a new call is made. Keeping this
+number low ensures that the data is fetched uniformly from all servers.
+)doc");
+
+class ReverbDatasetOp : public tensorflow::data::DatasetOpKernel {
+ public:
+ explicit ReverbDatasetOp(tensorflow::OpKernelConstruction* ctx)
+ : tensorflow::data::DatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("server_address", &server_address_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("table", &table_));
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr("max_in_flight_samples_per_worker",
+ &sampler_options_.max_in_flight_samples_per_worker));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_workers_per_iterator",
+ &sampler_options_.num_workers));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max_samples_per_stream",
+ &sampler_options_.max_samples_per_stream));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("sequence_length", &sequence_length_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("emit_timesteps", &emit_timesteps_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
+
+ if (!emit_timesteps_) {
+ for (int i = 0; i < shapes_.size(); i++) {
+ OP_REQUIRES(ctx, shapes_[i].dims() != 0,
+ InvalidArgument(
+ "When emit_timesteps is false, all elements of shapes "
+ "must have "
+ "dim[0] = sequence_length (",
+ sequence_length_, "). Element ", i,
+ " of flattened shapes has rank 0 and thus no dim[0]."));
+
+ OP_REQUIRES(ctx, shapes_[i].dim_size(0) == sequence_length_,
+ InvalidArgument("When emit_timesteps is false, all "
+ "elements of shapes must have "
+ "dim[0] = sequence_length (",
+ sequence_length_, "). Element ", i,
+ " of flattened shapes has dim[0] = ",
+ shapes_[i].dim_size(0), "."));
+ }
+ }
+ }
+
+ void MakeDataset(tensorflow::OpKernelContext* ctx,
+ tensorflow::data::DatasetBase** output) override {
+ *output = new Dataset(ctx, server_address_, dtypes_, shapes_, table_,
+ sampler_options_, sequence_length_, emit_timesteps_);
+ }
+
+ private:
+ class Dataset : public tensorflow::data::DatasetBase {
+ public:
+ Dataset(tensorflow::OpKernelContext* ctx, std::string server_address,
+ tensorflow::DataTypeVector dtypes,
+ std::vector shapes,
+ std::string table, const ReplaySampler::Options& sampler_options,
+ int sequence_length, bool emit_timesteps)
+ : tensorflow::data::DatasetBase(tensorflow::data::DatasetContext(ctx)),
+ server_address_(std::move(server_address)),
+ dtypes_(std::move(dtypes)),
+ shapes_(std::move(shapes)),
+ table_(std::move(table)),
+ sampler_options_(sampler_options),
+ sequence_length_(sequence_length),
+ emit_timesteps_(emit_timesteps),
+ client_(absl::make_unique(server_address_)) {}
+
+ std::unique_ptr MakeIteratorInternal(
+ const std::string& prefix) const override {
+ return absl::make_unique(
+ tensorflow::data::DatasetIterator::Params{
+ this, absl::StrCat(prefix, "::ReverbDataset")},
+ client_.get(), table_, sampler_options_, sequence_length_,
+ emit_timesteps_, dtypes_, shapes_);
+ }
+
+ const tensorflow::DataTypeVector& output_dtypes() const override {
+ return dtypes_;
+ }
+
+ const std::vector& output_shapes()
+ const override {
+ return shapes_;
+ }
+
+ std::string DebugString() const override {
+ return "ReverbDatasetOp::Dataset";
+ }
+
+ tensorflow::Status CheckExternalState() const override {
+ return FailedPrecondition(DebugString(), " depends on external state.");
+ }
+
+ protected:
+ tensorflow::Status AsGraphDefInternal(
+ tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b,
+ tensorflow::Node** output) const override {
+ tensorflow::AttrValue server_address_attr;
+ tensorflow::AttrValue table_attr;
+ tensorflow::AttrValue max_in_flight_samples_per_worker_attr;
+ tensorflow::AttrValue num_workers_attr;
+ tensorflow::AttrValue max_samples_per_stream_attr;
+ tensorflow::AttrValue sequence_length_attr;
+ tensorflow::AttrValue emit_timesteps_attr;
+ tensorflow::AttrValue dtypes_attr;
+ tensorflow::AttrValue shapes_attr;
+
+ b->BuildAttrValue(server_address_, &server_address_attr);
+ b->BuildAttrValue(table_, &table_attr);
+ b->BuildAttrValue(sampler_options_.max_in_flight_samples_per_worker,
+ &max_in_flight_samples_per_worker_attr);
+ b->BuildAttrValue(sampler_options_.num_workers, &num_workers_attr);
+ b->BuildAttrValue(sampler_options_.max_samples_per_stream,
+ &max_samples_per_stream_attr);
+ b->BuildAttrValue(sequence_length_, &sequence_length_attr);
+ b->BuildAttrValue(emit_timesteps_, &emit_timesteps_attr);
+ b->BuildAttrValue(dtypes_, &dtypes_attr);
+ b->BuildAttrValue(shapes_, &shapes_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {},
+ {
+ {"server_address", server_address_attr},
+ {"table", table_attr},
+ {"max_in_flight_samples_per_worker",
+ max_in_flight_samples_per_worker_attr},
+ {"num_workers_per_iterator", num_workers_attr},
+ {"max_samples_per_stream", max_samples_per_stream_attr},
+ {"sequence_length", sequence_length_attr},
+ {"emit_timesteps", emit_timesteps_attr},
+ {"dtypes", dtypes_attr},
+ {"shapes", shapes_attr},
+ },
+ output));
+
+ return tensorflow::Status::OK();
+ }
+
+ private:
+ class Iterator : public tensorflow::data::DatasetIterator {
+ public:
+ explicit Iterator(
+ const Params& params, ReplayClient* client, const std::string& table,
+ const ReplaySampler::Options& sampler_options, int sequence_length,
+ bool emit_timesteps, const tensorflow::DataTypeVector& dtypes,
+ const std::vector& shapes)
+ : DatasetIterator(params),
+ client_(client),
+ table_(table),
+ sampler_options_(sampler_options),
+ sequence_length_(sequence_length),
+ emit_timesteps_(emit_timesteps),
+ dtypes_(dtypes),
+ shapes_(shapes),
+ step_within_sample_(0) {}
+
+ tensorflow::Status Initialize(
+ tensorflow::data::IteratorContext* ctx) override {
+ // If sequences are emitted then the all shapes will start with the
+ // sequence length. The validation expects the shapes of a single
+ // timestep so if sequences are emitted then we need to trim the leading
+ // dim on all shapes before validating it.
+ auto validation_shapes = shapes_;
+ if (!emit_timesteps_) {
+ for (auto& shape : validation_shapes) {
+ shape.RemoveDim(0);
+ }
+ }
+
+ // TODO(b/154929217): Expose this option so it can be set to infinite
+ // outside of tests.
+ const auto kTimeout = absl::Seconds(30);
+ auto status =
+ client_->NewSampler(table_, sampler_options_,
+ /*validation_dtypes=*/dtypes_,
+ validation_shapes, kTimeout, &sampler_);
+ if (tensorflow::errors::IsDeadlineExceeded(status)) {
+ REVERB_LOG(REVERB_WARNING)
+ << "Unable to validate shapes and dtypes of new sampler for '"
+ << table_ << "' as server could not be reached in time ("
+ << kTimeout
+ << "). We were thus unable to fetch signature from server. The "
+ "sampler will be constructed without validating the dtypes "
+ "and shapes.";
+ return client_->NewSampler(table_, sampler_options_, &sampler_);
+ }
+ return status;
+ }
+
+ tensorflow::Status GetNextInternal(
+ tensorflow::data::IteratorContext* ctx,
+ std::vector* out_tensors,
+ bool* end_of_sequence) override {
+ REVERB_CHECK(sampler_.get() != nullptr)
+ << "Initialize was not called?";
+
+ auto token = ctx->cancellation_manager()->get_cancellation_token();
+ bool registered = ctx->cancellation_manager()->RegisterCallback(
+ token, [&] { sampler_->Close(); });
+ if (!registered) {
+ sampler_->Close();
+ }
+
+ tensorflow::Status status;
+ if (emit_timesteps_) {
+ bool last_timestep = false;
+ status = sampler_->GetNextTimestep(out_tensors, &last_timestep);
+
+ step_within_sample_++;
+
+ if (last_timestep && sequence_length_ > 0 &&
+ step_within_sample_ != sequence_length_) {
+ return InvalidArgument(
+ "Received sequence of invalid length. Expected ",
+ sequence_length_, " steps, got ", step_within_sample_);
+ }
+ if (step_within_sample_ == sequence_length_ && !last_timestep) {
+ return InvalidArgument(
+ "Receieved sequence did not terminate after expected number of "
+ "steps (",
+ sequence_length_, ").");
+ }
+ if (last_timestep) {
+ step_within_sample_ = 0;
+ }
+ } else {
+ status = sampler_->GetNextSample(out_tensors);
+ }
+
+ if (registered &&
+ !ctx->cancellation_manager()->DeregisterCallback(token)) {
+ return Cancelled("Iterator context was cancelled");
+ }
+
+ return status;
+ }
+
+ protected:
+ tensorflow::Status SaveInternal(
+ tensorflow::data::SerializationContext* ctx,
+ tensorflow::data::IteratorStateWriter* writer) override {
+ return Unimplemented("SaveInternal is currently not supported");
+ }
+
+ tensorflow::Status RestoreInternal(
+ tensorflow::data::IteratorContext* ctx,
+ tensorflow::data::IteratorStateReader* reader) override {
+ return Unimplemented("RestoreInternal is currently not supported");
+ }
+
+ private:
+ ReplayClient* client_;
+ const std::string& table_;
+ const ReplaySampler::Options sampler_options_;
+ const int sequence_length_;
+ const bool emit_timesteps_;
+ const tensorflow::DataTypeVector& dtypes_;
+ const std::vector& shapes_;
+ std::unique_ptr sampler_;
+ int step_within_sample_;
+ }; // Iterator.
+
+ const std::string server_address_;
+ const tensorflow::DataTypeVector dtypes_;
+ const std::vector shapes_;
+ const std::string table_;
+ const ReplaySampler::Options sampler_options_;
+ const int sequence_length_;
+ const bool emit_timesteps_;
+ std::unique_ptr client_;
+ }; // Dataset.
+
+ std::string server_address_;
+ std::string table_;
+ ReplaySampler::Options sampler_options_;
+ int sequence_length_;
+ bool emit_timesteps_;
+ tensorflow::DataTypeVector dtypes_;
+ std::vector shapes_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ReverbDatasetOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("ReverbDataset").Device(tensorflow::DEVICE_CPU),
+ ReverbDatasetOp);
+
+} // namespace
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/platform/BUILD b/reverb/cc/platform/BUILD
new file mode 100644
index 0000000..357dbee
--- /dev/null
+++ b/reverb/cc/platform/BUILD
@@ -0,0 +1,151 @@
+# Platform-specific code for reverb
+
+load(
+ "//reverb/cc/platform:build_rules.bzl",
+ "reverb_absl_deps",
+ "reverb_cc_library",
+ "reverb_cc_test",
+ "reverb_grpc_deps",
+ "reverb_tf_deps",
+)
+
+package(default_visibility = ["//reverb:__subpackages__"])
+
+licenses(["notice"])
+
+exports_files(["LICENSE"])
+
+reverb_cc_library(
+ name = "tfrecord_checkpointer",
+ srcs = ["tfrecord_checkpointer.cc"],
+ hdrs = ["tfrecord_checkpointer.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ "//reverb/cc:chunk_store",
+ "//reverb/cc:priority_table",
+ "//reverb/cc:schema_cc_proto",
+ "//reverb/cc/checkpointing:checkpoint_cc_proto",
+ "//reverb/cc/checkpointing:interface",
+ "//reverb/cc/distributions:fifo",
+ "//reverb/cc/distributions:heap",
+ "//reverb/cc/distributions:interface",
+ "//reverb/cc/distributions:lifo",
+ "//reverb/cc/distributions:prioritized",
+ "//reverb/cc/distributions:uniform",
+ "//reverb/cc/table_extensions:interface",
+ ] + reverb_tf_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_test(
+ name = "tfrecord_checkpointer_test",
+ srcs = ["tfrecord_checkpointer_test.cc"],
+ deps = [
+ ":tfrecord_checkpointer",
+ "//reverb/cc:chunk_store",
+ "//reverb/cc:priority_table",
+ "//reverb/cc/distributions:fifo",
+ "//reverb/cc/distributions:heap",
+ "//reverb/cc/distributions:prioritized",
+ "//reverb/cc/distributions:uniform",
+ "//reverb/cc/testing:proto_test_util",
+ ] + reverb_tf_deps(),
+)
+
+reverb_cc_library(
+ name = "checkpointing_hdr",
+ hdrs = ["checkpointing.h"],
+ deps = [
+ "//reverb/cc/checkpointing:interface",
+ ] + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "checkpointing",
+ hdrs = ["checkpointing.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ "//reverb/cc/checkpointing:interface",
+ "//reverb/cc/platform/default:checkpointer",
+ ] + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "grpc_utils_hdr",
+ hdrs = ["grpc_utils.h"],
+ deps = reverb_grpc_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "grpc_utils",
+ hdrs = ["grpc_utils.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ "//reverb/cc/platform/default:grpc_utils",
+ ] + reverb_grpc_deps() + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "snappy_hdr",
+ hdrs = ["snappy.h"],
+ deps = reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "snappy",
+ hdrs = ["snappy.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ "//reverb/cc/platform/default:snappy",
+ ] + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "thread_hdr",
+ hdrs = ["thread.h"],
+ deps = reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "thread",
+ hdrs = ["thread.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = [
+ "//reverb/cc/platform/default:thread",
+ ] + reverb_absl_deps(),
+)
+
+reverb_cc_test(
+ name = "thread_test",
+ srcs = ["thread_test.cc"],
+ deps = [
+ ":thread",
+ ] + reverb_absl_deps(),
+)
+
+reverb_cc_library(
+ name = "logging",
+ hdrs = ["logging.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = ["//reverb/cc/platform/default:logging"],
+)
+
+reverb_cc_library(
+ name = "net_hdr",
+ hdrs = ["net.h"],
+)
+
+reverb_cc_library(
+ name = "net",
+ hdrs = ["net.h"],
+ visibility = ["//reverb:__subpackages__"],
+ deps = ["//reverb/cc/platform/default:net"],
+)
+
+reverb_cc_test(
+ name = "net_test",
+ srcs = ["net_test.cc"],
+ deps = [
+ ":logging",
+ ":net",
+ ],
+)
diff --git a/reverb/cc/platform/build_rules.bzl b/reverb/cc/platform/build_rules.bzl
new file mode 100644
index 0000000..cffa946
--- /dev/null
+++ b/reverb/cc/platform/build_rules.bzl
@@ -0,0 +1,40 @@
+"""Main Starlark code for platform-specific build rules."""
+
+load(
+ "//reverb/cc/platform/default:build_rules.bzl",
+ _reverb_absl_deps = "reverb_absl_deps",
+ _reverb_cc_grpc_library = "reverb_cc_grpc_library",
+ _reverb_cc_library = "reverb_cc_library",
+ _reverb_cc_proto_library = "reverb_cc_proto_library",
+ _reverb_cc_test = "reverb_cc_test",
+ _reverb_gen_op_wrapper_py = "reverb_gen_op_wrapper_py",
+ _reverb_grpc_deps = "reverb_grpc_deps",
+ _reverb_kernel_library = "reverb_kernel_library",
+ _reverb_py_proto_library = "reverb_py_proto_library",
+ _reverb_py_standard_imports = "reverb_py_standard_imports",
+ _reverb_py_test = "reverb_py_test",
+ _reverb_pybind_deps = "reverb_pybind_deps",
+ _reverb_pybind_extension = "reverb_pybind_extension",
+ _reverb_pytype_library = "reverb_pytype_library",
+ _reverb_pytype_strict_library = "reverb_pytype_strict_library",
+ _reverb_tf_deps = "reverb_tf_deps",
+ _reverb_tf_ops_visibility = "reverb_tf_ops_visibility",
+)
+
+reverb_absl_deps = _reverb_absl_deps
+reverb_cc_library = _reverb_cc_library
+reverb_cc_test = _reverb_cc_test
+reverb_cc_grpc_library = _reverb_cc_grpc_library
+reverb_cc_proto_library = _reverb_cc_proto_library
+reverb_gen_op_wrapper_py = _reverb_gen_op_wrapper_py
+reverb_grpc_deps = _reverb_grpc_deps
+reverb_kernel_library = _reverb_kernel_library
+reverb_py_proto_library = _reverb_py_proto_library
+reverb_py_standard_imports = _reverb_py_standard_imports
+reverb_py_test = _reverb_py_test
+reverb_pybind_deps = _reverb_pybind_deps
+reverb_pybind_extension = _reverb_pybind_extension
+reverb_pytype_library = _reverb_pytype_library
+reverb_pytype_strict_library = _reverb_pytype_strict_library
+reverb_tf_ops_visibility = _reverb_tf_ops_visibility
+reverb_tf_deps = _reverb_tf_deps
diff --git a/reverb/cc/platform/checkpointing.h b/reverb/cc/platform/checkpointing.h
new file mode 100644
index 0000000..a90e2fd
--- /dev/null
+++ b/reverb/cc/platform/checkpointing.h
@@ -0,0 +1,33 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef REVERB_CC_PLATFORM_CHECKPOINTER_H_
+#define REVERB_CC_PLATFORM_CHECKPOINTER_H_
+
+#include
+#include
+
+#include "absl/strings/string_view.h"
+#include "reverb/cc/checkpointing/interface.h"
+
+namespace deepmind {
+namespace reverb {
+
+std::unique_ptr CreateDefaultCheckpointer(
+ std::string root_dir, std::string group = "");
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // REVERB_CC_PLATFORM_CHECKPOINTER_H_
diff --git a/reverb/cc/platform/default/BUILD b/reverb/cc/platform/default/BUILD
new file mode 100644
index 0000000..0b7e276
--- /dev/null
+++ b/reverb/cc/platform/default/BUILD
@@ -0,0 +1,72 @@
+# Platform-specific code for reverb
+
+load(
+ "//reverb/cc/platform/default:build_rules.bzl",
+ "reverb_cc_library",
+ "reverb_grpc_deps",
+ "reverb_tf_deps",
+)
+
+package(default_visibility = ["//reverb/cc/platform:__pkg__"])
+
+licenses(["notice"])
+
+exports_files(["LICENSE"])
+
+reverb_cc_library(
+ name = "snappy",
+ srcs = ["snappy.cc"],
+ deps = [
+ "//reverb/cc/platform:snappy_hdr",
+ "@com_google_absl//absl/strings",
+ ] + reverb_tf_deps(),
+ alwayslink = 1,
+)
+
+reverb_cc_library(
+ name = "checkpointer",
+ srcs = ["default_checkpointer.cc"],
+ deps = [
+ "//reverb/cc/checkpointing:interface",
+ "//reverb/cc/platform:checkpointing_hdr",
+ "//reverb/cc/platform:tfrecord_checkpointer",
+ "@com_google_absl//absl/strings",
+ ],
+ alwayslink = 1,
+)
+
+reverb_cc_library(
+ name = "thread",
+ srcs = ["thread.cc"],
+ deps = [
+ "//reverb/cc/platform:thread_hdr",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+ alwayslink = 1,
+)
+
+reverb_cc_library(
+ name = "logging",
+ hdrs = ["logging.h"],
+ deps = reverb_tf_deps(),
+)
+
+reverb_cc_library(
+ name = "grpc_utils",
+ srcs = ["grpc_utils.cc"],
+ deps = [
+ "//reverb/cc/platform:grpc_utils_hdr",
+ ] + reverb_grpc_deps(),
+ alwayslink = 1,
+)
+
+reverb_cc_library(
+ name = "net",
+ srcs = ["net.cc"],
+ deps = [
+ "//reverb/cc/platform:logging",
+ "//reverb/cc/platform:net_hdr",
+ ],
+ alwayslink = 1,
+)
diff --git a/reverb/cc/platform/default/build_rules.bzl b/reverb/cc/platform/default/build_rules.bzl
new file mode 100644
index 0000000..c996a96
--- /dev/null
+++ b/reverb/cc/platform/default/build_rules.bzl
@@ -0,0 +1,552 @@
+"""Default versions of reverb build rule helpers."""
+
+def tf_copts():
+ return ["-Wno-sign-compare"]
+
+def reverb_cc_library(
+ name,
+ srcs = [],
+ hdrs = [],
+ deps = [],
+ testonly = 0,
+ **kwargs):
+ if testonly:
+ new_deps = [
+ "@com_google_googletest//:gtest",
+ "@tensorflow_includes//:includes",
+ "@tensorflow_solib//:framework_lib",
+ ]
+ else:
+ new_deps = []
+ native.cc_library(
+ name = name,
+ srcs = srcs,
+ hdrs = hdrs,
+ copts = tf_copts(),
+ testonly = testonly,
+ deps = depset(deps + new_deps),
+ **kwargs
+ )
+
+def reverb_kernel_library(name, srcs = [], deps = [], **kwargs):
+ deps = deps + reverb_tf_deps()
+ reverb_cc_library(
+ name = name,
+ srcs = srcs,
+ deps = deps,
+ alwayslink = 1,
+ **kwargs
+ )
+
+def _normalize_proto(x):
+ if x.endswith("_proto"):
+ x = x.rstrip("_proto")
+ if x.endswith("_cc"):
+ x = x.rstrip("_cc")
+ if x.endswith("_pb2"):
+ x = x.rstrip("_pb2")
+ return x
+
+def _strip_proto_suffix(x):
+ # Workaround for bug that str.rstrip(".END") takes off more than just ".END"
+ if x.endswith(".proto"):
+ x = x[:-6]
+ return x
+
+def reverb_cc_proto_library(name, srcs = [], deps = [], **kwargs):
+ """Build a proto cc_library.
+
+ This rule does three things:
+
+ 1) Create a filegroup with name `name` that contains `srcs`
+ and any sources from deps named "x_proto" or "x_cc_proto".
+
+ 2) Uses protoc to compile srcs to .h/.cc files, allowing any
+ tensorflow imports.
+
+ 3) Creates a cc_library with name `name` building the resulting .h/.cc
+ files.
+
+ Args:
+ name: The name, should end with "_cc_proto".
+ srcs: The .proto files.
+ deps: Any reverb_cc_proto_library targets.
+ **kwargs: Any additional args for the cc_library rule.
+ """
+ gen_srcs = [_strip_proto_suffix(x) + ".pb.cc" for x in srcs]
+ gen_hdrs = [_strip_proto_suffix(x) + ".pb.h" for x in srcs]
+ src_paths = ["$(location {})".format(x) for x in srcs]
+ dep_srcs = []
+ for x in deps:
+ if x.endswith("_proto"):
+ dep_srcs.append(_normalize_proto(x))
+ native.filegroup(
+ name = _normalize_proto(name),
+ srcs = srcs + dep_srcs,
+ **kwargs
+ )
+ native.genrule(
+ name = name + "_gen",
+ srcs = srcs,
+ outs = gen_srcs + gen_hdrs,
+ tools = dep_srcs + [
+ "@protobuf_protoc//:protoc_bin",
+ "@tensorflow_includes//:protos",
+ ],
+ cmd = """
+ OUTDIR=$$(echo $(RULEDIR) | sed -e 's#reverb/.*##')
+ $(location @protobuf_protoc//:protoc_bin) \
+ --proto_path=external/tensorflow_includes/tensorflow_includes/ \
+ --proto_path=. \
+ --cpp_out=$$OUTDIR {}""".format(
+ " ".join(src_paths),
+ ),
+ )
+
+ native.cc_library(
+ name = "{}_static".format(name),
+ srcs = gen_srcs,
+ hdrs = gen_hdrs,
+ deps = depset(deps + reverb_tf_deps()),
+ alwayslink = 1,
+ **kwargs
+ )
+ native.cc_binary(
+ name = "lib{}.so".format(name),
+ deps = ["{}_static".format(name)],
+ linkshared = 1,
+ **kwargs
+ )
+ native.cc_library(
+ name = name,
+ hdrs = gen_hdrs,
+ srcs = ["lib{}.so".format(name)],
+ deps = depset(deps + reverb_tf_deps()),
+ alwayslink = 1,
+ **kwargs
+ )
+
+def reverb_py_proto_library(name, srcs = [], deps = [], **kwargs):
+ """Build a proto py_library.
+
+ This rule does three things:
+
+ 1) Create a filegroup with name `name` that contains `srcs`
+ and any sources from deps named "x_proto" or "x_py_proto".
+
+ 2) Uses protoc to compile srcs to _pb2.py files, allowing any
+ tensorflow imports.
+
+ 3) Creates a py_library with name `name` building the resulting .py
+ files.
+
+ Args:
+ name: The name, should end with "_cc_proto".
+ srcs: The .proto files.
+ deps: Any reverb_cc_proto_library targets.
+ **kwargs: Any additional args for the cc_library rule.
+ """
+ gen_srcs = [_strip_proto_suffix(x) + "_pb2.py" for x in srcs]
+ src_paths = ["$(location {})".format(x) for x in srcs]
+ proto_deps = []
+ py_deps = []
+ for x in deps:
+ if x.endswith("_proto"):
+ proto_deps.append(_normalize_proto(x))
+ else:
+ py_deps.append(x)
+ native.filegroup(
+ name = _normalize_proto(name),
+ srcs = srcs + proto_deps,
+ **kwargs
+ )
+ native.genrule(
+ name = name + "_gen",
+ srcs = srcs,
+ outs = gen_srcs,
+ tools = proto_deps + [
+ "@protobuf_protoc//:protoc_bin",
+ "@tensorflow_includes//:protos",
+ ],
+ cmd = """
+ OUTDIR=$$(echo $(RULEDIR) | sed -e 's#reverb/.*##')
+ $(location @protobuf_protoc//:protoc_bin) \
+ --proto_path=external/tensorflow_includes/tensorflow_includes/ \
+ --proto_path=. \
+ --python_out=$$OUTDIR {}""".format(
+ " ".join(src_paths),
+ ),
+ )
+ native.py_library(
+ name = name,
+ srcs = gen_srcs,
+ deps = py_deps,
+ data = proto_deps,
+ **kwargs
+ )
+
+def reverb_cc_grpc_library(
+ name,
+ srcs = [],
+ deps = [],
+ generate_mocks = False,
+ **kwargs):
+ """Build a grpc cc_library.
+
+ This rule does two things:
+
+ 1) Uses protoc + grpc plugin to compile srcs to .h/.cc files, allowing any
+ tensorflow imports. Also creates mock headers if requested.
+
+ 2) Creates a cc_library with name `name` building the resulting .h/.cc
+ files.
+
+ Args:
+ name: The name, should end with "_cc_grpc_proto".
+ srcs: The .proto files.
+ deps: reverb_cc_proto_library targets. Must include src + "_cc_proto",
+ the cc_proto library, for each src in srcs.
+ generate_mocks: If true, creates mock headers for each source.
+ **kwargs: Any additional args for the cc_library rule.
+ """
+ gen_srcs = [x.rstrip(".proto") + ".grpc.pb.cc" for x in srcs]
+ gen_hdrs = [x.rstrip(".proto") + ".grpc.pb.h" for x in srcs]
+ proto_src_deps = []
+ for x in deps:
+ if x.endswith("_proto"):
+ proto_src_deps.append(_normalize_proto(x))
+ src_paths = ["$(location {})".format(x) for x in srcs]
+
+ if generate_mocks:
+ gen_mocks = [x.rstrip(".proto") + "_mock.grpc.pb.h" for x in srcs]
+ else:
+ gen_mocks = []
+
+ native.genrule(
+ name = name + "_gen",
+ srcs = srcs,
+ outs = gen_srcs + gen_hdrs + gen_mocks,
+ tools = proto_src_deps + [
+ "@protobuf_protoc//:protoc_bin",
+ "@tensorflow_includes//:protos",
+ "@com_github_grpc_grpc//src/compiler:grpc_cpp_plugin",
+ ],
+ cmd = """
+ OUTDIR=$$(echo $(RULEDIR) | sed -e 's#reverb/.*##')
+ $(location @protobuf_protoc//:protoc_bin) \
+ --plugin=protoc-gen-grpc=$(location @com_github_grpc_grpc//src/compiler:grpc_cpp_plugin) \
+ --proto_path=external/tensorflow_includes/tensorflow_includes/ \
+ --proto_path=. \
+ --grpc_out={} {}""".format(
+ "generate_mock_code=true:$$OUTDIR" if generate_mocks else "$$OUTDIR",
+ " ".join(src_paths),
+ ),
+ )
+
+ native.cc_library(
+ name = name,
+ srcs = gen_srcs,
+ hdrs = gen_hdrs + gen_mocks,
+ deps = depset(deps + ["@com_github_grpc_grpc//:grpc++_codegen_proto"]),
+ **kwargs
+ )
+
+def reverb_cc_test(name, srcs, deps = [], **kwargs):
+ """Reverb-specific version of cc_test.
+
+ Args:
+ name: Target name.
+ srcs: Target sources.
+ deps: Target deps.
+ **kwargs: Additional args to cc_test.
+ """
+ new_deps = [
+ "@com_google_googletest//:gtest",
+ "@tensorflow_includes//:includes",
+ "@tensorflow_solib//:framework_lib",
+ "@com_google_googletest//:gtest_main",
+ ]
+ size = kwargs.pop("size", "small")
+ native.cc_test(
+ name = name,
+ size = size,
+ copts = tf_copts(),
+ srcs = srcs,
+ deps = depset(deps + new_deps),
+ **kwargs
+ )
+
+def reverb_gen_op_wrapper_py(name, out, kernel_lib, linkopts = [], **kwargs):
+ """Generates the py_library `name` with a data dep on the ops in kernel_lib.
+
+ The resulting py_library creates file `$out`, and has a dependency on a
+ symbolic library called lib{$name}_gen_op.so, which contains the kernels
+ and ops and can be loaded via `tf.load_op_library`.
+
+ Args:
+ name: The name of the py_library.
+ out: The name of the python file. Use "gen_{name}_ops.py".
+ kernel_lib: A cc_kernel_library target to generate for.
+ **kwargs: Any args to the `cc_binary` and `py_library` internal rules.
+ """
+ if not out.endswith(".py"):
+ fail("Argument out must end with '.py', but saw: {}".format(out))
+
+ module_name = "lib{}_gen_op".format(name)
+ version_script_file = "%s-version-script.lds" % module_name
+ native.genrule(
+ name = module_name + "_version_script",
+ outs = [version_script_file],
+ cmd = "echo '{global:\n *tensorflow*;\n *deepmind*;\n local: *;};' >$@",
+ output_licenses = ["unencumbered"],
+ visibility = ["//visibility:private"],
+ )
+ native.cc_binary(
+ name = "{}.so".format(module_name),
+ deps = [kernel_lib] + reverb_tf_deps() + [version_script_file],
+ copts = tf_copts() + [
+ "-fno-strict-aliasing", # allow a wider range of code [aliasing] to compile.
+ "-fvisibility=hidden", # avoid symbol clashes between DSOs.
+ ],
+ linkshared = 1,
+ linkopts = linkopts + _rpath_linkopts(module_name) + [
+ "-Wl,--version-script",
+ "$(location %s)" % version_script_file,
+ ],
+ **kwargs
+ )
+ native.genrule(
+ name = "{}_genrule".format(out),
+ outs = [out],
+ cmd = """
+ echo 'import tensorflow as tf
+_reverb_gen_op = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile(
+ "lib{}_gen_op.so"))
+_locals = locals()
+for k in dir(_reverb_gen_op):
+ _locals[k] = getattr(_reverb_gen_op, k)
+del _locals' > $@""".format(name),
+ )
+ native.py_library(
+ name = name,
+ srcs = [out],
+ data = [":lib{}_gen_op.so".format(name)],
+ **kwargs
+ )
+
+def reverb_pytype_library(**kwargs):
+ if "strict_deps" in kwargs:
+ kwargs.pop("strict_deps")
+ native.py_library(**kwargs)
+
+reverb_pytype_strict_library = native.py_library
+
+def _make_search_paths(prefix, levels_to_root):
+ return ",".join(
+ [
+ "-rpath,%s/%s" % (prefix, "/".join([".."] * search_level))
+ for search_level in range(levels_to_root + 1)
+ ],
+ )
+
+def _rpath_linkopts(name):
+ # Search parent directories up to the TensorFlow root directory for shared
+ # object dependencies, even if this op shared object is deeply nested
+ # (e.g. tensorflow/contrib/package:python/ops/_op_lib.so). tensorflow/ is then
+ # the root and tensorflow/libtensorflow_framework.so should exist when
+ # deployed. Other shared object dependencies (e.g. shared between contrib/
+ # ops) are picked up as long as they are in either the same or a parent
+ # directory in the tensorflow/ tree.
+ levels_to_root = native.package_name().count("/") + name.count("/")
+ return ["-Wl,%s" % (_make_search_paths("$$ORIGIN", levels_to_root),)]
+
+def reverb_pybind_extension(
+ name,
+ srcs,
+ module_name,
+ hdrs = [],
+ features = [],
+ srcs_version = "PY3",
+ data = [],
+ copts = [],
+ linkopts = [],
+ deps = [],
+ defines = [],
+ visibility = None,
+ testonly = None,
+ licenses = None,
+ compatible_with = None,
+ restricted_to = None,
+ deprecation = None):
+ """Builds a generic Python extension module.
+
+ The module can be loaded in python by performing "import ${name}.".
+
+ Args:
+ name: Name.
+ srcs: cc files.
+ module_name: The name of the hidden module. It should be different
+ from `name`, and *must* match the MODULE declaration in the .cc file.
+ hdrs: h files.
+ features: see bazel docs.
+ srcs_version: srcs_version for py_library.
+ data: data deps.
+ copts: compilation opts.
+ linkopts: linking opts.
+ deps: cc_library deps.
+ defines: cc_library defines.
+ visibility: visibility.
+ testonly: whether the rule is testonly.
+ licenses: see bazel docs.
+ compatible_with: see bazel docs.
+ restricted_to: see bazel docs.
+ deprecation: see bazel docs.
+ """
+ if name == module_name:
+ fail(
+ "Must have name != module_name ({} vs. {}) because the python ".format(name, module_name) +
+ "wrapper $name.py needs to add extra logic loading tensorflow.",
+ )
+ py_file = "%s.py" % name
+ so_file = "%s.so" % module_name
+ pyd_file = "%s.pyd" % module_name
+ symbol = "init%s" % module_name
+ symbol2 = "init_%s" % module_name
+ symbol3 = "PyInit_%s" % module_name
+ exported_symbols_file = "%s-exported-symbols.lds" % module_name
+ version_script_file = "%s-version-script.lds" % module_name
+ native.genrule(
+ name = module_name + "_exported_symbols",
+ outs = [exported_symbols_file],
+ cmd = "echo '_%s\n_%s\n_%s' >$@" % (symbol, symbol2, symbol3),
+ output_licenses = ["unencumbered"],
+ visibility = ["//visibility:private"],
+ testonly = testonly,
+ )
+ native.genrule(
+ name = module_name + "_version_script",
+ outs = [version_script_file],
+ cmd = "echo '{global:\n %s;\n %s;\n %s;\n local: *;};' >$@" % (symbol, symbol2, symbol3),
+ output_licenses = ["unencumbered"],
+ visibility = ["//visibility:private"],
+ testonly = testonly,
+ )
+ native.cc_binary(
+ name = so_file,
+ srcs = srcs + hdrs,
+ data = data,
+ copts = copts + [
+ "-fno-strict-aliasing", # allow a wider range of code [aliasing] to compile.
+ "-fexceptions", # pybind relies on exceptions, required to compile.
+ "-fvisibility=hidden", # avoid pybind symbol clashes between DSOs.
+ ],
+ linkopts = linkopts + _rpath_linkopts(module_name) + [
+ "-Wl,--version-script",
+ "$(location %s)" % version_script_file,
+ ],
+ deps = depset(deps + [
+ exported_symbols_file,
+ version_script_file,
+ ]),
+ defines = defines,
+ features = features + ["-use_header_modules"],
+ linkshared = 1,
+ testonly = testonly,
+ licenses = licenses,
+ visibility = visibility,
+ deprecation = deprecation,
+ restricted_to = restricted_to,
+ compatible_with = compatible_with,
+ )
+ native.genrule(
+ name = module_name + "_pyd_copy",
+ srcs = [so_file],
+ outs = [pyd_file],
+ cmd = "cp $< $@",
+ output_to_bindir = True,
+ visibility = visibility,
+ deprecation = deprecation,
+ restricted_to = restricted_to,
+ compatible_with = compatible_with,
+ )
+ native.genrule(
+ name = name + "_py_file",
+ outs = [py_file],
+ cmd = (
+ "echo 'import tensorflow as _tf; from .%s import *; del _tf' >$@" %
+ module_name
+ ),
+ output_licenses = ["unencumbered"],
+ visibility = visibility,
+ testonly = testonly,
+ )
+ native.py_library(
+ name = name,
+ data = [so_file],
+ srcs = [py_file],
+ srcs_version = srcs_version,
+ licenses = licenses,
+ testonly = testonly,
+ visibility = visibility,
+ deprecation = deprecation,
+ restricted_to = restricted_to,
+ compatible_with = compatible_with,
+ )
+
+def reverb_py_standard_imports():
+ return []
+
+def reverb_py_test(
+ name,
+ srcs = [],
+ deps = [],
+ paropts = [],
+ python_version = "PY3",
+ **kwargs):
+ size = kwargs.pop("size", "small")
+ native.py_test(
+ name = name,
+ size = size,
+ srcs = srcs,
+ deps = deps,
+ python_version = python_version,
+ **kwargs
+ )
+ return
+
+def reverb_pybind_deps():
+ return [
+ "@pybind11",
+ ]
+
+def reverb_tf_ops_visibility():
+ return [
+ "//reverb:__subpackages__",
+ ]
+
+def reverb_tf_deps():
+ return [
+ "@tensorflow_includes//:includes",
+ "@tensorflow_solib//:framework_lib",
+ ]
+
+def reverb_grpc_deps():
+ return ["@com_github_grpc_grpc//:grpc++"]
+
+def reverb_absl_deps():
+ return [
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/functional:bind_front",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/numeric:int128",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/random:distributions",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ]
diff --git a/reverb/cc/platform/default/default_checkpointer.cc b/reverb/cc/platform/default/default_checkpointer.cc
new file mode 100644
index 0000000..59f6a7d
--- /dev/null
+++ b/reverb/cc/platform/default/default_checkpointer.cc
@@ -0,0 +1,28 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/platform/checkpointing.h"
+#include "reverb/cc/platform/tfrecord_checkpointer.h"
+
+namespace deepmind {
+namespace reverb {
+
+std::unique_ptr CreateDefaultCheckpointer(
+ std::string root_dir, std::string group) {
+ return absl::make_unique(std::move(root_dir),
+ std::move(group));
+}
+
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/platform/default/grpc_utils.cc b/reverb/cc/platform/default/grpc_utils.cc
new file mode 100644
index 0000000..7dbface
--- /dev/null
+++ b/reverb/cc/platform/default/grpc_utils.cc
@@ -0,0 +1,44 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/platform/grpc_utils.h"
+
+#include
+
+#include "grpcpp/create_channel.h"
+#include "grpcpp/security/credentials.h"
+#include "grpcpp/security/server_credentials.h"
+#include "grpcpp/support/channel_arguments.h"
+
+namespace deepmind {
+namespace reverb {
+
+std::shared_ptr MakeServerCredentials() {
+ return grpc::InsecureServerCredentials();
+}
+
+std::shared_ptr MakeChannelCredentials() {
+ return grpc::InsecureChannelCredentials();
+}
+
+std::shared_ptr CreateCustomGrpcChannel(
+ absl::string_view target,
+ const std::shared_ptr& credentials,
+ const grpc::ChannelArguments& channel_arguments) {
+ return grpc::CreateCustomChannel(
+ std::string(target), credentials, channel_arguments);
+}
+
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/platform/default/logging.h b/reverb/cc/platform/default/logging.h
new file mode 100644
index 0000000..41d3e3f
--- /dev/null
+++ b/reverb/cc/platform/default/logging.h
@@ -0,0 +1,206 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Copyright 2016-2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// ============================================================================
+//
+// A minimal replacement for "glog"-like functionality. Does not provide output
+// in a separate thread nor backtracing.
+
+#ifndef REVERB_CC_PLATFORM_DEFAULT_LOGGING_H_
+#define REVERB_CC_PLATFORM_DEFAULT_LOGGING_H_
+
+#include
+#include
+#include
+#include
+#include
+
+#ifdef __GNUC__
+#define PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
+#define PREDICT_FALSE(x) (__builtin_expect(x, 0))
+#define NORETURN __attribute__((noreturn))
+#else
+#define PREDICT_TRUE(x) (x)
+#define PREDICT_FALSE(x) (x)
+#define NORETURN
+#endif
+
+namespace deepmind {
+namespace reverb {
+namespace internal {
+
+struct CheckOpString {
+ explicit CheckOpString(std::string* str) : str_(str) {}
+ explicit operator bool() const { return PREDICT_FALSE(str_ != nullptr); }
+ std::string* const str_;
+};
+
+template
+CheckOpString MakeCheckOpString(const T1& v1, const T2& v2,
+ const char* exprtext) {
+ std::ostringstream oss;
+ oss << exprtext << " (" << v1 << " vs. " << v2 << ")";
+ return CheckOpString(new std::string(oss.str()));
+}
+
+#define DEFINE_CHECK_OP_IMPL(name, op) \
+ template \
+ inline CheckOpString name##Impl(const T1& v1, const T2& v2, \
+ const char* exprtext) { \
+ if (PREDICT_TRUE(v1 op v2)) { \
+ return CheckOpString(nullptr); \
+ } else { \
+ return (MakeCheckOpString)(v1, v2, exprtext); \
+ } \
+ } \
+ inline CheckOpString name##Impl(int v1, int v2, const char* exprtext) { \
+ return (name##Impl)(v1, v2, exprtext); \
+ }
+DEFINE_CHECK_OP_IMPL(Check_EQ, ==)
+DEFINE_CHECK_OP_IMPL(Check_NE, !=)
+DEFINE_CHECK_OP_IMPL(Check_LE, <=)
+DEFINE_CHECK_OP_IMPL(Check_LT, <)
+DEFINE_CHECK_OP_IMPL(Check_GE, >=)
+DEFINE_CHECK_OP_IMPL(Check_GT, >)
+#undef DEFINE_CHECK_OP_IMPL
+
+class LogMessage {
+ public:
+ LogMessage(const char* file, int line) {
+ std::clog << "[" << file << ":" << line << "] ";
+ }
+
+ ~LogMessage() { std::clog << "\n"; }
+
+ std::ostream& stream() && { return std::clog; }
+};
+
+class LogMessageFatal {
+ public:
+ LogMessageFatal(const char* file, int line) {
+ stream_ << "[" << file << ":" << line << "] ";
+ }
+
+ LogMessageFatal(const char* file, int line, const CheckOpString& result) {
+ stream_ << "[" << file << ":" << line << "] Check failed: " << *result.str_;
+ }
+
+ ~LogMessageFatal() NORETURN;
+
+ std::ostream& stream() && { return stream_; }
+
+ private:
+ std::ostringstream stream_;
+};
+
+inline LogMessageFatal::~LogMessageFatal() {
+ std::cerr << stream_.str() << std::endl;
+ std::abort();
+}
+
+struct NullStream {};
+
+template
+NullStream&& operator<<(NullStream&& s, T&&) {
+ return std::move(s);
+}
+
+enum class LogSeverity {
+ kFatal,
+ kNonFatal,
+};
+
+LogMessage LogStream(
+ std::integral_constant);
+LogMessageFatal LogStream(
+ std::integral_constant);
+
+struct Voidify {
+ void operator&(std::ostream&) {}
+};
+
+} // namespace internal
+} // namespace reverb
+} // namespace deepmind
+
+#define REVERB_CHECK_OP_LOG(name, op, val1, val2, log) \
+ while (::deepmind::reverb::internal::CheckOpString _result = \
+ ::deepmind::reverb::internal::name##Impl( \
+ val1, val2, #val1 " " #op " " #val2)) \
+ log(__FILE__, __LINE__, _result).stream()
+
+#define REVERB_CHECK_OP(name, op, val1, val2) \
+ REVERB_CHECK_OP_LOG(name, op, val1, val2, \
+ ::deepmind::reverb::internal::LogMessageFatal)
+
+#define REVERB_CHECK_EQ(val1, val2) REVERB_CHECK_OP(Check_EQ, ==, val1, val2)
+#define REVERB_CHECK_NE(val1, val2) REVERB_CHECK_OP(Check_NE, !=, val1, val2)
+#define REVERB_CHECK_LE(val1, val2) REVERB_CHECK_OP(Check_LE, <=, val1, val2)
+#define REVERB_CHECK_LT(val1, val2) REVERB_CHECK_OP(Check_LT, <, val1, val2)
+#define REVERB_CHECK_GE(val1, val2) REVERB_CHECK_OP(Check_GE, >=, val1, val2)
+#define REVERB_CHECK_GT(val1, val2) REVERB_CHECK_OP(Check_GT, >, val1, val2)
+
+#define REVERB_QCHECK_EQ(val1, val2) REVERB_CHECK_OP(Check_EQ, ==, val1, val2)
+#define REVERB_QCHECK_NE(val1, val2) REVERB_CHECK_OP(Check_NE, !=, val1, val2)
+#define REVERB_QCHECK_LE(val1, val2) REVERB_CHECK_OP(Check_LE, <=, val1, val2)
+#define REVERB_QCHECK_LT(val1, val2) REVERB_CHECK_OP(Check_LT, <, val1, val2)
+#define REVERB_QCHECK_GE(val1, val2) REVERB_CHECK_OP(Check_GE, >=, val1, val2)
+#define REVERB_QCHECK_GT(val1, val2) REVERB_CHECK_OP(Check_GT, >, val1, val2)
+
+#define REVERB_CHECK(condition) \
+ while (auto _result = ::deepmind::reverb::internal::CheckOpString( \
+ (condition) ? nullptr : new std::string(#condition))) \
+ ::deepmind::reverb::internal::LogMessageFatal(__FILE__, __LINE__, _result) \
+ .stream()
+
+#define REVERB_QCHECK(condition) REVERB_CHECK(condition)
+
+#define REVERB_FATAL ::deepmind::reverb::internal::LogSeverity::kFatal
+#define REVERB_QFATAL ::deepmind::reverb::internal::LogSeverity::kFatal
+#define REVERB_INFO ::deepmind::reverb::internal::LogSeverity::kNonFatal
+#define REVERB_WARNING ::deepmind::reverb::internal::LogSeverity::kNonFatal
+#define REVERB_ERROR ::deepmind::reverb::internal::LogSeverity::kNonFatal
+
+#define REVERB_LOG(level) \
+ decltype(::deepmind::reverb::internal::LogStream( \
+ std::integral_constant<::deepmind::reverb::internal::LogSeverity, \
+ level>()))(__FILE__, __LINE__) \
+ .stream()
+
+#define REVERB_VLOG(level) ::deepmind::reverb::internal::NullStream()
+
+#define REVERB_LOG_IF(level, condition) \
+ !(condition) \
+ ? static_cast(0) \
+ : ::deepmind::reverb::internal::Voidify() & \
+ decltype(::deepmind::reverb::internal::LogStream( \
+ std::integral_constant< \
+ ::deepmind::reverb::internal::LogSeverity, level>()))( \
+ __FILE__, __LINE__) \
+ .stream()
+
+#endif // REVERB_CC_PLATFORM_DEFAULT_LOGGING_H_
diff --git a/reverb/cc/platform/default/net.cc b/reverb/cc/platform/default/net.cc
new file mode 100644
index 0000000..62e869f
--- /dev/null
+++ b/reverb/cc/platform/default/net.cc
@@ -0,0 +1,136 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/platform/net.h"
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+#include "reverb/cc/platform/logging.h"
+
+namespace deepmind::reverb::internal {
+namespace {
+bool IsPortAvailable(int* port, bool is_tcp) {
+ const int protocol = is_tcp ? IPPROTO_TCP : 0;
+ const int fd = socket(AF_INET, is_tcp ? SOCK_STREAM : SOCK_DGRAM, protocol);
+
+ struct sockaddr_in addr;
+ socklen_t addr_len = sizeof(addr);
+ int actual_port;
+
+ REVERB_CHECK_GE(*port, 0);
+ REVERB_CHECK_LE(*port, 65535);
+ if (fd < 0) {
+ REVERB_LOG(REVERB_ERROR) << "socket() failed: " << strerror(errno);
+ return false;
+ }
+
+ // SO_REUSEADDR lets us start up a server immediately after it exists.
+ int one = 1;
+ if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0) {
+ REVERB_LOG(REVERB_ERROR) << "setsockopt() failed: " << strerror(errno);
+ if (close(fd) < 0) {
+ REVERB_LOG(REVERB_ERROR) << "close() failed: " << strerror(errno);
+ }
+ return false;
+ }
+
+ // Try binding to port.
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = INADDR_ANY;
+ addr.sin_port = htons(static_cast(*port));
+ if (bind(fd, reinterpret_cast(&addr), sizeof(addr)) < 0) {
+ REVERB_LOG(REVERB_WARNING)
+ << "bind(port=" << *port << ") failed: " << strerror(errno);
+ if (close(fd) < 0) {
+ REVERB_LOG(REVERB_ERROR) << "close() failed: " << strerror(errno);
+ }
+ return false;
+ }
+
+ // Get the bound port number.
+ if (getsockname(fd, reinterpret_cast(&addr), &addr_len) <
+ 0) {
+ REVERB_LOG(REVERB_WARNING) << "getsockname() failed: " << strerror(errno);
+ if (close(fd) < 0) {
+ REVERB_LOG(REVERB_ERROR) << "close() failed: " << strerror(errno);
+ }
+ return false;
+ }
+ REVERB_CHECK_LE(addr_len, sizeof(addr));
+ actual_port = ntohs(addr.sin_port);
+ REVERB_CHECK_GT(actual_port, 0);
+ if (*port == 0) {
+ *port = actual_port;
+ } else {
+ REVERB_CHECK_EQ(*port, actual_port);
+ }
+ if (close(fd) < 0) {
+ REVERB_LOG(REVERB_ERROR) << "close() failed: " << strerror(errno);
+ }
+ return true;
+}
+
+const int kNumRandomPortsToPick = 100;
+const int kMaximumTrials = 1000;
+
+} // namespace
+
+int PickUnusedPortOrDie() {
+ static std::unordered_set chosen_ports;
+
+ // Type of port to first pick in the next iteration.
+ bool is_tcp = true;
+ int trial = 0;
+ while (true) {
+ int port;
+ trial++;
+ REVERB_CHECK_LE(trial, kMaximumTrials)
+ << "Failed to pick an unused port for testing.";
+ if (trial == 1) {
+ port = getpid() % (65536 - 30000) + 30000;
+ } else if (trial <= kNumRandomPortsToPick) {
+ port = rand() % (65536 - 30000) + 30000; // NOLINT: Ignore suggestion to use rand_r instead.
+ } else {
+ port = 0;
+ }
+
+ if (chosen_ports.find(port) != chosen_ports.end()) {
+ continue;
+ }
+ if (!IsPortAvailable(&port, is_tcp)) {
+ continue;
+ }
+
+ REVERB_CHECK_GT(port, 0);
+ if (!IsPortAvailable(&port, !is_tcp)) {
+ is_tcp = !is_tcp;
+ continue;
+ }
+
+ chosen_ports.insert(port);
+ return port;
+ }
+
+ return 0;
+}
+
+} // namespace deepmind::reverb::internal
diff --git a/reverb/cc/platform/default/repo.bzl b/reverb/cc/platform/default/repo.bzl
new file mode 100644
index 0000000..200e9d1
--- /dev/null
+++ b/reverb/cc/platform/default/repo.bzl
@@ -0,0 +1,361 @@
+"""Reverb custom external dependencies."""
+
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+
+# Sanitize a dependency so that it works correctly from code that includes
+# reverb as a submodule.
+def clean_dep(dep):
+ return str(Label(dep))
+
+def get_python_path(ctx):
+ path = ctx.os.environ.get("PYTHON_BIN_PATH")
+ if not path:
+ fail(
+ "Could not get environment variable PYTHON_BIN_PATH. " +
+ "Check your .bazelrc file.",
+ )
+ return path
+
+def _find_tf_include_path(repo_ctx):
+ exec_result = repo_ctx.execute(
+ [
+ get_python_path(repo_ctx),
+ "-c",
+ "import tensorflow as tf; import sys; " +
+ "sys.stdout.write(tf.sysconfig.get_include())",
+ ],
+ quiet = True,
+ )
+ if exec_result.return_code != 0:
+ fail("Could not locate tensorflow installation path:\n{}"
+ .format(exec_result.stderr))
+ return exec_result.stdout.splitlines()[-1]
+
+def _find_tf_lib_path(repo_ctx):
+ exec_result = repo_ctx.execute(
+ [
+ get_python_path(repo_ctx),
+ "-c",
+ "import tensorflow as tf; import sys; " +
+ "sys.stdout.write(tf.sysconfig.get_lib())",
+ ],
+ quiet = True,
+ )
+ if exec_result.return_code != 0:
+ fail("Could not locate tensorflow installation path:\n{}"
+ .format(exec_result.stderr))
+ return exec_result.stdout.splitlines()[-1]
+
+def _find_numpy_include_path(repo_ctx):
+ exec_result = repo_ctx.execute(
+ [
+ get_python_path(repo_ctx),
+ "-c",
+ "import numpy; import sys; " +
+ "sys.stdout.write(numpy.get_include())",
+ ],
+ quiet = True,
+ )
+ if exec_result.return_code != 0:
+ fail("Could not locate numpy includes path:\n{}"
+ .format(exec_result.stderr))
+ return exec_result.stdout.splitlines()[-1]
+
+def _find_python_include_path(repo_ctx):
+ exec_result = repo_ctx.execute(
+ [
+ get_python_path(repo_ctx),
+ "-c",
+ "from distutils import sysconfig; import sys; " +
+ "sys.stdout.write(sysconfig.get_python_inc())",
+ ],
+ quiet = True,
+ )
+ if exec_result.return_code != 0:
+ fail("Could not locate python includes path:\n{}"
+ .format(exec_result.stderr))
+ return exec_result.stdout.splitlines()[-1]
+
+def _find_python_solib_path(repo_ctx):
+ exec_result = repo_ctx.execute(
+ [
+ get_python_path(repo_ctx),
+ "-c",
+ "import sys; vi = sys.version_info; " +
+ "sys.stdout.write('python{}.{}'.format(vi.major, vi.minor))",
+ ],
+ )
+ if exec_result.return_code != 0:
+ fail("Could not locate python shared library path:\n{}"
+ .format(exec_result.stderr))
+ version = exec_result.stdout.splitlines()[-1]
+ basename = "lib{}.so".format(version)
+ exec_result = repo_ctx.execute(
+ ["{}-config".format(version), "--configdir"],
+ quiet = True,
+ )
+ if exec_result.return_code != 0:
+ fail("Could not locate python shared library path:\n{}"
+ .format(exec_result.stderr))
+ solib_dir = exec_result.stdout.splitlines()[-1]
+ full_path = repo_ctx.path("{}/{}".format(solib_dir, basename))
+ if not full_path.exists:
+ fail("Unable to find python shared library file:\n{}/{}"
+ .format(solib_dir, basename))
+ return struct(dir = solib_dir, basename = basename)
+
+def _eigen_archive_repo_impl(repo_ctx):
+ tf_include_path = _find_tf_include_path(repo_ctx)
+ repo_ctx.symlink(tf_include_path, "tf_includes")
+ repo_ctx.file(
+ "BUILD",
+ content = """
+cc_library(
+ name = "includes",
+ hdrs = glob(["tf_includes/Eigen/**/*.h",
+ "tf_includes/Eigen/**",
+ "tf_includes/unsupported/Eigen/**/*.h",
+ "tf_includes/unsupported/Eigen/**"]),
+ # https://groups.google.com/forum/#!topic/bazel-discuss/HyyuuqTxKok
+ includes = ["tf_includes"],
+ visibility = ["//visibility:public"],
+)
+""",
+ executable = False,
+ )
+
+def _nsync_includes_repo_impl(repo_ctx):
+ tf_include_path = _find_tf_include_path(repo_ctx)
+ repo_ctx.symlink(tf_include_path + "/external", "nsync_includes")
+ repo_ctx.file(
+ "BUILD",
+ content = """
+cc_library(
+ name = "includes",
+ hdrs = glob(["nsync_includes/nsync/public/*.h"]),
+ includes = ["nsync_includes"],
+ visibility = ["//visibility:public"],
+)
+""",
+ executable = False,
+ )
+
+def _zlib_includes_repo_impl(repo_ctx):
+ tf_include_path = _find_tf_include_path(repo_ctx)
+ repo_ctx.symlink(
+ tf_include_path + "/external/zlib",
+ "zlib",
+ )
+ repo_ctx.file(
+ "BUILD",
+ content = """
+cc_library(
+ name = "includes",
+ hdrs = glob(["zlib/**/*.h"]),
+ includes = ["zlib"],
+ visibility = ["//visibility:public"],
+)
+""",
+ executable = False,
+ )
+
+def _snappy_includes_repo_impl(repo_ctx):
+ tf_include_path = _find_tf_include_path(repo_ctx)
+ repo_ctx.symlink(
+ tf_include_path + "/external/snappy",
+ "snappy",
+ )
+ repo_ctx.file(
+ "BUILD",
+ content = """
+cc_library(
+ name = "includes",
+ hdrs = glob(["snappy/*.h"]),
+ includes = ["snappy"],
+ visibility = ["//visibility:public"],
+)
+""",
+ executable = False,
+ )
+
+def _protobuf_includes_repo_impl(repo_ctx):
+ tf_include_path = _find_tf_include_path(repo_ctx)
+ repo_ctx.symlink(tf_include_path, "tf_includes")
+ repo_ctx.symlink(Label("//third_party:protobuf.BUILD"), "BUILD")
+
+def _tensorflow_includes_repo_impl(repo_ctx):
+ tf_include_path = _find_tf_include_path(repo_ctx)
+ repo_ctx.symlink(tf_include_path, "tensorflow_includes")
+ repo_ctx.file(
+ "BUILD",
+ content = """
+cc_library(
+ name = "includes",
+ hdrs = glob(
+ [
+ "tensorflow_includes/**/*.h",
+ "tensorflow_includes/third_party/eigen3/**",
+ ],
+ exclude = ["tensorflow_includes/absl/**/*.h"],
+ ),
+ includes = ["tensorflow_includes"],
+ deps = [
+ "@eigen_archive//:includes",
+ "@protobuf_archive//:includes",
+ "@zlib_includes//:includes",
+ "@snappy_includes//:includes",
+ ],
+ visibility = ["//visibility:public"],
+)
+filegroup(
+ name = "protos",
+ srcs = glob(["tensorflow_includes/**/*.proto"]),
+ visibility = ["//visibility:public"],
+)
+""",
+ executable = False,
+ )
+
+def _tensorflow_solib_repo_impl(repo_ctx):
+ tf_lib_path = _find_tf_lib_path(repo_ctx)
+ repo_ctx.symlink(tf_lib_path, "tensorflow_solib")
+ repo_ctx.file(
+ "BUILD",
+ content = """
+cc_library(
+ name = "framework_lib",
+ srcs = ["tensorflow_solib/libtensorflow_framework.so.2"],
+ deps = ["@python_includes", "@python_includes//:numpy_includes"],
+ visibility = ["//visibility:public"],
+)
+""",
+ )
+
+def _python_includes_repo_impl(repo_ctx):
+ python_include_path = _find_python_include_path(repo_ctx)
+ python_solib = _find_python_solib_path(repo_ctx)
+ repo_ctx.symlink(python_include_path, "python_includes")
+ numpy_include_path = _find_numpy_include_path(repo_ctx)
+ repo_ctx.symlink(numpy_include_path, "numpy_includes")
+ repo_ctx.symlink(
+ "{}/{}".format(python_solib.dir, python_solib.basename),
+ python_solib.basename,
+ )
+
+ # Note, "@python_includes" is a misnomer since we include the
+ # libpythonX.Y.so in the srcs, so we can get access to python's various
+ # symbols at link time.
+ repo_ctx.file(
+ "BUILD",
+ content = """
+cc_library(
+ name = "python_includes",
+ hdrs = glob(["python_includes/**/*.h"]),
+ srcs = ["{}"],
+ includes = ["python_includes"],
+ visibility = ["//visibility:public"],
+)
+cc_library(
+ name = "numpy_includes",
+ hdrs = glob(["numpy_includes/**/*.h"]),
+ includes = ["numpy_includes"],
+ visibility = ["//visibility:public"],
+)
+""".format(python_solib.basename),
+ executable = False,
+ )
+
+def cc_tf_configure():
+ """Autoconf pre-installed tensorflow repo."""
+ make_eigen_repo = repository_rule(implementation = _eigen_archive_repo_impl)
+ make_eigen_repo(name = "eigen_archive")
+ make_nsync_repo = repository_rule(
+ implementation = _nsync_includes_repo_impl,
+ )
+ make_nsync_repo(name = "nsync_includes")
+ make_zlib_repo = repository_rule(
+ implementation = _zlib_includes_repo_impl,
+ )
+ make_zlib_repo(name = "zlib_includes")
+ make_snappy_repo = repository_rule(
+ implementation = _snappy_includes_repo_impl,
+ )
+ make_snappy_repo(name = "snappy_includes")
+ make_protobuf_repo = repository_rule(
+ implementation = _protobuf_includes_repo_impl,
+ )
+ make_protobuf_repo(name = "protobuf_archive")
+ make_tfinc_repo = repository_rule(
+ implementation = _tensorflow_includes_repo_impl,
+ )
+ make_tfinc_repo(name = "tensorflow_includes")
+ make_tflib_repo = repository_rule(
+ implementation = _tensorflow_solib_repo_impl,
+ )
+ make_tflib_repo(name = "tensorflow_solib")
+ make_python_inc_repo = repository_rule(
+ implementation = _python_includes_repo_impl,
+ )
+ make_python_inc_repo(name = "python_includes")
+
+def reverb_python_deps():
+ http_archive(
+ name = "pybind11",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.4.3.tar.gz",
+ "https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz",
+ ],
+ sha256 = "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d",
+ strip_prefix = "pybind11-2.4.3",
+ build_file = clean_dep("//third_party:pybind11.BUILD"),
+ )
+
+ http_archive(
+ name = "absl_py",
+ sha256 = "603febc9b95a8f2979a7bdb77d2f5e4d9b30d4e0d59579f88eba67d4e4cc5462",
+ strip_prefix = "abseil-py-pypi-v0.9.0",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-py/archive/pypi-v0.9.0.tar.gz",
+ "https://github.com/abseil/abseil-py/archive/pypi-v0.9.0.tar.gz",
+ ],
+ )
+
+def _reverb_protoc_archive(ctx):
+ version = ctx.attr.version
+ sha256 = ctx.attr.sha256
+
+ override_version = ctx.os.environ.get("REVERB_PROTOC_VERSION")
+ if override_version:
+ sha256 = ""
+ version = override_version
+
+ urls = [
+ "https://github.com/protocolbuffers/protobuf/releases/download/v%s/protoc-%s-linux-x86_64.zip" % (version, version),
+ ]
+ ctx.download_and_extract(
+ url = urls,
+ sha256 = sha256,
+ )
+
+ ctx.file(
+ "BUILD",
+ content = """
+filegroup(
+ name = "protoc_bin",
+ srcs = ["bin/protoc"],
+ visibility = ["//visibility:public"],
+)
+""",
+ executable = False,
+ )
+
+reverb_protoc_archive = repository_rule(
+ implementation = _reverb_protoc_archive,
+ attrs = {
+ "version": attr.string(mandatory = True),
+ "sha256": attr.string(mandatory = True),
+ },
+)
+
+def reverb_protoc_deps(version, sha256):
+ reverb_protoc_archive(name = "protobuf_protoc", version = version, sha256 = sha256)
diff --git a/reverb/cc/platform/default/snappy.cc b/reverb/cc/platform/default/snappy.cc
new file mode 100644
index 0000000..5fd5e09
--- /dev/null
+++ b/reverb/cc/platform/default/snappy.cc
@@ -0,0 +1,171 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/platform/snappy.h"
+
+#include
+
+#include "absl/meta/type_traits.h"
+#include "snappy-sinksource.h" // NOLINT(build/include)
+#include "snappy.h" // NOLINT(build/include)
+
+namespace deepmind {
+namespace reverb {
+
+namespace {
+
+// Helpers for STLStringResizeUninitialized
+// HasMember is true_type or false_type, depending on whether or not
+// T has a __resize_default_init member. Resize will call the
+// __resize_default_init member if it exists, and will call the resize
+// member otherwise.
+template
+struct ResizeUninitializedTraits {
+ using HasMember = std::false_type;
+ static void Resize(string_type* s, size_t new_size) { s->resize(new_size); }
+};
+
+// __resize_default_init is provided by libc++ >= 8.0.
+template
+struct ResizeUninitializedTraits<
+ string_type, absl::void_t()
+ .__resize_default_init(237))> > {
+ using HasMember = std::true_type;
+ static void Resize(string_type* s, size_t new_size) {
+ s->__resize_default_init(new_size);
+ }
+};
+
+template
+inline constexpr bool STLStringSupportsNontrashingResize(string_type*) {
+ return ResizeUninitializedTraits::HasMember::value;
+}
+
+// Resize string `s` to `new_size`, leaving the data uninitialized.
+static inline void STLStringResizeUninitialized(std::string* s,
+ size_t new_size) {
+ ResizeUninitializedTraits::Resize(s, new_size);
+}
+
+class StringSink : public snappy::Sink {
+ public:
+ explicit StringSink(std::string* dest) : dest_(dest) {}
+
+ StringSink(const StringSink&) = delete;
+ StringSink& operator=(const StringSink&) = delete;
+
+ void Append(const char* data, size_t n) override {
+ if (STLStringSupportsNontrashingResize(dest_)) {
+ size_t current_size = dest_->size();
+ if (data == (const_cast(dest_->data()) + current_size)) {
+ // Zero copy append
+ STLStringResizeUninitialized(dest_, current_size + n);
+ return;
+ }
+ }
+ dest_->append(data, n);
+ }
+
+ char* GetAppendBuffer(size_t size, char* scratch) override {
+ if (!STLStringSupportsNontrashingResize(dest_)) {
+ return scratch;
+ }
+
+ const size_t current_size = dest_->size();
+ if ((size + current_size) > dest_->capacity()) {
+ // Use resize instead of reserve so that we grow by the strings growth
+ // factor. Then reset the size to where it was.
+ STLStringResizeUninitialized(dest_, size + current_size);
+ STLStringResizeUninitialized(dest_, current_size);
+ }
+
+ // If string size is zero, then string_as_array() returns nullptr, so
+ // we need to use data() instead
+ return const_cast(dest_->data()) + current_size;
+ }
+
+ private:
+ std::string* dest_;
+};
+
+// TODO(b/140988915): See if this can be moved to snappy's codebase.
+class CheckedByteArraySink : public snappy::Sink {
+ // A snappy Sink that takes an output buffer and a capacity value. If the
+ // writer attempts to write more data than capacity, it does the safe thing
+ // and doesn't attempt to write past the data boundary. After writing,
+ // call sink.Overflowed() to see if an overflow occurred.
+
+ public:
+ CheckedByteArraySink(char* outbuf, size_t capacity)
+ : outbuf_(outbuf), capacity_(capacity), size_(0), overflowed_(false) {}
+ CheckedByteArraySink(const CheckedByteArraySink&) = delete;
+ CheckedByteArraySink& operator=(const CheckedByteArraySink&) = delete;
+
+ void Append(const char* bytes, size_t n) override {
+ size_t available = capacity_ - size_;
+ if (n > available) {
+ n = available;
+ overflowed_ = true;
+ }
+ if (n > 0 && bytes != (outbuf_ + size_)) {
+ // Catch cases where the pointer returned by GetAppendBuffer() was
+ // modified.
+ assert(!(outbuf_ <= bytes && bytes < outbuf_ + capacity_));
+ memcpy(outbuf_ + size_, bytes, n);
+ }
+ size_ += n;
+ }
+
+ char* GetAppendBuffer(size_t length, char* scratch) override {
+ size_t available = capacity_ - size_;
+ if (available >= length) {
+ return outbuf_ + size_;
+ } else {
+ return scratch;
+ }
+ }
+
+ // Returns the number of bytes actually written to the sink.
+ size_t NumberOfBytesWritten() const { return size_; }
+
+ // Returns true if any bytes were discarded during the Append(), i.e., if
+ // Append() attempted to write more than 'capacity' bytes.
+ bool Overflowed() const { return overflowed_; }
+
+ private:
+ char* outbuf_;
+ const size_t capacity_;
+ size_t size_;
+ bool overflowed_;
+};
+
+} // namespace
+
+template <>
+size_t SnappyCompressFromString(absl::string_view input, std::string* output) {
+ snappy::ByteArraySource source(input.data(), input.size());
+ StringSink sink(output);
+ return snappy::Compress(&source, &sink);
+}
+
+template <>
+bool SnappyUncompressToString(const std::string& input, size_t output_capacity,
+ char* output) {
+ snappy::ByteArraySource source(input.data(), input.size());
+ CheckedByteArraySink sink(output, output_capacity);
+ return snappy::Uncompress(&source, &sink);
+}
+
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/platform/default/thread.cc b/reverb/cc/platform/default/thread.cc
new file mode 100644
index 0000000..583d283
--- /dev/null
+++ b/reverb/cc/platform/default/thread.cc
@@ -0,0 +1,45 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/platform/thread.h"
+
+#include // NOLINT(build/c++11)
+
+#include "absl/memory/memory.h"
+
+namespace deepmind {
+namespace reverb {
+namespace internal {
+namespace {
+
+class StdThread : public Thread {
+ public:
+ explicit StdThread(std::function fn) : thread_(std::move(fn)) {}
+
+ ~StdThread() override { thread_.join(); }
+
+ private:
+ std::thread thread_;
+};
+
+} // namespace
+
+std::unique_ptr StartThread(absl::string_view name,
+ std::function fn) {
+ return {absl::make_unique(std::move(fn))};
+}
+
+} // namespace internal
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/platform/grpc_utils.h b/reverb/cc/platform/grpc_utils.h
new file mode 100644
index 0000000..12f4659
--- /dev/null
+++ b/reverb/cc/platform/grpc_utils.h
@@ -0,0 +1,40 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef REVERB_CC_PLATFORM_GRPC_CREDENTIALS_H_
+#define REVERB_CC_PLATFORM_GRPC_CREDENTIALS_H_
+
+#include
+
+#include "absl/strings/string_view.h"
+#include "grpcpp/security/credentials.h"
+#include "grpcpp/security/server_credentials.h"
+#include "grpcpp/support/channel_arguments.h"
+
+namespace deepmind {
+namespace reverb {
+
+std::shared_ptr MakeServerCredentials();
+
+std::shared_ptr MakeChannelCredentials();
+
+std::shared_ptr CreateCustomGrpcChannel(
+ absl::string_view target,
+ const std::shared_ptr& credentials,
+ const grpc::ChannelArguments& channel_arguments);
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // REVERB_CC_PLATFORM_GRPC_CREDENTIALS_H_
diff --git a/reverb/cc/platform/logging.h b/reverb/cc/platform/logging.h
new file mode 100644
index 0000000..a7d1335
--- /dev/null
+++ b/reverb/cc/platform/logging.h
@@ -0,0 +1,20 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef REVERB_CC_PLATFORM_LOGGING_H_
+#define REVERB_CC_PLATFORM_LOGGING_H_
+
+#include "reverb/cc/platform/default/logging.h"
+
+#endif // REVERB_CC_PLATFORM_LOGGING_H_
diff --git a/reverb/cc/platform/net.h b/reverb/cc/platform/net.h
new file mode 100644
index 0000000..6113d80
--- /dev/null
+++ b/reverb/cc/platform/net.h
@@ -0,0 +1,22 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef REVERB_CC_TESTING_NET_H_
+#define REVERB_CC_TESTING_NET_H_
+
+namespace deepmind::reverb::internal{
+int PickUnusedPortOrDie();
+} // namespace deepmind::reverb::internal
+
+#endif // REVERB_CC_TESTING_NET_H_
diff --git a/reverb/cc/platform/net_test.cc b/reverb/cc/platform/net_test.cc
new file mode 100644
index 0000000..e877bdd
--- /dev/null
+++ b/reverb/cc/platform/net_test.cc
@@ -0,0 +1,36 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/platform/net.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "reverb/cc/platform/logging.h"
+
+namespace deepmind::reverb::internal{
+namespace {
+
+
+TEST(Net, PickUnusedPortOrDie) {
+ int port0 = PickUnusedPortOrDie();
+ int port1 = PickUnusedPortOrDie();
+ REVERB_CHECK_GE(port0, 0);
+ REVERB_CHECK_LT(port0, 65536);
+ REVERB_CHECK_GE(port1, 0);
+ REVERB_CHECK_LT(port1, 65536);
+ REVERB_CHECK_NE(port0, port1);
+}
+
+} // namespace
+} // namespace deepmind::reverb::internal
diff --git a/reverb/cc/platform/snappy.h b/reverb/cc/platform/snappy.h
new file mode 100644
index 0000000..209313f
--- /dev/null
+++ b/reverb/cc/platform/snappy.h
@@ -0,0 +1,37 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef REVERB_CC_PLATFORM_SNAPPY_H_
+#define REVERB_CC_PLATFORM_SNAPPY_H_
+
+#include
+
+#include "absl/strings/string_view.h"
+
+namespace deepmind {
+namespace reverb {
+
+// Compress a string to a `Toutput` output. Return the number of bytes stored.
+template
+size_t SnappyCompressFromString(absl::string_view input, Toutput* output);
+
+// Uncompress an `input` containing snappy-compressed data to *output.
+template
+bool SnappyUncompressToString(const Tinput& input, size_t output_capacity,
+ char* output);
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // REVERB_CC_PLATFORM_SNAPPY_H_
diff --git a/reverb/cc/platform/tfrecord_checkpointer.cc b/reverb/cc/platform/tfrecord_checkpointer.cc
new file mode 100644
index 0000000..4be2fb2
--- /dev/null
+++ b/reverb/cc/platform/tfrecord_checkpointer.cc
@@ -0,0 +1,332 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/platform/tfrecord_checkpointer.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "reverb/cc/checkpointing/checkpoint.pb.h"
+#include "reverb/cc/chunk_store.h"
+#include "reverb/cc/distributions/fifo.h"
+#include "reverb/cc/distributions/heap.h"
+#include "reverb/cc/distributions/interface.h"
+#include "reverb/cc/distributions/lifo.h"
+#include "reverb/cc/distributions/prioritized.h"
+#include "reverb/cc/distributions/uniform.h"
+#include "reverb/cc/priority_table.h"
+#include "reverb/cc/rate_limiter.h"
+#include "reverb/cc/schema.pb.h"
+#include "reverb/cc/table_extensions/interface.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/io/record_reader.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/file_system.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+constexpr char kTablesFileName[] = "tables.tfrecord";
+constexpr char kChunksFileName[] = "chunks.tfrecord";
+constexpr char kDoneFileName[] = "DONE";
+
+using RecordWriterUniquePtr =
+ std::unique_ptr>;
+using RecordReaderUniquePtr =
+ std::unique_ptr>;
+
+tensorflow::Status OpenWriter(const std::string& path,
+ RecordWriterUniquePtr* writer) {
+ std::unique_ptr file;
+ TF_RETURN_IF_ERROR(tensorflow::Env::Default()->NewWritableFile(path, &file));
+ auto* file_ptr = file.release();
+ *writer = RecordWriterUniquePtr(new tensorflow::io::RecordWriter(file_ptr),
+ [file_ptr](tensorflow::io::RecordWriter* w) {
+ delete w;
+ delete file_ptr;
+ });
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status OpenReader(const std::string& path,
+ RecordReaderUniquePtr* reader) {
+ std::unique_ptr file;
+ TF_RETURN_IF_ERROR(
+ tensorflow::Env::Default()->NewRandomAccessFile(path, &file));
+ auto* file_ptr = file.release();
+ *reader = RecordReaderUniquePtr(new tensorflow::io::RecordReader(file_ptr),
+ [file_ptr](tensorflow::io::RecordReader* r) {
+ delete r;
+ delete file_ptr;
+ });
+ return tensorflow::Status::OK();
+}
+
+inline tensorflow::Status WriteDone(const std::string& path) {
+ std::unique_ptr file;
+ TF_RETURN_IF_ERROR(tensorflow::Env::Default()->NewWritableFile(
+ tensorflow::io::JoinPath(path, kDoneFileName), &file));
+ return file->Close();
+}
+
+inline bool HasDone(const std::string& path) {
+ return tensorflow::Env::Default()
+ ->FileExists(tensorflow::io::JoinPath(path, kDoneFileName))
+ .ok();
+}
+
+std::unique_ptr MakeDistribution(
+ const KeyDistributionOptions& options) {
+ switch (options.distribution_case()) {
+ case KeyDistributionOptions::kFifo:
+ return absl::make_unique();
+ case KeyDistributionOptions::kLifo:
+ return absl::make_unique();
+ case KeyDistributionOptions::kUniform:
+ return absl::make_unique();
+ case KeyDistributionOptions::kPrioritized:
+ return absl::make_unique(
+ options.prioritized().priority_exponent());
+ case KeyDistributionOptions::kHeap:
+ return absl::make_unique(options.heap().min_heap());
+ case KeyDistributionOptions::DISTRIBUTION_NOT_SET:
+ REVERB_LOG(REVERB_FATAL) << "Distribution not set";
+ default:
+ REVERB_LOG(REVERB_FATAL) << "Distribution not supported";
+ }
+}
+
+inline size_t find_table_index(
+ const std::vector>* tables,
+ const std::string& name) {
+ for (int i = 0; i < tables->size(); i++) {
+ if (tables->at(i)->name() == name) return i;
+ }
+ return -1;
+}
+
+} // namespace
+
+TFRecordCheckpointer::TFRecordCheckpointer(std::string root_dir,
+ std::string group)
+ : root_dir_(std::move(root_dir)), group_(std::move(group)) {
+ REVERB_LOG(REVERB_INFO) << "Initializing TFRecordCheckpointer in "
+ << root_dir_;
+}
+
+tensorflow::Status TFRecordCheckpointer::Save(
+ std::vector tables, int keep_latest, std::string* path) {
+ if (keep_latest <= 0) {
+ return tensorflow::errors::InvalidArgument(
+ "TFRecordCheckpointer must have keep_latest > 0.");
+ }
+ if (!group_.empty()) {
+ return tensorflow::errors::InvalidArgument(
+ "Setting non-empty group is not supported");
+ }
+
+ std::string dir_path =
+ tensorflow::io::JoinPath(root_dir_, absl::FormatTime(absl::Now()));
+ TF_RETURN_IF_ERROR(
+ tensorflow::Env::Default()->RecursivelyCreateDir(dir_path));
+
+ RecordWriterUniquePtr table_writer;
+ TF_RETURN_IF_ERROR(OpenWriter(
+ tensorflow::io::JoinPath(dir_path, kTablesFileName), &table_writer));
+
+ absl::flat_hash_set> chunks;
+ for (PriorityTable* table : tables) {
+ auto checkpoint = table->Checkpoint();
+ chunks.merge(checkpoint.chunks);
+ TF_RETURN_IF_ERROR(
+ table_writer->WriteRecord(checkpoint.checkpoint.SerializeAsString()));
+ }
+
+ TF_RETURN_IF_ERROR(table_writer->Close());
+ table_writer = nullptr;
+
+ RecordWriterUniquePtr chunk_writer;
+ TF_RETURN_IF_ERROR(OpenWriter(
+ tensorflow::io::JoinPath(dir_path, kChunksFileName), &chunk_writer));
+
+ for (const auto& chunk : chunks) {
+ TF_RETURN_IF_ERROR(
+ chunk_writer->WriteRecord(chunk->data().SerializeAsString()));
+ }
+ TF_RETURN_IF_ERROR(chunk_writer->Close());
+ chunk_writer = nullptr;
+
+ // Both chunks and table checkpoint has now been written so we can proceed to
+ // add the DONE-file.
+ TF_RETURN_IF_ERROR(WriteDone(dir_path));
+
+ // Delete the older checkpoints.
+ std::vector filenames;
+ TF_RETURN_IF_ERROR(tensorflow::Env::Default()->GetMatchingPaths(
+ tensorflow::io::JoinPath(root_dir_, "*"), &filenames));
+ std::sort(filenames.begin(), filenames.end());
+ int history_counter = 0;
+ for (auto it = filenames.rbegin(); it != filenames.rend(); it++) {
+ if (++history_counter > keep_latest) {
+ tensorflow::int64 undeleted_files;
+ tensorflow::int64 undeleted_dirs;
+ TF_RETURN_IF_ERROR(tensorflow::Env::Default()->DeleteRecursively(
+ *it, &undeleted_files, &undeleted_dirs));
+ }
+ }
+
+ *path = std::move(dir_path);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFRecordCheckpointer::Load(
+ absl::string_view relative_path, ChunkStore* chunk_store,
+ std::vector>* tables) {
+ const std::string dir_path =
+ tensorflow::io::JoinPath(root_dir_, relative_path);
+ REVERB_LOG(REVERB_INFO) << "Loading checkpoint from " << dir_path;
+ if (!HasDone(dir_path)) {
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Load called with invalid checkpoint path: ", dir_path));
+ }
+ // Insert data first to ensure that all data referenced by the priority tables
+ // exists. Keep the map of chunks around so that none of the chunks are
+ // cleaned up before all the priority tables have been loaded.
+ absl::flat_hash_map>
+ chunk_by_key;
+ {
+ RecordReaderUniquePtr chunk_reader;
+ TF_RETURN_IF_ERROR(OpenReader(
+ tensorflow::io::JoinPath(dir_path, kChunksFileName), &chunk_reader));
+
+ ChunkData chunk_data;
+ tensorflow::Status chunk_status;
+ tensorflow::uint64 chunk_offset = 0;
+ tensorflow::tstring chunk_record;
+ do {
+ chunk_status = chunk_reader->ReadRecord(&chunk_offset, &chunk_record);
+ if (!chunk_status.ok()) break;
+ if (!chunk_data.ParseFromArray(chunk_record.data(),
+ chunk_record.size())) {
+ return tensorflow::errors::DataLoss(
+ "Could not parse TFRecord as ChunkData: '", chunk_record, "'");
+ }
+ chunk_by_key[chunk_data.chunk_key()] = chunk_store->Insert(chunk_data);
+ } while (chunk_status.ok());
+ if (!tensorflow::errors::IsOutOfRange(chunk_status)) {
+ return chunk_status;
+ }
+ }
+
+ RecordReaderUniquePtr table_reader;
+ TF_RETURN_IF_ERROR(OpenReader(
+ tensorflow::io::JoinPath(dir_path, kTablesFileName), &table_reader));
+
+ PriorityTableCheckpoint checkpoint;
+ tensorflow::Status table_status;
+ tensorflow::uint64 table_offset = 0;
+ tensorflow::tstring table_record;
+ do {
+ table_status = table_reader->ReadRecord(&table_offset, &table_record);
+ if (!table_status.ok()) break;
+ if (!checkpoint.ParseFromArray(table_record.data(), table_record.size())) {
+ return tensorflow::errors::DataLoss(
+ "Could not parse TFRecord as Checkpoint: '", table_record, "'");
+ }
+
+ int index = find_table_index(tables, checkpoint.table_name());
+ if (index == -1) {
+ std::vector table_names;
+ for (const auto& table : *tables) {
+ table_names.push_back(absl::StrCat("'", table->name(), "'"));
+ }
+ return tensorflow::errors::InvalidArgument(
+ "Trying to load table ", checkpoint.table_name(),
+ " but table was not found in provided list of tables. Available "
+ "tables: [",
+ absl::StrJoin(table_names, ", "), "]");
+ }
+
+ auto sampler = MakeDistribution(checkpoint.sampler());
+ auto remover = MakeDistribution(checkpoint.remover());
+ auto rate_limiter =
+ std::make_shared(checkpoint.rate_limiter());
+
+ auto table = std::make_shared(
+ /*name=*/checkpoint.table_name(), /*sampler=*/std::move(sampler),
+ /*remover=*/std::move(remover),
+ /*max_size=*/checkpoint.max_size(),
+ /*max_times_sampled=*/checkpoint.max_times_sampled(),
+ /*rate_limiter=*/std::move(rate_limiter),
+ /*extensions=*/tables->at(index)->extensions());
+
+ for (const auto& checkpoint_item : checkpoint.items()) {
+ PriorityTable::Item insert_item;
+ insert_item.item = checkpoint_item;
+
+ for (const auto& key : checkpoint_item.chunk_keys()) {
+ REVERB_CHECK(chunk_by_key.contains(key));
+ insert_item.chunks.push_back(chunk_by_key[key]);
+ }
+
+ TF_RETURN_IF_ERROR(table->InsertCheckpointItem(std::move(insert_item)));
+ }
+
+ tables->at(index).swap(table);
+ } while (table_status.ok());
+
+ if (!tensorflow::errors::IsOutOfRange(table_status)) {
+ return table_status;
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFRecordCheckpointer::LoadLatest(
+ ChunkStore* chunk_store,
+ std::vector>* tables) {
+ REVERB_LOG(REVERB_INFO) << "Loading latest checkpoint from " << root_dir_;
+ std::vector filenames;
+ TF_RETURN_IF_ERROR(tensorflow::Env::Default()->GetMatchingPaths(
+ tensorflow::io::JoinPath(root_dir_, "*"), &filenames));
+ std::sort(filenames.begin(), filenames.end());
+ for (auto it = filenames.rbegin(); it != filenames.rend(); it++) {
+ if (HasDone(*it)) {
+ return Load(tensorflow::io::Basename(*it), chunk_store, tables);
+ }
+ }
+ return tensorflow::errors::NotFound(
+ absl::StrCat("No checkpoint found in ", root_dir_));
+}
+
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/platform/tfrecord_checkpointer.h b/reverb/cc/platform/tfrecord_checkpointer.h
new file mode 100644
index 0000000..7575395
--- /dev/null
+++ b/reverb/cc/platform/tfrecord_checkpointer.h
@@ -0,0 +1,100 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef REVERB_CC_PLATFORM_TFRECORD_CHECKPOINTER_H_
+#define REVERB_CC_PLATFORM_TFRECORD_CHECKPOINTER_H_
+
+#include
+#include
+#include
+
+#include "absl/strings/string_view.h"
+#include "reverb/cc/checkpointing/interface.h"
+#include "reverb/cc/chunk_store.h"
+#include "reverb/cc/priority_table.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+
+// Generates and stores proto checkpoints of PriorityTables and ChunkStore data
+// to a directory inside the top level `root_dir`.
+//
+// A set of PriorityTable constitutes the bases for a checkpoint. When `Save` is
+// called state of each PriorityTable is encoded into a PriorityTableCheckpoint.
+// The proto contains the state and initialization options of the table itself
+// and all its dependencies (RateLimiter, KeyDistribution etc) but does not
+// include the actual data. Instead a container with shared_ptr to every
+// referenced ChunkStore::Chunk is attached which ensures that all data remains
+// for the complete duration of the checkpointing operation.
+//
+// To avoid duplicating data, the union of the referenced chunks are
+// deduplicated before being stored to disk. The stored checkpoint has the
+// following format:
+//
+// /
+// /
+// tables.tfrecord
+// chunks.tfrecord
+// DONE
+//
+// DONE an empty file written once the checkpoint has been successfully written.
+// If DONE does not exist then the checkpoint is in process of being written or
+// the operation was unexpectedly interrupted and the data should be considered
+// corrupt.
+//
+// The most recent checkpoint can therefore be inferred from the name of the
+// directories within `root_dir`.
+//
+// If `group` is nonempty then the directory containing the checkpoint will be
+// created with `group` as group.
+class TFRecordCheckpointer : public CheckpointerInterface {
+ public:
+ explicit TFRecordCheckpointer(std::string root_dir, std::string group = "");
+
+ // Save a new checkpoint for every table in `tables` in sub directory
+ // inside `root_dir_`. If the call is successful, the ABSOLUTE path to the
+ // newly created checkpoint directory is returned.
+ //
+ // If `root_path_` does not exist then `Save` attempts to recursively
+ // create it before proceeding.
+ //
+ // After a successful save, all but the `keep_latest` most recent checkpoints
+ // are deleted.
+ tensorflow::Status Save(std::vector tables, int keep_latest,
+ std::string* path) override;
+
+ // Attempts to load a checkpoint stored within `root_dir_`.
+ tensorflow::Status Load(
+ absl::string_view relative_path, ChunkStore* chunk_store,
+ std::vector>* tables) override;
+
+ // Finds the most recent checkpoint within `root_dir_` and calls `Load`.
+ tensorflow::Status LoadLatest(
+ ChunkStore* chunk_store,
+ std::vector>* tables) override;
+
+ // TFRecordCheckpointer is neither copyable nor movable.
+ TFRecordCheckpointer(const TFRecordCheckpointer&) = delete;
+ TFRecordCheckpointer& operator=(const TFRecordCheckpointer&) = delete;
+
+ private:
+ const std::string root_dir_;
+ const std::string group_;
+};
+
+} // namespace reverb
+} // namespace deepmind
+
+#endif // REVERB_CC_PLATFORM_TFRECORD_CHECKPOINTER_H_
diff --git a/reverb/cc/platform/tfrecord_checkpointer_test.cc b/reverb/cc/platform/tfrecord_checkpointer_test.cc
new file mode 100644
index 0000000..a1bd403
--- /dev/null
+++ b/reverb/cc/platform/tfrecord_checkpointer_test.cc
@@ -0,0 +1,233 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/platform/tfrecord_checkpointer.h"
+
+#include
+#include
+#include
+#include
+#include
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "reverb/cc/chunk_store.h"
+#include "reverb/cc/distributions/fifo.h"
+#include "reverb/cc/distributions/heap.h"
+#include "reverb/cc/distributions/prioritized.h"
+#include "reverb/cc/distributions/uniform.h"
+#include "reverb/cc/priority_table.h"
+#include "reverb/cc/rate_limiter.h"
+#include "reverb/cc/testing/proto_test_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+using ::deepmind::reverb::testing::EqualsProto;
+
+inline std::string MakeRoot() {
+ std::string name;
+ REVERB_CHECK(tensorflow::Env::Default()->LocalTempFilename(&name));
+ return name;
+}
+
+std::unique_ptr MakeUniformTable(const std::string& name) {
+ return absl::make_unique(
+ name, absl::make_unique(),
+ absl::make_unique(), 1000, 0,
+ absl::make_unique(1.0, 1, -DBL_MAX, DBL_MAX));
+}
+
+std::unique_ptr MakePrioritizedTable(const std::string& name,
+ double exponent) {
+ return absl::make_unique(
+ name, absl::make_unique(exponent),
+ absl::make_unique(), 1000, 0,
+ absl::make_unique(1.0, 1, -DBL_MAX, DBL_MAX));
+}
+
+TEST(TFRecordCheckpointerTest, CreatesDirectoryInRoot) {
+ std::string root = MakeRoot();
+ TFRecordCheckpointer checkpointer(root);
+ std::string path;
+ auto* env = tensorflow::Env::Default();
+ TF_ASSERT_OK(checkpointer.Save(std::vector{}, 1, &path));
+ ASSERT_EQ(tensorflow::io::Dirname(path), root);
+ TF_EXPECT_OK(env->FileExists(path));
+}
+
+TEST(TFRecordCheckpointerTest, SaveAndLoad) {
+ ChunkStore chunk_store;
+
+ std::vector> tables;
+ tables.push_back(MakeUniformTable("uniform"));
+ tables.push_back(MakePrioritizedTable("prioritized_a", 0.5));
+ tables.push_back(MakePrioritizedTable("prioritized_b", 0.9));
+
+ std::vector chunk_keys;
+ for (int i = 0; i < 100; i++) {
+ for (int j = 0; j < tables.size(); j++) {
+ chunk_keys.push_back((j + 1) * 1000 + i);
+ auto chunk =
+ chunk_store.Insert(testing::MakeChunkData(chunk_keys.back()));
+ TF_EXPECT_OK(tables[j]->InsertOrAssign(
+ {testing::MakePrioritizedItem(i, i, {chunk->data()}), {chunk}}));
+ }
+ }
+
+ for (int i = 0; i < 100; i++) {
+ for (auto& table : tables) {
+ PriorityTable::SampledItem sample;
+ TF_EXPECT_OK(table->Sample(&sample));
+ }
+ }
+
+ TFRecordCheckpointer checkpointer(MakeRoot());
+
+ std::string path;
+ TF_ASSERT_OK(checkpointer.Save(
+ {tables[0].get(), tables[1].get(), tables[2].get()}, 1, &path));
+
+ ChunkStore loaded_chunk_store;
+ std::vector> loaded_tables;
+ loaded_tables.push_back(MakeUniformTable("uniform"));
+ loaded_tables.push_back(MakePrioritizedTable("prioritized_a", 0.5));
+ loaded_tables.push_back(MakePrioritizedTable("prioritized_b", 0.9));
+ TF_ASSERT_OK(checkpointer.Load(tensorflow::io::Basename(path),
+ &loaded_chunk_store, &loaded_tables));
+
+ // Check that all the chunks have been added.
+ std::vector> chunks;
+ TF_EXPECT_OK(loaded_chunk_store.Get(chunk_keys, &chunks));
+
+ // Check that the number of items matches for the loaded tables.
+ for (int i = 0; i < tables.size(); i++) {
+ EXPECT_EQ(loaded_tables[i]->size(), tables[i]->size());
+ }
+
+ // Sample a random item and check that it matches the item in the original
+ // table.
+ for (int i = 0; i < tables.size(); i++) {
+ PriorityTable::SampledItem sample;
+ TF_EXPECT_OK(loaded_tables[i]->Sample(&sample));
+ bool item_found = false;
+ for (auto& item : tables[i]->Copy()) {
+ if (item.item.key() == sample.item.key()) {
+ item_found = true;
+ item.item.set_times_sampled(item.item.times_sampled() + 1);
+ EXPECT_THAT(item.item, EqualsProto(sample.item));
+ break;
+ }
+ }
+ EXPECT_TRUE(item_found);
+ }
+}
+
+TEST(TFRecordCheckpointerTest, SaveDeletesOldData) {
+ ChunkStore chunk_store;
+
+ std::vector> tables;
+ tables.push_back(MakeUniformTable("uniform"));
+ tables.push_back(MakePrioritizedTable("prioritized_a", 0.5));
+ tables.push_back(MakePrioritizedTable("prioritized_b", 0.9));
+
+ std::vector chunk_keys;
+ for (int i = 0; i < 100; i++) {
+ for (int j = 0; j < tables.size(); j++) {
+ chunk_keys.push_back((j + 1) * 1000 + i);
+ auto chunk =
+ chunk_store.Insert(testing::MakeChunkData(chunk_keys.back()));
+ TF_EXPECT_OK(tables[j]->InsertOrAssign(
+ {testing::MakePrioritizedItem(i, i, {chunk->data()}), {chunk}}));
+ }
+ }
+
+ for (int i = 0; i < 100; i++) {
+ for (auto& table : tables) {
+ PriorityTable::SampledItem sample;
+ TF_EXPECT_OK(table->Sample(&sample));
+ }
+ }
+
+ auto test = [&tables](int keep_latest) {
+ auto root = MakeRoot();
+ TFRecordCheckpointer checkpointer(root);
+
+ for (int i = 0; i < 10; i++) {
+ std::string path;
+ TF_ASSERT_OK(
+ checkpointer.Save({tables[0].get(), tables[1].get(), tables[2].get()},
+ keep_latest, &path));
+
+ std::vector filenames;
+ TF_ASSERT_OK(tensorflow::Env::Default()->GetMatchingPaths(
+ tensorflow::io::JoinPath(root, "*"), &filenames));
+ ASSERT_EQ(filenames.size(), std::min(keep_latest, i + 1));
+ }
+ };
+ test(1); // Keep one checkpoint.
+ test(3); // Edge case keep_latest == num_tables
+ test(5); // Edge case keep_latest > num_tables
+}
+
+TEST(TFRecordCheckpointerTest, KeepLatestZeroReturnsError) {
+ ChunkStore chunk_store;
+
+ std::vector> tables;
+ tables.push_back(MakeUniformTable("uniform"));
+ tables.push_back(MakePrioritizedTable("prioritized_a", 0.5));
+ tables.push_back(MakePrioritizedTable("prioritized_b", 0.9));
+
+ std::vector chunk_keys;
+ for (int i = 0; i < 100; i++) {
+ for (int j = 0; j < tables.size(); j++) {
+ chunk_keys.push_back((j + 1) * 1000 + i);
+ auto chunk =
+ chunk_store.Insert(testing::MakeChunkData(chunk_keys.back()));
+ TF_EXPECT_OK(tables[j]->InsertOrAssign(
+ {testing::MakePrioritizedItem(i, i, {chunk->data()}), {chunk}}));
+ }
+ }
+
+ for (int i = 0; i < 100; i++) {
+ for (auto& table : tables) {
+ PriorityTable::SampledItem sample;
+ TF_EXPECT_OK(table->Sample(&sample));
+ }
+ }
+
+ TFRecordCheckpointer checkpointer(MakeRoot());
+ std::string path;
+ EXPECT_EQ(
+ checkpointer
+ .Save({tables[0].get(), tables[1].get(), tables[2].get()}, 0, &path)
+ .code(),
+ tensorflow::error::INVALID_ARGUMENT);
+}
+
+TEST(TFRecordCheckpointerTest, LoadLatestInEmptyDir) {
+ TFRecordCheckpointer checkpointer(MakeRoot());
+ ChunkStore chunk_store;
+ std::vector> tables;
+ EXPECT_EQ(checkpointer.LoadLatest(&chunk_store, &tables).code(),
+ tensorflow::error::NOT_FOUND);
+}
+
+} // namespace
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/platform/thread.h b/reverb/cc/platform/thread.h
new file mode 100644
index 0000000..5d9def7
--- /dev/null
+++ b/reverb/cc/platform/thread.h
@@ -0,0 +1,58 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Used to switch between different threading implementations. Concretely we
+// switch between Google internal threading libraries and std::thread.
+
+#ifndef REVERB_CC_SUPPORT_THREAD_H_
+#define REVERB_CC_SUPPORT_THREAD_H_
+
+#include
+#include
+
+#include "absl/strings/string_view.h"
+
+namespace deepmind {
+namespace reverb {
+namespace internal {
+
+// The `Thread` class can be subclassed to hold an object that invokes a
+// method in a separate thread. A `Thread` is considered to be active after
+// construction until the execution terminates. Calling the destructor of this
+// class must join the separate thread and block until it has completed.
+// Use `StartThread()` to create an instance of this class.
+class Thread {
+ public:
+ // Joins the running thread, i.e. blocks until the thread function has
+ // returned.
+ virtual ~Thread() = default;
+
+ // A Thread is not copyable.
+ Thread(const Thread&) = delete;
+ Thread& operator=(const Thread&) = delete;
+
+ protected:
+ Thread() = default;
+};
+
+// Starts a new thread that executes (a copy of) fn. The `name_prefix` may be
+// used by the implementation to label the new thread.
+std::unique_ptr StartThread(absl::string_view name_prefix,
+ std::function fn);
+
+} // namespace internal
+} // namespace reverb
+} // namespace deepmind
+
+#endif // REVERB_CC_SUPPORT_THREAD_H_
diff --git a/reverb/cc/platform/thread_test.cc b/reverb/cc/platform/thread_test.cc
new file mode 100644
index 0000000..5a8b7b7
--- /dev/null
+++ b/reverb/cc/platform/thread_test.cc
@@ -0,0 +1,40 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/platform/thread.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/synchronization/notification.h"
+
+namespace deepmind {
+namespace reverb {
+namespace internal {
+namespace {
+
+TEST(ThreadStdTest, ThreadRuns) {
+ absl::Notification n;
+ int x;
+ auto t = StartThread("", [&n, &x] {
+ x = 7;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ EXPECT_EQ(x, 7);
+}
+
+} // namespace
+} // namespace internal
+} // namespace reverb
+} // namespace deepmind
diff --git a/reverb/cc/priority_table.cc b/reverb/cc/priority_table.cc
new file mode 100644
index 0000000..87afc40
--- /dev/null
+++ b/reverb/cc/priority_table.cc
@@ -0,0 +1,377 @@
+// Copyright 2019 DeepMind Technologies Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reverb/cc/priority_table.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "google/protobuf/timestamp.pb.h"
+#include
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/memory/memory.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "reverb/cc/checkpointing/checkpoint.pb.h"
+#include "reverb/cc/chunk_store.h"
+#include "reverb/cc/distributions/interface.h"
+#include "reverb/cc/platform/logging.h"
+#include "reverb/cc/rate_limiter.h"
+#include "reverb/cc/schema.pb.h"
+#include "reverb/cc/table_extensions/interface.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace deepmind {
+namespace reverb {
+namespace {
+
+using Extensions =
+ std::vector>;
+
+inline bool IsAdjacent(const SequenceRange& a, const SequenceRange& b) {
+ return a.episode_id() == b.episode_id() && a.end() + 1 == b.start();
+}
+
+inline bool IsInsertedBefore(const PrioritizedItem& a,
+ const PrioritizedItem& b) {
+ return a.inserted_at().seconds() < b.inserted_at().seconds() ||
+ (a.inserted_at().seconds() == b.inserted_at().seconds() &&
+ a.inserted_at().nanos() < b.inserted_at().nanos());
+}
+
+inline void EncodeAsTimestampProto(absl::Time t,
+ google::protobuf::Timestamp* proto) {
+ const int64_t s = absl::ToUnixSeconds(t);
+ proto->set_seconds(s);
+ proto->set_nanos((t - absl::FromUnixSeconds(s)) / absl::Nanoseconds(1));
+}
+
+} // namespace
+
+PriorityTable::PriorityTable(
+ std::string name, std::shared_ptr sampler,
+ std::shared_ptr remover, int64_t max_size,
+ int32_t max_times_sampled, std::shared_ptr rate_limiter,
+ Extensions extensions,
+ absl::optional signature)
+ : sampler_(std::move(sampler)),
+ remover_(std::move(remover)),
+ max_size_(max_size),
+ max_times_sampled_(max_times_sampled),
+ name_(std::move(name)),
+ rate_limiter_(std::move(rate_limiter)),
+ extensions_(std::move(extensions)),
+ signature_(std::move(signature)) {
+ TF_CHECK_OK(rate_limiter_->RegisterPriorityTable(this));
+}
+
+PriorityTable::~PriorityTable() {
+ rate_limiter_->UnregisterPriorityTable(&mu_, this);
+}
+
+std::vector PriorityTable::Copy(size_t count) const {
+ std::vector- items;
+ absl::ReaderMutexLock lock(&mu_);
+ items.reserve(count == 0 ? data_.size() : count);
+ for (auto it = data_.cbegin();
+ it != data_.cend() && (count == 0 || items.size() < count); it++) {
+ items.push_back(it->second);
+ }
+ return items;
+}
+
+tensorflow::Status PriorityTable::InsertOrAssign(Item item) {
+ auto key = item.item.key();
+ auto priority = item.item.priority();
+
+ absl::WriterMutexLock lock(&mu_);
+
+ /// If item already exists in table then update its priority.
+ if (data_.contains(key)) {
+ return UpdateItem(key, priority, /*diffuse=*/true);
+ }
+
+ // Wait for the insert to be staged. While waiting the lock is released but
+ // once it returns the lock is aquired again. While waiting for the right to
+ // insert the operation might have transformed into an update.
+ TF_RETURN_IF_ERROR(rate_limiter_->AwaitCanInsert(&mu_));
+
+ if (data_.contains(key)) {
+ return UpdateItem(key, priority, /*diffuse=*/true);
+ }
+
+ // Set the insertion timestamp after the lock has been acquired as this
+ // represents the order it was inserted into the sampler and remover.
+ EncodeAsTimestampProto(absl::Now(), item.item.mutable_inserted_at());
+ data_[key] = std::move(item);
+
+ // TODO(b/154929932): If these fail, the rate limiter becomes out of sync.
+ TF_RETURN_IF_ERROR(sampler_->Insert(key, priority));
+ TF_RETURN_IF_ERROR(remover_->Insert(key, priority));
+
+ auto it = data_.find(key);
+ for (auto& extension : extensions_) {
+ extension->OnInsert(it->second);
+ }
+
+ // Remove an item if we exceeded `max_size_`.
+ if (data_.size() > max_size_) {
+ DeleteItem(remover_->Sample().key);
+ }
+
+ // Now that the new item has been inserted and an older item has
+ // (potentially) been removed the insert can be finalized.
+ rate_limiter_->Insert(&mu_);
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status PriorityTable::MutateItems(
+ absl::Span