Skip to content

Commit 2871cb2

Browse files
authored
Merge pull request #17 from pfizer-opensource/feature/zero_is_not_nan
Adding feature to use custom "default_value" as filler for intervals …
2 parents b026d84 + 499b766 commit 2871cb2

File tree

9 files changed

+112
-19
lines changed

9 files changed

+112
-19
lines changed

CHANGELOG.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
## [Unreleased]
1010
...
1111

12+
## [v0.1.5]
13+
### Added
14+
- set a default value different from 0.0
15+
1216
## [v0.1.4]
1317
### Fixed
1418
- Updated README.md with better install instructions. Making release so
@@ -43,8 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4347
### Added
4448
- release to pypi
4549

46-
[Unreleased]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.4...HEAD
47-
[v0.1.3]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.4...v0.1.4
50+
[Unreleased]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.5...HEAD
51+
[v0.1.5]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.4...v0.1.5
52+
[v0.1.4]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.3...v0.1.4
4853
[v0.1.3]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.2...v0.1.3
4954
[v0.1.2]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.1...v0.1.2
5055
[v0.1.1]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.0...v0.1.1

bigwig_loader/batch_processor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def get_batch(
104104
end: Union[Sequence[int], npt.NDArray[np.int64]],
105105
window_size: int = 1,
106106
scaling_factors_cupy: Optional[cp.ndarray] = None,
107+
default_value: float = 0.0,
107108
out: Optional[cp.ndarray] = None,
108109
) -> cp.ndarray:
109110
(
@@ -137,10 +138,10 @@ def get_batch(
137138
query_starts=abs_start,
138139
query_ends=abs_end,
139140
window_size=window_size,
141+
default_value=default_value,
140142
out=out,
141143
)
142144
batch = cp.transpose(out, (1, 0, 2))
143145
if scaling_factors_cupy is not None:
144146
batch *= scaling_factors_cupy
145-
146147
return batch

bigwig_loader/collection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def get_batch(
140140
start: Union[Sequence[int], npt.NDArray[np.int64]],
141141
end: Union[Sequence[int], npt.NDArray[np.int64]],
142142
window_size: int = 1,
143+
default_value: float = 0.0,
143144
out: Optional[cp.ndarray] = None,
144145
) -> cp.ndarray:
145146
return self.batch_processor.get_batch(
@@ -148,6 +149,7 @@ def get_batch(
148149
end=end,
149150
window_size=window_size,
150151
scaling_factors_cupy=self.scaling_factors_cupy,
152+
default_value=default_value,
151153
out=out,
152154
)
153155

@@ -171,7 +173,8 @@ def make_positions_global(
171173
172174
"""
173175
offsets = np.array(
174-
[self.chromosome_offset_dict[chrom] for chrom in chromosomes]
176+
[self.chromosome_offset_dict[chrom] for chrom in chromosomes],
177+
dtype=np.int64,
175178
)
176179
return positions + offsets
177180

bigwig_loader/dataset.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any
44
from typing import Callable
55
from typing import Iterator
6+
from typing import Literal
67
from typing import Optional
78
from typing import Sequence
89
from typing import Union
@@ -38,16 +39,16 @@ class BigWigDataset:
3839
reference_genome_path: path to fasta file containing the reference genome.
3940
sequence_length: number of base pairs in input sequence
4041
center_bin_to_predict: if given, only do prediction on a central window. Should be
41-
smaller than or equal to sequence_length. If not given will be the same as
42-
sequence_length.
42+
smaller than or equal to sequence_length. If None, the whole sequence length
43+
will be used. Default: None
4344
window_size: used to down sample the resolution of the target from sequence_length
4445
moving_average_window_size: window size for moving average on the target. Can
4546
help too smooth out the target. Default: 1, which means no smoothing. If
4647
used in combination with window_size, the target is first downsampled and
4748
then smoothed.
4849
batch_size: batch size
4950
super_batch_size: batch size that is used in the background to load data from
50-
bigwig files. Should be larget than or equal to batch_size. If None, it will
51+
bigwig files. Should be larger than or equal to batch_size. If None, it will
5152
be equal to batch_size.
5253
batches_per_epoch: because the length of an epoch is slightly arbitrary here,
5354
the number of batches can be set by hand. If not the number of batches per
@@ -61,6 +62,9 @@ class BigWigDataset:
6162
If None, no scaling is done. Keys can be (partial) file paths. See
6263
bigwig_loader.path.match_key_to_path for more information about how
6364
dict keys are mapped to paths.
65+
default_value: value to use for intervals that are not present in the
66+
bigwig file. Defaults to 0.0. Can be set to cp.nan to differentiate
67+
between missing values listed as 0.0.
6468
first_n_files: Only use the first n files (handy for debugging on less tasks)
6569
position_sampler_buffer_size: number of intervals picked up front by the position sampler.
6670
When all intervals are used, new intervals are picked.
@@ -73,7 +77,7 @@ class BigWigDataset:
7377
n_threads: number of python threads / cuda streams to use for loading the data to
7478
GPU. More threads means that more IO can take place while the GPU is busy doing
7579
calculations (decompressing or neural network training for example). More threads
76-
also means a higher GPU memory usage.
80+
also means a higher GPU memory usage. Default: 4
7781
return_batch_objects: if True, the batches will be returned as instances of
7882
bigwig_loader.batch.Batch
7983
"""
@@ -92,11 +96,12 @@ def __init__(
9296
batches_per_epoch: Optional[int] = None,
9397
maximum_unknown_bases_fraction: float = 0.1,
9498
sequence_encoder: Optional[
95-
Union[Callable[[Sequence[str]], Any], str]
99+
Union[Callable[[Sequence[str]], Any], Literal["onehot"]]
96100
] = "onehot",
97101
file_extensions: Sequence[str] = (".bigWig", ".bw"),
98102
crawl: bool = True,
99103
scale: Optional[dict[Union[str | Path], Any]] = None,
104+
default_value: float = 0.0,
100105
first_n_files: Optional[int] = None,
101106
position_sampler_buffer_size: int = 100000,
102107
repeat_same_positions: bool = False,
@@ -139,6 +144,7 @@ def __init__(
139144
self._first_n_files = first_n_files
140145
self._file_extensions = file_extensions
141146
self._crawl = crawl
147+
self._default_value = default_value
142148
self._scale = scale
143149
self._position_sampler_buffer_size = position_sampler_buffer_size
144150
self._repeat_same_positions = repeat_same_positions
@@ -181,6 +187,7 @@ def _create_dataloader(self) -> StreamedDataloader:
181187
queue_size=self._n_threads + 1,
182188
slice_size=self.batch_size,
183189
window_size=self.window_size,
190+
default_value=self._default_value,
184191
)
185192

186193
def __iter__(

bigwig_loader/intervals_to_values.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
CUDA_KERNEL_DIR = Path(__file__).parent.parent / "cuda_kernels"
1010

11-
_zero = cp.asarray(0.0, dtype=cp.float32).item()
12-
1311

1412
def get_cuda_kernel() -> str:
1513
with open(CUDA_KERNEL_DIR / "intervals_to_values.cu") as f:
@@ -31,6 +29,7 @@ def intervals_to_values(
3129
found_ends: cp.ndarray | None = None,
3230
sizes: cp.ndarray | None = None,
3331
window_size: int = 1,
32+
default_value: float = 0.0,
3433
out: cp.ndarray | None = None,
3534
) -> cp.ndarray:
3635
"""
@@ -58,6 +57,8 @@ def intervals_to_values(
5857
sizes: number of elements in track_starts/track_ends/track_values for each track.
5958
Only needed when found_starts and found_ends are not given.
6059
window_size: size in basepairs to average over (default: 1)
60+
default_value: value to use for regions where no data is specified (default: 0.0)
61+
out: array of size n_tracks x batch_size x sequence_length to store the output.
6162
Returns:
6263
out: array of size n_tracks x batch_size x sequence_length
6364
@@ -85,12 +86,15 @@ def intervals_to_values(
8586
)
8687

8788
if out is None:
88-
out = cp.zeros(
89+
out = cp.full(
8990
(found_starts.shape[0], len(query_starts), sequence_length // window_size),
91+
default_value,
9092
dtype=cp.float32,
9193
)
9294
else:
93-
out *= _zero
95+
logging.debug(f"Setting default value in output tensor to {default_value}")
96+
out.fill(default_value)
97+
logging.debug(out)
9498

9599
max_number_intervals = min(
96100
sequence_length, (found_ends - found_starts).max().item()
@@ -136,7 +140,6 @@ def intervals_to_values(
136140
out,
137141
),
138142
)
139-
140143
return out
141144

142145

bigwig_loader/pytorch.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ class PytorchBigWigDataset(IterableDataset[BATCH_TYPE]):
124124
reference_genome_path: path to fasta file containing the reference genome.
125125
sequence_length: number of base pairs in input sequence
126126
center_bin_to_predict: if given, only do prediction on a central window. Should be
127-
smaller than or equal to sequence_length. If not given will be the same as
128-
sequence_length.
127+
smaller than or equal to sequence_length. If None, the whole sequence length
128+
will be used. Default: None
129129
window_size: used to down sample the resolution of the target from sequence_length
130130
moving_average_window_size: window size for moving average on the target. Can
131131
help too smooth out the target. Default: 1, which means no smoothing. If
@@ -141,8 +141,21 @@ class PytorchBigWigDataset(IterableDataset[BATCH_TYPE]):
141141
maximum_unknown_bases_fraction: maximum number of bases in an input sequence that
142142
is unknown.
143143
sequence_encoder: encoder to apply to the sequence. Default: bigwig_loader.util.onehot_sequences
144-
position_samples_buffer_size: number of intervals picked up front by the position sampler.
144+
file_extensions: load files with these extensions (default .bw and .bigWig)
145+
crawl: whether to search in sub-directories for BigWig files
146+
scale: Optional, dictionary with scaling factors for each BigWig file.
147+
If None, no scaling is done. Keys can be (partial) file paths. See
148+
bigwig_loader.path.match_key_to_path for more information about how
149+
dict keys are mapped to paths.
150+
default_value: value to use for intervals that are not present in the
151+
bigwig file. Defaults to 0.0. Can be set to cp.nan to differentiate
152+
between missing values listed as 0.0.
153+
first_n_files: Only use the first n files (handy for debugging on less tasks)
154+
position_sampler_buffer_size: number of intervals picked up front by the position sampler.
145155
When all intervals are used, new intervals are picked.
156+
repeat_same_positions: if True the positions sampler does not draw a new random collection
157+
of positions when the buffer runs out, but repeats the same samples. Can be used to
158+
check whether network can overfit.
146159
sub_sample_tracks: int, if set a different random set of tracks is selected in each
147160
superbatch from the total number of tracks. The indices corresponding to those tracks
148161
are returned in the output.
@@ -160,7 +173,7 @@ def __init__(
160173
collection: Union[str, Sequence[str], Path, Sequence[Path], BigWigCollection],
161174
reference_genome_path: Path,
162175
sequence_length: int = 1000,
163-
center_bin_to_predict: Optional[int] = 200,
176+
center_bin_to_predict: Optional[int] = None,
164177
window_size: int = 1,
165178
moving_average_window_size: int = 1,
166179
batch_size: int = 256,
@@ -172,6 +185,8 @@ def __init__(
172185
] = "onehot",
173186
file_extensions: Sequence[str] = (".bigWig", ".bw"),
174187
crawl: bool = True,
188+
scale: Optional[dict[Union[str | Path], Any]] = None,
189+
default_value: float = 0.0,
175190
first_n_files: Optional[int] = None,
176191
position_sampler_buffer_size: int = 100000,
177192
repeat_same_positions: bool = False,
@@ -195,6 +210,8 @@ def __init__(
195210
sequence_encoder=sequence_encoder,
196211
file_extensions=file_extensions,
197212
crawl=crawl,
213+
scale=scale,
214+
default_value=default_value,
198215
first_n_files=first_n_files,
199216
position_sampler_buffer_size=position_sampler_buffer_size,
200217
repeat_same_positions=repeat_same_positions,

bigwig_loader/streamed_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
queue_size: int = 10,
109109
slice_size: int | None = None,
110110
window_size: int = 1,
111+
default_value: float = 0.0,
111112
):
112113
self.input_generator = input_generator
113114
self.collection = collection
@@ -127,6 +128,7 @@ def __init__(
127128
self.data_generator_thread: threading.Thread | None = None
128129
self._entered = False
129130
self._out = None
131+
self._default_value = default_value
130132

131133
def __enter__(self) -> "StreamedDataloader":
132134
self._entered = True
@@ -214,6 +216,7 @@ def _generate_batches(self) -> Generator[Batch, None, None]:
214216
found_starts=found_starts[:, select],
215217
found_ends=found_ends[:, select],
216218
window_size=self.window_size,
219+
default_value=self._default_value,
217220
out=out,
218221
)
219222

@@ -248,7 +251,7 @@ def _get_out_tensor(
248251
shape = (number_of_tracks, batch_size, sequence_length)
249252

250253
if self._out is None or self._out.shape != shape:
251-
self._out = cp.zeros(shape, dtype=cp.float32)
254+
self._out = cp.full(shape, self._default_value, dtype=cp.float32)
252255
return self._out
253256

254257
def _determine_slice_size(self, n_samples: int) -> int:

tests/test_collection.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,29 @@ def test_get_batch(collection):
5757
assert batch.shape == (256, n_files, 1000)
5858

5959

60+
def test_get_batch_with_nans(collection):
61+
"""
62+
Testing whether NaNs are returned when setting
63+
the default_value to cp.nan. Giving it some
64+
coordinates way beyond the total chromosome
65+
length now because the bigwigs I am testing
66+
on include intervals with value 0.0 instead
67+
of just not listing it.
68+
"""
69+
n_files = len(collection.bigwig_paths)
70+
71+
batch = collection.get_batch(
72+
["chr1", "chr20", "chr4"],
73+
[0, 99998999, 99998999],
74+
[1000, 99999999, 99999999],
75+
default_value=cp.nan,
76+
)
77+
78+
print(batch)
79+
assert batch.shape == (3, n_files, 1000)
80+
assert cp.any(cp.isnan(batch))
81+
82+
6083
def test_exclude_intervals(collection):
6184
intervals = collection.intervals(
6285
["chr3", "chr4", "chr5"], exclude_chromosomes=["chr4"]

tests/test_intervals_to_values.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,34 @@ def test_get_values_from_intervals_batch_multiple_tracks() -> None:
230230
print(expected)
231231
print(values)
232232
assert (values == expected).all()
233+
234+
235+
def test_default_nan() -> None:
236+
"""Query end is exactly at end index before "gap"
237+
Now instead of zeros, NaN values should be
238+
used.
239+
."""
240+
track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32)
241+
track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32)
242+
track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4"))
243+
query_starts = cp.asarray([7, 9], dtype=cp.int32)
244+
query_ends = cp.asarray([18, 20], dtype=cp.int32)
245+
reserved = cp.zeros([2, 11], dtype=cp.dtype("<f4"))
246+
values = intervals_to_values(
247+
track_starts,
248+
track_ends,
249+
track_values,
250+
query_starts,
251+
query_ends,
252+
default_value=cp.nan,
253+
out=reserved,
254+
)
255+
expected = cp.asarray(
256+
[
257+
[20.0, 20.0, 20.0, 30.0, 30.0, 40.0, 40.0, cp.nan, cp.nan, cp.nan, cp.nan],
258+
[20.0, 30.0, 30.0, 40.0, 40.0, cp.nan, cp.nan, cp.nan, cp.nan, 50.0, 50.0],
259+
]
260+
)
261+
print(expected)
262+
print(values)
263+
assert cp.allclose(values, expected, equal_nan=True)

0 commit comments

Comments
 (0)