-
Notifications
You must be signed in to change notification settings - Fork 63
Add ability to save synthesizers and data when running benchmark_single_table #415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a572a3a
8e967cc
cdf3fce
6f6b0bc
8a70378
98f6ce3
d83d96c
995ab2e
1b70b04
341c77d
0f81feb
a1502a6
1d77b58
662786d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,13 @@ | |
import multiprocessing | ||
import os | ||
import pickle | ||
import re | ||
import tracemalloc | ||
import warnings | ||
from collections import defaultdict | ||
from contextlib import contextmanager | ||
from datetime import datetime | ||
from importlib.metadata import version | ||
from pathlib import Path | ||
|
||
import boto3 | ||
|
@@ -17,6 +21,7 @@ | |
import numpy as np | ||
import pandas as pd | ||
import tqdm | ||
import yaml | ||
from sdmetrics.reports.multi_table import ( | ||
DiagnosticReport as MultiTableDiagnosticReport, | ||
) | ||
|
@@ -63,6 +68,15 @@ | |
'covtype', | ||
] | ||
N_BYTES_IN_MB = 1000 * 1000 | ||
EXTERNAL_SYNTHESIZER_TO_LIBRARY = { | ||
'RealTabFormerSynthesizer': 'realtabformer', | ||
} | ||
SDV_SINGLE_TABLE_SYNTHESIZERS = [ | ||
'GaussianCopulaSynthesizer', | ||
'CTGANSynthesizer', | ||
'CopulaGANSynthesizer', | ||
'TVAESynthesizer', | ||
] | ||
|
||
|
||
def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers): | ||
|
@@ -93,13 +107,67 @@ def _create_detailed_results_directory(detailed_results_folder): | |
os.makedirs(detailed_results_folder, exist_ok=True) | ||
|
||
|
||
def _setup_output_destination(output_destination, synthesizers, datasets): | ||
"""Set up the output destination for the benchmark results. | ||
|
||
Args: | ||
output_destination (str or None): | ||
The path to the output directory where results will be saved. | ||
If None, no output will be saved. | ||
synthesizers (list): | ||
The list of synthesizers to benchmark. | ||
datasets (list): | ||
The list of datasets to benchmark. | ||
""" | ||
if output_destination is None: | ||
return {} | ||
|
||
_validate_output_destination(output_destination) | ||
output_path = Path(output_destination) | ||
output_path.mkdir(parents=True, exist_ok=True) | ||
today = datetime.today().strftime('%m_%d_%Y') | ||
top_folder = output_path / f'SDGym_results_{today}' | ||
top_folder.mkdir(parents=True, exist_ok=True) | ||
pattern = re.compile(rf'run_{re.escape(today)}_(\d+)\.yaml$') | ||
increments = [] | ||
for file in top_folder.glob(f'run_{today}_*.yaml'): | ||
match = pattern.match(file.name) | ||
if match: | ||
increments.append(int(match.group(1))) | ||
|
||
if increments: | ||
next_increment = max(increments) + 1 | ||
else: | ||
next_increment = 1 | ||
|
||
paths = defaultdict(dict) | ||
for dataset in datasets: | ||
dataset_folder = top_folder / f'{dataset}_{today}' | ||
dataset_folder.mkdir(parents=True, exist_ok=True) | ||
|
||
for synth_name in synthesizers: | ||
synth_folder = dataset_folder / synth_name | ||
synth_folder.mkdir(parents=True, exist_ok=True) | ||
|
||
paths[dataset][synth_name] = { | ||
'synthesizer': str(synth_folder / f'{synth_name}_synthesizer.pkl'), | ||
'synthetic_data': str(synth_folder / f'{synth_name}_synthetic_data.csv'), | ||
'benchmark_result': str(synth_folder / f'{synth_name}_benchmark_result.csv'), | ||
'run_id': str(top_folder / f'run_{today}_{next_increment}.yaml'), | ||
'results': str(top_folder / f'results_{today}_{next_increment}.csv'), | ||
} | ||
|
||
return paths | ||
|
||
|
||
def _generate_job_args_list( | ||
limit_dataset_size, | ||
sdv_datasets, | ||
additional_datasets_folder, | ||
sdmetrics, | ||
detailed_results_folder, | ||
timeout, | ||
output_destination, | ||
compute_quality_score, | ||
compute_diagnostic_score, | ||
compute_privacy_score, | ||
|
@@ -119,7 +187,9 @@ def _generate_job_args_list( | |
else get_dataset_paths(bucket=additional_datasets_folder) | ||
) | ||
datasets = sdv_datasets + additional_datasets | ||
|
||
synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers] | ||
dataset_names = [dataset.name for dataset in datasets] | ||
paths = _setup_output_destination(output_destination, synthesizer_names, dataset_names) | ||
job_tuples = [] | ||
for dataset in datasets: | ||
for synthesizer in synthesizers: | ||
|
@@ -130,7 +200,7 @@ def _generate_job_args_list( | |
data, metadata_dict = load_dataset( | ||
'single_table', dataset, limit_dataset_size=limit_dataset_size | ||
) | ||
|
||
path = paths.get(dataset.name, {}).get(synthesizer['name'], None) | ||
args = ( | ||
synthesizer, | ||
data, | ||
|
@@ -143,13 +213,14 @@ def _generate_job_args_list( | |
compute_privacy_score, | ||
dataset.name, | ||
'single_table', | ||
path, | ||
) | ||
job_args_list.append(args) | ||
|
||
return job_args_list | ||
|
||
|
||
def _synthesize(synthesizer_dict, real_data, metadata): | ||
def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None): | ||
synthesizer = synthesizer_dict['synthesizer'] | ||
if isinstance(synthesizer, type): | ||
assert issubclass(synthesizer, BaselineSynthesizer), ( | ||
|
@@ -177,6 +248,10 @@ def _synthesize(synthesizer_dict, real_data, metadata): | |
peak_memory = tracemalloc.get_traced_memory()[1] / N_BYTES_IN_MB | ||
tracemalloc.stop() | ||
tracemalloc.clear_traces() | ||
if synthesizer_path is not None: | ||
synthetic_data.to_csv(synthesizer_path['synthetic_data'], index=False) | ||
with open(synthesizer_path['synthesizer'], 'wb') as f: | ||
pickle.dump(synthesizer_obj, f) | ||
|
||
return synthetic_data, train_now - now, sample_now - train_now, synthesizer_size, peak_memory | ||
|
||
|
@@ -283,6 +358,7 @@ def _score( | |
compute_privacy_score=False, | ||
modality=None, | ||
dataset_name=None, | ||
synthesizer_path=None, | ||
): | ||
if output is None: | ||
output = {} | ||
|
@@ -302,7 +378,7 @@ def _score( | |
# To be deleted if there is no error | ||
output['error'] = 'Synthesizer Timeout' | ||
synthetic_data, train_time, sample_time, synthesizer_size, peak_memory = _synthesize( | ||
synthesizer, data.copy(), metadata | ||
synthesizer, data.copy(), metadata, synthesizer_path=synthesizer_path | ||
) | ||
|
||
output['synthetic_data'] = synthetic_data | ||
|
@@ -383,6 +459,7 @@ def _score_with_timeout( | |
compute_privacy_score=False, | ||
modality=None, | ||
dataset_name=None, | ||
synthesizer_path=None, | ||
): | ||
with multiprocessing_context(): | ||
with multiprocessing.Manager() as manager: | ||
|
@@ -400,6 +477,7 @@ def _score_with_timeout( | |
compute_privacy_score, | ||
modality, | ||
dataset_name, | ||
synthesizer_path, | ||
), | ||
) | ||
|
||
|
@@ -500,6 +578,7 @@ def _run_job(args): | |
compute_privacy_score, | ||
dataset_name, | ||
modality, | ||
synthesizer_path, | ||
) = args | ||
|
||
name = synthesizer['name'] | ||
|
@@ -510,7 +589,6 @@ def _run_job(args): | |
timeout, | ||
used_memory(), | ||
) | ||
|
||
output = {} | ||
try: | ||
if timeout: | ||
|
@@ -525,6 +603,7 @@ def _run_job(args): | |
compute_privacy_score=compute_privacy_score, | ||
modality=modality, | ||
dataset_name=dataset_name, | ||
synthesizer_path=synthesizer_path, | ||
) | ||
else: | ||
output = _score( | ||
|
@@ -537,6 +616,7 @@ def _run_job(args): | |
compute_privacy_score=compute_privacy_score, | ||
modality=modality, | ||
dataset_name=dataset_name, | ||
synthesizer_path=synthesizer_path, | ||
) | ||
except Exception as error: | ||
output['exception'] = error | ||
|
@@ -551,6 +631,9 @@ def _run_job(args): | |
cache_dir, | ||
) | ||
|
||
if synthesizer_path is not None: | ||
scores.to_csv(synthesizer_path['benchmark_result'], index=False) | ||
|
||
return scores | ||
|
||
|
||
|
@@ -607,6 +690,13 @@ def _run_jobs(multi_processing_config, job_args_list, show_progress): | |
raise SDGymError('No valid Dataset/Synthesizer combination given.') | ||
|
||
scores = pd.concat(scores, ignore_index=True) | ||
output_directions = job_args_list[0][-1] | ||
if output_directions and isinstance(output_directions, dict): | ||
result_file = Path(output_directions['results']) | ||
if not result_file.exists(): | ||
scores.to_csv(result_file, index=False, mode='w') | ||
else: | ||
scores.to_csv(result_file, index=False, mode='a', header=False) | ||
|
||
return scores | ||
|
||
|
@@ -779,6 +869,77 @@ def _create_instance_on_ec2(script_content): | |
print(f'Job kicked off for SDGym on {instance_id}') # noqa | ||
|
||
|
||
def _handle_deprecated_parameters( | ||
output_filepath, detailed_results_folder, multi_processing_config | ||
): | ||
"""Handle deprecated parameters and issue warnings.""" | ||
parameters_to_deprecate = { | ||
'output_filepath': output_filepath, | ||
'detailed_results_folder': detailed_results_folder, | ||
'multi_processing_config': multi_processing_config, | ||
} | ||
parameters = [] | ||
for name, value in parameters_to_deprecate.items(): | ||
if value is not None and value: | ||
parameters.append(name) | ||
|
||
if parameters: | ||
parameters = "', '".join(sorted(parameters)) | ||
message = ( | ||
f"Parameters '{parameters}' are deprecated in the 'benchmark_single_table' " | ||
'function and will be removed in October 2025. ' | ||
"For saving results, please use the 'output_destination' parameter. For running SDGym" | ||
" remotely on AWS please use the 'benchmark_single_table_aws' method." | ||
) | ||
warnings.warn(message, FutureWarning) | ||
|
||
|
||
def _validate_output_destination(output_destination): | ||
if not isinstance(output_destination, str): | ||
raise ValueError( | ||
'The `output_destination` parameter must be a string representing the output path.' | ||
) | ||
|
||
if is_s3_path(output_destination): | ||
raise ValueError( | ||
'The `output_destination` parameter cannot be an S3 path. ' | ||
'Please use `benchmark_single_table_aws` instead.' | ||
) | ||
|
||
|
||
def _write_run_id_file(synthesizers, job_args_list): | ||
jobs = [[job[-3], job[0]['name']] for job in job_args_list] | ||
output_directions = job_args_list[0][-1] | ||
path = output_directions['run_id'] | ||
run_id = Path(path).stem | ||
metadata = { | ||
'run_id': run_id, | ||
'starting_date': datetime.today().strftime('%m_%d_%Y %H:%M:%S'), | ||
'completed_date': None, | ||
'sdgym_version': version('sdgym'), | ||
'jobs': jobs, | ||
} | ||
for synthesizer in synthesizers: | ||
if synthesizer not in SDV_SINGLE_TABLE_SYNTHESIZERS: | ||
ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY[synthesizer] | ||
library_version = version(ext_lib) | ||
metadata[f'{ext_lib}_version'] = library_version | ||
elif 'sdv' not in metadata.keys(): | ||
metadata['sdv_version'] = version('sdv') | ||
|
||
with open(path, 'w') as file: | ||
yaml.dump(metadata, file) | ||
|
||
|
||
def _update_run_id_file(run_file): | ||
with open(run_file, 'r') as f: | ||
run_data = yaml.safe_load(f) or {} | ||
Comment on lines
+935
to
+936
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have to grab a lock here or worry about multiple runs trying to modify this file at the same time? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No we're safe here because the method is called after all the jobs are run and the results generated. |
||
|
||
run_data['completed_date'] = datetime.today().strftime('%m_%d_%Y %H:%M:%S') | ||
with open(run_file, 'w') as f: | ||
yaml.dump(run_data, f) | ||
|
||
|
||
def benchmark_single_table( | ||
synthesizers=DEFAULT_SYNTHESIZERS, | ||
custom_synthesizers=None, | ||
|
@@ -790,6 +951,7 @@ def benchmark_single_table( | |
compute_privacy_score=True, | ||
sdmetrics=None, | ||
timeout=None, | ||
output_destination=None, | ||
output_filepath=None, | ||
detailed_results_folder=None, | ||
show_progress=False, | ||
|
@@ -837,6 +999,18 @@ def benchmark_single_table( | |
timeout (int or ``None``): | ||
The maximum number of seconds to wait for synthetic data creation. If ``None``, no | ||
timeout is enforced. | ||
output_destination (str or ``None``): | ||
The path to the output directory where results will be saved. If ``None``, no | ||
output is saved. The results are saved with the following structure: | ||
output_destination/ | ||
run_<id>.yaml | ||
SDGym_results_<date>/ | ||
results.csv | ||
<dataset_name>_<date>/ | ||
meta.yaml | ||
<synthesizer_name>/ | ||
synthesizer.pkl | ||
synthetic_data.csv | ||
output_filepath (str or ``None``): | ||
A file path for where to write the output as a csv file. If ``None``, no output | ||
is written. If run_on_ec2 flag output_filepath needs to be defined and | ||
|
@@ -863,6 +1037,7 @@ def benchmark_single_table( | |
pandas.DataFrame: | ||
A table containing one row per synthesizer + dataset + metric. | ||
""" | ||
_handle_deprecated_parameters(output_filepath, detailed_results_folder, multi_processing_config) | ||
if run_on_ec2: | ||
print("This will create an instance for the current AWS user's account.") # noqa | ||
if output_filepath is not None: | ||
|
@@ -873,22 +1048,23 @@ def benchmark_single_table( | |
return None | ||
|
||
_validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers) | ||
|
||
_create_detailed_results_directory(detailed_results_folder) | ||
|
||
job_args_list = _generate_job_args_list( | ||
limit_dataset_size, | ||
sdv_datasets, | ||
additional_datasets_folder, | ||
sdmetrics, | ||
detailed_results_folder, | ||
timeout, | ||
output_destination, | ||
compute_quality_score, | ||
compute_diagnostic_score, | ||
compute_privacy_score, | ||
synthesizers, | ||
custom_synthesizers, | ||
) | ||
if output_destination is not None: | ||
_write_run_id_file(synthesizers, job_args_list) | ||
|
||
if job_args_list: | ||
scores = _run_jobs(multi_processing_config, job_args_list, show_progress) | ||
|
@@ -905,4 +1081,8 @@ def benchmark_single_table( | |
if output_filepath: | ||
write_csv(scores, output_filepath, None, None) | ||
|
||
if output_destination is not None: | ||
run_id_filename = job_args_list[0][-1]['run_id'] | ||
_update_run_id_file(run_id_filename) | ||
|
||
return scores |
Uh oh!
There was an error while loading. Please reload this page.