Skip to content

Commit 77ebf98

Browse files
lukebaumanncopybara-github
authored andcommitted
Fix Pathwaysutils tests to run with the JAX cpu platform.
PiperOrigin-RevId: 726291291
1 parent 08cacfe commit 77ebf98

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

pathwaysutils/test/persistence_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ class PersistenceTest(absltest.TestCase):
1414
name = "name"
1515
dtype = np.dtype(np.int32)
1616
shape = [8, 4]
17-
timeout = datetime.timedelta(seconds=30)
17+
timeout = datetime.timedelta(seconds=3)
18+
19+
def setUp(self):
20+
jax.config.update("jax_platforms", "cpu")
21+
super().setUp()
1822

1923
def test_get_read_request(self):
2024
devices = jax.devices()

pathwaysutils/test/plugin_executable_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
class PluginExecutableTest(absltest.TestCase):
1515

16+
def setUp(self):
17+
jax.config.update("jax_platforms", "cpu")
18+
super().setUp()
19+
1620
def test_bad_json_program(self):
1721
with self.assertRaisesRegex(XlaRuntimeError, "INVALID_ARGUMENT"):
1822
PluginExecutable('{"printTextRequest":{"badParamName":"foo"}}')

0 commit comments

Comments
 (0)