diff --git a/compiler/base/orchestrator/Cargo.lock b/compiler/base/orchestrator/Cargo.lock index 55ad4a33..b2238ec1 100644 --- a/compiler/base/orchestrator/Cargo.lock +++ b/compiler/base/orchestrator/Cargo.lock @@ -171,6 +171,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -192,6 +203,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -218,6 +230,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.2" @@ -243,7 +261,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.2", ] [[package]] @@ -711,6 +729,8 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.14.5", "pin-project-lite", "tokio", ] diff --git a/compiler/base/orchestrator/Cargo.toml b/compiler/base/orchestrator/Cargo.toml index 2f0b7bd8..066f4303 100644 --- a/compiler/base/orchestrator/Cargo.toml +++ b/compiler/base/orchestrator/Cargo.toml @@ -20,7 +20,7 @@ snafu = { version = "0.8.0", default-features = false, features = ["futures", "s strum_macros = { version = "0.26.1", default-features = false } tokio = { version = "1.28", default-features = false, features = ["fs", "io-std", "io-util", "macros", "process", "rt", "time", "sync"] } tokio-stream = { version = "0.1.14", default-features = false } -tokio-util = { version = "0.7.8", default-features = false, features = ["io", "io-util"] } +tokio-util = { version = "0.7.8", default-features = false, features = ["io", "io-util", "rt"] } toml = { version = "0.8.2", default-features = false, features = ["parse", "display"] } tracing = { version = "0.1.37", default-features = false, features = ["attributes"] } diff --git a/compiler/base/orchestrator/src/coordinator.rs b/compiler/base/orchestrator/src/coordinator.rs index 2e58aeb3..85645496 100644 --- a/compiler/base/orchestrator/src/coordinator.rs +++ b/compiler/base/orchestrator/src/coordinator.rs @@ -15,12 +15,12 @@ use tokio::{ process::{Child, ChildStdin, ChildStdout, Command}, select, sync::{mpsc, oneshot, OnceCell}, - task::{JoinHandle, JoinSet}, + task::JoinSet, time::{self, MissedTickBehavior}, try_join, }; use tokio_stream::wrappers::ReceiverStream; -use tokio_util::{io::SyncIoBridge, sync::CancellationToken}; +use tokio_util::{io::SyncIoBridge, sync::CancellationToken, task::AbortOnDropHandle}; use tracing::{error, info, info_span, instrument, trace, trace_span, warn, Instrument}; use crate::{ @@ -30,7 +30,7 @@ use crate::{ ExecuteCommandResponse, JobId, Multiplexed, OneToOneResponse, ReadFileRequest, ReadFileResponse, SerializedError2, WorkerMessage, WriteFileRequest, }, - DropErrorDetailsExt, + DropErrorDetailsExt, TaskAbortExt as _, }; pub mod limits; @@ -1161,7 +1161,7 @@ impl Drop for CancelOnDrop { #[derive(Debug)] struct Container { permit: Box, - task: JoinHandle>, + task: AbortOnDropHandle>, kill_child: TerminateContainer, modify_cargo_toml: ModifyCargoToml, commander: Commander, @@ -1186,7 +1186,8 @@ impl Container { let (command_tx, command_rx) = mpsc::channel(8); let demultiplex_task = - tokio::spawn(Commander::demultiplex(command_rx, from_worker_rx).in_current_span()); + tokio::spawn(Commander::demultiplex(command_rx, from_worker_rx).in_current_span()) + .abort_on_drop(); let task = tokio::spawn( async move { @@ -1216,7 +1217,8 @@ impl Container { Ok(()) } .in_current_span(), - ); + ) + .abort_on_drop(); let commander = Commander { to_worker_tx, @@ -1865,7 +1867,8 @@ impl Container { } } .instrument(trace_span!("cargo task").or_current()) - }); + }) + .abort_on_drop(); Ok(SpawnCargo { permit, @@ -2128,7 +2131,7 @@ pub enum DoRequestError { struct SpawnCargo { permit: Box, - task: JoinHandle>, + task: AbortOnDropHandle>, stdin_tx: mpsc::Sender, stdout_rx: mpsc::Receiver, stderr_rx: mpsc::Receiver, @@ -2842,14 +2845,9 @@ fn spawn_io_queue(stdin: ChildStdin, stdout: ChildStdout, token: CancellationTok let handle = tokio::runtime::Handle::current(); loop { - let coordinator_msg = handle.block_on(async { - select! { - () = token.cancelled() => None, - msg = rx.recv() => msg, - } - }); + let coordinator_msg = handle.block_on(token.run_until_cancelled(rx.recv())); - let Some(coordinator_msg) = coordinator_msg else { + let Some(Some(coordinator_msg)) = coordinator_msg else { break; }; diff --git a/compiler/base/orchestrator/src/lib.rs b/compiler/base/orchestrator/src/lib.rs index 08662e31..62a3746a 100644 --- a/compiler/base/orchestrator/src/lib.rs +++ b/compiler/base/orchestrator/src/lib.rs @@ -4,6 +4,16 @@ pub mod coordinator; mod message; pub mod worker; +pub trait TaskAbortExt: Sized { + fn abort_on_drop(self) -> tokio_util::task::AbortOnDropHandle; +} + +impl TaskAbortExt for tokio::task::JoinHandle { + fn abort_on_drop(self) -> tokio_util::task::AbortOnDropHandle { + tokio_util::task::AbortOnDropHandle::new(self) + } +} + pub trait DropErrorDetailsExt { fn drop_error_details(self) -> Result>; } diff --git a/compiler/base/orchestrator/src/worker.rs b/compiler/base/orchestrator/src/worker.rs index 3e912c3b..5ce151df 100644 --- a/compiler/base/orchestrator/src/worker.rs +++ b/compiler/base/orchestrator/src/worker.rs @@ -46,7 +46,7 @@ use tokio::{ sync::mpsc, task::JoinSet, }; -use tokio_util::sync::CancellationToken; +use tokio_util::sync::{CancellationToken, DropGuard}; use crate::{ bincode_input_closed, @@ -55,7 +55,7 @@ use crate::{ ExecuteCommandResponse, JobId, Multiplexed, ReadFileRequest, ReadFileResponse, SerializedError2, WorkerMessage, WriteFileRequest, WriteFileResponse, }, - DropErrorDetailsExt, + DropErrorDetailsExt as _, TaskAbortExt as _, }; pub async fn listen(project_dir: impl Into) -> Result<(), Error> { @@ -66,14 +66,16 @@ pub async fn listen(project_dir: impl Into) -> Result<(), Error> { let mut io_tasks = spawn_io_queue(coordinator_msg_tx, worker_msg_rx); let (process_tx, process_rx) = mpsc::channel(8); - let process_task = tokio::spawn(manage_processes(process_rx, project_dir.clone())); + let process_task = + tokio::spawn(manage_processes(process_rx, project_dir.clone())).abort_on_drop(); let handler_task = tokio::spawn(handle_coordinator_message( coordinator_msg_rx, worker_msg_tx, project_dir, process_tx, - )); + )) + .abort_on_drop(); select! { Some(io_task) = io_tasks.join_next() => { @@ -403,7 +405,7 @@ struct ProcessState { processes: JoinSet>, stdin_senders: HashMap>, stdin_shutdown_tx: mpsc::Sender, - kill_tokens: HashMap, + kill_tokens: HashMap, } impl ProcessState { @@ -456,7 +458,7 @@ impl ProcessState { let task_set = stream_stdio(worker_msg_tx.clone(), stdin_rx, stdin, stdout, stderr); - self.kill_tokens.insert(job_id, token.clone()); + self.kill_tokens.insert(job_id, token.clone().drop_guard()); self.processes.spawn({ let stdin_shutdown_tx = self.stdin_shutdown_tx.clone(); @@ -508,8 +510,8 @@ impl ProcessState { } fn kill(&mut self, job_id: JobId) { - if let Some(token) = self.kill_tokens.get(&job_id) { - token.cancel(); + if let Some(token) = self.kill_tokens.remove(&job_id) { + drop(token); } } } diff --git a/ui/Cargo.lock b/ui/Cargo.lock index 59b02052..e475eda0 100644 --- a/ui/Cargo.lock +++ b/ui/Cargo.lock @@ -525,6 +525,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.2" @@ -540,7 +546,7 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" dependencies = [ - "hashbrown", + "hashbrown 0.15.2", ] [[package]] @@ -871,7 +877,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.2", ] [[package]] @@ -1972,6 +1978,8 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.14.5", "pin-project-lite", "tokio", ] diff --git a/ui/src/server_axum/cache.rs b/ui/src/server_axum/cache.rs index 65a2c02e..7fbe0011 100644 --- a/ui/src/server_axum/cache.rs +++ b/ui/src/server_axum/cache.rs @@ -2,7 +2,7 @@ use futures::{ future::{Fuse, FusedFuture as _}, FutureExt as _, }; -use orchestrator::DropErrorDetailsExt as _; +use orchestrator::{DropErrorDetailsExt as _, TaskAbortExt as _}; use snafu::prelude::*; use std::{ future::Future, @@ -13,9 +13,9 @@ use std::{ use tokio::{ select, sync::{mpsc, oneshot}, - task::JoinHandle, time, }; +use tokio_util::task::AbortOnDropHandle; use tracing::warn; const ONE_HUNDRED_MILLISECONDS: Duration = Duration::from_millis(100); @@ -48,12 +48,12 @@ where { pub fn spawn( f: impl FnOnce(mpsc::Receiver>) -> Fut, - ) -> (JoinHandle<()>, Self) + ) -> (AbortOnDropHandle<()>, Self) where Fut: Future + Send + 'static, { let (tx, rx) = mpsc::channel(8); - let task = tokio::spawn(f(rx)); + let task = tokio::spawn(f(rx)).abort_on_drop(); let cache_tx = CacheTx(tx); (task, cache_tx) } @@ -148,7 +148,8 @@ where let new_value = generator().await.map_err(CacheError::from); CacheInfo::build(new_value) } - }); + }) + .abort_on_drop(); new_value.set(new_value_task.fuse()); } diff --git a/ui/src/server_axum/websocket.rs b/ui/src/server_axum/websocket.rs index d621329c..3ddeda84 100644 --- a/ui/src/server_axum/websocket.rs +++ b/ui/src/server_axum/websocket.rs @@ -29,7 +29,7 @@ use tokio::{ task::{AbortHandle, JoinSet}, time, }; -use tokio_util::sync::CancellationToken; +use tokio_util::sync::{CancellationToken, DropGuard}; use tracing::{error, info, instrument, warn, Instrument}; #[derive(Debug, serde::Deserialize, serde::Serialize)] @@ -525,7 +525,7 @@ async fn handle_idle(manager: &mut CoordinatorManager, tx: &ResponseTx) -> Contr ControlFlow::Continue(()) } -type ActiveExecutionInfo = (CancellationToken, Option>); +type ActiveExecutionInfo = (DropGuard, Option>); async fn handle_msg( txt: &str, @@ -545,7 +545,10 @@ async fn handle_msg( let guard = db.clone().start_with_guard("ws.Execute", txt).await; - active_executions.insert(meta.sequence_number, (token.clone(), Some(execution_tx))); + active_executions.insert( + meta.sequence_number, + (token.clone().drop_guard(), Some(execution_tx)), + ); // TODO: Should a single execute / build / etc. session have a timeout of some kind? let spawned = manager @@ -602,11 +605,11 @@ async fn handle_msg( } Ok(ExecuteKill { meta }) => { - let Some((token, _)) = active_executions.get(&meta.sequence_number) else { + let Some((token, _)) = active_executions.remove(&meta.sequence_number) else { warn!("Received kill for an execution that is no longer active"); return; }; - token.cancel(); + drop(token); } Err(e) => {