Skip to content

Commit

Permalink
Fix TSAN error by adding some indirection to TableItem.
Browse files Browse the repository at this point in the history
We have taken a lot of care to make sure that we always makes safe use of the TableItem even though it is accessed from multiple threads at once. These accesses includes modificiations of the `priority` and `times_sampled` field which was assumed to be safe as these fields are guaranteed not to be read at the same time as they are changed. HOWEVER, turns out that modifying a single field within a proto actualy invalidates all the fields of the proto so the field access we are doing is not actually safe.

This CL modifies TableItem so that it provides read access to the static fields and then keeps the mutable fields isolated from the rest of the proto. This means that we can modify these two fields while concurrently reading from the remaining fields without risking data races.

PiperOrigin-RevId: 549590469
Change-Id: Ia369082aba22ea5d78f6c5095c7f89e18219c642
  • Loading branch information
acassirer authored and copybara-github committed Jul 20, 2023
1 parent 190eae0 commit 33aba4e
Show file tree
Hide file tree
Showing 12 changed files with 267 additions and 154 deletions.
30 changes: 14 additions & 16 deletions reverb/cc/platform/tfrecord_checkpointer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,37 +506,35 @@ absl::Status LoadWithCompression(absl::string_view path,

for (auto& table : *tables) {
for (auto& checkpoint_item : table_to_items[table->name()]) {
Table::Item insert_item;
std::swap(insert_item.item, checkpoint_item);

if (insert_item.item.has_deprecated_sequence_range()) {
if (checkpoint_item.has_deprecated_sequence_range()) {
std::vector<std::shared_ptr<ChunkStore::Chunk>> trajectory_chunks;
REVERB_RETURN_IF_ERROR(chunk_store->Get(
insert_item.item.deprecated_chunk_keys(), &trajectory_chunks));
checkpoint_item.deprecated_chunk_keys(), &trajectory_chunks));

*insert_item.item.mutable_flat_trajectory() =
*checkpoint_item.mutable_flat_trajectory() =
internal::FlatTimestepTrajectory(
trajectory_chunks,
insert_item.item.deprecated_sequence_range().offset(),
insert_item.item.deprecated_sequence_range().length());
checkpoint_item.deprecated_sequence_range().offset(),
checkpoint_item.deprecated_sequence_range().length());

insert_item.item.clear_deprecated_sequence_range();
insert_item.item.clear_deprecated_chunk_keys();
checkpoint_item.clear_deprecated_sequence_range();
checkpoint_item.clear_deprecated_chunk_keys();
}

std::vector<std::shared_ptr<ChunkStore::Chunk>> chunks;
REVERB_RETURN_IF_ERROR(chunk_store->Get(
internal::GetChunkKeys(insert_item.item.flat_trajectory()),
&insert_item.chunks));
internal::GetChunkKeys(checkpoint_item.flat_trajectory()), &chunks));

// The original table has already been destroyed so if this fails then
// there is way to recover.
REVERB_RETURN_IF_ERROR(
table->InsertCheckpointItem(std::move(insert_item)));
REVERB_RETURN_IF_ERROR(table->InsertCheckpointItem(
Table::Item(std::move(checkpoint_item), std::move(chunks))));
}

REVERB_LOG(REVERB_INFO)
<< "Table " << table->name()
<< " has been successfully loaded from the checkpoint.";
<< "Table " << table->name() << " and " << table->size()
<< " items have been successfully loaded from checkpoint at path "
<< path << ".";
}

