From 0be1379d5c39429389134d9394c0009a747ed8ad Mon Sep 17 00:00:00 2001 From: Eric Lippert Date: Mon, 10 Oct 2022 17:45:03 -0700 Subject: [PATCH] Continue implementing broadcasting semantics in compiler Summary: We're continuing to add broadcasting semantics by small steps. In this diff: * I've fixed a bug in how requirements on the matrix add incoming edges are represented. The correct representation is that the input types must be identical to the output types. In the test case, the output type is a 2x2 real matrix, so the inputs must also be 2x2 real matrices. * I've fixed a bug in the error message when that requirement is not met; previously it said that the requirement was "a real" regardless of what the actual requirement was. The requirement is "a real 2x2 matrix" in this example, so that's what we say in the error message. * The code which handled detecting that there's a problem with the dimensions was hard to read. I refactored it to introduce explanatory variables. Since we do not yet introduce a broadcast node, the correct thing to do here is to give an error, which we now do. In an upcoming diff we'll instead insert a broadcast node. Reviewed By: AishwaryaSivaraman Differential Revision: D40046224 fbshipit-source-id: 30e605dfdb52bed68144ad982a05155547213716 --- .../ppl/compiler/bmg_requirements.py | 27 +- .../ppl/compiler/fix_requirements.py | 10 +- src/beanmachine/ppl/compiler/graph_labels.py | 1 + tests/ppl/compiler/broadcast_test.py | 72 +----- .../compiler/fix_vectorized_models_test.py | 7 +- tests/ppl/compiler/gep_test.py | 235 +++++++++--------- 6 files changed, 156 insertions(+), 196 deletions(-) diff --git a/src/beanmachine/ppl/compiler/bmg_requirements.py b/src/beanmachine/ppl/compiler/bmg_requirements.py index 458ca3ae21..41dd2c8f03 100644 --- a/src/beanmachine/ppl/compiler/bmg_requirements.py +++ b/src/beanmachine/ppl/compiler/bmg_requirements.py @@ -35,7 +35,15 @@ # what their inputs are. _known_requirements: Dict[type, List[bt.Requirement]] = { - # TODO See comment below regarding RealMatrix + # TODO: This is wrong in several ways. + # First, RealMatrix does not meet the contract of a requirement; + # in particular, it cannot be printed out by the requirement diagnostic + # in gen_to_dot. + # Second, it is too strict; the requirement on matrix add is actually + # that the two operands be any double matrix (real, neg real, + # pos real or probability). + # Third, this requirement is too weak; we are missing the requirement + # that the operands have the same element type and shape. bn.ElementwiseMultiplyNode: [bt.RealMatrix, bt.RealMatrix], bn.Observation: [bt.any_requirement], bn.Query: [bt.any_requirement], @@ -68,16 +76,6 @@ bn.LogisticNode: [bt.Real], bn.Log1mexpNode: [bt.NegativeReal], bn.MatrixMultiplicationNode: [bt.any_real_matrix, bt.any_real_matrix], - # TODO: This is wrong in several ways. - # First, RealMatrix does not meet the contract of a requirement; - # in particular, it cannot be printed out by the requirement diagnostic - # in gen_to_dot. - # Second, it is too strict; the requirement on matrix add is actually - # that the two operands be any double matrix (real, neg real, - # pos real or probability). - # Third, this requirement is too weak; we are missing the requirement - # that the operands have the same element type and shape. - bn.MatrixAddNode: [bt.RealMatrix, bt.RealMatrix], bn.MatrixExpNode: [bt.any_real_matrix], bn.MatrixLogNode: [bt.any_pos_real_matrix], bn.MatrixLog1mexpNode: [bt.any_real_matrix], @@ -127,6 +125,7 @@ def __init__(self, typer: LatticeTyper) -> None: # TODO: bn.MatrixMultiplyNode: self._requirements_matrix_multiply, # see comment above bn.MatrixComplementNode: self._requrirements_matrix_complement, + bn.MatrixAddNode: self._requirements_matrix_add, bn.MatrixScaleNode: self._requirements_matrix_scale, bn.MultiplicationNode: self._requirements_multiplication, bn.NegateNode: self._requirements_exp_neg, @@ -447,6 +446,12 @@ def _requrirements_matrix_complement( req = [bt.SimplexMatrix] return req + def _requirements_matrix_add(self, node: bn.MatrixAddNode) -> List[bt.Requirement]: + # Matrix add requires that both operands be the same as the output type. + it = self.typer[node] + assert isinstance(it, bt.BMGMatrixType) + return [it, it] + def _requirements_matrix_scale( self, node: bn.MatrixScaleNode ) -> List[bt.Requirement]: diff --git a/src/beanmachine/ppl/compiler/fix_requirements.py b/src/beanmachine/ppl/compiler/fix_requirements.py index f38b051495..953f4802fa 100644 --- a/src/beanmachine/ppl/compiler/fix_requirements.py +++ b/src/beanmachine/ppl/compiler/fix_requirements.py @@ -338,11 +338,13 @@ def _meet_real_matrix_requirement( edge: str, ) -> bn.BMGNode: result = None - node_is_scalar = node_dim[0] == 1 and node_dim[1] == 1 - requires_scalar = dim_req[0] == 1 and dim_req[1] == 1 + req_rows, req_cols = dim_req + node_rows, node_cols = node_dim + node_is_scalar = node_rows == 1 and node_cols == 1 + requires_scalar = req_rows == 1 and req_cols == 1 if requires_scalar and node_is_scalar: result = self.bmg.add_to_real(node) - elif node_dim[0] == dim_req[0] and node_dim[1] == dim_req[1]: + elif node_rows == req_rows and node_cols == req_cols: result = self.bmg.add_to_real_matrix(node) if result is None: @@ -350,7 +352,7 @@ def _meet_real_matrix_requirement( Violation( node, self._typer[node], - bt.RealMatrix(1, 1), + bt.RealMatrix(req_rows, req_cols), consumer, edge, self.bmg.execution_context.node_locations(consumer), diff --git a/src/beanmachine/ppl/compiler/graph_labels.py b/src/beanmachine/ppl/compiler/graph_labels.py index 4f83871f1f..0fa9b08e6c 100644 --- a/src/beanmachine/ppl/compiler/graph_labels.py +++ b/src/beanmachine/ppl/compiler/graph_labels.py @@ -205,6 +205,7 @@ def _val(node: bn.ConstantNode) -> str: bn.LogSumExpVectorNode: "logsumexp", bn.LogAddExpNode: "logaddexp", bn.LShiftNode: "'left shift' (<<)", + bn.MatrixAddNode: "matrix add", bn.MatrixMultiplicationNode: "matrix multiplication (@)", bn.MatrixScaleNode: "matrix scale", bn.ModNode: "modulus (%)", diff --git a/tests/ppl/compiler/broadcast_test.py b/tests/ppl/compiler/broadcast_test.py index 9a25180714..2542137bb7 100644 --- a/tests/ppl/compiler/broadcast_test.py +++ b/tests/ppl/compiler/broadcast_test.py @@ -82,69 +82,15 @@ def test_broadcast_add(self) -> None: """ self.assertEqual(expected.strip(), observed.strip()) - # After: + # We do not yet insert a broadcast node. Demonstrate here + # that the compiler gives a reasonable error message. - observed = BMGInference().to_dot(queries, observations, after_transform=True) - expected = """ -digraph "graph" { - N00[label=0.0]; - N01[label=1.0]; - N02[label=Normal]; - N03[label=Sample]; - N04[label=Sample]; - N05[label=Normal]; - N06[label=Sample]; - N07[label=Normal]; - N08[label=Sample]; - N09[label=Sample]; - N10[label=Sample]; - N11[label=Normal]; - N12[label=Sample]; - N13[label=Normal]; - N14[label=Sample]; - N15[label=2]; - N16[label=1]; - N17[label=ToMatrix]; - N18[label=ToMatrix]; - N19[label=MatrixAdd]; - N20[label=Query]; - N00 -> N02; - N01 -> N02; - N01 -> N05; - N01 -> N07; - N01 -> N11; - N01 -> N13; - N02 -> N03; - N02 -> N04; - N02 -> N09; - N02 -> N10; - N03 -> N05; - N04 -> N07; - N05 -> N06; - N06 -> N17; - N07 -> N08; - N08 -> N17; - N09 -> N11; - N10 -> N13; - N11 -> N12; - N12 -> N18; - N13 -> N14; - N14 -> N18; - N15 -> N17; - N15 -> N18; - N16 -> N17; - N16 -> N18; - N17 -> N19; - N18 -> N19; - N19 -> N20; -} -""" - self.assertEqual(expected.strip(), observed.strip()) + with self.assertRaises(ValueError) as ex: + BMGInference().to_graph(queries, observations) - # BMG + expected = """ +The left of a matrix add is required to be a 2 x 2 real matrix but is a 2 x 1 real matrix. +The right of a matrix add is required to be a 2 x 2 real matrix but is a 1 x 2 real matrix.""" - with self.assertRaises(ValueError): - g, _ = BMGInference().to_graph(queries, observations) - # observed = g.to_dot() - # expected = "" - # self.assertEqual(expected.strip(), observed.strip()) + observed = str(ex.exception) + self.assertEqual(expected.strip(), observed.strip()) diff --git a/tests/ppl/compiler/fix_vectorized_models_test.py b/tests/ppl/compiler/fix_vectorized_models_test.py index 8b3d0f0a2a..ae2262ba50 100644 --- a/tests/ppl/compiler/fix_vectorized_models_test.py +++ b/tests/ppl/compiler/fix_vectorized_models_test.py @@ -753,7 +753,12 @@ def test_fix_vectorized_models_6(self) -> None: """ self.assertEqual(expected.strip(), observed.strip()) - def test_fix_vectorized_models_7(self) -> None: + def _disabled_test_fix_vectorized_models_7(self) -> None: + # TODO: This test is disabled until broadcasting semantics + # are correctly implemented. This model generates an incorrect + # BMG graph right now because it tries to add a 2x1 matrix of + # probabilities to a 2x2 matrix of reals without inserting the + # necessary broadcasting and type conversion nodes. self.maxDiff = None observations = {} queries = [operators()] diff --git a/tests/ppl/compiler/gep_test.py b/tests/ppl/compiler/gep_test.py index 940a4507d5..235aea7f53 100644 --- a/tests/ppl/compiler/gep_test.py +++ b/tests/ppl/compiler/gep_test.py @@ -72,96 +72,96 @@ def test_gep_model_compilation(self) -> None: self.maxDiff = None queries = [prev()] observations = {bucket_prob(): torch.tensor([1.0])} - observed = BMGInference().to_dot(queries, observations) + # Demonstrate that compiling to an actual BMG graph + # generates a graph which type checks. + g, _ = BMGInference().to_graph(queries, observations) + observed = g.to_dot() expected = """ digraph "graph" { - N00[label=5.0]; - N01[label=HalfNormal]; - N02[label=Sample]; - N03[label=0.10000000149011612]; - N04[label=HalfNormal]; - N05[label=Sample]; - N06[label=0.0]; - N07[label=1.0]; - N08[label=Normal]; - N09[label=Sample]; - N10[label=Normal]; - N11[label=Sample]; - N12[label=HalfNormal]; - N13[label=Sample]; + N0[label="5"]; + N1[label="HalfNormal"]; + N2[label="~"]; + N3[label="0.1"]; + N4[label="HalfNormal"]; + N5[label="~"]; + N6[label="0"]; + N7[label="1"]; + N8[label="Normal"]; + N9[label="~"]; + N10[label="Normal"]; + N11[label="~"]; + N12[label="HalfNormal"]; + N13[label="~"]; N14[label="*"]; - N15[label=2.0]; + N15[label="2"]; N16[label="*"]; - N17[label=-1.0]; - N18[label="**"]; - N19[label=ToReal]; - N20[label="[[-0.0,-8.836000051815063e-05],\\\\n[-8.836000051815063e-05,-0.0]]"]; - N21[label=MatrixScale]; - N22[label=MatrixExp]; - N23[label=MatrixScale]; - N24[label=ToRealMatrix]; - N25[label="[[0.0010000000474974513,0.0],\\\\n[0.0,0.0010000000474974513]]"]; - N26[label=MatrixAdd]; - N27[label=Cholesky]; - N28[label=2]; - N29[label=1]; - N30[label=ToMatrix]; - N31[label="@"]; - N32[label=0]; - N33[label=index]; - N34[label=Normal]; - N35[label=Sample]; - N36[label=index]; - N37[label=Normal]; - N38[label=Sample]; - N39[label="[4.0,0.0]"]; - N40[label=Phi]; - N41[label=Phi]; - N42[label=ToMatrix]; - N43[label=MatrixLog]; - N44[label=ToRealMatrix]; - N45[label=ElementwiseMult]; - N46[label="[29850.0,2016.0]"]; - N47[label=complement]; - N48[label=Log]; - N49[label=complement]; - N50[label=Log]; - N51[label=ToMatrix]; - N52[label=ToRealMatrix]; - N53[label=ElementwiseMult]; - N54[label=MatrixAdd]; - N55[label=MatrixSum]; - N56[label=ToNegReal]; - N57[label=Log1mexp]; - N58[label="-"]; - N59[label=ToReal]; - N60[label="+"]; - N61[label="Bernoulli(logits)"]; - N62[label=Sample]; - N63[label="Observation True"]; - N64[label=ToMatrix]; - N65[label=Query]; - N00 -> N01; - N01 -> N02; - N02 -> N14; - N02 -> N14; - N03 -> N04; - N04 -> N05; - N05 -> N16; - N05 -> N16; - N06 -> N08; - N06 -> N10; - N07 -> N08; - N07 -> N10; - N07 -> N12; - N08 -> N09; - N09 -> N30; + N17[label="-1"]; + N18[label="^"]; + N19[label="ToReal"]; + N20[label="matrix"]; + N21[label="MatrixScale"]; + N22[label="MatrixExp"]; + N23[label="MatrixScale"]; + N24[label="matrix"]; + N25[label="MatrixAdd"]; + N26[label="Cholesky"]; + N27[label="2"]; + N28[label="1"]; + N29[label="ToMatrix"]; + N30[label="MatrixMultiply"]; + N31[label="0"]; + N32[label="Index"]; + N33[label="Normal"]; + N34[label="~"]; + N35[label="Index"]; + N36[label="Normal"]; + N37[label="~"]; + N38[label="matrix"]; + N39[label="Phi"]; + N40[label="Phi"]; + N41[label="ToMatrix"]; + N42[label="MatrixLog"]; + N43[label="ToReal"]; + N44[label="ElementwiseMultiply"]; + N45[label="matrix"]; + N46[label="Complement"]; + N47[label="Log"]; + N48[label="Complement"]; + N49[label="Log"]; + N50[label="ToMatrix"]; + N51[label="ToReal"]; + N52[label="ElementwiseMultiply"]; + N53[label="MatrixAdd"]; + N54[label="MatrixSum"]; + N55[label="ToNegReal"]; + N56[label="Log1mExp"]; + N57[label="Negate"]; + N58[label="ToReal"]; + N59[label="+"]; + N60[label="BernoulliLogit"]; + N61[label="~"]; + N62[label="ToMatrix"]; + N0 -> N1; + N1 -> N2; + N2 -> N14; + N2 -> N14; + N3 -> N4; + N4 -> N5; + N5 -> N16; + N5 -> N16; + N6 -> N8; + N6 -> N10; + N7 -> N8; + N7 -> N10; + N7 -> N12; + N8 -> N9; + N9 -> N29; N10 -> N11; - N11 -> N30; + N11 -> N29; N12 -> N13; - N13 -> N34; - N13 -> N37; + N13 -> N33; + N13 -> N36; N14 -> N23; N15 -> N16; N16 -> N18; @@ -171,60 +171,61 @@ def test_gep_model_compilation(self) -> None: N20 -> N21; N21 -> N22; N22 -> N23; - N23 -> N24; - N24 -> N26; + N23 -> N25; + N24 -> N25; N25 -> N26; - N26 -> N27; - N27 -> N31; - N28 -> N30; - N28 -> N42; - N28 -> N51; - N28 -> N64; + N26 -> N30; + N27 -> N29; + N27 -> N41; + N27 -> N50; + N27 -> N62; + N28 -> N29; + N28 -> N35; + N28 -> N41; + N28 -> N50; + N28 -> N62; N29 -> N30; - N29 -> N36; - N29 -> N42; - N29 -> N51; - N29 -> N64; - N30 -> N31; - N31 -> N33; - N31 -> N36; + N30 -> N32; + N30 -> N35; + N31 -> N32; N32 -> N33; N33 -> N34; - N34 -> N35; - N35 -> N40; - N35 -> N64; + N34 -> N39; + N34 -> N62; + N35 -> N36; N36 -> N37; - N37 -> N38; - N38 -> N41; - N38 -> N64; - N39 -> N45; - N40 -> N42; - N40 -> N47; + N37 -> N40; + N37 -> N62; + N38 -> N44; + N39 -> N41; + N39 -> N46; + N40 -> N41; + N40 -> N48; N41 -> N42; - N41 -> N49; N42 -> N43; N43 -> N44; - N44 -> N45; - N45 -> N54; - N46 -> N53; - N47 -> N48; - N48 -> N51; + N44 -> N53; + N45 -> N52; + N46 -> N47; + N47 -> N50; + N48 -> N49; N49 -> N50; N50 -> N51; N51 -> N52; N52 -> N53; N53 -> N54; N54 -> N55; + N54 -> N59; N55 -> N56; - N55 -> N60; N56 -> N57; N57 -> N58; N58 -> N59; N59 -> N60; N60 -> N61; - N61 -> N62; - N62 -> N63; - N64 -> N65; + O0[label="Observation"]; + N61 -> O0; + Q0[label="Query"]; + N62 -> Q0; } """ self.assertEqual(expected.strip(), observed.strip())