Skip to content

Commit 18c3a03

Browse files
committed
Add validation for empty dataset and enhance oneshot function parameters
Signed-off-by: Arka Sanka <[email protected]>
1 parent 1c85a66 commit 18c3a03

File tree

3 files changed

+72
-9
lines changed

3 files changed

+72
-9
lines changed

src/llmcompressor/datasets/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ def format_calibration_data(
149149
f"the provided dataset only has {safe_calibration_samples}. "
150150
)
151151

152+
if safe_calibration_samples == 0:
153+
logger.error("Dataset is empty. Cannot create a calibration dataloader.")
154+
raise ValueError(
155+
"Dataset is empty. Cannot create a calibration dataloader with 0 samples."
156+
)
157+
152158
if do_shuffle:
153159
tokenized_dataset = tokenized_dataset.shuffle()
154160
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import os
1313
from datetime import datetime
1414
from pathlib import Path
15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Any, Callable
1616

1717
from loguru import logger
1818
from torch.utils.data import DataLoader
@@ -260,8 +260,16 @@ def oneshot(
260260
preprocessing_num_workers: int | None = None,
261261
min_tokens_per_module: float | None = None,
262262
moe_calibrate_all_experts: bool = True,
263+
pipeline: str = "independent",
264+
tracing_ignore: list[str] | None = None,
265+
raw_kwargs: dict[str, Any] | None = None,
266+
preprocessing_func: Callable | None = None,
267+
max_train_samples: int | None = None,
268+
remove_columns: list[str] | None = None,
269+
dvc_data_repository: str | None = None,
263270
quantization_aware_calibration: bool = True,
264-
# Miscellaneous arguments
271+
sequential_targets: list[str] | None = None,
272+
# Miscellaneous arguments
265273
output_dir: str | None = None,
266274
log_dir: str | None = None,
267275
**kwargs,
@@ -331,6 +339,16 @@ def oneshot(
331339
during forward pass in calibration. When False, quantization is disabled
332340
during forward pass in calibration. Default is set to True.
333341
342+
:param pipeline: The pipeline configuration to use for calibration. Options include
343+
'independent', 'sequential', or 'layer_sequential'.
344+
:param tracing_ignore: List of module names to ignore during tracing.
345+
:param raw_kwargs: Dictionary of raw keyword arguments passed to the function.
346+
:param preprocessing_func: Optional callable for preprocessing the dataset.
347+
:param max_train_samples: Maximum number of training samples to use.
348+
:param remove_columns: List of column names to remove from the dataset.
349+
:param dvc_data_repository: Path to the DVC data repository, if applicable.
350+
:param sequential_targets: List of sequential targets for calibration.
351+
334352
# Miscellaneous arguments
335353
:param output_dir: Path to save the output model after calibration.
336354
Nothing is saved if None.
@@ -340,10 +358,18 @@ def oneshot(
340358
:return: The calibrated PreTrainedModel
341359
"""
342360

343-
# pass all args directly into Oneshot
361+
if sequential_targets and pipeline == "independent":
362+
raise ValueError(
363+
"Invalid configuration: "
364+
"sequential_targets' cannot be used with 'independent' pipeline. "
365+
"Please use 'sequential' or 'layer_sequential' pipeline when specifying "
366+
"sequential_targets."
367+
)
368+
344369
local_args = {
345370
k: v for k, v in locals().items() if k not in ("local_args", "kwargs")
346371
}
372+
347373
one_shot = Oneshot(**local_args, **kwargs)
348374
one_shot()
349375

tests/llmcompressor/transformers/oneshot/test_api_inputs.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import logging
2+
13
import pytest
24
from transformers import AutoModelForCausalLM, AutoTokenizer
35

46
from llmcompressor import oneshot
57
from tests.llmcompressor.transformers.oneshot.dataset_processing import get_data_utils
68
from tests.testing_utils import parse_params
79

10+
logging.basicConfig(level=logging.INFO)
11+
logger = logging.getLogger(__name__)
12+
813
CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/oneshot/oneshot_configs"
914

1015
# TODO: Seems better to mark test type (smoke, sanity, regression) as a marker as
@@ -42,15 +47,41 @@ def wrapped_preprocess_func(sample):
4247
dataset_config_name=config.get("dataset_config_name"),
4348
)
4449

50+
args["pipeline"] = config.get("pipeline", "independent")
51+
args["sequential_targets"] = config.get("sequential_targets", None)
52+
args["tracing_ignore"] = config.get("tracing_ignore", [])
53+
args["raw_kwargs"] = config.get("raw_kwargs", {})
54+
args["preprocessing_func"] = config.get("preprocessing_func", lambda x: x)
55+
args["max_train_samples"] = config.get("max_train_samples", 50)
56+
args["remove_columns"] = config.get("remove_columns", None)
57+
args["dvc_data_repository"] = config.get("dvc_data_repository", None)
58+
args["splits"] = config.get("splits", {"calibration": "train[:50]"})
59+
args["log_dir"] = config.get("log_dir", "sparse_logs")
60+
4561
return args
4662

4763

4864
@pytest.mark.smoke
4965
@pytest.mark.integration
5066
def test_one_shot_inputs(one_shot_args, tmp_path):
51-
oneshot(
52-
**one_shot_args,
53-
output_dir=tmp_path,
54-
num_calibration_samples=10,
55-
pad_to_max_length=False,
56-
)
67+
logger.info(f"Dataset type: {type(one_shot_args.get('dataset'))}")
68+
if isinstance(one_shot_args.get("dataset"), str):
69+
logger.info(f"Dataset name: {one_shot_args.get('dataset')}")
70+
logger.info(f"Dataset config: {one_shot_args.get('dataset_config_name')}")
71+
try:
72+
# Call oneshot with all parameters as flat arguments
73+
oneshot(
74+
**one_shot_args,
75+
output_dir=tmp_path,
76+
num_calibration_samples=10,
77+
pad_to_max_length=False,
78+
)
79+
80+
except ValueError as e:
81+
if "num_samples should be a positive integer value" in str(
82+
e
83+
) or "Dataset is empty. Cannot create a calibration dataloader" in str(e):
84+
logger.warning(f"Dataset is empty: {one_shot_args.get('dataset')}")
85+
pytest.skip(f"Dataset is empty: {one_shot_args.get('dataset')}")
86+
else:
87+
raise # Re-raise other ValueError exceptions

0 commit comments

Comments
 (0)