Skip to content

Commit 91e82e7

Browse files
committed
fix test errors
1 parent e557805 commit 91e82e7

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
keras.config.disable_traceback_filtering()
3030

31+
from keras_rs.src import testing
3132

3233
def _create_sparsecore_layout(
3334
sharding_axis: str = "sparsecore",
@@ -308,7 +309,7 @@ def my_initializer(shape: tuple[int, int], dtype: Any):
308309
keras.backend.backend() != "jax",
309310
reason="Backend specific test",
310311
)
311-
class DistributedEmbeddingLayerTest(parameterized.TestCase):
312+
class DistributedEmbeddingLayerTest(testing.TestCase, parameterized.TestCase):
312313
@parameterized.product(
313314
ragged=[True, False],
314315
combiner=["sum", "mean", "sqrtn"],
@@ -326,6 +327,7 @@ def test_call(
326327
table_stacking: str | list[str] | list[list[str]],
327328
jit: bool,
328329
):
330+
self.on_tpu = "TPU_NAME" in os.environ
329331
table_configs = keras_test_utils.create_random_table_configs(
330332
combiner=combiner, seed=10
331333
)
@@ -374,7 +376,7 @@ def test_call(
374376
)
375377

376378
keras.tree.map_structure(
377-
lambda a, b: np.testing.assert_allclose(a, b, atol=1e-5),
379+
lambda a, b: self.assertAllClose(a, b, atol=1e-3, is_tpu=self.on_tpu),
378380
outputs,
379381
expected_outputs,
380382
)

keras_rs/src/utils/tpu_test_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import contextlib
22
import os
33
from types import ModuleType
4-
from typing import Any, Callable, Optional, Tuple, Union
4+
from typing import Any, Callable, ContextManager, Optional, Tuple, Union
55

66
import keras
77
import tensorflow as tf
@@ -15,24 +15,26 @@
1515

1616

1717
class DummyStrategy:
18-
def scope(self):
18+
def scope(self) -> ContextManager[None]:
1919
return contextlib.nullcontext()
2020

2121
@property
22-
def num_replicas_in_sync(self):
22+
def num_replicas_in_sync(self) -> int:
2323
return 1
2424

25-
def run(self, fn, args):
25+
def run(self, fn: Callable[..., Any], args: Tuple[Any, ...]) -> Any:
2626
return fn(*args)
2727

28-
def experimental_distribute_dataset(self, dataset, options=None):
28+
def experimental_distribute_dataset(
29+
self, dataset: Any, options: Optional[Any] = None
30+
) -> Any:
2931
del options
3032
return dataset
3133

3234

3335
class JaxDummyStrategy(DummyStrategy):
3436
@property
35-
def num_replicas_in_sync(self):
37+
def num_replicas_in_sync(self) -> int:
3638
if jax is None:
3739
return 0
3840
return jax.device_count("tpu")
@@ -87,7 +89,7 @@ def run_with_strategy(
8789
sample_weight_value = kwargs.get("sample_weight", None)
8890
all_inputs = args + (sample_weight_value,)
8991

90-
@tf.function(jit_compile=jit_compile)
92+
@tf.function(jit_compile=jit_compile) # type: ignore[misc]
9193
def tf_function_wrapper(input_tuple: Tuple[Any, ...]) -> Any:
9294
num_original_args = len(args)
9395
core_args = input_tuple[:num_original_args]

0 commit comments

Comments
 (0)