@@ -51,7 +51,7 @@ def _num_steps(num_samples, step_size):
5151class LoaderBase :
5252 """Base class containing common functionality between the PyTorch and TensorFlow dataloaders."""
5353
54- _use_nnz = False
54+ _use_row_lengths = False
5555
5656 def __init__ (
5757 self ,
@@ -337,15 +337,15 @@ def _get_next_batch(self):
337337 return batch
338338
339339 @annotate ("make_tensors" , color = "darkgreen" , domain = "merlin_dataloader" )
340- def make_tensors (self , gdf , use_nnz = False ):
340+ def make_tensors (self , gdf , use_row_lengths = False ):
341341 """Turns a gdf into tensor representation by column
342342
343343 Parameters
344344 ----------
345345 gdf : DataFrame
346346 A dataframe type object.
347- use_nnz : bool, optional
348- toggle nnzs or use offsets for list columns, by default False
347+ use_row_lengths : bool, optional
348+ Enable using row lengths instead of offsets for list columns, by default False
349349
350350 Returns
351351 -------
@@ -357,7 +357,7 @@ def make_tensors(self, gdf, use_nnz=False):
357357 # map from big chunk to framework-specific tensors
358358 chunks , names = self ._create_tensors (gdf )
359359
360- # if we have any offsets, calculate nnzs up front
360+ # if we have any offsets, calculate row lengths up front
361361 # will need to get offsets if list columns detected in schema
362362
363363 # if len(chunks) == 4:
@@ -368,8 +368,8 @@ def make_tensors(self, gdf, use_nnz=False):
368368 ]
369369 if len (lists_list ) > 0 :
370370 offsets = chunks [- 1 ]
371- if use_nnz :
372- nnzs = offsets [1 :] - offsets [:- 1 ]
371+ if use_row_lengths :
372+ row_lengths = offsets [1 :] - offsets [:- 1 ]
373373 chunks = chunks [:- 1 ]
374374
375375 # split them into batches and map to the framework-specific output format
@@ -388,43 +388,43 @@ def make_tensors(self, gdf, use_nnz=False):
388388 if lists is not None :
389389 num_list_columns = len (lists )
390390
391- # grab the set of offsets and nnzs corresponding to
392- # the list columns from this chunk
391+ # grab the set of offsets and row lengths
392+ # corresponding to the list columns from this chunk
393393 chunk_offsets = offsets [:, offset_idx : offset_idx + num_list_columns ]
394- if use_nnz :
395- chunk_nnzs = nnzs [:, offset_idx : offset_idx + num_list_columns ]
394+ if use_row_lengths :
395+ chunk_row_lengths = row_lengths [:, offset_idx : offset_idx + num_list_columns ]
396396 offset_idx += num_list_columns
397397
398398 # split them into batches, including an extra 1 on the offsets
399399 # so we know how long the very last element is
400400 batch_offsets = self ._split_fn (chunk_offsets , split_idx + [1 ])
401- if use_nnz and len (split_idx ) > 1 :
402- batch_nnzs = self ._split_fn (chunk_nnzs , split_idx )
403- elif use_nnz :
404- batch_nnzs = [chunk_nnzs ]
401+ if use_row_lengths and len (split_idx ) > 1 :
402+ batch_row_lengths = self ._split_fn (chunk_row_lengths , split_idx )
403+ elif use_row_lengths :
404+ batch_row_lengths = [chunk_row_lengths ]
405405 else :
406- batch_nnzs = [None ] * (len (batch_offsets ) - 1 )
406+ batch_row_lengths = [None ] * (len (batch_offsets ) - 1 )
407407
408408 # group all these indices together and iterate through
409409 # them in batches to grab the proper elements from each
410410 # values tensor
411- chunk = zip (chunk , batch_offsets [:- 1 ], batch_offsets [1 :], batch_nnzs )
411+ chunk = zip (chunk , batch_offsets [:- 1 ], batch_offsets [1 :], batch_row_lengths )
412412
413413 for n , c in enumerate (chunk ):
414414 if isinstance (c , tuple ):
415- c , off0s , off1s , _nnzs = c
415+ c , off0s , off1s , _row_lengths = c
416416 offsets_split_idx = [1 for _ in range (num_list_columns )]
417417 off0s = self ._split_fn (off0s , offsets_split_idx , axis = 1 )
418418 off1s = self ._split_fn (off1s , offsets_split_idx , axis = 1 )
419- if use_nnz :
420- _nnzs = self ._split_fn (_nnzs , offsets_split_idx , axis = 1 )
419+ if use_row_lengths :
420+ _row_lengths = self ._split_fn (_row_lengths , offsets_split_idx , axis = 1 )
421421
422422 # TODO: does this need to be ordereddict?
423423 batch_lists = {}
424424 for k , (column_name , values ) in enumerate (lists .items ()):
425425 off0 , off1 = off0s [k ], off1s [k ]
426- if use_nnz :
427- nnz = _nnzs [k ]
426+ if use_row_lengths :
427+ row_length = _row_lengths [k ]
428428
429429 # need to grab scalars for TF case
430430 if len (off0 .shape ) == 1 :
@@ -435,7 +435,7 @@ def make_tensors(self, gdf, use_nnz=False):
435435 print (off0 , off1 )
436436 raise ValueError
437437 value = values [int (start ) : int (stop )]
438- index = off0 - start if not use_nnz else nnz
438+ index = off0 - start if not use_row_lengths else row_length
439439 batch_lists [column_name ] = (value , index )
440440 c = (c , batch_lists )
441441
@@ -829,7 +829,7 @@ def chunk_logic(self, itr):
829829 chunks = shuffle_df (chunks )
830830
831831 if len (chunks ) > 0 :
832- chunks = self .dataloader .make_tensors (chunks , self .dataloader ._use_nnz )
832+ chunks = self .dataloader .make_tensors (chunks , self .dataloader ._use_row_lengths )
833833 # put returns True if buffer is stopped before
834834 # packet can be put in queue. Keeps us from
835835 # freezing on a put on a full queue
@@ -838,7 +838,7 @@ def chunk_logic(self, itr):
838838 chunks = None
839839 # takes care final batch, which is less than batch size
840840 if not self .dataloader .drop_last and spill is not None and not spill .empty :
841- spill = self .dataloader .make_tensors (spill , self .dataloader ._use_nnz )
841+ spill = self .dataloader .make_tensors (spill , self .dataloader ._use_row_lengths )
842842 self .put (spill )
843843
844844 @annotate ("load_chunks" , color = "darkgreen" , domain = "merlin_dataloader" )
0 commit comments