Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions src/beanmachine/ppl/compiler/bmg_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@
# what their inputs are.

_known_requirements: Dict[type, List[bt.Requirement]] = {
# TODO See comment below regarding RealMatrix
# TODO: This is wrong in several ways.
# First, RealMatrix does not meet the contract of a requirement;
# in particular, it cannot be printed out by the requirement diagnostic
# in gen_to_dot.
# Second, it is too strict; the requirement on matrix add is actually
# that the two operands be any double matrix (real, neg real,
# pos real or probability).
# Third, this requirement is too weak; we are missing the requirement
# that the operands have the same element type and shape.
bn.ElementwiseMultiplyNode: [bt.RealMatrix, bt.RealMatrix],
bn.Observation: [bt.any_requirement],
bn.Query: [bt.any_requirement],
Expand Down Expand Up @@ -68,16 +76,6 @@
bn.LogisticNode: [bt.Real],
bn.Log1mexpNode: [bt.NegativeReal],
bn.MatrixMultiplicationNode: [bt.any_real_matrix, bt.any_real_matrix],
# TODO: This is wrong in several ways.
# First, RealMatrix does not meet the contract of a requirement;
# in particular, it cannot be printed out by the requirement diagnostic
# in gen_to_dot.
# Second, it is too strict; the requirement on matrix add is actually
# that the two operands be any double matrix (real, neg real,
# pos real or probability).
# Third, this requirement is too weak; we are missing the requirement
# that the operands have the same element type and shape.
bn.MatrixAddNode: [bt.RealMatrix, bt.RealMatrix],
bn.MatrixExpNode: [bt.any_real_matrix],
bn.MatrixLogNode: [bt.any_pos_real_matrix],
bn.MatrixLog1mexpNode: [bt.any_real_matrix],
Expand Down Expand Up @@ -127,6 +125,7 @@ def __init__(self, typer: LatticeTyper) -> None:
# TODO: bn.MatrixMultiplyNode: self._requirements_matrix_multiply,
# see comment above
bn.MatrixComplementNode: self._requrirements_matrix_complement,
bn.MatrixAddNode: self._requirements_matrix_add,
bn.MatrixScaleNode: self._requirements_matrix_scale,
bn.MultiplicationNode: self._requirements_multiplication,
bn.NegateNode: self._requirements_exp_neg,
Expand Down Expand Up @@ -447,6 +446,12 @@ def _requrirements_matrix_complement(
req = [bt.SimplexMatrix]
return req

def _requirements_matrix_add(self, node: bn.MatrixAddNode) -> List[bt.Requirement]:
# Matrix add requires that both operands be the same as the output type.
it = self.typer[node]
assert isinstance(it, bt.BMGMatrixType)
return [it, it]

def _requirements_matrix_scale(
self, node: bn.MatrixScaleNode
) -> List[bt.Requirement]:
Expand Down
10 changes: 6 additions & 4 deletions src/beanmachine/ppl/compiler/fix_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,19 +338,21 @@ def _meet_real_matrix_requirement(
edge: str,
) -> bn.BMGNode:
result = None
node_is_scalar = node_dim[0] == 1 and node_dim[1] == 1
requires_scalar = dim_req[0] == 1 and dim_req[1] == 1
req_rows, req_cols = dim_req
node_rows, node_cols = node_dim
node_is_scalar = node_rows == 1 and node_cols == 1
requires_scalar = req_rows == 1 and req_cols == 1
if requires_scalar and node_is_scalar:
result = self.bmg.add_to_real(node)
elif node_dim[0] == dim_req[0] and node_dim[1] == dim_req[1]:
elif node_rows == req_rows and node_cols == req_cols:
result = self.bmg.add_to_real_matrix(node)

if result is None:
self.errors.add_error(
Violation(
node,
self._typer[node],
bt.RealMatrix(1, 1),
bt.RealMatrix(req_rows, req_cols),
consumer,
edge,
self.bmg.execution_context.node_locations(consumer),
Expand Down
1 change: 1 addition & 0 deletions src/beanmachine/ppl/compiler/graph_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _val(node: bn.ConstantNode) -> str:
bn.LogSumExpVectorNode: "logsumexp",
bn.LogAddExpNode: "logaddexp",
bn.LShiftNode: "'left shift' (<<)",
bn.MatrixAddNode: "matrix add",
bn.MatrixMultiplicationNode: "matrix multiplication (@)",
bn.MatrixScaleNode: "matrix scale",
bn.ModNode: "modulus (%)",
Expand Down
72 changes: 9 additions & 63 deletions tests/ppl/compiler/broadcast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,69 +82,15 @@ def test_broadcast_add(self) -> None:
"""
self.assertEqual(expected.strip(), observed.strip())

# After:
# We do not yet insert a broadcast node. Demonstrate here
# that the compiler gives a reasonable error message.

observed = BMGInference().to_dot(queries, observations, after_transform=True)
expected = """
digraph "graph" {
N00[label=0.0];
N01[label=1.0];
N02[label=Normal];
N03[label=Sample];
N04[label=Sample];
N05[label=Normal];
N06[label=Sample];
N07[label=Normal];
N08[label=Sample];
N09[label=Sample];
N10[label=Sample];
N11[label=Normal];
N12[label=Sample];
N13[label=Normal];
N14[label=Sample];
N15[label=2];
N16[label=1];
N17[label=ToMatrix];
N18[label=ToMatrix];
N19[label=MatrixAdd];
N20[label=Query];
N00 -> N02;
N01 -> N02;
N01 -> N05;
N01 -> N07;
N01 -> N11;
N01 -> N13;
N02 -> N03;
N02 -> N04;
N02 -> N09;
N02 -> N10;
N03 -> N05;
N04 -> N07;
N05 -> N06;
N06 -> N17;
N07 -> N08;
N08 -> N17;
N09 -> N11;
N10 -> N13;
N11 -> N12;
N12 -> N18;
N13 -> N14;
N14 -> N18;
N15 -> N17;
N15 -> N18;
N16 -> N17;
N16 -> N18;
N17 -> N19;
N18 -> N19;
N19 -> N20;
}
"""
self.assertEqual(expected.strip(), observed.strip())
with self.assertRaises(ValueError) as ex:
BMGInference().to_graph(queries, observations)

# BMG
expected = """
The left of a matrix add is required to be a 2 x 2 real matrix but is a 2 x 1 real matrix.
The right of a matrix add is required to be a 2 x 2 real matrix but is a 1 x 2 real matrix."""

with self.assertRaises(ValueError):
g, _ = BMGInference().to_graph(queries, observations)
# observed = g.to_dot()
# expected = ""
# self.assertEqual(expected.strip(), observed.strip())
observed = str(ex.exception)
self.assertEqual(expected.strip(), observed.strip())
7 changes: 6 additions & 1 deletion tests/ppl/compiler/fix_vectorized_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,12 @@ def test_fix_vectorized_models_6(self) -> None:
"""
self.assertEqual(expected.strip(), observed.strip())

def test_fix_vectorized_models_7(self) -> None:
def _disabled_test_fix_vectorized_models_7(self) -> None:
# TODO: This test is disabled until broadcasting semantics
# are correctly implemented. This model generates an incorrect
# BMG graph right now because it tries to add a 2x1 matrix of
# probabilities to a 2x2 matrix of reals without inserting the
# necessary broadcasting and type conversion nodes.
self.maxDiff = None
observations = {}
queries = [operators()]
Expand Down
Loading