Skip to content

Commit

Permalink
Record device time measurements in PJRT stream executor client. Set d…
Browse files Browse the repository at this point in the history
…evice type to the platform that the client is running on.

PiperOrigin-RevId: 725859515
  • Loading branch information
Google-ML-Automation committed Feb 12, 2025
1 parent d6be12c commit 3352841
Show file tree
Hide file tree
Showing 11 changed files with 345 additions and 20 deletions.
7 changes: 6 additions & 1 deletion xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,9 @@ cc_library(

cc_library(
name = "pjrt_stream_executor_client",
srcs = ["pjrt_stream_executor_client.cc"],
srcs = [
"pjrt_stream_executor_client.cc",
],
hdrs = ["pjrt_stream_executor_client.h"],
visibility = internal_visibility(["//xla:friends"]),
deps = [
Expand Down Expand Up @@ -511,6 +513,7 @@ cc_library(
"//xla/hlo/builder:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/pjrt/distributed:protocol_proto_cc",
"//xla/pjrt/profiling:device_time_measurement",
"//xla/service:compiler",
"//xla/service:computation_layout",
"//xla/service:computation_placer",
Expand Down Expand Up @@ -572,6 +575,7 @@ xla_cc_test(
"//xla/client:local_client",
"//xla/hlo/builder:xla_builder",
"//xla/hlo/testlib:test",
"//xla/pjrt/profiling:device_time_measurement",
"//xla/service:cpu_plugin",
"//xla/service:platform_util",
"//xla/stream_executor:platform",
Expand All @@ -582,6 +586,7 @@ xla_cc_test(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
2 changes: 2 additions & 0 deletions xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ xla_cc_test(
"//xla/pjrt/distributed:client",
"//xla/pjrt/distributed:in_memory_key_value_store",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
"//xla/pjrt/profiling:device_time_measurement",
"//xla/pjrt/profiling/test_util:mock_device_time_measurement",
"//xla/service:gpu_plugin",
"//xla/service:platform_util",
"//xla/stream_executor:device_memory",
Expand Down
62 changes: 62 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ limitations under the License.
#include "xla/pjrt/pjrt_future.h"
#include "xla/pjrt/pjrt_stream_executor_client.h"
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
#include "xla/pjrt/profiling/device_time_measurement.h"
#include "xla/pjrt/profiling/test_util/mock_device_time_measurement.h"
#include "xla/service/platform_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -1875,6 +1877,66 @@ TEST(StreamExecutorGpuClientTest, AutoLayoutIsSupported) {
EXPECT_NE(layouts[1]->ToString(), "{2,1,0}");
}

// Same test as SendRecvChunked, but check GPU device time measurement.
TEST(StreamExecutorGpuClientTest, NonZeroGPUDeviceTimeMeasurement) {
TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));

TF_ASSERT_OK_AND_ASSIGN(auto executable,
CompileExecutable(kProgram, *client));

std::array<float, 2> sent_value = {0.0f, 0.0f};

// Send buffer to host.
SendCallback send_callback = {
/*channel_id=*/1, [&](const PjRtTransferMetadata& m, PjRtChunk chunk,
int64_t total_size_in_bytes, bool done) {
float* data = reinterpret_cast<float*>(chunk.data());
sent_value[0] = data[0];
sent_value[1] = data[1];
return absl::OkStatus();
}};

// Recv buffer from host.
RecvCallback recv_callback = {
/*channel_id=*/2, [&](const PjRtTransferMetadata& m,
std::unique_ptr<CopyToDeviceStream> stream) {
auto chunk0 = PjRtChunk::AllocateDefault(sizeof(float));
*reinterpret_cast<float*>(chunk0.data()) = 5.0f;
TF_CHECK_OK(stream->AddChunk(std::move(chunk0)).Await());

auto chunk1 = PjRtChunk::AllocateDefault(sizeof(float));
*reinterpret_cast<float*>(chunk1.data()) = 6.0f;
TF_CHECK_OK(stream->AddChunk(std::move(chunk1)).Await());

return absl::OkStatus();
}};

// Callbacks for point-to-point communication ops.
std::vector<std::vector<SendCallback>> send_callbacks = {{send_callback}};
std::vector<std::vector<RecvCallback>> recv_callbacks = {{recv_callback}};

ExecuteOptions opts;
opts.send_callbacks = send_callbacks;
opts.recv_callbacks = recv_callbacks;

// Test non-zero GPU device time measurement.
auto measurement0 = CreateDeviceTimeMeasurement();
auto result = executable->Execute(/*argument_handles=*/{{}}, opts);

TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<xla::Literal> result_literal,
ExtractSingleResult(result));
EXPECT_EQ(sent_value[0], 2.0f);
EXPECT_EQ(sent_value[1], 3.0f);
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<float>({5.0f, 6.0f}),
*result_literal));

// Check measurement after execution completes.
EXPECT_GT(
measurement0->GetTotalDuration(DeviceTimeMeasurement::DeviceType::kGpu),
absl::ZeroDuration());
}

struct ShardedAutotuningTestInfo {
bool use_xla_computation;
int num_active_nodes;
Expand Down
34 changes: 34 additions & 0 deletions xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ limitations under the License.
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/pjrt/profiling/device_time_measurement.h"
#include "xla/pjrt/semaphore.h"
#include "xla/pjrt/tracked_device_buffer.h"
#include "xla/pjrt/transpose.h"
Expand Down Expand Up @@ -2843,6 +2844,18 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
device_state->compute_semaphore().ScopedAcquire(1));
}

