Skip to content

Commit e7cda94

Browse files
authored
Fix for jax import in distributed_embedding_test.py. (#162)
Passing works better in some environments.
1 parent 47ab13d commit e7cda94

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
import jax
2020
import jax.experimental.sparse as jax_sparse
2121
except ImportError:
22-
jax = None
23-
jax_sparse = None
22+
pass
2423

2524

2625
FEATURE1_EMBEDDING_OUTPUT_DIM = 7

0 commit comments

Comments
 (0)