Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions examples/geophysics/diffusion_fwi/data/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,12 @@ def main():
help="Peak frequency (Hz) of the Ricker source wavelet used during "
"forward modeling. Defaults to 15.",
)
parser.add_argument(
"--n_workers",
type=int,
default=8,
help="Num of workers per GPU. Defaults to 8",
)
args = parser.parse_args()

dataset_path: Path = Path(args.in_dir) / "samples"
Expand Down Expand Up @@ -402,13 +408,18 @@ def main():
if (i + 1) % 1000 == 0:
logging.info(f"Processed {i + 1} / {total_files} files")
else:
logging.info(f"Found {num_gpus} GPUs. Starting parallel processing.")
workers_per_gpu: int = args.n_workers
num_workers: int = num_gpus * workers_per_gpu

logging.info(
f"Found {num_gpus} GPUs. Starting parallel processing with\
{num_workers} workers ({workers_per_gpu} per GPU).")
args: list[tuple[str, str, int, int]] = [
(filepath, output_path, i % num_gpus, user_source_frequency)
for i, filepath in enumerate(file_list)
Comment on lines 417 to 419
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Multiple workers will be assigned to the same GPU (via modulo), but each worker loads models/data onto the same GPU device without coordination. This could cause CUDA out-of-memory errors. Have you tested this with multiple workers per GPU to ensure GPU memory usage doesn't exceed available VRAM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been tested. Default n_wrokers are set according to that.

]

with mp.get_context("spawn").Pool(processes=num_gpus) as pool:
with mp.get_context("spawn").Pool(processes=num_workers) as pool:
iterator = pool.imap_unordered(process_file_wrapper, args)

for i, result in enumerate(iterator):
Expand Down