Skip to content

Commit 47ab13d

Browse files
authored
Add TPU tests for JAX and Tensorflow. (#160)
This is using the self-hosted TPU runners. Only the multi-backend `keras_rs/src/layers/embedding/distributed_embedding_test.py` test is run for now. Many other tests have failures. They will be addressed in subsequent PRs. - Modified `distributed_embedding_test.py` to replace the `TPU` flag with a `TPU_NAME` environment variable as plumbing flag through `pytest` is needlessly complicated. - Modified `distributed_embedding_test.py` to not require JAX to be installed when running against the TensorFlow backend.
1 parent eb4a13f commit 47ab13d

File tree

4 files changed

+89
-20
lines changed

4 files changed

+89
-20
lines changed

.github/workflows/actions.yml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,53 @@ jobs:
4545
- name: Test with pytest
4646
run: |
4747
pytest keras_rs/
48+
49+
run_tests_in_container:
50+
name: Test the code on TPU
51+
runs-on: linux-x86-ct6e-44-1tpu
52+
53+
strategy:
54+
fail-fast: false
55+
matrix:
56+
backend: [tensorflow, jax]
57+
58+
container:
59+
image: python:3.11-slim
60+
options: --privileged --network host
61+
62+
steps:
63+
- name: Checkout repository
64+
uses: actions/checkout@v4
65+
66+
- name: Install Dependencies
67+
run: |
68+
pip install --no-cache-dir -U pip && \
69+
pip install --no-cache-dir -r requirements-${{ matrix.backend }}-tpu.txt
70+
71+
- name: Set Keras Backend
72+
run: |
73+
echo "KERAS_BACKEND=${{ matrix.backend }}" >> $GITHUB_ENV
74+
echo "TPU_NAME=local" >> $GITHUB_ENV
75+
76+
- name: Set TF Specific Environment Variables
77+
if: ${{ matrix.backend == 'tensorflow'}}
78+
run: |
79+
echo "PJRT_DEVICE=TPU" >> $GITHUB_ENV
80+
echo "NEXT_PLUGGABLE_DEVICE_USE_C_API=true" >> $GITHUB_ENV
81+
echo "TF_XLA_FLAGS=--tf_mlir_enable_mlir_bridge=true" >> $GITHUB_ENV
82+
pip show libtpu | grep "^Location: " | sed "s/^Location: \(.*\)$/TF_PLUGGABLE_DEVICE_LIBRARY_PATH=\1\/libtpu\/libtpu.so/1" >> $GITHUB_ENV
83+
84+
- name: Verify TF Installation
85+
if: ${{ matrix.backend == 'tensorflow'}}
86+
run: python3 -c "import tensorflow as tf; print('Tensorflow devices:', tf.config.list_logical_devices())"
87+
88+
- name: Verify JAX Installation
89+
if: ${{ matrix.backend == 'jax'}}
90+
run: python3 -c "import jax; print('JAX devices:', jax.devices())"
91+
92+
- name: Test with pytest
93+
run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py
94+
4895
check_format:
4996
name: Check the code format
5097
runs-on: ubuntu-latest

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@
44
import os
55
import tempfile
66

7-
import jax
8-
import jax.experimental.sparse as jax_sparse
9-
import jax.numpy as jnp
107
import keras
118
import numpy as np
129
import tensorflow as tf
13-
from absl import flags
1410
from absl.testing import absltest
1511
from absl.testing import parameterized
1612

@@ -19,8 +15,12 @@
1915
from keras_rs.src.layers.embedding import distributed_embedding
2016
from keras_rs.src.layers.embedding import distributed_embedding_config as config
2117

22-
FLAGS = flags.FLAGS
23-
_TPU = flags.DEFINE_string("tpu", None, "The TPU to use for TPUStrategy.")
18+
try:
19+
import jax
20+
import jax.experimental.sparse as jax_sparse
21+
except ImportError:
22+
jax = None
23+
jax_sparse = None
2424

2525

2626
FEATURE1_EMBEDDING_OUTPUT_DIM = 7
@@ -50,29 +50,32 @@ def experimental_distribute_dataset(self, dataset, options=None):
5050
class JaxDummyStrategy(DummyStrategy):
5151
@property
5252
def num_replicas_in_sync(self):
53-
return len(jax.devices("tpu"))
53+
return jax.device_count("tpu")
54+
55+
56+
def ragged_bool_true(self):
57+
return True
5458

5559

5660
class DistributedEmbeddingTest(testing.TestCase, parameterized.TestCase):
5761
def setUp(self):
5862
super().setUp()
59-
try:
60-
self.on_tpu = _TPU.value is not None
61-
except flags.UnparsedFlagAccessError:
62-
self.on_tpu = False
63-
63+
self.on_tpu = "TPU_NAME" in os.environ
6464
self.placement = "sparsecore" if self.on_tpu else "default_device"
6565

6666
if keras.backend.backend() == "tensorflow":
6767
tf.debugging.disable_traceback_filtering()
6868

6969
if keras.backend.backend() == "tensorflow" and self.on_tpu:
70+
# Workaround for a bug preventing weights from being ragged tensors.
71+
# The fix in TensorFlow was added after 2.19.1:
72+
# https://github.com/tensorflow/tensorflow/commit/185f2f58bafc6410125080264d5d7730e1fa1eb2
73+
tf.RaggedTensor.__bool__ = ragged_bool_true
74+
7075
# FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16
7176
# FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16
7277

73-
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
74-
tpu=_TPU.value
75-
)
78+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
7679
tf.config.experimental_connect_to_cluster(resolver)
7780

7881
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
@@ -187,7 +190,10 @@ def create_tensor(feature_config, op):
187190
sequence_length = feature_config.input_shape[-1]
188191
indices = [[i, i % sequence_length] for i in range(batch_size)]
189192
return jax_sparse.BCOO(
190-
(jnp.asarray(op((batch_size,))), jnp.asarray(indices)),
193+
(
194+
jax.numpy.asarray(op((batch_size,))),
195+
jax.numpy.asarray(indices),
196+
),
191197
shape=(batch_size, sequence_length),
192198
unique_indices=True,
193199
)
@@ -534,16 +540,18 @@ def test_correctness(
534540
elif keras.backend.backend() == "jax":
535541
inputs = jax_sparse.BCOO(
536542
(
537-
jnp.asarray([1, 2, 3, 4, 5] * num_repeats),
538-
jnp.asarray(indices),
543+
jax.numpy.asarray([1, 2, 3, 4, 5] * num_repeats),
544+
jax.numpy.asarray(indices),
539545
),
540546
shape=(self.batch_size, 4),
541547
unique_indices=True,
542548
)
543549
weights = jax_sparse.BCOO(
544550
(
545-
jnp.asarray([1.0, 1.0, 2.0, 3.0, 4.0] * num_repeats),
546-
jnp.asarray(indices),
551+
jax.numpy.asarray(
552+
[1.0, 1.0, 2.0, 3.0, 4.0] * num_repeats
553+
),
554+
jax.numpy.asarray(indices),
547555
),
548556
shape=(self.batch_size, 4),
549557
unique_indices=True,

requirements-jax-tpu.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Tensorflow cpu-only version.
2+
tensorflow-cpu>=2.20.0
3+
4+
# Jax with TPU support.
5+
jax[tpu]
6+
7+
# Support for TPU embeddings.
8+
jax-tpu-embedding
9+
10+
-r requirements-common.txt

requirements-tensorflow-tpu.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Tensorflow with TPU support.
2+
tensorflow-tpu==2.19.1
3+
4+
-r requirements-common.txt

0 commit comments

Comments
 (0)