From c87dd1576f682ed341616402e90db4f816ae2c7e Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Fri, 29 Apr 2022 20:20:53 -0500 Subject: [PATCH 1/2] add equality test --- test/test_pymbolic.py | 154 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 146 insertions(+), 8 deletions(-) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 96b6d678..32b18343 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -20,19 +20,14 @@ THE SOFTWARE. """ -import pymbolic.primitives as prim import pytest +from functools import reduce + +import pymbolic.primitives as prim from pymbolic import parse from pytools.lex import ParseError - - from pymbolic.mapper import IdentityMapper -try: - reduce -except NameError: - from functools import reduce - # {{{ utilities @@ -61,6 +56,8 @@ def assert_parse_roundtrip(expr_str): # }}} +# {{{ test_integer_power + def test_integer_power(): from pymbolic.algorithm import integer_power @@ -72,6 +69,10 @@ def test_integer_power(): ]: assert base**expn == integer_power(base, expn) +# }}} + + +# {{{ test_expand def test_expand(): from pymbolic import var, expand @@ -80,6 +81,10 @@ def test_expand(): u = (x+1)**5 expand(u) +# }}} + + +# {{{ test_substitute def test_substitute(): from pymbolic import parse, substitute, evaluate @@ -87,6 +92,10 @@ def test_substitute(): xmin = parse("x.min") assert evaluate(substitute(u, {xmin: 25})) == 630 +# }}} + + +# {{{ test_no_comparison def test_no_comparison(): from pymbolic import parse @@ -107,6 +116,10 @@ def expect_typeerror(f): expect_typeerror(lambda: x > y) expect_typeerror(lambda: x >= y) +# }}} + + +# {{{ test_structure_preservation def test_structure_preservation(): x = prim.Sum((5, 7)) @@ -114,6 +127,10 @@ def test_structure_preservation(): x2 = IdentityMapper()(x) assert x == x2 +# }}} + + +# {{{ test_sympy_interaction def test_sympy_interaction(): pytest.importorskip("sympy") @@ -141,6 +158,8 @@ def test_sympy_interaction(): assert sp.ratsimp(s1_expr - s3_expr) == 0 +# }}} + # {{{ fft @@ -201,6 +220,8 @@ def test_fft(): # }}} +# {{{ test_sparse_multiply + def test_sparse_multiply(): numpy = pytest.importorskip("numpy") pytest.importorskip("scipy") @@ -219,6 +240,8 @@ def test_sparse_multiply(): assert la.norm(mat_vec-mat_vec_2) < 1e-14 +# }}} + # {{{ parser @@ -295,6 +318,8 @@ def test_parser(): # }}} +# {{{ test_mappers + def test_mappers(): from pymbolic import variables f, x, y, z = variables("f x y z") @@ -310,6 +335,11 @@ def test_mappers(): DependencyMapper()(expr) +# }}} + + +# {{{ test_func_dep_consistency + def test_func_dep_consistency(): from pymbolic import var from pymbolic.mapper.dependency import DependencyMapper @@ -319,6 +349,10 @@ def test_func_dep_consistency(): assert dep_map(f(x)) == {x} assert dep_map(f(x=x)) == {x} +# }}} + + +# {{{ test_conditions def test_conditions(): from pymbolic import var @@ -326,6 +360,10 @@ def test_conditions(): y = var("y") assert str(x.eq(y).and_(x.le(5))) == "x == y and x <= 5" +# }}} + + +# {{{ test_graphviz def test_graphviz(): from pymbolic import parse @@ -338,6 +376,8 @@ def test_graphviz(): gvm(expr) print(gvm.get_dot_code()) +# }}} + # {{{ geometric algebra @@ -443,6 +483,8 @@ def test_geometric_algebra(dims): # }}} +# {{{ test_ast_interop + def test_ast_interop(): src = """ def f(): @@ -472,6 +514,10 @@ def f(): print(lhs, rhs) +# }}} + + +# {{{ test_compile def test_compile(): from pymbolic import parse, compile @@ -483,6 +529,10 @@ def test_compile(): code = pickle.loads(pickle.dumps(code)) assert code(3, 3) == 27 +# }}} + + +# {{{ test_unifier def test_unifier(): from pymbolic import var @@ -521,6 +571,10 @@ def match_found(records, eqns): assert len(recs) == 1 assert match_found(recs, {(a, b), (b, c), (c, d)}) +# }}} + + +# {{{ test_long_sympy_mapping def test_long_sympy_mapping(): sp = pytest.importorskip("sympy") @@ -528,6 +582,10 @@ def test_long_sympy_mapping(): SympyToPymbolicMapper()(sp.sympify(int(10**20))) SympyToPymbolicMapper()(sp.sympify(int(10))) +# }}} + + +# {{{ test_stringifier_preserve_shift_order def test_stringifier_preserve_shift_order(): for expr in [ @@ -536,6 +594,10 @@ def test_stringifier_preserve_shift_order(): ]: assert parse(str(expr)) == expr +# }}} + + +# {{{ test_latex_mapper LATEX_TEMPLATE = r"""\documentclass{article} \usepackage{amsmath} @@ -604,6 +666,10 @@ def add(expr): finally: shutil.rmtree(latex_dir) +# }}} + + +# {{{ test_flop_counter def test_flop_counter(): x = prim.Variable("x") @@ -618,6 +684,10 @@ def test_flop_counter(): assert CSEAwareFlopCounter()(expr) == 4 + 2 +# }}} + + +# {{{ test_make_sym_vector def test_make_sym_vector(): numpy = pytest.importorskip("numpy") @@ -627,6 +697,10 @@ def test_make_sym_vector(): assert len(make_sym_vector("vec", numpy.int32(2))) == 2 assert len(make_sym_vector("vec", [1, 2, 3])) == 3 +# }}} + + +# {{{ test_multiplicative_stringify_preserves_association def test_multiplicative_stringify_preserves_association(): for inner in ["*", " / ", " // ", " % "]: @@ -639,6 +713,10 @@ def test_multiplicative_stringify_preserves_association(): assert_parse_roundtrip("(-1)*(((-1)*x) / 5)") +# }}} + + +# {{{ test_differentiator_flags_for_nonsmooth_and_discontinuous def test_differentiator_flags_for_nonsmooth_and_discontinuous(): import pymbolic.functions as pf @@ -658,6 +736,10 @@ def test_differentiator_flags_for_nonsmooth_and_discontinuous(): result = differentiate(pf.sign(x), x, allowed_nonsmoothness="discontinuous") assert result == 0 +# }}} + + +# {{{ test_diff_cse def test_diff_cse(): from pymbolic.mapper.differentiator import differentiate @@ -686,6 +768,10 @@ def test_diff_cse(): assert err2 < 1.1 * 0.5**2 * err1 +# }}} + + +# {{{ test_coefficient_collector def test_coefficient_collector(): from pymbolic.mapper.coefficient import CoefficientCollector @@ -698,6 +784,10 @@ def test_coefficient_collector(): assert cc(2*x + y - z) == {x: 2, y: 1, 1: -z} assert cc(x/2 + z**2) == {x: prim.Quotient(1, 2), 1: z**2} +# }}} + + +# {{{ test_np_bool_handling def test_np_bool_handling(): from pymbolic.mapper.evaluator import evaluate @@ -705,6 +795,10 @@ def test_np_bool_handling(): expr = prim.LogicalNot(numpy.bool_(False)) assert evaluate(expr) is True +# }}} + + +# {{{ test_mapper_method_of_parent_class def test_mapper_method_of_parent_class(): class SpatialConstant(prim.Variable): @@ -719,6 +813,50 @@ def map_spatial_constant(self, expr): assert MyMapper()(c) == 2*c assert IdentityMapper()(c) == c +# }}} + + +# {{{ test_equality_complexity + +@pytest.mark.xfail +def test_equality_complexity(): + # NOTE: https://github.com/inducer/pymbolic/issues/73 + from numpy.random import default_rng + + def construct_intestine_graph(depth=64, seed=0): + rng = default_rng(seed) + x = prim.Variable("x") + + for _ in range(depth): + coeff1, coeff2 = rng.integers(1, 10, 2) + x = coeff1 * x + coeff2 * x + + return x + + def check_equality(): + graph1 = construct_intestine_graph() + graph2 = construct_intestine_graph() + graph3 = construct_intestine_graph(seed=3) + + assert graph1 == graph2 + assert graph2 == graph1 + assert graph1 != graph3 + assert graph2 != graph3 + + # NOTE: this should finish in a second! + import multiprocessing + p = multiprocessing.Process(target=check_equality) + p.start() + p.join(timeout=1) + + is_alive = p.is_alive() + if p.is_alive(): + p.terminate() + + assert not is_alive + +# }}} + if __name__ == "__main__": import sys From d2364df109f03c6f898abd35a132c5b64541960e Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Fri, 29 Apr 2022 20:32:20 -0500 Subject: [PATCH 2/2] test_pymbolic: use a logger --- test/test_pymbolic.py | 75 +++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 32b18343..785feafb 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -28,15 +28,20 @@ from pytools.lex import ParseError from pymbolic.mapper import IdentityMapper +import logging +logger = logging.getLogger(__name__) + # {{{ utilities def assert_parsed_same_as_python(expr_str): # makes sure that has only one line expr_str, = expr_str.split("\n") - from pymbolic.interop.ast import ASTToPymbolic + import ast + from pymbolic.interop.ast import ASTToPymbolic ast2p = ASTToPymbolic() + try: expr_parsed_by_python = ast2p(ast.parse(expr_str).body[0].value) except SyntaxError: @@ -48,9 +53,10 @@ def assert_parsed_same_as_python(expr_str): def assert_parse_roundtrip(expr_str): - expr = parse(expr_str) from pymbolic.mapper.stringifier import StringifyMapper + expr = parse(expr_str) strified = StringifyMapper()(expr) + assert strified == expr_str, (strified, expr_str) # }}} @@ -123,7 +129,6 @@ def expect_typeerror(f): def test_structure_preservation(): x = prim.Sum((5, 7)) - from pymbolic.mapper import IdentityMapper x2 = IdentityMapper()(x) assert x == x2 @@ -200,9 +205,9 @@ def test_fft(): from pymbolic.algorithm import fft, sym_fft vars = numpy.array([var(chr(97+i)) for i in range(16)], dtype=object) - print(vars) + logger.info("vars: %s", vars) - print(fft(vars)) + logger.info("fft: %s", fft(vars)) traced_fft = sym_fft(vars) from pymbolic.mapper.stringifier import PREC_NONE @@ -212,10 +217,10 @@ def test_fft(): code = [ccm(tfi, PREC_NONE) for tfi in traced_fft] for cse_name, cse_str in enumerate(ccm.cse_name_list): - print(f"{cse_name} = {cse_str}") + logger.info("%s = %s", cse_name, cse_str) for i, line in enumerate(code): - print("result[%d] = %s" % (i, line)) + logger.info("result[%d] = %s", i, line) # }}} @@ -250,25 +255,25 @@ def test_parser(): parse("(2*a[1]*b[1]+2*a[0]*b[0])*(hankel_1(-1,sqrt(a[1]**2+a[0]**2)*k) " "-hankel_1(1,sqrt(a[1]**2+a[0]**2)*k))*k /(4*sqrt(a[1]**2+a[0]**2)) " "+hankel_1(0,sqrt(a[1]**2+a[0]**2)*k)") - print(repr(parse("d4knl0"))) - print(repr(parse("0."))) - print(repr(parse("0.e1"))) + logger.info("%r", parse("d4knl0")) + logger.info("%r", parse("0.")) + logger.info("%r", parse("0.e1")) assert parse("0.e1") == 0 assert parse("1e-12") == 1e-12 - print(repr(parse("a >= 1"))) - print(repr(parse("a <= 1"))) - - print(repr(parse(":"))) - print(repr(parse("1:"))) - print(repr(parse(":2"))) - print(repr(parse("1:2"))) - print(repr(parse("::"))) - print(repr(parse("1::"))) - print(repr(parse(":1:"))) - print(repr(parse("::1"))) - print(repr(parse("3::1"))) - print(repr(parse(":5:1"))) - print(repr(parse("3:5:1"))) + logger.info("%r", parse("a >= 1")) + logger.info("%r", parse("a <= 1")) + + logger.info("%r", parse(":")) + logger.info("%r", parse("1:")) + logger.info("%r", parse(":2")) + logger.info("%r", parse("1:2")) + logger.info("%r", parse("::")) + logger.info("%r", parse("1::")) + logger.info("%r", parse(":1:")) + logger.info("%r", parse("::1")) + logger.info("%r", parse("3::1")) + logger.info("%r", parse(":5:1")) + logger.info("%r", parse("3:5:1")) assert_parse_roundtrip("()") assert_parse_roundtrip("(3,)") @@ -280,17 +285,17 @@ def test_parser(): assert_parse_roundtrip("g[i, k] + 2.0*h[i, k]") parse("g[i,k]+(+2.0)*h[i, k]") - print(repr(parse("a - b - c"))) - print(repr(parse("-a - -b - -c"))) - print(repr(parse("- - - a - - - - b - - - - - c"))) + logger.info("%r", parse("a - b - c")) + logger.info("%r", parse("-a - -b - -c")) + logger.info("%r", parse("- - - a - - - - b - - - - - c")) - print(repr(parse("~(a ^ b)"))) - print(repr(parse("(a | b) | ~(~a & ~b)"))) + logger.info("%r", parse("~(a ^ b)")) + logger.info("%r", parse("(a | b) | ~(~a & ~b)")) - print(repr(parse("3 << 1"))) - print(repr(parse("1 >> 3"))) + logger.info("%r", parse("3 << 1")) + logger.info("%r", parse("1 >> 3")) - print(parse("3::1")) + logger.info(parse("3::1")) assert parse("e1") == prim.Variable("e1") assert parse("d1") == prim.Variable("d1") @@ -374,7 +379,7 @@ def test_graphviz(): from pymbolic.mapper.graphviz import GraphvizMapper gvm = GraphvizMapper() gvm(expr) - print(gvm.get_dot_code()) + logger.info("%s", gvm.get_dot_code()) # }}} @@ -495,7 +500,7 @@ def f(): import ast mod = ast.parse(src.replace("\n ", "\n")) - print(ast.dump(mod)) + logger.info("%s", ast.dump(mod)) from pymbolic.interop.ast import ASTToPymbolic ast2p = ASTToPymbolic() @@ -512,7 +517,7 @@ def f(): lhs = ast2p(lhs) rhs = ast2p(stmt.value) - print(lhs, rhs) + logger.info("lhs %s rhs %s", lhs, rhs) # }}}