Skip to content
This repository was archived by the owner on Apr 25, 2024. It is now read-only.

Make anti-unification more sort-aware #598

Merged
merged 20 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
28157b5
Add least_common_supersort(), use in sort(), use sort() in anti_unify…
nwatson22 Aug 10, 2023
d5895f9
Merge 28157b53dac3a613c8681df13907b5cbd9458d9d into 97a9a21bb10f163e0…
nwatson22 Aug 10, 2023
a8222ac
Set Version: 0.1.412
rv-auditor Aug 10, 2023
b2ad557
Move anti-unification into cterm.py and make anti-unification of cter…
nwatson22 Aug 11, 2023
778cd00
Merge branch 'noah/anti-unify-sort-fix' of https://github.com/runtime…
nwatson22 Aug 11, 2023
fee00bb
Merge master into branch
nwatson22 Aug 14, 2023
b41f7a3
Merge fee00bbc56924ae24af85126efda22d0007ef2e9 into fd06d71322d85f41e…
nwatson22 Aug 14, 2023
3fcf774
Set Version: 0.1.414
rv-auditor Aug 14, 2023
34f35cf
Incorporarte code from #544 which removes constraints for variables t…
nwatson22 Aug 14, 2023
b73dfba
Merge branch 'noah/anti-unify-sort-fix' of https://github.com/runtime…
nwatson22 Aug 14, 2023
410d334
Simplify code, allow pruning of constraints referring to absent varia…
nwatson22 Aug 14, 2023
c782d0f
Fix formatting and error message typing
nwatson22 Aug 14, 2023
b09cd39
Merge branch 'master' into noah/anti-unify-sort-fix
nwatson22 Aug 16, 2023
633caf3
Merge b09cd393d978c417d45863d172626a82e2d5bf17 into 66e1af269c21ab330…
nwatson22 Aug 16, 2023
3a1c459
Set Version: 0.1.415
rv-auditor Aug 16, 2023
9153ffb
Add tests for KDefinition.sort()
nwatson22 Aug 16, 2023
365f7ec
Merge branch 'noah/anti-unify-sort-fix' of https://github.com/runtime…
nwatson22 Aug 16, 2023
4bb5db9
Merge branch 'master' into noah/anti-unify-sort-fix
nwatson22 Aug 16, 2023
28b0b76
Merge 4bb5db9cfd6cf6894b7a09d504234b407f77573c into 3dbb978c37d9e006e…
nwatson22 Aug 16, 2023
4b7472e
Set Version: 0.1.416
rv-auditor Aug 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package/version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.415
0.1.416
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "pyk"
version = "0.1.415"
version = "0.1.416"
description = ""
authors = [
"Runtime Verification, Inc. <[email protected]>",
Expand Down
63 changes: 60 additions & 3 deletions src/pyk/cterm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from itertools import chain
from typing import TYPE_CHECKING

from .kast.inner import KApply, KInner, KRewrite, KVariable, Subst
from .kast.inner import KApply, KInner, KRewrite, KToken, KVariable, Subst, bottom_up
from .kast.kast import KAtt
from .kast.manip import (
abstract_term_safely,
apply_existential_substitutions,
count_vars,
flatten_label,
Expand All @@ -22,13 +23,16 @@
)
from .kast.outer import KClaim, KRule
from .prelude.k import GENERATED_TOP_CELL
from .prelude.ml import is_top, mlAnd, mlImplies, mlTop
from .prelude.kbool import orBool
from .prelude.ml import is_top, mlAnd, mlEqualsTrue, mlImplies, mlTop
from .utils import unique

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
from typing import Any

from .kast.outer import KDefinition


@dataclass(frozen=True, order=True)
class CTerm:
Expand Down Expand Up @@ -56,7 +60,7 @@ def from_dict(dct: dict[str, Any]) -> CTerm:
@staticmethod
def _check_config(config: KInner) -> None:
if not isinstance(config, KApply) or not config.is_cell:
raise ValueError('Expected cell label, found: {config.label.name}')
raise ValueError(f'Expected cell label, found: {config}')

@staticmethod
def _normalize_constraints(constraints: Iterable[KInner]) -> tuple[KInner, ...]:
Expand Down Expand Up @@ -138,6 +142,59 @@ def _ml_impl(antecedents: Iterable[KInner], consequents: Iterable[KInner]) -> KI
def add_constraint(self, new_constraint: KInner) -> CTerm:
return CTerm(self.config, [new_constraint] + list(self.constraints))

def anti_unify(
self, other: CTerm, keep_values: bool = False, kdef: KDefinition | None = None
) -> tuple[CTerm, CSubst, CSubst]:
def disjunction_from_substs(subst1: Subst, subst2: Subst) -> KInner:
if KToken('true', 'Bool') in [subst1.pred, subst2.pred]:
return mlTop()
return mlEqualsTrue(orBool([subst1.pred, subst2.pred]))

new_config, self_subst, other_subst = anti_unify(self.config, other.config, kdef=kdef)
common_constraints = [constraint for constraint in self.constraints if constraint in other.constraints]

new_cterm = CTerm(
config=new_config, constraints=([disjunction_from_substs(self_subst, other_subst)] if keep_values else [])
)

new_constraints = []
fvs = free_vars(new_cterm.kast)
len_fvs = 0
while len_fvs < len(fvs):
len_fvs = len(fvs)
for constraint in common_constraints:
if constraint not in new_constraints:
constraint_fvs = free_vars(constraint)
if any(fv in fvs for fv in constraint_fvs):
new_constraints.append(constraint)
fvs.extend(constraint_fvs)

for constraint in new_constraints:
new_cterm = new_cterm.add_constraint(constraint)
self_csubst = new_cterm.match_with_constraint(self)
other_csubst = new_cterm.match_with_constraint(other)
if self_csubst is None or other_csubst is None:
raise ValueError(
f'Anti-unification failed to produce a more general state: {(new_cterm, (self, self_csubst), (other, other_csubst))}'
)
return (new_cterm, self_csubst, other_csubst)


def anti_unify(state1: KInner, state2: KInner, kdef: KDefinition | None = None) -> tuple[KInner, Subst, Subst]:
def _rewrites_to_abstractions(_kast: KInner) -> KInner:
if type(_kast) is KRewrite:
sort = kdef.sort(_kast) if kdef else None
return abstract_term_safely(_kast, sort=sort)
return _kast

minimized_rewrite = push_down_rewrites(KRewrite(state1, state2))
abstracted_state = bottom_up(_rewrites_to_abstractions, minimized_rewrite)
subst1 = abstracted_state.match(state1)
subst2 = abstracted_state.match(state2)
if subst1 is None or subst2 is None:
raise ValueError('Anti-unification failed to produce a more general state!')
return (abstracted_state, subst1, subst2)


@dataclass(frozen=True, order=True)
class CSubst:
Expand Down
54 changes: 1 addition & 53 deletions src/pyk/kast/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..prelude.k import DOTS, GENERATED_TOP_CELL
from ..prelude.kbool import FALSE, TRUE, andBool, impliesBool, notBool, orBool
from ..prelude.ml import mlAnd, mlEqualsTrue, mlImplies, mlOr, mlTop
from ..prelude.ml import mlAnd, mlEqualsTrue, mlOr
from ..utils import find_common_items, hash_str
from .inner import KApply, KRewrite, KSequence, KToken, KVariable, Subst, bottom_up, top_down, var_occurrences
from .kast import EMPTY_ATT, KAtt, WithKAtt
Expand Down Expand Up @@ -582,58 +582,6 @@ def _abstract(k: KInner) -> KVariable:
return new_var


def anti_unify(state1: KInner, state2: KInner) -> tuple[KInner, Subst, Subst]:
def _rewrites_to_abstractions(_kast: KInner) -> KInner:
if type(_kast) is KRewrite:
return abstract_term_safely(_kast)
return _kast

minimized_rewrite = push_down_rewrites(KRewrite(state1, state2))
abstracted_state = bottom_up(_rewrites_to_abstractions, minimized_rewrite)
subst1 = abstracted_state.match(state1)
subst2 = abstracted_state.match(state2)
if subst1 is None or subst2 is None:
raise ValueError('Anti-unification failed to produce a more general state!')
return (abstracted_state, subst1, subst2)


def anti_unify_with_constraints(
constrained_term_1: KInner,
constrained_term_2: KInner,
implications: bool = False,
constraint_disjunct: bool = False,
abstracted_disjunct: bool = False,
) -> KInner:
def disjunction_from_substs(subst1: Subst, subst2: Subst) -> KInner:
if KToken('true', 'Bool') in [subst1.pred, subst2.pred]:
return mlTop()
return mlEqualsTrue(orBool([subst1.pred, subst2.pred]))

state1, constraint1 = split_config_and_constraints(constrained_term_1)
state2, constraint2 = split_config_and_constraints(constrained_term_2)
constraints1 = flatten_label('#And', constraint1)
constraints2 = flatten_label('#And', constraint2)
state, subst1, subst2 = anti_unify(state1, state2)

constraints = [c for c in constraints1 if c in constraints2]
constraint1 = mlAnd([c for c in constraints1 if c not in constraints])
constraint2 = mlAnd([c for c in constraints2 if c not in constraints])
implication1 = mlImplies(constraint1, subst1.ml_pred)
implication2 = mlImplies(constraint2, subst2.ml_pred)

if abstracted_disjunct:
constraints.append(disjunction_from_substs(subst1, subst2))

if implications:
constraints.append(implication1)
constraints.append(implication2)

if constraint_disjunct:
constraints.append(mlOr([constraint1, constraint2]))

return mlAnd([state] + constraints)


def apply_existential_substitutions(constrained_term: KInner) -> KInner:
state, constraint = split_config_and_constraints(constrained_term)
constraints = flatten_label('#And', constraint)
Expand Down
20 changes: 18 additions & 2 deletions src/pyk/kast/outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,8 +1101,11 @@ def sort(self, kast: KInner) -> KSort | None:
case KToken(_, sort) | KVariable(_, sort):
return sort
case KRewrite(lhs, rhs):
sort = self.sort(lhs)
return sort if sort == self.sort(rhs) else None
lhs_sort = self.sort(lhs)
rhs_sort = self.sort(rhs)
if lhs_sort and rhs_sort:
return self.least_common_supersort(lhs_sort, rhs_sort)
return None
case KSequence(_):
return KSort('K')
case KApply(label, _):
Expand All @@ -1128,13 +1131,26 @@ def sort_strict(self, kast: KInner) -> KSort:
raise ValueError(f'Could not determine sort of term: {kast}')
return sort

def least_common_supersort(self, sort1: KSort, sort2: KSort) -> KSort | None:
if sort1 == sort2:
return sort1
if sort1 in self.subsorts(sort2):
return sort2
if sort2 in self.subsorts(sort1):
return sort1
# Computing least common supersort is not currently supported if sort1 is not a subsort of sort2 or
# vice versa. In that case there may be more than one LCS.
return None

def greatest_common_subsort(self, sort1: KSort, sort2: KSort) -> KSort | None:
if sort1 == sort2:
return sort1
if sort1 in self.subsorts(sort2):
return sort1
if sort2 in self.subsorts(sort1):
return sort2
# Computing greatest common subsort is not currently supported if sort1 is not a subsort of sort2 or
# vice versa. In that case there may be more than one GCS.
return None

# Sorts like Int cannot be injected directly into sort K so they are embedded in a KSequence.
Expand Down
138 changes: 136 additions & 2 deletions src/tests/integration/kcfg/test_imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from pyk.cterm import CSubst, CTerm
from pyk.kast.inner import KApply, KSequence, KSort, KToken, KVariable, Subst
from pyk.kast.manip import minimize_term
from pyk.kast.manip import get_cell, minimize_term
from pyk.kcfg.semantics import KCFGSemantics
from pyk.kcfg.show import KCFGShow
from pyk.prelude.kbool import BOOL, notBool
from pyk.prelude.kbool import BOOL, notBool, orBool
from pyk.prelude.kint import intToken
from pyk.prelude.ml import mlAnd, mlBottom, mlEqualsFalse, mlEqualsTrue, mlTop
from pyk.proof import APRBMCProof, APRBMCProver, APRProof, APRProver, ProofStatus
Expand Down Expand Up @@ -1147,3 +1147,137 @@ def test_fail_fast(
assert len(proof.pending) == 1
assert len(proof.terminal) == 1
assert len(proof.failing) == 1

def test_anti_unify_forget_values(
self,
kcfg_explore: KCFGExplore,
kprint: KPrint,
) -> None:
cterm1 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> X:Int',
constraint=mlAnd(
[
mlEqualsTrue(KApply('_>Int_', [KVariable('K', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])),
]
),
)
cterm2 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> Y:Int',
constraint=mlAnd(
[
mlEqualsTrue(KApply('_>Int_', [KVariable('K', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])),
]
),
)

anti_unifier, subst1, subst2 = cterm1.anti_unify(cterm2, keep_values=False, kdef=kprint.definition)

k_cell = get_cell(anti_unifier.kast, 'STATE_CELL')
assert type(k_cell) is KApply
assert k_cell.label.name == '_|->_'
assert type(k_cell.args[1]) is KVariable
abstracted_var: KVariable = k_cell.args[1]

expected_anti_unifier = self.config(
kprint=kprint,
k='int $n ; { }',
state=f'N |-> {abstracted_var.name}:Int',
constraint=mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
)

assert anti_unifier.kast == expected_anti_unifier.kast

def test_anti_unify_keep_values(
self,
kcfg_explore: KCFGExplore,
kprint: KPrint,
) -> None:
cterm1 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> X:Int',
constraint=mlAnd(
[
mlEqualsTrue(KApply('_>Int_', [KVariable('K', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])),
]
),
)
cterm2 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> Y:Int',
constraint=mlAnd(
[
mlEqualsTrue(KApply('_>Int_', [KVariable('K', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])),
]
),
)

anti_unifier, subst1, subst2 = cterm1.anti_unify(cterm2, keep_values=True, kdef=kprint.definition)

k_cell = get_cell(anti_unifier.kast, 'STATE_CELL')
assert type(k_cell) is KApply
assert k_cell.label.name == '_|->_'
assert type(k_cell.args[1]) is KVariable
abstracted_var: KVariable = k_cell.args[1]

expected_anti_unifier = self.config(
kprint=kprint,
k='int $n ; { }',
state=f'N |-> {abstracted_var.name}:Int',
constraint=mlAnd(
[
mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(
orBool(
[
KApply('_==K_', [KVariable(name=abstracted_var.name), KVariable('X', 'Int')]),
KApply('_==K_', [KVariable(name=abstracted_var.name), KVariable('Y', 'Int')]),
]
)
),
]
),
)

assert anti_unifier.kast == expected_anti_unifier.kast

def test_anti_unify_subst_true(
self,
kcfg_explore: KCFGExplore,
kprint: KPrint,
) -> None:
cterm1 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> 0',
constraint=mlEqualsTrue(KApply('_==K_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
)
cterm2 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> 0',
constraint=mlEqualsTrue(KApply('_==K_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
)

anti_unifier, _, _ = cterm1.anti_unify(cterm2, keep_values=True, kdef=kprint.definition)

assert anti_unifier.kast == cterm1.kast
Loading