Skip to content

Commit

Permalink
Merge pull request #1138 from rust-lang/careful-dropping
Browse files Browse the repository at this point in the history
Cancel tokens and abort tasks on drop
  • Loading branch information
shepmaster authored Feb 14, 2025
2 parents 87a7451 + 036f2b8 commit f698118
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 37 deletions.
22 changes: 21 additions & 1 deletion compiler/base/orchestrator/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion compiler/base/orchestrator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ snafu = { version = "0.8.0", default-features = false, features = ["futures", "s
strum_macros = { version = "0.26.1", default-features = false }
tokio = { version = "1.28", default-features = false, features = ["fs", "io-std", "io-util", "macros", "process", "rt", "time", "sync"] }
tokio-stream = { version = "0.1.14", default-features = false }
tokio-util = { version = "0.7.8", default-features = false, features = ["io", "io-util"] }
tokio-util = { version = "0.7.8", default-features = false, features = ["io", "io-util", "rt"] }
toml = { version = "0.8.2", default-features = false, features = ["parse", "display"] }
tracing = { version = "0.1.37", default-features = false, features = ["attributes"] }

Expand Down
28 changes: 13 additions & 15 deletions compiler/base/orchestrator/src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use tokio::{
process::{Child, ChildStdin, ChildStdout, Command},
select,
sync::{mpsc, oneshot, OnceCell},
task::{JoinHandle, JoinSet},
task::JoinSet,
time::{self, MissedTickBehavior},
try_join,
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::{io::SyncIoBridge, sync::CancellationToken};
use tokio_util::{io::SyncIoBridge, sync::CancellationToken, task::AbortOnDropHandle};
use tracing::{error, info, info_span, instrument, trace, trace_span, warn, Instrument};

use crate::{
Expand All @@ -30,7 +30,7 @@ use crate::{
ExecuteCommandResponse, JobId, Multiplexed, OneToOneResponse, ReadFileRequest,
ReadFileResponse, SerializedError2, WorkerMessage, WriteFileRequest,
},
DropErrorDetailsExt,
DropErrorDetailsExt, TaskAbortExt as _,
};

pub mod limits;
Expand Down Expand Up @@ -1161,7 +1161,7 @@ impl Drop for CancelOnDrop {
#[derive(Debug)]
struct Container {
permit: Box<dyn ContainerPermit>,
task: JoinHandle<Result<()>>,
task: AbortOnDropHandle<Result<()>>,
kill_child: TerminateContainer,
modify_cargo_toml: ModifyCargoToml,
commander: Commander,
Expand All @@ -1186,7 +1186,8 @@ impl Container {

let (command_tx, command_rx) = mpsc::channel(8);
let demultiplex_task =
tokio::spawn(Commander::demultiplex(command_rx, from_worker_rx).in_current_span());
tokio::spawn(Commander::demultiplex(command_rx, from_worker_rx).in_current_span())
.abort_on_drop();

let task = tokio::spawn(
async move {
Expand Down Expand Up @@ -1216,7 +1217,8 @@ impl Container {
Ok(())
}
.in_current_span(),
);
)
.abort_on_drop();

let commander = Commander {
to_worker_tx,
Expand Down Expand Up @@ -1865,7 +1867,8 @@ impl Container {
}
}
.instrument(trace_span!("cargo task").or_current())
});
})
.abort_on_drop();

Ok(SpawnCargo {
permit,
Expand Down Expand Up @@ -2128,7 +2131,7 @@ pub enum DoRequestError {

struct SpawnCargo {
permit: Box<dyn ProcessPermit>,
task: JoinHandle<Result<ExecuteCommandResponse, SpawnCargoError>>,
task: AbortOnDropHandle<Result<ExecuteCommandResponse, SpawnCargoError>>,
stdin_tx: mpsc::Sender<String>,
stdout_rx: mpsc::Receiver<String>,
stderr_rx: mpsc::Receiver<String>,
Expand Down Expand Up @@ -2842,14 +2845,9 @@ fn spawn_io_queue(stdin: ChildStdin, stdout: ChildStdout, token: CancellationTok
let handle = tokio::runtime::Handle::current();

loop {
let coordinator_msg = handle.block_on(async {
select! {
() = token.cancelled() => None,
msg = rx.recv() => msg,
}
});
let coordinator_msg = handle.block_on(token.run_until_cancelled(rx.recv()));

let Some(coordinator_msg) = coordinator_msg else {
let Some(Some(coordinator_msg)) = coordinator_msg else {
break;
};

Expand Down
10 changes: 10 additions & 0 deletions compiler/base/orchestrator/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ pub mod coordinator;
mod message;
pub mod worker;

pub trait TaskAbortExt<T>: Sized {
fn abort_on_drop(self) -> tokio_util::task::AbortOnDropHandle<T>;
}

impl<T> TaskAbortExt<T> for tokio::task::JoinHandle<T> {
fn abort_on_drop(self) -> tokio_util::task::AbortOnDropHandle<T> {
tokio_util::task::AbortOnDropHandle::new(self)
}
}

pub trait DropErrorDetailsExt<T> {
fn drop_error_details(self) -> Result<T, tokio::sync::mpsc::error::SendError<()>>;
}
Expand Down
18 changes: 10 additions & 8 deletions compiler/base/orchestrator/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use tokio::{
sync::mpsc,
task::JoinSet,
};
use tokio_util::sync::CancellationToken;
use tokio_util::sync::{CancellationToken, DropGuard};

use crate::{
bincode_input_closed,
Expand All @@ -55,7 +55,7 @@ use crate::{
ExecuteCommandResponse, JobId, Multiplexed, ReadFileRequest, ReadFileResponse,
SerializedError2, WorkerMessage, WriteFileRequest, WriteFileResponse,
},
DropErrorDetailsExt,
DropErrorDetailsExt as _, TaskAbortExt as _,
};

pub async fn listen(project_dir: impl Into<PathBuf>) -> Result<(), Error> {
Expand All @@ -66,14 +66,16 @@ pub async fn listen(project_dir: impl Into<PathBuf>) -> Result<(), Error> {
let mut io_tasks = spawn_io_queue(coordinator_msg_tx, worker_msg_rx);

let (process_tx, process_rx) = mpsc::channel(8);
let process_task = tokio::spawn(manage_processes(process_rx, project_dir.clone()));
let process_task =
tokio::spawn(manage_processes(process_rx, project_dir.clone())).abort_on_drop();

let handler_task = tokio::spawn(handle_coordinator_message(
coordinator_msg_rx,
worker_msg_tx,
project_dir,
process_tx,
));
))
.abort_on_drop();

select! {
Some(io_task) = io_tasks.join_next() => {
Expand Down Expand Up @@ -403,7 +405,7 @@ struct ProcessState {
processes: JoinSet<Result<(), ProcessError>>,
stdin_senders: HashMap<JobId, mpsc::Sender<String>>,
stdin_shutdown_tx: mpsc::Sender<JobId>,
kill_tokens: HashMap<JobId, CancellationToken>,
kill_tokens: HashMap<JobId, DropGuard>,
}

impl ProcessState {
Expand Down Expand Up @@ -456,7 +458,7 @@ impl ProcessState {

let task_set = stream_stdio(worker_msg_tx.clone(), stdin_rx, stdin, stdout, stderr);

self.kill_tokens.insert(job_id, token.clone());
self.kill_tokens.insert(job_id, token.clone().drop_guard());

self.processes.spawn({
let stdin_shutdown_tx = self.stdin_shutdown_tx.clone();
Expand Down Expand Up @@ -508,8 +510,8 @@ impl ProcessState {
}

fn kill(&mut self, job_id: JobId) {
if let Some(token) = self.kill_tokens.get(&job_id) {
token.cancel();
if let Some(token) = self.kill_tokens.remove(&job_id) {
drop(token);
}
}
}
Expand Down
12 changes: 10 additions & 2 deletions ui/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions ui/src/server_axum/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use futures::{
future::{Fuse, FusedFuture as _},
FutureExt as _,
};
use orchestrator::DropErrorDetailsExt as _;
use orchestrator::{DropErrorDetailsExt as _, TaskAbortExt as _};
use snafu::prelude::*;
use std::{
future::Future,
Expand All @@ -13,9 +13,9 @@ use std::{
use tokio::{
select,
sync::{mpsc, oneshot},
task::JoinHandle,
time,
};
use tokio_util::task::AbortOnDropHandle;
use tracing::warn;

const ONE_HUNDRED_MILLISECONDS: Duration = Duration::from_millis(100);
Expand Down Expand Up @@ -48,12 +48,12 @@ where
{
pub fn spawn<Fut>(
f: impl FnOnce(mpsc::Receiver<CacheTaskItem<T, E>>) -> Fut,
) -> (JoinHandle<()>, Self)
) -> (AbortOnDropHandle<()>, Self)
where
Fut: Future<Output = ()> + Send + 'static,
{
let (tx, rx) = mpsc::channel(8);
let task = tokio::spawn(f(rx));
let task = tokio::spawn(f(rx)).abort_on_drop();
let cache_tx = CacheTx(tx);
(task, cache_tx)
}
Expand Down Expand Up @@ -148,7 +148,8 @@ where
let new_value = generator().await.map_err(CacheError::from);
CacheInfo::build(new_value)
}
});
})
.abort_on_drop();

new_value.set(new_value_task.fuse());
}
Expand Down
13 changes: 8 additions & 5 deletions ui/src/server_axum/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use tokio::{
task::{AbortHandle, JoinSet},
time,
};
use tokio_util::sync::CancellationToken;
use tokio_util::sync::{CancellationToken, DropGuard};
use tracing::{error, info, instrument, warn, Instrument};

#[derive(Debug, serde::Deserialize, serde::Serialize)]
Expand Down Expand Up @@ -525,7 +525,7 @@ async fn handle_idle(manager: &mut CoordinatorManager, tx: &ResponseTx) -> Contr
ControlFlow::Continue(())
}

type ActiveExecutionInfo = (CancellationToken, Option<mpsc::Sender<String>>);
type ActiveExecutionInfo = (DropGuard, Option<mpsc::Sender<String>>);

async fn handle_msg(
txt: &str,
Expand All @@ -545,7 +545,10 @@ async fn handle_msg(

let guard = db.clone().start_with_guard("ws.Execute", txt).await;

active_executions.insert(meta.sequence_number, (token.clone(), Some(execution_tx)));
active_executions.insert(
meta.sequence_number,
(token.clone().drop_guard(), Some(execution_tx)),
);

// TODO: Should a single execute / build / etc. session have a timeout of some kind?
let spawned = manager
Expand Down Expand Up @@ -602,11 +605,11 @@ async fn handle_msg(
}

Ok(ExecuteKill { meta }) => {
let Some((token, _)) = active_executions.get(&meta.sequence_number) else {
let Some((token, _)) = active_executions.remove(&meta.sequence_number) else {
warn!("Received kill for an execution that is no longer active");
return;
};
token.cancel();
drop(token);
}

Err(e) => {
Expand Down

0 comments on commit f698118

Please sign in to comment.