1+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ # SPDX-License-Identifier: LicenseRef-Apache2
3+ #
4+ # Licensed under the Apache License, Version 2.0 (the "License");
5+ # you may not use this file except in compliance with the License.
6+ # You may obtain a copy of the License at
7+ #
8+ # http://www.apache.org/licenses/LICENSE-2.0
9+ #
10+ # Unless required by applicable law or agreed to in writing, software
11+ # distributed under the License is distributed on an "AS IS" BASIS,
12+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ # See the License for the specific language governing permissions and
14+ # limitations under the License.
15+
116# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
217# SPDX-License-Identifier: LicenseRef-Apache2
318
@@ -31,19 +46,19 @@ def test_tokenizer_encode_simple_sequences(tokenizer):
3146 """Test encoding a simple repeated character sequences."""
3247 sequence = "AAAA"
3348 encoded = tokenizer .encode (sequence , add_special_tokens = True )
34-
49+
3550 # Expected: BOS + AAAA + EOS = [2, 65, 65, 65, 65, 0]
3651 expected = [2 , 65 , 65 , 65 , 65 , 0 ]
3752 assert encoded == expected
3853
39- sequence = "C"
54+ sequence = "C"
4055 encoded = tokenizer .encode (sequence , add_special_tokens = True )
4156
4257 # Expected: BOS + C + EOS = [2, 67, 0]
4358 expected = [2 , 67 , 0 ]
4459 assert encoded == expected
45-
46- sequence = "G" * 20
60+
61+ sequence = "G" * 20
4762 encoded = tokenizer .encode (sequence , add_special_tokens = True )
4863 expected = [2 ] + [71 ] * 20 + [0 ]
4964 assert encoded == expected
@@ -53,7 +68,7 @@ def test_tokenizer_encode_without_special_tokens(tokenizer):
5368 """Test encoding without BOS/EOS tokens."""
5469 sequence = "TTTT"
5570 encoded = tokenizer .encode (sequence , add_special_tokens = False )
56-
71+
5772 # Expected: just the Ts (T=84)
5873 expected = [84 , 84 , 84 , 84 ]
5974 assert encoded == expected
@@ -64,7 +79,7 @@ def test_tokenizer_roundtrip_encode_decode(tokenizer):
6479 sequence = "ATCGATCG"
6580 encoded = tokenizer .encode (sequence , add_special_tokens = True )
6681 decoded = tokenizer .decode (encoded , skip_special_tokens = True )
67-
82+
6883 # Decoded may have spaces between tokens, so compare without spaces
6984 assert sequence == decoded .replace (" " , "" )
7085
@@ -81,43 +96,43 @@ def test_tokenizer_nucleotide_mappings(tokenizer):
8196def test_tokenizer_padding_to_longest (tokenizer ):
8297 """Test padding pads to longest sequence in batch."""
8398 batch = tokenizer (["AAAA" , "TTTTTTTT" ], padding = True , add_special_tokens = True , return_tensors = "pt" )
84-
99+
85100 # AAAA → [2, 65, 65, 65, 65, 0] = 6 tokens
86101 # TTTTTTTT → [2, 84, 84, 84, 84, 84, 84, 84, 84, 0] = 10 tokens
87102 # Should pad to 10
88- assert batch [' input_ids' ].shape == torch .Size ([2 , 10 ])
89-
103+ assert batch [" input_ids" ].shape == torch .Size ([2 , 10 ])
104+
90105 # First sequence should have padding (PAD=1)
91- assert batch [' input_ids' ][0 , 6 ].item () == 1 # First padding position
92- assert batch [' input_ids' ][0 , 9 ].item () == 1 # Last padding position
93-
106+ assert batch [" input_ids" ][0 , 6 ].item () == 1 # First padding position
107+ assert batch [" input_ids" ][0 , 9 ].item () == 1 # Last padding position
108+
94109 # Attention mask: 1 for real tokens, 0 for padding
95- assert batch [' attention_mask' ][0 , 5 ].item () == 1 # Last real token
96- assert batch [' attention_mask' ][0 , 6 ].item () == 0 # First padding
110+ assert batch [" attention_mask" ][0 , 5 ].item () == 1 # Last real token
111+ assert batch [" attention_mask" ][0 , 6 ].item () == 0 # First padding
97112
98113
99114def test_tokenizer_attention_mask_correct (tokenizer ):
100115 """Test attention mask is 1 for real tokens, 0 for padding."""
101116 batch = tokenizer (["GG" , "GGGGGG" ], padding = True , add_special_tokens = True , return_tensors = "pt" )
102-
117+
103118 # GG → 4 tokens (BOS + GG + EOS)
104119 # GGGGGG → 8 tokens (BOS + GGGGGG + EOS)
105120 # Padded to 8 tokens
106-
121+
107122 # First sequence: 4 real + 4 padding
108123 expected_mask_0 = [1 , 1 , 1 , 1 , 0 , 0 , 0 , 0 ]
109- assert batch [' attention_mask' ][0 ].tolist () == expected_mask_0
110-
124+ assert batch [" attention_mask" ][0 ].tolist () == expected_mask_0
125+
111126 # Second sequence: all real
112127 expected_mask_1 = [1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
113- assert batch [' attention_mask' ][1 ].tolist () == expected_mask_1
128+ assert batch [" attention_mask" ][1 ].tolist () == expected_mask_1
114129
115130
116131def test_tokenizer_mixed_nucleotides (tokenizer ):
117132 """Test all standard nucleotides encode correctly."""
118133 sequence = "ATCGGTC"
119134 encoded = tokenizer .encode (sequence , add_special_tokens = False )
120-
135+
121136 # A=65, T=84, C=67, G=71
122137 # ATCGGTC = A, T, C, G, G, T, C
123138 expected = [65 , 84 , 67 , 71 , 71 , 84 , 67 ]
@@ -136,20 +151,20 @@ def test_tokenizer_special_nucleotides(tokenizer):
136151
137152def test_10kbp_sequence_creates_expected_window_count (tokenizer ):
138153 """Test 10kbp sequence creates correct number of windows with seq_length=1000, stride=800.
139-
154+
140155 Verifies windowing math: 10000bp with seq_length=1000, stride=800.
141156 """
142157 sequence = "A" * 10000 # 10kbp
143-
158+
144159 result = tokenizer (
145160 sequence ,
146161 max_length = 1000 ,
147- stride = 800 , # 800 token overlap
162+ stride = 800 , # 800 token overlap
148163 truncation = True ,
149164 return_overflowing_tokens = True ,
150165 add_special_tokens = True ,
151166 )
152-
167+
153168 # Hardcoded expectation based on input data:
154169 # 10000bp with 1000 token windows and 800 token stride
155170 # Step forward = 1000 - 800 = 200 tokens per window
@@ -159,7 +174,7 @@ def test_10kbp_sequence_creates_expected_window_count(tokenizer):
159174def test_overlapping_windows_creates_more_samples (tokenizer ):
160175 """Test overlapping stride creates more windows than less overlapping."""
161176 sequence = "ATCG" * 2500 # 10kbp
162-
177+
163178 result_more_overlap = tokenizer (
164179 sequence ,
165180 max_length = 1000 ,
@@ -168,7 +183,7 @@ def test_overlapping_windows_creates_more_samples(tokenizer):
168183 return_overflowing_tokens = True ,
169184 add_special_tokens = True ,
170185 )
171-
186+
172187 result_less_overlap = tokenizer (
173188 sequence ,
174189 max_length = 1000 ,
@@ -177,7 +192,7 @@ def test_overlapping_windows_creates_more_samples(tokenizer):
177192 return_overflowing_tokens = True ,
178193 add_special_tokens = True ,
179194 )
180-
195+
181196 # Hardcoded expectations
182197 assert len (result_more_overlap ["input_ids" ]) == 47 # With more overlap (smaller step)
183198 assert len (result_less_overlap ["input_ids" ]) == 20 # With less overlap (larger step)
@@ -187,7 +202,7 @@ def test_overlapping_windows_creates_more_samples(tokenizer):
187202def test_production_window_length_creates_expected_samples (tokenizer ):
188203 """Test production settings (8192 window, 200 overlap) create correct number of windows."""
189204 sequence = "A" * 50000 # 50kbp sequence
190-
205+
191206 result = tokenizer (
192207 sequence ,
193208 max_length = 8192 ,
@@ -196,7 +211,7 @@ def test_production_window_length_creates_expected_samples(tokenizer):
196211 return_overflowing_tokens = True ,
197212 add_special_tokens = True ,
198213 )
199-
214+
200215 # Hardcoded expectation with production settings:
201216 # 50000bp with 8192 window and 200 stride (overlap)
202217 # Step forward = 8192 - 200 = 7992 tokens per window
@@ -206,7 +221,7 @@ def test_production_window_length_creates_expected_samples(tokenizer):
206221def test_short_sequences_dont_overflow (tokenizer ):
207222 """Test that short sequences (< max_length) don't create overflow windows."""
208223 sequence = "ATCG" * 100 # 400bp
209-
224+
210225 result = tokenizer (
211226 sequence ,
212227 max_length = 1000 ,
@@ -215,7 +230,7 @@ def test_short_sequences_dont_overflow(tokenizer):
215230 return_overflowing_tokens = True ,
216231 add_special_tokens = True ,
217232 )
218-
233+
219234 # Sequence is shorter than max_length, should only create 1 window
220235 assert len (result ["input_ids" ]) == 1
221236 # Length should be 400bp + BOS + EOS = 402 tokens
@@ -224,40 +239,38 @@ def test_short_sequences_dont_overflow(tokenizer):
224239
225240def test_bos_eos_in_overlapping_windows (tokenizer ):
226241 """Test that BOS/EOS tokens are added to every overlapping window.
227-
242+
228243 Verifies that when using return_overflowing_tokens with add_special_tokens=True,
229244 each window gets its own BOS and EOS tokens, treating each as an independent sequence.
230245 This matches the behavior needed for causal language modeling training.
231246 """
232247 # Use a short genomic sequence that will produce exactly 2 overlapping windows
233248 # With max_length=7 and stride=4, sequence of 8bp should give 2 windows
234249 sequence = "ATCGATCG" # 8bp
235-
250+
236251 result = tokenizer (
237252 sequence ,
238- max_length = 7 , # BOS + 5 content + EOS = 7 tokens total
239- stride = 4 , # Overlap of 4 tokens between windows
253+ max_length = 7 , # BOS + 5 content + EOS = 7 tokens total
254+ stride = 4 , # Overlap of 4 tokens between windows
240255 truncation = True ,
241256 return_overflowing_tokens = True ,
242257 add_special_tokens = True ,
243258 )
244-
259+
245260 # Should produce exactly 2 windows
246261 num_windows = len (result ["input_ids" ])
247262 assert num_windows >= 2 , f"Should produce at least 2 overlapping windows, got { num_windows } "
248-
263+
249264 first_window = result ["input_ids" ][0 ]
250265 second_window = result ["input_ids" ][1 ]
251-
266+
252267 # Verify both windows have BOS at start and EOS at end
253268 assert first_window [0 ] == tokenizer .bos_token_id
254269 assert first_window [- 1 ] == tokenizer .eos_token_id
255270 assert second_window [0 ] == tokenizer .bos_token_id
256271 assert second_window [- 1 ] == tokenizer .eos_token_id
257-
272+
258273 # Verify windows are actually overlapping by checking they share some content
259274 first_content = set (first_window [1 :- 1 ])
260275 second_content = set (second_window [1 :- 1 ])
261276 assert len (first_content & second_content ) > 0
262-
263-
0 commit comments