Skip to content

Commit 9f1557c

Browse files
Refactor: Replace embedding_utils with table_stacking in DistributedEmbedding and tests (#161)
1 parent e7cda94 commit 9f1557c

File tree

5 files changed

+19
-308
lines changed

5 files changed

+19
-308
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def sparsecore_build(
442442

443443
# Collect all stacked tables.
444444
table_specs = embedding.get_table_specs(feature_specs)
445-
table_stacks = embedding_utils.get_table_stacks(table_specs)
445+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
446446

447447
# Create variables for all stacked tables and slot variables.
448448
with sparsecore_distribution.scope():
@@ -516,7 +516,7 @@ def _sparsecore_symbolic_preprocess(
516516

517517
# Each stacked-table gets a ShardedCooMatrix.
518518
table_specs = embedding.get_table_specs(self._config.feature_specs)
519-
table_stacks = embedding_utils.get_table_stacks(table_specs)
519+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
520520
stacked_table_specs = {
521521
stack_name: stack[0].stacked_table_spec
522522
for stack_name, stack in table_stacks.items()
@@ -720,7 +720,7 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
720720
config = self._config
721721
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
722722
table_specs = embedding.get_table_specs(config.feature_specs)
723-
sharded_tables = embedding_utils.stack_and_shard_tables(
723+
sharded_tables = jte_table_stacking.stack_and_shard_tables(
724724
table_specs,
725725
tables,
726726
num_table_shards,
@@ -763,7 +763,7 @@ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:
763763

764764
return typing.cast(
765765
dict[str, ArrayLike],
766-
embedding_utils.unshard_and_unstack_tables(
766+
jte_table_stacking.unshard_and_unstack_tables(
767767
table_specs, table_variables, num_table_shards
768768
),
769769
)

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from keras_rs.src.layers.embedding.jax import (
2525
distributed_embedding as jax_distributed_embedding,
2626
)
27-
from keras_rs.src.layers.embedding.jax import embedding_utils
2827
from keras_rs.src.layers.embedding.jax import test_utils
2928

3029
keras.config.disable_traceback_filtering()
@@ -177,7 +176,7 @@ def test_sharded_matches_unsharded(self):
177176
)
178177
self.assertEqual(actual.shape, expected_shape)
179178

180-
unsharded_tables = embedding_utils.unshard_and_unstack_tables(
179+
unsharded_tables = table_stacking_lib.unshard_and_unstack_tables(
181180
table_specs,
182181
{stacked_table_spec.stack_name: actual},
183182
num_table_shards,

keras_rs/src/layers/embedding/jax/embedding_lookup_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def test_backward_pass(
398398
res=(sharded_samples, sharded_table_and_slot_variables, None),
399399
gradients=activation_grads,
400400
)
401-
updated_tables_and_slots = embedding_utils.unshard_and_unstack_tables(
401+
updated_tables_and_slots = table_stacking.unshard_and_unstack_tables(
402402
table_specs, updated_stacked_tables, num_table_shards
403403
)
404404

@@ -553,7 +553,7 @@ def loss_fn(params, lookups, labels):
553553
lookup_grads = grads["lookup_tables"]
554554

555555
# Recover unstacked and unsharded gradients.
556-
updated_tables_and_slots = embedding_utils.unshard_and_unstack_tables(
556+
updated_tables_and_slots = table_stacking.unshard_and_unstack_tables(
557557
table_specs, lookup_grads, num_table_shards
558558
)
559559

keras_rs/src/layers/embedding/jax/embedding_utils.py

Lines changed: 5 additions & 296 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
"""Utility functions for manipulating JAX embedding tables and inputs."""
22

33
import collections
4-
import typing
54
from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
65

76
import jax
87
import numpy as np
9-
from jax import numpy as jnp
108
from jax_tpu_embedding.sparsecore.lib.nn import embedding
9+
from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
1110
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
12-
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import StackedTableSpec
13-
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec
1411

1512
from keras_rs.src.types import Nested
1613

@@ -34,297 +31,6 @@ class ShardedCooMatrix(NamedTuple):
3431
values: ArrayLike
3532

3633

37-
def _round_up_to_multiple(value: int, multiple: int) -> int:
38-
return ((value + multiple - 1) // multiple) * multiple
39-
40-
41-
def _default_stacked_table_spec(
42-
table_spec: TableSpec, num_shards: int, batch_size: int
43-
) -> StackedTableSpec:
44-
return StackedTableSpec(
45-
stack_name=table_spec.name,
46-
stack_vocab_size=_round_up_to_multiple(
47-
table_spec.vocabulary_size, 8 * num_shards
48-
),
49-
stack_embedding_dim=_round_up_to_multiple(table_spec.embedding_dim, 8),
50-
optimizer=table_spec.optimizer,
51-
combiner=table_spec.combiner,
52-
total_sample_count=batch_size,
53-
max_ids_per_partition=table_spec.max_ids_per_partition,
54-
max_unique_ids_per_partition=table_spec.max_unique_ids_per_partition,
55-
)
56-
57-
58-
def _get_stacked_table_spec(
59-
table_spec: TableSpec, num_shards: int, batch_size: int = 0
60-
) -> StackedTableSpec:
61-
return table_spec.stacked_table_spec or _default_stacked_table_spec(
62-
table_spec, num_shards, batch_size
63-
)
64-
65-
66-
def pad_table(
67-
table_spec: TableSpec,
68-
table_values: jax.Array,
69-
num_shards: int,
70-
pad_value: jnp.float32 = jnp.nan,
71-
) -> jax.Array:
72-
"""Adds appropriate padding to a table to prepare for stacking.
73-
74-
Args:
75-
table_spec: Table specification describing the table to pad.
76-
table_values: Table values array to pad.
77-
num_shards: Number of shards in the table (typically
78-
`global_device_count * num_sc_per_device`).
79-
pad_value: Value to use for padding.
80-
81-
Returns:
82-
Padded table values.
83-
"""
84-
vocabulary_size = table_spec.vocabulary_size
85-
embedding_dim = table_spec.embedding_dim
86-
padded_vocabulary_size = _round_up_to_multiple(
87-
vocabulary_size, 8 * num_shards
88-
)
89-
stack_embedding_dim = _get_stacked_table_spec(
90-
table_spec, num_shards
91-
).stack_embedding_dim
92-
return jnp.pad(
93-
table_values,
94-
(
95-
(0, padded_vocabulary_size - vocabulary_size),
96-
(0, stack_embedding_dim - embedding_dim),
97-
),
98-
constant_values=pad_value,
99-
)
100-
101-
102-
def _stack_and_shard_table(
103-
stacked_table: jax.Array,
104-
table_spec: TableSpec,
105-
table: jax.Array,
106-
num_shards: int,
107-
pad_value: jnp.float32,
108-
) -> jax.Array:
109-
"""Stacks and shards a single table for use in sparsecore lookups."""
110-
padded_values = pad_table(table_spec, table, num_shards, pad_value)
111-
sharded_padded_vocabulary_size = padded_values.shape[0] // num_shards
112-
stack_embedding_dim = stacked_table.shape[-1]
113-
114-
# Mod-shard vocabulary across devices.
115-
sharded_values = jnp.swapaxes(
116-
padded_values.reshape(-1, num_shards, stack_embedding_dim),
117-
0,
118-
1,
119-
)
120-
121-
# Rotate shards.
122-
setting_in_stack = table_spec.setting_in_stack
123-
rotated_values = jnp.roll(
124-
sharded_values, setting_in_stack.shard_rotation, axis=0
125-
)
126-
127-
# Insert table into the stack.
128-
table_row = setting_in_stack.row_offset_in_shard
129-
stacked_table = stacked_table.at[
130-
:, table_row : (table_row + sharded_padded_vocabulary_size), :
131-
].set(rotated_values)
132-
133-
return stacked_table
134-
135-
136-
def stack_and_shard_tables(
137-
table_specs: Nested[TableSpec],
138-
tables: Nested[ArrayLike],
139-
num_shards: int,
140-
pad_value: jnp.float32 = jnp.nan,
141-
) -> dict[str, Nested[jax.Array]]:
142-
"""Stacks and shards tables for use in sparsecore lookups.
143-
144-
Args:
145-
table_specs: Nested collection of unstacked table specifications.
146-
tables: Table values corresponding to the table_specs.
147-
num_shards: Number of shards in the table (typically
148-
`global_device_count * num_sc_per_device`).
149-
pad_value: Value to use for padding.
150-
151-
Returns:
152-
A mapping of stacked table names to stacked table values.
153-
"""
154-
155-
# Gather stacked table information.
156-
stacked_table_map: dict[
157-
str,
158-
tuple[StackedTableSpec, list[TableSpec]],
159-
] = {}
160-
161-
def collect_stacked_tables(table_spec: TableSpec) -> None:
162-
stacked_table_spec = _get_stacked_table_spec(table_spec, num_shards)
163-
stacked_table_name = stacked_table_spec.stack_name
164-
if stacked_table_name not in stacked_table_map:
165-
stacked_table_map[stacked_table_name] = (stacked_table_spec, [])
166-
stacked_table_map[stacked_table_name][1].append(table_spec)
167-
168-
_ = jax.tree.map(collect_stacked_tables, table_specs)
169-
170-
table_map: dict[str, Nested[jax.Array]] = {}
171-
172-
def collect_tables(table_spec: TableSpec, table: Nested[jax.Array]) -> None:
173-
table_map[table_spec.name] = table
174-
175-
_ = jax.tree.map(collect_tables, table_specs, tables)
176-
177-
stacked_tables: dict[str, Nested[jax.Array]] = {}
178-
for (
179-
stacked_table_spec,
180-
table_specs,
181-
) in stacked_table_map.values():
182-
stack_vocab_size = stacked_table_spec.stack_vocab_size
183-
sharded_vocab_size = stack_vocab_size // num_shards
184-
stack_embedding_dim = stacked_table_spec.stack_embedding_dim
185-
186-
# Allocate initial buffer. The stacked table will be divided among
187-
# shards by splitting the vocabulary dimension:
188-
# [ v, e ] -> [s, v/s, e]
189-
stacked_table_tree = jax.tree.map(
190-
lambda _: jnp.zeros(
191-
# pylint: disable-next=cell-var-from-loop, used only in loop body.
192-
shape=(num_shards, sharded_vocab_size, stack_embedding_dim),
193-
dtype=jnp.float32,
194-
),
195-
table_map[table_specs[0].name],
196-
)
197-
198-
for table_spec in table_specs:
199-
table_tree = table_map[table_spec.name]
200-
stacked_table_tree = jax.tree.map(
201-
lambda stacked_table, table: _stack_and_shard_table(
202-
# pylint: disable-next=cell-var-from-loop, used only in loop body.
203-
stacked_table,
204-
# pylint: disable-next=cell-var-from-loop, used only in loop body.
205-
table_spec,
206-
table,
207-
num_shards,
208-
pad_value,
209-
),
210-
stacked_table_tree,
211-
table_tree,
212-
)
213-
214-
stacked_tables[stacked_table_spec.stack_name] = stacked_table_tree
215-
216-
return stacked_tables
217-
218-
219-
def _unshard_and_unstack_table(
220-
table_spec: TableSpec,
221-
stacked_table_tree: Nested[jax.Array],
222-
num_shards: int,
223-
) -> Nested[jax.Array]:
224-
"""Unshards and unstacks a single table."""
225-
vocabulary_size = table_spec.vocabulary_size
226-
embedding_dim = table_spec.embedding_dim
227-
228-
def _unshard_and_unstack_single_table(
229-
table_spec: TableSpec, stacked_table: jax.Array
230-
) -> jax.Array:
231-
stack_embedding_dim = stacked_table.shape[-1]
232-
233-
# Maybe re-shape in case it was flattened.
234-
stacked_table = stacked_table.reshape(
235-
num_shards, -1, stack_embedding_dim
236-
)
237-
sharded_vocabulary_size = (
238-
_round_up_to_multiple(vocabulary_size, 8 * num_shards) // num_shards
239-
)
240-
241-
# Extract padded values from the stacked table.
242-
setting_in_stack = table_spec.setting_in_stack
243-
row = setting_in_stack.row_offset_in_shard
244-
padded_values = stacked_table[
245-
:, row : (row + sharded_vocabulary_size), :
246-
]
247-
248-
# Un-rotate shards.
249-
padded_values = jnp.roll(
250-
padded_values, -setting_in_stack.shard_rotation, axis=0
251-
)
252-
253-
# Un-mod-shard.
254-
padded_values = jnp.swapaxes(padded_values, 0, 1).reshape(
255-
-1, stack_embedding_dim
256-
)
257-
258-
# Un-pad.
259-
return padded_values[:vocabulary_size, :embedding_dim]
260-
261-
output: Nested[jax.Array] = jax.tree.map(
262-
lambda stacked_table: _unshard_and_unstack_single_table(
263-
table_spec, stacked_table
264-
),
265-
stacked_table_tree,
266-
)
267-
return output
268-
269-
270-
def unshard_and_unstack_tables(
271-
table_specs: Nested[TableSpec],
272-
stacked_tables: Mapping[str, Nested[jax.Array]],
273-
num_shards: int,
274-
) -> Nested[jax.Array]:
275-
"""Unshards and unstacks a collection of tables.
276-
277-
Args:
278-
table_specs: Nested collection of unstacked table specifications.
279-
stacked_tables: Mapping of stacked table names to stacked table values.
280-
num_shards: Number of shards in the table (typically
281-
`global_device_count * num_sc_per_device`).
282-
283-
Returns:
284-
A mapping of table names to unstacked table values.
285-
"""
286-
output: Nested[jax.Array] = jax.tree.map(
287-
lambda table_spec: _unshard_and_unstack_table(
288-
table_spec,
289-
stacked_tables[
290-
_get_stacked_table_spec(table_spec, num_shards=1).stack_name
291-
],
292-
num_shards,
293-
),
294-
table_specs,
295-
)
296-
return output
297-
298-
299-
def get_table_stacks(
300-
table_specs: Nested[TableSpec],
301-
) -> dict[str, list[TableSpec]]:
302-
"""Extracts lists of tables that are stacked together.
303-
304-
Args:
305-
table_specs: Nested collection of table specifications.
306-
307-
Returns:
308-
A mapping of stacked table names to lists of table specifications for
309-
each stack.
310-
"""
311-
stacked_table_specs: dict[str, list[TableSpec]] = collections.defaultdict(
312-
list
313-
)
314-
flat_table_specs, _ = jax.tree.flatten(table_specs)
315-
for table_spec in flat_table_specs:
316-
table_spec = typing.cast(TableSpec, table_spec)
317-
stacked_table_spec = table_spec.stacked_table_spec
318-
if stacked_table_spec is not None:
319-
stacked_table_specs[stacked_table_spec.stack_name].append(
320-
table_spec
321-
)
322-
else:
323-
stacked_table_specs[table_spec.name].append(table_spec)
324-
325-
return stacked_table_specs
326-
327-
32834
def convert_to_numpy(
32935
ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
33036
dtype: Any,
@@ -522,7 +228,10 @@ def collect_tokens_and_weights(
522228
for table_name in tables_names:
523229
shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
524230
shard_starts = np.concatenate(
525-
[np.asarray([0]), _round_up_to_multiple(shard_ends[:-1], 8)]
231+
[
232+
np.asarray([0]),
233+
table_stacking._next_largest_multiple(shard_ends[:-1], 8),
234+
]
526235
)
527236
out[table_name] = ShardedCooMatrix(
528237
shard_starts=shard_starts,

0 commit comments

Comments
 (0)