@@ -94,25 +94,36 @@ def test_nested_list():
9494 schema = ds .schema
9595 schema ["label" ] = schema ["label" ].with_tags ([Tags .TARGET ])
9696 ds .schema = schema
97- train_dataset = tf_dataloader .Loader (
97+ loader = tf_dataloader .Loader (
9898 ds ,
9999 batch_size = batch_size ,
100100 shuffle = False ,
101101 )
102102
103- batch = next (train_dataset )
103+ batch = next (loader )
104+
104105 # [[1,2,3],[3,1],[...],[]]
105- nested_data_col = tf .RaggedTensor .from_row_lengths (
106- batch [0 ]["data" ][0 ][:, 0 ], tf .cast (batch [0 ]["data" ][1 ][:, 0 ], tf .int32 )
107- ).to_tensor ()
106+ @tf .function
107+ def _ragged_for_nested_data_col ():
108+ nested_data_col = tf .RaggedTensor .from_row_lengths (
109+ batch [0 ]["data" ][0 ][:, 0 ], tf .cast (batch [0 ]["data" ][1 ][:, 0 ], tf .int32 )
110+ ).to_tensor ()
111+ return nested_data_col
112+
113+ nested_data_col = _ragged_for_nested_data_col ()
108114 true_data_col = tf .reshape (
109- tf .ragged .constant (df .iloc [:batch_size , 0 ].tolist ()).to_tensor (),
110- [batch_size , - 1 ],
115+ tf .ragged .constant (df .iloc [:batch_size , 0 ].tolist ()).to_tensor (), [batch_size , - 1 ]
111116 )
117+
112118 # [1,2,3]
113- multihot_data2_col = tf .RaggedTensor .from_row_lengths (
114- batch [0 ]["data2" ][0 ][:, 0 ], tf .cast (batch [0 ]["data2" ][1 ][:, 0 ], tf .int32 )
115- ).to_tensor ()
119+ @tf .function
120+ def _ragged_for_multihot_data_col ():
121+ multihot_data2_col = tf .RaggedTensor .from_row_lengths (
122+ batch [0 ]["data2" ][0 ][:, 0 ], tf .cast (batch [0 ]["data2" ][1 ][:, 0 ], tf .int32 )
123+ ).to_tensor ()
124+ return multihot_data2_col
125+
126+ multihot_data2_col = _ragged_for_multihot_data_col ()
116127 true_data2_col = tf .reshape (
117128 tf .ragged .constant (df .iloc [:batch_size , 1 ].tolist ()).to_tensor (),
118129 [batch_size , - 1 ],
0 commit comments