Skip to content

Commit

Permalink
Pol: ask for hashes instead of positions + tests
Browse files Browse the repository at this point in the history
This commit makes all proving-related methods take leaf hashes instead
of positions, and adds additional tests to proving and ingesting proofs
  • Loading branch information
Davidson-Souza committed Jan 29, 2025
1 parent 76c3bf3 commit 35602af
Showing 1 changed file with 112 additions and 28 deletions.
140 changes: 112 additions & 28 deletions src/accumulator/pollard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,11 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
self.do_ingest_proof(proof, del_hashes, remembers, false)
}

pub fn verify(&self, proof: &Proof<Hash>, del_hashes: &[Hash]) -> Result<bool, String> {
let roots = self.roots();
proof.verify(del_hashes, &roots, self.leaves)
}

pub fn verify_and_ingest(
&mut self,
proof: Proof<Hash>,
Expand All @@ -451,6 +456,8 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
}

pub fn prune(&mut self, positions: &[u64]) -> Result<(), &'static str> {
self.prune_map(positions);

let positions = detwin(positions.to_vec(), tree_rows(self.leaves));
let nodes = positions
.into_iter()
Expand Down Expand Up @@ -481,21 +488,21 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
///
/// This function takes a list of positions and returns a list of proofs for each position.
pub fn batch_proof(&self, targets: &[Hash]) -> Result<Proof<Hash>, String> {
let mut positions = Vec::new();
let mut target_positions = Vec::new();
for target in targets {
let node = self
.leaf_map
.get(target)
.ok_or(format!("leaf {target} not found in the forest"))?;
let position = self.get_pos(node)?;
positions.push(position);
target_positions.push(position);
}

let targets = detwin(positions, tree_rows(self.leaves));
let positions = get_proof_positions(&targets, self.leaves, tree_rows(self.leaves));
let proof_positions =
get_proof_positions(&target_positions, self.leaves, tree_rows(self.leaves));
let mut proof_hashes = Vec::new();

for pos in positions.iter() {
for pos in proof_positions.iter() {
let hash = self
.grab_position(*pos)
.ok_or("Position not found")?
Expand All @@ -507,7 +514,7 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {

Ok(Proof::<Hash> {
hashes: proof_hashes,
targets: positions,
targets: target_positions,
})
}

Expand Down Expand Up @@ -581,6 +588,13 @@ type AddSingleResult<T> = (Vec<(u64, T)>, Vec<usize>);
type ChildrenTuple<Hash> = (Rc<PollardNode<Hash>>, Rc<PollardNode<Hash>>);

impl<Hash: AccumulatorHash> Pollard<Hash> {
fn prune_map(&mut self, positions: &[u64]) {
for pos in positions {
let node = self.grab_position(*pos).unwrap().0;
self.leaf_map.remove(&node.hash());
}
}

fn grab_position(&self, pos: u64) -> Option<ChildrenTuple<Hash>> {
let (root, depth, bits) = Self::detect_offset(pos, self.leaves);
let mut node = self.roots[root as usize].clone()?;
Expand Down Expand Up @@ -608,6 +622,7 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
fn ingest_positions(
&mut self,
mut iter: impl Iterator<Item = (u64, Hash)>,
remembers: &[u64],
) -> Result<(), String> {
let forest_rows = tree_rows(self.leaves);
while let Some((pos1, hash1)) = iter.next() {
Expand Down Expand Up @@ -638,6 +653,11 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
new_node.set_aunt(Rc::downgrade(&aunt));
new_sibling.set_aunt(Rc::downgrade(&aunt));

if remembers.contains(&pos1) || remembers.contains(&pos2) {
self.leaf_map.insert(hash1, Rc::downgrade(&new_node));
self.leaf_map.insert(hash2, Rc::downgrade(&new_sibling));
}

aunt.set_niece(Some(new_sibling), Some(new_node));
}

Expand All @@ -658,7 +678,7 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
all_nodes.extend(proof_positions.into_iter().zip(proof.hashes.clone()));
all_nodes.sort();
let iter = all_nodes.into_iter().rev();
self.ingest_positions(iter)?;
self.ingest_positions(iter, remembers)?;

let pruned = proof
.targets
Expand All @@ -667,7 +687,6 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
.copied()
.collect::<Vec<_>>();

self.map_targets(remembers)?;
self.prune(&pruned)?;

if recompute {
Expand All @@ -679,17 +698,6 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
Ok(())
}

fn map_targets(&mut self, targets: &[u64]) -> Result<(), String> {
for target in targets {
let node = self
.grab_position(*target)
.ok_or(format!("Position {target} not found"))?;
self.leaf_map.insert(node.0.hash(), Rc::downgrade(&node.0));
}

Ok(())
}

fn detect_offset(pos: u64, num_leaves: u64) -> (u8, u8, u64) {
let mut tr = tree_rows(num_leaves);
let nr = detect_row(pos, tr);
Expand Down Expand Up @@ -1156,6 +1164,88 @@ mod tests {
assert_eq!(root.unwrap().hash(), hashes[0].hash);
}

#[test]
fn test_ingest_proof_and_prove() {
// this test will create a forest, prove a few leaves, prune all leaves, ingest the proof
// and prove the same leaves + siblings again
let values = vec![0, 1, 2, 3, 4, 5, 6, 7];
let hashes: Vec<_> = values
.into_iter()
.map(|preimage| {
let hash = hash_from_u8(preimage);
PollardAddition {
hash,
remember: true,
}
})
.collect();

let mut acc = Pollard::<BitcoinNodeHash>::new();
acc.modify(&hashes, &[], Proof::default()).unwrap();

let del_hashes = [
hash_from_u8(2),
hash_from_u8(1),
hash_from_u8(4),
hash_from_u8(6),
];
let proof = acc.batch_proof(&del_hashes).unwrap();

acc.prune(&[0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
acc.ingest_proof(proof, &del_hashes, &[2, 1, 4, 6]).unwrap();

let del_hashes = [0, 1, 4, 5, 6, 7]
.iter()
.map(|x| hash_from_u8(*x))
.collect::<Vec<_>>();
let proof = acc.batch_proof(&del_hashes).unwrap();
assert!(acc.verify(&proof, &del_hashes).unwrap());
}
#[test]
fn test_prove() {
let values = vec![0, 1, 2, 3, 4, 5, 6, 7];
let hashes: Vec<_> = values
.into_iter()
.map(|preimage| {
let hash = hash_from_u8(preimage);
PollardAddition {
hash,
remember: true,
}
})
.collect();

let mut acc = Pollard::<BitcoinNodeHash>::new();
acc.modify(&hashes, &[], Proof::default()).unwrap();
let del_hashes = [
hash_from_u8(2),
hash_from_u8(1),
hash_from_u8(4),
hash_from_u8(6),
];
let proof = acc.batch_proof(&del_hashes).unwrap();
let expected_proof = Proof::new(
[2, 1, 4, 6].to_vec(),
vec![
"6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d"
.parse()
.unwrap(),
"084fed08b978af4d7d196a7446a86b58009e636b611db16211b65a9aadff29c5"
.parse()
.unwrap(),
"e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db"
.parse()
.unwrap(),
"ca358758f6d27e6cf45272937977a748fd88391db679ceda7dc7bf1f005ee879"
.parse()
.unwrap(),
],
);

assert_eq!(proof, expected_proof);
assert!(acc.verify(&proof, &del_hashes).unwrap());
}

#[test]
fn test_prove_single() {
let values = vec![0, 1, 2, 3, 4, 5];
Expand Down Expand Up @@ -1232,20 +1322,14 @@ mod tests {
test_get_pos!(p, 10);
test_get_pos!(p, 11);
test_get_pos!(p, 12);

let root = p.roots[3].as_ref().unwrap();
let left = root.left_niece().unwrap();
let right = root.right_niece().unwrap();

assert_eq!(p.get_pos(&Rc::downgrade(&root)), Ok(28));
assert_eq!(
p.get_pos(&Rc::downgrade(&left)),
Ok(24)
);
assert_eq!(
p.get_pos(&Rc::downgrade(&right)),
Ok(25)
);
assert_eq!(p.get_pos(&Rc::downgrade(&left)), Ok(24));
assert_eq!(p.get_pos(&Rc::downgrade(&right)), Ok(25));
}

#[test]
Expand Down

0 comments on commit 35602af

Please sign in to comment.