diff --git a/src/accumulator/proof.rs b/src/accumulator/proof.rs index c0d8836..d5c4f68 100644 --- a/src/accumulator/proof.rs +++ b/src/accumulator/proof.rs @@ -379,51 +379,88 @@ impl Proof { // Nodes must be sorted for finding siblings during hashing nodes.sort(); - let mut i = 0; - while i < nodes.len() { - let (pos1, hash1) = nodes[i]; - let next_to_prove = util::parent(pos1, total_rows); - - // If the current position is a root, we add that to our result and don't go any further - if util::is_root_position(pos1, num_leaves, total_rows) { - calculated_root_hashes.push(hash1); - i += 1; + let mut computed = Vec::with_capacity(nodes.len() * 2); + let mut computed_index = 0; + let mut provided_index = 0; + loop { + let Some(next_vec) = Self::get_next(&computed, &nodes, computed_index, provided_index) + else { + break; + }; + + let (next_pos, next_hash) = match next_vec { + true => { + computed_index += 1; + computed[computed_index - 1] + } + false => { + provided_index += 1; + nodes[provided_index - 1] + } + }; + + if util::is_root_position(next_pos, num_leaves, total_rows) { + calculated_root_hashes.push(next_hash); continue; } - let Some((pos2, hash2)) = nodes.get(i + 1) else { - return Err(format!( - "Proof is too short. Expected at least {} elements, got {}", - i + 1, - nodes.len() - )); + let sibling = next_pos | 1; + let sibling_vec = Self::get_next(&computed, &nodes, computed_index, provided_index) + .ok_or(format!("Missing sibling for {}", next_pos))?; + let (sibling_pos, sibling_hash) = match sibling_vec { + true => { + computed_index += 1; + computed[computed_index - 1] + } + false => { + provided_index += 1; + nodes[provided_index - 1] + } }; - if pos1 != util::left_sibling(*pos2) { - return Err(format!( - "Invalid proof. Expected left sibling of {} to be {}, got {}", - pos2, - util::left_sibling(*pos2), - pos1 - )); + if sibling_pos != sibling { + return Err(format!("Missing sibling for {}", next_pos)); } - let parent_hash = match (hash1.is_empty(), hash2.is_empty()) { + let parent_hash = match (next_hash.is_empty(), sibling_hash.is_empty()) { (true, true) => NodeHash::empty(), - (true, false) => *hash2, - (false, true) => hash1, - (false, false) => NodeHash::parent_hash(&hash1, hash2), + (true, false) => sibling_hash, + (false, true) => next_hash, + (false, false) => NodeHash::parent_hash(&next_hash, &sibling_hash), }; - Self::sorted_push(&mut nodes, (next_to_prove, parent_hash)); - i += 2; + let parent = util::parent(next_pos, total_rows); + computed.push((parent, parent_hash)); } // we shouldn't return the hashes in the proof + nodes.extend(computed); nodes.retain(|(pos, _)| !proof_positions.contains(pos)); - Ok((nodes, calculated_root_hashes)) } + + fn get_next( + computed: &[(u64, NodeHash)], + provided: &[(u64, NodeHash)], + computed_pos: usize, + provided_pos: usize, + ) -> Option { + let last_computed = computed.get(computed_pos); + let last_provided = provided.get(provided_pos); + + match (last_computed, last_provided) { + (Some((pos1, _)), Some((pos2, _))) => { + if pos1 < pos2 { + Some(true) + } else { + Some(false) + } + } + (Some(_), None) => Some(true), + (None, Some(_)) => Some(false), + (None, None) => None, + } + } /// Uses the data passed in to update a proof, creating a valid proof for a given /// set of targets, after an update. This is useful for caching UTXOs. You grab a proof /// for it once and then keep updating it every block, yielding an always valid proof @@ -699,12 +736,6 @@ impl Proof { new_positions.sort(); Ok(new_positions) } - fn sorted_push(nodes: &mut Vec<(u64, NodeHash)>, to_add: (u64, NodeHash)) { - let pos = nodes - .binary_search_by(|(pos, _)| pos.cmp(&to_add.0)) - .unwrap_or_else(|x| x); - nodes.insert(pos, to_add); - } } #[cfg(test)] @@ -854,6 +885,20 @@ mod tests { assert_eq!(cached_hashes, expected_cached_hashes); } } + + #[test] + fn test_get_next() { + use super::Proof; + let computed = vec![(1, NodeHash::empty()), (3, NodeHash::empty())]; + let provided = vec![(2, NodeHash::empty()), (4, NodeHash::empty())]; + + assert_eq!(Proof::get_next(&computed, &provided, 0, 0), Some(true)); + assert_eq!(Proof::get_next(&computed, &provided, 1, 0), Some(false)); + assert_eq!(Proof::get_next(&computed, &provided, 1, 1), Some(true)); + assert_eq!(Proof::get_next(&computed, &provided, 1, 2), Some(true)); + assert_eq!(Proof::get_next(&computed, &provided, 2, 2), None); + } + #[test] fn test_calc_next_positions() { use super::Proof;