diff --git a/src/beanmachine/ppl/compiler/bmg_requirements.py b/src/beanmachine/ppl/compiler/bmg_requirements.py index 458ca3ae21..41dd2c8f03 100644 --- a/src/beanmachine/ppl/compiler/bmg_requirements.py +++ b/src/beanmachine/ppl/compiler/bmg_requirements.py @@ -35,7 +35,15 @@ # what their inputs are. _known_requirements: Dict[type, List[bt.Requirement]] = { - # TODO See comment below regarding RealMatrix + # TODO: This is wrong in several ways. + # First, RealMatrix does not meet the contract of a requirement; + # in particular, it cannot be printed out by the requirement diagnostic + # in gen_to_dot. + # Second, it is too strict; the requirement on matrix add is actually + # that the two operands be any double matrix (real, neg real, + # pos real or probability). + # Third, this requirement is too weak; we are missing the requirement + # that the operands have the same element type and shape. bn.ElementwiseMultiplyNode: [bt.RealMatrix, bt.RealMatrix], bn.Observation: [bt.any_requirement], bn.Query: [bt.any_requirement], @@ -68,16 +76,6 @@ bn.LogisticNode: [bt.Real], bn.Log1mexpNode: [bt.NegativeReal], bn.MatrixMultiplicationNode: [bt.any_real_matrix, bt.any_real_matrix], - # TODO: This is wrong in several ways. - # First, RealMatrix does not meet the contract of a requirement; - # in particular, it cannot be printed out by the requirement diagnostic - # in gen_to_dot. - # Second, it is too strict; the requirement on matrix add is actually - # that the two operands be any double matrix (real, neg real, - # pos real or probability). - # Third, this requirement is too weak; we are missing the requirement - # that the operands have the same element type and shape. - bn.MatrixAddNode: [bt.RealMatrix, bt.RealMatrix], bn.MatrixExpNode: [bt.any_real_matrix], bn.MatrixLogNode: [bt.any_pos_real_matrix], bn.MatrixLog1mexpNode: [bt.any_real_matrix], @@ -127,6 +125,7 @@ def __init__(self, typer: LatticeTyper) -> None: # TODO: bn.MatrixMultiplyNode: self._requirements_matrix_multiply, # see comment above bn.MatrixComplementNode: self._requrirements_matrix_complement, + bn.MatrixAddNode: self._requirements_matrix_add, bn.MatrixScaleNode: self._requirements_matrix_scale, bn.MultiplicationNode: self._requirements_multiplication, bn.NegateNode: self._requirements_exp_neg, @@ -447,6 +446,12 @@ def _requrirements_matrix_complement( req = [bt.SimplexMatrix] return req + def _requirements_matrix_add(self, node: bn.MatrixAddNode) -> List[bt.Requirement]: + # Matrix add requires that both operands be the same as the output type. + it = self.typer[node] + assert isinstance(it, bt.BMGMatrixType) + return [it, it] + def _requirements_matrix_scale( self, node: bn.MatrixScaleNode ) -> List[bt.Requirement]: diff --git a/src/beanmachine/ppl/compiler/fix_requirements.py b/src/beanmachine/ppl/compiler/fix_requirements.py index f38b051495..953f4802fa 100644 --- a/src/beanmachine/ppl/compiler/fix_requirements.py +++ b/src/beanmachine/ppl/compiler/fix_requirements.py @@ -338,11 +338,13 @@ def _meet_real_matrix_requirement( edge: str, ) -> bn.BMGNode: result = None - node_is_scalar = node_dim[0] == 1 and node_dim[1] == 1 - requires_scalar = dim_req[0] == 1 and dim_req[1] == 1 + req_rows, req_cols = dim_req + node_rows, node_cols = node_dim + node_is_scalar = node_rows == 1 and node_cols == 1 + requires_scalar = req_rows == 1 and req_cols == 1 if requires_scalar and node_is_scalar: result = self.bmg.add_to_real(node) - elif node_dim[0] == dim_req[0] and node_dim[1] == dim_req[1]: + elif node_rows == req_rows and node_cols == req_cols: result = self.bmg.add_to_real_matrix(node) if result is None: @@ -350,7 +352,7 @@ def _meet_real_matrix_requirement( Violation( node, self._typer[node], - bt.RealMatrix(1, 1), + bt.RealMatrix(req_rows, req_cols), consumer, edge, self.bmg.execution_context.node_locations(consumer), diff --git a/src/beanmachine/ppl/compiler/graph_labels.py b/src/beanmachine/ppl/compiler/graph_labels.py index 4f83871f1f..0fa9b08e6c 100644 --- a/src/beanmachine/ppl/compiler/graph_labels.py +++ b/src/beanmachine/ppl/compiler/graph_labels.py @@ -205,6 +205,7 @@ def _val(node: bn.ConstantNode) -> str: bn.LogSumExpVectorNode: "logsumexp", bn.LogAddExpNode: "logaddexp", bn.LShiftNode: "'left shift' (<<)", + bn.MatrixAddNode: "matrix add", bn.MatrixMultiplicationNode: "matrix multiplication (@)", bn.MatrixScaleNode: "matrix scale", bn.ModNode: "modulus (%)", diff --git a/tests/ppl/compiler/broadcast_test.py b/tests/ppl/compiler/broadcast_test.py index 9a25180714..2542137bb7 100644 --- a/tests/ppl/compiler/broadcast_test.py +++ b/tests/ppl/compiler/broadcast_test.py @@ -82,69 +82,15 @@ def test_broadcast_add(self) -> None: """ self.assertEqual(expected.strip(), observed.strip()) - # After: + # We do not yet insert a broadcast node. Demonstrate here + # that the compiler gives a reasonable error message. - observed = BMGInference().to_dot(queries, observations, after_transform=True) - expected = """ -digraph "graph" { - N00[label=0.0]; - N01[label=1.0]; - N02[label=Normal]; - N03[label=Sample]; - N04[label=Sample]; - N05[label=Normal]; - N06[label=Sample]; - N07[label=Normal]; - N08[label=Sample]; - N09[label=Sample]; - N10[label=Sample]; - N11[label=Normal]; - N12[label=Sample]; - N13[label=Normal]; - N14[label=Sample]; - N15[label=2]; - N16[label=1]; - N17[label=ToMatrix]; - N18[label=ToMatrix]; - N19[label=MatrixAdd]; - N20[label=Query]; - N00 -> N02; - N01 -> N02; - N01 -> N05; - N01 -> N07; - N01 -> N11; - N01 -> N13; - N02 -> N03; - N02 -> N04; - N02 -> N09; - N02 -> N10; - N03 -> N05; - N04 -> N07; - N05 -> N06; - N06 -> N17; - N07 -> N08; - N08 -> N17; - N09 -> N11; - N10 -> N13; - N11 -> N12; - N12 -> N18; - N13 -> N14; - N14 -> N18; - N15 -> N17; - N15 -> N18; - N16 -> N17; - N16 -> N18; - N17 -> N19; - N18 -> N19; - N19 -> N20; -} -""" - self.assertEqual(expected.strip(), observed.strip()) + with self.assertRaises(ValueError) as ex: + BMGInference().to_graph(queries, observations) - # BMG + expected = """ +The left of a matrix add is required to be a 2 x 2 real matrix but is a 2 x 1 real matrix. +The right of a matrix add is required to be a 2 x 2 real matrix but is a 1 x 2 real matrix.""" - with self.assertRaises(ValueError): - g, _ = BMGInference().to_graph(queries, observations) - # observed = g.to_dot() - # expected = "" - # self.assertEqual(expected.strip(), observed.strip()) + observed = str(ex.exception) + self.assertEqual(expected.strip(), observed.strip()) diff --git a/tests/ppl/compiler/fix_vectorized_models_test.py b/tests/ppl/compiler/fix_vectorized_models_test.py index 8b3d0f0a2a..ae2262ba50 100644 --- a/tests/ppl/compiler/fix_vectorized_models_test.py +++ b/tests/ppl/compiler/fix_vectorized_models_test.py @@ -753,7 +753,12 @@ def test_fix_vectorized_models_6(self) -> None: """ self.assertEqual(expected.strip(), observed.strip()) - def test_fix_vectorized_models_7(self) -> None: + def _disabled_test_fix_vectorized_models_7(self) -> None: + # TODO: This test is disabled until broadcasting semantics + # are correctly implemented. This model generates an incorrect + # BMG graph right now because it tries to add a 2x1 matrix of + # probabilities to a 2x2 matrix of reals without inserting the + # necessary broadcasting and type conversion nodes. self.maxDiff = None observations = {} queries = [operators()] diff --git a/tests/ppl/compiler/gep_test.py b/tests/ppl/compiler/gep_test.py index 940a4507d5..235aea7f53 100644 --- a/tests/ppl/compiler/gep_test.py +++ b/tests/ppl/compiler/gep_test.py @@ -72,96 +72,96 @@ def test_gep_model_compilation(self) -> None: self.maxDiff = None queries = [prev()] observations = {bucket_prob(): torch.tensor([1.0])} - observed = BMGInference().to_dot(queries, observations) + # Demonstrate that compiling to an actual BMG graph + # generates a graph which type checks. + g, _ = BMGInference().to_graph(queries, observations) + observed = g.to_dot() expected = """ digraph "graph" { - N00[label=5.0]; - N01[label=HalfNormal]; - N02[label=Sample]; - N03[label=0.10000000149011612]; - N04[label=HalfNormal]; - N05[label=Sample]; - N06[label=0.0]; - N07[label=1.0]; - N08[label=Normal]; - N09[label=Sample]; - N10[label=Normal]; - N11[label=Sample]; - N12[label=HalfNormal]; - N13[label=Sample]; + N0[label="5"]; + N1[label="HalfNormal"]; + N2[label="~"]; + N3[label="0.1"]; + N4[label="HalfNormal"]; + N5[label="~"]; + N6[label="0"]; + N7[label="1"]; + N8[label="Normal"]; + N9[label="~"]; + N10[label="Normal"]; + N11[label="~"]; + N12[label="HalfNormal"]; + N13[label="~"]; N14[label="*"]; - N15[label=2.0]; + N15[label="2"]; N16[label="*"]; - N17[label=-1.0]; - N18[label="**"]; - N19[label=ToReal]; - N20[label="[[-0.0,-8.836000051815063e-05],\\\\n[-8.836000051815063e-05,-0.0]]"]; - N21[label=MatrixScale]; - N22[label=MatrixExp]; - N23[label=MatrixScale]; - N24[label=ToRealMatrix]; - N25[label="[[0.0010000000474974513,0.0],\\\\n[0.0,0.0010000000474974513]]"]; - N26[label=MatrixAdd]; - N27[label=Cholesky]; - N28[label=2]; - N29[label=1]; - N30[label=ToMatrix]; - N31[label="@"]; - N32[label=0]; - N33[label=index]; - N34[label=Normal]; - N35[label=Sample]; - N36[label=index]; - N37[label=Normal]; - N38[label=Sample]; - N39[label="[4.0,0.0]"]; - N40[label=Phi]; - N41[label=Phi]; - N42[label=ToMatrix]; - N43[label=MatrixLog]; - N44[label=ToRealMatrix]; - N45[label=ElementwiseMult]; - N46[label="[29850.0,2016.0]"]; - N47[label=complement]; - N48[label=Log]; - N49[label=complement]; - N50[label=Log]; - N51[label=ToMatrix]; - N52[label=ToRealMatrix]; - N53[label=ElementwiseMult]; - N54[label=MatrixAdd]; - N55[label=MatrixSum]; - N56[label=ToNegReal]; - N57[label=Log1mexp]; - N58[label="-"]; - N59[label=ToReal]; - N60[label="+"]; - N61[label="Bernoulli(logits)"]; - N62[label=Sample]; - N63[label="Observation True"]; - N64[label=ToMatrix]; - N65[label=Query]; - N00 -> N01; - N01 -> N02; - N02 -> N14; - N02 -> N14; - N03 -> N04; - N04 -> N05; - N05 -> N16; - N05 -> N16; - N06 -> N08; - N06 -> N10; - N07 -> N08; - N07 -> N10; - N07 -> N12; - N08 -> N09; - N09 -> N30; + N17[label="-1"]; + N18[label="^"]; + N19[label="ToReal"]; + N20[label="matrix"]; + N21[label="MatrixScale"]; + N22[label="MatrixExp"]; + N23[label="MatrixScale"]; + N24[label="matrix"]; + N25[label="MatrixAdd"]; + N26[label="Cholesky"]; + N27[label="2"]; + N28[label="1"]; + N29[label="ToMatrix"]; + N30[label="MatrixMultiply"]; + N31[label="0"]; + N32[label="Index"]; + N33[label="Normal"]; + N34[label="~"]; + N35[label="Index"]; + N36[label="Normal"]; + N37[label="~"]; + N38[label="matrix"]; + N39[label="Phi"]; + N40[label="Phi"]; + N41[label="ToMatrix"]; + N42[label="MatrixLog"]; + N43[label="ToReal"]; + N44[label="ElementwiseMultiply"]; + N45[label="matrix"]; + N46[label="Complement"]; + N47[label="Log"]; + N48[label="Complement"]; + N49[label="Log"]; + N50[label="ToMatrix"]; + N51[label="ToReal"]; + N52[label="ElementwiseMultiply"]; + N53[label="MatrixAdd"]; + N54[label="MatrixSum"]; + N55[label="ToNegReal"]; + N56[label="Log1mExp"]; + N57[label="Negate"]; + N58[label="ToReal"]; + N59[label="+"]; + N60[label="BernoulliLogit"]; + N61[label="~"]; + N62[label="ToMatrix"]; + N0 -> N1; + N1 -> N2; + N2 -> N14; + N2 -> N14; + N3 -> N4; + N4 -> N5; + N5 -> N16; + N5 -> N16; + N6 -> N8; + N6 -> N10; + N7 -> N8; + N7 -> N10; + N7 -> N12; + N8 -> N9; + N9 -> N29; N10 -> N11; - N11 -> N30; + N11 -> N29; N12 -> N13; - N13 -> N34; - N13 -> N37; + N13 -> N33; + N13 -> N36; N14 -> N23; N15 -> N16; N16 -> N18; @@ -171,60 +171,61 @@ def test_gep_model_compilation(self) -> None: N20 -> N21; N21 -> N22; N22 -> N23; - N23 -> N24; - N24 -> N26; + N23 -> N25; + N24 -> N25; N25 -> N26; - N26 -> N27; - N27 -> N31; - N28 -> N30; - N28 -> N42; - N28 -> N51; - N28 -> N64; + N26 -> N30; + N27 -> N29; + N27 -> N41; + N27 -> N50; + N27 -> N62; + N28 -> N29; + N28 -> N35; + N28 -> N41; + N28 -> N50; + N28 -> N62; N29 -> N30; - N29 -> N36; - N29 -> N42; - N29 -> N51; - N29 -> N64; - N30 -> N31; - N31 -> N33; - N31 -> N36; + N30 -> N32; + N30 -> N35; + N31 -> N32; N32 -> N33; N33 -> N34; - N34 -> N35; - N35 -> N40; - N35 -> N64; + N34 -> N39; + N34 -> N62; + N35 -> N36; N36 -> N37; - N37 -> N38; - N38 -> N41; - N38 -> N64; - N39 -> N45; - N40 -> N42; - N40 -> N47; + N37 -> N40; + N37 -> N62; + N38 -> N44; + N39 -> N41; + N39 -> N46; + N40 -> N41; + N40 -> N48; N41 -> N42; - N41 -> N49; N42 -> N43; N43 -> N44; - N44 -> N45; - N45 -> N54; - N46 -> N53; - N47 -> N48; - N48 -> N51; + N44 -> N53; + N45 -> N52; + N46 -> N47; + N47 -> N50; + N48 -> N49; N49 -> N50; N50 -> N51; N51 -> N52; N52 -> N53; N53 -> N54; N54 -> N55; + N54 -> N59; N55 -> N56; - N55 -> N60; N56 -> N57; N57 -> N58; N58 -> N59; N59 -> N60; N60 -> N61; - N61 -> N62; - N62 -> N63; - N64 -> N65; + O0[label="Observation"]; + N61 -> O0; + Q0[label="Query"]; + N62 -> Q0; } """ self.assertEqual(expected.strip(), observed.strip())