Skip to content

Commit fd5d3fc

Browse files
committed
Use tf.function for list column operations (#89)
1 parent 42dd301 commit fd5d3fc

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

tests/unit/dataloader/test_tf_dataloader.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,25 +94,36 @@ def test_nested_list():
9494
schema = ds.schema
9595
schema["label"] = schema["label"].with_tags([Tags.TARGET])
9696
ds.schema = schema
97-
train_dataset = tf_dataloader.Loader(
97+
loader = tf_dataloader.Loader(
9898
ds,
9999
batch_size=batch_size,
100100
shuffle=False,
101101
)
102102

103-
batch = next(train_dataset)
103+
batch = next(loader)
104+
104105
# [[1,2,3],[3,1],[...],[]]
105-
nested_data_col = tf.RaggedTensor.from_row_lengths(
106-
batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32)
107-
).to_tensor()
106+
@tf.function
107+
def _ragged_for_nested_data_col():
108+
nested_data_col = tf.RaggedTensor.from_row_lengths(
109+
batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32)
110+
).to_tensor()
111+
return nested_data_col
112+
113+
nested_data_col = _ragged_for_nested_data_col()
108114
true_data_col = tf.reshape(
109-
tf.ragged.constant(df.iloc[:batch_size, 0].tolist()).to_tensor(),
110-
[batch_size, -1],
115+
tf.ragged.constant(df.iloc[:batch_size, 0].tolist()).to_tensor(), [batch_size, -1]
111116
)
117+
112118
# [1,2,3]
113-
multihot_data2_col = tf.RaggedTensor.from_row_lengths(
114-
batch[0]["data2"][0][:, 0], tf.cast(batch[0]["data2"][1][:, 0], tf.int32)
115-
).to_tensor()
119+
@tf.function
120+
def _ragged_for_multihot_data_col():
121+
multihot_data2_col = tf.RaggedTensor.from_row_lengths(
122+
batch[0]["data2"][0][:, 0], tf.cast(batch[0]["data2"][1][:, 0], tf.int32)
123+
).to_tensor()
124+
return multihot_data2_col
125+
126+
multihot_data2_col = _ragged_for_multihot_data_col()
116127
true_data2_col = tf.reshape(
117128
tf.ragged.constant(df.iloc[:batch_size, 1].tolist()).to_tensor(),
118129
[batch_size, -1],

0 commit comments

Comments
 (0)