diff --git a/src/swell/suites/localensembleda/flow.cylc b/src/swell/suites/localensembleda/flow.cylc index 8675bccba..3c7e6cefa 100644 --- a/src/swell/suites/localensembleda/flow.cylc +++ b/src/swell/suites/localensembleda/flow.cylc @@ -61,9 +61,10 @@ GetEnsembleGeosExperiment-{{model_component}} => sync_point - sync_point => RunJediObsfiltersExecutable-{{model_component}} {% if skip_ensemble_hofx %} - sync_point => RunJediObsfiltersExecutable-{{model_component}} => RunJediLocalEnsembleDaExecutable-{{model_component}} + sync_point => RunJediObsfiltersExecutable-{{model_component}} + RunJediObsfiltersExecutable-{{model_component}} => RunJediEtkfObserver-{{model_component}} + RunJediEtkfObserver-{{model_component}} => RunJediEtkfSolver-{{model_component}} {% else %} # Run hofx for ensemble members according to strategy {% if ensemble_hofx_strategy == 'serial' %} @@ -86,17 +87,17 @@ # EvaIncrement - RunJediLocalEnsembleDaExecutable-{{model_component}} => EvaIncrement-{{model_component}} + RunJediEtkfSolver-{{model_component}} => EvaIncrement-{{model_component}} # EvaObservations - # RunJediLocalEnsembleDaExecutable-{{model_component}} => EvaObservations-{{model_component}} + RunJediEtkfSolver-{{model_component}} => EvaObservations-{{model_component}} # Save observations - # RunJediLocalEnsembleDaExecutable-{{model_component}} => SaveObsDiags-{{model_component}} + # RunJediEtkfSolver-{{model_component}} => SaveObsDiags-{{model_component}} # Clean up large files # EvaObservations-{{model_component}} & SaveObsDiags-{{model_component}} & - EvaIncrement-{{model_component}} => CleanCycle-{{model_component}} + # EvaIncrement-{{model_component}} => CleanCycle-{{model_component}} {% endif %} {% endfor %} @@ -211,6 +212,24 @@ --{{key}} = {{value}} {%- endfor %} + [[RunJediEtkfObserver-{{model_component}}]] + script = "swell task RunJediEtkfObserver $config -d $datetime -m {{model_component}}" + platform = {{platform}} + execution time limit = {{scheduling["RunJediEtkfObserver"]["execution_time_limit"]}} + [[[directives]]] + {%- for key, value in scheduling["RunJediEtkfObserver"]["directives"][model_component].items() %} + --{{key}} = {{value}} + {%- endfor %} + + [[RunJediEtkfSolver-{{model_component}}]] + script = "swell task RunJediEtkfSolver $config -d $datetime -m {{model_component}}" + platform = {{platform}} + execution time limit = {{scheduling["RunJediEtkfSolver"]["execution_time_limit"]}} + [[[directives]]] + {%- for key, value in scheduling["RunJediEtkfSolver"]["directives"][model_component].items() %} + --{{key}} = {{value}} + {%- endfor %} + [[EvaIncrement-{{model_component}}]] script = "swell task EvaIncrement $config -d $datetime -m {{model_component}}" @@ -232,7 +251,6 @@ script = "swell task CleanCycle $config -d $datetime -m {{model_component}}" {% endfor %} - [[sync_point]] script = true # -------------------------------------------------------------------------------------------------- diff --git a/src/swell/suites/localensembleda/suite_config.py b/src/swell/suites/localensembleda/suite_config.py index e4c670272..9d89de48b 100644 --- a/src/swell/suites/localensembleda/suite_config.py +++ b/src/swell/suites/localensembleda/suite_config.py @@ -39,24 +39,27 @@ class SuiteConfig(QuestionContainer, Enum): 'rtodling/archive/Restarts/JEDI/541x'), qd.geos_x_ensemble_directory('/discover/nobackup/projects/gmao/dadev/' 'rtodling/archive/541/Milan'), - qd.npx_proc(3), - qd.npy_proc(3), + qd.npx_proc(4), + qd.npy_proc(4), qd.cycle_times(['T00']), qd.background_time_offset("PT3H"), qd.ensemble_num_members(3), qd.skip_ensemble_hofx(True), qd.local_ensemble_solver("Deterministic GETKF"), - qd.local_ensemble_use_linear_observer(False), + qd.local_ensemble_use_linear_observer(True), qd.ensmean_only(False), qd.local_ensemble_save_posterior_mean(True), qd.local_ensemble_save_posterior_mean_increment(True), qd.local_ensemble_save_posterior_ensemble(False), qd.local_ensemble_save_posterior_ensemble_increments(False), - qd.obs_thinning_rej_fraction(0.75), + qd.obs_thinning_rej_fraction(0.9), qd.observations([ + "sondes", + "sfcship", "atms_n20", ]), qd.window_type("3D"), + qd.change_vbc_to_sbc(False), qd.clean_patterns(['*.txt']) ] ) @@ -79,9 +82,9 @@ class SuiteConfig(QuestionContainer, Enum): 'rtodling/archive/Restarts/JEDI/541x'), qd.geos_x_ensemble_directory('/discover/nobackup/projects/gmao/dadev/' 'rtodling/archive/541/Milan'), - qd.npx_proc(4), - qd.npy_proc(4), - # qd.perhost(32), + qd.npx_proc(8), + qd.npy_proc(8), + qd.perhost(96), qd.cycle_times(['T00']), qd.background_time_offset("PT3H"), qd.ensemble_num_members(16), @@ -95,35 +98,12 @@ class SuiteConfig(QuestionContainer, Enum): qd.local_ensemble_save_posterior_ensemble_increments(False), qd.obs_thinning_rej_fraction(0.75), qd.observations([ - "aircraft_temperature", - "aircraft_wind", "sondes", - "gps", - "amsua_aqua", - "amsua_n15", - "amsua_n18", - "amsua_n19", - "amsr2_gcom-w1", - "atms_n20", - "atms_npp", - "avhrr3_metop-b", - "avhrr3_n18", - "avhrr3_n19", - "scatwind", "sfcship", - "sfc", - "mhs_metop-b", - "mhs_metop-c", - "mhs_n19", - "mls55_aura", - "omi_aura", - "ompsnm_npp", - "pibal", - "ssmis_f17", - "amsua_metop-b", - "amsua_metop-c" + "atms_n20", ]), qd.window_type("3D"), + qd.change_vbc_to_sbc(False), qd.clean_patterns(['*.txt']) ] ) diff --git a/src/swell/tasks/run_jedi_etkf_observer.py b/src/swell/tasks/run_jedi_etkf_observer.py new file mode 100644 index 000000000..4b80ed3ec --- /dev/null +++ b/src/swell/tasks/run_jedi_etkf_observer.py @@ -0,0 +1,224 @@ +# (C) Copyright 2021- United States Government as represented by the Administrator of the +# National Aeronautics and Space Administration. All Rights Reserved. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + + +# -------------------------------------------------------------------------------------------------- + + +import os +import copy +import subprocess +from ruamel.yaml import YAML + +from swell.swell_path import get_swell_path +from swell.tasks.base.task_base import taskBase +from swell.utilities.yaml_utils import replace_key + +# -------------------------------------------------------------------------------------------------- + + +class RunJediEtkfObserver(taskBase): + + # ---------------------------------------------------------------------------------------------- + + def execute(self) -> None: + + # Jedi application name + # --------------------- + jedi_application = 'localensembleda' + + # Parse configuration + # ------------------- + window_type = self.config.window_type() + window_length = self.config.window_length() + background_time_offset = self.config.background_time_offset() + + jedi_forecast_model = self.config.jedi_forecast_model(None) + generate_yaml_and_exit = self.config.generate_yaml_and_exit(False) + change_vbc_to_sbc = self.config.change_vbc_to_sbc(False) + + # Set the observing system records path + self.jedi_rendering.set_obs_records_path(self.config.observing_system_records_path(None)) + + # Compute data assimilation window parameters + background_time = self.da_window_params.background_time(background_time_offset) + local_background_time = self.da_window_params.local_background_time(window_length, + window_type) + local_background_time_iso = self.da_window_params.local_background_time_iso(window_length, + window_type) + window_begin = self.da_window_params.window_begin(window_length) + window_begin_iso = self.da_window_params.window_begin_iso(window_length) + window_end_iso = self.da_window_params.window_end_iso(window_length) + + # Populate jedi interface templates dictionary + # -------------------------------------------- + self.jedi_rendering.add_key('window_begin_iso', window_begin_iso) + self.jedi_rendering.add_key('window_length', window_length) + self.jedi_rendering.add_key('window_end_iso', window_end_iso) + + # Background + self.jedi_rendering.add_key('horizontal_resolution', self.config.horizontal_resolution()) + self.jedi_rendering.add_key('local_background_time', local_background_time) + self.jedi_rendering.add_key('local_background_time_iso', local_background_time_iso) + self.jedi_rendering.add_key('ensemble_num_members', self.config.ensemble_num_members()) + + # Geometry + self.jedi_rendering.add_key('vertical_resolution', self.config.vertical_resolution()) + self.jedi_rendering.add_key('npx_proc', self.config.npx_proc(None)) + self.jedi_rendering.add_key('npy_proc', self.config.npy_proc(None)) + self.jedi_rendering.add_key('total_processors', self.config.total_processors(None)) + + # Observations + self.jedi_rendering.add_key('background_time', background_time) + self.jedi_rendering.add_key('crtm_coeff_dir', self.config.crtm_coeff_dir(None)) + self.jedi_rendering.add_key('window_begin', window_begin) + + # Ensemble Localizations + self.jedi_rendering.add_key('horizontal_localization_method', + self.config.horizontal_localization_method()) + self.jedi_rendering.add_key('horizontal_localization_lengthscale', + self.config.horizontal_localization_lengthscale()) + self.jedi_rendering.add_key('horizontal_localization_max_nobs', + self.config.horizontal_localization_max_nobs()) + self.jedi_rendering.add_key('vertical_localization_method', + self.config.vertical_localization_method()) + self.jedi_rendering.add_key('vertical_localization_apply_log_transform', + self.config.vertical_localization_apply_log_transform()) + self.jedi_rendering.add_key('vertical_localization_lengthscale', + self.config.vertical_localization_lengthscale()) + self.jedi_rendering.add_key('vertical_localization_ioda_vertical_coord', + self.config.vertical_localization_ioda_vertical_coord()) + self.jedi_rendering.add_key('vertical_localization_ioda_vertical_coord_group', + self.config.vertical_localization_ioda_vertical_coord_group()) + self.jedi_rendering.add_key('vertical_localization_function', + self.config.vertical_localization_function()) + + # Driver + self.jedi_rendering.add_key('local_ensemble_solver', self.config.local_ensemble_solver()) + self.jedi_rendering.add_key('local_ensemble_inflation_rtps', + self.config.local_ensemble_inflation_rtps()) + self.jedi_rendering.add_key('local_ensemble_inflation_rtpp', + self.config.local_ensemble_inflation_rtpp()) + self.jedi_rendering.add_key('local_ensemble_inflation_mult', + self.config.local_ensemble_inflation_mult()) + self.jedi_rendering.add_key('local_ensemble_save_posterior_mean', + self.config.local_ensemble_save_posterior_mean()) + self.jedi_rendering.add_key('local_ensemble_save_posterior_ensemble', + self.config.local_ensemble_save_posterior_ensemble()) + self.jedi_rendering.add_key('local_ensemble_save_posterior_mean_increment', + self.config.local_ensemble_save_posterior_mean_increment()) + self.jedi_rendering.add_key('local_ensemble_save_posterior_ensemble_increments', + self.config.local_ensemble_save_posterior_ensemble_increments()) + self.jedi_rendering.add_key('local_ensemble_use_linear_observer', + self.config.local_ensemble_use_linear_observer()) + self.jedi_rendering.add_key('skip_ensemble_hofx', self.config.skip_ensemble_hofx()) + + # Prevent both 'local_ensemble_save_posterior_mean' and + # 'local_ensemble_save_posterior_ensemble' from being true + # -------------------------------------------------------- + if self.config.local_ensemble_save_posterior_mean() and \ + self.config.local_ensemble_save_posterior_ensemble(): + raise ValueError("'local_ensemble_save_posterior_mean' and\ + 'local_ensemble_save_posterior_ensemble' cannot be both true!") + + # Open the JEDI config file and fill initial templates + # ---------------------------------------------------- + jedi_config_dict = self.jedi_rendering.render_oops_file('LocalEnsembleDA', + window_type, + jedi_forecast_model) + + # Assemble localizations + # ---------------------- + # # Vertical localizations have bug(s) - Commented out for now... + # vertLoc = {'localization method': self.config.vertical_localization_method(), + # 'apply log transformation': + # self.config.vertical_localization_apply_log_transform(), + # 'vertical lengthscale': self.config.vertical_localization_lengthscale(), + # 'ioda vertical coordinate': + # self.config.vertical_localization_ioda_vertical_coord(), + # 'ioda vertical coordinate group': + # self.config.vertical_localization_ioda_vertical_coord_group(), + # 'localization function': self.config.vertical_localization_function()} + # localizations = [horizLoc, vertLoc] if len(vertLoc) != 0 else [horizLoc] + + # Include ensemble localizations and halo types with each observation + # ------------------------------------------------------------------- + + swell_path = get_swell_path() + localization_path = os.path.join(swell_path, + f'configuration/jedi/interfaces/geos_atmosphere' + f'/observations/localization') + yaml = YAML() + # update localizations in dict + for observer in jedi_config_dict['observations']['observers']: + # Get observation name + observation_name = observer['observation_name'] + config_file = os.path.join(localization_path, f'{observation_name}.yaml') + with open(config_file, 'r') as f: + loc_list = yaml.load(f) + horizLoc = loc_list['obs localizations'] + localization = [horizLoc] + observer.update({'obs localizations': localization}) + observer['obs space'].update( + {'distribution': {'name': 'RoundRobin', 'halo size': 1500.e3}}) + + # change variational bc to static bc + # ------------------------------------------------------------------- + if change_vbc_to_sbc: + for observer in jedi_config_dict['observations']['observers']: + if 'obs bias' in observer: + observer['obs bias'] = replace_key(observer['obs bias'], + "variational bc", "static bc") + model_component_meta = self.jedi_rendering.render_interface_meta() + jedi_executable = model_component_meta['executables'][f'{jedi_application}'] + jedi_executable_path = os.path.join(self.experiment_path(), 'jedi_bundle', + 'build', 'bin', jedi_executable) + + # seperate each obs and write to disk + # ------------------------------------------------------------------- + driver = jedi_config_dict['driver'] + driver['run as observer only'] = True + driver['read HX from disk'] = False + print(f'driver= {driver}') + + observers = jedi_config_dict["observations"]["observers"] + npx = 1 + npy = 1 + np = 6 * npx * npy + cmd = """ + export SLURM_MPI_TYPE=pmi2 + export I_MPI_PMI_LIBRARY=/usr/lib64/libpmi2.so + """ + cmd += f"cd {self.cycle_dir()} \n" + cmd += f"rm -f log.* logfile* \n" + for i, obs in enumerate(observers): + x0 = copy.deepcopy(jedi_config_dict) + x0["observations"]["observers"] = [obs] + x0['geometry']['layout'] = [npx, npy] + observation_name = obs['observation_name'] + tmp_file1 = os.path.join(self.cycle_dir(), f'diag_{observation_name}.yaml') + tmp_file2 = os.path.join(self.cycle_dir(), f'log.diag_{observation_name}') + with open(tmp_file1, "w") as f: + yaml.dump(x0, f) + cmd += ( + f"srun --exclusive --mpi=pmi2 -n {np} " + f"{jedi_executable_path} {tmp_file1} {tmp_file2} &\n" + ) + cmd += f"wait \n" + print(f'nobs = {i+1}') + np_use = (i+1) * np + np_total = eval(str(model_component_meta['total_processors'])) + error_msg = f'{i+1} obs: each {np} cores, np_use: {np_use} vs np_avail: {np_total}' + assert np_use <= np_total, error_msg + + if not generate_yaml_and_exit: + subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, check=True) + else: + print(f'intended mpi_command = {cmd}') + self.logger.info('YAML generated, now exiting.') + +# -------------------------------------------------------------------------------------------------- diff --git a/src/swell/tasks/run_jedi_etkf_solver.py b/src/swell/tasks/run_jedi_etkf_solver.py new file mode 100644 index 000000000..0e7bf6245 --- /dev/null +++ b/src/swell/tasks/run_jedi_etkf_solver.py @@ -0,0 +1,220 @@ +# (C) Copyright 2021- United States Government as represented by the Administrator of the +# National Aeronautics and Space Administration. All Rights Reserved. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + + +# -------------------------------------------------------------------------------------------------- + + +import os +from ruamel.yaml import YAML + +from swell.swell_path import get_swell_path +from swell.tasks.base.task_base import taskBase +from swell.utilities.run_jedi_executables import run_executable +from swell.utilities.yaml_utils import replace_key + +# -------------------------------------------------------------------------------------------------- + + +class RunJediEtkfSolver(taskBase): + + # ---------------------------------------------------------------------------------------------- + + def execute(self) -> None: + + # Jedi application name + # --------------------- + jedi_application = 'localensembleda' + + # Parse configuration + # ------------------- + window_type = self.config.window_type() + window_length = self.config.window_length() + background_time_offset = self.config.background_time_offset() + change_vbc_to_sbc = self.config.change_vbc_to_sbc(False) + + jedi_forecast_model = self.config.jedi_forecast_model(None) + generate_yaml_and_exit = self.config.generate_yaml_and_exit(False) + + # Set the observing system records path + self.jedi_rendering.set_obs_records_path(self.config.observing_system_records_path(None)) + + # Compute data assimilation window parameters + background_time = self.da_window_params.background_time(background_time_offset) + local_background_time = self.da_window_params.local_background_time(window_length, + window_type) + local_background_time_iso = self.da_window_params.local_background_time_iso(window_length, + window_type) + window_begin = self.da_window_params.window_begin(window_length) + window_begin_iso = self.da_window_params.window_begin_iso(window_length) + window_end_iso = self.da_window_params.window_end_iso(window_length) + + # Populate jedi interface templates dictionary + # -------------------------------------------- + self.jedi_rendering.add_key('window_begin_iso', window_begin_iso) + self.jedi_rendering.add_key('window_length', window_length) + self.jedi_rendering.add_key('window_end_iso', window_end_iso) + + # Background + self.jedi_rendering.add_key('horizontal_resolution', self.config.horizontal_resolution()) + self.jedi_rendering.add_key('local_background_time', local_background_time) + self.jedi_rendering.add_key('local_background_time_iso', local_background_time_iso) + self.jedi_rendering.add_key('ensemble_num_members', self.config.ensemble_num_members()) + + # Geometry + self.jedi_rendering.add_key('vertical_resolution', self.config.vertical_resolution()) + self.jedi_rendering.add_key('npx_proc', self.config.npx_proc(None)) + self.jedi_rendering.add_key('npy_proc', self.config.npy_proc(None)) + self.jedi_rendering.add_key('total_processors', self.config.total_processors(None)) + + # Observations + self.jedi_rendering.add_key('background_time', background_time) + self.jedi_rendering.add_key('crtm_coeff_dir', self.config.crtm_coeff_dir(None)) + self.jedi_rendering.add_key('window_begin', window_begin) + + # Ensemble Localizations + self.jedi_rendering.add_key('horizontal_localization_method', + self.config.horizontal_localization_method()) + self.jedi_rendering.add_key('horizontal_localization_lengthscale', + self.config.horizontal_localization_lengthscale()) + self.jedi_rendering.add_key('horizontal_localization_max_nobs', + self.config.horizontal_localization_max_nobs()) + self.jedi_rendering.add_key('vertical_localization_method', + self.config.vertical_localization_method()) + self.jedi_rendering.add_key('vertical_localization_apply_log_transform', + self.config.vertical_localization_apply_log_transform()) + self.jedi_rendering.add_key('vertical_localization_lengthscale', + self.config.vertical_localization_lengthscale()) + self.jedi_rendering.add_key('vertical_localization_ioda_vertical_coord', + self.config.vertical_localization_ioda_vertical_coord()) + self.jedi_rendering.add_key('vertical_localization_ioda_vertical_coord_group', + self.config.vertical_localization_ioda_vertical_coord_group()) + self.jedi_rendering.add_key('vertical_localization_function', + self.config.vertical_localization_function()) + + # Driver + self.jedi_rendering.add_key('local_ensemble_solver', self.config.local_ensemble_solver()) + self.jedi_rendering.add_key('local_ensemble_inflation_rtps', + self.config.local_ensemble_inflation_rtps()) + self.jedi_rendering.add_key('local_ensemble_inflation_rtpp', + self.config.local_ensemble_inflation_rtpp()) + self.jedi_rendering.add_key('local_ensemble_inflation_mult', + self.config.local_ensemble_inflation_mult()) + self.jedi_rendering.add_key('local_ensemble_save_posterior_mean', + self.config.local_ensemble_save_posterior_mean()) + self.jedi_rendering.add_key('local_ensemble_save_posterior_ensemble', + self.config.local_ensemble_save_posterior_ensemble()) + self.jedi_rendering.add_key('local_ensemble_save_posterior_mean_increment', + self.config.local_ensemble_save_posterior_mean_increment()) + self.jedi_rendering.add_key('local_ensemble_save_posterior_ensemble_increments', + self.config.local_ensemble_save_posterior_ensemble_increments()) + self.jedi_rendering.add_key('local_ensemble_use_linear_observer', + self.config.local_ensemble_use_linear_observer()) + self.jedi_rendering.add_key('skip_ensemble_hofx', self.config.skip_ensemble_hofx()) + + # Prevent both 'local_ensemble_save_posterior_mean' and + # 'local_ensemble_save_posterior_ensemble' from being true + # -------------------------------------------------------- + if self.config.local_ensemble_save_posterior_mean() and \ + self.config.local_ensemble_save_posterior_ensemble(): + raise ValueError("'local_ensemble_save_posterior_mean' and\ + 'local_ensemble_save_posterior_ensemble' cannot be both true!") + + # Jedi configuration file + # ----------------------- + jedi_config_file = os.path.join(self.cycle_dir(), f'jedi_etkf_solver_config.yaml') + + # Output log file + # --------------- + output_log_file = os.path.join(self.cycle_dir(), f'jedi_etkf_solver_log.log') + + # Open the JEDI config file and fill initial templates + # ---------------------------------------------------- + jedi_config_dict = self.jedi_rendering.render_oops_file('LocalEnsembleDA', + window_type, + jedi_forecast_model) + + # Assemble localizations + # ---------------------- + # # Vertical localizations have bug(s) - Commented out for now... + # vertLoc = {'localization method': self.config.vertical_localization_method(), + # 'apply log transformation': + # self.config.vertical_localization_apply_log_transform(), + # 'vertical lengthscale': self.config.vertical_localization_lengthscale(), + # 'ioda vertical coordinate': + # self.config.vertical_localization_ioda_vertical_coord(), + # 'ioda vertical coordinate group': + # self.config.vertical_localization_ioda_vertical_coord_group(), + # 'localization function': self.config.vertical_localization_function()} + # localizations = [horizLoc, vertLoc] if len(vertLoc) != 0 else [horizLoc] + + # Include ensemble localizations and halo types with each observation + # ------------------------------------------------------------------- + + swell_path = get_swell_path() + localization_path = os.path.join(swell_path, + f'configuration/jedi/interfaces/geos_atmosphere' + f'/observations/localization') + yaml = YAML() + # update localizations in dict + for observer in jedi_config_dict['observations']['observers']: + # Get observation name + observation_name = observer['observation_name'] + config_file = os.path.join(localization_path, f'{observation_name}.yaml') + with open(config_file, 'r') as f: + loc_list = yaml.load(f) + horizLoc = loc_list['obs localizations'] + localization = [horizLoc] + observer.update({'obs localizations': localization}) + observer['obs space'].update( + {'distribution': {'name': 'Halo', 'halo size': 1500.e3}}) + + # change variational bc to static bc + # ------------------------------------------------------------------- + if change_vbc_to_sbc: + for observer in jedi_config_dict['observations']['observers']: + if 'obs bias' in observer: + observer['obs bias'] = replace_key(observer['obs bias'], + "variational bc", "static bc") + + driver = jedi_config_dict['driver'] + driver['read HX from disk'] = True + driver['run as observer only'] = False + print(f'driver= {driver}') + + observers = jedi_config_dict["observations"]["observers"] + for i, obs in enumerate(observers): + observation_name = obs['observation_name'] + obs_file_read = obs['obs space']['obsdataout']['engine']['obsfile'] + print(f'\n obs_file_read = {obs_file_read}') + obs['obs space']['obsdatain']['engine']['obsfile'] = obs_file_read + dir_path = os.path.dirname(obs_file_read) + file_name = os.path.basename(obs_file_read) + obs['obs space']['obsdataout']['engine']['obsfile'] = ( + os.path.join(dir_path, 'solver.' + file_name) + ) + + with open(jedi_config_file, 'w') as f: + yaml.dump(jedi_config_dict, f) + + model_component_meta = self.jedi_rendering.render_interface_meta() + jedi_executable = model_component_meta['executables'][f'{jedi_application}'] + jedi_executable_path = os.path.join(self.experiment_path(), 'jedi_bundle', 'build', 'bin', + jedi_executable) + np = eval(str(model_component_meta['total_processors'])) + perhost = self.config.perhost(None) + if not generate_yaml_and_exit: + run_executable(self.logger, self.cycle_dir(), np, jedi_executable_path, + jedi_config_file, output_log_file, perhost=perhost) + else: + mpi_command = "mpirun" + if not (perhost is None or perhost == "None"): + mpi_command += f" -perhost {perhost}" + mpi_command += f" -np {np} {jedi_executable_path} {jedi_config_file} {output_log_file}" + print(f'intended mpi_command = {mpi_command}') + self.logger.info('YAML generated, now exiting.') + +# -------------------------------------------------------------------------------------------------- diff --git a/src/swell/tasks/run_jedi_local_ensemble_da_executable.py b/src/swell/tasks/run_jedi_local_ensemble_da_executable.py index b0d67a5b8..e01ead301 100644 --- a/src/swell/tasks/run_jedi_local_ensemble_da_executable.py +++ b/src/swell/tasks/run_jedi_local_ensemble_da_executable.py @@ -13,27 +13,12 @@ from swell.swell_path import get_swell_path from swell.tasks.base.task_base import taskBase +from swell.utilities.yaml_utils import replace_key from swell.utilities.run_jedi_executables import run_executable # -------------------------------------------------------------------------------------------------- -def replace_key(obj, old_key, new_key): - """ - Recursively replace dictionary keys in nested dictionaries/lists. - """ - if isinstance(obj, dict): - new_dict = {} - for k, v in obj.items(): - new_k = new_key if k == old_key else k - new_dict[new_k] = replace_key(v, old_key, new_key) - return new_dict - elif isinstance(obj, list): - return [replace_key(item, old_key, new_key) for item in obj] - else: - return obj - - class RunJediLocalEnsembleDaExecutable(taskBase): # ---------------------------------------------------------------------------------------------- @@ -55,6 +40,7 @@ def execute(self) -> None: generate_yaml_and_exit = self.config.generate_yaml_and_exit(False) ensmean_only = self.config.ensmean_only() ensmeanvariance_only = self.config.ensmeanvariance_only() + change_vbc_to_sbc = self.config.change_vbc_to_sbc(False) # Set the observing system records path self.jedi_rendering.set_obs_records_path(self.config.observing_system_records_path(None)) @@ -208,10 +194,11 @@ def execute(self) -> None: # change variational bc to static bc # ------------------------------------------------------------------- - for observer in jedi_config_dict['observations']['observers']: - if 'obs bias' in observer: - observer['obs bias'] = replace_key(observer['obs bias'], - "variational bc", "static bc") + if change_vbc_to_sbc: + for observer in jedi_config_dict['observations']['observers']: + if 'obs bias' in observer: + observer['obs bias'] = replace_key(observer['obs bias'], + "variational bc", "static bc") # Write the expanded dictionary to YAML file (in rt mode) # ------------------------------------------ diff --git a/src/swell/tasks/run_jedi_variational_executable.py b/src/swell/tasks/run_jedi_variational_executable.py index 3c9e0b5a8..d129e35e8 100644 --- a/src/swell/tasks/run_jedi_variational_executable.py +++ b/src/swell/tasks/run_jedi_variational_executable.py @@ -152,6 +152,11 @@ def execute(self) -> None: run_executable(self.logger, self.cycle_dir(), np, jedi_executable_path, jedi_config_file, output_log_file, perhost) else: + mpi_command = "mpirun" + if not (perhost is None or perhost == "None"): + mpi_command += f" -perhost {perhost}" + mpi_command += f" -np {np} {jedi_executable_path} {jedi_config_file} {output_log_file}" + print(f'intended mpi_command = {mpi_command}') self.logger.info('YAML generated, now exiting.') # -------------------------------------------------------------------------------------------------- diff --git a/src/swell/tasks/task_questions.py b/src/swell/tasks/task_questions.py index 8347850a7..5643a98ce 100644 --- a/src/swell/tasks/task_questions.py +++ b/src/swell/tasks/task_questions.py @@ -693,6 +693,7 @@ class TaskQuestions(QuestionContainer, Enum): qd.vertical_localization_lengthscale(), qd.vertical_localization_method(), qd.perhost(), + qd.change_vbc_to_sbc(), qd.comparison_log_type('localensembleda'), ] ) @@ -717,6 +718,80 @@ class TaskQuestions(QuestionContainer, Enum): # -------------------------------------------------------------------------------------------------- + RunJediEtkfObserver = QuestionList( + list_name="RunJediEtkfObserver", + questions=[ + np_proc_resolution, + window_questions, + background_crtm_obs, + qd.ensemble_num_members(), + qd.generate_yaml_and_exit(), + qd.horizontal_localization_lengthscale(), + qd.horizontal_localization_max_nobs(), + qd.horizontal_localization_method(), + qd.jedi_forecast_model(), + qd.local_ensemble_inflation_mult(), + qd.local_ensemble_inflation_rtpp(), + qd.local_ensemble_inflation_rtps(), + qd.local_ensemble_save_posterior_ensemble(), + qd.local_ensemble_save_posterior_ensemble_increments(), + qd.local_ensemble_save_posterior_mean(), + qd.local_ensemble_save_posterior_mean_increment(), + qd.local_ensemble_solver(), + qd.local_ensemble_use_linear_observer(), + qd.skip_ensemble_hofx(), + qd.total_processors(), + qd.vertical_localization_apply_log_transform(), + qd.vertical_localization_function(), + qd.vertical_localization_ioda_vertical_coord(), + qd.vertical_localization_ioda_vertical_coord_group(), + qd.vertical_localization_lengthscale(), + qd.vertical_localization_method(), + qd.perhost(), + qd.change_vbc_to_sbc(), + qd.comparison_log_type('localensembleda'), + ] + ) + + # -------------------------------------------------------------------------------------------------- + + RunJediEtkfSolver = QuestionList( + list_name="RunJediEtkfSolver", + questions=[ + np_proc_resolution, + window_questions, + background_crtm_obs, + qd.ensemble_num_members(), + qd.generate_yaml_and_exit(), + qd.horizontal_localization_lengthscale(), + qd.horizontal_localization_max_nobs(), + qd.horizontal_localization_method(), + qd.jedi_forecast_model(), + qd.local_ensemble_inflation_mult(), + qd.local_ensemble_inflation_rtpp(), + qd.local_ensemble_inflation_rtps(), + qd.local_ensemble_save_posterior_ensemble(), + qd.local_ensemble_save_posterior_ensemble_increments(), + qd.local_ensemble_save_posterior_mean(), + qd.local_ensemble_save_posterior_mean_increment(), + qd.local_ensemble_solver(), + qd.local_ensemble_use_linear_observer(), + qd.skip_ensemble_hofx(), + qd.total_processors(), + qd.vertical_localization_apply_log_transform(), + qd.vertical_localization_function(), + qd.vertical_localization_ioda_vertical_coord(), + qd.vertical_localization_ioda_vertical_coord_group(), + qd.vertical_localization_lengthscale(), + qd.vertical_localization_method(), + qd.perhost(), + qd.change_vbc_to_sbc(), + qd.comparison_log_type('localensembleda'), + ] + ) + + # -------------------------------------------------------------------------------------------------- + RunJediUfoTestsExecutable = QuestionList( list_name="RunJediUfoTestsExecutable", questions=[ diff --git a/src/swell/utilities/question_defaults.py b/src/swell/utilities/question_defaults.py index 987f64322..b60e690e0 100644 --- a/src/swell/utilities/question_defaults.py +++ b/src/swell/utilities/question_defaults.py @@ -1006,6 +1006,19 @@ class local_ensemble_use_linear_observer(TaskQuestion): # -------------------------------------------------------------------------------------------------- + @dataclass + class change_vbc_to_sbc(TaskQuestion): + default_value: str = "defer_to_model" + question_name: str = "change_vbc_to_sbc" + options: str = "defer_to_model" + models: List[str] = mutable_field([ + "geos_atmosphere" + ]) + prompt: str = "Shall variational bc be changed to static bc in local ensemble DA yaml?" + widget_type: WType = WType.BOOLEAN + + # -------------------------------------------------------------------------------------------------- + @dataclass class minimizer(TaskQuestion): default_value: str = "defer_to_model" diff --git a/src/swell/utilities/slurm.py b/src/swell/utilities/slurm.py index 787650d81..78ac48aab 100644 --- a/src/swell/utilities/slurm.py +++ b/src/swell/utilities/slurm.py @@ -46,7 +46,9 @@ def prepare_scheduling_dict( task_defaults = { "RunJediVariationalExecutable": {"all": {"nodes": 3}}, "RunJediUfoTestsExecutable": {"all": {"ntasks-per-node": 1}}, - "RunJediConvertStateSoca2ciceExecutable": {"all": {"nodes": 1}} + "RunJediConvertStateSoca2ciceExecutable": {"all": {"nodes": 1}}, + "RunJediEtkfObserver": {"all": {"nodes": 4}}, + "RunJediEtkfSolver": {"all": {"nodes": 4}} } # Global SLURM settings stored in $HOME/.swell/swell-slurm.yaml @@ -85,6 +87,8 @@ def prepare_scheduling_dict( 'RunJediHofxEnsembleExecutable', 'RunJediHofxExecutable', 'RunJediLocalEnsembleDaExecutable', + 'RunJediEtkfObserver', + 'RunJediEtkfSolver', 'RunJediObsfiltersExecutable', 'RunJediUfoTestsExecutable', 'RunJediVariationalExecutable', diff --git a/src/swell/utilities/yaml_utils.py b/src/swell/utilities/yaml_utils.py new file mode 100644 index 000000000..fc73d28e8 --- /dev/null +++ b/src/swell/utilities/yaml_utils.py @@ -0,0 +1,19 @@ + +# -------------------------------------------------------------------------------------------------- + +def replace_key(obj, old_key, new_key): + """ + Recursively replace dictionary keys in nested dictionaries/lists. + """ + if isinstance(obj, dict): + new_dict = {} + for k, v in obj.items(): + new_k = new_key if k == old_key else k + new_dict[new_k] = replace_key(v, old_key, new_key) + return new_dict + elif isinstance(obj, list): + return [replace_key(item, old_key, new_key) for item in obj] + else: + return obj + +# --------------------------------------------------------------------------------------------------