77import keras
88import numpy as np
99from jax import numpy as jnp
10+ from jax_tpu_embedding .sparsecore .lib .nn import embedding
1011from jax_tpu_embedding .sparsecore .lib .nn import embedding_spec
1112from jax_tpu_embedding .sparsecore .lib .nn .embedding_spec import FeatureSpec
1213from jax_tpu_embedding .sparsecore .lib .nn .embedding_spec import TableSpec
@@ -142,7 +143,7 @@ def create_tables(
142143def create_table_and_slot_variables (
143144 table_specs : Nested [TableSpec ],
144145 keys : Nested [ArrayLike ] | None = None ,
145- ) -> Nested [ArrayLike ]:
146+ ) -> Nested [embedding . EmbeddingVariables ]:
146147 """Creates and initializes embedding tables and slot variables.
147148
148149 Args:
@@ -164,7 +165,7 @@ def create_table_and_slot_variables(
164165 def _create_table_and_slot_variables (
165166 table_spec : TableSpec ,
166167 key : ArrayLike ,
167- ) -> tuple [ jax . Array , tuple [ jax . Array , ...]] :
168+ ) -> embedding . EmbeddingVariables :
168169 slot_initializers = table_spec .optimizer .slot_variables_initializers ()
169170 num_slot_variables = len (keras .tree .flatten (slot_initializers ))
170171 slot_keys = jnp .unstack (jax .random .split (key , num_slot_variables ))
@@ -178,10 +179,10 @@ def _create_table_and_slot_variables(
178179 slot_initializers ,
179180 slot_keys ,
180181 )
181- return (table , slot_variables )
182+ return embedding . EmbeddingVariables (table , slot_variables )
182183
183184 # Initialize tables.
184- output : Nested [ArrayLike ] = jax .tree .map (
185+ output : Nested [embedding . EmbeddingVariables ] = jax .tree .map (
185186 _create_table_and_slot_variables ,
186187 table_specs ,
187188 keys ,
@@ -311,14 +312,14 @@ def _create_samples(
311312
312313def stack_shard_and_put_tables (
313314 table_specs : Nested [TableSpec ],
314- tables : Nested [jax . Array ],
315+ tables : Nested [embedding . EmbeddingVariables ],
315316 num_shards : int ,
316317 sharding : jax .sharding .Sharding ,
317- ) -> dict [str , Nested [ jax . Array ] ]:
318+ ) -> dict [str , embedding . EmbeddingVariables ]:
318319 sharded_tables = embedding_utils .stack_and_shard_tables (
319320 table_specs , tables , num_shards
320321 )
321- output : dict [str , Nested [ jax . Array ] ] = jax .device_put (
322+ output : dict [str , embedding . EmbeddingVariables ] = jax .device_put (
322323 jax .tree .map (
323324 # Flatten shard dimension to allow auto-sharding to split the array.
324325 lambda table : table .reshape ((- 1 , table .shape [- 1 ])),
@@ -469,27 +470,24 @@ def compute_expected_lookup_grad(
469470def _update_table_and_slot_variables (
470471 table_spec : TableSpec ,
471472 grad : jax .Array ,
472- table_and_slot_variables : tuple [jax .Array , tuple [jax .Array , ...]],
473- ) -> tuple [
474- jax .Array ,
475- embedding_spec .SGDSlotVariables | embedding_spec .AdagradSlotVariables ,
476- ]:
473+ table_and_slot_variables : embedding .EmbeddingVariables ,
474+ ) -> embedding .EmbeddingVariables :
477475 """Updates a table and its slot variables based on the gradient."""
478- table = table_and_slot_variables [ 0 ]
476+ table = table_and_slot_variables . table
479477 optimizer = table_spec .optimizer
480478
481479 # Adagrad, update and apply gradient accumulator.
482480 if isinstance (optimizer , embedding_spec .AdagradOptimizerSpec ):
483- accumulator = table_and_slot_variables [ 1 ][ 0 ]
481+ accumulator = table_and_slot_variables . slot . accumulator
484482 accumulator = accumulator + grad * grad
485483 learning_rate = optimizer .get_learning_rate (0 ) / jnp .sqrt (accumulator )
486- return (
484+ return embedding . EmbeddingVariables (
487485 table - learning_rate * grad ,
488486 embedding_spec .AdagradSlotVariables (accumulator = accumulator ),
489487 )
490488
491489 # SGD
492- return (
490+ return embedding . EmbeddingVariables (
493491 table - optimizer .get_learning_rate (0 ) * grad ,
494492 embedding_spec .SGDSlotVariables (),
495493 )
@@ -500,8 +498,8 @@ def compute_expected_updates(
500498 feature_samples : Nested [FeatureSamples ],
501499 activation_gradients : Nested [jax .Array ],
502500 table_specs : Nested [TableSpec ],
503- table_and_slot_variables : Nested [jax . Array ],
504- ) -> Nested [jax . Array ]:
501+ table_and_slot_variables : Nested [embedding . EmbeddingVariables ],
502+ ) -> Nested [embedding . EmbeddingVariables ]:
505503 """Computes the expected updates for a given embedding lookup.
506504
507505 Args:
@@ -522,7 +520,7 @@ def compute_expected_updates(
522520 )
523521
524522 # Apply updates per table.
525- output : Nested [jax . Array ] = jax .tree .map (
523+ output : Nested [embedding . EmbeddingVariables ] = jax .tree .map (
526524 _update_table_and_slot_variables ,
527525 table_specs ,
528526 table_grads ,
0 commit comments