Skip to content

Commit

Permalink
[jax-transfer-lib]: Add timeouts to the event loop.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725346139
  • Loading branch information
pschuh authored and Google-ML-Automation committed Feb 10, 2025
1 parent 841b1bf commit 6b7e725
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 1 deletion.
1 change: 1 addition & 0 deletions xla/python/transfer/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@tsl//tsl/platform:env",
],
)
Expand Down
43 changes: 42 additions & 1 deletion xla/python/transfer/event_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include <atomic>
#include <cerrno>
#include <memory>
#include <queue>
#include <string>
#include <utility>
#include <vector>
Expand All @@ -38,6 +39,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "tsl/platform/env.h"

namespace aux {
Expand Down Expand Up @@ -66,20 +68,39 @@ class PollEventLoopImpl : public PollEventLoop {
WakeInternal();
}

void ScheduleAt(absl::Time t, absl::AnyInvocable<void() &&> cb) override {
absl::MutexLock l(&mu_);
bool needs_wake = timeout_cbs_.empty() || timeout_cbs_.top().t > t;
timeout_cbs_.push({t, std::move(cb)});
if (needs_wake) {
WakeInternal();
}
}

private:
void Run() {
// TODO(parkers): switch to epoll if handlers.size() is too big.
std::vector<Handler*> handlers;
std::vector<Handler*> new_handlers;
std::vector<pollfd> fds;
absl::Time wake_time = absl::InfiniteFuture();
while (true) {
fds.resize(handlers.size() + 1);
for (size_t i = 0; i < handlers.size(); ++i) {
memset(&fds[i], 0, sizeof(pollfd));
handlers[i]->PopulatePollInfo(fds[i]);
}
fds[handlers.size()] = {.fd = event_fd_, .events = POLLIN, .revents = 0};
poll(&fds[0], fds.size(), -1);
{
auto poll_time = absl::Now();
if (wake_time < poll_time) {
} else if (wake_time < absl::InfiniteFuture()) {
auto poll_duration = absl::ToTimespec(wake_time - poll_time);
ppoll(&fds[0], fds.size(), &poll_duration, nullptr);
} else {
poll(&fds[0], fds.size(), -1);
}
}
absl::InlinedVector<Handler*, 4> inserts;
absl::flat_hash_set<Handler*> wakes;
std::vector<absl::AnyInvocable<void() &&>> cbs;
Expand All @@ -93,6 +114,15 @@ class PollEventLoopImpl : public PollEventLoop {
std::swap(wakes_, wakes);
std::swap(inserts, inserts_);
std::swap(cbs, cbs_);
{
auto woken_time = absl::Now();
while (!timeout_cbs_.empty() && timeout_cbs_.top().t < woken_time) {
cbs.push_back(std::move(std::move(timeout_cbs_.top().cb)));
timeout_cbs_.pop();
}
wake_time = timeout_cbs_.empty() ? absl::InfiniteFuture()
: timeout_cbs_.top().t;
}
needs_wake_ = true;
}
for (auto& cb : cbs) {
Expand Down Expand Up @@ -125,6 +155,17 @@ class PollEventLoopImpl : public PollEventLoop {
bool needs_wake_ = true;
int event_fd_ = eventfd(0, EFD_CLOEXEC);
std::vector<absl::AnyInvocable<void() &&>> cbs_;
struct TimeoutWork {
absl::Time t;
mutable absl::AnyInvocable<void() &&> cb;
};
struct TimeoutOrder {
bool operator()(const TimeoutWork& a, const TimeoutWork& b) const {
return a.t < b.t;
}
};
std::priority_queue<TimeoutWork, std::vector<TimeoutWork>, TimeoutOrder>
timeout_cbs_;
absl::InlinedVector<Handler*, 4> inserts_;
absl::flat_hash_set<Handler*> wakes_;
std::unique_ptr<tsl::Thread> thread_;
Expand Down
5 changes: 5 additions & 0 deletions xla/python/transfer/event_loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/time/time.h"

// socket.h in conda sysroot include directory does not define
// SO_ZEROCOPY and SO_EE_ORIGIN_ZEROCOPY that were introduced in a
Expand Down Expand Up @@ -71,6 +72,10 @@ class PollEventLoop {
// Notifies the EventLoop to call HandleEvents with a spurious wake.
virtual void SendWake(Handler* handler) = 0;

// Run callback on the event loop at some point in the future.
virtual void ScheduleAt(absl::Time time,
absl::AnyInvocable<void() &&> cb) = 0;

private:
// Implementation detail of Handler::Register.
virtual void RegisterHandler(Handler* handler) = 0;
Expand Down
9 changes: 9 additions & 0 deletions xla/python/transfer/event_loop_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,14 @@ TEST(EventLoopTest, TestSchedule) {
done_notify.WaitForNotification();
}

TEST(EventLoopTest, TestScheduleAt) {
absl::Notification done_notify;
auto wake_time = absl::Now() + absl::Seconds(2);
PollEventLoop::GetDefault()->ScheduleAt(
wake_time, [&done_notify]() { done_notify.Notify(); });
done_notify.WaitForNotification();
ASSERT_GE(absl::Now(), wake_time);
}

} // namespace
} // namespace aux

0 comments on commit 6b7e725

Please sign in to comment.