From 1f838c639908d1a69e735e938e09ec687df21b3b Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 3 Jan 2024 10:21:33 -0800 Subject: [PATCH 01/20] Added `nodes` field to `EGraph` to avoid storing nodes in `analysis` and `analysis_pending` --- src/eclass.rs | 10 +++++----- src/egraph.rs | 41 +++++++++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/eclass.rs b/src/eclass.rs index 5f74b2c2..640dea63 100644 --- a/src/eclass.rs +++ b/src/eclass.rs @@ -17,8 +17,8 @@ pub struct EClass { /// Modifying this field will _not_ cause changes to propagate through the e-graph. /// Prefer [`EGraph::set_analysis_data`] instead. pub data: D, - /// The parent enodes and their original Ids. - pub(crate) parents: Vec<(L, Id)>, + /// The original Ids of parent enodes. + pub(crate) parents: Vec, } impl EClass { @@ -37,9 +37,9 @@ impl EClass { self.nodes.iter() } - /// Iterates over the parent enodes of this eclass. - pub fn parents(&self) -> impl ExactSizeIterator { - self.parents.iter().map(|(node, id)| (node, *id)) + /// Iterates over the non-canonical ids of parent enodes of this eclass. + pub fn parents(&self) -> impl ExactSizeIterator + '_ { + self.parents.iter().copied() } } diff --git a/src/egraph.rs b/src/egraph.rs index 6af452b2..f05456de 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -57,6 +57,8 @@ pub struct EGraph> { /// The `Explain` used to explain equivalences in this `EGraph`. pub(crate) explain: Option>, unionfind: UnionFind, + /// Stores the original node represented by each non-canonical id + nodes: Vec, /// Stores each enode's `Id`, not the `Id` of the eclass. /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new /// unions can cause them to become out of date. @@ -64,8 +66,8 @@ pub struct EGraph> { memo: HashMap, /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, /// not the canonical id of the eclass. - pending: Vec<(L, Id)>, - analysis_pending: UniqueQueue<(L, Id)>, + pending: Vec, + analysis_pending: UniqueQueue, #[cfg_attr( feature = "serde-1", serde(bound( @@ -114,6 +116,7 @@ impl> EGraph { analysis, classes: Default::default(), unionfind: Default::default(), + nodes: Default::default(), clean: false, explain: None, pending: Default::default(), @@ -769,7 +772,9 @@ impl> EGraph { *existing_explain } else { let new_id = self.unionfind.make_set(); - explain.add(original, new_id, new_id); + explain.add(original.clone(), new_id, new_id); + self.nodes.push(original); + debug_assert_eq!(Id::from(self.nodes.len()), new_id); self.unionfind.union(id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); new_id @@ -778,7 +783,7 @@ impl> EGraph { existing_id } } else { - let id = self.make_new_eclass(enode); + let id = self.make_new_eclass(enode, original.clone()); if let Some(explain) = self.explain.as_mut() { explain.add(original, id, id); } @@ -791,24 +796,26 @@ impl> EGraph { } /// This function makes a new eclass in the egraph (but doesn't touch explanations) - fn make_new_eclass(&mut self, enode: L) -> Id { + fn make_new_eclass(&mut self, enode: L, original: L) -> Id { let id = self.unionfind.make_set(); log::trace!(" ...adding to {}", id); let class = EClass { id, nodes: vec![enode.clone()], - data: N::make(self, &enode), + data: N::make(self, &original), parents: Default::default(), }; + self.nodes.push(original); + debug_assert_eq!(Id::from(self.nodes.len()), id); + // add this enode to the parent lists of its children enode.for_each(|child| { - let tup = (enode.clone(), id); - self[child].parents.push(tup); + self[child].parents.push(id); }); // TODO is this needed? - self.pending.push((enode.clone(), id)); + self.pending.push(id); self.classes.insert(id, class); assert!(self.memo.insert(enode, id).is_none()); @@ -943,13 +950,13 @@ impl> EGraph { let class1 = self.classes.get_mut(&id1).unwrap(); assert_eq!(id1, class1.id); - self.pending.extend(class2.parents.iter().cloned()); + self.pending.extend(class2.parents.iter().copied()); let did_merge = self.analysis.merge(&mut class1.data, class2.data); if did_merge.0 { - self.analysis_pending.extend(class1.parents.iter().cloned()); + self.analysis_pending.extend(class1.parents.iter().copied()); } if did_merge.1 { - self.analysis_pending.extend(class2.parents.iter().cloned()); + self.analysis_pending.extend(class2.parents.iter().copied()); } concat_vecs(&mut class1.nodes, class2.nodes); @@ -968,7 +975,7 @@ impl> EGraph { let id = self.find_mut(id); let class = self.classes.get_mut(&id).unwrap(); class.data = new_data; - self.analysis_pending.extend(class.parents.iter().cloned()); + self.analysis_pending.extend(class.parents.iter().copied()); N::modify(self, id) } @@ -1103,7 +1110,8 @@ impl> EGraph { let mut n_unions = 0; while !self.pending.is_empty() || !self.analysis_pending.is_empty() { - while let Some((mut node, class)) = self.pending.pop() { + while let Some(class) = self.pending.pop() { + let mut node = self.nodes[usize::from(class)].clone(); node.update_children(|id| self.find_mut(id)); if let Some(memo_class) = self.memo.insert(node, class) { let did_something = self.perform_union( @@ -1116,14 +1124,15 @@ impl> EGraph { } } - while let Some((node, class_id)) = self.analysis_pending.pop() { + while let Some(class_id) = self.analysis_pending.pop() { + let node = self.nodes[usize::from(class_id)].clone(); let class_id = self.find_mut(class_id); let node_data = N::make(self, &node); let class = self.classes.get_mut(&class_id).unwrap(); let did_merge = self.analysis.merge(&mut class.data, node_data); if did_merge.0 { - self.analysis_pending.extend(class.parents.iter().cloned()); + self.analysis_pending.extend(class.parents.iter().copied()); N::modify(self, class_id) } } From 3145a305d5054fb475e291528be46aff8deeb201 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 3 Jan 2024 11:43:49 -0800 Subject: [PATCH 02/20] eliminated `node` field of `ExplainNode` (used `EGraph.nodes` instead) --- src/egraph.rs | 107 ++++++++++++++------ src/explain.rs | 260 ++++++++++++++++++++----------------------------- 2 files changed, 186 insertions(+), 181 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index f05456de..3e0d8225 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -217,12 +217,11 @@ impl> EGraph { /// Make a copy of the egraph with the same nodes, but no unions between them. pub fn copy_without_unions(&self, analysis: N) -> Self { - if let Some(explain) = &self.explain { - let egraph = Self::new(analysis); - explain.populate_enodes(egraph) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions"); + let mut egraph = Self::new(analysis); + for node in &self.nodes { + egraph.add(node.clone()); } + egraph } /// Performs the union between two egraphs. @@ -342,20 +341,33 @@ impl> EGraph { /// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical), /// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical)) pub fn id_to_expr(&self, id: Id) -> RecExpr { - if let Some(explain) = &self.explain { - explain.node_to_recexpr(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); + let mut res = Default::default(); + let mut cache = Default::default(); + self.id_to_expr_internal(&mut res, id, &mut cache); + res + } + + fn id_to_expr_internal( + &self, + res: &mut RecExpr, + node_id: Id, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; } + let new_node = self + .id_to_node(node_id) + .clone() + .map_children(|child| self.id_to_expr_internal(res, child, cache)); + let res_id = res.add(new_node); + cache.insert(node_id, res_id); + res_id } /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep pub fn id_to_node(&self, id: Id) -> &L { - if let Some(explain) = &self.explain { - explain.node(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); - } + &self.nodes[usize::from(id)] } /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term. @@ -363,11 +375,36 @@ impl> EGraph { /// It also adds this variable and the corresponding Id value to the resulting [`Subst`] /// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr). pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap) -> (Pattern, Subst) { - if let Some(explain) = &self.explain { - explain.node_to_pattern(id, substitutions) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique patterns per id"); + let mut res = Default::default(); + let mut subst = Default::default(); + let mut cache = Default::default(); + self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache); + (Pattern::new(res), subst) + } + + fn id_to_pattern_internal( + &self, + res: &mut PatternAst, + node_id: Id, + var_substitutions: &HashMap, + subst: &mut Subst, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; } + let res_id = if let Some(existing) = var_substitutions.get(&node_id) { + let var = format!("?{}", node_id).parse().unwrap(); + subst.insert(var, *existing); + res.add(ENodeOrVar::Var(var)) + } else { + let new_node = self.id_to_node(node_id).clone().map_children(|child| { + self.id_to_pattern_internal(res, child, var_substitutions, subst, cache) + }); + res.add(ENodeOrVar::ENode(new_node)) + }; + cache.insert(node_id, res_id); + res_id } /// Get all the unions ever found in the egraph in terms of enode ids. @@ -393,8 +430,10 @@ impl> EGraph { /// Get the number of congruences between nodes in the egraph. /// Only available when explanations are enabled. pub fn get_num_congr(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_congr::(&self.classes, &self.unionfind) + if let Some(explain) = &mut self.explain { + explain + .with_nodes(&self.nodes) + .get_num_congr::(&self.classes, &self.unionfind) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -402,8 +441,8 @@ impl> EGraph { /// Get the number of nodes in the egraph used for explanations. pub fn get_explanation_num_nodes(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_nodes() + if let Some(explain) = &mut self.explain { + explain.with_nodes(&self.nodes).get_num_nodes() } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -441,7 +480,12 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain.with_nodes(&self.nodes).explain_equivalence::( + left, + right, + &mut self.unionfind, + &self.classes, + ) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -464,7 +508,7 @@ impl> EGraph { /// but more efficient fn explain_existance_id(&mut self, id: Id) -> Explanation { if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_nodes(&self.nodes).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -478,7 +522,7 @@ impl> EGraph { ) -> Explanation { let id = self.add_instantiation_noncanonical(pattern, subst); if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_nodes(&self.nodes).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -501,7 +545,12 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain.with_nodes(&self.nodes).explain_equivalence::( + left, + right, + &mut self.unionfind, + &self.classes, + ) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations."); } @@ -1213,9 +1262,9 @@ impl> EGraph { n_unions } - pub(crate) fn check_each_explain(&self, rules: &[&Rewrite]) -> bool { - if let Some(explain) = &self.explain { - explain.check_each_explain(rules) + pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite]) -> bool { + if let Some(explain) = &mut self.explain { + explain.with_nodes(&self.nodes).check_each_explain(rules) } else { panic!("Can't check explain when explanations are off"); } diff --git a/src/explain.rs b/src/explain.rs index 187aecfc..59315615 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1,12 +1,13 @@ use crate::Symbol; use crate::{ - util::pretty_print, Analysis, EClass, EGraph, ENodeOrVar, FromOp, HashMap, HashSet, Id, - Language, Pattern, PatternAst, RecExpr, Rewrite, Subst, UnionFind, Var, + util::pretty_print, Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, + PatternAst, RecExpr, Rewrite, UnionFind, Var, }; use saturating::Saturating; use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; +use std::ops::{Deref, DerefMut}; use std::rc::Rc; use symbolic_expressions::Sexp; @@ -38,8 +39,7 @@ struct Connection { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -struct ExplainNode { - node: L, +struct ExplainNode { // neighbors includes parent connections neighbors: Vec, parent_connection: Connection, @@ -54,7 +54,7 @@ struct ExplainNode { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] pub struct Explain { - explainfind: Vec>, + explainfind: Vec, #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] pub uncanon_memo: HashMap, /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted. @@ -69,6 +69,11 @@ pub struct Explain { shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>, } +pub(crate) struct ExplainNodes<'a, L: Language> { + explain: &'a mut Explain, + nodes: &'a [L], +} + #[derive(Default)] struct DistanceMemo { parent_distance: Vec<(Id, ProofCost)>, @@ -883,97 +888,6 @@ impl PartialOrd for HeapState { } impl Explain { - pub(crate) fn node(&self, node_id: Id) -> &L { - &self.explainfind[usize::from(node_id)].node - } - fn node_to_explanation( - &self, - node_id: Id, - cache: &mut NodeExplanationCache, - ) -> Rc> { - if let Some(existing) = cache.get(&node_id) { - existing.clone() - } else { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(vec![self.node_to_explanation(child, cache)]); - sofar - }); - let res = Rc::new(TreeTerm::new(node, children)); - cache.insert(node_id, res.clone()); - res - } - } - - pub(crate) fn node_to_recexpr(&self, node_id: Id) -> RecExpr { - let mut res = Default::default(); - let mut cache = Default::default(); - self.node_to_recexpr_internal(&mut res, node_id, &mut cache); - res - } - fn node_to_recexpr_internal( - &self, - res: &mut RecExpr, - node_id: Id, - cache: &mut HashMap, - ) { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_recexpr_internal(res, child, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(new_node); - } - - pub(crate) fn node_to_pattern( - &self, - node_id: Id, - substitutions: &HashMap, - ) -> (Pattern, Subst) { - let mut res = Default::default(); - let mut subst = Default::default(); - let mut cache = Default::default(); - self.node_to_pattern_internal(&mut res, node_id, substitutions, &mut subst, &mut cache); - (Pattern::new(res), subst) - } - - fn node_to_pattern_internal( - &self, - res: &mut PatternAst, - node_id: Id, - var_substitutions: &HashMap, - subst: &mut Subst, - cache: &mut HashMap, - ) { - if let Some(existing) = var_substitutions.get(&node_id) { - let var = format!("?{}", node_id).parse().unwrap(); - res.add(ENodeOrVar::Var(var)); - subst.insert(var, *existing); - } else { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_pattern_internal(res, child, var_substitutions, subst, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(ENodeOrVar::ENode(new_node)); - } - } - - fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(self.node_to_flat_explanation(child)); - sofar - }); - FlatTerm::new(node, children) - } - fn make_rule_table<'a, N: Analysis>( rules: &[&'a Rewrite], ) -> HashMap> { @@ -983,52 +897,6 @@ impl Explain { } table } - - pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { - let rule_table = Explain::make_rule_table(rules); - for i in 0..self.explainfind.len() { - let explain_node = &self.explainfind[i]; - - // check that explanation reasons never form a cycle - let mut existance = i; - let mut seen_existance: HashSet = Default::default(); - loop { - seen_existance.insert(existance); - let next = usize::from(self.explainfind[existance].existance_node); - if existance == next { - break; - } - existance = next; - if seen_existance.contains(&existance) { - panic!("Cycle in existance!"); - } - } - - if explain_node.parent_connection.next != Id::from(i) { - let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); - let mut next_explanation = - self.node_to_flat_explanation(explain_node.parent_connection.next); - if let Justification::Rule(rule_name) = - &explain_node.parent_connection.justification - { - if let Some(rule) = rule_table.get(rule_name) { - if !explain_node.parent_connection.is_rewrite_forward { - std::mem::swap(&mut current_explanation, &mut next_explanation); - } - if !Explanation::check_rewrite( - ¤t_explanation, - &next_explanation, - rule, - ) { - return false; - } - } - } - } - } - true - } - pub fn new() -> Self { Explain { explainfind: vec![], @@ -1046,7 +914,6 @@ impl Explain { assert_eq!(self.explainfind.len(), usize::from(set)); self.uncanon_memo.insert(node.clone(), set); self.explainfind.push(ExplainNode { - node, neighbors: vec![], parent_connection: Connection { justification: Justification::Congruence, @@ -1119,7 +986,7 @@ impl Explain { new_rhs: bool, ) { if let Justification::Congruence = justification { - assert!(self.node(node1).matches(self.node(node2))); + // assert!(self.node(node1).matches(self.node(node2))); } if new_rhs { self.set_existance_reason(node2, node1) @@ -1155,7 +1022,6 @@ impl Explain { .push(other_pconnection); self.explainfind[usize::from(node1)].parent_connection = pconnection; } - pub(crate) fn get_union_equalities(&self) -> UnionEqualities { let mut equalities = vec![]; for node in &self.explainfind { @@ -1170,13 +1036,103 @@ impl Explain { equalities } - pub(crate) fn populate_enodes>(&self, mut egraph: EGraph) -> EGraph { - for i in 0..self.explainfind.len() { - let node = &self.explainfind[i]; - egraph.add(node.node.clone()); + pub(crate) fn with_nodes<'a>(&'a mut self, nodes: &'a [L]) -> ExplainNodes<'a, L> { + ExplainNodes { + explain: self, + nodes, } + } +} + +impl<'a, L: Language> Deref for ExplainNodes<'a, L> { + type Target = Explain; + + fn deref(&self) -> &Self::Target { + self.explain + } +} + +impl<'a, L: Language> DerefMut for ExplainNodes<'a, L> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.explain + } +} + +impl<'x, L: Language> ExplainNodes<'x, L> { + pub(crate) fn node(&self, node_id: Id) -> &L { + &self.nodes[usize::from(node_id)] + } + fn node_to_explanation( + &self, + node_id: Id, + cache: &mut NodeExplanationCache, + ) -> Rc> { + if let Some(existing) = cache.get(&node_id) { + existing.clone() + } else { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(vec![self.node_to_explanation(child, cache)]); + sofar + }); + let res = Rc::new(TreeTerm::new(node, children)); + cache.insert(node_id, res.clone()); + res + } + } + + fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(self.node_to_flat_explanation(child)); + sofar + }); + FlatTerm::new(node, children) + } + + pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { + let rule_table = Explain::make_rule_table(rules); + for i in 0..self.explainfind.len() { + let explain_node = &self.explainfind[i]; + + // check that explanation reasons never form a cycle + let mut existance = i; + let mut seen_existance: HashSet = Default::default(); + loop { + seen_existance.insert(existance); + let next = usize::from(self.explainfind[existance].existance_node); + if existance == next { + break; + } + existance = next; + if seen_existance.contains(&existance) { + panic!("Cycle in existance!"); + } + } - egraph + if explain_node.parent_connection.next != Id::from(i) { + let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); + let mut next_explanation = + self.node_to_flat_explanation(explain_node.parent_connection.next); + if let Justification::Rule(rule_name) = + &explain_node.parent_connection.justification + { + if let Some(rule) = rule_table.get(rule_name) { + if !explain_node.parent_connection.is_rewrite_forward { + std::mem::swap(&mut current_explanation, &mut next_explanation); + } + if !Explanation::check_rewrite( + ¤t_explanation, + &next_explanation, + rule, + ) { + return false; + } + } + } + } + } + true } pub(crate) fn explain_equivalence>( @@ -1328,7 +1284,7 @@ impl Explain { let mut new_rest_of_proof = (*self.node_to_explanation(existance, enode_cache)).clone(); let mut index_of_child = 0; let mut found = false; - existance_node.node.for_each(|child| { + self.node(existance).for_each(|child| { if found { return; } @@ -2092,7 +2048,7 @@ mod tests { #[test] fn simple_explain_union_trusted() { - use crate::SymbolLang; + use crate::{EGraph, SymbolLang}; crate::init_logger(); let mut egraph = EGraph::new(()).with_explanations_enabled(); From c075cbf95bb7f3abb1cd904c59bc6ab3ad2752f2 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 3 Jan 2024 12:34:37 -0800 Subject: [PATCH 03/20] serde --- src/explain.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/explain.rs b/src/explain.rs index 59315615..a2d0a2b2 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -56,6 +56,10 @@ struct ExplainNode { pub struct Explain { explainfind: Vec, #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] + #[cfg_attr( + feature = "serde-1", + serde(bound(serialize = "L: Serialize", deserialize = "L: for<'a> Deserialize<'a>",)) + )] pub uncanon_memo: HashMap, /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted. pub optimize_explanation_lengths: bool, @@ -912,7 +916,7 @@ impl Explain { pub(crate) fn add(&mut self, node: L, set: Id, existance_node: Id) -> Id { assert_eq!(self.explainfind.len(), usize::from(set)); - self.uncanon_memo.insert(node.clone(), set); + self.uncanon_memo.insert(node, set); self.explainfind.push(ExplainNode { neighbors: vec![], parent_connection: Connection { From 3187e3688053cf29f666a0d69c87e9c5d96d5535 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 3 Jan 2024 13:08:41 -0800 Subject: [PATCH 04/20] serde --- src/explain.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/explain.rs b/src/explain.rs index a2d0a2b2..9de2a17e 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -58,7 +58,10 @@ pub struct Explain { #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] #[cfg_attr( feature = "serde-1", - serde(bound(serialize = "L: Serialize", deserialize = "L: for<'a> Deserialize<'a>",)) + serde(bound( + serialize = "L: serde::Serialize", + deserialize = "L: serde::Deserialize<'de>", + )) )] pub uncanon_memo: HashMap, /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted. From 4d4c52d3fd4091ddafe484bbe7d11f87ff8ca4bd Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 3 Jan 2024 19:22:11 -0800 Subject: [PATCH 05/20] Clarify `id_to_expr` and prevent `copy_with_unions` when explanations are disabled --- src/egraph.rs | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 3e0d8225..3f292460 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -217,6 +217,9 @@ impl> EGraph { /// Make a copy of the egraph with the same nodes, but no unions between them. pub fn copy_without_unions(&self, analysis: N) -> Self { + if self.explain.is_none() { + panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions"); + } let mut egraph = Self::new(analysis); for node in &self.nodes { egraph.add(node.clone()); @@ -638,7 +641,7 @@ impl> EGraph { /// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical /// - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` + /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled pub fn add_expr_uncanonical(&mut self, expr: &RecExpr) -> Id { let nodes = expr.as_ref(); let mut new_ids = Vec::with_capacity(nodes.len()); @@ -676,7 +679,7 @@ impl> EGraph { /// canonical /// /// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an corrispond to the + /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an correspond to the /// instantiation of the pattern fn add_instantiation_noncanonical(&mut self, pat: &PatternAst, subst: &Subst) -> Id { let nodes = pat.as_ref(); @@ -796,7 +799,7 @@ impl> EGraph { /// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will /// correspond to the parameter `enode` /// - /// # Example + /// ## Example /// ``` /// # use egg::*; /// let mut egraph: EGraph = EGraph::default().with_explanations_enabled(); @@ -811,6 +814,25 @@ impl> EGraph { /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); /// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap()); /// ``` + /// + /// When explanations are not enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will + /// produce an expression with equivalent but not necessarily identical children + /// + /// # Example + /// ``` + /// # use egg::*; + /// let mut egraph: EGraph = EGraph::default().with_explanations_disabled(); + /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + /// egraph.union(a, b); + /// egraph.rebuild(); + /// + /// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a])); + /// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b])); + /// + /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); + /// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap()); + /// ``` pub fn add_uncanonical(&mut self, mut enode: L) -> Id { let original = enode.clone(); if let Some(existing_id) = self.lookup_internal(&mut enode) { @@ -822,8 +844,8 @@ impl> EGraph { } else { let new_id = self.unionfind.make_set(); explain.add(original.clone(), new_id, new_id); - self.nodes.push(original); debug_assert_eq!(Id::from(self.nodes.len()), new_id); + self.nodes.push(original); self.unionfind.union(id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); new_id @@ -855,8 +877,8 @@ impl> EGraph { parents: Default::default(), }; - self.nodes.push(original); debug_assert_eq!(Id::from(self.nodes.len()), id); + self.nodes.push(original); // add this enode to the parent lists of its children enode.for_each(|child| { From 8bcfe665eee9828f26d9ff9db6581d20f0df5c65 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Sat, 3 Feb 2024 11:02:22 -0800 Subject: [PATCH 06/20] Extracted out low level egraph API --- src/dot.rs | 30 +-- src/eclass.rs | 28 +- src/egraph.rs | 460 +++++++++----------------------- src/explain.rs | 81 ++---- src/lib.rs | 3 + src/raw.rs | 5 + src/raw/eclass.rs | 43 +++ src/raw/egraph.rs | 656 ++++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 885 insertions(+), 421 deletions(-) create mode 100644 src/raw.rs create mode 100644 src/raw/eclass.rs create mode 100644 src/raw/egraph.rs diff --git a/src/dot.rs b/src/dot.rs index cefaf440..111fac51 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -11,7 +11,7 @@ use std::fmt::{self, Debug, Display, Formatter}; use std::io::{Error, ErrorKind, Result, Write}; use std::path::Path; -use crate::{egraph::EGraph, Analysis, Language}; +use crate::{raw, Language}; /** A wrapper for an [`EGraph`] that can output [GraphViz] for @@ -50,8 +50,8 @@ instead of to its own eclass. [GraphViz]: https://graphviz.gitlab.io/ **/ -pub struct Dot<'a, L: Language, N: Analysis> { - pub(crate) egraph: &'a EGraph, +pub struct Dot<'a, L: Language> { + pub(crate) egraph: &'a raw::EGraphResidual, /// A list of strings to be output top part of the dot file. pub config: Vec, /// Whether or not to anchor the edges in the output. @@ -59,10 +59,9 @@ pub struct Dot<'a, L: Language, N: Analysis> { pub use_anchors: bool, } -impl<'a, L, N> Dot<'a, L, N> +impl<'a, L> Dot<'a, L> where L: Language + Display, - N: Analysis, { /// Writes the `Dot` to a .dot file with the given filename. /// Does _not_ require a `dot` binary. @@ -170,16 +169,15 @@ where } } -impl<'a, L: Language, N: Analysis> Debug for Dot<'a, L, N> { +impl<'a, L: Language> Debug for Dot<'a, L> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_tuple("Dot").field(self.egraph).finish() } } -impl<'a, L, N> Display for Dot<'a, L, N> +impl<'a, L> Display for Dot<'a, L> where L: Language + Display, - N: Analysis, { fn fmt(&self, f: &mut Formatter) -> fmt::Result { writeln!(f, "digraph egraph {{")?; @@ -192,17 +190,19 @@ where writeln!(f, " {}", line)?; } + let classes = self.egraph.generate_class_nodes(); + // define all the nodes, clustered by eclass - for class in self.egraph.classes() { - writeln!(f, " subgraph cluster_{} {{", class.id)?; + for (&id, class) in &classes { + writeln!(f, " subgraph cluster_{} {{", id)?; writeln!(f, " style=dotted")?; for (i, node) in class.iter().enumerate() { - writeln!(f, " {}.{}[label = \"{}\"]", class.id, i, node)?; + writeln!(f, " {}.{}[label = \"{}\"]", id, i, node)?; } writeln!(f, " }}")?; } - for class in self.egraph.classes() { + for (&id, class) in &classes { for (i_in_class, node) in class.iter().enumerate() { let mut arg_i = 0; node.try_for_each(|child| { @@ -210,19 +210,19 @@ where let (anchor, label) = self.edge(arg_i, node.len()); let child_leader = self.egraph.find(child); - if child_leader == class.id { + if child_leader == id { writeln!( f, // {}.0 to pick an arbitrary node in the cluster " {}.{}{} -> {}.{}:n [lhead = cluster_{}, {}]", - class.id, i_in_class, anchor, class.id, i_in_class, class.id, label + id, i_in_class, anchor, id, i_in_class, id, label )?; } else { writeln!( f, // {}.0 to pick an arbitrary node in the cluster " {}.{}{} -> {}.0 [lhead = cluster_{}, {}]", - class.id, i_in_class, anchor, child, child_leader, label + id, i_in_class, anchor, child, child_leader, label )?; } arg_i += 1; diff --git a/src/eclass.rs b/src/eclass.rs index 640dea63..8136cff4 100644 --- a/src/eclass.rs +++ b/src/eclass.rs @@ -1,15 +1,13 @@ -use std::fmt::Debug; +use std::fmt::{Debug, Formatter}; use std::iter::ExactSizeIterator; use crate::*; -/// An equivalence class of enodes. +/// The additional data required to turn a [`raw::RawEClass`] into a [`EClass`] #[non_exhaustive] -#[derive(Debug, Clone)] +#[derive(Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -pub struct EClass { - /// This eclass's id. - pub id: Id, +pub struct EClassData { /// The equivalent enodes in this equivalence class. pub nodes: Vec, /// The analysis data associated with this eclass. @@ -17,10 +15,19 @@ pub struct EClass { /// Modifying this field will _not_ cause changes to propagate through the e-graph. /// Prefer [`EGraph::set_analysis_data`] instead. pub data: D, - /// The original Ids of parent enodes. - pub(crate) parents: Vec, } +impl Debug for EClassData { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut nodes = self.nodes.clone(); + nodes.sort(); + writeln!(f, "({:?}): {:?}", self.data, nodes) + } +} + +/// An equivalence class of enodes +pub type EClass = raw::RawEClass>; + impl EClass { /// Returns `true` if the `eclass` is empty. pub fn is_empty(&self) -> bool { @@ -36,11 +43,6 @@ impl EClass { pub fn iter(&self) -> impl ExactSizeIterator { self.nodes.iter() } - - /// Iterates over the non-canonical ids of parent enodes of this eclass. - pub fn parents(&self) -> impl ExactSizeIterator + '_ { - self.parents.iter().copied() - } } impl EClass { diff --git a/src/egraph.rs b/src/egraph.rs index 3f292460..6be8729c 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1,12 +1,12 @@ use crate::*; -use std::{ - borrow::BorrowMut, - fmt::{self, Debug, Display}, -}; +use std::fmt::{self, Debug, Display}; +use std::ops::Deref; #[cfg(feature = "serde-1")] use serde::{Deserialize, Serialize}; +use crate::eclass::EClassData; +use crate::raw::{EGraphResidual, RawEGraph}; use log::*; /** A data structure to keep track of equalities between expressions. @@ -56,17 +56,6 @@ pub struct EGraph> { pub analysis: N, /// The `Explain` used to explain equivalences in this `EGraph`. pub(crate) explain: Option>, - unionfind: UnionFind, - /// Stores the original node represented by each non-canonical id - nodes: Vec, - /// Stores each enode's `Id`, not the `Id` of the eclass. - /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new - /// unions can cause them to become out of date. - #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] - memo: HashMap, - /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, - /// not the canonical id of the eclass. - pending: Vec, analysis_pending: UniqueQueue, #[cfg_attr( feature = "serde-1", @@ -75,7 +64,7 @@ pub struct EGraph> { deserialize = "N::Data: for<'a> Deserialize<'a>", )) )] - pub(crate) classes: HashMap>, + pub(crate) inner: RawEGraph>, #[cfg_attr(feature = "serde-1", serde(skip))] #[cfg_attr(feature = "serde-1", serde(default = "default_classes_by_op"))] pub(crate) classes_by_op: HashMap>, @@ -102,10 +91,16 @@ impl + Default> Default for EGraph { // manual debug impl to avoid L: Language bound on EGraph defn impl> Debug for EGraph { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("EGraph") - .field("memo", &self.memo) - .field("classes", &self.classes) - .finish() + self.inner.fmt(f) + } +} + +impl> Deref for EGraph { + type Target = EGraphResidual; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.inner } } @@ -114,13 +109,9 @@ impl> EGraph { pub fn new(analysis: N) -> Self { Self { analysis, - classes: Default::default(), - unionfind: Default::default(), - nodes: Default::default(), clean: false, explain: None, - pending: Default::default(), - memo: Default::default(), + inner: Default::default(), analysis_pending: Default::default(), classes_by_op: Default::default(), } @@ -128,25 +119,12 @@ impl> EGraph { /// Returns an iterator over the eclasses in the egraph. pub fn classes(&self) -> impl ExactSizeIterator> { - self.classes.values() + self.inner.classes() } /// Returns an mutating iterator over the eclasses in the egraph. pub fn classes_mut(&mut self) -> impl ExactSizeIterator> { - self.classes.values_mut() - } - - /// Returns `true` if the egraph is empty - /// # Example - /// ``` - /// use egg::{*, SymbolLang as S}; - /// let mut egraph = EGraph::::default(); - /// assert!(egraph.is_empty()); - /// egraph.add(S::leaf("foo")); - /// assert!(!egraph.is_empty()); - /// ``` - pub fn is_empty(&self) -> bool { - self.memo.is_empty() + self.inner.classes_mut().0 } /// Returns the number of enodes in the `EGraph`. @@ -166,7 +144,7 @@ impl> EGraph { /// assert_eq!(egraph.number_of_classes(), 1); /// ``` pub fn total_size(&self) -> usize { - self.memo.len() + self.inner.total_size() } /// Iterates over the classes, returning the total number of nodes. @@ -176,7 +154,7 @@ impl> EGraph { /// Returns the number of eclasses in the egraph. pub fn number_of_classes(&self) -> usize { - self.classes.len() + self.classes().len() } /// Enable explanations for this `EGraph`. @@ -221,7 +199,7 @@ impl> EGraph { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions"); } let mut egraph = Self::new(analysis); - for node in &self.nodes { + for (_, node) in self.uncanonical_nodes() { egraph.add(node.clone()); } egraph @@ -315,8 +293,8 @@ impl> EGraph { product_map: &mut HashMap<(Id, Id), Id>, ) { let res_id = Self::get_product_id(class1, class2, product_map); - for node1 in &self.classes[&class1].nodes { - for node2 in &other.classes[&class2].nodes { + for node1 in &self[class1].nodes { + for node2 in &other[class2].nodes { if node1.matches(node2) { let children1 = node1.children(); let children2 = node2.children(); @@ -338,41 +316,6 @@ impl> EGraph { } } - /// Pick a representative term for a given Id. - /// - /// Calling this function on an uncanonical `Id` returns a representative based on the how it - /// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical), - /// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical)) - pub fn id_to_expr(&self, id: Id) -> RecExpr { - let mut res = Default::default(); - let mut cache = Default::default(); - self.id_to_expr_internal(&mut res, id, &mut cache); - res - } - - fn id_to_expr_internal( - &self, - res: &mut RecExpr, - node_id: Id, - cache: &mut HashMap, - ) -> Id { - if let Some(existing) = cache.get(&node_id) { - return *existing; - } - let new_node = self - .id_to_node(node_id) - .clone() - .map_children(|child| self.id_to_expr_internal(res, child, cache)); - let res_id = res.add(new_node); - cache.insert(node_id, res_id); - res_id - } - - /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep - pub fn id_to_node(&self, id: Id) -> &L { - &self.nodes[usize::from(id)] - } - /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term. /// When an eclass listed in the given substitutions is found, it creates a variable. /// It also adds this variable and the corresponding Id value to the resulting [`Subst`] @@ -434,9 +377,7 @@ impl> EGraph { /// Only available when explanations are enabled. pub fn get_num_congr(&mut self) -> usize { if let Some(explain) = &mut self.explain { - explain - .with_nodes(&self.nodes) - .get_num_congr::(&self.classes, &self.unionfind) + explain.with_raw_egraph(&self.inner).get_num_congr() } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -444,11 +385,7 @@ impl> EGraph { /// Get the number of nodes in the egraph used for explanations. pub fn get_explanation_num_nodes(&mut self) -> usize { - if let Some(explain) = &mut self.explain { - explain.with_nodes(&self.nodes).get_num_nodes() - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") - } + self.number_of_uncanonical_nodes() } /// When explanations are enabled, this function @@ -483,12 +420,9 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.with_nodes(&self.nodes).explain_equivalence::( - left, - right, - &mut self.unionfind, - &self.classes, - ) + explain + .with_raw_egraph(&self.inner) + .explain_equivalence(left, right) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -511,7 +445,7 @@ impl> EGraph { /// but more efficient fn explain_existance_id(&mut self, id: Id) -> Explanation { if let Some(explain) = &mut self.explain { - explain.with_nodes(&self.nodes).explain_existance(id) + explain.with_raw_egraph(&self.inner).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -525,7 +459,7 @@ impl> EGraph { ) -> Explanation { let id = self.add_instantiation_noncanonical(pattern, subst); if let Some(explain) = &mut self.explain { - explain.with_nodes(&self.nodes).explain_existance(id) + explain.with_raw_egraph(&self.inner).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -548,63 +482,20 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.with_nodes(&self.nodes).explain_equivalence::( - left, - right, - &mut self.unionfind, - &self.classes, - ) + explain + .with_raw_egraph(&self.inner) + .explain_equivalence(left, right) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations."); } } - - /// Canonicalizes an eclass id. - /// - /// This corresponds to the `find` operation on the egraph's - /// underlying unionfind data structure. - /// - /// # Example - /// ``` - /// use egg::{*, SymbolLang as S}; - /// let mut egraph = EGraph::::default(); - /// let x = egraph.add(S::leaf("x")); - /// let y = egraph.add(S::leaf("y")); - /// assert_ne!(egraph.find(x), egraph.find(y)); - /// - /// egraph.union(x, y); - /// egraph.rebuild(); - /// assert_eq!(egraph.find(x), egraph.find(y)); - /// ``` - pub fn find(&self, id: Id) -> Id { - self.unionfind.find(id) - } - - /// This is private, but internals should use this whenever - /// possible because it does path compression. - fn find_mut(&mut self, id: Id) -> Id { - self.unionfind.find_mut(id) - } - - /// Creates a [`Dot`] to visualize this egraph. See [`Dot`]. - /// - pub fn dot(&self) -> Dot { - Dot { - egraph: self, - config: vec![], - use_anchors: true, - } - } } /// Given an `Id` using the `egraph[id]` syntax, retrieve the e-class. impl> std::ops::Index for EGraph { type Output = EClass; fn index(&self, id: Id) -> &Self::Output { - let id = self.find(id); - self.classes - .get(&id) - .unwrap_or_else(|| panic!("Invalid id {}", id)) + self.inner.get_class(id) } } @@ -612,10 +503,7 @@ impl> std::ops::Index for EGraph { /// reference to the e-class. impl> std::ops::IndexMut for EGraph { fn index_mut(&mut self, id: Id) -> &mut Self::Output { - let id = self.find_mut(id); - self.classes - .get_mut(&id) - .unwrap_or_else(|| panic!("Invalid id {}", id)) + self.inner.get_class_mut(id).0 } } @@ -648,9 +536,9 @@ impl> EGraph { let mut new_node_q = Vec::with_capacity(nodes.len()); for node in nodes { let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]); - let size_before = self.unionfind.size(); + let size_before = self.inner.number_of_uncanonical_nodes(); let next_id = self.add_uncanonical(new_node); - if self.unionfind.size() > size_before { + if self.inner.number_of_uncanonical_nodes() > size_before { new_node_q.push(true); } else { new_node_q.push(false); @@ -694,9 +582,9 @@ impl> EGraph { } ENodeOrVar::ENode(node) => { let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]); - let size_before = self.unionfind.size(); + let size_before = self.inner.number_of_uncanonical_nodes(); let next_id = self.add_uncanonical(new_node); - if self.unionfind.size() > size_before { + if self.inner.number_of_uncanonical_nodes() > size_before { new_node_q.push(true); } else { new_node_q.push(false); @@ -716,67 +604,6 @@ impl> EGraph { *new_ids.last().unwrap() } - /// Lookup the eclass of the given enode. - /// - /// You can pass in either an owned enode or a `&mut` enode, - /// in which case the enode's children will be canonicalized. - /// - /// # Example - /// ``` - /// # use egg::*; - /// let mut egraph: EGraph = Default::default(); - /// let a = egraph.add(SymbolLang::leaf("a")); - /// let b = egraph.add(SymbolLang::leaf("b")); - /// - /// // lookup will find this node if its in the egraph - /// let mut node_f_ab = SymbolLang::new("f", vec![a, b]); - /// assert_eq!(egraph.lookup(node_f_ab.clone()), None); - /// let id = egraph.add(node_f_ab.clone()); - /// assert_eq!(egraph.lookup(node_f_ab.clone()), Some(id)); - /// - /// // if the query node isn't canonical, and its passed in by &mut instead of owned, - /// // its children will be canonicalized - /// egraph.union(a, b); - /// egraph.rebuild(); - /// assert_eq!(egraph.lookup(&mut node_f_ab), Some(id)); - /// assert_eq!(node_f_ab, SymbolLang::new("f", vec![a, a])); - /// ``` - pub fn lookup(&self, enode: B) -> Option - where - B: BorrowMut, - { - self.lookup_internal(enode).map(|id| self.find(id)) - } - - fn lookup_internal(&self, mut enode: B) -> Option - where - B: BorrowMut, - { - let enode = enode.borrow_mut(); - enode.update_children(|id| self.find(id)); - self.memo.get(enode).copied() - } - - /// Lookup the eclass of the given [`RecExpr`]. - /// - /// Equivalent to the last value in [`EGraph::lookup_expr_ids`]. - pub fn lookup_expr(&self, expr: &RecExpr) -> Option { - self.lookup_expr_ids(expr) - .and_then(|ids| ids.last().copied()) - } - - /// Lookup the eclasses of all the nodes in the given [`RecExpr`]. - pub fn lookup_expr_ids(&self, expr: &RecExpr) -> Option> { - let nodes = expr.as_ref(); - let mut new_ids = Vec::with_capacity(nodes.len()); - for node in nodes { - let node = node.clone().map_children(|i| new_ids[usize::from(i)]); - let id = self.lookup(node)?; - new_ids.push(id) - } - Some(new_ids) - } - /// Adds an enode to the [`EGraph`]. /// /// When adding an enode, to the egraph, [`add`] it performs @@ -833,64 +660,48 @@ impl> EGraph { /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); /// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap()); /// ``` - pub fn add_uncanonical(&mut self, mut enode: L) -> Id { - let original = enode.clone(); - if let Some(existing_id) = self.lookup_internal(&mut enode) { - let id = self.find(existing_id); - // when explanations are enabled, we need a new representative for this expr - if let Some(explain) = self.explain.as_mut() { - if let Some(existing_explain) = explain.uncanon_memo.get(&original) { - *existing_explain + pub fn add_uncanonical(&mut self, enode: L) -> Id { + let mut added = false; + let id = RawEGraph::raw_add( + self, + |x| &mut x.inner, + enode, + |this, existing_id, enode| { + if let Some(explain) = this.explain.as_mut() { + if let Some(existing_id) = explain.uncanon_memo.get(enode) { + Some(*existing_id) + } else { + None + } } else { - let new_id = self.unionfind.make_set(); - explain.add(original.clone(), new_id, new_id); - debug_assert_eq!(Id::from(self.nodes.len()), new_id); - self.nodes.push(original); - self.unionfind.union(id, new_id); + Some(existing_id) + } + }, + |this, existing_id, new_id| { + if let Some(explain) = this.explain.as_mut() { + explain.add(this.inner.id_to_node(new_id).clone(), new_id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); - new_id } - } else { - existing_id - } - } else { - let id = self.make_new_eclass(enode, original.clone()); + }, + |this, id, _| { + added = true; + let node = this.id_to_node(id).clone(); + let data = N::make(this, &node); + EClassData { + nodes: vec![node], + data, + } + }, + ); + if added { if let Some(explain) = self.explain.as_mut() { - explain.add(original, id, id); + explain.add(self.inner.id_to_node(id).clone(), id, id); } // now that we updated explanations, run the analysis for the new eclass N::modify(self, id); self.clean = false; - id } - } - - /// This function makes a new eclass in the egraph (but doesn't touch explanations) - fn make_new_eclass(&mut self, enode: L, original: L) -> Id { - let id = self.unionfind.make_set(); - log::trace!(" ...adding to {}", id); - let class = EClass { - id, - nodes: vec![enode.clone()], - data: N::make(self, &original), - parents: Default::default(), - }; - - debug_assert_eq!(Id::from(self.nodes.len()), id); - self.nodes.push(original); - - // add this enode to the parent lists of its children - enode.for_each(|child| { - self[child].parents.push(id); - }); - - // TODO is this needed? - self.pending.push(id); - - self.classes.insert(id, class); - assert!(self.memo.insert(enode, id).is_none()); - id } @@ -936,9 +747,9 @@ impl> EGraph { rule_name: impl Into, ) -> (Id, bool) { let id1 = self.add_instantiation_noncanonical(from_pat, subst); - let size_before = self.unionfind.size(); + let size_before = self.number_of_uncanonical_nodes(); let id2 = self.add_instantiation_noncanonical(to_pat, subst); - let rhs_new = self.unionfind.size() > size_before; + let rhs_new = self.number_of_uncanonical_nodes() > size_before; let did_union = self.perform_union( id1, @@ -992,49 +803,42 @@ impl> EGraph { N::pre_union(self, enode_id1, enode_id2, &rule); self.clean = false; - let mut id1 = self.find_mut(enode_id1); - let mut id2 = self.find_mut(enode_id2); - if id1 == id2 { + if let Some((id, class2)) = self.inner.raw_union(enode_id1, enode_id2) { + self.merge(id, class2); + if let Some(explain) = &mut self.explain { + explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs); + } + true + } else { if let Some(Justification::Rule(_)) = rule { if let Some(explain) = &mut self.explain { - explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap()); + explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap()) } } - return false; - } - // make sure class2 has fewer parents - let class1_parents = self.classes[&id1].parents.len(); - let class2_parents = self.classes[&id2].parents.len(); - if class1_parents < class2_parents { - std::mem::swap(&mut id1, &mut id2); - } - - if let Some(explain) = &mut self.explain { - explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs); + false } + } - // make id1 the new root - self.unionfind.union(id1, id2); - - assert_ne!(id1, id2); - let class2 = self.classes.remove(&id2).unwrap(); - let class1 = self.classes.get_mut(&id1).unwrap(); - assert_eq!(id1, class1.id); - - self.pending.extend(class2.parents.iter().copied()); + fn merge(&mut self, id1: Id, class2: EClass) { + let class1 = self.inner.get_class_mut_with_cannon(id1).0; + let (class2, parents) = class2.destruct(); let did_merge = self.analysis.merge(&mut class1.data, class2.data); if did_merge.0 { - self.analysis_pending.extend(class1.parents.iter().copied()); + // class1.parents already contains the combined parents, + // so we only take the ones that were there before the union + self.analysis_pending.extend( + class1 + .parents() + .take(class1.parents().len() - parents.len()), + ); } if did_merge.1 { - self.analysis_pending.extend(class2.parents.iter().copied()); + self.analysis_pending.extend(parents); } concat_vecs(&mut class1.nodes, class2.nodes); - concat_vecs(&mut class1.parents, class2.parents); - N::modify(self, id1); - true + N::modify(self, id1) } /// Update the analysis data of an e-class. @@ -1043,10 +847,9 @@ impl> EGraph { /// so [`Analysis::make`] and [`Analysis::merge`] will get /// called for other parts of the e-graph on rebuild. pub fn set_analysis_data(&mut self, id: Id, new_data: N::Data) { - let id = self.find_mut(id); - let class = self.classes.get_mut(&id).unwrap(); + let class = self.inner.get_class_mut(id).0; class.data = new_data; - self.analysis_pending.extend(class.parents.iter().copied()); + self.analysis_pending.extend(class.parents()); N::modify(self, id) } @@ -1059,7 +862,7 @@ impl> EGraph { /// /// [`Debug`]: std::fmt::Debug pub fn dump(&self) -> impl Debug + '_ { - EGraphDump(self) + self.inner.dump_classes() } } @@ -1098,9 +901,9 @@ impl> EGraph { classes_by_op.values_mut().for_each(|ids| ids.clear()); let mut trimmed = 0; - let uf = &mut self.unionfind; + let (classes, uf) = self.inner.classes_mut(); - for class in self.classes.values_mut() { + for class in classes { let old_len = class.len(); class .nodes @@ -1146,8 +949,8 @@ impl> EGraph { fn check_memo(&self) -> bool { let mut test_memo = HashMap::default(); - for (&id, class) in self.classes.iter() { - assert_eq!(class.id, id); + for class in self.classes() { + let id = class.id; for node in &class.nodes { if let Some(old) = test_memo.insert(node, id) { assert_eq!( @@ -1166,7 +969,7 @@ impl> EGraph { assert_eq!(e, self.find(e)); assert_eq!( Some(e), - self.memo.get(n).map(|id| self.find(*id)), + self.lookup(n.clone()), "Entry for {:?} at {} in test_memo was incorrect", n, e @@ -1180,36 +983,32 @@ impl> EGraph { fn process_unions(&mut self) -> usize { let mut n_unions = 0; - while !self.pending.is_empty() || !self.analysis_pending.is_empty() { - while let Some(class) = self.pending.pop() { - let mut node = self.nodes[usize::from(class)].clone(); - node.update_children(|id| self.find_mut(id)); - if let Some(memo_class) = self.memo.insert(node, class) { - let did_something = self.perform_union( - memo_class, - class, - Some(Justification::Congruence), - false, - ); + while !self.inner.is_clean() || !self.analysis_pending.is_empty() { + RawEGraph::raw_rebuild( + self, + |this| &mut this.inner, + |this, id1, id2| { + let did_something = + this.perform_union(id1, id2, Some(Justification::Congruence), false); n_unions += did_something as usize; - } - } + }, + |_, _, _| {}, + ); - while let Some(class_id) = self.analysis_pending.pop() { - let node = self.nodes[usize::from(class_id)].clone(); - let class_id = self.find_mut(class_id); + while let Some(mut class_id) = self.analysis_pending.pop() { + let node = self.id_to_node(class_id).clone(); let node_data = N::make(self, &node); - let class = self.classes.get_mut(&class_id).unwrap(); + let class = self.inner.get_class_mut(&mut class_id).0; let did_merge = self.analysis.merge(&mut class.data, node_data); if did_merge.0 { - self.analysis_pending.extend(class.parents.iter().copied()); + self.analysis_pending.extend(class.parents()); N::modify(self, class_id) } } } - assert!(self.pending.is_empty()); + assert!(self.inner.is_clean()); assert!(self.analysis_pending.is_empty()); n_unions @@ -1253,7 +1052,7 @@ impl> EGraph { /// assert_eq!(egraph.find(ax), egraph.find(ay)); /// ``` pub fn rebuild(&mut self) -> usize { - let old_hc_size = self.memo.len(); + let old_hc_size = self.total_size(); let old_n_eclasses = self.number_of_classes(); let start = Instant::now(); @@ -1273,7 +1072,7 @@ impl> EGraph { elapsed.subsec_millis(), old_hc_size, old_n_eclasses, - self.memo.len(), + self.total_size(), self.number_of_classes(), n_unions, trimmed_nodes, @@ -1286,28 +1085,15 @@ impl> EGraph { pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite]) -> bool { if let Some(explain) = &mut self.explain { - explain.with_nodes(&self.nodes).check_each_explain(rules) + explain + .with_raw_egraph(&self.inner) + .check_each_explain(rules) } else { panic!("Can't check explain when explanations are off"); } } } -struct EGraphDump<'a, L: Language, N: Analysis>(&'a EGraph); - -impl<'a, L: Language, N: Analysis> Debug for EGraphDump<'a, L, N> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut ids: Vec = self.0.classes().map(|c| c.id).collect(); - ids.sort(); - for id in ids { - let mut nodes = self.0[id].nodes.clone(); - nodes.sort(); - writeln!(f, "{} ({:?}): {:?}", id, self.0[id].data, nodes)? - } - Ok(()) - } -} - #[cfg(test)] mod tests { diff --git a/src/explain.rs b/src/explain.rs index 9de2a17e..53a8ec61 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1,7 +1,7 @@ use crate::Symbol; use crate::{ - util::pretty_print, Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, - PatternAst, RecExpr, Rewrite, UnionFind, Var, + util::pretty_print, Analysis, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, PatternAst, + RecExpr, Rewrite, UnionFind, Var, }; use saturating::Saturating; use std::cmp::Ordering; @@ -10,6 +10,7 @@ use std::fmt::{self, Debug, Display, Formatter}; use std::ops::{Deref, DerefMut}; use std::rc::Rc; +use crate::raw::RawEGraph; use symbolic_expressions::Sexp; type ProofCost = Saturating; @@ -76,9 +77,9 @@ pub struct Explain { shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>, } -pub(crate) struct ExplainNodes<'a, L: Language> { +pub(crate) struct ExplainWith<'a, L: Language, X> { explain: &'a mut Explain, - nodes: &'a [L], + raw: X, } #[derive(Default)] @@ -1043,15 +1044,12 @@ impl Explain { equalities } - pub(crate) fn with_nodes<'a>(&'a mut self, nodes: &'a [L]) -> ExplainNodes<'a, L> { - ExplainNodes { - explain: self, - nodes, - } + pub(crate) fn with_raw_egraph<'a, X>(&'a mut self, raw: X) -> ExplainWith<'a, L, X> { + ExplainWith { explain: self, raw } } } -impl<'a, L: Language> Deref for ExplainNodes<'a, L> { +impl<'a, L: Language, X> Deref for ExplainWith<'a, L, X> { type Target = Explain; fn deref(&self) -> &Self::Target { @@ -1059,15 +1057,15 @@ impl<'a, L: Language> Deref for ExplainNodes<'a, L> { } } -impl<'a, L: Language> DerefMut for ExplainNodes<'a, L> { +impl<'a, L: Language, X> DerefMut for ExplainWith<'a, L, X> { fn deref_mut(&mut self) -> &mut Self::Target { &mut *self.explain } } -impl<'x, L: Language> ExplainNodes<'x, L> { +impl<'x, L: Language, D> ExplainWith<'x, L, &'x RawEGraph> { pub(crate) fn node(&self, node_id: Id) -> &L { - &self.nodes[usize::from(node_id)] + self.raw.id_to_node(node_id) } fn node_to_explanation( &self, @@ -1142,15 +1140,9 @@ impl<'x, L: Language> ExplainNodes<'x, L> { true } - pub(crate) fn explain_equivalence>( - &mut self, - left: Id, - right: Id, - unionfind: &mut UnionFind, - classes: &HashMap>, - ) -> Explanation { + pub(crate) fn explain_equivalence(&mut self, left: Id, right: Id) -> Explanation { if self.optimize_explanation_lengths { - self.calculate_shortest_explanations::(left, right, classes, unionfind); + self.calculate_shortest_explanations(left, right); } let mut cache = Default::default(); @@ -1588,12 +1580,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { distance_memo.parent_distance[usize::from(enode)].1 } - fn find_congruence_neighbors>( - &self, - classes: &HashMap>, - congruence_neighbors: &mut [Vec], - unionfind: &UnionFind, - ) { + fn find_congruence_neighbors(&self, congruence_neighbors: &mut [Vec]) { let mut counter = 0; // add the normal congruence edges first for node in &self.explainfind { @@ -1606,15 +1593,15 @@ impl<'x, L: Language> ExplainNodes<'x, L> { } } - 'outer: for eclass in classes.keys() { - let enodes = self.find_all_enodes(*eclass); + 'outer: for eclass in self.raw.classes().map(|x| x.id) { + let enodes = self.find_all_enodes(eclass); // find all congruence nodes let mut cannon_enodes: HashMap> = Default::default(); for enode in &enodes { let cannon = self .node(*enode) .clone() - .map_children(|child| unionfind.find(child)); + .map_children(|child| self.raw.find(child)); if let Some(others) = cannon_enodes.get_mut(&cannon) { for other in others.iter() { congruence_neighbors[usize::from(*enode)].push(*other); @@ -1634,13 +1621,9 @@ impl<'x, L: Language> ExplainNodes<'x, L> { } } - pub fn get_num_congr>( - &self, - classes: &HashMap>, - unionfind: &UnionFind, - ) -> usize { + pub fn get_num_congr(&self) -> usize { let mut congruence_neighbors = vec![vec![]; self.explainfind.len()]; - self.find_congruence_neighbors::(classes, &mut congruence_neighbors, unionfind); + self.find_congruence_neighbors(&mut congruence_neighbors); let mut count = 0; for v in congruence_neighbors { count += v.len(); @@ -1649,10 +1632,6 @@ impl<'x, L: Language> ExplainNodes<'x, L> { count / 2 } - pub fn get_num_nodes(&self) -> usize { - self.explainfind.len() - } - fn shortest_path_modulo_congruence( &mut self, start: Id, @@ -1851,11 +1830,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { self.explainfind[usize::from(enode)].parent_connection.next } - fn calculate_common_ancestor>( - &self, - classes: &HashMap>, - congruence_neighbors: &[Vec], - ) -> HashMap<(Id, Id), Id> { + fn calculate_common_ancestor(&self, congruence_neighbors: &[Vec]) -> HashMap<(Id, Id), Id> { let mut common_ancestor_queries = HashMap::default(); for (s_int, others) in congruence_neighbors.iter().enumerate() { let start = &Id::from(s_int); @@ -1887,8 +1862,8 @@ impl<'x, L: Language> ExplainNodes<'x, L> { unionfind.make_set(); ancestor.push(Id::from(i)); } - for (eclass, _) in classes.iter() { - let enodes = self.find_all_enodes(*eclass); + for eclass in self.raw.classes().map(|x| x.id) { + let enodes = self.find_all_enodes(eclass); let mut children: HashMap> = HashMap::default(); for enode in &enodes { children.insert(*enode, vec![]); @@ -1919,15 +1894,9 @@ impl<'x, L: Language> ExplainNodes<'x, L> { common_ancestor } - fn calculate_shortest_explanations>( - &mut self, - start: Id, - end: Id, - classes: &HashMap>, - unionfind: &UnionFind, - ) { + fn calculate_shortest_explanations(&mut self, start: Id, end: Id) { let mut congruence_neighbors = vec![vec![]; self.explainfind.len()]; - self.find_congruence_neighbors::(classes, &mut congruence_neighbors, unionfind); + self.find_congruence_neighbors(&mut congruence_neighbors); let mut parent_distance = vec![(Id::from(0), Saturating(0)); self.explainfind.len()]; for (i, entry) in parent_distance.iter_mut().enumerate() { entry.0 = Id::from(i); @@ -1935,7 +1904,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { let mut distance_memo = DistanceMemo { parent_distance, - common_ancestor: self.calculate_common_ancestor::(classes, &congruence_neighbors), + common_ancestor: self.calculate_common_ancestor(&congruence_neighbors), tree_depth: self.calculate_tree_depths(), }; diff --git a/src/lib.rs b/src/lib.rs index 5a293a58..c7e8537f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,9 @@ mod lp_extract; mod machine; mod multipattern; mod pattern; + +/// Lower level egraph API +pub mod raw; mod rewrite; mod run; mod subst; diff --git a/src/raw.rs b/src/raw.rs new file mode 100644 index 00000000..22395db6 --- /dev/null +++ b/src/raw.rs @@ -0,0 +1,5 @@ +mod eclass; +mod egraph; + +pub use eclass::RawEClass; +pub use egraph::{EGraphResidual, RawEGraph}; diff --git a/src/raw/eclass.rs b/src/raw/eclass.rs new file mode 100644 index 00000000..dd6e43be --- /dev/null +++ b/src/raw/eclass.rs @@ -0,0 +1,43 @@ +use crate::Id; +use std::fmt::Debug; +use std::iter::ExactSizeIterator; +use std::ops::{Deref, DerefMut}; + +/// An equivalence class of enodes. +#[non_exhaustive] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct RawEClass { + /// This eclass's id. + pub id: Id, + /// Arbitrary data associated with this eclass. + pub(super) raw_data: D, + /// The original Ids of parent enodes. + pub(super) parents: Vec, +} + +impl RawEClass { + /// Iterates over the non-canonical ids of parent enodes of this eclass. + pub fn parents(&self) -> impl ExactSizeIterator + '_ { + self.parents.iter().copied() + } + + /// Consumes `self` returning the stored data and an iterator similar to [`parents`](RawEClass::parents) + pub fn destruct(self) -> (D, impl ExactSizeIterator) { + (self.raw_data, self.parents.into_iter()) + } +} + +impl Deref for RawEClass { + type Target = D; + + fn deref(&self) -> &D { + &self.raw_data + } +} + +impl DerefMut for RawEClass { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.raw_data + } +} diff --git a/src/raw/egraph.rs b/src/raw/egraph.rs new file mode 100644 index 00000000..55ab1368 --- /dev/null +++ b/src/raw/egraph.rs @@ -0,0 +1,656 @@ +use crate::{raw::RawEClass, Dot, HashMap, Id, Language, RecExpr, UnionFind}; +use std::ops::{Deref, DerefMut}; +use std::{ + borrow::BorrowMut, + fmt::{self, Debug}, +}; + +#[cfg(feature = "serde-1")] +use serde::{Deserialize, Serialize}; + +/// A [`RawEGraph`] without its classes that can be obtained by dereferencing a [`RawEGraph`]. +/// +/// It exists as a separate type so that it can still be used while mutably borrowing a [`RawEClass`] +/// +/// See [`RawEGraph::classes_mut`], [`RawEGraph::get_class_mut`] +#[derive(Clone)] +#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] +pub struct EGraphResidual { + unionfind: UnionFind, + /// Stores the original node represented by each non-canonical id + nodes: Vec, + /// Stores each enode's `Id`, not the `Id` of the eclass. + /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new + /// unions can cause them to become out of date. + #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] + memo: HashMap, +} + +impl EGraphResidual { + /// Pick a representative term for a given Id. + /// + /// Calling this function on an uncanonical `Id` returns a representative based on how it + /// was obtained + pub fn id_to_expr(&self, id: Id) -> RecExpr { + let mut res = Default::default(); + let mut cache = Default::default(); + self.id_to_expr_internal(&mut res, id, &mut cache); + res + } + + fn id_to_expr_internal( + &self, + res: &mut RecExpr, + node_id: Id, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; + } + let new_node = self + .id_to_node(node_id) + .clone() + .map_children(|child| self.id_to_expr_internal(res, child, cache)); + let res_id = res.add(new_node); + cache.insert(node_id, res_id); + res_id + } + + /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep + pub fn id_to_node(&self, id: Id) -> &L { + &self.nodes[usize::from(id)] + } + + /// Canonicalizes an eclass id. + /// + /// This corresponds to the `find` operation on the egraph's + /// underlying unionfind data structure. + /// + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// let mut egraph = RawEGraph::::default(); + /// let x = egraph.add_uncanonical(S::leaf("x")); + /// let y = egraph.add_uncanonical(S::leaf("y")); + /// assert_ne!(egraph.find(x), egraph.find(y)); + /// + /// egraph.union(x, y); + /// egraph.rebuild(); + /// assert_eq!(egraph.find(x), egraph.find(y)); + /// ``` + pub fn find(&self, id: Id) -> Id { + self.unionfind.find(id) + } + + /// Same as [`find`](EGraphResidual::find) but requires mutable access since it does path compression + pub fn find_mut(&mut self, id: Id) -> Id { + self.unionfind.find_mut(id) + } + + /// Returns `true` if the egraph is empty + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// let mut egraph = RawEGraph::::default(); + /// assert!(egraph.is_empty()); + /// egraph.add_uncanonical(S::leaf("foo")); + /// assert!(!egraph.is_empty()); + /// ``` + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + /// Returns the number of uncanonical enodes in the `EGraph`. + /// + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// let mut egraph = RawEGraph::::default(); + /// let x = egraph.add_uncanonical(S::leaf("x")); + /// let y = egraph.add_uncanonical(S::leaf("y")); + /// let fx = egraph.add_uncanonical(S::new("f", vec![x])); + /// let fy = egraph.add_uncanonical(S::new("f", vec![y])); + /// // only one eclass + /// egraph.union(x, y); + /// egraph.rebuild(); + /// + /// assert_eq!(egraph.number_of_uncanonical_nodes(), 4); + /// assert_eq!(egraph.number_of_classes(), 2); + /// ``` + pub fn number_of_uncanonical_nodes(&self) -> usize { + self.nodes.len() + } + + /// Returns an iterator over the uncanonical ids in the egraph and the node + /// that would be obtained by calling [`id_to_node`](EGraphResidual::id_to_node) on each of them + pub fn uncanonical_nodes(&self) -> impl ExactSizeIterator { + self.nodes + .iter() + .enumerate() + .map(|(id, node)| (Id::from(id), node)) + } + + /// Returns the number of enodes in the `EGraph`. + /// + /// Actually returns the size of the hashcons index. + /// # Example + /// ``` + /// use egg::{*, SymbolLang as S}; + /// let mut egraph = EGraph::::default(); + /// let x = egraph.add(S::leaf("x")); + /// let y = egraph.add(S::leaf("y")); + /// // only one eclass + /// egraph.union(x, y); + /// egraph.rebuild(); + /// + /// assert_eq!(egraph.total_size(), 2); + /// assert_eq!(egraph.number_of_classes(), 1); + /// ``` + pub fn total_size(&self) -> usize { + self.memo.len() + } + + /// Lookup the eclass of the given enode. + /// + /// You can pass in either an owned enode or a `&mut` enode, + /// in which case the enode's children will be canonicalized. + /// + /// # Example + /// ``` + /// # use egg::*; + /// let mut egraph: EGraph = Default::default(); + /// let a = egraph.add(SymbolLang::leaf("a")); + /// let b = egraph.add(SymbolLang::leaf("b")); + /// + /// // lookup will find this node if its in the egraph + /// let mut node_f_ab = SymbolLang::new("f", vec![a, b]); + /// assert_eq!(egraph.lookup(node_f_ab.clone()), None); + /// let id = egraph.add(node_f_ab.clone()); + /// assert_eq!(egraph.lookup(node_f_ab.clone()), Some(id)); + /// + /// // if the query node isn't canonical, and its passed in by &mut instead of owned, + /// // its children will be canonicalized + /// egraph.union(a, b); + /// egraph.rebuild(); + /// assert_eq!(egraph.lookup(&mut node_f_ab), Some(id)); + /// assert_eq!(node_f_ab, SymbolLang::new("f", vec![a, a])); + /// ``` + pub fn lookup(&self, enode: B) -> Option + where + B: BorrowMut, + { + self.lookup_internal(enode).map(|id| self.find(id)) + } + + #[inline] + fn lookup_internal(&self, mut enode: B) -> Option + where + B: BorrowMut, + { + let enode = enode.borrow_mut(); + enode.update_children(|id| self.find(id)); + self.memo.get(enode).copied() + } + + /// Lookup the eclass of the given [`RecExpr`]. + /// + /// Equivalent to the last value in [`EGraphResidual::lookup_expr_ids`]. + pub fn lookup_expr(&self, expr: &RecExpr) -> Option { + self.lookup_expr_ids(expr) + .and_then(|ids| ids.last().copied()) + } + + /// Lookup the eclasses of all the nodes in the given [`RecExpr`]. + pub fn lookup_expr_ids(&self, expr: &RecExpr) -> Option> { + let nodes = expr.as_ref(); + let mut new_ids = Vec::with_capacity(nodes.len()); + for node in nodes { + let node = node.clone().map_children(|i| new_ids[usize::from(i)]); + let id = self.lookup(node)?; + new_ids.push(id) + } + Some(new_ids) + } + + /// Generate a mapping from canonical ids to the list of nodes they represent + pub fn generate_class_nodes(&self) -> HashMap> { + let mut classes = HashMap::default(); + let find = |id| self.find(id); + for (id, node) in self.uncanonical_nodes() { + let id = find(id); + let node = node.clone().map_children(find); + match classes.get_mut(&id) { + None => { + classes.insert(id, vec![node]); + } + Some(x) => x.push(node), + } + } + + // define all the nodes, clustered by eclass + for class in classes.values_mut() { + class.sort_unstable(); + class.dedup(); + } + classes + } + + /// Returns a more debug-able representation of the egraph focusing on its uncanonical ids and nodes. + /// + /// [`EGraph`]s implement [`Debug`], but it ain't pretty. It + /// prints a lot of stuff you probably don't care about. + /// This method returns a wrapper that implements [`Debug`] in a + /// slightly nicer way, just dumping enodes in each eclass. + /// + /// [`Debug`]: std::fmt::Debug + pub fn dump_uncanonical(&self) -> impl Debug + '_ { + EGraphUncanonicalDump(self) + } + + /// Creates a [`Dot`] to visualize this egraph. See [`Dot`]. + pub fn dot(&self) -> Dot<'_, L> { + Dot { + egraph: self, + config: vec![], + use_anchors: true, + } + } +} + +// manual debug impl to avoid L: Language bound on EGraph defn +impl Debug for EGraphResidual { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("EGraphResidual") + .field("unionfind", &self.unionfind) + .field("nodes", &self.nodes) + .field("memo", &self.memo) + .finish() + } +} + +/** A data structure to keep track of equalities between expressions. + +Check out the [background tutorial](crate::tutorials::_01_background) +for more information on e-graphs in general. + +# E-graphs in `egg::raw` + +In `egg::raw`, the main types associated with e-graphs are +[`RawEGraph`], [`RawEClass`], [`Language`], and [`Id`]. + +[`RawEGraph`] and [`RawEClass`] are all generic over a +[`Language`], meaning that types actually floating around in the +egraph are all user-defined. +In particular, the e-nodes are elements of your [`Language`]. +[`RawEGraph`]s and [`RawEClass`]es are additionally parameterized by some +abritrary data associated with each e-class. + +Many methods of [`RawEGraph`] deal with [`Id`]s, which represent e-classes. +Because eclasses are frequently merged, many [`Id`]s will refer to the +same e-class. + +[`RawEGraph`] provides a low level API for dealing with egraphs, in particular with handling the data +stored in each [`RawEClass`] so user will likely want to implemented wrappers around +[`raw_add`](RawEGraph::raw_add), [`raw_union`](RawEGraph::raw_union), and [`raw_rebuild`](RawEGraph::raw_rebuild) +to properly handle this data + **/ +#[derive(Clone)] +#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] +pub struct RawEGraph { + #[cfg_attr(feature = "serde-1", serde(flatten))] + residual: EGraphResidual, + /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, + /// not the canonical id of the eclass. + pending: Vec, + classes: HashMap>, +} + +impl Default for RawEGraph { + fn default() -> Self { + let residual = EGraphResidual { + unionfind: Default::default(), + nodes: Default::default(), + memo: Default::default(), + }; + RawEGraph { + residual, + pending: Default::default(), + classes: Default::default(), + } + } +} + +impl Deref for RawEGraph { + type Target = EGraphResidual; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.residual + } +} + +impl DerefMut for RawEGraph { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.residual + } +} + +// manual debug impl to avoid L: Language bound on EGraph defn +impl Debug for RawEGraph { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("EGraph") + .field("memo", &self.residual.memo) + .field("classes", &self.classes) + .finish() + } +} + +impl RawEGraph { + /// Returns an iterator over the eclasses in the egraph. + pub fn classes(&self) -> impl ExactSizeIterator> { + self.classes.iter().map(|(id, class)| { + debug_assert_eq!(*id, class.id); + class + }) + } + + /// Returns a mutating iterator over the eclasses in the egraph. + /// Also returns the [`EGraphResidual`] so it can still be used while `self` is borrowed + pub fn classes_mut( + &mut self, + ) -> ( + impl ExactSizeIterator>, + &mut EGraphResidual, + ) { + let iter = self.classes.iter_mut().map(|(id, class)| { + debug_assert_eq!(*id, class.id); + class + }); + (iter, &mut self.residual) + } + + /// Returns the number of eclasses in the egraph. + pub fn number_of_classes(&self) -> usize { + self.classes().len() + } + + /// Returns the eclass corresponding to `id` + pub fn get_class>(&self, mut id: I) -> &RawEClass { + let id = id.borrow_mut(); + *id = self.find(*id); + self.get_class_with_cannon(*id) + } + + /// Like [`get_class`](RawEGraph::get_class) but panics if `id` is not canonical + pub fn get_class_with_cannon(&self, id: Id) -> &RawEClass { + self.classes + .get(&id) + .unwrap_or_else(|| panic!("Invalid id {}", id)) + } + + /// Returns the eclass corresponding to `id` + /// Also returns the [`EGraphResidual`] so it can still be used while `self` is borrowed + pub fn get_class_mut>( + &mut self, + mut id: I, + ) -> (&mut RawEClass, &mut EGraphResidual) { + let id = id.borrow_mut(); + *id = self.find_mut(*id); + self.get_class_mut_with_cannon(*id) + } + + /// Like [`get_class_mut`](RawEGraph::get_class_mut) but panics if `id` is not canonical + pub fn get_class_mut_with_cannon( + &mut self, + id: Id, + ) -> (&mut RawEClass, &mut EGraphResidual) { + ( + self.classes + .get_mut(&id) + .unwrap_or_else(|| panic!("Invalid id {}", id)), + &mut self.residual, + ) + } +} + +impl RawEGraph { + /// Adds `enode` to a [`RawEGraph`] contained within a wrapper type `T` + /// + /// ## Parameters + /// + /// ### `get_self` + /// Called to extract the [`RawEGraph`] from the wrapper type, and should not perform any mutation. + /// + /// This will likely be a simple field access or just the identity function if there is no wrapper type. + /// + /// ### `handle_equiv` + /// When there already exists a node that is congruently equivalent to `enode` in the egraph + /// this function is called with the uncanonical id of a equivalent node, and a reference to `enode` + /// + /// Returning `Some(id)` will cause `raw_add` to immediately return `id` + /// (in this case `id` should represent an enode that is equivalent to the one being inserted). + /// + /// Returning `None` will cause `raw_add` to create a new id for `enode`, union it to the equivalent node, + /// and then return it. + /// + /// ### `handle_union` + /// Called after `handle_equiv` returns `None` with the uncanonical id of the equivalent node + /// and the new `id` assigned to `enode` + /// + /// Calling [`id_to_node`](EGraphResidual::id_to_node) on the new `id` will return a reference to `enode` + /// + /// ### `mk_data` + /// When there does not already exist a node is congruently equivalent to `enode` in the egraph + /// this function is called with the new `id` assigned to `enode` and a reference to the canonicalized version of + /// `enode` to create to data that will be stored in the [`RawEClass`] associated with it + /// + /// Calling [`id_to_node`](EGraphResidual::id_to_node) on the new `id` will return a reference to `enode` + /// + /// Calling [`get_class`](RawEGraph::get_class) on the new `id` will cause a panic since the [`RawEClass`] is + /// still being built + #[inline] + pub fn raw_add( + outer: &mut T, + get_self: impl Fn(&mut T) -> &mut Self, + mut enode: L, + handle_equiv: impl FnOnce(&mut T, Id, &L) -> Option, + handle_union: impl FnOnce(&mut T, Id, Id), + mk_data: impl FnOnce(&mut T, Id, &L) -> D, + ) -> Id { + let this = get_self(outer); + let original = enode.clone(); + if let Some(existing_id) = this.lookup_internal(&mut enode) { + let canon_id = this.find(existing_id); + // when explanations are enabled, we need a new representative for this expr + if let Some(existing_id) = handle_equiv(outer, existing_id, &original) { + existing_id + } else { + let this = get_self(outer); + let new_id = this.residual.unionfind.make_set(); + debug_assert_eq!(Id::from(this.nodes.len()), new_id); + this.residual.nodes.push(original); + this.residual.unionfind.union(canon_id, new_id); + handle_union(outer, existing_id, new_id); + new_id + } + } else { + let id = this.residual.unionfind.make_set(); + debug_assert_eq!(Id::from(this.nodes.len()), id); + this.residual.nodes.push(original); + + log::trace!(" ...adding to {}", id); + let class = RawEClass { + id, + raw_data: mk_data(outer, id, &enode), + parents: Default::default(), + }; + let this = get_self(outer); + + // add this enode to the parent lists of its children + enode.for_each(|child| { + this.get_class_mut(child).0.parents.push(id); + }); + + // TODO is this needed? + this.pending.push(id); + + this.classes.insert(id, class); + assert!(this.residual.memo.insert(enode, id).is_none()); + + id + } + } + + /// Unions two eclasses given their ids. + /// + /// The given ids need not be canonical. + /// + /// Returns `None` if the two ids were already equivalent. + /// + /// Returns `Some((id, class))` if two classes were merged where `id` is the id of the newly merged class + /// and `class` is the old `RawEClass` that merged into `id` + #[inline] + pub fn raw_union(&mut self, enode_id1: Id, enode_id2: Id) -> Option<(Id, RawEClass)> { + let mut id1 = self.find_mut(enode_id1); + let mut id2 = self.find_mut(enode_id2); + if id1 == id2 { + return None; + } + // make sure class2 has fewer parents + let class1_parents = self.classes[&id1].parents.len(); + let class2_parents = self.classes[&id2].parents.len(); + if class1_parents < class2_parents { + std::mem::swap(&mut id1, &mut id2); + } + + // make id1 the new root + self.residual.unionfind.union(id1, id2); + + assert_ne!(id1, id2); + let class2 = self.classes.remove(&id2).unwrap(); + let class1 = self.classes.get_mut(&id1).unwrap(); + assert_eq!(id1, class1.id); + + self.pending.extend(class2.parents()); + + class1.parents.extend(class2.parents()); + Some((id1, class2)) + } + + #[inline] + /// Rebuild to [`RawEGraph`] to restore congruence closure + /// + /// ## Parameters + /// + /// ### `get_self` + /// Called to extract the [`RawEGraph`] from the wrapper type, and should not perform any mutation. + /// + /// This will likely be a simple field access or just the identity function if there is no wrapper type. + /// + /// ### `perform_union` + /// Called on each pair of ids that needs to be unioned + /// + /// In order to be correct `perform_union` should call [`raw_union`](RawEGraph::raw_union) + /// + /// ### `handle_pending` + /// Called with the uncanonical id of each enode whose canonical children have changned, along with a canonical + /// version of it + pub fn raw_rebuild( + outer: &mut T, + get_self: impl Fn(&mut T) -> &mut Self, + mut perform_union: impl FnMut(&mut T, Id, Id), + mut handle_pending: impl FnMut(&mut T, Id, &L), + ) { + loop { + let this = get_self(outer); + if let Some(class) = this.pending.pop() { + let mut node = this.id_to_node(class).clone(); + node.update_children(|id| this.find_mut(id)); + handle_pending(outer, class, &node); + if let Some(memo_class) = get_self(outer).residual.memo.insert(node, class) { + perform_union(outer, memo_class, class); + } + } else { + break; + } + } + } + + /// Returns whether `self` is congruently closed + /// + /// This will always be true after calling [`raw_rebuild`](RawEGraph::raw_rebuild) + pub fn is_clean(&self) -> bool { + self.pending.is_empty() + } + + /// Returns a more debug-able representation of the egraph focusing on its classes. + /// + /// [`EGraph`]s implement [`Debug`], but it ain't pretty. It + /// prints a lot of stuff you probably don't care about. + /// This method returns a wrapper that implements [`Debug`] in a + /// slightly nicer way, just dumping enodes in each eclass. + /// + /// [`Debug`]: std::fmt::Debug + pub fn dump_classes(&self) -> impl Debug + '_ + where + D: Debug, + { + EGraphDump(self) + } +} + +impl RawEGraph { + /// Simplified version of [`raw_add`](RawEGraph::raw_add) for egraphs without eclass data + pub fn add_uncanonical(&mut self, enode: L) -> Id { + Self::raw_add( + self, + |x| x, + enode, + |_, id, _| Some(id), + |_, _, _| {}, + |_, _, _| (), + ) + } + + /// Simplified version of [`raw_union`](RawEGraph::raw_union) for egraphs without eclass data + pub fn union(&mut self, id1: Id, id2: Id) -> bool { + Self::raw_union(self, id1, id2).is_some() + } + + /// Simplified version of [`raw_rebuild`](RawEGraph::raw_rebuild) for egraphs without eclass data + pub fn rebuild(&mut self) { + Self::raw_rebuild( + self, + |x| x, + |this, id1, id2| { + this.union(id1, id2); + }, + |_, _, _| {}, + ); + } +} + +struct EGraphUncanonicalDump<'a, L: Language>(&'a EGraphResidual); + +impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (id, node) in self.0.uncanonical_nodes() { + writeln!(f, "{}: {:?} (root={})", id, node, self.0.find(id))? + } + Ok(()) + } +} + +struct EGraphDump<'a, L: Language, D>(&'a RawEGraph); + +impl<'a, L: Language, D: Debug> Debug for EGraphDump<'a, L, D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut ids: Vec = self.0.classes().map(|c| c.id).collect(); + ids.sort(); + for id in ids { + writeln!(f, "{} {:?}", id, self.0.get_class(id).raw_data)? + } + Ok(()) + } +} From 8370122f9efcf1e8f0ba7a3ce0d159c343a3654b Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 7 Feb 2024 17:13:18 -0800 Subject: [PATCH 07/20] doc-link fixes --- src/dot.rs | 10 +++++----- src/egraph.rs | 22 +++++++++++----------- src/lib.rs | 2 +- src/raw/egraph.rs | 6 +++--- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/dot.rs b/src/dot.rs index 111fac51..b68028ce 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -1,7 +1,7 @@ /*! EGraph visualization with [GraphViz] -Use the [`Dot`] struct to visualize an [`EGraph`] +Use the [`Dot`] struct to visualize an [`EGraph`](crate::EGraph) [GraphViz]: https://graphviz.gitlab.io/ !*/ @@ -11,13 +11,13 @@ use std::fmt::{self, Debug, Display, Formatter}; use std::io::{Error, ErrorKind, Result, Write}; use std::path::Path; -use crate::{raw, Language}; +use crate::{raw::EGraphResidual, Language}; /** -A wrapper for an [`EGraph`] that can output [GraphViz] for +A wrapper for an [`EGraphResidual`] that can output [GraphViz] for visualization. -The [`EGraph::dot`](EGraph::dot()) method creates `Dot`s. +The [`EGraphResidual::dot`] method creates `Dot`s. # Example @@ -51,7 +51,7 @@ instead of to its own eclass. [GraphViz]: https://graphviz.gitlab.io/ **/ pub struct Dot<'a, L: Language> { - pub(crate) egraph: &'a raw::EGraphResidual, + pub(crate) egraph: &'a EGraphResidual, /// A list of strings to be output top part of the dot file. pub config: Vec, /// Whether or not to anchor the edges in the output. diff --git a/src/egraph.rs b/src/egraph.rs index 6be8729c..5544fb99 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -316,10 +316,10 @@ impl> EGraph { } } - /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term. + /// Like [`id_to_expr`](EGraphResidual::id_to_expr), but creates a pattern instead of a term. /// When an eclass listed in the given substitutions is found, it creates a variable. /// It also adds this variable and the corresponding Id value to the resulting [`Subst`] - /// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr). + /// Otherwise it behaves like [`id_to_expr`](EGraphResidual::id_to_expr). pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap) -> (Pattern, Subst) { let mut res = Default::default(); let mut subst = Default::default(); @@ -405,10 +405,10 @@ impl> EGraph { self.explain_id_equivalence(left, right) } - /// Equivalent to calling [`explain_equivalence`](EGraph::explain_equivalence)`(`[`id_to_expr`](EGraph::id_to_expr)`(left),` - /// [`id_to_expr`](EGraph::id_to_expr)`(right))` but more efficient + /// Equivalent to calling [`explain_equivalence`](EGraph::explain_equivalence)`(`[`id_to_expr`](EGraphResidual::id_to_expr)`(left),` + /// [`id_to_expr`](EGraphResidual::id_to_expr)`(right))` but more efficient /// - /// This function picks representatives using [`id_to_expr`](EGraph::id_to_expr) so choosing + /// This function picks representatives using [`id_to_expr`](EGraphResidual::id_to_expr) so choosing /// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important /// to control explanations pub fn explain_id_equivalence(&mut self, left: Id, right: Id) -> Explanation { @@ -441,7 +441,7 @@ impl> EGraph { self.explain_existance_id(id) } - /// Equivalent to calling [`explain_existance`](EGraph::explain_existance)`(`[`id_to_expr`](EGraph::id_to_expr)`(id))` + /// Equivalent to calling [`explain_existance`](EGraph::explain_existance)`(`[`id_to_expr`](EGraphResidual::id_to_expr)`(id))` /// but more efficient fn explain_existance_id(&mut self, id: Id) -> Explanation { if let Some(explain) = &mut self.explain { @@ -529,7 +529,7 @@ impl> EGraph { /// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical /// - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled + /// Calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled pub fn add_expr_uncanonical(&mut self, expr: &RecExpr) -> Id { let nodes = expr.as_ref(); let mut new_ids = Vec::with_capacity(nodes.len()); @@ -567,7 +567,7 @@ impl> EGraph { /// canonical /// /// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an correspond to the + /// Calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` return an correspond to the /// instantiation of the pattern fn add_instantiation_noncanonical(&mut self, pat: &PatternAst, subst: &Subst) -> Id { let nodes = pat.as_ref(); @@ -623,7 +623,7 @@ impl> EGraph { /// Similar to [`add`](EGraph::add) but the `Id` returned may not be canonical /// - /// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will + /// When explanations are enabled calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` will /// correspond to the parameter `enode` /// /// ## Example @@ -642,7 +642,7 @@ impl> EGraph { /// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap()); /// ``` /// - /// When explanations are not enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will + /// When explanations are not enabled calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` will /// produce an expression with equivalent but not necessarily identical children /// /// # Example @@ -762,7 +762,7 @@ impl> EGraph { /// Unions two e-classes, using a given reason to justify it. /// - /// This function picks representatives using [`id_to_expr`](EGraph::id_to_expr) so choosing + /// This function picks representatives using [`id_to_expr`](EGraphResidual::id_to_expr) so choosing /// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important /// to control explanations pub fn union_trusted(&mut self, from: Id, to: Id, reason: impl Into) -> bool { diff --git a/src/lib.rs b/src/lib.rs index c7e8537f..298b3518 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,7 +92,7 @@ pub(crate) use {explain::Explain, unionfind::UnionFind}; pub use { dot::Dot, - eclass::EClass, + eclass::{EClass, EClassData}, egraph::EGraph, explain::{ Explanation, FlatExplanation, FlatTerm, Justification, TreeExplanation, TreeTerm, diff --git a/src/raw/egraph.rs b/src/raw/egraph.rs index 55ab1368..510c12b5 100644 --- a/src/raw/egraph.rs +++ b/src/raw/egraph.rs @@ -56,7 +56,7 @@ impl EGraphResidual { res_id } - /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep + /// Like [`id_to_expr`](EGraphResidual::id_to_expr) but only goes one layer deep pub fn id_to_node(&self, id: Id) -> &L { &self.nodes[usize::from(id)] } @@ -237,7 +237,7 @@ impl EGraphResidual { /// Returns a more debug-able representation of the egraph focusing on its uncanonical ids and nodes. /// - /// [`EGraph`]s implement [`Debug`], but it ain't pretty. It + /// [`RawEGraph`]s implement [`Debug`], but it's not pretty. It /// prints a lot of stuff you probably don't care about. /// This method returns a wrapper that implements [`Debug`] in a /// slightly nicer way, just dumping enodes in each eclass. @@ -586,7 +586,7 @@ impl RawEGraph { /// Returns a more debug-able representation of the egraph focusing on its classes. /// - /// [`EGraph`]s implement [`Debug`], but it ain't pretty. It + /// [`RawEGraph`]s implement [`Debug`], but it's not pretty. It /// prints a lot of stuff you probably don't care about. /// This method returns a wrapper that implements [`Debug`] in a /// slightly nicer way, just dumping enodes in each eclass. From c18f6d4a44714e684b7291af7ff8eebe88370914 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Sat, 10 Feb 2024 16:57:14 -0800 Subject: [PATCH 08/20] Improved raw_union interface, fixed EGraph::dump and updated edition --- Cargo.toml | 2 +- src/eclass.rs | 2 +- src/egraph.rs | 42 ++++++++++++++++++------------------------ src/raw/egraph.rs | 46 ++++++++++++++++++++++++++++++++++++---------- 4 files changed, 56 insertions(+), 36 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e2b3af6b..9decbb29 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ authors = ["Max Willsey "] categories = ["data-structures"] description = "An implementation of egraphs" -edition = "2018" +edition = "2021" keywords = ["e-graphs"] license = "MIT" name = "egg" diff --git a/src/eclass.rs b/src/eclass.rs index 8136cff4..e235d58e 100644 --- a/src/eclass.rs +++ b/src/eclass.rs @@ -21,7 +21,7 @@ impl Debug for EClassData { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let mut nodes = self.nodes.clone(); nodes.sort(); - writeln!(f, "({:?}): {:?}", self.data, nodes) + write!(f, "({:?}): {:?}", self.data, nodes) } } diff --git a/src/egraph.rs b/src/egraph.rs index 5544fb99..0eafb9d4 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -803,11 +803,27 @@ impl> EGraph { N::pre_union(self, enode_id1, enode_id2, &rule); self.clean = false; - if let Some((id, class2)) = self.inner.raw_union(enode_id1, enode_id2) { - self.merge(id, class2); + let mut new_root = None; + self.inner + .raw_union(enode_id1, enode_id2, |class1, id1, p1, class2, _, p2| { + new_root = Some(id1); + + let did_merge = self.analysis.merge(&mut class1.data, class2.data); + if did_merge.0 { + self.analysis_pending.extend(p1); + } + if did_merge.1 { + self.analysis_pending.extend(p2); + } + + concat_vecs(&mut class1.nodes, class2.nodes); + }); + if let Some(id) = new_root { if let Some(explain) = &mut self.explain { explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs); } + N::modify(self, id); + true } else { if let Some(Justification::Rule(_)) = rule { @@ -819,28 +835,6 @@ impl> EGraph { } } - fn merge(&mut self, id1: Id, class2: EClass) { - let class1 = self.inner.get_class_mut_with_cannon(id1).0; - let (class2, parents) = class2.destruct(); - let did_merge = self.analysis.merge(&mut class1.data, class2.data); - if did_merge.0 { - // class1.parents already contains the combined parents, - // so we only take the ones that were there before the union - self.analysis_pending.extend( - class1 - .parents() - .take(class1.parents().len() - parents.len()), - ); - } - if did_merge.1 { - self.analysis_pending.extend(parents); - } - - concat_vecs(&mut class1.nodes, class2.nodes); - - N::modify(self, id1) - } - /// Update the analysis data of an e-class. /// /// This also propagates the changes through the e-graph, diff --git a/src/raw/egraph.rs b/src/raw/egraph.rs index 510c12b5..d9e4592e 100644 --- a/src/raw/egraph.rs +++ b/src/raw/egraph.rs @@ -3,11 +3,23 @@ use std::ops::{Deref, DerefMut}; use std::{ borrow::BorrowMut, fmt::{self, Debug}, + iter, slice, }; #[cfg(feature = "serde-1")] use serde::{Deserialize, Serialize}; +pub struct Parents<'a>(&'a [Id]); + +impl<'a> IntoIterator for Parents<'a> { + type Item = Id; + type IntoIter = iter::Copied>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter().copied() + } +} + /// A [`RawEGraph`] without its classes that can be obtained by dereferencing a [`RawEGraph`]. /// /// It exists as a separate type so that it can still be used while mutably borrowing a [`RawEClass`] @@ -506,16 +518,18 @@ impl RawEGraph { /// /// The given ids need not be canonical. /// - /// Returns `None` if the two ids were already equivalent. - /// - /// Returns `Some((id, class))` if two classes were merged where `id` is the id of the newly merged class - /// and `class` is the old `RawEClass` that merged into `id` + /// If a union occurs, `merge` is called with the data, id, and parents of the two eclasses being merged #[inline] - pub fn raw_union(&mut self, enode_id1: Id, enode_id2: Id) -> Option<(Id, RawEClass)> { + pub fn raw_union( + &mut self, + enode_id1: Id, + enode_id2: Id, + merge: impl FnOnce(&mut D, Id, Parents<'_>, D, Id, Parents<'_>), + ) { let mut id1 = self.find_mut(enode_id1); let mut id2 = self.find_mut(enode_id2); if id1 == id2 { - return None; + return; } // make sure class2 has fewer parents let class1_parents = self.classes[&id1].parents.len(); @@ -531,11 +545,19 @@ impl RawEGraph { let class2 = self.classes.remove(&id2).unwrap(); let class1 = self.classes.get_mut(&id1).unwrap(); assert_eq!(id1, class1.id); + let (p1, p2) = (Parents(&class1.parents), Parents(&class2.parents)); + merge( + &mut class1.raw_data, + class1.id, + p1, + class2.raw_data, + class2.id, + p2, + ); - self.pending.extend(class2.parents()); + self.pending.extend(&class2.parents); - class1.parents.extend(class2.parents()); - Some((id1, class2)) + class1.parents.extend(class2.parents); } #[inline] @@ -615,7 +637,11 @@ impl RawEGraph { /// Simplified version of [`raw_union`](RawEGraph::raw_union) for egraphs without eclass data pub fn union(&mut self, id1: Id, id2: Id) -> bool { - Self::raw_union(self, id1, id2).is_some() + let mut unioned = false; + self.raw_union(id1, id2, |_, _, _, _, _, _| { + unioned = true; + }); + unioned } /// Simplified version of [`raw_rebuild`](RawEGraph::raw_rebuild) for egraphs without eclass data From f955195d08d2ffa1af31ea9f8d405e3acdfdad0c Mon Sep 17 00:00:00 2001 From: dewert99 Date: Sat, 10 Feb 2024 17:25:30 -0800 Subject: [PATCH 09/20] Added 2 versions of semi-persistence to `RawEGraph` --- Cargo.toml | 3 +- rust-toolchain | 2 +- src/lib.rs | 3 +- src/raw.rs | 14 +- src/raw/dhashmap.rs | 149 ++++++++++++++++ src/raw/egraph.rs | 150 +++++++++++----- src/raw/semi_persistent.rs | 128 ++++++++++++++ src/raw/semi_persistent1.rs | 223 +++++++++++++++++++++++ src/raw/semi_persistent2.rs | 343 ++++++++++++++++++++++++++++++++++++ src/{ => raw}/unionfind.rs | 18 +- src/util.rs | 3 + 11 files changed, 985 insertions(+), 51 deletions(-) create mode 100644 src/raw/dhashmap.rs create mode 100644 src/raw/semi_persistent.rs create mode 100644 src/raw/semi_persistent1.rs create mode 100644 src/raw/semi_persistent2.rs rename src/{ => raw}/unionfind.rs (74%) diff --git a/Cargo.toml b/Cargo.toml index 9decbb29..21f54c71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,11 +9,12 @@ name = "egg" readme = "README.md" repository = "https://github.com/egraphs-good/egg" version = "0.9.5" +rust-version = "1.63.0" [dependencies] env_logger = { version = "0.9.0", default-features = false } fxhash = "0.2.1" -hashbrown = "0.12.1" +hashbrown = { version = "0.14.3", default-features = false, features = ["inline-more", "ahash"] } indexmap = "1.8.1" instant = "0.1.12" log = "0.4.17" diff --git a/rust-toolchain b/rust-toolchain index 2fef84a8..6cb4a6fe 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.60 \ No newline at end of file +1.63 \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 298b3518..7f8853fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,7 +54,6 @@ pub mod raw; mod rewrite; mod run; mod subst; -mod unionfind; mod util; /// A key to identify [`EClass`]es within an @@ -88,7 +87,7 @@ impl std::fmt::Display for Id { } } -pub(crate) use {explain::Explain, unionfind::UnionFind}; +pub(crate) use {explain::Explain, raw::UnionFind}; pub use { dot::Dot, diff --git a/src/raw.rs b/src/raw.rs index 22395db6..444f2016 100644 --- a/src/raw.rs +++ b/src/raw.rs @@ -1,5 +1,17 @@ +mod dhashmap; mod eclass; mod egraph; +mod semi_persistent; + +/// One variant of semi_persistence +pub mod semi_persistent1; + +/// Another variant of semi_persistence +pub mod semi_persistent2; +mod unionfind; pub use eclass::RawEClass; -pub use egraph::{EGraphResidual, RawEGraph}; +pub use egraph::{EGraphResidual, RawEGraph, UnionInfo}; +use semi_persistent::Sealed; +pub use semi_persistent::{AsUnwrap, UndoLogT}; +pub use unionfind::UnionFind; diff --git a/src/raw/dhashmap.rs b/src/raw/dhashmap.rs new file mode 100644 index 00000000..a4ef7423 --- /dev/null +++ b/src/raw/dhashmap.rs @@ -0,0 +1,149 @@ +use std::fmt::{Debug, Formatter}; +use std::hash::{BuildHasher, Hash}; +use std::iter; +use std::iter::FromIterator; + +use hashbrown::hash_table; + +pub(super) type DHMIdx = u32; + +/// Similar to [`HashMap`](std::collections::HashMap) but with deterministic iteration order +/// +/// Accessing individual elements has similar performance to a [`HashMap`](std::collections::HashMap) +/// (faster than an `IndexMap`), but iteration requires allocation +/// +#[derive(Clone)] +pub(super) struct DHashMap { + data: hash_table::HashTable<(K, V, DHMIdx)>, + hasher: S, +} + +impl Default for DHashMap { + fn default() -> Self { + DHashMap { + data: Default::default(), + hasher: Default::default(), + } + } +} + +pub(super) struct VacantEntry<'a, K, V> { + len: DHMIdx, + entry: hash_table::VacantEntry<'a, (K, V, DHMIdx)>, + k: K, +} + +impl<'a, K, V> VacantEntry<'a, K, V> { + pub(super) fn insert(self, v: V) { + self.entry.insert((self.k, v, self.len)); + } +} + +pub(super) enum Entry<'a, K, V> { + Occupied((K, &'a mut V)), + Vacant(VacantEntry<'a, K, V>), +} + +#[inline] +fn hash_one(hasher: &impl BuildHasher, hash: impl Hash) -> u64 { + use core::hash::Hasher; + let mut hasher = hasher.build_hasher(); + hash.hash(&mut hasher); + hasher.finish() +} + +#[inline] +fn eq<'a, K: Eq, V>(k: &'a K) -> impl Fn(&(K, V, DHMIdx)) -> bool + 'a { + move |x| &x.0 == k +} + +#[inline] +fn hasher_fn<'a, K: Hash, V, S: BuildHasher>( + hasher: &'a S, +) -> impl Fn(&(K, V, DHMIdx)) -> u64 + 'a { + move |x| hash_one(hasher, &x.0) +} + +impl DHashMap { + #[inline] + pub(super) fn entry(&mut self, k: K) -> (Entry<'_, K, V>, u64) { + let hash = hash_one(&mut self.hasher, &k); + let len = self.data.len() as DHMIdx; + let entry = match self.data.entry(hash, eq(&k), hasher_fn(&self.hasher)) { + hash_table::Entry::Occupied(entry) => Entry::Occupied((k, &mut entry.into_mut().1)), + hash_table::Entry::Vacant(entry) => Entry::Vacant(VacantEntry { len, entry, k }), + }; + (entry, hash) + } + + #[inline] + pub(super) fn insert_with_hash(&mut self, hash: u64, k: K, v: V) { + debug_assert!({ + let (v, hash2) = self.get(&k); + v.is_none() && hash == hash2 + }); + let len = self.data.len() as DHMIdx; + self.data + .insert_unique(hash, (k, v, len), hasher_fn(&self.hasher)); + } + + #[inline] + pub(super) fn remove_nth(&mut self, hash: u64, idx: DHMIdx) { + debug_assert_eq!(self.data.len() as DHMIdx - 1, idx); + match self.data.find_entry(hash, |x| x.2 == idx) { + Ok(x) => x.remove(), + Err(_) => unreachable!(), + }; + } + + #[inline] + pub(super) fn len(&self) -> usize { + self.data.len() + } + + #[inline] + pub(super) fn get(&self, k: &K) -> (Option<&V>, u64) { + let hash = hash_one(&self.hasher, k); + (self.data.find(hash, eq(k)).map(|x| &x.1), hash) + } + + pub(super) fn clear(&mut self) { + self.data.clear() + } +} + +impl<'a, K, V, S> IntoIterator for &'a DHashMap { + type Item = (&'a K, &'a V); + + // TODO replace with TAIT + type IntoIter = iter::Map< + std::vec::IntoIter>, + fn(Option<(&'a K, &'a V)>) -> (&'a K, &'a V), + >; + + #[inline(never)] + fn into_iter(self) -> Self::IntoIter { + let mut data: Vec<_> = iter::repeat(None).take(self.data.len()).collect(); + for (k, v, i) in &self.data { + data[*i as usize] = Some((k, v)) + } + data.into_iter().map(Option::unwrap) + } +} + +impl FromIterator<(K, V)> for DHashMap { + fn from_iter>(iter: T) -> Self { + let mut res = Self::default(); + iter.into_iter().for_each(|(k, v)| { + let hash = hash_one(&mut res.hasher, &k); + res.insert_with_hash(hash, k, v) + }); + res + } +} + +impl Debug for DHashMap { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_map().entries(self).finish() + } +} diff --git a/src/raw/egraph.rs b/src/raw/egraph.rs index d9e4592e..dcbb775d 100644 --- a/src/raw/egraph.rs +++ b/src/raw/egraph.rs @@ -1,4 +1,5 @@ use crate::{raw::RawEClass, Dot, HashMap, Id, Language, RecExpr, UnionFind}; +use std::collections::BTreeMap; use std::ops::{Deref, DerefMut}; use std::{ borrow::BorrowMut, @@ -6,6 +7,8 @@ use std::{ iter, slice, }; +use crate::raw::dhashmap::*; +use crate::raw::UndoLogT; #[cfg(feature = "serde-1")] use serde::{Deserialize, Serialize}; @@ -28,14 +31,14 @@ impl<'a> IntoIterator for Parents<'a> { #[derive(Clone)] #[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] pub struct EGraphResidual { - unionfind: UnionFind, + pub(super) unionfind: UnionFind, /// Stores the original node represented by each non-canonical id - nodes: Vec, + pub(super) nodes: Vec, /// Stores each enode's `Id`, not the `Id` of the eclass. /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new /// unions can cause them to become out of date. #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] - memo: HashMap, + pub(super) memo: DHashMap, } impl EGraphResidual { @@ -142,15 +145,22 @@ impl EGraphResidual { .map(|(id, node)| (Id::from(id), node)) } + /// Returns an iterator over all the uncanonical ids + pub fn uncanonical_ids(&self) -> impl ExactSizeIterator + 'static { + (0..self.number_of_uncanonical_nodes()) + .into_iter() + .map(Id::from) + } + /// Returns the number of enodes in the `EGraph`. /// /// Actually returns the size of the hashcons index. /// # Example /// ``` - /// use egg::{*, SymbolLang as S}; - /// let mut egraph = EGraph::::default(); - /// let x = egraph.add(S::leaf("x")); - /// let y = egraph.add(S::leaf("y")); + /// use egg::{raw::*, SymbolLang as S}; + /// let mut egraph = RawEGraph::::default(); + /// let x = egraph.add_uncanonical(S::leaf("x")); + /// let y = egraph.add_uncanonical(S::leaf("y")); /// // only one eclass /// egraph.union(x, y); /// egraph.rebuild(); @@ -169,15 +179,15 @@ impl EGraphResidual { /// /// # Example /// ``` - /// # use egg::*; - /// let mut egraph: EGraph = Default::default(); - /// let a = egraph.add(SymbolLang::leaf("a")); - /// let b = egraph.add(SymbolLang::leaf("b")); + /// # use egg::{SymbolLang, raw::*}; + /// let mut egraph: RawEGraph = Default::default(); + /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); /// /// // lookup will find this node if its in the egraph /// let mut node_f_ab = SymbolLang::new("f", vec![a, b]); /// assert_eq!(egraph.lookup(node_f_ab.clone()), None); - /// let id = egraph.add(node_f_ab.clone()); + /// let id = egraph.add_uncanonical(node_f_ab.clone()); /// assert_eq!(egraph.lookup(node_f_ab.clone()), Some(id)); /// /// // if the query node isn't canonical, and its passed in by &mut instead of owned, @@ -201,7 +211,7 @@ impl EGraphResidual { { let enode = enode.borrow_mut(); enode.update_children(|id| self.find(id)); - self.memo.get(enode).copied() + self.memo.get(enode).0.copied() } /// Lookup the eclass of the given [`RecExpr`]. @@ -308,16 +318,17 @@ to properly handle this data **/ #[derive(Clone)] #[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] -pub struct RawEGraph { +pub struct RawEGraph { #[cfg_attr(feature = "serde-1", serde(flatten))] - residual: EGraphResidual, + pub(super) residual: EGraphResidual, /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, /// not the canonical id of the eclass. - pending: Vec, - classes: HashMap>, + pub(super) pending: Vec, + pub(super) classes: HashMap>, + pub(super) undo_log: U, } -impl Default for RawEGraph { +impl Default for RawEGraph { fn default() -> Self { let residual = EGraphResidual { unionfind: Default::default(), @@ -328,11 +339,12 @@ impl Default for RawEGraph { residual, pending: Default::default(), classes: Default::default(), + undo_log: Default::default(), } } } -impl Deref for RawEGraph { +impl Deref for RawEGraph { type Target = EGraphResidual; #[inline] @@ -341,7 +353,7 @@ impl Deref for RawEGraph { } } -impl DerefMut for RawEGraph { +impl DerefMut for RawEGraph { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.residual @@ -349,16 +361,32 @@ impl DerefMut for RawEGraph { } // manual debug impl to avoid L: Language bound on EGraph defn -impl Debug for RawEGraph { +impl Debug for RawEGraph { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let classes: BTreeMap<_, _> = self + .classes + .iter() + .map(|(x, y)| { + let mut parents = y.parents.clone(); + parents.sort_unstable(); + ( + *x, + RawEClass { + id: y.id, + raw_data: &y.raw_data, + parents, + }, + ) + }) + .collect(); f.debug_struct("EGraph") .field("memo", &self.residual.memo) - .field("classes", &self.classes) + .field("classes", &classes) .finish() } } -impl RawEGraph { +impl RawEGraph { /// Returns an iterator over the eclasses in the egraph. pub fn classes(&self) -> impl ExactSizeIterator> { self.classes.iter().map(|(id, class)| { @@ -424,9 +452,28 @@ impl RawEGraph { &mut self.residual, ) } + + /// Returns whether `self` is congruently closed + /// + /// This will always be true after calling [`raw_rebuild`](RawEGraph::raw_rebuild) + pub fn is_clean(&self) -> bool { + self.pending.is_empty() + } +} + +/// Information about a call to [`RawEGraph::raw_union`] +pub struct UnionInfo { + /// The canonical id of the newly merged class + pub new_id: Id, + /// The number of parents that were in the newly merged class before it was merged + pub parents_cut: usize, + /// The id that used to canonically represent the class that was merged into `new_id` + pub old_id: Id, + /// The data that was in the class reprented by `old_id` + pub old_data: D, } -impl RawEGraph { +impl> RawEGraph { /// Adds `enode` to a [`RawEGraph`] contained within a wrapper type `T` /// /// ## Parameters @@ -465,14 +512,15 @@ impl RawEGraph { pub fn raw_add( outer: &mut T, get_self: impl Fn(&mut T) -> &mut Self, - mut enode: L, + original: L, handle_equiv: impl FnOnce(&mut T, Id, &L) -> Option, handle_union: impl FnOnce(&mut T, Id, Id), mk_data: impl FnOnce(&mut T, Id, &L) -> D, ) -> Id { let this = get_self(outer); - let original = enode.clone(); - if let Some(existing_id) = this.lookup_internal(&mut enode) { + let enode = original.clone().map_children(|x| this.find(x)); + let (existing_id, hash) = this.residual.memo.get(&enode); + if let Some(&existing_id) = existing_id { let canon_id = this.find(existing_id); // when explanations are enabled, we need a new representative for this expr if let Some(existing_id) = handle_equiv(outer, existing_id, &original) { @@ -480,6 +528,8 @@ impl RawEGraph { } else { let this = get_self(outer); let new_id = this.residual.unionfind.make_set(); + this.undo_log.add_node(&original, &[], new_id); + this.undo_log.union(canon_id, new_id, Vec::new()); debug_assert_eq!(Id::from(this.nodes.len()), new_id); this.residual.nodes.push(original); this.residual.unionfind.union(canon_id, new_id); @@ -488,6 +538,7 @@ impl RawEGraph { } } else { let id = this.residual.unionfind.make_set(); + this.undo_log.add_node(&original, enode.children(), id); debug_assert_eq!(Id::from(this.nodes.len()), id); this.residual.nodes.push(original); @@ -508,7 +559,8 @@ impl RawEGraph { this.pending.push(id); this.classes.insert(id, class); - assert!(this.residual.memo.insert(enode, id).is_none()); + this.residual.memo.insert_with_hash(hash, enode, id); + this.undo_log.insert_memo(hash); id } @@ -545,6 +597,7 @@ impl RawEGraph { let class2 = self.classes.remove(&id2).unwrap(); let class1 = self.classes.get_mut(&id1).unwrap(); assert_eq!(id1, class1.id); + let (p1, p2) = (Parents(&class1.parents), Parents(&class2.parents)); merge( &mut class1.raw_data, @@ -557,10 +610,11 @@ impl RawEGraph { self.pending.extend(&class2.parents); - class1.parents.extend(class2.parents); + class1.parents.extend(&class2.parents); + + self.undo_log.union(id1, id2, class2.parents); } - #[inline] /// Rebuild to [`RawEGraph`] to restore congruence closure /// /// ## Parameters @@ -578,6 +632,7 @@ impl RawEGraph { /// ### `handle_pending` /// Called with the uncanonical id of each enode whose canonical children have changned, along with a canonical /// version of it + #[inline] pub fn raw_rebuild( outer: &mut T, get_self: impl Fn(&mut T) -> &mut Self, @@ -590,8 +645,17 @@ impl RawEGraph { let mut node = this.id_to_node(class).clone(); node.update_children(|id| this.find_mut(id)); handle_pending(outer, class, &node); - if let Some(memo_class) = get_self(outer).residual.memo.insert(node, class) { - perform_union(outer, memo_class, class); + let this = get_self(outer); + let (entry, hash) = this.residual.memo.entry(node); + match entry { + Entry::Occupied((_, id)) => { + let memo_class = *id; + perform_union(outer, memo_class, class); + } + Entry::Vacant(vac) => { + this.undo_log.insert_memo(hash); + vac.insert(class); + } } } else { break; @@ -599,13 +663,6 @@ impl RawEGraph { } } - /// Returns whether `self` is congruently closed - /// - /// This will always be true after calling [`raw_rebuild`](RawEGraph::raw_rebuild) - pub fn is_clean(&self) -> bool { - self.pending.is_empty() - } - /// Returns a more debug-able representation of the egraph focusing on its classes. /// /// [`RawEGraph`]s implement [`Debug`], but it's not pretty. It @@ -620,9 +677,18 @@ impl RawEGraph { { EGraphDump(self) } + + /// Remove all nodes from this egraph + pub fn clear(&mut self) { + self.residual.nodes.clear(); + self.residual.memo.clear(); + self.residual.unionfind.clear(); + self.pending.clear(); + self.undo_log.clear(); + } } -impl RawEGraph { +impl> RawEGraph { /// Simplified version of [`raw_add`](RawEGraph::raw_add) for egraphs without eclass data pub fn add_uncanonical(&mut self, enode: L) -> Id { Self::raw_add( @@ -668,9 +734,9 @@ impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> { } } -struct EGraphDump<'a, L: Language, D>(&'a RawEGraph); +struct EGraphDump<'a, L: Language, D, U>(&'a RawEGraph); -impl<'a, L: Language, D: Debug> Debug for EGraphDump<'a, L, D> { +impl<'a, L: Language, D: Debug, U> Debug for EGraphDump<'a, L, D, U> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut ids: Vec = self.0.classes().map(|c| c.id).collect(); ids.sort(); diff --git a/src/raw/semi_persistent.rs b/src/raw/semi_persistent.rs new file mode 100644 index 00000000..87112186 --- /dev/null +++ b/src/raw/semi_persistent.rs @@ -0,0 +1,128 @@ +use crate::raw::RawEGraph; +use crate::{Id, Language}; +use std::fmt::Debug; + +pub trait Sealed {} +impl Sealed for () {} +impl Sealed for Option {} + +/// A sealed trait for types that can be used for `push`/`pop` APIs +/// It is trivially implemented for `()` +pub trait UndoLogT: Default + Debug + Sealed { + #[doc(hidden)] + fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id); + + #[doc(hidden)] + fn union(&mut self, id1: Id, id2: Id, id2_parents: Vec); + + #[doc(hidden)] + fn insert_memo(&mut self, hash: u64); + + #[doc(hidden)] + fn clear(&mut self); + + #[doc(hidden)] + fn is_enabled(&self) -> bool; +} + +impl UndoLogT for () { + #[inline] + fn add_node(&mut self, _: &L, _: &[Id], _: Id) {} + + #[inline] + fn union(&mut self, _: Id, _: Id, _: Vec) {} + + #[inline] + fn insert_memo(&mut self, _: u64) {} + + #[inline] + fn clear(&mut self) {} + + fn is_enabled(&self) -> bool { + false + } +} + +impl> UndoLogT for Option { + #[inline] + fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id) { + if let Some(undo) = self { + undo.add_node(node, canon_children, node_id) + } + } + + #[inline] + fn union(&mut self, id1: Id, id2: Id, id2_parents: Vec) { + if let Some(undo) = self { + undo.union(id1, id2, id2_parents) + } + } + + #[inline] + fn insert_memo(&mut self, hash: u64) { + if let Some(undo) = self { + undo.insert_memo(hash) + } + } + + #[inline] + fn clear(&mut self) { + if let Some(undo) = self { + undo.clear() + } + } + + #[inline] + fn is_enabled(&self) -> bool { + self.as_ref().map(U::is_enabled).unwrap_or(false) + } +} + +/// Trait implemented for `T` and `Option` used to provide bounds for push/pop impls +pub trait AsUnwrap { + #[doc(hidden)] + fn as_unwrap(&self) -> &T; + + #[doc(hidden)] + fn as_mut_unwrap(&mut self) -> &mut T; +} + +impl AsUnwrap for T { + #[inline] + fn as_unwrap(&self) -> &T { + self + } + + #[inline] + fn as_mut_unwrap(&mut self) -> &mut T { + self + } +} +impl AsUnwrap for Option { + #[inline] + fn as_unwrap(&self) -> &T { + self.as_ref().unwrap() + } + + #[inline] + fn as_mut_unwrap(&mut self) -> &mut T { + self.as_mut().unwrap() + } +} + +impl> RawEGraph { + /// Change the [`UndoLogT`] being used + /// + /// If the new [`UndoLogT`] is enabled then the egraph must be empty + pub fn set_undo_log(&mut self, undo_log: U) { + if !self.is_empty() && undo_log.is_enabled() { + panic!("Need to set undo log enabled before adding any expressions to the egraph.") + } + self.undo_log = undo_log + } + + /// Check if the [`UndoLogT`] being used is enabled + pub fn has_undo_log(&self) -> bool { + self.undo_log.is_enabled() + } +} diff --git a/src/raw/semi_persistent1.rs b/src/raw/semi_persistent1.rs new file mode 100644 index 00000000..5ec5f2c9 --- /dev/null +++ b/src/raw/semi_persistent1.rs @@ -0,0 +1,223 @@ +use crate::raw::dhashmap::DHMIdx; +use crate::raw::{AsUnwrap, RawEClass, RawEGraph, Sealed, UndoLogT, UnionFind}; +use crate::{Id, Language}; +use std::fmt::Debug; + +/// Stored information required to restore the egraph to a previous state +/// +/// see [`push1`](RawEGraph::push1) and [`pop1`](RawEGraph::pop1) +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct PushInfo { + node_count: usize, + union_count: usize, + memo_log_count: usize, + pop_parents_count: usize, +} + +impl PushInfo { + /// Returns the result of [`EGraphResidual::number_of_uncanonical_nodes`](super::EGraphResidual::number_of_uncanonical_nodes) + /// from the state where `self` was created + pub fn number_of_uncanonical_nodes(&self) -> usize { + self.node_count + } +} + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +struct UnionInfo { + old_id: Id, + old_parents: Vec, + added_after: u32, +} + +/// Value for [`UndoLogT`] that enables [`push1`](RawEGraph::push1) and [`raw_pop1`](RawEGraph::raw_pop1) +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct UndoLog { + // Mirror of the union find without path compression + undo_find: UnionFind, + pop_parents: Vec, + union_log: Vec, + memo_log: Vec, +} + +impl Default for UndoLog { + fn default() -> Self { + UndoLog { + undo_find: Default::default(), + pop_parents: Default::default(), + union_log: vec![UnionInfo { + old_id: Id::from(0), + old_parents: vec![], + added_after: 0, + }], + memo_log: Default::default(), + } + } +} + +impl Sealed for UndoLog {} + +impl UndoLogT for UndoLog { + fn add_node(&mut self, _: &L, canon_children: &[Id], node_id: Id) { + let new = self.undo_find.make_set(); + debug_assert_eq!(new, node_id); + self.pop_parents.extend(canon_children); + self.union_log.last_mut().unwrap().added_after += canon_children.len() as u32; + } + + fn union(&mut self, id1: Id, id2: Id, old_parents: Vec) { + self.undo_find.union(id1, id2); + self.union_log.push(UnionInfo { + old_id: id2, + added_after: 0, + old_parents, + }) + } + + fn insert_memo(&mut self, hash: u64) { + self.memo_log.push(hash); + } + + fn clear(&mut self) { + self.union_log.truncate(1); + self.union_log[0].added_after = 0; + self.memo_log.clear(); + self.undo_find.clear(); + } + + #[inline] + fn is_enabled(&self) -> bool { + true + } +} + +impl> RawEGraph { + /// Create a [`PushInfo`] representing the current state of the egraph + /// which can later be passed into [`raw_pop1`](RawEGraph::raw_pop1) + /// + /// Requires [`self.is_clean()`](RawEGraph::is_clean) + /// + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// use egg::raw::semi_persistent1::UndoLog; + /// let mut egraph = RawEGraph::::default(); + /// let a = egraph.add_uncanonical(S::leaf("a")); + /// let fa = egraph.add_uncanonical(S::new("f", vec![a])); + /// let c = egraph.add_uncanonical(S::leaf("c")); + /// egraph.rebuild(); + /// let old = egraph.clone(); + /// let restore_point = egraph.push1(); + /// let b = egraph.add_uncanonical(S::leaf("b")); + /// let _fb = egraph.add_uncanonical(S::new("g", vec![b])); + /// egraph.union(b, a); + /// egraph.union(b, c); + /// egraph.rebuild(); + /// egraph.pop1(restore_point); + /// assert_eq!(format!("{:#?}", egraph.dump_uncanonical()), format!("{:#?}", old.dump_uncanonical())); + /// assert_eq!(format!("{:#?}", egraph), format!("{:#?}", old)); + /// ``` + pub fn push1(&self) -> PushInfo { + assert!(self.is_clean()); + let undo = self.undo_log.as_unwrap(); + PushInfo { + node_count: self.number_of_uncanonical_nodes(), + union_count: undo.union_log.len(), + memo_log_count: undo.memo_log.len(), + pop_parents_count: undo.pop_parents.len(), + } + } + + /// Mostly restores the egraph to the state it was it when it called [`push1`](RawEGraph::push1) + /// to create `info` + /// + /// Invalidates all [`PushInfo`]s that were created after `info` + /// + /// The `raw_data` fields of the [`RawEClass`]s are not properly restored + /// Instead, `split` is called to undo each union with a mutable reference to the merged data, and the two ids + /// that were merged to create the data for the eclass of the second `id` (the eclass of the first `id` will + /// be what's left of the merged data after the call) + pub fn raw_pop1(&mut self, info: PushInfo, split: impl FnMut(&mut D, Id, Id) -> D) { + let PushInfo { + node_count, + union_count, + memo_log_count, + pop_parents_count, + } = info; + self.pop_memo1(memo_log_count); + self.pop_unions1(union_count, pop_parents_count, split); + self.pop_nodes1(node_count); + } + + fn pop_memo1(&mut self, old_count: usize) { + assert!(self.memo.len() >= old_count); + let memo_log = &mut self.undo_log.as_mut_unwrap().memo_log; + let len = memo_log.len(); + for (hash, idx) in memo_log.drain(old_count..).zip(old_count..len).rev() { + self.residual.memo.remove_nth(hash, idx as DHMIdx); + } + } + + fn pop_unions1( + &mut self, + old_count: usize, + pop_parents_count: usize, + mut split: impl FnMut(&mut D, Id, Id) -> D, + ) { + let undo = self.undo_log.as_mut_unwrap(); + assert!(self.residual.number_of_uncanonical_nodes() >= old_count); + for info in undo.union_log.drain(old_count..).rev() { + for _ in 0..info.added_after { + let id = undo.pop_parents.pop().unwrap(); + self.classes.get_mut(&id).unwrap().parents.pop(); + } + let old_id = info.old_id; + let new_id = undo.undo_find.parent(old_id); + debug_assert_ne!(new_id, old_id); + debug_assert_eq!(undo.undo_find.find(new_id), new_id); + *undo.undo_find.parent_mut(old_id) = old_id; + let new_class = &mut self.classes.get_mut(&new_id).unwrap(); + let cut = new_class.parents.len() - info.old_parents.len(); + debug_assert_eq!(&new_class.parents[cut..], &info.old_parents); + new_class.parents.truncate(cut); + let old_data = split(&mut new_class.raw_data, new_id, old_id); + self.classes.insert( + old_id, + RawEClass { + id: old_id, + raw_data: old_data, + parents: info.old_parents, + }, + ); + } + let rem = undo.pop_parents.len() - pop_parents_count; + for _ in 0..rem { + let id = undo.pop_parents.pop().unwrap(); + self.classes.get_mut(&id).unwrap().parents.pop(); + } + undo.union_log.last_mut().unwrap().added_after -= rem as u32; + } + + fn pop_nodes1(&mut self, old_count: usize) { + assert!(self.number_of_uncanonical_nodes() >= old_count); + let undo = self.undo_log.as_mut_unwrap(); + undo.undo_find.parents.truncate(old_count); + self.residual + .unionfind + .parents + .clone_from(&undo.undo_find.parents); + for id in (old_count..self.number_of_uncanonical_nodes()).map(Id::from) { + self.classes.remove(&id); + } + self.residual.nodes.truncate(old_count); + } +} + +impl> RawEGraph { + /// Simplified version of [`raw_pop1`](RawEGraph::raw_pop1) for egraphs without eclass data + pub fn pop1(&mut self, info: PushInfo) { + self.raw_pop1(info, |_, _, _| ()) + } +} diff --git a/src/raw/semi_persistent2.rs b/src/raw/semi_persistent2.rs new file mode 100644 index 00000000..5d715424 --- /dev/null +++ b/src/raw/semi_persistent2.rs @@ -0,0 +1,343 @@ +use crate::raw::dhashmap::DHMIdx; +use crate::raw::{AsUnwrap, RawEClass, RawEGraph, Sealed, UndoLogT}; +use crate::util::{Entry, HashSet}; +use crate::{Id, Language}; +use std::fmt::Debug; + +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +struct UndoNode { + /// Other ENodes that were unioned with this ENode and chose it as their representative + representative_of: Vec, + /// Non-canonical Id's of direct parents of this non-canonical node + parents: Vec, +} + +fn visit_undo_node(id: Id, undo_find: &[UndoNode], f: &mut impl FnMut(Id, &UndoNode)) { + let node = &undo_find[usize::from(id)]; + f(id, node); + node.representative_of + .iter() + .for_each(|&id| visit_undo_node(id, undo_find, &mut *f)) +} + +/// Stored information required to restore the egraph to a previous state +/// +/// see [`push2`](RawEGraph::push2) and [`pop2`](RawEGraph::pop2) +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct PushInfo { + node_count: usize, + union_count: usize, + memo_log_count: usize, + pop_parents_count: usize, +} + +impl PushInfo { + /// Returns the result of [`EGraphResidual::number_of_uncanonical_nodes`](super::EGraphResidual::number_of_uncanonical_nodes) + /// from the state where `self` was created + pub fn number_of_uncanonical_nodes(&self) -> usize { + self.node_count + } +} + +/// Value for [`UndoLogT`] that enables [`push2`](RawEGraph::push2) and [`raw_pop2`](RawEGraph::raw_pop2) +#[derive(Clone, Debug, Default)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct UndoLog { + undo_find: Vec, + union_log: Vec, + memo_log: Vec, + pop_parents: Vec, + // Scratch space, should be empty other that when inside `pop` + #[cfg_attr(feature = "serde-1", serde(skip))] + dirty: HashSet, +} + +impl Sealed for UndoLog {} + +impl UndoLogT for UndoLog { + fn add_node(&mut self, node: &L, canon: &[Id], node_id: Id) { + debug_assert_eq!(self.undo_find.len(), usize::from(node_id)); + self.undo_find.push(UndoNode::default()); + if !canon.is_empty() { + // this node's children shouldn't since it was equivalent when it was added + for id in node.children() { + self.undo_find[usize::from(*id)].parents.push(node_id) + } + } + self.pop_parents.extend(canon) + } + + fn union(&mut self, id1: Id, id2: Id, _: Vec) { + self.undo_find[usize::from(id1)].representative_of.push(id2); + self.union_log.push(id1) + } + + fn insert_memo(&mut self, hash: u64) { + self.memo_log.push(hash); + } + + fn clear(&mut self) { + self.union_log.clear(); + self.memo_log.clear(); + self.undo_find.clear(); + } + + fn is_enabled(&self) -> bool { + true + } +} + +impl> RawEGraph { + /// Create a [`PushInfo`] representing the current state of the egraph + /// which can later be passed into [`raw_pop2`](RawEGraph::raw_pop2) + /// + /// Requires [`self.is_clean()`](RawEGraph::is_clean) + /// + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// use egg::raw::semi_persistent2::UndoLog; + /// let mut egraph = RawEGraph::::default(); + /// let a = egraph.add_uncanonical(S::leaf("a")); + /// let fa = egraph.add_uncanonical(S::new("f", vec![a])); + /// let c = egraph.add_uncanonical(S::leaf("c")); + /// egraph.rebuild(); + /// assert_eq!(egraph.number_of_classes(), 3); + /// assert_eq!(egraph.number_of_uncanonical_nodes(), 3); + /// assert_eq!(egraph.total_size(), 3); + /// let restore_point = egraph.push2(); + /// let b = egraph.add_uncanonical(S::leaf("b")); + /// let _fb = egraph.add_uncanonical(S::new("g", vec![b])); + /// egraph.union(b, a); + /// egraph.union(b, c); + /// egraph.rebuild(); + /// assert_eq!(egraph.find(a), b); + /// assert_eq!(egraph.number_of_classes(), 3); + /// assert_eq!(egraph.number_of_uncanonical_nodes(), 5); + /// assert_eq!(egraph.total_size(), 6); + /// egraph.pop2(restore_point); + /// assert_ne!(egraph.find(a), b); + /// assert_eq!(egraph.lookup(S::leaf("a")), Some(a)); + /// assert_eq!(egraph.lookup(S::new("f", vec![a])), Some(fa)); + /// assert_eq!(egraph.lookup(S::leaf("b")), None); + /// assert_eq!(egraph.number_of_classes(), 3); + /// assert_eq!(egraph.number_of_uncanonical_nodes(), 3); + /// assert_eq!(egraph.total_size(), 3); + /// ``` + pub fn push2(&self) -> PushInfo { + assert!(self.is_clean()); + let undo = self.undo_log.as_unwrap(); + PushInfo { + node_count: undo.undo_find.len(), + union_count: undo.union_log.len(), + memo_log_count: undo.memo_log.len(), + pop_parents_count: undo.pop_parents.len(), + } + } + + /// Mostly restores the egraph to the state it was it when it called [`push2`](RawEGraph::push2) + /// to create `info` + /// + /// Invalidates all [`PushInfo`]s that were created after `info` + /// + /// The `raw_data` fields of the [`RawEClass`]s are not properly restored + /// Instead all eclasses that have were merged into another eclass are recreated with `mk_data` and + /// `clear` is called eclass that had another eclass merged into them + /// + /// After each call to either `mk_data` or `clear`, `handle_eqv` is called on each id that is in + /// the eclass (that was handled by `mk_data` or `clear` + /// + /// The `state` parameter represents arbitrary state that be accessed in any of the closures + pub fn raw_pop2( + &mut self, + info: PushInfo, + state: &mut T, + clear: impl FnMut(&mut T, &mut D, Id, UndoCtx<'_, L>), + mk_data: impl FnMut(&mut T, Id, UndoCtx<'_, L>) -> D, + handle_eqv: impl FnMut(&mut T, &mut D, Id, UndoCtx<'_, L>), + ) { + let PushInfo { + node_count, + union_count, + memo_log_count, + pop_parents_count, + } = info; + self.pop_memo2(memo_log_count); + self.pop_parents2(pop_parents_count, node_count); + self.pop_unions2(union_count, node_count, state, clear, mk_data, handle_eqv); + self.pop_nodes2(usize::from(node_count)); + } + + fn pop_memo2(&mut self, old_count: usize) { + assert!(self.memo.len() >= old_count); + let memo_log = &mut self.undo_log.as_mut_unwrap().memo_log; + let len = memo_log.len(); + for (hash, idx) in memo_log.drain(old_count..).zip(old_count..len).rev() { + self.residual.memo.remove_nth(hash, idx as DHMIdx); + } + } + + fn pop_parents2(&mut self, old_count: usize, node_count: usize) { + // Pop uncanonical parents within undo find + let undo = self.undo_log.as_mut_unwrap(); + for (id, node) in self + .residual + .nodes + .iter() + .enumerate() + .skip(node_count) + .rev() + { + for child in node.children() { + let parents = &mut undo.undo_find[usize::from(*child)].parents; + if parents.last().copied() == Some(Id::from(id)) { + // Otherwise this id's children never had it added to its parents + // since it was already equivalent to another node when it was added + parents.pop(); + } + } + } + // Pop canonical parents from classes in egraph + // Note, if `id` is not canonical then its class must have been merged into another class so it's parents will + // be rebuilt anyway + // If another class was merged into `id` we will be popping an incorrect parent, but again it's parents will + // be rebuilt anyway + for id in undo.pop_parents.drain(old_count..) { + if let Some(x) = self.classes.get_mut(&id) { + x.parents.pop(); + } + } + } + + fn pop_unions2( + &mut self, + old_count: usize, + node_count: usize, + state: &mut T, + mut clear: impl FnMut(&mut T, &mut D, Id, UndoCtx<'_, L>), + mut mk_data: impl FnMut(&mut T, Id, UndoCtx<'_, L>) -> D, + mut handle_eqv: impl FnMut(&mut T, &mut D, Id, UndoCtx<'_, L>), + ) { + let undo = self.undo_log.as_mut_unwrap(); + assert!(undo.union_log.len() >= old_count); + for id in undo.union_log.drain(old_count..) { + let id2 = undo.undo_find[usize::from(id)] + .representative_of + .pop() + .unwrap(); + for id in [id, id2] { + if usize::from(id) < node_count { + undo.dirty.insert(id); + } + } + } + let ctx = UndoCtx { + nodes: &self.residual.nodes[..node_count], + undo_find: &undo.undo_find[..node_count], + }; + for root in undo.dirty.iter().copied() { + let union_find = &mut self.residual.unionfind; + let class = match self.classes.entry(root) { + Entry::Vacant(vac) => { + let default = RawEClass { + id: root, + raw_data: mk_data(state, root, ctx), + parents: Default::default(), + }; + vac.insert(default) + } + Entry::Occupied(occ) => { + let res = occ.into_mut(); + clear(state, &mut res.raw_data, root, ctx); + res.parents.clear(); + res + } + }; + class.parents.clear(); + let parents = &mut class.parents; + let data = &mut class.raw_data; + visit_undo_node(root, &undo.undo_find, &mut |id, node| { + union_find.parents[usize::from(id)] = root; + parents.extend(&node.parents); + handle_eqv(state, data, id, ctx) + }); + // If we call pop again we need parents added more recently at the end + parents.sort_unstable(); + } + undo.dirty.clear(); + } + + fn pop_nodes2(&mut self, old_count: usize) { + assert!(self.number_of_uncanonical_nodes() >= old_count); + for id in (old_count..self.number_of_uncanonical_nodes()).map(Id::from) { + if self.find(id) == id { + self.classes.remove(&id); + } + } + self.residual.nodes.truncate(old_count); + self.undo_log.as_mut_unwrap().undo_find.truncate(old_count); + self.residual.unionfind.parents.truncate(old_count); + } + + /// Returns the [`UndoCtx`] corresponding to the current egraph + pub fn undo_ctx(&self) -> UndoCtx<'_, L> { + UndoCtx { + nodes: &self.nodes, + undo_find: &self.undo_log.as_unwrap().undo_find, + } + } +} + +/// The egraph is in a partially broken state during a call to [`RawEGraph::raw_pop2`] so the passed in closures +/// are given this struct which represents the aspects of the egraph that are currently usable +pub struct UndoCtx<'a, L> { + nodes: &'a [L], + undo_find: &'a [UndoNode], +} + +impl<'a, L> Copy for UndoCtx<'a, L> {} + +impl<'a, L> Clone for UndoCtx<'a, L> { + fn clone(&self) -> Self { + *self + } +} + +impl<'a, L> UndoCtx<'a, L> { + /// Calls `f` on all nodes that are equivalent to `id` + /// + /// Requires `id` to be canonical + pub fn equivalent_nodes(self, id: Id, mut f: impl FnMut(Id)) { + visit_undo_node(id, self.undo_find, &mut |id, _| f(id)) + } + + /// Returns an iterator of the uncanonical ids of nodes that contain the uncanonical id `id` + pub fn direct_parents(self, id: Id) -> impl ExactSizeIterator + 'a { + self.undo_find[usize::from(id)].parents.iter().copied() + } + + /// See [`EGraphResidual::id_to_node`](super::EGraphResidual::id_to_node) + pub fn id_to_node(self, id: Id) -> &'a L { + &self.nodes[usize::from(id)] + } + + /// See [`EGraphResidual::number_of_uncanonical_nodes`](super::EGraphResidual::number_of_uncanonical_nodes) + pub fn number_of_uncanonical_nodes(self) -> usize { + self.nodes.len() + } +} + +impl> RawEGraph { + /// Simplified version of [`raw_pop2`](RawEGraph::raw_pop2) for egraphs without eclass data + pub fn pop2(&mut self, info: PushInfo) { + self.raw_pop2( + info, + &mut (), + |_, _, _, _| {}, + |_, _, _| (), + |_, _, _, _| {}, + ) + } +} diff --git a/src/unionfind.rs b/src/raw/unionfind.rs similarity index 74% rename from src/unionfind.rs rename to src/raw/unionfind.rs index 39e9bc58..32fc8e0c 100644 --- a/src/unionfind.rs +++ b/src/raw/unionfind.rs @@ -3,29 +3,33 @@ use std::fmt::Debug; #[derive(Debug, Clone, Default)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +/// Data structure that stores disjoint sets of `Id`s each with a representative pub struct UnionFind { - parents: Vec, + pub(super) parents: Vec, } impl UnionFind { + /// Creates a singleton set and returns its representative pub fn make_set(&mut self) -> Id { let id = Id::from(self.parents.len()); self.parents.push(id); id } + /// Returns the number of ids in all the sets pub fn size(&self) -> usize { self.parents.len() } - fn parent(&self, query: Id) -> Id { + pub(super) fn parent(&self, query: Id) -> Id { self.parents[usize::from(query)] } - fn parent_mut(&mut self, query: Id) -> &mut Id { + pub(super) fn parent_mut(&mut self, query: Id) -> &mut Id { &mut self.parents[usize::from(query)] } + /// Returns the representative of the set `current` belongs to pub fn find(&self, mut current: Id) -> Id { while current != self.parent(current) { current = self.parent(current) @@ -33,6 +37,7 @@ impl UnionFind { current } + /// Equivalent to [`find`](UnionFind::find) but preforms path-compression to optimize further calls pub fn find_mut(&mut self, mut current: Id) -> Id { while current != self.parent(current) { let grandparent = self.parent(self.parent(current)); @@ -42,11 +47,16 @@ impl UnionFind { current } - /// Given two leader ids, unions the two eclasses making root1 the leader. + /// Given two representative ids, unions the two eclasses making root1 the representative. pub fn union(&mut self, root1: Id, root2: Id) -> Id { *self.parent_mut(root2) = root1; root1 } + + /// Resets the union find + pub fn clear(&mut self) { + self.parents.clear() + } } #[cfg(test)] diff --git a/src/util.rs b/src/util.rs index 0e9051ee..808be648 100644 --- a/src/util.rs +++ b/src/util.rs @@ -53,12 +53,15 @@ pub(crate) use hashmap::*; mod hashmap { pub(crate) type HashMap = super::IndexMap; pub(crate) type HashSet = super::IndexSet; + + pub(crate) type Entry<'a, K, V> = indexmap::map::Entry; } #[cfg(not(feature = "deterministic"))] mod hashmap { use super::BuildHasher; pub(crate) type HashMap = hashbrown::HashMap; pub(crate) type HashSet = hashbrown::HashSet; + pub(crate) type Entry<'a, K, V> = hashbrown::hash_map::Entry<'a, K, V, BuildHasher>; } pub(crate) type IndexMap = indexmap::IndexMap; From d7bc03e485aa4512d50b3046d208008d5fc542e7 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Mon, 12 Feb 2024 12:30:40 -0800 Subject: [PATCH 10/20] Lifted semi-persistence to `EGraph` --- Cargo.toml | 2 + Makefile | 2 + src/egraph.rs | 236 ++++++++++++++++++++++++++++++++- src/explain.rs | 74 ++++++++--- src/explain/semi_persistent.rs | 73 ++++++++++ src/language.rs | 10 +- src/test.rs | 53 +++++++- src/util.rs | 7 +- tests/lambda.rs | 4 +- tests/math.rs | 11 +- 10 files changed, 437 insertions(+), 35 deletions(-) create mode 100644 src/explain/semi_persistent.rs diff --git a/Cargo.toml b/Cargo.toml index 21f54c71..e304dc9a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,9 +50,11 @@ serde-1 = [ "vectorize", ] wasm-bindgen = ["instant/wasm-bindgen"] +push-pop-alt = [] # private features for testing test-explanations = [] +test-push-pop = ["deterministic"] [package.metadata.docs.rs] all-features = true diff --git a/Makefile b/Makefile index 229977bf..279518cb 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,8 @@ test: cargo test --release --features=lp # don't run examples in proof-production mode cargo test --release --features "test-explanations" + cargo test --release --features "test-push-pop" --features "test-explanations" + cargo test --release --features "test-push-pop" --features "push-pop-alt" .PHONY: nits diff --git a/src/egraph.rs b/src/egraph.rs index 0eafb9d4..e35e3787 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1,5 +1,6 @@ use crate::*; use std::fmt::{self, Debug, Display}; +use std::mem; use std::ops::Deref; #[cfg(feature = "serde-1")] @@ -9,6 +10,14 @@ use crate::eclass::EClassData; use crate::raw::{EGraphResidual, RawEGraph}; use log::*; +#[cfg(feature = "push-pop-alt")] +use raw::semi_persistent1 as sp; + +#[cfg(not(feature = "push-pop-alt"))] +use raw::semi_persistent2 as sp; + +use sp::UndoLog; +type PushInfo = (sp::PushInfo, explain::PushInfo, usize); /** A data structure to keep track of equalities between expressions. Check out the [background tutorial](crate::tutorials::_01_background) @@ -64,7 +73,7 @@ pub struct EGraph> { deserialize = "N::Data: for<'a> Deserialize<'a>", )) )] - pub(crate) inner: RawEGraph>, + pub(crate) inner: RawEGraph, Option>, #[cfg_attr(feature = "serde-1", serde(skip))] #[cfg_attr(feature = "serde-1", serde(default = "default_classes_by_op"))] pub(crate) classes_by_op: HashMap>, @@ -75,6 +84,8 @@ pub struct EGraph> { /// Only manually set it if you know what you're doing. #[cfg_attr(feature = "serde-1", serde(skip))] pub clean: bool, + push_log: Vec, + data_history: Vec<(Id, N::Data)>, } #[cfg(feature = "serde-1")] @@ -114,6 +125,8 @@ impl> EGraph { inner: Default::default(), analysis_pending: Default::default(), classes_by_op: Default::default(), + push_log: Default::default(), + data_history: Default::default(), } } @@ -167,7 +180,11 @@ impl> EGraph { if self.total_size() > 0 { panic!("Need to set explanations enabled before adding any expressions to the egraph."); } - self.explain = Some(Explain::new()); + let mut explain = Explain::new(); + if self.inner.has_undo_log() { + explain.enable_undo_log() + } + self.explain = Some(explain); self } @@ -193,6 +210,28 @@ impl> EGraph { } } + /// Enable [`push`](EGraph::push) and [`pop`](EGraph::pop) for this `EGraph`. + /// This allows the egraph to revert to an earlier state + pub fn with_push_pop_enabled(mut self) -> Self { + if self.inner.has_undo_log() { + return self; + } + self.inner.set_undo_log(Some(UndoLog::default())); + if let Some(explain) = &mut self.explain { + explain.enable_undo_log() + } + self + } + + /// Disable [`push`](EGraph::push) and [`pop`](EGraph::pop) for this `EGraph`. + pub fn with_push_pop_disabled(mut self) -> Self { + self.inner.set_undo_log(None); + if let Some(explain) = &mut self.explain { + explain.disable_undo_log() + } + self + } + /// Make a copy of the egraph with the same nodes, but no unions between them. pub fn copy_without_unions(&self, analysis: N) -> Self { if self.explain.is_none() { @@ -804,9 +843,14 @@ impl> EGraph { self.clean = false; let mut new_root = None; + let has_undo_log = self.inner.has_undo_log(); self.inner - .raw_union(enode_id1, enode_id2, |class1, id1, p1, class2, _, p2| { + .raw_union(enode_id1, enode_id2, |class1, id1, p1, class2, id2, p2| { new_root = Some(id1); + if has_undo_log && mem::size_of::() > 0 { + self.data_history.push((id1, class1.data.clone())); + self.data_history.push((id2, class2.data.clone())); + } let did_merge = self.analysis.merge(&mut class1.data, class2.data); if did_merge.0 { @@ -841,9 +885,13 @@ impl> EGraph { /// so [`Analysis::make`] and [`Analysis::merge`] will get /// called for other parts of the e-graph on rebuild. pub fn set_analysis_data(&mut self, id: Id, new_data: N::Data) { - let class = self.inner.get_class_mut(id).0; - class.data = new_data; + let mut canon = id; + let class = self.inner.get_class_mut(&mut canon).0; + let old_data = mem::replace(&mut class.data, new_data); self.analysis_pending.extend(class.parents()); + if self.inner.has_undo_log() && mem::size_of::() > 0 { + self.data_history.push((canon, old_data)) + } N::modify(self, id) } @@ -992,8 +1040,11 @@ impl> EGraph { while let Some(mut class_id) = self.analysis_pending.pop() { let node = self.id_to_node(class_id).clone(); let node_data = N::make(self, &node); + let has_undo_log = self.inner.has_undo_log(); let class = self.inner.get_class_mut(&mut class_id).0; - + if has_undo_log && mem::size_of::() > 0 { + self.data_history.push((class.id, class.data.clone())); + } let did_merge = self.analysis.merge(&mut class.data, node_data); if did_merge.0 { self.analysis_pending.extend(class.parents()); @@ -1086,6 +1137,176 @@ impl> EGraph { panic!("Can't check explain when explanations are off"); } } + + /// Remove all nodes from this egraph + pub fn clear(&mut self) { + self.push_log.clear(); + self.inner.clear(); + self.clean = true; + if let Some(explain) = &mut self.explain { + explain.clear() + } + self.analysis_pending.clear(); + self.data_history.clear(); + } +} + +impl> EGraph +where + N::Data: Default, +{ + /// Push the current egraph off the stack + /// Requires that the egraph is clean + /// + /// See [`EGraph::pop`] + pub fn push(&mut self) { + assert!( + self.analysis_pending.is_empty() && self.inner.is_clean(), + "`push` can only be called on clean egraphs" + ); + if !self.inner.has_undo_log() { + panic!("Use egraph.with_push_pop() before running to call push"); + } + N::pre_push(self); + let exp_push_info = self.explain.as_ref().map(Explain::push).unwrap_or_default(); + #[cfg(feature = "push-pop-alt")] + let raw_push_info = self.inner.push1(); + #[cfg(not(feature = "push-pop-alt"))] + let raw_push_info = self.inner.push2(); + self.push_log + .push((raw_push_info, exp_push_info, self.data_history.len())) + } + + /// Pop the current egraph off the stack, replacing + /// it with the previously [`push`](EGraph::push)ed egraph + /// + /// ``` + /// use egg::{EGraph, SymbolLang}; + /// let mut egraph = EGraph::new(()).with_push_pop_enabled(); + /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + /// egraph.rebuild(); + /// egraph.push(); + /// egraph.union(a, b); + /// assert_eq!(egraph.find(a), egraph.find(b)); + /// egraph.pop(); + /// assert_ne!(egraph.find(a), egraph.find(b)); + /// ``` + pub fn pop(&mut self) { + self.pop_n(1) + } + + /// Equivalent to calling [`pop`](EGraph::pop) `n` times but possibly more efficient + pub fn pop_n(&mut self, n: usize) { + if !self.inner.has_undo_log() { + panic!("Use egraph.with_push_pop() before running to call pop"); + } + if n > self.push_log.len() { + self.clear() + } + let mut info = None; + for _ in 0..n { + info = self.push_log.pop() + } + if let Some(info) = info { + self.pop_internal(info); + N::post_pop_n(self, n); + } + } + + #[cfg(not(feature = "push-pop-alt"))] + fn pop_internal(&mut self, (raw_info, exp_info, data_history_len): PushInfo) { + if let Some(explain) = &mut self.explain { + explain.pop( + exp_info, + raw_info.number_of_uncanonical_nodes(), + &self.inner, + ) + } + self.analysis_pending.clear(); + + let mut has_dirty_parents = Vec::new(); + let mut dirty_status = HashMap::default(); + self.inner.raw_pop2( + raw_info, + &mut dirty_status, + |dirty_status, data, id, _| { + dirty_status.insert(id, false); + data.nodes.clear(); + }, + |dirty_status, id, _| { + has_dirty_parents.push(id); + dirty_status.insert(id, false); + EClassData { + nodes: vec![], + data: Default::default(), + } + }, + |_, data, id, ctx| data.nodes.push(ctx.id_to_node(id).clone()), + ); + for id in has_dirty_parents { + for parent in self.inner.get_class_with_cannon(id).parents() { + dirty_status.entry(self.find(parent)).or_insert(true); + } + } + for (id, needs_reset) in dirty_status { + if needs_reset { + let mut nodes = mem::take(&mut self.inner.get_class_mut_with_cannon(id).0.nodes); + nodes.clear(); + self.inner + .undo_ctx() + .equivalent_nodes(id, |eqv| nodes.push(self.id_to_node(eqv).clone())); + self.inner.get_class_mut_with_cannon(id).0.nodes = nodes; + } + let (class, residual) = self.inner.get_class_mut_with_cannon(id); + for node in &mut class.nodes { + node.update_children(|id| residual.find(id)); + } + class.nodes.sort_unstable(); + class.nodes.dedup(); + } + + for (id, data) in self.data_history.drain(data_history_len..).rev() { + if usize::from(id) < self.inner.number_of_uncanonical_nodes() { + self.inner.get_class_mut_with_cannon(id).0.data = data; + } + } + + self.clean = true; + } + + #[cfg(feature = "push-pop-alt")] + fn pop_internal(&mut self, (raw_info, exp_info, data_history_len): PushInfo) { + if let Some(explain) = &mut self.explain { + explain.pop( + exp_info, + raw_info.number_of_uncanonical_nodes(), + &self.inner, + ) + } + self.analysis_pending.clear(); + + self.inner.raw_pop1(raw_info, |_, _, _| EClassData { + nodes: vec![], + data: Default::default(), + }); + + for class in self.classes_mut() { + class.nodes.clear() + } + + for id in self.uncanonical_ids() { + let node = self.id_to_node(id).clone().map_children(|x| self.find(x)); + self[id].nodes.push(node) + } + + for (id, data) in self.data_history.drain(data_history_len..).rev() { + if usize::from(id) < self.inner.number_of_uncanonical_nodes() { + self.inner.get_class_mut_with_cannon(id).0.data = data; + } + } + self.rebuild_classes(); + } } #[cfg(test)] @@ -1125,6 +1346,9 @@ mod tests { de(&egraph); let json_rep = serde_json::to_string_pretty(&egraph).unwrap(); + let egraph2: EGraph = serde_json::from_str(&json_rep).unwrap(); + let json_rep2 = serde_json::to_string_pretty(&egraph2).unwrap(); + assert_eq!(json_rep, json_rep2); println!("{}", json_rep); } } diff --git a/src/explain.rs b/src/explain.rs index 53a8ec61..c25f6ba2 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1,3 +1,6 @@ +mod semi_persistent; +pub(crate) use semi_persistent::PushInfo; + use crate::Symbol; use crate::{ util::pretty_print, Analysis, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, PatternAst, @@ -7,6 +10,7 @@ use saturating::Saturating; use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; +use std::mem; use std::ops::{Deref, DerefMut}; use std::rc::Rc; @@ -38,6 +42,18 @@ struct Connection { is_rewrite_forward: bool, } +impl Connection { + #[inline] + fn end(node: Id) -> Self { + Connection { + next: node, + current: node, + justification: Justification::Congruence, + is_rewrite_forward: false, + } + } +} + #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] struct ExplainNode { @@ -75,6 +91,7 @@ pub struct Explain { // That is, less than or equal to the result of `distance_between` #[cfg_attr(feature = "serde-1", serde(skip))] shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>, + undo_log: semi_persistent::UndoLog, } pub(crate) struct ExplainWith<'a, L: Language, X> { @@ -911,6 +928,7 @@ impl Explain { uncanon_memo: Default::default(), shortest_explanation_memo: Default::default(), optimize_explanation_lengths: true, + undo_log: None, } } @@ -920,33 +938,44 @@ impl Explain { pub(crate) fn add(&mut self, node: L, set: Id, existance_node: Id) -> Id { assert_eq!(self.explainfind.len(), usize::from(set)); - self.uncanon_memo.insert(node, set); + self.uncanon_memo.entry(node).or_insert(set); + // If the node already in uncanon memo keep the old version so it's easier to revert the add self.explainfind.push(ExplainNode { neighbors: vec![], - parent_connection: Connection { - justification: Justification::Congruence, - is_rewrite_forward: false, - next: set, - current: set, - }, + parent_connection: Connection::end(set), existance_node, }); set } + /// Reorient connections to make `node` the leader (Used for testing push/pop) + pub(crate) fn test_mk_root(&mut self, node: Id) { + self.set_parent(node, Connection::end(node)) + } + // reverse edges recursively to make this node the leader - fn make_leader(&mut self, node: Id) { - let next = self.explainfind[usize::from(node)].parent_connection.next; - if next != node { - self.make_leader(next); - let node_connection = &self.explainfind[usize::from(node)].parent_connection; + fn set_parent(&mut self, node: Id, parent: Connection) { + let mut prev = node; + let mut curr = mem::replace( + &mut self.explainfind[usize::from(prev)].parent_connection, + parent, + ); + let mut count = 0; + while prev != curr.next { let pconnection = Connection { - justification: node_connection.justification.clone(), - is_rewrite_forward: !node_connection.is_rewrite_forward, - next: node, - current: next, + justification: curr.justification, + is_rewrite_forward: !curr.is_rewrite_forward, + next: prev, + current: curr.next, }; - self.explainfind[usize::from(next)].parent_connection = pconnection; + let next = mem::replace( + &mut self.explainfind[usize::from(curr.next)].parent_connection, + pconnection, + ); + prev = curr.next; + curr = next; + count += 1; + assert!(count < 1000); } } @@ -984,6 +1013,7 @@ impl Explain { .insert((node1, node2), (Saturating(1), node2)); self.shortest_explanation_memo .insert((node2, node1), (Saturating(1), node1)); + self.undo_log_union(node1); } pub(crate) fn union( @@ -1000,9 +1030,6 @@ impl Explain { self.set_existance_reason(node2, node1) } - self.make_leader(node1); - self.explainfind[usize::from(node1)].parent_connection.next = node2; - if let Justification::Rule(_) = justification { self.shortest_explanation_memo .insert((node1, node2), (Saturating(1), node2)); @@ -1028,7 +1055,10 @@ impl Explain { self.explainfind[usize::from(node2)] .neighbors .push(other_pconnection); - self.explainfind[usize::from(node1)].parent_connection = pconnection; + + self.set_parent(node1, pconnection); + + self.undo_log_union(node1); } pub(crate) fn get_union_equalities(&self) -> UnionEqualities { let mut equalities = vec![]; @@ -1063,7 +1093,7 @@ impl<'a, L: Language, X> DerefMut for ExplainWith<'a, L, X> { } } -impl<'x, L: Language, D> ExplainWith<'x, L, &'x RawEGraph> { +impl<'x, L: Language, D, U> ExplainWith<'x, L, &'x RawEGraph> { pub(crate) fn node(&self, node_id: Id) -> &L { self.raw.id_to_node(node_id) } diff --git a/src/explain/semi_persistent.rs b/src/explain/semi_persistent.rs new file mode 100644 index 00000000..bcabcd94 --- /dev/null +++ b/src/explain/semi_persistent.rs @@ -0,0 +1,73 @@ +use crate::explain::{Connection, Explain}; +use crate::raw::EGraphResidual; +use crate::{Id, Language}; + +pub(super) type UndoLog = Option>; + +#[derive(Default, Clone, Debug)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub(crate) struct PushInfo(usize); + +impl Explain { + pub(super) fn undo_log_union(&mut self, node: Id) { + if let Some(x) = &mut self.undo_log { + x.push(node) + } + } + pub(crate) fn enable_undo_log(&mut self) { + assert_eq!(self.explainfind.len(), 0); + self.undo_log = Some(Vec::new()); + } + + pub(crate) fn disable_undo_log(&mut self) { + self.undo_log = None + } + + pub(crate) fn push(&self) -> PushInfo { + PushInfo(self.undo_log.as_ref().unwrap().len()) + } + + pub(crate) fn pop( + &mut self, + info: PushInfo, + number_of_uncanon_nodes: usize, + egraph: &EGraphResidual, + ) { + for id in self.undo_log.as_mut().unwrap().drain(info.0..).rev() { + let node1 = &mut self.explainfind[usize::from(id)]; + let id2 = node1.neighbors.pop().unwrap().next; + if node1.parent_connection.next == id2 { + node1.parent_connection = Connection::end(id); + } + let node2 = &mut self.explainfind[usize::from(id2)]; + let id1 = node2.neighbors.pop().unwrap().next; + assert_eq!(id, id1); + if node2.parent_connection.next == id1 { + node2.parent_connection = Connection::end(id2); + } + } + self.explainfind.truncate(number_of_uncanon_nodes); + // We can't easily undo memoize operations, so we just clear them + self.shortest_explanation_memo.clear(); + dbg!(egraph.dump_uncanonical()); + dbg!(&self.uncanon_memo); + for (id, node) in egraph.uncanonical_nodes().skip(number_of_uncanon_nodes) { + if *self.uncanon_memo.get(node).unwrap() == id { + self.uncanon_memo.remove(node).unwrap(); + } + } + } + + pub(crate) fn clear_memo(&mut self) { + self.shortest_explanation_memo.clear() + } + + pub(crate) fn clear(&mut self) { + if let Some(v) = &mut self.undo_log { + v.clear() + } + self.explainfind.clear(); + self.uncanon_memo.clear(); + self.shortest_explanation_memo.clear(); + } +} diff --git a/src/language.rs b/src/language.rs index 6414c63a..40072358 100644 --- a/src/language.rs +++ b/src/language.rs @@ -698,7 +698,7 @@ assert_eq!(runner.egraph.find(runner.roots[0]), runner.egraph.find(just_foo)); */ pub trait Analysis: Sized { /// The per-[`EClass`] data for this analysis. - type Data: Debug; + type Data: Debug + Clone; /// Makes a new [`Analysis`] data for a given e-node. /// @@ -761,6 +761,14 @@ pub trait Analysis: Sized { /// `Analysis::merge` when unions are performed. #[allow(unused_variables)] fn modify(egraph: &mut EGraph, id: Id) {} + + /// A hook called at the start of [`EGraph::push`] + #[allow(unused_variables)] + fn pre_push(egraph: &mut EGraph) {} + + /// A hook called at the end of [`EGraph::pop_n`] + #[allow(unused_variables)] + fn post_pop_n(egraph: &mut EGraph, n: usize) {} } impl Analysis for () { diff --git a/src/test.rs b/src/test.rs index 10815d66..44e185db 100644 --- a/src/test.rs +++ b/src/test.rs @@ -3,6 +3,8 @@ These are not considered part of the public api. */ +use std::cell::RefCell; +use std::rc::Rc; use std::{fmt::Display, fs::File, io::Write, path::PathBuf}; use saturating::Saturating; @@ -37,11 +39,19 @@ pub fn test_runner( should_check: bool, ) where L: Language + Display + FromOp + 'static, - A: Analysis + Default, + A: Analysis + Default + Clone + 'static, + A::Data: Default + Clone, { let _ = env_logger::builder().is_test(true).try_init(); let mut runner = runner.unwrap_or_default(); + let nodes: Vec<_> = runner + .egraph + .uncanonical_nodes() + .map(|(_, n)| n.clone()) + .collect(); + runner.egraph.clear(); + if let Some(lim) = env_var("EGG_NODE_LIMIT") { runner = runner.with_node_limit(lim) } @@ -57,6 +67,22 @@ pub fn test_runner( runner = runner.with_explanations_enabled(); } + let history = Rc::new(RefCell::new(Vec::new())); + let history2 = history.clone(); + // Test push if feature is on + if cfg!(feature = "test-push-pop") { + runner.egraph = runner.egraph.with_push_pop_enabled(); + runner = runner.with_hook(move |runner| { + runner.egraph.push(); + history2.borrow_mut().push(EGraph::clone(&runner.egraph)); + Ok(()) + }); + } + + for node in nodes { + runner.egraph.add_uncanonical(node); + } + runner = runner.with_expr(&start); // NOTE this is a bit of hack, we rely on the fact that the // initial root is the last expr added by the runner. We can't @@ -118,6 +144,31 @@ pub fn test_runner( if let Some(check_fn) = check_fn { check_fn(runner) + } else if cfg!(feature = "test-push-pop") { + let mut egraph = runner.egraph; + let _ = runner.hooks; + for mut old in history.borrow().iter().cloned().rev() { + dbg!(old.total_size()); + egraph.pop(); + assert_eq!( + format!("{:#?}", old.dump_uncanonical()), + format!("{:#?}", egraph.dump_uncanonical()), + ); + assert_eq!(format!("{:#?}", old), format!("{:#?}", egraph)); + assert_eq!( + format!("{:#?}", old.dump()), + format!("{:#?}", egraph.dump()), + ); + if let Some(explain) = &mut egraph.explain { + let old_explain = old.explain.as_mut().unwrap(); + old_explain.clear_memo(); + for class in egraph.inner.classes_mut().0 { + explain.test_mk_root(class.id); + old_explain.test_mk_root(class.id); + } + assert_eq!(format!("{:#?}", old_explain), format!("{:#?}", explain)); + } + } } } } diff --git a/src/util.rs b/src/util.rs index 808be648..61417210 100644 --- a/src/util.rs +++ b/src/util.rs @@ -54,7 +54,7 @@ mod hashmap { pub(crate) type HashMap = super::IndexMap; pub(crate) type HashSet = super::IndexSet; - pub(crate) type Entry<'a, K, V> = indexmap::map::Entry; + pub(crate) type Entry<'a, K, V> = indexmap::map::Entry<'a, K, V>; } #[cfg(not(feature = "deterministic"))] mod hashmap { @@ -174,4 +174,9 @@ where debug_assert_eq!(r, self.set.is_empty()); r } + + pub fn clear(&mut self) { + self.queue.clear(); + self.set.clear(); + } } diff --git a/tests/lambda.rs b/tests/lambda.rs index 80ea4fbd..0ec3acb8 100644 --- a/tests/lambda.rs +++ b/tests/lambda.rs @@ -33,10 +33,10 @@ impl Lambda { type EGraph = egg::EGraph; -#[derive(Default)] +#[derive(Default, Clone)] struct LambdaAnalysis; -#[derive(Debug)] +#[derive(Debug, Clone, Default)] struct Data { free: HashSet, constant: Option<(Lambda, PatternAst)>, diff --git a/tests/math.rs b/tests/math.rs index a0d8c07a..3d9dc5db 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -45,7 +45,7 @@ impl egg::CostFunction for MathCostFn { } } -#[derive(Default)] +#[derive(Default, Clone)] pub struct ConstantFold; impl Analysis for ConstantFold { type Data = Option<(Constant, PatternAst)>; @@ -102,6 +102,14 @@ impl Analysis for ConstantFold { egraph[id].assert_unique_leaves(); } } + + fn post_pop_n(egraph: &mut EGraph, _: usize) { + for class in egraph.classes_mut() { + if class.data.is_some() { + class.nodes.retain(|x| x.is_leaf()) + } + } + } } fn is_const_or_distinct_var(v: &str, w: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { @@ -275,7 +283,6 @@ egg::test_fn! { .with_time_limit(std::time::Duration::from_secs(10)) .with_iter_limit(60) .with_node_limit(100_000) - .with_explanations_enabled() // HACK this needs to "see" the end expression .with_expr(&"(* x (- (* 3 x) 14))".parse().unwrap()), "(d x (- (pow x 3) (* 7 (pow x 2))))" From e053f95ae19341f13158d238349b37fee1cb2660 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Mon, 12 Feb 2024 12:30:40 -0800 Subject: [PATCH 11/20] Lifted semi-persistence to `EGraph` --- src/explain/semi_persistent.rs | 2 -- src/test.rs | 1 - 2 files changed, 3 deletions(-) diff --git a/src/explain/semi_persistent.rs b/src/explain/semi_persistent.rs index bcabcd94..b8309b95 100644 --- a/src/explain/semi_persistent.rs +++ b/src/explain/semi_persistent.rs @@ -49,8 +49,6 @@ impl Explain { self.explainfind.truncate(number_of_uncanon_nodes); // We can't easily undo memoize operations, so we just clear them self.shortest_explanation_memo.clear(); - dbg!(egraph.dump_uncanonical()); - dbg!(&self.uncanon_memo); for (id, node) in egraph.uncanonical_nodes().skip(number_of_uncanon_nodes) { if *self.uncanon_memo.get(node).unwrap() == id { self.uncanon_memo.remove(node).unwrap(); diff --git a/src/test.rs b/src/test.rs index 44e185db..1a730930 100644 --- a/src/test.rs +++ b/src/test.rs @@ -148,7 +148,6 @@ pub fn test_runner( let mut egraph = runner.egraph; let _ = runner.hooks; for mut old in history.borrow().iter().cloned().rev() { - dbg!(old.total_size()); egraph.pop(); assert_eq!( format!("{:#?}", old.dump_uncanonical()), From 420e527df848f3e0b2765d360b91c81380568888 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Mon, 12 Feb 2024 14:14:19 -0800 Subject: [PATCH 12/20] Fixed bug and switched `UniqueQueue` to use `util::BuildHasher` --- src/explain/semi_persistent.rs | 6 +++++- src/raw/egraph.rs | 4 +++- src/util.rs | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/explain/semi_persistent.rs b/src/explain/semi_persistent.rs index b8309b95..744e8243 100644 --- a/src/explain/semi_persistent.rs +++ b/src/explain/semi_persistent.rs @@ -49,7 +49,11 @@ impl Explain { self.explainfind.truncate(number_of_uncanon_nodes); // We can't easily undo memoize operations, so we just clear them self.shortest_explanation_memo.clear(); - for (id, node) in egraph.uncanonical_nodes().skip(number_of_uncanon_nodes) { + for (id, node) in egraph + .uncanonical_nodes() + .skip(number_of_uncanon_nodes) + .rev() + { if *self.uncanon_memo.get(node).unwrap() == id { self.uncanon_memo.remove(node).unwrap(); } diff --git a/src/raw/egraph.rs b/src/raw/egraph.rs index dcbb775d..55cdc9eb 100644 --- a/src/raw/egraph.rs +++ b/src/raw/egraph.rs @@ -138,7 +138,9 @@ impl EGraphResidual { /// Returns an iterator over the uncanonical ids in the egraph and the node /// that would be obtained by calling [`id_to_node`](EGraphResidual::id_to_node) on each of them - pub fn uncanonical_nodes(&self) -> impl ExactSizeIterator { + pub fn uncanonical_nodes( + &self, + ) -> impl ExactSizeIterator + DoubleEndedIterator { self.nodes .iter() .enumerate() diff --git a/src/util.rs b/src/util.rs index 61417210..90eadece 100644 --- a/src/util.rs +++ b/src/util.rs @@ -128,7 +128,7 @@ pub(crate) struct UniqueQueue where T: Eq + std::hash::Hash + Clone, { - set: hashbrown::HashSet, + set: hashbrown::HashSet, queue: std::collections::VecDeque, } From 9da2e487e740b61799b97253499af4e6648076ce Mon Sep 17 00:00:00 2001 From: dewert99 Date: Mon, 12 Feb 2024 14:27:11 -0800 Subject: [PATCH 13/20] Removed ahash dependency --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e304dc9a..59545916 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ rust-version = "1.63.0" [dependencies] env_logger = { version = "0.9.0", default-features = false } fxhash = "0.2.1" -hashbrown = { version = "0.14.3", default-features = false, features = ["inline-more", "ahash"] } +hashbrown = { version = "0.14.3", default-features = false, features = ["inline-more"] } indexmap = "1.8.1" instant = "0.1.12" log = "0.4.17" From 8649cdbc32958ac189f4800695aca580aa762c46 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Mon, 12 Feb 2024 15:26:03 -0800 Subject: [PATCH 14/20] Added `DHashMap` tests --- src/raw/dhashmap.rs | 55 +++++++++++++++++++++++++++++++++++-- src/raw/semi_persistent1.rs | 3 +- src/raw/semi_persistent2.rs | 3 +- 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/src/raw/dhashmap.rs b/src/raw/dhashmap.rs index a4ef7423..ec847c42 100644 --- a/src/raw/dhashmap.rs +++ b/src/raw/dhashmap.rs @@ -88,8 +88,9 @@ impl DHashMap { } #[inline] - pub(super) fn remove_nth(&mut self, hash: u64, idx: DHMIdx) { - debug_assert_eq!(self.data.len() as DHMIdx - 1, idx); + pub(super) fn remove_nth(&mut self, hash: u64, idx: usize) { + debug_assert_eq!(self.data.len() - 1, idx); + let idx = idx as DHMIdx; match self.data.find_entry(hash, |x| x.2 == idx) { Ok(x) => x.remove(), Err(_) => unreachable!(), @@ -147,3 +148,53 @@ impl Debug for DHashMap { f.debug_map().entries(self).finish() } } + +#[cfg(test)] +mod test { + use crate::raw::dhashmap::DHashMap; + use std::fmt::Debug; + use std::hash::{Hash, Hasher}; + + #[derive(Eq, PartialEq, Debug, Clone)] + struct BadHash(T); + + impl Hash for BadHash { + fn hash(&self, _: &mut H) {} + } + + fn test(arr: [(K, V); N]) { + let mut map: DHashMap = DHashMap::default(); + let mut hashes = Vec::new(); + for (k, v) in arr.iter().cloned() { + let (r, hash) = map.get(&k); + assert!(r.is_none()); + hashes.push(hash); + map.insert_with_hash(hash, k, v) + } + assert_eq!(map.len(), N); + for (i, (k, v)) in arr.iter().enumerate().rev() { + let (r, hash) = map.get(k); + assert_eq!(Some(hash), hashes.pop()); + assert_eq!(r, Some(v)); + map.remove_nth(hash, i); + let (r2, hash2) = map.get(k); + assert_eq!(hash2, hash); + assert_eq!(r2, None); + assert_eq!(map.len(), i); + } + } + + #[test] + fn test_base() { + test([('a', "a"), ('b', "b"), ('c', "c")]) + } + + #[test] + fn test_bad_hash() { + test([ + (BadHash('a'), "a"), + (BadHash('b'), "b"), + (BadHash('c'), "c"), + ]) + } +} diff --git a/src/raw/semi_persistent1.rs b/src/raw/semi_persistent1.rs index 5ec5f2c9..9162b7ae 100644 --- a/src/raw/semi_persistent1.rs +++ b/src/raw/semi_persistent1.rs @@ -1,4 +1,3 @@ -use crate::raw::dhashmap::DHMIdx; use crate::raw::{AsUnwrap, RawEClass, RawEGraph, Sealed, UndoLogT, UnionFind}; use crate::{Id, Language}; use std::fmt::Debug; @@ -156,7 +155,7 @@ impl> RawEGraph { let memo_log = &mut self.undo_log.as_mut_unwrap().memo_log; let len = memo_log.len(); for (hash, idx) in memo_log.drain(old_count..).zip(old_count..len).rev() { - self.residual.memo.remove_nth(hash, idx as DHMIdx); + self.residual.memo.remove_nth(hash, idx); } } diff --git a/src/raw/semi_persistent2.rs b/src/raw/semi_persistent2.rs index 5d715424..982b8a5d 100644 --- a/src/raw/semi_persistent2.rs +++ b/src/raw/semi_persistent2.rs @@ -1,4 +1,3 @@ -use crate::raw::dhashmap::DHMIdx; use crate::raw::{AsUnwrap, RawEClass, RawEGraph, Sealed, UndoLogT}; use crate::util::{Entry, HashSet}; use crate::{Id, Language}; @@ -175,7 +174,7 @@ impl> RawEGraph { let memo_log = &mut self.undo_log.as_mut_unwrap().memo_log; let len = memo_log.len(); for (hash, idx) in memo_log.drain(old_count..).zip(old_count..len).rev() { - self.residual.memo.remove_nth(hash, idx as DHMIdx); + self.residual.memo.remove_nth(hash, idx); } } From 34091d09dd2abc741f27b06bf568caa239507174 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Mon, 12 Feb 2024 15:50:09 -0800 Subject: [PATCH 15/20] Clippy --- src/eclass.rs | 1 + src/egraph.rs | 6 +----- src/explain.rs | 2 +- src/raw/dhashmap.rs | 11 +++++------ src/raw/semi_persistent2.rs | 2 +- 5 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/eclass.rs b/src/eclass.rs index e235d58e..fdb47f0b 100644 --- a/src/eclass.rs +++ b/src/eclass.rs @@ -35,6 +35,7 @@ impl EClass { } /// Returns the number of enodes in this eclass. + #[allow(clippy::len_without_is_empty)] // https://github.com/rust-lang/rust-clippy/issues/11165 pub fn len(&self) -> usize { self.nodes.len() } diff --git a/src/egraph.rs b/src/egraph.rs index e35e3787..9ff10c1e 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -707,11 +707,7 @@ impl> EGraph { enode, |this, existing_id, enode| { if let Some(explain) = this.explain.as_mut() { - if let Some(existing_id) = explain.uncanon_memo.get(enode) { - Some(*existing_id) - } else { - None - } + explain.uncanon_memo.get(enode).copied() } else { Some(existing_id) } diff --git a/src/explain.rs b/src/explain.rs index c25f6ba2..76ae5e4e 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1074,7 +1074,7 @@ impl Explain { equalities } - pub(crate) fn with_raw_egraph<'a, X>(&'a mut self, raw: X) -> ExplainWith<'a, L, X> { + pub(crate) fn with_raw_egraph(&mut self, raw: X) -> ExplainWith<'_, L, X> { ExplainWith { explain: self, raw } } } diff --git a/src/raw/dhashmap.rs b/src/raw/dhashmap.rs index ec847c42..bc2ac8eb 100644 --- a/src/raw/dhashmap.rs +++ b/src/raw/dhashmap.rs @@ -53,21 +53,19 @@ fn hash_one(hasher: &impl BuildHasher, hash: impl Hash) -> u64 { } #[inline] -fn eq<'a, K: Eq, V>(k: &'a K) -> impl Fn(&(K, V, DHMIdx)) -> bool + 'a { +fn eq(k: &K) -> impl Fn(&(K, V, DHMIdx)) -> bool + '_ { move |x| &x.0 == k } #[inline] -fn hasher_fn<'a, K: Hash, V, S: BuildHasher>( - hasher: &'a S, -) -> impl Fn(&(K, V, DHMIdx)) -> u64 + 'a { +fn hasher_fn(hasher: &S) -> impl Fn(&(K, V, DHMIdx)) -> u64 + '_ { move |x| hash_one(hasher, &x.0) } impl DHashMap { #[inline] pub(super) fn entry(&mut self, k: K) -> (Entry<'_, K, V>, u64) { - let hash = hash_one(&mut self.hasher, &k); + let hash = hash_one(&self.hasher, &k); let len = self.data.len() as DHMIdx; let entry = match self.data.entry(hash, eq(&k), hasher_fn(&self.hasher)) { hash_table::Entry::Occupied(entry) => Entry::Occupied((k, &mut entry.into_mut().1)), @@ -136,7 +134,7 @@ impl FromIterator<(K, V)> for DHashMa fn from_iter>(iter: T) -> Self { let mut res = Self::default(); iter.into_iter().for_each(|(k, v)| { - let hash = hash_one(&mut res.hasher, &k); + let hash = hash_one(&res.hasher, &k); res.insert_with_hash(hash, k, v) }); res @@ -158,6 +156,7 @@ mod test { #[derive(Eq, PartialEq, Debug, Clone)] struct BadHash(T); + #[allow(clippy::derive_hash_xor_eq)] // We explicitly want to test a bad implementation impl Hash for BadHash { fn hash(&self, _: &mut H) {} } diff --git a/src/raw/semi_persistent2.rs b/src/raw/semi_persistent2.rs index 982b8a5d..3e321e9a 100644 --- a/src/raw/semi_persistent2.rs +++ b/src/raw/semi_persistent2.rs @@ -166,7 +166,7 @@ impl> RawEGraph { self.pop_memo2(memo_log_count); self.pop_parents2(pop_parents_count, node_count); self.pop_unions2(union_count, node_count, state, clear, mk_data, handle_eqv); - self.pop_nodes2(usize::from(node_count)); + self.pop_nodes2(node_count); } fn pop_memo2(&mut self, old_count: usize) { From 1af40397a9d2be1c804a4ba795a997456902d821 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Mon, 12 Feb 2024 15:58:40 -0800 Subject: [PATCH 16/20] Fixed error messages --- src/egraph.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 9ff10c1e..d4cf717c 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1161,7 +1161,7 @@ where "`push` can only be called on clean egraphs" ); if !self.inner.has_undo_log() { - panic!("Use egraph.with_push_pop() before running to call push"); + panic!("Use egraph.with_push_pop_enabled() before running to call push"); } N::pre_push(self); let exp_push_info = self.explain.as_ref().map(Explain::push).unwrap_or_default(); @@ -1195,7 +1195,7 @@ where /// Equivalent to calling [`pop`](EGraph::pop) `n` times but possibly more efficient pub fn pop_n(&mut self, n: usize) { if !self.inner.has_undo_log() { - panic!("Use egraph.with_push_pop() before running to call pop"); + panic!("Use egraph.with_push_pop_enabled() before running to call pop"); } if n > self.push_log.len() { self.clear() From 2c1c95f558355397eae0b8327bdd3903be4fc8da Mon Sep 17 00:00:00 2001 From: dewert99 Date: Fri, 8 Mar 2024 10:28:55 -0800 Subject: [PATCH 17/20] Fixed `pop` to clear `pending` --- src/raw/semi_persistent1.rs | 1 + src/raw/semi_persistent2.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/src/raw/semi_persistent1.rs b/src/raw/semi_persistent1.rs index 9162b7ae..e71b3dde 100644 --- a/src/raw/semi_persistent1.rs +++ b/src/raw/semi_persistent1.rs @@ -145,6 +145,7 @@ impl> RawEGraph { memo_log_count, pop_parents_count, } = info; + self.pending.clear(); self.pop_memo1(memo_log_count); self.pop_unions1(union_count, pop_parents_count, split); self.pop_nodes1(node_count); diff --git a/src/raw/semi_persistent2.rs b/src/raw/semi_persistent2.rs index 3e321e9a..31639515 100644 --- a/src/raw/semi_persistent2.rs +++ b/src/raw/semi_persistent2.rs @@ -163,6 +163,7 @@ impl> RawEGraph { memo_log_count, pop_parents_count, } = info; + self.pending.clear(); self.pop_memo2(memo_log_count); self.pop_parents2(pop_parents_count, node_count); self.pop_unions2(union_count, node_count, state, clear, mk_data, handle_eqv); From da807b00a3579820a1885e74d9c8e5f83a1afeab Mon Sep 17 00:00:00 2001 From: dewert99 Date: Mon, 11 Mar 2024 16:33:25 -0700 Subject: [PATCH 18/20] Added method to access union find without path compression from undo-log --- src/raw/semi_persistent1.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/raw/semi_persistent1.rs b/src/raw/semi_persistent1.rs index e71b3dde..eddf5526 100644 --- a/src/raw/semi_persistent1.rs +++ b/src/raw/semi_persistent1.rs @@ -151,6 +151,11 @@ impl> RawEGraph { self.pop_nodes1(node_count); } + /// Return the direct parent from the union find without path compression + pub fn find_direct_parent(&self, id: Id) -> Id { + self.undo_log.as_unwrap().undo_find.parent(id) + } + fn pop_memo1(&mut self, old_count: usize) { assert!(self.memo.len() >= old_count); let memo_log = &mut self.undo_log.as_mut_unwrap().memo_log; From fb07f3bcc7466d803099eb3dfc4fb9dab79641a8 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 20 Mar 2024 16:41:39 -0700 Subject: [PATCH 19/20] Make `raw_union` more flexible and add a fallible `try_raw_rebuild` --- src/egraph.rs | 25 +++++++-------- src/raw/egraph.rs | 78 +++++++++++++++++++++++++++++++++++++---------- 2 files changed, 75 insertions(+), 28 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 0eafb9d4..3f35eddb 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -804,20 +804,21 @@ impl> EGraph { self.clean = false; let mut new_root = None; - self.inner - .raw_union(enode_id1, enode_id2, |class1, id1, p1, class2, _, p2| { - new_root = Some(id1); + self.inner.raw_union(enode_id1, enode_id2, |info| { + new_root = Some(info.id1); - let did_merge = self.analysis.merge(&mut class1.data, class2.data); - if did_merge.0 { - self.analysis_pending.extend(p1); - } - if did_merge.1 { - self.analysis_pending.extend(p2); - } + let did_merge = self.analysis.merge(&mut info.data1.data, info.data2.data); + if did_merge.0 { + self.analysis_pending + .extend(info.parents1.into_iter().copied()); + } + if did_merge.1 { + self.analysis_pending + .extend(info.parents2.into_iter().copied()); + } - concat_vecs(&mut class1.nodes, class2.nodes); - }); + concat_vecs(&mut info.data1.nodes, info.data2.nodes); + }); if let Some(id) = new_root { if let Some(explain) = &mut self.explain { explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs); diff --git a/src/raw/egraph.rs b/src/raw/egraph.rs index d9e4592e..383b1361 100644 --- a/src/raw/egraph.rs +++ b/src/raw/egraph.rs @@ -1,4 +1,5 @@ use crate::{raw::RawEClass, Dot, HashMap, Id, Language, RecExpr, UnionFind}; +use std::convert::Infallible; use std::ops::{Deref, DerefMut}; use std::{ borrow::BorrowMut, @@ -426,6 +427,26 @@ impl RawEGraph { } } +/// Information for [`RawEGraph::raw_union`] callback +#[non_exhaustive] +pub struct MergeInfo<'a, D: 'a> { + /// id that will be the root for the newly merged eclass + pub id1: Id, + /// data associated with `id1` that can be modified to reflect `data2` being merged into it + pub data1: &'a mut D, + /// parents of `id1` before the merge + pub parents1: &'a [Id], + /// id that used to be a root but will now be in `id1` eclass + pub id2: Id, + /// data associated with `id2` + pub data2: D, + /// parents of `id2` before the merge + pub parents2: &'a [Id], + /// true if `id1` was the root of the second id passed to [`RawEGraph::raw_union`] + /// false if `id1` was the root of the first id passed to [`RawEGraph::raw_union`] + pub swapped_ids: bool, +} + impl RawEGraph { /// Adds `enode` to a [`RawEGraph`] contained within a wrapper type `T` /// @@ -524,7 +545,7 @@ impl RawEGraph { &mut self, enode_id1: Id, enode_id2: Id, - merge: impl FnOnce(&mut D, Id, Parents<'_>, D, Id, Parents<'_>), + merge: impl FnOnce(MergeInfo<'_, D>), ) { let mut id1 = self.find_mut(enode_id1); let mut id2 = self.find_mut(enode_id2); @@ -534,7 +555,9 @@ impl RawEGraph { // make sure class2 has fewer parents let class1_parents = self.classes[&id1].parents.len(); let class2_parents = self.classes[&id2].parents.len(); + let mut swapped = false; if class1_parents < class2_parents { + swapped = true; std::mem::swap(&mut id1, &mut id2); } @@ -545,22 +568,22 @@ impl RawEGraph { let class2 = self.classes.remove(&id2).unwrap(); let class1 = self.classes.get_mut(&id1).unwrap(); assert_eq!(id1, class1.id); - let (p1, p2) = (Parents(&class1.parents), Parents(&class2.parents)); - merge( - &mut class1.raw_data, - class1.id, - p1, - class2.raw_data, - class2.id, - p2, - ); + let info = MergeInfo { + id1: class1.id, + data1: &mut class1.raw_data, + parents1: &class1.parents, + id2: class2.id, + data2: class2.raw_data, + parents2: &class2.parents, + swapped_ids: swapped, + }; + merge(info); self.pending.extend(&class2.parents); class1.parents.extend(class2.parents); } - #[inline] /// Rebuild to [`RawEGraph`] to restore congruence closure /// /// ## Parameters @@ -576,14 +599,31 @@ impl RawEGraph { /// In order to be correct `perform_union` should call [`raw_union`](RawEGraph::raw_union) /// /// ### `handle_pending` - /// Called with the uncanonical id of each enode whose canonical children have changned, along with a canonical + /// Called with the uncanonical id of each enode whose canonical children have changed, along with a canonical /// version of it + #[inline] pub fn raw_rebuild( outer: &mut T, get_self: impl Fn(&mut T) -> &mut Self, mut perform_union: impl FnMut(&mut T, Id, Id), - mut handle_pending: impl FnMut(&mut T, Id, &L), + handle_pending: impl FnMut(&mut T, Id, &L), ) { + let _: Result<(), Infallible> = RawEGraph::try_raw_rebuild( + outer, + get_self, + |this, id1, id2| Ok(perform_union(this, id1, id2)), + handle_pending, + ); + } + + /// Similar to [`raw_rebuild`] but allows for the union operation to fail and abort the rebuild + #[inline] + pub fn try_raw_rebuild( + outer: &mut T, + get_self: impl Fn(&mut T) -> &mut Self, + mut perform_union: impl FnMut(&mut T, Id, Id) -> Result<(), E>, + mut handle_pending: impl FnMut(&mut T, Id, &L), + ) -> Result<(), E> { loop { let this = get_self(outer); if let Some(class) = this.pending.pop() { @@ -591,10 +631,16 @@ impl RawEGraph { node.update_children(|id| this.find_mut(id)); handle_pending(outer, class, &node); if let Some(memo_class) = get_self(outer).residual.memo.insert(node, class) { - perform_union(outer, memo_class, class); + match perform_union(outer, memo_class, class) { + Ok(()) => {} + Err(e) => { + get_self(outer).pending.push(class); + return Err(e); + } + } } } else { - break; + break Ok(()); } } } @@ -638,7 +684,7 @@ impl RawEGraph { /// Simplified version of [`raw_union`](RawEGraph::raw_union) for egraphs without eclass data pub fn union(&mut self, id1: Id, id2: Id) -> bool { let mut unioned = false; - self.raw_union(id1, id2, |_, _, _, _, _, _| { + self.raw_union(id1, id2, |_| { unioned = true; }); unioned From 91c1a10926c950e02841df6dcb84a9bc48d226e4 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Thu, 28 Mar 2024 15:35:39 -0700 Subject: [PATCH 20/20] Bump symbol_table --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 59545916..f87e9493 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ indexmap = "1.8.1" instant = "0.1.12" log = "0.4.17" smallvec = { version = "1.8.0", features = ["union", "const_generics"] } -symbol_table = { version = "0.2.0", features = ["global"] } +symbol_table = { version = "0.3.0", features = ["global"] } symbolic_expressions = "5.0.3" thiserror = "1.0.31"