From 7b0bd95549cde17628a5ab5a5766d7682d30adaf Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Tue, 16 Nov 2021 20:18:03 +0000 Subject: [PATCH 01/12] Prepend batch dimension if it's not there (partial fix?) --- xbatcher/generators.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 612be61..992c32f 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -34,7 +34,7 @@ def _iterate_through_dataset(ds, dims, overlap={}): size = dims[dim] olap = overlap.get(dim, 0) dim_slices.append(_slices(dimsize, size, olap)) - + for slices in itertools.product(*dim_slices): selector = {key: slice for key, slice in zip(dims, slices)} yield ds.isel(**selector) @@ -53,11 +53,10 @@ def _drop_input_dims(ds, input_dims, suffix='_input'): out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs return out - def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'): batch_dims = [d for d in ds.dims if d not in input_dims] if len(batch_dims) < 2: - return ds + return ds.expand_dims(stacked_dim_name, 0) ds_stack = ds.stack(**{stacked_dim_name: batch_dims}) # ensure correct order dim_order = (stacked_dim_name,) + tuple(input_dims) From 0679395d67b6f731c9115cba7ef0575d4bf815f2 Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Tue, 16 Nov 2021 20:20:03 +0000 Subject: [PATCH 02/12] whitespace --- xbatcher/generators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 992c32f..5992acd 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -34,7 +34,7 @@ def _iterate_through_dataset(ds, dims, overlap={}): size = dims[dim] olap = overlap.get(dim, 0) dim_slices.append(_slices(dimsize, size, olap)) - + for slices in itertools.product(*dim_slices): selector = {key: slice for key, slice in zip(dims, slices)} yield ds.isel(**selector) From 1322eff50d00cd4e5f4b4463eb58ab1ca05640f2 Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Thu, 18 Nov 2021 19:55:12 +0000 Subject: [PATCH 03/12] Wrap batch dim generation in a flag --- xbatcher/generators.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 5992acd..1496c0b 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -53,10 +53,13 @@ def _drop_input_dims(ds, input_dims, suffix='_input'): out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs return out -def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'): +def _maybe_stack_batch_dims(ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample'): batch_dims = [d for d in ds.dims if d not in input_dims] if len(batch_dims) < 2: - return ds.expand_dims(stacked_dim_name, 0) + if(squeeze_batch_dim): + return ds + else: + return ds.expand_dims(stacked_dim_name, 0) ds_stack = ds.stack(**{stacked_dim_name: batch_dims}) # ensure correct order dim_order = (stacked_dim_name,) + tuple(input_dims) @@ -89,6 +92,10 @@ class BatchGenerator: preload_batch : bool, optional If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. + squeeze_batch_dim : bool, optional + If ``False", each batch's dataset will have a "batch" dimension of size 1 + prepended to the array. This functionality is useful for interoperability + with Keras / Tensorflow. Yields ------ @@ -104,6 +111,7 @@ def __init__( batch_dims={}, concat_input_dims=False, preload_batch=True, + squeeze_batch_dim=True ): self.ds = _as_xarray_dataset(ds) @@ -113,6 +121,7 @@ def __init__( self.batch_dims = OrderedDict(batch_dims) self.concat_input_dims = concat_input_dims self.preload_batch = preload_batch + self.squeeze_batch_dim = squeeze_batch_dim def __iter__(self): for ds_batch in self._iterate_batch_dims(self.ds): @@ -131,11 +140,11 @@ def __iter__(self): new_input_dims = [ dim + new_dim_suffix for dim in self.input_dims ] - yield _maybe_stack_batch_dims(dsc, new_input_dims) + yield _maybe_stack_batch_dims(dsc, new_input_dims, self.squeeze_batch_dim) else: for ds_input in input_generator: yield _maybe_stack_batch_dims( - ds_input, list(self.input_dims) + ds_input, list(self.input_dims), self.squeeze_batch_dim ) def _iterate_batch_dims(self, ds): From ca34638915d10ea2e08a1ef43f35773bd7c00f65 Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Thu, 18 Nov 2021 20:05:26 +0000 Subject: [PATCH 04/12] Linter fixes --- xbatcher/generators.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 1496c0b..6a5bf95 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -56,7 +56,7 @@ def _drop_input_dims(ds, input_dims, suffix='_input'): def _maybe_stack_batch_dims(ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample'): batch_dims = [d for d in ds.dims if d not in input_dims] if len(batch_dims) < 2: - if(squeeze_batch_dim): + if squeeze_batch_dim: return ds else: return ds.expand_dims(stacked_dim_name, 0) @@ -111,7 +111,7 @@ def __init__( batch_dims={}, concat_input_dims=False, preload_batch=True, - squeeze_batch_dim=True + squeeze_batch_dim=True, ): self.ds = _as_xarray_dataset(ds) @@ -140,7 +140,9 @@ def __iter__(self): new_input_dims = [ dim + new_dim_suffix for dim in self.input_dims ] - yield _maybe_stack_batch_dims(dsc, new_input_dims, self.squeeze_batch_dim) + yield _maybe_stack_batch_dims( + dsc, new_input_dims, self.squeeze_batch_dim + ) else: for ds_input in input_generator: yield _maybe_stack_batch_dims( From 630bb2717b82d64a0d640cf40b730ca52fcdbbb1 Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Thu, 18 Nov 2021 20:09:46 +0000 Subject: [PATCH 05/12] really? --- xbatcher/generators.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 6a5bf95..9c105fd 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -53,7 +53,9 @@ def _drop_input_dims(ds, input_dims, suffix='_input'): out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs return out -def _maybe_stack_batch_dims(ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample'): +def _maybe_stack_batch_dims( + ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample' +): batch_dims = [d for d in ds.dims if d not in input_dims] if len(batch_dims) < 2: if squeeze_batch_dim: From c61846d01a97e6413132c8e7b25ae6f16051dfd0 Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Thu, 18 Nov 2021 20:11:59 +0000 Subject: [PATCH 06/12] omg --- xbatcher/generators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 9c105fd..b179fe3 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -53,6 +53,7 @@ def _drop_input_dims(ds, input_dims, suffix='_input'): out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs return out + def _maybe_stack_batch_dims( ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample' ): From 0e8f7166a5ce56dfee28c46de3dfc01caa87f21d Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Thu, 18 Nov 2021 20:45:34 +0000 Subject: [PATCH 07/12] squeeze_batch_dim test sketch --- xbatcher/tests/test_generators.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 54984b3..d3a2756 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -160,3 +160,23 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize): * (sample_ds_3d.dims['y'] // bsize) * sample_ds_3d.dims['time'] ) + + +@pytest.mark.parametrize('bsize', [5, 10]) +def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize): + xbsize = 20 + bg = BatchGenerator( + sample_ds_3d, + input_dims={'y': bsize, 'x': xbsize}, + squeeze_batch_dim=False, + ) + for ds_batch in bg: + assert ds_batch['x'].shape == [1, bsize, xbsize] + + bg2 = BatchGenerator( + sample_ds_3d, + input_dims={'y': bsize, 'x': xbsize}, + squeeze_batch_dim=True, + ) + for ds_batch in bg: + assert ds_batch['x'].shape == [bsize, xbsize] \ No newline at end of file From 749ac265634dd6f773bd55e1da78ec7bea62fefa Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Fri, 19 Nov 2021 16:33:29 +0000 Subject: [PATCH 08/12] squeeze_batch_dim test (attempt 2) --- xbatcher/tests/test_generators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index d3a2756..5227e77 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -167,15 +167,15 @@ def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize): xbsize = 20 bg = BatchGenerator( sample_ds_3d, - input_dims={'y': bsize, 'x': xbsize}, + input_dims={'time': 1, 'y': bsize, 'x': xbsize}, squeeze_batch_dim=False, ) for ds_batch in bg: assert ds_batch['x'].shape == [1, bsize, xbsize] - + bg2 = BatchGenerator( sample_ds_3d, - input_dims={'y': bsize, 'x': xbsize}, + input_dims={'time': 1, 'y': bsize, 'x': xbsize}, squeeze_batch_dim=True, ) for ds_batch in bg: From 142031d618988eec0bbd8db3b7ba9ab5dfbdbf5f Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Fri, 19 Nov 2021 18:19:49 +0000 Subject: [PATCH 09/12] Fix 1D squeeze_batch_dim test --- xbatcher/tests/test_generators.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 5227e77..399f02f 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -163,20 +163,20 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize): @pytest.mark.parametrize('bsize', [5, 10]) -def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize): +def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize): xbsize = 20 bg = BatchGenerator( - sample_ds_3d, - input_dims={'time': 1, 'y': bsize, 'x': xbsize}, + sample_ds_1d, + input_dims={'x': xbsize}, squeeze_batch_dim=False, ) for ds_batch in bg: - assert ds_batch['x'].shape == [1, bsize, xbsize] + assert list(ds_batch['foo'].shape) == [1, xbsize] bg2 = BatchGenerator( - sample_ds_3d, - input_dims={'time': 1, 'y': bsize, 'x': xbsize}, + sample_ds_1d, + input_dims={'x': xbsize}, squeeze_batch_dim=True, ) - for ds_batch in bg: - assert ds_batch['x'].shape == [bsize, xbsize] \ No newline at end of file + for ds_batch in bg2: + assert list(ds_batch['foo'].shape) == [xbsize] \ No newline at end of file From da42a9cab573a4c21b89be61bad808deaac8a101 Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Fri, 19 Nov 2021 19:03:35 +0000 Subject: [PATCH 10/12] More squeeze_batch_dim tests; fix bug --- xbatcher/generators.py | 13 ++++++---- xbatcher/tests/test_generators.py | 40 +++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index b179fe3..b522fa9 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -58,15 +58,18 @@ def _maybe_stack_batch_dims( ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample' ): batch_dims = [d for d in ds.dims if d not in input_dims] - if len(batch_dims) < 2: + if len(batch_dims) == 0: if squeeze_batch_dim: return ds else: return ds.expand_dims(stacked_dim_name, 0) - ds_stack = ds.stack(**{stacked_dim_name: batch_dims}) - # ensure correct order - dim_order = (stacked_dim_name,) + tuple(input_dims) - return ds_stack.transpose(*dim_order) + elif len(batch_dims) == 1: + return ds + else: + ds_stack = ds.stack(**{stacked_dim_name: batch_dims}) + # ensure correct order + dim_order = (stacked_dim_name,) + tuple(input_dims) + return ds_stack.transpose(*dim_order) class BatchGenerator: diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 9152007..8bf3df8 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -189,6 +189,46 @@ def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize): assert list(ds_batch['foo'].shape) == [xbsize] +@pytest.mark.parametrize('bsize', [5, 10]) +def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize): + xbsize = 20 + bg = BatchGenerator( + sample_ds_3d, + input_dims={'y': bsize, 'x': xbsize}, + squeeze_batch_dim=False, + ) + for ds_batch in bg: + assert list(ds_batch['foo'].shape) == [10, bsize, xbsize] + + bg2 = BatchGenerator( + sample_ds_3d, + input_dims={'y': bsize, 'x': xbsize}, + squeeze_batch_dim=True, + ) + for ds_batch in bg2: + assert list(ds_batch['foo'].shape) == [10, bsize, xbsize] + + +@pytest.mark.parametrize('bsize', [5, 10]) +def test_batch_3d_squeeze_batch_dim2(sample_ds_3d, bsize): + xbsize = 20 + bg = BatchGenerator( + sample_ds_3d, + input_dims={'x': xbsize}, + squeeze_batch_dim=False, + ) + for ds_batch in bg: + assert list(ds_batch['foo'].shape) == [500, xbsize] + + bg2 = BatchGenerator( + sample_ds_3d, + input_dims={'x': xbsize}, + squeeze_batch_dim=True, + ) + for ds_batch in bg2: + assert list(ds_batch['foo'].shape) == [500, xbsize] + + def test_preload_batch_false(sample_ds_1d): sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2}) bg = BatchGenerator( From a54a9b79ca230c32526cfb939873ae6529214360 Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Tue, 30 Nov 2021 19:10:08 +0000 Subject: [PATCH 11/12] minor updates --- xbatcher/generators.py | 6 +++--- xbatcher/tests/test_generators.py | 11 +++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index b522fa9..886ac47 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -99,9 +99,9 @@ class BatchGenerator: If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. squeeze_batch_dim : bool, optional - If ``False", each batch's dataset will have a "batch" dimension of size 1 - prepended to the array. This functionality is useful for interoperability - with Keras / Tensorflow. + If ``False" and all dims are input dims, each batch's dataset will have a + "batch" dimension of size 1 prepended to the array. This functionality is + useful for interoperability with Keras / Tensorflow. Yields ------ diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 8bf3df8..05683c0 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -169,24 +169,23 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize): ) -@pytest.mark.parametrize('bsize', [5, 10]) +@pytest.mark.parametrize('bsize', [10, 20]) def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize): - xbsize = 20 bg = BatchGenerator( sample_ds_1d, - input_dims={'x': xbsize}, + input_dims={'x': bsize}, squeeze_batch_dim=False, ) for ds_batch in bg: - assert list(ds_batch['foo'].shape) == [1, xbsize] + assert list(ds_batch['foo'].shape) == [1, bsize] bg2 = BatchGenerator( sample_ds_1d, - input_dims={'x': xbsize}, + input_dims={'x': bsize}, squeeze_batch_dim=True, ) for ds_batch in bg2: - assert list(ds_batch['foo'].shape) == [xbsize] + assert list(ds_batch['foo'].shape) == [bsize] @pytest.mark.parametrize('bsize', [5, 10]) From f56919055c3c3202ae0c71ec47cef8075484df9b Mon Sep 17 00:00:00 2001 From: cmdupuis3 Date: Tue, 30 Nov 2021 19:46:55 +0000 Subject: [PATCH 12/12] streamline squeeze_batch_dim tests --- xbatcher/generators.py | 2 +- xbatcher/tests/test_generators.py | 18 ++++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 886ac47..a257567 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -99,7 +99,7 @@ class BatchGenerator: If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. squeeze_batch_dim : bool, optional - If ``False" and all dims are input dims, each batch's dataset will have a + If ``False`` and all dims are input dims, each batch's dataset will have a "batch" dimension of size 1 prepended to the array. This functionality is useful for interoperability with Keras / Tensorflow. diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 05683c0..1b99083 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -190,42 +190,40 @@ def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize): @pytest.mark.parametrize('bsize', [5, 10]) def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize): - xbsize = 20 bg = BatchGenerator( sample_ds_3d, - input_dims={'y': bsize, 'x': xbsize}, + input_dims={'y': bsize, 'x': bsize}, squeeze_batch_dim=False, ) for ds_batch in bg: - assert list(ds_batch['foo'].shape) == [10, bsize, xbsize] + assert list(ds_batch['foo'].shape) == [10, bsize, bsize] bg2 = BatchGenerator( sample_ds_3d, - input_dims={'y': bsize, 'x': xbsize}, + input_dims={'y': bsize, 'x': bsize}, squeeze_batch_dim=True, ) for ds_batch in bg2: - assert list(ds_batch['foo'].shape) == [10, bsize, xbsize] + assert list(ds_batch['foo'].shape) == [10, bsize, bsize] @pytest.mark.parametrize('bsize', [5, 10]) def test_batch_3d_squeeze_batch_dim2(sample_ds_3d, bsize): - xbsize = 20 bg = BatchGenerator( sample_ds_3d, - input_dims={'x': xbsize}, + input_dims={'x': bsize}, squeeze_batch_dim=False, ) for ds_batch in bg: - assert list(ds_batch['foo'].shape) == [500, xbsize] + assert list(ds_batch['foo'].shape) == [500, bsize] bg2 = BatchGenerator( sample_ds_3d, - input_dims={'x': xbsize}, + input_dims={'x': bsize}, squeeze_batch_dim=True, ) for ds_batch in bg2: - assert list(ds_batch['foo'].shape) == [500, xbsize] + assert list(ds_batch['foo'].shape) == [500, bsize] def test_preload_batch_false(sample_ds_1d):