diff --git a/examples/geophysics/diffusion_fwi/data/generate_data.py b/examples/geophysics/diffusion_fwi/data/generate_data.py index 453af783f6..a7d09af6b8 100644 --- a/examples/geophysics/diffusion_fwi/data/generate_data.py +++ b/examples/geophysics/diffusion_fwi/data/generate_data.py @@ -366,6 +366,12 @@ def main(): help="Peak frequency (Hz) of the Ricker source wavelet used during " "forward modeling. Defaults to 15.", ) + parser.add_argument( + "--n_workers", + type=int, + default=8, + help="Num of workers per GPU. Defaults to 8", + ) args = parser.parse_args() dataset_path: Path = Path(args.in_dir) / "samples" @@ -402,13 +408,18 @@ def main(): if (i + 1) % 1000 == 0: logging.info(f"Processed {i + 1} / {total_files} files") else: - logging.info(f"Found {num_gpus} GPUs. Starting parallel processing.") + workers_per_gpu: int = args.n_workers + num_workers: int = num_gpus * workers_per_gpu + + logging.info( + f"Found {num_gpus} GPUs. Starting parallel processing with " + f"{num_workers} workers ({workers_per_gpu} per GPU).") args: list[tuple[str, str, int, int]] = [ (filepath, output_path, i % num_gpus, user_source_frequency) for i, filepath in enumerate(file_list) ] - with mp.get_context("spawn").Pool(processes=num_gpus) as pool: + with mp.get_context("spawn").Pool(processes=num_workers) as pool: iterator = pool.imap_unordered(process_file_wrapper, args) for i, result in enumerate(iterator):