From 74ce72d19c90a781b23f9d24809303a9bbcc4bc7 Mon Sep 17 00:00:00 2001 From: Silvestr Predko Date: Thu, 13 Jun 2024 14:57:23 +0300 Subject: [PATCH] - bridge.rs: - Add a `shutdown_rx` parameter to the `start` function to receive shutdown signals. - Handle the shutdown signal in the main loop and break out cleanly. - console.rs: - Add a `shutdown_rx` parameter to the `start` function to receive shutdown signals. - Use `axum::serve` with `with_graceful_shutdown` to handle shutdown signals. - timer.rs: - Add a `shutdown_rx` parameter to the `start` function to receive shutdown signals. - Handle the shutdown signal in the main loop and break out cleanly. - broker.rs: - Add a `shutdown_rx` and `shutdown_tx` to the `Broker` struct for shutdown handling. - Pass the `shutdown_rx` to various components like bridge, console, timer, and servers. - Handle shutdown signals in the server's main loop and break out cleanly. - Add a `ShutdownHandler` and `ShutdownDropGuard` for managing shutdown signals. These changes allow for graceful shutdown of the rumqtt broker and its components, ensuring clean termination and resource cleanup when a shutdown signal is received. --- rumqttd/src/link/bridge.rs | 6 ++ rumqttd/src/link/console.rs | 14 ++- rumqttd/src/link/timer.rs | 8 +- rumqttd/src/server/broker.rs | 168 +++++++++++++++++++++++++++++------ 4 files changed, 165 insertions(+), 31 deletions(-) diff --git a/rumqttd/src/link/bridge.rs b/rumqttd/src/link/bridge.rs index 0c6b193be..eac2050dd 100644 --- a/rumqttd/src/link/bridge.rs +++ b/rumqttd/src/link/bridge.rs @@ -12,6 +12,7 @@ use std::{io, net::AddrParseError, time::Duration}; use tokio::{ net::TcpStream, + sync::watch, time::{sleep, sleep_until, Instant}, }; @@ -48,6 +49,7 @@ pub async fn start

