|
14 | 14 | from collections.abc import Sequence
|
15 | 15 |
|
16 | 16 | from pytensor import Variable, clone_replace
|
| 17 | +from pytensor.compile import SharedVariable |
17 | 18 | from pytensor.graph import ancestors
|
18 | 19 | from pytensor.graph.fg import FunctionGraph
|
19 | 20 |
|
20 |
| -from pymc.data import MinibatchOp |
| 21 | +from pymc.data import Minibatch, MinibatchOp |
21 | 22 | from pymc.model.core import Model
|
22 | 23 | from pymc.model.fgraph import (
|
23 | 24 | ModelObservedRV,
|
24 | 25 | ModelVar,
|
| 26 | + extract_dims, |
25 | 27 | fgraph_from_model,
|
26 | 28 | model_from_fgraph,
|
| 29 | + model_observed_rv, |
27 | 30 | )
|
| 31 | +from pymc.pytensorf import toposort_replace |
28 | 32 |
|
29 | 33 | ModelVariable = Variable | str
|
30 | 34 |
|
@@ -62,6 +66,47 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
|
62 | 66 | return [model[var] if isinstance(var, str) else var for var in vars_seq]
|
63 | 67 |
|
64 | 68 |
|
| 69 | +def model_to_minibatch(model: Model, batch_size: int) -> Model: |
| 70 | + """Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs.""" |
| 71 | + from pymc.variational.minibatch_rv import create_minibatch_rv |
| 72 | + |
| 73 | + fgraph, memo = fgraph_from_model(model, inlined_views=True) |
| 74 | + |
| 75 | + # obs_rvs, data_vars = model.rvs_to_values.items() |
| 76 | + |
| 77 | + data_vars = [ |
| 78 | + memo[datum].owner.inputs[0] |
| 79 | + for datum in (model.named_vars[datum_name] for datum_name in model.named_vars) |
| 80 | + if isinstance(datum, SharedVariable) |
| 81 | + ] |
| 82 | + |
| 83 | + minibatch_vars = Minibatch(*data_vars, batch_size=batch_size) |
| 84 | + replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)} |
| 85 | + assert 0 |
| 86 | + # Add total_size to all observed RVs |
| 87 | + total_size = data_vars[0].get_value().shape[0] |
| 88 | + for obs_var in model.observed_RVs: |
| 89 | + model_var = memo[obs_var] |
| 90 | + var = model_var.owner.inputs[0] |
| 91 | + var.name = model_var.name |
| 92 | + dims = extract_dims(model_var) |
| 93 | + |
| 94 | + new_rv = create_minibatch_rv(var, total_size=total_size) |
| 95 | + new_rv.name = var.name |
| 96 | + |
| 97 | + replacements[model_var] = model_observed_rv(new_rv, model.rvs_to_values[obs_var], *dims) |
| 98 | + |
| 99 | + # old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths |
| 100 | + toposort_replace(fgraph, tuple(replacements.items())) |
| 101 | + # new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] |
| 102 | + |
| 103 | + # fgraph = FunctionGraph(outputs=new_outs, clone=False) |
| 104 | + # fgraph._coords = old_coords # type: ignore[attr-defined] |
| 105 | + # fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] |
| 106 | + |
| 107 | + return model_from_fgraph(fgraph, mutate_fgraph=True) |
| 108 | + |
| 109 | + |
65 | 110 | def remove_minibatched_nodes(model: Model) -> Model:
|
66 | 111 | """Remove all uses of pm.Minibatch in the Model."""
|
67 | 112 | fgraph, _ = fgraph_from_model(model)
|
|
0 commit comments