Skip to content
Merged
194 changes: 187 additions & 7 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
import multiprocessing
import os
import pickle
import re
import tracemalloc
import warnings
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from importlib.metadata import version
from pathlib import Path

import boto3
Expand All @@ -17,6 +21,7 @@
import numpy as np
import pandas as pd
import tqdm
import yaml
from sdmetrics.reports.multi_table import (
DiagnosticReport as MultiTableDiagnosticReport,
)
Expand Down Expand Up @@ -63,6 +68,15 @@
'covtype',
]
N_BYTES_IN_MB = 1000 * 1000
EXTERNAL_SYNTHESIZER_TO_LIBRARY = {
'RealTabFormerSynthesizer': 'realtabformer',
}
SDV_SINGLE_TABLE_SYNTHESIZERS = [
'GaussianCopulaSynthesizer',
'CTGANSynthesizer',
'CopulaGANSynthesizer',
'TVAESynthesizer',
]


def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers):
Expand Down Expand Up @@ -93,13 +107,67 @@ def _create_detailed_results_directory(detailed_results_folder):
os.makedirs(detailed_results_folder, exist_ok=True)


def _setup_output_destination(output_destination, synthesizers, datasets):
"""Set up the output destination for the benchmark results.

Args:
output_destination (str or None):
The path to the output directory where results will be saved.
If None, no output will be saved.
synthesizers (list):
The list of synthesizers to benchmark.
datasets (list):
The list of datasets to benchmark.
"""
if output_destination is None:
return {}

_validate_output_destination(output_destination)
output_path = Path(output_destination)
output_path.mkdir(parents=True, exist_ok=True)
today = datetime.today().strftime('%m_%d_%Y')
top_folder = output_path / f'SDGym_results_{today}'
top_folder.mkdir(parents=True, exist_ok=True)
pattern = re.compile(rf'run_{re.escape(today)}_(\d+)\.yaml$')
increments = []
for file in top_folder.glob(f'run_{today}_*.yaml'):
match = pattern.match(file.name)
if match:
increments.append(int(match.group(1)))

if increments:
next_increment = max(increments) + 1
else:
next_increment = 1

paths = defaultdict(dict)
for dataset in datasets:
dataset_folder = top_folder / f'{dataset}_{today}'
dataset_folder.mkdir(parents=True, exist_ok=True)

for synth_name in synthesizers:
synth_folder = dataset_folder / synth_name
synth_folder.mkdir(parents=True, exist_ok=True)

paths[dataset][synth_name] = {
'synthesizer': str(synth_folder / f'{synth_name}_synthesizer.pkl'),
'synthetic_data': str(synth_folder / f'{synth_name}_synthetic_data.csv'),
'benchmark_result': str(synth_folder / f'{synth_name}_benchmark_result.csv'),
'run_id': str(top_folder / f'run_{today}_{next_increment}.yaml'),
'results': str(top_folder / f'results_{today}_{next_increment}.csv'),
}

return paths


