Skip to content

Commit

Permalink
Trigger CancellationTokens on drop
Browse files Browse the repository at this point in the history
We'd previously done this for one specific usage via the
`CancelOnDrop` wrapper. This works for the more complicated case where
we need to continue to clone the token. Simpler cases can use the
`DropGuard` directly.
  • Loading branch information
shepmaster committed Feb 14, 2025
1 parent e226d44 commit 036f2b8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
10 changes: 5 additions & 5 deletions compiler/base/orchestrator/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use tokio::{
sync::mpsc,
task::JoinSet,
};
use tokio_util::sync::CancellationToken;
use tokio_util::sync::{CancellationToken, DropGuard};

use crate::{
bincode_input_closed,
Expand Down Expand Up @@ -405,7 +405,7 @@ struct ProcessState {
processes: JoinSet<Result<(), ProcessError>>,
stdin_senders: HashMap<JobId, mpsc::Sender<String>>,
stdin_shutdown_tx: mpsc::Sender<JobId>,
kill_tokens: HashMap<JobId, CancellationToken>,
kill_tokens: HashMap<JobId, DropGuard>,
}

impl ProcessState {
Expand Down Expand Up @@ -458,7 +458,7 @@ impl ProcessState {

let task_set = stream_stdio(worker_msg_tx.clone(), stdin_rx, stdin, stdout, stderr);

self.kill_tokens.insert(job_id, token.clone());
self.kill_tokens.insert(job_id, token.clone().drop_guard());

self.processes.spawn({
let stdin_shutdown_tx = self.stdin_shutdown_tx.clone();
Expand Down Expand Up @@ -510,8 +510,8 @@ impl ProcessState {
}

fn kill(&mut self, job_id: JobId) {
if let Some(token) = self.kill_tokens.get(&job_id) {
token.cancel();
if let Some(token) = self.kill_tokens.remove(&job_id) {
drop(token);
}
}
}
Expand Down
13 changes: 8 additions & 5 deletions ui/src/server_axum/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use tokio::{
task::{AbortHandle, JoinSet},
time,
};
use tokio_util::sync::CancellationToken;
use tokio_util::sync::{CancellationToken, DropGuard};
use tracing::{error, info, instrument, warn, Instrument};

#[derive(Debug, serde::Deserialize, serde::Serialize)]
Expand Down Expand Up @@ -525,7 +525,7 @@ async fn handle_idle(manager: &mut CoordinatorManager, tx: &ResponseTx) -> Contr
ControlFlow::Continue(())
}

type ActiveExecutionInfo = (CancellationToken, Option<mpsc::Sender<String>>);
type ActiveExecutionInfo = (DropGuard, Option<mpsc::Sender<String>>);

async fn handle_msg(
txt: &str,
Expand All @@ -545,7 +545,10 @@ async fn handle_msg(

let guard = db.clone().start_with_guard("ws.Execute", txt).await;

active_executions.insert(meta.sequence_number, (token.clone(), Some(execution_tx)));
active_executions.insert(
meta.sequence_number,
(token.clone().drop_guard(), Some(execution_tx)),
);

// TODO: Should a single execute / build / etc. session have a timeout of some kind?
let spawned = manager
Expand Down Expand Up @@ -602,11 +605,11 @@ async fn handle_msg(
}

Ok(ExecuteKill { meta }) => {
let Some((token, _)) = active_executions.get(&meta.sequence_number) else {
let Some((token, _)) = active_executions.remove(&meta.sequence_number) else {
warn!("Received kill for an execution that is no longer active");
return;
};
token.cancel();
drop(token);
}

Err(e) => {
Expand Down

0 comments on commit 036f2b8

Please sign in to comment.