From 1df8e8bee04c191f6802c88ea4c1044bd86fb3ab Mon Sep 17 00:00:00 2001 From: Adrian Lehmann Date: Fri, 7 Mar 2025 15:24:26 -0600 Subject: [PATCH] Add multi-condition rewrites --- src/rewrite.rs | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/src/rewrite.rs b/src/rewrite.rs index eae5755e..0ceb39e6 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -538,6 +538,105 @@ where } } +#[derive(Clone)] +pub struct AndCondition +where + L: Language, + N: Analysis, +{ + pub c1: Arc + Send + Sync>, + pub c2: Arc + Send + Sync>, +} + +impl> Condition for AndCondition { + #[inline] + fn check(&self, egraph: &mut EGraph, id: Id, subst: &Subst) -> bool { + self.c1.check(egraph, id, subst) && self.c2.check(egraph, id, subst) + } + fn vars(&self) -> Vec { + let mut vars = self.c1.vars(); + vars.extend(self.c2.vars()); + vars + } +} + +#[derive(Clone)] +pub struct OrCondition +where + L: Language, + N: Analysis, +{ + pub c1: Arc + Send + Sync>, + pub c2: Arc + Send + Sync>, +} + +impl> Condition for OrCondition { + #[inline] + fn check(&self, egraph: &mut EGraph, id: Id, subst: &Subst) -> bool { + self.c1.check(egraph, id, subst) || self.c2.check(egraph, id, subst) + } + fn vars(&self) -> Vec { + let mut vars = self.c1.vars(); + vars.extend(self.c2.vars()); + vars + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TrueCondition {} + +impl> Condition for TrueCondition { + #[inline(always)] + fn check(&self, _egraph: &mut EGraph, _id: Id, _subst: &Subst) -> bool { + true + } + #[inline(always)] + fn vars(&self) -> Vec { + vec![] + } +} +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FalseCondition {} + +impl> Condition for FalseCondition { + #[inline(always)] + fn check(&self, _egraph: &mut EGraph, _id: Id, _subst: &Subst) -> bool { + false + } + #[inline(always)] + fn vars(&self) -> Vec { + vec![] + } +} + +pub fn any( + conds: Vec + Send + Sync>>, +) -> Arc + Send + Sync> +where + L: Language + 'static, + N: Analysis + 'static, +{ + conds + .into_iter() + .fold(Arc::new(FalseCondition {}), |acc, c| { + Arc::new(OrCondition { c1: acc, c2: c }) + }) +} + +pub fn all( + conds: Vec + Send + Sync>>, +) -> Arc + Send + Sync> +where + L: Language + 'static, + N: Analysis + 'static, +{ + conds + .into_iter() + .fold(Arc::new(TrueCondition {}), |acc, c| { + Arc::new(AndCondition { c1: acc, c2: c }) + }) +} + #[cfg(test)] mod tests {