auto start_time_ns = std::make_shared<uint64_t>();
std::optional<uint64_t> key = xla::GetDeviceTimeMeasurementKey();
// Record the start time of the execution by placing a callback on the stream
// directly before the execution. If this callback is added, another callback
// will be added directly after the execution to record the elapsed device
// time.
if (key.has_value()) {
TF_RETURN_IF_ERROR(device_state->ThenExecuteCallback(
device_state->compute_stream(), [start_time_ns]() {
*start_time_ns = tsl::Env::Default()->NowNanos();
}));
}
absl::StatusOr<ExecutionOutput> result_buffer_or_status =
executables_[executable_idx]->RunAsync(std::move(execution_inputs),
run_options);
Expand All @@ -2854,6 +2867,27 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
return result_buffer_or_status.status();
}

// Add a callback on the stream to record the elapsed device time of the
// executable execution.
//
// Do not place other callbacks between the callback recording the start time
// and this callback because their execution time will incorrectly count
// toward device execution time.
//
// This callback is only added if there is a valid key to guarantee that
// either both or none of the device time measurement callbacks are added to
// the stream, and to avoid needing a mutex.
if (key.has_value()) {
TF_RETURN_IF_ERROR(device_state->ThenExecuteCallback(
device_state->compute_stream(),
[key, start_time_ns,
device_type = GetDeviceType(client_->platform_id())]() {
auto elapsed = absl::FromUnixNanos(tsl::Env::Default()->NowNanos()) -
absl::FromUnixNanos(*start_time_ns);
xla::RecordDeviceTimeMeasurement(*key, elapsed, device_type);
}));
}

if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
ExecutionOutput& execution_output = result_buffer_or_status.value();
// If we used a transient tuple for the arguments we donated its root table
Expand Down
27 changes: 16 additions & 11 deletions xla/pjrt/profiling/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,39 @@ exports_files(
)

