11"""Utility functions for manipulating JAX embedding tables and inputs."""
22
33import collections
4- import dataclasses
54import typing
65from typing import Any , Mapping , NamedTuple , Sequence , TypeAlias , TypeVar
76
@@ -35,12 +34,6 @@ class ShardedCooMatrix(NamedTuple):
3534 values : ArrayLike
3635
3736
38- class InputStatsPerTable (NamedTuple ):
39- max_ids_per_partition : int
40- max_unique_ids_per_partition : int
41- required_buffer_size_per_device : int
42-
43-
4437def _round_up_to_multiple (value : int , multiple : int ) -> int :
4538 return ((value + multiple - 1 ) // multiple ) * multiple
4639
@@ -303,15 +296,6 @@ def unshard_and_unstack_tables(
303296 return output
304297
305298
306- def get_table_specs (feature_specs : Nested [FeatureSpec ]) -> dict [str , TableSpec ]:
307- table_spec_map : dict [str , TableSpec ] = {}
308- flat_feature_specs , _ = jax .tree .flatten (feature_specs )
309- for feature_spec in flat_feature_specs :
310- table_spec = feature_spec .table_spec
311- table_spec_map [table_spec .name ] = table_spec
312- return table_spec_map
313-
314-
315299def get_table_stacks (
316300 table_specs : Nested [TableSpec ],
317301) -> dict [str , list [TableSpec ]]:
@@ -341,84 +325,6 @@ def get_table_stacks(
341325 return stacked_table_specs
342326
343327
344- def get_stacked_table_stats (
345- feature_specs : Nested [FeatureSpec ],
346- ) -> dict [str , InputStatsPerTable ]:
347- """Extracts the stacked-table input statistics from the feature specs.
348-
349- Args:
350- feature_specs: Feature specs from which to extracts the statistics.
351-
352- Returns:
353- A mapping of stacked table names to input statistics per table.
354- """
355- stacked_table_specs : dict [str , StackedTableSpec ] = {}
356- for feature_spec in jax .tree .flatten (feature_specs )[0 ]:
357- feature_spec = typing .cast (FeatureSpec , feature_spec )
358- stacked_table_spec = typing .cast (
359- StackedTableSpec , feature_spec .table_spec .stacked_table_spec
360- )
361- stacked_table_specs [stacked_table_spec .stack_name ] = stacked_table_spec
362-
363- stats : dict [str , InputStatsPerTable ] = {}
364- for stacked_table_spec in stacked_table_specs .values ():
365- buffer_size = stacked_table_spec .suggested_coo_buffer_size_per_device
366- buffer_size = buffer_size or 0
367- stats [stacked_table_spec .stack_name ] = InputStatsPerTable (
368- max_ids_per_partition = stacked_table_spec .max_ids_per_partition ,
369- max_unique_ids_per_partition = stacked_table_spec .max_unique_ids_per_partition ,
370- required_buffer_size_per_device = buffer_size ,
371- )
372-
373- return stats
374-
375-
376- def update_stacked_table_stats (
377- feature_specs : Nested [FeatureSpec ],
378- stats : Mapping [str , InputStatsPerTable ],
379- ) -> None :
380- """Updates stacked-table input properties in the supplied feature specs.
381-
382- Args:
383- feature_specs: Feature specs to update in-place.
384- stats: Per-stacked-table input statistics.
385- """
386- # Collect table specs and stacked table specs.
387- table_specs : dict [str , TableSpec ] = {}
388- for feature_spec in jax .tree .flatten (feature_specs )[0 ]:
389- feature_spec = typing .cast (FeatureSpec , feature_spec )
390- table_specs [feature_spec .table_spec .name ] = feature_spec .table_spec
391-
392- stacked_table_specs : dict [str , StackedTableSpec ] = {}
393- for table_spec in table_specs .values ():
394- stacked_table_spec = typing .cast (
395- StackedTableSpec , table_spec .stacked_table_spec
396- )
397- stacked_table_specs [stacked_table_spec .stack_name ] = stacked_table_spec
398-
399- # Replace fields in the stacked_table_specs.
400- stack_names = stacked_table_specs .keys ()
401- for stack_name in stack_names :
402- stack_stats = stats [stack_name ]
403- stacked_table_spec = stacked_table_specs [stack_name ]
404- buffer_size = stack_stats .required_buffer_size_per_device or None
405- stacked_table_specs [stack_name ] = dataclasses .replace (
406- stacked_table_spec ,
407- max_ids_per_partition = stack_stats .max_ids_per_partition ,
408- max_unique_ids_per_partition = stack_stats .max_unique_ids_per_partition ,
409- suggested_coo_buffer_size_per_device = buffer_size ,
410- )
411-
412- # Insert new stacked tables into tables.
413- for table_spec in table_specs .values ():
414- stacked_table_spec = typing .cast (
415- StackedTableSpec , table_spec .stacked_table_spec
416- )
417- table_spec .stacked_table_spec = stacked_table_specs [
418- stacked_table_spec .stack_name
419- ]
420-
421-
422328def convert_to_numpy (
423329 ragged_or_dense : np .ndarray [Any , Any ] | Sequence [Sequence [Any ]] | Any ,
424330 dtype : Any ,
@@ -483,7 +389,7 @@ def ones_like(
483389
484390 Args:
485391 ragged_or_dense: The ragged or dense input whose shape and data-type
486- define these same attributes of the returned array.
392+ define these same attributes of the returned array.
487393 dtype: The data-type of the returned array.
488394
489395 Returns:
@@ -567,7 +473,7 @@ def stack_and_shard_samples(
567473 global_device_count : int ,
568474 num_sc_per_device : int ,
569475 static_buffer_size : int | Mapping [str , int ] | None = None ,
570- ) -> tuple [dict [str , ShardedCooMatrix ], dict [ str , InputStatsPerTable ] ]:
476+ ) -> tuple [dict [str , ShardedCooMatrix ], embedding . SparseDenseMatmulInputStats ]:
571477 """Prepares input samples for use in embedding lookups.
572478
573479 Args:
@@ -612,7 +518,6 @@ def collect_tokens_and_weights(
612518 )
613519
614520 out : dict [str , ShardedCooMatrix ] = {}
615- out_stats : dict [str , InputStatsPerTable ] = {}
616521 tables_names = preprocessed_inputs .lhs_row_pointers .keys ()
617522 for table_name in tables_names :
618523 shard_ends = preprocessed_inputs .lhs_row_pointers [table_name ]
@@ -626,17 +531,5 @@ def collect_tokens_and_weights(
626531 row_ids = preprocessed_inputs .lhs_sample_ids [table_name ],
627532 values = preprocessed_inputs .lhs_gains [table_name ],
628533 )
629- out_stats [table_name ] = InputStatsPerTable (
630- max_ids_per_partition = np .max (
631- stats .max_ids_per_partition [table_name ]
632- ),
633- max_unique_ids_per_partition = np .max (
634- stats .max_unique_ids_per_partition [table_name ]
635- ),
636- required_buffer_size_per_device = np .max (
637- stats .required_buffer_size_per_sc [table_name ]
638- )
639- * num_sc_per_device ,
640- )
641534
642- return out , out_stats
535+ return out , stats
0 commit comments