diff --git a/.bazelrc b/.bazelrc new file mode 100644 index 0000000..587a0f9 --- /dev/null +++ b/.bazelrc @@ -0,0 +1,23 @@ +# gRPC using libcares in opensource has some issues. +build --define=grpc_no_ares=true + +# Suppress all warning messages. +build:short_logs --output_filter=DONT_MATCH_ANYTHING + +# Force python3 +build --action_env=PYTHON_BIN_PATH=python3 +build --repo_env=PYTHON_BIN_PATH=python3 +build --python_path=python3 + +build:manylinux2010 --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain + +build -c opt +build --cxxopt="-std=c++14" +build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" +build --auto_output_filter=subpackages +build --copt="-Wall" --copt="-Wno-sign-compare" +build --linkopt="-lrt -lm" + +# TF isn't built in dbg mode, so our dbg builds will segfault due to inconsistency +# of defines when using tf's headers. In particular in refcount.h. +build --cxxopt="-DNDEBUG" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..db177d4 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,28 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8018c87 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018, DeepMind Technologies Limited. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..679e11b --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# Reverb + +Reverb is a service for data transport (and storage) that is used for machine +learning research. One particularly common use is as a prioritized experience +replay system in reinforcement learning algorithms. diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000..ceb6840 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,86 @@ +workspace(name = "reverb") + +# To change to a version of protoc compatible with tensorflow: +# 1. Convert the required header version to a version string, e.g.: +# 3011004 => "3.11.4" +# 2. Calculate the sha256 of the binary: +# PROTOC_VERSION="3.11.4" +# curl -L "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-x86_64.zip" | sha256 +# 3. Update the two variables below. +# +# Alternatively, run bazel with the environment var "REVERB_PROTOC_VERSION" +# set to override PROTOC_VERSION. +# +# *WARNING* If using the REVERB_PROTOC_VERSION environment variable, sha256 +# checking is disabled. Use at your own risk. +PROTOC_VERSION = "3.9.0" +PROTOC_SHA256 = "15e395b648a1a6dda8fd66868824a396e9d3e89bc2c8648e3b9ab9801bea5d55" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "com_google_googletest", + sha256 = "ff7a82736e158c077e76188232eac77913a15dac0b22508c390ab3f88e6d6d86", + strip_prefix = "googletest-b6cd405286ed8635ece71c72f118e659f4ade3fb", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/googletest/archive/b6cd405286ed8635ece71c72f118e659f4ade3fb.zip", + "https://github.com/google/googletest/archive/b6cd405286ed8635ece71c72f118e659f4ade3fb.zip", + ], +) + +http_archive( + name = "com_google_absl", + sha256 = "f368a8476f4e2e0eccf8a7318b98dafbe30b2600f4e3cf52636e5eb145aba06a", # SHARED_ABSL_SHA + strip_prefix = "abseil-cpp-df3ea785d8c30a9503321a3d35ee7d35808f190d", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz", + ], +) + +## Begin GRPC related deps +http_archive( + name = "com_github_grpc_grpc", + patch_cmds = [ + """sed -i.bak 's/"python",/"python3",/g' third_party/py/python_configure.bzl""", + """sed -i.bak 's/PYTHONHASHSEED=0/PYTHONHASHSEED=0 python3/g' bazel/cython_library.bzl""", + ], + sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f", + strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz", + "https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz", + ], +) + +load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") + +grpc_deps() + +load("@upb//bazel:repository_defs.bzl", "bazel_version_repository") + +bazel_version_repository( + name = "bazel_version", +) + +load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") + +apple_rules_dependencies() + +load("@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies") + +apple_support_dependencies() +## End GRPC related deps + +load( + "//reverb/cc/platform/default:repo.bzl", + "cc_tf_configure", + "reverb_protoc_deps", + "reverb_python_deps", +) + +cc_tf_configure() + +reverb_python_deps() + +reverb_protoc_deps(version = PROTOC_VERSION, sha256 = PROTOC_SHA256) diff --git a/docker/dev.dockerfile b/docker/dev.dockerfile new file mode 100644 index 0000000..7a90ef9 --- /dev/null +++ b/docker/dev.dockerfile @@ -0,0 +1,97 @@ +# Run the following commands in order: +# +# REVERB_DIR="/tmp/reverb" # (change to the cloned reverb directory, e.g. "$HOME/reverb") +# docker build --tag tensorflow:reverb - < "$REVERB_DIR/docker/dev.dockerfile" +# docker run --rm -it -v ${REVERB_DIR}:/tmp/reverb \ +# -v ${HOME}/.gitconfig:/home/${USER}/.gitconfig:ro \ +# --name reverb tensorflow:reverb bash +# +# Test that everything worked: +# +# bazel test -c opt --test_output=streamed //reverb:tf_client_test + +ARG cpu_base_image="ubuntu:18.04" +ARG base_image=$cpu_base_image +FROM $base_image + +LABEL maintainer="Reverb Team " + +# Re-declare args because the args declared before FROM can't be used in any +# instruction after a FROM. +ARG cpu_base_image="ubuntu:18.04" +ARG base_image=$cpu_base_image + +# Pick up some TF dependencies +RUN apt-get update && apt-get install -y --no-install-recommends software-properties-common +RUN apt-get update && apt-get install -y --no-install-recommends \ + aria2 \ + build-essential \ + curl \ + git \ + less \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng-dev \ + libzmq3-dev \ + lsof \ + pkg-config \ + python3-distutils \ + python3-dev \ + python3.6-dev \ + rename \ + rsync \ + sox \ + unzip \ + vim \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -O https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && rm get-pip.py + +ARG bazel_version=2.2.0 +# This is to install bazel, for development purposes. +ENV BAZEL_VERSION ${bazel_version} +RUN mkdir /bazel && \ + cd /bazel && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \ + chmod +x bazel-*.sh && \ + ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + cd / && \ + rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh + +# Add support for Bazel autocomplete, see +# https://docs.bazel.build/versions/master/completion.html for instructions. +RUN cp /usr/local/lib/bazel/bin/bazel-complete.bash /etc/bash_completion.d +# TODO(b/154932078): This line should go into bashrc. +# NOTE(ebrevdo): RUN source doesn't work. Disabling the command below for now. +# RUN source /etc/bash_autcompletion.d/bazel-complete.bash + + +ARG pip_dependencies=' \ + contextlib2 \ + dm-tree \ + google-api-python-client \ + h5py \ + numpy \ + oauth2client \ + pandas \ + portpicker' + +RUN pip3 --no-cache-dir install $pip_dependencies + +# The latest tensorflow requires CUDA 10 compatible nvidia drivers (410.xx). +# If you are unable to update your drivers, an alternative is to compile +# tensorflow from source instead of installing from pip. +# Ensure we install the correct version by uninstalling first. +RUN pip3 uninstall -y tensorflow tensorflow-gpu tf-nightly tf-nightly-gpu + +RUN pip3 --no-cache-dir install tf-nightly --upgrade + +# bazel assumes the python executable is "python". +RUN ln -s /usr/bin/python3 /usr/bin/python + +WORKDIR "/tmp/reverb" + +CMD ["/bin/bash"] diff --git a/docker/release.dockerfile b/docker/release.dockerfile new file mode 100644 index 0000000..e50a88c --- /dev/null +++ b/docker/release.dockerfile @@ -0,0 +1,74 @@ +# Run the following commands in order: +# +# REVERB_DIR="/tmp/reverb" # (change to the cloned reverb directory, e.g. "$HOME/reverb") +# docker build --tag tensorflow:reverb_release - < "$REVERB_DIR/docker/release.dockerfile" +# docker run --rm -it -v ${REVERB_DIR}:/tmp/reverb \ +# -v ${HOME}/.gitconfig:/home/${USER}/.gitconfig:ro \ +# --name reverb_release tensorflow:reverb_release bash +# +# Test that everything worked: +# +# bazel test -c opt --copt=-mavx --config=manylinux2010 --test_output=errors //reverb/... + +ARG cpu_base_image="tensorflow/tensorflow:2.1.0-custom-op-ubuntu16" +ARG base_image=$cpu_base_image +FROM $base_image + +LABEL maintainer="Reverb Team " + +# Re-declare args because the args declared before FROM can't be used in any +# instruction after a FROM. +ARG cpu_base_image="tensorflow/tensorflow:2.1.0-custom-op-ubuntu16" +ARG base_image=$cpu_base_image + +# Pick up some TF dependencies +RUN apt-get update && apt-get install -y --no-install-recommends software-properties-common +RUN apt-get update && apt-get install -y --no-install-recommends \ + aria2 \ + build-essential \ + curl \ + git \ + less \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng-dev \ + libzmq3-dev \ + lsof \ + pkg-config \ + python3-dev \ + python3.6-dev \ + python3.7-dev \ + python3.8-dev \ + rename \ + rsync \ + sox \ + unzip \ + vim \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -O https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && rm get-pip.py + +ARG pip_dependencies=' \ + absl-py \ + contextlib2 \ + dm-tree \ + google-api-python-client \ + h5py \ + numpy \ + oauth2client \ + pandas \ + portpicker' + +# TODO(b/154930404): Update to 2.2.0 once it's out. May need to +# cut a branch to make changes that allow us to build against 2.2.0 instead +# of tf-nightly due to API changes. +RUN pip3 uninstall -y tensorflow tensorflow-gpu tf-nightly tf-nightly-gpu +RUN pip3 --no-cache-dir install tf-nightly --upgrade + +RUN pip3 --no-cache-dir install $pip_dependencies + +WORKDIR "/tmp/reverb" + +CMD ["/bin/bash"] diff --git a/requirements-no-deps.txt b/requirements-no-deps.txt new file mode 100644 index 0000000..c24d6dc --- /dev/null +++ b/requirements-no-deps.txt @@ -0,0 +1,5 @@ +# These packages are required by TensorFlow, but they need to be installed +# without their dependencies. +# To install: python -m pip install --no-deps -r requirements.txt. +keras_applications +keras_preprocessing diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fc3b214 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +dm-tree>=0.1.1 +numpy +wheel +setuptools +mock +future>=0.17.1 +portpicker diff --git a/reverb/BUILD b/reverb/BUILD new file mode 100644 index 0000000..33953d5 --- /dev/null +++ b/reverb/BUILD @@ -0,0 +1,216 @@ +# Description: Reverb is an efficient and easy to use prioritized replay system designed for ML research. + +load( + "//reverb/cc/platform:build_rules.bzl", + "reverb_absl_deps", + "reverb_py_test", + "reverb_pybind_deps", + "reverb_pybind_extension", + "reverb_pytype_library", + "reverb_pytype_strict_library", +) + +package(default_visibility = [":__subpackages__"]) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +reverb_pytype_strict_library( + name = "reverb", + srcs = ["__init__.py"], + srcs_version = "PY3", + deps = [ + ":client", + ":distributions", + ":errors", + ":rate_limiters", + ":replay_sample", + ":server", + ":tf_client", + ], +) + +reverb_pytype_strict_library( + name = "distributions", + srcs = ["distributions.py"], + srcs_version = "PY3", + deps = [":pybind"], +) + +reverb_pytype_strict_library( + name = "rate_limiters", + srcs = ["rate_limiters.py"], + srcs_version = "PY3", + deps = [":pybind"], +) + +reverb_pytype_library( + name = "client", + srcs = ["client.py"], + srcs_version = "PY3", + strict_deps = True, + visibility = ["//visibility:public"], + deps = [ + ":errors", + ":pybind", + ":replay_sample", + ":reverb_types", + "//reverb/cc:schema_py_pb2", + ], +) + +reverb_pytype_library( + name = "errors", + srcs = ["errors.py"], + srcs_version = "PY3", + strict_deps = True, + visibility = ["//visibility:public"], + deps = [], +) + +reverb_pytype_library( + name = "server", + srcs = ["server.py"], + srcs_version = "PY3", + strict_deps = True, + visibility = ["//visibility:public"], + deps = [ + ":checkpointer", + ":client", + ":distributions", + ":pybind", + ":rate_limiters", + ":reverb_types", + ], +) + +reverb_pytype_library( + name = "replay_sample", + srcs = ["replay_sample.py"], + srcs_version = "PY3", + strict_deps = True, +) + +reverb_pybind_extension( + name = "pybind", + srcs = ["pybind.cc"], + module_name = "libpybind", + srcs_version = "PY3ONLY", + visibility = ["//reverb:__subpackages__"], + deps = [ + "//reverb/cc:priority_table", + "//reverb/cc:replay_client", + "//reverb/cc:replay_writer", + "//reverb/cc:replay_sampler", + "//reverb/cc:reverb_server", + "//reverb/cc/checkpointing:interface", + "//reverb/cc/distributions:fifo", + "//reverb/cc/distributions:heap", + "//reverb/cc/distributions:interface", + "//reverb/cc/distributions:lifo", + "//reverb/cc/distributions:prioritized", + "//reverb/cc/distributions:uniform", + "//reverb/cc/platform:checkpointing", + "//reverb/cc/table_extensions:interface", + ] + reverb_pybind_deps() + reverb_absl_deps(), +) + +reverb_pytype_library( + name = "tf_client", + srcs = ["tf_client.py"], + srcs_version = "PY3", + strict_deps = True, + visibility = ["//visibility:public"], + deps = [ + ":replay_sample", + "//reverb/cc/ops:gen_client_ops", + "//reverb/cc/ops:gen_dataset_op", + ], +) + +reverb_pytype_library( + name = "checkpointer", + srcs = ["checkpointer.py"], + srcs_version = "PY3", + strict_deps = True, + visibility = ["//visibility:public"], + deps = [":pybind"], +) + +reverb_pytype_library( + name = "reverb_types", + srcs = ["reverb_types.py"], + srcs_version = "PY3", + strict_deps = True, + visibility = ["//visibility:public"], + deps = [ + ":pybind", + "//reverb/cc:schema_py_pb2", + ], +) + +reverb_py_test( + name = "checkpointer_test", + srcs = ["checkpointer_test.py"], + python_version = "PY3", + deps = [ + ":checkpointer", + ":pybind", + ], +) + +reverb_py_test( + name = "client_test", + srcs = ["client_test.py"], + python_version = "PY3", + deps = [ + ":client", + ":distributions", + ":errors", + ":rate_limiters", + ":server", + ], +) + +reverb_py_test( + name = "server_test", + srcs = ["server_test.py"], + python_version = "PY3", + deps = [ + ":client", + ":distributions", + ":rate_limiters", + ":server", + ], +) + +reverb_py_test( + name = "tf_client_test", + timeout = "short", + srcs = ["tf_client_test.py"], + python_version = "PY3", + shard_count = 6, + deps = [ + ":client", + ":distributions", + ":rate_limiters", + ":replay_sample", + ":server", + ":tf_client", + ], +) + +reverb_py_test( + name = "rate_limiters_test", + srcs = ["rate_limiters_test.py"], + python_version = "PY3", + deps = [":rate_limiters"], +) + +reverb_py_test( + name = "pybind_test", + srcs = ["pybind_test.py"], + python_version = "PY3", + deps = [":reverb"], +) diff --git a/reverb/__init__.py b/reverb/__init__.py new file mode 100644 index 0000000..a90da56 --- /dev/null +++ b/reverb/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reverb.""" + +from reverb import distributions +from reverb import rate_limiters + +from reverb.client import Client +from reverb.client import Writer + +from reverb.errors import ReverbError +from reverb.errors import TimeoutError + +from reverb.replay_sample import ReplaySample +from reverb.replay_sample import SampleInfo + +from reverb.server import PriorityTable +from reverb.server import Server + +from reverb.tf_client import ReplayDataset +from reverb.tf_client import TFClient diff --git a/reverb/cc/BUILD b/reverb/cc/BUILD new file mode 100644 index 0000000..7c8d884 --- /dev/null +++ b/reverb/cc/BUILD @@ -0,0 +1,290 @@ +# Description: Reverb is an efficient and easy to use prioritized replay system designed for ML research. + +load( + "//reverb/cc/platform:build_rules.bzl", + "reverb_absl_deps", + "reverb_cc_grpc_library", + "reverb_cc_library", + "reverb_cc_proto_library", + "reverb_cc_test", + "reverb_grpc_deps", + "reverb_py_proto_library", + "reverb_tf_deps", +) + +package(default_visibility = ["//reverb:__subpackages__"]) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +reverb_cc_test( + name = "chunk_store_test", + srcs = ["chunk_store_test.cc"], + deps = [ + ":chunk_store", + ":schema_cc_proto", + "//reverb/cc/platform:thread", + "//reverb/cc/testing:proto_test_util", + ] + reverb_tf_deps(), +) + +reverb_cc_test( + name = "rate_limiter_test", + srcs = ["rate_limiter_test.cc"], + deps = [ + ":priority_table", + "//reverb/cc/distributions:uniform", + "//reverb/cc/platform:thread", + "//reverb/cc/testing:proto_test_util", + ] + reverb_absl_deps(), +) + +reverb_cc_test( + name = "priority_table_test", + srcs = ["priority_table_test.cc"], + deps = [ + ":chunk_store", + ":priority_table", + ":schema_cc_proto", + "//reverb/cc/checkpointing:checkpoint_cc_proto", + "//reverb/cc/distributions:fifo", + "//reverb/cc/distributions:uniform", + "//reverb/cc/platform:thread", + "//reverb/cc/testing:proto_test_util", + ] + reverb_tf_deps() + reverb_absl_deps(), +) + +reverb_cc_test( + name = "tensor_compression_test", + srcs = ["tensor_compression_test.cc"], + deps = [ + ":tensor_compression", + "//reverb/cc/testing:tensor_testutil", + ] + reverb_tf_deps(), +) + +reverb_cc_test( + name = "replay_sampler_test", + srcs = ["replay_sampler_test.cc"], + deps = [ + ":replay_sampler", + ":replay_service_cc_grpc_proto", + ":replay_service_cc_proto", + ":tensor_compression", + "//reverb/cc/platform:logging", + "//reverb/cc/testing:tensor_testutil", + "//reverb/cc/testing:time_testutil", + ] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(), +) + +reverb_cc_test( + name = "replay_writer_test", + srcs = ["replay_writer_test.cc"], + deps = [ + ":replay_client", + ":replay_service_cc_grpc_proto", + ":replay_writer", + "//reverb/cc/support:grpc_util", + "//reverb/cc/support:uint128", + "//reverb/cc/testing:proto_test_util", + ] + reverb_tf_deps() + reverb_grpc_deps(), +) + +reverb_cc_test( + name = "replay_client_test", + srcs = ["replay_client_test.cc"], + deps = [ + ":replay_client", + ":replay_service_cc_grpc_proto", + ":replay_service_cc_proto", + "//reverb/cc/support:uint128", + "//reverb/cc/testing:proto_test_util", + ] + reverb_tf_deps() + reverb_grpc_deps(), +) + +reverb_cc_test( + name = "reverb_server_test", + srcs = ["reverb_server_test.cc"], + deps = [ + ":reverb_server", + "//reverb/cc/platform:net", + ] + reverb_tf_deps() + reverb_grpc_deps(), +) + +reverb_cc_test( + name = "replay_service_impl_test", + srcs = ["replay_service_impl_test.cc"], + deps = [ + ":replay_service_cc_proto", + ":replay_service_impl", + ":schema_cc_proto", + "//reverb/cc/distributions:fifo", + "//reverb/cc/distributions:uniform", + "//reverb/cc/platform:checkpointing", + "//reverb/cc/platform:thread", + "//reverb/cc/testing:proto_test_util", + ] + reverb_tf_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "chunk_store", + srcs = ["chunk_store.cc"], + hdrs = ["chunk_store.h"], + deps = [ + ":schema_cc_proto", + "//reverb/cc/platform:thread", + "//reverb/cc/support:queue", + ] + reverb_tf_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "priority_table", + srcs = [ + "priority_table.cc", + "rate_limiter.cc", + ], + hdrs = [ + "priority_table.h", + "rate_limiter.h", + ], + visibility = ["//reverb:__subpackages__"], + deps = [ + ":chunk_store", + ":priority_table_item", + ":schema_cc_proto", + "//reverb/cc/checkpointing:checkpoint_cc_proto", + "//reverb/cc/distributions:interface", + "//reverb/cc/platform:logging", + "//reverb/cc/table_extensions:interface", + ] + reverb_tf_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "tensor_compression", + srcs = ["tensor_compression.cc"], + hdrs = ["tensor_compression.h"], + visibility = ["//reverb:__subpackages__"], + deps = [ + "//reverb/cc/platform:logging", + "//reverb/cc/platform:snappy", + ] + reverb_tf_deps(), +) + +reverb_cc_library( + name = "replay_sampler", + srcs = ["replay_sampler.cc"], + hdrs = ["replay_sampler.h"], + visibility = ["//reverb:__subpackages__"], + deps = [ + ":replay_service_cc_grpc_proto", + ":replay_service_cc_proto", + ":tensor_compression", + "//reverb/cc/platform:logging", + "//reverb/cc/platform:thread", + "//reverb/cc/support:grpc_util", + "//reverb/cc/support:queue", + ] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "replay_writer", + srcs = ["replay_writer.cc"], + hdrs = ["replay_writer.h"], + visibility = ["//reverb:__subpackages__"], + deps = [ + ":replay_service_cc_grpc_proto", + ":replay_service_cc_proto", + ":schema_cc_proto", + ":tensor_compression", + "//reverb/cc/platform:logging", + "//reverb/cc/support:grpc_util", + "//reverb/cc/support:signature", + ] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "replay_client", + srcs = ["replay_client.cc"], + hdrs = ["replay_client.h"], + visibility = ["//reverb:__subpackages__"], + deps = [ + ":replay_sampler", + ":replay_service_cc_grpc_proto", + ":replay_service_cc_proto", + ":replay_writer", + ":schema_cc_proto", + "//reverb/cc/platform:grpc_utils", + "//reverb/cc/platform:logging", + "//reverb/cc/support:grpc_util", + "//reverb/cc/support:signature", + "//reverb/cc/support:uint128", + ] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "replay_service_impl", + srcs = ["replay_service_impl.cc"], + hdrs = ["replay_service_impl.h"], + deps = [ + ":chunk_store", + ":priority_table", + ":replay_service_cc_grpc_proto", + ":replay_service_cc_proto", + ":schema_cc_proto", + "//reverb/cc/checkpointing:interface", + "//reverb/cc/platform:logging", + "//reverb/cc/support:grpc_util", + "//reverb/cc/support:uint128", + ] + reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "priority_table_item", + hdrs = ["priority_table_item.h"], + deps = [ + ":chunk_store", + ":schema_cc_proto", + ], +) + +reverb_cc_library( + name = "reverb_server", + srcs = ["reverb_server.cc"], + hdrs = ["reverb_server.h"], + visibility = ["//reverb:__subpackages__"], + deps = [ + ":priority_table", + ":replay_client", + ":replay_service_impl", + "//reverb/cc/checkpointing:interface", + "//reverb/cc/platform:grpc_utils", + "//reverb/cc/platform:logging", + ] + reverb_grpc_deps() + reverb_absl_deps(), +) + +reverb_cc_proto_library( + name = "schema_cc_proto", + srcs = ["schema.proto"], +) + +reverb_py_proto_library( + name = "schema_py_pb2", + srcs = ["schema.proto"], + deps = [":schema_cc_proto"], +) + +reverb_cc_proto_library( + name = "replay_service_cc_proto", + srcs = ["replay_service.proto"], + visibility = ["//reverb:__subpackages__"], + deps = [":schema_cc_proto"], +) + +reverb_cc_grpc_library( + name = "replay_service_cc_grpc_proto", + srcs = ["replay_service.proto"], + generate_mocks = True, + visibility = ["//reverb:__subpackages__"], + deps = [":replay_service_cc_proto"], +) diff --git a/reverb/cc/checkpointing/BUILD b/reverb/cc/checkpointing/BUILD new file mode 100644 index 0000000..80a5013 --- /dev/null +++ b/reverb/cc/checkpointing/BUILD @@ -0,0 +1,23 @@ +load( + "//reverb/cc/platform:build_rules.bzl", + "reverb_cc_library", + "reverb_cc_proto_library", +) + +package(default_visibility = ["//reverb:__subpackages__"]) + +licenses(["notice"]) + +reverb_cc_proto_library( + name = "checkpoint_cc_proto", + srcs = ["checkpoint.proto"], + deps = [ + "//reverb/cc:schema_cc_proto", + ], +) + +reverb_cc_library( + name = "interface", + hdrs = ["interface.h"], + deps = ["//reverb/cc:priority_table"], +) diff --git a/reverb/cc/checkpointing/checkpoint.proto b/reverb/cc/checkpointing/checkpoint.proto new file mode 100644 index 0000000..34dbcbc --- /dev/null +++ b/reverb/cc/checkpointing/checkpoint.proto @@ -0,0 +1,61 @@ +syntax = "proto3"; + +package deepmind.reverb; + +import "reverb/cc/schema.proto"; + +// Configs for reconstructing a distribution to its initial state. + +message PriorityTableCheckpoint { + // Name of the priority table. + string table_name = 1; + + // Maximum number of items in the priority table. + // If an insert would result in this value getting exceeded, `remover` is used + // to select an item to remove before proceeding with the insert. + int64 max_size = 6; + + // The maximum number of times an item can be sampled before being removed. + int32 max_times_sampled = 7; + + // Items in the priority table ordered by `inserted_at` (asc). + // When loading a checkpoint the items should be added in the same order so + // position based distributions (e.g fifo) are reconstructed correctly. + repeated PrioritizedItem items = 2; + + // Checkpoint of the associated rate limiter. + RateLimiterCheckpoint rate_limiter = 3; + + // Options for constructing new samplers and removers of the correct type. + // Note that this does not include the state that they currently hold as it + // will be reproduced using the order of `items. + KeyDistributionOptions sampler = 4; + KeyDistributionOptions remover = 5; +} + +message RateLimiterCheckpoint { + reserved 1; // Deprecated field `name`. + + // The average number of times each item should be sampled during its + // lifetime. + double samples_per_insert = 2; + + // The minimum and maximum values the cursor is allowed to reach. The cursor + // value is calculated as `insert_count * samples_per_insert - + // sample_count`. If the value would go beyond these limits then the call is + // blocked until it can proceed without violating the constraints. + double min_diff = 3; + double max_diff = 4; + + // The minimum number of inserts required before any sample operation. + int64 min_size_to_sample = 5; + + // The total number of samples that occurred before the checkpoint. + int64 sample_count = 6; + + // The total number of inserts that occurred before the checkpoint. + int64 insert_count = 7; + + // The total number of deletes that occured before the checkpoint. + int64 delete_count = 8; +} diff --git a/reverb/cc/checkpointing/interface.h b/reverb/cc/checkpointing/interface.h new file mode 100644 index 0000000..c94f6e3 --- /dev/null +++ b/reverb/cc/checkpointing/interface.h @@ -0,0 +1,55 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_CHECKPOINTING_INTERFACE_H_ +#define REVERB_CC_CHECKPOINTING_INTERFACE_H_ + +#include "reverb/cc/priority_table.h" + +namespace deepmind { +namespace reverb { + +// A checkpointer is able to encode the configuration, data and state as a +// proto . This proto is stored in a permanent storage system where it can +// retrieved at a later point and restore a copy of the checkpointed tables. +class CheckpointerInterface { + public: + virtual ~CheckpointerInterface() = default; + + // Save a new checkpoint for every table in `tables` to permanent storage. If + // successful, `path` will contain an ABSOLUTE path that could be used to + // restore the checkpoint. + virtual tensorflow::Status Save(std::vector tables, + int keep_latest, std::string* path) = 0; + + // Attempts to load a checkpoint from the active workspace. + // + // Tables loaded from checkpoint must already exist in `tables`. When + // constructing the newly loaded table the extensions are passed from the old + // table and the item is replaced with the newly loaded table. + virtual tensorflow::Status Load( + absl::string_view relative_path, ChunkStore* chunk_store, + std::vector>* tables) = 0; + + // Finds the most recent checkpoint within the active workspace. See `Load` + // for more details. + virtual tensorflow::Status LoadLatest( + ChunkStore* chunk_store, + std::vector>* tables) = 0; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_CHECKPOINTING_INTERFACE_H_ diff --git a/reverb/cc/chunk_store.cc b/reverb/cc/chunk_store.cc new file mode 100644 index 0000000..d86bcf0 --- /dev/null +++ b/reverb/cc/chunk_store.cc @@ -0,0 +1,101 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/chunk_store.h" + +#include +#include +#include + +#include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "reverb/cc/platform/thread.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/support/queue.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +ChunkStore::ChunkStore(int cleanup_batch_size) + : delete_keys_(std::make_shared>(10000000)), + cleaner_(internal::StartThread( + "ChunkStore-Cleaner", [this, cleanup_batch_size] { + while (CleanupInternal(cleanup_batch_size)) { + } + })) {} + +ChunkStore::~ChunkStore() { + // Closing the queue makes all calls to `CleanupInternal` to return false + // which will break the loop in `cleaner_` making it joinable. + delete_keys_->Close(); + cleaner_ = nullptr; // Joins thread. +} + +std::shared_ptr ChunkStore::Insert(ChunkData item) { + absl::WriterMutexLock lock(&mu_); + std::weak_ptr& wp = data_[item.chunk_key()]; + std::shared_ptr sp = wp.lock(); + if (sp == nullptr) { + wp = (sp = std::shared_ptr(new Chunk(std::move(item)), + [q = delete_keys_](Chunk* chunk) { + q->Push(chunk->data().chunk_key()); + delete chunk; + })); + } + return sp; +} + +tensorflow::Status ChunkStore::Get( + absl::Span keys, + std::vector>* chunks) { + absl::ReaderMutexLock lock(&mu_); + chunks->clear(); + chunks->reserve(keys.size()); + for (int i = 0; i < keys.size(); i++) { + chunks->push_back(GetItem(keys[i])); + if (!chunks->at(i)) { + return tensorflow::errors::NotFound( + absl::StrCat("Chunk ", keys[i], " cannot be found.")); + } + } + return tensorflow::Status::OK(); +} + +std::shared_ptr ChunkStore::GetItem(Key key) { + auto it = data_.find(key); + return it == data_.end() ? nullptr : it->second.lock(); +} + +bool ChunkStore::CleanupInternal(int num_chunks) { + std::vector popped_keys(num_chunks); + for (int i = 0; i < num_chunks; i++) { + if (!delete_keys_->Pop(&popped_keys[i])) return false; + } + + absl::WriterMutexLock data_lock(&mu_); + for (const Key& key : popped_keys) { + data_.erase(key); + } + + return true; +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/chunk_store.h b/reverb/cc/chunk_store.h new file mode 100644 index 0000000..4b37aca --- /dev/null +++ b/reverb/cc/chunk_store.h @@ -0,0 +1,127 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_CHUNK_STORE_H_ +#define REVERB_CC_CHUNK_STORE_H_ + +#include +#include +#include + +#include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "reverb/cc/platform/thread.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/support/queue.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +// Maintains a bijection from chunk keys to Chunks. For inserting, the caller +// passes ChunkData which contains a chunk key and the actual data. We use the +// key for the mapping and wrap the ChunkData with in a thin class which +// provides a read-only accessor to the ChunkData. +// +// For ChunkData that populates sequence_range, an interval tree is maintained +// to allow lookups of intersecting ranges. +// +// Each Chunk is reference counted individually. When its reference count drops +// to zero, the Chunk is destroyed and subsequent calls to Get() will no longer +// return that Chunk. Please note that this container only holds a weak pointer +// to a Chunk, and thus does not count towards the reference count. For this +// reason, Insert() returns a shared pointer, as otherwise the Chunk would be +// destroyed right away. +// +// All public methods are thread safe. +class ChunkStore { + public: + using Key = uint64_t; + + class Chunk { + public: + explicit Chunk(ChunkData data) : data_(std::move(data)) {} + + // Returns the proto data of the chunk. + const ChunkData& data() const { return data_; } + + // (Potentially cached) size of `data`. + size_t DataByteSizeLong() const { + if (data_byte_size_ == 0) { + data_byte_size_ = data_.ByteSizeLong(); + } + return data_byte_size_; + } + + private: + ChunkData data_; + mutable size_t data_byte_size_ = 0; + }; + + // Starts `cleaner_`. `cleanp_batch_size` is the number of keys the cleaner + // should wait for before acquiring the lock and erasing them from `data_`. + explicit ChunkStore(int cleanup_batch_size = 1000); + + // Stops `cleaner_` closes `delete_keys_`. + ~ChunkStore(); + + // Attempts to insert a Chunk into the map using the key inside `item`. If no + // entry existed for the key, a new Chunk is created, inserted and returned. + // Otherwise, the existing chunk is returned. + std::shared_ptr Insert(ChunkData item) ABSL_LOCKS_EXCLUDED(mu_); + + // Gets the Chunk for each given key. Returns an error if one of the items + // does not exist or if `Close` has been called. On success, the returned + // items are in the same order as given in `keys`. + tensorflow::Status Get(absl::Span keys, + std::vector> *chunks) + ABSL_LOCKS_EXCLUDED(mu_); + + // Blocks until `num_chunks` expired entries have been cleaned up from + // `data_`. This method is called automatically by a background thread to + // limit memory size, but does not have any effect on the semantics of Get() + // or Insert() calls. + // + // Returns false if `delete_keys_` closed before `num_chunks` could be popped. + bool CleanupInternal(int num_chunks) ABSL_LOCKS_EXCLUDED(mu_); + + private: + // Gets an item. Returns nullptr if the item does not exist. + std::shared_ptr GetItem(Key key) ABSL_SHARED_LOCKS_REQUIRED(mu_); + + // Holds the actual mapping of key to Chunk. We only hold a weak pointer to + // the Chunk, which means that destruction and reference counting of the + // chunks happens independently of this map. + absl::flat_hash_map> data_ ABSL_GUARDED_BY(mu_); + + // Mutex protecting access to `data_`. + mutable absl::Mutex mu_; + + // Queue of keys of deleted items that will be cleaned up by `cleaner_`. Note + // the queue have to be allocated on the heap in order to avoid dereferncing + // errors caused by a stack allocated ChunkStore getting destroyed before all + // Chunk have been destroyed. + std::shared_ptr> delete_keys_; + + // Consumes `delete_keys_` to remove dead pointers in `data_`. + std::unique_ptr cleaner_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_CHUNK_STORE_H_ diff --git a/reverb/cc/chunk_store_test.cc b/reverb/cc/chunk_store_test.cc new file mode 100644 index 0000000..f412458 --- /dev/null +++ b/reverb/cc/chunk_store_test.cc @@ -0,0 +1,128 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/chunk_store.h" + +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "reverb/cc/platform/thread.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { +namespace { + +using ChunkVector = ::std::vector<::std::shared_ptr>; + +TEST(ChunkStoreTest, GetAfterInsertSucceeds) { + ChunkStore store; + std::shared_ptr inserted = + store.Insert(testing::MakeChunkData(2)); + ChunkVector chunks; + TF_ASSERT_OK(store.Get({2}, &chunks)); + EXPECT_EQ(inserted, chunks[0]); +} + +TEST(ChunkStoreTest, GetFailsWhenKeyDoesNotExist) { + ChunkStore store; + ChunkVector chunks; + EXPECT_EQ(store.Get({2}, &chunks).code(), tensorflow::error::NOT_FOUND); +} + +TEST(ChunkStoreTest, GetFailsAfterChunkIsDestroyed) { + ChunkStore store; + std::shared_ptr inserted = + store.Insert(testing::MakeChunkData(1)); + inserted = nullptr; + ChunkVector chunks; + EXPECT_EQ(store.Get({2}, &chunks).code(), tensorflow::error::NOT_FOUND); +} + +TEST(ChunkStoreTest, InsertingTwiceReturnsExistingChunk) { + ChunkStore store; + ChunkData data = testing::MakeChunkData(2); + data.add_data(); + std::shared_ptr first = + store.Insert(testing::MakeChunkData(2)); + EXPECT_NE(first, nullptr); + std::shared_ptr second = + store.Insert(testing::MakeChunkData(2)); + EXPECT_EQ(first, second); +} + +TEST(ChunkStoreTest, InsertingTwiceSucceedsWhenChunkIsDestroyed) { + ChunkStore store; + std::shared_ptr first = + store.Insert(testing::MakeChunkData(1)); + EXPECT_NE(first, nullptr); + first = nullptr; + std::shared_ptr second = + store.Insert(testing::MakeChunkData(1)); + EXPECT_NE(second, nullptr); +} + +TEST(ChunkStoreTest, CleanupDoesNotDeleteRequiredChunks) { + ChunkStore store(/*cleanup_batch_size=*/1); + + // Keep this one around. + std::shared_ptr first = + store.Insert(testing::MakeChunkData(1)); + + // Let this one expire. + { + std::shared_ptr second = + store.Insert(testing::MakeChunkData(2)); + } + + // The first one should still be available. + ChunkVector chunks; + TF_EXPECT_OK(store.Get({1}, &chunks)); + + // The second one should be gone eventually. + tensorflow::Status status; + while (status.code() != tensorflow::error::NOT_FOUND) { + status = store.Get({2}, &chunks); + } +} + +TEST(ChunkStoreTest, ConcurrentCalls) { + ChunkStore store; + std::vector> bundle; + std::atomic count(0); + for (ChunkStore::Key i = 0; i < 1000; i++) { + bundle.push_back(internal::StartThread("", [i, &store, &count] { + std::shared_ptr first = + store.Insert(testing::MakeChunkData(i)); + ChunkVector chunks; + TF_ASSERT_OK(store.Get({i}, &chunks)); + first = nullptr; + while (store.Get({i}, &chunks).code() != tensorflow::error::NOT_FOUND) { + } + count++; + })); + } + bundle.clear(); // Joins all threads. + EXPECT_EQ(count, 1000); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/distributions/BUILD b/reverb/cc/distributions/BUILD new file mode 100644 index 0000000..6907126 --- /dev/null +++ b/reverb/cc/distributions/BUILD @@ -0,0 +1,130 @@ +load( + "//reverb/cc/platform:build_rules.bzl", + "reverb_absl_deps", + "reverb_cc_library", + "reverb_cc_test", + "reverb_tf_deps", +) + +package(default_visibility = ["//reverb:__subpackages__"]) + +licenses(["notice"]) + +reverb_cc_library( + name = "interface", + hdrs = ["interface.h"], + deps = [ + "//reverb/cc/checkpointing:checkpoint_cc_proto", + ] + reverb_tf_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "uniform", + srcs = ["uniform.cc"], + hdrs = ["uniform.h"], + deps = [ + ":interface", + "//reverb/cc:schema_cc_proto", + "//reverb/cc/checkpointing:checkpoint_cc_proto", + "//reverb/cc/platform:logging", + ] + reverb_tf_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "fifo", + srcs = ["fifo.cc"], + hdrs = ["fifo.h"], + deps = [ + ":interface", + "//reverb/cc:schema_cc_proto", + "//reverb/cc/checkpointing:checkpoint_cc_proto", + "//reverb/cc/platform:logging", + ] + reverb_tf_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "lifo", + srcs = ["lifo.cc"], + hdrs = ["lifo.h"], + deps = [ + ":interface", + "//reverb/cc:schema_cc_proto", + "//reverb/cc/checkpointing:checkpoint_cc_proto", + "//reverb/cc/platform:logging", + ] + reverb_tf_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "prioritized", + srcs = ["prioritized.cc"], + hdrs = ["prioritized.h"], + deps = [ + ":interface", + "//reverb/cc:schema_cc_proto", + "//reverb/cc/checkpointing:checkpoint_cc_proto", + "//reverb/cc/platform:logging", + ] + reverb_tf_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "heap", + srcs = ["heap.cc"], + hdrs = ["heap.h"], + deps = [ + ":interface", + "//reverb/cc:schema_cc_proto", + "//reverb/cc/checkpointing:checkpoint_cc_proto", + "//reverb/cc/support:intrusive_heap", + ] + reverb_tf_deps() + reverb_absl_deps(), +) + +reverb_cc_test( + name = "uniform_test", + srcs = ["uniform_test.cc"], + deps = [ + ":uniform", + "//reverb/cc:schema_cc_proto", + "//reverb/cc/testing:proto_test_util", + ], +) + +reverb_cc_test( + name = "fifo_test", + srcs = ["fifo_test.cc"], + deps = [ + ":fifo", + "//reverb/cc:schema_cc_proto", + "//reverb/cc/testing:proto_test_util", + ], +) + +reverb_cc_test( + name = "lifo_test", + srcs = ["lifo_test.cc"], + deps = [ + ":lifo", + "//reverb/cc:schema_cc_proto", + "//reverb/cc/testing:proto_test_util", + ], +) + +reverb_cc_test( + name = "prioritized_test", + srcs = ["prioritized_test.cc"], + deps = [ + ":prioritized", + "//reverb/cc:schema_cc_proto", + "//reverb/cc/testing:proto_test_util", + ] + reverb_absl_deps(), +) + +reverb_cc_test( + name = "heap_test", + srcs = ["heap_test.cc"], + deps = [ + ":heap", + ":interface", + "//reverb/cc:schema_cc_proto", + "//reverb/cc/testing:proto_test_util", + ], +) diff --git a/reverb/cc/distributions/fifo.cc b/reverb/cc/distributions/fifo.cc new file mode 100644 index 0000000..60e6023 --- /dev/null +++ b/reverb/cc/distributions/fifo.cc @@ -0,0 +1,72 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/distributions/fifo.h" + +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/schema.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +tensorflow::Status FifoDistribution::Delete(KeyDistributionInterface::Key key) { + auto it = key_to_iterator_.find(key); + if (it == key_to_iterator_.end()) + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " not found in distribution.")); + keys_.erase(it->second); + key_to_iterator_.erase(it); + return tensorflow::Status::OK(); +} + +tensorflow::Status FifoDistribution::Insert(KeyDistributionInterface::Key key, + double priority) { + if (key_to_iterator_.find(key) != key_to_iterator_.end()) { + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " already exists in distribution.")); + } + key_to_iterator_.emplace(key, keys_.emplace(keys_.end(), key)); + return tensorflow::Status::OK(); +} + +tensorflow::Status FifoDistribution::Update(KeyDistributionInterface::Key key, + double priority) { + if (key_to_iterator_.find(key) == key_to_iterator_.end()) { + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " not found in distribution.")); + } + return tensorflow::Status::OK(); +} + +KeyDistributionInterface::KeyWithProbability FifoDistribution::Sample() { + REVERB_CHECK(!keys_.empty()); + return {keys_.front(), 1.}; +} + +void FifoDistribution::Clear() { + keys_.clear(); + key_to_iterator_.clear(); +} + +KeyDistributionOptions FifoDistribution::options() const { + KeyDistributionOptions options; + options.set_fifo(true); + return options; +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/distributions/fifo.h b/reverb/cc/distributions/fifo.h new file mode 100644 index 0000000..1cdd02d --- /dev/null +++ b/reverb/cc/distributions/fifo.h @@ -0,0 +1,56 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_FIFO_H_ +#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_FIFO_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/distributions/interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +// Fifo sampling. We ignore all priority values in the calls. Sample() always +// returns the key that was inserted first until this key is deleted. All +// operations take O(1) time. See KeyDistributionInterface for documentation +// about the methods. +class FifoDistribution : public KeyDistributionInterface { + public: + tensorflow::Status Delete(Key key) override; + + // The priority is ignored. + tensorflow::Status Insert(Key key, double priority) override; + + // This is a no-op but will return an error if the key does not exist. + tensorflow::Status Update(Key key, double priority) override; + + KeyWithProbability Sample() override; + + void Clear() override; + + KeyDistributionOptions options() const override; + + private: + std::list keys_; + absl::flat_hash_map::iterator> key_to_iterator_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_FIFO_H_ diff --git a/reverb/cc/distributions/fifo_test.cc b/reverb/cc/distributions/fifo_test.cc new file mode 100644 index 0000000..7080096 --- /dev/null +++ b/reverb/cc/distributions/fifo_test.cc @@ -0,0 +1,88 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/distributions/fifo.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { +namespace { + +TEST(FifoTest, ReturnValueSantiyChecks) { + FifoDistribution fifo; + + // Non existent keys cannot be deleted or updated. + EXPECT_EQ(fifo.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(fifo.Update(123, 4).code(), tensorflow::error::INVALID_ARGUMENT); + + // Keys cannot be inserted twice. + TF_EXPECT_OK(fifo.Insert(123, 4)); + EXPECT_THAT(fifo.Insert(123, 4).code(), tensorflow::error::INVALID_ARGUMENT); + + // Existing keys can be updated and sampled. + TF_EXPECT_OK(fifo.Update(123, 5)); + EXPECT_EQ(fifo.Sample().key, 123); + + // Existing keys cannot be deleted twice. + TF_EXPECT_OK(fifo.Delete(123)); + EXPECT_THAT(fifo.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT); +} + +TEST(FifoTest, MatchesFifoOrdering) { + int64_t kItems = 100; + + FifoDistribution fifo; + // Insert items. + for (int i = 0; i < kItems; i++) { + TF_EXPECT_OK(fifo.Insert(i, 0)); + } + // Delete every 10th item. + for (int i = 0; i < kItems; i++) { + if (i % 10 == 0) TF_EXPECT_OK(fifo.Delete(i)); + } + + for (int i = 0; i < kItems; i++) { + if (i % 10 == 0) continue; + KeyDistributionInterface::KeyWithProbability sample = fifo.Sample(); + EXPECT_EQ(sample.key, i); + EXPECT_EQ(sample.probability, 1); + TF_EXPECT_OK(fifo.Delete(sample.key)); + } +} + +TEST(FifoTest, Options) { + FifoDistribution fifo; + EXPECT_THAT(fifo.options(), testing::EqualsProto("fifo: true")); +} + +TEST(FifoDeathTest, ClearThenSample) { + FifoDistribution fifo; + for (int i = 0; i < 100; i++) { + TF_EXPECT_OK(fifo.Insert(i, i)); + } + fifo.Sample(); + fifo.Clear(); + EXPECT_DEATH(fifo.Sample(), ""); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/distributions/heap.cc b/reverb/cc/distributions/heap.cc new file mode 100644 index 0000000..3b6b985 --- /dev/null +++ b/reverb/cc/distributions/heap.cc @@ -0,0 +1,81 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/distributions/heap.h" + +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/distributions/interface.h" +#include "reverb/cc/schema.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +HeapDistribution::HeapDistribution(bool min_heap) + : sign_(min_heap ? 1 : -1), update_count_(0) {} + +tensorflow::Status HeapDistribution::Delete(KeyDistributionInterface::Key key) { + auto it = nodes_.find(key); + if (it == nodes_.end()) { + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " not found in distribution.")); + } + heap_.Remove(it->second.get()); + nodes_.erase(it); + return tensorflow::Status::OK(); +} + +tensorflow::Status HeapDistribution::Insert(KeyDistributionInterface::Key key, + double priority) { + if (nodes_.contains(key)) { + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " already exists in distribution.")); + } + nodes_[key] = + absl::make_unique(key, priority * sign_, update_count_++); + heap_.Push(nodes_[key].get()); + return tensorflow::Status::OK(); +} + +tensorflow::Status HeapDistribution::Update(KeyDistributionInterface::Key key, + double priority) { + if (!nodes_.contains(key)) { + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " not found in distribution.")); + } + nodes_[key]->priority = priority * sign_; + nodes_[key]->update_number = update_count_++; + heap_.Adjust(nodes_[key].get()); + return tensorflow::Status::OK(); +} + +KeyDistributionInterface::KeyWithProbability HeapDistribution::Sample() { + REVERB_CHECK(!nodes_.empty()); + return {heap_.top()->key, 1.}; +} + +void HeapDistribution::Clear() { + nodes_.clear(); + heap_.Clear(); +} + +KeyDistributionOptions HeapDistribution::options() const { + KeyDistributionOptions options; + options.mutable_heap()->set_min_heap(sign_ == 1); + return options; +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/distributions/heap.h b/reverb/cc/distributions/heap.h new file mode 100644 index 0000000..be0c609 --- /dev/null +++ b/reverb/cc/distributions/heap.h @@ -0,0 +1,90 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_HEAP_H_ +#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_HEAP_H_ + +#include +#include "absl/container/flat_hash_map.h" +#include "reverb/cc/distributions/interface.h" +#include "reverb/cc/support/intrusive_heap.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +// HeapDistribution always samples the item with the lowest or highest priority +// (controlled by `min_heap`). If multiple items share the same priority then +// the least recently inserted or updated key is sampled. +class HeapDistribution : public KeyDistributionInterface { + public: + explicit HeapDistribution(bool min_heap = true); + + // O(log n) time. + tensorflow::Status Delete(Key key) override; + + // O(log n) time. + tensorflow::Status Insert(Key key, double priority) override; + + // O(log n) time. + tensorflow::Status Update(Key key, double priority) override; + + // O(1) time. + KeyWithProbability Sample() override; + + // O(n) time. + void Clear() override; + + KeyDistributionOptions options() const override; + + private: + struct HeapNode { + Key key; + double priority; + IntrusiveHeapLink heap; + uint64_t update_number; + + HeapNode(Key key, double priority, uint64_t update_number) + : key(key), priority(priority), update_number(update_number) {} + }; + + struct HeapNodeCompare { + bool operator()(const HeapNode* a, const HeapNode* b) const { + // Lexicographic ordering by (priority, update_number). + return (a->priority < b->priority) || + ((a->priority == b->priority) && + (a->update_number < b->update_number)); + } + }; + + // 1 if `min_heap` = true, else -1. Priorities are multiplied by this number + // to control whether the min or max priority item should be sampled. + const double sign_; + + // Heap where the top item is the one with the lowest/highest priority in the + // distribution. + IntrusiveHeap heap_; + + // `IntrusiveHeap` does not manage the memory of its nodes so they are stored + // in `nodes_`. The content of nodes_ and heap_ are always kept in sync. + absl::flat_hash_map> nodes_; + + // Keep track of the number of inserts/updates for most-recent tie-breaking. + uint64_t update_count_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_HEAP_H_ diff --git a/reverb/cc/distributions/heap_test.cc b/reverb/cc/distributions/heap_test.cc new file mode 100644 index 0000000..dac8b6a --- /dev/null +++ b/reverb/cc/distributions/heap_test.cc @@ -0,0 +1,186 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/distributions/heap.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "reverb/cc/distributions/interface.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { +namespace { + + +TEST(HeapDistributionTest, ReturnValueSantiyChecks) { + HeapDistribution heap; + + // Non existent keys cannot be deleted or updated. + EXPECT_EQ(heap.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(heap.Update(123, 4).code(), tensorflow::error::INVALID_ARGUMENT); + + // Keys cannot be inserted twice. + TF_EXPECT_OK(heap.Insert(123, 4)); + EXPECT_EQ(heap.Insert(123, 4).code(), tensorflow::error::INVALID_ARGUMENT); + + // Existing keys can be updated and sampled. + TF_EXPECT_OK(heap.Update(123, 5)); + EXPECT_EQ(heap.Sample().key, 123); + + // Existing keys cannot be deleted twice. + TF_EXPECT_OK(heap.Delete(123)); + EXPECT_EQ(heap.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT); +} + +TEST(HeapDistributionTest, SampleMinPriorityFirstByDefault) { + HeapDistribution heap; + + TF_EXPECT_OK(heap.Insert(123, 2)); + TF_EXPECT_OK(heap.Insert(124, 1)); + TF_EXPECT_OK(heap.Insert(125, 3)); + + EXPECT_EQ(heap.Sample().key, 124); + + // Remove the top item. + TF_EXPECT_OK(heap.Delete(124)); + + // Second lowest priority is now the lowest. + EXPECT_EQ(heap.Sample().key, 123); +} + +TEST(HeapDistributionTest, BreakTiesByInsertionOrder) { + HeapDistribution heap; + + // We insert keys with priorities such that the keys are removed in the + // order [0, 1, 2, 3, 4, 5]. + // Three-way tie and two-way tie are checked. + TF_EXPECT_OK(heap.Insert(5, 300)); + TF_EXPECT_OK(heap.Insert(0, 1)); + TF_EXPECT_OK(heap.Insert(3, 20)); + TF_EXPECT_OK(heap.Insert(1, 1)); + TF_EXPECT_OK(heap.Insert(4, 20)); + TF_EXPECT_OK(heap.Insert(2, 1)); + + for (auto i = 0; i < 6; i++) { + EXPECT_EQ(heap.Sample().key, i); + TF_EXPECT_OK(heap.Delete(i)); + } +} + +TEST(HeapDistributionTest, BreakTiesByUpdateOrder) { + HeapDistribution heap; + + TF_EXPECT_OK(heap.Insert(2, 1)); + TF_EXPECT_OK(heap.Insert(0, 1)); + TF_EXPECT_OK(heap.Insert(1, 1)); + + // Removing keys at this point would result in the order [2, 0, 1] + // by LRU because the priorites are equal. + // This update does not change the priority, but does increase the update + // recency, resulting in the new order [0, 1, 2] which we verify. + TF_EXPECT_OK(heap.Update(2, 1)); + for (auto i = 0; i < 3; i++) { + EXPECT_EQ(heap.Sample().key, i); + TF_EXPECT_OK(heap.Delete(i)); + } +} + +TEST(HeapDistributionTest, SampleMaxPriorityWhenMinHeapFalse) { + HeapDistribution heap(false); + + TF_EXPECT_OK(heap.Insert(123, 2)); + TF_EXPECT_OK(heap.Insert(124, 1)); + TF_EXPECT_OK(heap.Insert(125, 3)); + + EXPECT_EQ(heap.Sample().key, 125); + + // Remove the top item. + TF_EXPECT_OK(heap.Delete(125)); + + // Second lowest priority is now the highest. + EXPECT_EQ(heap.Sample().key, 123); +} + +TEST(HeapDistributionTest, UpdateChangesOrder) { + HeapDistribution heap; + + TF_EXPECT_OK(heap.Insert(123, 2)); + TF_EXPECT_OK(heap.Insert(124, 1)); + TF_EXPECT_OK(heap.Insert(125, 3)); + + EXPECT_EQ(heap.Sample().key, 124); + + // Update the current top item. + TF_EXPECT_OK(heap.Update(124, 5)); + EXPECT_EQ(heap.Sample().key, 123); + + // Update another item and check that it is moved to the top. + TF_EXPECT_OK(heap.Update(125, 0.5)); + EXPECT_EQ(heap.Sample().key, 125); +} + +TEST(HeapDistributionTest, Clear) { + HeapDistribution heap; + + TF_EXPECT_OK(heap.Insert(123, 2)); + TF_EXPECT_OK(heap.Insert(124, 1)); + + EXPECT_EQ(heap.Sample().key, 124); + + // Clear distibution and insert an item that should otherwise be at the end. + heap.Clear(); + TF_EXPECT_OK(heap.Insert(125, 10)); + EXPECT_EQ(heap.Sample().key, 125); +} + +TEST(HeapDistributionTest, ProbabilityIsAlwaysOne) { + HeapDistribution heap; + + for (int i = 100; i < 150; i++) { + TF_EXPECT_OK(heap.Insert(i, i)); + } + + for (int i = 0; i < 50; i++) { + auto sample = heap.Sample(); + EXPECT_EQ(sample.probability, 1); + TF_EXPECT_OK(heap.Delete(sample.key)); + } +} + +TEST(HeapDistributionTest, Options) { + HeapDistribution min_heap; + HeapDistribution max_heap(false); + EXPECT_THAT(min_heap.options(), + testing::EqualsProto("heap: { min_heap: true }")); + EXPECT_THAT(max_heap.options(), + testing::EqualsProto("heap: { min_heap: false }")); +} + +TEST(HeapDistributionDeathTest, SampleFromEmptyDistribution) { + HeapDistribution heap; + EXPECT_DEATH(heap.Sample(), ""); + + TF_EXPECT_OK(heap.Insert(123, 2)); + heap.Sample(); + + TF_EXPECT_OK(heap.Delete(123)); + EXPECT_DEATH(heap.Sample(), ""); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/distributions/interface.h b/reverb/cc/distributions/interface.h new file mode 100644 index 0000000..53f61c4 --- /dev/null +++ b/reverb/cc/distributions/interface.h @@ -0,0 +1,68 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_INTERFACE_H_ +#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_INTERFACE_H_ + +#include +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +// Allows sampling from a population of keys with a specified priority per key. +// +// Member methods will not be called concurrently, so implementations do not +// need to be thread-safe. More to the point, a number of subclasses use bit +// generators that are not thread-safe, so methods like `Sample` are not +// thread-safe. +class KeyDistributionInterface { + public: + using Key = uint64_t; + + struct KeyWithProbability { + Key key; + double probability; + }; + + virtual ~KeyDistributionInterface() = default; + + // Deletes a key and the associated priority. Returns an error if the key does + // not exist. + virtual tensorflow::Status Delete(Key key) = 0; + + // Inserts a key and associated priority. Returns an error without any change + // if the key already exists. + virtual tensorflow::Status Insert(Key key, double priority) = 0; + + // Updates a key and associated priority. Returns an error if the key does + // not exist. + virtual tensorflow::Status Update(Key key, double priority) = 0; + + // Samples a key. Must contain keys when this is called. + virtual KeyWithProbability Sample() = 0; + + // Clear the distribution of all data. + virtual void Clear() = 0; + + // Options for dynamically constructing the distribution. Required when + // reconstructing class from checkpoint. Also used to query table metadata. + virtual KeyDistributionOptions options() const = 0; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_INTERFACE_H_ diff --git a/reverb/cc/distributions/lifo.cc b/reverb/cc/distributions/lifo.cc new file mode 100644 index 0000000..fe88785 --- /dev/null +++ b/reverb/cc/distributions/lifo.cc @@ -0,0 +1,72 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/distributions/lifo.h" + +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/schema.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +tensorflow::Status LifoDistribution::Delete(KeyDistributionInterface::Key key) { + auto it = key_to_iterator_.find(key); + if (it == key_to_iterator_.end()) + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " not found in distribution.")); + keys_.erase(it->second); + key_to_iterator_.erase(it); + return tensorflow::Status::OK(); +} + +tensorflow::Status LifoDistribution::Insert(KeyDistributionInterface::Key key, + double priority) { + if (key_to_iterator_.find(key) != key_to_iterator_.end()) { + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " already exists in distribution.")); + } + key_to_iterator_.emplace(key, keys_.emplace(keys_.begin(), key)); + return tensorflow::Status::OK(); +} + +tensorflow::Status LifoDistribution::Update(KeyDistributionInterface::Key key, + double priority) { + if (key_to_iterator_.find(key) == key_to_iterator_.end()) { + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " not found in distribution.")); + } + return tensorflow::Status::OK(); +} + +KeyDistributionInterface::KeyWithProbability LifoDistribution::Sample() { + REVERB_CHECK(!keys_.empty()); + return {keys_.front(), 1.}; +} + +void LifoDistribution::Clear() { + keys_.clear(); + key_to_iterator_.clear(); +} + +KeyDistributionOptions LifoDistribution::options() const { + KeyDistributionOptions options; + options.set_lifo(true); + return options; +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/distributions/lifo.h b/reverb/cc/distributions/lifo.h new file mode 100644 index 0000000..3850471 --- /dev/null +++ b/reverb/cc/distributions/lifo.h @@ -0,0 +1,56 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_LIFO_H_ +#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_LIFO_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/distributions/interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +// Lifo sampling. We ignore all priority values in the calls. Sample() always +// returns the key that was inserted last until this key is deleted. All +// operations take O(1) time. See KeyDistributionInterface for documentation +// about the methods. +class LifoDistribution : public KeyDistributionInterface { + public: + tensorflow::Status Delete(Key key) override; + + // The priority is ignored. + tensorflow::Status Insert(Key key, double priority) override; + + // This is a no-op but will return an error if the key does not exist. + tensorflow::Status Update(Key key, double priority) override; + + KeyWithProbability Sample() override; + + void Clear() override; + + KeyDistributionOptions options() const override; + + private: + std::list keys_; + absl::flat_hash_map::iterator> key_to_iterator_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_LIFO_H_ diff --git a/reverb/cc/distributions/lifo_test.cc b/reverb/cc/distributions/lifo_test.cc new file mode 100644 index 0000000..e8d2b10 --- /dev/null +++ b/reverb/cc/distributions/lifo_test.cc @@ -0,0 +1,88 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/distributions/lifo.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { +namespace { + +TEST(LifoTest, ReturnValueSantiyChecks) { + LifoDistribution lifo; + + // Non existent keys cannot be deleted or updated. + EXPECT_EQ(lifo.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(lifo.Update(123, 4).code(), tensorflow::error::INVALID_ARGUMENT); + + // Keys cannot be inserted twice. + TF_EXPECT_OK(lifo.Insert(123, 4)); + EXPECT_THAT(lifo.Insert(123, 4).code(), tensorflow::error::INVALID_ARGUMENT); + + // Existing keys can be updated and sampled. + TF_EXPECT_OK(lifo.Update(123, 5)); + EXPECT_EQ(lifo.Sample().key, 123); + + // Existing keys cannot be deleted twice. + TF_EXPECT_OK(lifo.Delete(123)); + EXPECT_THAT(lifo.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT); +} + +TEST(LifoTest, MatchesLifoOrdering) { + int64_t kItems = 100; + + LifoDistribution lifo; + // Insert items. + for (int i = 0; i < kItems; i++) { + TF_EXPECT_OK(lifo.Insert(i, 0)); + } + // Delete every 10th item. + for (int i = 0; i < kItems; i++) { + if (i % 10 == 0) TF_EXPECT_OK(lifo.Delete(i)); + } + + for (int i = kItems - 1; i >= 0; i--) { + if (i % 10 == 0) continue; + KeyDistributionInterface::KeyWithProbability sample = lifo.Sample(); + EXPECT_EQ(sample.key, i); + EXPECT_EQ(sample.probability, 1); + TF_EXPECT_OK(lifo.Delete(sample.key)); + } +} + +TEST(LifoTest, Options) { + LifoDistribution lifo; + EXPECT_THAT(lifo.options(), testing::EqualsProto("lifo: true")); +} + +TEST(LifoDeathTest, ClearThenSample) { + LifoDistribution lifo; + for (int i = 0; i < 100; i++) { + TF_EXPECT_OK(lifo.Insert(i, i)); + } + lifo.Sample(); + lifo.Clear(); + EXPECT_DEATH(lifo.Sample(), ""); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/distributions/prioritized.cc b/reverb/cc/distributions/prioritized.cc new file mode 100644 index 0000000..4a0f0f3 --- /dev/null +++ b/reverb/cc/distributions/prioritized.cc @@ -0,0 +1,181 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/distributions/prioritized.h" + +#include +#include + +#include "absl/random/distributions.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/schema.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { +namespace { + +// A priority of zero should correspond to zero probability, even if the +// priority exponent is zero. So this modified version of std::pow is used to +// turn priorities into weights. Expects base and exponent to be non-negative. +inline double power(double base, double exponent) { + return base == 0. ? 0. : std::pow(base, exponent); +} + +tensorflow::Status CheckValidPriority(double priority) { + if (std::isnan(priority)) + return tensorflow::errors::InvalidArgument("Priority must not be NaN."); + if (priority < 0) + return tensorflow::errors::InvalidArgument( + "Priority must not be negative."); + return tensorflow::Status::OK(); +} + +} // namespace + +PrioritizedDistribution::PrioritizedDistribution(double priority_exponent) + : priority_exponent_(priority_exponent), capacity_(std::pow(2, 17)) { + REVERB_CHECK_GE(priority_exponent_, 0); + sum_tree_.resize(capacity_); +} + +tensorflow::Status PrioritizedDistribution::Delete(Key key) { + const size_t last_index = key_to_index_.size() - 1; + const auto it = key_to_index_.find(key); + if (it == key_to_index_.end()) + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " not found in distribution.")); + const size_t index = it->second; + + if (index != last_index) { + // Replace the element that we want to remove with the last element. + SetNode(index, NodeValue(last_index)); + const Key last_key = sum_tree_[last_index].key; + sum_tree_[index].key = last_key; + key_to_index_[last_key] = index; + } + + SetNode(last_index, 0); + key_to_index_.erase(it); // Note that this must occur after SetNode. + + return tensorflow::Status::OK(); +} + +tensorflow::Status PrioritizedDistribution::Insert(Key key, double priority) { + TF_RETURN_IF_ERROR(CheckValidPriority(priority)); + const size_t index = key_to_index_.size(); + if (index == capacity_) { + capacity_ *= 2; + sum_tree_.resize(capacity_); + } + if (!key_to_index_.try_emplace(key, index).second) { + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " already exists in distribution.")); + } + sum_tree_[index].key = key; + REVERB_CHECK_EQ(sum_tree_[index].sum, 0); + SetNode(index, power(priority, priority_exponent_)); + return tensorflow::Status::OK(); +} + +tensorflow::Status PrioritizedDistribution::Update(Key key, double priority) { + TF_RETURN_IF_ERROR(CheckValidPriority(priority)); + const auto it = key_to_index_.find(key); + if (it == key_to_index_.end()) { + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " not found in distribution.")); + } + SetNode(it->second, power(priority, priority_exponent_)); + return tensorflow::Status::OK(); +} + +KeyDistributionInterface::KeyWithProbability PrioritizedDistribution::Sample() { + const size_t size = key_to_index_.size(); + REVERB_CHECK_NE(size, 0); + + // This should never be called concurrently from multiple threads. + const double target = absl::Uniform(bit_gen_, 0, 1); + const double total_weight = sum_tree_[0].sum; + + // All keys have zero priority so treat as if uniformly sampling. + if (total_weight == 0) { + const size_t pos = static_cast(target * size); + return {sum_tree_[pos].key, 1. / size}; + } + + // We begin traversing the `sum_tree_` from the root to the children in order + // to find the `index` corresponding to the sampled `target_weight`. + size_t index = 0; + double target_weight = target * total_weight; + while (true) { + // Go to the left sub tree if it contains our sampled `target_weight`. + const size_t left_index = 2 * index + 1; + const double left_sum = NodeSum(left_index); + if (target_weight < left_sum) { + index = left_index; + continue; + } + target_weight -= left_sum; + // Go to the right sub tree if it contains our sampled `target_weight`. + const size_t right_index = 2 * index + 2; + const double right_sum = NodeSum(right_index); + if (target_weight < right_sum) { + index = right_index; + continue; + } + target_weight -= right_sum; + // Otherwise it is the current index. + break; + } + REVERB_CHECK_LT(index, size); + const double picked_weight = NodeValue(index); + REVERB_CHECK_LT(target_weight, picked_weight); + return {sum_tree_[index].key, picked_weight / total_weight}; +} + +void PrioritizedDistribution::Clear() { + for (size_t i = 0; i < key_to_index_.size(); ++i) { + sum_tree_[i].sum = 0; + } + key_to_index_.clear(); +} + +KeyDistributionOptions PrioritizedDistribution::options() const { + KeyDistributionOptions options; + options.mutable_prioritized()->set_priority_exponent(priority_exponent_); + return options; +} + +double PrioritizedDistribution::NodeValue(size_t index) const { + const size_t left_index = 2 * index + 1; + const size_t right_index = 2 * index + 2; + return sum_tree_[index].sum - NodeSum(left_index) - NodeSum(right_index); +} + +double PrioritizedDistribution::NodeSum(size_t index) const { + return index < key_to_index_.size() ? sum_tree_[index].sum : 0; +} + +void PrioritizedDistribution::SetNode(size_t index, double value) { + double difference = value - NodeValue(index); + sum_tree_[index].sum += difference; + while (index != 0) { + index = (index - 1) / 2; + sum_tree_[index].sum += difference; + } +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/distributions/prioritized.h b/reverb/cc/distributions/prioritized.h new file mode 100644 index 0000000..c9106cb --- /dev/null +++ b/reverb/cc/distributions/prioritized.h @@ -0,0 +1,108 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_PRIORITIZED_H_ +#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_PRIORITIZED_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/random/random.h" +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/distributions/interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +// This an implementation of a categorical distribution that allows incremental +// changes to the keys to be made efficiently. The probability of sampling a key +// is proportional to its priority raised to configurable exponent. +// +// Since the priorities and probabilities are stored as doubles, numerical +// rounding errors may be introduced especially when the relative size of +// probabilities for keys is large. Ideally when using this class priorities are +// roughly the same scale and the priority exponent is not large, e.g. less than +// 2. +// +// This was forked from: +// ## proportional_picker.h +// +class PrioritizedDistribution : public KeyDistributionInterface { + public: + explicit PrioritizedDistribution(double priority_exponent); + + // O(log n) time. + tensorflow::Status Delete(Key key) override; + + // The priority must be non-negative. O(log n) time. + tensorflow::Status Insert(Key key, double priority) override; + + // The priority must be non-negative. O(log n) time. + tensorflow::Status Update(Key key, double priority) override; + + // O(log n) time. + KeyWithProbability Sample() override; + + // O(n) time. + void Clear() override; + + KeyDistributionOptions options() const override; + + private: + struct Node { + Key key; + // Sum of the exponentiated priority of this node and all its descendants. + // This includes the entire sub tree with inner and leaf nodes. + // `NodeValue()` can be used to get the exponentiated priority of a node + // without its children. + double sum = 0; + }; + + // Gets the individual value of a node in `sum_tree_` without the summed up + // value of all its descendants. + double NodeValue(size_t index) const; + + // Sum of the exponentiated priority of this node and all its descendants. + // If the index is out of bounds, then 0 is returned. + double NodeSum(size_t index) const; + + // Sets the individual value of a node in the `sum_tree_`. This does not + // include the value of the descendants. + void SetNode(size_t index, double value); + + // Controls the degree of prioritization. Priorities are raised to this + // exponent before adding them to the `SumTree` as weights. A non-negative + // number where a value of zero corresponds each key having the same + // probability (except for keys with zero priority). + const double priority_exponent_; + + // Capacity of the summary tree. Starts at ~130000 and grows exponentially. + size_t capacity_; + + // A tree stored as a flat vector were each node is the sum of its children + // plus its own exponentiated priority. + std::vector sum_tree_; + + // Maps a key to the index where this key can be found in `sum_tree_`. + absl::flat_hash_map key_to_index_; + + // Used for sampling, not thread-safe. + absl::BitGen bit_gen_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_PRIORITIZED_H_ diff --git a/reverb/cc/distributions/prioritized_test.cc b/reverb/cc/distributions/prioritized_test.cc new file mode 100644 index 0000000..7f631de --- /dev/null +++ b/reverb/cc/distributions/prioritized_test.cc @@ -0,0 +1,151 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/distributions/prioritized.h" + +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/container/flat_hash_map.h" +#include "absl/random/distributions.h" +#include "absl/random/random.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { +namespace { + +const double kInitialPriorityExponent = 1; + +TEST(PrioritizedTest, ReturnValueSantiyChecks) { + PrioritizedDistribution prioritized(kInitialPriorityExponent); + + // Non existent keys cannot be deleted or updated. + EXPECT_EQ(prioritized.Delete(123).code(), + tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(prioritized.Update(123, 4).code(), + tensorflow::error::INVALID_ARGUMENT); + + // Keys cannot be inserted twice. + TF_EXPECT_OK(prioritized.Insert(123, 4)); + EXPECT_EQ(prioritized.Insert(123, 4).code(), + tensorflow::error::INVALID_ARGUMENT); + + // Existing keys can be updated and sampled. + TF_EXPECT_OK(prioritized.Update(123, 5)); + EXPECT_EQ(prioritized.Sample().key, 123); + + // Negative priorities are not allowed. + EXPECT_EQ(prioritized.Update(123, -1).code(), + tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(prioritized.Insert(456, -1).code(), + tensorflow::error::INVALID_ARGUMENT); + + // NAN priorites are not allowed + EXPECT_EQ(prioritized.Update(123, NAN).code(), + tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(prioritized.Insert(456, NAN).code(), + tensorflow::error::INVALID_ARGUMENT); + + // Existing keys cannot be deleted twice. + TF_EXPECT_OK(prioritized.Delete(123)); + EXPECT_EQ(prioritized.Delete(123).code(), + tensorflow::error::INVALID_ARGUMENT); +} + +TEST(PrioritizedTest, AllZeroPrioritiesResultsInUniformSampling) { + int64_t kItems = 100; + int64_t kSamples = 1000000; + double expected_probability = 1. / static_cast(kItems); + + PrioritizedDistribution prioritized(kInitialPriorityExponent); + for (int i = 0; i < kItems; i++) { + TF_EXPECT_OK(prioritized.Insert(i, 0)); + } + std::vector counts(kItems); + for (int i = 0; i < kSamples; i++) { + KeyDistributionInterface::KeyWithProbability sample = prioritized.Sample(); + EXPECT_EQ(sample.probability, expected_probability); + counts[sample.key]++; + } + for (int64_t count : counts) { + EXPECT_NEAR(static_cast(count) / static_cast(kSamples), + expected_probability, 0.05); + } +} + +TEST(PrioritizedTest, SampledDistributionMatchesProbabilities) { + const int kStart = 10; + const int kEnd = 100; + const int kSamples = 1000000; + + PrioritizedDistribution prioritized(kInitialPriorityExponent); + double sum = 0; + absl::BitGen bit_gen_; + for (int i = 0; i < kEnd; i++) { + if (absl::Uniform(bit_gen_, 0, 1) < 0.5) { + TF_EXPECT_OK(prioritized.Insert(i, i)); + } else { + TF_EXPECT_OK(prioritized.Insert(i, 123)); + TF_EXPECT_OK(prioritized.Update(i, i)); + } + sum += i; + } + // Remove the first few items. + for (int i = 0; i < kStart; i++) { + TF_EXPECT_OK(prioritized.Delete(i)); + sum -= i; + } + // Update the priorities. + std::vector counts(kEnd); + absl::flat_hash_map probabilities; + for (int i = 0; i < kSamples; i++) { + KeyDistributionInterface::KeyWithProbability sample = prioritized.Sample(); + probabilities[sample.key] = sample.probability; + counts[sample.key]++; + EXPECT_NEAR(sample.probability, sample.key / sum, 0.001); + } + for (int k = 0; k < kStart; k++) EXPECT_EQ(counts[k], 0); + for (int k = kStart; k < kEnd; k++) { + EXPECT_NEAR(static_cast(counts[k]) / static_cast(kSamples), + probabilities[k], 0.05); + } +} + +TEST(PrioritizedTest, SetsPriorityExponentInOptions) { + PrioritizedDistribution prioritized_a(0.1); + PrioritizedDistribution prioritized_b(0.5); + EXPECT_THAT(prioritized_a.options(), + testing::EqualsProto("prioritized: { priority_exponent: 0.1 } ")); + EXPECT_THAT(prioritized_b.options(), + testing::EqualsProto("prioritized: { priority_exponent: 0.5 } ")); +} + +TEST(PrioritizedDeathTest, ClearThenSample) { + PrioritizedDistribution prioritized(kInitialPriorityExponent); + for (int i = 0; i < 100; i++) { + TF_EXPECT_OK(prioritized.Insert(i, i)); + } + prioritized.Sample(); + prioritized.Clear(); + EXPECT_DEATH(prioritized.Sample(), ""); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/distributions/uniform.cc b/reverb/cc/distributions/uniform.cc new file mode 100644 index 0000000..0e6ed43 --- /dev/null +++ b/reverb/cc/distributions/uniform.cc @@ -0,0 +1,83 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/distributions/uniform.h" + +#include "absl/strings/str_cat.h" +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/schema.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +tensorflow::Status UniformDistribution::Delete(Key key) { + const auto it = key_to_index_.find(key); + if (it == key_to_index_.end()) + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " not found in distribution.")); + const size_t index = it->second; + key_to_index_.erase(it); + + const size_t last_index = keys_.size() - 1; + const Key last_key = keys_.back(); + if (index != last_index) { + keys_[index] = last_key; + key_to_index_[last_key] = index; + } + + keys_.pop_back(); + return tensorflow::Status::OK(); +} + +tensorflow::Status UniformDistribution::Insert(Key key, double priority) { + const size_t index = keys_.size(); + if (!key_to_index_.emplace(key, index).second) + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " already exists in distribution.")); + keys_.push_back(key); + return tensorflow::Status::OK(); +} + +tensorflow::Status UniformDistribution::Update(Key key, double priority) { + if (key_to_index_.find(key) == key_to_index_.end()) + return tensorflow::errors::InvalidArgument( + absl::StrCat("Key ", key, " not found in distribution.")); + return tensorflow::Status::OK(); +} + +KeyDistributionInterface::KeyWithProbability UniformDistribution::Sample() { + REVERB_CHECK(!keys_.empty()); + + // This code is not thread-safe, because bit_gen_ is not protected by a mutex + // and is not itself thread-safe. + const size_t index = absl::Uniform(bit_gen_, 0, keys_.size()); + return {keys_[index], 1.0 / static_cast(keys_.size())}; +} + +void UniformDistribution::Clear() { + keys_.clear(); + key_to_index_.clear(); +} + +KeyDistributionOptions UniformDistribution::options() const { + KeyDistributionOptions options; + options.set_uniform(true); + return options; +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/distributions/uniform.h b/reverb/cc/distributions/uniform.h new file mode 100644 index 0000000..623892b --- /dev/null +++ b/reverb/cc/distributions/uniform.h @@ -0,0 +1,60 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_UNIFORM_H_ +#define LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_UNIFORM_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/random/random.h" +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/distributions/interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +// Uniform sampling. We ignore all priority values in the calls. All operations +// take O(1) time. See KeyDistributionInterface for documentation about the +// methods. +class UniformDistribution : public KeyDistributionInterface { + public: + tensorflow::Status Delete(Key key) override; + + tensorflow::Status Insert(Key key, double priority) override; + + tensorflow::Status Update(Key key, double priority) override; + + KeyWithProbability Sample() override; + + void Clear() override; + + KeyDistributionOptions options() const override; + + private: + // All keys. + std::vector keys_; + + // Maps a key to the index where this key can be found in `keys_. + absl::flat_hash_map key_to_index_; + + // Used for sampling, not thread-safe. + absl::BitGen bit_gen_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // LEARNING_DEEPMIND_REPLAY_REVERB_DISTRIBUTIONS_UNIFORM_H_ diff --git a/reverb/cc/distributions/uniform_test.cc b/reverb/cc/distributions/uniform_test.cc new file mode 100644 index 0000000..8000dbe --- /dev/null +++ b/reverb/cc/distributions/uniform_test.cc @@ -0,0 +1,87 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/distributions/uniform.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { +namespace { + +TEST(UniformTest, ReturnValueSantiyChecks) { + UniformDistribution uniform; + + // Non existent keys cannot be deleted or updated. + EXPECT_EQ(uniform.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(uniform.Update(123, 4).code(), tensorflow::error::INVALID_ARGUMENT); + + // Keys cannot be inserted twice. + TF_EXPECT_OK(uniform.Insert(123, 4)); + EXPECT_EQ(uniform.Insert(123, 4).code(), tensorflow::error::INVALID_ARGUMENT); + + // Existing keys can be updated and sampled. + TF_EXPECT_OK(uniform.Update(123, 5)); + EXPECT_EQ(uniform.Sample().key, 123); + + // Existing keys cannot be deleted twice. + TF_EXPECT_OK(uniform.Delete(123)); + EXPECT_EQ(uniform.Delete(123).code(), tensorflow::error::INVALID_ARGUMENT); +} + +TEST(UniformTest, MatchesUniformDistribution) { + const int64_t kItems = 100; + const int64_t kSamples = 1000000; + double expected_probability = 1. / static_cast(kItems); + + UniformDistribution uniform; + for (int i = 0; i < kItems; i++) { + TF_EXPECT_OK(uniform.Insert(i, 0)); + } + std::vector counts(kItems); + for (int i = 0; i < kSamples; i++) { + KeyDistributionInterface::KeyWithProbability sample = uniform.Sample(); + EXPECT_EQ(sample.probability, expected_probability); + counts[sample.key]++; + } + for (int64_t count : counts) { + EXPECT_NEAR(static_cast(count) / static_cast(kSamples), + expected_probability, 0.05); + } +} + +TEST(UniformTest, Options) { + UniformDistribution uniform; + EXPECT_THAT(uniform.options(), testing::EqualsProto("uniform: true")); +} + +TEST(UniformDeathTest, ClearThenSample) { + UniformDistribution uniform; + for (int i = 0; i < 100; i++) { + TF_EXPECT_OK(uniform.Insert(i, i)); + } + uniform.Sample(); + uniform.Clear(); + EXPECT_DEATH(uniform.Sample(), ""); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/ops/BUILD b/reverb/cc/ops/BUILD new file mode 100644 index 0000000..8ed6418 --- /dev/null +++ b/reverb/cc/ops/BUILD @@ -0,0 +1,41 @@ +load( + "//reverb/cc/platform:build_rules.bzl", + "reverb_absl_deps", + "reverb_gen_op_wrapper_py", + "reverb_kernel_library", + "reverb_tf_ops_visibility", +) + +package(default_visibility = reverb_tf_ops_visibility()) + +licenses(["notice"]) + +reverb_kernel_library( + name = "client", + srcs = ["client.cc"], + deps = [ + "//reverb/cc:replay_client", + ] + reverb_absl_deps(), +) + +reverb_kernel_library( + name = "dataset", + srcs = ["dataset.cc"], + deps = [ + "//reverb/cc:replay_client", + "//reverb/cc:replay_sampler", + "//reverb/cc/platform:logging", + ], +) + +reverb_gen_op_wrapper_py( + name = "gen_client_ops", + out = "gen_client_ops.py", + kernel_lib = ":client", +) + +reverb_gen_op_wrapper_py( + name = "gen_dataset_op", + out = "gen_dataset_op.py", + kernel_lib = ":dataset", +) diff --git a/reverb/cc/ops/client.cc b/reverb/cc/ops/client.cc new file mode 100644 index 0000000..4105069 --- /dev/null +++ b/reverb/cc/ops/client.cc @@ -0,0 +1,284 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include +#include "absl/strings/str_cat.h" +#include "reverb/cc/replay_client.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { +namespace { + +using ::tensorflow::tstring; +using ::tensorflow::errors::InvalidArgument; + +REGISTER_OP("ReverbClient") + .Output("handle: resource") + .Attr("server_address: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc(R"doc( +Constructs a `ClientResource` that constructs a `ReplayClient` connected to +`server_address`. The resource allows ops to share the stub across calls. +)doc"); + +REGISTER_OP("ReverbClientSample") + .Attr("Toutput_list: list(type) >= 0") + .Input("handle: resource") + .Input("table: string") + .Output("key: uint64") + .Output("probability: double") + .Output("table_size: int64") + .Output("outputs: Toutput_list") + .Doc(R"doc( +Blocking call to sample a single item from table `table` using shared resource. +A `SampleStream`-stream is opened between the client and the server and when +the one sample has been received, the stream is closed. + +Prefer to use `ReverbDataset` when requesting more than one sample to avoid +opening and closing the stream with each call. +)doc"); + +REGISTER_OP("ReverbClientUpdatePriorities") + .Input("handle: resource") + .Input("table: string") + .Input("keys: uint64") + .Input("priorities: double") + .Doc(R"doc( +Blocking call to update the priorities of a collection of items. Keys that could +not be found in table `table` on server are ignored and does not impact the rest +of the request. +)doc"); + +REGISTER_OP("ReverbClientInsert") + .Attr("T: list(type) >= 0") + .Input("handle: resource") + .Input("data: T") + .Input("tables: string") + .Input("priorities: double") + .Doc(R"doc( +Blocking call to insert a single trajectory into one or more tables. The data +is treated as an episode constituting of a single timestep. Note that this mean +that when the item is sampled, it will be returned as a sequence of length 1, +containing `data`. +)doc"); + +class ClientResource : public tensorflow::ResourceBase { + public: + explicit ClientResource(const std::string& server_address) + : tensorflow::ResourceBase(), + client_(server_address), + server_address_(server_address) {} + + std::string DebugString() const override { + return tensorflow::strings::StrCat("Client with server address: ", + server_address_); + } + + ReplayClient* client() { return &client_; } + + private: + ReplayClient client_; + std::string server_address_; + + TF_DISALLOW_COPY_AND_ASSIGN(ClientResource); +}; + +class ClientHandleOp : public tensorflow::ResourceOpKernel { + public: + explicit ClientHandleOp(tensorflow::OpKernelConstruction* context) + : tensorflow::ResourceOpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("server_address", &server_address_)); + } + + private: + tensorflow::Status CreateResource(ClientResource** ret) override { + *ret = new ClientResource(server_address_); + return tensorflow::Status::OK(); + } + + std::string server_address_; + + TF_DISALLOW_COPY_AND_ASSIGN(ClientHandleOp); +}; + +// TODO(b/154929314): Change this to an async op. +class SampleOp : public tensorflow::OpKernel { + public: + explicit SampleOp(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(tensorflow::OpKernelContext* context) override { + ClientResource* resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &resource)); + + const tensorflow::Tensor* table_tensor; + OP_REQUIRES_OK(context, context->input("table", &table_tensor)); + std::string table = table_tensor->scalar()(); + + std::vector sample; + std::unique_ptr sampler; + + ReplaySampler::Options options; + options.max_samples = 1; + options.max_in_flight_samples_per_worker = 1; + + OP_REQUIRES_OK(context, + resource->client()->NewSampler(table, options, &sampler)); + OP_REQUIRES_OK(context, sampler->GetNextTimestep(&sample, nullptr)); + OP_REQUIRES(context, sample.size() == context->num_outputs(), + InvalidArgument( + "Number of tensors in the replay sample did not match the " + "expected count.")); + + for (int i = 0; i < sample.size(); i++) { + tensorflow::Tensor* tensor; + OP_REQUIRES_OK(context, + context->allocate_output(i, sample[i].shape(), &tensor)); + *tensor = std::move(sample[i]); + } + } + + TF_DISALLOW_COPY_AND_ASSIGN(SampleOp); +}; + +class UpdatePrioritiesOp : public tensorflow::OpKernel { + public: + explicit UpdatePrioritiesOp(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(tensorflow::OpKernelContext* context) override { + ClientResource* resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &resource)); + + const tensorflow::Tensor* table; + OP_REQUIRES_OK(context, context->input("table", &table)); + const tensorflow::Tensor* keys; + OP_REQUIRES_OK(context, context->input("keys", &keys)); + const tensorflow::Tensor* priorities; + OP_REQUIRES_OK(context, context->input("priorities", &priorities)); + + OP_REQUIRES( + context, keys->dims() == 1, + InvalidArgument("Tensors `keys` and `priorities` must be of rank 1.")); + OP_REQUIRES(context, keys->shape() == priorities->shape(), + InvalidArgument( + "Tensors `keys` and `priorities` do not match in shape.")); + + std::string table_str = table->scalar()(); + std::vector updates; + for (int i = 0; i < keys->dim_size(0); i++) { + KeyWithPriority update; + update.set_key(keys->flat()(i)); + update.set_priority(priorities->flat()(i)); + updates.push_back(std::move(update)); + } + + // The call will only fail if the Reverb-server is brought down during an + // active call (e.g preempted). When this happens the request is retried and + // since MutatePriorities sets `wait_for_ready` the request will no be sent + // before the server is brought up again. It is therefore no problem to have + // this retry in this tight loop. + tensorflow::Status status; + do { + status = resource->client()->MutatePriorities(table_str, updates, {}); + } while (tensorflow::errors::IsUnavailable(status) || + tensorflow::errors::IsDeadlineExceeded(status)); + OP_REQUIRES_OK(context, status); + } + + TF_DISALLOW_COPY_AND_ASSIGN(UpdatePrioritiesOp); +}; + +class InsertOp : public tensorflow::OpKernel { + public: + explicit InsertOp(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(tensorflow::OpKernelContext* context) override { + ClientResource* resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &resource)); + + const tensorflow::Tensor* tables; + OP_REQUIRES_OK(context, context->input("tables", &tables)); + const tensorflow::Tensor* priorities; + OP_REQUIRES_OK(context, context->input("priorities", &priorities)); + + OP_REQUIRES(context, tables->dims() == 1 && priorities->dims() == 1, + InvalidArgument( + "Tensors `tables` and `priorities` must be of rank 1.")); + OP_REQUIRES( + context, tables->shape() == priorities->shape(), + InvalidArgument( + "Tensors `tables` and `priorities` do not match in shape.")); + + tensorflow::OpInputList data; + OP_REQUIRES_OK(context, context->input_list("data", &data)); + + // TODO(b/154929210): This can probably be avoided. + std::vector tensors; + for (const auto& i : data) { + tensors.push_back(i); + } + + std::unique_ptr writer; + OP_REQUIRES_OK(context, + resource->client()->NewWriter(1, 1, false, &writer)); + OP_REQUIRES_OK(context, writer->AppendTimestep(std::move(tensors))); + + auto tables_t = tables->flat(); + auto priorities_t = priorities->flat(); + for (int i = 0; i < tables->dim_size(0); i++) { + OP_REQUIRES_OK(context, + writer->AddPriority(tables_t(i), 1, priorities_t(i))); + } + + OP_REQUIRES_OK(context, writer->Close()); + } + + TF_DISALLOW_COPY_AND_ASSIGN(InsertOp); +}; + +REGISTER_KERNEL_BUILDER(Name("ReverbClient").Device(tensorflow::DEVICE_CPU), + ClientHandleOp); + +REGISTER_KERNEL_BUILDER( + Name("ReverbClientInsert").Device(tensorflow::DEVICE_CPU), InsertOp); + +REGISTER_KERNEL_BUILDER( + Name("ReverbClientSample").Device(tensorflow::DEVICE_CPU), SampleOp); + +REGISTER_KERNEL_BUILDER( + Name("ReverbClientUpdatePriorities").Device(tensorflow::DEVICE_CPU), + UpdatePrioritiesOp); + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/ops/dataset.cc b/reverb/cc/ops/dataset.cc new file mode 100644 index 0000000..a8f9d38 --- /dev/null +++ b/reverb/cc/ops/dataset.cc @@ -0,0 +1,371 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/core/framework/dataset.h" + +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/replay_client.h" +#include "reverb/cc/replay_sampler.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" + +namespace deepmind { +namespace reverb { +namespace { + +using ::tensorflow::errors::Cancelled; +using ::tensorflow::errors::FailedPrecondition; +using ::tensorflow::errors::InvalidArgument; +using ::tensorflow::errors::Unimplemented; + +REGISTER_OP("ReverbDataset") + .Attr("server_address: string") + .Attr("table: string") + .Attr("sequence_length: int = -1") + .Attr("emit_timesteps: bool = true") + .Attr("max_in_flight_samples_per_worker: int = 100") + .Attr("num_workers_per_iterator: int = -1") + .Attr("max_samples_per_stream: int = -1") + .Attr("dtypes: list(type) >= 1") + .Attr("shapes: list(shape) >= 1") + .Output("dataset: variant") + .SetIsStateful() + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc(R"doc( +Establishes and manages a connection to gRPC ReplayService at `server_address` +to stream samples from table `table`. + +The connection is managed using a single instance of `ReplayClient` (see +../replay_client.h) owned by the Dataset. From the shared `ReplayClient`, each +iterator maintains their own `ReplaySampler` (see ../replay_sampler.h), allowing +for multiple parallel streams using a single connection. + +`dtypes` and `shapes` must match the type and shape of a single "timestep" +within sampled sequences. That is, (key, priority, table_size, ...data passed to +`ReplayWriter::AppendTimestep` at insertion time). This is the type and shape of +tensors returned by `GetNextTimestep`. + +sequence_length: (Defaults to -1, i.e unknown) The number of timesteps in +the samples. If set then the length of the received samples are checked against +this value. + +`emit_timesteps` (defaults to true) determines whether individual timesteps or +complete sequences should be returned from the iterators. When set to false +(i.e return sequences), `shapes` must have dim[0] equal to `sequence_length`. +Emitting complete samples is more efficient as it avoids the memcopies involved +in splitting up a sequence and then batching it up again. + +`max_in_flight_samples_per_worker` (defaults to 100) is the maximum number of + sampled item allowed to exist in flight (per iterator). See +`ReplaySampler::Options::max_in_flight_samples_per_worker` for more details. + +`num_workers_per_iterator` (defaults to -1, i.e auto selected) is the number of +worker threads to start per iterator. When the selected table uses a FIFO +sampler (i.e a queue) then exactly 1 worker must be used to avoid races causing +invalid ordering of items. For all other samplers, this value should be roughly +equal to the number of threads available on the CPU. + +`max_samples_per_stream` (defaults to -1, i.e auto selected) is the maximum +number of samples to fetch from a stream before a new call is made. Keeping this +number low ensures that the data is fetched uniformly from all servers. +)doc"); + +class ReverbDatasetOp : public tensorflow::data::DatasetOpKernel { + public: + explicit ReverbDatasetOp(tensorflow::OpKernelConstruction* ctx) + : tensorflow::data::DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("server_address", &server_address_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("table", &table_)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr("max_in_flight_samples_per_worker", + &sampler_options_.max_in_flight_samples_per_worker)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_workers_per_iterator", + &sampler_options_.num_workers)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_samples_per_stream", + &sampler_options_.max_samples_per_stream)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("sequence_length", &sequence_length_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("emit_timesteps", &emit_timesteps_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_)); + + if (!emit_timesteps_) { + for (int i = 0; i < shapes_.size(); i++) { + OP_REQUIRES(ctx, shapes_[i].dims() != 0, + InvalidArgument( + "When emit_timesteps is false, all elements of shapes " + "must have " + "dim[0] = sequence_length (", + sequence_length_, "). Element ", i, + " of flattened shapes has rank 0 and thus no dim[0].")); + + OP_REQUIRES(ctx, shapes_[i].dim_size(0) == sequence_length_, + InvalidArgument("When emit_timesteps is false, all " + "elements of shapes must have " + "dim[0] = sequence_length (", + sequence_length_, "). Element ", i, + " of flattened shapes has dim[0] = ", + shapes_[i].dim_size(0), ".")); + } + } + } + + void MakeDataset(tensorflow::OpKernelContext* ctx, + tensorflow::data::DatasetBase** output) override { + *output = new Dataset(ctx, server_address_, dtypes_, shapes_, table_, + sampler_options_, sequence_length_, emit_timesteps_); + } + + private: + class Dataset : public tensorflow::data::DatasetBase { + public: + Dataset(tensorflow::OpKernelContext* ctx, std::string server_address, + tensorflow::DataTypeVector dtypes, + std::vector shapes, + std::string table, const ReplaySampler::Options& sampler_options, + int sequence_length, bool emit_timesteps) + : tensorflow::data::DatasetBase(tensorflow::data::DatasetContext(ctx)), + server_address_(std::move(server_address)), + dtypes_(std::move(dtypes)), + shapes_(std::move(shapes)), + table_(std::move(table)), + sampler_options_(sampler_options), + sequence_length_(sequence_length), + emit_timesteps_(emit_timesteps), + client_(absl::make_unique(server_address_)) {} + + std::unique_ptr MakeIteratorInternal( + const std::string& prefix) const override { + return absl::make_unique( + tensorflow::data::DatasetIterator::Params{ + this, absl::StrCat(prefix, "::ReverbDataset")}, + client_.get(), table_, sampler_options_, sequence_length_, + emit_timesteps_, dtypes_, shapes_); + } + + const tensorflow::DataTypeVector& output_dtypes() const override { + return dtypes_; + } + + const std::vector& output_shapes() + const override { + return shapes_; + } + + std::string DebugString() const override { + return "ReverbDatasetOp::Dataset"; + } + + tensorflow::Status CheckExternalState() const override { + return FailedPrecondition(DebugString(), " depends on external state."); + } + + protected: + tensorflow::Status AsGraphDefInternal( + tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b, + tensorflow::Node** output) const override { + tensorflow::AttrValue server_address_attr; + tensorflow::AttrValue table_attr; + tensorflow::AttrValue max_in_flight_samples_per_worker_attr; + tensorflow::AttrValue num_workers_attr; + tensorflow::AttrValue max_samples_per_stream_attr; + tensorflow::AttrValue sequence_length_attr; + tensorflow::AttrValue emit_timesteps_attr; + tensorflow::AttrValue dtypes_attr; + tensorflow::AttrValue shapes_attr; + + b->BuildAttrValue(server_address_, &server_address_attr); + b->BuildAttrValue(table_, &table_attr); + b->BuildAttrValue(sampler_options_.max_in_flight_samples_per_worker, + &max_in_flight_samples_per_worker_attr); + b->BuildAttrValue(sampler_options_.num_workers, &num_workers_attr); + b->BuildAttrValue(sampler_options_.max_samples_per_stream, + &max_samples_per_stream_attr); + b->BuildAttrValue(sequence_length_, &sequence_length_attr); + b->BuildAttrValue(emit_timesteps_, &emit_timesteps_attr); + b->BuildAttrValue(dtypes_, &dtypes_attr); + b->BuildAttrValue(shapes_, &shapes_attr); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, {}, + { + {"server_address", server_address_attr}, + {"table", table_attr}, + {"max_in_flight_samples_per_worker", + max_in_flight_samples_per_worker_attr}, + {"num_workers_per_iterator", num_workers_attr}, + {"max_samples_per_stream", max_samples_per_stream_attr}, + {"sequence_length", sequence_length_attr}, + {"emit_timesteps", emit_timesteps_attr}, + {"dtypes", dtypes_attr}, + {"shapes", shapes_attr}, + }, + output)); + + return tensorflow::Status::OK(); + } + + private: + class Iterator : public tensorflow::data::DatasetIterator { + public: + explicit Iterator( + const Params& params, ReplayClient* client, const std::string& table, + const ReplaySampler::Options& sampler_options, int sequence_length, + bool emit_timesteps, const tensorflow::DataTypeVector& dtypes, + const std::vector& shapes) + : DatasetIterator(params), + client_(client), + table_(table), + sampler_options_(sampler_options), + sequence_length_(sequence_length), + emit_timesteps_(emit_timesteps), + dtypes_(dtypes), + shapes_(shapes), + step_within_sample_(0) {} + + tensorflow::Status Initialize( + tensorflow::data::IteratorContext* ctx) override { + // If sequences are emitted then the all shapes will start with the + // sequence length. The validation expects the shapes of a single + // timestep so if sequences are emitted then we need to trim the leading + // dim on all shapes before validating it. + auto validation_shapes = shapes_; + if (!emit_timesteps_) { + for (auto& shape : validation_shapes) { + shape.RemoveDim(0); + } + } + + // TODO(b/154929217): Expose this option so it can be set to infinite + // outside of tests. + const auto kTimeout = absl::Seconds(30); + auto status = + client_->NewSampler(table_, sampler_options_, + /*validation_dtypes=*/dtypes_, + validation_shapes, kTimeout, &sampler_); + if (tensorflow::errors::IsDeadlineExceeded(status)) { + REVERB_LOG(REVERB_WARNING) + << "Unable to validate shapes and dtypes of new sampler for '" + << table_ << "' as server could not be reached in time (" + << kTimeout + << "). We were thus unable to fetch signature from server. The " + "sampler will be constructed without validating the dtypes " + "and shapes."; + return client_->NewSampler(table_, sampler_options_, &sampler_); + } + return status; + } + + tensorflow::Status GetNextInternal( + tensorflow::data::IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + REVERB_CHECK(sampler_.get() != nullptr) + << "Initialize was not called?"; + + auto token = ctx->cancellation_manager()->get_cancellation_token(); + bool registered = ctx->cancellation_manager()->RegisterCallback( + token, [&] { sampler_->Close(); }); + if (!registered) { + sampler_->Close(); + } + + tensorflow::Status status; + if (emit_timesteps_) { + bool last_timestep = false; + status = sampler_->GetNextTimestep(out_tensors, &last_timestep); + + step_within_sample_++; + + if (last_timestep && sequence_length_ > 0 && + step_within_sample_ != sequence_length_) { + return InvalidArgument( + "Received sequence of invalid length. Expected ", + sequence_length_, " steps, got ", step_within_sample_); + } + if (step_within_sample_ == sequence_length_ && !last_timestep) { + return InvalidArgument( + "Receieved sequence did not terminate after expected number of " + "steps (", + sequence_length_, ")."); + } + if (last_timestep) { + step_within_sample_ = 0; + } + } else { + status = sampler_->GetNextSample(out_tensors); + } + + if (registered && + !ctx->cancellation_manager()->DeregisterCallback(token)) { + return Cancelled("Iterator context was cancelled"); + } + + return status; + } + + protected: + tensorflow::Status SaveInternal( + tensorflow::data::SerializationContext* ctx, + tensorflow::data::IteratorStateWriter* writer) override { + return Unimplemented("SaveInternal is currently not supported"); + } + + tensorflow::Status RestoreInternal( + tensorflow::data::IteratorContext* ctx, + tensorflow::data::IteratorStateReader* reader) override { + return Unimplemented("RestoreInternal is currently not supported"); + } + + private: + ReplayClient* client_; + const std::string& table_; + const ReplaySampler::Options sampler_options_; + const int sequence_length_; + const bool emit_timesteps_; + const tensorflow::DataTypeVector& dtypes_; + const std::vector& shapes_; + std::unique_ptr sampler_; + int step_within_sample_; + }; // Iterator. + + const std::string server_address_; + const tensorflow::DataTypeVector dtypes_; + const std::vector shapes_; + const std::string table_; + const ReplaySampler::Options sampler_options_; + const int sequence_length_; + const bool emit_timesteps_; + std::unique_ptr client_; + }; // Dataset. + + std::string server_address_; + std::string table_; + ReplaySampler::Options sampler_options_; + int sequence_length_; + bool emit_timesteps_; + tensorflow::DataTypeVector dtypes_; + std::vector shapes_; + + TF_DISALLOW_COPY_AND_ASSIGN(ReverbDatasetOp); +}; + +REGISTER_KERNEL_BUILDER(Name("ReverbDataset").Device(tensorflow::DEVICE_CPU), + ReverbDatasetOp); + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/platform/BUILD b/reverb/cc/platform/BUILD new file mode 100644 index 0000000..357dbee --- /dev/null +++ b/reverb/cc/platform/BUILD @@ -0,0 +1,151 @@ +# Platform-specific code for reverb + +load( + "//reverb/cc/platform:build_rules.bzl", + "reverb_absl_deps", + "reverb_cc_library", + "reverb_cc_test", + "reverb_grpc_deps", + "reverb_tf_deps", +) + +package(default_visibility = ["//reverb:__subpackages__"]) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +reverb_cc_library( + name = "tfrecord_checkpointer", + srcs = ["tfrecord_checkpointer.cc"], + hdrs = ["tfrecord_checkpointer.h"], + visibility = ["//reverb:__subpackages__"], + deps = [ + "//reverb/cc:chunk_store", + "//reverb/cc:priority_table", + "//reverb/cc:schema_cc_proto", + "//reverb/cc/checkpointing:checkpoint_cc_proto", + "//reverb/cc/checkpointing:interface", + "//reverb/cc/distributions:fifo", + "//reverb/cc/distributions:heap", + "//reverb/cc/distributions:interface", + "//reverb/cc/distributions:lifo", + "//reverb/cc/distributions:prioritized", + "//reverb/cc/distributions:uniform", + "//reverb/cc/table_extensions:interface", + ] + reverb_tf_deps() + reverb_absl_deps(), +) + +reverb_cc_test( + name = "tfrecord_checkpointer_test", + srcs = ["tfrecord_checkpointer_test.cc"], + deps = [ + ":tfrecord_checkpointer", + "//reverb/cc:chunk_store", + "//reverb/cc:priority_table", + "//reverb/cc/distributions:fifo", + "//reverb/cc/distributions:heap", + "//reverb/cc/distributions:prioritized", + "//reverb/cc/distributions:uniform", + "//reverb/cc/testing:proto_test_util", + ] + reverb_tf_deps(), +) + +reverb_cc_library( + name = "checkpointing_hdr", + hdrs = ["checkpointing.h"], + deps = [ + "//reverb/cc/checkpointing:interface", + ] + reverb_absl_deps(), +) + +reverb_cc_library( + name = "checkpointing", + hdrs = ["checkpointing.h"], + visibility = ["//reverb:__subpackages__"], + deps = [ + "//reverb/cc/checkpointing:interface", + "//reverb/cc/platform/default:checkpointer", + ] + reverb_absl_deps(), +) + +reverb_cc_library( + name = "grpc_utils_hdr", + hdrs = ["grpc_utils.h"], + deps = reverb_grpc_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "grpc_utils", + hdrs = ["grpc_utils.h"], + visibility = ["//reverb:__subpackages__"], + deps = [ + "//reverb/cc/platform/default:grpc_utils", + ] + reverb_grpc_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "snappy_hdr", + hdrs = ["snappy.h"], + deps = reverb_absl_deps(), +) + +reverb_cc_library( + name = "snappy", + hdrs = ["snappy.h"], + visibility = ["//reverb:__subpackages__"], + deps = [ + "//reverb/cc/platform/default:snappy", + ] + reverb_absl_deps(), +) + +reverb_cc_library( + name = "thread_hdr", + hdrs = ["thread.h"], + deps = reverb_absl_deps(), +) + +reverb_cc_library( + name = "thread", + hdrs = ["thread.h"], + visibility = ["//reverb:__subpackages__"], + deps = [ + "//reverb/cc/platform/default:thread", + ] + reverb_absl_deps(), +) + +reverb_cc_test( + name = "thread_test", + srcs = ["thread_test.cc"], + deps = [ + ":thread", + ] + reverb_absl_deps(), +) + +reverb_cc_library( + name = "logging", + hdrs = ["logging.h"], + visibility = ["//reverb:__subpackages__"], + deps = ["//reverb/cc/platform/default:logging"], +) + +reverb_cc_library( + name = "net_hdr", + hdrs = ["net.h"], +) + +reverb_cc_library( + name = "net", + hdrs = ["net.h"], + visibility = ["//reverb:__subpackages__"], + deps = ["//reverb/cc/platform/default:net"], +) + +reverb_cc_test( + name = "net_test", + srcs = ["net_test.cc"], + deps = [ + ":logging", + ":net", + ], +) diff --git a/reverb/cc/platform/build_rules.bzl b/reverb/cc/platform/build_rules.bzl new file mode 100644 index 0000000..cffa946 --- /dev/null +++ b/reverb/cc/platform/build_rules.bzl @@ -0,0 +1,40 @@ +"""Main Starlark code for platform-specific build rules.""" + +load( + "//reverb/cc/platform/default:build_rules.bzl", + _reverb_absl_deps = "reverb_absl_deps", + _reverb_cc_grpc_library = "reverb_cc_grpc_library", + _reverb_cc_library = "reverb_cc_library", + _reverb_cc_proto_library = "reverb_cc_proto_library", + _reverb_cc_test = "reverb_cc_test", + _reverb_gen_op_wrapper_py = "reverb_gen_op_wrapper_py", + _reverb_grpc_deps = "reverb_grpc_deps", + _reverb_kernel_library = "reverb_kernel_library", + _reverb_py_proto_library = "reverb_py_proto_library", + _reverb_py_standard_imports = "reverb_py_standard_imports", + _reverb_py_test = "reverb_py_test", + _reverb_pybind_deps = "reverb_pybind_deps", + _reverb_pybind_extension = "reverb_pybind_extension", + _reverb_pytype_library = "reverb_pytype_library", + _reverb_pytype_strict_library = "reverb_pytype_strict_library", + _reverb_tf_deps = "reverb_tf_deps", + _reverb_tf_ops_visibility = "reverb_tf_ops_visibility", +) + +reverb_absl_deps = _reverb_absl_deps +reverb_cc_library = _reverb_cc_library +reverb_cc_test = _reverb_cc_test +reverb_cc_grpc_library = _reverb_cc_grpc_library +reverb_cc_proto_library = _reverb_cc_proto_library +reverb_gen_op_wrapper_py = _reverb_gen_op_wrapper_py +reverb_grpc_deps = _reverb_grpc_deps +reverb_kernel_library = _reverb_kernel_library +reverb_py_proto_library = _reverb_py_proto_library +reverb_py_standard_imports = _reverb_py_standard_imports +reverb_py_test = _reverb_py_test +reverb_pybind_deps = _reverb_pybind_deps +reverb_pybind_extension = _reverb_pybind_extension +reverb_pytype_library = _reverb_pytype_library +reverb_pytype_strict_library = _reverb_pytype_strict_library +reverb_tf_ops_visibility = _reverb_tf_ops_visibility +reverb_tf_deps = _reverb_tf_deps diff --git a/reverb/cc/platform/checkpointing.h b/reverb/cc/platform/checkpointing.h new file mode 100644 index 0000000..a90e2fd --- /dev/null +++ b/reverb/cc/platform/checkpointing.h @@ -0,0 +1,33 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_PLATFORM_CHECKPOINTER_H_ +#define REVERB_CC_PLATFORM_CHECKPOINTER_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "reverb/cc/checkpointing/interface.h" + +namespace deepmind { +namespace reverb { + +std::unique_ptr CreateDefaultCheckpointer( + std::string root_dir, std::string group = ""); + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_PLATFORM_CHECKPOINTER_H_ diff --git a/reverb/cc/platform/default/BUILD b/reverb/cc/platform/default/BUILD new file mode 100644 index 0000000..0b7e276 --- /dev/null +++ b/reverb/cc/platform/default/BUILD @@ -0,0 +1,72 @@ +# Platform-specific code for reverb + +load( + "//reverb/cc/platform/default:build_rules.bzl", + "reverb_cc_library", + "reverb_grpc_deps", + "reverb_tf_deps", +) + +package(default_visibility = ["//reverb/cc/platform:__pkg__"]) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +reverb_cc_library( + name = "snappy", + srcs = ["snappy.cc"], + deps = [ + "//reverb/cc/platform:snappy_hdr", + "@com_google_absl//absl/strings", + ] + reverb_tf_deps(), + alwayslink = 1, +) + +reverb_cc_library( + name = "checkpointer", + srcs = ["default_checkpointer.cc"], + deps = [ + "//reverb/cc/checkpointing:interface", + "//reverb/cc/platform:checkpointing_hdr", + "//reverb/cc/platform:tfrecord_checkpointer", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +reverb_cc_library( + name = "thread", + srcs = ["thread.cc"], + deps = [ + "//reverb/cc/platform:thread_hdr", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +reverb_cc_library( + name = "logging", + hdrs = ["logging.h"], + deps = reverb_tf_deps(), +) + +reverb_cc_library( + name = "grpc_utils", + srcs = ["grpc_utils.cc"], + deps = [ + "//reverb/cc/platform:grpc_utils_hdr", + ] + reverb_grpc_deps(), + alwayslink = 1, +) + +reverb_cc_library( + name = "net", + srcs = ["net.cc"], + deps = [ + "//reverb/cc/platform:logging", + "//reverb/cc/platform:net_hdr", + ], + alwayslink = 1, +) diff --git a/reverb/cc/platform/default/build_rules.bzl b/reverb/cc/platform/default/build_rules.bzl new file mode 100644 index 0000000..c996a96 --- /dev/null +++ b/reverb/cc/platform/default/build_rules.bzl @@ -0,0 +1,552 @@ +"""Default versions of reverb build rule helpers.""" + +def tf_copts(): + return ["-Wno-sign-compare"] + +def reverb_cc_library( + name, + srcs = [], + hdrs = [], + deps = [], + testonly = 0, + **kwargs): + if testonly: + new_deps = [ + "@com_google_googletest//:gtest", + "@tensorflow_includes//:includes", + "@tensorflow_solib//:framework_lib", + ] + else: + new_deps = [] + native.cc_library( + name = name, + srcs = srcs, + hdrs = hdrs, + copts = tf_copts(), + testonly = testonly, + deps = depset(deps + new_deps), + **kwargs + ) + +def reverb_kernel_library(name, srcs = [], deps = [], **kwargs): + deps = deps + reverb_tf_deps() + reverb_cc_library( + name = name, + srcs = srcs, + deps = deps, + alwayslink = 1, + **kwargs + ) + +def _normalize_proto(x): + if x.endswith("_proto"): + x = x.rstrip("_proto") + if x.endswith("_cc"): + x = x.rstrip("_cc") + if x.endswith("_pb2"): + x = x.rstrip("_pb2") + return x + +def _strip_proto_suffix(x): + # Workaround for bug that str.rstrip(".END") takes off more than just ".END" + if x.endswith(".proto"): + x = x[:-6] + return x + +def reverb_cc_proto_library(name, srcs = [], deps = [], **kwargs): + """Build a proto cc_library. + + This rule does three things: + + 1) Create a filegroup with name `name` that contains `srcs` + and any sources from deps named "x_proto" or "x_cc_proto". + + 2) Uses protoc to compile srcs to .h/.cc files, allowing any + tensorflow imports. + + 3) Creates a cc_library with name `name` building the resulting .h/.cc + files. + + Args: + name: The name, should end with "_cc_proto". + srcs: The .proto files. + deps: Any reverb_cc_proto_library targets. + **kwargs: Any additional args for the cc_library rule. + """ + gen_srcs = [_strip_proto_suffix(x) + ".pb.cc" for x in srcs] + gen_hdrs = [_strip_proto_suffix(x) + ".pb.h" for x in srcs] + src_paths = ["$(location {})".format(x) for x in srcs] + dep_srcs = [] + for x in deps: + if x.endswith("_proto"): + dep_srcs.append(_normalize_proto(x)) + native.filegroup( + name = _normalize_proto(name), + srcs = srcs + dep_srcs, + **kwargs + ) + native.genrule( + name = name + "_gen", + srcs = srcs, + outs = gen_srcs + gen_hdrs, + tools = dep_srcs + [ + "@protobuf_protoc//:protoc_bin", + "@tensorflow_includes//:protos", + ], + cmd = """ + OUTDIR=$$(echo $(RULEDIR) | sed -e 's#reverb/.*##') + $(location @protobuf_protoc//:protoc_bin) \ + --proto_path=external/tensorflow_includes/tensorflow_includes/ \ + --proto_path=. \ + --cpp_out=$$OUTDIR {}""".format( + " ".join(src_paths), + ), + ) + + native.cc_library( + name = "{}_static".format(name), + srcs = gen_srcs, + hdrs = gen_hdrs, + deps = depset(deps + reverb_tf_deps()), + alwayslink = 1, + **kwargs + ) + native.cc_binary( + name = "lib{}.so".format(name), + deps = ["{}_static".format(name)], + linkshared = 1, + **kwargs + ) + native.cc_library( + name = name, + hdrs = gen_hdrs, + srcs = ["lib{}.so".format(name)], + deps = depset(deps + reverb_tf_deps()), + alwayslink = 1, + **kwargs + ) + +def reverb_py_proto_library(name, srcs = [], deps = [], **kwargs): + """Build a proto py_library. + + This rule does three things: + + 1) Create a filegroup with name `name` that contains `srcs` + and any sources from deps named "x_proto" or "x_py_proto". + + 2) Uses protoc to compile srcs to _pb2.py files, allowing any + tensorflow imports. + + 3) Creates a py_library with name `name` building the resulting .py + files. + + Args: + name: The name, should end with "_cc_proto". + srcs: The .proto files. + deps: Any reverb_cc_proto_library targets. + **kwargs: Any additional args for the cc_library rule. + """ + gen_srcs = [_strip_proto_suffix(x) + "_pb2.py" for x in srcs] + src_paths = ["$(location {})".format(x) for x in srcs] + proto_deps = [] + py_deps = [] + for x in deps: + if x.endswith("_proto"): + proto_deps.append(_normalize_proto(x)) + else: + py_deps.append(x) + native.filegroup( + name = _normalize_proto(name), + srcs = srcs + proto_deps, + **kwargs + ) + native.genrule( + name = name + "_gen", + srcs = srcs, + outs = gen_srcs, + tools = proto_deps + [ + "@protobuf_protoc//:protoc_bin", + "@tensorflow_includes//:protos", + ], + cmd = """ + OUTDIR=$$(echo $(RULEDIR) | sed -e 's#reverb/.*##') + $(location @protobuf_protoc//:protoc_bin) \ + --proto_path=external/tensorflow_includes/tensorflow_includes/ \ + --proto_path=. \ + --python_out=$$OUTDIR {}""".format( + " ".join(src_paths), + ), + ) + native.py_library( + name = name, + srcs = gen_srcs, + deps = py_deps, + data = proto_deps, + **kwargs + ) + +def reverb_cc_grpc_library( + name, + srcs = [], + deps = [], + generate_mocks = False, + **kwargs): + """Build a grpc cc_library. + + This rule does two things: + + 1) Uses protoc + grpc plugin to compile srcs to .h/.cc files, allowing any + tensorflow imports. Also creates mock headers if requested. + + 2) Creates a cc_library with name `name` building the resulting .h/.cc + files. + + Args: + name: The name, should end with "_cc_grpc_proto". + srcs: The .proto files. + deps: reverb_cc_proto_library targets. Must include src + "_cc_proto", + the cc_proto library, for each src in srcs. + generate_mocks: If true, creates mock headers for each source. + **kwargs: Any additional args for the cc_library rule. + """ + gen_srcs = [x.rstrip(".proto") + ".grpc.pb.cc" for x in srcs] + gen_hdrs = [x.rstrip(".proto") + ".grpc.pb.h" for x in srcs] + proto_src_deps = [] + for x in deps: + if x.endswith("_proto"): + proto_src_deps.append(_normalize_proto(x)) + src_paths = ["$(location {})".format(x) for x in srcs] + + if generate_mocks: + gen_mocks = [x.rstrip(".proto") + "_mock.grpc.pb.h" for x in srcs] + else: + gen_mocks = [] + + native.genrule( + name = name + "_gen", + srcs = srcs, + outs = gen_srcs + gen_hdrs + gen_mocks, + tools = proto_src_deps + [ + "@protobuf_protoc//:protoc_bin", + "@tensorflow_includes//:protos", + "@com_github_grpc_grpc//src/compiler:grpc_cpp_plugin", + ], + cmd = """ + OUTDIR=$$(echo $(RULEDIR) | sed -e 's#reverb/.*##') + $(location @protobuf_protoc//:protoc_bin) \ + --plugin=protoc-gen-grpc=$(location @com_github_grpc_grpc//src/compiler:grpc_cpp_plugin) \ + --proto_path=external/tensorflow_includes/tensorflow_includes/ \ + --proto_path=. \ + --grpc_out={} {}""".format( + "generate_mock_code=true:$$OUTDIR" if generate_mocks else "$$OUTDIR", + " ".join(src_paths), + ), + ) + + native.cc_library( + name = name, + srcs = gen_srcs, + hdrs = gen_hdrs + gen_mocks, + deps = depset(deps + ["@com_github_grpc_grpc//:grpc++_codegen_proto"]), + **kwargs + ) + +def reverb_cc_test(name, srcs, deps = [], **kwargs): + """Reverb-specific version of cc_test. + + Args: + name: Target name. + srcs: Target sources. + deps: Target deps. + **kwargs: Additional args to cc_test. + """ + new_deps = [ + "@com_google_googletest//:gtest", + "@tensorflow_includes//:includes", + "@tensorflow_solib//:framework_lib", + "@com_google_googletest//:gtest_main", + ] + size = kwargs.pop("size", "small") + native.cc_test( + name = name, + size = size, + copts = tf_copts(), + srcs = srcs, + deps = depset(deps + new_deps), + **kwargs + ) + +def reverb_gen_op_wrapper_py(name, out, kernel_lib, linkopts = [], **kwargs): + """Generates the py_library `name` with a data dep on the ops in kernel_lib. + + The resulting py_library creates file `$out`, and has a dependency on a + symbolic library called lib{$name}_gen_op.so, which contains the kernels + and ops and can be loaded via `tf.load_op_library`. + + Args: + name: The name of the py_library. + out: The name of the python file. Use "gen_{name}_ops.py". + kernel_lib: A cc_kernel_library target to generate for. + **kwargs: Any args to the `cc_binary` and `py_library` internal rules. + """ + if not out.endswith(".py"): + fail("Argument out must end with '.py', but saw: {}".format(out)) + + module_name = "lib{}_gen_op".format(name) + version_script_file = "%s-version-script.lds" % module_name + native.genrule( + name = module_name + "_version_script", + outs = [version_script_file], + cmd = "echo '{global:\n *tensorflow*;\n *deepmind*;\n local: *;};' >$@", + output_licenses = ["unencumbered"], + visibility = ["//visibility:private"], + ) + native.cc_binary( + name = "{}.so".format(module_name), + deps = [kernel_lib] + reverb_tf_deps() + [version_script_file], + copts = tf_copts() + [ + "-fno-strict-aliasing", # allow a wider range of code [aliasing] to compile. + "-fvisibility=hidden", # avoid symbol clashes between DSOs. + ], + linkshared = 1, + linkopts = linkopts + _rpath_linkopts(module_name) + [ + "-Wl,--version-script", + "$(location %s)" % version_script_file, + ], + **kwargs + ) + native.genrule( + name = "{}_genrule".format(out), + outs = [out], + cmd = """ + echo 'import tensorflow as tf +_reverb_gen_op = tf.load_op_library( + tf.compat.v1.resource_loader.get_path_to_datafile( + "lib{}_gen_op.so")) +_locals = locals() +for k in dir(_reverb_gen_op): + _locals[k] = getattr(_reverb_gen_op, k) +del _locals' > $@""".format(name), + ) + native.py_library( + name = name, + srcs = [out], + data = [":lib{}_gen_op.so".format(name)], + **kwargs + ) + +def reverb_pytype_library(**kwargs): + if "strict_deps" in kwargs: + kwargs.pop("strict_deps") + native.py_library(**kwargs) + +reverb_pytype_strict_library = native.py_library + +def _make_search_paths(prefix, levels_to_root): + return ",".join( + [ + "-rpath,%s/%s" % (prefix, "/".join([".."] * search_level)) + for search_level in range(levels_to_root + 1) + ], + ) + +def _rpath_linkopts(name): + # Search parent directories up to the TensorFlow root directory for shared + # object dependencies, even if this op shared object is deeply nested + # (e.g. tensorflow/contrib/package:python/ops/_op_lib.so). tensorflow/ is then + # the root and tensorflow/libtensorflow_framework.so should exist when + # deployed. Other shared object dependencies (e.g. shared between contrib/ + # ops) are picked up as long as they are in either the same or a parent + # directory in the tensorflow/ tree. + levels_to_root = native.package_name().count("/") + name.count("/") + return ["-Wl,%s" % (_make_search_paths("$$ORIGIN", levels_to_root),)] + +def reverb_pybind_extension( + name, + srcs, + module_name, + hdrs = [], + features = [], + srcs_version = "PY3", + data = [], + copts = [], + linkopts = [], + deps = [], + defines = [], + visibility = None, + testonly = None, + licenses = None, + compatible_with = None, + restricted_to = None, + deprecation = None): + """Builds a generic Python extension module. + + The module can be loaded in python by performing "import ${name}.". + + Args: + name: Name. + srcs: cc files. + module_name: The name of the hidden module. It should be different + from `name`, and *must* match the MODULE declaration in the .cc file. + hdrs: h files. + features: see bazel docs. + srcs_version: srcs_version for py_library. + data: data deps. + copts: compilation opts. + linkopts: linking opts. + deps: cc_library deps. + defines: cc_library defines. + visibility: visibility. + testonly: whether the rule is testonly. + licenses: see bazel docs. + compatible_with: see bazel docs. + restricted_to: see bazel docs. + deprecation: see bazel docs. + """ + if name == module_name: + fail( + "Must have name != module_name ({} vs. {}) because the python ".format(name, module_name) + + "wrapper $name.py needs to add extra logic loading tensorflow.", + ) + py_file = "%s.py" % name + so_file = "%s.so" % module_name + pyd_file = "%s.pyd" % module_name + symbol = "init%s" % module_name + symbol2 = "init_%s" % module_name + symbol3 = "PyInit_%s" % module_name + exported_symbols_file = "%s-exported-symbols.lds" % module_name + version_script_file = "%s-version-script.lds" % module_name + native.genrule( + name = module_name + "_exported_symbols", + outs = [exported_symbols_file], + cmd = "echo '_%s\n_%s\n_%s' >$@" % (symbol, symbol2, symbol3), + output_licenses = ["unencumbered"], + visibility = ["//visibility:private"], + testonly = testonly, + ) + native.genrule( + name = module_name + "_version_script", + outs = [version_script_file], + cmd = "echo '{global:\n %s;\n %s;\n %s;\n local: *;};' >$@" % (symbol, symbol2, symbol3), + output_licenses = ["unencumbered"], + visibility = ["//visibility:private"], + testonly = testonly, + ) + native.cc_binary( + name = so_file, + srcs = srcs + hdrs, + data = data, + copts = copts + [ + "-fno-strict-aliasing", # allow a wider range of code [aliasing] to compile. + "-fexceptions", # pybind relies on exceptions, required to compile. + "-fvisibility=hidden", # avoid pybind symbol clashes between DSOs. + ], + linkopts = linkopts + _rpath_linkopts(module_name) + [ + "-Wl,--version-script", + "$(location %s)" % version_script_file, + ], + deps = depset(deps + [ + exported_symbols_file, + version_script_file, + ]), + defines = defines, + features = features + ["-use_header_modules"], + linkshared = 1, + testonly = testonly, + licenses = licenses, + visibility = visibility, + deprecation = deprecation, + restricted_to = restricted_to, + compatible_with = compatible_with, + ) + native.genrule( + name = module_name + "_pyd_copy", + srcs = [so_file], + outs = [pyd_file], + cmd = "cp $< $@", + output_to_bindir = True, + visibility = visibility, + deprecation = deprecation, + restricted_to = restricted_to, + compatible_with = compatible_with, + ) + native.genrule( + name = name + "_py_file", + outs = [py_file], + cmd = ( + "echo 'import tensorflow as _tf; from .%s import *; del _tf' >$@" % + module_name + ), + output_licenses = ["unencumbered"], + visibility = visibility, + testonly = testonly, + ) + native.py_library( + name = name, + data = [so_file], + srcs = [py_file], + srcs_version = srcs_version, + licenses = licenses, + testonly = testonly, + visibility = visibility, + deprecation = deprecation, + restricted_to = restricted_to, + compatible_with = compatible_with, + ) + +def reverb_py_standard_imports(): + return [] + +def reverb_py_test( + name, + srcs = [], + deps = [], + paropts = [], + python_version = "PY3", + **kwargs): + size = kwargs.pop("size", "small") + native.py_test( + name = name, + size = size, + srcs = srcs, + deps = deps, + python_version = python_version, + **kwargs + ) + return + +def reverb_pybind_deps(): + return [ + "@pybind11", + ] + +def reverb_tf_ops_visibility(): + return [ + "//reverb:__subpackages__", + ] + +def reverb_tf_deps(): + return [ + "@tensorflow_includes//:includes", + "@tensorflow_solib//:framework_lib", + ] + +def reverb_grpc_deps(): + return ["@com_github_grpc_grpc//:grpc++"] + +def reverb_absl_deps(): + return [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ] diff --git a/reverb/cc/platform/default/default_checkpointer.cc b/reverb/cc/platform/default/default_checkpointer.cc new file mode 100644 index 0000000..59f6a7d --- /dev/null +++ b/reverb/cc/platform/default/default_checkpointer.cc @@ -0,0 +1,28 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/platform/checkpointing.h" +#include "reverb/cc/platform/tfrecord_checkpointer.h" + +namespace deepmind { +namespace reverb { + +std::unique_ptr CreateDefaultCheckpointer( + std::string root_dir, std::string group) { + return absl::make_unique(std::move(root_dir), + std::move(group)); +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/platform/default/grpc_utils.cc b/reverb/cc/platform/default/grpc_utils.cc new file mode 100644 index 0000000..7dbface --- /dev/null +++ b/reverb/cc/platform/default/grpc_utils.cc @@ -0,0 +1,44 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/platform/grpc_utils.h" + +#include + +#include "grpcpp/create_channel.h" +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" +#include "grpcpp/support/channel_arguments.h" + +namespace deepmind { +namespace reverb { + +std::shared_ptr MakeServerCredentials() { + return grpc::InsecureServerCredentials(); +} + +std::shared_ptr MakeChannelCredentials() { + return grpc::InsecureChannelCredentials(); +} + +std::shared_ptr CreateCustomGrpcChannel( + absl::string_view target, + const std::shared_ptr& credentials, + const grpc::ChannelArguments& channel_arguments) { + return grpc::CreateCustomChannel( + std::string(target), credentials, channel_arguments); +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/platform/default/logging.h b/reverb/cc/platform/default/logging.h new file mode 100644 index 0000000..41d3e3f --- /dev/null +++ b/reverb/cc/platform/default/logging.h @@ -0,0 +1,206 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2016-2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +// +// A minimal replacement for "glog"-like functionality. Does not provide output +// in a separate thread nor backtracing. + +#ifndef REVERB_CC_PLATFORM_DEFAULT_LOGGING_H_ +#define REVERB_CC_PLATFORM_DEFAULT_LOGGING_H_ + +#include +#include +#include +#include +#include + +#ifdef __GNUC__ +#define PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) +#define PREDICT_FALSE(x) (__builtin_expect(x, 0)) +#define NORETURN __attribute__((noreturn)) +#else +#define PREDICT_TRUE(x) (x) +#define PREDICT_FALSE(x) (x) +#define NORETURN +#endif + +namespace deepmind { +namespace reverb { +namespace internal { + +struct CheckOpString { + explicit CheckOpString(std::string* str) : str_(str) {} + explicit operator bool() const { return PREDICT_FALSE(str_ != nullptr); } + std::string* const str_; +}; + +template +CheckOpString MakeCheckOpString(const T1& v1, const T2& v2, + const char* exprtext) { + std::ostringstream oss; + oss << exprtext << " (" << v1 << " vs. " << v2 << ")"; + return CheckOpString(new std::string(oss.str())); +} + +#define DEFINE_CHECK_OP_IMPL(name, op) \ + template \ + inline CheckOpString name##Impl(const T1& v1, const T2& v2, \ + const char* exprtext) { \ + if (PREDICT_TRUE(v1 op v2)) { \ + return CheckOpString(nullptr); \ + } else { \ + return (MakeCheckOpString)(v1, v2, exprtext); \ + } \ + } \ + inline CheckOpString name##Impl(int v1, int v2, const char* exprtext) { \ + return (name##Impl)(v1, v2, exprtext); \ + } +DEFINE_CHECK_OP_IMPL(Check_EQ, ==) +DEFINE_CHECK_OP_IMPL(Check_NE, !=) +DEFINE_CHECK_OP_IMPL(Check_LE, <=) +DEFINE_CHECK_OP_IMPL(Check_LT, <) +DEFINE_CHECK_OP_IMPL(Check_GE, >=) +DEFINE_CHECK_OP_IMPL(Check_GT, >) +#undef DEFINE_CHECK_OP_IMPL + +class LogMessage { + public: + LogMessage(const char* file, int line) { + std::clog << "[" << file << ":" << line << "] "; + } + + ~LogMessage() { std::clog << "\n"; } + + std::ostream& stream() && { return std::clog; } +}; + +class LogMessageFatal { + public: + LogMessageFatal(const char* file, int line) { + stream_ << "[" << file << ":" << line << "] "; + } + + LogMessageFatal(const char* file, int line, const CheckOpString& result) { + stream_ << "[" << file << ":" << line << "] Check failed: " << *result.str_; + } + + ~LogMessageFatal() NORETURN; + + std::ostream& stream() && { return stream_; } + + private: + std::ostringstream stream_; +}; + +inline LogMessageFatal::~LogMessageFatal() { + std::cerr << stream_.str() << std::endl; + std::abort(); +} + +struct NullStream {}; + +template +NullStream&& operator<<(NullStream&& s, T&&) { + return std::move(s); +} + +enum class LogSeverity { + kFatal, + kNonFatal, +}; + +LogMessage LogStream( + std::integral_constant); +LogMessageFatal LogStream( + std::integral_constant); + +struct Voidify { + void operator&(std::ostream&) {} +}; + +} // namespace internal +} // namespace reverb +} // namespace deepmind + +#define REVERB_CHECK_OP_LOG(name, op, val1, val2, log) \ + while (::deepmind::reverb::internal::CheckOpString _result = \ + ::deepmind::reverb::internal::name##Impl( \ + val1, val2, #val1 " " #op " " #val2)) \ + log(__FILE__, __LINE__, _result).stream() + +#define REVERB_CHECK_OP(name, op, val1, val2) \ + REVERB_CHECK_OP_LOG(name, op, val1, val2, \ + ::deepmind::reverb::internal::LogMessageFatal) + +#define REVERB_CHECK_EQ(val1, val2) REVERB_CHECK_OP(Check_EQ, ==, val1, val2) +#define REVERB_CHECK_NE(val1, val2) REVERB_CHECK_OP(Check_NE, !=, val1, val2) +#define REVERB_CHECK_LE(val1, val2) REVERB_CHECK_OP(Check_LE, <=, val1, val2) +#define REVERB_CHECK_LT(val1, val2) REVERB_CHECK_OP(Check_LT, <, val1, val2) +#define REVERB_CHECK_GE(val1, val2) REVERB_CHECK_OP(Check_GE, >=, val1, val2) +#define REVERB_CHECK_GT(val1, val2) REVERB_CHECK_OP(Check_GT, >, val1, val2) + +#define REVERB_QCHECK_EQ(val1, val2) REVERB_CHECK_OP(Check_EQ, ==, val1, val2) +#define REVERB_QCHECK_NE(val1, val2) REVERB_CHECK_OP(Check_NE, !=, val1, val2) +#define REVERB_QCHECK_LE(val1, val2) REVERB_CHECK_OP(Check_LE, <=, val1, val2) +#define REVERB_QCHECK_LT(val1, val2) REVERB_CHECK_OP(Check_LT, <, val1, val2) +#define REVERB_QCHECK_GE(val1, val2) REVERB_CHECK_OP(Check_GE, >=, val1, val2) +#define REVERB_QCHECK_GT(val1, val2) REVERB_CHECK_OP(Check_GT, >, val1, val2) + +#define REVERB_CHECK(condition) \ + while (auto _result = ::deepmind::reverb::internal::CheckOpString( \ + (condition) ? nullptr : new std::string(#condition))) \ + ::deepmind::reverb::internal::LogMessageFatal(__FILE__, __LINE__, _result) \ + .stream() + +#define REVERB_QCHECK(condition) REVERB_CHECK(condition) + +#define REVERB_FATAL ::deepmind::reverb::internal::LogSeverity::kFatal +#define REVERB_QFATAL ::deepmind::reverb::internal::LogSeverity::kFatal +#define REVERB_INFO ::deepmind::reverb::internal::LogSeverity::kNonFatal +#define REVERB_WARNING ::deepmind::reverb::internal::LogSeverity::kNonFatal +#define REVERB_ERROR ::deepmind::reverb::internal::LogSeverity::kNonFatal + +#define REVERB_LOG(level) \ + decltype(::deepmind::reverb::internal::LogStream( \ + std::integral_constant<::deepmind::reverb::internal::LogSeverity, \ + level>()))(__FILE__, __LINE__) \ + .stream() + +#define REVERB_VLOG(level) ::deepmind::reverb::internal::NullStream() + +#define REVERB_LOG_IF(level, condition) \ + !(condition) \ + ? static_cast(0) \ + : ::deepmind::reverb::internal::Voidify() & \ + decltype(::deepmind::reverb::internal::LogStream( \ + std::integral_constant< \ + ::deepmind::reverb::internal::LogSeverity, level>()))( \ + __FILE__, __LINE__) \ + .stream() + +#endif // REVERB_CC_PLATFORM_DEFAULT_LOGGING_H_ diff --git a/reverb/cc/platform/default/net.cc b/reverb/cc/platform/default/net.cc new file mode 100644 index 0000000..62e869f --- /dev/null +++ b/reverb/cc/platform/default/net.cc @@ -0,0 +1,136 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/platform/net.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "reverb/cc/platform/logging.h" + +namespace deepmind::reverb::internal { +namespace { +bool IsPortAvailable(int* port, bool is_tcp) { + const int protocol = is_tcp ? IPPROTO_TCP : 0; + const int fd = socket(AF_INET, is_tcp ? SOCK_STREAM : SOCK_DGRAM, protocol); + + struct sockaddr_in addr; + socklen_t addr_len = sizeof(addr); + int actual_port; + + REVERB_CHECK_GE(*port, 0); + REVERB_CHECK_LE(*port, 65535); + if (fd < 0) { + REVERB_LOG(REVERB_ERROR) << "socket() failed: " << strerror(errno); + return false; + } + + // SO_REUSEADDR lets us start up a server immediately after it exists. + int one = 1; + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0) { + REVERB_LOG(REVERB_ERROR) << "setsockopt() failed: " << strerror(errno); + if (close(fd) < 0) { + REVERB_LOG(REVERB_ERROR) << "close() failed: " << strerror(errno); + } + return false; + } + + // Try binding to port. + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(static_cast(*port)); + if (bind(fd, reinterpret_cast(&addr), sizeof(addr)) < 0) { + REVERB_LOG(REVERB_WARNING) + << "bind(port=" << *port << ") failed: " << strerror(errno); + if (close(fd) < 0) { + REVERB_LOG(REVERB_ERROR) << "close() failed: " << strerror(errno); + } + return false; + } + + // Get the bound port number. + if (getsockname(fd, reinterpret_cast(&addr), &addr_len) < + 0) { + REVERB_LOG(REVERB_WARNING) << "getsockname() failed: " << strerror(errno); + if (close(fd) < 0) { + REVERB_LOG(REVERB_ERROR) << "close() failed: " << strerror(errno); + } + return false; + } + REVERB_CHECK_LE(addr_len, sizeof(addr)); + actual_port = ntohs(addr.sin_port); + REVERB_CHECK_GT(actual_port, 0); + if (*port == 0) { + *port = actual_port; + } else { + REVERB_CHECK_EQ(*port, actual_port); + } + if (close(fd) < 0) { + REVERB_LOG(REVERB_ERROR) << "close() failed: " << strerror(errno); + } + return true; +} + +const int kNumRandomPortsToPick = 100; +const int kMaximumTrials = 1000; + +} // namespace + +int PickUnusedPortOrDie() { + static std::unordered_set chosen_ports; + + // Type of port to first pick in the next iteration. + bool is_tcp = true; + int trial = 0; + while (true) { + int port; + trial++; + REVERB_CHECK_LE(trial, kMaximumTrials) + << "Failed to pick an unused port for testing."; + if (trial == 1) { + port = getpid() % (65536 - 30000) + 30000; + } else if (trial <= kNumRandomPortsToPick) { + port = rand() % (65536 - 30000) + 30000; // NOLINT: Ignore suggestion to use rand_r instead. + } else { + port = 0; + } + + if (chosen_ports.find(port) != chosen_ports.end()) { + continue; + } + if (!IsPortAvailable(&port, is_tcp)) { + continue; + } + + REVERB_CHECK_GT(port, 0); + if (!IsPortAvailable(&port, !is_tcp)) { + is_tcp = !is_tcp; + continue; + } + + chosen_ports.insert(port); + return port; + } + + return 0; +} + +} // namespace deepmind::reverb::internal diff --git a/reverb/cc/platform/default/repo.bzl b/reverb/cc/platform/default/repo.bzl new file mode 100644 index 0000000..200e9d1 --- /dev/null +++ b/reverb/cc/platform/default/repo.bzl @@ -0,0 +1,361 @@ +"""Reverb custom external dependencies.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +# Sanitize a dependency so that it works correctly from code that includes +# reverb as a submodule. +def clean_dep(dep): + return str(Label(dep)) + +def get_python_path(ctx): + path = ctx.os.environ.get("PYTHON_BIN_PATH") + if not path: + fail( + "Could not get environment variable PYTHON_BIN_PATH. " + + "Check your .bazelrc file.", + ) + return path + +def _find_tf_include_path(repo_ctx): + exec_result = repo_ctx.execute( + [ + get_python_path(repo_ctx), + "-c", + "import tensorflow as tf; import sys; " + + "sys.stdout.write(tf.sysconfig.get_include())", + ], + quiet = True, + ) + if exec_result.return_code != 0: + fail("Could not locate tensorflow installation path:\n{}" + .format(exec_result.stderr)) + return exec_result.stdout.splitlines()[-1] + +def _find_tf_lib_path(repo_ctx): + exec_result = repo_ctx.execute( + [ + get_python_path(repo_ctx), + "-c", + "import tensorflow as tf; import sys; " + + "sys.stdout.write(tf.sysconfig.get_lib())", + ], + quiet = True, + ) + if exec_result.return_code != 0: + fail("Could not locate tensorflow installation path:\n{}" + .format(exec_result.stderr)) + return exec_result.stdout.splitlines()[-1] + +def _find_numpy_include_path(repo_ctx): + exec_result = repo_ctx.execute( + [ + get_python_path(repo_ctx), + "-c", + "import numpy; import sys; " + + "sys.stdout.write(numpy.get_include())", + ], + quiet = True, + ) + if exec_result.return_code != 0: + fail("Could not locate numpy includes path:\n{}" + .format(exec_result.stderr)) + return exec_result.stdout.splitlines()[-1] + +def _find_python_include_path(repo_ctx): + exec_result = repo_ctx.execute( + [ + get_python_path(repo_ctx), + "-c", + "from distutils import sysconfig; import sys; " + + "sys.stdout.write(sysconfig.get_python_inc())", + ], + quiet = True, + ) + if exec_result.return_code != 0: + fail("Could not locate python includes path:\n{}" + .format(exec_result.stderr)) + return exec_result.stdout.splitlines()[-1] + +def _find_python_solib_path(repo_ctx): + exec_result = repo_ctx.execute( + [ + get_python_path(repo_ctx), + "-c", + "import sys; vi = sys.version_info; " + + "sys.stdout.write('python{}.{}'.format(vi.major, vi.minor))", + ], + ) + if exec_result.return_code != 0: + fail("Could not locate python shared library path:\n{}" + .format(exec_result.stderr)) + version = exec_result.stdout.splitlines()[-1] + basename = "lib{}.so".format(version) + exec_result = repo_ctx.execute( + ["{}-config".format(version), "--configdir"], + quiet = True, + ) + if exec_result.return_code != 0: + fail("Could not locate python shared library path:\n{}" + .format(exec_result.stderr)) + solib_dir = exec_result.stdout.splitlines()[-1] + full_path = repo_ctx.path("{}/{}".format(solib_dir, basename)) + if not full_path.exists: + fail("Unable to find python shared library file:\n{}/{}" + .format(solib_dir, basename)) + return struct(dir = solib_dir, basename = basename) + +def _eigen_archive_repo_impl(repo_ctx): + tf_include_path = _find_tf_include_path(repo_ctx) + repo_ctx.symlink(tf_include_path, "tf_includes") + repo_ctx.file( + "BUILD", + content = """ +cc_library( + name = "includes", + hdrs = glob(["tf_includes/Eigen/**/*.h", + "tf_includes/Eigen/**", + "tf_includes/unsupported/Eigen/**/*.h", + "tf_includes/unsupported/Eigen/**"]), + # https://groups.google.com/forum/#!topic/bazel-discuss/HyyuuqTxKok + includes = ["tf_includes"], + visibility = ["//visibility:public"], +) +""", + executable = False, + ) + +def _nsync_includes_repo_impl(repo_ctx): + tf_include_path = _find_tf_include_path(repo_ctx) + repo_ctx.symlink(tf_include_path + "/external", "nsync_includes") + repo_ctx.file( + "BUILD", + content = """ +cc_library( + name = "includes", + hdrs = glob(["nsync_includes/nsync/public/*.h"]), + includes = ["nsync_includes"], + visibility = ["//visibility:public"], +) +""", + executable = False, + ) + +def _zlib_includes_repo_impl(repo_ctx): + tf_include_path = _find_tf_include_path(repo_ctx) + repo_ctx.symlink( + tf_include_path + "/external/zlib", + "zlib", + ) + repo_ctx.file( + "BUILD", + content = """ +cc_library( + name = "includes", + hdrs = glob(["zlib/**/*.h"]), + includes = ["zlib"], + visibility = ["//visibility:public"], +) +""", + executable = False, + ) + +def _snappy_includes_repo_impl(repo_ctx): + tf_include_path = _find_tf_include_path(repo_ctx) + repo_ctx.symlink( + tf_include_path + "/external/snappy", + "snappy", + ) + repo_ctx.file( + "BUILD", + content = """ +cc_library( + name = "includes", + hdrs = glob(["snappy/*.h"]), + includes = ["snappy"], + visibility = ["//visibility:public"], +) +""", + executable = False, + ) + +def _protobuf_includes_repo_impl(repo_ctx): + tf_include_path = _find_tf_include_path(repo_ctx) + repo_ctx.symlink(tf_include_path, "tf_includes") + repo_ctx.symlink(Label("//third_party:protobuf.BUILD"), "BUILD") + +def _tensorflow_includes_repo_impl(repo_ctx): + tf_include_path = _find_tf_include_path(repo_ctx) + repo_ctx.symlink(tf_include_path, "tensorflow_includes") + repo_ctx.file( + "BUILD", + content = """ +cc_library( + name = "includes", + hdrs = glob( + [ + "tensorflow_includes/**/*.h", + "tensorflow_includes/third_party/eigen3/**", + ], + exclude = ["tensorflow_includes/absl/**/*.h"], + ), + includes = ["tensorflow_includes"], + deps = [ + "@eigen_archive//:includes", + "@protobuf_archive//:includes", + "@zlib_includes//:includes", + "@snappy_includes//:includes", + ], + visibility = ["//visibility:public"], +) +filegroup( + name = "protos", + srcs = glob(["tensorflow_includes/**/*.proto"]), + visibility = ["//visibility:public"], +) +""", + executable = False, + ) + +def _tensorflow_solib_repo_impl(repo_ctx): + tf_lib_path = _find_tf_lib_path(repo_ctx) + repo_ctx.symlink(tf_lib_path, "tensorflow_solib") + repo_ctx.file( + "BUILD", + content = """ +cc_library( + name = "framework_lib", + srcs = ["tensorflow_solib/libtensorflow_framework.so.2"], + deps = ["@python_includes", "@python_includes//:numpy_includes"], + visibility = ["//visibility:public"], +) +""", + ) + +def _python_includes_repo_impl(repo_ctx): + python_include_path = _find_python_include_path(repo_ctx) + python_solib = _find_python_solib_path(repo_ctx) + repo_ctx.symlink(python_include_path, "python_includes") + numpy_include_path = _find_numpy_include_path(repo_ctx) + repo_ctx.symlink(numpy_include_path, "numpy_includes") + repo_ctx.symlink( + "{}/{}".format(python_solib.dir, python_solib.basename), + python_solib.basename, + ) + + # Note, "@python_includes" is a misnomer since we include the + # libpythonX.Y.so in the srcs, so we can get access to python's various + # symbols at link time. + repo_ctx.file( + "BUILD", + content = """ +cc_library( + name = "python_includes", + hdrs = glob(["python_includes/**/*.h"]), + srcs = ["{}"], + includes = ["python_includes"], + visibility = ["//visibility:public"], +) +cc_library( + name = "numpy_includes", + hdrs = glob(["numpy_includes/**/*.h"]), + includes = ["numpy_includes"], + visibility = ["//visibility:public"], +) +""".format(python_solib.basename), + executable = False, + ) + +def cc_tf_configure(): + """Autoconf pre-installed tensorflow repo.""" + make_eigen_repo = repository_rule(implementation = _eigen_archive_repo_impl) + make_eigen_repo(name = "eigen_archive") + make_nsync_repo = repository_rule( + implementation = _nsync_includes_repo_impl, + ) + make_nsync_repo(name = "nsync_includes") + make_zlib_repo = repository_rule( + implementation = _zlib_includes_repo_impl, + ) + make_zlib_repo(name = "zlib_includes") + make_snappy_repo = repository_rule( + implementation = _snappy_includes_repo_impl, + ) + make_snappy_repo(name = "snappy_includes") + make_protobuf_repo = repository_rule( + implementation = _protobuf_includes_repo_impl, + ) + make_protobuf_repo(name = "protobuf_archive") + make_tfinc_repo = repository_rule( + implementation = _tensorflow_includes_repo_impl, + ) + make_tfinc_repo(name = "tensorflow_includes") + make_tflib_repo = repository_rule( + implementation = _tensorflow_solib_repo_impl, + ) + make_tflib_repo(name = "tensorflow_solib") + make_python_inc_repo = repository_rule( + implementation = _python_includes_repo_impl, + ) + make_python_inc_repo(name = "python_includes") + +def reverb_python_deps(): + http_archive( + name = "pybind11", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.4.3.tar.gz", + "https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz", + ], + sha256 = "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d", + strip_prefix = "pybind11-2.4.3", + build_file = clean_dep("//third_party:pybind11.BUILD"), + ) + + http_archive( + name = "absl_py", + sha256 = "603febc9b95a8f2979a7bdb77d2f5e4d9b30d4e0d59579f88eba67d4e4cc5462", + strip_prefix = "abseil-py-pypi-v0.9.0", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-py/archive/pypi-v0.9.0.tar.gz", + "https://github.com/abseil/abseil-py/archive/pypi-v0.9.0.tar.gz", + ], + ) + +def _reverb_protoc_archive(ctx): + version = ctx.attr.version + sha256 = ctx.attr.sha256 + + override_version = ctx.os.environ.get("REVERB_PROTOC_VERSION") + if override_version: + sha256 = "" + version = override_version + + urls = [ + "https://github.com/protocolbuffers/protobuf/releases/download/v%s/protoc-%s-linux-x86_64.zip" % (version, version), + ] + ctx.download_and_extract( + url = urls, + sha256 = sha256, + ) + + ctx.file( + "BUILD", + content = """ +filegroup( + name = "protoc_bin", + srcs = ["bin/protoc"], + visibility = ["//visibility:public"], +) +""", + executable = False, + ) + +reverb_protoc_archive = repository_rule( + implementation = _reverb_protoc_archive, + attrs = { + "version": attr.string(mandatory = True), + "sha256": attr.string(mandatory = True), + }, +) + +def reverb_protoc_deps(version, sha256): + reverb_protoc_archive(name = "protobuf_protoc", version = version, sha256 = sha256) diff --git a/reverb/cc/platform/default/snappy.cc b/reverb/cc/platform/default/snappy.cc new file mode 100644 index 0000000..5fd5e09 --- /dev/null +++ b/reverb/cc/platform/default/snappy.cc @@ -0,0 +1,171 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/platform/snappy.h" + +#include + +#include "absl/meta/type_traits.h" +#include "snappy-sinksource.h" // NOLINT(build/include) +#include "snappy.h" // NOLINT(build/include) + +namespace deepmind { +namespace reverb { + +namespace { + +// Helpers for STLStringResizeUninitialized +// HasMember is true_type or false_type, depending on whether or not +// T has a __resize_default_init member. Resize will call the +// __resize_default_init member if it exists, and will call the resize +// member otherwise. +template +struct ResizeUninitializedTraits { + using HasMember = std::false_type; + static void Resize(string_type* s, size_t new_size) { s->resize(new_size); } +}; + +// __resize_default_init is provided by libc++ >= 8.0. +template +struct ResizeUninitializedTraits< + string_type, absl::void_t() + .__resize_default_init(237))> > { + using HasMember = std::true_type; + static void Resize(string_type* s, size_t new_size) { + s->__resize_default_init(new_size); + } +}; + +template +inline constexpr bool STLStringSupportsNontrashingResize(string_type*) { + return ResizeUninitializedTraits::HasMember::value; +} + +// Resize string `s` to `new_size`, leaving the data uninitialized. +static inline void STLStringResizeUninitialized(std::string* s, + size_t new_size) { + ResizeUninitializedTraits::Resize(s, new_size); +} + +class StringSink : public snappy::Sink { + public: + explicit StringSink(std::string* dest) : dest_(dest) {} + + StringSink(const StringSink&) = delete; + StringSink& operator=(const StringSink&) = delete; + + void Append(const char* data, size_t n) override { + if (STLStringSupportsNontrashingResize(dest_)) { + size_t current_size = dest_->size(); + if (data == (const_cast(dest_->data()) + current_size)) { + // Zero copy append + STLStringResizeUninitialized(dest_, current_size + n); + return; + } + } + dest_->append(data, n); + } + + char* GetAppendBuffer(size_t size, char* scratch) override { + if (!STLStringSupportsNontrashingResize(dest_)) { + return scratch; + } + + const size_t current_size = dest_->size(); + if ((size + current_size) > dest_->capacity()) { + // Use resize instead of reserve so that we grow by the strings growth + // factor. Then reset the size to where it was. + STLStringResizeUninitialized(dest_, size + current_size); + STLStringResizeUninitialized(dest_, current_size); + } + + // If string size is zero, then string_as_array() returns nullptr, so + // we need to use data() instead + return const_cast(dest_->data()) + current_size; + } + + private: + std::string* dest_; +}; + +// TODO(b/140988915): See if this can be moved to snappy's codebase. +class CheckedByteArraySink : public snappy::Sink { + // A snappy Sink that takes an output buffer and a capacity value. If the + // writer attempts to write more data than capacity, it does the safe thing + // and doesn't attempt to write past the data boundary. After writing, + // call sink.Overflowed() to see if an overflow occurred. + + public: + CheckedByteArraySink(char* outbuf, size_t capacity) + : outbuf_(outbuf), capacity_(capacity), size_(0), overflowed_(false) {} + CheckedByteArraySink(const CheckedByteArraySink&) = delete; + CheckedByteArraySink& operator=(const CheckedByteArraySink&) = delete; + + void Append(const char* bytes, size_t n) override { + size_t available = capacity_ - size_; + if (n > available) { + n = available; + overflowed_ = true; + } + if (n > 0 && bytes != (outbuf_ + size_)) { + // Catch cases where the pointer returned by GetAppendBuffer() was + // modified. + assert(!(outbuf_ <= bytes && bytes < outbuf_ + capacity_)); + memcpy(outbuf_ + size_, bytes, n); + } + size_ += n; + } + + char* GetAppendBuffer(size_t length, char* scratch) override { + size_t available = capacity_ - size_; + if (available >= length) { + return outbuf_ + size_; + } else { + return scratch; + } + } + + // Returns the number of bytes actually written to the sink. + size_t NumberOfBytesWritten() const { return size_; } + + // Returns true if any bytes were discarded during the Append(), i.e., if + // Append() attempted to write more than 'capacity' bytes. + bool Overflowed() const { return overflowed_; } + + private: + char* outbuf_; + const size_t capacity_; + size_t size_; + bool overflowed_; +}; + +} // namespace + +template <> +size_t SnappyCompressFromString(absl::string_view input, std::string* output) { + snappy::ByteArraySource source(input.data(), input.size()); + StringSink sink(output); + return snappy::Compress(&source, &sink); +} + +template <> +bool SnappyUncompressToString(const std::string& input, size_t output_capacity, + char* output) { + snappy::ByteArraySource source(input.data(), input.size()); + CheckedByteArraySink sink(output, output_capacity); + return snappy::Uncompress(&source, &sink); +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/platform/default/thread.cc b/reverb/cc/platform/default/thread.cc new file mode 100644 index 0000000..583d283 --- /dev/null +++ b/reverb/cc/platform/default/thread.cc @@ -0,0 +1,45 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/platform/thread.h" + +#include // NOLINT(build/c++11) + +#include "absl/memory/memory.h" + +namespace deepmind { +namespace reverb { +namespace internal { +namespace { + +class StdThread : public Thread { + public: + explicit StdThread(std::function fn) : thread_(std::move(fn)) {} + + ~StdThread() override { thread_.join(); } + + private: + std::thread thread_; +}; + +} // namespace + +std::unique_ptr StartThread(absl::string_view name, + std::function fn) { + return {absl::make_unique(std::move(fn))}; +} + +} // namespace internal +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/platform/grpc_utils.h b/reverb/cc/platform/grpc_utils.h new file mode 100644 index 0000000..12f4659 --- /dev/null +++ b/reverb/cc/platform/grpc_utils.h @@ -0,0 +1,40 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_PLATFORM_GRPC_CREDENTIALS_H_ +#define REVERB_CC_PLATFORM_GRPC_CREDENTIALS_H_ + +#include + +#include "absl/strings/string_view.h" +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" +#include "grpcpp/support/channel_arguments.h" + +namespace deepmind { +namespace reverb { + +std::shared_ptr MakeServerCredentials(); + +std::shared_ptr MakeChannelCredentials(); + +std::shared_ptr CreateCustomGrpcChannel( + absl::string_view target, + const std::shared_ptr& credentials, + const grpc::ChannelArguments& channel_arguments); + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_PLATFORM_GRPC_CREDENTIALS_H_ diff --git a/reverb/cc/platform/logging.h b/reverb/cc/platform/logging.h new file mode 100644 index 0000000..a7d1335 --- /dev/null +++ b/reverb/cc/platform/logging.h @@ -0,0 +1,20 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_PLATFORM_LOGGING_H_ +#define REVERB_CC_PLATFORM_LOGGING_H_ + +#include "reverb/cc/platform/default/logging.h" + +#endif // REVERB_CC_PLATFORM_LOGGING_H_ diff --git a/reverb/cc/platform/net.h b/reverb/cc/platform/net.h new file mode 100644 index 0000000..6113d80 --- /dev/null +++ b/reverb/cc/platform/net.h @@ -0,0 +1,22 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_TESTING_NET_H_ +#define REVERB_CC_TESTING_NET_H_ + +namespace deepmind::reverb::internal{ +int PickUnusedPortOrDie(); +} // namespace deepmind::reverb::internal + +#endif // REVERB_CC_TESTING_NET_H_ diff --git a/reverb/cc/platform/net_test.cc b/reverb/cc/platform/net_test.cc new file mode 100644 index 0000000..e877bdd --- /dev/null +++ b/reverb/cc/platform/net_test.cc @@ -0,0 +1,36 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/platform/net.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "reverb/cc/platform/logging.h" + +namespace deepmind::reverb::internal{ +namespace { + + +TEST(Net, PickUnusedPortOrDie) { + int port0 = PickUnusedPortOrDie(); + int port1 = PickUnusedPortOrDie(); + REVERB_CHECK_GE(port0, 0); + REVERB_CHECK_LT(port0, 65536); + REVERB_CHECK_GE(port1, 0); + REVERB_CHECK_LT(port1, 65536); + REVERB_CHECK_NE(port0, port1); +} + +} // namespace +} // namespace deepmind::reverb::internal diff --git a/reverb/cc/platform/snappy.h b/reverb/cc/platform/snappy.h new file mode 100644 index 0000000..209313f --- /dev/null +++ b/reverb/cc/platform/snappy.h @@ -0,0 +1,37 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_PLATFORM_SNAPPY_H_ +#define REVERB_CC_PLATFORM_SNAPPY_H_ + +#include + +#include "absl/strings/string_view.h" + +namespace deepmind { +namespace reverb { + +// Compress a string to a `Toutput` output. Return the number of bytes stored. +template +size_t SnappyCompressFromString(absl::string_view input, Toutput* output); + +// Uncompress an `input` containing snappy-compressed data to *output. +template +bool SnappyUncompressToString(const Tinput& input, size_t output_capacity, + char* output); + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_PLATFORM_SNAPPY_H_ diff --git a/reverb/cc/platform/tfrecord_checkpointer.cc b/reverb/cc/platform/tfrecord_checkpointer.cc new file mode 100644 index 0000000..4be2fb2 --- /dev/null +++ b/reverb/cc/platform/tfrecord_checkpointer.cc @@ -0,0 +1,332 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/platform/tfrecord_checkpointer.h" + +#include +#include +#include +#include +#include +#include + +#include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/chunk_store.h" +#include "reverb/cc/distributions/fifo.h" +#include "reverb/cc/distributions/heap.h" +#include "reverb/cc/distributions/interface.h" +#include "reverb/cc/distributions/lifo.h" +#include "reverb/cc/distributions/prioritized.h" +#include "reverb/cc/distributions/uniform.h" +#include "reverb/cc/priority_table.h" +#include "reverb/cc/rate_limiter.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/table_extensions/interface.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" + +namespace deepmind { +namespace reverb { +namespace { + +constexpr char kTablesFileName[] = "tables.tfrecord"; +constexpr char kChunksFileName[] = "chunks.tfrecord"; +constexpr char kDoneFileName[] = "DONE"; + +using RecordWriterUniquePtr = + std::unique_ptr>; +using RecordReaderUniquePtr = + std::unique_ptr>; + +tensorflow::Status OpenWriter(const std::string& path, + RecordWriterUniquePtr* writer) { + std::unique_ptr file; + TF_RETURN_IF_ERROR(tensorflow::Env::Default()->NewWritableFile(path, &file)); + auto* file_ptr = file.release(); + *writer = RecordWriterUniquePtr(new tensorflow::io::RecordWriter(file_ptr), + [file_ptr](tensorflow::io::RecordWriter* w) { + delete w; + delete file_ptr; + }); + return tensorflow::Status::OK(); +} + +tensorflow::Status OpenReader(const std::string& path, + RecordReaderUniquePtr* reader) { + std::unique_ptr file; + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->NewRandomAccessFile(path, &file)); + auto* file_ptr = file.release(); + *reader = RecordReaderUniquePtr(new tensorflow::io::RecordReader(file_ptr), + [file_ptr](tensorflow::io::RecordReader* r) { + delete r; + delete file_ptr; + }); + return tensorflow::Status::OK(); +} + +inline tensorflow::Status WriteDone(const std::string& path) { + std::unique_ptr file; + TF_RETURN_IF_ERROR(tensorflow::Env::Default()->NewWritableFile( + tensorflow::io::JoinPath(path, kDoneFileName), &file)); + return file->Close(); +} + +inline bool HasDone(const std::string& path) { + return tensorflow::Env::Default() + ->FileExists(tensorflow::io::JoinPath(path, kDoneFileName)) + .ok(); +} + +std::unique_ptr MakeDistribution( + const KeyDistributionOptions& options) { + switch (options.distribution_case()) { + case KeyDistributionOptions::kFifo: + return absl::make_unique(); + case KeyDistributionOptions::kLifo: + return absl::make_unique(); + case KeyDistributionOptions::kUniform: + return absl::make_unique(); + case KeyDistributionOptions::kPrioritized: + return absl::make_unique( + options.prioritized().priority_exponent()); + case KeyDistributionOptions::kHeap: + return absl::make_unique(options.heap().min_heap()); + case KeyDistributionOptions::DISTRIBUTION_NOT_SET: + REVERB_LOG(REVERB_FATAL) << "Distribution not set"; + default: + REVERB_LOG(REVERB_FATAL) << "Distribution not supported"; + } +} + +inline size_t find_table_index( + const std::vector>* tables, + const std::string& name) { + for (int i = 0; i < tables->size(); i++) { + if (tables->at(i)->name() == name) return i; + } + return -1; +} + +} // namespace + +TFRecordCheckpointer::TFRecordCheckpointer(std::string root_dir, + std::string group) + : root_dir_(std::move(root_dir)), group_(std::move(group)) { + REVERB_LOG(REVERB_INFO) << "Initializing TFRecordCheckpointer in " + << root_dir_; +} + +tensorflow::Status TFRecordCheckpointer::Save( + std::vector tables, int keep_latest, std::string* path) { + if (keep_latest <= 0) { + return tensorflow::errors::InvalidArgument( + "TFRecordCheckpointer must have keep_latest > 0."); + } + if (!group_.empty()) { + return tensorflow::errors::InvalidArgument( + "Setting non-empty group is not supported"); + } + + std::string dir_path = + tensorflow::io::JoinPath(root_dir_, absl::FormatTime(absl::Now())); + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->RecursivelyCreateDir(dir_path)); + + RecordWriterUniquePtr table_writer; + TF_RETURN_IF_ERROR(OpenWriter( + tensorflow::io::JoinPath(dir_path, kTablesFileName), &table_writer)); + + absl::flat_hash_set> chunks; + for (PriorityTable* table : tables) { + auto checkpoint = table->Checkpoint(); + chunks.merge(checkpoint.chunks); + TF_RETURN_IF_ERROR( + table_writer->WriteRecord(checkpoint.checkpoint.SerializeAsString())); + } + + TF_RETURN_IF_ERROR(table_writer->Close()); + table_writer = nullptr; + + RecordWriterUniquePtr chunk_writer; + TF_RETURN_IF_ERROR(OpenWriter( + tensorflow::io::JoinPath(dir_path, kChunksFileName), &chunk_writer)); + + for (const auto& chunk : chunks) { + TF_RETURN_IF_ERROR( + chunk_writer->WriteRecord(chunk->data().SerializeAsString())); + } + TF_RETURN_IF_ERROR(chunk_writer->Close()); + chunk_writer = nullptr; + + // Both chunks and table checkpoint has now been written so we can proceed to + // add the DONE-file. + TF_RETURN_IF_ERROR(WriteDone(dir_path)); + + // Delete the older checkpoints. + std::vector filenames; + TF_RETURN_IF_ERROR(tensorflow::Env::Default()->GetMatchingPaths( + tensorflow::io::JoinPath(root_dir_, "*"), &filenames)); + std::sort(filenames.begin(), filenames.end()); + int history_counter = 0; + for (auto it = filenames.rbegin(); it != filenames.rend(); it++) { + if (++history_counter > keep_latest) { + tensorflow::int64 undeleted_files; + tensorflow::int64 undeleted_dirs; + TF_RETURN_IF_ERROR(tensorflow::Env::Default()->DeleteRecursively( + *it, &undeleted_files, &undeleted_dirs)); + } + } + + *path = std::move(dir_path); + return tensorflow::Status::OK(); +} + +tensorflow::Status TFRecordCheckpointer::Load( + absl::string_view relative_path, ChunkStore* chunk_store, + std::vector>* tables) { + const std::string dir_path = + tensorflow::io::JoinPath(root_dir_, relative_path); + REVERB_LOG(REVERB_INFO) << "Loading checkpoint from " << dir_path; + if (!HasDone(dir_path)) { + return tensorflow::errors::InvalidArgument( + absl::StrCat("Load called with invalid checkpoint path: ", dir_path)); + } + // Insert data first to ensure that all data referenced by the priority tables + // exists. Keep the map of chunks around so that none of the chunks are + // cleaned up before all the priority tables have been loaded. + absl::flat_hash_map> + chunk_by_key; + { + RecordReaderUniquePtr chunk_reader; + TF_RETURN_IF_ERROR(OpenReader( + tensorflow::io::JoinPath(dir_path, kChunksFileName), &chunk_reader)); + + ChunkData chunk_data; + tensorflow::Status chunk_status; + tensorflow::uint64 chunk_offset = 0; + tensorflow::tstring chunk_record; + do { + chunk_status = chunk_reader->ReadRecord(&chunk_offset, &chunk_record); + if (!chunk_status.ok()) break; + if (!chunk_data.ParseFromArray(chunk_record.data(), + chunk_record.size())) { + return tensorflow::errors::DataLoss( + "Could not parse TFRecord as ChunkData: '", chunk_record, "'"); + } + chunk_by_key[chunk_data.chunk_key()] = chunk_store->Insert(chunk_data); + } while (chunk_status.ok()); + if (!tensorflow::errors::IsOutOfRange(chunk_status)) { + return chunk_status; + } + } + + RecordReaderUniquePtr table_reader; + TF_RETURN_IF_ERROR(OpenReader( + tensorflow::io::JoinPath(dir_path, kTablesFileName), &table_reader)); + + PriorityTableCheckpoint checkpoint; + tensorflow::Status table_status; + tensorflow::uint64 table_offset = 0; + tensorflow::tstring table_record; + do { + table_status = table_reader->ReadRecord(&table_offset, &table_record); + if (!table_status.ok()) break; + if (!checkpoint.ParseFromArray(table_record.data(), table_record.size())) { + return tensorflow::errors::DataLoss( + "Could not parse TFRecord as Checkpoint: '", table_record, "'"); + } + + int index = find_table_index(tables, checkpoint.table_name()); + if (index == -1) { + std::vector table_names; + for (const auto& table : *tables) { + table_names.push_back(absl::StrCat("'", table->name(), "'")); + } + return tensorflow::errors::InvalidArgument( + "Trying to load table ", checkpoint.table_name(), + " but table was not found in provided list of tables. Available " + "tables: [", + absl::StrJoin(table_names, ", "), "]"); + } + + auto sampler = MakeDistribution(checkpoint.sampler()); + auto remover = MakeDistribution(checkpoint.remover()); + auto rate_limiter = + std::make_shared(checkpoint.rate_limiter()); + + auto table = std::make_shared( + /*name=*/checkpoint.table_name(), /*sampler=*/std::move(sampler), + /*remover=*/std::move(remover), + /*max_size=*/checkpoint.max_size(), + /*max_times_sampled=*/checkpoint.max_times_sampled(), + /*rate_limiter=*/std::move(rate_limiter), + /*extensions=*/tables->at(index)->extensions()); + + for (const auto& checkpoint_item : checkpoint.items()) { + PriorityTable::Item insert_item; + insert_item.item = checkpoint_item; + + for (const auto& key : checkpoint_item.chunk_keys()) { + REVERB_CHECK(chunk_by_key.contains(key)); + insert_item.chunks.push_back(chunk_by_key[key]); + } + + TF_RETURN_IF_ERROR(table->InsertCheckpointItem(std::move(insert_item))); + } + + tables->at(index).swap(table); + } while (table_status.ok()); + + if (!tensorflow::errors::IsOutOfRange(table_status)) { + return table_status; + } + return tensorflow::Status::OK(); +} + +tensorflow::Status TFRecordCheckpointer::LoadLatest( + ChunkStore* chunk_store, + std::vector>* tables) { + REVERB_LOG(REVERB_INFO) << "Loading latest checkpoint from " << root_dir_; + std::vector filenames; + TF_RETURN_IF_ERROR(tensorflow::Env::Default()->GetMatchingPaths( + tensorflow::io::JoinPath(root_dir_, "*"), &filenames)); + std::sort(filenames.begin(), filenames.end()); + for (auto it = filenames.rbegin(); it != filenames.rend(); it++) { + if (HasDone(*it)) { + return Load(tensorflow::io::Basename(*it), chunk_store, tables); + } + } + return tensorflow::errors::NotFound( + absl::StrCat("No checkpoint found in ", root_dir_)); +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/platform/tfrecord_checkpointer.h b/reverb/cc/platform/tfrecord_checkpointer.h new file mode 100644 index 0000000..7575395 --- /dev/null +++ b/reverb/cc/platform/tfrecord_checkpointer.h @@ -0,0 +1,100 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_PLATFORM_TFRECORD_CHECKPOINTER_H_ +#define REVERB_CC_PLATFORM_TFRECORD_CHECKPOINTER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "reverb/cc/checkpointing/interface.h" +#include "reverb/cc/chunk_store.h" +#include "reverb/cc/priority_table.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +// Generates and stores proto checkpoints of PriorityTables and ChunkStore data +// to a directory inside the top level `root_dir`. +// +// A set of PriorityTable constitutes the bases for a checkpoint. When `Save` is +// called state of each PriorityTable is encoded into a PriorityTableCheckpoint. +// The proto contains the state and initialization options of the table itself +// and all its dependencies (RateLimiter, KeyDistribution etc) but does not +// include the actual data. Instead a container with shared_ptr to every +// referenced ChunkStore::Chunk is attached which ensures that all data remains +// for the complete duration of the checkpointing operation. +// +// To avoid duplicating data, the union of the referenced chunks are +// deduplicated before being stored to disk. The stored checkpoint has the +// following format: +// +// / +// / +// tables.tfrecord +// chunks.tfrecord +// DONE +// +// DONE an empty file written once the checkpoint has been successfully written. +// If DONE does not exist then the checkpoint is in process of being written or +// the operation was unexpectedly interrupted and the data should be considered +// corrupt. +// +// The most recent checkpoint can therefore be inferred from the name of the +// directories within `root_dir`. +// +// If `group` is nonempty then the directory containing the checkpoint will be +// created with `group` as group. +class TFRecordCheckpointer : public CheckpointerInterface { + public: + explicit TFRecordCheckpointer(std::string root_dir, std::string group = ""); + + // Save a new checkpoint for every table in `tables` in sub directory + // inside `root_dir_`. If the call is successful, the ABSOLUTE path to the + // newly created checkpoint directory is returned. + // + // If `root_path_` does not exist then `Save` attempts to recursively + // create it before proceeding. + // + // After a successful save, all but the `keep_latest` most recent checkpoints + // are deleted. + tensorflow::Status Save(std::vector tables, int keep_latest, + std::string* path) override; + + // Attempts to load a checkpoint stored within `root_dir_`. + tensorflow::Status Load( + absl::string_view relative_path, ChunkStore* chunk_store, + std::vector>* tables) override; + + // Finds the most recent checkpoint within `root_dir_` and calls `Load`. + tensorflow::Status LoadLatest( + ChunkStore* chunk_store, + std::vector>* tables) override; + + // TFRecordCheckpointer is neither copyable nor movable. + TFRecordCheckpointer(const TFRecordCheckpointer&) = delete; + TFRecordCheckpointer& operator=(const TFRecordCheckpointer&) = delete; + + private: + const std::string root_dir_; + const std::string group_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_PLATFORM_TFRECORD_CHECKPOINTER_H_ diff --git a/reverb/cc/platform/tfrecord_checkpointer_test.cc b/reverb/cc/platform/tfrecord_checkpointer_test.cc new file mode 100644 index 0000000..a1bd403 --- /dev/null +++ b/reverb/cc/platform/tfrecord_checkpointer_test.cc @@ -0,0 +1,233 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/platform/tfrecord_checkpointer.h" + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "reverb/cc/chunk_store.h" +#include "reverb/cc/distributions/fifo.h" +#include "reverb/cc/distributions/heap.h" +#include "reverb/cc/distributions/prioritized.h" +#include "reverb/cc/distributions/uniform.h" +#include "reverb/cc/priority_table.h" +#include "reverb/cc/rate_limiter.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" + +namespace deepmind { +namespace reverb { +namespace { + +using ::deepmind::reverb::testing::EqualsProto; + +inline std::string MakeRoot() { + std::string name; + REVERB_CHECK(tensorflow::Env::Default()->LocalTempFilename(&name)); + return name; +} + +std::unique_ptr MakeUniformTable(const std::string& name) { + return absl::make_unique( + name, absl::make_unique(), + absl::make_unique(), 1000, 0, + absl::make_unique(1.0, 1, -DBL_MAX, DBL_MAX)); +} + +std::unique_ptr MakePrioritizedTable(const std::string& name, + double exponent) { + return absl::make_unique( + name, absl::make_unique(exponent), + absl::make_unique(), 1000, 0, + absl::make_unique(1.0, 1, -DBL_MAX, DBL_MAX)); +} + +TEST(TFRecordCheckpointerTest, CreatesDirectoryInRoot) { + std::string root = MakeRoot(); + TFRecordCheckpointer checkpointer(root); + std::string path; + auto* env = tensorflow::Env::Default(); + TF_ASSERT_OK(checkpointer.Save(std::vector{}, 1, &path)); + ASSERT_EQ(tensorflow::io::Dirname(path), root); + TF_EXPECT_OK(env->FileExists(path)); +} + +TEST(TFRecordCheckpointerTest, SaveAndLoad) { + ChunkStore chunk_store; + + std::vector> tables; + tables.push_back(MakeUniformTable("uniform")); + tables.push_back(MakePrioritizedTable("prioritized_a", 0.5)); + tables.push_back(MakePrioritizedTable("prioritized_b", 0.9)); + + std::vector chunk_keys; + for (int i = 0; i < 100; i++) { + for (int j = 0; j < tables.size(); j++) { + chunk_keys.push_back((j + 1) * 1000 + i); + auto chunk = + chunk_store.Insert(testing::MakeChunkData(chunk_keys.back())); + TF_EXPECT_OK(tables[j]->InsertOrAssign( + {testing::MakePrioritizedItem(i, i, {chunk->data()}), {chunk}})); + } + } + + for (int i = 0; i < 100; i++) { + for (auto& table : tables) { + PriorityTable::SampledItem sample; + TF_EXPECT_OK(table->Sample(&sample)); + } + } + + TFRecordCheckpointer checkpointer(MakeRoot()); + + std::string path; + TF_ASSERT_OK(checkpointer.Save( + {tables[0].get(), tables[1].get(), tables[2].get()}, 1, &path)); + + ChunkStore loaded_chunk_store; + std::vector> loaded_tables; + loaded_tables.push_back(MakeUniformTable("uniform")); + loaded_tables.push_back(MakePrioritizedTable("prioritized_a", 0.5)); + loaded_tables.push_back(MakePrioritizedTable("prioritized_b", 0.9)); + TF_ASSERT_OK(checkpointer.Load(tensorflow::io::Basename(path), + &loaded_chunk_store, &loaded_tables)); + + // Check that all the chunks have been added. + std::vector> chunks; + TF_EXPECT_OK(loaded_chunk_store.Get(chunk_keys, &chunks)); + + // Check that the number of items matches for the loaded tables. + for (int i = 0; i < tables.size(); i++) { + EXPECT_EQ(loaded_tables[i]->size(), tables[i]->size()); + } + + // Sample a random item and check that it matches the item in the original + // table. + for (int i = 0; i < tables.size(); i++) { + PriorityTable::SampledItem sample; + TF_EXPECT_OK(loaded_tables[i]->Sample(&sample)); + bool item_found = false; + for (auto& item : tables[i]->Copy()) { + if (item.item.key() == sample.item.key()) { + item_found = true; + item.item.set_times_sampled(item.item.times_sampled() + 1); + EXPECT_THAT(item.item, EqualsProto(sample.item)); + break; + } + } + EXPECT_TRUE(item_found); + } +} + +TEST(TFRecordCheckpointerTest, SaveDeletesOldData) { + ChunkStore chunk_store; + + std::vector> tables; + tables.push_back(MakeUniformTable("uniform")); + tables.push_back(MakePrioritizedTable("prioritized_a", 0.5)); + tables.push_back(MakePrioritizedTable("prioritized_b", 0.9)); + + std::vector chunk_keys; + for (int i = 0; i < 100; i++) { + for (int j = 0; j < tables.size(); j++) { + chunk_keys.push_back((j + 1) * 1000 + i); + auto chunk = + chunk_store.Insert(testing::MakeChunkData(chunk_keys.back())); + TF_EXPECT_OK(tables[j]->InsertOrAssign( + {testing::MakePrioritizedItem(i, i, {chunk->data()}), {chunk}})); + } + } + + for (int i = 0; i < 100; i++) { + for (auto& table : tables) { + PriorityTable::SampledItem sample; + TF_EXPECT_OK(table->Sample(&sample)); + } + } + + auto test = [&tables](int keep_latest) { + auto root = MakeRoot(); + TFRecordCheckpointer checkpointer(root); + + for (int i = 0; i < 10; i++) { + std::string path; + TF_ASSERT_OK( + checkpointer.Save({tables[0].get(), tables[1].get(), tables[2].get()}, + keep_latest, &path)); + + std::vector filenames; + TF_ASSERT_OK(tensorflow::Env::Default()->GetMatchingPaths( + tensorflow::io::JoinPath(root, "*"), &filenames)); + ASSERT_EQ(filenames.size(), std::min(keep_latest, i + 1)); + } + }; + test(1); // Keep one checkpoint. + test(3); // Edge case keep_latest == num_tables + test(5); // Edge case keep_latest > num_tables +} + +TEST(TFRecordCheckpointerTest, KeepLatestZeroReturnsError) { + ChunkStore chunk_store; + + std::vector> tables; + tables.push_back(MakeUniformTable("uniform")); + tables.push_back(MakePrioritizedTable("prioritized_a", 0.5)); + tables.push_back(MakePrioritizedTable("prioritized_b", 0.9)); + + std::vector chunk_keys; + for (int i = 0; i < 100; i++) { + for (int j = 0; j < tables.size(); j++) { + chunk_keys.push_back((j + 1) * 1000 + i); + auto chunk = + chunk_store.Insert(testing::MakeChunkData(chunk_keys.back())); + TF_EXPECT_OK(tables[j]->InsertOrAssign( + {testing::MakePrioritizedItem(i, i, {chunk->data()}), {chunk}})); + } + } + + for (int i = 0; i < 100; i++) { + for (auto& table : tables) { + PriorityTable::SampledItem sample; + TF_EXPECT_OK(table->Sample(&sample)); + } + } + + TFRecordCheckpointer checkpointer(MakeRoot()); + std::string path; + EXPECT_EQ( + checkpointer + .Save({tables[0].get(), tables[1].get(), tables[2].get()}, 0, &path) + .code(), + tensorflow::error::INVALID_ARGUMENT); +} + +TEST(TFRecordCheckpointerTest, LoadLatestInEmptyDir) { + TFRecordCheckpointer checkpointer(MakeRoot()); + ChunkStore chunk_store; + std::vector> tables; + EXPECT_EQ(checkpointer.LoadLatest(&chunk_store, &tables).code(), + tensorflow::error::NOT_FOUND); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/platform/thread.h b/reverb/cc/platform/thread.h new file mode 100644 index 0000000..5d9def7 --- /dev/null +++ b/reverb/cc/platform/thread.h @@ -0,0 +1,58 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Used to switch between different threading implementations. Concretely we +// switch between Google internal threading libraries and std::thread. + +#ifndef REVERB_CC_SUPPORT_THREAD_H_ +#define REVERB_CC_SUPPORT_THREAD_H_ + +#include +#include + +#include "absl/strings/string_view.h" + +namespace deepmind { +namespace reverb { +namespace internal { + +// The `Thread` class can be subclassed to hold an object that invokes a +// method in a separate thread. A `Thread` is considered to be active after +// construction until the execution terminates. Calling the destructor of this +// class must join the separate thread and block until it has completed. +// Use `StartThread()` to create an instance of this class. +class Thread { + public: + // Joins the running thread, i.e. blocks until the thread function has + // returned. + virtual ~Thread() = default; + + // A Thread is not copyable. + Thread(const Thread&) = delete; + Thread& operator=(const Thread&) = delete; + + protected: + Thread() = default; +}; + +// Starts a new thread that executes (a copy of) fn. The `name_prefix` may be +// used by the implementation to label the new thread. +std::unique_ptr StartThread(absl::string_view name_prefix, + std::function fn); + +} // namespace internal +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_SUPPORT_THREAD_H_ diff --git a/reverb/cc/platform/thread_test.cc b/reverb/cc/platform/thread_test.cc new file mode 100644 index 0000000..5a8b7b7 --- /dev/null +++ b/reverb/cc/platform/thread_test.cc @@ -0,0 +1,40 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/platform/thread.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/synchronization/notification.h" + +namespace deepmind { +namespace reverb { +namespace internal { +namespace { + +TEST(ThreadStdTest, ThreadRuns) { + absl::Notification n; + int x; + auto t = StartThread("", [&n, &x] { + x = 7; + n.Notify(); + }); + n.WaitForNotification(); + EXPECT_EQ(x, 7); +} + +} // namespace +} // namespace internal +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/priority_table.cc b/reverb/cc/priority_table.cc new file mode 100644 index 0000000..87afc40 --- /dev/null +++ b/reverb/cc/priority_table.cc @@ -0,0 +1,377 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/priority_table.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/timestamp.pb.h" +#include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/chunk_store.h" +#include "reverb/cc/distributions/interface.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/rate_limiter.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/table_extensions/interface.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { +namespace { + +using Extensions = + std::vector>; + +inline bool IsAdjacent(const SequenceRange& a, const SequenceRange& b) { + return a.episode_id() == b.episode_id() && a.end() + 1 == b.start(); +} + +inline bool IsInsertedBefore(const PrioritizedItem& a, + const PrioritizedItem& b) { + return a.inserted_at().seconds() < b.inserted_at().seconds() || + (a.inserted_at().seconds() == b.inserted_at().seconds() && + a.inserted_at().nanos() < b.inserted_at().nanos()); +} + +inline void EncodeAsTimestampProto(absl::Time t, + google::protobuf::Timestamp* proto) { + const int64_t s = absl::ToUnixSeconds(t); + proto->set_seconds(s); + proto->set_nanos((t - absl::FromUnixSeconds(s)) / absl::Nanoseconds(1)); +} + +} // namespace + +PriorityTable::PriorityTable( + std::string name, std::shared_ptr sampler, + std::shared_ptr remover, int64_t max_size, + int32_t max_times_sampled, std::shared_ptr rate_limiter, + Extensions extensions, + absl::optional signature) + : sampler_(std::move(sampler)), + remover_(std::move(remover)), + max_size_(max_size), + max_times_sampled_(max_times_sampled), + name_(std::move(name)), + rate_limiter_(std::move(rate_limiter)), + extensions_(std::move(extensions)), + signature_(std::move(signature)) { + TF_CHECK_OK(rate_limiter_->RegisterPriorityTable(this)); +} + +PriorityTable::~PriorityTable() { + rate_limiter_->UnregisterPriorityTable(&mu_, this); +} + +std::vector PriorityTable::Copy(size_t count) const { + std::vector items; + absl::ReaderMutexLock lock(&mu_); + items.reserve(count == 0 ? data_.size() : count); + for (auto it = data_.cbegin(); + it != data_.cend() && (count == 0 || items.size() < count); it++) { + items.push_back(it->second); + } + return items; +} + +tensorflow::Status PriorityTable::InsertOrAssign(Item item) { + auto key = item.item.key(); + auto priority = item.item.priority(); + + absl::WriterMutexLock lock(&mu_); + + /// If item already exists in table then update its priority. + if (data_.contains(key)) { + return UpdateItem(key, priority, /*diffuse=*/true); + } + + // Wait for the insert to be staged. While waiting the lock is released but + // once it returns the lock is aquired again. While waiting for the right to + // insert the operation might have transformed into an update. + TF_RETURN_IF_ERROR(rate_limiter_->AwaitCanInsert(&mu_)); + + if (data_.contains(key)) { + return UpdateItem(key, priority, /*diffuse=*/true); + } + + // Set the insertion timestamp after the lock has been acquired as this + // represents the order it was inserted into the sampler and remover. + EncodeAsTimestampProto(absl::Now(), item.item.mutable_inserted_at()); + data_[key] = std::move(item); + + // TODO(b/154929932): If these fail, the rate limiter becomes out of sync. + TF_RETURN_IF_ERROR(sampler_->Insert(key, priority)); + TF_RETURN_IF_ERROR(remover_->Insert(key, priority)); + + auto it = data_.find(key); + for (auto& extension : extensions_) { + extension->OnInsert(it->second); + } + + // Remove an item if we exceeded `max_size_`. + if (data_.size() > max_size_) { + DeleteItem(remover_->Sample().key); + } + + // Now that the new item has been inserted and an older item has + // (potentially) been removed the insert can be finalized. + rate_limiter_->Insert(&mu_); + + return tensorflow::Status::OK(); +} + +tensorflow::Status PriorityTable::MutateItems( + absl::Span updates, absl::Span deletes) { + absl::WriterMutexLock lock(&mu_); + + for (Key key : deletes) { + DeleteItem(key); + } + + for (const auto& item : updates) { + TF_RETURN_IF_ERROR( + UpdateItem(item.key(), item.priority(), /*diffuse=*/true)); + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status PriorityTable::Sample(SampledItem* sampled_item) { + absl::WriterMutexLock lock(&mu_); + TF_RETURN_IF_ERROR(rate_limiter_->AwaitAndFinalizeSample(&mu_)); + + KeyDistributionInterface::KeyWithProbability sample = sampler_->Sample(); + Item& item = data_.at(sample.key); + + // Increment the sample count. + item.item.set_times_sampled(item.item.times_sampled() + 1); + + // Copy Details of the sampled item. + sampled_item->item = item.item; + sampled_item->chunks = item.chunks; + sampled_item->probability = sample.probability; + sampled_item->table_size = data_.size(); + + // Notify extensions which item was sampled. + for (auto& extension : extensions_) { + extension->OnSample(item); + } + + // If there is an upper bound of the number of times an item can be sampled + // and it is now reached then delete the item before the lock is released. + if (item.item.times_sampled() == max_times_sampled_) { + DeleteItem(item.item.key()); + } + + return tensorflow::Status::OK(); +} + +int64_t PriorityTable::size() const { + absl::ReaderMutexLock lock(&mu_); + return data_.size(); +} + +const std::string& PriorityTable::name() const { return name_; } + +TableInfo PriorityTable::info() const { + absl::ReaderMutexLock lock(&mu_); + TableInfo info; + info.set_name(name_); + info.set_max_size(max_size_); + info.set_max_times_sampled(max_times_sampled_); + *info.mutable_rate_limiter_info() = rate_limiter_->info(); + + if (signature_) { + *info.mutable_signature() = *signature_; + } + + *info.mutable_sampler_options() = sampler_->options(); + *info.mutable_remover_options() = remover_->options(); + info.set_current_size(data_.size()); + + return info; +} + +void PriorityTable::Close() { + absl::WriterMutexLock lock(&mu_); + rate_limiter_->Cancel(&mu_); +} + +void PriorityTable::DeleteItem(PriorityTable::Key key) { + auto it = data_.find(key); + if (it == data_.end()) return; + + for (auto& extension : extensions_) { + extension->OnDelete(it->second); + } + + data_.erase(it); + rate_limiter_->Delete(&mu_); + TF_CHECK_OK(sampler_->Delete(key)); + TF_CHECK_OK(remover_->Delete(key)); +} + +tensorflow::Status PriorityTable::UpdateItem(Key key, double priority, + bool diffuse) { + auto it = data_.find(key); + if (it == data_.end()) { + return tensorflow::Status::OK(); + } + const double old_priority = it->second.item.priority(); + it->second.item.set_priority(priority); + TF_RETURN_IF_ERROR(sampler_->Update(key, priority)); + TF_RETURN_IF_ERROR(remover_->Update(key, priority)); + + for (auto& extension : extensions_) { + extension->OnUpdate(it->second); + } + + if (diffuse) { + for (auto& extension : extensions_) { + for (const auto& diffused_item : + extension->Diffuse(this, it->second, old_priority)) { + TF_RETURN_IF_ERROR(UpdateItem( + diffused_item.key(), diffused_item.priority(), /*diffuse=*/false)); + } + } + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status PriorityTable::Reset() { + absl::WriterMutexLock lock(&mu_); + + for (auto& extension : extensions_) { + extension->OnReset(); + } + + sampler_->Clear(); + remover_->Clear(); + + data_.clear(); + + rate_limiter_->Reset(&mu_); + + return tensorflow::Status::OK(); +} + +PriorityTable::CheckpointAndChunks PriorityTable::Checkpoint() { + absl::ReaderMutexLock lock(&mu_); + + PriorityTableCheckpoint checkpoint; + checkpoint.set_table_name(name()); + checkpoint.set_max_size(max_size_); + checkpoint.set_max_times_sampled(max_times_sampled_); + *checkpoint.mutable_sampler() = sampler_->options(); + *checkpoint.mutable_remover() = remover_->options(); + + // Note that is is important that the rate limiter checkpoint is + // finalized before the items are added + *checkpoint.mutable_rate_limiter() = rate_limiter_->CheckpointReader(&mu_); + + absl::flat_hash_set> chunks; + for (const auto& entry : data_) { + *checkpoint.add_items() = entry.second.item; + chunks.insert(entry.second.chunks.begin(), entry.second.chunks.end()); + } + + // Sort the items in ascending order based on their insertion time. This makes + // it possible to reconstruct ordered structures (Fifo) when the checkpoint is + // loaded. + std::sort(checkpoint.mutable_items()->begin(), + checkpoint.mutable_items()->end(), IsInsertedBefore); + + return {std::move(checkpoint), std::move(chunks)}; +} + +tensorflow::Status PriorityTable::InsertCheckpointItem( + PriorityTable::Item item) { + absl::WriterMutexLock lock(&mu_); + REVERB_CHECK_LE(data_.size() + 1, max_size_) + << "InsertCheckpointItem called on already full PriorityTable"; + REVERB_CHECK(!data_.contains(item.item.key())) + << "InsertCheckpointItem called for item with already present key: " + << item.item.key(); + + TF_RETURN_IF_ERROR(sampler_->Insert(item.item.key(), item.item.priority())); + TF_RETURN_IF_ERROR(remover_->Insert(item.item.key(), item.item.priority())); + + auto it = data_.emplace(item.item.key(), std::move(item)).first; + for (auto& extension : extensions_) { + extension->OnInsert(it->second); + } + + return tensorflow::Status::OK(); +} + +bool PriorityTable::Get(PriorityTable::Key key, PriorityTable::Item* item) { + absl::ReaderMutexLock lock(&mu_); + auto it = data_.find(key); + if (it != data_.end()) { + *item = it->second; + return true; + } + return false; +} + +const absl::flat_hash_map* +PriorityTable::RawLookup() { + mu_.AssertReaderHeld(); + return &data_; +} + +void PriorityTable::UnsafeAddExtension( + std::shared_ptr extension) { + absl::WriterMutexLock lock(&mu_); + REVERB_CHECK(data_.empty()); + extensions_.push_back(std::move(extension)); +} + +const std::vector>& +PriorityTable::extensions() const { + return extensions_; +} + +const absl::optional& PriorityTable::signature() + const { + return signature_; +} + +bool PriorityTable::CanSample(int num_samples) const { + absl::ReaderMutexLock lock(&mu_); + return rate_limiter_->CanSample(&mu_, num_samples); +} + +bool PriorityTable::CanInsert(int num_inserts) const { + absl::ReaderMutexLock lock(&mu_); + return rate_limiter_->CanInsert(&mu_, num_inserts); +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/priority_table.h b/reverb/cc/priority_table.h new file mode 100644 index 0000000..c860578 --- /dev/null +++ b/reverb/cc/priority_table.h @@ -0,0 +1,252 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_PRIORITY_TABLE_H_ +#define REVERB_CC_PRIORITY_TABLE_H_ + +#include +#include +#include +#include +#include + +#include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/chunk_store.h" +#include "reverb/cc/distributions/interface.h" +#include "reverb/cc/priority_table_item.h" +#include "reverb/cc/rate_limiter.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/table_extensions/interface.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace deepmind { +namespace reverb { + +// Maintains a priority distribution over keys used for sampling from the +// replay. Internally, this container maintains two instances of +// KeyDistributionInterface, one for sampling and one for removing. The remover +// is needed to ensure that the size of the container does not grow beyond a +// given capacity. +// +// Please note that the removing implementation only limits the number of items +// in the priority table, not the number of timesteps (or actual memory) on this +// server. When we delete an item of a priority table, the reference counts for +// its chunks decreases and we can maybe delete the chunks. However, this is not +// guaranteed, as other priority tables might still hold references to the +// chunks in which case no memory is freed up. This means you must be careful +// when choosing the remover strategy. A dangerous example would be using a FIFO +// remover for one priority table and then introducing another with table with a +// LIFO remover. In this scenario, the two priority tables would not share any +// chunks and would this require twice the amount of storage. +// +// All public methods are thread safe. +class PriorityTable { + public: + using Key = KeyDistributionInterface::Key; + using Item = PriorityTableItem; + + // Used as the return of Sample(). Note that this returns the probability of + // an item instead as opposed to the raw priority value. + struct SampledItem { + PrioritizedItem item; + std::vector> chunks; + double probability; + int64_t table_size; + }; + + // Used when checkpointing to ensure that none of the chunks referenced by the + // checkpointed items are removed before the checkpoint operations has + // completed. + struct CheckpointAndChunks { + PriorityTableCheckpoint checkpoint; + absl::flat_hash_set> chunks; + }; + + // Constructor. + // `name` is the name of the table. Must be unique within server. + // `sampler` is used in Sample() calls, while `remover` is used in + // InsertOrAssign() when we need to remove an item to not exceed `max_size` + // items in this container. + // `max_times_sampled` is the maximum number of times we allow for an item to + // be sampled before it is deleted. No value lower than 1 will be used. + // `rate_limiter` controls when sample and insert calls are allowed to + // proceed. + // `extensions` allows additional features in the table, like time diffusion. + // `signature` allows an optional declaration of the data that can be stored + // in this table. writers and readers are responsible for checking against + // this signature, as it is available via RPC request. + PriorityTable( + std::string name, std::shared_ptr sampler, + std::shared_ptr remover, int64_t max_size, + int32_t max_times_sampled, std::shared_ptr rate_limiter, + std::vector> extensions = + {}, + absl::optional signature = absl::nullopt); + + ~PriorityTable(); + + // Copies at most `count` items that are currently in the table. + // If `count` is `0` (default) then all items are copied. + // If `count` is less than `size` then a subset is selected with in an + // undefined manner. + std::vector Copy(size_t count = 0) const; + + // Attempts to insert an item into the priority distribution. If the item + // already exists, the existing item is updated. Also applies the necessary + // updates to sampler and remover. + // + // This call also ensures that the container does not grow larger than + // `max_size`. If an insertion causes the container to exceed `max_size_`, one + // item is removed with the strategy specified through `remover_`. Please note + // that we insert the new item that exceeds the capacity BEFORE we run the + // remover. This means that the newly inserted item could be deleted right + // away. + tensorflow::Status InsertOrAssign(Item item); + + // Inserts an item without consulting or modifying the RateLimiter about the + // operation. + // + // This should ONLY be used when restoring a PriorityTable from a checkpoint. + tensorflow::Status InsertCheckpointItem(Item item); + + // Updates the priority or deletes items in this priority distribution. All + // operations in the arguments are applied in the order that they are listed. + // Different operations can be set at the same time. Ignores non existing keys + // but returns any other errors. The operations might be applied partially + // when an error occurs. + tensorflow::Status MutateItems(absl::Span updates, + absl::Span deletes); + + // Attempts to sample an item from this distribution with the sampling + // strategy passed in the constructor. We only allow the sample operation if + // the `rate_limiter_` allows it. If the item has reached + // `max_times_sampled_`, then we delete it before returning so it cannot be + // sampled again. + tensorflow::Status Sample(SampledItem* item); + + // Returns true iff the current state would allow for `num_samples` to be + // sampled. Dies if `num_samples` is < 1. + // + // TODO(b/153258711): This currently ignores max_size and max_times_sampled + // arguments to the PriorityTable, and will return True if e.g. there are + // 2 items in the table, max_times_sampled=1, and num_samples=3. + bool CanSample(int num_samples) const; + + // Returns true iff the current state would allow for `num_inserts` to be + // inserted. Dies if `num_inserts` is < 1. + // + // TODO(b/153258711): This currently ignores max_size and max_times_sampled + // arguments to the PriorityTable. + bool CanInsert(int num_inserts) const; + + // Appends the extension to the internal list. Note that this must be called + // before any other operation is called. If called when the number of items + // is non zero, death is triggered. + // + // Note! This method is not thread safe and caller is responsible for making + // sure that this method, nor any other method, is called concurrently. + void UnsafeAddExtension( + std::shared_ptr extension); + + // Registered table extensions. + const std::vector>& + extensions() const; + + // Lookup a single item. Returns true if found, else false. + bool Get(Key key, Item* item) ABSL_LOCKS_EXCLUDED(mu_); + + // Get pointer to `data_`. Must only be called by extensions while lock held. + const absl::flat_hash_map* RawLookup() + ABSL_ASSERT_SHARED_LOCK(mu_); + + // Removes all items and resets the RateLimiter to its initial state. + tensorflow::Status Reset(); + + // Generate a checkpoint from the PriorityTable's current state. + CheckpointAndChunks Checkpoint(); + + // Number of items in the priority distribution. + int64_t size() const; + + const std::string& name() const; + + // Metadata about the table. + TableInfo info() const; + + // Signature (if any) of the table. + const absl::optional& signature() const; + + // Cancels pending calls and marks object as closed. Object must be abandoned + // after `Close` called. + void Close(); + + private: + tensorflow::Status UpdateItem(Key key, double priority, bool diffuse) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Deletes the item associated with the key from `data_`, `sampler_` and + // `remover_`. Ignores the key if it cannot be found. + void DeleteItem(Key key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Distribution used for sampling. + std::shared_ptr sampler_ ABSL_GUARDED_BY(mu_); + + // Distribution used for removing. + std::shared_ptr remover_ ABSL_GUARDED_BY(mu_); + + // Bijection of key to item. Used for storing the chunks and timestep range of + // each item. + absl::flat_hash_map data_ ABSL_GUARDED_BY(mu_); + + // Maximum number of items that this container can hold. InsertOrAssign() + // respects this limit when inserting a new item. + const int64_t max_size_; + + // Maximum number of times an item can be sampled before it is deleted. + // A value <= 0 means there is no limit. + const int32_t max_times_sampled_; + + // Name of the table. + const std::string name_; + + // Controls what operations can proceed. A shared_ptr is used to allow the + // Python layer to interact with the object after it has been passed to the + // PriorityTable. + std::shared_ptr rate_limiter_ ABSL_GUARDED_BY(mu_); + + // Extensions implement hooks that are executed while holding `mu_` as part + // of insert, update or delete operation. + std::vector> extensions_ + ABSL_GUARDED_BY(mu_); + + // Synchronizes access to `sampler_`, `remover_`, 'rate_limiter_`, + // 'extensions_` and `data_`, + mutable absl::Mutex mu_; + + // Optional signature for data in the table. + const absl::optional signature_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_PRIORITY_TABLE_H_ diff --git a/reverb/cc/priority_table_item.h b/reverb/cc/priority_table_item.h new file mode 100644 index 0000000..d1f93b3 --- /dev/null +++ b/reverb/cc/priority_table_item.h @@ -0,0 +1,37 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_PRIORITY_TABLE_ITEM_H_ +#define REVERB_CC_PRIORITY_TABLE_ITEM_H_ + +#include +#include + +#include "reverb/cc/chunk_store.h" +#include "reverb/cc/schema.pb.h" + +namespace deepmind { +namespace reverb { + +// Used for representing items of the priority distribution. See +// PrioritizedItem in schema.proto for documentation. +struct PriorityTableItem { + PrioritizedItem item; + std::vector> chunks; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_PRIORITY_TABLE_ITEM_H_ diff --git a/reverb/cc/priority_table_test.cc b/reverb/cc/priority_table_test.cc new file mode 100644 index 0000000..9d67a65 --- /dev/null +++ b/reverb/cc/priority_table_test.cc @@ -0,0 +1,604 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/priority_table.h" + +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include +#include "absl/memory/memory.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/chunk_store.h" +#include "reverb/cc/distributions/fifo.h" +#include "reverb/cc/distributions/uniform.h" +#include "reverb/cc/platform/thread.h" +#include "reverb/cc/rate_limiter.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { +namespace { + +const absl::Duration kTimeout = absl::Milliseconds(250); + +using ::deepmind::reverb::testing::Partially; +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::SizeIs; + +MATCHER_P(HasItemKey, key, "") { return arg.item.key() == key; } + +PriorityTableItem MakeItem(uint64_t key, double priority, + const std::vector& sequences) { + PriorityTableItem item; + + std::vector data(sequences.size()); + for (int i = 0; i < sequences.size(); i++) { + data[i] = testing::MakeChunkData(key * 100 + i, sequences[i]); + item.chunks.push_back(std::make_shared(data[i])); + } + + item.item = testing::MakePrioritizedItem(key, priority, data); + + return item; +} + +PriorityTableItem MakeItem(uint64_t key, double priority) { + return MakeItem(key, priority, {testing::MakeSequenceRange(key * 100, 0, 1)}); +} + +std::unique_ptr MakeLimiter(int64_t min_size) { + return absl::make_unique(1.0, min_size, -DBL_MAX, DBL_MAX); +} + +std::unique_ptr MakeUniformTable(const std::string& name, + int64_t max_size = 1000, + int32_t max_times_sampled = 0) { + return absl::make_unique( + name, absl::make_unique(), + absl::make_unique(), max_size, max_times_sampled, + MakeLimiter(1)); +} + +TEST(PriorityTableTest, SetsName) { + auto first = MakeUniformTable("first"); + auto second = MakeUniformTable("second"); + EXPECT_EQ(first->name(), "first"); + EXPECT_EQ(second->name(), "second"); +} + +TEST(PriorityTableTest, CopyAfterInsert) { + auto table = MakeUniformTable("dist"); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 123))); + + auto items = table->Copy(); + ASSERT_THAT(items, SizeIs(1)); + EXPECT_THAT( + items[0].item, + Partially(testing::EqualsProto("key: 3 times_sampled: 0 priority: 123"))); +} + +TEST(PriorityTableTest, CopySubset) { + auto table = MakeUniformTable("dist"); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 123))); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(4, 123))); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(5, 123))); + EXPECT_THAT(table->Copy(1), SizeIs(1)); + EXPECT_THAT(table->Copy(2), SizeIs(2)); +} + +TEST(PriorityTableTest, InsertOrAssignOverwrites) { + auto table = MakeUniformTable("dist"); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 123))); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 456))); + + auto items = table->Copy(); + ASSERT_THAT(items, SizeIs(1)); + EXPECT_EQ(items[0].item.priority(), 456); +} + +TEST(PriorityTableTest, UpdatesAreAppliedPartially) { + auto table = MakeUniformTable("dist"); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 123))); + TF_EXPECT_OK(table->MutateItems( + { + testing::MakeKeyWithPriority(5, 55), + testing::MakeKeyWithPriority(3, 456), + }, + {})); + + auto items = table->Copy(); + ASSERT_THAT(items, SizeIs(1)); + EXPECT_EQ(items[0].item.priority(), 456); +} + +TEST(PriorityTableTest, DeletesAreAppliedPartially) { + auto table = MakeUniformTable("dist"); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 123))); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(7, 456))); + TF_EXPECT_OK(table->MutateItems({}, {5, 3})); + EXPECT_THAT(table->Copy(), ElementsAre(HasItemKey(7))); +} + +TEST(PriorityTableTest, SampleBlocksWhenNotEnoughItems) { + auto table = MakeUniformTable("dist"); + + absl::Notification notification; + auto sample_thread = internal::StartThread("", [&table, ¬ification] { + PriorityTable::SampledItem item; + TF_EXPECT_OK(table->Sample(&item)); + notification.Notify(); + }); + + EXPECT_FALSE(notification.WaitForNotificationWithTimeout(kTimeout)); + + // Inserting an item should allow the call to complete. + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 123))); + EXPECT_TRUE(notification.WaitForNotificationWithTimeout(kTimeout)); + + sample_thread = nullptr; // Joins the thread. +} + +TEST(PriorityTableTest, SampleMatchesInsert) { + auto table = MakeUniformTable("dist"); + + PriorityTable::Item item = MakeItem(3, 123); + TF_EXPECT_OK(table->InsertOrAssign(item)); + + PriorityTable::SampledItem sample; + TF_EXPECT_OK(table->Sample(&sample)); + item.item.set_times_sampled(1); + sample.item.clear_inserted_at(); + EXPECT_THAT(sample.item, testing::EqualsProto(item.item)); + EXPECT_EQ(sample.chunks, item.chunks); + EXPECT_EQ(sample.probability, 1); +} + +TEST(PriorityTableTest, SampleIncrementsSampleTimes) { + auto table = MakeUniformTable("dist"); + + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 123))); + + PriorityTable::SampledItem item; + EXPECT_EQ(table->Copy()[0].item.times_sampled(), 0); + TF_EXPECT_OK(table->Sample(&item)); + EXPECT_EQ(table->Copy()[0].item.times_sampled(), 1); + TF_EXPECT_OK(table->Sample(&item)); + EXPECT_EQ(table->Copy()[0].item.times_sampled(), 2); +} + +TEST(PriorityTableTest, MaxTimesSampledIsRespected) { + auto table = MakeUniformTable("dist", 10, 2); + + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 123))); + + PriorityTable::SampledItem item; + EXPECT_EQ(table->Copy()[0].item.times_sampled(), 0); + TF_ASSERT_OK(table->Sample(&item)); + EXPECT_EQ(table->Copy()[0].item.times_sampled(), 1); + TF_ASSERT_OK(table->Sample(&item)); + EXPECT_THAT(table->Copy(), IsEmpty()); +} + +TEST(PriorityTableTest, InsertDeletesWhenOverflowing) { + auto table = MakeUniformTable("dist", 10); + + for (int i = 0; i < 15; i++) { + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(i, 123))); + } + auto items = table->Copy(); + EXPECT_THAT(items, SizeIs(10)); + for (const PriorityTable::Item& item : items) { + EXPECT_GE(item.item.key(), 5); + EXPECT_LT(item.item.key(), 15); + } +} + +TEST(PriorityTableTest, ConcurrentCalls) { + auto table = MakeUniformTable("dist", 1000); + + std::vector> bundle; + std::atomic count(0); + for (PriorityTable::Key i = 0; i < 1000; i++) { + bundle.push_back(internal::StartThread("", [i, &table, &count] { + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(i, 123))); + PriorityTable::SampledItem item; + TF_EXPECT_OK(table->Sample(&item)); + TF_EXPECT_OK( + table->MutateItems({testing::MakeKeyWithPriority(i, 456)}, {i})); + count++; + })); + } + bundle.clear(); // Joins all threads. + EXPECT_EQ(count, 1000); +} + +TEST(PriorityTableTest, UseAsQueue) { + PriorityTable queue( + /*name=*/"queue", + /*sampler=*/absl::make_unique(), + /*remover=*/absl::make_unique(), + /*max_size=*/10, + /*max_times_sampled=*/1, + absl::make_unique( + /*samples_per_insert=*/1.0, + /*min_size_to_sample=*/1, + /*min_diff=*/0, + /*max_diff=*/10.0)); + for (int i = 0; i < 10; i++) { + TF_EXPECT_OK(queue.InsertOrAssign(MakeItem(i, 123))); + } + + // This should now be blocked + absl::Notification insert; + auto insert_thread = internal::StartThread("", [&] { + TF_EXPECT_OK(queue.InsertOrAssign(MakeItem(10, 123))); + insert.Notify(); + }); + + EXPECT_FALSE(insert.WaitForNotificationWithTimeout(kTimeout)); + + for (int i = 0; i < 11; i++) { + PriorityTable::SampledItem item; + TF_EXPECT_OK(queue.Sample(&item)); + EXPECT_THAT(item, HasItemKey(i)); + } + + EXPECT_TRUE(insert.WaitForNotificationWithTimeout(kTimeout)); + + insert_thread = nullptr; // Joins the thread. + + EXPECT_EQ(queue.size(), 0); + + // Sampling should now be blocked. + absl::Notification sample; + auto sample_thread = internal::StartThread("", [&] { + PriorityTable::SampledItem item; + TF_EXPECT_OK(queue.Sample(&item)); + sample.Notify(); + }); + + EXPECT_FALSE(sample.WaitForNotificationWithTimeout(kTimeout)); + + // Inserting a new item should result in it being sampled straight away. + TF_EXPECT_OK(queue.InsertOrAssign(MakeItem(100, 123))); + EXPECT_TRUE(sample.WaitForNotificationWithTimeout(kTimeout)); + + EXPECT_EQ(queue.size(), 0); + + sample_thread = nullptr; // Joins the thread. +} + +TEST(PriorityTableTest, ConcurrentInsertOfTheSameKey) { + PriorityTable table( + /*name=*/"dist", + /*sampler=*/absl::make_unique(), + /*remover=*/absl::make_unique(), + /*max_size=*/1000, + /*max_times_sampled=*/0, + absl::make_unique( + /*samples_per_insert=*/1.0, + /*min_size_to_sample=*/1, + /*min_diff=*/-1, + /*max_diff=*/1)); + + // Insert one item to make new inserts block. + TF_ASSERT_OK(table.InsertOrAssign(MakeItem(1, 123))); // diff = 1.0 + + std::vector> bundle; + + // Try to insert the same item 10 times. All should be blocked. + std::atomic count(0); + for (int i = 0; i < 10; i++) { + bundle.push_back(internal::StartThread("", [&] { + TF_EXPECT_OK(table.InsertOrAssign(MakeItem(10, 123))); + count++; + })); + } + + EXPECT_EQ(count, 0); + + // Making a single sample should unblock one of the inserts. The other inserts + // are now updates but they are still waiting for their right to insert. + PriorityTable::SampledItem item; + TF_EXPECT_OK(table.Sample(&item)); + + // Sampling once more would unblock one of the inserts, it will then see that + // it is now an update and not use its right to insert. Once it releases the + // lock the same process will follow for all the remaining inserts. + TF_EXPECT_OK(table.Sample(&item)); + + bundle.clear(); // Joins all threads. + + EXPECT_EQ(count, 10); + EXPECT_EQ(table.size(), 2); +} + +TEST(PriorityTableTest, CloseCancelsPendingCalls) { + PriorityTable table( + /*name=*/"dist", + /*sampler=*/absl::make_unique(), + /*remover=*/absl::make_unique(), + /*max_size=*/1000, + /*max_times_sampled=*/0, + absl::make_unique( + /*samples_per_insert=*/1.0, + /*min_size_to_sample=*/1, + /*min_diff=*/-1, + /*max_diff=*/1)); + + // Insert two item to make new inserts block. + TF_ASSERT_OK(table.InsertOrAssign(MakeItem(1, 123))); // diff = 1.0 + + tensorflow::Status status; + absl::Notification notification; + auto thread = internal::StartThread("", [&] { + status = table.InsertOrAssign(MakeItem(10, 123)); + notification.Notify(); + }); + + EXPECT_FALSE(notification.WaitForNotificationWithTimeout(kTimeout)); + + table.Close(); + + EXPECT_TRUE(notification.WaitForNotificationWithTimeout(kTimeout)); + EXPECT_EQ(status.code(), tensorflow::error::CANCELLED); + + thread = nullptr; // Joins the thread. +} + +TEST(PriorityTableTest, ResetResetsRateLimiter) { + PriorityTable table( + /*name=*/"dist", + /*sampler=*/absl::make_unique(), + /*remover=*/absl::make_unique(), + /*max_size=*/1000, + /*max_times_sampled=*/0, + absl::make_unique( + /*samples_per_insert=*/1.0, + /*min_size_to_sample=*/1, + /*min_diff=*/-1, + /*max_diff=*/1)); + + // Insert two item to make new inserts block. + TF_ASSERT_OK(table.InsertOrAssign(MakeItem(1, 123))); // diff = 1.0 + + absl::Notification notification; + auto thread = internal::StartThread("", [&] { + TF_ASSERT_OK(table.InsertOrAssign(MakeItem(10, 123))); + notification.Notify(); + }); + + EXPECT_FALSE(notification.WaitForNotificationWithTimeout(kTimeout)); + + // Resetting the table should unblock new inserts. + TF_ASSERT_OK(table.Reset()); + + EXPECT_TRUE(notification.WaitForNotificationWithTimeout(kTimeout)); + + thread = nullptr; // Joins the thread. +} + +TEST(PriorityTableTest, ResetClearsAllData) { + auto table = MakeUniformTable("dist"); + TF_ASSERT_OK(table->InsertOrAssign(MakeItem(1, 123))); + EXPECT_EQ(table->size(), 1); + TF_ASSERT_OK(table->Reset()); + EXPECT_EQ(table->size(), 0); +} + +TEST(PriorityTableTest, ResetWhileConcurrentCalls) { + auto table = MakeUniformTable("dist"); + std::vector> bundle; + for (PriorityTable::Key i = 0; i < 1000; i++) { + bundle.push_back(internal::StartThread("", [i, &table] { + if (i % 123 == 0) TF_EXPECT_OK(table->Reset()); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(i, 123))); + TF_EXPECT_OK( + table->MutateItems({testing::MakeKeyWithPriority(i, 456)}, {i})); + })); + } + bundle.clear(); // Joins all threads. +} + +TEST(PriorityTableTest, CheckpointOrderItems) { + auto table = MakeUniformTable("dist"); + + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(1, 123))); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 125))); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(2, 124))); + + auto checkpoint = table->Checkpoint(); + EXPECT_THAT(checkpoint.checkpoint.items(), + ElementsAre(Partially(testing::EqualsProto("key: 1")), + Partially(testing::EqualsProto("key: 3")), + Partially(testing::EqualsProto("key: 2")))); +} + +TEST(PriorityTableTest, CheckpointSanityCheck) { + PriorityTable table("dist", absl::make_unique(), + absl::make_unique(), 10, 1, + absl::make_unique(1.0, 3, -10, 7)); + + TF_EXPECT_OK(table.InsertOrAssign(MakeItem(1, 123))); + + auto checkpoint = table.Checkpoint(); + + PriorityTableCheckpoint want; + want.set_table_name("dist"); + want.set_max_size(10); + want.set_max_times_sampled(1); + want.add_items()->set_key(1); + want.mutable_rate_limiter()->set_samples_per_insert(1.0); + want.mutable_rate_limiter()->set_min_size_to_sample(3); + want.mutable_rate_limiter()->set_min_diff(-10); + + EXPECT_THAT(checkpoint.checkpoint, + Partially(testing::EqualsProto("table_name: 'dist' " + "max_size: 10 " + "max_times_sampled: 1 " + "items: { key: 1 } " + "rate_limiter: { " + " samples_per_insert: 1.0" + " min_size_to_sample: 3" + " min_diff: -10" + " max_diff: 7" + " sample_count: 0" + " insert_count: 1" + "} " + "sampler: { uniform: true } " + "remover: { fifo: true } "))); +} + +TEST(PriorityTableTest, BlocksSamplesWhenSizeToSmallDueToAutoDelete) { + PriorityTable table( + /*name=*/"dist", + /*sampler=*/absl::make_unique(), + /*remover=*/absl::make_unique(), + /*max_size=*/10, + /*max_times_sampled=*/2, + absl::make_unique( + /*samples_per_insert=*/1.0, + /*min_size_to_sample=*/3, + /*min_diff=*/0, + /*max_diff=*/5)); + TF_EXPECT_OK(table.InsertOrAssign(MakeItem(1, 1))); + TF_EXPECT_OK(table.InsertOrAssign(MakeItem(2, 1))); + TF_EXPECT_OK(table.InsertOrAssign(MakeItem(3, 1))); + + // It should be fine to sample now as the table has been reached its min size. + PriorityTable::SampledItem sample_1; + TF_EXPECT_OK(table.Sample(&sample_1)); + EXPECT_THAT(sample_1, HasItemKey(1)); + + // A second sample should be fine since the table is still large enough. + PriorityTable::SampledItem sample_2; + TF_EXPECT_OK(table.Sample(&sample_2)); + EXPECT_THAT(sample_2, HasItemKey(1)); + + // Due to max_times_sampled, the table should have one item less which should + // block more samples from proceeding. + absl::Notification notification; + auto sample_thread = internal::StartThread("", [&] { + PriorityTable::SampledItem sample; + TF_EXPECT_OK(table.Sample(&sample)); + notification.Notify(); + }); + EXPECT_FALSE(notification.WaitForNotificationWithTimeout(kTimeout)); + + // Inserting a new item should unblock the sampling. + TF_EXPECT_OK(table.InsertOrAssign(MakeItem(4, 1))); + EXPECT_TRUE(notification.WaitForNotificationWithTimeout(kTimeout)); + + sample_thread = nullptr; // Joins the thread. +} + +TEST(PriorityTableTest, BlocksSamplesWhenSizeToSmallDueToExplicitDelete) { + PriorityTable table( + /*name=*/"dist", + /*sampler=*/absl::make_unique(), + /*remover=*/absl::make_unique(), + /*max_size=*/10, + /*max_times_sampled=*/-1, + absl::make_unique( + /*samples_per_insert=*/1.0, + /*min_size_to_sample=*/3, + /*min_diff=*/0, + /*max_diff=*/5)); + TF_EXPECT_OK(table.InsertOrAssign(MakeItem(1, 1))); + TF_EXPECT_OK(table.InsertOrAssign(MakeItem(2, 1))); + TF_EXPECT_OK(table.InsertOrAssign(MakeItem(3, 1))); + + // It should be fine to sample now as the table has been reached its min size. + PriorityTable::SampledItem sample_1; + TF_EXPECT_OK(table.Sample(&sample_1)); + EXPECT_THAT(sample_1, HasItemKey(1)); + + // Deleting an item will make the table too small to allow samples. + TF_EXPECT_OK(table.MutateItems({}, {1})); + + absl::Notification notification; + auto sample_thread = internal::StartThread("", [&] { + PriorityTable::SampledItem sample; + TF_EXPECT_OK(table.Sample(&sample)); + notification.Notify(); + }); + EXPECT_FALSE(notification.WaitForNotificationWithTimeout(kTimeout)); + + // Inserting a new item should unblock the sampling. + TF_EXPECT_OK(table.InsertOrAssign(MakeItem(4, 1))); + EXPECT_TRUE(notification.WaitForNotificationWithTimeout(kTimeout)); + + sample_thread = nullptr; // Joins the thread. + + // And any new samples should be fine. + PriorityTable::SampledItem sample_2; + TF_EXPECT_OK(table.Sample(&sample_2)); + EXPECT_THAT(sample_2, HasItemKey(2)); +} + +TEST(PriorityTableTest, GetExistingItem) { + auto table = MakeUniformTable("dist"); + + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(1, 1))); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(2, 1))); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 1))); + + PriorityTableItem item; + EXPECT_TRUE(table->Get(2, &item)); + EXPECT_THAT(item, HasItemKey(2)); +} + +TEST(PriorityTableTest, GetMissingItem) { + auto table = MakeUniformTable("dist"); + + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(1, 1))); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(3, 1))); + + PriorityTableItem item; + EXPECT_FALSE(table->Get(2, &item)); +} + +TEST(PriorityTableTest, SampleSetsTableSize) { + auto table = MakeUniformTable("dist"); + + for (int i = 1; i <= 10; i++) { + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(i, 1))); + PriorityTable::SampledItem sample; + TF_EXPECT_OK(table->Sample(&sample)); + EXPECT_EQ(sample.table_size, i); + } +} + +TEST(PriorityTableDeathTest, DiesIfUnsafeAddExtensionCalledWhenNonEmpty) { + auto table = MakeUniformTable("dist"); + TF_EXPECT_OK(table->InsertOrAssign(MakeItem(1, 1))); + ASSERT_DEATH(table->UnsafeAddExtension(nullptr), ""); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/rate_limiter.cc b/reverb/cc/rate_limiter.cc new file mode 100644 index 0000000..645a47c --- /dev/null +++ b/reverb/cc/rate_limiter.cc @@ -0,0 +1,208 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/rate_limiter.h" + +#include +#include + +#include +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/priority_table.h" +#include "reverb/cc/schema.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" + +namespace deepmind { +namespace reverb { +namespace { + +bool WaitAndLog(absl::CondVar* cv, absl::Mutex* mu, absl::Time deadline, + absl::string_view call_description, + absl::Duration log_after = absl::Seconds(10)) { + const auto log_deadline = absl::Now() + log_after; + if (log_deadline < deadline) { + if (!cv->WaitWithDeadline(mu, log_deadline)) { + return false; + } + REVERB_LOG(REVERB_INFO) << call_description << " blocked for " << log_after; + } + + return cv->WaitWithDeadline(mu, deadline); +} +} // namespace + +RateLimiter::RateLimiter(double samples_per_insert, int64_t min_size_to_sample, + double min_diff, double max_diff) + : samples_per_insert_(samples_per_insert), + min_diff_(min_diff), + max_diff_(max_diff), + min_size_to_sample_(min_size_to_sample), + inserts_(0), + samples_(0), + deletes_(0), + cancelled_(false) { + REVERB_CHECK_GT(min_size_to_sample, 0); +} + +RateLimiter::RateLimiter(const RateLimiterCheckpoint& checkpoint) + : RateLimiter(/*samples_per_insert=*/checkpoint.samples_per_insert(), + /*min_size_to_sample=*/ + checkpoint.min_size_to_sample(), + /*min_diff=*/checkpoint.min_diff(), + /*max_diff=*/checkpoint.max_diff()) { + inserts_ = checkpoint.insert_count(); + samples_ = checkpoint.sample_count(); + deletes_ = checkpoint.delete_count(); +} + +tensorflow::Status RateLimiter::RegisterPriorityTable( + PriorityTable* priority_table) { + if (priority_table_) { + return tensorflow::errors::FailedPrecondition( + "Attempting to registering a priority table ", priority_table, + " (name: ", priority_table->name(), ") with RateLimiter when is ", + "already registered with this limiter: ", priority_table_, + " (name: ", priority_table_->name(), ")"); + } + priority_table_ = priority_table; + return tensorflow::Status::OK(); +} + +void RateLimiter::UnregisterPriorityTable(absl::Mutex* mu, + PriorityTable* table) { + // Keep priority_table_registered_ at its current value to ensure that + // no one else tries to access state associated with a table that no longer + // exists. + REVERB_CHECK_EQ(table, priority_table_) + << "The wrong PriorityTable attempted to unregister this rate limiter."; + absl::MutexLock lock(mu); + Reset(mu); + priority_table_ = nullptr; +} + +tensorflow::Status RateLimiter::AwaitCanInsert(absl::Mutex* mu, + absl::Duration timeout) { + const auto start = absl::Now(); + const auto deadline = start + timeout; + while (!cancelled_ && !CanInsert(mu, 1)) { + if (WaitAndLog(&can_insert_cv_, mu, deadline, "Insert call")) { + return tensorflow::errors::DeadlineExceeded( + "timeout exceeded before right to insert was acquired."); + } + } + TF_RETURN_IF_ERROR(CheckIfCancelled()); + + return tensorflow::Status::OK(); +} + +void RateLimiter::Insert(absl::Mutex* mu) { + inserts_++; + MaybeSignalCondVars(mu); +} + +void RateLimiter::Delete(absl::Mutex* mu) { + deletes_++; + MaybeSignalCondVars(mu); +} + +void RateLimiter::Reset(absl::Mutex* mu) { + inserts_ = 0; + samples_ = 0; + deletes_ = 0; + MaybeSignalCondVars(mu); +} + +tensorflow::Status RateLimiter::AwaitAndFinalizeSample(absl::Mutex* mu, + absl::Duration timeout) { + const auto start = absl::Now(); + const auto deadline = start + timeout; + while (!cancelled_ && !CanSample(mu, 1)) { + if (WaitAndLog(&can_sample_cv_, mu, deadline, "Sample call")) { + return tensorflow::errors::DeadlineExceeded( + "timeout exceeded before right to sample was acquired."); + } + } + TF_RETURN_IF_ERROR(CheckIfCancelled()); + + samples_++; + MaybeSignalCondVars(mu); + return tensorflow::Status::OK(); +} + +bool RateLimiter::CanSample(absl::Mutex*, int num_samples) const { + REVERB_CHECK_GT(num_samples, 0); + if (inserts_ - deletes_ < min_size_to_sample_) { + return false; + } + double diff = inserts_ * samples_per_insert_ - samples_ - num_samples; + return diff >= min_diff_; +} + +bool RateLimiter::CanInsert(absl::Mutex*, int num_inserts) const { + REVERB_CHECK_GT(num_inserts, 0); + // Until the min size is reached inserts are free to progress. + if (inserts_ + num_inserts - deletes_ <= min_size_to_sample_) { + return true; + } + + double diff = (num_inserts + inserts_) * samples_per_insert_ - samples_; + return diff <= max_diff_; +} + +void RateLimiter::Cancel(absl::Mutex*) { + cancelled_ = true; + can_insert_cv_.SignalAll(); + can_sample_cv_.SignalAll(); +} + +RateLimiterCheckpoint RateLimiter::CheckpointReader(absl::Mutex*) const { + RateLimiterCheckpoint checkpoint; + checkpoint.set_samples_per_insert(samples_per_insert_); + checkpoint.set_min_diff(min_diff_); + checkpoint.set_max_diff(max_diff_); + checkpoint.set_min_size_to_sample(min_size_to_sample_); + checkpoint.set_sample_count(samples_); + checkpoint.set_insert_count(inserts_); + checkpoint.set_delete_count(deletes_); + + return checkpoint; +} + +RateLimiterInfo RateLimiter::info() const { + RateLimiterInfo info_proto; + info_proto.set_samples_per_insert(samples_per_insert_); + info_proto.set_min_diff(min_diff_); + info_proto.set_max_diff(max_diff_); + info_proto.set_min_size_to_sample(min_size_to_sample_); + return info_proto; +} + +tensorflow::Status RateLimiter::CheckIfCancelled() const { + if (!cancelled_) return tensorflow::Status::OK(); + return tensorflow::errors::Cancelled("RateLimiter has been cancelled"); +} + +void RateLimiter::MaybeSignalCondVars(absl::Mutex* mu) { + if (CanInsert(mu, 1)) can_insert_cv_.Signal(); + if (CanSample(mu, 1)) can_sample_cv_.Signal(); +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/rate_limiter.h b/reverb/cc/rate_limiter.h new file mode 100644 index 0000000..33e6299 --- /dev/null +++ b/reverb/cc/rate_limiter.h @@ -0,0 +1,150 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_RATE_LIMITER_H_ +#define REVERB_CC_RATE_LIMITER_H_ + +#include + +#include +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "reverb/cc/checkpointing/checkpoint.pb.h" +#include "reverb/cc/schema.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +class PriorityTable; + +constexpr absl::Duration kDefaultTimeout = absl::InfiniteDuration(); + +// RateLimiter manages the data throughput for a PriorityTable by blocking +// sample or insert calls if the ratio between the two deviates too much from +// the ratio specified by `samples_per_insert`. +class RateLimiter { + public: + RateLimiter(double samples_per_insert, int64_t min_size_to_sample, + double min_diff, double max_diff); + + // Construct and restore a RateLimiter from a previous checkpoint. + explicit RateLimiter(const RateLimiterCheckpoint& checkpoint); + + // Waits until the insert operation can proceed without violating the + // conditions of the rate limiter. + // + // The state is not modified as the caller must first check that the operation + // is still an insert op (while waiting the item may be inserted by another + // thread and thus the operation now is an update). If the operation remains + // an insert then `Insert` must be called to commit the state change. + tensorflow::Status AwaitCanInsert(absl::Mutex* mu, + absl::Duration timeout = kDefaultTimeout) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Waits until the sample operation can proceed without violating the + // conditions of the rate limiter. If the condition is fulfilled before the + // timeout expires or `Cancel` called then the state is updated. + tensorflow::Status AwaitAndFinalizeSample( + absl::Mutex* mu, absl::Duration timeout = kDefaultTimeout) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Register that an item has been inserted into the table. Caller must call + // `AwaitCanInsert` before calling this method without releasing the lock in + // between. + void Insert(absl::Mutex* mu) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Register that an item have been deleted from the table. + void Delete(absl::Mutex* mu) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Register that the table has been fully reset. + void Reset(absl::Mutex* mu) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Unblocks any `Await` calls with a Cancelled-status. + void Cancel(absl::Mutex* mu) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Returns true iff the current state would allow for `num_samples` to be + // sampled. Dies if `num_samples` is < 1. + bool CanSample(absl::Mutex* mu, int num_samples) const + ABSL_SHARED_LOCKS_REQUIRED(mu); + + // Returns true iff the current state would allow for `num_inserts` to be + // inserted. Dies if `num_inserts` is < 1. + bool CanInsert(absl::Mutex* mu, int num_inserts) const + ABSL_SHARED_LOCKS_REQUIRED(mu); + + // Creates a checkpoint of the current state for the rate limiter. + RateLimiterCheckpoint CheckpointReader(absl::Mutex* mu) const + ABSL_SHARED_LOCKS_REQUIRED(mu); + + // Configuration details of the limiter. + RateLimiterInfo info() const; + + private: + friend class PriorityTable; + // PriorityTable calls these methods on construction and destruction. + tensorflow::Status RegisterPriorityTable(PriorityTable* table); + void UnregisterPriorityTable(absl::Mutex* mu, PriorityTable* table) + ABSL_LOCKS_EXCLUDED(mu); + + // Checks if sample and insert operations can proceed and if so calls `Signal` + // on respective `CondVar` + void MaybeSignalCondVars(absl::Mutex* mu) ABSL_SHARED_LOCKS_REQUIRED(mu); + + // Returns Cancelled-status if `Cancel` have been called. + tensorflow::Status CheckIfCancelled() const; + + // Pointer to the priority table. We expect this to be available (if set), + // since it's set by a PriorityTable calling RegisterPriorityTable(this) after + // it stores a shared_ptr to this RateLimiter;. + PriorityTable* priority_table_ = nullptr; + + // The desired ratio between sample ops and insert operations. This can be + // interpreted as the average number of times each item is sampled during + // its total lifetime. + const double samples_per_insert_; + + // The minimum and maximum values the cursor is allowed to reach. The cursor + // value is calculated as `insert_count_ * samples_per_insert_ - + // sample_count_`. If the value would go beyond these limits then the call is + // blocked until it can proceed without violating the constraints. + const double min_diff_; + const double max_diff_; + + // The minimum number of items that must exist in the distribution for samples + // to be allowed. + const int64_t min_size_to_sample_; + + // Total number of items inserted into table. + int64_t inserts_; + + // Total number of times any item has been sampled from the table. + int64_t samples_; + + // Total number of items that has been deleted from the table. + int64_t deletes_; + + // Whether `Cancel` has been called. + bool cancelled_; + + // Signal called on respective cv if operation can proceed after state change. + absl::CondVar can_insert_cv_; + absl::CondVar can_sample_cv_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_RATE_LIMITER_H_ diff --git a/reverb/cc/rate_limiter_test.cc b/reverb/cc/rate_limiter_test.cc new file mode 100644 index 0000000..baf9e2e --- /dev/null +++ b/reverb/cc/rate_limiter_test.cc @@ -0,0 +1,446 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/rate_limiter.h" + +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "reverb/cc/distributions/uniform.h" +#include "reverb/cc/platform/thread.h" +#include "reverb/cc/priority_table.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { + +namespace { + +using ::deepmind::reverb::testing::EqualsProto; +using ::deepmind::reverb::testing::Partially; + +constexpr absl::Duration kTimeout = absl::Milliseconds(100); + +std::unique_ptr MakeTable(const std::string &name, + std::shared_ptr limiter) { + return absl::make_unique( + name, absl::make_unique(), + absl::make_unique(), 10000, 0, std::move(limiter)); +} + +TEST(RateLimiterTest, BlocksSamplesUntilMinInsertsReached) { + auto limiter = + std::make_shared(/*samples_per_insert=*/1.0, + /*min_size_to_sample=*/2, /*min_diff=*/-1.0, + /*max_diff=*/1.0); + auto table = MakeTable("table", limiter); + absl::Notification notification; + absl::Mutex mu; + auto thread = internal::StartThread("", [&] { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu)); + notification.Notify(); + }); + + // No inserts yet so the sample should be blocked. + EXPECT_FALSE(notification.WaitForNotificationWithTimeout(kTimeout)); + + // 1 insert is not enough so the sample should still be blocked. + { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu)); + limiter->Insert(&mu); + } + EXPECT_FALSE(notification.WaitForNotificationWithTimeout(kTimeout)); + + // 2 inserts is enough, the sampling should now be unblocked. + { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu)); + limiter->Insert(&mu); + } + EXPECT_TRUE(notification.WaitForNotificationWithTimeout(kTimeout)); + + thread = nullptr; // Joins the thread. +} + +TEST(RateLimiterTest, OperationsWithinTheBufferAreNotBlocked) { + auto limiter = + std::make_shared(/*samples_per_insert=*/1.5, + /*min_size_to_sample=*/1, /*min_diff=*/-3.0, + /*max_diff=*/3.1); + auto table = MakeTable("table", limiter); + absl::Mutex mu; + + // First insert is always fine because min_size_to_sample is not yet + // reached. The "diff" is now 1.5. + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu)); + limiter->Insert(&mu); + + // Second insert should be fine as the "diff" after the insert is 3.0 which + // is part of the buffer range. + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu)); + limiter->Insert(&mu); + + // Sample calls should not be blocked as long as diff is >= -3.0. + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu)); // diff = 2.0. + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu)); // diff = 1.0. + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu)); // diff = 0.0. + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu)); // diff = -1.0. + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu)); // diff = -2.0. + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu)); // diff = -3.0. +} + +TEST(RateLimiterTest, UnblocksCallsWhenCancelled) { + auto limiter = + std::make_shared(/*samples_per_insert=*/1.0, + /*min_size_to_sample=*/2, /*min_diff=*/-1.0, + /*max_diff=*/1.0); + auto table = MakeTable("table", limiter); + absl::Mutex mu; + absl::Notification notification; + auto thread = internal::StartThread("", [&] { + absl::WriterMutexLock lock(&mu); + EXPECT_EQ(limiter->AwaitAndFinalizeSample(&mu).code(), + tensorflow::error::CANCELLED); + notification.Notify(); + }); + + EXPECT_FALSE(notification.WaitForNotificationWithTimeout(kTimeout)); + + { + absl::WriterMutexLock lock(&mu); + limiter->Cancel(&mu); + } + EXPECT_TRUE(notification.WaitForNotificationWithTimeout(kTimeout)); + + thread = nullptr; // Joins the thread. +} + +TEST(RateLimiterTest, BlocksCallsThatExceedsTheMinMaxLimits) { + auto limiter = + std::make_shared(/*samples_per_insert=*/1.5, + /*min_size_to_sample=*/2, /*min_diff=*/-1.0, + /*max_diff=*/3.0); + auto table = MakeTable("table", limiter); + absl::Mutex mu; + + std::vector> bundle; + + absl::Notification sample; + bundle.push_back(internal::StartThread("", [&] { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu)); + sample.Notify(); + })); + + // No inserts yet so the sample should be blocked. + EXPECT_FALSE(sample.WaitForNotificationWithTimeout(kTimeout)); // diff = 0.0 + + // 1 insert is not enough so the sample should still be blocked. + { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu)); // diff = 1.5 + limiter->Insert(&mu); + } + EXPECT_FALSE(sample.WaitForNotificationWithTimeout(kTimeout)); + + // 2 inserts is enough, the sampling should now be unblocked. + { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu)); + limiter->Insert(&mu); // diff = 3.0 + } + + EXPECT_TRUE(sample.WaitForNotificationWithTimeout(kTimeout)); // diff = 2.0 + + // Inserts should now be blocked as it should lead to diff = 3.5. + absl::Notification insert; + bundle.push_back(internal::StartThread("", [&] { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu)); + limiter->Insert(&mu); + insert.Notify(); + })); + + EXPECT_FALSE(insert.WaitForNotificationWithTimeout(kTimeout)); + + // But adding a new sample should allow it to proceed. + { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu)); + } + + EXPECT_TRUE(insert.WaitForNotificationWithTimeout(kTimeout)); + + bundle.clear(); // Joins all threads. +} + +TEST(RateLimiterTest, CanSample) { + auto limiter = + std::make_shared(/*samples_per_insert=*/1.0, + /*min_size_to_sample=*/1, /*min_diff=*/-1.0, + /*max_diff=*/1.0); + auto table = MakeTable("table", limiter); + absl::Mutex mu; + absl::WriterMutexLock lock(&mu); + + // Min size should not have been reached so no samples should be allowed. + EXPECT_FALSE(limiter->CanSample(&mu, 1)); + + // Insert a single item. + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu, kTimeout)); + limiter->Insert(&mu); + + // It should now be possible to sample at most two items. + EXPECT_TRUE(limiter->CanSample(&mu, 1)); // diff = 0. + EXPECT_TRUE(limiter->CanSample(&mu, 2)); // diff = -1.0. + EXPECT_FALSE(limiter->CanSample(&mu, 3)); // diff = -2.0. +} + +TEST(RateLimiterTest, CanInsert) { + auto limiter = + std::make_shared(/*samples_per_insert=*/1.5, + /*min_size_to_sample=*/2, /*min_diff=*/0.0, + /*max_diff=*/5.0); + auto table = MakeTable("table", limiter); + absl::Mutex mu; + absl::WriterMutexLock lock(&mu); + + // The min size allows for the first two inserts and the error buffer allows + // for one additional insert. + EXPECT_TRUE(limiter->CanInsert(&mu, 1)); // diff = 1.5 (lt min size). + EXPECT_TRUE(limiter->CanInsert(&mu, 2)); // diff = 3.0 (eq min size). + EXPECT_TRUE(limiter->CanInsert(&mu, 3)); // diff = 4.5. + EXPECT_FALSE(limiter->CanInsert(&mu, 4)); // diff = 6.0 + + // Do the inserts. + for (int i = 0; i < 3; i++) { + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu, kTimeout)); + limiter->Insert(&mu); + } + + // No inserts should be allowed now. + EXPECT_FALSE(limiter->CanInsert(&mu, 1)); // diff = 6.0 + + // Move the cursor by sampling two items. + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu, kTimeout)); // diff = 3.5 + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu, kTimeout)); // diff = 2.5 + + // One more sample should now be allowed. + EXPECT_TRUE(limiter->CanInsert(&mu, 1)); // diff = 4.0. + EXPECT_FALSE(limiter->CanInsert(&mu, 2)); // diff = 5.5. +} + +TEST(RateLimiterTest, CheckpointSetsBasicOptions) { + auto limiter = + std::make_shared(/*samples_per_insert=*/1.5, + /*min_size_to_sample=*/2, /*min_diff=*/0.0, + /*max_diff=*/5.0); + auto table = MakeTable("table", limiter); + absl::Mutex mu; + absl::WriterMutexLock lock(&mu); + EXPECT_THAT(limiter->CheckpointReader(&mu), + testing::EqualsProto("samples_per_insert: 1.5 min_diff: 0 " + "max_diff: 5 min_size_to_sample: 2")); +} + +TEST(RateLimiterTest, CheckpointSetsInsertAndDeleteAndSampleCount) { + auto limiter = + std::make_shared(/*samples_per_insert=*/1.5, + /*min_size_to_sample=*/2, /*min_diff=*/0.0, + /*max_diff=*/5.0); + auto table = MakeTable("table", limiter); + absl::Mutex mu; + absl::WriterMutexLock lock(&mu); + + EXPECT_THAT( + limiter->CheckpointReader(&mu), + Partially(testing::EqualsProto("sample_count: 0 insert_count: 0"))); + + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu, kTimeout)); + limiter->Insert(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu, kTimeout)); + limiter->Insert(&mu); + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu, kTimeout)); + limiter->Delete(&mu); + + EXPECT_THAT(limiter->CheckpointReader(&mu), + Partially(testing::EqualsProto( + "sample_count: 1 insert_count: 2 delete_count: 1"))); +} + +TEST(RateLimiterTest, CanBeRestoredFromCheckpoint) { + auto limiter = + std::make_shared(/*samples_per_insert=*/1.5, + /*min_size_to_sample=*/2, /*min_diff=*/0.0, + /*max_diff=*/5.0); + auto table = MakeTable("table", limiter); + absl::Mutex mu; + absl::WriterMutexLock lock(&mu); + + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu, kTimeout)); + limiter->Insert(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu, kTimeout)); + limiter->Insert(&mu); + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu, kTimeout)); + limiter->Delete(&mu); + + // Create a checkpoint and check its content. + auto checkpoint = limiter->CheckpointReader(&mu); + EXPECT_THAT(checkpoint, testing::EqualsProto("samples_per_insert: 1.5 " + "min_diff: 0 " + "max_diff: 5 " + "min_size_to_sample: 2 " + "sample_count: 1 " + "insert_count: 2 " + "delete_count: 1")); + + // Create a new RateLimiter from the checkpoint and verify that it behaves as + // expected and that checkpoints generated from the restored RateLimiter + // includes both new and inherited information. + auto restored = std::make_shared(checkpoint); + table = MakeTable("table", restored); + + TF_EXPECT_OK(restored->AwaitCanInsert(&mu, kTimeout)); + restored->Insert(&mu); + TF_EXPECT_OK(restored->AwaitAndFinalizeSample(&mu, kTimeout)); + + EXPECT_THAT(restored->CheckpointReader(&mu), + testing::EqualsProto("samples_per_insert: 1.5 " + "min_diff: 0 " + "max_diff: 5 " + "min_size_to_sample: 2 " + "sample_count: 2 " + "insert_count: 3 " + "delete_count: 1")); +} + +TEST(RateLimiterTest, UnblocksInsertsIfDeletedItemsBringsSizeBelowMinSize) { + auto limiter = + std::make_shared(/*samples_per_insert=*/5, + /*min_size_to_sample=*/2, /*min_diff=*/0.0, + /*max_diff=*/5.0); + auto table = MakeTable("table", limiter); + absl::Mutex mu; + + { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu, kTimeout)); + limiter->Insert(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu, kTimeout)); + limiter->Insert(&mu); + } + + // No more inserts should be allowed until now. + absl::Notification insert; + auto insert_thread = internal::StartThread("", [&] { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu)); + insert.Notify(); + }); + EXPECT_FALSE(insert.WaitForNotificationWithTimeout(kTimeout)); + + // Sampling should be fine now since the min size has been reached. + { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu, kTimeout)); + } + + // The insert should still be blocked due to the large samples_per_insert. + EXPECT_FALSE(insert.WaitForNotificationWithTimeout(kTimeout)); + + // If we remove an item then the min size is no reached which should unblock + // the insert. + { + absl::WriterMutexLock lock(&mu); + limiter->Delete(&mu); + } + EXPECT_TRUE(insert.WaitForNotificationWithTimeout(kTimeout)); + + insert_thread = nullptr; // Joins the thread. +} + +TEST(RateLimiterTest, BlocksSamplesIfDeleteBringsSizeBelowMinSize) { + auto limiter = + std::make_shared(/*samples_per_insert=*/5, + /*min_size_to_sample=*/2, /*min_diff=*/0.0, + /*max_diff=*/5.0); + auto table = MakeTable("table", limiter); + absl::Mutex mu; + + { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu, kTimeout)); + limiter->Insert(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu, kTimeout)); + limiter->Insert(&mu); + + // Sampling should be fine now since the min size has been reached. + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu, kTimeout)); + + // Deleting an item will bring the size back below the + // min_size_to_sample which should block any further samples. + limiter->Delete(&mu); + } + + absl::Notification sample; + auto sample_thread = internal::StartThread("", [&] { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitAndFinalizeSample(&mu)); + sample.Notify(); + }); + EXPECT_FALSE(sample.WaitForNotificationWithTimeout(kTimeout)); + + // Inserting a new item will bring the size up again which should unblock the + // sampling. It should however not be unblocked by simply staging the insert. + { + absl::WriterMutexLock lock(&mu); + TF_EXPECT_OK(limiter->AwaitCanInsert(&mu, kTimeout)); + } + EXPECT_FALSE(sample.WaitForNotificationWithTimeout(kTimeout)); + { + absl::WriterMutexLock lock(&mu); + limiter->Insert(&mu); + } + + EXPECT_TRUE(sample.WaitForNotificationWithTimeout(kTimeout)); + + sample_thread = nullptr; // Joins the thread. +} + +TEST(RateLimiterTest, Info) { + EXPECT_THAT(RateLimiter(1, 1, 0, 5).info(), + EqualsProto("samples_per_insert: 1 min_size_to_sample: 1 " + "min_diff: 0 max_diff: 5")); + EXPECT_THAT(RateLimiter(1.5, 14, -10, 5.3).info(), + EqualsProto("samples_per_insert: 1.5 min_size_to_sample: 14 " + "min_diff: -10 max_diff: 5.3")); +} + +TEST(RateLimiterDeathTest, DiesIfMinSizeToSampleNonPositive) { + ASSERT_DEATH(RateLimiter(1, 0, 0, 5), ""); + ASSERT_DEATH(RateLimiter(1, -1, 0, 5), ""); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/replay_client.cc b/reverb/cc/replay_client.cc new file mode 100644 index 0000000..cdd1f62 --- /dev/null +++ b/reverb/cc/replay_client.cc @@ -0,0 +1,287 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/replay_client.h" + +#include +#include + +#include "grpcpp/support/channel_arguments.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "reverb/cc/platform/grpc_utils.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/replay_service.pb.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/support/grpc_util.h" +#include "reverb/cc/support/uint128.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" + +namespace deepmind { +namespace reverb { +namespace { + +constexpr int kMaxMessageSize = 30 * 1000 * 1000; + +grpc::ChannelArguments CreateChannelArguments() { + grpc::ChannelArguments arguments; + arguments.SetMaxReceiveMessageSize(kMaxMessageSize); + arguments.SetMaxSendMessageSize(kMaxMessageSize); + arguments.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 30 * 1000); + arguments.SetLoadBalancingPolicyName("round_robin"); + return arguments; +} + +} // namespace + +ReplayClient::ReplayClient( + std::shared_ptr stub) + : stub_(std::move(stub)) { + REVERB_CHECK(stub_ != nullptr); +} + +ReplayClient::ReplayClient(absl::string_view server_address) + : stub_(/* grpc_gen:: */ReplayService::NewStub(CreateCustomGrpcChannel( + server_address, MakeChannelCredentials(), CreateChannelArguments()))) {} + +tensorflow::Status ReplayClient::MaybeUpdateServerInfoCache( + absl::Duration timeout, + std::shared_ptr* cached_flat_signatures) { + // TODO(b/154927570): Once tables can be mutated on the server, we'll need to + // decide a new rule for updating the server info, instead of doing it just + // once at the beginning. + { + // Exit early if we have table info cached. + absl::ReaderMutexLock lock(&cached_table_mu_); + if (cached_flat_signatures_) { + *cached_flat_signatures = cached_flat_signatures_; + return tensorflow::Status::OK(); + } + } + + // This performs an RPC, so don't run it within a mutex. + // Note, this operation can run into a race condition where multiple + // threads of the same ReplayClient request server info, get different + // values, and one of these overwrites cached_table_info_ with a staler + // ServerInfo after another thread writes a newer version of ServerInfo + // Then future writers see stale signatures. + // + // In practice this isn't a real issue because: + // (1) This type of concurrency is not common: once ServerInfo + // is set, this code path isn't executed again. + // (2) ServerInfo doesn't change often on a reverb server. + // (3) Due to the default gRPC client load balancing mechanism, + // a client with a stub connection to one IP of a group of + // servers will always use the same IP address for all + // consecutive requests. So even concurrent requests will all + // go to the same server ("pick_first" policy): + // + // https://github.com/grpc/grpc/blob/631fe79f84af295c60aea5693350b45154827398/src/core/ext/filters/client_channel/client_channel.cc#L1661 + struct ServerInfo info; + TF_RETURN_IF_ERROR(GetServerInfo(timeout, &info)); + + absl::MutexLock lock(&cached_table_mu_); + TF_RETURN_IF_ERROR(LockedUpdateServerInfoCache(info)); + *cached_flat_signatures = cached_flat_signatures_; + return tensorflow::Status::OK(); +} + +tensorflow::Status ReplayClient::NewWriter( + int chunk_length, int max_timesteps, bool delta_encoded, + std::unique_ptr* writer) { + // TODO(b/154928265): caching this request? For example, if + // it's been N seconds or minutes, it may be time to + // get an updated ServerInfo and see if there are new tables. + std::shared_ptr cached_flat_signatures; + // TODO(b/154927687): It is not ideal that this blocks forever. We should + // probably limit this and ignore the signature if it couldn't be found within + // some limits. + TF_RETURN_IF_ERROR(MaybeUpdateServerInfoCache(absl::InfiniteDuration(), + &cached_flat_signatures)); + *writer = absl::make_unique(stub_, chunk_length, max_timesteps, + delta_encoded, + std::move(cached_flat_signatures)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ReplayClient::MutatePriorities( + absl::string_view table, const std::vector& updates, + const std::vector& deletes) { + grpc::ClientContext context; + context.set_wait_for_ready(true); + MutatePrioritiesRequest request; + request.set_table(table.data(), table.size()); + for (const KeyWithPriority& item : updates) { + *request.add_updates() = item; + } + for (int64_t key : deletes) { + request.add_delete_keys(key); + } + MutatePrioritiesResponse response; + return FromGrpcStatus(stub_->MutatePriorities(&context, request, &response)); +} + +tensorflow::Status ReplayClient::NewSampler( + const std::string& table, const ReplaySampler::Options& options, + std::unique_ptr* sampler) { + *sampler = absl::make_unique(stub_, table, options); + return tensorflow::Status::OK(); +} + +tensorflow::Status ReplayClient::NewSampler( + const std::string& table, const ReplaySampler::Options& options, + const tensorflow::DataTypeVector& validation_dtypes, + const std::vector& validation_shapes, + absl::Duration validation_timeout, + std::unique_ptr* sampler) { + // TODO(b/154928265): caching this request? For example, if + // it's been N seconds or minutes, it may be time to + // get an updated ServerInfo and see if there are new tables. + std::shared_ptr cached_flat_signatures; + TF_RETURN_IF_ERROR( + MaybeUpdateServerInfoCache(validation_timeout, &cached_flat_signatures)); + + const auto iter = cached_flat_signatures->find(table); + if (iter == cached_flat_signatures->end()) { + std::vector table_names; + for (const auto& table : *cached_flat_signatures) { + table_names.push_back(absl::StrCat("'", table.first, "'")); + } + REVERB_LOG(REVERB_WARNING) + << "Unable to find table '" << table + << "' in server signature. Perhaps the table hasn't yet been added to " + "the server? Available tables: [" + << absl::StrJoin(table_names, ", ") << "]."; + } else { + const auto& dtypes_and_shapes_no_info = iter->second; + // Only perform check if the table had a signature associated with it. + if (dtypes_and_shapes_no_info) { + std::vector dtypes_and_shapes; + // First element of sampled signature is the key. + dtypes_and_shapes.push_back( + {tensorflow::DT_UINT64, tensorflow::PartialTensorShape({})}); + // Second element of sampled signature is the probability value. + dtypes_and_shapes.push_back( + {tensorflow::DT_DOUBLE, tensorflow::PartialTensorShape({})}); + // Third element of sampled signature is the size of the table. + dtypes_and_shapes.push_back( + {tensorflow::DT_INT64, tensorflow::PartialTensorShape({})}); + for (const auto& dtype_and_shape : *dtypes_and_shapes_no_info) { + dtypes_and_shapes.push_back(dtype_and_shape); + } + if (dtypes_and_shapes.size() != validation_shapes.size()) { + return tensorflow::errors::InvalidArgument( + "Inconsistent number of tensors requested from table '", table, + "'. Requested ", validation_shapes.size(), + " tensors, but table signature shows ", dtypes_and_shapes.size(), + " tensors. Table signature: ", + internal::DtypesShapesString(dtypes_and_shapes)); + } + for (int i = 0; i < dtypes_and_shapes.size(); ++i) { + if (dtypes_and_shapes[i].dtype != validation_dtypes[i] || + !dtypes_and_shapes[i].shape.IsCompatibleWith( + validation_shapes[i])) { + return tensorflow::errors::InvalidArgument( + "Requested incompatible tensor at flattened index ", i, + " from table '", table, "'. Requested (dtype, shape): (", + tensorflow::DataTypeString(validation_dtypes[i]), ", ", + validation_shapes[i].DebugString(), + "). Signature (dtype, shape): (", + tensorflow::DataTypeString(dtypes_and_shapes[i].dtype), ", ", + dtypes_and_shapes[i].shape.DebugString(), "). Table signature: ", + internal::DtypesShapesString(dtypes_and_shapes)); + } + } + } + } + + // TODO(b/154927849): Do sanity checks on the buffer_size and max_samples. + // TODO(b/154928566): Maybe we don't even need to expose the buffer_size. + return NewSampler(table, options, sampler); +} + +tensorflow::Status ReplayClient::GetServerInfo(absl::Duration timeout, + struct ServerInfo* info) { + grpc::ClientContext context; + context.set_wait_for_ready(true); + if (timeout != absl::InfiniteDuration()) { + context.set_deadline(std::chrono::system_clock::now() + + absl::ToChronoSeconds(timeout)); + } + + ServerInfoRequest request; + ServerInfoResponse response; + TF_RETURN_IF_ERROR( + FromGrpcStatus(stub_->ServerInfo(&context, request, &response))); + info->tables_state_id = MessageToUint128(response.tables_state_id()); + for (class TableInfo& table : *response.mutable_table_info()) { + info->table_info.emplace_back(std::move(table)); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status ReplayClient::ServerInfo(struct ServerInfo* info) { + return ServerInfo(absl::InfiniteDuration(), info); +} + +tensorflow::Status ReplayClient::ServerInfo(absl::Duration timeout, + struct ServerInfo* info) { + struct ServerInfo local_info; + TF_RETURN_IF_ERROR(GetServerInfo(timeout, &local_info)); + { + absl::MutexLock lock(&cached_table_mu_); + TF_RETURN_IF_ERROR(LockedUpdateServerInfoCache(local_info)); + } + std::swap(*info, local_info); + return tensorflow::Status::OK(); +} + +tensorflow::Status ReplayClient::LockedUpdateServerInfoCache( + const struct ServerInfo& info) { + if (!cached_flat_signatures_ || tables_state_id_ != info.tables_state_id) { + internal::FlatSignatureMap signatures; + for (const auto& table_info : info.table_info) { + TF_RETURN_IF_ERROR(internal::FlatSignatureFromTableInfo( + table_info, &(signatures[table_info.name()]))); + } + cached_flat_signatures_.reset( + new internal::FlatSignatureMap(std::move(signatures))); + tables_state_id_ = info.tables_state_id; + } + return tensorflow::Status::OK(); +} + +tensorflow::Status ReplayClient::Reset(const std::string& table) { + grpc::ClientContext context; + context.set_wait_for_ready(true); + ResetRequest request; + request.set_table(table); + ResetResponse response; + return FromGrpcStatus(stub_->Reset(&context, request, &response)); +} + +tensorflow::Status ReplayClient::Checkpoint(std::string* path) { + grpc::ClientContext context; + context.set_fail_fast(true); + CheckpointRequest request; + CheckpointResponse response; + TF_RETURN_IF_ERROR( + FromGrpcStatus(stub_->Checkpoint(&context, request, &response))); + *path = response.checkpoint_path(); + return tensorflow::Status::OK(); +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/replay_client.h b/reverb/cc/replay_client.h new file mode 100644 index 0000000..c06b603 --- /dev/null +++ b/reverb/cc/replay_client.h @@ -0,0 +1,137 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_REPLAY_CLIENT_H_ +#define REVERB_CC_REPLAY_CLIENT_H_ + +#include + +#include +#include +#include + +#include +#include "absl/base/thread_annotations.h" +#include "absl/numeric/int128.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "reverb/cc/replay_sampler.h" +#include "reverb/cc/replay_service.grpc.pb.h" +#include "reverb/cc/replay_service.pb.h" +#include "reverb/cc/replay_writer.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/support/signature.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +class ReplayWriter; + +// See ReplayService proto definition for documentation. +class ReplayClient { + public: + struct ServerInfo { + // This struct mirrors the ServerInfo message in + // replay_service.proto. Take a look at that proto file for + // field documentation. + absl::uint128 tables_state_id; + std::vector table_info; + }; + + explicit ReplayClient( + std::shared_ptr stub); + explicit ReplayClient(absl::string_view server_address); + + // Upon successful return, `writer` will contain an instance of ReplayWriter. + tensorflow::Status NewWriter(int chunk_length, int max_timesteps, + bool delta_encoded, + std::unique_ptr* writer); + + // Upon successful return, `sampler` will contain an instance of + // ReplaySampler. + tensorflow::Status NewSampler(const std::string& table, + const ReplaySampler::Options& options, + std::unique_ptr* sampler); + + // Upon successful return, `sampler` will contain an instance of + // ReplaySampler. + // + // If the table has signature metadata available on the server, then + // `validation_shapes` and `validation_dtypes` are checked against the + // flattened signature. + // + // On the other hand, if the table info returned from the server lacks a + // signature, then no validation is performed. If no table entry exists + // for the given table string, then a warning is logged. + // + // **NOTE** Because the sampler always prepends the entry key and + // priority tensors when returning samples, the `validation_{dtypes, shapes}` + // vectors must always be prepended with the signatures of these outputs. + // Specifically, the user must pass: + // + // validation_dtypes[0:1] = {DT_UINT64, DT_DOUBLE} + // validation_shapes[0:1] = {PartialTensorShape({}), PartialTensorShape({})} + // + // and the remaining elements should be the dtypes/shapes of the entries + // expected in table signature. + // + tensorflow::Status NewSampler( + const std::string& table, const ReplaySampler::Options& options, + const tensorflow::DataTypeVector& validation_dtypes, + const std::vector& validation_shapes, + absl::Duration validation_timeout, + std::unique_ptr* sampler); + + tensorflow::Status MutatePriorities( + absl::string_view table, const std::vector& updates, + const std::vector& deletes); + + tensorflow::Status Reset(const std::string& table); + + tensorflow::Status Checkpoint(std::string* path); + + // Requests ServerInfo. Forces an update of internal signature caches. + tensorflow::Status ServerInfo(absl::Duration timeout, + struct ServerInfo* info); + // Waits indefinetely for server to respond. + tensorflow::Status ServerInfo(struct ServerInfo* info); + + private: + const std::shared_ptr stub_; + + tensorflow::Status MaybeUpdateServerInfoCache( + absl::Duration timeout, + std::shared_ptr* cached_flat_signatures); + + // Purely functional request for server info. Does not update any internal + // caches. + tensorflow::Status GetServerInfo(absl::Duration timeout, + struct ServerInfo* info); + + // Updates tables_state_id_ and cached_flat_signatures_ using info. + tensorflow::Status LockedUpdateServerInfoCache(const struct ServerInfo& info) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(cached_table_mu_); + + absl::Mutex cached_table_mu_; + absl::uint128 tables_state_id_ ABSL_GUARDED_BY(cached_table_mu_); + std::shared_ptr cached_flat_signatures_ + ABSL_GUARDED_BY(cached_table_mu_); +}; + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_REPLAY_CLIENT_H_ diff --git a/reverb/cc/replay_client_test.cc b/reverb/cc/replay_client_test.cc new file mode 100644 index 0000000..142cc8a --- /dev/null +++ b/reverb/cc/replay_client_test.cc @@ -0,0 +1,132 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/replay_client.h" + +#include + +#include "grpcpp/impl/codegen/client_context.h" +#include "grpcpp/impl/codegen/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "reverb/cc/replay_service.pb.h" +#include "reverb/cc/replay_service_mock.grpc.pb.h" +#include "reverb/cc/support/uint128.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { +namespace { + +constexpr char kCheckpointPath[] = "/path/to/checkpoint"; + +class FakeStub : public /* grpc_gen:: */MockReplayServiceStub { + public: + grpc::Status MutatePriorities(grpc::ClientContext* context, + const MutatePrioritiesRequest& request, + MutatePrioritiesResponse* response) override { + mutate_priorities_request_ = request; + return grpc::Status::OK; + } + + grpc::Status Reset(grpc::ClientContext* context, const ResetRequest& request, + ResetResponse* response) override { + reset_request_ = request; + return grpc::Status::OK; + } + + grpc::Status Checkpoint(grpc::ClientContext* context, + const CheckpointRequest& request, + CheckpointResponse* response) override { + response->set_checkpoint_path(kCheckpointPath); + return grpc::Status::OK; + } + + grpc::Status ServerInfo(grpc::ClientContext* context, + const ServerInfoRequest& request, + ServerInfoResponse* response) override { + *response->mutable_tables_state_id() = + Uint128ToMessage(absl::MakeUint128(1, 2)); + response->add_table_info()->set_max_size(2); + return grpc::Status::OK; + } + + const MutatePrioritiesRequest& mutate_priorities_request() { + return mutate_priorities_request_; + } + + const ResetRequest& reset_request() { return reset_request_; } + + private: + MutatePrioritiesRequest mutate_priorities_request_; + ResetRequest reset_request_; +}; + +TEST(ReplayClientTest, MutatePrioritiesDefaultValues) { + auto stub = std::make_shared(); + ReplayClient client(stub); + TF_EXPECT_OK(client.MutatePriorities("", {}, {})); + EXPECT_THAT(stub->mutate_priorities_request(), + testing::EqualsProto(MutatePrioritiesRequest())); +} + +TEST(ReplayClientTest, MutatePrioritiesFilled) { + auto stub = std::make_shared(); + ReplayClient client(stub); + auto pair = testing::MakeKeyWithPriority(123, 456); + TF_EXPECT_OK(client.MutatePriorities("table", {pair}, {4})); + + MutatePrioritiesRequest expected; + expected.set_table("table"); + *expected.add_updates() = pair; + expected.add_delete_keys(4); + EXPECT_THAT(stub->mutate_priorities_request(), + testing::EqualsProto(expected)); +} + +TEST(ReplayClientTest, ResetRequestFilled) { + auto stub = std::make_shared(); + ReplayClient client(stub); + TF_EXPECT_OK(client.Reset("table")); + + ResetRequest expected; + expected.set_table("table"); + EXPECT_THAT(stub->reset_request(), testing::EqualsProto(expected)); +} + +TEST(ReplayClientTest, Checkpoint) { + auto stub = std::make_shared(); + ReplayClient client(stub); + std::string path; + TF_EXPECT_OK(client.Checkpoint(&path)); + EXPECT_EQ(path, kCheckpointPath); +} + +TEST(ReplayClientTest, ServerInfoRequestFilled) { + auto stub = std::make_shared(); + ReplayClient client(stub); + struct ReplayClient::ServerInfo info; + TF_EXPECT_OK(client.ServerInfo(&info)); + + TableInfo expected_info; + expected_info.set_max_size(2); + EXPECT_EQ(info.tables_state_id, absl::MakeUint128(1, 2)); + EXPECT_EQ(info.table_info.size(), 1); + EXPECT_THAT(info.table_info[0], testing::EqualsProto(expected_info)); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/replay_sampler.cc b/reverb/cc/replay_sampler.cc new file mode 100644 index 0000000..cbed60a --- /dev/null +++ b/reverb/cc/replay_sampler.cc @@ -0,0 +1,419 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/replay_sampler.h" + +#include +#include + +#include "grpcpp/impl/codegen/client_context.h" +#include "grpcpp/impl/codegen/sync_stream.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/platform/thread.h" +#include "reverb/cc/replay_service.pb.h" +#include "reverb/cc/support/grpc_util.h" +#include "reverb/cc/tensor_compression.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { +namespace { + +inline bool SampleIsDone(const std::vector& sample) { + if (sample.empty()) return false; + int64_t chunk_length = 0; + for (const auto& response : sample) { + chunk_length += response.data().data(0).tensor_shape().dim(0).size(); + } + const auto& range = sample.front().info().item().sequence_range(); + return chunk_length >= range.length() + range.offset(); +} + +template +tensorflow::Tensor InitializeTensor(T value, int64_t length) { + tensorflow::Tensor tensor(tensorflow::DataTypeToEnum::v(), + tensorflow::TensorShape({length})); + auto tensor_t = tensor.flat(); + std::fill(tensor_t.data(), tensor_t.data() + length, value); + return tensor; +} + +std::unique_ptr AsSample(std::vector responses) { + const auto& info = responses.front().info(); + + // Extract all chunks belonging to this sample. + std::list> chunks; + + // The chunks are not required to be aligned perfectly with the data so a + // part of the first chunk is potentially stripped. The same applies to the + // last part of the final chunk. + int64_t offset = info.item().sequence_range().offset(); + int64_t remaining = info.item().sequence_range().length(); + + for (auto& response : responses) { + REVERB_CHECK_GT(remaining, 0); + + std::vector batches; + batches.resize(response.data().data_size()); + + int64_t batch_size = -1; + + // Convert each chunk tensor and release the chunk memory afterwards. + int64_t insert_index = response.data().data_size() - 1; + while (!response.data().data().empty()) { + tensorflow::Tensor batch; + + { + // This ensures we release the response proto after converting the + // result to a tensor. + auto chunk = absl::WrapUnique( + response.mutable_data()->mutable_data()->ReleaseLast()); + batch = DecompressTensorFromProto(*chunk); + } + + if (response.data().delta_encoded()) { + batch = DeltaEncode(batch, /*encode=*/false); + } + + if (batch_size < 0) { + batch_size = batch.dim_size(0); + } else { + REVERB_CHECK_EQ(batch_size, batch.dim_size(0)) + << "Chunks of the same response have varying batch size."; + } + + batch = + batch.Slice(offset, std::min(offset + remaining, batch_size)); + if (!batch.IsAligned()) { + batch = tensorflow::tensor::DeepCopy(batch); + } + + batches[insert_index--] = std::move(batch); + } + + chunks.push_back(std::move(batches)); + + remaining -= std::min(remaining, batch_size - offset); + offset = 0; + } + + REVERB_CHECK_EQ(remaining, 0); + + return absl::make_unique(info.item().key(), info.probability(), + info.table_size(), std::move(chunks)); +} + +} // namespace + +ReplaySampler::ReplaySampler( + std::shared_ptr stub, + const std::string& table, const Options& options) + : stub_(std::move(stub)), + max_samples_(options.max_samples == kUnlimitedMaxSamples + ? INT64_MAX + : options.max_samples), + max_samples_per_stream_(options.max_samples_per_stream == kAutoSelectValue + ? kDefaultMaxSamplesPerStream + : options.max_samples_per_stream), + active_sample_(nullptr), + samples_(std::max(options.num_workers, 1)) { + REVERB_CHECK_GT(max_samples_, 0); + REVERB_CHECK_GT(options.max_in_flight_samples_per_worker, 0); + REVERB_CHECK(options.num_workers == kAutoSelectValue || + options.num_workers > 0); + + int64_t num_workers = options.num_workers == kAutoSelectValue + ? kDefaultNumWorkers + : options.num_workers; + + // If a subset of the workers are able to fetch all of `max_samples_` in the + // first batch then there is no point in creating all of them. + num_workers = std::min( + num_workers, + std::max(1, + max_samples_ / options.max_in_flight_samples_per_worker)); + + for (int i = 0; i < num_workers; i++) { + workers_.push_back(absl::make_unique( + stub_, table, options.max_in_flight_samples_per_worker)); + worker_threads_.push_back(internal::StartThread( + absl::StrCat("SampleWorker", i), + [this, worker = workers_[i].get()] { RunWorker(worker); })); + } +} + +ReplaySampler::~ReplaySampler() { Close(); } + +tensorflow::Status ReplaySampler::GetNextTimestep( + std::vector* data, bool* end_of_sequence) { + TF_RETURN_IF_ERROR(MaybeSampleNext()); + + *data = active_sample_->GetNextTimestep(); + + if (end_of_sequence != nullptr) { + *end_of_sequence = active_sample_->is_end_of_sample(); + } + + if (active_sample_->is_end_of_sample()) { + absl::WriterMutexLock lock(&mu_); + if (++returned_ == max_samples_) samples_.Close(); + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status ReplaySampler::GetNextSample( + std::vector* data) { + std::unique_ptr sample; + TF_RETURN_IF_ERROR(PopNextSample(&sample)); + *data = sample->AsBatchedTimesteps(); + + absl::WriterMutexLock lock(&mu_); + if (++returned_ == max_samples_) samples_.Close(); + return tensorflow::Status::OK(); +} + +bool ReplaySampler::should_stop_workers() const { + return closed_ || returned_ == max_samples_ || !stream_status_.ok(); +} + +void ReplaySampler::Close() { + { + absl::WriterMutexLock lock(&mu_); + if (closed_) return; + closed_ = true; + } + + for (auto& worker : workers_) { + worker->Cancel(); + } + + samples_.Close(); + worker_threads_.clear(); // Joins worker threads. +} + +tensorflow::Status ReplaySampler::MaybeSampleNext() { + if (active_sample_ != nullptr && !active_sample_->is_end_of_sample()) { + return tensorflow::Status::OK(); + } + + return PopNextSample(&active_sample_); +} + +tensorflow::Status ReplaySampler::PopNextSample( + std::unique_ptr* sample) { + if (samples_.Pop(sample)) return tensorflow::Status::OK(); + + absl::ReaderMutexLock lock(&mu_); + if (returned_ == max_samples_) { + return tensorflow::errors::OutOfRange("`max_samples` already returned."); + } + if (closed_) { + return tensorflow::errors::Cancelled("Sampler has been cancelled."); + } + return FromGrpcStatus(stream_status_); +} + +void ReplaySampler::RunWorker(Worker* worker) { + auto trigger = [this]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return should_stop_workers() || requested_ < max_samples_; + }; + + while (true) { + mu_.LockWhen(absl::Condition(&trigger)); + + if (should_stop_workers()) { + mu_.Unlock(); + return; + } + int64_t samples_to_stream = + std::min(max_samples_per_stream_, max_samples_ - requested_); + requested_ += samples_to_stream; + mu_.Unlock(); + + auto result = worker->OpenStreamAndFetch(&samples_, samples_to_stream); + + { + absl::WriterMutexLock lock(&mu_); + + // If the stream was closed prematurely then we need to reduce the number + // of requested samples by the difference of the expected number and the + // actual. + requested_ -= samples_to_stream - result.first; + + // Overwrite the final status only if it wasn't already an error. + if (stream_status_.ok() && !result.second.ok() && + result.second.error_code() != grpc::StatusCode::UNAVAILABLE) { + stream_status_ = result.second; + samples_.Close(); // Unblock any pending calls. + return; + } + } + } +} + +ReplaySampler::Worker::Worker( + std::shared_ptr stub, + std::string table, int64_t samples_per_request) + : stub_(std::move(stub)), + table_(std::move(table)), + samples_per_request_(samples_per_request) {} + +std::pair ReplaySampler::Worker::OpenStreamAndFetch( + deepmind::reverb::internal::Queue>* queue, + int64_t num_samples) { + std::unique_ptr> + stream; + { + absl::MutexLock lock(&mu_); + if (closed_) { + return {0, grpc::Status(grpc::StatusCode::CANCELLED, + "`Close` called on ReplaySampler.")}; + } + context_ = absl::make_unique(); + context_->set_wait_for_ready(false); + stream = stub_->SampleStream(context_.get()); + } + + int64_t num_samples_returned = 0; + while (num_samples_returned < num_samples) { + SampleStreamRequest request; + request.set_table(table_); + request.set_num_samples( + std::min(samples_per_request_, num_samples - num_samples_returned)); + + if (!stream->Write(request)) { + return {num_samples_returned, stream->Finish()}; + } + + for (int64_t i = 0; i < request.num_samples(); i++) { + std::vector responses; + while (!SampleIsDone(responses)) { + SampleStreamResponse response; + if (!stream->Read(&response)) { + return {num_samples_returned, stream->Finish()}; + } + responses.push_back(std::move(response)); + } + + if (!queue->Push(AsSample(std::move(responses)))) { + return {num_samples_returned, + grpc::Status(grpc::StatusCode::CANCELLED, + "`Close` called on ReplaySampler.")}; + } + ++num_samples_returned; + } + } + + // TODO(b/147404612): Remove this or return INTERNAL error. + REVERB_CHECK_EQ(num_samples_returned, num_samples); + return {num_samples_returned, grpc::Status::OK}; +} + +void ReplaySampler::Worker::Cancel() { + absl::MutexLock lock(&mu_); + closed_ = true; + if (context_ != nullptr) context_->TryCancel(); +} + +Sample::Sample(tensorflow::uint64 key, double probability, + tensorflow::int64 table_size, + std::list> chunks) + : key_(key), + probability_(probability), + table_size_(table_size), + num_timesteps_(0), + num_data_tensors_(0), + chunks_(std::move(chunks)), + next_timestep_index_(0), + next_timestep_called_(false) { + REVERB_CHECK(!chunks_.empty()) << "Must provide at least one chunk."; + REVERB_CHECK(!chunks_.front().empty()) + << "Chunks must hold at least one tensor."; + + num_data_tensors_ = chunks_.front().size(); + for (const auto& batches : chunks_) { + num_timesteps_ += batches.front().dim_size(0); + } +} + +std::vector Sample::GetNextTimestep() { + REVERB_CHECK(!is_end_of_sample()); + + // Construct the output tensors. + std::vector result; + result.reserve(num_data_tensors_ + 3); + result.push_back(tensorflow::Tensor(key_)); + result.push_back(tensorflow::Tensor(probability_)); + result.push_back(tensorflow::Tensor(table_size_)); + + for (const auto& t : chunks_.front()) { + auto slice = t.SubSlice(next_timestep_index_); + if (slice.IsAligned()) { + result.push_back(std::move(slice)); + } else { + result.push_back(tensorflow::tensor::DeepCopy(slice)); + } + } + + // Advance the iterator. + ++next_timestep_index_; + if (next_timestep_index_ == chunks_.front().front().dim_size(0)) { + // Go to the next chunk. + chunks_.pop_front(); + next_timestep_index_ = 0; + } + next_timestep_called_ = true; + + return result; +} + +bool Sample::is_end_of_sample() const { return chunks_.empty(); } + +std::vector Sample::AsBatchedTimesteps() { + CHECK(!next_timestep_called_) << "Some time steps have been lost."; + + std::vector sequences(num_data_tensors_ + 3); + + // Initialize the first three items with the key, probability and table size. + sequences[0] = InitializeTensor(key_, num_timesteps_); + sequences[1] = InitializeTensor(probability_, num_timesteps_); + sequences[2] = InitializeTensor(table_size_, num_timesteps_); + + // Prepare the data for concatenation. + // data_tensors[i][j] is the j-th chunk of the i-th data tensor. + std::vector> data_tensors(num_data_tensors_); + + // Extract all chunks. + while (!chunks_.empty()) { + auto it_to = data_tensors.begin(); + for (auto& batch : chunks_.front()) { + (it_to++)->push_back(std::move(batch)); + } + chunks_.pop_front(); + } + + // Concatenate all chunks. + int64_t i = 3; + for (const auto& chunks : data_tensors) { + TF_CHECK_OK(tensorflow::tensor::Concat(chunks, &sequences[i++])); + } + + return sequences; +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/replay_sampler.h b/reverb/cc/replay_sampler.h new file mode 100644 index 0000000..5df8e4f --- /dev/null +++ b/reverb/cc/replay_sampler.h @@ -0,0 +1,299 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_REPLAY_SAMPLER_H_ +#define LEARNING_DEEPMIND_REPLAY_REVERB_REPLAY_SAMPLER_H_ + +#include + +#include +#include +#include +#include + +#include +#include "reverb/cc/platform/thread.h" +#include "reverb/cc/replay_service.grpc.pb.h" +#include "reverb/cc/support/queue.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { + +// A sample from the replay buffer. +class Sample { + public: + Sample(tensorflow::uint64 key, double probability, + tensorflow::int64 table_size, + std::list> chunks); + + // Returns the next time step from this sample as a flat sequence of tensors. + // CHECK-fails if the entire sample has already been returned. + std::vector GetNextTimestep(); + + // Returns the entire sample as a flat sequence of batched tensors. + // CHECK-fails if `GetNextTimestep()` has already been called on this sample. + // Return: + // K+3 tensors each having a leading dimension of size N (= sample + // length). The first thre tensors are 1D (length N) representing the key, + // sample probability and table size respectively. The following K tensors + // holds the actual timestep data batched into a tensor of shape [N, + // ...original_shape]. + std::vector AsBatchedTimesteps(); + + // Returns true if the end of the sample has been reached. + ABSL_MUST_USE_RESULT bool is_end_of_sample() const; + + private: + // The key of the replay item this time step was sampled from. + tensorflow::uint64 key_; + // The probability of the replay item this time step was sampled from. + double probability_; + // The size of the replay table this time step was sampled from at the time + // of sampling. + tensorflow::int64 table_size_; + + // Total number of time steps in this sample. + int64_t num_timesteps_; + + // Number of data tensors per time step. + int64_t num_data_tensors_; + + // A list of tensor chunks. + std::list> chunks_; + + // The next time step to return when GetNextTimestep() is called. + int64_t next_timestep_index_; + + // True if GetNextTimestep() has been called on this sample. + bool next_timestep_called_; +}; + +// The `ReplaySampler` class should be used to retrieve samples from a +// ReplayService. A set of workers, each managing a bi-directional gRPC stream +// are created. The workers unpack the responses into samples (sequences of +// timesteps) which are returned through calls to `GetNextTimestep` and +// `GetNextSample`. +// +// Concurrent calls to `GetNextTimestep` is NOT supported! This includes calling +// `GetNextSample` and `GetNextTimestep` concurrently. +// +// Terminology: +// Timestep: +// Set of tensors representing a single "step" (i.e data passed to +// `ReplayWriter::AppendTimestep`). +// Chunk: +// Timesteps batched (along the time dimension) and compressed. If each +// timestep contains K tensors of dtype dt_k and shape s_k and the chunk +// has length N then the chunk will contain K (compressed) tensors of dtype +// dt_k and shape [N, ...s_k]. +// Sample: +// Metadata (i.e priority, key) and sequence of timesteps that constitutes +// an item in a `PriorityTable`. During transmission the "sample" is made +// up of a vector of chunks and a metadata that defines what parts of the +// chunks are actually part of the sample. Once received the sample is +// unpacked into a sequence of `Timestep` before being returned to caller. +// Worker: +// Instance of `ReplaySampler::Worker` running within its own thread managed +// by the parent `ReplaySampler`. The worker opens and manages +// bi-directional gRPC streams to the server. It unpacks responses into +// samples and pushes these into a `Queue` owned by the `ReplaySampler` +// (effectively merging the outputs of the workers). +// +// Each `ReplaySampler` will create a set of `Worker`s, each managing a stream +// to a server. The combined output of the workers are merged into a `Queue` of +// complete samples. If `GetNextTimestep` is called then a sample is popped from +// the queue and split into timesteps and the first one returned. Timesteps are +// then popped one by one until the sample has been completely emitted and the +// process starts over. Calls to `GetNextSample` skips the timestep splitting +// and returns samples as a "batch of timesteps". +// +class ReplaySampler { + public: + static const int64_t kUnlimitedMaxSamples = -1; + static const int kAutoSelectValue = -1; + + // By default, streams are only allowed to be open for a small number + // (10000) of samples. A larger value could provide better performance + // (reconnecting less frequently) but increases the risk of subtle "bias" in + // the sampling distribution across a multi server setup. The bias will be + // caused by a non uniform number of SampleStream-connections across the + // servers being maintained for a longer period. The same phenomenon is + // present with more short lived connections but is mitigated by the round + // robin of the (more) frequently created new connections. + // TODO(b/147425281): Set this value higher for localhost connections. + static const int kDefaultMaxSamplesPerStream = 10000; + + // By default, only one worker is used as any higher number could lead to + // incorrect behavior for FIFO samplers. + static const int kDefaultNumWorkers = 1; + + struct Options { + // `max_samples` is the maximum number of samples the object will return. + // Must be a positive number or `kUnlimitedMaxSamples`. + int64_t max_samples = kUnlimitedMaxSamples; + + // `max_in_flight_samples_per_worker` is the number of samples requested by + // a worker in each batch. A new batch is requested once all the requested + // samples have been received. + int max_in_flight_samples_per_worker = 100; + + // `num_workers` is the number of worker threads started. + // + // When set to `kAutoSelectValue`, `kDefaultNumWorkers` is used. + int num_workers = kAutoSelectValue; + + // `max_samples_per_stream` is the maximum number of samples to fetch from a + // stream before a new call is made. Keeping this number low ensures that + // the data is fetched uniformly from all servers behind the `stub`. + // + // When set to `kAutoSelectValue`, `kDefaultMaxSamplesPerStream` is used. + int max_samples_per_stream = kAutoSelectValue; + }; + + // Constructs a new `ReplaySampler`. + // + // `stub` is a connected gRPC stub to the ReplayService. + // `table` is the name of the `PriorityTable` to sample from. + // `options` defines details of how to samples. + ReplaySampler(std::shared_ptr stub, + const std::string& table, const Options& options); + + // Joins worker threads through call to `Close`. + virtual ~ReplaySampler(); + + // Blocks until a timestep has been retrieved or until a non transient error + // is encountered or `Close` has been called. + tensorflow::Status GetNextTimestep(std::vector* data, + bool* end_of_sequence); + + // Blocks until a complete sample has been retrieved or until a non transient + // error is encountered or `Close` has been called. + tensorflow::Status GetNextSample(std::vector* data); + + // Cancels all workers and joins their threads. Any blocking or future call + // to `GetNextTimestep` or `GetNextSample` will return CancelledError without + // blocking. + void Close(); + + // ReplaySampler is neither copyable nor movable. + ReplaySampler(const ReplaySampler&) = delete; + ReplaySampler& operator=(const ReplaySampler&) = delete; + + private: + class Worker { + public: + // Constructs a new worker without creating a stream to a server. + explicit Worker( + std::shared_ptr stub, + std::string table, int64_t samples_per_request); + + // Cancels the stream and marks the worker as closed. Active and future + // calls to `OpenStreamAndFetch` will return status `CANCELLED`. + void Cancel(); + + // Opens a new `SampleStream` to a server and requests `num_samples` samples + // in batches with maximum size `samples_per_request`. Once complete (either + // done or from non transient error), the stream is closed and the number of + // samples pushed to `queue` is returned together with the status of the + // stream. + std::pair OpenStreamAndFetch( + internal::Queue>* queue, int64_t num_samples); + + private: + // Stub used to open `SampleStream`-streams to a server. + std::shared_ptr stub_; + + // Name of the `PriorityTable` to sample from. + const std::string table_; + + // The maximum number of samples to request in a "batch". + const int64_t samples_per_request_; + + // Context of the active stream. + std::unique_ptr context_ ABSL_GUARDED_BY(mu_); + + // True if `Cancel` has been called. + bool closed_ ABSL_GUARDED_BY(mu_) = false; + + absl::Mutex mu_; + }; + + void RunWorker(Worker* worker) ABSL_LOCKS_EXCLUDED(mu_); + + // If `active_sample_` has been read, blocks until a sample has been retrieved + // (popped from `samples_`) and populates `active_sample_`. + tensorflow::Status MaybeSampleNext(); + + // Blocks until a complete sample has been retrieved or until a non transient + // error is encountered or `Close` has been called. Note that this method does + // NOT increment `returned_`. This is left to `GetNextTimestep` and + // `GetNextSample`. The returned pointer is only valid if the status is OK. + tensorflow::Status PopNextSample(std::unique_ptr* sample); + + // True if the workers should be shut down. This is the case when either: + // - `Close` has been called. + // - The number of returned samples equal `max_samples_`. + // - One of the worker streams has been closed with a non transient error + // status. + bool should_stop_workers() const ABSL_SHARED_LOCKS_REQUIRED(mu_); + + // Stub used by workers to open SampleStream-connections to the servers. Note + // that the endpoints are load balanced using "roundrobin" which results in + // uniform sampling when using multiple backends. + std::shared_ptr stub_; + + // The maximum number of samples to fetch. Calls to `GetNextTimestep` or + // `GetNextSample` after `max_samples_` has been returned will result in + // OutOfRangeError. + const int64_t max_samples_; + + // The maximum number of samples to stream from a single call. Once the number + // of samples has been reached, a new stream is opened through the `stub_`. + // This ensures that data is fetched from all the servers. + const int64_t max_samples_per_stream_; + + // The number of complete samples that have been successfully requested. + int64_t requested_ ABSL_GUARDED_BY(mu_) = 0; + + // The number of complete samples that have been returned through + // `GetNextTimestep`. + int64_t returned_ ABSL_GUARDED_BY(mu_) = 0; + + // Workers and threads managing the worker with the same index. + std::vector> workers_; + std::vector> worker_threads_; + + // Remaining timesteps of the currently active sample. Not that this is not + // protected by mutex as concurrent calls to `GetNextTimestep` is not + // supported. + std::unique_ptr active_sample_; + + // Queue of complete samples (timesteps batched up by into sequence). + internal::Queue> samples_; + + // Set if `Close` called. + bool closed_ ABSL_GUARDED_BY(mu_) = false; + + // OK or the first non transient error encountered by a worker. + grpc::Status stream_status_ ABSL_GUARDED_BY(mu_); + + mutable absl::Mutex mu_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // LEARNING_DEEPMIND_REPLAY_REVERB_REPLAY_SAMPLER_H_ diff --git a/reverb/cc/replay_sampler_test.cc b/reverb/cc/replay_sampler_test.cc new file mode 100644 index 0000000..a00de73 --- /dev/null +++ b/reverb/cc/replay_sampler_test.cc @@ -0,0 +1,541 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/replay_sampler.h" + +#include +#include + +#include "grpcpp/client_context.h" +#include "grpcpp/impl/codegen/call_op_set.h" +#include "grpcpp/impl/codegen/status.h" +#include "grpcpp/support/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/replay_service.pb.h" +#include "reverb/cc/replay_service_mock.grpc.pb.h" +#include "reverb/cc/tensor_compression.h" +#include "reverb/cc/testing/tensor_testutil.h" +#include "reverb/cc/testing/time_testutil.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { +namespace { + +using test::ExpectTensorEqual; +using ::testing::SizeIs; + +class FakeStream + : public grpc::ClientReaderWriterInterface { + public: + FakeStream(std::function on_write, + std::vector responses, grpc::Status status) + : on_write_(std::move(on_write)), + responses_(std::move(responses)), + status_(std::move(status)) {} + + bool Write(const SampleStreamRequest& request, + grpc::WriteOptions options) override { + on_write_(request); + return status_.ok(); + } + + bool Read(SampleStreamResponse* response) override { + if (!responses_.empty() && status_.ok()) { + *response->mutable_data() = responses_.front().data(); + *response->mutable_info() = responses_.front().info(); + responses_.erase(responses_.begin()); + return true; + } + return false; + } + + void WaitForInitialMetadata() override {} + + bool WritesDone() override { return true; } + + bool NextMessageSize(uint32_t* sz) override { + *sz = responses_.front().ByteSizeLong(); + return true; + } + + grpc::Status Finish() override { return status_; } + + private: + std::function on_write_; + std::vector responses_; + grpc::Status status_; +}; + +class FakeStub : public /* grpc_gen:: */MockReplayServiceStub { + public: + grpc::ClientReaderWriterInterface* + SampleStreamRaw(grpc::ClientContext* context) override { + absl::WriterMutexLock lock(&mu_); + if (!streams_.empty()) { + FakeStream* stream = streams_.front().release(); + streams_.pop_front(); + return stream; + } + + return new FakeStream( + [this](const SampleStreamRequest& request) { + absl::WriterMutexLock lock(&mu_); + requests_.push_back(request); + }, + {}, grpc::Status::OK); + } + + void AddStream(std::vector responses, + grpc::Status status = grpc::Status::OK) { + absl::WriterMutexLock lock(&mu_); + streams_.push_back(absl::make_unique( + [this](const SampleStreamRequest& request) { + absl::WriterMutexLock lock(&mu_); + requests_.push_back(request); + }, + std::move(responses), std::move(status))); + } + + std::vector requests() const { + absl::ReaderMutexLock lock(&mu_); + return requests_; + } + + private: + std::list> streams_ ABSL_GUARDED_BY(mu_); + std::vector requests_ ABSL_GUARDED_BY(mu_); + mutable absl::Mutex mu_; +}; + +std::shared_ptr MakeFlakyStub( + std::vector responses, + std::vector errors) { + auto stub = std::make_shared(); + for (const auto& error : errors) { + stub->AddStream(responses, error); + } + stub->AddStream(responses); + return stub; +} + +std::shared_ptr MakeGoodStub( + std::vector responses) { + return MakeFlakyStub(std::move(responses), /*errors=*/{}); +} + +tensorflow::Tensor MakeTensor(int length) { + tensorflow::TensorShape shape({length, 2}); + tensorflow::Tensor tensor(tensorflow::DT_UINT64, shape); + for (int i = 0; i < tensor.NumElements(); i++) { + tensor.flat().data()[i] = i; + } + return tensor; +} + +SampleStreamResponse MakeResponse(int item_length, bool delta_encode = false, + int offset = 0, int data_length = 0) { + if (data_length == 0) { + data_length = item_length; + } + REVERB_CHECK_LE(item_length + offset, data_length); + + SampleStreamResponse response; + response.mutable_info()->mutable_item()->mutable_sequence_range()->set_length( + item_length); + response.mutable_info()->mutable_item()->mutable_sequence_range()->set_offset( + offset); + auto tensor = MakeTensor(data_length); + if (delta_encode) { + tensor = DeltaEncode(tensor, true); + response.mutable_data()->set_delta_encoded(true); + } + + CompressTensorAsProto(tensor, response.mutable_data()->add_data()); + return response; +} + +TEST(ReplaySamplerTest, SendsFirstRequest) { + auto stub = MakeGoodStub({MakeResponse(1)}); + ReplaySampler sampler(stub, "table", {1, 1, 1}); + std::vector sample; + bool end_of_sequence; + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + EXPECT_THAT(stub->requests(), SizeIs(1)); +} + +TEST(ReplaySamplerTest, SetsEndOfSequence) { + auto stub = MakeGoodStub({MakeResponse(2), MakeResponse(1)}); + ReplaySampler sampler(stub, "table", {2, 1}); + + std::vector sample; + bool end_of_sequence; + + // First sequence has 2 timesteps so first timestep should not be the end of + // a sequence. + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + absl::SleepFor(absl::Milliseconds(5)); + EXPECT_FALSE(end_of_sequence); + + // Second timestep is the end of the first sequence. + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + absl::SleepFor(absl::Milliseconds(5)); + EXPECT_TRUE(end_of_sequence); + + // Third timestep is the first and only timestep of the second sequence. + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + absl::SleepFor(absl::Milliseconds(5)); + EXPECT_TRUE(end_of_sequence); +} + +TEST(ReplaySamplerTest, GetNextSampleReturnsWholeSequence) { + auto stub = MakeGoodStub({MakeResponse(5), MakeResponse(3)}); + ReplaySampler sampler(stub, "table", {2, 1}); + + std::vector first; + TF_EXPECT_OK(sampler.GetNextSample(&first)); + EXPECT_THAT(first, SizeIs(4)); // ID, probability, table size, data. + ExpectTensorEqual(first[3], MakeTensor(5)); + + std::vector second; + TF_EXPECT_OK(sampler.GetNextSample(&second)); + EXPECT_THAT(second, SizeIs(4)); // ID, probability, table size, data. + ExpectTensorEqual(second[3], MakeTensor(3)); +} + +TEST(ReplaySamplerTest, GetNextSampleTrimsSequence) { + auto stub = MakeGoodStub({ + MakeResponse(5, false, 1, 6), // Trim offset at the start. + MakeResponse(3, false, 0, 4), // Trim timestep from end. + MakeResponse(2, false, 1, 10), // Trim offset and end. + }); + ReplaySampler sampler(stub, "table", {3, 1}); + + std::vector start_trimmed; + TF_EXPECT_OK(sampler.GetNextSample(&start_trimmed)); + ASSERT_THAT(start_trimmed, SizeIs(4)); // ID, probability, table size, data. + ExpectTensorEqual( + start_trimmed[3], + tensorflow::tensor::DeepCopy(MakeTensor(6).Slice(1, 6))); + + std::vector end_trimmed; + TF_EXPECT_OK(sampler.GetNextSample(&end_trimmed)); + ASSERT_THAT(end_trimmed, SizeIs(4)); // ID, probability, table size, data. + ExpectTensorEqual(end_trimmed[3], + MakeTensor(4).Slice(0, 3)); + + std::vector start_and_end_trimmed; + TF_EXPECT_OK(sampler.GetNextSample(&start_and_end_trimmed)); + ASSERT_THAT(start_and_end_trimmed, + SizeIs(4)); // ID, probability, table size, data. + ExpectTensorEqual( + start_and_end_trimmed[3], + tensorflow::tensor::DeepCopy(MakeTensor(10).Slice(1, 3))); +} + +TEST(ReplaySamplerTest, RespectsBufferSizeAndMaxSamples) { + const int kMaxSamples = 20; + const int kMaxInFlightSamplesPerWorker = 11; + const int kNumWorkers = 1; + + std::vector responses; + for (int i = 0; i < 40; i++) responses.push_back(MakeResponse(1)); + auto stub = MakeGoodStub(std::move(responses)); + + ReplaySampler sampler( + stub, "table", {kMaxSamples, kMaxInFlightSamplesPerWorker, kNumWorkers}); + + test::WaitFor( + [&]() { + return !stub->requests().empty() && stub->requests()[0].num_samples() == + kMaxInFlightSamplesPerWorker; + }, + absl::Milliseconds(10), 100); + + std::vector sample; + bool end_of_sequence; + + // The first request should aim to fill up the buffer. + ASSERT_THAT(stub->requests(), SizeIs(1)); + EXPECT_EQ(stub->requests()[0].num_samples(), kMaxInFlightSamplesPerWorker); + + // The queue outside the workers has size `num_workers` (i.e 1 here) so in + // addition to the samples actually returned to the user, an additional + // sample is considered to been "consumed" from the perspective of the worker. + + // The first 9 (9 + 1 = 10) pops should not result in a new request. + for (int i = 0; i < kMaxInFlightSamplesPerWorker - kNumWorkers - 1; i++) { + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + } + + test::WaitFor([&]() { return stub->requests().size() == 1; }, + absl::Milliseconds(10), 100); + EXPECT_THAT(stub->requests(), SizeIs(1)); + + // The 10th sample (+1 in the queue) mean that all the requested samples + // have been received and thus a new request is sent to retrieve more. + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + test::WaitFor( + [&]() { + return stub->requests().size() == 2 && + stub->requests()[1].num_samples() == + kMaxSamples - kMaxInFlightSamplesPerWorker; + }, + absl::Milliseconds(10), 100); + EXPECT_THAT(stub->requests(), SizeIs(2)); + + // The second request should respect the `max_samples` and thus only request + // 9 (9 + 11 = 20) more samples. + EXPECT_EQ(stub->requests()[1].num_samples(), + kMaxSamples - kMaxInFlightSamplesPerWorker); + + // Consuming the remaining 10 samples should not trigger any more requests + // as this would violate `max_samples`. + for (int i = 0; i < 10; i++) { + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + } + test::WaitFor([&]() { return stub->requests().size() == 2; }, + absl::Milliseconds(10), 100); + EXPECT_THAT(stub->requests(), SizeIs(2)); +} + +TEST(ReplaySamplerTest, UnpacksDeltaEncodedTensors) { + auto stub = MakeGoodStub({MakeResponse(10, false), MakeResponse(10, true)}); + ReplaySampler sampler(stub, "table", {2, 1}); + std::vector not_encoded; + std::vector encoded; + TF_EXPECT_OK(sampler.GetNextSample(¬_encoded)); + TF_EXPECT_OK(sampler.GetNextSample(&encoded)); + ASSERT_EQ(not_encoded.size(), encoded.size()); + EXPECT_EQ(encoded[0].dtype(), tensorflow::DT_UINT64); + for (int i = 3; i < encoded.size(); i++) { + ExpectTensorEqual(encoded[i], not_encoded[i]); + } +} + +TEST(ReplaySamplerTest, GetNextTimestepForwardsFatalServerError) { + const int kNumWorkers = 4; + const int kItemLength = 10; + const auto kError = grpc::Status(grpc::StatusCode::NOT_FOUND, ""); + + auto stub = MakeFlakyStub({MakeResponse(kItemLength)}, {kError}); + ReplaySampler sampler(stub, "table", + {ReplaySampler::kUnlimitedMaxSamples, 1, kNumWorkers}); + + // It is possible that the sample returned by one of the workers is reached + // before the failing worker has reported it's error so we need to pop at + // least two samples to ensure that the we will see the error. + tensorflow::Status status; + for (int i = 0; status.ok() && i < kItemLength + 1; i++) { + std::vector sample; + bool end_of_sequence; + status = sampler.GetNextTimestep(&sample, &end_of_sequence); + } + EXPECT_EQ(status.code(), tensorflow::error::NOT_FOUND); + sampler.Close(); +} + +TEST(ReplaySamplerTest, GetNextSampleForwardsFatalServerError) { + const int kNumWorkers = 4; + const int kItemLength = 10; + const auto kError = grpc::Status(grpc::StatusCode::NOT_FOUND, ""); + + auto stub = MakeFlakyStub({MakeResponse(kItemLength)}, {kError}); + ReplaySampler sampler(stub, "table", + {ReplaySampler::kUnlimitedMaxSamples, 1, kNumWorkers}); + + // It is possible that the sample returned by one of the workers is reached + // before the failing worker has reported it's error so we need to pop at + // least two samples to ensure that the we will see the error. + tensorflow::Status status; + for (int i = 0; status.ok() && i < 2; i++) { + std::vector sample; + status = sampler.GetNextSample(&sample); + } + EXPECT_EQ(status.code(), tensorflow::error::NOT_FOUND); +} + +TEST(ReplaySamplerTest, GetNextTimestepRetriesTransientErrors) { + const int kNumWorkers = 2; + const int kItemLength = 10; + const auto kError = grpc::Status(grpc::StatusCode::UNAVAILABLE, ""); + + auto stub = MakeFlakyStub( + {MakeResponse(kItemLength), MakeResponse(kItemLength)}, {kError}); + ReplaySampler sampler(stub, "table", + {ReplaySampler::kUnlimitedMaxSamples, 1, kNumWorkers}); + + // It is possible that the sample returned by one of the workers is reached + // before the failing worker has reported it's error so we need to pop at + // least two samples to ensure that the we will see the error. + for (int i = 0; i < kItemLength + 1; i++) { + std::vector sample; + bool end_of_sequence; + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + } +} + +TEST(ReplaySamplerTest, GetNextSampleRetriesTransientErrors) { + const int kNumWorkers = 2; + const int kItemLength = 10; + const auto kError = grpc::Status(grpc::StatusCode::UNAVAILABLE, ""); + + auto stub = MakeFlakyStub( + {MakeResponse(kItemLength), MakeResponse(kItemLength)}, {kError}); + ReplaySampler sampler(stub, "table", + {ReplaySampler::kUnlimitedMaxSamples, 1, kNumWorkers}); + + // It is possible that the sample returned by one of the workers is reached + // before the failing worker has reported it's error so we need to pop at + // least two samples to ensure that the we will see the error. + for (int i = 0; i < 2; i++) { + std::vector sample; + TF_EXPECT_OK(sampler.GetNextSample(&sample)); + } +} + +TEST(ReplaySamplerTest, GetNextTimestepReturnsErrorIfMaximumSamplesExceeded) { + auto stub = MakeGoodStub({MakeResponse(1), MakeResponse(1), MakeResponse(1)}); + ReplaySampler sampler(stub, "table", {2, 1, 1}); + std::vector sample; + bool end_of_sequence; + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + EXPECT_EQ(sampler.GetNextTimestep(&sample, &end_of_sequence).code(), + tensorflow::error::OUT_OF_RANGE); +} + +TEST(ReplaySamplerTest, GetNextSampleReturnsErrorIfMaximumSamplesExceeded) { + auto stub = MakeGoodStub({MakeResponse(5), MakeResponse(5), MakeResponse(5)}); + ReplaySampler sampler(stub, "table", {2, 1, 1}); + std::vector sample; + TF_EXPECT_OK(sampler.GetNextSample(&sample)); + TF_EXPECT_OK(sampler.GetNextSample(&sample)); + EXPECT_EQ(sampler.GetNextSample(&sample).code(), + tensorflow::error::OUT_OF_RANGE); +} + +TEST(ReplaySamplerTest, StressTestWithoutErrors) { + const int kNumWorkers = 100; // Should be larger than the number of CPUs. + const int kMaxSamples = 10000; + const int kMaxSamplesPerStream = 50; + const int kMaxInflightSamplesPerStream = 7; + const int kItemLength = 3; + + std::vector responses(kMaxSamplesPerStream); + for (int i = 0; i < kMaxSamplesPerStream; i++) { + responses[i] = MakeResponse(kItemLength); + } + + auto stub = std::make_shared(); + for (int i = 0; i < (kMaxSamples / kMaxSamplesPerStream) + kNumWorkers; i++) { + stub->AddStream(responses); + } + + ReplaySampler sampler(stub, "table", + {kMaxSamples, kMaxInflightSamplesPerStream, kNumWorkers, + kMaxSamplesPerStream}); + + for (int i = 0; i < kItemLength * kMaxSamples; i++) { + std::vector sample; + bool end_of_sequence; + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + } + + // There should be no more samples left. + std::vector sample; + bool end_of_sequence; + EXPECT_EQ(sampler.GetNextTimestep(&sample, &end_of_sequence).code(), + tensorflow::error::OUT_OF_RANGE); +} + +TEST(ReplaySamplerTest, StressTestWithTransientErrors) { + const int kNumWorkers = 100; // Should be larger than the number of CPUs. + const int kMaxSamples = 10000; + const int kMaxSamplesPerStream = 50; + const int kMaxInflightSamplesPerStream = 7; + const int kItemLength = 3; + const int kTransientErrorFrequency = 23; + + std::vector responses(kMaxSamplesPerStream); + for (int i = 0; i < kMaxSamplesPerStream; i++) { + responses[i] = MakeResponse(kItemLength); + } + + auto stub = std::make_shared(); + for (int i = 0; i < (kMaxSamples / kMaxSamplesPerStream) + kNumWorkers; i++) { + auto status = i % kTransientErrorFrequency != 0 + ? grpc::Status::OK + : grpc::Status(grpc::StatusCode::UNAVAILABLE, ""); + stub->AddStream(responses, status); + } + + ReplaySampler sampler(stub, "table", + {kMaxSamples, kMaxInflightSamplesPerStream, kNumWorkers, + kMaxSamplesPerStream}); + + for (int i = 0; i < kItemLength * kMaxSamples; i++) { + std::vector sample; + bool end_of_sequence; + TF_EXPECT_OK(sampler.GetNextTimestep(&sample, &end_of_sequence)); + } + + // There should be no more samples left. + std::vector sample; + bool end_of_sequence; + EXPECT_EQ(sampler.GetNextTimestep(&sample, &end_of_sequence).code(), + tensorflow::error::OUT_OF_RANGE); +} + +TEST(ReplaySamplerDeathTest, DiesIfMaxInFlightSamplesPerWorkerIsNonPositive) { + ReplaySampler::Options options; + + options.max_in_flight_samples_per_worker = 0; + ASSERT_DEATH(ReplaySampler sampler(MakeGoodStub({}), "table", options), ""); + + options.max_in_flight_samples_per_worker = -1; + ASSERT_DEATH(ReplaySampler sampler(MakeGoodStub({}), "table", options), ""); +} + +TEST(ReplaySamplerDeathTest, DiesIfMaxSamplesInvalid) { + ReplaySampler::Options options; + + options.max_samples = -2; + ASSERT_DEATH(ReplaySampler sampler(MakeGoodStub({}), "table", options), ""); + + options.max_samples = 0; + ASSERT_DEATH(ReplaySampler sampler(MakeGoodStub({}), "table", options), ""); +} + +TEST(ReplaySamplerDeathTest, DiesIfNumWorkersIsInvalid) { + ReplaySampler::Options options; + + options.num_workers = 0; + ASSERT_DEATH(ReplaySampler sampler(MakeGoodStub({}), "table", options), ""); + + options.num_workers = -2; + ASSERT_DEATH(ReplaySampler sampler(MakeGoodStub({}), "table", options), ""); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/replay_service.proto b/reverb/cc/replay_service.proto new file mode 100644 index 0000000..f50974a --- /dev/null +++ b/reverb/cc/replay_service.proto @@ -0,0 +1,129 @@ +syntax = "proto3"; + +package deepmind.reverb; + +import "reverb/cc/schema.proto"; + +service ReplayService { + // Writes all in-memory data to disk. We store the priority tables + // along with the chunks. On success, the path of the checkpoint is + // returned. On preemption, the last checkpoint will be used to restore the + // replay. + rpc Checkpoint(CheckpointRequest) returns (CheckpointResponse) {} + + // Inserts chunks into the ChunkStore and items into priority tables. This + // operation is a stream that is called by `ReplayWriter`. A stream mesasage + // either inserts a chunk or an item into a priority table. + // + // Important: We keep a reference to each chunk that was written to + // the stream. When inserting an item into a priority table, this item is + // allowed to refer to any items that we keep references to. After inserting + // an item, we clear our references to all chunks which are not explictly + // specified in `keep_chunk_keys`. This means the typical + // order of stream messages is something like: [CHUNK C1] [CHUNK C2] [ITEM + // USING C1&C2 AND KEEP C2] [CHUNK C3] [ITEM USING C2&C3] + rpc InsertStream(stream InsertStreamRequest) returns (InsertStreamResponse) {} + + // Changes the items in a priority table. The operations might be + // applied partially if an error occurs. + rpc MutatePriorities(MutatePrioritiesRequest) + returns (MutatePrioritiesResponse) {} + + // Clears all items of a PriorityTable and resets its RateLimiter. + rpc Reset(ResetRequest) returns (ResetResponse) {} + + // Samples items from a priority table. The client starts by requesting a + // fixed (or unlimited) number of samples from a specific table. Once + // received, the server starts streaming the requested samples. The first + // response of each sample contains info in addition to the first chunk of + // data. A typical response looks like: [INFO + CHUNK] [CHUNK] [CHUNK] [INFO + + // CHUNK] [CHUNK] ... + // The client is allowed at any time to send additional requests for more + // samples. + rpc SampleStream(stream SampleStreamRequest) + returns (stream SampleStreamResponse) {} + + // Get updated information on all of the tables on the server. + rpc ServerInfo(ServerInfoRequest) returns (ServerInfoResponse) {} +} + +message CheckpointRequest {} + +message CheckpointResponse { + // Path to disk where the checkpoint was written to. + string checkpoint_path = 1; +} + +message InsertStreamRequest { + message PriorityInsertion { + // The item that should be inserted. + PrioritizedItem item = 1; + + // Specifies which chunk keys are needed in the next request. This will + // result in an internal reference which prevents the chunks from deletion + // until the next priority insertion. + repeated uint64 keep_chunk_keys = 2; + } + + oneof payload { + // Chunk that can be referenced by a PrioritizedItem. + ChunkData chunk = 1; + + // Item for inserting into priority tables. The item must only reference + // chunks that has been sent been sent on the stream thus far and kept after + // previous insertion requests. + PriorityInsertion item = 2; + } +} + +message InsertStreamResponse {} + +message MutatePrioritiesRequest { + // Name of the priority table that the operations below should be + // applied to. This field must be set. + string table = 1; + + // All operations below are applied in the order that they are listed. + // Different operations can be set at the same time. + + // Items to update. If an item does not exist, that item is ignored. + repeated KeyWithPriority updates = 2; + + // Items to delete. If an item does not exist, that item is deleted. + repeated uint64 delete_keys = 3; +} + +message MutatePrioritiesResponse {} + +message ServerInfoRequest {} + +message ServerInfoResponse { + Uint128 tables_state_id = 1; + repeated TableInfo table_info = 2; +} + +message SampleStreamRequest { + // Name of the priority table that we should sample from. + string table = 1; + + // The number of samples to stream. Defaults to infinite. + int64 num_samples = 2; +} + +message SampleStreamResponse { + // First response item in the stream is info about the sample. + SampleInfo info = 1; + + // Followed by at least one data response. + ChunkData data = 2; + + // True if this is the last message in the sequence. + bool end_of_sequence = 3; +} + +message ResetRequest { + // The table to reset. + string table = 1; +} + +message ResetResponse {} diff --git a/reverb/cc/replay_service_impl.cc b/reverb/cc/replay_service_impl.cc new file mode 100644 index 0000000..512aa3f --- /dev/null +++ b/reverb/cc/replay_service_impl.cc @@ -0,0 +1,290 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/replay_service_impl.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "reverb/cc/checkpointing/interface.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/replay_service.grpc.pb.h" +#include "reverb/cc/support/grpc_util.h" +#include "reverb/cc/support/uint128.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { +namespace { + +inline grpc::Status TableNotFound(absl::string_view name) { + return grpc::Status(grpc::StatusCode::NOT_FOUND, + absl::StrCat("Priority table ", name, " was not found")); +} + +inline grpc::Status Internal(const std::string& message) { + return grpc::Status(grpc::StatusCode::INTERNAL, message); +} + +} // namespace + +ReplayServiceImpl::ReplayServiceImpl( + std::vector> priority_tables, + std::shared_ptr checkpointer) + : checkpointer_(std::move(checkpointer)) { + if (checkpointer_ != nullptr) { + auto status = checkpointer_->LoadLatest(&chunk_store_, &priority_tables); + if (!tensorflow::errors::IsNotFound(status)) { + TF_CHECK_OK(status) << "Error when loading checkpoint: " + << status.ToString(); + } + } + + for (auto& priority_table : priority_tables) { + priority_tables_[priority_table->name()] = std::move(priority_table); + } + + tables_state_id_ = absl::MakeUint128(absl::Uniform(rnd_), + absl::Uniform(rnd_)); +} + +grpc::Status ReplayServiceImpl::Checkpoint(grpc::ServerContext* context, + const CheckpointRequest* request, + CheckpointResponse* response) { + if (checkpointer_ == nullptr) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "no Checkpointer configured for the replay service."); + } + + std::vector tables; + for (auto& table : priority_tables_) { + tables.push_back(table.second.get()); + } + + auto status = checkpointer_->Save(std::move(tables), 1, + response->mutable_checkpoint_path()); + if (!status.ok()) return ToGrpcStatus(status); + + REVERB_LOG(REVERB_INFO) << "Stored checkpoint to " + << response->checkpoint_path(); + return grpc::Status::OK; +} + +grpc::Status ReplayServiceImpl::InsertStream( + grpc::ServerContext* context, + grpc::ServerReader* reader, + InsertStreamResponse* response) { + return InsertStreamInternal(context, reader, response); +} + +grpc::Status ReplayServiceImpl::InsertStreamInternal( + grpc::ServerContext* context, + grpc::ServerReaderInterface* reader, + InsertStreamResponse* response) { + absl::flat_hash_map> + chunks; + InsertStreamRequest request; + + while (reader->Read(&request)) { + if (request.has_chunk()) { + ChunkStore::Key key = request.chunk().chunk_key(); + std::shared_ptr chunk = + chunk_store_.Insert(std::move(*request.mutable_chunk())); + if (!chunk) { + return grpc::Status(grpc::StatusCode::CANCELLED, + "Service has been closed"); + } + chunks[key] = std::move(chunk); + } else if (request.has_item()) { + PriorityTable::Item item; + + auto push_or = [&chunks, &item](ChunkStore::Key key) -> grpc::Status { + auto it = chunks.find(key); + if (it == chunks.end()) { + return Internal( + absl::StrCat("Could not find sequence chunk ", key, ".")); + } + item.chunks.push_back(it->second); + return grpc::Status::OK; + }; + + for (ChunkStore::Key key : request.item().item().chunk_keys()) { + auto status = push_or(key); + if (!status.ok()) return status; + } + + const auto& table_name = request.item().item().table(); + PriorityTable* priority_table = PriorityTableByName(table_name); + if (priority_table == nullptr) return TableNotFound(table_name); + + item.item = *request.mutable_item()->mutable_item(); + + { + auto status = priority_table->InsertOrAssign(item); + if (!status.ok()) { + return ToGrpcStatus(status); + } + } + + // Only keep specified chunks. + absl::flat_hash_set keep_keys{ + request.item().keep_chunk_keys().begin(), + request.item().keep_chunk_keys().end()}; + for (auto it = chunks.cbegin(); it != chunks.cend();) { + if (keep_keys.find(it->first) == keep_keys.end()) { + chunks.erase(it++); + } else { + ++it; + } + } + REVERB_CHECK_EQ(chunks.size(), keep_keys.size()) + << "Kept less chunks than expected."; + } + } + + return grpc::Status::OK; +} + +grpc::Status ReplayServiceImpl::MutatePriorities( + grpc::ServerContext* context, const MutatePrioritiesRequest* request, + MutatePrioritiesResponse* response) { + PriorityTable* priority_table = PriorityTableByName(request->table()); + if (priority_table == nullptr) return TableNotFound(request->table()); + + auto status = priority_table->MutateItems( + std::vector(request->updates().begin(), + request->updates().end()), + request->delete_keys()); + if (!status.ok()) return ToGrpcStatus(status); + return grpc::Status::OK; +} + +grpc::Status ReplayServiceImpl::Reset(grpc::ServerContext* context, + const ResetRequest* request, + ResetResponse* response) { + PriorityTable* priority_table = PriorityTableByName(request->table()); + if (priority_table == nullptr) return TableNotFound(request->table()); + + auto status = priority_table->Reset(); + if (!status.ok()) { + return ToGrpcStatus(status); + } + return grpc::Status::OK; +} + +grpc::Status ReplayServiceImpl::SampleStream( + grpc::ServerContext* context, + grpc::ServerReaderWriter* + stream) { + return SampleStreamInternal(context, stream); +} + +grpc::Status ReplayServiceImpl::SampleStreamInternal( + grpc::ServerContext* context, + grpc::ServerReaderWriterInterface* stream) { + SampleStreamRequest request; + if (!stream->Read(&request)) { + return Internal("Could not read initial request"); + } + + do { + if (request.num_samples() <= 0) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "`num_samples` must be > 0."); + } + PriorityTable* priority_table = PriorityTableByName(request.table()); + if (priority_table == nullptr) return TableNotFound(request.table()); + + int count = 0; + while (!context->IsCancelled() && count++ != request.num_samples()) { + PriorityTable::SampledItem sample; + { + + auto status = priority_table->Sample(&sample); + if (!status.ok()) { + return ToGrpcStatus(status); + } + } + + for (int i = 0; i < sample.chunks.size(); i++) { + SampleStreamResponse response; + response.set_end_of_sequence(i + 1 == sample.chunks.size()); + + // Attach the info to the first message. + if (i == 0) { + *response.mutable_info()->mutable_item() = sample.item; + response.mutable_info()->set_probability(sample.probability); + response.mutable_info()->set_table_size(sample.table_size); + } + + // We const cast to avoid copying the proto. + response.set_allocated_data( + const_cast(&sample.chunks[i]->data())); + + grpc::WriteOptions options; + options.set_no_compression(); // Data is already compressed. + bool ok = stream->Write(response, options); + response.release_data(); + if (!ok) { + return Internal("Failed to write to Sample stream."); + } + + // We no longer need our chunk reference, so we free it. + sample.chunks[i] = nullptr; + } + } + + request.Clear(); + } while (stream->Read(&request)); + + return grpc::Status::OK; +} + +PriorityTable* ReplayServiceImpl::PriorityTableByName( + absl::string_view name) const { + auto it = priority_tables_.find(name); + if (it == priority_tables_.end()) return nullptr; + return it->second.get(); +} + +void ReplayServiceImpl::Close() { + for (auto& table : priority_tables_) { + table.second->Close(); + } +} + +grpc::Status ReplayServiceImpl::ServerInfo(grpc::ServerContext* context, + const ServerInfoRequest* request, + ServerInfoResponse* response) { + for (const auto& iter : priority_tables_) { + *response->add_table_info() = iter.second->info(); + } + *response->mutable_tables_state_id() = Uint128ToMessage(tables_state_id_); + return grpc::Status::OK; +} + +absl::flat_hash_map> +ReplayServiceImpl::tables() const { + return priority_tables_; +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/replay_service_impl.h b/reverb/cc/replay_service_impl.h new file mode 100644 index 0000000..ca5fee8 --- /dev/null +++ b/reverb/cc/replay_service_impl.h @@ -0,0 +1,109 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_REPLAY_SERVICE_IMPL_H_ +#define REVERB_CC_REPLAY_SERVICE_IMPL_H_ + +#include + +#include "grpcpp/grpcpp.h" +#include "absl/container/flat_hash_map.h" +#include "absl/numeric/int128.h" +#include "absl/random/random.h" +#include "absl/strings/string_view.h" +#include "reverb/cc/checkpointing/interface.h" +#include "reverb/cc/chunk_store.h" +#include "reverb/cc/priority_table.h" +#include "reverb/cc/replay_service.grpc.pb.h" +#include "reverb/cc/replay_service.pb.h" +#include "reverb/cc/schema.pb.h" + +namespace deepmind { +namespace reverb { + +// Implements ReplayService. See replay_service.proto for documentation. +class ReplayServiceImpl : public /* grpc_gen:: */ReplayService::Service { + public: + explicit ReplayServiceImpl( + std::vector> priority_tables, + std::shared_ptr checkpointer = nullptr); + + grpc::Status Checkpoint(grpc::ServerContext* context, + const CheckpointRequest* request, + CheckpointResponse* response) override; + + grpc::Status InsertStream(grpc::ServerContext* context, + grpc::ServerReader* reader, + InsertStreamResponse* response) override; + + grpc::Status InsertStreamInternal( + grpc::ServerContext* context, + grpc::ServerReaderInterface* reader, + InsertStreamResponse* response); + + grpc::Status MutatePriorities(grpc::ServerContext* context, + const MutatePrioritiesRequest* request, + MutatePrioritiesResponse* response) override; + + grpc::Status Reset(grpc::ServerContext* context, const ResetRequest* request, + ResetResponse* response) override; + + grpc::Status SampleStream( + grpc::ServerContext* context, + grpc::ServerReaderWriter* + stream) override; + + grpc::Status SampleStreamInternal( + grpc::ServerContext* context, + grpc::ServerReaderWriterInterface* stream); + + grpc::Status ServerInfo(grpc::ServerContext* context, + const ServerInfoRequest* request, + ServerInfoResponse* response) override; + + // Gets a copy of the table lookup. + absl::flat_hash_map> tables() + const; + + // Closes all priority tables and the chunk store. + void Close(); + + private: + // Lookups the priority table for a given name. Returns nullptr if not found. + PriorityTable* PriorityTableByName(absl::string_view name) const; + + // Checkpointer used to restore state in the constructor and to save data + // when `Checkpoint` is called. Note that if `checkpointer_` is nullptr then + // `Checkpoint` will return an `InvalidArgumentError`. + std::shared_ptr checkpointer_; + + // Stores chunks and keeps references to them. + ChunkStore chunk_store_; + + // Priority tables. Must be destroyed after `chunk_store_`. + absl::flat_hash_map> + priority_tables_; + + absl::BitGen rnd_; + + // A new id must be generated whenever a table is added, deleted, or has its + // signature modified. + absl::uint128 tables_state_id_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_REPLAY_SERVICE_IMPL_H_ diff --git a/reverb/cc/replay_service_impl_test.cc b/reverb/cc/replay_service_impl_test.cc new file mode 100644 index 0000000..7325004 --- /dev/null +++ b/reverb/cc/replay_service_impl_test.cc @@ -0,0 +1,420 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/replay_service_impl.h" + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/memory/memory.h" +#include "absl/synchronization/notification.h" +#include "absl/types/optional.h" +#include "reverb/cc/distributions/fifo.h" +#include "reverb/cc/distributions/uniform.h" +#include "reverb/cc/platform/checkpointing.h" +#include "reverb/cc/platform/thread.h" +#include "reverb/cc/replay_service.pb.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace deepmind { +namespace reverb { +namespace { + +const int64_t kMinSizeToSample = 1; +const double kSamplesPerInsert = 1.0; +const double kMinDiff = -DBL_MAX; +const double kMaxDiff = DBL_MAX; + +int64_t nextId = 1; + +class FakeInsertStreamReader + : public grpc::ServerReaderInterface { + public: + void AddChunk(int64_t key) { + InsertStreamRequest request; + request.mutable_chunk()->set_chunk_key(key); + read_buffer_.push_back(std::move(request)); + } + + PrioritizedItem AddItem(absl::string_view table, + const std::vector& sequence_chunks, + const std::vector& keep_chunks = {}) { + PrioritizedItem item; + item.set_key(nextId++); + item.set_table(table.data(), table.size()); + *item.mutable_chunk_keys() = {sequence_chunks.begin(), + sequence_chunks.end()}; + if (!sequence_chunks.empty()) { + item.mutable_sequence_range()->set_offset(0); + item.mutable_sequence_range()->set_length(100); + } + + InsertStreamRequest request; + *request.mutable_item()->mutable_keep_chunk_keys() = {keep_chunks.begin(), + keep_chunks.end()}; + *request.mutable_item()->mutable_item() = item; + read_buffer_.push_back(std::move(request)); + return item; + } + + bool Read(InsertStreamRequest* request) override { + if (read_buffer_.empty()) return false; + *request = read_buffer_.front(); + read_buffer_.pop_front(); + return true; + } + + void SendInitialMetadata() override {} + bool NextMessageSize(uint32_t*) override { return false; } + + private: + std::list read_buffer_; +}; + +class FakeSampleStream + : public grpc::ServerReaderWriterInterface { + public: + explicit FakeSampleStream() = default; + + const std::vector& responses() { return buffer_; } + const grpc::WriteOptions last_options() { return options_; } + + bool Write(const SampleStreamResponse& response, + grpc::WriteOptions options) override { + buffer_.push_back(response); + options_ = options; + return true; + } + + bool Read(SampleStreamRequest* request) override { + if (requests_.empty()) return false; + request->set_table(requests_.front().table()); + request->set_num_samples(requests_.front().num_samples()); + requests_.pop_front(); + return true; + } + + bool NextMessageSize(uint32_t* sz) override { + if (!requests_.empty()) *sz = requests_.front().ByteSizeLong(); + return !requests_.empty(); + } + + void AddRequest(std::string table, int num_samples) { + SampleStreamRequest request; + request.set_table(std::move(table)); + request.set_num_samples(num_samples); + requests_.push_back(std::move(request)); + } + + void SendInitialMetadata() override {} + + private: + std::list requests_; + std::vector buffer_; + grpc::WriteOptions options_; +}; + +tensorflow::StructuredValue MakeSignature() { + tensorflow::StructuredValue signature; + auto* tensor_spec = signature.mutable_tensor_spec_value(); + tensor_spec->set_name("item0"); + tensorflow::TensorShape().AsProto(tensor_spec->mutable_shape()); + tensor_spec->set_dtype(tensorflow::DT_INT32); + return signature; +} + +std::unique_ptr MakeService( + int max_size, std::unique_ptr checkpointer) { + std::vector> tables; + + tables.push_back(absl::make_unique( + "dist", absl::make_unique(), + absl::make_unique(), max_size, 0, + absl::make_unique(kSamplesPerInsert, kMinSizeToSample, + kMinDiff, kMaxDiff), + /*extensions=*/ + std::vector>{}, + /*signature=*/absl::make_optional(MakeSignature()))); + return absl::make_unique(std::move(tables), + std::move(checkpointer)); +} + +std::unique_ptr MakeService(int max_size) { + return MakeService(max_size, nullptr); +} + +TEST(ReplayServiceImplTest, SampleAfterInsertWorks) { + std::unique_ptr service = MakeService(10); + + FakeInsertStreamReader reader; + reader.AddChunk(1); + reader.AddChunk(2); + reader.AddChunk(3); + PrioritizedItem item = reader.AddItem("dist", {2, 3}); + ASSERT_TRUE(service->InsertStreamInternal(nullptr, &reader, nullptr).ok()); + + for (int i = 0; i < 5; i++) { + FakeSampleStream stream; + stream.AddRequest("dist", 1); + + grpc::ServerContext context; + ASSERT_TRUE(service->SampleStreamInternal(&context, &stream).ok()); + ASSERT_EQ(stream.responses().size(), 2); + + item.set_times_sampled(i + 1); + + SampleInfo info = stream.responses()[0].info(); + info.mutable_item()->clear_inserted_at(); + EXPECT_THAT(info.item(), testing::EqualsProto(item)); + EXPECT_EQ(info.probability(), 1); + EXPECT_EQ(info.table_size(), 1); + + EXPECT_EQ(stream.responses()[0].data().chunk_key(), item.chunk_keys(0)); + EXPECT_EQ(stream.responses()[1].data().chunk_key(), item.chunk_keys(1)); + EXPECT_TRUE(stream.last_options().get_no_compression()); + } +} + +TEST(ReplayServiceImplTest, InsertChunksWithoutItemWorks) { + std::unique_ptr service = MakeService(10); + grpc::ServerContext context; + + FakeInsertStreamReader reader; + reader.AddChunk(1); + reader.AddChunk(2); + EXPECT_OK(service->InsertStreamInternal(&context, &reader, nullptr)); +} + +TEST(ReplayServiceImplTest, InsertSameChunkTwiceWorks) { + std::unique_ptr service = MakeService(10); + grpc::ServerContext context; + + FakeInsertStreamReader reader; + reader.AddChunk(1); + reader.AddChunk(1); + EXPECT_OK(service->InsertStreamInternal(&context, &reader, nullptr)); +} + +TEST(ReplayServiceImplTest, InsertItemWithoutKeptChunkFails) { + std::unique_ptr service = MakeService(10); + grpc::ServerContext context; + + FakeInsertStreamReader reader; + reader.AddChunk(1); + reader.AddChunk(2); + reader.AddItem("dist", {1, 2}); + reader.AddItem("dist", {2, 3}); + EXPECT_EQ( + service->InsertStreamInternal(&context, &reader, nullptr).error_code(), + grpc::StatusCode::INTERNAL); +} + +TEST(ReplayServiceImplTest, InsertItemWithKeptChunkWorks) { + std::unique_ptr service = MakeService(10); + grpc::ServerContext context; + + FakeInsertStreamReader reader; + reader.AddChunk(1); + reader.AddChunk(2); + reader.AddItem("dist", {1, 2}, {2}); + reader.AddItem("dist", {2, 3}); + EXPECT_EQ( + service->InsertStreamInternal(&context, &reader, nullptr).error_code(), + grpc::StatusCode::INTERNAL); +} + +TEST(ReplayServiceImplTest, InsertItemWithMissingChunksFails) { + std::unique_ptr service = MakeService(10); + grpc::ServerContext context; + + FakeInsertStreamReader reader; + reader.AddChunk(1); + reader.AddItem("dist", {2}); + EXPECT_EQ( + service->InsertStreamInternal(&context, &reader, nullptr).error_code(), + grpc::StatusCode::INTERNAL); +} + +TEST(ReplayServiceImplTest, SampleBlocksUntilEnoughInserts) { + std::unique_ptr service = MakeService(10); + absl::Notification notification; + auto thread = internal::StartThread("", [&] { + FakeSampleStream stream; + stream.AddRequest("dist", 1); + grpc::ServerContext context; + EXPECT_OK(service->SampleStreamInternal(&context, &stream)); + notification.Notify(); + }); + + // Blocking because there are no data to sample. + EXPECT_FALSE(notification.HasBeenNotified()); + + // Insert an item. + FakeInsertStreamReader reader; + reader.AddChunk(1); + reader.AddItem("dist", {1}); + ASSERT_TRUE(service->InsertStreamInternal(nullptr, &reader, nullptr).ok()); + + // The sample should now complete because there is data to sample. + notification.WaitForNotification(); + + thread = nullptr; // Joins the thread. +} + +TEST(ReplayServiceImplTest, MutateDeletionWorks) { + std::unique_ptr service = MakeService(10); + + FakeInsertStreamReader reader; + reader.AddChunk(1); + PrioritizedItem item = reader.AddItem("dist", {1}); + ASSERT_TRUE(service->InsertStreamInternal(nullptr, &reader, nullptr).ok()); + + EXPECT_EQ(service->tables()["dist"]->size(), 1); + + MutatePrioritiesRequest mutate_request; + mutate_request.set_table("dist"); + mutate_request.add_delete_keys(item.key()); + EXPECT_OK(service->MutatePriorities(nullptr, &mutate_request, nullptr)); + + EXPECT_EQ(service->tables()["dist"]->size(), 0); +} + +TEST(ReplayServiceImplTest, AnyCallWithInvalidDistributionFails) { + std::unique_ptr service = MakeService(10); + grpc::ServerContext context; + + FakeSampleStream sample_stream; + sample_stream.AddRequest("invalid", 1); + EXPECT_EQ( + service->SampleStreamInternal(&context, &sample_stream).error_code(), + grpc::StatusCode::NOT_FOUND); + + MutatePrioritiesRequest mutate_request; + mutate_request.set_table("invalid"); + EXPECT_EQ( + service->MutatePriorities(nullptr, &mutate_request, nullptr).error_code(), + grpc::StatusCode::NOT_FOUND); + + FakeInsertStreamReader reader; + reader.AddChunk(1); + reader.AddItem("invalid", {1}); + EXPECT_EQ( + service->InsertStreamInternal(nullptr, &reader, nullptr).error_code(), + grpc::StatusCode::NOT_FOUND); +} + +TEST(ReplayServiceImplTest, ResetWorks) { + std::unique_ptr service = MakeService(10); + + FakeInsertStreamReader reader; + reader.AddChunk(1); + PrioritizedItem item = reader.AddItem("dist", {1}); + ASSERT_TRUE(service->InsertStreamInternal(nullptr, &reader, nullptr).ok()); + + EXPECT_EQ(service->tables()["dist"]->size(), 1); + + ResetRequest reset_request; + reset_request.set_table("dist"); + ResetResponse reset_response; + ASSERT_TRUE(service->Reset(nullptr, &reset_request, &reset_response).ok()); + + EXPECT_EQ(service->tables()["dist"]->size(), 0); +} + +TEST(ReplayServiceImplTest, ServerInfoWorks) { + auto service = MakeService(10); + + ServerInfoRequest server_info_request; + ServerInfoResponse server_info_response; + ASSERT_TRUE( + service->ServerInfo(nullptr, &server_info_request, &server_info_response) + .ok()); + + // The probability of these being 0 is 2^{-128} + EXPECT_NE(std::make_pair(server_info_response.tables_state_id().low(), + server_info_response.tables_state_id().high()), + std::make_pair(uint64_t{0}, uint64_t{0})); + + EXPECT_EQ(server_info_response.table_info_size(), 1); + const auto& table_info = server_info_response.table_info()[0]; + + TableInfo expected_table_info; + expected_table_info.set_name("dist"); + expected_table_info.mutable_sampler_options()->set_uniform(true); + expected_table_info.mutable_remover_options()->set_fifo(true); + expected_table_info.set_max_size(10); + expected_table_info.set_current_size(0); + auto rate_limiter = expected_table_info.mutable_rate_limiter_info(); + rate_limiter->set_samples_per_insert(kSamplesPerInsert); + rate_limiter->set_min_size_to_sample(kMinSizeToSample); + rate_limiter->set_min_diff(kMinDiff); + rate_limiter->set_max_diff(kMaxDiff); + *expected_table_info.mutable_signature() = MakeSignature(); + + EXPECT_THAT(table_info, testing::EqualsProto(expected_table_info)); +} + +TEST(ReplayServiceImplTest, CheckpointCalledWithoutCheckpointer) { + auto service = MakeService(10); + CheckpointRequest request; + CheckpointResponse response; + + EXPECT_EQ(service->Checkpoint(nullptr, &request, &response).error_code(), + grpc::StatusCode::INVALID_ARGUMENT); +} + +TEST(ReplayServiceImplTest, CheckpointAndLoadFromCheckpoint) { + std::string path = getenv("TEST_TMPDIR"); + REVERB_CHECK(tensorflow::Env::Default()->CreateUniqueFileName(&path, "temp")); + auto service = MakeService(10, CreateDefaultCheckpointer(path)); + + // Check that there are no items in the service to begin with. + EXPECT_EQ(service->tables()["dist"]->size(), 0); + + // Insert an item. + { + FakeInsertStreamReader reader; + reader.AddChunk(1); + reader.AddItem("dist", {1}); + ASSERT_TRUE(service->InsertStreamInternal(nullptr, &reader, nullptr).ok()); + } + + EXPECT_EQ(service->tables()["dist"]->size(), 1); + + // Checkpoint the service. + { + CheckpointRequest request; + CheckpointResponse response; + grpc::ServerContext context; + EXPECT_OK(service->Checkpoint(nullptr, &request, &response)); + } + + // Create a new service from the checkpoint and check that it has the correct + // number of items. + auto loaded_service = MakeService(10, CreateDefaultCheckpointer(path)); + EXPECT_EQ(loaded_service->tables()["dist"]->size(), 1); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/replay_writer.cc b/reverb/cc/replay_writer.cc new file mode 100644 index 0000000..56fd9f8 --- /dev/null +++ b/reverb/cc/replay_writer.cc @@ -0,0 +1,379 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/replay_writer.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/support/grpc_util.h" +#include "reverb/cc/tensor_compression.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { +namespace { + +bool IsTransientError(const tensorflow::Status& status) { + return tensorflow::errors::IsDeadlineExceeded(status) || + tensorflow::errors::IsUnavailable(status) || + tensorflow::errors::IsCancelled(status); +} + +int PositiveModulo(int value, int divisor) { + if (divisor == 0) return value; + + if ((value > 0) == (divisor > 0)) { + // value and divisor have the same sign. + return value % divisor; + } else { + // value and divisor have different signs. + return (value % divisor + divisor) % divisor; + } +} + +} // namespace + +ReplayWriter::ReplayWriter( + std::shared_ptr stub, + int chunk_length, int max_timesteps, bool delta_encoded, + std::shared_ptr signatures) + : stub_(std::move(stub)), + chunk_length_(chunk_length), + max_timesteps_(max_timesteps), + delta_encoded_(delta_encoded), + signatures_(std::move(signatures)), + next_chunk_key_(NewID()), + episode_id_(NewID()), + index_within_episode_(0), + closed_(false), + inserted_dtypes_and_shapes_(max_timesteps) {} + +ReplayWriter::~ReplayWriter() { + if (!closed_) Close().IgnoreError(); +} + +tensorflow::Status ReplayWriter::AppendTimestep( + std::vector data) { + if (closed_) { + return tensorflow::errors::FailedPrecondition( + "Calling method AppendTimestep after Close has been called"); + } + if (!buffer_.empty() && buffer_.front().size() != data.size()) { + return tensorflow::errors::InvalidArgument( + "Number of tensors per timestep was inconsistent. Previously ", + buffer_.front().size(), " now ", data.size()); + } + + // Store flattened signature into inserted_dtypes_and_shapes_ + internal::DtypesAndShapes dtypes_and_shapes_t(0); + dtypes_and_shapes_t->reserve(data.size()); + for (const auto& t : data) { + dtypes_and_shapes_t->push_back( + {t.dtype(), tensorflow::PartialTensorShape(t.shape())}); + } + std::swap(dtypes_and_shapes_t, + inserted_dtypes_and_shapes_[insert_dtypes_and_shapes_location_]); + insert_dtypes_and_shapes_location_ = + (insert_dtypes_and_shapes_location_ + 1) % max_timesteps_; + + buffer_.push_back(std::move(data)); + if (buffer_.size() < chunk_length_) return tensorflow::Status::OK(); + + auto status = Finish(); + if (!status.ok()) { + // Undo adding stuff to the buffer and undo the dtypes_and_shapes_ changes. + buffer_.pop_back(); + insert_dtypes_and_shapes_location_ = + PositiveModulo(insert_dtypes_and_shapes_location_ - 1, max_timesteps_); + std::swap(dtypes_and_shapes_t, + inserted_dtypes_and_shapes_[insert_dtypes_and_shapes_location_]); + } + return status; +} + +tensorflow::Status ReplayWriter::AddPriority(const std::string& table, + int num_timesteps, + double priority) { + if (closed_) { + return tensorflow::errors::FailedPrecondition( + "Calling method AddPriority after Close has been called"); + } + if (num_timesteps > chunks_.size() * chunk_length_ + buffer_.size()) { + return tensorflow::errors::InvalidArgument( + "Argument `num_timesteps` is larger than number of buffered " + "timesteps."); + } + if (num_timesteps > max_timesteps_) { + return tensorflow::errors::InvalidArgument( + "`num_timesteps` must be <= `max_timesteps`"); + } + + const internal::DtypesAndShapes* dtypes_and_shapes = nullptr; + TF_RETURN_IF_ERROR(GetFlatSignature(table, &dtypes_and_shapes)); + CHECK(dtypes_and_shapes != nullptr); + if (dtypes_and_shapes->has_value()) { + for (int t = 0; t < num_timesteps; ++t) { + // Subtract 1 from the location since it is currently pointing to the next + // write. + const int check_offset = PositiveModulo( + insert_dtypes_and_shapes_location_ - 1 - t, max_timesteps_); + const auto& dtypes_and_shapes_t = + inserted_dtypes_and_shapes_[check_offset]; + REVERB_CHECK(dtypes_and_shapes_t.has_value()) + << "Unexpected missing dtypes and shapes while calling AddPriority: " + "expected a value at index " + << check_offset << " (timestep offset " << t << ")"; + + if (dtypes_and_shapes_t->size() != (*dtypes_and_shapes)->size()) { + return tensorflow::errors::InvalidArgument( + "Unable to AddPriority to table ", table, + " because AppendTimestep was called with a tensor signature " + "inconsistent with table signature. AppendTimestep for timestep " + "offset ", + t, " was called with ", dtypes_and_shapes_t->size(), + " tensors, but table requires ", (*dtypes_and_shapes)->size(), + " tensors per entry. Table signature: ", + internal::DtypesShapesString(**dtypes_and_shapes)); + } + + for (int c = 0; c < dtypes_and_shapes_t->size(); ++c) { + const auto& signature_dtype_and_shape = (**dtypes_and_shapes)[c]; + const auto& seen_dtype_and_shape = (*dtypes_and_shapes_t)[c]; + if (seen_dtype_and_shape.dtype != signature_dtype_and_shape.dtype || + !signature_dtype_and_shape.shape.IsCompatibleWith( + seen_dtype_and_shape.shape)) { + return tensorflow::errors::InvalidArgument( + "Unable to AddPriority to table ", table, + " because AppendTimestep was called with a tensor signature " + "inconsistent with table signature. Saw a tensor at " + "timestep offset ", + t, " in (flattened) tensor location ", c, " with dtype ", + DataTypeString(seen_dtype_and_shape.dtype), " and shape ", + seen_dtype_and_shape.shape.DebugString(), + " but expected a tensor of dtype ", + DataTypeString(signature_dtype_and_shape.dtype), + " and shape compatible with ", + signature_dtype_and_shape.shape.DebugString(), + ". (Flattened) table signature: ", + internal::DtypesShapesString(**dtypes_and_shapes)); + } + } + } + } + + int remaining = num_timesteps - buffer_.size(); + int num_chunks = + remaining / chunk_length_ + (remaining % chunk_length_ ? 1 : 0); + + // Don't use additional chunks if the entire episode is contained in the + // current buffer. + if (remaining < 0) { + num_chunks = 0; + } + + PrioritizedItem item; + item.set_key(NewID()); + item.set_table(table.data(), table.size()); + item.set_priority(priority); + item.mutable_sequence_range()->set_length(num_timesteps); + item.mutable_sequence_range()->set_offset( + (chunk_length_ - (remaining % chunk_length_)) % chunk_length_); + + for (auto it = std::next(chunks_.begin(), chunks_.size() - num_chunks); + it != chunks_.end(); it++) { + item.add_chunk_keys(it->chunk_key()); + } + if (!buffer_.empty()) { + item.add_chunk_keys(next_chunk_key_); + } + + pending_items_.push_back(item); + + if (buffer_.empty()) { + auto status = WriteWithRetries(); + if (!status.ok()) pending_items_.pop_back(); + return status; + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status ReplayWriter::Close() { + if (closed_) + return tensorflow::errors::FailedPrecondition( + "Calling method Close after Close has been called"); + if (!pending_items_.empty()) { + TF_RETURN_IF_ERROR(Finish()); + } + if (stream_) { + stream_->WritesDone(); + auto status = stream_->Finish(); + if (!status.ok()) { + REVERB_LOG(REVERB_INFO) << "Received error when closing the stream: " + << FormatGrpcStatus(status); + } + stream_ = nullptr; + } + chunks_.clear(); + closed_ = true; + return tensorflow::Status::OK(); +} + +tensorflow::Status ReplayWriter::Finish() { + SequenceRange sequence; + std::vector batched_tensors; + for (int i = 0; i < buffer_[0].size(); ++i) { + std::vector tensors(buffer_.size()); + for (int j = 0; j < buffer_.size(); ++j) { + const tensorflow::Tensor& item = buffer_[j][i]; + tensorflow::TensorShape shape = item.shape(); + shape.InsertDim(0, 1); + REVERB_CHECK(tensors[j].CopyFrom(item, shape)); + } + batched_tensors.emplace_back(); + TF_RETURN_IF_ERROR( + tensorflow::tensor::Concat(tensors, &batched_tensors.back())); + } + + ChunkData chunk; + chunk.set_chunk_key(next_chunk_key_); + chunk.mutable_sequence_range()->set_episode_id(episode_id_); + chunk.mutable_sequence_range()->set_start(index_within_episode_); + chunk.mutable_sequence_range()->set_end(index_within_episode_ + + buffer_.size() - 1); + + if (delta_encoded_) { + batched_tensors = DeltaEncodeList(batched_tensors, true); + chunk.set_delta_encoded(true); + } + + for (const auto& tensor : batched_tensors) { + CompressTensorAsProto(tensor, chunk.add_data()); + } + + chunks_.push_back(std::move(chunk)); + + auto status = WriteWithRetries(); + if (status.ok()) { + index_within_episode_ += buffer_.size(); + buffer_.clear(); + next_chunk_key_ = NewID(); + while ((chunks_.size() - 1) * chunk_length_ >= max_timesteps_) { + streamed_chunk_keys_.erase(chunks_.front().chunk_key()); + chunks_.pop_front(); + } + } else { + chunks_.pop_back(); + } + return status; +} + +tensorflow::Status ReplayWriter::WriteWithRetries() { + tensorflow::Status status; + while (true) { + if (WritePendingData()) return tensorflow::Status::OK(); + stream_->WritesDone(); + status = FromGrpcStatus(stream_->Finish()); + stream_ = nullptr; + if (!IsTransientError(status)) break; + } + return status; +} + +bool ReplayWriter::WritePendingData() { + if (!stream_) { + streamed_chunk_keys_.clear(); + context_ = absl::make_unique(); + stream_ = stub_->InsertStream(context_.get(), &response_); + } + + // Stream all chunks which are referenced by the pending items and haven't + // already been sent. After the items has been inserted we want the server + // to keep references only to the ones which the client still keeps + // around. + absl::flat_hash_set item_chunk_keys; + for (const auto& item : pending_items_) { + for (uint64_t key : item.chunk_keys()) { + item_chunk_keys.insert(key); + } + } + std::vector keep_chunk_keys; + for (const ChunkData& chunk : chunks_) { + if (item_chunk_keys.contains(chunk.chunk_key()) && + !streamed_chunk_keys_.contains(chunk.chunk_key())) { + InsertStreamRequest request; + request.set_allocated_chunk(const_cast(&chunk)); + grpc::WriteOptions options; + options.set_no_compression(); + bool ok = stream_->Write(request, options); + request.release_chunk(); + if (!ok) return false; + streamed_chunk_keys_.insert(chunk.chunk_key()); + } + if (streamed_chunk_keys_.contains(chunk.chunk_key())) { + keep_chunk_keys.push_back(chunk.chunk_key()); + } + } + while (!pending_items_.empty()) { + InsertStreamRequest request; + *request.mutable_item()->mutable_item() = pending_items_.front(); + *request.mutable_item()->mutable_keep_chunk_keys() = { + keep_chunk_keys.begin(), keep_chunk_keys.end()}; + if (!stream_->Write(request)) return false; + pending_items_.pop_front(); + } + + return true; +} + +uint64_t ReplayWriter::NewID() { + return absl::Uniform(bit_gen_, 0, UINT64_MAX); +} + +tensorflow::Status ReplayWriter::GetFlatSignature( + absl::string_view table, + const internal::DtypesAndShapes** dtypes_and_shapes) const { + static const auto* empty_dtypes_and_shapes = + new internal::DtypesAndShapes(absl::nullopt); + if (!signatures_) { + // No signatures available, return an unknown set. + *dtypes_and_shapes = empty_dtypes_and_shapes; + return tensorflow::Status::OK(); + } + auto iter = signatures_->find(table); + if (iter == signatures_->end()) { + std::vector table_names; + for (const auto& table : *signatures_) { + table_names.push_back(absl::StrCat("'", table.first, "'")); + } + return tensorflow::errors::InvalidArgument( + "Unable to find signatures for table '", table, + "' in signature cache. Available tables: [", + absl::StrJoin(table_names, ", "), "]."); + } + *dtypes_and_shapes = &(iter->second); + return tensorflow::Status::OK(); +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/replay_writer.h b/reverb/cc/replay_writer.h new file mode 100644 index 0000000..57d59d2 --- /dev/null +++ b/reverb/cc/replay_writer.h @@ -0,0 +1,168 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_REPLAY_WRITER_H_ +#define LEARNING_DEEPMIND_REPLAY_REVERB_REPLAY_WRITER_H_ + +#include +#include +#include + +#include "grpcpp/impl/codegen/client_context.h" +#include "grpcpp/impl/codegen/sync_stream.h" +#include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/random/random.h" +#include "absl/strings/string_view.h" +#include "reverb/cc/replay_service.grpc.pb.h" +#include "reverb/cc/replay_service.pb.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/support/signature.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace deepmind { +namespace reverb { + +// None of the methods are thread safe. +class ReplayWriter { + public: + // The client must not be deleted while any of its writer instances exist. + ReplayWriter( + std::shared_ptr stub, + int chunk_length, int max_timesteps, bool delta_encoded = false, + std::shared_ptr signatures = nullptr); + ~ReplayWriter(); + + // Appends a timestamp to internal `buffer_`. If the size of the buffer + // reached `chunk_length_` then its content is batched and inserted into + // `chunks_`. If `pending_items_` is not empty then its items are streamed to + // the ReplayService and popped. + // + // If all operations are successful then `buffer_` is cleared, a new + // `next_chunk_key_` is set and old items are removed from `chunks_` until its + // size is <= `max_chunks_`. If unsuccessful all internal state is reverted. + tensorflow::Status AppendTimestep(std::vector data); + + // Adds a new PrioritizedItem to `table` spanning the last `num_timesteps` and + // pushes new item to `pending_items_`. If `buffer_` is empty then the new + // item is streamed to the ReplayService. If unsuccessful all internal state + // is reverted. + tensorflow::Status AddPriority(const std::string& table, int num_timesteps, + double priority); + + // TODO(b/154929199): There should probably be a method for ending an episode + // even if you don't want to close the stream. + + // Creates a new batch from the content of `buffer_` and writes all + // `pending_items_` and closes the stream_. The object must be abandoned after + // calling this method. + tensorflow::Status Close(); + + private: + // Creates a new batch from the content of `buffer_` and inserts it into + // `chunks_`. If `pending_items_` is not empty then the items are streamed to + // the ReplayService and popped. + // + // If all operations are successful then `buffer_` is cleared, a new + // `next_chunk_key_` is set, `index_within_episode_` is incremented by the + // number of items in `buffer_` and old items are removed from `chunks_` until + // its size is <= `max_chunks_`. If the operation was unsuccessful then chunk + // is popped from `chunks_`. + tensorflow::Status Finish(); + + // Retries `WritePendingData` for at most `kMaxRetries` times. + tensorflow::Status WriteWithRetries(); + + // Streams the chunks in `chunks_` referenced by `pending_items_` followed by + // items in `pending_items_` + bool WritePendingData(); + + // Helper for generating a random ID. + uint64_t NewID(); + + // gRPC stub for the ReplayService. + std::shared_ptr stub_; + + // gRPC stream to the ReplayService.InsertStream endpoint. + std::unique_ptr> stream_; + std::unique_ptr context_; + InsertStreamResponse response_; + + // The number of timesteps to batch in each chunk. + const int chunk_length_; + + // The maximum number of recent timesteps which new items can reference. + const int max_timesteps_; + + // Whether chunks should be delta encoded before compressed. + const bool delta_encoded_; + + // Cache mapping table name to cached flattened signature. + std::shared_ptr signatures_; + + // Bit generator used by `NewID`. + absl::BitGen bit_gen_; + + // PriorityItems waiting to be sent to the ReplayService. Items are appended + // to the list when they reference timesteps in `buffer_`. Once `buffer_` has + // size `chunk_length_` the content is chunked and the pending items are + // written to the ReplayService. While `buffer_` is empty new items are + // written to the ReplayService immediately. + std::list pending_items_; + + // Timesteps not yet batched up and put into `chunks_`. + std::vector> buffer_; + + // Batched timesteps that can be referenced by new items. + std::list chunks_; + + // Keys of the chunks which have been streamed to the server. + absl::flat_hash_set streamed_chunk_keys_; + + // The key used to reference the items currently in `buffer_`. + uint64_t next_chunk_key_; + + // The episode id to attach to inserted timesteps. + uint64_t episode_id_; + + // Index of the first timestep in `buffer_`. + int32_t index_within_episode_; + + // Set if `Close` has been called. + bool closed_; + + // Set of signatures passed to AppendTimestep in a circular buffer. Each + // entry is the flat list of tensor dtypes and shapes in past AppendTimestep + // calls. The vector itself is of length max_time_steps_ and AppendTimestep + // updates the DtypesAndShapes at index append_dtypes_and_shapes_location_. + std::vector inserted_dtypes_and_shapes_; + int insert_dtypes_and_shapes_location_ = 0; + + // Get a pointer to the DtypesAndShapes flattened signature for `table`. + // Returns a nullopt signature if no signatures were provided to the Writer on + // initialization. Raises an InvalidArgument if the table is not in + // signatures_. + tensorflow::Status GetFlatSignature( + absl::string_view table, + const internal::DtypesAndShapes** dtypes_and_shapes) const; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // LEARNING_DEEPMIND_REPLAY_REVERB_REPLAY_WRITER_H_ diff --git a/reverb/cc/replay_writer_test.cc b/reverb/cc/replay_writer_test.cc new file mode 100644 index 0000000..3f00da6 --- /dev/null +++ b/reverb/cc/replay_writer_test.cc @@ -0,0 +1,610 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/replay_writer.h" + +#include +#include + +#include "grpcpp/impl/codegen/call_op_set.h" +#include "grpcpp/impl/codegen/status.h" +#include "grpcpp/impl/codegen/sync_stream.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "reverb/cc/replay_client.h" +#include "reverb/cc/replay_service.grpc.pb.h" +#include "reverb/cc/replay_service_mock.grpc.pb.h" +#include "reverb/cc/support/grpc_util.h" +#include "reverb/cc/support/uint128.h" +#include "reverb/cc/testing/proto_test_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace deepmind { + +namespace reverb { +namespace { + +using ::deepmind::reverb::testing::Partially; +using ::tensorflow::errors::DeadlineExceeded; +using ::tensorflow::errors::Internal; +using ::tensorflow::errors::Unavailable; +using ::testing::ElementsAre; +using ::testing::SizeIs; + +std::vector MakeTimestep(int num_tensors = 1) { + tensorflow::Tensor tensor(1.0f); + std::vector res(num_tensors, tensor); + return res; +} + +tensorflow::StructuredValue MakeSignature( + tensorflow::DataType dtype = tensorflow::DT_FLOAT, + const tensorflow::PartialTensorShape& shape = + tensorflow::PartialTensorShape{}) { + tensorflow::StructuredValue signature; + auto* spec = signature.mutable_tensor_spec_value(); + spec->set_dtype(dtype); + spec->set_name("tensor0"); + shape.AsProto(spec->mutable_shape()); + return signature; +} + +MATCHER(IsChunk, "") { return arg.has_chunk(); } + +MATCHER_P4(IsItemWithRangeAndPriorityAndTable, offset, length, priority, table, + "") { + return arg.has_item() && + arg.item().item().sequence_range().offset() == offset && + arg.item().item().sequence_range().length() == length && + arg.item().item().priority() == priority && + arg.item().item().table() == table; +} + +class FakeWriter : public grpc::ClientWriterInterface { + public: + FakeWriter(std::vector* requests, int num_success_writes, + grpc::Status bad_status) + : requests_(requests), + num_success_writes_(num_success_writes), + bad_status_(std::move(bad_status)) {} + + bool Write(const InsertStreamRequest& msg, + grpc::WriteOptions options) override { + requests_->push_back(msg); + return num_success_writes_-- > 0; + } + + grpc::Status Finish() override { + return num_success_writes_ >= 0 ? grpc::Status::OK : bad_status_; + } + + bool WritesDone() override { return num_success_writes_-- > 0; } + + private: + std::vector* requests_; + int num_success_writes_; + grpc::Status bad_status_; +}; + +class FakeStub : public /* grpc_gen:: */MockReplayServiceStub { + public: + explicit FakeStub(std::list writers, + const tensorflow::StructuredValue* signature = nullptr) + : writers_(std::move(writers)) { + if (signature) { + *response_.mutable_tables_state_id() = + Uint128ToMessage(absl::MakeUint128(1, 2)); + auto* table_info = response_.add_table_info(); + table_info->set_name("dist"); + *table_info->mutable_signature() = *signature; + } + } + ~FakeStub() override { + // Since writers where allocated with New we manually free the memory if + // the writer hasn't already been passed to the ReplayWriter where it is + // handled as a unique ptr and thus is destroyed with ~ReplayWriter. + while (!writers_.empty()) { + auto writer = writers_.front(); + delete writer; + writers_.pop_front(); + } + } + + grpc::ClientWriterInterface* InsertStreamRaw( + grpc::ClientContext* context, InsertStreamResponse* response) override { + auto writer = writers_.front(); + writers_.pop_front(); + return writer; + } + + grpc::Status ServerInfo(grpc::ClientContext* context, + const ServerInfoRequest& request, + ServerInfoResponse* response) override { + *response = response_; + return grpc::Status::OK; + } + + private: + ServerInfoResponse response_; + std::list writers_; +}; + +std::shared_ptr MakeGoodStub( + std::vector* requests, + const tensorflow::StructuredValue* signature = nullptr) { + FakeWriter* writer = + new FakeWriter(requests, 10000, ToGrpcStatus(Internal(""))); + return std::make_shared(std::list{writer}, signature); +} + +std::shared_ptr MakeFlakyStub( + std::vector* requests, int num_success, int num_fail, + grpc::Status error) { + std::list writers; + writers.push_back(new FakeWriter(requests, num_success, error)); + for (int i = 1; i < num_fail; i++) { + writers.push_back(new FakeWriter(requests, 0, error)); + } + writers.push_back( + new FakeWriter(requests, 10000, ToGrpcStatus(Internal("")))); + return std::make_shared(std::move(writers)); +} + +TEST(ReplayWriterTest, DoesNotSendTimestepsWhenThereAreNoItems) { + std::vector requests; + auto stub = MakeGoodStub(&requests); + ReplayWriter client(stub, 2, 10); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + EXPECT_THAT(requests, SizeIs(0)); +} + +TEST(ReplayWriterTest, OnlySendsChunksWhichAreUsedByItems) { + std::vector requests; + auto stub = MakeGoodStub(&requests); + ReplayWriter client(stub, 2, 10); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + EXPECT_THAT(requests, SizeIs(0)); + + TF_ASSERT_OK(client.AddPriority("dist", 3, 1.0)); + ASSERT_THAT(requests, SizeIs(3)); + EXPECT_THAT(requests[0], IsChunk()); + EXPECT_THAT(requests[1], IsChunk()); + EXPECT_THAT(requests[2], + IsItemWithRangeAndPriorityAndTable(1, 3, 1.0, "dist")); + EXPECT_THAT(requests[2].item().item().chunk_keys(), + ElementsAre(requests[0].chunk().chunk_key(), + requests[1].chunk().chunk_key())); +} + +TEST(ReplayWriterTest, DoesNotSendAlreadySentChunks) { + std::vector requests; + auto stub = MakeGoodStub(&requests); + ReplayWriter client(stub, 2, 10); + + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AddPriority("dist", 1, 1.5)); + + ASSERT_THAT(requests, SizeIs(2)); + + EXPECT_THAT(requests[0], IsChunk()); + auto first_chunk_key = requests[0].chunk().chunk_key(); + + EXPECT_THAT(requests[1], + IsItemWithRangeAndPriorityAndTable(1, 1, 1.5, "dist")); + EXPECT_THAT(requests[1].item().item().chunk_keys(), + ElementsAre(first_chunk_key)); + + requests.clear(); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AddPriority("dist", 3, 1.3)); + + ASSERT_THAT(requests, SizeIs(2)); + EXPECT_THAT(requests[0], IsChunk()); + auto second_chunk_key = requests[0].chunk().chunk_key(); + + EXPECT_THAT(requests[1], + IsItemWithRangeAndPriorityAndTable(1, 3, 1.3, "dist")); + EXPECT_THAT(requests[1].item().item().chunk_keys(), + ElementsAre(first_chunk_key, second_chunk_key)); +} + +TEST(ReplayWriterTest, SendsPendingDataOnClose) { + std::vector requests; + auto stub = MakeGoodStub(&requests); + ReplayWriter client(stub, 2, 10); + + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AddPriority("dist", 1, 1.0)); + EXPECT_THAT(requests, SizeIs(0)); + + TF_ASSERT_OK(client.Close()); + ASSERT_THAT(requests, SizeIs(2)); + EXPECT_THAT(requests[0], IsChunk()); + EXPECT_THAT(requests[1], + IsItemWithRangeAndPriorityAndTable(0, 1, 1.0, "dist")); + EXPECT_THAT(requests[1].item().item().chunk_keys(), + ElementsAre(requests[0].chunk().chunk_key())); +} + +TEST(ReplayWriterTest, FailsIfMethodsCalledAfterClose) { + std::vector requests; + auto stub = MakeGoodStub(&requests); + ReplayWriter client(stub, 2, 10); + + TF_ASSERT_OK(client.Close()); + + EXPECT_FALSE(client.Close().ok()); + EXPECT_FALSE(client.AppendTimestep(MakeTimestep()).ok()); + EXPECT_FALSE(client.AddPriority("dist", 1, 1.0).ok()); +} + +TEST(ReplayWriterTest, RetriesOnTransientError) { + std::vector transient_errors( + {DeadlineExceeded(""), Unavailable("")}); + + for (const auto& error : transient_errors) { + std::vector requests; + // 1 fail, then all success. + auto stub = MakeFlakyStub(&requests, 0, 1, ToGrpcStatus(error)); + ReplayWriter client(stub, 2, 10); + + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AddPriority("dist", 1, 1.0)); + + ASSERT_THAT(requests, SizeIs(3)); + EXPECT_THAT(requests[0], IsChunk()); + EXPECT_THAT(requests[1], IsChunk()); + EXPECT_THAT(requests[0], testing::EqualsProto(requests[1])); + EXPECT_THAT(requests[2], + IsItemWithRangeAndPriorityAndTable(1, 1, 1.0, "dist")); + EXPECT_THAT(requests[2].item().item().chunk_keys(), + ElementsAre(requests[0].chunk().chunk_key())); + } +} + +TEST(ReplayWriterTest, DoesNotRetryOnNonTransientError) { + std::vector requests; + auto stub = MakeFlakyStub(&requests, 0, 1, ToGrpcStatus(Internal(""))); + ReplayWriter client(stub, 2, 10); + + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + EXPECT_FALSE(client.AddPriority("dist", 1, 1.0).ok()); + + EXPECT_THAT(requests, SizeIs(1)); // Tries only once and then gives up. +} + +TEST(ReplayWriterTest, CallsCloseWhenObjectDestroyed) { + std::vector requests; + { + auto stub = MakeGoodStub(&requests); + ReplayWriter client(stub, 2, 10); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AddPriority("dist", 1, 1.0)); + EXPECT_THAT(requests, SizeIs(0)); + } + ASSERT_THAT(requests, SizeIs(2)); +} + +TEST(ReplayWriterTest, ResendsOnlyTheChunksTheRemainingItemsNeedWithNewStream) { + std::vector requests; + auto stub = + MakeFlakyStub(&requests, 3, 1, ToGrpcStatus(DeadlineExceeded(""))); + ReplayWriter client(stub, 2, 10); + + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AddPriority("dist", 3, 1.0)); + TF_ASSERT_OK(client.AddPriority("dist2", 1, 1.0)); + EXPECT_THAT(requests, SizeIs(0)); + + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + + ASSERT_THAT(requests, SizeIs(6)); + EXPECT_THAT(requests[0], IsChunk()); + EXPECT_THAT(requests[1], IsChunk()); + auto first_chunk_key = requests[0].chunk().chunk_key(); + auto second_chunk_key = requests[1].chunk().chunk_key(); + + EXPECT_THAT(requests[2], + IsItemWithRangeAndPriorityAndTable(0, 3, 1.0, "dist")); + EXPECT_THAT(requests[2].item().item().chunk_keys(), + ElementsAre(first_chunk_key, second_chunk_key)); + + EXPECT_THAT(requests[3], IsItemWithRangeAndPriorityAndTable( + 0, 1, 1.0, "dist2")); // Failed + EXPECT_THAT(requests[3].item().item().chunk_keys(), + ElementsAre(second_chunk_key)); + + // Stream is opened and only the second chunk is sent again. + EXPECT_THAT(requests[4], IsChunk()); + EXPECT_THAT(requests[5], + IsItemWithRangeAndPriorityAndTable(0, 1, 1.0, "dist2")); + EXPECT_THAT(requests[5].item().item().chunk_keys(), + ElementsAre(second_chunk_key)); +} + +TEST(ReplayWriterTest, TellsServerToKeepStreamedItemsStillInClient) { + std::vector requests; + auto stub = MakeGoodStub(&requests); + ReplayWriter client(stub, 2, 6); + + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AddPriority("dist", 1, 1.0)); + + ASSERT_THAT(requests, SizeIs(2)); + EXPECT_THAT(requests[0], IsChunk()); + auto first_chunk_key = requests[0].chunk().chunk_key(); + + EXPECT_THAT(requests[1], + IsItemWithRangeAndPriorityAndTable(1, 1, 1.0, "dist")); + EXPECT_THAT(requests[1].item().keep_chunk_keys(), + ElementsAre(first_chunk_key)); + + requests.clear(); + + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AddPriority("dist", 1, 1.0)); + + ASSERT_THAT(requests, SizeIs(2)); + EXPECT_THAT(requests[0], IsChunk()); + auto third_chunk_key = requests[0].chunk().chunk_key(); + + EXPECT_THAT(requests[1], + IsItemWithRangeAndPriorityAndTable(1, 1, 1.0, "dist")); + EXPECT_THAT(requests[1].item().keep_chunk_keys(), + ElementsAre(first_chunk_key, third_chunk_key)); + + requests.clear(); + + // Now the first chunk will go out of scope + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client.AddPriority("dist", 1, 1.0)); + + ASSERT_THAT(requests, SizeIs(2)); + EXPECT_THAT(requests[0], IsChunk()); + auto forth_chunk_key = requests[0].chunk().chunk_key(); + + EXPECT_THAT(requests[1], + IsItemWithRangeAndPriorityAndTable(1, 1, 1.0, "dist")); + EXPECT_THAT(requests[1].item().keep_chunk_keys(), + ElementsAre(third_chunk_key, forth_chunk_key)); +} + +TEST(ReplayWriterTest, IgnoresCloseErrorsIfAllItemsWritten) { + std::vector requests; + auto stub = MakeFlakyStub(&requests, /*num_success=*/2, + /*num_fail=*/1, ToGrpcStatus(Internal(""))); + ReplayWriter client(stub, /*chunk_length=*/1, /*max_timesteps=*/2); + + // Insert an item and make sure it is flushed to the server. + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AddPriority("dist", 1, 1.0)); + EXPECT_THAT(requests, SizeIs(2)); + + // Close the client without any pending items and check that it swallows + // the error. + TF_EXPECT_OK(client.Close()); +} + +TEST(ReplayWriterTest, ReturnsCloseErrorsIfAllItemsNotWritten) { + std::vector requests; + auto stub = MakeFlakyStub(&requests, /*num_success=*/1, + /*num_fail=*/1, ToGrpcStatus(Internal(""))); + ReplayWriter client(stub, /*chunk_length=*/2, /*max_timesteps=*/4); + + // Insert an item which is shorter + // than the batch and thus should not + // be automatically flushed. + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AddPriority("dist", 1, 1.0)); + EXPECT_THAT(requests, SizeIs(0)); + + // Since not all items where sent + // before the error should be + // returned. + EXPECT_EQ(client.Close().code(), tensorflow::error::INTERNAL); +} + +TEST(ReplayWriterTest, SequenceRangeIsSetOnChunks) { + std::vector requests; + auto stub = MakeGoodStub(&requests); + ReplayWriter client(stub, /*chunk_length=*/2, + /*max_timesteps=*/4); + + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AddPriority("dist", 3, 1.0)); + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + + EXPECT_THAT( + requests, + ElementsAre( + Partially(testing::EqualsProto("chunk: { sequence_range: { start: 0 " + "end: 1 } delta_encoded: false }")), + Partially(testing::EqualsProto("chunk: { sequence_range: { start: 2 " + "end: 3 } delta_encoded: false }")), + IsItemWithRangeAndPriorityAndTable(0, 3, 1.0, "dist"))); + + EXPECT_NE(requests[0].chunk().sequence_range().episode_id(), 0); + EXPECT_EQ(requests[0].chunk().sequence_range().episode_id(), + requests[1].chunk().sequence_range().episode_id()); +} + +TEST(ReplayWriterTest, DeltaEncode) { + std::vector requests; + auto stub = MakeGoodStub(&requests); + ReplayWriter client(stub, /*chunk_length=*/2, + /*max_timesteps=*/4, /*delta_encoded=*/true); + + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AddPriority("dist", 3, 1.0)); + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + + EXPECT_THAT( + requests, + ElementsAre( + Partially(testing::EqualsProto("chunk: { sequence_range: { start: 0 " + "end: 1 } delta_encoded: true }")), + Partially(testing::EqualsProto("chunk: { sequence_range: { start: 2 " + "end: 3 } delta_encoded: true }")), + IsItemWithRangeAndPriorityAndTable(0, 3, 1.0, "dist"))); +} + +TEST(ReplayWriterTest, MultiChunkItemsAreCorrect) { + std::vector requests; + auto stub = MakeGoodStub(&requests); + ReplayWriter client(stub, /*chunk_length=*/3, + /*max_timesteps=*/4, /*delta_encoded=*/false); + + // We create two chunks with 5 time steps (t_0,.., t_4) and 3 sequences + // (s_0, s_1, s_2): + // +--- CHUNK0 --+- CHUNK1 -+ + // | t_0 t_1 t_2 | t_3 t_4 | + // +-------------+----------+ + // | s_0 s_0 s_1 | s_1 s_3 | + // +-------------+----------+ + + // First item: 1 chunk. + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AddPriority("dist", 2, 1.0)); + + // Second item: 2 chunks. + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AddPriority("dist", 2, 1.0)); + + // Third item: 1 chunk. + TF_EXPECT_OK(client.AppendTimestep(MakeTimestep())); + TF_EXPECT_OK(client.AddPriority("dist", 1, 1.0)); + + TF_EXPECT_OK(client.Close()); + + EXPECT_THAT( + requests, + ElementsAre( + Partially(testing::EqualsProto("chunk: { sequence_range: { start: 0 " + "end: 2 } delta_encoded: false }")), + IsItemWithRangeAndPriorityAndTable(0, 2, 1.0, "dist"), + Partially(testing::EqualsProto("chunk: { sequence_range: { start: 3 " + "end: 4 } delta_encoded: false }")), + IsItemWithRangeAndPriorityAndTable(2, 2, 1.0, "dist"), + IsItemWithRangeAndPriorityAndTable(1, 1, 1.0, "dist"))); + + EXPECT_EQ(requests[1].item().item().chunk_keys_size(), 1); + EXPECT_EQ(requests[3].item().item().chunk_keys_size(), 2); + EXPECT_EQ(requests[4].item().item().chunk_keys_size(), 1); +} + +TEST(ReplayWriterTest, WriteTimeStepsMatchingSignature) { + std::vector requests; + tensorflow::StructuredValue signature = + MakeSignature(tensorflow::DT_FLOAT, tensorflow::PartialTensorShape({})); + auto stub = MakeGoodStub(&requests, &signature); + ReplayClient replay_client(stub); + std::unique_ptr client; + TF_EXPECT_OK(replay_client.NewWriter(2, 6, /*delta_encoded=*/false, &client)); + + TF_ASSERT_OK(client->AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client->AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client->AddPriority("dist", 2, 1.0)); + ASSERT_THAT(requests, SizeIs(2)); +} + +TEST(ReplayWriterTest, WriteTimeStepsNumTensorsDontMatchSignatureError) { + std::vector requests; + tensorflow::StructuredValue signature = MakeSignature(); + auto stub = MakeGoodStub(&requests, &signature); + ReplayClient replay_client(stub); + std::unique_ptr client; + TF_EXPECT_OK(replay_client.NewWriter(2, 6, /*delta_encoded=*/false, &client)); + + TF_ASSERT_OK(client->AppendTimestep(MakeTimestep(/*num_tensors=*/2))); + TF_ASSERT_OK(client->AppendTimestep(MakeTimestep(/*num_tensors=*/2))); + auto status = client->AddPriority("dist", 2, 1.0); + EXPECT_EQ(status.code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr( + "AppendTimestep for timestep offset 0 was called with 2 tensors, " + "but table requires 1 tensors per entry.")); +} + +TEST(ReplayWriterTest, WriteTimeStepsInconsistentDtypeError) { + std::vector requests; + tensorflow::StructuredValue signature = MakeSignature(tensorflow::DT_INT32); + auto stub = MakeGoodStub(&requests, &signature); + ReplayClient replay_client(stub); + std::unique_ptr client; + TF_EXPECT_OK(replay_client.NewWriter(2, 6, /*delta_encoded=*/false, &client)); + + TF_ASSERT_OK(client->AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client->AppendTimestep(MakeTimestep())); + auto status = client->AddPriority("dist", 2, 1.0); + EXPECT_EQ(status.code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr( + "timestep offset 0 in (flattened) tensor location 0 with " + "dtype float and shape [] but expected a tensor of dtype " + "int32 and shape compatible with ")); +} + +TEST(ReplayWriterTest, WriteTimeStepsInconsistentShapeError) { + std::vector requests; + tensorflow::StructuredValue signature = + MakeSignature(tensorflow::DT_FLOAT, tensorflow::PartialTensorShape({-1})); + auto stub = MakeGoodStub(&requests, &signature); + ReplayClient replay_client(stub); + std::unique_ptr client; + TF_EXPECT_OK(replay_client.NewWriter(2, 6, /*delta_encoded=*/false, &client)); + + TF_ASSERT_OK(client->AppendTimestep(MakeTimestep())); + TF_ASSERT_OK(client->AppendTimestep(MakeTimestep())); + auto status = client->AddPriority("dist", 2, 1.0); + EXPECT_EQ(status.code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr( + "timestep offset 0 in (flattened) tensor location 0 with " + "dtype float and shape [] but expected a tensor of dtype " + "float and shape compatible with [?]")); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/reverb_server.cc b/reverb/cc/reverb_server.cc new file mode 100644 index 0000000..f9982d7 --- /dev/null +++ b/reverb/cc/reverb_server.cc @@ -0,0 +1,118 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/reverb_server.h" + +#include // NOLINT(build/c++11) - grpc API requires it. +#include +#include + +#include "grpcpp/server_builder.h" +#include "absl/strings/str_cat.h" +#include "reverb/cc/checkpointing/interface.h" +#include "reverb/cc/platform/grpc_utils.h" +#include "reverb/cc/platform/logging.h" + +namespace deepmind { +namespace reverb { +namespace { + +constexpr int kMaxMessageSize = 300 * 1000 * 1000; + +} // namespace + +ReverbServer::ReverbServer( + std::vector> priority_tables, int port, + std::shared_ptr checkpointer) : port_(port) { + + replay_service_ = absl::make_unique( + std::move(priority_tables), std::move(checkpointer)); + + server_ = grpc::ServerBuilder() + .AddListeningPort(absl::StrCat("[::]:", port), + MakeServerCredentials()) + .RegisterService(replay_service_.get()) + .SetMaxSendMessageSize(kMaxMessageSize) + .SetMaxReceiveMessageSize(kMaxMessageSize) + .BuildAndStart(); +} + +tensorflow::Status ReverbServer::Initialize() { + absl::WriterMutexLock lock(&mu_); + REVERB_CHECK(!running_) << "Initialize() called twice?"; + if (!server_) { + return tensorflow::errors::InvalidArgument( + "Failed to BuildAndStart gRPC server"); + } + running_ = true; + REVERB_LOG(REVERB_INFO) << "Started replay server on port " << port_; + return tensorflow::Status::OK(); +} + +/* static */ tensorflow::Status ReverbServer::StartReverbServer( + std::vector> priority_tables, int port, + std::unique_ptr* server) { + // We can't use make_unique here since it can't access the private + // ReverbServer constructor. + std::unique_ptr s( + new ReverbServer(std::move(priority_tables), port)); + TF_RETURN_IF_ERROR(s->Initialize()); + std::swap(s, *server); + return tensorflow::Status::OK(); +} + +/* static */ tensorflow::Status ReverbServer::StartReverbServer( + std::vector> priority_tables, int port, + std::shared_ptr checkpointer, + std::unique_ptr* server) { + // We can't use make_unique here since it can't access the private + // ReverbServer constructor. + std::unique_ptr s(new ReverbServer( + std::move(priority_tables), port, std::move(checkpointer))); + TF_RETURN_IF_ERROR(s->Initialize()); + std::swap(s, *server); + return tensorflow::Status::OK(); +} + +ReverbServer::~ReverbServer() { Stop(); } + +void ReverbServer::Stop() { + absl::WriterMutexLock lock(&mu_); + if (!running_) return; + REVERB_LOG(REVERB_INFO) << "Shutting down replay server"; + + // Closes the dependent services in the desirable order. + replay_service_->Close(); + + // Set a deadline as the sampler streams never closes by themselves. + server_->Shutdown(std::chrono::system_clock::now() + std::chrono::seconds(5)); + + running_ = false; +} + +void ReverbServer::Wait() { + server_->Wait(); +} + +std::unique_ptr ReverbServer::InProcessClient() { + grpc::ChannelArguments arguments; + arguments.SetMaxReceiveMessageSize(kMaxMessageSize); + arguments.SetMaxSendMessageSize(kMaxMessageSize); + absl::WriterMutexLock lock(&mu_); + return absl::make_unique( + /* grpc_gen:: */ReplayService::NewStub(server_->InProcessChannel(arguments))); +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/reverb_server.h b/reverb/cc/reverb_server.h new file mode 100644 index 0000000..9fe2c12 --- /dev/null +++ b/reverb/cc/reverb_server.h @@ -0,0 +1,74 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_REVERB_SERVER_H_ +#define REVERB_CC_REVERB_SERVER_H_ + +#include + +#include "reverb/cc/checkpointing/interface.h" +#include "reverb/cc/priority_table.h" +#include "reverb/cc/replay_client.h" +#include "reverb/cc/replay_service_impl.h" + +namespace deepmind { +namespace reverb { + +class ReverbServer { + public: + static tensorflow::Status StartReverbServer( + std::vector> priority_tables, int port, + std::shared_ptr checkpointer, + std::unique_ptr* server); + + static tensorflow::Status StartReverbServer( + std::vector> priority_tables, int port, + std::unique_ptr* server); + + ~ReverbServer(); + + // Terminates the server and blocks until it has been terminated. + void Stop(); + + // Blocks until the server has terminated. Does not terminate the server + // itself. Use this to want to wait indefinitely. + void Wait(); + + // Gets a local in process client. This bypasses proto serialization and + // network overhead. Careful: The resulting client instance must not be used + // after this server instance has terminated. + std::unique_ptr InProcessClient(); + + private: + ReverbServer(std::vector> priority_tables, + int port, + std::shared_ptr checkpointer = nullptr); + + tensorflow::Status Initialize(); + + // The port the server is on. + int port_; + + std::unique_ptr replay_service_; + + std::unique_ptr server_ = nullptr; + + absl::Mutex mu_; + bool running_ ABSL_GUARDED_BY(mu_) = false; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_REVERB_SERVER_H_ diff --git a/reverb/cc/reverb_server_test.cc b/reverb/cc/reverb_server_test.cc new file mode 100644 index 0000000..21c0bd6 --- /dev/null +++ b/reverb/cc/reverb_server_test.cc @@ -0,0 +1,51 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/reverb_server.h" + +#include + +#include "grpcpp/impl/codegen/client_context.h" +#include "grpcpp/impl/codegen/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "reverb/cc/platform/net.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { +namespace { + +TEST(ReverbServerTest, StartServer) { + int port = internal::PickUnusedPortOrDie(); + std::unique_ptr server; + TF_EXPECT_OK(ReverbServer::StartReverbServer(/*priority_tables=*/{}, + /*port=*/port, &server)); +} + +TEST(ReverbServerTest, ErrorOnUnavailablePort) { + // We expect that port==-1 to always be unavailable. + std::unique_ptr server; + auto status = ReverbServer::StartReverbServer(/*priority_tables=*/{}, + /*port=*/-1, &server); + EXPECT_EQ(status.code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Failed to BuildAndStart gRPC server")); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/schema.proto b/reverb/cc/schema.proto new file mode 100644 index 0000000..bb1dff8 --- /dev/null +++ b/reverb/cc/schema.proto @@ -0,0 +1,171 @@ +syntax = "proto3"; + +package deepmind.reverb; + +import "google/protobuf/timestamp.proto"; +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/protobuf/struct.proto"; + +// The actual data is stored in chunks. The data can be arbitrary tensors. We do +// not interpret the bytes data of the tensors on the server side. It is up to +// the client to compress the bytes blob within the tensors. +message ChunkData { + // Unique identifier of the chunk. + uint64 chunk_key = 1; + + // The timesteps within the episode that the chunk covers. + SequenceRange sequence_range = 2; + + // Actual tensor data. + repeated tensorflow.TensorProto data = 3; + + // True if delta encoding has been applied before compressing data. + bool delta_encoded = 4; +} + +// A range that specifies which items to slice out from a sequence of chunks. +// The length of all chunks must at least be `offset`+`length`. +message SliceRange { + // Offset where the slice should start. + int32 offset = 1; + + // Length of the slice. Can span multiple chunks. + int32 length = 2; +} + +message SequenceRange { + // Globally unique identifier of the episode the sequence belongs to. + uint64 episode_id = 1; + + // Index within the episode of the first timestep covered by the range. + int32 start = 2; + + // Index within the episode of the last timestep covered by the range. + // Must be >= start_index. + int32 end = 3; +} + +// A prioritized item is part of a priority table and references a chunk of +// data. Sampling happens based on the priority of items. +message PrioritizedItem { + // Unique identifier of this item. + uint64 key = 1; + + // Priority table that the item belongs to. + string table = 2; + + // Sequence of chunks that is referenced by this item. Chunks in this list + // will be batched together and are expected to have compatible shape and + // dtype. Optional. + repeated uint64 chunk_keys = 3; + + // Range that should be sliced out of the chunks in `chunk_keys`. This range + // can span multiple chunks. + SliceRange sequence_range = 4; + + // Priority used for sampling. + double priority = 5; + + // The number of times the item has been sampled. + int32 times_sampled = 6; + + // The time when the item was first inserted. + google.protobuf.Timestamp inserted_at = 7; +} + +// Used for updating an existing PrioritizedItem. +message KeyWithPriority { + // Identifier of the PrioritizedItem. + uint64 key = 1; + + // Priority used for sampling. + double priority = 2; +} + +message SampleInfo { + // Item from that was sampled from the priority table. + PrioritizedItem item = 1; + + // Probability that this item had at sampling time. Useful for importance + // sampling. + double probability = 2; + + // Number of items in the table at the time of the sample operation. + int64 table_size = 3; +} + +// Metadata about the table, including (optional) data signature. +// +// These fields correspond to initialization arguments of the +// `PriorityTable` class, unless noted otherwise. +message TableInfo { + // Table's name. + string name = 8; + + // Sampler and remover metadata. + KeyDistributionOptions sampler_options = 1; + KeyDistributionOptions remover_options = 2; + + // Max size of the table. + int64 max_size = 3; + + // Max number of times an element can be sampled before being + // removed. + int32 max_times_sampled = 4; + + // How data read/write is rate limited. + RateLimiterInfo rate_limiter_info = 5; + + // Optional data signature for tensors stored in the table. Note + // that this data type is more flexible than we use. For example, + // we only store tensors (TensorSpecProto, TypeSpecProto) and not + // any special data types (no NoneValue or other special fixed values). + tensorflow.StructuredValue signature = 6; + + // Current size of table. + int64 current_size = 7; +} + +message RateLimiterInfo { + // The average number of times each item should be sampled during its + // lifetime. + double samples_per_insert = 1; + + // The minimum and maximum values the cursor is allowed to reach. The cursor + // value is calculated as `insert_count * samples_per_insert - + // sample_count`. If the value would go beyond these limits then the call is + // blocked until it can proceed without violating the constraints. + double min_diff = 2; + double max_diff = 3; + + // The minimum number of inserts required before any sample operation. + int64 min_size_to_sample = 4; +} + +// Metadata about sampler or remover. Describes its configuration. +message KeyDistributionOptions { + message Prioritized { + double priority_exponent = 1; + } + + message Heap { + bool min_heap = 1; + } + + oneof distribution { + bool fifo = 1; + bool uniform = 2; + Prioritized prioritized = 3; + Heap heap = 4; + + bool lifo = 6; + } + + reserved 5; +} + +// Uint128 representation. Can be used for unique identifiers. +message Uint128 { + uint64 high = 1; + uint64 low = 2; +} diff --git a/reverb/cc/support/BUILD b/reverb/cc/support/BUILD new file mode 100644 index 0000000..5885704 --- /dev/null +++ b/reverb/cc/support/BUILD @@ -0,0 +1,87 @@ +load( + "//reverb/cc/platform:build_rules.bzl", + "reverb_absl_deps", + "reverb_cc_library", + "reverb_cc_test", + "reverb_grpc_deps", + "reverb_tf_deps", +) + +package(default_visibility = ["//reverb:__subpackages__"]) + +licenses(["notice"]) + +reverb_cc_library( + name = "uint128", + hdrs = ["uint128.h"], + deps = [ + "//reverb/cc:schema_cc_proto", + ] + reverb_absl_deps(), +) + +reverb_cc_library( + name = "grpc_util", + hdrs = ["grpc_util.h"], + deps = reverb_tf_deps() + reverb_grpc_deps() + reverb_absl_deps(), +) + +reverb_cc_library( + name = "queue", + hdrs = ["queue.h"], + deps = reverb_absl_deps(), +) + +reverb_cc_library( + name = "intrusive_heap", + srcs = ["intrusive_heap.cc"], + hdrs = ["intrusive_heap.h"], + deps = [ + "//reverb/cc/platform:logging", + ], +) + +reverb_cc_test( + name = "intrusive_heap_test", + srcs = ["intrusive_heap_test.cc"], + deps = [ + ":intrusive_heap", + ] + reverb_absl_deps(), +) + +reverb_cc_library( + name = "periodic_closure", + srcs = ["periodic_closure.cc"], + hdrs = ["periodic_closure.h"], + deps = [ + "//reverb/cc/platform:logging", + "//reverb/cc/platform:thread", + ] + reverb_absl_deps() + reverb_tf_deps(), +) + +reverb_cc_test( + name = "queue_test", + srcs = ["queue_test.cc"], + deps = [ + ":queue", + "//reverb/cc/platform:logging", + "//reverb/cc/platform:thread", + ] + reverb_absl_deps(), +) + +reverb_cc_test( + name = "periodic_closure_test", + srcs = ["periodic_closure_test.cc"], + deps = [ + ":periodic_closure", + "//reverb/cc/testing:time_testutil", + ] + reverb_absl_deps() + reverb_tf_deps(), +) + +reverb_cc_library( + name = "signature", + srcs = ["signature.cc"], + hdrs = ["signature.h"], + deps = [ + "//reverb/cc:schema_cc_proto", + ] + reverb_tf_deps() + reverb_absl_deps(), +) diff --git a/reverb/cc/support/grpc_util.h b/reverb/cc/support/grpc_util.h new file mode 100644 index 0000000..5061f41 --- /dev/null +++ b/reverb/cc/support/grpc_util.h @@ -0,0 +1,69 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_SUPPORT_GRPC_UTIL_H_ +#define LEARNING_DEEPMIND_REPLAY_REVERB_SUPPORT_GRPC_UTIL_H_ + +#include "grpcpp/grpcpp.h" +#include "grpcpp/impl/codegen/proto_utils.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "absl/strings/substitute.h" + +namespace deepmind { +namespace reverb { + +constexpr char kStreamRemovedMessage[] = "Stream removed"; + +// Identify if the given grpc::Status corresponds to an HTTP stream removed +// error (see chttp2_transport.cc). +// +// When auto-reconnecting to a remote TensorFlow worker after it restarts, gRPC +// can return an UNKNOWN error code with a "Stream removed" error message. +// This should not be treated as an unrecoverable error. +// +// N.B. This is dependent on the error message from grpc remaining consistent. +inline bool IsStreamRemovedError(const ::grpc::Status& s) { + return !s.ok() && s.error_code() == ::grpc::StatusCode::UNKNOWN && + s.error_message() == kStreamRemovedMessage; +} + +inline grpc::Status ToGrpcStatus(const tensorflow::Status& s) { + if (s.ok()) return grpc::Status::OK; + + return grpc::Status(static_cast(s.code()), + s.error_message()); +} + +inline tensorflow::Status FromGrpcStatus(const grpc::Status& s) { + if (s.ok()) return tensorflow::Status::OK(); + + // Convert "UNKNOWN" stream removed errors into unavailable, to allow + // for retry upstream. + if (IsStreamRemovedError(s)) { + return tensorflow::Status(tensorflow::error::UNAVAILABLE, + s.error_message()); + } + return tensorflow::Status( + static_cast(s.error_code()), s.error_message()); +} + +inline std::string FormatGrpcStatus(const grpc::Status& s) { + return absl::Substitute("[$0] $1", s.error_code(), s.error_message()); +} + +} // namespace reverb +} // namespace deepmind + +#endif // LEARNING_DEEPMIND_REPLAY_REVERB_SUPPORT_GRPC_UTIL_H_ diff --git a/reverb/cc/support/intrusive_heap.cc b/reverb/cc/support/intrusive_heap.cc new file mode 100644 index 0000000..0626ba0 --- /dev/null +++ b/reverb/cc/support/intrusive_heap.cc @@ -0,0 +1,23 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/support/intrusive_heap.h" + +namespace deepmind { +namespace reverb { + +const IntrusiveHeapLink::size_type IntrusiveHeapLink::kNotMember; + +} +} // namespace deepmind diff --git a/reverb/cc/support/intrusive_heap.h b/reverb/cc/support/intrusive_heap.h new file mode 100644 index 0000000..f2c7742 --- /dev/null +++ b/reverb/cc/support/intrusive_heap.h @@ -0,0 +1,276 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// A heap that supports removing and adjusting the weights of arbitrary +// elements. To do so, it records the heap location of each element in +// a field within that element. +// +// The default IntrusiveHeap uses a heap value field embedded in each item to +// maintain the heap ordering. The item type T provides a public field +// named "heap" (by default) of type IntrusiveHeapLink, which the +// heap uses for this. See intrusive_heap_test.cc for an example. +// +// The storage for the heap link in elements can be customized by providing a +// LinkAccess policy. This should not commonly be required. + +#ifndef REVERB_CC_SUPPORT_INTRUSIVE_HEAP_H_ +#define REVERB_CC_SUPPORT_INTRUSIVE_HEAP_H_ + +#include + +#include +#include // NOLINT +#include + +#include "reverb/cc/platform/logging.h" + +namespace deepmind { +namespace reverb { + +// The bookkeeping area inside each element, used by IntrusiveHeap. +// IntrusiveHeap objects are configured with a LinkAccess policy with +// read-write access to the IntrusiveHeapLink object within each element. +// Currently implemented as a vector index. +class IntrusiveHeapLink { + public: + using size_type = size_t; + static const size_type kNotMember = -1; + + IntrusiveHeapLink() = default; + + // Only IntrusiveHeap and LinkAccess objects should make these objects. + explicit IntrusiveHeapLink(size_type pos) : pos_{pos} {} + + // Only IntrusiveHeap and LinkAccess should get the value. + size_type get() const { return pos_; } + + private: + size_type pos_{kNotMember}; +}; + +// Manipulate a link object accessible as a data member. +// Usable as an IntrusiveHeap's LinkAccess policy object (see IntrusiveHeap). +template +struct IntrusiveHeapDataMemberLinkAccess { + IntrusiveHeapLink Get(const T* elem) const { return elem->*M; } + void Set(T* elem, IntrusiveHeapLink link) const { elem->*M = link; } +}; + +// The default LinkAccess object, uses the 'heap' data member as a Link. +// Usable as an IntrusiveHeap's LinkAccess policy object (see IntrusiveHeap). +template +struct DefaultIntrusiveHeapLinkAccess { + IntrusiveHeapLink Get(const T* elem) const { return elem->heap; } + void Set(T* elem, IntrusiveHeapLink link) const { elem->heap = link; } +}; + +// IntrusiveHeap +// +// A min-heap (under PtrCompare ordering) of pointers to T. +// +// Supports random access removal of elements (in O(lg(N) time), but +// requires that the pointed-to elements provide a +// IntrusiveHeapLink data member (usually called 'heap'). +// +// T: the value type to be referenced by the IntrusiveHeap. Note that +// IntrusiveHeap does not take ownership of its elements; it merely points +// to them. +// PtrCompare: a binary predicate applying a strict weak ordering over +// 'const T*' returning true if and only if 'a' should be considered +// less than 'b'. Note that IntrusiveHeap is a min-heap under the +// PtrCompare ordering, such that if PtrCompare(a, b), then 'a' will be +// popped before 'b'. +// LinkAccess: Rarely specified, as the default is sufficient for most +// uses. A policy class providing functions with the signatures +// 'IntrusiveHeapLink Get(const T* elem)' and +// void Set(T* elem, IntrusiveHeapLink link)'. +// These functions allow for customization of location of +// the IntrusiveHeapLink member in a T* object. The default +// LinkAccessor policy's Get(elem) and Set(link,elem) functions +// manipulate the member accessed by 'elem->heap'. +// Alloc: an STL allocator for T* elements. Default is std::allocator. +// +// Note that the IntrusiveHeap does not hold or own any T objects, +// only pointers to them. Users must manage storage on their own. +template , + typename Alloc = std::allocator > +class IntrusiveHeap { + public: + typedef typename IntrusiveHeapLink::size_type size_type; + typedef T value_type; + typedef T* pointer; + typedef const T* const_pointer; + typedef PtrCompare pointer_compare_type; + typedef LinkAccess link_access_type; + typedef Alloc allocator_type; + + explicit IntrusiveHeap( + const pointer_compare_type& comp = pointer_compare_type(), + const link_access_type& link_access = link_access_type(), + const allocator_type& alloc = allocator_type()) + : rep_(comp, link_access, alloc) { } + + size_type size() const { + return heap().size(); + } + + bool empty() const { + return heap().empty(); + } + + // Return the top element, but don't remove it. + pointer top() const { + REVERB_CHECK(!empty()); + return heap()[0]; + } + + // Remove the top() pointer from the heap and return it. + pointer Pop() { + pointer t = top(); + Remove(t); + return t; + } + + // Insert 't' into the heap. + void Push(pointer t) { + SetPositionOf(t, heap().size()); + heap().push_back(t); + FixHeapUp(t); + } + + // Adjust the heap to accommodate changes in '*t'. + void Adjust(pointer t) { + REVERB_CHECK(Contains(t)); + size_type h = GetPositionOf(t); + if (h != 0 && compare()(t, heap()[(h - 1) >> 1])) { + FixHeapUp(t); + } else { + FixHeapDown(t); + } + } + + // Remove the specified pointer from the heap. + void Remove(pointer t) { + REVERB_CHECK(Contains(t)); + size_type h = GetPositionOf(t); + SetPositionOf(t, IntrusiveHeapLink::kNotMember); + if (h == heap().size() - 1) { + // Fast path for removing from back of heap. + heap().pop_back(); + return; + } + // Move the element from the back of the heap to overwrite 't'. + pointer& elem = heap()[h]; + elem = heap().back(); + SetPositionOf(elem, h); // Element has moved, so update its link. + heap().pop_back(); + Adjust(elem); // Restore the heap invariant. + } + + void Clear() { + heap().clear(); + } + + bool Contains(const_pointer t) const { + size_type h = GetPositionOf(t); + return (h != IntrusiveHeapLink::kNotMember) && + (h < size()) && + heap()[h] == t; + } + + void reserve(size_type n) { heap().reserve(n); } + + size_type capacity() const { return heap().capacity(); } + + allocator_type get_allocator() const { return rep_.heap_.get_allocator(); } + + private: + typedef std::vector heap_type; + + // Empty base class optimization for pointer_compare and link_access. + // The heap_ data member retains a copy of the allocator, so it is not + // stored explicitly. + struct Rep : pointer_compare_type, link_access_type { + explicit Rep(const pointer_compare_type& cmp, + const link_access_type& link_access, + const allocator_type& alloc) + : pointer_compare_type(cmp), + link_access_type(link_access), + heap_(alloc) { } + heap_type heap_; + }; + + const pointer_compare_type& compare() const { return rep_; } + + pointer_compare_type compare() { return rep_; } + + const link_access_type& link_access() const { return rep_; } + + const heap_type& heap() const { return rep_.heap_; } + heap_type& heap() { return rep_.heap_; } + + size_type GetPositionOf(const_pointer t) const { + return link_access().Get(t).get(); + } + + void SetPositionOf(pointer t, size_type pos) const { + return link_access().Set(t, IntrusiveHeapLink(pos)); + } + + void FixHeapUp(pointer t) { + size_type h = GetPositionOf(t); + while (h != 0) { + size_type parent = (h - 1) >> 1; + if (compare()(heap()[parent], t)) { + break; + } + heap()[h] = heap()[parent]; + SetPositionOf(heap()[h], h); + h = parent; + } + heap()[h] = t; + SetPositionOf(t, h); + } + + void FixHeapDown(pointer t) { + size_type h = GetPositionOf(t); + for (;;) { + size_type kid = (h << 1) + 1; + if (kid >= heap().size()) { + break; + } + if (kid + 1 < heap().size() && + compare()(heap()[kid + 1], heap()[kid])) { + ++kid; + } + if (compare()(t, heap()[kid])) { + break; + } + heap()[h] = heap()[kid]; + SetPositionOf(heap()[h], h); + h = kid; + } + + heap()[h] = t; + SetPositionOf(t, h); + } + + Rep rep_; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_SUPPORT_INTRUSIVE_HEAP_H_ diff --git a/reverb/cc/support/intrusive_heap_test.cc b/reverb/cc/support/intrusive_heap_test.cc new file mode 100644 index 0000000..69580d3 --- /dev/null +++ b/reverb/cc/support/intrusive_heap_test.cc @@ -0,0 +1,265 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/support/intrusive_heap.h" + +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/random/random.h" + +namespace deepmind { +namespace reverb { +namespace { + +static const int kNumElems = 100; + +class IntrusiveHeapTest : public testing::Test { + protected: + struct Elem { + int32_t val; + int iota; + IntrusiveHeapLink heap; // position in the heap + }; + + struct ElemChild : Elem {}; + + struct ElemLess { + bool operator()(const Elem* e1, const Elem* e2) const { + if (e1->val != e2->val) { + return e1->val < e2->val; + } + return e1->iota < e2->iota; + } + }; + struct ElemValLess { + bool operator()(const Elem& e1, const Elem& e2) const { + return ElemLess()(&e1, &e2); + } + }; + struct StatefulLess { + bool operator()(const Elem* e1, const Elem* e2) const { + return ElemLess()(e1, e2); + } + void* dummy; + }; + struct StatefulLinkAccess { + typedef IntrusiveHeapLink Link; + Link Get(const Elem* e) const { return e->heap; } + void Set(Elem* e, Link link) const { e->heap = link; } + void* dummy; + }; + + typedef IntrusiveHeap ElemHeap; + + absl::BitGen rnd_; + ElemHeap heap_; // The heap + std::vector elems_; // Storage for items in the heap + std::vector expected_; // Copy of the elements, for reference + + IntrusiveHeapTest() {} + + // Build a heap. + void BuildHeap() { + elems_.resize(kNumElems); + for (int i = 0; i < kNumElems; i++) { + elems_[i].val = absl::Uniform(rnd_); + elems_[i].iota = i; + heap_.Push(&elems_[i]); + expected_.push_back(elems_[i]); + } + } + + // Pop the elements from the heap, verifying they are as expected. + void VerifyHeap() { + EXPECT_EQ(expected_.size(), heap_.size()); + EXPECT_FALSE(heap_.empty()); + + ElemValLess less; + std::sort(expected_.begin(), expected_.end(), less); + + for (int i = 0; i < expected_.size(); i++) { + ASSERT_FALSE(heap_.empty()); + Elem* e = heap_.Pop(); + EXPECT_EQ(expected_[i].iota, e->iota) << i; + EXPECT_EQ(expected_[i].val, e->val) << i; + } + + EXPECT_EQ(0, heap_.size()); + EXPECT_TRUE(heap_.empty()); + } +}; + +TEST_F(IntrusiveHeapTest, PushPop) { + BuildHeap(); + VerifyHeap(); +} + +TEST_F(IntrusiveHeapTest, Clear) { + Elem dummy; + dummy.val = 8675309; + dummy.iota = 123456; + heap_.Push(&dummy); + heap_.Clear(); + EXPECT_EQ(0, heap_.size()); +} + +TEST_F(IntrusiveHeapTest, Contains) { + Elem dummy; + dummy.val = 8675309; + dummy.iota = 123456; + EXPECT_FALSE(heap_.Contains(&dummy)); + heap_.Push(&dummy); + EXPECT_TRUE(heap_.Contains(&dummy)); + heap_.Clear(); + EXPECT_FALSE(heap_.Contains(&dummy)); +} + +TEST_F(IntrusiveHeapTest, ContainsTwoHeaps) { + Elem dummy1; + dummy1.val = 8675309; + dummy1.iota = 123456; + Elem dummy2 = dummy1; + heap_.Push(&dummy1); + + ElemHeap other_heap; + + EXPECT_FALSE(other_heap.Contains(&dummy1)); + EXPECT_FALSE(other_heap.Contains(&dummy2)); + + other_heap.Push(&dummy2); + + EXPECT_TRUE(heap_.Contains(&dummy1)); + EXPECT_FALSE(heap_.Contains(&dummy2)); + EXPECT_FALSE(other_heap.Contains(&dummy1)); + EXPECT_TRUE(other_heap.Contains(&dummy2)); +} + +TEST_F(IntrusiveHeapTest, Remove) { + BuildHeap(); + + // Remove the second half of the elements. + for (int i = kNumElems / 2; i < kNumElems; i++) { + heap_.Remove(&elems_[i]); + } + expected_.resize(kNumElems / 2); + + VerifyHeap(); +} + +TEST_F(IntrusiveHeapTest, Adjust) { + BuildHeap(); + + // Adjust the weights of all elements. + for (int i = 0; i < kNumElems; i++) { + elems_[i].val = absl::Uniform(rnd_); + expected_[i].val = elems_[i].val; + heap_.Adjust(&elems_[i]); + } + + VerifyHeap(); +} + +TEST_F(IntrusiveHeapTest, EmptyBaseClassOptimization) { + // EBC optimization reduces size from 32 to 24 bytes. + // Testing that neither stateless PtrCompare nor stateless + // StatefulLinkAccess contribute to object size. + EXPECT_LT(sizeof(IntrusiveHeap), + sizeof(IntrusiveHeap)); + EXPECT_LT(sizeof(IntrusiveHeap), + sizeof(IntrusiveHeap)); + EXPECT_LT(sizeof(IntrusiveHeap), + sizeof(IntrusiveHeap)); +} + +// Test that an IntrusiveHeap can access +// T's HeapLink element even if T inherits it. +// That is, even if IntrusiveHeapLink data member comes from a base +// class of the Element type, we should still find it. +TEST_F(IntrusiveHeapTest, InheritElement) { + std::vector elems(5); + for (int i = 0; i < elems.size(); ++i) { + elems[i].val = (i * 19) % 7; + elems[i].iota = i; + } + typedef IntrusiveHeap Heap; + Heap heap; + for (ElemChild& e : elems) { + heap.Push(&e); + } + std::vector expected = elems; + std::sort(expected.begin(), expected.end(), ElemValLess()); + std::vector actual; + while (!heap.empty()) { + actual.push_back(*heap.Pop()); + } + for (int i = 0; i < expected.size(); ++i) { + EXPECT_EQ(expected[i].val, actual[i].val); + } +} + +class SyntheticLinkTest : public testing::Test { + public: + struct Element { + std::string value; + unsigned char heap; + }; + struct Access { + using Link = IntrusiveHeapLink; + Link Get(const Element* elem) const { + return Link(elem->heap); + } + void Set(Element* elem, Link link) const { + elem->heap = link.get(); + } + }; + struct PtrCompare { + bool operator()(const Element* a, const Element* b) const { + return a->value < b->value; + } + }; +}; + +TEST_F(SyntheticLinkTest, SetAndGet) { + IntrusiveHeap heap; + std::vector elems{{"d"}, {"b"}, {"e"}, {"a"}, {"c"}}; + for (auto& e : elems) heap.Push(&e); + std::vector out; + while (!heap.empty()) { + out.push_back(heap.top()->value); + heap.Pop(); + } + auto sorted = out; + std::sort(sorted.begin(), sorted.end()); + EXPECT_THAT(out, testing::ElementsAreArray(sorted)); +} + +TEST_F(SyntheticLinkTest, ReserveAndCapacity) { + IntrusiveHeap heap; + std::vector elems{{"d"}, {"b"}, {"e"}, {"a"}, {"c"}}; + EXPECT_EQ(0, heap.capacity()); + EXPECT_EQ(0, heap.size()); + heap.reserve(elems.size()); + EXPECT_EQ(0, heap.size()); + EXPECT_GE(heap.capacity(), elems.size()); + for (auto& e : elems) heap.Push(&e); + EXPECT_EQ(heap.size(), elems.size()); + EXPECT_GE(heap.capacity(), elems.size()); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/support/periodic_closure.cc b/reverb/cc/support/periodic_closure.cc new file mode 100644 index 0000000..7f440ef --- /dev/null +++ b/reverb/cc/support/periodic_closure.cc @@ -0,0 +1,83 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/support/periodic_closure.h" + +#include +#include + +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "reverb/cc/platform/logging.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" + +namespace deepmind { +namespace reverb { +namespace internal { + +PeriodicClosure::~PeriodicClosure() { + REVERB_CHECK(worker_ == nullptr) << "must be Stop()'d before destructed"; +} + +PeriodicClosure::PeriodicClosure(std::function fn, + absl::Duration period, std::string name_prefix) + : fn_(std::move(fn)), + period_(period), + name_prefix_(std::move(name_prefix)) { + REVERB_CHECK_GE(period_, absl::ZeroDuration()) << "period should be >= 0"; +} + +tensorflow::Status PeriodicClosure::Start() { + absl::WriterMutexLock lock(&mu_); + if (stopped_) { + return tensorflow::errors::InvalidArgument( + "PeriodicClosure: Start called after Close"); + } + if (worker_ != nullptr) { + return tensorflow::errors::InvalidArgument( + "PeriodicClosure: Start called when closure already running"); + } + worker_ = StartThread(name_prefix_, [this] { + for (auto next_run = absl::Now() + period_; true;) { + if (mu_.LockWhenWithDeadline(absl::Condition(&stopped_), next_run)) { + mu_.Unlock(); + return; + } + mu_.Unlock(); + next_run = absl::Now() + period_; + + fn_(); + } + }); + return tensorflow::Status::OK(); +} + +tensorflow::Status PeriodicClosure::Stop() { + { + absl::MutexLock l(&mu_); + if (stopped_) { + return tensorflow::errors::InvalidArgument( + "PeriodicClsoure: Stop called multiple times"); + } + stopped_ = true; + } + worker_ = nullptr; // Join thread. + return tensorflow::Status::OK(); +} + +} // namespace internal +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/support/periodic_closure.h b/reverb/cc/support/periodic_closure.h new file mode 100644 index 0000000..07e7a8c --- /dev/null +++ b/reverb/cc/support/periodic_closure.h @@ -0,0 +1,116 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_SUPPORT_PERIODIC_CLOSURE_H_ +#define REVERB_CC_SUPPORT_PERIODIC_CLOSURE_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "reverb/cc/platform/thread.h" +#include "tensorflow/core/lib/core/status.h" + +namespace deepmind { +namespace reverb { +namespace internal { + +// PeriodicClosure will periodically call the given closure with a specified +// period in a background thread. After Start() returns, the thread is +// guaranteed to have started and after Stop() returns, the thread is +// guaranteed to be stopped. Start()/Stop() may be called more than once; each +// pair of calls will result in a new thread being created and subsequently +// destroyed. +// +// PeriodicClosure runs the closure as soon as any previous run both is +// complete and was started more than "interval" earlier. Thus, runs are +// both serialized, and normally have a period of "interval" if no run +// exceeds the time. +// +// Note that, if the closure takes longer than the interval, then the closure is +// called immediately and the next call is scheduled at `interval` into the +// future. If the interval is 50ms and the first call to the closure takes 75ms +// and all other calls takes 25ms, then the closure will run at: 0ms, 75ms, +// 125ms, 175ms, and so on. +// +// This object is thread-safe. +// +// Example: +// +// class Foo { +// public: +// Foo() : periodic_closure_([this]() { Bar(); }, +// absl::Seconds(1)) { +// periodic_closure_.Start(); +// } +// +// ~Foo() { +// periodic_closure_.Stop(); +// } +// +// private: +// void Bar() { ... } +// +// PeriodicClosure periodic_closure_; +// }; +// +class PeriodicClosure { + public: + PeriodicClosure(std::function fn, absl::Duration period, + std::string name_prefix = ""); + + // Dies if `Start` but not `Stop` called. + ~PeriodicClosure(); + + // Starts the background thread that will be calling the closure. + // + // Returns InvalidArgument if called more than once. + tensorflow::Status Start(); + + // Waits for active closure call (if any) to complete and joins background + // thread. Must be called before object is destroyed and `Start` has been + // called. + // + // Returns InvalidArgument if called more than once. + tensorflow::Status Stop(); + + // PeriodicClosure is neither copyable nor movable. + PeriodicClosure(const PeriodicClosure&) = delete; + PeriodicClosure& operator=(const PeriodicClosure&) = delete; + + private: + // Closure called by the background thread. + const std::function fn_; + + // The minimum duration between calls to `fn_`. + const absl::Duration period_; + + // Name prefix assigned to background thread. + const std::string name_prefix_; + + // Flag to ensure that `Start` and `Stop` is not called multiple times. + bool stopped_ ABSL_GUARDED_BY(mu_) = false; + absl::Mutex mu_; + + // Background thread constructed in `Start` and joined in `Stop`. + std::unique_ptr worker_ = nullptr; +}; + +} // namespace internal +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_SUPPORT_PERIODIC_CLOSURE_H_ diff --git a/reverb/cc/support/periodic_closure_test.cc b/reverb/cc/support/periodic_closure_test.cc new file mode 100644 index 0000000..b640116 --- /dev/null +++ b/reverb/cc/support/periodic_closure_test.cc @@ -0,0 +1,129 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/support/periodic_closure.h" + +#include + +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "reverb/cc/testing/time_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace deepmind { +namespace reverb { +namespace internal { +namespace { + +void IncrementAndSleep(std::atomic_int* value, absl::Duration interval) { + *value += 1; + absl::SleepFor(interval); +} + +TEST(PeriodicClosureTest, ObeyInterval) { + const absl::Duration kPeriod = absl::Milliseconds(10); + const int kCalls = 10; + const absl::Duration timeout = (kPeriod * kCalls); + + std::atomic_int actual_calls(0); + auto callback = [&] { IncrementAndSleep(&actual_calls, kPeriod); }; + + PeriodicClosure pc(callback, kPeriod); + + TF_EXPECT_OK(pc.Start()); + absl::SleepFor(timeout); + TF_EXPECT_OK(pc.Stop()); + + // The closure could get called up to kCalls+1 times: once at time 0, once + // at time kPeriod, once at time kPeriod*2, up to once at time + // kPeriod*kCalls. It could be called fewer times if, say, the machine is + // overloaded, so let's check that: + // (kCalls - 5) <= actual_calls <= (kCalls + 1). + ASSERT_LE(kCalls - 5, actual_calls); + ASSERT_LE(actual_calls, kCalls + 1); +} + +// If this test hangs forever, its probably a deadlock caused by setting the +// PeriodicClosure's interval to 0ms. +TEST(PeriodicClosureTest, MinInterval) { + const absl::Duration kCallDuration = absl::Milliseconds(10); + + std::atomic_int actual_calls(0); + auto callback = [&] { IncrementAndSleep(&actual_calls, kCallDuration); }; + + PeriodicClosure pc(callback, absl::ZeroDuration()); + + TF_EXPECT_OK(pc.Start()); + + test::WaitFor([&]() { return actual_calls > 0 && actual_calls < 3; }, + kCallDuration, 100); + + ASSERT_GT(actual_calls, 0); + ASSERT_LT(actual_calls, 3); + + TF_EXPECT_OK(pc.Stop()); // we should be able to Stop() +} + +std::function DoNothing() { + return []() {}; +} + +TEST(PeriodicClosureDeathTest, BadInterval) { + EXPECT_DEATH(PeriodicClosure pc(DoNothing, absl::Milliseconds(-1)), + ".* should be >= 0"); +} + +TEST(PeriodicClosureDeathTest, NotStopped) { + PeriodicClosure* pc = + new PeriodicClosure(DoNothing(), absl::Milliseconds(10)); + + TF_EXPECT_OK(pc->Start()); + ASSERT_DEATH(delete pc, ".* before destructed"); + + TF_EXPECT_OK(pc->Stop()); + delete pc; +} + +TEST(PeriodicClosureDeathTest, DoubleStart) { + PeriodicClosure pc(DoNothing, absl::Milliseconds(10)); + + TF_EXPECT_OK(pc.Start()); + auto status = pc.Start(); + EXPECT_EQ(status.code(), tensorflow::error::INVALID_ARGUMENT); + + TF_EXPECT_OK(pc.Stop()); +} + +TEST(PeriodicClosureDeathTest, DoubleStop) { + PeriodicClosure pc(DoNothing, absl::Milliseconds(10)); + + TF_EXPECT_OK(pc.Start()); + + TF_EXPECT_OK(pc.Stop()); + auto status = pc.Stop(); + EXPECT_EQ(status.code(), tensorflow::error::INVALID_ARGUMENT); +} + +TEST(PeriodicClosureDeathTest, StartAfterStop) { + PeriodicClosure pc(DoNothing, absl::Milliseconds(10)); + + TF_EXPECT_OK(pc.Stop()); + auto status = pc.Start(); + EXPECT_EQ(status.code(), tensorflow::error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace internal +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/support/queue.h b/reverb/cc/support/queue.h new file mode 100644 index 0000000..f942dfb --- /dev/null +++ b/reverb/cc/support/queue.h @@ -0,0 +1,106 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_SUPPORT_QUEUE_H_ +#define REVERB_CC_SUPPORT_QUEUE_H_ + +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" + +namespace deepmind { +namespace reverb { +namespace internal { + +// Thread-safe and closable queue with fixed capacity (buffered channel). +// +// A call to `Push()` inserts an item while `Pop()` removes and retrieves an +// item in fifo order. Once the maximum capacity has been reached, calls to +// `Push()` block until the queue is no longer full. Similarly, `Pop()` blocks +// if there are no items in the queue. `Close()` can be called to unblock any +// pending and future calls to `Push()` and `Pop()`. +// +// Note: This implementation is only intended for a single producer and single +// consumer use case, where Close() is called by the consumer. The +// implementation has not been tested for other use cases! +template +class Queue { + public: + // `capacity` is the maximum number of elements which the queue can hold. + explicit Queue(int capacity) + : buffer_(capacity), size_(0), index_(0), closed_(false) {} + + // Closes the queue. All pending and future calls to `Push()` and `Pop()` are + // unblocked and return false without performing the operation. Additional + // calls of Close after the first one have no effect. + void Close() { + absl::MutexLock lock(&mu_); + closed_ = true; + } + + // Pushes an item to the queue. Blocks if the queue has reached `capacity`. On + // success, `true` is returned. If the queue is closed, `false` is returned. + bool Push(T x) { + absl::MutexLock lock(&mu_); + mu_.Await(absl::Condition( + +[](Queue* q) { return q->closed_ || q->size_ < q->buffer_.size(); }, + this)); + if (closed_) return false; + buffer_[(index_ + size_) % buffer_.size()] = std::move(x); + ++size_; + return true; + } + + // Removes an element from the queue and move-assigns it to *item. Blocks if + // the queue is empty. On success, `true` is returned. If the queue was + // closed, `false` is returned. + bool Pop(T* item) { + absl::MutexLock lock(&mu_); + mu_.Await(absl::Condition( + +[](Queue* q) { return q->closed_ || q->size_ > 0; }, this)); + if (closed_) return false; + *item = std::move(buffer_[index_]); + index_ = (index_ + 1) % buffer_.size(); + --size_; + return true; + } + + // Current number of elements. + int size() const { + absl::ReaderMutexLock lock(&mu_); + return size_; + } + + private: + mutable absl::Mutex mu_; + + // Circular buffer. Initialized with fixed size `capacity_`. + std::vector buffer_; + + // Current number of elements. + int size_; + + // Index of the beginning of the queue in the circular buffer. + int index_; + + // Whether `Close()` was called. + bool closed_; +}; + +} // namespace internal +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_SUPPORT_QUEUE_H_ diff --git a/reverb/cc/support/queue_test.cc b/reverb/cc/support/queue_test.cc new file mode 100644 index 0000000..25e3b66 --- /dev/null +++ b/reverb/cc/support/queue_test.cc @@ -0,0 +1,125 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/support/queue.h" + +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/synchronization/notification.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/platform/thread.h" + +namespace deepmind { +namespace reverb { +namespace internal { +namespace { + +TEST(QueueTest, PushAndPopAreConsistent) { + Queue q(10); + int output; + for (int i = 0; i < 100; i++) { + q.Push(i); + q.Pop(&output); + EXPECT_EQ(output, i); + } +} + +TEST(QueueTest, PushBlocksWhenFull) { + Queue q(2); + ASSERT_TRUE(q.Push(1)); + ASSERT_TRUE(q.Push(2)); + absl::Notification n; + auto t = StartThread("", [&q, &n] { + REVERB_CHECK(q.Push(3)); + n.Notify(); + }); + ASSERT_FALSE(n.HasBeenNotified()); + int output; + ASSERT_TRUE(q.Pop(&output)); + n.WaitForNotification(); + EXPECT_EQ(output, 1); +} + +TEST(QueueTest, PopBlocksWhenEmpty) { + Queue q(2); + absl::Notification n; + int output; + auto t = StartThread("", [&q, &n, &output] { + REVERB_CHECK(q.Pop(&output)); + n.Notify(); + }); + ASSERT_FALSE(n.HasBeenNotified()); + ASSERT_TRUE(q.Push(1)); + n.WaitForNotification(); + EXPECT_EQ(output, 1); +} + +TEST(QueueTest, AfterClosePushAndPopReturnFalse) { + Queue q(2); + q.Close(); + EXPECT_FALSE(q.Push(1)); + EXPECT_FALSE(q.Pop(nullptr)); +} + +TEST(QueueTest, CloseUnblocksPush) { + Queue q(2); + ASSERT_TRUE(q.Push(1)); + ASSERT_TRUE(q.Push(2)); + absl::Notification n; + bool ok; + auto t = StartThread("", [&q, &n, &ok] { + ok = q.Push(3); + n.Notify(); + }); + ASSERT_FALSE(n.HasBeenNotified()); + q.Close(); + n.WaitForNotification(); + EXPECT_FALSE(ok); +} + +TEST(QueueTest, CloseUnblocksPop) { + Queue q(2); + absl::Notification n; + bool ok; + auto t = StartThread("", [&q, &n, &ok] { + int output; + ok = q.Pop(&output); + n.Notify(); + }); + ASSERT_FALSE(n.HasBeenNotified()); + q.Close(); + n.WaitForNotification(); + EXPECT_FALSE(ok); +} + +TEST(QueueTest, SizeReturnsNumberOfElements) { + Queue q(3); + EXPECT_EQ(q.size(), 0); + + q.Push(20); + q.Push(30); + EXPECT_EQ(q.size(), 2); + + int v; + ASSERT_TRUE(q.Pop(&v)); + EXPECT_EQ(q.size(), 1); +} + +} // namespace +} // namespace internal +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/support/signature.cc b/reverb/cc/support/signature.cc new file mode 100644 index 0000000..a7e7578 --- /dev/null +++ b/reverb/cc/support/signature.cc @@ -0,0 +1,173 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/support/signature.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.h" + +namespace deepmind { +namespace reverb { +namespace internal { + +tensorflow::Status FlatSignatureFromTableInfo( + const TableInfo& info, DtypesAndShapes* dtypes_and_shapes) { + if (!info.has_signature()) { + *dtypes_and_shapes = absl::nullopt; + } else { + const auto& sig = info.signature(); + *dtypes_and_shapes = DtypesAndShapes::value_type{}; + auto status = FlatSignatureFromStructuredValue(sig, dtypes_and_shapes); + if (!status.ok()) { + tensorflow::errors::AppendToMessage(&status, "Full signature struct: '", + info.signature().DebugString(), "'"); + return status; + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status FlatSignatureFromStructuredValue( + const tensorflow::StructuredValue& value, + DtypesAndShapes* dtypes_and_shapes) { + switch (value.kind_case()) { + case tensorflow::StructuredValue::kTensorSpecValue: { + const auto& tensor_spec = value.tensor_spec_value(); + (*dtypes_and_shapes) + ->push_back({tensor_spec.dtype(), + tensorflow::PartialTensorShape(tensor_spec.shape())}); + } break; + case tensorflow::StructuredValue::kListValue: { + for (const auto& v : value.list_value().values()) { + TF_RETURN_IF_ERROR( + FlatSignatureFromStructuredValue(v, dtypes_and_shapes)); + } + } break; + case tensorflow::StructuredValue::kTupleValue: { + for (const auto& v : value.tuple_value().values()) { + TF_RETURN_IF_ERROR( + FlatSignatureFromStructuredValue(v, dtypes_and_shapes)); + } + } break; + case tensorflow::StructuredValue::kDictValue: { + std::vector keys; + keys.reserve(value.dict_value().fields_size()); + for (const auto& f : value.dict_value().fields()) { + keys.push_back(f.first); + } + std::sort(keys.begin(), keys.end()); + for (const auto& k : keys) { + TF_RETURN_IF_ERROR(FlatSignatureFromStructuredValue( + value.dict_value().fields().at(k), dtypes_and_shapes)); + } + } break; + case tensorflow::StructuredValue::kNamedTupleValue: { + for (const auto &p : value.named_tuple_value().values()) { + TF_RETURN_IF_ERROR(FlatSignatureFromStructuredValue( + p.value(), dtypes_and_shapes)); + } + } break; + default: + return tensorflow::errors::InvalidArgument( + "Saw unsupported encoded subtree in signature: '", + value.DebugString(), "'"); + } + return tensorflow::Status::OK(); +} + +std::string DtypesShapesString( + const std::vector& + dtypes_and_shapes) { + std::vector strings; + strings.reserve(dtypes_and_shapes.size()); + for (const auto& p : dtypes_and_shapes) { + strings.push_back(absl::StrCat( + "Tensor")); + } + return absl::StrJoin(strings, ", "); +} + +tensorflow::StructuredValue StructuredValueFromChunkData( + const ChunkData& chunk_data) { + tensorflow::StructuredValue value; + for (int i = 0; i < chunk_data.data_size(); i++) { + const auto& chunk = chunk_data.data(i); + tensorflow::PartialTensorShape shape(chunk.tensor_shape()); + shape.RemoveDim(0); + + auto* spec = + value.mutable_list_value()->add_values()->mutable_tensor_spec_value(); + spec->set_dtype(chunk.dtype()); + shape.AsProto(spec->mutable_shape()); + } + + return value; +} + +tensorflow::Status FlatPathFromStructuredValue( + const tensorflow::StructuredValue& value, absl::string_view prefix, + std::vector* paths) { + switch (value.kind_case()) { + case tensorflow::StructuredValue::kTensorSpecValue: + paths->push_back(std::string(prefix)); + break; + case tensorflow::StructuredValue::kListValue: { + for (int i = 0; i < value.list_value().values_size(); i++) { + TF_RETURN_IF_ERROR(FlatPathFromStructuredValue( + value.list_value().values(i), absl::StrCat(prefix, "[", i, "]"), + paths)); + } + } break; + case tensorflow::StructuredValue::kTupleValue: { + for (int i = 0; i < value.tuple_value().values_size(); i++) { + TF_RETURN_IF_ERROR(FlatPathFromStructuredValue( + value.tuple_value().values(i), absl::StrCat(prefix, "[", i, "]"), + paths)); + } + } break; + case tensorflow::StructuredValue::kDictValue: { + std::vector keys; + keys.reserve(value.dict_value().fields_size()); + for (const auto& f : value.dict_value().fields()) { + keys.push_back(f.first); + } + std::sort(keys.begin(), keys.end()); + for (const auto& k : keys) { + TF_RETURN_IF_ERROR( + FlatPathFromStructuredValue(value.dict_value().fields().at(k), + absl::StrCat(prefix, ".", k), paths)); + } + } break; + case tensorflow::StructuredValue::kNamedTupleValue: { + for (const auto& p : value.named_tuple_value().values()) { + TF_RETURN_IF_ERROR(FlatPathFromStructuredValue( + p.value(), absl::StrCat(prefix, ".", p.key()), paths)); + } + } break; + default: + return tensorflow::errors::InvalidArgument( + "Saw unsupported encoded subtree in signature: '", + value.DebugString(), "'"); + } + return tensorflow::Status::OK(); +} + +} // namespace internal +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/support/signature.h b/reverb/cc/support/signature.h new file mode 100644 index 0000000..9566ade --- /dev/null +++ b/reverb/cc/support/signature.h @@ -0,0 +1,64 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_SUPPORT_SIGNATURE_H_ +#define REVERB_CC_SUPPORT_SIGNATURE_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "reverb/cc/schema.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace deepmind { +namespace reverb { +namespace internal { + +typedef absl::optional> + DtypesAndShapes; + +tensorflow::Status FlatSignatureFromTableInfo( + const TableInfo& info, DtypesAndShapes* dtypes_and_shapes); + +tensorflow::Status FlatSignatureFromStructuredValue( + const tensorflow::StructuredValue& value, + DtypesAndShapes* dtypes_and_shapes); + +tensorflow::StructuredValue StructuredValueFromChunkData( + const ChunkData& chunk_data); + +tensorflow::Status FlatPathFromStructuredValue( + const tensorflow::StructuredValue& value, absl::string_view prefix, + std::vector* paths); + +// Map from table name to optional vector of flattened (dtype, shape) pairs. +typedef absl::flat_hash_map + FlatSignatureMap; + +std::string DtypesShapesString( + const std::vector& + dtypes_and_shapes); + +} // namespace internal +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_SUPPORT_SIGNATURE_H_ diff --git a/reverb/cc/support/uint128.h b/reverb/cc/support/uint128.h new file mode 100644 index 0000000..bcaec2d --- /dev/null +++ b/reverb/cc/support/uint128.h @@ -0,0 +1,38 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_SUPPORT_UINT128_H_ +#define REVERB_CC_SUPPORT_UINT128_H_ + +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/support/uint128.h" + +namespace deepmind { +namespace reverb { + +inline Uint128 Uint128ToMessage(const absl::uint128& value) { + Uint128 message; + message.set_high(absl::Uint128High64(value)); + message.set_low(absl::Uint128Low64(value)); + return message; +} + +inline absl::uint128 MessageToUint128(const Uint128& message) { + return absl::MakeUint128(message.high(), message.low()); +} + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_SUPPORT_UINT128_H_ diff --git a/reverb/cc/table_extensions/BUILD b/reverb/cc/table_extensions/BUILD new file mode 100644 index 0000000..c2598be --- /dev/null +++ b/reverb/cc/table_extensions/BUILD @@ -0,0 +1,18 @@ +load( + "//reverb/cc/platform:build_rules.bzl", + "reverb_absl_deps", + "reverb_cc_library", +) + +package(default_visibility = ["//reverb:__subpackages__"]) + +licenses(["notice"]) + +reverb_cc_library( + name = "interface", + hdrs = ["interface.h"], + deps = [ + "//reverb/cc:priority_table_item", + "//reverb/cc:schema_cc_proto", + ] + reverb_absl_deps(), +) diff --git a/reverb/cc/table_extensions/interface.h b/reverb/cc/table_extensions/interface.h new file mode 100644 index 0000000..c4598f5 --- /dev/null +++ b/reverb/cc/table_extensions/interface.h @@ -0,0 +1,74 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_TABLE_EXTENSIONS_INTERFACE_H_ +#define REVERB_CC_TABLE_EXTENSIONS_INTERFACE_H_ + +#include + +#include +#include "reverb/cc/priority_table_item.h" +#include "reverb/cc/schema.pb.h" + +namespace deepmind { +namespace reverb { + +class PriorityTable; + +// A `PriorityTableExtension` is passed to a single `PriorityTable` and executed +// as part of the atomic operations of the parent table. All "hooks" are +// executed while parent is holding its mutex and thus latency is very +// important. +class PriorityTableExtensionInterface { + public: + virtual ~PriorityTableExtensionInterface() = default; + + // Executed just after item is inserted into parent `PriorityTable`. + virtual void OnInsert(const PriorityTableItem& item) = 0; + + // Executed just before item is removed from parent `PriorityTable`. + virtual void OnDelete(const PriorityTableItem& item) = 0; + + // Executed just after the priority of an item has been updated in parent + // `PriorityTable`. `OnUpdate` of all registered extensions are called before + // `Diffuse` is called. + virtual void OnUpdate(const PriorityTableItem& item) = 0; + + // Executed just before a sample is returned. The sample count of the item + // includes the active sample and thus always is >= 1. + virtual void OnSample(const PriorityTableItem& item) = 0; + + // Executed just before all items are deleted. + virtual void OnReset() = 0; + + // Diffuses the update to the neighborhood and returns a vector of updates + // that should be applied as a result. + // + // `item` is the updated item after the update has been applied and + // `old_priority` is was the priority of the item before the update was + // applied. + // + // This method must only be called from `table` as mutex lock is held as part + // of an update. + // + // `table` must not be nullptr and `item` must contain chunks. + virtual std::vector Diffuse(PriorityTable* table, + const PriorityTableItem& item, + double old_priority) = 0; +}; + +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_TABLE_EXTENSIONS_INTERFACE_H_ diff --git a/reverb/cc/tensor_compression.cc b/reverb/cc/tensor_compression.cc new file mode 100644 index 0000000..169a5eb --- /dev/null +++ b/reverb/cc/tensor_compression.cc @@ -0,0 +1,111 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/tensor_compression.h" + +#include + +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/platform/snappy.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace deepmind { +namespace reverb { +namespace { + +template +tensorflow::Tensor DeltaEncode(const tensorflow::Tensor& tensor, bool encode) { + tensorflow::Tensor output(tensor.dtype(), tensor.shape()); + + tensorflow::Tensor tensor_reinterpret; + TF_CHECK_OK(tensor_reinterpret.BitcastFrom( + tensor, tensorflow::DataTypeToEnum::v(), tensor.shape())); + + tensorflow::Tensor output_reinterpret; + TF_CHECK_OK(output_reinterpret.BitcastFrom( + output, tensorflow::DataTypeToEnum::v(), output.shape())); + + auto src = tensor_reinterpret.flat_outer_dims(); + auto dst = output_reinterpret.flat_outer_dims(); + for (int j = 0; j < src.dimension(1); j++) { + dst(0, j) = src(0, j); + } + for (int i = 1; i < src.dimension(0); i++) { + for (int j = 0; j < src.dimension(1); j++) { + dst(i, j) = src(i, j) + (encode ? -src(i - 1, j) : dst(i - 1, j)); + } + } + return output; +} + +} // namespace + +tensorflow::Tensor DeltaEncode(const tensorflow::Tensor& tensor, bool encode) { + if (tensor.dims() < 2) return tensor; + + switch (tensor.dtype()) { +#define DELTA_ENCODE(T) \ + case tensorflow::DataTypeToEnum::value: \ + return DeltaEncode::Type>(tensor, encode); + TF_CALL_INTEGRAL_TYPES(DELTA_ENCODE) +#undef DELTA_ENCODE + default: + return tensor; + } +} + +std::vector DeltaEncodeList( + const std::vector& tensors, bool encode) { + std::vector outputs; + outputs.reserve(tensors.size()); + for (const tensorflow::Tensor& tensor : tensors) { + outputs.push_back(DeltaEncode(tensor, encode)); + } + return outputs; +} + +void CompressTensorAsProto(const tensorflow::Tensor& tensor, + tensorflow::TensorProto* proto) { + if (tensor.dtype() == tensorflow::DT_STRING) { + tensor.AsProtoTensorContent(proto); + } else { + proto->set_dtype(tensor.dtype()); + tensor.shape().AsProto(proto->mutable_tensor_shape()); + SnappyCompressFromString(tensor.tensor_data(), + proto->mutable_tensor_content()); + } +} + +tensorflow::Tensor DecompressTensorFromProto( + const tensorflow::TensorProto& proto) { + if (proto.dtype() == tensorflow::DT_STRING) { + tensorflow::Tensor tensor; + REVERB_CHECK(tensor.FromProto(proto)); + return tensor; + } else { + tensorflow::Tensor tensor(proto.dtype(), + tensorflow::TensorShape(proto.tensor_shape())); + const auto& tensor_content = proto.tensor_content(); + SnappyUncompressToString(tensor_content, + tensor.tensor_data().size(), + const_cast(tensor.tensor_data().data())); + return tensor; + } +} + +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/tensor_compression.h b/reverb/cc/tensor_compression.h new file mode 100644 index 0000000..a119097 --- /dev/null +++ b/reverb/cc/tensor_compression.h @@ -0,0 +1,70 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LEARNING_DEEPMIND_REPLAY_REVERB_TENSOR_COMPRESSION_H_ +#define LEARNING_DEEPMIND_REPLAY_REVERB_TENSOR_COMPRESSION_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" + +namespace deepmind { +namespace reverb { + +// Delta encodes INT8,16,32,64 and UINT8,16,32,64 tensors of dimensions >= 2. +// The first dimension is assumed to be the time step and each timestep will be +// encoded as follows: output[i] = input[i] - input[i-1]. For encoding +// `encode=true` should be passed, for decoding `encode=false`. +tensorflow::Tensor DeltaEncode(const tensorflow::Tensor& tensor, bool encode); + +// Applies `DeltaEncode` on a vector of tensors. +std::vector DeltaEncodeList( + const std::vector& tensors, bool encode); + +// Compresses a Tensor with Zippy. The resulting `proto` must be read with +// `DecompressTensorFromProto`. Note that string tensors are not compressed. +void CompressTensorAsProto(const tensorflow::Tensor& tensor, + tensorflow::TensorProto* proto); + +// Assumes that the TensorProto was built by calling `CompressTensorAsProto`. +tensorflow::Tensor DecompressTensorFromProto( + const tensorflow::TensorProto& proto); + +template +struct UnsignedType { + static_assert( + tensorflow::kDataTypeIsUnsigned.Contains( + tensorflow::DataTypeToEnum::value), + "Attempt to treat signed data type as unsigned. Perhaps a new integer " + "type was added to TensorFlow's TF_CALL_INTEGRAL_TYPES? Please extend " + "UnsignedType specializations for this new data type."); + typedef T Type; +}; + +#define REVERB_CREATE_UNSIGNED_TYPE(S, U) \ + template <> \ + struct UnsignedType { \ + typedef U Type; \ + }; + +REVERB_CREATE_UNSIGNED_TYPE(tensorflow::int8, tensorflow::uint8) +REVERB_CREATE_UNSIGNED_TYPE(tensorflow::int16, tensorflow::uint16) +REVERB_CREATE_UNSIGNED_TYPE(tensorflow::int32, tensorflow::uint32) +REVERB_CREATE_UNSIGNED_TYPE(tensorflow::int64, tensorflow::uint64) + +#undef REVERB_CREATE_UNSIGNED_TYPE + +} // namespace reverb +} // namespace deepmind + +#endif // LEARNING_DEEPMIND_REPLAY_REVERB_TENSOR_COMPRESSION_H_ diff --git a/reverb/cc/tensor_compression_test.cc b/reverb/cc/tensor_compression_test.cc new file mode 100644 index 0000000..1aa4958 --- /dev/null +++ b/reverb/cc/tensor_compression_test.cc @@ -0,0 +1,103 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/tensor_compression.h" + +#include + +#include "gtest/gtest.h" +#include "reverb/cc/testing/tensor_testutil.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace deepmind { +namespace reverb { +namespace { + +template +void EncodeMatchesDecodeT() { + tensorflow::Tensor tensor(tensorflow::DataTypeToEnum::v(), + tensorflow::TensorShape({16, 37, 6})); + tensor.flat().setRandom(); + tensorflow::Tensor encoded = DeltaEncode(tensor, true); + tensorflow::Tensor decoded = DeltaEncode(encoded, false); + test::ExpectTensorEqual(tensor, decoded); +} + +TEST(TensorCompressionTest, EncodeMatchesDecode) { +#define ENCODE_MATCHES_DECODE(T) EncodeMatchesDecodeT(); + TF_CALL_INTEGRAL_TYPES(ENCODE_MATCHES_DECODE) +#undef ENCODE_MATCHES_DECODE + EncodeMatchesDecodeT(); + EncodeMatchesDecodeT(); + EncodeMatchesDecodeT(); +} + +TEST(TensorCompressionTest, EncodeListMatchesDecode) { + tensorflow::Tensor tensor(tensorflow::DT_INT32, + tensorflow::TensorShape({16, 37, 6})); + tensor.flat().setRandom(); + std::vector tensors{tensor, tensor}; + std::vector encoded = DeltaEncodeList(tensors, true); + std::vector decoded = DeltaEncodeList(encoded, false); + EXPECT_EQ(tensors.size(), decoded.size()); + for (int i = 0; i < tensors.size(); i++) { + test::ExpectTensorEqual(tensors[i], decoded[i]); + } +} + +TEST(TensorCompressionTest, StringTensor) { + tensorflow::Tensor tensor(tensorflow::DT_STRING, + tensorflow::TensorShape({2})); + tensor.flat()(0) = "hello"; + tensor.flat()(1) = "world"; + + tensorflow::TensorProto proto; + CompressTensorAsProto(tensor, &proto); + + tensorflow::Tensor result = DecompressTensorFromProto(proto); + test::ExpectTensorEqual(tensor, result); +} + +TEST(TensorCompressionTest, NonStringTensor) { + tensorflow::Tensor tensor(tensorflow::DT_INT32, + tensorflow::TensorShape({2, 2})); + tensor.flat().setRandom(); + + tensorflow::TensorProto proto; + CompressTensorAsProto(tensor, &proto); + + tensorflow::Tensor result = DecompressTensorFromProto(proto); + test::ExpectTensorEqual(tensor, result); +} + +TEST(TensorCompressionTest, NonStringTensorWithDeltaEncoding) { + tensorflow::Tensor tensor(tensorflow::DT_INT32, + tensorflow::TensorShape({2, 2})); + tensor.flat().setRandom(); + + tensorflow::TensorProto proto; + CompressTensorAsProto(DeltaEncode(tensor, true), &proto); + + tensorflow::Tensor result = DecompressTensorFromProto(proto); + test::ExpectTensorEqual(tensor, DeltaEncode(result, false)); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/testing/BUILD b/reverb/cc/testing/BUILD new file mode 100644 index 0000000..a27ecf0 --- /dev/null +++ b/reverb/cc/testing/BUILD @@ -0,0 +1,38 @@ +load( + "//reverb/cc/platform:build_rules.bzl", + "reverb_absl_deps", + "reverb_cc_library", + "reverb_tf_deps", +) + +package(default_visibility = ["//reverb:__subpackages__"]) + +licenses(["notice"]) + +reverb_cc_library( + name = "proto_test_util", + testonly = 1, + srcs = ["proto_test_util.cc"], + hdrs = ["proto_test_util.h"], + deps = [ + "//reverb/cc:schema_cc_proto", + "//reverb/cc:tensor_compression", + "//reverb/cc/platform:logging", + ] + reverb_tf_deps(), +) + +reverb_cc_library( + name = "tensor_testutil", + testonly = 1, + hdrs = ["tensor_testutil.h"], + deps = [ + "//reverb/cc/platform:logging", + ] + reverb_tf_deps(), +) + +reverb_cc_library( + name = "time_testutil", + testonly = 1, + hdrs = ["time_testutil.h"], + deps = reverb_absl_deps(), +) diff --git a/reverb/cc/testing/proto_test_util.cc b/reverb/cc/testing/proto_test_util.cc new file mode 100644 index 0000000..b149d6f --- /dev/null +++ b/reverb/cc/testing/proto_test_util.cc @@ -0,0 +1,81 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reverb/cc/testing/proto_test_util.h" + +#include + +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/schema.pb.h" +#include "reverb/cc/tensor_compression.h" +#include "tensorflow/core/framework/tensor.h" + +namespace deepmind { +namespace reverb { +namespace testing { + +ChunkData MakeChunkData(uint64_t key) { + return MakeChunkData(key, MakeSequenceRange(key * 100, 0, 1)); +} + +ChunkData MakeChunkData(uint64_t key, SequenceRange range) { + ChunkData chunk; + chunk.set_chunk_key(key); + tensorflow::Tensor t(tensorflow::DT_INT32, + {range.end() - range.start() + 1, 10}); + t.flat().setConstant(1); + CompressTensorAsProto(t, chunk.add_data()); + *chunk.mutable_sequence_range() = std::move(range); + + return chunk; +} + +SequenceRange MakeSequenceRange(uint64_t episode_id, int32_t start, int32_t end) { + REVERB_CHECK_LE(start, end); + SequenceRange sequence; + sequence.set_episode_id(episode_id); + sequence.set_start(start); + sequence.set_end(end); + return sequence; +} + +KeyWithPriority MakeKeyWithPriority(uint64_t key, double priority) { + KeyWithPriority update; + update.set_key(key); + update.set_priority(priority); + return update; +} + +PrioritizedItem MakePrioritizedItem(uint64_t key, double priority, + const std::vector& chunks) { + QCHECK(!chunks.empty()); + + PrioritizedItem item; + item.set_key(key); + item.set_priority(priority); + + for (const auto& chunk : chunks) { + item.add_chunk_keys(chunk.chunk_key()); + } + + item.mutable_sequence_range()->set_length( + 1 + chunks.back().sequence_range().end() - + chunks.front().sequence_range().start()); + + return item; +} + +} // namespace testing +} // namespace reverb +} // namespace deepmind diff --git a/reverb/cc/testing/proto_test_util.h b/reverb/cc/testing/proto_test_util.h new file mode 100644 index 0000000..4542993 --- /dev/null +++ b/reverb/cc/testing/proto_test_util.h @@ -0,0 +1,129 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_TESTING_PROTO_TEST_UTIL_H_ +#define REVERB_CC_TESTING_PROTO_TEST_UTIL_H_ + +#include + +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/message_differencer.h" +#include "gmock/gmock.h" +#include "reverb/cc/platform/logging.h" +#include "reverb/cc/schema.pb.h" + +namespace deepmind { +namespace reverb { +namespace testing { + +ChunkData MakeChunkData(uint64_t key); +ChunkData MakeChunkData(uint64_t key, SequenceRange range); + +SequenceRange MakeSequenceRange(uint64_t episode_id, int32_t start, int32_t end); + +KeyWithPriority MakeKeyWithPriority(uint64_t key, double priority); + +PrioritizedItem MakePrioritizedItem(uint64_t key, double priority, + const std::vector& chunks); + +// Simple implementation of a proto matcher comparing string representations. +// +// IMPORTANT: Only use this for protos whose textual representation is +// deterministic (that may not be the case for the map collection type). + +class ProtoStringMatcher { + public: + explicit ProtoStringMatcher(const std::string& expected) + : expected_proto_str_(expected) {} + explicit ProtoStringMatcher(const google::protobuf::Message& expected) + : expected_proto_str_(expected.DebugString()) {} + + template + bool MatchAndExplain(const Message& actual_proto, + ::testing::MatchResultListener* listener) const; + + void DescribeTo(::std::ostream* os) const { *os << expected_proto_str_; } + void DescribeNegationTo(::std::ostream* os) const { + *os << "not equal to expected message: " << expected_proto_str_; + } + + void SetComparePartially() { + scope_ = ::google::protobuf::util::MessageDifferencer::PARTIAL; + } + + private: + const std::string expected_proto_str_; + google::protobuf::util::MessageDifferencer::Scope scope_ = + google::protobuf::util::MessageDifferencer::FULL; +}; + +template +T CreateProto(const std::string& textual_proto) { + T proto; + REVERB_CHECK(google::protobuf::TextFormat::ParseFromString(textual_proto, &proto)); + return proto; +} + +template +bool ProtoStringMatcher::MatchAndExplain( + const Message& actual_proto, + ::testing::MatchResultListener* listener) const { + Message expected_proto = CreateProto(expected_proto_str_); + + google::protobuf::util::MessageDifferencer differencer; + std::string differences; + differencer.ReportDifferencesToString(&differences); + differencer.set_scope(scope_); + + if (!differencer.Compare(expected_proto, actual_proto)) { + *listener << "the protos are different:\n" << differences; + return false; + } + + return true; +} + +// Polymorphic matcher to compare any two protos. +inline ::testing::PolymorphicMatcher EqualsProto( + const std::string& x) { + return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); +} + +// Polymorphic matcher to compare any two protos. +inline ::testing::PolymorphicMatcher EqualsProto( + const google::protobuf::Message& x) { + return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); +} + +// Only compare the fields populated in the matcher proto. +template +inline InnerProtoMatcher Partially(InnerProtoMatcher inner_proto_matcher) { + inner_proto_matcher.mutable_impl().SetComparePartially(); + return inner_proto_matcher; +} + +// Parse input string as a protocol buffer. +template +T ParseTextProtoOrDie(const std::string& input) { + T result; + REVERB_CHECK(google::protobuf::TextFormat::ParseFromString(input, &result)) + << "Failed to parse: " << input; + return result; +} + +} // namespace testing +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_TESTING_PROTO_TEST_UTIL_H_ diff --git a/reverb/cc/testing/tensor_testutil.h b/reverb/cc/testing/tensor_testutil.h new file mode 100644 index 0000000..c055d36 --- /dev/null +++ b/reverb/cc/testing/tensor_testutil.h @@ -0,0 +1,227 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_TESTING_TENSOR_TESTUTIL_H +#define REVERB_CC_TESTING_TENSOR_TESTUTIL_H + +#include + +#include "gtest/gtest.h" +#include "reverb/cc/platform/logging.h" +#include "tensorflow/core/framework/tensor.h" + +namespace deepmind { +namespace reverb { +namespace test { + +// Expects "x" and "y" are tensors of the same type, same shape, and +// identical values. +template +void ExpectTensorEqual(const tensorflow::Tensor& x, + const tensorflow::Tensor& y); + +// Expects "x" and "y" are tensors of the same type, same shape, and +// approximate equal values, each within "abs_err". +template +void ExpectTensorNear(const tensorflow::Tensor& x, const tensorflow::Tensor& y, + const T& abs_err); + +// Expects "x" and "y" are tensors of the same type (float or double), +// same shape and element-wise difference between x and y is no more +// than atol + rtol * abs(x). If atol or rtol is negative, it is replaced +// with a default tolerance value = data type's epsilon * kSlackFactor. +void ExpectClose(const tensorflow::Tensor& x, const tensorflow::Tensor& y, + double atol = -1.0, double rtol = -1.0); + +// Implementation details. + +namespace internal { + +template +struct is_floating_point_type { + static const bool value = std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same>::value || + std::is_same>::value; +}; + +template +inline void ExpectEqual(const T& a, const T& b) { + EXPECT_EQ(a, b); +} + +template <> +inline void ExpectEqual(const float& a, const float& b) { + EXPECT_FLOAT_EQ(a, b); +} + +template <> +inline void ExpectEqual(const double& a, const double& b) { + EXPECT_DOUBLE_EQ(a, b); +} + +template <> +inline void ExpectEqual(const tensorflow::complex64& a, + const tensorflow::complex64& b) { + EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b; + EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b; +} + +template <> +inline void ExpectEqual( + const tensorflow::complex128& a, const tensorflow::complex128& b) { + EXPECT_DOUBLE_EQ(a.real(), b.real()) << a << " vs. " << b; + EXPECT_DOUBLE_EQ(a.imag(), b.imag()) << a << " vs. " << b; +} + +template +inline void ExpectEqual(const T& a, const T& b, int index) { + EXPECT_EQ(a, b) << " at index " << index; +} + +template <> +inline void ExpectEqual(const float& a, const float& b, int index) { + EXPECT_FLOAT_EQ(a, b) << " at index " << index; +} + +template <> +inline void ExpectEqual(const double& a, const double& b, int index) { + EXPECT_DOUBLE_EQ(a, b) << " at index " << index; +} + +template <> +inline void ExpectEqual(const tensorflow::complex64& a, + const tensorflow::complex64& b, + int index) { + EXPECT_FLOAT_EQ(a.real(), b.real()) + << a << " vs. " << b << " at index " << index; + EXPECT_FLOAT_EQ(a.imag(), b.imag()) + << a << " vs. " << b << " at index " << index; +} + +template <> +inline void ExpectEqual(const tensorflow::complex128& a, + const tensorflow::complex128& b, + int index) { + EXPECT_DOUBLE_EQ(a.real(), b.real()) + << a << " vs. " << b << " at index " << index; + EXPECT_DOUBLE_EQ(a.imag(), b.imag()) + << a << " vs. " << b << " at index " << index; +} + +inline void AssertSameTypeDims(const tensorflow::Tensor& x, + const tensorflow::Tensor& y) { + ASSERT_EQ(x.dtype(), y.dtype()); + ASSERT_TRUE(x.IsSameSize(y)) + << "x.shape [" << x.shape().DebugString() << "] vs " + << "y.shape [ " << y.shape().DebugString() << "]"; +} + +template ::value> +struct Expector; + +template +struct Expector { + static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + + static void Equal(const tensorflow::Tensor& x, const tensorflow::Tensor& y) { + ASSERT_EQ(x.dtype(), tensorflow::DataTypeToEnum::v()); + AssertSameTypeDims(x, y); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + ExpectEqual(a[i], b[i]); + } + } +}; + +// Partial specialization for float and double. +template +struct Expector { + static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + + static void Equal(const tensorflow::Tensor& x, const tensorflow::Tensor& y) { + ASSERT_EQ(x.dtype(), tensorflow::DataTypeToEnum::v()); + AssertSameTypeDims(x, y); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + ExpectEqual(a[i], b[i]); + } + } + + static bool Near(const T& a, const T& b, const double abs_err) { + // Need a == b so that infinities are close to themselves. + return (a == b) || + (static_cast(Eigen::numext::abs(a - b)) <= abs_err); + } + + static void Near(const tensorflow::Tensor& x, const tensorflow::Tensor& y, + const double abs_err) { + ASSERT_EQ(x.dtype(), tensorflow::DataTypeToEnum::v()); + AssertSameTypeDims(x, y); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + EXPECT_TRUE(Near(a[i], b[i], abs_err)) + << "a = " << a[i] << " b = " << b[i] << " index = " << i; + } + } +}; + +template +struct Helper { + // Assumes atol and rtol are nonnegative. + static bool IsClose(const T& x, const T& y, const T& atol, const T& rtol) { + // Need x == y so that infinities are close to themselves. + return (x == y) || + (Eigen::numext::abs(x - y) <= atol + rtol * Eigen::numext::abs(x)); + } +}; + +template +struct Helper> { + static bool IsClose(const std::complex& x, const std::complex& y, + const T& atol, const T& rtol) { + return Helper::IsClose(x.real(), y.real(), atol, rtol) && + Helper::IsClose(x.imag(), y.imag(), atol, rtol); + } +}; + +} // namespace internal + +template +void ExpectTensorEqual(const tensorflow::Tensor& x, + const tensorflow::Tensor& y) { + internal::Expector::Equal(x, y); +} + +template +void ExpectTensorNear(const tensorflow::Tensor& x, const tensorflow::Tensor& y, + const double abs_err) { + static_assert(internal::is_floating_point_type::value, + "T is not a floating point types."); + ASSERT_GE(abs_err, 0.0) << "abs_error is negative" << abs_err; + internal::Expector::Near(x, y, abs_err); +} + +} // namespace test +} // namespace reverb +} // namespace deepmind + +#endif diff --git a/reverb/cc/testing/time_testutil.h b/reverb/cc/testing/time_testutil.h new file mode 100644 index 0000000..b34649d --- /dev/null +++ b/reverb/cc/testing/time_testutil.h @@ -0,0 +1,38 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REVERB_CC_TESTING_TIME_TESTUTIL_H_ +#define REVERB_CC_TESTING_TIME_TESTUTIL_H_ + +#include "absl/time/clock.h" +#include "absl/time/time.h" + +namespace deepmind { +namespace reverb { +namespace test { + +template +void WaitFor(F&& exit_criteria_fn, const absl::Duration& wait_duration, + int max_iteration) { + for (int retries = 0; !exit_criteria_fn() && retries < max_iteration; + ++retries) { + absl::SleepFor(wait_duration); + } +} + +} // namespace test +} // namespace reverb +} // namespace deepmind + +#endif // REVERB_CC_TESTING_TIME_TESTUTIL_H_ diff --git a/reverb/checkpointer.py b/reverb/checkpointer.py new file mode 100644 index 0000000..b0fed08 --- /dev/null +++ b/reverb/checkpointer.py @@ -0,0 +1,60 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python wrappers for constructing Checkpointers to pass to ReverbServer.""" + +import abc +import tempfile + +import numpy # pylint: disable=unused-import +from reverb import pybind + + +class CheckpointerBase(metaclass=abc.ABCMeta): + """Base class for Python wrappers of the Checkpointer.""" + + @abc.abstractmethod + def internal_checkpointer(self) -> pybind.CheckpointerInterface: + """Creates the actual Checkpointer-object used by the C++ layer.""" + + +class DefaultCheckpointer(CheckpointerBase): + """Base class for storing checkpoints to as recordIO files..""" + + def __init__(self, path: str, group: str = ''): + """Constructor of DefaultCheckpointer. + + Args: + path: Root directory to store checkpoints in. + group: MDB group to set as "group" of checkpoint directory. If empty + (default) then no group is set. + """ + self.path = path + self.group = group + + def internal_checkpointer(self) -> pybind.CheckpointerInterface: + """Creates the actual Checkpointer-object used by the C++ layer.""" + return pybind.create_default_checkpointer(self.path, self.group) + + +class TempDirCheckpointer(DefaultCheckpointer): + """Stores and loads checkpoints from a temporary directory.""" + + def __init__(self): + super().__init__(tempfile.mkdtemp()) + + +def default_checkpointer() -> CheckpointerBase: + return TempDirCheckpointer() diff --git a/reverb/checkpointer_test.py b/reverb/checkpointer_test.py new file mode 100644 index 0000000..2f031ac --- /dev/null +++ b/reverb/checkpointer_test.py @@ -0,0 +1,32 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Reverb checkpointer.""" + +from absl.testing import absltest +from reverb import checkpointer as checkpointer_lib +from reverb import pybind + + +class TempDirCheckpointer(absltest.TestCase): + + def test_constructs_internal_checkpointer(self): + checkpointer = checkpointer_lib.TempDirCheckpointer() + self.assertIsInstance(checkpointer.internal_checkpointer(), + pybind.CheckpointerInterface) + + +if __name__ == '__main__': + absltest.main() diff --git a/reverb/client.py b/reverb/client.py new file mode 100644 index 0000000..4d98766 --- /dev/null +++ b/reverb/client.py @@ -0,0 +1,436 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Replay client Python interface. + +The ReverbClient is used primarily for feeding the ReplayService with new data. +The preferred method is to use the `Writer` as it allows for the most +flexibility. + +Consider an example where we wish to generate all possible connected sequences +of length 5 based on a single actor. + +```python + + client = Client(...) + env = .... # Construct the environment + policy = .... # Construct the agent's policy + + for episode in range(NUM_EPISODES): + timestep = env.reset() + step_within_episode = 0 + with client.writer(max_sequence_length=5) as writer: + while not timestep.last(): + action = policy(timestep) + new_timestep = env.step(action) + + # Add the observation of the state the agent when doing action, the + # action it took and the reward it received. + writer.append_timestep( + (timestep.observation, action, new_timestep.reward)) + + timestep = new_timestep + step_within_episode += 1 + + if step_within_episode >= 5: + writer.create_prioritized_item( + table='my_distribution', + num_timesteps=5, + priority=calc_priority(...)) + + # Add an item for the sequence terminating in the final stage. + if steps_within_episode >= 5: + writer.create_prioritized_item( + table='my_distribution', + num_timesteps=5, + priority=calc_priority(...)) + +``` + +If you do not want overlapping sequences but instead want to insert complete +trajectories then the `insert`-method should be used. + +```python + + client = Client(...) + + trajectory_generator = ... + for trajectory in trajectory_generator: + client.insert(trajectory, {'my_distribution': calc_priority(trajectory)}) + +``` + +""" + +from typing import Dict, List, Optional + +from absl import logging +from reverb import errors +from reverb import pybind +from reverb import replay_sample +from reverb import reverb_types +import tree + +from reverb.cc import schema_pb2 +from tensorflow.python.saved_model import nested_structure_coder # pylint: disable=g-direct-tensorflow-import + + +class Client: + """Client for interacting with a Reverb ReplayService from Python. + + Note: This client should primarily be used when inserting data or prototyping + at very small scale. + Whenever possible, prefer to use TFClient (see ./tf_client.py). + """ + + def __init__(self, server_address: str, client: pybind.ReplayClient = None): + """Constructor of ReverbClient. + + Args: + server_address: Address to the Reverb ReplayService. + client: Optional pre-existing ReplayClient. For internal use only. + """ + self._server_address = server_address + self._client = client if client else pybind.ReplayClient(server_address) + + def __reduce__(self): + return self.__class__, (self._server_address,) + + @property + def server_address(self): + return self._server_address + + def insert(self, data, priorities: Dict[str, float]): + """Inserts a "blob" (e.g. trajectory) into one or more priority tables. + + Note: The data is only stored once even if samples are inserted into + multiple priority tables. + + Note: When possible, prefer to use the in graph version (see ./tf_client.py) + to avoid stepping through Python. + + Args: + data: A (possible nested) structure to insert. + priorities: Mapping from table name to priority value. + + Raises: + ValueError: If priorities is empty. + """ + if not priorities: + raise ValueError('priorities must contain at least one item') + + with self.writer(max_sequence_length=1) as writer: + writer.append_timestep(data) + for table, priority in priorities.items(): + writer.create_prioritized_item( + table=table, num_timesteps=1, priority=priority) + + def writer( + self, + max_sequence_length: int, + delta_encoded: bool = False, + chunk_length: int = None, + ): + """Constructs a writer with a `max_sequence_length` buffer. + + The writer can be used to stream data of any length. `max_sequence_length` + controls the size of the internal buffer and ensures that prioritized items + can be created of any length <= `max_sequence_length`. + + The writer is stateful and must be closed after the write has finished. The + easiest way to manage this is to use it as a contextmanager: + + ```python + + with client.writer(10) as writer: + ... # Write data of any length. + + ``` + + If not used as a contextmanager then `.close()` must be called explicitly. + + Args: + max_sequence_length: Size of the internal buffer controlling the upper + limit of the number of timesteps which can be referenced in a single + prioritized item. Note that this is NOT a limit of how many timesteps or + items that can be inserted. + delta_encoded: If `True` (False by default) tensors are delta encoded + against the first item within their respective batch before compressed. + This can significantly reduce RAM at the cost of a small amount of CPU + for highly correlated data (e.g frames of video observations). + chunk_length: Number of timesteps grouped together before delta encoding + and compression. Set by default to `min(10, max_sequence_length)` but + can be overridden to achieve better compression rates when using longer + sequences with a small overlap. + + Returns: + A `Writer` with `max_sequence_length`. + + Raises: + ValueError: If max_sequence_length < 1. + ValueError: if chunk_length > max_sequence_length. + ValueError: if chunk_length < 1. + """ + if max_sequence_length < 1: + raise ValueError('max_sequence_length (%d) must be a positive integer' % + max_sequence_length) + + if chunk_length is None: + chunk_length = min(10, max_sequence_length) + + if chunk_length < 1 or chunk_length > max_sequence_length: + raise ValueError( + 'chunk_length (%d) must be a positive integer le to max_sequence_length (%d)' + % (chunk_length, max_sequence_length)) + + return Writer( + self._client.NewWriter(chunk_length, max_sequence_length, + delta_encoded)) + + def sample(self, table: str, num_samples=1): + """Samples `num_samples` items from table `table` of the Server. + + NOTE: This method should NOT be used for real training. TFClient (see + tf_client.py) has far superior performance and should always be preferred. + + Note: If data was written using `insert` (e.g when inserting complete + trajectories) then the returned "sequence" will be a list of length 1 + containing the trajectory as a single item. + + If `num_samples` is greater than the number of items in `table`, (or + a rate limiter is used to control sampling), then the returned generator + will block when an item past the sampling limit is requested. It will + unblock when sufficient additional items have been added to `table`. + + Example: + ```python + server = Server(..., tables=[queue("queue", ...)]) + client = Client(...) + # Don't insert anything into "queue" + generator = client.sample("queue") + generator.next() # Blocks until another thread/process writes to queue. + ``` + + Args: + table: Name of the priority table to sample from. + num_samples: (default to 1) The number of samples to fetch. + + Yields: + Lists of timesteps (lists of instances of `ReplaySample`). + If data was inserted into the table via `insert`, then each element + of the generator is a length 1 list containing a `ReplaySample`. + If data was inserted via a writer, then each element is a list whose + length is the sampled trajectory's length. + """ + sampler = self._client.NewSampler(table, num_samples, 1) + + for _ in range(num_samples): + sequence = [] + last = False + + while not last: + step, last = sampler.GetNextTimestep() + key = int(step[0]) + probability = float(step[1]) + table_size = int(step[2]) + data = step[3:] + sequence.append( + replay_sample.ReplaySample( + info=replay_sample.SampleInfo(key, probability, table_size), + data=data)) + + yield sequence + + def mutate_priorities(self, + table: str, + updates: Dict[int, float] = None, + deletes: List[int] = None): + """Updates and/or deletes existing items in a priority table. + + NOTE: Whenever possible, prefer to use `TFClient.update_priorities` + instead to avoid leaving the graph. + + Actions are executed in the same order as the arguments are specified. + + Args: + table: Name of the priority table to update. + updates: Mapping from priority item key to new priority value. If a key + cannot be found then it is ignored. + deletes: List of keys for priority items to delete. If a key cannot be + found then it is ignored. + """ + if updates is None: + updates = {} + if deletes is None: + deletes = [] + self._client.MutatePriorities(table, list(updates.items()), deletes) + + def reset(self, table: str): + """Clears all items of the table and resets its RateLimiter. + + Args: + table: Name of the priority table to reset. + """ + self._client.Reset(table) + + def server_info( + self, + timeout: Optional[int] = None) -> Dict[str, reverb_types.TableInfo]: + """Get table metadata information. + + Args: + timeout: Timeout in seconds to wait for server response. By default no + deadline is set and call will block indefinetely until server responds. + + Returns: + A dictionary mapping table names to their associated `TableInfo` + instances, which contain metadata about the table. + + Raises: + errors.TimeoutError: If timeout provided and exceeded. + """ + try: + info_proto_strings = self._client.ServerInfo(timeout or 0) + except RuntimeError as e: + if 'Deadline Exceeded' in str(e) and timeout is not None: + raise errors.TimeoutError( + f'ServerInfo call did not complete within provided timeout of ' + f'{timeout}s') + raise + + table_info = {} + for proto_string in info_proto_strings: + proto = schema_pb2.TableInfo.FromString(proto_string) + if proto.HasField('signature'): + signature = nested_structure_coder.StructureCoder().decode_proto( + proto.signature) + else: + signature = None + info_dict = dict((descr.name, getattr(proto, descr.name)) + for descr in proto.DESCRIPTOR.fields) + info_dict['signature'] = signature + name = str(info_dict['name']) + table_info[name] = reverb_types.TableInfo(**info_dict) + return table_info + + def checkpoint(self) -> str: + """Triggers a checkpoint to be created. + + Returns: + Absolute path to the saved checkpoint. + """ + return self._client.Checkpoint() + + +class Writer: + """Writer is used for streaming data of arbitrary length. + + See ReverbClient.writer for documentation. + """ + + def __init__(self, internal_writer): + """Constructor for Writer (must only be called by ReverbClient.writer).""" + self._writer = internal_writer + self._closed = False + + def __enter__(self): + if self._closed: + raise ValueError('Cannot reuse already closed Writer') + return self + + def __exit__(self, *_): + self.close() + + def __del__(self): + if not self._closed: + logging.warning( + 'Writer-object deleted without calling .close explicitly.') + + def append_timestep(self, timestep): + """Appends a timestep to the internal buffer. + + NOTE: Calling this method alone does not result in anything being inserted + into the replay. To trigger timestep insertion, `create_prioritized_item` + must be called so that the resulting sequence includes the timestep. + + Consider the following example: + + ```python + + A, B, C = ... + client = Client(...) + + with client.writer(max_sequence_length=2) as writer: + writer.append_timestep(A) # A is added to the internal buffer. + writer.append_timestep(B) # B is added to the internal buffer. + + # The buffer is now full so when this is called C is added and A is + # removed from the internal buffer and since A was never referenced by + # a prioritized item it was never sent to the server. + writer.append_timestep(C) + + # A sequence of length 1 is created referencing only C and thus C is + # sent to the server. + writer.create_prioritized_item('my_table', 1, 5.0) + + # Writer is now closed and B was never referenced by a prioritized item + # and thus never sent to the server. + + ``` + + Args: + timestep: The (possibly nested) structure to make available for new + prioritized items to reference. + """ + self._writer.AppendTimestep(tree.flatten(timestep)) + + def create_prioritized_item(self, table: str, num_timesteps: int, + priority: float): + """Creates a prioritized item and sends it to the ReplayService. + + This method is what effectively makes data available for sampling. See the + docstring of `append_timestep` for an illustrative example of the behavior. + + Args: + table: Name of the priority table to insert the item into. + num_timesteps: The number of most recently added timesteps that the new + item should reference. + priority: The priority used for determining the sample probability of the + new item. + + Raises: + ValueError: If num_timesteps is < 1. + StatusNotOk: If num_timesteps is > than the timesteps currently available + in the buffer. + """ + if num_timesteps < 1: + raise ValueError('num_timesteps (%d) must be a positive integer') + # TODO(b/154930410): Catch the StatusNotOk and raise a ValueError instead. + self._writer.AddPriority(table, num_timesteps, priority) + + def close(self): + """Closes the stream to the ReplayService. + + The method is automatically called when existing the contextmanager scope. + + Note: Writer-object must be abandoned after this method called. + + Raises: + ValueError: If already has been called. + """ + if self._closed: + raise ValueError('close() has already been called on Writer.') + self._closed = True + self._writer.Close() diff --git a/reverb/client_test.py b/reverb/client_test.py new file mode 100644 index 0000000..e9b1446 --- /dev/null +++ b/reverb/client_test.py @@ -0,0 +1,228 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for python client.""" + +import collections +import multiprocessing.dummy as multithreading +import pickle + +from absl.testing import absltest +from reverb import client +from reverb import distributions +from reverb import errors +from reverb import rate_limiters +from reverb import server +import tensorflow.compat.v1 as tf + +TABLE_NAME = 'table' + + +class ReverbClientTest(absltest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.server = server.Server( + priority_tables=[ + server.PriorityTable( + name=TABLE_NAME, + sampler=distributions.Prioritized(1), + remover=distributions.Fifo(), + max_size=1000, + rate_limiter=rate_limiters.MinSize(3), + signature=tf.TensorSpec(dtype=tf.int64, shape=()), + ), + ], + port=None) + cls.client = client.Client(f'localhost:{cls.server.port}') + + def tearDown(self): + self.client.reset(TABLE_NAME) + super().tearDown() + + @classmethod + def tearDownClass(cls): + cls.server.stop() + super().tearDownClass() + + def _get_sample_frequency(self, n=10000): + keys = [sample[0].info.key for sample in self.client.sample(TABLE_NAME, n)] + counter = collections.Counter(keys) + return [count / n for _, count in counter.most_common()] + + def test_sample_sets_table_size(self): + for i in range(1, 11): + self.client.insert(i, {TABLE_NAME: 1.0}) + if i >= 3: + sample = next(self.client.sample(TABLE_NAME, 1))[0] + self.assertEqual(sample.info.table_size, i) + + def test_sample_sets_probability(self): + for i in range(1, 11): + self.client.insert(i, {TABLE_NAME: 1.0}) + if i >= 3: + sample = next(self.client.sample(TABLE_NAME, 1))[0] + self.assertAlmostEqual(sample.info.probability, 1.0 / i, 0.01) + + def test_insert_raises_if_priorities_empty(self): + with self.assertRaises(ValueError): + self.client.insert([1], {}) + + def test_insert(self): + self.client.insert(1, {TABLE_NAME: 1.0}) # This should be sampled often. + self.client.insert(2, {TABLE_NAME: 0.1}) # This should be sampled rarely. + self.client.insert(3, {TABLE_NAME: 0.0}) # This should never be sampled. + + freqs = self._get_sample_frequency() + + self.assertLen(freqs, 2) + self.assertAlmostEqual(freqs[0], 0.9, delta=0.05) + self.assertAlmostEqual(freqs[1], 0.1, delta=0.05) + + def test_writer_raises_if_max_sequence_length_lt_1(self): + with self.assertRaises(ValueError): + self.client.writer(0) + + def test_writer_raises_if_chunk_length_lt_1(self): + self.client.writer(2, chunk_length=1) # Should be fine. + + for chunk_length in [0, -1]: + with self.assertRaises(ValueError): + self.client.writer(2, chunk_length=chunk_length) + + def test_writer_raises_if_chunk_length_gt_max_sequence_length(self): + self.client.writer(2, chunk_length=1) # lt should be fine. + self.client.writer(2, chunk_length=2) # eq should be fine. + + with self.assertRaises(ValueError): + self.client.writer(2, chunk_length=3) + + def test_writer(self): + with self.client.writer(2) as writer: + writer.append_timestep([0]) + writer.create_prioritized_item(TABLE_NAME, 1, 1.0) + writer.append_timestep([1]) + writer.create_prioritized_item(TABLE_NAME, 2, 1.0) + writer.append_timestep([2]) + writer.create_prioritized_item(TABLE_NAME, 1, 1.0) + + freqs = self._get_sample_frequency() + self.assertLen(freqs, 3) + for freq in freqs: + self.assertAlmostEqual(freq, 0.33, delta=0.05) + + def test_mutate_priorities_update(self): + self.client.insert([0], {TABLE_NAME: 1.0}) + self.client.insert([0], {TABLE_NAME: 1.0}) + self.client.insert([0], {TABLE_NAME: 1.0}) + + before = self._get_sample_frequency() + self.assertLen(before, 3) + for freq in before: + self.assertAlmostEqual(freq, 0.33, delta=0.05) + + key = next(self.client.sample(TABLE_NAME, 1))[0].info.key + self.client.mutate_priorities(TABLE_NAME, updates={key: 0.5}) + + after = self._get_sample_frequency() + self.assertLen(after, 3) + self.assertAlmostEqual(after[0], 0.4, delta=0.05) + self.assertAlmostEqual(after[1], 0.4, delta=0.05) + self.assertAlmostEqual(after[2], 0.2, delta=0.05) + + def test_mutate_priorities_delete(self): + self.client.insert([0], {TABLE_NAME: 1.0}) + self.client.insert([0], {TABLE_NAME: 1.0}) + self.client.insert([0], {TABLE_NAME: 1.0}) + self.client.insert([0], {TABLE_NAME: 1.0}) + + before = self._get_sample_frequency() + self.assertLen(before, 4) + + key = next(self.client.sample(TABLE_NAME, 1))[0].info.key + self.client.mutate_priorities(TABLE_NAME, deletes=[key]) + + after = self._get_sample_frequency() + self.assertLen(after, 3) + + def test_reset(self): + self.client.insert([0], {TABLE_NAME: 1.0}) + self.client.insert([0], {TABLE_NAME: 1.0}) + self.client.insert([0], {TABLE_NAME: 1.0}) + + keys_before = set( + sample[0].info.key for sample in self.client.sample(TABLE_NAME, 1000)) + self.assertLen(keys_before, 3) + + self.client.reset(TABLE_NAME) + + self.client.insert([0], {TABLE_NAME: 1.0}) + self.client.insert([0], {TABLE_NAME: 1.0}) + self.client.insert([0], {TABLE_NAME: 1.0}) + + keys_after = set( + sample[0].info.key for sample in self.client.sample(TABLE_NAME, 1000)) + self.assertLen(keys_after, 3) + + self.assertTrue(keys_after.isdisjoint(keys_before)) + + def test_server_info(self): + self.client.insert([0], {TABLE_NAME: 1.0}) + self.client.insert([0], {TABLE_NAME: 1.0}) + self.client.insert([0], {TABLE_NAME: 1.0}) + server_info = self.client.server_info() + self.assertLen(server_info, 1) + self.assertIn(TABLE_NAME, server_info) + info = server_info[TABLE_NAME] + self.assertEqual(info.current_size, 3) + self.assertEqual(info.max_size, 1000) + self.assertEqual(info.sampler_options.prioritized.priority_exponent, 1) + self.assertTrue(info.remover_options.fifo) + self.assertEqual(info.signature, tf.TensorSpec(dtype=tf.int64, shape=())) + + def test_server_info_timeout(self): + # Setup a client that doesn't actually connect to anything. + dummy_client = client.Client(f'localhost:{self.server.port + 1}') + with self.assertRaises( + errors.TimeoutError, + msg='ServerInfo call did not complete within provided timeout of 1s'): + dummy_client.server_info(timeout=1) + + def test_pickle(self): + loaded_client = pickle.loads(pickle.dumps(self.client)) + self.assertEqual(loaded_client._server_address, self.client._server_address) + loaded_client.insert([0], {TABLE_NAME: 1.0}) + + def test_multithreaded_writer(self): + # Ensure that we don't have any errors caused by multithreaded use of + # writers or clients. + pool = multithreading.Pool(64) + def _write(i): + with self.client.writer(1) as writer: + writer.append_timestep([i]) + writer.create_prioritized_item(TABLE_NAME, 1, 1.0) + + for _ in range(5): + pool.map(_write, list(range(256))) + + info = self.client.server_info()[TABLE_NAME] + self.assertEqual(info.current_size, 1000) + pool.close() + pool.join() + + +if __name__ == '__main__': + absltest.main() diff --git a/reverb/distributions.py b/reverb/distributions.py new file mode 100644 index 0000000..ce9e293 --- /dev/null +++ b/reverb/distributions.py @@ -0,0 +1,26 @@ +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sampling and removing distributions.""" + +import functools + +from reverb import pybind + +Fifo = pybind.FifoDistribution +Lifo = pybind.LifoDistribution +MaxHeap = functools.partial(pybind.HeapDistribution, False) # pylint: disable=invalid-name +MinHeap = functools.partial(pybind.HeapDistribution, True) # pylint: disable=invalid-name +Prioritized = pybind.PrioritizedDistribution +Uniform = pybind.UniformDistribution diff --git a/reverb/errors.py b/reverb/errors.py new file mode 100644 index 0000000..760427a --- /dev/null +++ b/reverb/errors.py @@ -0,0 +1,25 @@ +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Error classes for Reverb.""" + + +class ReverbError(Exception): + """Base class for Reverb errors.""" + pass + + +class TimeoutError(ReverbError): + """A call to the server timed out.""" + pass diff --git a/reverb/pybind.cc b/reverb/pybind.cc new file mode 100644 index 0000000..e837312 --- /dev/null +++ b/reverb/pybind.cc @@ -0,0 +1,670 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "numpy/arrayobject.h" +#include "absl/container/inlined_vector.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "reverb/cc/checkpointing/interface.h" +#include "reverb/cc/distributions/fifo.h" +#include "reverb/cc/distributions/heap.h" +#include "reverb/cc/distributions/interface.h" +#include "reverb/cc/distributions/lifo.h" +#include "reverb/cc/distributions/prioritized.h" +#include "reverb/cc/distributions/uniform.h" +#include "reverb/cc/platform/checkpointing.h" +#include "reverb/cc/priority_table.h" +#include "reverb/cc/rate_limiter.h" +#include "reverb/cc/replay_client.h" +#include "reverb/cc/replay_sampler.h" +#include "reverb/cc/replay_writer.h" +#include "reverb/cc/reverb_server.h" +#include "reverb/cc/table_extensions/interface.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" + +namespace { + +using ::tensorflow::error::Code; + +struct PyDecrefDeleter { + void operator()(PyObject *p) const { Py_DECREF(p); } +}; +using Safe_PyObjectPtr = std::unique_ptr; +Safe_PyObjectPtr make_safe(PyObject *o) { return Safe_PyObjectPtr(o); } + +// Converts non OK statuses to Python exceptions and throws. Does nothing for +// OK statuses. +inline void MaybeRaiseFromStatus(const tensorflow::Status &status) { + if (status.ok()) return; + + // TODO(b/152982733): Add tests that validates that casting behaviour is + // aligned with what tensorflow does. + switch (status.code()) { +#define CODE_TO_PY_EXC(CODE, PY_EXC) \ + case CODE: \ + PyErr_SetString(PY_EXC, status.error_message().c_str()); \ + break; + + CODE_TO_PY_EXC(Code::INVALID_ARGUMENT, PyExc_ValueError) + CODE_TO_PY_EXC(Code::RESOURCE_EXHAUSTED, PyExc_IndexError) + CODE_TO_PY_EXC(Code::UNIMPLEMENTED, PyExc_NotImplementedError) + CODE_TO_PY_EXC(Code::INTERNAL, PyExc_RuntimeError) + + // TODO(b/154927554): Map more status codes to Python exceptions. + +#undef CODE_TO_PY_EXC + + default: + PyErr_SetString(PyExc_RuntimeError, status.error_message().c_str()); + } + + throw pybind11::error_already_set(); +} + +char const *NumpyTypeName(int numpy_type) { + switch (numpy_type) { +#define TYPE_CASE(s) \ + case s: \ + return #s + + TYPE_CASE(NPY_BOOL); + TYPE_CASE(NPY_BYTE); + TYPE_CASE(NPY_UBYTE); + TYPE_CASE(NPY_SHORT); + TYPE_CASE(NPY_USHORT); + TYPE_CASE(NPY_INT); + TYPE_CASE(NPY_UINT); + TYPE_CASE(NPY_LONG); + TYPE_CASE(NPY_ULONG); + TYPE_CASE(NPY_LONGLONG); + TYPE_CASE(NPY_ULONGLONG); + TYPE_CASE(NPY_FLOAT); + TYPE_CASE(NPY_DOUBLE); + TYPE_CASE(NPY_LONGDOUBLE); + TYPE_CASE(NPY_CFLOAT); + TYPE_CASE(NPY_CDOUBLE); + TYPE_CASE(NPY_CLONGDOUBLE); + TYPE_CASE(NPY_OBJECT); + TYPE_CASE(NPY_STRING); + TYPE_CASE(NPY_UNICODE); + TYPE_CASE(NPY_VOID); + TYPE_CASE(NPY_DATETIME); + TYPE_CASE(NPY_TIMEDELTA); + TYPE_CASE(NPY_HALF); + TYPE_CASE(NPY_NTYPES); + TYPE_CASE(NPY_NOTYPE); + TYPE_CASE(NPY_USERDEF); + + default: + return "not a numpy type"; + } +} + +void ImportNumpy() { import_array1(); } + +tensorflow::Status PyObjectToString(PyObject *obj, const char **ptr, + Py_ssize_t *len, PyObject **ptr_owner) { + *ptr_owner = nullptr; + if (PyBytes_Check(obj)) { + char *buf; + if (PyBytes_AsStringAndSize(obj, &buf, len) != 0) { + return tensorflow::errors::Internal("Unable to get element as bytes."); + } + *ptr = buf; + } else if (PyUnicode_Check(obj)) { + *ptr = PyUnicode_AsUTF8AndSize(obj, len); + if (*ptr == nullptr) { + return tensorflow::errors::Internal("Unable to convert element to UTF-8"); + } + } else { + return tensorflow::errors::Internal("Unsupported object type ", + obj->ob_type->tp_name); + } + + return tensorflow::Status::OK(); +} + +// Iterate over the string array 'array', extract the ptr and len of each string +// element and call f(ptr, len). +template +tensorflow::Status PyBytesArrayMap(PyArrayObject *array, F f) { + auto iter = make_safe(PyArray_IterNew(reinterpret_cast(array))); + + while (PyArray_ITER_NOTDONE(iter.get())) { + PyObject *item = PyArray_GETITEM( + array, static_cast(PyArray_ITER_DATA(iter.get()))); + if (!item) { + return tensorflow::errors::Internal( + "Unable to get element from the feed - no item."); + } + Py_ssize_t len; + const char *ptr; + PyObject *ptr_owner = nullptr; + TF_RETURN_IF_ERROR(PyObjectToString(item, &ptr, &len, &ptr_owner)); + f(ptr, len); + Py_XDECREF(ptr_owner); + PyArray_ITER_NEXT(iter.get()); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status StringTensorToPyArray(const tensorflow::Tensor &tensor, + PyArrayObject *dst) { + DCHECK_EQ(tensor.dtype(), tensorflow::DT_STRING); + + auto iter = make_safe(PyArray_IterNew(reinterpret_cast(dst))); + + const auto &flat_data = tensor.flat().data(); + for (int i = 0; i < tensor.NumElements(); i++) { + const auto &value = flat_data[i]; + auto py_string = + make_safe(PyBytes_FromStringAndSize(value.c_str(), value.size())); + if (py_string == nullptr) { + return tensorflow::errors::Internal( + "failed to create a python byte array when converting element #", i, + " of a TF_STRING tensor to a numpy ndarray"); + } + + if (PyArray_SETITEM(dst, PyArray_ITER_DATA(iter.get()), py_string.get()) != + 0) { + return tensorflow::errors::Internal("Error settings element #", i, + " in the numpy ndarray"); + } + + PyArray_ITER_NEXT(iter.get()); + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status GetPyDescrFromTensor(const tensorflow::Tensor &tensor, + PyArray_Descr **out_descr) { + switch (tensor.dtype()) { +#define TF_TO_PY_ARRAY_TYPE_CASE(TF_DTYPE, PY_ARRAY_TYPE) \ + case TF_DTYPE: \ + *out_descr = PyArray_DescrFromType(PY_ARRAY_TYPE); \ + break; + + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_HALF, NPY_FLOAT16) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_FLOAT, NPY_FLOAT32) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_DOUBLE, NPY_FLOAT64) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_INT32, NPY_INT32) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_UINT8, NPY_UINT8) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_UINT16, NPY_UINT16) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_UINT32, NPY_UINT32) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_INT8, NPY_INT8) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_INT16, NPY_INT16) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_BOOL, NPY_BOOL) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_COMPLEX64, NPY_COMPLEX64) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_COMPLEX128, NPY_COMPLEX128) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_STRING, NPY_OBJECT) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_UINT64, NPY_UINT64) + TF_TO_PY_ARRAY_TYPE_CASE(tensorflow::DT_INT64, NPY_INT64) + +#undef TF_DTYPE_TO_PY_ARRAY_TYPE_CASE + + default: + return tensorflow::errors::Internal( + "Unsupported tf type: ", tensorflow::DataType_Name(tensor.dtype())); + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status GetTensorDtypeFromPyArray( + PyArrayObject *array, tensorflow::DataType *out_tf_datatype) { + int pyarray_type = PyArray_TYPE(array); + switch (pyarray_type) { +#define NP_TO_TF_DTYPE_CASE(NP_DTYPE, TF_DTYPE) \ + case NP_DTYPE: \ + *out_tf_datatype = TF_DTYPE; \ + break; + + NP_TO_TF_DTYPE_CASE(NPY_FLOAT16, tensorflow::DT_HALF) + NP_TO_TF_DTYPE_CASE(NPY_FLOAT32, tensorflow::DT_FLOAT) + NP_TO_TF_DTYPE_CASE(NPY_FLOAT64, tensorflow::DT_DOUBLE) + + NP_TO_TF_DTYPE_CASE(NPY_INT8, tensorflow::DT_INT8) + NP_TO_TF_DTYPE_CASE(NPY_INT16, tensorflow::DT_INT16) + NP_TO_TF_DTYPE_CASE(NPY_INT32, tensorflow::DT_INT32) + NP_TO_TF_DTYPE_CASE(NPY_LONGLONG, tensorflow::DT_INT64) + NP_TO_TF_DTYPE_CASE(NPY_INT64, tensorflow::DT_INT64) + + NP_TO_TF_DTYPE_CASE(NPY_UINT8, tensorflow::DT_UINT8) + NP_TO_TF_DTYPE_CASE(NPY_UINT16, tensorflow::DT_UINT16) + NP_TO_TF_DTYPE_CASE(NPY_UINT32, tensorflow::DT_UINT32) + NP_TO_TF_DTYPE_CASE(NPY_ULONGLONG, tensorflow::DT_UINT64) + NP_TO_TF_DTYPE_CASE(NPY_UINT64, tensorflow::DT_UINT64) + + NP_TO_TF_DTYPE_CASE(NPY_BOOL, tensorflow::DT_BOOL) + + NP_TO_TF_DTYPE_CASE(NPY_COMPLEX64, tensorflow::DT_COMPLEX64) + NP_TO_TF_DTYPE_CASE(NPY_COMPLEX128, tensorflow::DT_COMPLEX128) + + // TODO(b/154925823): String types are not supported + NP_TO_TF_DTYPE_CASE(NPY_OBJECT, tensorflow::DT_STRING) + NP_TO_TF_DTYPE_CASE(NPY_STRING, tensorflow::DT_STRING) + NP_TO_TF_DTYPE_CASE(NPY_UNICODE, tensorflow::DT_STRING) + +#undef NP_TO_TF_DTYPE_CASE + + case NPY_VOID: + // TODO(b/154925774): Support struct and quantized types. + return tensorflow::errors::Unimplemented( + "Custom structs and quantized types are not supported"); + default: + // TODO(b/154926401): Add support for bfloat16. + // The bfloat16 type is defined in the internals of tf. + if (pyarray_type == -1) { + return tensorflow::errors::Unimplemented( + "bfloat16 types are not yet supported"); + } + + return tensorflow::errors::Internal("Unsupported numpy type: ", + NumpyTypeName(pyarray_type)); + } + return tensorflow::Status::OK(); +} + +inline tensorflow::Status VerifyDtypeIsSupported( + const tensorflow::DataType &dtype) { + if (!tensorflow::DataTypeCanUseMemcpy(dtype) && + dtype != tensorflow::DT_STRING) { + return tensorflow::errors::Unimplemented( + "ndarrays that maps to tensors with dtype ", + tensorflow::DataType_Name(dtype), " are not yet supported"); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status NdArrayToTensor(PyObject *ndarray, + tensorflow::Tensor *out_tensor) { + DCHECK(out_tensor != nullptr); + auto array_safe = make_safe(PyArray_FromAny( + /*op=*/ndarray, + /*dtype=*/nullptr, + /*min_depth=*/0, + /*max_depth=*/0, + /*requirements=*/NPY_ARRAY_CARRAY_RO, + /*context=*/nullptr)); + if (!array_safe) { + return tensorflow::errors::InvalidArgument( + "Provided input could not be interpreted as an ndarray"); + } + PyArrayObject *py_array = reinterpret_cast(array_safe.get()); + + // Convert numpy dtype to TensorFlow dtype. + tensorflow::DataType dtype; + TF_RETURN_IF_ERROR(GetTensorDtypeFromPyArray(py_array, &dtype)); + TF_RETURN_IF_ERROR(VerifyDtypeIsSupported(dtype)); + + absl::InlinedVector dims(PyArray_NDIM(py_array)); + tensorflow::int64 nelems = 1; + for (int i = 0; i < PyArray_NDIM(py_array); ++i) { + dims[i] = PyArray_SHAPE(py_array)[i]; + nelems *= dims[i]; + } + + if (tensorflow::DataTypeCanUseMemcpy(dtype)) { + *out_tensor = tensorflow::Tensor(dtype, tensorflow::TensorShape(dims)); + size_t size = PyArray_NBYTES(py_array); + memcpy(out_tensor->data(), PyArray_DATA(py_array), size); + } else if (dtype == tensorflow::DT_STRING) { + *out_tensor = tensorflow::Tensor(dtype, tensorflow::TensorShape(dims)); + int i = 0; + auto *out_t = out_tensor->flat().data(); + TF_RETURN_IF_ERROR( + PyBytesArrayMap(py_array, [out_t, &i](const char *ptr, Py_ssize_t len) { + out_t[i++] = tensorflow::tstring(ptr, len); + })); + } else { + return tensorflow::errors::Unimplemented("Unexpected dtype: ", + tensorflow::DataTypeString(dtype)); + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status TensorToNdArray(const tensorflow::Tensor &tensor, + PyObject **out_ndarray) { + TF_RETURN_IF_ERROR(VerifyDtypeIsSupported(tensor.dtype())); + + // Extract the numpy type and dimensions. + PyArray_Descr *descr = nullptr; + TF_RETURN_IF_ERROR(GetPyDescrFromTensor(tensor, &descr)); + + absl::InlinedVector dims(tensor.dims()); + for (int i = 0; i < tensor.dims(); i++) { + dims[i] = tensor.dim_size(i); + } + + // Allocate an empty array of the desired shape and type. + auto safe_out_ndarray = + make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0)); + if (!safe_out_ndarray) { + return tensorflow::errors::Internal("Could not allocate ndarray"); + } + + // Populate the ndarray with data from the tensor. + PyArrayObject *py_array = + reinterpret_cast(safe_out_ndarray.get()); + if (tensorflow::DataTypeCanUseMemcpy(tensor.dtype())) { + memcpy(PyArray_DATA(py_array), tensor.data(), PyArray_NBYTES(py_array)); + } else if (tensor.dtype() == tensorflow::DT_STRING) { + TF_RETURN_IF_ERROR(StringTensorToPyArray(tensor, py_array)); + } else { + return tensorflow::errors::Unimplemented( + "Unexpected tensor dtype: ", + tensorflow::DataTypeString(tensor.dtype())); + } + + *out_ndarray = safe_out_ndarray.release(); + return tensorflow::Status::OK(); +} + +} // namespace + +namespace pybind11 { +namespace detail { + +// Convert between absl::optional and python. +// +// pybind11 supports std::optional, and absl::optional is meant to be a +// drop-in replacement for std::optional, so we can just use the built in +// implementation. +// +// If we start getting errors due to this being defined in multiple places that +// likely means that pybind11 has included the cast itself and we can remove +// this implementation. +#ifndef ABSL_USES_STD_OPTIONAL +template +struct type_caster> + : public optional_caster> {}; + +template <> +struct type_caster : public void_caster {}; +#endif + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(tensorflow::Tensor, _("tensorflow::Tensor")); + + bool load(handle handle, bool) { + tensorflow::Status status = NdArrayToTensor(handle.ptr(), &value); + + if (!status.ok()) { + std::string message = status.ToString(); + REVERB_LOG(REVERB_ERROR) + << "Tensor can't be extracted from the source represented as " + "ndarray: " + << message; + // When a conversion fails, PyErr is set. Returning from `load` with PyErr + // set results in crashes so we clear the error here to make the Python + // error slightly more readable. + PyErr_Clear(); + return false; + } + return true; + } + + static handle cast(const tensorflow::Tensor &src, return_value_policy, + handle) { + PyObject *ret; + tensorflow::Status status = TensorToNdArray(src, &ret); + if (!status.ok()) { + std::string message = status.ToString(); + PyErr_SetString(PyExc_ValueError, message.data()); + return nullptr; + } + return ret; + } +}; + +// Raise an exception if a given status is not OK, otherwise return None. +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(tensorflow::Status, _("Status")); + static handle cast(tensorflow::Status status, return_value_policy, handle) { + MaybeRaiseFromStatus(status); + return none().inc_ref(); + } +}; + +} // namespace detail +} // namespace pybind11 + +namespace deepmind { +namespace reverb { +namespace { + +namespace py = pybind11; + +PYBIND11_MODULE(libpybind, m) { + // Initialization code to use numpy types in the type casters. + ImportNumpy(); + + py::class_> + unused_key_distribution_interface(m, "KeyDistributionInterface"); + + py::class_>( + m, "PrioritizedDistribution") + .def(py::init(), py::arg("priority_exponent")); + + py::class_>(m, "FifoDistribution") + .def(py::init()); + + py::class_>(m, "LifoDistribution") + .def(py::init()); + + py::class_>(m, "UniformDistribution") + .def(py::init()); + + py::class_>(m, "HeapDistribution") + .def(py::init(), py::arg("min_heap")); + + py::class_> + unused_priority_table_extension_interface( + m, "PriorityTableExtensionInterface"); + + py::class_>(m, "RateLimiter") + .def(py::init(), + py::arg("samples_per_insert"), py::arg("min_size_to_sample"), + py::arg("min_diff"), py::arg("max_diff")); + + py::class_>(m, "PriorityTable") + .def(py::init([](const std::string &name, + const std::shared_ptr &sampler, + const std::shared_ptr &remover, + int max_size, int max_times_sampled, + const std::shared_ptr &rate_limiter, + const std::vector> &extensions, + const absl::optional &serialized_signature = + absl::nullopt) -> PriorityTable * { + absl::optional signature = + absl::nullopt; + if (serialized_signature) { + signature.emplace(); + if (!signature->ParseFromString(*serialized_signature)) { + MaybeRaiseFromStatus(tensorflow::errors::InvalidArgument( + "Unable to deserialize StructuredValue from " + "serialized proto bytes: '", + *serialized_signature, "'")); + return nullptr; + } + } + return new PriorityTable(name, sampler, remover, max_size, + max_times_sampled, rate_limiter, + extensions, std::move(signature)); + }), + py::arg("name"), py::arg("sampler"), py::arg("remover"), + py::arg("max_size"), py::arg("max_times_sampled"), + py::arg("rate_limiter"), py::arg("extensions"), py::arg("signature")) + .def("name", &PriorityTable::name) + .def("can_sample", &PriorityTable::CanSample, + py::call_guard()) + .def("can_insert", &PriorityTable::CanInsert, + py::call_guard()); + + py::class_(m, "ReplayWriter") + .def("AppendTimestep", &ReplayWriter::AppendTimestep, + py::call_guard()) + .def("AddPriority", &ReplayWriter::AddPriority, + py::call_guard()) + .def("Close", &ReplayWriter::Close, + py::call_guard()); + + py::class_(m, "ReplaySampler") + .def( + "GetNextTimestep", + [](ReplaySampler *sampler) { + std::vector sample; + bool end_of_sequence; + MaybeRaiseFromStatus( + sampler->GetNextTimestep(&sample, &end_of_sequence)); + return std::make_pair(std::move(sample), end_of_sequence); + }, + py::call_guard()) + .def("Close", &ReplaySampler::Close, + py::call_guard()); + + py::class_(m, "ReplayClient") + .def(py::init(), py::arg("server_name")) + .def( + "NewWriter", + [](ReplayClient *client, int chunk_length, int max_timesteps, + bool delta_encoded) { + std::unique_ptr writer; + MaybeRaiseFromStatus(client->NewWriter(chunk_length, max_timesteps, + delta_encoded, &writer)); + return writer; + }, + py::call_guard(), py::arg("chunk_length"), + py::arg("max_timesteps"), py::arg("delta_encoded") = false) + .def( + "NewSampler", + [](ReplayClient *client, const std::string &table, int64_t max_samples, + size_t buffer_size) { + std::unique_ptr sampler; + ReplaySampler::Options options; + options.max_samples = max_samples; + options.max_in_flight_samples_per_worker = buffer_size; + MaybeRaiseFromStatus(client->NewSampler(table, options, &sampler)); + return sampler; + }, + py::call_guard()) + .def( + "MutatePriorities", + [](ReplayClient *client, const std::string &table, + const std::vector> &updates, + const std::vector &deletes) { + std::vector update_protos; + for (const auto &update : updates) { + update_protos.emplace_back(); + update_protos.back().set_key(update.first); + update_protos.back().set_priority(update.second); + } + return client->MutatePriorities(table, update_protos, deletes); + }, + py::call_guard()) + .def("Reset", &ReplayClient::Reset, + py::call_guard()) + .def("ServerInfo", + [](ReplayClient *client, int timeout_sec) { + // Wait indefinetely for server to startup when timeout not + // provided. + auto timeout = timeout_sec > 0 ? absl::Seconds(timeout_sec) + : absl::InfiniteDuration(); + + struct ReplayClient::ServerInfo info; + + // Release the GIL only when waiting for the call to complete. If + // the GIL is not held when `MaybeRaiseFromStatus` is called it can + // result in segfaults as the Python exception is populated with + // details from the status. + tensorflow::Status status; + { + py::gil_scoped_release g; + status = client->ServerInfo(timeout, &info); + } + MaybeRaiseFromStatus(status); + + // Return a serialized ServerInfo proto bytes string. + std::vector serialized_table_info; + serialized_table_info.reserve(info.table_info.size()); + for (const auto &table_info : info.table_info) { + serialized_table_info.push_back( + py::bytes(table_info.SerializeAsString())); + } + return serialized_table_info; + }) + .def("Checkpoint", [](ReplayClient *client) { + std::string path; + MaybeRaiseFromStatus(client->Checkpoint(&path)); + return path; + }); + + py::class_> + unused_checkpointer_interface(m, "CheckpointerInterface"); + + m.def( + "create_default_checkpointer", + [](const std::string &name, const std::string &group = "") { + auto checkpointer = CreateDefaultCheckpointer(name, group); + return std::shared_ptr(checkpointer.release()); + }, + py::call_guard()); + + py::class_>(m, "ReverbServer") + .def(py::init( + [](std::vector> priority_tables, + int port, + std::shared_ptr checkpointer = + nullptr) { + std::unique_ptr server; + MaybeRaiseFromStatus(ReverbServer::StartReverbServer( + std::move(priority_tables), port, std::move(checkpointer), + &server)); + return server.release(); + }), + py::arg("priority_tables"), py::arg("port"), + py::arg("checkpointer") = nullptr) + .def("Stop", &ReverbServer::Stop, + py::call_guard()) + .def("Wait", &ReverbServer::Wait, + py::call_guard()) + .def("InProcessClient", &ReverbServer::InProcessClient, + py::call_guard()); +} + +} // namespace +} // namespace reverb +} // namespace deepmind diff --git a/reverb/pybind_test.py b/reverb/pybind_test.py new file mode 100644 index 0000000..7e89701 --- /dev/null +++ b/reverb/pybind_test.py @@ -0,0 +1,75 @@ +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sanity tests for the pybind.py.""" + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +import reverb + +TABLE_NAME = 'queue' + + +class TestNdArrayToTensorAndBack(parameterized.TestCase): + + @classmethod + def setUpClass(cls): + super(TestNdArrayToTensorAndBack, cls).setUpClass() + cls._server = reverb.Server( + priority_tables=[reverb.PriorityTable.queue(TABLE_NAME, 1)], + port=None, + ) + cls._client = cls._server.in_process_client() + + def tearDown(self): + super(TestNdArrayToTensorAndBack, self).tearDown() + self._client.reset(TABLE_NAME) + + @classmethod + def tearDownClass(cls): + super(TestNdArrayToTensorAndBack, cls).tearDownClass() + cls._server.stop() + + @parameterized.parameters( + (1,), + (1.0,), + (np.arange(4).reshape([2, 2]),), + (np.array(1, dtype=np.float16),), + (np.array(1, dtype=np.float32),), + (np.array(1, dtype=np.float64),), + (np.array(1, dtype=np.int8),), + (np.array(1, dtype=np.int16),), + (np.array(1, dtype=np.int32),), + (np.array(1, dtype=np.int64),), + (np.array(1, dtype=np.uint8),), + (np.array(1, dtype=np.uint16),), + (np.array(1, dtype=np.uint32),), + (np.array(1, dtype=np.uint64),), + (np.array(True, dtype=np.bool),), + (np.array(1, dtype=np.complex64),), + (np.array(1, dtype=np.complex128),), + (np.array([b'a string']),), + ) + def test_sanity_check(self, data): + with self._client.writer(1) as writer: + writer.append_timestep([data]) + writer.create_prioritized_item(TABLE_NAME, 1, 1) + + sample = next(self._client.sample(TABLE_NAME)) + got = sample[0].data[0] + np.testing.assert_array_equal(data, got) + +if __name__ == '__main__': + absltest.main() diff --git a/reverb/rate_limiters.py b/reverb/rate_limiters.py new file mode 100644 index 0000000..a14f877 --- /dev/null +++ b/reverb/rate_limiters.py @@ -0,0 +1,180 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rate limiters.""" + +import abc +import sys + +from typing import Tuple, Union +from absl import logging +from reverb import pybind + + +class RateLimiter(metaclass=abc.ABCMeta): + """Abstract base class for RateLimiters.""" + + def __init__(self, internal_limiter: pybind.RateLimiter): + self.internal_limiter = internal_limiter + + +class MinSize(RateLimiter): + """Block sample calls unless replay contains `min_size_to_sample`. + + This limiter blocks all sample calls when the replay contains less than + `min_size_to_sample` items, and accepts all sample calls otherwise. + """ + + def __init__(self, min_size_to_sample: int): + if min_size_to_sample < 1: + raise ValueError( + f'min_size_to_sample ({min_size_to_sample}) must be a positive ' + f'integer') + + super().__init__( + pybind.RateLimiter( + samples_per_insert=1.0, + min_size_to_sample=min_size_to_sample, + min_diff=-sys.float_info.max, + max_diff=sys.float_info.max)) + + +class SampleToInsertRatio(RateLimiter): + """Maintains a specified ratio between samples and inserts. + + The limiter works in two stages: + + Stage 1. Size of table is lt `min_size_to_sample`. + Stage 2. Size of table is ge `min_size_to_sample`. + + During stage 1 the limiter works exactly like MinSize, i.e. it allows + all insert calls and blocks all sample calls. Note that it is possible to + transition into stage 1 from stage 2 when items are removed from the table. + + During stage 2 the limiter attempts to maintain the ratio + `samples_per_inserts` between the samples and inserts. This is done by + measuring the "error" in this ratio, calculated as: + + (number_of_inserts - min_size_to_sample) * samples_per_insert + - number_of_samples + + If this quantity is within the range (-error_buffer, error_buffer) then no + limiting occurs. If the error is larger than `error_buffer` then insert calls + will be blocked; sampling will be blocked for error less than -error_buffer. + + `error_buffer` exists to avoid unnecessary blocking for a system that is + more or less in equilibrium. + """ + + def __init__(self, samples_per_insert: float, min_size_to_sample: int, + error_buffer: Union[float, Tuple[float, float]]): + """Constructor of SampleToInsertRatio. + + Args: + samples_per_insert: The average number of times the learner should sample + each item in the replay error_buffer during the item's entire lifetime. + min_size_to_sample: The minimum number of items that the table must + contain before transitioning into stage 2. + error_buffer: Maximum size of the "error" before calls should be blocked. + When a single value is provided then inferred range is + ( + min_size_to_sample * samples_per_insert - error_buffer, + min_size_to_sample * samples_per_insert + error_buffer + ) + The offset is added so that the error tracked is for the insert/sample + ratio only takes into account operatons occurring AFTER stage 1. If a + range (two float tuple) then the values are used without any offset. + + Raises: + ValueError: If error_buffer is smaller than max(1.0, samples_per_inserts). + """ + if isinstance(error_buffer, float) or isinstance(error_buffer, int): + offset = samples_per_insert * min_size_to_sample + min_diff = offset - error_buffer + max_diff = offset + error_buffer + else: + min_diff, max_diff = error_buffer + + if max_diff - min_diff < 2 * max(1.0, samples_per_insert): + raise ValueError( + 'The size of error_buffer must be >= 2 * max(1.0, samples_per_insert) ' + 'as smaller values could completely block samples and/or insert calls.' + ) + + if max_diff < samples_per_insert * min_size_to_sample: + logging.warning( + 'The range covered by error_buffer is below ' + 'samples_per_insert * min_size_to_sample. If the sampler cannot ' + 'sample concurrently, this will result in a deadlock as soon as ' + 'min_size_to_sample items have been inserted.') + if min_diff > samples_per_insert * min_size_to_sample: + raise ValueError( + 'The range covered by error_buffer is above ' + 'samples_per_insert * min_size_to_sample. This will result in a ' + 'deadlock as soon as min_size_to_sample items have been inserted.') + + if min_size_to_sample < 1: + raise ValueError( + f'min_size_to_sample ({min_size_to_sample}) must be a positive ' + f'integer') + + super().__init__( + pybind.RateLimiter( + samples_per_insert=samples_per_insert, + min_size_to_sample=min_size_to_sample, + min_diff=min_diff, + max_diff=max_diff)) + + +class Queue(RateLimiter): + """Effectively turns the priority table into a queue. + + NOTE: Do not use this RateLimiter directly. Use PriorityTable.queue instead. + NOTE: Must be used in conjunction with a Fifo sampler and remover. + """ + + def __init__(self, size: int): + """Constructor of Queue (do not use directly). + + Args: + size: Maximum size of the queue. + """ + super().__init__( + pybind.RateLimiter( + samples_per_insert=1.0, + min_size_to_sample=1, + min_diff=0.0, + max_diff=size)) + + +class Stack(RateLimiter): + """Effectively turns the priority table into a stack. + + NOTE: Do not use this RateLimiter directly. Use PriorityTable.stack instead. + NOTE: Must be used in conjunction with a Lifo sampler and remover. + """ + + def __init__(self, size: int): + """Constructor of Stack (do not use directly). + + Args: + size: Maximum size of the stack. + """ + super().__init__( + pybind.RateLimiter( + samples_per_insert=1.0, + min_size_to_sample=1, + min_diff=0.0, + max_diff=size)) diff --git a/reverb/rate_limiters_test.py b/reverb/rate_limiters_test.py new file mode 100644 index 0000000..2e70a1c --- /dev/null +++ b/reverb/rate_limiters_test.py @@ -0,0 +1,125 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Reverb rate limiters.""" + +from absl.testing import absltest +from absl.testing import parameterized +from reverb import rate_limiters + + +class TestSampleToInsertRatio(parameterized.TestCase): + + @parameterized.named_parameters( + { + 'testcase_name': 'less_than_samples_per_insert', + 'samples_per_insert': 5, + 'error_buffer': 4, + 'want': ValueError, + }, + { + 'testcase_name': 'less_than_one', + 'samples_per_insert': 0.5, + 'error_buffer': 0.9, + 'want': ValueError, + }, + { + 'testcase_name': 'valid', + 'samples_per_insert': 0.5, + 'error_buffer': 1.1, + 'want': None, + }, + ) + def test_validates_single_number_error_buffer(self, samples_per_insert, + error_buffer, want): + if want: + with self.assertRaises(want): + rate_limiters.SampleToInsertRatio(samples_per_insert, 10, error_buffer) + else: # Should not raise any error. + rate_limiters.SampleToInsertRatio(samples_per_insert, 10, error_buffer) + + @parameterized.named_parameters( + { + 'testcase_name': 'range_too_small_due_to_sample_per_insert_ratio', + 'min_size_to_sample': 10, + 'samples_per_insert': 5, + 'error_buffer': (8, 12), + 'want': ValueError, + }, + { + 'testcase_name': 'range_smaller_than_2', + 'min_size_to_sample': 10, + 'samples_per_insert': 0.1, + 'error_buffer': (9.5, 10.5), + 'want': ValueError, + }, + { + 'testcase_name': 'range_below_min_size_to_sample', + 'min_size_to_sample': 10, + 'samples_per_insert': 1, + 'error_buffer': (5, 9), + 'want': None, + }, + { + 'testcase_name': 'range_above_min_size_to_sample', + 'min_size_to_sample': 10, + 'samples_per_insert': 1, + 'error_buffer': (11, 15), + 'want': ValueError, + }, + { + 'testcase_name': 'min_size_to_sample_smaller_than_1', + 'min_size_to_sample': 0, + 'samples_per_insert': 1, + 'error_buffer': (-100, 100), + 'want': ValueError, + }, + { + 'testcase_name': 'valid', + 'min_size_to_sample': 10, + 'samples_per_insert': 1, + 'error_buffer': (7, 12), + 'want': None, + }, + ) + def test_validates_explicit_range_error_buffer(self, min_size_to_sample, + samples_per_insert, + error_buffer, want): + if want: + with self.assertRaises(want): + rate_limiters.SampleToInsertRatio(samples_per_insert, + min_size_to_sample, error_buffer) + else: # Should not raise any error. + rate_limiters.SampleToInsertRatio(samples_per_insert, min_size_to_sample, + error_buffer) + + +class TestMinSize(parameterized.TestCase): + + @parameterized.parameters( + (-1, True), + (0, True), + (1, False), + ) + def test_raises_if_min_size_lt_1(self, min_size_to_sample, want_error): + if want_error: + with self.assertRaises(ValueError): + rate_limiters.MinSize(min_size_to_sample) + else: + rate_limiters.MinSize(min_size_to_sample) + + +if __name__ == '__main__': + absltest.main() diff --git a/reverb/replay_sample.py b/reverb/replay_sample.py new file mode 100644 index 0000000..b0fe1ac --- /dev/null +++ b/reverb/replay_sample.py @@ -0,0 +1,54 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data structures for output of client samples.""" + +from typing import Any, NamedTuple, Sequence, Union + +import numpy as np +import tensorflow.compat.v1 as tf + + +class SampleInfo(NamedTuple): + """Extra details about the sampled item. + + Fields: + key: Key of the item that was sampled. Used for updating the priority. + Typically a python `int` (for output of ReplayClient.sample) or + `tf.uint64` Tensor (for output of TF ReplayClient.sample). + probability: Probability of selecting the item at the time of sampling. + A python `float` or `tf.float64` Tensor. + table_size: The total number of items present in the table at sample time. + """ + key: Union[np.ndarray, tf.Tensor] + probability: Union[np.ndarray, tf.Tensor] + table_size: Union[np.ndarray, tf.Tensor] + + @classmethod + def tf_dtypes(cls): + return cls(tf.uint64, tf.double, tf.int64) + + +class ReplaySample(NamedTuple): + """Item returned by sample operations. + + Fields: + info: Details about the sampled item. Instance of `SampleInfo`. + data: Tensors for the data. Flat list of numpy arrays for output of python + `ReverbClient.sample`, nested structure of Tensors for TensorFlow + `ReverbClient.sample`. + """ + info: SampleInfo + data: Union[Sequence[np.ndarray], Any] diff --git a/reverb/reverb_types.py b/reverb/reverb_types.py new file mode 100644 index 0000000..e21a2dd --- /dev/null +++ b/reverb/reverb_types.py @@ -0,0 +1,51 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pytype helpers.""" + +import collections +from typing import Any, Iterable, Mapping, NamedTuple, Optional, Union, get_type_hints + +from reverb import pybind +import tensorflow.compat.v1 as tf + +from reverb.cc import schema_pb2 + + +Fifo = pybind.FifoDistribution +Prioritized = pybind.PrioritizedDistribution +Uniform = pybind.UniformDistribution + +DistributionType = Union[Fifo, pybind.HeapDistribution, Prioritized, Uniform] + +# Note that this is effectively treated as `Any`; see b/109648354. +SpecNest = Union[ + tf.TensorSpec, Iterable['SpecNest'], Mapping[str, 'SpecNest']] # pytype: disable=not-supported-yet + +_table_info_proto_types = get_type_hints(schema_pb2.TableInfo) or {} + +_table_info_type_dict = collections.OrderedDict( + (descr.name, _table_info_proto_types.get(descr.name, Any)) + for descr in schema_pb2.TableInfo.DESCRIPTOR.fields) +_table_info_type_dict['signature'] = Optional[SpecNest] + + +"""A tuple describing PriorityTable information. + +The main difference between this object and a `schema_pb2.TableInfo` message +is that the signature is a nested structure of `tf.TypeSpec` objects, +instead of a raw proto. +""" +TableInfo = NamedTuple('TableInfo', tuple(_table_info_type_dict.items())) diff --git a/reverb/server.py b/reverb/server.py new file mode 100644 index 0000000..eb17281 --- /dev/null +++ b/reverb/server.py @@ -0,0 +1,257 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python bindings for creating and serving the Reverb ReplayService. + +See ./client.py and ./tf_client.py for details of how to interact with the +service. +""" + +import abc +import collections +from typing import List, Optional, Sequence, Union + +from absl import logging + +import portpicker +from reverb import checkpointer as checkpointer_lib +from reverb import client +from reverb import distributions +from reverb import pybind +from reverb import rate_limiters +from reverb import reverb_types +import termcolor +import tree + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.python.framework import tensor_spec +from tensorflow.python.saved_model import nested_structure_coder +# pylint: enable=g-direct-tensorflow-import + + +class PriorityTableExtensionBase(metaclass=abc.ABCMeta): + """Abstract base class for PriorityTable extensions.""" + + @abc.abstractmethod + def build_internal_extensions( + self, table_name: str) -> List[pybind.PriorityTableExtensionInterface]: + """Constructs the c++ PriorityTableExtensions.""" + + +class PriorityTable: + """PriorityTable defines how items are selected for sampling and removal.""" + + def __init__(self, + name: str, + sampler: reverb_types.DistributionType, + remover: reverb_types.DistributionType, + max_size: int, + max_times_sampled: int = 0, + rate_limiter: Optional[rate_limiters.RateLimiter] = None, + extensions: Sequence[PriorityTableExtensionBase] = (), + signature: Optional[reverb_types.SpecNest] = None): + """Constructor of the PriorityTable. + + Args: + name: Name of the priority table. + sampler: The strategy to use when selecting samples. + remover: The strategy to use when selecting which items to remove. + max_size: The maximum number of items which the replay is allowed to hold. + When an item is inserted into an already full priority table the + `remover` is used for selecting which item to remove before proceeding + with the new insert. + max_times_sampled: Maximum number of times an item can be sampled before + it is deleted. Any value < 1 is ignored and means there is no limit. + rate_limiter: Manages the data flow by limiting the sample and insert + calls. Defaults to `rate_limiters.MinSize` using 95% of `max_size` as + `min_size_to_sample`. + extensions: Optional sequence of extensions used to add extra features to + the table. + signature: Optional nested structure containing `tf.TypeSpec` objects, + describing the storage schema for this table. + + Raises: + ValueError: If name is empty. + ValueError: If max_size <= 0. + """ + if not name: + raise ValueError('name must be nonempty') + if max_size <= 0: + raise ValueError('max_size (%d) must be a positive integer' % max_size) + + if rate_limiter is None: + min_items_for_sampling = int(0.95 * max_size) + rate_limiter = rate_limiters.MinSize(min_items_for_sampling) + + # Merge the c++ extensions into a single list. + internal_extensions = [] + for extension in extensions: + internal_extensions += extension.build_internal_extensions(name) + + if signature: + flat_signature = tree.flatten(signature) + for s in flat_signature: + if not isinstance(s, tensor_spec.TensorSpec): + raise ValueError(f'Unsupported signature spec: {s}') + signature_proto_str = ( + nested_structure_coder.StructureCoder().encode_structure( + signature).SerializeToString()) + else: + signature_proto_str = None + + self.internal_table = pybind.PriorityTable( + name=name, + sampler=sampler, + remover=remover, + max_size=max_size, + max_times_sampled=max_times_sampled, + rate_limiter=rate_limiter.internal_limiter, + extensions=internal_extensions, + signature=signature_proto_str) + + @classmethod + def queue(cls, name: str, max_size: int): + """Constructs a PriorityTable which acts like a queue. + + Args: + name: Name of the priority table (aka queue). + max_size: Maximum number of items in the priority table (aka queue). + + Returns: + PriorityTable which behaves like a queue of size `max_size`. + """ + return cls( + name=name, + sampler=distributions.Fifo(), + remover=distributions.Fifo(), + max_size=max_size, + max_times_sampled=1, + rate_limiter=rate_limiters.Queue(max_size)) + + @classmethod + def stack(cls, name: str, max_size: int): + """Constructs a PriorityTable which acts like a stack. + + Args: + name: Name of the priority table (aka stack). + max_size: Maximum number of items in the priority table (aka stack). + + Returns: + PriorityTable which behaves like a stack of size `max_size`. + """ + return cls( + name=name, + sampler=distributions.Lifo(), + remover=distributions.Lifo(), + max_size=max_size, + max_times_sampled=1, + rate_limiter=rate_limiters.Stack(max_size)) + + @property + def name(self): + return self.internal_table.name() + + def can_sample(self, num_samples: int) -> bool: + """Returns True if a sample operation is permitted at the current state.""" + return self.internal_table.can_sample(num_samples) + + def can_insert(self, num_inserts: int) -> bool: + """Returns True if an insert operation is permitted at the current state.""" + return self.internal_table.can_insert(num_inserts) + + +class Server: + """Reverb replay server. + + The Server hosts the gRPC-service deepmind.reverb.ReplayService (see + //third_party/reverb/replay_service.proto). See ./client.py and + ./tf_client for details of how to interact with the service. + + A Server maintains inserted data and one or more PriorityTables. + Multiple tables can be used to provide different views of the same underlying + and since the operations performed by the PriorityTable is relatively + inexpensive compared to operations on the actual data using multiple tables + referencing the same data is encouraged over replicating data. + """ + + def __init__(self, + priority_tables: List[PriorityTable], + port: Union[int, None], + checkpointer: checkpointer_lib.CheckpointerBase = None): + """Constructor of Server serving the ReplayService. + + Args: + priority_tables: A list of priority tables to host on the server. + port: The port number to serve the gRPC-service on. If `None` is passed + then a port is automatically picked and assigned. + checkpointer: Checkpointer used for storing/loading checkpoints. If None + (default) then `checkpointer_lib.default_checkpointer` is used to + construct the checkpointer. + + Raises: + ValueError: If priority_tables is empty. + ValueError: If multiple PriorityTable in priority_tables share names. + """ + if not priority_tables: + raise ValueError('At least one priority table must be provided') + names = collections.Counter(table.name for table in priority_tables) + duplicates = [name for name, count in names.items() if count > 1] + if duplicates: + raise ValueError( + 'Multiple items in priority_tables have the same name: {}'.format( + ', '.join(duplicates))) + + if port is None: + port = portpicker.pick_unused_port() + + if checkpointer is None: + checkpointer = checkpointer_lib.default_checkpointer() + + self._server = pybind.ReverbServer( + [table.internal_table for table in priority_tables], port, + checkpointer.internal_checkpointer()) + self._port = port + + def __del__(self): + """Stop server and free up the port if was reserved through portpicker.""" + if hasattr(self, '_server'): + self.stop() + + if hasattr(self, '_port'): + portpicker.return_port(self._port) + + @property + def port(self): + """Port the gRPC service is running at.""" + return self._port + + def stop(self): + """Request that the ReplayService is terminated and wait for shutdown.""" + return self._server.Stop() + + def wait(self): + """Wait indefinitely for the ReplayService to stop.""" + return self._server.Wait() + + def in_process_client(self): + """Gets a local in process client. + + This bypasses proto serialization and network overhead. + + Returns: + ReplayClient. Must not be used after this ReplayServer has been stopped! + """ + return client.Client(f'[::1]:{self._port}', + self._server.InProcessClient()) diff --git a/reverb/server_test.py b/reverb/server_test.py new file mode 100644 index 0000000..eec2ab0 --- /dev/null +++ b/reverb/server_test.py @@ -0,0 +1,94 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for python server. + +Note: Most of the functionality is tested through ./client_test.py. This file +only contains a few extra cases which does not fit well in the client tests. +""" + +from absl.testing import absltest +from reverb import distributions +from reverb import rate_limiters +from reverb import server + +TABLE_NAME = 'table' + + +class ServerTest(absltest.TestCase): + + def test_in_process_client(self): + my_server = server.Server( + priority_tables=[ + server.PriorityTable( + name=TABLE_NAME, + sampler=distributions.Prioritized(1), + remover=distributions.Fifo(), + max_size=100, + rate_limiter=rate_limiters.MinSize(2)), + ], + port=None) + my_client = my_server.in_process_client() + my_client.reset(TABLE_NAME) + del my_client + my_server.stop() + + def test_duplicate_priority_table_name(self): + with self.assertRaises(ValueError): + server.Server( + priority_tables=[ + server.PriorityTable( + name='test', + sampler=distributions.Prioritized(1), + remover=distributions.Fifo(), + max_size=100, + rate_limiter=rate_limiters.MinSize(2)), + server.PriorityTable( + name='test', + sampler=distributions.Prioritized(2), + remover=distributions.Fifo(), + max_size=200, + rate_limiter=rate_limiters.MinSize(1)) + ], + port=None) + + def test_no_priority_table_provided(self): + with self.assertRaises(ValueError): + server.Server(priority_tables=[], port=None) + + def test_can_sample(self): + table = server.PriorityTable( + name=TABLE_NAME, + sampler=distributions.Prioritized(1), + remover=distributions.Fifo(), + max_size=100, + max_times_sampled=1, + rate_limiter=rate_limiters.MinSize(2)) + my_server = server.Server(priority_tables=[table], port=None) + my_client = my_server.in_process_client() + self.assertFalse(table.can_sample(1)) + self.assertTrue(table.can_insert(1)) + my_client.insert(1, {TABLE_NAME: 1.0}) + self.assertFalse(table.can_sample(1)) + my_client.insert(1, {TABLE_NAME: 1.0}) + self.assertTrue(table.can_sample(2)) + # TODO(b/153258711): This should return False since max_times_sampled=1. + self.assertTrue(table.can_sample(3)) + del my_client + my_server.stop() + + +if __name__ == '__main__': + absltest.main() diff --git a/reverb/tf_client.py b/reverb/tf_client.py new file mode 100644 index 0000000..132c817 --- /dev/null +++ b/reverb/tf_client.py @@ -0,0 +1,404 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TFClient provides tf-ops for interacting with Reverb.""" + +from typing import Any, Sequence, Optional + +from reverb import replay_sample +import tensorflow.compat.v1 as tf +import tree + +from reverb.cc.ops import gen_client_ops +from reverb.cc.ops import gen_dataset_op + + +class ReplayDataset(tf.data.Dataset): + """A tf.data.Dataset which samples timesteps from the ReplayService. + + Note: The dataset returns `ReplaySample` where `data` with the structure of + `dtypes` and `shapes`. + + Note: Uses of Python lists are converted into tuples as nest used by the + tf.data API doesn't have good support for lists. + + Timesteps are streamed through the dataset as follows: + + 1. Does an active prioritized item exists? + - Yes: Go to 3 + - No: Go to 2. + 2. Sample a prioritized item from `table` using its sample-function and set + the item as "active". Go to 3. + 3. Yield the next timestep within the active prioritized item. If the + timestep was the last one within the item, clear its "active" status. + + This allows for items of arbitrary length to be streamed with limited memory. + """ + + def __init__(self, + server_address: str, + table: str, + dtypes: Any, + shapes: Any, + max_in_flight_samples_per_worker: int, + num_workers_per_iterator: int = -1, + max_samples_per_stream: int = -1, + sequence_length: Optional[int] = None, + emit_timesteps: bool = True): + """Constructs a new ReplayDataset. + + Args: + server_address: Address of gRPC ReplayService. + table: Probability table to sample from. + dtypes: Dtypes of the data output. Can be nested. + shapes: Shapes of the data output. Can be nested. + max_in_flight_samples_per_worker: The number of samples requested in each + batch of samples. Higher values give higher throughput but too big + values can result in skewed sampling distributions as large number of + samples are fetched from single snapshot of the replay (followed by a + period of lower activity as the samples are consumed). A good rule of + thumb is to set this value to 2-3x times the batch size used. + num_workers_per_iterator: (Defaults to -1, i.e auto selected) The + number of worker threads to create per dataset iterator. When the + selected table uses a FIFO sampler (i.e a queue) then exactly 1 worker + must be used to avoid races causing invalid ordering of items. For all + other samplers, this value should be roughly equal to the number of + threads available on the CPU. + max_samples_per_stream: (Defaults to -1, i.e auto selected) The + maximum number of samples to fetch from a stream before a new call is + made. Keeping this number low ensures that the data is fetched + uniformly from all server. + sequence_length: (Defaults to None, i.e unknown) The number of timesteps + that each sample consists of. If set then the length of samples received + from the server will be validated against this number. + emit_timesteps: (Defaults to True) If set, timesteps instead of full + sequences are retturned from the dataset. Returning sequences instead + of timesteps can be more efficient as the memcopies caused by the + splitting and batching of tensor can be avoided. Note that if set to + False then then all `shapes` must have dim[0] equal to + `sequence_length`. + + Raises: + ValueError: If `dtypes` and `shapes` don't share the same structure. + ValueError: If max_in_flight_samples_per_worker is not a positive integer. + ValueError: If num_workers_per_iterator is not a positive integer or -1. + ValueError: If max_samples_per_stream is not a positive integer or -1. + ValueError: If sequence_length is not a positive integer or None. + ValueError: If emit_timesteps is False and not all items in shapes has + sequence_length as its leading dimension. + """ + tree.assert_same_structure(dtypes, shapes, False) + if max_in_flight_samples_per_worker < 1: + raise ValueError( + 'max_in_flight_samples_per_worker (%d) must be a positive integer' % + max_in_flight_samples_per_worker) + if num_workers_per_iterator < 1 and num_workers_per_iterator != -1: + raise ValueError( + 'num_workers_per_iterator (%d) must be a positive integer or -1' % + num_workers_per_iterator) + if max_samples_per_stream < 1 and max_samples_per_stream != -1: + raise ValueError( + 'max_samples_per_stream (%d) must be a positive integer or -1' % + max_samples_per_stream) + if sequence_length is not None and sequence_length < 1: + raise ValueError( + 'sequence_length (%s) must be None or a positive integer' % + sequence_length) + + # Add the info fields. + dtypes = replay_sample.ReplaySample(replay_sample.SampleInfo.tf_dtypes(), + dtypes) + shapes = replay_sample.ReplaySample( + replay_sample.SampleInfo( + tf.TensorShape([sequence_length] if not emit_timesteps else []), + tf.TensorShape([sequence_length] if not emit_timesteps else []), + tf.TensorShape([sequence_length] if not emit_timesteps else [])), + shapes) + + # If sequences are to be emitted then all shapes must specify use + # sequence_length as their batch dimension. + if not emit_timesteps: + + def _validate_batch_dim(path: str, shape: tf.TensorShape): + if (not shape.ndims + or tf.compat.dimension_value(shape[0]) != sequence_length): + raise ValueError( + 'All items in shapes must use sequence_range (%s) as the leading ' + 'dimension, but "%s" has shape %s' % + (sequence_length, path[0], shape)) + + tree.map_structure_with_path(_validate_batch_dim, shapes.data) + + # The tf.data API doesn't fully support lists so we convert all uses of + # lists into tuples. + dtypes = _convert_lists_to_tuples(dtypes) + shapes = _convert_lists_to_tuples(shapes) + + self._server_address = server_address + self._table = table + self._dtypes = dtypes + self._shapes = shapes + self._sequence_length = sequence_length + self._emit_timesteps = emit_timesteps + self._max_in_flight_samples_per_worker = max_in_flight_samples_per_worker + self._num_workers_per_iterator = num_workers_per_iterator + self._max_samples_per_stream = max_samples_per_stream + + if _is_tf1_runtime(): + # Disabling to avoid errors given the different tf.data.Dataset init args + # between v1 and v2 APIs. + # pytype: disable=wrong-arg-count + super().__init__() + else: + # DatasetV2 requires the dataset as a variant tensor during init. + super().__init__(self._as_variant_tensor()) + # pytype: enable=wrong-arg-count + + def _as_variant_tensor(self): + return gen_dataset_op.reverb_dataset( + server_address=self._server_address, + table=self._table, + dtypes=tree.flatten(self._dtypes), + shapes=tree.flatten(self._shapes), + emit_timesteps=self._emit_timesteps, + sequence_length=self._sequence_length or -1, + max_in_flight_samples_per_worker=self._max_in_flight_samples_per_worker, + num_workers_per_iterator=self._num_workers_per_iterator, + max_samples_per_stream=self._max_samples_per_stream) + + def _inputs(self): + return [] + + @property + def element_spec(self): + return tree.map_structure(tf.TensorSpec, self._shapes, self._dtypes) + + +class TFClient: + """Client class for calling Reverb replay servers from a TensorFlow graph.""" + + def __init__(self, + server_address: str, + shared_name: Optional[str] = None, + name='reverb'): + """Creates the client TensorFlow handle. + + Args: + server_address: Address of the server. + shared_name: (Optional) If non-empty, this client will be shared under the + given name across multiple sessions. + name: Optional name for the Client operations. + """ + self._name = name + self._server_address = server_address + self._handle = gen_client_ops.reverb_client( + server_address=server_address, shared_name=shared_name, name=name) + + def sample(self, + table: str, + data_dtypes, + name: Optional[str] = None) -> replay_sample.ReplaySample: + """Samples an item from the replay. + + This only allows sampling items with a data field. + + Args: + table: Probability table to sample from. + data_dtypes: Dtypes of the data output. Can be nested. + name: Optional name for the Client operations. + + Returns: + A ReplaySample with data nested according to data_dtypes. See ReplaySample + for more details. + """ + with tf.name_scope(name, f'{self._name}_sample', ['sample']) as scope: + key, probability, table_size, data = gen_client_ops.reverb_client_sample( + self._handle, table, tree.flatten(data_dtypes), name=scope) + return replay_sample.ReplaySample( + replay_sample.SampleInfo(key, probability, table_size), + tree.unflatten_as(data_dtypes, data)) + + def insert(self, + data: Sequence[tf.Tensor], + tables: tf.Tensor, + priorities: tf.Tensor, + name: Optional[str] = None): + """Inserts a trajectory into one or more tables. + + The content of `tables` and `priorities` are zipped to create the + prioritized items. That is, an item with priority `priorities[i]` is + inserted into `tables[i]`. + + Args: + data: Tensors to insert as the trajectory. + tables: Rank 1 tensor with the names of the tables to create prioritized + items in. + priorities: Rank 1 tensor with priorities of the new items. + name: Optional name for the client operation. + + Returns: + A tf-op for performing the insert. + + Raises: + ValueError: If tables is not a string tensor of rank 1. + ValueError: If priorities is not a float64 tensor of rank 1. + ValueError: If priorities and tables does not have the same shape. + """ + if tables.dtype != tf.string or tables.shape.rank != 1: + raise ValueError('tables must be a string tensor of rank 1') + if priorities.dtype != tf.float64 or priorities.shape.rank != 1: + raise ValueError('priorities must be a float64 tensor of rank 1') + if not tables.shape.is_compatible_with(priorities.shape): + raise ValueError('priorities and tables must have the same shape') + + with tf.name_scope(name, f'{self._name}_insert', ['insert']) as scope: + return gen_client_ops.reverb_client_insert( + self._handle, data, tables, priorities, name=scope) + + def update_priorities(self, + table: str, + keys: tf.Tensor, + priorities: tf.Tensor, + name: str = None): + """Creates op for updating priorities of existing items in the replay. + + Not found elements for `keys` are silently ignored. + + Args: + table: Probability table to update. + keys: Keys of the items to update. Must be same length as `priorities`. + priorities: New priorities for `keys`. Must be same length as `keys`. + name: Optional name for the operation. + + Returns: + A tf-op for performing the update. + """ + + with tf.name_scope(name, f'{self._name}_update_priorities', + ['update_priorities']) as scope: + return gen_client_ops.reverb_client_update_priorities( + self._handle, table, keys, priorities, name=scope) + + def dataset(self, + table: str, + dtypes: Sequence[Any], + shapes: Sequence[Any], + capacity: int = 100, + num_workers_per_iterator: int = -1, + max_samples_per_stream: int = -1, + sequence_length: Optional[int] = None, + emit_timesteps: bool = True) -> ReplayDataset: + """Creates a ReplayDataset which samples from Replay service. + + Note: Uses of Python lists are converted into tuples as nest used by the + tf.data API doesn't have good support for lists. + + See ReplayDataset for detailed documentation. + + Args: + table: Probability table to sample from. + dtypes: Dtypes of a single timestep in the sampled items. Can be nested. + shapes: Shapes of a single timestep in the sampled items. Can be nested. + capacity: (Defaults to 100) Maximum number of samples requested by the + workers with each request. Higher values give higher throughput but too + big values can result in skewed sampling distributions as large number + of samples are fetched from single snapshot of the replay (followed by a + period of lower activity as the samples are consumed). A good rule of + thumb is to set this value to 2-3x times the batch size used. + num_workers_per_iterator: (Defaults to -1, i.e auto selected) The number + of worker threads to create per dataset iterator. When the selected + table uses a FIFO sampler (i.e a queue) then exactly 1 worker must be + used to avoid races causing invalid ordering of items. For all other + samplers, this value should be roughly equal to the number of threads + available on the CPU. + max_samples_per_stream: (Defaults to -1, i.e auto selected) The maximum + number of samples to fetch from a stream before a new call is made. + Keeping this number low ensures that the data is fetched uniformly from + all server. + sequence_length: (Defaults to None, i.e unknown) The number of timesteps + that each sample consists of. If set then the length of samples received + from the server will be validated against this number. + emit_timesteps: (Defaults to True) If set, timesteps instead of full + sequences are retturned from the dataset. Returning sequences instead + of timesteps can be more efficient as the memcopies caused by the + splitting and batching of tensor can be avoided. Note that if set to + False then then all `shapes` must have dim[0] equal to + `sequence_length`. + + Returns: + A ReplayDataset with the above specification. + """ + return ReplayDataset( + server_address=self._server_address, + table=table, + dtypes=dtypes, + shapes=shapes, + max_in_flight_samples_per_worker=capacity, + num_workers_per_iterator=num_workers_per_iterator, + max_samples_per_stream=max_samples_per_stream, + sequence_length=sequence_length, + emit_timesteps=emit_timesteps) + + +# TODO(b/148080741): switch to tree.apply_to_structure when it is available. +def _apply_to_structure(branch_fn, leaf_fn, structure): + """`apply_to_structure` applies branch_fn and leaf_fn to branches and leaves. + + This function accepts two separate callables depending on whether the + structure is a sequence. + + Args: + branch_fn: A function to call on a struct if is_nested(struct) is `True`. + leaf_fn: A function to call on a struct if is_nested(struct) is `False`. + structure: A nested structure containing arguments to be applied to. + + Returns: + A nested structure of function outputs. + + Raises: + TypeError: If `branch_fn` or `leaf_fn` is not callable. + ValueError: If no structure is provided. + """ + if not callable(leaf_fn): + raise TypeError('leaf_fn must be callable, got: %s' % leaf_fn) + + if not callable(branch_fn): + raise TypeError('branch_fn must be callable, got: %s' % branch_fn) + + if not tree.is_nested(structure): + return leaf_fn(structure) + + processed = branch_fn(structure) + + # pylint: disable=protected-access + new_structure = [ + _apply_to_structure(branch_fn, leaf_fn, value) + for value in tree._yield_value(processed) + ] + return tree._sequence_like(processed, new_structure) + # pylint: enable=protected-access + + +def _convert_lists_to_tuples(structure: Any): + return _apply_to_structure( + branch_fn=lambda s: tuple(s) if isinstance(s, list) else s, + leaf_fn=lambda s: s, + structure=structure) + + +def _is_tf1_runtime(): + """Returns True if the runtime is executing with TF1.0 APIs.""" + # TODO(b/145023272): Update when/if there is a better way. + return hasattr(tf, 'to_float') diff --git a/reverb/tf_client_test.py b/reverb/tf_client_test.py new file mode 100644 index 0000000..26918f2 --- /dev/null +++ b/reverb/tf_client_test.py @@ -0,0 +1,734 @@ +# Lint as: python3 +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tf_client.""" + +from concurrent import futures +import threading +import time + +from absl.testing import parameterized +import numpy as np +from reverb import client as reverb_client +from reverb import distributions +from reverb import rate_limiters +from reverb import replay_sample +from reverb import server +from reverb import tf_client +import tensorflow.compat.v1 as tf +import tree + + +def make_server(): + return server.Server( + priority_tables=[ + server.PriorityTable( + 'dist', + sampler=distributions.Prioritized(priority_exponent=1), + remover=distributions.Fifo(), + max_size=1000000, + rate_limiter=rate_limiters.MinSize(1)), + server.PriorityTable( + 'dist2', + sampler=distributions.Prioritized(priority_exponent=1), + remover=distributions.Fifo(), + max_size=1000000, + rate_limiter=rate_limiters.MinSize(1)), + server.PriorityTable( + 'signatured', + sampler=distributions.Prioritized(priority_exponent=1), + remover=distributions.Fifo(), + max_size=1000000, + rate_limiter=rate_limiters.MinSize(1), + signature=tf.TensorSpec(dtype=tf.float32, shape=(None, None))), + ], + port=None, + ) + + +class SampleOpTest(tf.test.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._server = make_server() + cls._client = reverb_client.Client(f'localhost:{cls._server.port}') + + def tearDown(self): + super().tearDown() + self._client.reset('dist') + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._server.stop() + + def testSample(self): + input_data = [np.ones((81, 81), dtype=np.float64)] + self._client.insert(input_data, {'dist': 1}) + with self.session() as session: + client = tf_client.TFClient(self._client.server_address) + sample = session.run(client.sample('dist', [tf.float64])) + np.testing.assert_equal(input_data, sample.data) + self.assertNotEqual(sample.info.key, 0) + self.assertEqual(sample.info.probability, 1) + self.assertEqual(sample.info.table_size, 1) + + def testSampleDtypeMismatchFails(self): + data = [np.zeros((81, 81))] + self._client.insert(data, {'dist': 1}) + with self.session() as session: + client = tf_client.TFClient(self._client.server_address) + with self.assertRaises(tf.errors.InternalError): + session.run(client.sample('dist', [tf.float32])) + + def testSampleForwardServerError(self): + with self.session() as session: + client = tf_client.TFClient(self._client.server_address) + with self.assertRaises(tf.errors.NotFoundError): + session.run(client.sample('invalid', [tf.float64])) + + def testSampleRetryUntilOkOrFatal(self): + with self.session() as session: + client = tf_client.TFClient(self._client.server_address) + with futures.ThreadPoolExecutor(max_workers=1) as executor: + sample = executor.submit(session.run, + client.sample('dist', [tf.float64])) + input_data = [np.zeros((81, 81))] + self._client.insert(input_data, {'dist': 1}) + np.testing.assert_equal(input_data, sample.result().data) + + +class UpdatePrioritiesOpTest(tf.test.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._server = make_server() + cls._client = reverb_client.Client(f'localhost:{cls._server.port}') + + def tearDown(self): + super().tearDown() + self._client.reset('dist') + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._server.stop() + + def testUpdatePrioritiesShapeMismatchFails(self): + with self.session() as session: + client = tf_client.TFClient(self._client.server_address) + update_op = client.update_priorities( + tf.constant('dist'), tf.constant([1, 2], dtype=tf.uint64), + tf.constant([1], dtype=tf.float64)) + with self.assertRaises(tf.errors.InvalidArgumentError): + session.run(update_op) + + def testUpdatePriorities(self): + with self.session() as session: + client = tf_client.TFClient(self._client.server_address) + update_op = client.update_priorities( + tf.constant('dist'), tf.constant([1], dtype=tf.uint64), + tf.constant([1], dtype=tf.float64)) + # TODO(b/154931002): Test that update is applied once Sample method is + # exposed. + self.assertEqual(None, session.run(update_op)) + + +class InsertOpTest(tf.test.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._server = make_server() + cls._client = reverb_client.Client(f'localhost:{cls._server.port}') + + def tearDown(self): + super().tearDown() + self._client.reset('dist') + self._client.reset('dist2') + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._server.stop() + + def setUp(self): + super().setUp() + self.data = [tf.constant([1, 2, 3], dtype=tf.int8)] + + def testValidatesTablesHasRank1(self): + client = tf_client.TFClient(self._client.server_address) + priorities = tf.constant([1.0], dtype=tf.float64) + + # Works for rank 1. + client.insert(self.data, tf.constant(['dist']), priorities) + + # Does not work for rank > 1. + with self.assertRaises(ValueError): + client.insert(self.data, tf.constant([['dist']]), priorities) + + # Does not work for rank < 1. + with self.assertRaises(ValueError): + client.insert(self.data, tf.constant('dist'), priorities) + + def testValidatesTablesDtype(self): + client = tf_client.TFClient(self._client.server_address) + with self.assertRaises(ValueError): + client.insert(self.data, tf.constant([1]), + tf.constant([1.0], dtype=tf.float64)) + + def testValidatesPrioritiesHasRank1(self): + client = tf_client.TFClient(self._client.server_address) + data = [tf.constant([1, 2])] + tables = tf.constant(['dist']) + + # Works for rank 1. + client.insert(data, tables, tf.constant([1.0], dtype=tf.float64)) + + # Does not work for rank > 1. + with self.assertRaises(ValueError): + client.insert(data, tables, tf.constant([[1.0]], dtype=tf.float64)) + + # Does not work for rank < 1. + with self.assertRaises(ValueError): + client.insert(data, tables, tf.constant(1.0, dtype=tf.float64)) + + def testValidatesPrioritiesDtype(self): + client = tf_client.TFClient(self._client.server_address) + with self.assertRaises(ValueError): + client.insert(self.data, tf.constant(['dist']), + tf.constant([1.0], dtype=tf.float32)) + + def testValidatesTablesAndPrioritiesHaveSameShape(self): + client = tf_client.TFClient(self._client.server_address) + with self.assertRaises(ValueError): + client.insert(self.data, tf.constant(['dist', 'dist2']), + tf.constant([1.0], dtype=tf.float64)) + + def testInsertSingleTable(self): + with self.session() as session: + client = tf_client.TFClient(self._client.server_address) + insert_op = client.insert( + data=[tf.constant([1, 2, 3], dtype=tf.int8)], + tables=tf.constant(['dist']), + priorities=tf.constant([1.0], dtype=tf.float64)) + sample_op = client.sample('dist', [tf.int8]) + + # Check that insert op succeeds. + self.assertEqual(None, session.run(insert_op)) + + # Check that the sampled data matches the inserted. + sample = session.run(sample_op) + self.assertLen(sample.data, 1) + np.testing.assert_equal( + np.array([1, 2, 3], dtype=np.int8), sample.data[0]) + + def testInsertMultiTable(self): + with self.session() as session: + client = tf_client.TFClient(self._client.server_address) + insert_op = client.insert( + data=[tf.constant([1, 2, 3], dtype=tf.int8)], + tables=tf.constant(['dist', 'dist2']), + priorities=tf.constant([1.0, 2.0], dtype=tf.float64)) + + sample_ops = [ + client.sample('dist', [tf.int8]), + client.sample('dist2', [tf.int8]) + ] + + # Check that insert op succeeds. + self.assertEqual(None, session.run(insert_op)) + + # Check that the sampled data matches the inserted in all tables. + for sample_op in sample_ops: + sample = session.run(sample_op) + self.assertLen(sample.data, 1) + np.testing.assert_equal( + np.array([1, 2, 3], dtype=np.int8), sample.data[0]) + + +class DatasetTest(tf.test.TestCase, parameterized.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._server = make_server() + cls._client = reverb_client.Client(f'localhost:{cls._server.port}') + + def tearDown(self): + super().tearDown() + self._client.reset('dist') + self._client.reset('signatured') + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._server.stop() + + def _PopulateReplay(self, sequence_length=100, max_time_steps=None): + max_time_steps = max_time_steps or sequence_length + with self._client.writer(max_time_steps) as writer: + for i in range(1000): + writer.append_timestep([np.zeros((3, 3), dtype=np.float32)]) + if i % 5 == 0 and i >= sequence_length: + writer.create_prioritized_item( + table='dist', num_timesteps=sequence_length, priority=1) + writer.create_prioritized_item( + table='signatured', num_timesteps=sequence_length, priority=1) + + def _SampleFrom(self, dataset, num_samples): + iterator = dataset.make_initializable_iterator() + dataset_item = iterator.get_next() + self.evaluate(iterator.initializer) + return [self.evaluate(dataset_item) for _ in range(num_samples)] + + @parameterized.named_parameters( + { + 'testcase_name': 'default_values', + }, + { + 'testcase_name': 'num_workers_per_iterator_is_0', + 'num_workers_per_iterator': 0, + 'want_error': ValueError, + }, + { + 'testcase_name': 'num_workers_per_iterator_is_1', + 'num_workers_per_iterator': 1, + }, + { + 'testcase_name': 'num_workers_per_iterator_is_minus_1', + 'num_workers_per_iterator': -1, + }, + { + 'testcase_name': 'num_workers_per_iterator_is_minus_2', + 'num_workers_per_iterator': -2, + 'want_error': ValueError, + }, + { + 'testcase_name': 'max_samples_per_stream_is_0', + 'max_samples_per_stream': 0, + 'want_error': ValueError, + }, + { + 'testcase_name': 'max_samples_per_stream_is_1', + 'max_samples_per_stream': 1, + }, + { + 'testcase_name': 'max_samples_per_stream_is_minus_1', + 'max_samples_per_stream': -1, + }, + { + 'testcase_name': 'max_samples_per_stream_is_minus_2', + 'num_workers_per_iterator': -2, + 'want_error': ValueError, + }, + { + 'testcase_name': 'capacity_is_0', + 'capacity': 0, + 'want_error': ValueError, + }, + { + 'testcase_name': 'capacity_is_1', + 'capacity': 1, + }, + { + 'testcase_name': 'capacity_is_minus_1', + 'capacity': -1, + 'want_error': ValueError, + }, + ) + def testSamplerParametersValidation(self, **kwargs): + client = tf_client.TFClient(self._client.server_address) + dtypes = (tf.float32,) + shapes = (tf.TensorShape([3, 3]),) + + if 'want_error' in kwargs: + error = kwargs.pop('want_error') + with self.assertRaises(error): + client.dataset('dist', dtypes, shapes, **kwargs) + else: + client.dataset('dist', dtypes, shapes, **kwargs) + + def testIterate(self): + self._PopulateReplay() + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table='dist', dtypes=(tf.float32,), shapes=(tf.TensorShape([3, 3]),)) + got = self._SampleFrom(dataset, 10) + for sample in got: + self.assertIsInstance(sample, replay_sample.ReplaySample) + # A single sample is returned so the key should be a scalar int64. + self.assertIsInstance(sample.info.key, np.uint64) + np.testing.assert_array_equal(sample.data[0], + np.zeros((3, 3), dtype=np.float32)) + + def testInconsistentSignatureSize(self): + self._PopulateReplay() + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table='signatured', + dtypes=(tf.float32, tf.float64), + shapes=(tf.TensorShape([3, 3]), tf.TensorShape([]))) + with self.assertRaisesWithPredicateMatch( + tf.errors.InvalidArgumentError, + r'Inconsistent number of tensors requested from table \'signatured\'. ' + r'Requested 5 tensors, but table signature shows 4 tensors.'): + self._SampleFrom(dataset, 10) + + def testIncompatibleSignatureDtype(self): + self._PopulateReplay() + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table='signatured', + dtypes=(tf.int64,), + shapes=(tf.TensorShape([3, 3]),)) + with self.assertRaisesWithPredicateMatch( + tf.errors.InvalidArgumentError, + r'Requested incompatible tensor at flattened index 3 from table ' + r'\'signatured\'. Requested \(dtype, shape\): \(int64, \[3,3\]\). ' + r'Signature \(dtype, shape\): \(float, \[\?,\?\]\)'): + self._SampleFrom(dataset, 10) + + def testIncompatibleSignatureShape(self): + self._PopulateReplay() + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table='signatured', dtypes=(tf.float32,), shapes=(tf.TensorShape([3]),)) + with self.assertRaisesWithPredicateMatch( + tf.errors.InvalidArgumentError, + r'Requested incompatible tensor at flattened index 3 from table ' + r'\'signatured\'. Requested \(dtype, shape\): \(float, \[3\]\). ' + r'Signature \(dtype, shape\): \(float, \[\?,\?\]\)'): + self._SampleFrom(dataset, 10) + + @parameterized.parameters([1], [3], [10]) + def testIncompatibleShapeWhenUsingSequenceLength(self, sequence_length): + client = tf_client.TFClient(self._client.server_address) + with self.assertRaises(ValueError): + client.dataset( + table='dist', + dtypes=(tf.float32,), + shapes=(tf.TensorShape([sequence_length + 1, 3, 3]),), + emit_timesteps=False, + sequence_length=sequence_length) + + @parameterized.parameters( + ('dist', 1, 1), + ('dist', 1, 3), + ('dist', 3, 3), + ('dist', 3, 5), + ('dist', 10, 10), + ('dist', 10, 11), + ('signatured', 1, 1), + ('signatured', 3, 3), + ('signatured', 3, 5), + ('signatured', 10, 10), + ) + def testIterateWithSequenceLength( + self, table_name, sequence_length, max_time_steps): + # Also ensure we get sequence_length-shaped outputs when + # writers' max_time_steps != sequence_length. + self._PopulateReplay(sequence_length, max_time_steps=max_time_steps) + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table=table_name, + dtypes=(tf.float32,), + shapes=(tf.TensorShape([sequence_length, 3, 3]),), + emit_timesteps=False, + sequence_length=sequence_length) + + got = self._SampleFrom(dataset, 10) + for sample in got: + self.assertIsInstance(sample, replay_sample.ReplaySample) + + # The keys and data should be batched up by the sequence length. + self.assertEqual(sample.info.key.shape, (sequence_length,)) + np.testing.assert_array_equal( + sample.data[0], np.zeros((sequence_length, 3, 3), dtype=np.float32)) + + @parameterized.parameters( + ('dist', 1), + ('dist', 3), + ('dist', 10), + ('signatured', 1), + ('signatured', 3), + ('signatured', 10), + ) + def testIterateWithUnknownSequenceLength(self, table_name, sequence_length): + self._PopulateReplay(sequence_length) + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table=table_name, + dtypes=(tf.float32,), + shapes=(tf.TensorShape([None, 3, 3]),), + emit_timesteps=False, + sequence_length=None) + + # Check the shape of the items. + iterator = dataset.make_initializable_iterator() + dataset_item = iterator.get_next() + self.assertIsNone(dataset_item.info.key.shape.as_list()[0], None) + self.assertIsNone(dataset_item.data[0].shape.as_list()[0], None) + + # Verify that once evaluated, the samples has the expected length. + got = self._SampleFrom(dataset, 10) + for sample in got: + self.assertIsInstance(sample, replay_sample.ReplaySample) + + # The keys and data should be batched up by the sequence length. + self.assertEqual(sample.info.key.shape, (sequence_length,)) + np.testing.assert_array_equal( + sample.data[0], np.zeros((sequence_length, 3, 3), dtype=np.float32)) + + @parameterized.parameters( + ('dist', 1, 2), + ('dist', 2, 1), + ('signatured', 1, 2), + ('signatured', 2, 1), + ) + def testValidatesSequenceLengthWhenTimestepsEmitted(self, table_name, + actual_sequence_length, + provided_sequence_length): + self._PopulateReplay(actual_sequence_length) + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table=table_name, + dtypes=(tf.float32,), + shapes=(tf.TensorShape([provided_sequence_length, 3, 3]),), + emit_timesteps=True, + sequence_length=provided_sequence_length) + + with self.assertRaises(tf.errors.InvalidArgumentError): + self._SampleFrom(dataset, 10) + + @parameterized.named_parameters( + dict(testcase_name='TableDist', table_name='dist'), + dict(testcase_name='TableSignatured', table_name='signatured')) + def testIterateBatched(self, table_name): + self._PopulateReplay() + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table=table_name, + dtypes=(tf.float32,), + shapes=(tf.TensorShape([3, 3]),)) + dataset = dataset.batch(2, True) + + got = self._SampleFrom(dataset, 10) + for sample in got: + self.assertIsInstance(sample, replay_sample.ReplaySample) + + # The keys should be batched up like the data. + self.assertEqual(sample.info.key.shape, (2,)) + + np.testing.assert_array_equal(sample.data[0], + np.zeros((2, 3, 3), dtype=np.float32)) + + def testIterateNestedAndBatched(self): + with self._client.writer(100) as writer: + for i in range(1000): + writer.append_timestep({ + 'observation': { + 'data': np.zeros((3, 3), dtype=np.float32), + 'extras': [ + np.int64(10), + np.ones([1], dtype=np.int32), + ], + }, + 'reward': np.zeros((10, 10), dtype=np.float32), + }) + if i % 5 == 0 and i >= 100: + writer.create_prioritized_item( + table='dist', num_timesteps=100, priority=1) + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table='dist', + dtypes=(((tf.float32), (tf.int64, tf.int32)), tf.float32), + shapes=((tf.TensorShape([3, 3]), (tf.TensorShape(None), + tf.TensorShape([1]))), + tf.TensorShape([10, 10])), + ) + dataset = dataset.batch(3) + + structure = { + 'observation': { + 'data': + tf.TensorSpec([3, 3], tf.float32), + 'extras': [ + tf.TensorSpec([], tf.int64), + tf.TensorSpec([1], tf.int32), + ], + }, + 'reward': tf.TensorSpec([], tf.int64), + } + + got = self._SampleFrom(dataset, 10) + self.assertLen(got, 10) + for sample in got: + self.assertIsInstance(sample, replay_sample.ReplaySample) + + transition = tree.unflatten_as(structure, tree.flatten(sample.data)) + np.testing.assert_array_equal(transition['observation']['data'], + np.zeros([3, 3, 3], dtype=np.float32)) + np.testing.assert_array_equal(transition['observation']['extras'][0], + np.ones([3], dtype=np.int64) * 10) + np.testing.assert_array_equal(transition['observation']['extras'][1], + np.ones([3, 1], dtype=np.int32)) + np.testing.assert_array_equal(transition['reward'], + np.zeros([3, 10, 10], dtype=np.float32)) + + def testMultipleIterators(self): + with self._client.writer(100) as writer: + for i in range(10): + writer.append_timestep([np.ones((81, 81), dtype=np.float32) * i]) + writer.create_prioritized_item(table='dist', num_timesteps=10, priority=1) + + trajectory_length = 5 + batch_size = 3 + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table='dist', dtypes=(tf.float32,), shapes=(tf.TensorShape([81, 81]),)) + dataset = dataset.batch(trajectory_length) + + iterators = [ + dataset.make_initializable_iterator() for _ in range(batch_size) + ] + items = tf.stack( + [tf.squeeze(iterator.get_next().data) for iterator in iterators]) + + with self.session() as session: + session.run([iterator.initializer for iterator in iterators]) + got = session.run(items) + self.assertEqual(got.shape, (batch_size, trajectory_length, 81, 81)) + + want = np.array( + [[np.ones([81, 81]) * i for i in range(trajectory_length)]] * + batch_size) + np.testing.assert_array_equal(got, want) + + def testIterateOverBlobs(self): + for _ in range(10): + self._client.insert((np.ones([3, 3], dtype=np.int32)), {'dist': 1}) + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table='dist', dtypes=(tf.int32,), shapes=(tf.TensorShape([3, 3]),)) + + got = self._SampleFrom(dataset, 20) + self.assertLen(got, 20) + for sample in got: + self.assertIsInstance(sample, replay_sample.ReplaySample) + self.assertIsInstance(sample.info.key, np.uint64) + self.assertIsInstance(sample.info.probability, np.float64) + np.testing.assert_array_equal(sample.data[0], + np.ones((3, 3), dtype=np.int32)) + + def testIterateOverBatchedBlobs(self): + for _ in range(10): + self._client.insert((np.ones([3, 3], dtype=np.int32)), {'dist': 1}) + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table='dist', dtypes=(tf.int32,), shapes=(tf.TensorShape([3, 3]),)) + + dataset = dataset.batch(5) + + got = self._SampleFrom(dataset, 20) + self.assertLen(got, 20) + for sample in got: + self.assertIsInstance(sample, replay_sample.ReplaySample) + self.assertEqual(sample.info.key.shape, (5,)) + np.testing.assert_array_equal(sample.data[0], + np.ones((5, 3, 3), dtype=np.int32)) + + def testConvertsSpecListsIntoTuples(self): + for _ in range(10): + data = [ + (np.ones([1, 1], dtype=np.int32),), + [ + np.ones([3, 3], dtype=np.int8), + (np.ones([2, 2], dtype=np.float64),) + ], + ] + self._client.insert(data, {'dist': 1}) + + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table='dist', + dtypes=[ + (tf.int32,), + [ + tf.int8, + (tf.float64,), + ], + ], + shapes=[ + (tf.TensorShape([1, 1]),), + [ + tf.TensorShape([3, 3]), + (tf.TensorShape([2, 2]),), + ], + ]) + + got = self._SampleFrom(dataset, 10) + + for sample in got: + self.assertIsInstance(sample, replay_sample.ReplaySample) + self.assertIsInstance(sample.info.key, np.uint64) + tree.assert_same_structure(sample.data, ( + (None,), + ( + None, + (None,), + ), + )) + + def testSessionIsClosedWhileOpPending(self): + client = tf_client.TFClient(self._client.server_address) + dataset = client.dataset( + table='dist', dtypes=tf.float32, shapes=tf.TensorShape([])) + + iterator = dataset.make_initializable_iterator() + item = iterator.get_next() + + def _session_closer(sess, wait_time_secs): + def _fn(): + time.sleep(wait_time_secs) + sess.close() + + return _fn + + with self.session() as sess: + sess.run(iterator.initializer) + thread = threading.Thread(target=_session_closer(sess, 3)) + thread.start() + with self.assertRaises(tf.errors.CancelledError): + sess.run(item) + + +if __name__ == '__main__': + tf.disable_eager_execution() + tf.test.main() diff --git a/third_party/BUILD b/third_party/BUILD new file mode 100644 index 0000000..fe756e1 --- /dev/null +++ b/third_party/BUILD @@ -0,0 +1 @@ +licenses(["notice"]) # Apache 2.0 diff --git a/third_party/__init__.py b/third_party/__init__.py new file mode 100644 index 0000000..ac5cd1f --- /dev/null +++ b/third_party/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2019 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/third_party/auditwheel.sh b/third_party/auditwheel.sh new file mode 100755 index 0000000..d6320c7 --- /dev/null +++ b/third_party/auditwheel.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +TF_SHARED_LIBRARY_NAME=$(grep -r TF_SHARED_LIBRARY_NAME .bazelrc | awk -F= '{print$2}') + +POLICY_JSON=$(find / -name policy.json) + +sed -i "s/libresolv.so.2\"/libresolv.so.2\", $TF_SHARED_LIBRARY_NAME/g" $POLICY_JSON + +cat $POLICY_JSON + +auditwheel $@ diff --git a/third_party/opensource_only.files b/third_party/opensource_only.files new file mode 100644 index 0000000..b472d39 --- /dev/null +++ b/third_party/opensource_only.files @@ -0,0 +1,7 @@ +reverb/cc/platform/default/BUILD +third_party/BUILD +third_party/protobuf.BUILD +third_party/pybind11.BUILD +third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD +third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl +third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl \ No newline at end of file diff --git a/third_party/protobuf.BUILD b/third_party/protobuf.BUILD new file mode 100644 index 0000000..b8e9496 --- /dev/null +++ b/third_party/protobuf.BUILD @@ -0,0 +1,40 @@ +_CHECK_VERSION = """ +PROTOC_VERSION=$$($(location @protobuf_protoc//:protoc_bin) --version \ + | cut -d' ' -f2 | sed -e 's/\\./ /g') +PROTOC_VERSION=$$(printf '%d%03d%03d' $${PROTOC_VERSION}) +TF_PROTO_VERSION=$$(grep '#define PROTOBUF_MIN_PROTOC_VERSION' \ + $(location tf_includes/google/protobuf/port_def.inc) | cut -d' ' -f3) +if [ "$${PROTOC_VERSION}" -ne "$${TF_PROTO_VERSION}" ]; then + echo !!!!!!!!!!!!!!!!!!!!!!!!!!!!! 1>&2 + echo Your protoc version does not match the tensorflow proto header \ + required version: "$${PROTOC_VERSION}" vs. "$${TF_PROTO_VERSION}" 1>&2 + echo Please update the PROTOC_VERSION in your WORKSPACE file. 1>&2 + echo !!!!!!!!!!!!!!!!!!!!!!!!!!!!! 1>&2 + false +else + touch $@ +fi +""" + +genrule( + name = "compare_protobuf_version", + outs = ["versions_compared"], + srcs = [ + "tf_includes/google/protobuf/port_def.inc", + ], + tools = ["@protobuf_protoc//:protoc_bin"], + cmd = _CHECK_VERSION, +) + +cc_library( + name = "includes", + data = [":versions_compared"], + hdrs = glob([ + "tf_includes/google/protobuf/*.h", + "tf_includes/google/protobuf/*.inc", + "tf_includes/google/protobuf/**/*.h", + "tf_includes/google/protobuf/**/*.inc", + ]), + includes = ["tf_includes"], + visibility = ["//visibility:public"], +) diff --git a/third_party/pybind11.BUILD b/third_party/pybind11.BUILD new file mode 100644 index 0000000..8bee1d6 --- /dev/null +++ b/third_party/pybind11.BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "pybind11", + hdrs = glob( + include = [ + "include/pybind11/*.h", + "include/pybind11/detail/*.h", + ], + exclude = [ + "include/pybind11/common.h", + "include/pybind11/eigen.h", + ], + ), + copts = [ + "-fexceptions", + "-Wno-undefined-inline", + "-Wno-pragma-once-outside-header", + ], + includes = ["include"], + deps = [ + "@python_includes", + ], +) diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD new file mode 100755 index 0000000..305dfbf --- /dev/null +++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD @@ -0,0 +1,121 @@ +# Copyright 2016 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This becomes the BUILD file for @local_config_cc// under non-FreeBSD unixes. + +load(":cc_toolchain_config.bzl", "cc_toolchain_config") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "malloc", +) + +filegroup( + name = "empty", + srcs = [], +) + +filegroup( + name = "cc_wrapper", + srcs = ["cc_wrapper.sh"], +) + +filegroup( + name = "compiler_deps", + srcs = glob(["extra_tools/**"]) + [":empty"], +) + +# This is the entry point for --crosstool_top. Toolchains are found +# by lopping off the name of --crosstool_top and searching for +# the "${CPU}" entry in the toolchains attribute. +cc_toolchain_suite( + name = "toolchain", + toolchains = { + "k8|/dt7/usr/bin/gcc": ":cc-compiler-k8", + "k8": ":cc-compiler-k8", + "armeabi-v7a|compiler": ":cc-compiler-armeabi-v7a", + "armeabi-v7a": ":cc-compiler-armeabi-v7a", + }, +) + +cc_toolchain( + name = "cc-compiler-k8", + all_files = ":compiler_deps", + ar_files = ":empty", + as_files = ":empty", + compiler_files = ":compiler_deps", + dwp_files = ":empty", + linker_files = ":compiler_deps", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_gnu_x86", + toolchain_identifier = "linux_gnu_x86", +) + +cc_toolchain_config( + name = "linux_gnu_x86", + compiler = "/dt7/usr/bin/gcc", + cpu = "k8", +) + +toolchain( + name = "cc-toolchain-k8", + exec_compatible_with = [ + # TODO(b/154931569): add autodiscovered constraints for host CPU and OS. + ], + target_compatible_with = [ + # TODO(b/154931569): add autodiscovered constraints for host CPU and OS. + ], + toolchain = ":cc-compiler-k8", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +# Android tooling requires a default toolchain for the armeabi-v7a cpu. +cc_toolchain( + name = "cc-compiler-armeabi-v7a", + all_files = ":empty", + ar_files = ":empty", + as_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":stub_armeabi-v7a", + toolchain_identifier = "stub_armeabi-v7a", +) + +cc_toolchain_config( + name = "stub_armeabi-v7a", + compiler = "compiler", + cpu = "armeabi-v7a", +) + +toolchain( + name = "cc-toolchain-armeabi-v7a", + exec_compatible_with = [ + # TODO(b/154931569): add autodiscovered constraints for host CPU and OS. + ], + target_compatible_with = [ + "@bazel_tools//platforms:arm", + "@bazel_tools//platforms:android", + ], + toolchain = ":cc-compiler-armabi-v7a", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/WORKSPACE b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/WORKSPACE new file mode 100644 index 0000000..bc05b4c --- /dev/null +++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/WORKSPACE @@ -0,0 +1,2 @@ +# DO NOT EDIT: automatically generated WORKSPACE file for cc_autoconf rule +workspace(name = "local_config_cc") diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl new file mode 100755 index 0000000..12f087e --- /dev/null +++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl @@ -0,0 +1,1732 @@ +# Copyright 2019 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Starlark cc_toolchain configuration rule""" + +load( + "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "action_config", + "artifact_name_pattern", + "env_entry", + "env_set", + "feature", + "feature_set", + "flag_group", + "flag_set", + "make_variable", # @unused + "tool", + "tool_path", + "variable_with_value", + "with_feature_set", +) +load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") + +all_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, + ACTION_NAMES.lto_backend, +] + +all_cpp_compile_actions = [ + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, +] + +preprocessor_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, +] + +codegen_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, +] + +all_link_actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, +] + +def _windows_msvc_impl(ctx): + toolchain_identifier = "msvc_x64" + host_system_name = "local" + target_system_name = "local" + target_cpu = "x64_windows" + target_libc = "msvcrt" + compiler = "msvc-cl" + abi_version = "local" + abi_libc_version = "local" + cc_target_os = None + builtin_sysroot = None + + cxx_builtin_include_directories = [ + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include-fixed", + "/dt7/usr/include", + "/dt7/usr/include/c++/7", + "/dt7/usr/include/c++/7/x86_64-pc-linux-gnu", + "/dt7/usr/include/c++/7/backward", + ] + + cpp_link_nodeps_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_nodeps_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "default_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = "")], + ) + + cpp_link_static_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_static_library, + implies = [ + "nologo", + "archiver_flags", + "input_param_flags", + "linker_param_file", + "msvc_env", + ], + tools = [tool(path = "")], + ) + + assemble_action = action_config( + action_name = ACTION_NAMES.assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = "")], + ) + + preprocess_assemble_action = action_config( + action_name = ACTION_NAMES.preprocess_assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = "")], + ) + + c_compile_action = action_config( + action_name = ACTION_NAMES.c_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "default_compile_flags", + "nologo", + "msvc_env", + "parse_showincludes", + "user_compile_flags", + "sysroot", + "unfiltered_compile_flags", + ], + tools = [tool(path = "")], + ) + + cpp_compile_action = action_config( + action_name = ACTION_NAMES.cpp_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "default_compile_flags", + "nologo", + "msvc_env", + "parse_showincludes", + "user_compile_flags", + "sysroot", + "unfiltered_compile_flags", + ], + tools = [tool(path = "")], + ) + + cpp_link_executable_action = action_config( + action_name = ACTION_NAMES.cpp_link_executable, + implies = [ + "nologo", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "default_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + ], + tools = [tool(path = "")], + ) + + cpp_link_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "default_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = "")], + ) + + action_configs = [ + assemble_action, + preprocess_assemble_action, + c_compile_action, + cpp_compile_action, + cpp_link_executable_action, + cpp_link_dynamic_library_action, + cpp_link_nodeps_dynamic_library_action, + cpp_link_static_library_action, + ] + + msvc_link_env_feature = feature( + name = "msvc_link_env", + env_sets = [ + env_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + env_entries = [env_entry(key = "LIB", value = "")], + ), + ], + ) + + shared_flag_feature = feature( + name = "shared_flag", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["/DLL"])], + ), + ], + ) + + determinism_feature = feature( + name = "determinism", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "/wd4117", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + ], + ), + ], + ), + ], + ) + + sysroot_feature = feature( + name = "sysroot", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + iterate_over = "sysroot", + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + unfiltered_compile_flags_feature = feature( + name = "unfiltered_compile_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["%{unfiltered_compile_flags}"], + iterate_over = "unfiltered_compile_flags", + expand_if_available = "unfiltered_compile_flags", + ), + ], + ), + ], + ) + + copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary") + + input_param_flags_feature = feature( + name = "input_param_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["/IMPLIB:%{interface_library_output_path}"], + expand_if_available = "interface_library_output_path", + ), + ], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{libopts}"], + iterate_over = "libopts", + expand_if_available = "libopts", + ), + ], + ), + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link.object_files", + flag_groups = [flag_group(flags = ["%{libraries_to_link.object_files}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "interface_library", + ), + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_false = "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = ["/WHOLEARCHIVE:%{libraries_to_link.name}"], + expand_if_true = "libraries_to_link.is_whole_archive", + ), + ], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "static_library", + ), + ), + ], + expand_if_available = "libraries_to_link", + ), + ], + ), + ], + ) + + fastbuild_feature = feature( + name = "fastbuild", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["", "/INCREMENTAL:NO"], + ), + ], + ), + ], + implies = ["generate_pdb_file"], + ) + + user_compile_flags_feature = feature( + name = "user_compile_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + archiver_flags_feature = feature( + name = "archiver_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/MACHINE:X64"])], + ), + ], + ) + + static_link_msvcrt_feature = feature(name = "static_link_msvcrt") + + dynamic_link_msvcrt_debug_feature = feature( + name = "dynamic_link_msvcrt_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MDd"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrtd.lib"])], + ), + ], + requires = [feature_set(features = ["dbg"])], + ) + + dbg_feature = feature( + name = "dbg", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["", "/INCREMENTAL:NO"], + ), + ], + ), + ], + implies = ["generate_pdb_file"], + ) + + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/O2"])], + ), + ], + implies = ["frame_pointer"], + ) + + supports_interface_shared_libraries_feature = feature( + name = "supports_interface_shared_libraries", + enabled = True, + ) + + user_link_flags_feature = feature( + name = "user_link_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{user_link_flags}"], + iterate_over = "user_link_flags", + expand_if_available = "user_link_flags", + ), + ], + ), + ], + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = [ + "/DCOMPILER_MSVC", + "/DNOMINMAX", + "/D_WIN32_WINNT=0x0601", + "/D_CRT_SECURE_NO_DEPRECATE", + "/D_CRT_SECURE_NO_WARNINGS", + "/bigobj", + "/Zm500", + "/EHsc", + "/wd4351", + "/wd4291", + "/wd4250", + "/wd4996", + ], + ), + ], + ), + ], + ) + + msvc_compile_env_feature = feature( + name = "msvc_compile_env", + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ], + env_entries = [env_entry(key = "INCLUDE", value = "")], + ), + ], + ) + + preprocessor_defines_feature = feature( + name = "preprocessor_defines", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/D%{preprocessor_defines}"], + iterate_over = "preprocessor_defines", + ), + ], + ), + ], + ) + + generate_pdb_file_feature = feature( + name = "generate_pdb_file", + requires = [ + feature_set(features = ["dbg"]), + feature_set(features = ["fastbuild"]), + ], + ) + + output_execpath_flags_feature = feature( + name = "output_execpath_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ) + + dynamic_link_msvcrt_no_debug_feature = feature( + name = "dynamic_link_msvcrt_no_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MD"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrt.lib"])], + ), + ], + requires = [ + feature_set(features = ["fastbuild"]), + feature_set(features = ["opt"]), + ], + ) + + disable_assertions_feature = feature( + name = "disable_assertions", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/DNDEBUG"])], + with_features = [with_feature_set(features = ["opt"])], + ), + ], + ) + + has_configured_linker_path_feature = feature(name = "has_configured_linker_path") + + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + + no_stripping_feature = feature(name = "no_stripping") + + linker_param_file_feature = feature( + name = "linker_param_file", + flag_sets = [ + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["@%{linker_param_file}"], + expand_if_available = "linker_param_file", + ), + ], + ), + ], + ) + + ignore_noisy_warnings_feature = feature( + name = "ignore_noisy_warnings", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [flag_group(flags = ["/ignore:4221"])], + ), + ], + ) + + no_legacy_features_feature = feature(name = "no_legacy_features") + + parse_showincludes_feature = feature( + name = "parse_showincludes", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_header_parsing, + ], + flag_groups = [flag_group(flags = ["/showIncludes"])], + ), + ], + ) + + static_link_msvcrt_no_debug_feature = feature( + name = "static_link_msvcrt_no_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MT"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmt.lib"])], + ), + ], + requires = [ + feature_set(features = ["fastbuild"]), + feature_set(features = ["opt"]), + ], + ) + + treat_warnings_as_errors_feature = feature( + name = "treat_warnings_as_errors", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/WX"])], + ), + ], + ) + + windows_export_all_symbols_feature = feature(name = "windows_export_all_symbols") + + no_windows_export_all_symbols_feature = feature(name = "no_windows_export_all_symbols") + + include_paths_feature = feature( + name = "include_paths", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/I%{quote_include_paths}"], + iterate_over = "quote_include_paths", + ), + flag_group( + flags = ["/I%{include_paths}"], + iterate_over = "include_paths", + ), + flag_group( + flags = ["/I%{system_include_paths}"], + iterate_over = "system_include_paths", + ), + ], + ), + ], + ) + + linkstamps_feature = feature( + name = "linkstamps", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{linkstamp_paths}"], + iterate_over = "linkstamp_paths", + expand_if_available = "linkstamp_paths", + ), + ], + ), + ], + ) + + targets_windows_feature = feature( + name = "targets_windows", + enabled = True, + implies = ["copy_dynamic_libraries_to_binary"], + ) + + linker_subsystem_flag_feature = feature( + name = "linker_subsystem_flag", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/SUBSYSTEM:CONSOLE"])], + ), + ], + ) + + static_link_msvcrt_debug_feature = feature( + name = "static_link_msvcrt_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MTd"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmtd.lib"])], + ), + ], + requires = [feature_set(features = ["dbg"])], + ) + + frame_pointer_feature = feature( + name = "frame_pointer", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Oy-"])], + ), + ], + ) + + compiler_output_flags_feature = feature( + name = "compiler_output_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.assemble], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}", "/Zi"], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + ], + expand_if_not_available = "output_preprocess_file", + ), + ], + ), + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}"], + expand_if_not_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fa%{output_file}"], + expand_if_available = "output_assembly_file", + ), + ], + expand_if_available = "output_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/P", "/Fi%{output_file}"], + expand_if_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + ), + ], + ), + ], + ) + + nologo_feature = feature( + name = "nologo", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + flag_groups = [flag_group(flags = ["/nologo"])], + ), + ], + ) + + smaller_binary_feature = feature( + name = "smaller_binary", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Gy", "/Gw"])], + with_features = [with_feature_set(features = ["opt"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/OPT:ICF", "/OPT:REF"])], + with_features = [with_feature_set(features = ["opt"])], + ), + ], + ) + + compiler_input_flags_feature = feature( + name = "compiler_input_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["/c", "%{source_file}"], + expand_if_available = "source_file", + ), + ], + ), + ], + ) + + def_file_feature = feature( + name = "def_file", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/DEF:%{def_file_path}", "/ignore:4070"], + expand_if_available = "def_file_path", + ), + ], + ), + ], + ) + + msvc_env_feature = feature( + name = "msvc_env", + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + env_entries = [ + env_entry(key = "PATH", value = ""), + env_entry(key = "TMP", value = ""), + env_entry(key = "TEMP", value = ""), + ], + ), + ], + implies = ["msvc_compile_env", "msvc_link_env"], + ) + + features = [ + no_legacy_features_feature, + nologo_feature, + has_configured_linker_path_feature, + no_stripping_feature, + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + default_compile_flags_feature, + msvc_env_feature, + msvc_compile_env_feature, + msvc_link_env_feature, + include_paths_feature, + preprocessor_defines_feature, + parse_showincludes_feature, + generate_pdb_file_feature, + shared_flag_feature, + linkstamps_feature, + output_execpath_flags_feature, + archiver_flags_feature, + input_param_flags_feature, + linker_subsystem_flag_feature, + user_link_flags_feature, + default_link_flags_feature, + linker_param_file_feature, + static_link_msvcrt_feature, + static_link_msvcrt_no_debug_feature, + dynamic_link_msvcrt_no_debug_feature, + static_link_msvcrt_debug_feature, + dynamic_link_msvcrt_debug_feature, + dbg_feature, + fastbuild_feature, + opt_feature, + frame_pointer_feature, + disable_assertions_feature, + determinism_feature, + treat_warnings_as_errors_feature, + smaller_binary_feature, + ignore_noisy_warnings_feature, + user_compile_flags_feature, + sysroot_feature, + unfiltered_compile_flags_feature, + compiler_output_flags_feature, + compiler_input_flags_feature, + def_file_feature, + windows_export_all_symbols_feature, + no_windows_export_all_symbols_feature, + supports_dynamic_linker_feature, + supports_interface_shared_libraries_feature, + ] + + artifact_name_patterns = [ + artifact_name_pattern( + category_name = "object_file", + prefix = "", + extension = ".obj", + ), + artifact_name_pattern( + category_name = "static_library", + prefix = "", + extension = ".lib", + ), + artifact_name_pattern( + category_name = "alwayslink_static_library", + prefix = "", + extension = ".lo.lib", + ), + artifact_name_pattern( + category_name = "executable", + prefix = "", + extension = ".exe", + ), + artifact_name_pattern( + category_name = "dynamic_library", + prefix = "", + extension = ".dll", + ), + artifact_name_pattern( + category_name = "interface_library", + prefix = "", + extension = ".if.lib", + ), + ] + + make_variables = [] + + tool_paths = [ + tool_path(name = "ar", path = ""), + tool_path(name = "ml", path = ""), + tool_path(name = "cpp", path = ""), + tool_path(name = "gcc", path = ""), + tool_path(name = "gcov", path = "wrapper/bin/msvc_nop.bat"), + tool_path(name = "ld", path = ""), + tool_path(name = "nm", path = "wrapper/bin/msvc_nop.bat"), + tool_path( + name = "objcopy", + path = "wrapper/bin/msvc_nop.bat", + ), + tool_path( + name = "objdump", + path = "wrapper/bin/msvc_nop.bat", + ), + tool_path( + name = "strip", + path = "wrapper/bin/msvc_nop.bat", + ), + ] + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = cxx_builtin_include_directories, + toolchain_identifier = toolchain_identifier, + host_system_name = host_system_name, + target_system_name = target_system_name, + target_cpu = target_cpu, + target_libc = target_libc, + compiler = compiler, + abi_version = abi_version, + abi_libc_version = abi_libc_version, + tool_paths = tool_paths, + make_variables = make_variables, + builtin_sysroot = builtin_sysroot, + cc_target_os = None, + ) + +def _windows_msys_mingw_impl(ctx): + toolchain_identifier = "msys_x64_mingw" + host_system_name = "local" + target_system_name = "local" + target_cpu = "x64_windows" + target_libc = "mingw" + compiler = "mingw-gcc" + abi_version = "local" + abi_libc_version = "local" + cc_target_os = None + builtin_sysroot = None + action_configs = [] + + targets_windows_feature = feature( + name = "targets_windows", + implies = ["copy_dynamic_libraries_to_binary"], + enabled = True, + ) + + copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary") + + gcc_env_feature = feature( + name = "gcc_env", + enabled = True, + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + env_entries = [ + env_entry(key = "PATH", value = "NOT_USED"), + ], + ), + ], + ) + + msys_mingw_flags = [ + ] + msys_mingw_link_flags = [ + ] + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + ), + flag_set( + actions = [ + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = ([flag_group(flags = msys_mingw_flags)] if msys_mingw_flags else []), + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = ([flag_group(flags = msys_mingw_link_flags)] if msys_mingw_link_flags else []), + ), + ], + ) + + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + + features = [ + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + gcc_env_feature, + default_compile_flags_feature, + default_link_flags_feature, + supports_dynamic_linker_feature, + ] + + cxx_builtin_include_directories = [ + ] + + artifact_name_patterns = [ + artifact_name_pattern( + category_name = "executable", + prefix = "", + extension = ".exe", + ), + ] + + make_variables = [] + tool_paths = [ + ] + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = cxx_builtin_include_directories, + toolchain_identifier = toolchain_identifier, + host_system_name = host_system_name, + target_system_name = target_system_name, + target_cpu = target_cpu, + target_libc = target_libc, + compiler = compiler, + abi_version = abi_version, + abi_libc_version = abi_libc_version, + tool_paths = tool_paths, + make_variables = make_variables, + builtin_sysroot = builtin_sysroot, + cc_target_os = cc_target_os, + ) + +def _armeabi_impl(ctx): + toolchain_identifier = "stub_armeabi-v7a" + host_system_name = "armeabi-v7a" + target_system_name = "armeabi-v7a" + target_cpu = "armeabi-v7a" + target_libc = "armeabi-v7a" + compiler = "compiler" + abi_version = "armeabi-v7a" + abi_libc_version = "armeabi-v7a" + cc_target_os = None + builtin_sysroot = None + action_configs = [] + + supports_pic_feature = feature(name = "supports_pic", enabled = True) + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + features = [supports_dynamic_linker_feature, supports_pic_feature] + + cxx_builtin_include_directories = [] + artifact_name_patterns = [] + make_variables = [] + + tool_paths = [ + tool_path(name = "ar", path = "/bin/false"), + tool_path(name = "compat-ld", path = "/bin/false"), + tool_path(name = "cpp", path = "/bin/false"), + tool_path(name = "dwp", path = "/bin/false"), + tool_path(name = "gcc", path = "/bin/false"), + tool_path(name = "gcov", path = "/bin/false"), + tool_path(name = "ld", path = "/bin/false"), + tool_path(name = "nm", path = "/bin/false"), + tool_path(name = "objcopy", path = "/bin/false"), + tool_path(name = "objdump", path = "/bin/false"), + tool_path(name = "strip", path = "/bin/false"), + ] + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = cxx_builtin_include_directories, + toolchain_identifier = toolchain_identifier, + host_system_name = host_system_name, + target_system_name = target_system_name, + target_cpu = target_cpu, + target_libc = target_libc, + compiler = compiler, + abi_version = abi_version, + abi_libc_version = abi_libc_version, + tool_paths = tool_paths, + make_variables = make_variables, + builtin_sysroot = builtin_sysroot, + cc_target_os = cc_target_os, + ) + +def _impl(ctx): + if ctx.attr.cpu == "armeabi-v7a": + return _armeabi_impl(ctx) + elif ctx.attr.cpu == "x64_windows" and ctx.attr.compiler == "msvc-cl": + return _windows_msvc_impl(ctx) + elif ctx.attr.cpu == "x64_windows" and ctx.attr.compiler == "mingw-gcc": + return _windows_msys_mingw_impl(ctx) + + tool_paths = [ + tool_path(name = "ar", path = "/usr/bin/ar"), + tool_path(name = "ld", path = "/usr/bin/ld"), + tool_path(name = "cpp", path = "/usr/bin/cpp"), + tool_path(name = "gcc", path = "/dt7/usr/bin/gcc"), + tool_path(name = "dwp", path = "/usr/bin/dwp"), + tool_path(name = "gcov", path = "/usr/bin/gcov"), + tool_path(name = "nm", path = "/usr/bin/nm"), + tool_path(name = "objcopy", path = "/usr/bin/objcopy"), + tool_path(name = "objdump", path = "/usr/bin/objdump"), + tool_path(name = "strip", path = "/usr/bin/strip"), + ] + + cxx_builtin_include_directories = [ + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include-fixed", + "/dt7/usr/include", + "/dt7/usr/include/c++/7", + "/dt7/usr/include/c++/7/x86_64-pc-linux-gnu", + "/dt7/usr/include/c++/7/backward", + ] + + action_configs = [] + + compile_flags = [ + "-U_FORTIFY_SOURCE", + "-fstack-protector", + "-Wall", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fno-omit-frame-pointer", + ] + + dbg_compile_flags = [ + "-g", + ] + + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ] + + cxx_flags = [ + "-std=c++0x", + ] + + link_flags = [ + "-fuse-ld=gold", + "-Wl,-no-as-needed", + "-Wl,-z,relro,-z,now", + "-B/dt7/usr/bin", + "-pass-exit-codes", + "-lstdc++", + "-lm", + ] + + opt_link_flags = [ + "-Wl,--gc-sections", + ] + + unfiltered_compile_flags = [ + "-fno-canonical-system-headers", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + ] + + targets_windows_feature = feature( + name = "targets_windows", + implies = ["copy_dynamic_libraries_to_binary"], + enabled = True, + ) + + copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary") + + gcc_env_feature = feature( + name = "gcc_env", + enabled = True, + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + env_entries = [ + env_entry(key = "PATH", value = "NOT_USED"), + ], + ), + ], + ) + + windows_features = [ + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + gcc_env_feature, + ] + + coverage_feature = feature( + name = "coverage", + provides = ["profile"], + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group(flags = ["--coverage"]), + ], + ), + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_executable, + ], + flag_groups = [ + flag_group(flags = ["--coverage"]), + ], + ), + ], + ) + + supports_pic_feature = feature( + name = "supports_pic", + enabled = True, + ) + supports_start_end_lib_feature = feature( + name = "supports_start_end_lib", + enabled = True, + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = ([flag_group(flags = compile_flags)] if compile_flags else []), + ), + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = ([flag_group(flags = dbg_compile_flags)] if dbg_compile_flags else []), + with_features = [with_feature_set(features = ["dbg"])], + ), + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = ([flag_group(flags = opt_compile_flags)] if opt_compile_flags else []), + with_features = [with_feature_set(features = ["opt"])], + ), + flag_set( + actions = [ + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = ([flag_group(flags = cxx_flags)] if cxx_flags else []), + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = ([flag_group(flags = link_flags)] if link_flags else []), + ), + flag_set( + actions = all_link_actions, + flag_groups = ([flag_group(flags = opt_link_flags)] if opt_link_flags else []), + with_features = [with_feature_set(features = ["opt"])], + ), + ], + ) + + dbg_feature = feature(name = "dbg") + + opt_feature = feature(name = "opt") + + sysroot_feature = feature( + name = "sysroot", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + fdo_optimize_feature = feature( + name = "fdo_optimize", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-fprofile-use=%{fdo_profile_path}", + "-fprofile-correction", + ], + expand_if_available = "fdo_profile_path", + ), + ], + ), + ], + provides = ["profile"], + ) + + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + + user_compile_flags_feature = feature( + name = "user_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + unfiltered_compile_flags_feature = feature( + name = "unfiltered_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = ([flag_group(flags = unfiltered_compile_flags)] if unfiltered_compile_flags else []), + ), + ], + ) + + features = [ + supports_pic_feature, + supports_start_end_lib_feature, + coverage_feature, + default_compile_flags_feature, + default_link_flags_feature, + fdo_optimize_feature, + supports_dynamic_linker_feature, + dbg_feature, + opt_feature, + user_compile_flags_feature, + sysroot_feature, + unfiltered_compile_flags_feature, + ] + + artifact_name_patterns = [ + ] + + make_variables = [] + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = cxx_builtin_include_directories, + toolchain_identifier = "linux_gnu_x86", + host_system_name = "i686-unknown-linux-gnu", + target_system_name = "x86_64-unknown-linux-gnu", + target_cpu = "k8", + target_libc = "glibc_2.19", + compiler = "/dt7/usr/bin/gcc", + abi_version = "gcc", + abi_libc_version = "glibc_2.19", + tool_paths = tool_paths, + make_variables = make_variables, + builtin_sysroot = "", + cc_target_os = None, + ) + +cc_toolchain_config = rule( + implementation = _impl, + attrs = { + "cpu": attr.string(mandatory = True), + "compiler": attr.string(), + }, + provides = [CcToolchainConfigInfo], +) diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_wrapper.sh b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_wrapper.sh new file mode 100755 index 0000000..898befb --- /dev/null +++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_wrapper.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# +# Copyright 2015 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Ship the environment to the C++ action +# +set -eu + +# Set-up the environment + + +# Call the C++ compiler +/dt7/usr/bin/gcc "$@" diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl new file mode 100755 index 0000000..85b3412 --- /dev/null +++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl @@ -0,0 +1,23 @@ +# pylint: disable=g-bad-file-header +# Copyright 2017 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Starlark rule that stubs a toolchain.""" + +def _dummy_toolchain_impl(ctx): + ctx = ctx # unused argument + toolchain = platform_common.ToolchainInfo() + return [toolchain] + +dummy_toolchain = rule(_dummy_toolchain_impl, attrs = {}) diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/tools/cpp/empty.cc b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/tools/cpp/empty.cc new file mode 100755 index 0000000..40da260 --- /dev/null +++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/tools/cpp/empty.cc @@ -0,0 +1,15 @@ +// Copyright 2019 DeepMind Technologies Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +int main() {}