diff --git a/src/lib.rs b/src/lib.rs index 308c2bd..17958fb 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,4 +15,4 @@ mod multiset; -pub use multiset::{HashMultiSet, Iter}; +pub use multiset::{HashMultiSet, Intersection, IntersectionCounts, Iter, Union, UnionCounts}; diff --git a/src/multiset.rs b/src/multiset.rs index dacaec8..2a55ef8 100644 --- a/src/multiset.rs +++ b/src/multiset.rs @@ -340,6 +340,239 @@ where pub fn count_of(&self, val: &K) -> usize { self.elem_counts.get(val).map_or(0, |x| *x) } + + /// Returns an iterator over the union of the two sets. + /// + /// This entails visiting each key the maximum number of times it appears in either set. + /// + /// # Examples + /// + /// ``` + /// use multiset::HashMultiSet; + /// + /// let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + /// let set2: HashMultiSet<_> = [2, 2, 3].iter().cloned().collect(); + /// let mut union_vec: Vec<_> = set1.union(&set2).cloned().collect(); + /// union_vec.sort(); // Order is arbitrary + /// assert_eq!(union_vec, vec![1, 1, 2, 2, 3]); + /// ``` + pub fn union<'a>(&'a self, other: &'a HashMultiSet) -> Union<'a, K> { + Union { + set1_iter: self.elem_counts.iter(), + set2_iter: other.elem_counts.iter(), + set1: self, + set2: other, + cur_entry: None, + cur_count: 0, + } + } + + /// Returns an iterator over the intersection of the two sets. + /// + /// This entails visiting each key the minimum number of times it appears in either set. + /// + /// # Examples + /// + /// ``` + /// use multiset::HashMultiSet; + /// + /// let set1: HashMultiSet<_> = [1, 1, 2, 2, 2].iter().cloned().collect(); + /// let set2: HashMultiSet<_> = [1, 2, 2, 3].iter().cloned().collect(); + /// let mut intersection_vec: Vec<_> = set1.intersection(&set2).cloned().collect(); + /// intersection_vec.sort(); // Order is arbitrary + /// assert_eq!(intersection_vec, vec![1, 2, 2]); + /// ``` + pub fn intersection<'a>(&'a self, other: &'a HashMultiSet) -> Intersection<'a, K> { + Intersection { + set1_iter: self.elem_counts.iter(), + set2: other, + cur_entry: None, + cur_count: 0, + } + } + + /// Returns an iterator over the union of the two sets. + /// + /// This entails yielding each key with the maximum number of times it appears in either set. + /// + /// # Examples + /// + /// ``` + /// use multiset::HashMultiSet; + /// + /// let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + /// let set2: HashMultiSet<_> = [2, 2, 3].iter().cloned().collect(); + /// let set_union: HashMultiSet<_> = set1 + /// .union_counts(&set2) + /// .map(|(&key, count)| (key, count)) + /// .collect(); + /// assert_eq!(set_union.count_of(&1), 2); + /// assert_eq!(set_union.count_of(&2), 2); + /// assert_eq!(set_union.count_of(&3), 1); + /// ``` + pub fn union_counts<'a>(&'a self, other: &'a HashMultiSet) -> UnionCounts<'a, K> { + UnionCounts { + set1: self, + set2: other, + set1_iter: self.elem_counts.iter(), + set2_iter: other.elem_counts.iter(), + } + } + + /// Return an iterator over the intersection of the two sets. + /// + /// This entails yielding each key with the minimum number of times it appears in either set. If + /// a key only appears in one set, it will not appear at all in the intersection. + /// + /// ``` + /// use multiset::HashMultiSet; + /// + /// let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + /// let set2: HashMultiSet<_> = [2, 2, 3].iter().cloned().collect(); + /// let set_intersection: HashMultiSet<_> = set1 + /// .intersection_counts(&set2) + /// .map(|(&key, count)| (key, count)) + /// .collect(); + /// assert_eq!(set_intersection.count_of(&1), 0); + /// assert_eq!(set_intersection.count_of(&2), 1); + /// assert_eq!(set_intersection.count_of(&3), 0); + /// ``` + pub fn intersection_counts<'a>( + &'a self, + other: &'a HashMultiSet, + ) -> IntersectionCounts<'a, K> { + IntersectionCounts { + set2: other, + set1_iter: self.elem_counts.iter(), + } + } +} + +/// An iterator over the union of two structs. See [HashMultiSet::union] for +/// more details. +pub struct Union<'a, K> { + set1_iter: hash_map::Iter<'a, K, usize>, + set2_iter: hash_map::Iter<'a, K, usize>, + set1: &'a HashMultiSet, + set2: &'a HashMultiSet, + cur_entry: Option<(&'a K, usize)>, + cur_count: usize, +} + +impl<'a, K> Iterator for Union<'a, K> +where + K: Eq + Hash, +{ + type Item = &'a K; + + fn next(&mut self) -> Option { + if let Some((key, count)) = self.cur_entry { + if self.cur_count < count { + self.cur_count += 1; + return Some(key); + } + } + if let Some((new_key, &new_count)) = self.set1_iter.next() { + let max_count = std::cmp::max(new_count, self.set2.count_of(new_key)); + self.cur_entry = Some((new_key, max_count)); + self.cur_count = 0; + return self.next(); + } + while let Some((new_key, &new_count)) = self.set2_iter.next() { + if self.set1.contains(new_key) { + continue; + } + self.cur_entry = Some((new_key, new_count)); + self.cur_count = 0; + return self.next(); + } + None + } +} + +/// An iterator over the intersection of two structs. See +/// [HashMultiSet::intersection] for more details. +pub struct Intersection<'a, K> { + set1_iter: hash_map::Iter<'a, K, usize>, + set2: &'a HashMultiSet, + cur_entry: Option<(&'a K, usize)>, + cur_count: usize, +} + +impl<'a, K> Iterator for Intersection<'a, K> +where + K: Eq + Hash, +{ + type Item = &'a K; + + fn next(&mut self) -> Option { + if let Some((key, count)) = self.cur_entry { + if self.cur_count < count { + self.cur_count += 1; + return Some(key); + } + } + if let Some((new_key, &new_count)) = self.set1_iter.next() { + let min_count = std::cmp::min(new_count, self.set2.count_of(new_key)); + self.cur_entry = Some((new_key, min_count)); + self.cur_count = 0; + return self.next(); + } + None + } +} + +/// An iterator over the union of two structs. See [HashMultiSet::union] for +/// more details. +pub struct UnionCounts<'a, K> { + set1: &'a HashMultiSet, + set2: &'a HashMultiSet, + set1_iter: hash_map::Iter<'a, K, usize>, + set2_iter: hash_map::Iter<'a, K, usize>, +} + +impl<'a, K> Iterator for UnionCounts<'a, K> +where + K: Eq + Hash, +{ + type Item = (&'a K, usize); + + fn next(&mut self) -> Option { + while let Some((key, count)) = self.set1_iter.next() { + let max_count = std::cmp::max(*count, self.set2.count_of(key)); + return Some((key, max_count)); + } + while let Some((key, count)) = self.set2_iter.next() { + if !self.set1.contains(key) { + return Some((key, *count)); + } + } + None + } +} + +/// An iterator over the intersection of two structs. See +/// [HashMultiSet::intersection] for more details. +pub struct IntersectionCounts<'a, K> { + set2: &'a HashMultiSet, + set1_iter: hash_map::Iter<'a, K, usize>, +} + +impl<'a, K> Iterator for IntersectionCounts<'a, K> +where + K: Eq + Hash, +{ + type Item = (&'a K, usize); + + fn next(&mut self) -> Option { + while let Some((key, count)) = self.set1_iter.next() { + if self.set2.contains(key) { + let min_count = std::cmp::min(*count, self.set2.count_of(key)); + return Some((key, min_count)); + } + } + None + } } impl Add for HashMultiSet @@ -447,6 +680,38 @@ where } } +impl FromIterator<(K, usize)> for HashMultiSet +where + K: Eq + Hash, +{ + /// Creates a new `HashMultiSet` from the elements in an iterable. + /// + /// # Examples + /// + /// Count occurrences of each `char` in `"hello world"`: + /// + /// ``` + /// use multiset::HashMultiSet; + /// use std::iter::FromIterator; + /// + /// let vals = vec!['h','e','l','l','o',' ','w','o','r','l','d']; + /// let multiset: HashMultiSet = FromIterator::from_iter(vals); + /// assert_eq!(1, multiset.count_of(&'h')); + /// assert_eq!(3, multiset.count_of(&'l')); + /// assert_eq!(0, multiset.count_of(&'z')); + /// ``` + fn from_iter(iterable: T) -> HashMultiSet + where + T: IntoIterator, + { + let mut multiset: HashMultiSet = HashMultiSet::new(); + for (elem, count) in iterable.into_iter() { + multiset.insert_times(elem, count); + } + multiset + } +} + impl PartialEq for HashMultiSet where T: Eq + Hash, @@ -556,5 +821,265 @@ mod test_multiset { assert_eq!(set.len(), 1); set.remove(&'d'); assert_eq!(set.len(), 0); + + set.insert_times('e', 2); + assert_eq!(set.len(), 2); + assert_eq!(set.remove_times(&'e', 4), 2); + assert_eq!(set.len(), 0); + } + + #[test] + fn test_union() { + let empty_array: [u32; 0] = []; + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [2, 2, 3].iter().cloned().collect(); + let mut union_vec: Vec<_> = set1.union(&set2).cloned().collect(); + union_vec.sort(); + assert_eq!(union_vec, vec![1, 1, 2, 2, 3]); + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [3, 4, 4].iter().cloned().collect(); + let mut union_vec: Vec<_> = set1.union(&set2).cloned().collect(); + union_vec.sort(); + assert_eq!(union_vec, vec![1, 1, 2, 3, 4, 4]); + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let mut union_vec: Vec<_> = set1.union(&set2).cloned().collect(); + union_vec.sort(); + assert_eq!(union_vec, vec![1, 1, 2]); + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1, 2, 2].iter().cloned().collect(); + let mut union_vec: Vec<_> = set1.union(&set2).cloned().collect(); + union_vec.sort(); + assert_eq!(union_vec, vec![1, 1, 2, 2]); + + let set1: HashMultiSet<_> = [1, 1, 2, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let mut union_vec: Vec<_> = set1.union(&set2).cloned().collect(); + union_vec.sort(); + assert_eq!(union_vec, vec![1, 1, 2, 2]); + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 2, 2].iter().cloned().collect(); + let mut union_vec: Vec<_> = set1.union(&set2).cloned().collect(); + union_vec.sort(); + assert_eq!(union_vec, vec![1, 1, 2, 2]); + + let set1: HashMultiSet<_> = empty_array.iter().cloned().collect(); + let set2: HashMultiSet<_> = [2, 2, 3].iter().cloned().collect(); + let mut union_vec: Vec<_> = set1.union(&set2).cloned().collect(); + union_vec.sort(); + assert_eq!(union_vec, vec![2, 2, 3]); + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = empty_array.iter().cloned().collect(); + let mut union_vec: Vec<_> = set1.union(&set2).cloned().collect(); + union_vec.sort(); + assert_eq!(union_vec, vec![1, 1, 2]); + + let set1: HashMultiSet<_> = empty_array.iter().cloned().collect(); + let set2: HashMultiSet<_> = empty_array.iter().cloned().collect(); + let mut union_vec: Vec<_> = set1.union(&set2).cloned().collect(); + union_vec.sort(); + assert_eq!(union_vec, vec![]); + } + + #[test] + fn test_intersection() { + let empty_array: [u32; 0] = []; + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [2, 2, 3].iter().cloned().collect(); + let mut intersection_vec: Vec<_> = set1.intersection(&set2).cloned().collect(); + intersection_vec.sort(); + assert_eq!(intersection_vec, vec![2]); + + let set1: HashMultiSet<_> = [1, 1, 2, 2, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 2, 2, 3].iter().cloned().collect(); + let mut intersection_vec: Vec<_> = set1.intersection(&set2).cloned().collect(); + intersection_vec.sort(); + assert_eq!(intersection_vec, vec![1, 2, 2]); + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [3, 4, 4].iter().cloned().collect(); + let mut intersection_vec: Vec<_> = set1.intersection(&set2).cloned().collect(); + intersection_vec.sort(); + assert_eq!(intersection_vec, vec![]); + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let mut intersection_vec: Vec<_> = set1.intersection(&set2).cloned().collect(); + intersection_vec.sort(); + assert_eq!(intersection_vec, vec![1, 1, 2]); + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1, 2, 2].iter().cloned().collect(); + let mut intersection_vec: Vec<_> = set1.intersection(&set2).cloned().collect(); + intersection_vec.sort(); + assert_eq!(intersection_vec, vec![1, 1, 2]); + + let set1: HashMultiSet<_> = [1, 1, 2, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let mut intersection_vec: Vec<_> = set1.intersection(&set2).cloned().collect(); + intersection_vec.sort(); + assert_eq!(intersection_vec, vec![1, 1, 2]); + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 2, 2].iter().cloned().collect(); + let mut intersection_vec: Vec<_> = set1.intersection(&set2).cloned().collect(); + intersection_vec.sort(); + assert_eq!(intersection_vec, vec![1, 2]); + + let set1: HashMultiSet<_> = empty_array.iter().cloned().collect(); + let set2: HashMultiSet<_> = [2, 2, 3].iter().cloned().collect(); + let mut intersection_vec: Vec<_> = set1.intersection(&set2).cloned().collect(); + intersection_vec.sort(); + assert_eq!(intersection_vec, vec![]); + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = empty_array.iter().cloned().collect(); + let mut intersection_vec: Vec<_> = set1.intersection(&set2).cloned().collect(); + intersection_vec.sort(); + assert_eq!(intersection_vec, vec![]); + + let set1: HashMultiSet<_> = empty_array.iter().cloned().collect(); + let set2: HashMultiSet<_> = empty_array.iter().cloned().collect(); + let mut intersection_vec: Vec<_> = set1.intersection(&set2).cloned().collect(); + intersection_vec.sort(); + assert_eq!(intersection_vec, vec![]); + } + + #[test] + fn test_union_counts() { + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [2, 2, 3].iter().cloned().collect(); + let set_union: HashMultiSet<_> = set1 + .union_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_union.count_of(&1), 2); + assert_eq!(set_union.count_of(&2), 2); + assert_eq!(set_union.count_of(&3), 1); + + let set1: HashMultiSet<_> = [1, 1].iter().cloned().collect(); + let set2: HashMultiSet<_> = [2, 3].iter().cloned().collect(); + let set_union: HashMultiSet<_> = set1 + .union_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_union.count_of(&1), 2); + assert_eq!(set_union.count_of(&2), 1); + assert_eq!(set_union.count_of(&3), 1); + assert_eq!(set_union.len(), 4); + + let set1: HashMultiSet<_> = [1, 1].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1].iter().cloned().collect(); + let set_union: HashMultiSet<_> = set1 + .union_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_union.count_of(&1), 2); + assert_eq!(set_union.len(), 2); + + let set1: HashMultiSet<_> = [1, 1].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1, 1].iter().cloned().collect(); + let set_union: HashMultiSet<_> = set1 + .union_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_union.count_of(&1), 3); + assert_eq!(set_union.len(), 3); + + let set1: HashMultiSet<_> = [1, 1].iter().cloned().collect(); + let set2: HashMultiSet<_> = [2, 2, 2].iter().cloned().collect(); + let set_union: HashMultiSet<_> = set1 + .union_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_union.count_of(&1), 2); + assert_eq!(set_union.count_of(&2), 3); + assert_eq!(set_union.len(), 5); + + let set1: HashMultiSet<_> = [1, 1, 2, 3].iter().cloned().collect(); + let empty_array: [u32; 0] = []; + let set2: HashMultiSet<_> = empty_array.iter().cloned().collect(); + let set_union: HashMultiSet<_> = set1 + .union_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_union.count_of(&1), 2); + assert_eq!(set_union.count_of(&2), 1); + assert_eq!(set_union.count_of(&3), 1); + assert_eq!(set_union.len(), 4); + + let set1: HashMultiSet<_> = [1, 1, 2, 3].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1, 2, 3].iter().cloned().collect(); + let set_union: HashMultiSet<_> = set1 + .union_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_union.count_of(&1), 2); + assert_eq!(set_union.count_of(&2), 1); + assert_eq!(set_union.count_of(&3), 1); + assert_eq!(set_union.len(), 4); + } + + #[test] + fn test_intersection_counts() { + let set1: HashMultiSet<_> = [1, 1].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1, 1].iter().cloned().collect(); + let set_intersection: HashMultiSet<_> = set1 + .intersection_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_intersection.count_of(&1), 2); + assert_eq!(set_intersection.len(), 2); + + let set1: HashMultiSet<_> = [1, 1].iter().cloned().collect(); + let set2: HashMultiSet<_> = [2, 2, 2].iter().cloned().collect(); + let set_intersection: HashMultiSet<_> = set1 + .intersection_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_intersection.count_of(&1), 0); + assert_eq!(set_intersection.count_of(&2), 0); + assert_eq!(set_intersection.len(), 0); + + let set1: HashMultiSet<_> = [1, 1, 2].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1, 3].iter().cloned().collect(); + let set_intersection: HashMultiSet<_> = set1 + .intersection_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_intersection.count_of(&1), 2); + assert_eq!(set_intersection.count_of(&2), 0); + assert_eq!(set_intersection.count_of(&3), 0); + assert_eq!(set_intersection.len(), 2); + + let set1: HashMultiSet<_> = [1, 1, 2, 3].iter().cloned().collect(); + let empty_array: [u32; 0] = []; + let set2: HashMultiSet<_> = empty_array.iter().cloned().collect(); + let set_intersection: HashMultiSet<_> = set1 + .intersection_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_intersection.count_of(&1), 0); + assert_eq!(set_intersection.count_of(&2), 0); + assert_eq!(set_intersection.count_of(&3), 0); + assert_eq!(set_intersection.len(), 0); + + let set1: HashMultiSet<_> = [1, 1, 2, 3].iter().cloned().collect(); + let set2: HashMultiSet<_> = [1, 1, 2, 3].iter().cloned().collect(); + let set_intersection: HashMultiSet<_> = set1 + .intersection_counts(&set2) + .map(|(&key, count)| (key, count)) + .collect(); + assert_eq!(set_intersection.count_of(&1), 2); + assert_eq!(set_intersection.count_of(&2), 1); + assert_eq!(set_intersection.count_of(&3), 1); + assert_eq!(set_intersection.len(), 4); } }