diff --git a/rumqttd/CHANGELOG.md b/rumqttd/CHANGELOG.md index ed1a3248..1e853e16 100644 --- a/rumqttd/CHANGELOG.md +++ b/rumqttd/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Assign random identifier to clients connecting with empty client id. - `Unsubscribe` with `local::LinkTx`. +- Optional shutdown handle for `Broker` in the form of `BrokerHandle`. ### Changed - Public re-export `Strategy` for shared subscriptions diff --git a/rumqttd/src/server/broker.rs b/rumqttd/src/server/broker.rs index d737fa79..99c12181 100644 --- a/rumqttd/src/server/broker.rs +++ b/rumqttd/src/server/broker.rs @@ -14,7 +14,7 @@ use flume::{RecvError, SendError, Sender}; use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::{Arc, Mutex}; -use tokio_util::sync::CancellationToken; +use tokio_util::sync::{CancellationToken, DropGuard}; use tracing::{error, field, info, warn, Instrument}; use uuid::Uuid; @@ -41,7 +41,7 @@ use crate::{Config, ConnectionId, ServerSettings}; use tokio::net::{TcpListener, TcpStream}; use tokio::time::error::Elapsed; -use tokio::{select, task, time}; +use tokio::{task, time}; #[derive(Debug, thiserror::Error)] #[error("Acceptor error")] @@ -68,16 +68,30 @@ pub enum Error { pub struct Broker { config: Arc, router_tx: Sender<(ConnectionId, Event)>, + cancellation_token: CancellationToken, } +/// Handle to signal a shutdown to the broker, +/// +/// NOTE: Handles are oneshot. So if a broker is stopped using the [`BrokerHandle::stop()`] method, +/// any handles that were created before stop was called will no longer control the broker. pub struct BrokerHandle { token: CancellationToken, } impl BrokerHandle { - pub fn stop(self) { + pub fn stop(&self) { self.token.cancel(); } + /// Checks if handle is stopped + pub fn is_stopped(&self) -> bool { + self.token.is_cancelled() + } + /// Fetches the underlying drop guard for the broker handle. + /// Will cancel the broker if dropped. + pub fn drop_guard(self) -> DropGuard { + self.token.drop_guard() + } } impl Broker { @@ -85,6 +99,7 @@ impl Broker { let config = Arc::new(config); let router_config = config.router.clone(); let router: Router = Router::new(config.id, router_config); + let cancellation_token = CancellationToken::new(); // Setup cluster if cluster settings are configured. match config.cluster.clone() { @@ -98,11 +113,19 @@ impl Broker { // Start router first and then cluster in the background let router_tx = router.spawn(); // cluster.spawn(); - Broker { config, router_tx } + Broker { + config, + router_tx, + cancellation_token, + } } None => { let router_tx = router.spawn(); - Broker { config, router_tx } + Broker { + config, + router_tx, + cancellation_token, + } } } } @@ -166,14 +189,21 @@ impl Broker { Ok((link_tx, link_rx)) } + /// Get a shutdown handle for this broker. + pub fn handle(&self) -> BrokerHandle { + BrokerHandle { + token: self.cancellation_token.clone(), + } + } + #[tracing::instrument(skip(self))] - pub fn start(&mut self) -> Result { + pub fn start(&mut self) -> Result<(), Error> { if self.config.v4.is_none() && self.config.v5.is_none() && (cfg!(not(feature = "websocket")) || self.config.ws.is_none()) { return Err(Error::Config( - "Atleast one server config must be specified, \ + "At least one server config must be specified, \ consider adding either of [v4.x]/[v5.x] or [ws.x] (if enabled) in config file." .to_string(), )); @@ -183,22 +213,22 @@ impl Broker { // so we collect handles for all of the spawned servers let mut server_thread_handles = Vec::new(); - let cancel_token = CancellationToken::new(); - if let Some(metrics_config) = self.config.metrics.clone() { let timer_thread = thread::Builder::new().name("timer".to_owned()); let router_tx = self.router_tx.clone(); - let token = cancel_token.clone(); + let token = self.cancellation_token.clone(); timer_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async move { - select! { - _ = timer::start(metrics_config, router_tx) => {} - _ = token.cancelled() => { - info!("shutting down timer"); - } + let out = token + .run_until_cancelled(timer::start(metrics_config, router_tx)) + .await; + if out.is_none() { + info!("Timer thread was cancelled.") + } else { + info!("Timer thread completed.") } }); })?; @@ -208,20 +238,23 @@ impl Broker { if let Some(bridge_config) = self.config.bridge.clone() { let bridge_thread = thread::Builder::new().name(bridge_config.name.clone()); let router_tx = self.router_tx.clone(); - let token = cancel_token.clone(); + let token = self.cancellation_token.clone(); bridge_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async move { - select! { - val = bridge::start(bridge_config, router_tx, V4) => { - if let Err(e) = val { error!(error=?e, "Bridge Link error") }; + let out = token + .run_until_cancelled(bridge::start(bridge_config, router_tx, V4)) + .await; + match out { + Some(Ok(())) => { + info!("Bridge thread completed.") } - _ = token.cancelled() => { - info!("shutting down bridge"); + Some(Err(err)) => error!(error=%err, "Bridge thread error"), + None => { + info!("Bridge thread cancelled.") } - } }); })?; @@ -229,23 +262,28 @@ impl Broker { // Spawn servers in a separate thread. if let Some(v4_config) = &self.config.v4 { - for (_, config) in v4_config.clone() { + for (server_name, config) in v4_config.clone() { let server_thread = thread::Builder::new().name(config.name.clone()); let mut server = Server::new(config, self.router_tx.clone(), V4); - let token = cancel_token.clone(); + let token = self.cancellation_token.clone(); let handle = server_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async { - select! { - val = server.start(LinkType::Remote) => { - if let Err(e) = val { error!(error=?e, "Server error - V4") }; + let out = token + .run_until_cancelled(server.start(LinkType::Remote)) + .await; + match out { + Some(Ok(())) => { + info!(server_name, "V4 server thread completed.") } - _ = token.cancelled() => { - info!("shutting down V4 server"); + Some(Err(err)) => { + error!(server_name, error=%err, "V4 server error") + } + None => { + info!(server_name, "V4 server cancelled") } - } }); })?; @@ -254,23 +292,28 @@ impl Broker { } if let Some(v5_config) = &self.config.v5 { - for (_, config) in v5_config.clone() { + for (server_name, config) in v5_config.clone() { let server_thread = thread::Builder::new().name(config.name.clone()); let mut server = Server::new(config, self.router_tx.clone(), V5); - let token = cancel_token.clone(); + let token = self.cancellation_token.clone(); let handle = server_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async { - select! { - val = server.start(LinkType::Remote) => { - if let Err(e) = val { error!(error=?e, "Server error - V5") }; + let out = token + .run_until_cancelled(server.start(LinkType::Remote)) + .await; + match out { + Some(Ok(())) => { + info!(server_name, "V5 server thread completed.") } - _ = token.cancelled() => { - info!("shutting down V5 server"); + Some(Err(err)) => { + error!(server_name, error=%err, "V5 server error") + } + None => { + info!(server_name, "V5 server cancelled") } - } }); })?; @@ -285,24 +328,29 @@ impl Broker { #[cfg(feature = "websocket")] if let Some(ws_config) = &self.config.ws { - for (_, config) in ws_config.clone() { + for (server_name, config) in ws_config.clone() { let server_thread = thread::Builder::new().name(config.name.clone()); //TODO: Add support for V5 procotol with websockets. Registered in config or on ServerSettings let mut server = Server::new(config, self.router_tx.clone(), V4); - let token = cancel_token.clone(); + let token = self.cancellation_token.clone(); let handle = server_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async { - select! { - val = server.start(LinkType::Websocket) => { - if let Err(e) = val { error!(error=?e, "Server error - WS") }; + let out = token + .run_until_cancelled(server.start(LinkType::Remote)) + .await; + match out { + Some(Ok(())) => { + info!(server_name, "WS server thread completed.") } - _ = token.cancelled() => { - info!("shutting down WS server"); + Some(Err(err)) => { + error!(server_name, error=%err, "WS server error") + } + None => { + info!(server_name, "WS server cancelled") } - } }); })?; @@ -327,7 +375,7 @@ impl Broker { }; let metrics_thread = thread::Builder::new().name("Metrics".to_owned()); let meter_link = self.meters().unwrap(); - let token = cancel_token.clone(); + let token = self.cancellation_token.clone(); metrics_thread.spawn(move || { let builder = PrometheusBuilder::new().with_http_listener(addr); builder.install().unwrap(); @@ -336,6 +384,11 @@ impl Broker { let total_connections = gauge!("metrics.router.total_connections"); let failed_publishes = gauge!("metrics.router.failed_publishes"); loop { + if token.is_cancelled() { + info!("shutting down prometheus"); + break; + } + if let Ok(metrics) = meter_link.recv() { for m in metrics { match m { @@ -349,11 +402,6 @@ impl Broker { } } - if token.is_cancelled() { - info!("shutting down prometheus"); - break; - } - std::thread::sleep(Duration::from_secs(timeout)); } })?; @@ -364,16 +412,21 @@ impl Broker { let console_link = Arc::new(console_link); let console_thread = thread::Builder::new().name("Console".to_string()); - let token = cancel_token.clone(); + let token = self.cancellation_token.clone(); console_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async move { - select! { - _ = console::start(console_link) => {} - _ = token.cancelled() => { - info!("shutting down console"); + let out = token + .run_until_cancelled(console::start(console_link)) + .await; + match out { + Some(()) => { + info!("Console thread completed.") + } + None => { + info!("Console thread cancelled") } } }); @@ -383,15 +436,13 @@ impl Broker { // // in ideal case, where server doesn't crash, join() will never resolve // // we still try to join threads so that we don't return from function // // unless everything crashes. - // server_thread_handles.into_iter().for_each(|handle| { - // // join() might panic in case the thread panics - // // we just ignore it - // let _ = handle.join(); - // }); - - Ok(BrokerHandle { - token: cancel_token, - }) + server_thread_handles.into_iter().for_each(|handle| { + // join() might panic in case the thread panics + // we just ignore it + let _ = handle.join(); + }); + self.cancellation_token = CancellationToken::new(); + Ok(()) } } @@ -691,3 +742,56 @@ async fn remote( router_tx.send((connection_id, message)).ok(); } } + +#[cfg(test)] +mod test { + use super::*; + #[tokio::test] + async fn test_broker_handle_stops_broker() { + let config = config::Config::builder() + .add_source(config::File::with_name("rumqttd.toml")) + .build() + .unwrap(); + + let config: Config = config.try_deserialize().unwrap(); + + dbg!(&config); + + let mut broker = Broker::new(config); + let broker_handle = broker.handle(); + let broker_task = tokio::task::spawn_blocking(move || broker.start()); + assert!(!broker_handle.is_stopped()); + tokio::time::sleep(Duration::from_secs(1)).await; + broker_handle.stop(); + tokio::time::timeout(Duration::from_secs(5), broker_task) + .await + .unwrap() + .unwrap() + .unwrap(); + assert!(broker_handle.is_stopped()); + } + #[tokio::test] + async fn test_broker_handle_drop_guard_stops_broker_when_dropped() { + let config = config::Config::builder() + .add_source(config::File::with_name("rumqttd.toml")) + .build() + .unwrap(); + + let config: Config = config.try_deserialize().unwrap(); + + dbg!(&config); + + let mut broker = Broker::new(config); + let broker_handle = broker.handle(); + let broker_task = tokio::task::spawn_blocking(move || broker.start()); + assert!(!broker_handle.is_stopped()); + tokio::time::sleep(Duration::from_secs(1)).await; + drop(broker_handle.drop_guard()); + + tokio::time::timeout(Duration::from_secs(5), broker_task) + .await + .unwrap() + .unwrap() + .unwrap(); + } +}