@@ -256,3 +256,70 @@ def test_categorical_model(separate_trees, split_rule):
256
256
# Fit should be good enough so right category is selected over 50% of time
257
257
assert (idata .predictions .y .median (["chain" , "draw" ]) == Y ).all ()
258
258
assert pmb .compute_variable_importance (idata , bartrv = lo , X = X )["preds" ].shape == (5 , 50 , 9 , 3 )
259
+
260
+
261
+ def test_multiple_bart_variables ():
262
+ """Test that multiple BART variables can coexist in a single model."""
263
+ X1 = np .random .normal (0 , 1 , size = (50 , 2 ))
264
+ X2 = np .random .normal (0 , 1 , size = (50 , 3 ))
265
+ Y = np .random .normal (0 , 1 , size = 50 )
266
+
267
+ # Create correlated responses
268
+ Y1 = X1 [:, 0 ] + np .random .normal (0 , 0.1 , size = 50 )
269
+ Y2 = X2 [:, 0 ] + X2 [:, 1 ] + np .random .normal (0 , 0.1 , size = 50 )
270
+
271
+ with pm .Model () as model :
272
+ # Two separate BART variables with different covariates
273
+ mu1 = pmb .BART ("mu1" , X1 , Y1 , m = 5 )
274
+ mu2 = pmb .BART ("mu2" , X2 , Y2 , m = 5 )
275
+
276
+ # Combined model
277
+ sigma = pm .HalfNormal ("sigma" , 1 )
278
+ y = pm .Normal ("y" , mu1 + mu2 , sigma , observed = Y )
279
+
280
+ # Sample with automatic assignment of BART samplers
281
+ idata = pm .sample (tune = 50 , draws = 50 , chains = 1 , random_seed = 3415 )
282
+
283
+ # Verify both BART variables have their own tree collections
284
+ assert hasattr (mu1 .owner .op , "all_trees" )
285
+ assert hasattr (mu2 .owner .op , "all_trees" )
286
+
287
+ # Verify trees are stored separately (different object references)
288
+ assert mu1 .owner .op .all_trees is not mu2 .owner .op .all_trees
289
+
290
+ # Verify sampling worked
291
+ assert idata .posterior ["mu1" ].shape == (1 , 50 , 50 )
292
+ assert idata .posterior ["mu2" ].shape == (1 , 50 , 50 )
293
+
294
+
295
+ def test_multiple_bart_variables_manual_step ():
296
+ """Test that multiple BART variables work with manually assigned PGBART samplers."""
297
+ X1 = np .random .normal (0 , 1 , size = (30 , 2 ))
298
+ X2 = np .random .normal (0 , 1 , size = (30 , 2 ))
299
+ Y = np .random .normal (0 , 1 , size = 30 )
300
+
301
+ # Create simple responses
302
+ Y1 = X1 [:, 0 ] + np .random .normal (0 , 0.1 , size = 30 )
303
+ Y2 = X2 [:, 1 ] + np .random .normal (0 , 0.1 , size = 30 )
304
+
305
+ with pm .Model () as model :
306
+ # Two separate BART variables
307
+ mu1 = pmb .BART ("mu1" , X1 , Y1 , m = 3 )
308
+ mu2 = pmb .BART ("mu2" , X2 , Y2 , m = 3 )
309
+
310
+ # Non-BART variable
311
+ sigma = pm .HalfNormal ("sigma" , 1 )
312
+ y = pm .Normal ("y" , mu1 + mu2 , sigma , observed = Y )
313
+
314
+ # Manually create PGBART samplers for each BART variable
315
+ step1 = pmb .PGBART ([mu1 ], num_particles = 5 )
316
+ step2 = pmb .PGBART ([mu2 ], num_particles = 5 )
317
+
318
+ # Sample with manual step assignment
319
+ idata = pm .sample (tune = 20 , draws = 20 , chains = 1 , step = [step1 , step2 ], random_seed = 3415 )
320
+
321
+ # Verify both variables were sampled
322
+ assert "mu1" in idata .posterior
323
+ assert "mu2" in idata .posterior
324
+ assert idata .posterior ["mu1" ].shape == (1 , 20 , 30 )
325
+ assert idata .posterior ["mu2" ].shape == (1 , 20 , 30 )
0 commit comments