Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,105 @@ where
}
}

#[derive(Clone)]
pub struct AndCondition<L, N>
where
L: Language,
N: Analysis<L>,
{
pub c1: Arc<dyn Condition<L, N> + Send + Sync>,
pub c2: Arc<dyn Condition<L, N> + Send + Sync>,
}

impl<L: Language, N: Analysis<L>> Condition<L, N> for AndCondition<L, N> {
#[inline]
fn check(&self, egraph: &mut EGraph<L, N>, id: Id, subst: &Subst) -> bool {
self.c1.check(egraph, id, subst) && self.c2.check(egraph, id, subst)
}
fn vars(&self) -> Vec<Var> {
let mut vars = self.c1.vars();
vars.extend(self.c2.vars());
vars
}
}

#[derive(Clone)]
pub struct OrCondition<L, N>
where
L: Language,
N: Analysis<L>,
{
pub c1: Arc<dyn Condition<L, N> + Send + Sync>,
pub c2: Arc<dyn Condition<L, N> + Send + Sync>,
}

impl<L: Language, N: Analysis<L>> Condition<L, N> for OrCondition<L, N> {
#[inline]
fn check(&self, egraph: &mut EGraph<L, N>, id: Id, subst: &Subst) -> bool {
self.c1.check(egraph, id, subst) || self.c2.check(egraph, id, subst)
}
fn vars(&self) -> Vec<Var> {
let mut vars = self.c1.vars();
vars.extend(self.c2.vars());
vars
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TrueCondition {}

impl<L: Language, N: Analysis<L>> Condition<L, N> for TrueCondition {
#[inline(always)]
fn check(&self, _egraph: &mut EGraph<L, N>, _id: Id, _subst: &Subst) -> bool {
true
}
#[inline(always)]
fn vars(&self) -> Vec<Var> {
vec![]
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FalseCondition {}

impl<L: Language, N: Analysis<L>> Condition<L, N> for FalseCondition {
#[inline(always)]
fn check(&self, _egraph: &mut EGraph<L, N>, _id: Id, _subst: &Subst) -> bool {
false
}
#[inline(always)]
fn vars(&self) -> Vec<Var> {
vec![]
}
}

pub fn any<L, N>(
conds: Vec<Arc<dyn Condition<L, N> + Send + Sync>>,
) -> Arc<dyn Condition<L, N> + Send + Sync>
where
L: Language + 'static,
N: Analysis<L> + 'static,
{
conds
.into_iter()
.fold(Arc::new(FalseCondition {}), |acc, c| {
Arc::new(OrCondition { c1: acc, c2: c })
})
}

pub fn all<L, N>(
conds: Vec<Arc<dyn Condition<L, N> + Send + Sync>>,
) -> Arc<dyn Condition<L, N> + Send + Sync>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also easily change the return type here into Arc<AndCondition<L,N>> if we wanted to (same for any with OrCondition) by making the base an and of true conditions. Little less efficient but nicer types

where
L: Language + 'static,
N: Analysis<L> + 'static,
{
conds
.into_iter()
.fold(Arc::new(TrueCondition {}), |acc, c| {
Arc::new(AndCondition { c1: acc, c2: c })
})
}

#[cfg(test)]
mod tests {

Expand Down
Loading