From 87a1c9b72ffba18cd47ad27394a408a66570c1c6 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 20 Mar 2025 14:33:50 -0400 Subject: [PATCH] Update tests for new cmdstan defaults --- test/test_log_prob.py | 2 +- test/test_pathfinder.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_log_prob.py b/test/test_log_prob.py index bed2f57b..046dd3af 100644 --- a/test/test_log_prob.py +++ b/test/test_log_prob.py @@ -30,7 +30,7 @@ ["-5.5395901199", "-1.4903938392"], ), (3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]), - (None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"]), + (None, ["-7.0214668", "-1.1884726"], ["-5.5395901", "-1.4903938"]), ], ) def test_lp_good( diff --git a/test/test_pathfinder.py b/test/test_pathfinder.py index dfbab0b6..c4d8bb43 100644 --- a/test/test_pathfinder.py +++ b/test/test_pathfinder.py @@ -36,7 +36,7 @@ def test_pathfinder_outputs(): assert pathfinder.is_resampled - assert pathfinder.draws().shape == (draws, 3) + assert pathfinder.draws().shape == (draws, 4) def test_pathfinder_from_csv(): @@ -159,7 +159,7 @@ def test_pathfinder_no_psis(): pathfinder = bern_model.pathfinder(data=jdata, psis_resample=False) assert not pathfinder.is_resampled - assert pathfinder.draws().shape == (4000, 3) + assert pathfinder.draws().shape == (4000, 4) def test_pathfinder_no_lp_calc(): @@ -170,7 +170,7 @@ def test_pathfinder_no_lp_calc(): pathfinder = bern_model.pathfinder(data=jdata, calculate_lp=False) assert not pathfinder.is_resampled - assert pathfinder.draws().shape == (4000, 3) + assert pathfinder.draws().shape == (4000, 4) n_lp_nan = np.sum(np.isnan(pathfinder.method_variables()['lp__'])) assert n_lp_nan < 4000 # some lp still calculated during pathfinder assert n_lp_nan > 3000 # but most are not @@ -190,4 +190,4 @@ def test_pathfinder_threads(): stan_file=stan, cpp_options={'STAN_THREADS': True}, force_compile=True ) pathfinder = bern_model.pathfinder(data=jdata, num_threads=4) - assert pathfinder.draws().shape == (1000, 3) + assert pathfinder.draws().shape == (1000, 4)