( config: BridgeConfig, router_tx: Sender<(ConnectionId, Event)>, protocol: P, + mut shutdown_rx: watch::Receiver<()>, ) -> Result<(), BridgeError> where P: Protocol + Clone + Send + 'static, @@ -154,6 +156,10 @@ where // resetting timeout because tokio::select! consumes the old timeout future timeout = sleep_until(ping_time + Duration::from_secs(config.ping_delay)); } + _ = shutdown_rx.changed() => { + debug!("Shutting down bridge"); + break 'outer Ok(()); + } } } } diff --git a/rumqttd/src/link/console.rs b/rumqttd/src/link/console.rs index 9f03f8a1e..d566e1896 100644 --- a/rumqttd/src/link/console.rs +++ b/rumqttd/src/link/console.rs @@ -9,8 +9,8 @@ use axum::Json; use axum::{routing::get, Router}; use flume::Sender; use std::sync::Arc; -use tokio::net::TcpListener; -use tracing::info; +use tokio::{net::TcpListener, sync::watch}; +use tracing::{debug, info}; #[derive(Debug)] pub struct ConsoleLink { @@ -39,7 +39,7 @@ impl ConsoleLink { } #[tracing::instrument] -pub async fn start(console: Arc) { +pub async fn start(console: Arc, mut shutdown_rx: watch::Receiver<()>) { let listener = TcpListener::bind(console.config.listen.clone()) .await .unwrap(); @@ -56,7 +56,13 @@ pub async fn start(console: Arc) { .route("/logs", post(logs)) .with_state(console); - axum::serve(listener, app).await.unwrap(); + axum::serve(listener, app) + .with_graceful_shutdown(async move { + debug!("Shutting down console"); + let _ = shutdown_rx.changed().await; + }) + .await + .unwrap(); } async fn root(State(console): State>) -> impl IntoResponse { diff --git a/rumqttd/src/link/timer.rs b/rumqttd/src/link/timer.rs index fcc3bf00c..e72d48d25 100644 --- a/rumqttd/src/link/timer.rs +++ b/rumqttd/src/link/timer.rs @@ -5,7 +5,8 @@ use crate::{router::Event, MetricType}; use crate::{ConnectionId, MetricSettings}; use flume::{SendError, Sender}; use tokio::select; -use tracing::error; +use tokio::sync::watch; +use tracing::{debug, error}; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -18,6 +19,7 @@ pub enum Error { pub async fn start( config: HashMap, router_tx: Sender<(ConnectionId, Event)>, + mut shutdown_rx: watch::Receiver<()>, ) { let span = tracing::info_span!("metrics_timer"); let _guard = span.enter(); @@ -42,6 +44,10 @@ pub async fn start( error!("Failed to push alerts: {e}"); } } + _ = shutdown_rx.changed() => { + debug!("Shutting down metrics timer"); + break; + } } } } diff --git a/rumqttd/src/server/broker.rs b/rumqttd/src/server/broker.rs index 9886541c9..20053a4b3 100644 --- a/rumqttd/src/server/broker.rs +++ b/rumqttd/src/server/broker.rs @@ -14,7 +14,7 @@ use flume::{RecvError, SendError, Sender}; use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::{Arc, Mutex}; -use tracing::{error, field, info, warn, Instrument}; +use tracing::{debug, error, field, info, warn, Instrument}; use uuid::Uuid; #[cfg(feature = "websocket")] @@ -39,6 +39,7 @@ use crate::router::{Event, Router}; use crate::{Config, ConnectionId, ServerSettings}; use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::watch; use tokio::time::error::Elapsed; use tokio::{task, time}; @@ -67,6 +68,8 @@ pub enum Error { pub struct Broker { config: Arc, router_tx: Sender<(ConnectionId, Event)>, + shutdown_rx: watch::Receiver<()>, + shutdown_tx: Arc>, } impl Broker { @@ -74,6 +77,7 @@ impl Broker { let config = Arc::new(config); let router_config = config.router.clone(); let router: Router = Router::new(config.id, router_config); + let (tx, shutdown_rx) = watch::channel(()); // Setup cluster if cluster settings are configured. match config.cluster.clone() { @@ -87,11 +91,21 @@ impl Broker { // Start router first and then cluster in the background let router_tx = router.spawn(); // cluster.spawn(); - Broker { config, router_tx } + Broker { + config, + router_tx, + shutdown_rx, + shutdown_tx: Arc::new(tx), + } } None => { let router_tx = router.spawn(); - Broker { config, router_tx } + Broker { + config, + router_tx, + shutdown_rx, + shutdown_tx: Arc::new(tx), + } } } } @@ -155,6 +169,11 @@ impl Broker { Ok((link_tx, link_rx)) } + /// Returns a `ShutdownHandler` that can be used to trigger a shutdown of the broker. + pub fn shutdown_handler(&mut self) -> ShutdownHandler { + ShutdownHandler(Some(self.shutdown_tx.clone())) + } + #[tracing::instrument(skip(self))] pub fn start(&mut self) -> Result<(), Error> { if self.config.v4.is_none() @@ -175,12 +194,13 @@ impl Broker { if let Some(metrics_config) = self.config.metrics.clone() { let timer_thread = thread::Builder::new().name("timer".to_owned()); let router_tx = self.router_tx.clone(); + let shutdown_rx = self.shutdown_rx.clone(); timer_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async move { - timer::start(metrics_config, router_tx).await; + timer::start(metrics_config, router_tx, shutdown_rx).await; }); })?; } @@ -189,12 +209,13 @@ impl Broker { if let Some(bridge_config) = self.config.bridge.clone() { let bridge_thread = thread::Builder::new().name(bridge_config.name.clone()); let router_tx = self.router_tx.clone(); + let shutdown_rx = self.shutdown_rx.clone(); bridge_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async move { - if let Err(e) = bridge::start(bridge_config, router_tx, V4).await { + if let Err(e) = bridge::start(bridge_config, router_tx, V4, shutdown_rx).await { error!(error=?e, "Bridge Link error"); }; }); @@ -206,13 +227,16 @@ impl Broker { for (_, config) in v4_config.clone() { let server_thread = thread::Builder::new().name(config.name.clone()); let mut server = Server::new(config, self.router_tx.clone(), V4); + let shutdown_rx = self.shutdown_rx.clone(); let handle = server_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async { - if let Err(e) = server.start(LinkType::Remote).await { + if let Err(e) = server.start(LinkType::Remote, shutdown_rx).await { error!(error=?e, "Server error - V4"); + } else { + debug!("Shutting down v4 server"); } }); })?; @@ -224,13 +248,16 @@ impl Broker { for (_, config) in v5_config.clone() { let server_thread = thread::Builder::new().name(config.name.clone()); let mut server = Server::new(config, self.router_tx.clone(), V5); + let shutdown_rx = self.shutdown_rx.clone(); let handle = server_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async { - if let Err(e) = server.start(LinkType::Remote).await { + if let Err(e) = server.start(LinkType::Remote, shutdown_rx).await { error!(error=?e, "Server error - V5"); + } else { + debug!("Shutting down v5 server"); } }); })?; @@ -249,13 +276,16 @@ impl Broker { let server_thread = thread::Builder::new().name(config.name.clone()); //TODO: Add support for V5 procotol with websockets. Registered in config or on ServerSettings let mut server = Server::new(config, self.router_tx.clone(), V4); + let shutdown_rx = self.shutdown_rx.clone(); let handle = server_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async { - if let Err(e) = server.start(LinkType::Websocket).await { + if let Err(e) = server.start(LinkType::Websocket, shutdown_rx).await { error!(error=?e, "Server error - WS"); + } else { + debug!("Shutting down websocket server"); } }); })?; @@ -280,6 +310,7 @@ impl Broker { }; let metrics_thread = thread::Builder::new().name("Metrics".to_owned()); let meter_link = self.meters().unwrap(); + let shutdown_rx = self.shutdown_rx.clone(); metrics_thread.spawn(move || { let builder = PrometheusBuilder::new().with_http_listener(addr); builder.install().unwrap(); @@ -301,6 +332,11 @@ impl Broker { } } + if shutdown_rx.has_changed().is_ok_and(|flag| flag) { + debug!("Shutting down metrics"); + break; + } + std::thread::sleep(Duration::from_secs(timeout)); } })?; @@ -311,10 +347,11 @@ impl Broker { let console_link = Arc::new(console_link); let console_thread = thread::Builder::new().name("Console".to_string()); + let shutdown_rx = self.shutdown_rx.clone(); console_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); - runtime.block_on(console::start(console_link)); + runtime.block_on(console::start(console_link, shutdown_rx)); })?; } @@ -379,7 +416,11 @@ impl Server