REVERB_LOG(REVERB_INFO) << "Successfully loaded " << table_checkpoints.size()
Expand Down
7 changes: 4 additions & 3 deletions reverb/cc/platform/tfrecord_checkpointer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,11 @@ TEST(TFRecordCheckpointerTest, SaveAndLoad) {
REVERB_EXPECT_OK(loaded_tables[i]->Sample(&sample));
bool item_found = false;
for (auto& item : tables[i]->Copy()) {
if (item.item.key() == sample.ref->item.key()) {
if (item.key() == sample.ref->key()) {
item_found = true;
item.item.set_times_sampled(item.item.times_sampled() + 1);
EXPECT_THAT(item.item, EqualsProto(sample.ref->item));
item.set_times_sampled(item.times_sampled() + 1);
EXPECT_THAT(item.AsPrioritizedItem(),
EqualsProto(sample.ref->AsPrioritizedItem()));
break;
}
}
Expand Down
49 changes: 24 additions & 25 deletions reverb/cc/reverb_service_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "reverb/cc/platform/status_macros.h"
#include "reverb/cc/reverb_server_reactor.h"
#include "reverb/cc/reverb_service.pb.h"
#include "reverb/cc/schema.pb.h"
#include "reverb/cc/support/grpc_util.h"
#include "reverb/cc/support/trajectory_util.h"
#include "reverb/cc/support/uint128.h"
Expand Down Expand Up @@ -231,19 +232,20 @@ ReverbServiceImpl::InsertStream(grpc::CallbackServerContext* context) {
}
bool can_insert = true;
for (auto& request_item : *request->mutable_items()) {
Table::Item item;
if (auto status = GetItemWithChunks(&item, &request_item);
!status.ok()) {
return status;
auto item_or_status = GetItemWithChunks(std::move(request_item));
if (!item_or_status.ok()) {
return ToGrpcStatus(item_or_status.status());
}
const auto& table_name = item.item.table();

const auto& table_name = item_or_status->table();
// Check that table name is valid.
auto table = server_->TableByName(table_name);
if (table == nullptr) {
return TableNotFound(table_name);
}
if (auto status = table->InsertOrAssignAsync(
std::move(item), &can_insert, insert_completed_);
if (auto status =
table->InsertOrAssignAsync(std::move(item_or_status).value(),
&can_insert, insert_completed_);
!status.ok()) {
return ToGrpcStatus(status);
}
Expand Down Expand Up @@ -272,22 +274,20 @@ ReverbServiceImpl::InsertStream(grpc::CallbackServerContext* context) {
return grpc::Status::OK;
}

grpc::Status GetItemWithChunks(
Table::Item* item,
PrioritizedItem* request_item) {
absl::StatusOr<Table::Item> GetItemWithChunks(
PrioritizedItem request_item) {
std::vector<std::shared_ptr<ChunkStore::Chunk>> chunks;
for (ChunkStore::Key key :
internal::GetChunkKeys(request_item->flat_trajectory())) {
internal::GetChunkKeys(request_item.flat_trajectory())) {
auto it = chunks_.find(key);
if (it == chunks_.end()) {
return Internal(
return absl::InternalError(
absl::StrCat("Could not find sequence chunk ", key, "."));
}
item->chunks.push_back(it->second);
chunks.push_back(it->second);
}

item->item = std::move(*request_item);

return grpc::Status::OK;
return Table::Item(std::move(request_item), std::move(chunks));
}

grpc::Status ReleaseOutOfRangeChunks(absl::Span<const uint64_t> keep_keys) {
Expand Down Expand Up @@ -623,31 +623,30 @@ ReverbServiceImpl::SampleStream(grpc::CallbackServerContext* context) {
}
SampleStreamResponseCtx* response = &responses_to_send_.back();
auto* entry = response->payload.add_entries();
for (int i = 0; i < sample->ref->chunks.size(); i++) {
entry->set_end_of_sequence(i + 1 == sample->ref->chunks.size());
for (int i = 0; i < sample->ref->chunks().size(); i++) {
entry->set_end_of_sequence(i + 1 == sample->ref->chunks().size());
// Attach the info to the first message.
if (i == 0) {
auto* item = entry->mutable_info()->mutable_item();
auto& sample_item = sample->ref->item;
item->set_key(sample_item.key());
item->set_table(sample_item.table());
item->set_key(sample->ref->key());
item->set_table(std::string(sample->ref->table()));
item->set_priority(sample->priority);
item->set_times_sampled(sample->times_sampled);
// ~SampleStreamResponseCtx releases these fields from the proto
// upon destruction of the item.
item->/*unsafe_arena_*/set_allocated_inserted_at(
sample_item.mutable_inserted_at());
sample->ref->unsafe_mutable_inserted_at());
item->/*unsafe_arena_*/set_allocated_flat_trajectory(
sample_item.mutable_flat_trajectory());
sample->ref->unsafe_mutable_flat_trajectory());
entry->mutable_info()->set_probability(sample->probability);
entry->mutable_info()->set_table_size(sample->table_size);
entry->mutable_info()->set_rate_limited(sample->rate_limited);
}
ChunkData* chunk =
const_cast<ChunkData*>(&sample->ref->chunks[i]->data());
const_cast<ChunkData*>(&sample->ref->chunks()[i]->data());
current_response_size_bytes_ += chunk->ByteSizeLong();
entry->mutable_data()->UnsafeArenaAddAllocated(chunk);
if (i < sample->ref->chunks.size() - 1 &&
if (i < sample->ref->chunks().size() - 1 &&
current_response_size_bytes_ > kMaxSampleResponseSizeBytes) {
// Current response is too big, start a new one.
responses_to_send_.emplace();
Expand Down
13 changes: 9 additions & 4 deletions reverb/cc/reverb_service_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -709,12 +709,17 @@ TEST(ReverbServiceImplTest, InitializeConnectionFromOtherProcess) {
TEST(InsertWorkerTest, InsertWorkerReturnsCorrectStats) {
auto insert_worker = std::make_unique<InsertWorker>(
/*num_threads=*/1, /*max_queue_size_to_warn=*/3, "TestWorker");
Table::Item item;
item.item.set_table("my_table");

PrioritizedItem prioritized_item;
prioritized_item.set_table("my_table");

Table::Item item(std::move(prioritized_item), {});
absl::BlockingCounter counter(2);
for (int i = 0; i < 2; i++) {
InsertTaskInfo task_info;
task_info.item = item;
InsertTaskInfo task_info = {
.item = item,
.table = nullptr,
};
insert_worker->Schedule(
task_info,
[&counter](InsertTaskInfo task_info, const absl::Status& status,
Expand Down
14 changes: 6 additions & 8 deletions reverb/cc/sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,15 @@ absl::Status AsSample(std::vector<SampleStreamResponse::SampleEntry> responses,
absl::Status AsSample(const Table::SampledItem& sampled_item,
std::unique_ptr<Sample>* sample) {
internal::flat_hash_map<uint64_t, std::shared_ptr<ChunkStore::Chunk>> chunks(
sampled_item.ref->chunks.size());
for (auto& chunk : sampled_item.ref->chunks) {
sampled_item.ref->chunks().size());
for (auto& chunk : sampled_item.ref->chunks()) {
chunks[chunk->key()] = chunk;
}

std::vector<std::vector<tensorflow::Tensor>> column_chunks;
column_chunks.reserve(
sampled_item.ref->item.flat_trajectory().columns_size());
column_chunks.reserve(sampled_item.ref->flat_trajectory().columns_size());

for (const auto& column :
sampled_item.ref->item.flat_trajectory().columns()) {
for (const auto& column : sampled_item.ref->flat_trajectory().columns()) {
std::vector<tensorflow::Tensor> unpacked_chunks;

for (const auto& slice : column.chunk_slices()) {
Expand All @@ -154,11 +152,11 @@ absl::Status AsSample(const Table::SampledItem& sampled_item,
}

std::vector<bool> squeeze_columns;
for (const auto& col : sampled_item.ref->item.flat_trajectory().columns()) {
for (const auto& col : sampled_item.ref->flat_trajectory().columns()) {
squeeze_columns.push_back(col.squeeze());
}
auto info = std::make_shared<SampleInfo>();
info->mutable_item()->set_key(sampled_item.ref->item.key());
info->mutable_item()->set_key(sampled_item.ref->key());
info->mutable_item()->set_priority(sampled_item.priority);
info->mutable_item()->set_times_sampled(sampled_item.times_sampled);
info->set_probability(sampled_item.probability);
Expand Down
23 changes: 11 additions & 12 deletions reverb/cc/sampler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,33 +230,31 @@ ChunkData MakeChunkData(uint64_t key, SequenceRange range) {
TableItem MakeItem(uint64_t key, double priority,
const std::vector<SequenceRange>& sequences, int32_t offset,
int32_t length) {
TableItem item;

std::vector<std::shared_ptr<ChunkStore::Chunk>> chunks;
std::vector<ChunkData> data(sequences.size());
for (int i = 0; i < sequences.size(); i++) {
data[i] = MakeChunkData(key * 100 + i, sequences[i]);
item.chunks.push_back(std::make_shared<ChunkStore::Chunk>(data[i]));
chunks.push_back(std::make_shared<ChunkStore::Chunk>(data[i]));
}

item.item = testing::MakePrioritizedItem(key, priority, data);
Table::Item item(testing::MakePrioritizedItem(key, priority, data),
std::move(chunks));

int32_t remaining = length;
for (int slice_index = 0; slice_index < sequences.size(); slice_index++) {
for (int col_index = 0;
col_index < item.item.flat_trajectory().columns_size(); col_index++) {
for (int col_index = 0; col_index < item.flat_trajectory().columns_size();
col_index++) {
auto* col =
item.item.mutable_flat_trajectory()->mutable_columns(col_index);
item.unsafe_mutable_flat_trajectory()->mutable_columns(col_index);
auto* slice = col->mutable_chunk_slices(slice_index);
slice->set_offset(offset);
slice->set_length(
std::min<int32_t>(slice->length() - slice->offset(), remaining));
slice->set_index(col_index);
}

remaining -= item.item.flat_trajectory()
.columns(0)
.chunk_slices(slice_index)
.length();
remaining -=
item.flat_trajectory().columns(0).chunk_slices(slice_index).length();
offset = 0;
}

Expand All @@ -282,7 +280,8 @@ void InsertItem(Table* table, uint64_t key, double priority,
}

auto item = MakeItem(key, priority, ranges, offset, length);
item.item.mutable_flat_trajectory()->mutable_columns(0)->set_squeeze(squeeze);
item.unsafe_mutable_flat_trajectory()->mutable_columns(0)->set_squeeze(
squeeze);
REVERB_EXPECT_OK(table->InsertOrAssign(std::move(item)));
}

Expand Down
6 changes: 3 additions & 3 deletions reverb/cc/support/signature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,17 @@ tensorflow::StructuredValue StructuredValueFromItem(const TableItem& item) {
tensorflow::StructuredValue value;

auto get_tensor = [&](const FlatTrajectory::ChunkSlice& slice) {
for (const auto& chunk : item.chunks) {
for (const auto& chunk : item.chunks()) {
if (chunk->key() == slice.chunk_key()) {
return &chunk->data().data().tensors(slice.index());
}
}
REVERB_CHECK(false) << "Invalid item.";
};

for (int col_idx = 0; col_idx < item.item.flat_trajectory().columns_size();
for (int col_idx = 0; col_idx < item.flat_trajectory().columns_size();
col_idx++) {
const auto& col = item.item.flat_trajectory().columns(col_idx);
const auto& col = item.flat_trajectory().columns(col_idx);
const auto* tensor_proto = get_tensor(col.chunk_slices(0));

auto* spec =
Expand Down
Loading

0 comments on commit 33aba4e

Please sign in to comment.