Skip to content

Commit 68303f2

Browse files
committed
LLMBlockConcurrency: Test that _sdg_init handles batch_size correctly
#135 Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 470ffbb commit 68303f2

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

tests/test_generate_data.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
Unit tests for the top-level generate_data module.
3+
"""
4+
5+
# Standard
6+
from unittest import mock
7+
8+
# First Party
9+
from instructlab.sdg.generate_data import _sdg_init
10+
from instructlab.sdg.pipeline import PipelineContext
11+
12+
13+
def test_sdg_init_batch_size_optional():
14+
"""Test that the _sdg_init function can handle a missing batch size by
15+
delegating to the default in PipelineContext.
16+
"""
17+
sdgs = _sdg_init(
18+
"simple",
19+
None,
20+
"mixtral",
21+
"foo.bar",
22+
1,
23+
batch_size=None,
24+
batch_num_workers=None,
25+
)
26+
assert all(
27+
pipe.ctx.batch_size == PipelineContext.DEFAULT_BATCH_SIZE
28+
for sdg in sdgs
29+
for pipe in sdg.pipelines
30+
)
31+
32+
33+
def test_sdg_init_batch_size_optional():
34+
"""Test that the _sdg_init function can handle a passed batch size"""
35+
sdgs = _sdg_init(
36+
"simple",
37+
None,
38+
"mixtral",
39+
"foo.bar",
40+
1,
41+
batch_size=20,
42+
batch_num_workers=32,
43+
)
44+
assert all(pipe.ctx.batch_size == 20 for sdg in sdgs for pipe in sdg.pipelines)

0 commit comments

Comments
 (0)