def _generate_job_args_list(
limit_dataset_size,
sdv_datasets,
additional_datasets_folder,
sdmetrics,
detailed_results_folder,
timeout,
output_destination,
compute_quality_score,
compute_diagnostic_score,
compute_privacy_score,
Expand All @@ -119,7 +187,9 @@ def _generate_job_args_list(
else get_dataset_paths(bucket=additional_datasets_folder)
)
datasets = sdv_datasets + additional_datasets

synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers]
dataset_names = [dataset.name for dataset in datasets]
paths = _setup_output_destination(output_destination, synthesizer_names, dataset_names)
job_tuples = []
for dataset in datasets:
for synthesizer in synthesizers:
Expand All @@ -130,7 +200,7 @@ def _generate_job_args_list(
data, metadata_dict = load_dataset(
'single_table', dataset, limit_dataset_size=limit_dataset_size
)

path = paths.get(dataset.name, {}).get(synthesizer['name'], None)
args = (
synthesizer,
data,
Expand All @@ -143,13 +213,14 @@ def _generate_job_args_list(
compute_privacy_score,
dataset.name,
'single_table',
path,
)
job_args_list.append(args)

return job_args_list


def _synthesize(synthesizer_dict, real_data, metadata):
def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None):
synthesizer = synthesizer_dict['synthesizer']
if isinstance(synthesizer, type):
assert issubclass(synthesizer, BaselineSynthesizer), (
Expand Down Expand Up @@ -177,6 +248,10 @@ def _synthesize(synthesizer_dict, real_data, metadata):
peak_memory = tracemalloc.get_traced_memory()[1] / N_BYTES_IN_MB
tracemalloc.stop()
tracemalloc.clear_traces()
if synthesizer_path is not None:
synthetic_data.to_csv(synthesizer_path['synthetic_data'], index=False)
with open(synthesizer_path['synthesizer'], 'wb') as f:
pickle.dump(synthesizer_obj, f)

return synthetic_data, train_now - now, sample_now - train_now, synthesizer_size, peak_memory

Expand Down Expand Up @@ -283,6 +358,7 @@ def _score(
compute_privacy_score=False,
modality=None,
dataset_name=None,
synthesizer_path=None,
):
if output is None:
output = {}
Expand All @@ -302,7 +378,7 @@ def _score(
# To be deleted if there is no error
output['error'] = 'Synthesizer Timeout'
synthetic_data, train_time, sample_time, synthesizer_size, peak_memory = _synthesize(
synthesizer, data.copy(), metadata
synthesizer, data.copy(), metadata, synthesizer_path=synthesizer_path
)

output['synthetic_data'] = synthetic_data
Expand Down Expand Up @@ -383,6 +459,7 @@ def _score_with_timeout(
compute_privacy_score=False,
modality=None,
dataset_name=None,
synthesizer_path=None,
):
with multiprocessing_context():
with multiprocessing.Manager() as manager:
Expand All @@ -400,6 +477,7 @@ def _score_with_timeout(
compute_privacy_score,
modality,
dataset_name,
synthesizer_path,
),
)

Expand Down Expand Up @@ -500,6 +578,7 @@ def _run_job(args):
compute_privacy_score,
dataset_name,
modality,
synthesizer_path,
) = args

name = synthesizer['name']
Expand All @@ -510,7 +589,6 @@ def _run_job(args):
timeout,
used_memory(),
)

output = {}
try:
if timeout:
Expand All @@ -525,6 +603,7 @@ def _run_job(args):
compute_privacy_score=compute_privacy_score,
modality=modality,
dataset_name=dataset_name,
synthesizer_path=synthesizer_path,
)
else:
output = _score(
Expand All @@ -537,6 +616,7 @@ def _run_job(args):
compute_privacy_score=compute_privacy_score,
modality=modality,
dataset_name=dataset_name,
synthesizer_path=synthesizer_path,
)
except Exception as error:
output['exception'] = error
Expand All @@ -551,6 +631,9 @@ def _run_job(args):
cache_dir,
)

if synthesizer_path is not None:
scores.to_csv(synthesizer_path['benchmark_result'], index=False)

return scores


Expand Down Expand Up @@ -607,6 +690,13 @@ def _run_jobs(multi_processing_config, job_args_list, show_progress):
raise SDGymError('No valid Dataset/Synthesizer combination given.')

scores = pd.concat(scores, ignore_index=True)
output_directions = job_args_list[0][-1]
if output_directions and isinstance(output_directions, dict):
result_file = Path(output_directions['results'])
if not result_file.exists():
scores.to_csv(result_file, index=False, mode='w')
else:
scores.to_csv(result_file, index=False, mode='a', header=False)

return scores

Expand Down Expand Up @@ -779,6 +869,77 @@ def _create_instance_on_ec2(script_content):
print(f'Job kicked off for SDGym on {instance_id}') # noqa


def _handle_deprecated_parameters(
output_filepath, detailed_results_folder, multi_processing_config
):
"""Handle deprecated parameters and issue warnings."""
parameters_to_deprecate = {
'output_filepath': output_filepath,
'detailed_results_folder': detailed_results_folder,
'multi_processing_config': multi_processing_config,
}
parameters = []
for name, value in parameters_to_deprecate.items():
if value is not None and value:
parameters.append(name)

if parameters:
parameters = "', '".join(sorted(parameters))
message = (
f"Parameters '{parameters}' are deprecated in the 'benchmark_single_table' "
'function and will be removed in October 2025. '
"For saving results, please use the 'output_destination' parameter. For running SDGym"
" remotely on AWS please use the 'benchmark_single_table_aws' method."
)
warnings.warn(message, FutureWarning)


