Skip to content

Commit 8d1b479

Browse files
jessegrabowskizaxtax
authored andcommitted
initial PR
1 parent e6d3390 commit 8d1b479

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

pymc/model/transform/basic.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,21 @@
1414
from collections.abc import Sequence
1515

1616
from pytensor import Variable, clone_replace
17+
from pytensor.compile import SharedVariable
1718
from pytensor.graph import ancestors
1819
from pytensor.graph.fg import FunctionGraph
1920

20-
from pymc.data import MinibatchOp
21+
from pymc.data import Minibatch, MinibatchOp
2122
from pymc.model.core import Model
2223
from pymc.model.fgraph import (
2324
ModelObservedRV,
2425
ModelVar,
26+
extract_dims,
2527
fgraph_from_model,
2628
model_from_fgraph,
29+
model_observed_rv,
2730
)
31+
from pymc.pytensorf import toposort_replace
2832

2933
ModelVariable = Variable | str
3034

@@ -62,6 +66,47 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
6266
return [model[var] if isinstance(var, str) else var for var in vars_seq]
6367

6468

69+
def model_to_minibatch(model: Model, batch_size: int) -> Model:
70+
"""Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs."""
71+
from pymc.variational.minibatch_rv import create_minibatch_rv
72+
73+
fgraph, memo = fgraph_from_model(model, inlined_views=True)
74+
75+
# obs_rvs, data_vars = model.rvs_to_values.items()
76+
77+
data_vars = [
78+
memo[datum].owner.inputs[0]
79+
for datum in (model.named_vars[datum_name] for datum_name in model.named_vars)
80+
if isinstance(datum, SharedVariable)
81+
]
82+
83+
minibatch_vars = Minibatch(*data_vars, batch_size=batch_size)
84+
replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)}
85+
assert 0
86+
# Add total_size to all observed RVs
87+
total_size = data_vars[0].get_value().shape[0]
88+
for obs_var in model.observed_RVs:
89+
model_var = memo[obs_var]
90+
var = model_var.owner.inputs[0]
91+
var.name = model_var.name
92+
dims = extract_dims(model_var)
93+
94+
new_rv = create_minibatch_rv(var, total_size=total_size)
95+
new_rv.name = var.name
96+
97+
replacements[model_var] = model_observed_rv(new_rv, model.rvs_to_values[obs_var], *dims)
98+
99+
# old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths
100+
toposort_replace(fgraph, tuple(replacements.items()))
101+
# new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type]
102+
103+
# fgraph = FunctionGraph(outputs=new_outs, clone=False)
104+
# fgraph._coords = old_coords # type: ignore[attr-defined]
105+
# fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined]
106+
107+
return model_from_fgraph(fgraph, mutate_fgraph=True)
108+
109+
65110
def remove_minibatched_nodes(model: Model) -> Model:
66111
"""Remove all uses of pm.Minibatch in the Model."""
67112
fgraph, _ = fgraph_from_model(model)

tests/model/transform/test_basic.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515

1616
import pymc as pm
1717

18-
from pymc.model.transform.basic import prune_vars_detached_from_observed, remove_minibatched_nodes
18+
from pymc.model.transform.basic import (
19+
model_to_minibatch,
20+
prune_vars_detached_from_observed,
21+
remove_minibatched_nodes,
22+
)
1923

2024

2125
def test_prune_vars_detached_from_observed():
@@ -34,6 +38,28 @@ def test_prune_vars_detached_from_observed():
3438
assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"}
3539

3640

41+
def test_model_to_minibatch():
42+
data_size = 100
43+
n_features = 4
44+
45+
obs_data = np.zeros((data_size,))
46+
X_data = np.random.normal(size=(data_size, n_features))
47+
48+
with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m1:
49+
obs_data = pm.Data("obs_data", obs_data, dims=["data_dim"])
50+
X_data = pm.Data("X_data", X_data, dims=["data_dim", "feature"])
51+
beta = pm.Normal("beta", dims="feature")
52+
53+
mu = X_data @ beta
54+
55+
y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim")
56+
57+
m2 = model_to_minibatch(m1, batch_size=10)
58+
m2["y"].dprint()
59+
60+
assert 0
61+
62+
3763
def test_remove_minibatches():
3864
data_size = 100
3965
data = np.zeros((data_size,))

0 commit comments

Comments
 (0)