@@ -53,7 +53,7 @@ def setUp(self):
5353 # FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16
5454
5555 self .batch_size = (
56- BATCH_SIZE_PER_CORE * self ._strategy .num_replicas_in_sync
56+ BATCH_SIZE_PER_CORE * self .strategy .num_replicas_in_sync
5757 )
5858
5959 def get_embedding_config (self , input_type , placement ):
@@ -194,11 +194,11 @@ def test_basics(self, input_type, placement):
194194
195195 if placement == "sparsecore" and not self .on_tpu :
196196 with self .assertRaisesRegex (Exception , "sparsecore" ):
197- with self ._strategy .scope ():
197+ with self .strategy .scope ():
198198 distributed_embedding .DistributedEmbedding (feature_configs )
199199 return
200200
201- with self ._strategy .scope ():
201+ with self .strategy .scope ():
202202 layer = distributed_embedding .DistributedEmbedding (feature_configs )
203203
204204 if keras .backend .backend () == "jax" :
@@ -276,7 +276,7 @@ def test_model_fit(self, input_type, use_weights):
276276 (test_model_inputs , test_labels )
277277 )
278278
279- with self ._strategy .scope ():
279+ with self .strategy .scope ():
280280 layer = distributed_embedding .DistributedEmbedding (feature_configs )
281281
282282 def _create_keras_input (
@@ -347,7 +347,7 @@ def test_dataset_generator():
347347 # New preprocessed data removes the `weights` component.
348348 dataset_has_weights = False
349349 else :
350- train_dataset = self ._strategy .experimental_distribute_dataset (
350+ train_dataset = self .strategy .experimental_distribute_dataset (
351351 train_dataset ,
352352 options = tf .distribute .InputOptions (
353353 experimental_fetch_to_device = False
@@ -362,7 +362,7 @@ def test_dataset_generator():
362362 inputs = keras_model_inputs , outputs = keras_model_outputs
363363 )
364364
365- with self ._strategy .scope ():
365+ with self .strategy .scope ():
366366 model .compile (optimizer = "adam" , loss = "mse" )
367367
368368 model_inputs , _ = next (iter (test_dataset ))
@@ -511,7 +511,7 @@ def test_correctness(
511511 if not use_weights :
512512 weights = None
513513
514- with self ._strategy .scope ():
514+ with self .strategy .scope ():
515515 layer = distributed_embedding .DistributedEmbedding (feature_config )
516516
517517 if keras .backend .backend () == "jax" :
@@ -568,7 +568,7 @@ def test_correctness(
568568
569569 self .assertEqual (res .shape , (self .batch_size , EMBEDDING_OUTPUT_DIM ))
570570
571- with self ._strategy .scope ():
571+ with self .strategy .scope ():
572572 tables = layer .get_embedding_tables ()
573573
574574 emb = tables ["table" ]
@@ -633,11 +633,11 @@ def test_shared_table(self):
633633 "dense" , embedding_config
634634 )
635635
636- with self ._strategy .scope ():
636+ with self .strategy .scope ():
637637 layer = distributed_embedding .DistributedEmbedding (embedding_config )
638638
639639 res = tpu_test_utils .run_with_strategy (
640- self ._strategy , layer .__call__ , inputs
640+ self .strategy , layer .__call__ , inputs
641641 )
642642
643643 if self .placement == "default_device" :
@@ -709,11 +709,11 @@ def test_mixed_placement(self):
709709 "dense" , embedding_config
710710 )
711711
712- with self ._strategy .scope ():
712+ with self .strategy .scope ():
713713 layer = distributed_embedding .DistributedEmbedding (embedding_config )
714714
715715 res = tpu_test_utils .run_with_strategy (
716- self ._strategy , layer .__call__ , inputs
716+ self .strategy , layer .__call__ , inputs
717717 )
718718
719719 self .assertEqual (
@@ -740,22 +740,22 @@ def test_save_load_model(self):
740740 with tempfile .TemporaryDirectory () as temp_dir :
741741 path = os .path .join (temp_dir , "model.keras" )
742742
743- with self ._strategy .scope ():
743+ with self .strategy .scope ():
744744 layer = distributed_embedding .DistributedEmbedding (
745745 feature_configs
746746 )
747747 keras_outputs = layer (keras_inputs )
748748 model = keras .Model (inputs = keras_inputs , outputs = keras_outputs )
749749
750750 output_before = tpu_test_utils .run_with_strategy (
751- self ._strategy , model .__call__ , inputs
751+ self .strategy , model .__call__ , inputs
752752 )
753753 model .save (path )
754754
755- with self ._strategy .scope ():
755+ with self .strategy .scope ():
756756 reloaded_model = keras .models .load_model (path )
757757 output_after = tpu_test_utils .run_with_strategy (
758- self ._strategy , reloaded_model .__call__ , inputs
758+ self .strategy , reloaded_model .__call__ , inputs
759759 )
760760
761761 if self .placement == "sparsecore" :
0 commit comments