{ Ok((Box::new(stream), None)) } - async fn start(&mut self, link_type: LinkType) -> Result<(), Error> { + async fn start( + &mut self, + link_type: LinkType, + mut shutdown_rx: watch::Receiver<()>, + ) -> Result<(), Error> { let listener = TcpListener::bind(&self.config.listen).await?; let delay = Duration::from_millis(self.config.next_connection_delay_ms); let mut count: usize = 0; @@ -392,19 +433,33 @@ impl Server

{ ); loop { // Await new network connection. - let (stream, addr) = match listener.accept().await { - Ok((s, r)) => (s, r), - Err(e) => { - error!(error=?e, "Unable to accept socket."); - continue; + let (stream, addr) = tokio::select! { + accept = listener.accept() => { + match accept { + Ok((s, r)) => (s, r), + Err(e) => { + error!(error=?e, "Unable to accept socket."); + continue; + } + } + } + _ = shutdown_rx.changed() => { + return Ok(()); } }; - let (network, tenant_id) = match self.tls_accept(stream).await { - Ok(o) => o, - Err(e) => { - error!(error=?e, "Tls accept error"); - continue; + let (network, tenant_id) = tokio::select! { + accept = self.tls_accept(stream) => { + match accept { + Ok(o) => o, + Err(e) => { + error!(error=?e, "Tls accept error"); + continue; + } + } + } + _ = shutdown_rx.changed() => { + return Ok(()); } }; @@ -420,11 +475,18 @@ impl Server

{ match link_type { #[cfg(feature = "websocket")] LinkType::Websocket => { - let stream = match accept_hdr_async(network, WSCallback).await { - Ok(s) => Box::new(WsStream::new(s)), - Err(e) => { - error!(error=?e, "Websocket failed handshake"); - continue; + let stream = tokio::select! { + hdr_accept = accept_hdr_async(network, WSCallback) => { + match hdr_accept { + Ok(s) => Box::new(WsStream::new(s)), + Err(e) => { + error!(error=?e, "Websocket failed handshake"); + continue; + } + } + } + _ = shutdown_rx.changed() => { + return Ok(()); } }; task::spawn( @@ -461,7 +523,12 @@ impl Server

{ ), }; - time::sleep(delay).await; + tokio::select! { + _ = time::sleep(delay) => {} + _ = shutdown_rx.changed() => { + return Ok(()); + } + }; } } } @@ -627,3 +694,52 @@ async fn remote( router_tx.send((connection_id, message)).ok(); } } + +/// A struct that holds a shutdown handler for the broker. +/// +/// The `ShutdownHandler` struct is responsible for managing the shutdown of the broker. +/// The `shutdown_broker` method can be used to send the shutdown signal, and the `drop_guard` +/// method can be used to create a `ShutdownDropGuard` that will automatically send the shutdown +/// signal when it is dropped. +#[derive(Debug)] +pub struct ShutdownHandler(Option>>); + +impl ShutdownHandler { + /// Sends a shutdown signal to the broker. + /// This method takes ownership of the `ShutdownHandler` and sends a shutdown signal + pub fn shutdown_broker(mut self) { + if let Some(handler) = self.0.take() { + let _ = handler.send(()); + } + } + + /// Creates a `ShutdownDropGuard` that will automatically send a shutdown signal + /// to the broker when it is dropped. + #[inline] + pub fn drop_guard(mut self) -> ShutdownDropGuard { + ShutdownDropGuard(self.0.take()) + } +} + +/// A struct that holds a shutdown signal for the broker. +/// +/// The `ShutdownDropGuard` struct is responsible for managing the shutdown of the broker. +/// When the `ShutdownDropGuard` is dropped, it will automatically send the shutdown signal +#[derive(Debug)] +pub struct ShutdownDropGuard(Option>>); + +impl ShutdownDropGuard { + #[inline] + /// Disarms the `ShutdownHandler` and returns it, allowing the caller to manage the shutdown signal. + pub fn disarm(mut self) -> ShutdownHandler { + ShutdownHandler(self.0.take()) + } +} + +impl Drop for ShutdownDropGuard { + fn drop(&mut self) { + if let Some(handler) = self.0.take() { + let _ = handler.send(()); + } + } +}