diff --git a/src/experimental/zipper_algebra.rs b/src/experimental/zipper_algebra.rs index a69630b..96b8dc9 100644 --- a/src/experimental/zipper_algebra.rs +++ b/src/experimental/zipper_algebra.rs @@ -55,13 +55,13 @@ pub use zipper_algebra_poly::ZipperMergeF; /// - [`zipper_meet`] /// - [`zipper_subtract`] pub trait ZipperAlgebraExt: - ZipperInfallibleSubtries + ZipperMoving + Sized + ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving + Sized { #[inline] fn join(&mut self, rhs: &mut ZR, out: &mut Out) where V: Lattice, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { zipper_join(self, rhs, out); @@ -71,7 +71,7 @@ pub trait ZipperAlgebraExt: fn meet(&mut self, rhs: &mut ZR, out: &mut Out) where V: Lattice, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { zipper_meet(self, rhs, out); @@ -81,7 +81,7 @@ pub trait ZipperAlgebraExt: fn subtract(&mut self, rhs: &mut ZR, out: &mut Out) where V: DistributiveLattice, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { zipper_subtract(self, rhs, out); @@ -139,8 +139,8 @@ pub fn zipper_join(lhs: &mut ZL, rhs: &mut ZR, out: &mut Out) where V: Lattice + Clone + Send + Sync, A: Allocator, - ZL: ZipperInfallibleSubtries + ZipperMoving, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZL: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { zipper_merge::(lhs, rhs, out); @@ -157,9 +157,9 @@ pub fn zipper_join3(lhs: &mut ZL, mid: &mut ZM, rhs: &mut where V: Lattice + Clone + Send + Sync, A: Allocator, - ZL: ZipperInfallibleSubtries + ZipperMoving, - ZM: ZipperInfallibleSubtries + ZipperMoving, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZL: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZM: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { zipper_merge3::(lhs, mid, rhs, out); @@ -209,8 +209,8 @@ pub fn zipper_meet(lhs: &mut ZL, rhs: &mut ZR, out: &mut Out) where V: Lattice + Clone + Send + Sync, A: Allocator, - ZL: ZipperInfallibleSubtries + ZipperMoving, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZL: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { zipper_merge::(lhs, rhs, out); @@ -227,9 +227,9 @@ pub fn zipper_meet3(lhs: &mut ZL, mid: &mut ZM, rhs: &mut where V: Lattice + Clone + Send + Sync, A: Allocator, - ZL: ZipperInfallibleSubtries + ZipperMoving, - ZM: ZipperInfallibleSubtries + ZipperMoving, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZL: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZM: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { zipper_merge3::(lhs, mid, rhs, out); @@ -285,8 +285,8 @@ pub fn zipper_subtract(lhs: &mut ZL, rhs: &mut ZR, out: &mut where V: DistributiveLattice + Clone + Send + Sync, A: Allocator, - ZL: ZipperInfallibleSubtries + ZipperMoving, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZL: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { zipper_merge::(lhs, rhs, out); @@ -307,9 +307,9 @@ pub fn zipper_subtract3( ) where V: DistributiveLattice + Clone + Send + Sync, A: Allocator, - ZL: ZipperInfallibleSubtries + ZipperMoving, - ZM: ZipperInfallibleSubtries + ZipperMoving, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZL: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZM: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { zipper_merge3::(lhs, mid, rhs, out); @@ -343,6 +343,11 @@ trait MergePolicy { Out: ZipperWriting; fn descend_on_some_equal(mask: u64) -> bool; + fn on_id(z: &mut Z, out: &mut Out) + where + A: Allocator, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Out: ZipperWriting; } trait ValuePolicy { @@ -373,10 +378,24 @@ where V: Clone + Send + Sync, P: MergePolicy + ValuePolicy, A: Allocator, - ZL: ZipperInfallibleSubtries + ZipperMoving, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZL: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { + fn check_sharing(lhs: &ZL, rhs: &ZR) -> bool + where + ZL: ZipperConcrete, + ZR: ZipperConcrete, + { + lhs.shared_node_id() + .is_some_and(|lsnid| rhs.shared_node_id().is_some_and(|rsnid| lsnid == rsnid)) + } + // check for node-sharing first + if check_sharing(lhs, rhs) { + P::on_id(lhs, out); + return; + } + // merge root values before descending if let Some(v) = P::combine(lhs.val(), rhs.val()) { out.set_val(v); @@ -416,6 +435,20 @@ where lhs.descend_to_byte(lhs_byte); rhs.descend_to_byte(lhs_byte); + // optimization - if both zippers share the node after descend, we can skip + // further descend and continue merging + if check_sharing(lhs, rhs) { + P::on_id(lhs, out); + + rhs.ascend_byte(); + rhs_next = rhs_mask.next_bit(lhs_byte); + lhs.ascend_byte(); + lhs_next = lhs_mask.next_bit(lhs_byte); + out.ascend_byte(); + + continue 'merge_level; + } + if let Some(v) = P::combine(lhs.val(), rhs.val()) { out.set_val(v); } @@ -482,9 +515,9 @@ where V: Clone + Send + Sync, P: MergePolicy + ValuePolicy, A: Allocator, - ZL: ZipperInfallibleSubtries + ZipperMoving, - ZM: ZipperInfallibleSubtries + ZipperMoving, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZL: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZM: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { const L: u8 = 0b001; @@ -500,8 +533,8 @@ where V: Clone + Send + Sync, P: MergePolicy + ValuePolicy, A: Allocator, - ZL: ZipperInfallibleSubtries + ZipperMoving, - ZR: ZipperInfallibleSubtries + ZipperMoving, + ZL: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { out.descend_to_byte(b); @@ -515,6 +548,25 @@ where out.ascend_byte(); } + fn all_share(lhs: &ZL, mid: &ZM, rhs: &ZR) -> bool + where + ZL: ZipperConcrete, + ZM: ZipperConcrete, + ZR: ZipperConcrete, + { + lhs.shared_node_id().is_some_and(|lsnid| { + mid.shared_node_id().is_some_and(|msnid| { + lsnid == msnid && rhs.shared_node_id().is_some_and(|rsnid| msnid == rsnid) + }) + }) + } + + // check for node-sharing first + if all_share(lhs, mid, rhs) { + P::on_id(lhs, out); + return; + } + // merge root values before descending if let Some(v) = P::combine3(lhs.val(), mid.val(), rhs.val()) { out.set_val(v); @@ -611,6 +663,21 @@ where mid.descend_to_byte(min); rhs.descend_to_byte(min); + //structural sharing check + if all_share(lhs, mid, rhs) { + P::on_id(lhs, out); + + rhs.ascend_byte(); + r = rhs_mask.next_bit(min); + mid.ascend_byte(); + m = mid_mask.next_bit(min); + lhs.ascend_byte(); + l = lhs_mask.next_bit(min); + out.ascend_byte(); + + continue 'merge_level; + } + if let Some(val) = P::combine3(lhs.val(), mid.val(), rhs.val()) { out.set_val(val); } @@ -672,12 +739,35 @@ fn zipper_merge4( V: Clone + Send + Sync, P: MergePolicy + ValuePolicy, A: Allocator, - Z0: ZipperInfallibleSubtries + ZipperMoving, - Z1: ZipperInfallibleSubtries + ZipperMoving, - Z2: ZipperInfallibleSubtries + ZipperMoving, - Z3: ZipperInfallibleSubtries + ZipperMoving, + Z0: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Z1: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Z2: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Z3: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { + fn all_share(z0: &Z0, z1: &Z1, z2: &Z2, z3: &Z3) -> bool + where + Z0: ZipperConcrete, + Z1: ZipperConcrete, + Z2: ZipperConcrete, + Z3: ZipperConcrete, + { + z0.shared_node_id().is_some_and(|snid0| { + z1.shared_node_id().is_some_and(|snid1| { + snid0 == snid1 + && z2.shared_node_id().is_some_and(|snid2| { + snid1 == snid2 && z3.shared_node_id().is_some_and(|snid3| snid2 == snid3) + }) + }) + }) + } + + // check for node-sharing first + if all_share(z0, z1, z2, z3) { + P::on_id(z0, out); + return; + } + // merge root values before descending if let Some(v) = P::combine4(z0.val(), z1.val(), z2.val(), z3.val()) { out.set_val(v); @@ -731,6 +821,23 @@ fn zipper_merge4( z2.descend_to_byte(min); z3.descend_to_byte(min); + // check structural sharing + if all_share(z0, z1, z2, z3) { + P::on_id(z0, out); + + z3.ascend_byte(); + b3 = m3.next_bit(min); + z2.ascend_byte(); + b2 = m2.next_bit(min); + z1.ascend_byte(); + b1 = m1.next_bit(min); + z0.ascend_byte(); + b0 = m0.next_bit(min); + out.ascend_byte(); + + continue 'merge_level; + } + if let Some(v) = P::combine4(z0.val(), z1.val(), z2.val(), z3.val()) { out.set_val(v); } @@ -958,7 +1065,7 @@ fn zipper_merge4( pub fn zipper_n_join(zs: &mut [Z; N], out: &mut Out) where V: Lattice + Clone + Send + Sync + Unpin, - Z: ZipperInfallibleSubtries + ZipperMoving, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, A: Allocator, { @@ -994,7 +1101,7 @@ where pub fn zipper_n_meet(zs: &mut [Z; N], out: &mut Out) where V: Lattice + Clone + Send + Sync + Unpin, - Z: ZipperInfallibleSubtries + ZipperMoving, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, A: Allocator, { @@ -1037,7 +1144,7 @@ where pub fn zipper_n_subtract(zs: &mut [Z; N], out: &mut Out) where V: DistributiveLattice + Clone + Send + Sync + Unpin, - Z: ZipperInfallibleSubtries + ZipperMoving, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, A: Allocator, { @@ -1054,7 +1161,7 @@ where V: Clone + Send + Sync + Unpin, P: MergePolicy + ValuePolicy, A: Allocator, - Z: ZipperInfallibleSubtries + ZipperMoving, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { debug_assert!(N > 0 && N <= 64); @@ -1068,10 +1175,10 @@ where } fn only_active<'a, T, const N: usize>( - zs: &'a [T; N], + ts: &'a [T; N], active: u64, ) -> impl Iterator { - active_bits::(active).map(|i| (i, &zs[i])) + active_bits::(active).map(|i| (i, &ts[i])) } fn values<'a, V, Z, const N: usize>( @@ -1085,7 +1192,25 @@ where only_active(zs, active).map(|(_, z)| z.val()) } + fn all_active_share(zs: &[Z; N], active: u64) -> bool + where + Z: ZipperConcrete, + { + let mut iter = only_active(zs, active).map(|(_, z)| z.shared_node_id()); + match iter.next() { + Some(Some(first)) => iter.all(|next| next.is_some_and(|snid| snid == first)), + _ => false, + } + } + // small micro-helpers + #[inline(always)] + fn first_active(ts: &mut [T; N], active: u64) -> &mut T { + debug_assert_ne!(active, 0); + let i0 = active.trailing_zeros() as usize; + &mut ts[i0] + } + #[inline(always)] fn for_each_bit(mut bits: u64, mut f: impl FnMut(usize)) { while bits != 0 { @@ -1124,6 +1249,12 @@ where f(refs) } + // check for node-sharing first + if all_active_share(zs, active) { + P::on_id(first_active(zs, active), out); + return; + } + // combine root values if let Some(v) = P::combine_n(values(zs, active)) { out.set_val(v); @@ -1158,7 +1289,7 @@ where // (`descend_to_byte` / `ascend_byte`) and an explicit depth counter, // avoiding recursion in the common case. let mut k = 0; - debug_assert!(active.count_ones() > 0); + debug_assert_ne!(active.count_ones(), 0); 'ascend: loop { 'merge_level: loop { let mut min = None; @@ -1205,16 +1336,30 @@ where // descend and refresh masks and indices for_each_bit(active, |i| { - let mut z = &mut zs[i]; - z.descend_to_byte(a); - masks[i] = z.child_mask(); - bytes[i] = masks[i].indexed_bit::(0); + zs[i].descend_to_byte(a); }); + // check structural sharing first + if all_active_share(zs, active) { + P::on_id(first_active(zs, active), out); + + for_each_bit(active, |i| { + zs[i].ascend_byte(); + bytes[i] = masks[i].next_bit(a); + }); + out.ascend_byte(); + continue 'merge_level; + } + if let Some(v) = P::combine_n(values(zs, active)) { out.set_val(v); } + for_each_bit(active, |i| { + masks[i] = zs[i].child_mask(); + bytes[i] = masks[i].indexed_bit::(0); + }); + k += 1; continue 'merge_level; } @@ -1301,9 +1446,10 @@ where if (k == 0) { break 'ascend; } - - let i0 = active.trailing_zeros() as usize; - let byte_from = *zs[i0].path().last().expect("non-empty path when k > 0"); + let byte_from = *first_active(zs, active) + .path() + .last() + .expect("non-empty path when k > 0"); // ascend for_each_bit(active, |i| { @@ -1336,6 +1482,16 @@ impl MergePolicy for Join { fn descend_on_some_equal(_mask: u64) -> bool { true } + + #[inline] + fn on_id(z: &mut Z, out: &mut Out) + where + A: Allocator, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Out: ZipperWriting, + { + out.graft_children(z, ByteMask::FULL); + } } impl ValuePolicy for Join { @@ -1401,6 +1557,16 @@ impl MergePolicy for Meet { fn descend_on_some_equal(_mask: u64) -> bool { false } + + #[inline(always)] + fn on_id(z: &mut Z, out: &mut Out) + where + A: Allocator, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Out: ZipperWriting, + { + out.graft_children(z, ByteMask::FULL); + } } impl ValuePolicy for Meet { @@ -1502,6 +1668,15 @@ impl MergePolicy for Subtract { fn descend_on_some_equal(mask: u64) -> bool { mask & 1 != 0 } + + #[inline] + fn on_id(_z: &mut Z, _out: &mut Out) + where + A: Allocator, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Out: ZipperWriting, + { + } } impl ValuePolicy for Subtract { @@ -1581,7 +1756,7 @@ mod zipper_algebra_poly { use pathmap_derive::PolyZipperExplicit; #[derive(PolyZipperExplicit)] - #[poly_zipper_explicit(traits(ZipperMoving, ZipperValues))] + #[poly_zipper_explicit(traits(ZipperMoving, ZipperValues, ZipperConcrete))] pub(super) enum SomeMutRefZ<'a, 'trie, 'path, V: Clone + Send + Sync + Unpin, A: Allocator> { RZ(&'a mut ReadZipperUntracked<'trie, 'path, V, A>), RZT(&'a mut ReadZipperTracked<'trie, 'path, V, A>), @@ -1664,8 +1839,8 @@ mod zipper_algebra_poly { where V: Clone + Send + Sync, A: Allocator, - Z1: ZipperInfallibleSubtries + ZipperMoving, - Z2: ZipperInfallibleSubtries + ZipperMoving, + Z1: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Z2: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { fn merge_n

(mut self, out: &mut Out) @@ -1680,9 +1855,9 @@ mod zipper_algebra_poly { where V: Clone + Send + Sync, A: Allocator, - Z1: ZipperInfallibleSubtries + ZipperMoving, - Z2: ZipperInfallibleSubtries + ZipperMoving, - Z3: ZipperInfallibleSubtries + ZipperMoving, + Z1: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Z2: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Z3: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { fn merge_n

(mut self, out: &mut Out) @@ -1697,10 +1872,10 @@ mod zipper_algebra_poly { where V: Clone + Send + Sync, A: Allocator, - Z1: ZipperInfallibleSubtries + ZipperMoving, - Z2: ZipperInfallibleSubtries + ZipperMoving, - Z3: ZipperInfallibleSubtries + ZipperMoving, - Z4: ZipperInfallibleSubtries + ZipperMoving, + Z1: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Z2: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Z3: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Z4: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { fn merge_n

(mut self, out: &mut Out)