Skip to content

Commit 58c7897

Browse files
committed
use a shared strategy in conftest.py
1 parent b634f78 commit 58c7897

File tree

9 files changed

+119
-24
lines changed

9 files changed

+119
-24
lines changed

conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pytest
2+
import os
3+
from keras_rs.src.utils import tpu_test_utils
4+
5+
@pytest.fixture(scope="session", autouse=True)
6+
def prime_shared_tpu_strategy(request):
7+
"""
8+
Eagerly initializes the shared TPU strategy at the beginning of the session
9+
if running on a TPU. This helps catch initialization errors early.
10+
"""
11+
strategy = tpu_test_utils.get_shared_tpu_strategy()
12+
if not strategy:
13+
pytest.fail(
14+
"Failed to initialize shared TPUStrategy for the test session. "
15+
"Check logs for details from create_tpu_strategy."
16+
)

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def setUp(self):
5353
# FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16
5454

5555
self.batch_size = (
56-
BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync
56+
BATCH_SIZE_PER_CORE * self.strategy.num_replicas_in_sync
5757
)
5858

5959
def get_embedding_config(self, input_type, placement):
@@ -194,11 +194,11 @@ def test_basics(self, input_type, placement):
194194

195195
if placement == "sparsecore" and not self.on_tpu:
196196
with self.assertRaisesRegex(Exception, "sparsecore"):
197-
with self._strategy.scope():
197+
with self.strategy.scope():
198198
distributed_embedding.DistributedEmbedding(feature_configs)
199199
return
200200

201-
with self._strategy.scope():
201+
with self.strategy.scope():
202202
layer = distributed_embedding.DistributedEmbedding(feature_configs)
203203

204204
if keras.backend.backend() == "jax":
@@ -276,7 +276,7 @@ def test_model_fit(self, input_type, use_weights):
276276
(test_model_inputs, test_labels)
277277
)
278278

279-
with self._strategy.scope():
279+
with self.strategy.scope():
280280
layer = distributed_embedding.DistributedEmbedding(feature_configs)
281281

