11"""Utility functions for manipulating JAX embedding tables and inputs."""
22
33import collections
4- import typing
54from typing import Any , Mapping , NamedTuple , Sequence , TypeAlias , TypeVar
65
76import jax
87import numpy as np
9- from jax import numpy as jnp
108from jax_tpu_embedding .sparsecore .lib .nn import embedding
9+ from jax_tpu_embedding .sparsecore .lib .nn import table_stacking
1110from jax_tpu_embedding .sparsecore .lib .nn .embedding_spec import FeatureSpec
12- from jax_tpu_embedding .sparsecore .lib .nn .embedding_spec import StackedTableSpec
13- from jax_tpu_embedding .sparsecore .lib .nn .embedding_spec import TableSpec
1411
1512from keras_rs .src .types import Nested
1613
@@ -34,297 +31,6 @@ class ShardedCooMatrix(NamedTuple):
3431 values : ArrayLike
3532
3633
37- def _round_up_to_multiple (value : int , multiple : int ) -> int :
38- return ((value + multiple - 1 ) // multiple ) * multiple
39-
40-
41- def _default_stacked_table_spec (
42- table_spec : TableSpec , num_shards : int , batch_size : int
43- ) -> StackedTableSpec :
44- return StackedTableSpec (
45- stack_name = table_spec .name ,
46- stack_vocab_size = _round_up_to_multiple (
47- table_spec .vocabulary_size , 8 * num_shards
48- ),
49- stack_embedding_dim = _round_up_to_multiple (table_spec .embedding_dim , 8 ),
50- optimizer = table_spec .optimizer ,
51- combiner = table_spec .combiner ,
52- total_sample_count = batch_size ,
53- max_ids_per_partition = table_spec .max_ids_per_partition ,
54- max_unique_ids_per_partition = table_spec .max_unique_ids_per_partition ,
55- )
56-
57-
58- def _get_stacked_table_spec (
59- table_spec : TableSpec , num_shards : int , batch_size : int = 0
60- ) -> StackedTableSpec :
61- return table_spec .stacked_table_spec or _default_stacked_table_spec (
62- table_spec , num_shards , batch_size
63- )
64-
65-
66- def pad_table (
67- table_spec : TableSpec ,
68- table_values : jax .Array ,
69- num_shards : int ,
70- pad_value : jnp .float32 = jnp .nan ,
71- ) -> jax .Array :
72- """Adds appropriate padding to a table to prepare for stacking.
73-
74- Args:
75- table_spec: Table specification describing the table to pad.
76- table_values: Table values array to pad.
77- num_shards: Number of shards in the table (typically
78- `global_device_count * num_sc_per_device`).
79- pad_value: Value to use for padding.
80-
81- Returns:
82- Padded table values.
83- """
84- vocabulary_size = table_spec .vocabulary_size
85- embedding_dim = table_spec .embedding_dim
86- padded_vocabulary_size = _round_up_to_multiple (
87- vocabulary_size , 8 * num_shards
88- )
89- stack_embedding_dim = _get_stacked_table_spec (
90- table_spec , num_shards
91- ).stack_embedding_dim
92- return jnp .pad (
93- table_values ,
94- (
95- (0 , padded_vocabulary_size - vocabulary_size ),
96- (0 , stack_embedding_dim - embedding_dim ),
97- ),
98- constant_values = pad_value ,
99- )
100-
101-
102- def _stack_and_shard_table (
103- stacked_table : jax .Array ,
104- table_spec : TableSpec ,
105- table : jax .Array ,
106- num_shards : int ,
107- pad_value : jnp .float32 ,
108- ) -> jax .Array :
109- """Stacks and shards a single table for use in sparsecore lookups."""
110- padded_values = pad_table (table_spec , table , num_shards , pad_value )
111- sharded_padded_vocabulary_size = padded_values .shape [0 ] // num_shards
112- stack_embedding_dim = stacked_table .shape [- 1 ]
113-
114- # Mod-shard vocabulary across devices.
115- sharded_values = jnp .swapaxes (
116- padded_values .reshape (- 1 , num_shards , stack_embedding_dim ),
117- 0 ,
118- 1 ,
119- )
120-
121- # Rotate shards.
122- setting_in_stack = table_spec .setting_in_stack
123- rotated_values = jnp .roll (
124- sharded_values , setting_in_stack .shard_rotation , axis = 0
125- )
126-
127- # Insert table into the stack.
128- table_row = setting_in_stack .row_offset_in_shard
129- stacked_table = stacked_table .at [
130- :, table_row : (table_row + sharded_padded_vocabulary_size ), :
131- ].set (rotated_values )
132-
133- return stacked_table
134-
135-
136- def stack_and_shard_tables (
137- table_specs : Nested [TableSpec ],
138- tables : Nested [ArrayLike ],
139- num_shards : int ,
140- pad_value : jnp .float32 = jnp .nan ,
141- ) -> dict [str , Nested [jax .Array ]]:
142- """Stacks and shards tables for use in sparsecore lookups.
143-
144- Args:
145- table_specs: Nested collection of unstacked table specifications.
146- tables: Table values corresponding to the table_specs.
147- num_shards: Number of shards in the table (typically
148- `global_device_count * num_sc_per_device`).
149- pad_value: Value to use for padding.
150-
151- Returns:
152- A mapping of stacked table names to stacked table values.
153- """
154-
155- # Gather stacked table information.
156- stacked_table_map : dict [
157- str ,
158- tuple [StackedTableSpec , list [TableSpec ]],
159- ] = {}
160-
161- def collect_stacked_tables (table_spec : TableSpec ) -> None :
162- stacked_table_spec = _get_stacked_table_spec (table_spec , num_shards )
163- stacked_table_name = stacked_table_spec .stack_name
164- if stacked_table_name not in stacked_table_map :
165- stacked_table_map [stacked_table_name ] = (stacked_table_spec , [])
166- stacked_table_map [stacked_table_name ][1 ].append (table_spec )
167-
168- _ = jax .tree .map (collect_stacked_tables , table_specs )
169-
170- table_map : dict [str , Nested [jax .Array ]] = {}
171-
172- def collect_tables (table_spec : TableSpec , table : Nested [jax .Array ]) -> None :
173- table_map [table_spec .name ] = table
174-
175- _ = jax .tree .map (collect_tables , table_specs , tables )
176-
177- stacked_tables : dict [str , Nested [jax .Array ]] = {}
178- for (
179- stacked_table_spec ,
180- table_specs ,
181- ) in stacked_table_map .values ():
182- stack_vocab_size = stacked_table_spec .stack_vocab_size
183- sharded_vocab_size = stack_vocab_size // num_shards
184- stack_embedding_dim = stacked_table_spec .stack_embedding_dim
185-
186- # Allocate initial buffer. The stacked table will be divided among
187- # shards by splitting the vocabulary dimension:
188- # [ v, e ] -> [s, v/s, e]
189- stacked_table_tree = jax .tree .map (
190- lambda _ : jnp .zeros (
191- # pylint: disable-next=cell-var-from-loop, used only in loop body.
192- shape = (num_shards , sharded_vocab_size , stack_embedding_dim ),
193- dtype = jnp .float32 ,
194- ),
195- table_map [table_specs [0 ].name ],
196- )
197-
198- for table_spec in table_specs :
199- table_tree = table_map [table_spec .name ]
200- stacked_table_tree = jax .tree .map (
201- lambda stacked_table , table : _stack_and_shard_table (
202- # pylint: disable-next=cell-var-from-loop, used only in loop body.
203- stacked_table ,
204- # pylint: disable-next=cell-var-from-loop, used only in loop body.
205- table_spec ,
206- table ,
207- num_shards ,
208- pad_value ,
209- ),
210- stacked_table_tree ,
211- table_tree ,
212- )
213-
214- stacked_tables [stacked_table_spec .stack_name ] = stacked_table_tree
215-
216- return stacked_tables
217-
218-
219- def _unshard_and_unstack_table (
220- table_spec : TableSpec ,
221- stacked_table_tree : Nested [jax .Array ],
222- num_shards : int ,
223- ) -> Nested [jax .Array ]:
224- """Unshards and unstacks a single table."""
225- vocabulary_size = table_spec .vocabulary_size
226- embedding_dim = table_spec .embedding_dim
227-
228- def _unshard_and_unstack_single_table (
229- table_spec : TableSpec , stacked_table : jax .Array
230- ) -> jax .Array :
231- stack_embedding_dim = stacked_table .shape [- 1 ]
232-
233- # Maybe re-shape in case it was flattened.
234- stacked_table = stacked_table .reshape (
235- num_shards , - 1 , stack_embedding_dim
236- )
237- sharded_vocabulary_size = (
238- _round_up_to_multiple (vocabulary_size , 8 * num_shards ) // num_shards
239- )
240-
241- # Extract padded values from the stacked table.
242- setting_in_stack = table_spec .setting_in_stack
243- row = setting_in_stack .row_offset_in_shard
244- padded_values = stacked_table [
245- :, row : (row + sharded_vocabulary_size ), :
246- ]
247-
248- # Un-rotate shards.
249- padded_values = jnp .roll (
250- padded_values , - setting_in_stack .shard_rotation , axis = 0
251- )
252-
253- # Un-mod-shard.
254- padded_values = jnp .swapaxes (padded_values , 0 , 1 ).reshape (
255- - 1 , stack_embedding_dim
256- )
257-
258- # Un-pad.
259- return padded_values [:vocabulary_size , :embedding_dim ]
260-
261- output : Nested [jax .Array ] = jax .tree .map (
262- lambda stacked_table : _unshard_and_unstack_single_table (
263- table_spec , stacked_table
264- ),
265- stacked_table_tree ,
266- )
267- return output
268-
269-
270- def unshard_and_unstack_tables (
271- table_specs : Nested [TableSpec ],
272- stacked_tables : Mapping [str , Nested [jax .Array ]],
273- num_shards : int ,
274- ) -> Nested [jax .Array ]:
275- """Unshards and unstacks a collection of tables.
276-
277- Args:
278- table_specs: Nested collection of unstacked table specifications.
279- stacked_tables: Mapping of stacked table names to stacked table values.
280- num_shards: Number of shards in the table (typically
281- `global_device_count * num_sc_per_device`).
282-
283- Returns:
284- A mapping of table names to unstacked table values.
285- """
286- output : Nested [jax .Array ] = jax .tree .map (
287- lambda table_spec : _unshard_and_unstack_table (
288- table_spec ,
289- stacked_tables [
290- _get_stacked_table_spec (table_spec , num_shards = 1 ).stack_name
291- ],
292- num_shards ,
293- ),
294- table_specs ,
295- )
296- return output
297-
298-
299- def get_table_stacks (
300- table_specs : Nested [TableSpec ],
301- ) -> dict [str , list [TableSpec ]]:
302- """Extracts lists of tables that are stacked together.
303-
304- Args:
305- table_specs: Nested collection of table specifications.
306-
307- Returns:
308- A mapping of stacked table names to lists of table specifications for
309- each stack.
310- """
311- stacked_table_specs : dict [str , list [TableSpec ]] = collections .defaultdict (
312- list
313- )
314- flat_table_specs , _ = jax .tree .flatten (table_specs )
315- for table_spec in flat_table_specs :
316- table_spec = typing .cast (TableSpec , table_spec )
317- stacked_table_spec = table_spec .stacked_table_spec
318- if stacked_table_spec is not None :
319- stacked_table_specs [stacked_table_spec .stack_name ].append (
320- table_spec
321- )
322- else :
323- stacked_table_specs [table_spec .name ].append (table_spec )
324-
325- return stacked_table_specs
326-
327-
32834def convert_to_numpy (
32935 ragged_or_dense : np .ndarray [Any , Any ] | Sequence [Sequence [Any ]] | Any ,
33036 dtype : Any ,
@@ -522,7 +228,10 @@ def collect_tokens_and_weights(
522228 for table_name in tables_names :
523229 shard_ends = preprocessed_inputs .lhs_row_pointers [table_name ]
524230 shard_starts = np .concatenate (
525- [np .asarray ([0 ]), _round_up_to_multiple (shard_ends [:- 1 ], 8 )]
231+ [
232+ np .asarray ([0 ]),
233+ table_stacking ._next_largest_multiple (shard_ends [:- 1 ], 8 ),
234+ ]
526235 )
527236 out [table_name ] = ShardedCooMatrix (
528237 shard_starts = shard_starts ,
0 commit comments