Skip to content

Commit

Permalink
dns_cache: address record should only flush on the same network (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
keepsimple1 authored Oct 5, 2024
1 parent d5e9d9c commit 0381e30
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 27 deletions.
29 changes: 25 additions & 4 deletions src/dns_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
//!
//! This is an internal implementation, not visible to the public API.
use crate::dns_parser::{
current_time_millis, split_sub_domain, DnsAddress, DnsPointer, DnsRecordBox, DnsSrv, TYPE_A,
TYPE_AAAA, TYPE_NSEC, TYPE_PTR, TYPE_SRV, TYPE_TXT,
};
#[cfg(feature = "logging")]
use crate::log::debug;
use crate::{
dns_parser::{
current_time_millis, split_sub_domain, DnsAddress, DnsPointer, DnsRecordBox, DnsSrv,
TYPE_A, TYPE_AAAA, TYPE_NSEC, TYPE_PTR, TYPE_SRV, TYPE_TXT,
},
service_info::valid_two_addrs_on_intf,
};
use if_addrs::Interface;
use std::{
collections::{HashMap, HashSet},
net::IpAddr,
Expand Down Expand Up @@ -112,6 +116,7 @@ impl DnsCache {
/// If need to add new timers for related records, push into `timers`.
pub(crate) fn add_or_update(
&mut self,
intf: &Interface,
incoming: DnsRecordBox,
timers: &mut Vec<u64>,
) -> Option<(&DnsRecordBox, bool)> {
Expand Down Expand Up @@ -153,11 +158,27 @@ impl DnsCache {
// Ref: RFC 6762 Section 10.2
//
// Note: when the updated record actually expires, it will trigger events properly.
let mut should_flush = false;

if class == r.get_class()
&& rtype == r.get_type()
&& now > r.get_created() + 1000
&& r.get_expire() > now + 1000
{
should_flush = true;

// additional checks for address records.
if rtype == TYPE_A || rtype == TYPE_AAAA {
if let Some(addr) = r.any().downcast_ref::<DnsAddress>() {
if let Some(addr_b) = incoming.any().downcast_ref::<DnsAddress>() {
should_flush =
valid_two_addrs_on_intf(&addr.address, &addr_b.address, intf);
}
}
}
}

if should_flush {
debug!("FLUSH one record: {:?}", &r);
let new_expire = now + 1000;
r.set_expire(new_expire);
Expand Down
35 changes: 17 additions & 18 deletions src/service_daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,8 @@ impl Zeroconf {
let my_ifaddrs = my_ip_interfaces();

// Create a socket for every IP addr.
// Note: it is possible that `my_ifaddrs` contains duplicated IP addrs.
// Note: it is possible that `my_ifaddrs` contains the same IP addr with different interface names,
// or the same interface name with different IP addrs.
let mut intf_socks = HashMap::new();
for intf in my_ifaddrs {
let sock = match new_socket_bind(&intf) {
Expand Down Expand Up @@ -1054,7 +1055,8 @@ impl Zeroconf {
Self::add_poll_impl(&mut self.poll_ids, &mut self.poll_id_count, intf)
}

/// Insert a new interface into the poll map and return key
/// Insert a new interface into the poll map and return its key.
///
/// This exist to satisfy the borrow checker
fn add_poll_impl(
poll_ids: &mut HashMap<usize, Interface>,
Expand Down Expand Up @@ -1382,13 +1384,6 @@ impl Zeroconf {
send_dns_outgoing(&out, intf, sock).remove(0)
}

/// Binds a channel `listener` to querying mDNS domain type `ty`.
///
/// If there is already a `listener`, it will be updated, i.e. overwritten.
fn add_service_querier(&mut self, ty: String, listener: Sender<ServiceEvent>) {
self.service_queriers.insert(ty, listener);
}

/// Binds a channel `listener` to querying mDNS hostnames.
///
/// If there is already a `listener`, it will be updated, i.e. overwritten.
Expand Down Expand Up @@ -1495,7 +1490,7 @@ impl Zeroconf {
if msg.is_query() {
self.handle_query(msg, intf);
} else if msg.is_response() {
self.handle_response(msg);
self.handle_response(msg, intf);
} else {
error!("Invalid message: not query and not response");
}
Expand Down Expand Up @@ -1532,7 +1527,7 @@ impl Zeroconf {

/// Checks if `ty_domain` has records in the cache. If yes, sends the
/// cached records via `sender`.
fn query_cache_for_service(&mut self, ty_domain: &str, sender: Sender<ServiceEvent>) {
fn query_cache_for_service(&mut self, ty_domain: &str, sender: &Sender<ServiceEvent>) {
let mut resolved: HashSet<String> = HashSet::new();
let mut unresolved: HashSet<String> = HashSet::new();

Expand Down Expand Up @@ -1668,7 +1663,7 @@ impl Zeroconf {

/// Deal with incoming response packets. All answers
/// are held in the cache, and listeners are notified.
fn handle_response(&mut self, mut msg: DnsIncoming) {
fn handle_response(&mut self, mut msg: DnsIncoming, intf: &Interface) {
debug!(
"handle_response: {} answers {} authorities {} additionals",
&msg.answers.len(),
Expand Down Expand Up @@ -1716,7 +1711,7 @@ impl Zeroconf {
let mut changes = Vec::new();
let mut timers = Vec::new();
for record in msg.answers {
match self.cache.add_or_update(record, &mut timers) {
match self.cache.add_or_update(intf, record, &mut timers) {
Some((dns_record, true)) => {
timers.push(dns_record.get_record().get_expire_time());
timers.push(dns_record.get_record().get_refresh_time());
Expand Down Expand Up @@ -1895,7 +1890,7 @@ impl Zeroconf {
TYPE_AAAA => "TYPE_AAAA",
_ => "invalid_type",
};
error!(
debug!(
"Cannot find valid addrs for {} response on intf {:?}",
t, &intf
);
Expand Down Expand Up @@ -2042,8 +2037,8 @@ impl Zeroconf {
.collect();

if let Err(e) = listener.send(ServiceEvent::SearchStarted(format!(
"{} on addrs [{}]",
&ty,
"{ty} on {} interfaces [{}]",
pretty_addrs.len(),
pretty_addrs.join(", ")
))) {
error!(
Expand All @@ -2053,9 +2048,13 @@ impl Zeroconf {
return;
}
if !repeating {
self.add_service_querier(ty.clone(), listener.clone());
// Binds a `listener` to querying mDNS domain type `ty`.
//
// If there is already a `listener`, it will be updated, i.e. overwritten.
self.service_queriers.insert(ty.clone(), listener.clone());

// if we already have the records in our cache, just send them
self.query_cache_for_service(&ty, listener.clone());
self.query_cache_for_service(&ty, &listener);
}

self.send_query(&ty, TYPE_PTR);
Expand Down
82 changes: 81 additions & 1 deletion src/service_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -726,9 +726,34 @@ pub fn valid_ip_on_intf(addr: &IpAddr, intf: &Interface) -> bool {
}
}

/// Returns true if `addr_a` and `addr_b` are in the same network as `intf`.
pub fn valid_two_addrs_on_intf(addr_a: &IpAddr, addr_b: &IpAddr, intf: &Interface) -> bool {
match (addr_a, addr_b, &intf.addr) {
(IpAddr::V4(ipv4_a), IpAddr::V4(ipv4_b), IfAddr::V4(intf)) => {
let netmask = u32::from(intf.netmask);
let intf_net = u32::from(intf.ip) & netmask;
let net_a = u32::from(*ipv4_a) & netmask;
let net_b = u32::from(*ipv4_b) & netmask;
net_a == intf_net && net_b == intf_net
}
(IpAddr::V6(ipv6_a), IpAddr::V6(ipv6_b), IfAddr::V6(intf)) => {
let netmask = u128::from(intf.netmask);
let intf_net = u128::from(intf.ip) & netmask;
let net_a = u128::from(*ipv6_a) & netmask;
let net_b = u128::from(*ipv6_b) & netmask;
net_a == intf_net && net_b == intf_net
}
_ => false,
}
}

#[cfg(test)]
mod tests {
use super::{decode_txt, encode_txt, u8_slice_to_hex, ServiceInfo, TxtProperty};
use super::{
decode_txt, encode_txt, u8_slice_to_hex, valid_two_addrs_on_intf, ServiceInfo, TxtProperty,
};
use if_addrs::{IfAddr, Ifv4Addr, Ifv6Addr, Interface};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};

#[test]
fn test_txt_encode_decode() {
Expand Down Expand Up @@ -876,4 +901,59 @@ mod tests {
// Test that the key of the property we parsed is "key1"
assert_eq!(decoded[0].key, "key1");
}

#[test]
fn test_valid_two_addrs_on_intf() {
// test IPv4

let ipv4_netmask = Ipv4Addr::new(192, 168, 1, 0);
let ipv4_intf_addr = IfAddr::V4(Ifv4Addr {
ip: Ipv4Addr::new(192, 168, 1, 10),
netmask: ipv4_netmask,
prefixlen: 24,
broadcast: None,
});
let ipv4_intf = Interface {
name: "e0".to_string(),
addr: ipv4_intf_addr,
index: Some(1),
#[cfg(windows)]
adapter_name: "ethernet".to_string(),
};
let ipv4_a = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 10));
let ipv4_b = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 11));

let result = valid_two_addrs_on_intf(&ipv4_a, &ipv4_b, &ipv4_intf);
assert!(result);

let ipv4_c = IpAddr::V4(Ipv4Addr::new(172, 17, 0, 1));
let result = valid_two_addrs_on_intf(&ipv4_a, &ipv4_c, &ipv4_intf);
assert!(!result);

// test IPv6 (generated by AI)

let ipv6_netmask = Ipv6Addr::new(0xffff, 0xffff, 0, 0, 0, 0, 0, 0); // Equivalent to /32 prefix length
let ipv6_intf_addr = IfAddr::V6(Ifv6Addr {
ip: Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
netmask: ipv6_netmask,
prefixlen: 32,
broadcast: None,
});
let ipv6_intf = Interface {
name: "eth0".to_string(),
addr: ipv6_intf_addr,
index: Some(2),
#[cfg(windows)]
adapter_name: "ethernet".to_string(),
};
let ipv6_a = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
let ipv6_b = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2));

let result = valid_two_addrs_on_intf(&ipv6_a, &ipv6_b, &ipv6_intf);
assert!(result); // Expect true since both addresses are in the same subnet

let ipv6_c = IpAddr::V6(Ipv6Addr::new(0x2002, 0xdb8, 0, 0, 0, 0, 0, 1));
let result = valid_two_addrs_on_intf(&ipv6_a, &ipv6_c, &ipv6_intf);
assert!(!result); // Expect false since addresses are in different subnets
}
}
10 changes: 6 additions & 4 deletions tests/mdns_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ fn integration_success() {

// use the same approach as `IntfSock.multicast_send_tracker`
if let Some(idx) = intf.index {
unique_intf_idx_ip_ver_set.insert((idx, ip_ver));
if !unique_intf_idx_ip_ver_set.insert((idx, ip_ver)) {
println!("index {idx} IP v{ip_ver} repeated on interface {}, likely multi-addr on the same interface", intf.name);
}
} else {
non_idx_count += 1;
}
Expand Down Expand Up @@ -973,7 +975,7 @@ fn my_ip_interfaces() -> Vec<Interface> {
match std::net::UdpSocket::bind((ifv4.ip, test_port)) {
Ok(_) => Some(i),
Err(e) => {
println!("bind {}: {}, skipped.", ifv4.ip, e);
println!("failed to bind {}: {e}, skipped.", ifv4.ip);
None
}
}
Expand All @@ -990,7 +992,7 @@ fn my_ip_interfaces() -> Vec<Interface> {
match std::net::UdpSocket::bind(sock) {
Ok(_) => Some(i),
Err(e) => {
println!("bind {}: {}, skipped.", ifv6.ip, e);
println!("failed to bind {}: {e}, skipped.", ifv6.ip);
None
}
}
Expand Down Expand Up @@ -1211,7 +1213,7 @@ fn test_cache_flush_record() {

// Stop browsing for a moment.
client.stop_browse(service).unwrap();
sleep(Duration::from_secs(1));
sleep(Duration::from_secs(2)); // Let the cache record be surely older than 1 second.

// Modify the IPv4 address for the service.
if let IpAddr::V4(ipv4) = service_ip_addr {
Expand Down

0 comments on commit 0381e30

Please sign in to comment.