diff --git a/src/vocab.rs b/src/vocab.rs index bebfb3a..307270b 100644 --- a/src/vocab.rs +++ b/src/vocab.rs @@ -312,7 +312,7 @@ impl Vocabulary { #[inline] /// Compute the mean of a collection of binary arrays (descriptors). - fn desc_mean(descriptors: Vec<&Desc>) -> Desc { + pub fn desc_mean(descriptors: Vec<&Desc>) -> Desc { let n2 = descriptors.len() / 2; let mut counts = vec![0; std::mem::size_of::() * 8]; let mut result: Desc = [0; std::mem::size_of::()]; @@ -357,6 +357,38 @@ impl Vocabulary { levels: l, } } + + /// Checks the children of the root against a grouped set of kps that should match exactly + /// Used for tests + pub fn check_root_children(&self, kp_aggregates: Vec<(usize, usize, Desc)>) -> bool { + let root_block = &self.blocks[0]; + let mut is_consistent = true; + + for (id, size, centroid) in kp_aggregates { + //Find the index of the child that matches 'id' + let child_idx = root_block.children.ids.iter().position(|x| { + if let NodeId::Block(x) = x { + *x == id + } else { + false + } + }); + + //Check that the feature centroid and kp descriptor centroid are the same, and cluster size is the same + if let Some(child_idx) = child_idx { + if root_block.children.cluster_size[child_idx] != size + || root_block.children.features[child_idx] != centroid + { + is_consistent = false; + break; + } + } else { + is_consistent = false; + break; + } + } + is_consistent + } } #[inline]