282282
def _create_keras_input(
@@ -347,7 +347,7 @@ def test_dataset_generator():
347347
# New preprocessed data removes the `weights` component.
348348
dataset_has_weights = False
349349
else:
350-
train_dataset = self._strategy.experimental_distribute_dataset(
350+
train_dataset = self.strategy.experimental_distribute_dataset(
351351
train_dataset,
352352
options=tf.distribute.InputOptions(
353353
experimental_fetch_to_device=False
@@ -362,7 +362,7 @@ def test_dataset_generator():
362362
inputs=keras_model_inputs, outputs=keras_model_outputs
363363
)
364364

365-
with self._strategy.scope():
365+
with self.strategy.scope():
366366
model.compile(optimizer="adam", loss="mse")
367367

368368
model_inputs, _ = next(iter(test_dataset))
@@ -511,7 +511,7 @@ def test_correctness(
511511
if not use_weights:
512512
weights = None
513513

514-
with self._strategy.scope():
514+
with self.strategy.scope():
515515
layer = distributed_embedding.DistributedEmbedding(feature_config)
516516

517517
if keras.backend.backend() == "jax":
@@ -568,7 +568,7 @@ def test_correctness(
568568

569569
self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM))
570570

571-
with self._strategy.scope():
571+
with self.strategy.scope():
572572
tables = layer.get_embedding_tables()
573573

574574
emb = tables["table"]
@@ -633,11 +633,11 @@ def test_shared_table(self):
633633
"dense", embedding_config
634634
)
635635

636-
with self._strategy.scope():
636+
with self.strategy.scope():
637637
layer = distributed_embedding.DistributedEmbedding(embedding_config)
638638

639639
res = tpu_test_utils.run_with_strategy(
640-
self._strategy, layer.__call__, inputs
640+
self.strategy, layer.__call__, inputs
641641
)
642642

643643
if self.placement == "default_device":
@@ -709,11 +709,11 @@ def test_mixed_placement(self):
709709
"dense", embedding_config
710710
)
711711

712-
with self._strategy.scope():
712+
with self.strategy.scope():
713713
layer = distributed_embedding.DistributedEmbedding(embedding_config)
714714

715715
res = tpu_test_utils.run_with_strategy(
716-
self._strategy, layer.__call__, inputs
716+
self.strategy, layer.__call__, inputs
717717
)
718718

719719
self.assertEqual(
@@ -740,22 +740,22 @@ def test_save_load_model(self):
740740
with tempfile.TemporaryDirectory() as temp_dir:
741741
path = os.path.join(temp_dir, "model.keras")
742742

743-
with self._strategy.scope():
743+
with self.strategy.scope():
744744
layer = distributed_embedding.DistributedEmbedding(
745745
feature_configs
746746
)
747747
keras_outputs = layer(keras_inputs)
748748
model = keras.Model(inputs=keras_inputs, outputs=keras_outputs)
749749

750750
output_before = tpu_test_utils.run_with_strategy(
751-
self._strategy, model.__call__, inputs
751+
self.strategy, model.__call__, inputs
752752
)
753753
model.save(path)
754754

755-
with self._strategy.scope():
755+
with self.strategy.scope():
756756
reloaded_model = keras.models.load_model(path)
757757
output_after = tpu_test_utils.run_with_strategy(
758-
self._strategy, reloaded_model.__call__, inputs
758+
self.strategy, reloaded_model.__call__, inputs
759759
)
760760

761761
if self.placement == "sparsecore":

keras_rs/src/losses/list_mle_loss_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
class ListMLELossTest(testing.TestCase, parameterized.TestCase):
1212
def setUp(self):
13+
super().setUp()
1314
self.unbatched_scores = ops.array(
1415
[1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32"
1516
)

keras_rs/src/losses/pairwise_hinge_loss_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
class PairwiseHingeLossTest(testing.TestCase, parameterized.TestCase):
1212
def setUp(self):
13+
super().setUp()
1314
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
1415
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])
1516

keras_rs/src/losses/pairwise_logistic_loss_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
class PairwiseLogisticLossTest(testing.TestCase, parameterized.TestCase):
1212
def setUp(self):
13+
super().setUp()
1314
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
1415
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])
1516

keras_rs/src/losses/pairwise_mean_squared_error_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
class PairwiseMeanSquaredErrorTest(testing.TestCase, parameterized.TestCase):
1414
def setUp(self):
15+
super().setUp()
1516
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
1617
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])
1718

keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
class PairwiseSoftZeroOneLossTest(testing.TestCase, parameterized.TestCase):
1414
def setUp(self):
15+
super().setUp()
1516
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
1617
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])
1718

keras_rs/src/testing/test_case.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import tempfile
33
import unittest
4-
from typing import Any
4+
from typing import Any, Optional, Union
55

66
import keras
77
import numpy as np
@@ -10,6 +10,12 @@
1010
from keras_rs.src import types
1111
from keras_rs.src.utils import tpu_test_utils
1212

13+
StrategyType = Union[
14+
tf.distribute.Strategy,
15+
tpu_test_utils.DummyStrategy,
16+
tpu_test_utils.JaxDummyStrategy,
17+
]
18+
1319

1420
class TestCase(unittest.TestCase):
1521
"""TestCase class for all Keras Recommenders tests."""
@@ -21,22 +27,32 @@ def setUp(self) -> None:
2127
if keras.backend.backend() == "tensorflow":
2228
tf.debugging.disable_traceback_filtering()
2329
self.on_tpu = "TPU_NAME" in os.environ
30+
self._strategy: Optional[StrategyType] = None
2431