cc_library(
name = "device_time_measurement",
name = "no_op_device_time_measurement",
srcs = [
"device_time_measurement.h",
"no_op_device_time_measurement.cc",
"no_op_device_time_measurement.h",
],
# copybara:uncomment_begin(google-only)
# compatible_with = ["//buildenv/target:non_prod"],
# copybara:uncomment_end
textual_hdrs = ["device_time_measurement.h"],
deps = [
# copybara:comment_begin(oss-only)
":no_op_device_time_measurement",
# copybara:comment_end
# copybara:uncomment_begin(google-only)
# "//learning/brain/google/runtime:device_runtime_profiling",
# copybara:uncomment_end
"//xla/pjrt:pjrt_compiler",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)

cc_library(
name = "no_op_device_time_measurement",
srcs = ["device_time_measurement.h"],
hdrs = ["no_op_device_time_measurement.h"],
name = "device_time_measurement",
# copybara:uncomment_begin(google-only)
# compatible_with = ["//buildenv/target:non_prod"],
# copybara:uncomment_end
textual_hdrs = ["device_time_measurement.h"],
deps = [
# copybara:comment_begin(oss-only)
":no_op_device_time_measurement",
# copybara:comment_end
# copybara:uncomment_begin(google-only)
# "//learning/brain/google/runtime:device_runtime_profiling",
# copybara:uncomment_end
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"//xla/pjrt:pjrt_compiler",
],
)
14 changes: 14 additions & 0 deletions xla/pjrt/profiling/device_time_measurement.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "xla/pjrt/pjrt_compiler.h"

namespace xla {

Expand Down Expand Up @@ -79,5 +80,18 @@ void RecordDeviceTimeMeasurement(
uint64_t key, absl::Duration elapsed,
xla::DeviceTimeMeasurement::DeviceType device_type);

// Helper function to convert PjRtPlatformId to
// DeviceTimeMeasurement::DeviceType.
inline DeviceTimeMeasurement::DeviceType GetDeviceType(
PjRtPlatformId platform_id) {
if (platform_id == CudaId() || platform_id == RocmId() ||
platform_id == SyclId()) {
return DeviceTimeMeasurement::DeviceType::kGpu;
} else if (platform_id == TpuId()) {
return DeviceTimeMeasurement::DeviceType::kTpu;
}
return DeviceTimeMeasurement::DeviceType::kUnknown;
}

} // namespace xla
#endif // XLA_PJRT_PROFILING_DEVICE_TIME_MEASUREMENT_H_
37 changes: 37 additions & 0 deletions xla/pjrt/profiling/no_op_device_time_measurement.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/pjrt/profiling/no_op_device_time_measurement.h"

#include <cstdint>
#include <memory>
#include <optional>

#include "absl/time/time.h"
#include "xla/pjrt/profiling/device_time_measurement.h"

namespace xla {

std::unique_ptr<DeviceTimeMeasurement> CreateDeviceTimeMeasurement() {
return std::make_unique<NoOpDeviceTimeMeasurement>();
}

std::optional<uint64_t> GetDeviceTimeMeasurementKey() { return std::nullopt; }

void RecordDeviceTimeMeasurement(
uint64_t key, absl::Duration elapsed,
xla::DeviceTimeMeasurement::DeviceType device_type) {}

} // namespace xla
12 changes: 4 additions & 8 deletions xla/pjrt/profiling/no_op_device_time_measurement.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,13 @@ class NoOpDeviceTimeMeasurement : public DeviceTimeMeasurement {
void Record(absl::Duration elapsed, DeviceType device_type) override {};
};

inline std::unique_ptr<DeviceTimeMeasurement> CreateDeviceTimeMeasurement() {
return std::make_unique<NoOpDeviceTimeMeasurement>();
}
std::unique_ptr<DeviceTimeMeasurement> CreateDeviceTimeMeasurement();

inline std::optional<uint64_t> GetDeviceTimeMeasurementKey() {
return std::nullopt;
}
std::optional<uint64_t> GetDeviceTimeMeasurementKey();

inline void RecordDeviceTimeMeasurement(
void RecordDeviceTimeMeasurement(
uint64_t key, absl::Duration elapsed,
xla::DeviceTimeMeasurement::DeviceType device_type) {}
xla::DeviceTimeMeasurement::DeviceType device_type);

} // namespace xla
#endif // XLA_PJRT_PROFILING_NO_OP_DEVICE_TIME_MEASUREMENT_H_
32 changes: 32 additions & 0 deletions xla/pjrt/profiling/test_util/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
load(
"//xla/tsl:tsl.bzl",
"internal_visibility",
)

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = internal_visibility([
"//xla:internal",
]),
licenses = ["notice"],
)

cc_library(
name = "mock_device_time_measurement",
testonly = True,
srcs = [
"mock_device_time_measurement.cc",
"//xla/pjrt/profiling:device_time_measurement.h",
],
hdrs = ["mock_device_time_measurement.h"],
# copybara:uncomment_begin(google-only)
# compatible_with = ["//buildenv/target:non_prod"],
# copybara:uncomment_end
deps = [
"//xla/pjrt:pjrt_compiler",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/debugging:leak_check",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)
Loading

0 comments on commit 3352841

Please sign in to comment.