Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/Issue.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name: General Issue
description: Your issue does not fit the categories already mentioned.
labels:
labels:
- help wanted
body:
- type: markdown
attributes:
value: >
If you want to discuss more about the issue, reach us at our [Slack Channel](https://stingraysoftware.slack.com/ssb/redirecthttps://join.slack.com/t/stingraysoftware/shared_invite/zt-49kv4kba-mD1Y~s~rlrOOmvqM7mZugQ)
If you want to discuss more about the issue, reach us at our [Slack Channel](https://join.slack.com/t/stingraysoftware/shared_invite/zt-49kv4kba-mD1Y~s~rlrOOmvqM7mZugQ)
- type: textarea
attributes:
label: Description of the Issue
Expand Down
1 change: 1 addition & 0 deletions docs/changes/976.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The tests `test_get_mean_gaussian` and `test_get_mean_skew_gaussian` used exact floating-point equality (`==`) to compare results, which broke in JAX 0.10.x due to a 1-ULP change in how XLA compiles broadcast array division; replaced with `np.testing.assert_allclose`.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By "JAX 0.10.x", do you mean all possible versions starting from the 0.10.0? I see that JAX had a newer release on May, 20. I am wondering whether this bug still affects the latest version or only older ones.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, all new versions since 0.10.0 have this problem, so this test will never work with the old syntax

20 changes: 10 additions & 10 deletions stingray/modeling/tests/test_gpmodeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,18 @@ def test_get_mean_gaussian(self):
result_gaussian = 3 * jnp.exp(-((self.t - 0.2) ** 2) / (2 * (0.2**2))) + 4 * jnp.exp(
-((self.t - 0.7) ** 2) / (2 * (0.1**2))
)
assert (get_mean("gaussian", self.mean_params)(self.t) == result_gaussian).all()
assert np.allclose(get_mean("gaussian", self.mean_params)(self.t), result_gaussian)

def test_get_mean_exponential(self):
result_exponential = 3 * jnp.exp(-jnp.abs(self.t - 0.2) / (2 * (0.2**2))) + 4 * jnp.exp(
-jnp.abs(self.t - 0.7) / (2 * (0.1**2))
)
assert (get_mean("exponential", self.mean_params)(self.t) == result_exponential).all()
assert np.allclose(get_mean("exponential", self.mean_params)(self.t), result_exponential)

def test_get_mean_constant(self):
result_constant = 3 * jnp.ones_like(self.t)
const_param_dict = {"A": jnp.array([3.0])}
assert (get_mean("constant", const_param_dict)(self.t) == result_constant).all()
assert np.allclose(get_mean("constant", const_param_dict)(self.t), result_constant)

def test_get_mean_skew_gaussian(self):
result_skew_gaussian = 3.0 * jnp.where(
Expand All @@ -140,9 +140,9 @@ def test_get_mean_skew_gaussian(self):
jnp.exp(-((self.t - 0.7) ** 2) / (2 * (0.4**2))),
jnp.exp(-((self.t - 0.7) ** 2) / (2 * (0.1**2))),
)
assert (
get_mean("skew_gaussian", self.skew_mean_params)(self.t) == result_skew_gaussian
).all()
assert np.allclose(
get_mean("skew_gaussian", self.skew_mean_params)(self.t), result_skew_gaussian
)

def test_get_mean_skew_exponential(self):
result_skew_exponential = 3.0 * jnp.where(
Expand All @@ -154,15 +154,15 @@ def test_get_mean_skew_exponential(self):
jnp.exp(-jnp.abs(self.t - 0.7) / (2 * (0.4**2))),
jnp.exp(-jnp.abs(self.t - 0.7) / (2 * (0.1**2))),
)
assert (
get_mean("skew_exponential", self.skew_mean_params)(self.t) == result_skew_exponential
).all()
assert np.allclose(
get_mean("skew_exponential", self.skew_mean_params)(self.t), result_skew_exponential
)

def test_get_mean_fred(self):
result_fred = 3.0 * jnp.exp(-4.0 * ((self.t + 0.3) / 0.2 + 0.2 / (self.t + 0.3))) * jnp.exp(
2 * 4.0
) + 4.0 * jnp.exp(-5.0 * ((self.t + 0.4) / 0.7 + 0.7 / (self.t + 0.4))) * jnp.exp(2 * 5.0)
assert (get_mean("fred", self.fred_mean_params)(self.t) == result_fred).all()
assert np.allclose(get_mean("fred", self.fred_mean_params)(self.t), result_fred)

def test_value_error(self):
with pytest.raises(ValueError, match="Mean type not implemented"):
Expand Down
Loading