Skip to content

Commit b9608e4

Browse files
Sfno fix (#239)
* Add warning if jsbeutifier not installed, set default for h5 in inference, fix import * copy pytorch patches instead of using monkeypatching * Update README.md to include patching doc --------- Co-authored-by: Mohammad Amin Nabian <[email protected]>
1 parent b615801 commit b9608e4

File tree

22 files changed

+25
-102
lines changed

22 files changed

+25
-102
lines changed

modulus/experimental/sfno/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ This is a research code built for massively parallel training of SFNO for weathe
1111

1212
## Getting started
1313

14+
**For distributed training or inference, run `patch_pytorch.sh` in advance. This will patch the pytorch distributed utilities to support complex values.**
15+
1416
## Installing optional dependencies
1517

1618
Install the optional dependencies by running

modulus/experimental/sfno/convert_legacy_to_flexible.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from modulus.experimental.sfno.utils import logging_utils
3030

31+
import torch.distributed as dist
3132

3233
from modulus.experimental.sfno.networks.models import get_model
3334

@@ -36,10 +37,6 @@
3637
from modulus.experimental.sfno.utils.trainer import Trainer
3738
from modulus.experimental.sfno.utils.YParams import ParamsBase
3839

39-
# import patched distributed
40-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
41-
dist = dist_patch()
42-
4340

4441
class CheckpointSaver(Trainer):
4542
"""

modulus/experimental/sfno/inference/inferencer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,8 @@
3636
# distributed computing stuff
3737
from modulus.experimental.sfno.utils import comm
3838
from modulus.experimental.sfno.utils import visualize
39+
import torch.distributed as dist
3940

40-
# import patched distributed
41-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
42-
dist = dist_patch()
4341

4442
class Inferencer(Trainer):
4543
"""

modulus/experimental/sfno/mpu/helpers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@
1414

1515
import torch
1616
import torch.nn.functional as F
17+
import torch.distributed as dist
1718

1819
from modulus.experimental.sfno.utils import comm
1920

2021
from torch._utils import _flatten_dense_tensors
2122

22-
# import patched distributed
23-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
24-
dist = dist_patch()
2523

2624
def get_memory_format(tensor):
2725
if tensor.is_contiguous(memory_format=torch.channels_last):

modulus/experimental/sfno/mpu/layers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717
import torch.nn as nn
1818
import torch.nn.functional as F
19+
import torch.distributed as dist
1920
from torch.cuda.amp import custom_fwd, custom_bwd
2021
from modulus.experimental.sfno.utils import comm
2122

@@ -28,10 +29,6 @@
2829
from modulus.experimental.sfno.mpu.helpers import pad_helper
2930
from modulus.experimental.sfno.mpu.helpers import truncate_helper
3031

31-
# import patched distributed
32-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
33-
dist = dist_patch()
34-
3532

3633
class distributed_transpose_w(torch.autograd.Function):
3734

modulus/experimental/sfno/mpu/mappings.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from torch.nn.parallel import DistributedDataParallel
2020
from modulus.experimental.sfno.utils import comm
21+
import torch.distributed as dist
2122

2223
# torch utils
2324
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
@@ -28,9 +29,6 @@
2829
from modulus.experimental.sfno.mpu.helpers import _split
2930
from modulus.experimental.sfno.mpu.helpers import _gather
3031

31-
# import patched distributed
32-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
33-
dist = dist_patch()
3432

3533
# generalized
3634
class _CopyToParallelRegion(torch.autograd.Function):

modulus/experimental/sfno/networks/helpers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
import torch
1616

1717
from utils import comm
18-
19-
# imprt patched distributed
20-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
21-
dist = dist_patch()
18+
import torch.distributed as dist
2219

2320
def count_parameters(model, device):
2421
with torch.no_grad():
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
cp third_party/torch/distributed/utils.py /usr/local/lib/python3.10/dist-packages/torch/distributed/
3+
echo "Patching complete"

modulus/experimental/sfno/perf_tests/distributed/comm_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
import torch
2121
import torch.nn as nn
2222
import torch.nn.functional as F
23+
import torch.distributed as dist
2324
from torch.cuda import amp
2425

2526
sys.path.append(os.path.join("/opt", "ERA5_wind"))
2627
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
2728
from modulus.experimental.sfno.utils import comm
2829

29-
# import patched distributed
30-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
31-
dist = dist_patch()
3230

3331
# profile stuff
3432
from ctypes import cdll

modulus/experimental/sfno/perf_tests/distributed/dist_fft.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import torch.nn as nn
2222
import torch.nn.functional as F
23+
import torch.distributed as dist
2324
from torch.cuda import amp
2425

2526
sys.path.append(os.path.join("/opt", "ERA5_wind"))
@@ -31,10 +32,6 @@
3132
from modulus.experimental.sfno.mpu.mappings import gather_from_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
3233
from modulus.experimental.sfno.mpu.layers import DistributedRealFFT2, DistributedInverseRealFFT2
3334

34-
# imprt patched distributed
35-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
36-
dist = dist_patch()
37-
3835

3936
def main(args, verify):
4037
# parameters

0 commit comments

Comments
 (0)