Skip to content

Commit 02aad21

Browse files
authored
Replace nnzs with row_lengths for clarity (#99)
* Replace `nnzs` with `row_lengths` for clarity The term `nnzs` (i.e. number of non-zeros) comes from sparse tensor nomenclature, but sparse tensors are an implementation detail of the dataloaders, not a domain concept we should be propagating throughout our code base. Instead, let's call them `row_lengths` which make sense as part of a ragged tensor representation throughout Merlin. * Autoupdate precommit package versions * Apply auto-formatting to appease the linter
1 parent 2aeb862 commit 02aad21

File tree

5 files changed

+40
-41
lines changed

5 files changed

+40
-41
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,30 @@ repos:
55
hooks:
66
- id: absolufy-imports
77
- repo: https://github.com/timothycrosley/isort
8-
rev: 5.10.1
8+
rev: 5.12.0
99
hooks:
1010
- id: isort
1111
additional_dependencies: [toml]
1212
exclude: examples/.*
1313
# code style
1414
- repo: https://github.com/python/black
15-
rev: 22.6.0
15+
rev: 23.1.0
1616
hooks:
1717
- id: black
1818
- repo: https://github.com/pycqa/pylint
19-
rev: v2.14.1
19+
rev: v2.16.1
2020
hooks:
2121
- id: pylint
2222
- repo: https://github.com/pycqa/flake8
23-
rev: 3.9.2
23+
rev: 6.0.0
2424
hooks:
2525
- id: flake8
2626
- repo: https://github.com/adrienverge/yamllint
27-
rev: v1.28.0
27+
rev: v1.29.0
2828
hooks:
2929
- id: yamllint
3030
- repo: https://github.com/pre-commit/mirrors-prettier
31-
rev: v2.7.1
31+
rev: v3.0.0-alpha.4
3232
hooks:
3333
- id: prettier
3434
types_or: [yaml, markdown]
@@ -46,7 +46,7 @@ repos:
4646
exclude: ^(docs|examples|tests|setup.py|versioneer.py)
4747
args: [--config=pyproject.toml]
4848
- repo: https://github.com/codespell-project/codespell
49-
rev: v2.2.1
49+
rev: v2.2.2
5050
hooks:
5151
- id: codespell
5252
exclude: .github/.*

merlin/dataloader/loader_base.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _num_steps(num_samples, step_size):
5151
class LoaderBase:
5252
"""Base class containing common functionality between the PyTorch and TensorFlow dataloaders."""
5353

54-
_use_nnz = False
54+
_use_row_lengths = False
5555

5656
def __init__(
5757
self,
@@ -337,15 +337,15 @@ def _get_next_batch(self):
337337
return batch
338338

339339
@annotate("make_tensors", color="darkgreen", domain="merlin_dataloader")
340-
def make_tensors(self, gdf, use_nnz=False):
340+
def make_tensors(self, gdf, use_row_lengths=False):
341341
"""Turns a gdf into tensor representation by column
342342
343343
Parameters
344344
----------
345345
gdf : DataFrame
346346
A dataframe type object.
347-
use_nnz : bool, optional
348-
toggle nnzs or use offsets for list columns, by default False
347+
use_row_lengths : bool, optional
348+
Enable using row lengths instead of offsets for list columns, by default False
349349
350350
Returns
351351
-------
@@ -357,7 +357,7 @@ def make_tensors(self, gdf, use_nnz=False):
357357
# map from big chunk to framework-specific tensors
358358
chunks, names = self._create_tensors(gdf)
359359

360-
# if we have any offsets, calculate nnzs up front
360+
# if we have any offsets, calculate row lengths up front
361361
# will need to get offsets if list columns detected in schema
362362

363363
# if len(chunks) == 4:
@@ -368,8 +368,8 @@ def make_tensors(self, gdf, use_nnz=False):
368368
]
369369
if len(lists_list) > 0:
370370
offsets = chunks[-1]
371-
if use_nnz:
372-
nnzs = offsets[1:] - offsets[:-1]
371+
if use_row_lengths:
372+
row_lengths = offsets[1:] - offsets[:-1]
373373
chunks = chunks[:-1]
374374

375375
# split them into batches and map to the framework-specific output format
@@ -388,43 +388,43 @@ def make_tensors(self, gdf, use_nnz=False):
388388
if lists is not None:
389389
num_list_columns = len(lists)
390390

391-
# grab the set of offsets and nnzs corresponding to
392-
# the list columns from this chunk
391+
# grab the set of offsets and row lengths
392+
# corresponding to the list columns from this chunk
393393
chunk_offsets = offsets[:, offset_idx : offset_idx + num_list_columns]
394-
if use_nnz:
395-
chunk_nnzs = nnzs[:, offset_idx : offset_idx + num_list_columns]
394+
if use_row_lengths:
395+
chunk_row_lengths = row_lengths[:, offset_idx : offset_idx + num_list_columns]
396396
offset_idx += num_list_columns
397397

398398
# split them into batches, including an extra 1 on the offsets
399399
# so we know how long the very last element is
400400
batch_offsets = self._split_fn(chunk_offsets, split_idx + [1])
401-
if use_nnz and len(split_idx) > 1:
402-
batch_nnzs = self._split_fn(chunk_nnzs, split_idx)
403-
elif use_nnz:
404-
batch_nnzs = [chunk_nnzs]
401+
if use_row_lengths and len(split_idx) > 1:
402+
batch_row_lengths = self._split_fn(chunk_row_lengths, split_idx)
403+
elif use_row_lengths:
404+
batch_row_lengths = [chunk_row_lengths]
405405
else:
406-
batch_nnzs = [None] * (len(batch_offsets) - 1)
406+
batch_row_lengths = [None] * (len(batch_offsets) - 1)
407407

408408
# group all these indices together and iterate through
409409
# them in batches to grab the proper elements from each
410410
# values tensor
411-
chunk = zip(chunk, batch_offsets[:-1], batch_offsets[1:], batch_nnzs)
411+
chunk = zip(chunk, batch_offsets[:-1], batch_offsets[1:], batch_row_lengths)
412412

413413
for n, c in enumerate(chunk):
414414
if isinstance(c, tuple):
415-
c, off0s, off1s, _nnzs = c
415+
c, off0s, off1s, _row_lengths = c
416416
offsets_split_idx = [1 for _ in range(num_list_columns)]
417417
off0s = self._split_fn(off0s, offsets_split_idx, axis=1)
418418
off1s = self._split_fn(off1s, offsets_split_idx, axis=1)
419-
if use_nnz:
420-
_nnzs = self._split_fn(_nnzs, offsets_split_idx, axis=1)
419+
if use_row_lengths:
420+
_row_lengths = self._split_fn(_row_lengths, offsets_split_idx, axis=1)
421421

422422
# TODO: does this need to be ordereddict?
423423
batch_lists = {}
424424
for k, (column_name, values) in enumerate(lists.items()):
425425
off0, off1 = off0s[k], off1s[k]
426-
if use_nnz:
427-
nnz = _nnzs[k]
426+
if use_row_lengths:
427+
row_length = _row_lengths[k]
428428

429429
# need to grab scalars for TF case
430430
if len(off0.shape) == 1:
@@ -435,7 +435,7 @@ def make_tensors(self, gdf, use_nnz=False):
435435
print(off0, off1)
436436
raise ValueError
437437
value = values[int(start) : int(stop)]
438-
index = off0 - start if not use_nnz else nnz
438+
index = off0 - start if not use_row_lengths else row_length
439439
batch_lists[column_name] = (value, index)
440440
c = (c, batch_lists)
441441

@@ -829,7 +829,7 @@ def chunk_logic(self, itr):
829829
chunks = shuffle_df(chunks)
830830

831831
if len(chunks) > 0:
832-
chunks = self.dataloader.make_tensors(chunks, self.dataloader._use_nnz)
832+
chunks = self.dataloader.make_tensors(chunks, self.dataloader._use_row_lengths)
833833
# put returns True if buffer is stopped before
834834
# packet can be put in queue. Keeps us from
835835
# freezing on a put on a full queue
@@ -838,7 +838,7 @@ def chunk_logic(self, itr):
838838
chunks = None
839839
# takes care final batch, which is less than batch size
840840
if not self.dataloader.drop_last and spill is not None and not spill.empty:
841-
spill = self.dataloader.make_tensors(spill, self.dataloader._use_nnz)
841+
spill = self.dataloader.make_tensors(spill, self.dataloader._use_row_lengths)
842842
self.put(spill)
843843

844844
@annotate("load_chunks", color="darkgreen", domain="merlin_dataloader")

merlin/dataloader/tensorflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class Loader(tf.keras.utils.Sequence, LoaderBase):
102102
will usually contain fewer rows.
103103
"""
104104

105-
_use_nnz = True
105+
_use_row_lengths = True
106106

107107
def __init__(
108108
self,

merlin/dataloader/utils/tf/tf_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def seed_fn():
9090
for col in CATEGORICAL_MH_COLUMNS:
9191
inputs[col] = (
9292
tf.keras.Input(name=f"{col}__values", dtype=tf.int64, shape=(1,)),
93-
tf.keras.Input(name=f"{col}__nnzs", dtype=tf.int64, shape=(1,)),
93+
tf.keras.Input(name=f"{col}__lengths", dtype=tf.int64, shape=(1,)),
9494
)
9595
for col in CATEGORICAL_COLUMNS + CATEGORICAL_MH_COLUMNS:
9696
emb_layers.append(

tests/unit/dataloader/test_tf_dataloader.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def test_mh_support(tmpdir, multihot_data, multihot_dataset, batch_size):
382382
batch_size=batch_size,
383383
shuffle=False,
384384
)
385-
nnzs = None
385+
row_lengths = None
386386
idx = 0
387387

388388
for X, y in data_itr:
@@ -391,18 +391,18 @@ def test_mh_support(tmpdir, multihot_data, multihot_dataset, batch_size):
391391

392392
for mh_name in ["Authors", "Reviewers", "Embedding"]:
393393
# assert (mh_name) in X
394-
array, nnzs = X[mh_name]
395-
nnzs = nnzs.numpy()[:, 0]
394+
array, row_lengths = X[mh_name]
395+
row_lengths = row_lengths.numpy()[:, 0]
396396
array = array.numpy()[:, 0]
397397

398398
if mh_name == "Embedding":
399-
assert (nnzs == 3).all()
399+
assert (row_lengths == 3).all()
400400
else:
401401
lens = [
402402
len(x)
403403
for x in multihot_data[mh_name][idx * batch_size : idx * batch_size + n_samples]
404404
]
405-
assert (nnzs == np.array(lens)).all()
405+
assert (row_lengths == np.array(lens)).all()
406406

407407
if mh_name == "Embedding":
408408
assert len(array) == (n_samples * 3)
@@ -533,7 +533,7 @@ def test_sparse_tensors(tmpdir, sparse_dense):
533533
for batch in data_itr:
534534
feats, labs = batch
535535
for col in spa_lst:
536-
# grab nnzs
536+
# grab row lengths
537537
feature_tensor = feats[f"{col}"]
538538
if not sparse_dense:
539539
assert list(feature_tensor.shape) == [batch_size, spa_mx[col]]
@@ -648,7 +648,6 @@ def test_dataloader_schema(tmpdir, dataset, batch_size, cpu):
648648
batch_size=batch_size,
649649
shuffle=False,
650650
) as data_loader:
651-
652651
batch = data_loader.peek()
653652

654653
columns = set(dataset.schema.column_names) - {"label"}

0 commit comments

Comments
 (0)