2532
@property
26-
def strategy(self):
27-
if hasattr(self, "_strategy"):
28-
return self._strategy
29-
self._strategy = tpu_test_utils.get_tpu_strategy(self)
30-
return self._strategy
33+
def strategy(self) -> StrategyType:
34+
strat = tpu_test_utils.get_shared_tpu_strategy()
35+
36+
if strat is None:
37+
# This case should ideally be caught by the conftest.py fixture
38+
self.fail(
39+
"TPU environment detected, but the shared TPUStrategy is None. "
40+
"Initialization likely failed."
41+
)
42+
return strat
43+
# if self._strategy is not None:
44+
# return self._strategy
45+
# self._strategy = tpu_test_utils.get_tpu_strategy(self)
46+
# return self._strategy
3147

3248
def assertAllClose(
3349
self,
3450
actual: types.Tensor,
3551
desired: types.Tensor,
3652
atol: float = 1e-6,
3753
rtol: float = 1e-6,
38-
tpu_atol=None,
39-
tpu_rtol=None,
54+
tpu_atol: float = None,
55+
tpu_rtol: float = None,
4056
msg: str = "",
4157
) -> None:
4258
"""Verify that two tensors are close in value element by element.

keras_rs/src/utils/tpu_test_utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import os
3+
import threading
34
from types import ModuleType
45
from typing import Any, Callable, ContextManager, Optional, Tuple, Union
56

@@ -42,6 +43,63 @@ def num_replicas_in_sync(self) -> Any:
4243

4344
StrategyType = Union[tf.distribute.Strategy, DummyStrategy, JaxDummyStrategy]
4445

46+
_shared_strategy: Optional[StrategyType] = None
47+
_lock = threading.Lock()
48+
49+
def create_tpu_strategy() -> Optional[StrategyType]:
50+
"""Initializes the TPU system and returns a TPUStrategy."""
51+
print("Attempting to create TPUStrategy...")
52+
try:
53+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
54+
tf.config.experimental_connect_to_cluster(resolver)
55+
tf.tpu.experimental.initialize_tpu_system(resolver)
56+
strategy = tf.distribute.TPUStrategy(resolver)
57+
print(f"TPUStrategy created successfully. Devices: {strategy.extended.num_replicas_in_sync}")
58+
return strategy
59+
except Exception as e:
60+
print(f"Error creating TPUStrategy: {e}")
61+
return None
62+
63+
def get_shared_tpu_strategy() -> Optional[StrategyType]:
64+
"""
65+
Returns a session-wide shared TPUStrategy instance.
66+
Creates the instance on the first call.
67+
Returns None if not in a TPU environment or if creation fails.
68+
"""
69+
global _shared_strategy
70+
if _shared_strategy is not None:
71+
return _shared_strategy
72+
73+
with _lock:
74+
if _shared_strategy is None:
75+
if "TPU_NAME" not in os.environ:
76+
_shared_strategy = DummyStrategy()
77+
return _shared_strategy
78+
if keras.backend.backend() == "tensorflow":
79+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
80+
tf.config.experimental_connect_to_cluster(resolver)
81+
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
82+
tpu_metadata = resolver.get_tpu_system_metadata()
83+
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
84+
topology, num_replicas=tpu_metadata.num_hosts
85+
)
86+
_shared_strategy = tf.distribute.TPUStrategy(
87+
resolver, experimental_device_assignment=device_assignment
88+
)
89+
print("### num_replicas", _shared_strategy.num_replicas_in_sync)
90+
elif keras.backend.backend() == "jax":
91+
if jax is None:
92+
raise ImportError(
93+
"JAX backend requires jax to be installed for TPU."
94+
)
95+
print("### num_replicas", jax.device_count("tpu"))
96+
_shared_strategy = JaxDummyStrategy()
97+
else:
98+
_shared_strategy = DummyStrategy()
99+
if _shared_strategy is None:
100+
print("Failed to create the shared TPUStrategy.")
101+
return _shared_strategy
102+
45103

46104
def get_tpu_strategy(test_case: Any) -> StrategyType:
47105
"""Get TPU strategy if on TPU, otherwise return DummyStrategy."""

0 commit comments

Comments
 (0)