def _validate_output_destination(output_destination):
if not isinstance(output_destination, str):
raise ValueError(
'The `output_destination` parameter must be a string representing the output path.'
)

if is_s3_path(output_destination):
raise ValueError(
'The `output_destination` parameter cannot be an S3 path. '
'Please use `benchmark_single_table_aws` instead.'
)


def _write_run_id_file(synthesizers, job_args_list):
jobs = [[job[-3], job[0]['name']] for job in job_args_list]
output_directions = job_args_list[0][-1]
path = output_directions['run_id']
run_id = Path(path).stem
metadata = {
'run_id': run_id,
'starting_date': datetime.today().strftime('%m_%d_%Y %H:%M:%S'),
'completed_date': None,
'sdgym_version': version('sdgym'),
'jobs': jobs,
}
for synthesizer in synthesizers:
if synthesizer not in SDV_SINGLE_TABLE_SYNTHESIZERS:
ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY[synthesizer]
library_version = version(ext_lib)
metadata[f'{ext_lib}_version'] = library_version
elif 'sdv' not in metadata.keys():
metadata['sdv_version'] = version('sdv')

with open(path, 'w') as file:
yaml.dump(metadata, file)


def _update_run_id_file(run_file):
with open(run_file, 'r') as f:
run_data = yaml.safe_load(f) or {}
Comment on lines +935 to +936
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to grab a lock here or worry about multiple runs trying to modify this file at the same time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we're safe here because the method is called after all the jobs are run and the results generated.


run_data['completed_date'] = datetime.today().strftime('%m_%d_%Y %H:%M:%S')
with open(run_file, 'w') as f:
yaml.dump(run_data, f)


def benchmark_single_table(
synthesizers=DEFAULT_SYNTHESIZERS,
custom_synthesizers=None,
Expand All @@ -790,6 +951,7 @@ def benchmark_single_table(
compute_privacy_score=True,
sdmetrics=None,
timeout=None,
output_destination=None,
output_filepath=None,
detailed_results_folder=None,
show_progress=False,
Expand Down Expand Up @@ -837,6 +999,18 @@ def benchmark_single_table(
timeout (int or ``None``):
The maximum number of seconds to wait for synthetic data creation. If ``None``, no
timeout is enforced.
output_destination (str or ``None``):
The path to the output directory where results will be saved. If ``None``, no
output is saved. The results are saved with the following structure:
output_destination/
run_<id>.yaml
SDGym_results_<date>/
results.csv
<dataset_name>_<date>/
meta.yaml
<synthesizer_name>/
synthesizer.pkl
synthetic_data.csv
output_filepath (str or ``None``):
A file path for where to write the output as a csv file. If ``None``, no output
is written. If run_on_ec2 flag output_filepath needs to be defined and
Expand All @@ -863,6 +1037,7 @@ def benchmark_single_table(
pandas.DataFrame:
A table containing one row per synthesizer + dataset + metric.
"""
_handle_deprecated_parameters(output_filepath, detailed_results_folder, multi_processing_config)
if run_on_ec2:
print("This will create an instance for the current AWS user's account.") # noqa
if output_filepath is not None:
Expand All @@ -873,22 +1048,23 @@ def benchmark_single_table(
return None

_validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers)

_create_detailed_results_directory(detailed_results_folder)

job_args_list = _generate_job_args_list(
limit_dataset_size,
sdv_datasets,
additional_datasets_folder,
sdmetrics,
detailed_results_folder,
timeout,
output_destination,
compute_quality_score,
compute_diagnostic_score,
compute_privacy_score,
synthesizers,
custom_synthesizers,
)
if output_destination is not None:
_write_run_id_file(synthesizers, job_args_list)

if job_args_list:
scores = _run_jobs(multi_processing_config, job_args_list, show_progress)
Expand All @@ -905,4 +1081,8 @@ def benchmark_single_table(
if output_filepath:
write_csv(scores, output_filepath, None, None)

if output_destination is not None:
run_id_filename = job_args_list[0][-1]['run_id']
_update_run_id_file(run_id_filename)

return scores
Loading
Loading