Skip to content

Commit 376a954

Browse files
committed
clean up
1 parent 0b81a27 commit 376a954

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def my_initializer(shape: tuple[int, int], dtype: Any):
309309
keras.backend.backend() != "jax",
310310
reason="Backend specific test",
311311
)
312-
class DistributedEmbeddingLayerTest(testing.TestCase, parameterized.TestCase):
312+
class DistributedEmbeddingLayerTest(parameterized.TestCase):
313313
@parameterized.product(
314314
ragged=[True, False],
315315
combiner=["sum", "mean", "sqrtn"],
@@ -327,7 +327,6 @@ def test_call(
327327
table_stacking: str | list[str] | list[list[str]],
328328
jit: bool,
329329
):
330-
self.on_tpu = "TPU_NAME" in os.environ
331330
table_configs = keras_test_utils.create_random_table_configs(
332331
combiner=combiner, seed=10
333332
)
@@ -376,9 +375,7 @@ def test_call(
376375
)
377376

378377
keras.tree.map_structure(
379-
lambda a, b: self.assertAllClose(
380-
a, b, atol=1e-3, is_tpu=self.on_tpu
381-
),
378+
lambda a, b: np.testing.assert_allclose(a, b, atol=1e-5),
382379
outputs,
383380
expected_outputs,
384381
)

0 commit comments

Comments
 (0)