diff --git a/tests/unit/loader/test_tf_dataloader.py b/tests/unit/loader/test_tf_dataloader.py index b51d068190..6d4e6a0c28 100644 --- a/tests/unit/loader/test_tf_dataloader.py +++ b/tests/unit/loader/test_tf_dataloader.py @@ -382,11 +382,9 @@ def test_mh_support(tmpdir, batch_size): array, offsets = X[f"{mh_name}__values"], X[f"{mh_name}__offsets"] offsets = offsets.numpy() array = array.numpy() - lens = [0] - cur = 0 - for x in multihot_data[mh_name][idx * batch_size : idx * batch_size + n_samples]: - cur += len(x) - lens.append(cur) + m_dta = [len(x) for x in multihot_data[mh_name][idx * batch_size : idx * batch_size + n_samples]] + lens = [0] + np.cumsum(m_dta).tolist() + cur = np.sum(m_dta) assert (offsets == np.array(lens)).all() assert len(array) == max(lens)