Skip to content

Make ifelse logp robust to non-lazy backends #8036

@ricardoV94

Description

@ricardoV94

Description

The way we derive the logp for the graph in this test is not robust to non-lazy eval:

def test_ifelse_mixture_multiple_components():
    rng = np.random.default_rng(968)

    if_var = pt.scalar("if_var", dtype="bool")
    comp_then1 = pt.random.normal(size=(2,), name="comp_true1")
    comp_then2 = comp_then1 + pt.random.normal(size=(2, 2))
    comp_then2.name = "comp_then2"
    comp_else1 = pt.random.halfnormal(size=(4,), name="comp_else1")
    comp_else2 = pt.random.halfnormal(size=(4, 4), name="comp_else2")

    mix_rv1, mix_rv2 = ifelse(
        if_var, [comp_then1, comp_then2], [comp_else1, comp_else2], name="mix"
    )
    mix_rv1.name = "mix1"
    mix_rv2.name = "mix2"
    mix_vv1 = mix_rv1.clone()
    mix_vv2 = mix_rv2.clone()
    mix_logp1, mix_logp2 = conditional_logp({mix_rv1: mix_vv1, mix_rv2: mix_vv2}).values()
    assert_no_rvs(mix_logp1)
    assert_no_rvs(mix_logp2)

    fn = function([if_var, mix_vv1, mix_vv2], mix_logp1.sum() + mix_logp2.sum())
    mix_vv1_test = np.abs(rng.normal(size=(2,)))
    mix_vv2_test = np.abs(rng.normal(size=(2, 2)))
    np.testing.assert_almost_equal(
        fn(True, mix_vv1_test, mix_vv2_test),
        sp.norm(0, 1).logpdf(mix_vv1_test).sum()
        + sp.norm(mix_vv1_test, 1).logpdf(mix_vv2_test).sum(),
    )
    mix_vv1_test = np.abs(rng.normal(size=(4,)))
    mix_vv2_test = np.abs(rng.normal(size=(4, 4)))
    np.testing.assert_almost_equal(
        fn(False, mix_vv1_test, mix_vv2_test),
        sp.halfnorm(0, 1).logpdf(mix_vv1_test).sum() + sp.halfnorm(0, 1).logpdf(mix_vv2_test).sum(),
    )

Specifically, what we do is split the ifelse with multiple variables into multiple nested ifelse.
So we end up with ifelse(if_var, comp_then1, comp_else_1) and ifelse(if_var, comp_then2, comp_else2)`.

Note that comp_then2 is a function of comp_then1. So far so good, but our rewrite makes it a function of the new ifelse, which can technically have as output comp_else1 which has an incompatible shape. In a non-lazy backend, ifelse evaluates both branches and then picks the one that the condition variable demands. This leads to a shape error in the last fn eval in the test.

A more robust approach is to respect the original compound of the ifelse, and define it's logp as a joint function of two RVs. The splitting into separate ifelse was for developer convenience, as it makes the logp function simpler. But we can handle it there instead, we just need to do something like stack does, as the "value" of one of the variables may be needed to compute the logp of the other (as is the case here).

pymc/pymc/logprob/tensor.py

Lines 136 to 137 in b935d0d

# If the stacked variables depend on each other, we have to replace them by the respective values
logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_split_values)

The rewrite that does the splitting is here:

@node_rewriter([IfElse])
def split_valued_ifelse(fgraph, node):
"""Split valued variables in multi-output ifelse into their own ifelse."""
op = node.op
if op.n_outs == 1:
# Single outputs IfElse
return None
valued_output_nodes = get_related_valued_nodes(fgraph, node)
if not valued_output_nodes:
return None
cond, *all_outputs = node.inputs
then_outputs = all_outputs[: op.n_outs]
else_outputs = all_outputs[op.n_outs :]
# Split first topological valued output
then_else_valued_outputs = []
for valued_output_node in valued_output_nodes:
rv, value = valued_output_node.inputs
[valued_out] = valued_output_node.outputs
rv_idx = node.outputs.index(rv)
then_else_valued_outputs.append(
(
then_outputs[rv_idx],
else_outputs[rv_idx],
value,
valued_out,
)
)
toposort = fgraph.toposort()
then_else_valued_outputs = sorted(
then_else_valued_outputs,
key=lambda x: max(toposort.index(x[0].owner), toposort.index(x[1].owner)),
)
(first_then, first_else, first_value_var, first_valued_out), *remaining_vars = (
then_else_valued_outputs
)
first_ifelse = ifelse(cond, first_then, first_else)
first_valued_ifelse = valued_rv(first_ifelse, first_value_var)
replacements = {first_valued_out: first_valued_ifelse}
if remaining_vars:
first_ifelse_ancestors = {a for a in ancestors((first_then, first_else)) if a.owner}
remaining_thens = [then_out for (then_out, _, _, _) in remaining_vars]
remaininng_elses = [else_out for (_, else_out, _, _) in remaining_vars]
if set(remaining_thens + remaininng_elses) & first_ifelse_ancestors:
# IfElse graph cannot be split, because some remaining variables are inputs to first ifelse
return None
remaining_ifelses = ifelse(cond, remaining_thens, remaininng_elses)
# Replace potential dependencies on first_then, first_else in remaining ifelse by first_valued_ifelse
dummy_first_valued_ifelse = first_valued_ifelse.type()
temp_fgraph = FunctionGraph(
outputs=[*remaining_ifelses, dummy_first_valued_ifelse], clone=False
)
temp_fgraph.replace(first_then, dummy_first_valued_ifelse)
temp_fgraph.replace(first_else, dummy_first_valued_ifelse)
temp_fgraph.replace(dummy_first_valued_ifelse, first_valued_ifelse, import_missing=True)
for remaining_ifelse, (_, _, remaining_value_var, remaining_valued_out) in zip(
remaining_ifelses, remaining_vars
):
remaining_valued_ifelse = valued_rv(remaining_ifelse, remaining_value_var)
replacements[remaining_valued_out] = remaining_valued_ifelse
return replacements

Note we can (and probably still) want to exclude any non non-valued branches from the ifelse to simplify logic.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions