Skip to content

Commit 9f22028

Browse files
committed
update to separate validation and upload of dataset
1 parent 64c3b7e commit 9f22028

File tree

3 files changed

+120
-48
lines changed

3 files changed

+120
-48
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 93 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -363,18 +363,21 @@ def _resolve_evaluator(
363363
return evaluator_id, evaluator_resource_name, selected_test_file_path, selected_test_func_name
364364

365365

366-
def _resolve_and_prepare_dataset(
366+
def _resolve_dataset(
367367
project_root: str,
368368
account_id: str,
369-
api_key: str,
370-
api_base: str,
371369
evaluator_id: str,
372370
args: argparse.Namespace,
373371
selected_test_file_path: Optional[str],
374372
selected_test_func_name: Optional[str],
375-
dry_run: bool,
376373
) -> tuple[Optional[str], Optional[str], Optional[str]]:
377-
"""Resolve dataset id/resource and ensure dataset exists if using JSONL."""
374+
"""Resolve dataset source without performing any uploads.
375+
376+
Returns a tuple of:
377+
- dataset_id: existing dataset id when using --dataset or fully-qualified dataset resource
378+
- dataset_resource: fully-qualified dataset resource for existing datasets; None for JSONL sources
379+
- dataset_jsonl: local JSONL path when using --dataset-jsonl or inferred sources; None for id-only datasets
380+
"""
378381
dataset_id = getattr(args, "dataset", None)
379382
dataset_jsonl = getattr(args, "dataset_jsonl", None)
380383
dataset_display_name = getattr(args, "dataset_display_name", None)
@@ -432,40 +435,72 @@ def _resolve_and_prepare_dataset(
432435
)
433436
return None, None, None
434437

435-
inferred_dataset_id = _build_trimmed_dataset_id(evaluator_id)
436-
if dry_run:
437-
print("--dry-run: would create dataset and upload JSONL")
438-
dataset_id = inferred_dataset_id
439-
else:
440-
try:
441-
# Resolve dataset_jsonl path relative to CWD if needed
442-
jsonl_path_for_upload = (
443-
dataset_jsonl
444-
if os.path.isabs(dataset_jsonl)
445-
else os.path.abspath(os.path.join(project_root, dataset_jsonl))
446-
)
447-
dataset_id, _ = create_dataset_from_jsonl(
448-
account_id=account_id,
449-
api_key=api_key,
450-
api_base=api_base,
451-
dataset_id=inferred_dataset_id,
452-
display_name=dataset_display_name or inferred_dataset_id,
453-
jsonl_path=jsonl_path_for_upload,
454-
)
455-
print(f"✓ Created and uploaded dataset: {dataset_id}")
456-
except Exception as e:
457-
print(f"Error creating/uploading dataset: {e}")
458-
return None, None, None
459-
460-
if not dataset_id:
461-
return None, None, None
438+
# Build dataset resource for existing datasets; JSONL-based datasets will be uploaded later.
439+
dataset_resource = None
440+
if dataset_id:
441+
dataset_resource = dataset_resource_override or f"accounts/{account_id}/datasets/{dataset_id}"
462442

463-
# Build dataset resource (prefer override when provided)
464-
dataset_resource = dataset_resource_override or f"accounts/{account_id}/datasets/{dataset_id}"
465443
return dataset_id, dataset_resource, dataset_jsonl
466444

467445

468-
def _ensure_evaluator_active(
446+
def _upload_dataset(
447+
project_root: str,
448+
account_id: str,
449+
api_key: str,
450+
api_base: str,
451+
evaluator_id: str,
452+
dataset_id: Optional[str],
453+
dataset_resource: Optional[str],
454+
dataset_jsonl: Optional[str],
455+
args: argparse.Namespace,
456+
dry_run: bool,
457+
) -> tuple[Optional[str], Optional[str]]:
458+
"""Create/upload the dataset when using a local JSONL source.
459+
460+
For existing datasets (--dataset or fully-qualified ids), this is a no-op that
461+
simply ensures dataset_id and dataset_resource are populated.
462+
"""
463+
# Existing dataset case: nothing to upload
464+
if not dataset_jsonl:
465+
if not dataset_id:
466+
return None, None
467+
if not dataset_resource:
468+
dataset_resource = f"accounts/{account_id}/datasets/{dataset_id}"
469+
return dataset_id, dataset_resource
470+
471+
# JSONL-based dataset: upload or simulate upload
472+
inferred_dataset_id = _build_trimmed_dataset_id(evaluator_id)
473+
dataset_display_name = getattr(args, "dataset_display_name", None) or inferred_dataset_id
474+
475+
# Resolve dataset_jsonl path relative to CWD if needed
476+
jsonl_path_for_upload = (
477+
dataset_jsonl if os.path.isabs(dataset_jsonl) else os.path.abspath(os.path.join(project_root, dataset_jsonl))
478+
)
479+
480+
if dry_run:
481+
print("--dry-run: would create dataset and upload JSONL")
482+
dataset_id = inferred_dataset_id
483+
dataset_resource = f"accounts/{account_id}/datasets/{dataset_id}"
484+
return dataset_id, dataset_resource
485+
486+
try:
487+
dataset_id, _ = create_dataset_from_jsonl(
488+
account_id=account_id,
489+
api_key=api_key,
490+
api_base=api_base,
491+
dataset_id=inferred_dataset_id,
492+
display_name=dataset_display_name,
493+
jsonl_path=jsonl_path_for_upload,
494+
)
495+
print(f"✓ Created and uploaded dataset: {dataset_id}")
496+
dataset_resource = f"accounts/{account_id}/datasets/{dataset_id}"
497+
return dataset_id, dataset_resource
498+
except Exception as e:
499+
print(f"Error creating/uploading dataset: {e}")
500+
return None, None
501+
502+
503+
def _upload_and_ensure_evaluator(
469504
project_root: str,
470505
evaluator_id: str,
471506
evaluator_resource_name: str,
@@ -726,19 +761,17 @@ def create_rft_command(args) -> int:
726761
if not evaluator_id or not evaluator_resource_name:
727762
return 1
728763

729-
# 2) Resolve dataset (id/resource) and underlying JSONL (if any)
730-
dataset_id, dataset_resource, dataset_jsonl = _resolve_and_prepare_dataset(
764+
# 2) Resolve dataset source (id or JSONL path)
765+
dataset_id, dataset_resource, dataset_jsonl = _resolve_dataset(
731766
project_root=project_root,
732767
account_id=account_id,
733-
api_key=api_key,
734-
api_base=api_base,
735768
evaluator_id=evaluator_id,
736769
args=args,
737770
selected_test_file_path=selected_test_file_path,
738771
selected_test_func_name=selected_test_func_name,
739-
dry_run=dry_run,
740772
)
741-
if not dataset_id or not dataset_resource:
773+
# Require either an existing dataset id or a JSONL source to materialize from
774+
if dataset_jsonl is None and not dataset_id:
742775
return 1
743776

744777
# 3) Optional local validation
@@ -758,8 +791,24 @@ def create_rft_command(args) -> int:
758791
):
759792
return 1
760793

761-
# 4) Ensure evaluator exists and is ACTIVE (upload + poll if needed)
762-
if not _ensure_evaluator_active(
794+
# 4) Upload dataset when using JSONL sources (no-op for existing datasets)
795+
dataset_id, dataset_resource = _upload_dataset(
796+
project_root=project_root,
797+
account_id=account_id,
798+
api_key=api_key,
799+
api_base=api_base,
800+
evaluator_id=evaluator_id,
801+
dataset_id=dataset_id,
802+
dataset_resource=dataset_resource,
803+
dataset_jsonl=dataset_jsonl,
804+
args=args,
805+
dry_run=dry_run,
806+
)
807+
if not dataset_id or not dataset_resource:
808+
return 1
809+
810+
# 5) Ensure evaluator exists and is ACTIVE (upload + poll if needed)
811+
if not _upload_and_ensure_evaluator(
763812
project_root=project_root,
764813
evaluator_id=evaluator_id,
765814
evaluator_resource_name=evaluator_resource_name,
@@ -769,7 +818,7 @@ def create_rft_command(args) -> int:
769818
):
770819
return 1
771820

772-
# 5) Create the RFT job
821+
# 6) Create the RFT job
773822
return _create_rft_job(
774823
account_id=account_id,
775824
api_key=api_key,

tests/test_cli_create_rft_infer.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def rft_test_harness(tmp_path, monkeypatch):
4040
monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1])
4141
monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0)
4242
monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True)
43-
monkeypatch.setattr(cr, "_ensure_evaluator_active", lambda *a, **k: True)
43+
monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True)
4444

4545
return project
4646

@@ -82,6 +82,25 @@ def _fake_create_job(account_id, api_key, api_base, body):
8282

8383
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", _fake_create_job)
8484

85+
# Stub validation helpers: dataset always valid; capture evaluator validation flags
86+
monkeypatch.setattr(cr, "_validate_dataset", lambda dataset_jsonl: True)
87+
flag_calls = {"ignore_docker": None, "docker_build_extra": None, "docker_run_extra": None}
88+
89+
def _fake_validate_evaluator_locally(
90+
project_root,
91+
selected_test_file,
92+
selected_test_func,
93+
ignore_docker,
94+
docker_build_extra,
95+
docker_run_extra,
96+
):
97+
flag_calls["ignore_docker"] = ignore_docker
98+
flag_calls["docker_build_extra"] = docker_build_extra
99+
flag_calls["docker_run_extra"] = docker_run_extra
100+
return True
101+
102+
monkeypatch.setattr(cr, "_validate_evaluator_locally", _fake_validate_evaluator_locally)
103+
85104
args = argparse.Namespace(
86105
# Evaluator and dataset
87106
evaluator="my-evaluator",
@@ -94,7 +113,7 @@ def _fake_create_job(account_id, api_key, api_base, body):
94113
dry_run=False,
95114
force=False,
96115
env_file=None,
97-
skip_validation=True,
116+
skip_validation=False,
98117
ignore_docker=False,
99118
docker_build_extra="--build-extra FLAG",
100119
docker_run_extra="--run-extra FLAG",
@@ -177,6 +196,11 @@ def _fake_create_job(account_id, api_key, api_base, body):
177196
for k in ("skip_validation", "ignore_docker", "docker_build_extra", "docker_run_extra"):
178197
assert k not in body
179198

199+
# But they should be propagated into local evaluator validation
200+
assert flag_calls["ignore_docker"] is False
201+
assert flag_calls["docker_build_extra"] == "--build-extra FLAG"
202+
assert flag_calls["docker_run_extra"] == "--run-extra FLAG"
203+
180204

181205
def test_create_rft_evaluator_validation_fails(rft_test_harness, monkeypatch):
182206
project = rft_test_harness

tests/test_cli_local_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ def test_local_test_selector_single_test(tmp_path, monkeypatch):
126126

127127
# No entry; force discover + selector
128128
disc = SimpleNamespace(qualname="metric.test_sel", file_path=str(test_file))
129-
monkeypatch.setattr(lt, "_discover_tests", lambda root: [disc])
130-
monkeypatch.setattr(lt, "_prompt_select", lambda tests, non_interactive=False: tests[:1])
129+
monkeypatch.setattr(lt, "_discover_and_select_tests", lambda cwd, non_interactive=False: [disc])
131130
monkeypatch.setattr(lt, "_find_dockerfiles", lambda root: [])
132131

133132
called = {"host": False}

0 commit comments

Comments
 (0)