From 0381e30565af940e151b135cb99ca234c23f33bb Mon Sep 17 00:00:00 2001 From: keepsimple1 Date: Sat, 5 Oct 2024 13:28:18 -0700 Subject: [PATCH] dns_cache: address record should only flush on the same network (#261) --- src/dns_cache.rs | 29 ++++++++++++--- src/service_daemon.rs | 35 +++++++++--------- src/service_info.rs | 82 ++++++++++++++++++++++++++++++++++++++++++- tests/mdns_test.rs | 10 +++--- 4 files changed, 129 insertions(+), 27 deletions(-) diff --git a/src/dns_cache.rs b/src/dns_cache.rs index 5f9dc05..e2d56fd 100644 --- a/src/dns_cache.rs +++ b/src/dns_cache.rs @@ -2,12 +2,16 @@ //! //! This is an internal implementation, not visible to the public API. -use crate::dns_parser::{ - current_time_millis, split_sub_domain, DnsAddress, DnsPointer, DnsRecordBox, DnsSrv, TYPE_A, - TYPE_AAAA, TYPE_NSEC, TYPE_PTR, TYPE_SRV, TYPE_TXT, -}; #[cfg(feature = "logging")] use crate::log::debug; +use crate::{ + dns_parser::{ + current_time_millis, split_sub_domain, DnsAddress, DnsPointer, DnsRecordBox, DnsSrv, + TYPE_A, TYPE_AAAA, TYPE_NSEC, TYPE_PTR, TYPE_SRV, TYPE_TXT, + }, + service_info::valid_two_addrs_on_intf, +}; +use if_addrs::Interface; use std::{ collections::{HashMap, HashSet}, net::IpAddr, @@ -112,6 +116,7 @@ impl DnsCache { /// If need to add new timers for related records, push into `timers`. pub(crate) fn add_or_update( &mut self, + intf: &Interface, incoming: DnsRecordBox, timers: &mut Vec, ) -> Option<(&DnsRecordBox, bool)> { @@ -153,11 +158,27 @@ impl DnsCache { // Ref: RFC 6762 Section 10.2 // // Note: when the updated record actually expires, it will trigger events properly. + let mut should_flush = false; + if class == r.get_class() && rtype == r.get_type() && now > r.get_created() + 1000 && r.get_expire() > now + 1000 { + should_flush = true; + + // additional checks for address records. + if rtype == TYPE_A || rtype == TYPE_AAAA { + if let Some(addr) = r.any().downcast_ref::() { + if let Some(addr_b) = incoming.any().downcast_ref::() { + should_flush = + valid_two_addrs_on_intf(&addr.address, &addr_b.address, intf); + } + } + } + } + + if should_flush { debug!("FLUSH one record: {:?}", &r); let new_expire = now + 1000; r.set_expire(new_expire); diff --git a/src/service_daemon.rs b/src/service_daemon.rs index 73f8264..9efadcc 100644 --- a/src/service_daemon.rs +++ b/src/service_daemon.rs @@ -945,7 +945,8 @@ impl Zeroconf { let my_ifaddrs = my_ip_interfaces(); // Create a socket for every IP addr. - // Note: it is possible that `my_ifaddrs` contains duplicated IP addrs. + // Note: it is possible that `my_ifaddrs` contains the same IP addr with different interface names, + // or the same interface name with different IP addrs. let mut intf_socks = HashMap::new(); for intf in my_ifaddrs { let sock = match new_socket_bind(&intf) { @@ -1054,7 +1055,8 @@ impl Zeroconf { Self::add_poll_impl(&mut self.poll_ids, &mut self.poll_id_count, intf) } - /// Insert a new interface into the poll map and return key + /// Insert a new interface into the poll map and return its key. + /// /// This exist to satisfy the borrow checker fn add_poll_impl( poll_ids: &mut HashMap, @@ -1382,13 +1384,6 @@ impl Zeroconf { send_dns_outgoing(&out, intf, sock).remove(0) } - /// Binds a channel `listener` to querying mDNS domain type `ty`. - /// - /// If there is already a `listener`, it will be updated, i.e. overwritten. - fn add_service_querier(&mut self, ty: String, listener: Sender) { - self.service_queriers.insert(ty, listener); - } - /// Binds a channel `listener` to querying mDNS hostnames. /// /// If there is already a `listener`, it will be updated, i.e. overwritten. @@ -1495,7 +1490,7 @@ impl Zeroconf { if msg.is_query() { self.handle_query(msg, intf); } else if msg.is_response() { - self.handle_response(msg); + self.handle_response(msg, intf); } else { error!("Invalid message: not query and not response"); } @@ -1532,7 +1527,7 @@ impl Zeroconf { /// Checks if `ty_domain` has records in the cache. If yes, sends the /// cached records via `sender`. - fn query_cache_for_service(&mut self, ty_domain: &str, sender: Sender) { + fn query_cache_for_service(&mut self, ty_domain: &str, sender: &Sender) { let mut resolved: HashSet = HashSet::new(); let mut unresolved: HashSet = HashSet::new(); @@ -1668,7 +1663,7 @@ impl Zeroconf { /// Deal with incoming response packets. All answers /// are held in the cache, and listeners are notified. - fn handle_response(&mut self, mut msg: DnsIncoming) { + fn handle_response(&mut self, mut msg: DnsIncoming, intf: &Interface) { debug!( "handle_response: {} answers {} authorities {} additionals", &msg.answers.len(), @@ -1716,7 +1711,7 @@ impl Zeroconf { let mut changes = Vec::new(); let mut timers = Vec::new(); for record in msg.answers { - match self.cache.add_or_update(record, &mut timers) { + match self.cache.add_or_update(intf, record, &mut timers) { Some((dns_record, true)) => { timers.push(dns_record.get_record().get_expire_time()); timers.push(dns_record.get_record().get_refresh_time()); @@ -1895,7 +1890,7 @@ impl Zeroconf { TYPE_AAAA => "TYPE_AAAA", _ => "invalid_type", }; - error!( + debug!( "Cannot find valid addrs for {} response on intf {:?}", t, &intf ); @@ -2042,8 +2037,8 @@ impl Zeroconf { .collect(); if let Err(e) = listener.send(ServiceEvent::SearchStarted(format!( - "{} on addrs [{}]", - &ty, + "{ty} on {} interfaces [{}]", + pretty_addrs.len(), pretty_addrs.join(", ") ))) { error!( @@ -2053,9 +2048,13 @@ impl Zeroconf { return; } if !repeating { - self.add_service_querier(ty.clone(), listener.clone()); + // Binds a `listener` to querying mDNS domain type `ty`. + // + // If there is already a `listener`, it will be updated, i.e. overwritten. + self.service_queriers.insert(ty.clone(), listener.clone()); + // if we already have the records in our cache, just send them - self.query_cache_for_service(&ty, listener.clone()); + self.query_cache_for_service(&ty, &listener); } self.send_query(&ty, TYPE_PTR); diff --git a/src/service_info.rs b/src/service_info.rs index 139d99d..aacdd46 100644 --- a/src/service_info.rs +++ b/src/service_info.rs @@ -726,9 +726,34 @@ pub fn valid_ip_on_intf(addr: &IpAddr, intf: &Interface) -> bool { } } +/// Returns true if `addr_a` and `addr_b` are in the same network as `intf`. +pub fn valid_two_addrs_on_intf(addr_a: &IpAddr, addr_b: &IpAddr, intf: &Interface) -> bool { + match (addr_a, addr_b, &intf.addr) { + (IpAddr::V4(ipv4_a), IpAddr::V4(ipv4_b), IfAddr::V4(intf)) => { + let netmask = u32::from(intf.netmask); + let intf_net = u32::from(intf.ip) & netmask; + let net_a = u32::from(*ipv4_a) & netmask; + let net_b = u32::from(*ipv4_b) & netmask; + net_a == intf_net && net_b == intf_net + } + (IpAddr::V6(ipv6_a), IpAddr::V6(ipv6_b), IfAddr::V6(intf)) => { + let netmask = u128::from(intf.netmask); + let intf_net = u128::from(intf.ip) & netmask; + let net_a = u128::from(*ipv6_a) & netmask; + let net_b = u128::from(*ipv6_b) & netmask; + net_a == intf_net && net_b == intf_net + } + _ => false, + } +} + #[cfg(test)] mod tests { - use super::{decode_txt, encode_txt, u8_slice_to_hex, ServiceInfo, TxtProperty}; + use super::{ + decode_txt, encode_txt, u8_slice_to_hex, valid_two_addrs_on_intf, ServiceInfo, TxtProperty, + }; + use if_addrs::{IfAddr, Ifv4Addr, Ifv6Addr, Interface}; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; #[test] fn test_txt_encode_decode() { @@ -876,4 +901,59 @@ mod tests { // Test that the key of the property we parsed is "key1" assert_eq!(decoded[0].key, "key1"); } + + #[test] + fn test_valid_two_addrs_on_intf() { + // test IPv4 + + let ipv4_netmask = Ipv4Addr::new(192, 168, 1, 0); + let ipv4_intf_addr = IfAddr::V4(Ifv4Addr { + ip: Ipv4Addr::new(192, 168, 1, 10), + netmask: ipv4_netmask, + prefixlen: 24, + broadcast: None, + }); + let ipv4_intf = Interface { + name: "e0".to_string(), + addr: ipv4_intf_addr, + index: Some(1), + #[cfg(windows)] + adapter_name: "ethernet".to_string(), + }; + let ipv4_a = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 10)); + let ipv4_b = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 11)); + + let result = valid_two_addrs_on_intf(&ipv4_a, &ipv4_b, &ipv4_intf); + assert!(result); + + let ipv4_c = IpAddr::V4(Ipv4Addr::new(172, 17, 0, 1)); + let result = valid_two_addrs_on_intf(&ipv4_a, &ipv4_c, &ipv4_intf); + assert!(!result); + + // test IPv6 (generated by AI) + + let ipv6_netmask = Ipv6Addr::new(0xffff, 0xffff, 0, 0, 0, 0, 0, 0); // Equivalent to /32 prefix length + let ipv6_intf_addr = IfAddr::V6(Ifv6Addr { + ip: Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1), + netmask: ipv6_netmask, + prefixlen: 32, + broadcast: None, + }); + let ipv6_intf = Interface { + name: "eth0".to_string(), + addr: ipv6_intf_addr, + index: Some(2), + #[cfg(windows)] + adapter_name: "ethernet".to_string(), + }; + let ipv6_a = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)); + let ipv6_b = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2)); + + let result = valid_two_addrs_on_intf(&ipv6_a, &ipv6_b, &ipv6_intf); + assert!(result); // Expect true since both addresses are in the same subnet + + let ipv6_c = IpAddr::V6(Ipv6Addr::new(0x2002, 0xdb8, 0, 0, 0, 0, 0, 1)); + let result = valid_two_addrs_on_intf(&ipv6_a, &ipv6_c, &ipv6_intf); + assert!(!result); // Expect false since addresses are in different subnets + } } diff --git a/tests/mdns_test.rs b/tests/mdns_test.rs index 7986650..41ed2dc 100644 --- a/tests/mdns_test.rs +++ b/tests/mdns_test.rs @@ -36,7 +36,9 @@ fn integration_success() { // use the same approach as `IntfSock.multicast_send_tracker` if let Some(idx) = intf.index { - unique_intf_idx_ip_ver_set.insert((idx, ip_ver)); + if !unique_intf_idx_ip_ver_set.insert((idx, ip_ver)) { + println!("index {idx} IP v{ip_ver} repeated on interface {}, likely multi-addr on the same interface", intf.name); + } } else { non_idx_count += 1; } @@ -973,7 +975,7 @@ fn my_ip_interfaces() -> Vec { match std::net::UdpSocket::bind((ifv4.ip, test_port)) { Ok(_) => Some(i), Err(e) => { - println!("bind {}: {}, skipped.", ifv4.ip, e); + println!("failed to bind {}: {e}, skipped.", ifv4.ip); None } } @@ -990,7 +992,7 @@ fn my_ip_interfaces() -> Vec { match std::net::UdpSocket::bind(sock) { Ok(_) => Some(i), Err(e) => { - println!("bind {}: {}, skipped.", ifv6.ip, e); + println!("failed to bind {}: {e}, skipped.", ifv6.ip); None } } @@ -1211,7 +1213,7 @@ fn test_cache_flush_record() { // Stop browsing for a moment. client.stop_browse(service).unwrap(); - sleep(Duration::from_secs(1)); + sleep(Duration::from_secs(2)); // Let the cache record be surely older than 1 second. // Modify the IPv4 address for the service. if let IpAddr::V4(ipv4) = service_ip_addr {