-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Description
The way we derive the logp for the graph in this test is not robust to non-lazy eval:
def test_ifelse_mixture_multiple_components():
rng = np.random.default_rng(968)
if_var = pt.scalar("if_var", dtype="bool")
comp_then1 = pt.random.normal(size=(2,), name="comp_true1")
comp_then2 = comp_then1 + pt.random.normal(size=(2, 2))
comp_then2.name = "comp_then2"
comp_else1 = pt.random.halfnormal(size=(4,), name="comp_else1")
comp_else2 = pt.random.halfnormal(size=(4, 4), name="comp_else2")
mix_rv1, mix_rv2 = ifelse(
if_var, [comp_then1, comp_then2], [comp_else1, comp_else2], name="mix"
)
mix_rv1.name = "mix1"
mix_rv2.name = "mix2"
mix_vv1 = mix_rv1.clone()
mix_vv2 = mix_rv2.clone()
mix_logp1, mix_logp2 = conditional_logp({mix_rv1: mix_vv1, mix_rv2: mix_vv2}).values()
assert_no_rvs(mix_logp1)
assert_no_rvs(mix_logp2)
fn = function([if_var, mix_vv1, mix_vv2], mix_logp1.sum() + mix_logp2.sum())
mix_vv1_test = np.abs(rng.normal(size=(2,)))
mix_vv2_test = np.abs(rng.normal(size=(2, 2)))
np.testing.assert_almost_equal(
fn(True, mix_vv1_test, mix_vv2_test),
sp.norm(0, 1).logpdf(mix_vv1_test).sum()
+ sp.norm(mix_vv1_test, 1).logpdf(mix_vv2_test).sum(),
)
mix_vv1_test = np.abs(rng.normal(size=(4,)))
mix_vv2_test = np.abs(rng.normal(size=(4, 4)))
np.testing.assert_almost_equal(
fn(False, mix_vv1_test, mix_vv2_test),
sp.halfnorm(0, 1).logpdf(mix_vv1_test).sum() + sp.halfnorm(0, 1).logpdf(mix_vv2_test).sum(),
)Specifically, what we do is split the ifelse with multiple variables into multiple nested ifelse.
So we end up with ifelse(if_var, comp_then1, comp_else_1) and ifelse(if_var, comp_then2, comp_else2)`.
Note that comp_then2 is a function of comp_then1. So far so good, but our rewrite makes it a function of the new ifelse, which can technically have as output comp_else1 which has an incompatible shape. In a non-lazy backend, ifelse evaluates both branches and then picks the one that the condition variable demands. This leads to a shape error in the last fn eval in the test.
A more robust approach is to respect the original compound of the ifelse, and define it's logp as a joint function of two RVs. The splitting into separate ifelse was for developer convenience, as it makes the logp function simpler. But we can handle it there instead, we just need to do something like stack does, as the "value" of one of the variables may be needed to compute the logp of the other (as is the case here).
Lines 136 to 137 in b935d0d
| # If the stacked variables depend on each other, we have to replace them by the respective values | |
| logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_split_values) |
The rewrite that does the splitting is here:
Lines 464 to 532 in b935d0d
| @node_rewriter([IfElse]) | |
| def split_valued_ifelse(fgraph, node): | |
| """Split valued variables in multi-output ifelse into their own ifelse.""" | |
| op = node.op | |
| if op.n_outs == 1: | |
| # Single outputs IfElse | |
| return None | |
| valued_output_nodes = get_related_valued_nodes(fgraph, node) | |
| if not valued_output_nodes: | |
| return None | |
| cond, *all_outputs = node.inputs | |
| then_outputs = all_outputs[: op.n_outs] | |
| else_outputs = all_outputs[op.n_outs :] | |
| # Split first topological valued output | |
| then_else_valued_outputs = [] | |
| for valued_output_node in valued_output_nodes: | |
| rv, value = valued_output_node.inputs | |
| [valued_out] = valued_output_node.outputs | |
| rv_idx = node.outputs.index(rv) | |
| then_else_valued_outputs.append( | |
| ( | |
| then_outputs[rv_idx], | |
| else_outputs[rv_idx], | |
| value, | |
| valued_out, | |
| ) | |
| ) | |
| toposort = fgraph.toposort() | |
| then_else_valued_outputs = sorted( | |
| then_else_valued_outputs, | |
| key=lambda x: max(toposort.index(x[0].owner), toposort.index(x[1].owner)), | |
| ) | |
| (first_then, first_else, first_value_var, first_valued_out), *remaining_vars = ( | |
| then_else_valued_outputs | |
| ) | |
| first_ifelse = ifelse(cond, first_then, first_else) | |
| first_valued_ifelse = valued_rv(first_ifelse, first_value_var) | |
| replacements = {first_valued_out: first_valued_ifelse} | |
| if remaining_vars: | |
| first_ifelse_ancestors = {a for a in ancestors((first_then, first_else)) if a.owner} | |
| remaining_thens = [then_out for (then_out, _, _, _) in remaining_vars] | |
| remaininng_elses = [else_out for (_, else_out, _, _) in remaining_vars] | |
| if set(remaining_thens + remaininng_elses) & first_ifelse_ancestors: | |
| # IfElse graph cannot be split, because some remaining variables are inputs to first ifelse | |
| return None | |
| remaining_ifelses = ifelse(cond, remaining_thens, remaininng_elses) | |
| # Replace potential dependencies on first_then, first_else in remaining ifelse by first_valued_ifelse | |
| dummy_first_valued_ifelse = first_valued_ifelse.type() | |
| temp_fgraph = FunctionGraph( | |
| outputs=[*remaining_ifelses, dummy_first_valued_ifelse], clone=False | |
| ) | |
| temp_fgraph.replace(first_then, dummy_first_valued_ifelse) | |
| temp_fgraph.replace(first_else, dummy_first_valued_ifelse) | |
| temp_fgraph.replace(dummy_first_valued_ifelse, first_valued_ifelse, import_missing=True) | |
| for remaining_ifelse, (_, _, remaining_value_var, remaining_valued_out) in zip( | |
| remaining_ifelses, remaining_vars | |
| ): | |
| remaining_valued_ifelse = valued_rv(remaining_ifelse, remaining_value_var) | |
| replacements[remaining_valued_out] = remaining_valued_ifelse | |
| return replacements |
Note we can (and probably still) want to exclude any non non-valued branches from the ifelse to simplify logic.