From a1ebaa1c40fb4a150991fd16e389b7ddc4fad6f6 Mon Sep 17 00:00:00 2001 From: tommelt Date: Thu, 22 May 2025 21:45:20 +0100 Subject: [PATCH 01/15] chore: add new data to the download script --- era5_training/get-model-and-data.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/era5_training/get-model-and-data.sh b/era5_training/get-model-and-data.sh index 059c60a..65954da 100755 --- a/era5_training/get-model-and-data.sh +++ b/era5_training/get-model-and-data.sh @@ -11,4 +11,4 @@ wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/at cd .. echo "retrieving test input..." -(cd inputs && wget https://g-b56e81.7a577b.6fbd.data.globus.org/1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2010_constant_mu_sigma_scaling01.nc) +(cd inputs && wget https://g-b56e81.7a577b.6fbd.data.globus.org/1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling01.nc) From 0f5d416d5d17d4e1c692f3814934de15c43b5b3b Mon Sep 17 00:00:00 2001 From: tommelt Date: Thu, 22 May 2025 21:45:49 +0100 Subject: [PATCH 02/15] docs: update commands --- era5_training/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/era5_training/README.md b/era5_training/README.md index 7a379e1..3d78fd0 100644 --- a/era5_training/README.md +++ b/era5_training/README.md @@ -51,7 +51,7 @@ test-data/ ### Ann ```bash -python inference.py -M ann -d global -v global -f uvthetaw -e 8 -m 1 -s 1 -t era5 -i inputs/ -c model-huggingface/ -o outputs/ --script +python inference.py -M ann -d global -v global -f uvthetaw -e 45 -m 1 -s 1 -t era5 -i inputs/ -c model-huggingface/ -o outputs/ --script ``` This will generate some test data and a torchscripted model, to be used by `infer.f90` and `infer.py` later on. @@ -86,7 +86,7 @@ python infer.py -M ann -t test-data/ -s . To test the newly generate torchscript models, use the following command: ```bash -bash compile-and-run.sh intel +bash compile-and-run.sh gcc ``` This will compile `infer.f90` into `infer.exe`. This requires having cuda installed on your system. It also requires `ftorch` to From 675c75871eb2fcd77b28e4e76912d4b3c67fecd8 Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Fri, 23 May 2025 15:22:07 +0100 Subject: [PATCH 03/15] Update path to new (45 epoch) trained model --- era5_training/get-model-and-data.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/era5_training/get-model-and-data.sh b/era5_training/get-model-and-data.sh index 65954da..37687a9 100755 --- a/era5_training/get-model-and-data.sh +++ b/era5_training/get-model-and-data.sh @@ -5,10 +5,12 @@ mkdir -p inputs echo "retrieving model weights..." cd model-huggingface -wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch8.pt +wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/retrained_ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch45.pt wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch94.pt wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/attnunet_era5_global_global_uvthetaw_mseloss_train_epoch119.pt cd .. +mv model-huggingface/retrained_ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch45.pt model-huggingface/ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch45.pt + echo "retrieving test input..." (cd inputs && wget https://g-b56e81.7a577b.6fbd.data.globus.org/1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling01.nc) From 72a38f51b512e407b06dff90aad1d322e0a897d2 Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Wed, 28 May 2025 09:10:45 +0100 Subject: [PATCH 04/15] Modify inference.py for high res model / data --- era5_training/inference.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/era5_training/inference.py b/era5_training/inference.py index 3189f6b..8113295 100644 --- a/era5_training/inference.py +++ b/era5_training/inference.py @@ -157,7 +157,7 @@ # Define test files # ------- To test on one year of ERA5 data test_files = [] -test_years = np.array([2010]) +test_years = np.array([2015]) test_month = args.month # int(sys.argv[4]) # np.arange(1,13) logger.info(f"Inference for month {test_month}") if teston == "era5": @@ -231,6 +231,11 @@ # ---- load model PATH = pref + ckpt checkpoint = torch.load(PATH, map_location=torch.device(device)) + + state_dict = checkpoint["model_state_dict"] + filtered_state_dict = {k: v for k, v in state_dict.items() if "bnorm" not in k} + model.load_state_dict(filtered_state_dict, strict=False) + model.load_state_dict(checkpoint["model_state_dict"]) model = model.to(device) model.eval() From 3ec3a57e1b0494c4979e6dab35cf000fca2fd101 Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Wed, 28 May 2025 11:10:33 +0100 Subject: [PATCH 05/15] Amend hardcoded integer for higher resolution model --- era5_training/inference.py | 3 ++- utils/dataloader_definition.py | 14 +++++++------- utils/function_training.py | 6 ++++-- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/era5_training/inference.py b/era5_training/inference.py index 8113295..0dc6192 100644 --- a/era5_training/inference.py +++ b/era5_training/inference.py @@ -219,6 +219,7 @@ ) idim = testset.idim + odim = testset.odim hdim = 4 * idim @@ -236,7 +237,7 @@ filtered_state_dict = {k: v for k, v in state_dict.items() if "bnorm" not in k} model.load_state_dict(filtered_state_dict, strict=False) - model.load_state_dict(checkpoint["model_state_dict"]) + # model.load_state_dict(checkpoint["model_state_dict"]) model = model.to(device) model.eval() diff --git a/utils/dataloader_definition.py b/utils/dataloader_definition.py index 2c65014..9d5dfe7 100644 --- a/utils/dataloader_definition.py +++ b/utils/dataloader_definition.py @@ -53,10 +53,10 @@ def __init__(self, files, domain, vertical, stencil, manual_shuffle, features, r if self.features == "uvtheta": self.v = np.arange(0, 369) # for u,v,theta elif self.features == "uvthetaw": - self.v = np.arange(0, 491) # for u,v,theta,w + self.v = np.arange(0, 551) # for u,v,theta,w elif self.features == "uvw": self.v = np.concatenate( - (np.arange(0, 247), np.arange(369, 491)), axis=0 + (np.arange(0, 247), np.arange(369, 551)), axis=0 ) # for u,v,w self.w = np.arange(0, self.odim) # all vertical channels @@ -86,7 +86,7 @@ def __init__(self, files, domain, vertical, stencil, manual_shuffle, features, r self.v = np.arange(0, 491) # for u,v,theta,w elif self.features == "uvw": self.v = np.concatenate( - (np.arange(0, 247), np.arange(369, 491)), axis=0 + (np.arange(0, 247), np.arange(369, 551)), axis=0 ) # for u,v,w self.w = np.concatenate( (np.arange(0, 60), np.arange(122, 182)), axis=0 @@ -298,10 +298,10 @@ def __init__(self, files, domain, vertical, manual_shuffle, features, region="1a if self.features == "uvtheta": self.v = np.arange(3, 369) # for u,v,theta elif self.features == "uvthetaw": - self.v = np.arange(3, 491) # for u,v,theta,w + self.v = np.arange(3, 551) # for u,v,theta,w elif self.features == "uvw": self.v = np.concatenate( - (np.arange(3, 247), np.arange(369, 491)), axis=0 + (np.arange(3, 247), np.arange(369, 551)), axis=0 ) # for u,v,w self.w = np.arange(0, self.odim) # all vertical channels @@ -328,10 +328,10 @@ def __init__(self, files, domain, vertical, manual_shuffle, features, region="1a if self.features == "uvtheta": self.v = np.arange(3, 369) # for u,v,theta elif self.features == "uvthetaw": - self.v = np.arange(3, 491) # for u,v,theta,w + self.v = np.arange(3, 551) # for u,v,theta,w elif self.features == "uvw": self.v = np.concatenate( - (np.arange(3, 247), np.arange(369, 491)), axis=0 + (np.arange(3, 247), np.arange(369, 551)), axis=0 ) # for u,v,w self.w = np.concatenate( (np.arange(0, 60), np.arange(122, 182)), axis=0 diff --git a/utils/function_training.py b/utils/function_training.py index f3d54cd..a8c20d1 100644 --- a/utils/function_training.py +++ b/utils/function_training.py @@ -192,7 +192,9 @@ def Inference_and_Save_ANN_CNN( INP = INP.reshape(T[0] * T[1], T[2], T[3], T[4]) T = OUT.shape OUT = OUT.reshape(T[0] * T[1], -1) - PRED = model(INP) + + with torch.no_grad(): + PRED = model(INP) if is_script: print("saving data...") @@ -386,7 +388,7 @@ def Inference_and_Save_AttentionUNet( model.eval() count = 0 for i, (INP, OUT) in enumerate(testloader): - # print([i,count]) + # print([i, count]) INP = INP.to(device) S = OUT.shape o_output[count : count + S[0], :, :, :] = OUT[ From 38adde6397f8f9b85b6af4ae04c036ef0e86f4e7 Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Wed, 28 May 2025 17:36:07 +0100 Subject: [PATCH 06/15] Use torch.no_grad() to save memory during inference --- era5_training/infer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/era5_training/infer.py b/era5_training/infer.py index 58154bd..8907be4 100644 --- a/era5_training/infer.py +++ b/era5_training/infer.py @@ -26,7 +26,8 @@ def main(): model = torch.jit.load(model_path) # run model inference - pred = model(torch.tensor(input_data).to(device)) + with torch.no_grad(): + pred = model(torch.tensor(input_data).to(device)) pred = pred.cpu().detach().numpy() print("pred.shape = ", pred.shape) From 0e7db00a41e407ef299c298c0f2b6363e339abcb Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Fri, 6 Jun 2025 16:27:44 +0100 Subject: [PATCH 07/15] Ensure model definition is compatible with trained version --- era5_training/inference.py | 7 +------ utils/model_definition.py | 34 +++++++++++++++++++++++++--------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/era5_training/inference.py b/era5_training/inference.py index 0dc6192..07646a4 100644 --- a/era5_training/inference.py +++ b/era5_training/inference.py @@ -232,12 +232,7 @@ # ---- load model PATH = pref + ckpt checkpoint = torch.load(PATH, map_location=torch.device(device)) - - state_dict = checkpoint["model_state_dict"] - filtered_state_dict = {k: v for k, v in state_dict.items() if "bnorm" not in k} - model.load_state_dict(filtered_state_dict, strict=False) - - # model.load_state_dict(checkpoint["model_state_dict"]) + model.load_state_dict(checkpoint["model_state_dict"]) model = model.to(device) model.eval() diff --git a/utils/model_definition.py b/utils/model_definition.py index daa82f9..49e4245 100644 --- a/utils/model_definition.py +++ b/utils/model_definition.py @@ -47,27 +47,43 @@ def __init__(self, idim, odim, hdim, stencil, dropout=0.0): self.act_cnn = nn.ReLU() self.dropout0 = nn.Dropout(p=0.5 * self.dropout_prob) + self.dropout0 = nn.Dropout(p=0.5 * self.dropout_prob) # can define a block and divide it into blocks as well self.layer1 = nn.Linear(idim, hdim) # ,dtype=torch.float16) - self.act1 = nn.LeakyReLU() - + self.act1 = ( + nn.LeakyReLU() + ) # nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.GELU()#nn.ReLU() + self.bnorm1 = nn.BatchNorm1d(hdim) self.dropout = nn.Dropout(p=self.dropout_prob) - self.layer2 = nn.Linear(hdim, hdim) - self.act2 = nn.LeakyReLU() + self.act2 = ( + nn.LeakyReLU() + ) # nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.GELU()#nn.ReLU() + self.bnorm2 = nn.BatchNorm1d(hdim) # ------------------------------------------------------- self.layer3 = nn.Linear(hdim, hdim) - self.act3 = nn.LeakyReLU() + self.act3 = ( + nn.LeakyReLU() + ) # nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.GELU()#nn.ReLU() + self.bnorm3 = nn.BatchNorm1d(hdim) # ------------------------------------------------------- self.layer4 = nn.Linear(hdim, hdim) - self.act4 = nn.LeakyReLU() + self.act4 = ( + nn.LeakyReLU() + ) # nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.GELU()#nn.ReLU() + self.bnorm4 = nn.BatchNorm1d(2 * hdim) # -------------------------------------------------------- self.layer5 = nn.Linear(hdim, hdim) - self.act5 = nn.LeakyReLU() + self.act5 = ( + nn.LeakyReLU() + ) # nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.GELU()#nn.ReLU() + self.bnorm5 = nn.BatchNorm1d(hdim) # ------------------------------------------------------- self.layer6 = nn.Linear(hdim, 2 * odim) - self.act6 = nn.LeakyReLU() - + self.act6 = ( + nn.LeakyReLU() + ) # nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.LeakyReLU()#nn.Tanh()#nn.GELU()#nn.ReLU() + self.bnorm6 = nn.BatchNorm1d(2 * odim) self.output = nn.Linear(2 * odim, odim) def forward(self, x): From 96a11405463eed07abdad576ba9851808135e856 Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Thu, 12 Jun 2025 10:28:41 +0100 Subject: [PATCH 08/15] Update to epoch 85 model --- era5_training/README.md | 2 +- era5_training/get-model-and-data.sh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/era5_training/README.md b/era5_training/README.md index 3d78fd0..a085160 100644 --- a/era5_training/README.md +++ b/era5_training/README.md @@ -51,7 +51,7 @@ test-data/ ### Ann ```bash -python inference.py -M ann -d global -v global -f uvthetaw -e 45 -m 1 -s 1 -t era5 -i inputs/ -c model-huggingface/ -o outputs/ --script +python inference.py -M ann -d global -v global -f uvthetaw -e 85 -m 1 -s 1 -t era5 -i inputs/ -c model-huggingface/ -o outputs/ --script ``` This will generate some test data and a torchscripted model, to be used by `infer.f90` and `infer.py` later on. diff --git a/era5_training/get-model-and-data.sh b/era5_training/get-model-and-data.sh index 37687a9..e53d6c3 100755 --- a/era5_training/get-model-and-data.sh +++ b/era5_training/get-model-and-data.sh @@ -5,12 +5,12 @@ mkdir -p inputs echo "retrieving model weights..." cd model-huggingface -wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/retrained_ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch45.pt +wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/retrained_ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch85.pt wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch94.pt wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/attnunet_era5_global_global_uvthetaw_mseloss_train_epoch119.pt cd .. -mv model-huggingface/retrained_ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch45.pt model-huggingface/ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch45.pt +mv model-huggingface/retrained_ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch85.pt model-huggingface/ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch85.pt echo "retrieving test input..." (cd inputs && wget https://g-b56e81.7a577b.6fbd.data.globus.org/1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling01.nc) From 9dde0801179bf403c31d0a41bdd6d92973f5628c Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Mon, 16 Jun 2025 11:52:19 +0100 Subject: [PATCH 09/15] Add device flag to model filename --- utils/function_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/function_training.py b/utils/function_training.py index a8c20d1..ccd037f 100644 --- a/utils/function_training.py +++ b/utils/function_training.py @@ -207,7 +207,7 @@ def Inference_and_Save_ANN_CNN( xdata.to_netcdf(f"test-data/ann-cnn-{k}.nc") print("scripting...") - script_to_torchscript(model, filename="nlgw_ann-cnn_gpu_scripted.pt") + script_to_torchscript(model, filename=f"nlgw_ann-cnn_{device}_scripted.pt") print("complete") S = PRED.shape From e6dcdcc2ffbb5f6b390f44cf15ed84ce02bc9424 Mon Sep 17 00:00:00 2001 From: tommelt Date: Wed, 18 Jun 2025 06:35:16 -0600 Subject: [PATCH 10/15] chore: update to be more consistent with CESM --- README.md | 2 +- era5_training/compile-and-run.sh | 36 +++++++------------------------- 2 files changed, 9 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index f047266..169abaf 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ source .nlgw/bin/activate Now we can install `poetry` ```bash -pip install poetry +pip install "poetry<2.0.0" ``` The following command installs all of the necessary dependencies for `nonlocal_gwfluxes`. diff --git a/era5_training/compile-and-run.sh b/era5_training/compile-and-run.sh index 60c0491..7af1485 100755 --- a/era5_training/compile-and-run.sh +++ b/era5_training/compile-and-run.sh @@ -1,34 +1,15 @@ -COMP=$1 +FC=ifort +FFLAGS="" -if [[ ${COMP} == "intel" ]]; then - FC=ifort - FFLAGS="" - - # source /glade/u/home/tmeltzer/cam-test/debug_env.sh - - module purge - module load cesmdev/1.0 ncarenv/23.06 craype/2.7.20 linaro-forge/23.0 intel/2023.0.0 mkl/2023.0.0 - module load ncarcompilers/1.0.0 cmake/3.26.3 cray-mpich/8.1.25 hdf5-mpi/1.12.2 - module load netcdf-mpi/4.9.2 parallel-netcdf/1.12.3 parallelio/2.6.2-debug esmf/8.6.0b04-debug -elif [[ ${COMP} == "gcc" ]]; then - - FC=gfortran - FFLAGS="-ffree-line-length-none" - - module purge - module load ncarenv/24.12 gcc/12.4.0 cmake cuda/12.3.2 netcdf/4.9.3 -else - RED='\033[0;31m' - GREEN='\033[0;32m' - YELLOW='\033[0;33m' - NC='\033[0m' # No Color - echo -e "${RED}ERROR:${YELLOW} required option missing. Please specify [${GREEN}gcc${YELLOW}] or [${GREEN}intel${YELLOW}] as compiler.${NC}" - exit 1 -fi +module --force purge +# these come from the environment listed in software_environment.txt in the CESM Case directory +module load cesmdev/1.0 ncarenv/23.06 craype/2.7.20 intel/2023.0.0 mkl/2023.0.0 ncarcompilers/1.0.0 +module load cmake/3.26.3 cray-mpich/8.1.25 hdf5-mpi/1.12.2 netcdf-mpi/4.9.2 parallel-netcdf/1.12.3 +module load parallelio/2.6.2 esmf/8.6.0b04 source ../.nlgw/bin/activate -FTORCH_ROOT="/glade/u/home/tmeltzer/FTorch/bin/ftorch_${COMP}" +FTORCH_ROOT="${HOME}/fresh/ftorch-install" NETCDF_LIB="${NETCDF}/lib" export LD_LIBRARY_PATH="${NETCDF_LIB}:${FTORCH_ROOT}/lib64:${LD_LIBRARY_PATH}" @@ -45,7 +26,6 @@ echo $COMMAND ${COMMAND} -# gdb -q --args ./infer.exe attention test-data/ . ./infer.exe attention test-data/ . echo echo "=========================================" From c8deabcaf1c8cc59750fd4302fafd8e1d20d05a3 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 17 Oct 2025 14:01:25 -0600 Subject: [PATCH 11/15] Minor changes to the file to trace the L93 ANN and UNet --- era5_training/batch_ann.sh | 41 +++++++++++++++++++++---------- era5_training/batch_unet.sh | 44 +++++++++++++++++++--------------- era5_training/inference.py | 15 ++++++++---- utils/dataloader_definition.py | 26 ++++++++++++-------- 4 files changed, 80 insertions(+), 46 deletions(-) diff --git a/era5_training/batch_ann.sh b/era5_training/batch_ann.sh index e3f655f..240367a 100644 --- a/era5_training/batch_ann.sh +++ b/era5_training/batch_ann.sh @@ -1,5 +1,5 @@ #!/bin/bash -l -#PBS -N 1x1_uvthw +#PBS -N scripting #PBS -A USTN0009 #PBS -l select=1:ncpus=4:ngpus=1:mem=80GB #PBS -l walltime=01:00:00 @@ -33,19 +33,36 @@ source ~/nonlocal_gwfluxes/.nlgw/bin/activate # -o /glade/derecho/scratch/agupta/torch_saved_models/ +#python inference.py \ +# -M attention \ +# -d global \ +# -v global \ +# -f uvthetaw \ +# -e 119 \ +# -m 1 \ +# -s 3 \ +# -t era5 \ +# -i /glade/derecho/scratch/agupta/era5_training_data/ \ +# -c /glade/derecho/scratch/agupta/hugging_face_checkpoints/ \ +# -o /glade/derecho/scratch/agupta/gw_inference_files/ + + python inference.py \ - -M attention \ - -d global \ - -v global \ - -f uvthetaw \ - -e 119 \ - -m 1 \ - -s 3 \ - -t era5 \ - -i /glade/derecho/scratch/agupta/era5_training_data/ \ - -c /glade/derecho/scratch/agupta/hugging_face_checkpoints/ \ - -o /glade/derecho/scratch/agupta/gw_inference_files/ + -M ann \ + -d global \ + -v global \ + -f uvthetaw \ + -e 70 \ + -s 1 \ + -t era5 \ + -m 1 \ + -i inputs/ \ + -c model-huggingface/ \ + -o outputs/ \ + --script + +#python inference.py -M ann -d global -v global -f uvthetaw -e 85 -m 1 -s 1 -t era5 -i /glade/derecho/scratch/agupta/new_training_data/ -c /glade/derecho/scratch/agupta/hugging_face_checkpoints/ -o /glade/derecho/scratch/agupta/gw_inference_files/ --script diff --git a/era5_training/batch_unet.sh b/era5_training/batch_unet.sh index 8bc86c6..79bf3b2 100644 --- a/era5_training/batch_unet.sh +++ b/era5_training/batch_unet.sh @@ -24,25 +24,31 @@ source ~/nonlocal_gwfluxes/.nlgw/bin/activate #python training_attention_unet.py stratosphere_only uvthetawN2 -python training.py \ - -M attention \ - -d global \ - -v stratosphere_update \ - -f uvw \ - -i /glade/derecho/scratch/agupta/era5_training_data/ \ - -o /glade/derecho/scratch/agupta/torch_saved_models/ - - -#python inference.py \ -# -M attention \ -# -d global \ -# -v stratosphere_update \ -# -f uvw \ -# -e 100 \ -# -s 1 \ -# -t era5 \ -# -m 1 \ -# -i /glade/derecho/scratch/agupta/era5_training_data/ \ +#python training.py \ +# -M attention \ +# -d global \ +# -v stratosphere_update \ +# -f uvw \ +# -i /glade/derecho/scratch/agupta/era5_training_data/ \ +# -o /glade/derecho/scratch/agupta/torch_saved_models/ + + +python inference.py \ + -M attention \ + -d global \ + -v global \ + -f uvthetaw \ + -e 100 \ + -s 1 \ + -t era5 \ + -m 1 \ + -i inputs/ \ + -c model-huggingface/ \ + -o outputs/ \ + --script + + +# -i /glade/derecho/scratch/agupta/era5_training_data/ \ # -c /glade/derecho/scratch/agupta/torch_saved_models/ \ # -o /glade/derecho/scratch/agupta/gw_inference_files/ diff --git a/era5_training/inference.py b/era5_training/inference.py index 07646a4..506f371 100644 --- a/era5_training/inference.py +++ b/era5_training/inference.py @@ -108,7 +108,7 @@ print(f"output_dir={args.output_dir}") print(f"script={args.script}") -bs_train = 20 # 80 (80 works for most). (does not work for global uvthetaw) +bs_train = 5 # 20 # 80 (80 works for most). (does not work for global uvthetaw) bs_test = bs_train # -------------------------------------------------- @@ -136,11 +136,13 @@ odir = str(args.output_dir) + "/" pref = str(args.ckpt_dir) + "/" # "/scratch/users/ag4680/torch_saved_models/attention_unet/" if model == "ann": - ckpt = f"ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}__train_epoch{epoch}.pt" + # ckpt = f"retrained_ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}__train_epoch{epoch}.pt" + ckpt = f"retrained_L93_ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}__train_epoch{epoch}.pt" log_filename = f"./{teston}_inference_ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_{features}_ckpt_epoch_{epoch}.txt" elif model == "attention": ckpt = ( - f"attnunet_era5_{domain}_{vertical}_{features}_mseloss_train_epoch{str(epoch).zfill(2)}.pt" + # f"attnunet_era5_{domain}_{vertical}_{features}_mseloss_train_epoch{str(epoch).zfill(2)}.pt" + f"retrained_L93_attnunet_era5_{domain}_{vertical}_{features}_mseloss_train_epoch{epoch}.pt" ) log_filename = ( f"./{teston}_inference_attnunet_{domain}_{vertical}_{features}_ckpt_epoch_{epoch}.txt" @@ -174,7 +176,7 @@ ) elif vertical == "global" or vertical == "stratosphere_update": if stencil == 1: - pre = idir + f"1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_" + pre = idir + f"1x1_inputfeatures_u_v_theta_w_uw_vw_gcp_era5_training_data_hourly_" else: pre = ( idir @@ -183,7 +185,10 @@ for year in test_years: for months in np.arange(test_month, test_month + 1): - test_files.append(f"{pre}{year}_constant_mu_sigma_scaling{str(months).zfill(2)}.nc") + # test_files.append(f"{pre}{year}_constant_mu_sigma_scaling{str(months).zfill(2)}.nc") # usual + test_files.append( + f"{pre}{year}_L93_constant_mu_sigma_scaling{str(months).zfill(2)}.nc" + ) # L93 elif teston == "ifs": if vertical == "stratosphere_only": diff --git a/utils/dataloader_definition.py b/utils/dataloader_definition.py index 9d5dfe7..11328eb 100644 --- a/utils/dataloader_definition.py +++ b/utils/dataloader_definition.py @@ -51,13 +51,16 @@ def __init__(self, files, domain, vertical, stencil, manual_shuffle, features, r if self.vertical == "global": # 122 channels for each feature if self.features == "uvtheta": - self.v = np.arange(0, 369) # for u,v,theta + # self.v = np.arange(0, 369) # for u,v,theta + self.v = np.arange(0, 282) # for L93 elif self.features == "uvthetaw": - self.v = np.arange(0, 551) # for u,v,theta,w + # self.v = np.arange(0, 551) # for u,v,theta,w + self.v = np.arange(0, 375) # for L93 elif self.features == "uvw": - self.v = np.concatenate( - (np.arange(0, 247), np.arange(369, 551)), axis=0 - ) # for u,v,w + # self.v = np.concatenate( + # (np.arange(0, 247), np.arange(369, 551)), axis=0 + # ) # for u,v,w + self.v = np.concatenate((np.arange(0, 189), np.arange(282, 375)), axis=0) # for L93 self.w = np.arange(0, self.odim) # all vertical channels elif self.vertical == "stratosphere_only": @@ -296,13 +299,16 @@ def __init__(self, files, domain, vertical, manual_shuffle, features, region="1a if self.vertical == "global": # 122 channels for each feature if self.features == "uvtheta": - self.v = np.arange(3, 369) # for u,v,theta + self.v = np.arange(3, 282) # for L93 + # self.v = np.arange(3, 369) # for u,v,theta elif self.features == "uvthetaw": - self.v = np.arange(3, 551) # for u,v,theta,w + self.v = np.arange(3, 375) # for L93 + # self.v = np.arange(3, 551) # for u,v,theta,w elif self.features == "uvw": - self.v = np.concatenate( - (np.arange(3, 247), np.arange(369, 551)), axis=0 - ) # for u,v,w + self.v = np.concatenate((np.arange(3, 189), np.arange(282, 375)), axis=0) # for L93 + # self.v = np.concatenate( + # (np.arange(3, 247), np.arange(369, 551)), axis=0 + # ) # for u,v,w self.w = np.arange(0, self.odim) # all vertical channels elif self.vertical == "stratosphere_only": From 7d5006cb173859272e0781255a6fe575a194e6c8 Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Fri, 5 Dec 2025 14:54:37 +0000 Subject: [PATCH 12/15] Add circular padding to UNet --- utils/model_definition.py | 50 +++++++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/utils/model_definition.py b/utils/model_definition.py index 49e4245..9004f3f 100644 --- a/utils/model_definition.py +++ b/utils/model_definition.py @@ -138,23 +138,37 @@ def totalsize(self): class Conv_block(nn.Module): def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True): super().__init__() + + if padding > 0: + # pad width dimension circularly + pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) + # pad height dimension with zeros (height, width) + conv_padding = (padding, 0) + else: + pad_layer = nn.Identity() + conv_padding = padding + + # two applications of pad_layer, conv_padding, pad_layer, conv_padding self.conv = nn.Sequential( + pad_layer, nn.Conv2d( in_channels=ch_in, out_channels=ch_out, kernel_size=kernel_size, stride=stride, - padding=padding, + padding=conv_padding, bias=bias, ), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), + + pad_layer, nn.Conv2d( in_channels=ch_out, out_channels=ch_out, kernel_size=kernel_size, stride=stride, - padding=padding, + padding=conv_padding, bias=bias, ), nn.BatchNorm2d(ch_out), @@ -169,13 +183,22 @@ def forward(self, x): class Upsample(nn.Module): def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True): super().__init__() + + if padding > 0: + pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) + conv_padding = (padding, 0) + else: + pad_layer = nn.Identity() + conv_padding = padding + self.up = nn.Sequential( + pad_layer, nn.Upsample(scale_factor=2), nn.Conv2d( in_channels=ch_in, out_channels=ch_out, kernel_size=kernel_size, - padding=padding, + padding=conv_padding, stride=stride, bias=bias, ), @@ -192,43 +215,56 @@ class Attention_block(nn.Module): def __init__( self, F_x, F_g, F_int, kernel_size=3, stride=1, padding=1, bias=True, attn_3d=False ): + super().__init__() if attn_3d: self.F_attn = F_x else: self.F_attn = 1 - super().__init__() + + if padding > 0: + # pad width dimension circularly + pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) + # pad height dimension with zeros (height, width) + conv_padding = (padding, 0) + else: + pad_layer = nn.Identity() + conv_padding = padding + self.Wx = nn.Sequential( + pad_layer, nn.Conv2d( in_channels=F_x, out_channels=F_int, kernel_size=kernel_size, stride=stride, - padding=padding, + padding=conv_padding, bias=bias, ), nn.BatchNorm2d(F_int), ) self.Wg = nn.Sequential( + pad_layer, nn.Conv2d( in_channels=F_g, out_channels=F_int, kernel_size=kernel_size, stride=stride, - padding=padding, + padding=conv_padding, bias=bias, ), nn.BatchNorm2d(F_int), ) self.Psi = nn.Sequential( + pad_layer, nn.Conv2d( in_channels=F_int, out_channels=self.F_attn, kernel_size=kernel_size, - padding=padding, stride=stride, + padding=conv_padding, bias=bias, ), nn.BatchNorm2d(self.F_attn), From aed690bc9c3ca121f07ceb411f8f3d7ca0f81390 Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Fri, 5 Dec 2025 16:39:25 +0000 Subject: [PATCH 13/15] style: apply ruff formatter --- utils/model_definition.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/utils/model_definition.py b/utils/model_definition.py index 9004f3f..5a77f4d 100644 --- a/utils/model_definition.py +++ b/utils/model_definition.py @@ -138,8 +138,8 @@ def totalsize(self): class Conv_block(nn.Module): def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True): super().__init__() - - if padding > 0: + + if padding > 0: # pad width dimension circularly pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) # pad height dimension with zeros (height, width) @@ -161,7 +161,6 @@ def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True) ), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), - pad_layer, nn.Conv2d( in_channels=ch_out, @@ -184,7 +183,7 @@ class Upsample(nn.Module): def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True): super().__init__() - if padding > 0: + if padding > 0: pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) conv_padding = (padding, 0) else: @@ -221,8 +220,7 @@ def __init__( else: self.F_attn = 1 - - if padding > 0: + if padding > 0: # pad width dimension circularly pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) # pad height dimension with zeros (height, width) From 38516dbdce726d09cdc03fd718a9126b0e56ff38 Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Sat, 6 Dec 2025 13:21:29 +0000 Subject: [PATCH 14/15] refactor: simplify code padding=0 condition never met in Conv_block, Attention_block or Upsample. --- utils/model_definition.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/utils/model_definition.py b/utils/model_definition.py index 5a77f4d..685fafb 100644 --- a/utils/model_definition.py +++ b/utils/model_definition.py @@ -139,14 +139,10 @@ class Conv_block(nn.Module): def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True): super().__init__() - if padding > 0: - # pad width dimension circularly - pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) - # pad height dimension with zeros (height, width) - conv_padding = (padding, 0) - else: - pad_layer = nn.Identity() - conv_padding = padding + # pad width dimension circularly (left, right, top, bottom) + pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) + # pad height dimension with zeros (height, width) + conv_padding = (padding, 0) # two applications of pad_layer, conv_padding, pad_layer, conv_padding self.conv = nn.Sequential( @@ -183,12 +179,10 @@ class Upsample(nn.Module): def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True): super().__init__() - if padding > 0: - pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) - conv_padding = (padding, 0) - else: - pad_layer = nn.Identity() - conv_padding = padding + # pad width dimension circularly (left, right, top, bottom) + pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) + # pad height dimension with zeros (height, width) + conv_padding = (padding, 0) self.up = nn.Sequential( pad_layer, @@ -220,14 +214,10 @@ def __init__( else: self.F_attn = 1 - if padding > 0: - # pad width dimension circularly - pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) - # pad height dimension with zeros (height, width) - conv_padding = (padding, 0) - else: - pad_layer = nn.Identity() - conv_padding = padding + # pad width dimension circularly (left, right, top, bottom) + pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) + # pad height dimension with zeros (height, width) + conv_padding = (padding, 0) self.Wx = nn.Sequential( pad_layer, From a6361c2117aebcfb790f0e4b36705bd47323ef6d Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Sat, 6 Dec 2025 13:42:43 +0000 Subject: [PATCH 15/15] feat: add replication padding at poles --- utils/model_definition.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/utils/model_definition.py b/utils/model_definition.py index 685fafb..3860edc 100644 --- a/utils/model_definition.py +++ b/utils/model_definition.py @@ -139,12 +139,11 @@ class Conv_block(nn.Module): def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True): super().__init__() - # pad width dimension circularly (left, right, top, bottom) - pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) - # pad height dimension with zeros (height, width) - conv_padding = (padding, 0) + pad_layer = nn.Sequential( + nn.CircularPad2d((padding, padding, 0, 0)), + nn.ReplicationPad2d((0, 0, padding, padding)), + ) - # two applications of pad_layer, conv_padding, pad_layer, conv_padding self.conv = nn.Sequential( pad_layer, nn.Conv2d( @@ -152,7 +151,7 @@ def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True) out_channels=ch_out, kernel_size=kernel_size, stride=stride, - padding=conv_padding, + padding=0, bias=bias, ), nn.BatchNorm2d(ch_out), @@ -163,7 +162,7 @@ def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True) out_channels=ch_out, kernel_size=kernel_size, stride=stride, - padding=conv_padding, + padding=0, bias=bias, ), nn.BatchNorm2d(ch_out), @@ -179,10 +178,10 @@ class Upsample(nn.Module): def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True): super().__init__() - # pad width dimension circularly (left, right, top, bottom) - pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) - # pad height dimension with zeros (height, width) - conv_padding = (padding, 0) + pad_layer = nn.Sequential( + nn.CircularPad2d((padding, padding, 0, 0)), + nn.ReplicationPad2d((0, 0, padding, padding)), + ) self.up = nn.Sequential( pad_layer, @@ -191,8 +190,8 @@ def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True) in_channels=ch_in, out_channels=ch_out, kernel_size=kernel_size, - padding=conv_padding, stride=stride, + padding=0, bias=bias, ), nn.BatchNorm2d(ch_out), @@ -214,10 +213,10 @@ def __init__( else: self.F_attn = 1 - # pad width dimension circularly (left, right, top, bottom) - pad_layer = nn.CircularPad2d((padding, padding, 0, 0)) - # pad height dimension with zeros (height, width) - conv_padding = (padding, 0) + pad_layer = nn.Sequential( + nn.CircularPad2d((padding, padding, 0, 0)), + nn.ReplicationPad2d((0, 0, padding, padding)), + ) self.Wx = nn.Sequential( pad_layer, @@ -226,7 +225,7 @@ def __init__( out_channels=F_int, kernel_size=kernel_size, stride=stride, - padding=conv_padding, + padding=0, bias=bias, ), nn.BatchNorm2d(F_int), @@ -239,7 +238,7 @@ def __init__( out_channels=F_int, kernel_size=kernel_size, stride=stride, - padding=conv_padding, + padding=0, bias=bias, ), nn.BatchNorm2d(F_int), @@ -252,7 +251,7 @@ def __init__( out_channels=self.F_attn, kernel_size=kernel_size, stride=stride, - padding=conv_padding, + padding=0, bias=bias, ), nn.BatchNorm2d(self.F_attn),