diff --git a/GFED5/convert.py b/GFED5/convert.py index d6ba355..62b730b 100644 --- a/GFED5/convert.py +++ b/GFED5/convert.py @@ -130,7 +130,6 @@ ###################################################################### # Set dimension attributes and encoding -ds = hf.set_time_attrs(ds) ds = hf.set_lat_attrs(ds) ds = hf.set_lon_attrs(ds) @@ -148,6 +147,7 @@ # Add time bounds ds = hf.add_time_bounds_monthly(ds) +ds = hf.set_time_attrs(ds) time_range = f"{ds['time'].min().dt.year:d}{ds['time'].min().dt.month:02d}" time_range += f"-{ds['time'].max().dt.year:d}{ds['time'].max().dt.month:02d}" diff --git a/HWSD2/convert.py b/HWSD2/convert.py index 93adc7a..9f1058a 100644 --- a/HWSD2/convert.py +++ b/HWSD2/convert.py @@ -1,270 +1,361 @@ +import datetime import os +import sqlite3 import subprocess +import sys import time -import datetime -import xarray as xr -import rioxarray as rxr +import warnings + import numpy as np -import cftime as cf -from osgeo import gdal import pandas as pd -import sqlite3 -import warnings +import rioxarray as rxr +import xarray as xr from dask.distributed import Client, LocalCluster +from osgeo import gdal + +# Determine the parent directory (ILAMB-DATA) +project_dir = os.path.abspath(os.path.join(os.getcwd(), "..")) +if project_dir not in sys.path: + sys.path.insert(0, project_dir) + +# Now you can import the helper_funcs module from the scripts package. +from scripts import biblatex_builder as bb +from scripts import helper_funcs as hf ##################################################### -# set the parameters for this particular dataset +# Set Parameters ##################################################### # main parameters -chunksize = 3000 -var = 'cSoil' -long_name = 'carbon mass in soil pool' -layers = ['D1', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7'] -pools = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] -sdate = datetime.datetime(1960, 1, 1) -edate = datetime.datetime(2022, 1, 1) +VAR = "cSoil" +# VAR = "cSoilAbove1m" +LAYERS = ["D1", "D2", "D3", "D4", "D5", "D6", "D7"] # cSoil +# LAYERS = ["D1", "D2", "D3", "D4", "D5"] # cSoilAbove1m +POOLS = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # soil types +SDATE = datetime.datetime(1960, 1, 1) +EDATE = datetime.datetime(2022, 1, 1) # dask parameters -- adjust these to fit your computer's capabilities # chatgpt can optimize n_workers, n_threads, and mem_limit if you provide your computer specs! -n_workers = 20 -n_threads = 1 -mem_limit = '3.5GB' +CHUNKSIZE = 3000 +N_WORKERS = 4 +N_THREADS = 2 +MEM_LIMIT = "16GB" # paths to files -remote_rast = 'https://s3.eu-west-1.amazonaws.com/data.gaezdev.aws.fao.org/HWSD/HWSD2_RASTER.zip' -local_rast = 'HWSD2_RASTER/HWSD2.bil' -remote_data = 'https://www.isric.org/sites/default/files/HWSD2.sqlite' -local_data = 'HWSD2.sqlite' -github_path = 'https://github.com/rubisco-sfa/ILAMB-Data/blob/master/HWSD2/convert.py' +REMOTE_RAST = ( + "https://s3.eu-west-1.amazonaws.com/data.gaezdev.aws.fao.org/HWSD/HWSD2_RASTER.zip" +) +LOCAL_RAST = "HWSD2_RASTER/HWSD2.bil" +REMOTE_DATA = "https://www.isric.org/sites/default/files/HWSD2.sqlite" +LOCAL_DATA = "HWSD2.sqlite" +GITHUB_PATH = "https://github.com/rubisco-sfa/ILAMB-Data/blob/master/HWSD2/convert.py" # suppress specific warnings -warnings.filterwarnings('ignore', message='invalid value encountered in cast') +warnings.filterwarnings("ignore", message="invalid value encountered in cast") gdal.DontUseExceptions() ##################################################### # functions in the order that they are used in main() ##################################################### + # 1. download raster and sql database to connect to raster def download_data(remote_rast, remote_data): # check for raster directory rast_dir = os.path.splitext(os.path.basename(remote_rast))[0] - if not os.path.isdir(rast_dir) or not any(fname.endswith('.bil') for fname in os.listdir(rast_dir)): - subprocess.run(['mkdir', rast_dir]) - subprocess.run(['curl', '-L', remote_rast, '-o', os.path.basename(remote_rast)]) - subprocess.run(['unzip', os.path.basename(remote_rast), '-d', rast_dir]) + if not os.path.isdir(rast_dir) or not any( + fname.endswith(".bil") for fname in os.listdir(rast_dir) + ): + subprocess.run(["mkdir", rast_dir]) + subprocess.run(["curl", "-L", remote_rast, "-o", os.path.basename(remote_rast)]) + subprocess.run(["unzip", os.path.basename(remote_rast), "-d", rast_dir]) # check for database sql_database = os.path.basename(remote_data) if not os.path.isfile(sql_database): - subprocess.run(['curl', '-L', remote_data, '-o', sql_database]) + subprocess.run(["curl", "-L", remote_data, "-o", sql_database]) else: - print(f'Raster {rast_dir} and Database {sql_database} are already downloaded to current directory.') + print( + f"Raster {rast_dir} and Database {sql_database} are already downloaded to current directory." + ) + # 2. initialize the dask multiprocessing client; the link can be used to track worker progress def initialize_client(n_workers, n_threads, mem_limit): - cluster = LocalCluster(n_workers=n_workers, - threads_per_worker=n_threads, - memory_limit=mem_limit) + cluster = LocalCluster( + n_workers=n_workers, threads_per_worker=n_threads, memory_limit=mem_limit + ) client = Client(cluster) - print(f'Dask dashboard link: {client.dashboard_link}') + print(f"Dask dashboard link: {client.dashboard_link}") return client + # 3. load the raster we use to connect with HWSDv2 data def load_raster(path, chunksize): - rast = rxr.open_rasterio(path, band_as_variable=True, - mask_and_scale=True, - chunks={'x': chunksize, 'y': chunksize}) - rast = rast.astype('int16').drop_vars('spatial_ref').rename_vars(band_1='HWSD2_SMU_ID') + rast = rxr.open_rasterio( + path, + band_as_variable=True, + mask_and_scale=True, + chunks={"x": chunksize, "y": chunksize}, + ) + rast = ( + rast.astype("int16").drop_vars("spatial_ref").rename_vars(band_1="HWSD2_SMU_ID") + ) return rast + # 4. load the table with data from the sqlite database def load_layer_table(db_path, table_name): conn = sqlite3.connect(db_path) - query = f'SELECT * FROM {table_name}' + query = f"SELECT * FROM {table_name}" layer_df = pd.read_sql_query(query, conn) conn.close() return layer_df + # 5(a). function to calculate carbon stock -def calculate_stock(df, depth, bulk_density_g_cm3, cf, organic_carbon): - df['stock'] = ( - df[bulk_density_g_cm3] * - (1 - df[cf] / 100) * - df[depth] * 0.01 * - df[organic_carbon] +def calculate_stock( + df, top_depth, bottom_depth, bulk_density_g_cm3, cf, organic_carbon +): + thickness_cm = df[bottom_depth] - df[top_depth] + df["stock"] = ( + (df[bulk_density_g_cm3] * 1000) # g to kg + * (1 - df[cf] / 100) + * (thickness_cm * 0.01) # cm to meter + * (df[organic_carbon] / 100) ) - return df['stock'] + return df["stock"] + # 5(b). function to calculate weighted mean def weighted_mean(values, weights): return (values * weights).sum() / weights.sum() + # 5. process each soil layer by selecting the layer & pools of interest, # removing erroneous negative values, calculating C stock, and getting # the weighted mean of the pools def process_layers(layer_df, layers, pools, var): dfs = [] for layer in layers: - sel = layer_df[['HWSD2_SMU_ID', 'LAYER', 'SEQUENCE', 'ORG_CARBON', - 'BULK', 'BOTDEP', 'TOPDEP', 'COARSE', 'SHARE']] - df = sel[sel['LAYER'] == layer].drop(columns=['LAYER']) - df = df[df['SEQUENCE'].isin(pools)] - for attr in ['ORG_CARBON', 'BULK', 'SHARE']: + sel = layer_df[ + [ + "HWSD2_SMU_ID", + "LAYER", + "SEQUENCE", + "ORG_CARBON", + "BULK", + "BOTDEP", + "TOPDEP", + "COARSE", + "SHARE", + ] + ] + df = sel[sel["LAYER"] == layer].drop(columns=["LAYER"]) + df = df[df["SEQUENCE"].isin(pools)] + for attr in ["ORG_CARBON", "BULK", "SHARE"]: df[attr] = df[attr].where(df[attr] > 0, np.nan) - df[var] = calculate_stock(df, 'BOTDEP', 'BULK', 'COARSE', 'ORG_CARBON') - grouped = df.groupby('HWSD2_SMU_ID').apply( - lambda x: pd.Series({ - var: weighted_mean(x['ORG_CARBON'], x['SHARE']) - }), include_groups=False - ).reset_index() + df[var] = calculate_stock( + df, "TOPDEP", "BOTDEP", "BULK", "COARSE", "ORG_CARBON" + ) + grouped = ( + df.groupby("HWSD2_SMU_ID") + .apply( + lambda x: pd.Series({var: weighted_mean(x["stock"], x["SHARE"])}), + include_groups=False, + ) + .reset_index() + ) dfs.append(grouped) return dfs + # 6. combine all the layers by summing, and set the data types def combine_and_summarize(dfs, var): total_df = pd.concat(dfs) - total_df = total_df.groupby('HWSD2_SMU_ID')[var].agg('sum').reset_index(drop=False) - total_df['HWSD2_SMU_ID'] = total_df['HWSD2_SMU_ID'].astype('int16') - total_df[var] = total_df[var].astype('float32') + total_df = total_df.groupby("HWSD2_SMU_ID")[var].agg("sum").reset_index(drop=False) + total_df["HWSD2_SMU_ID"] = total_df["HWSD2_SMU_ID"].astype("int16") + total_df[var] = total_df[var].astype("float32") return total_df + # 7(a). function to map the soil unit ID to the cSoil variable def map_uid_to_var(uid, uid_to_var): - return uid_to_var.get(uid, float('nan')) + return uid_to_var.get(uid, float("nan")) + # 7. create a variable in the rast dataset containing cSoil data def apply_mapping(rast, total_df, var): - uid_to_var = total_df.set_index('HWSD2_SMU_ID')[var].to_dict() + uid_to_var = total_df.set_index("HWSD2_SMU_ID")[var].to_dict() mapped_orgc = xr.apply_ufunc( - map_uid_to_var, - rast['HWSD2_SMU_ID'], - input_core_dims=[[]], - vectorize=True, - dask='parallelized', - output_dtypes=['float32'], - kwargs={'uid_to_var': uid_to_var} + map_uid_to_var, + rast["HWSD2_SMU_ID"], + input_core_dims=[[]], + vectorize=True, + dask="parallelized", + output_dtypes=["float32"], + kwargs={"uid_to_var": uid_to_var}, ) rast = rast.assign({var: mapped_orgc}) return rast + # 8. save the rast dataset as a tif def save_raster(rast, var, layers, pools): - output_path = f'hwsd2_{var}_{layers[0]}-{layers[-1]}_seq{pools[0]}-{pools[-1]}.tif' + output_path = f"hwsd2_{var}_{layers[0]}-{layers[-1]}_seq{pools[0]}-{pools[-1]}.tif" rast[[var]].rio.to_raster(output_path) return output_path + # 9. resample the 250m resolution to 0.5deg resolution def resample_raster(input_path, output_path, xres, yres, interp, nan): - gdal.SetConfigOption('GDAL_CACHEMAX', '500') - ds = gdal.Warp(output_path, input_path, - xRes=xres, yRes=yres, - resampleAlg=interp, - outputType=gdal.GDT_Float32, - dstNodata=nan, - outputBounds=(-180.0, -90.0, 180.0, 90.0)) + gdal.SetConfigOption("GDAL_CACHEMAX", "500") + ds = gdal.Warp( + output_path, + input_path, + xRes=xres, + yRes=yres, + resampleAlg=interp, + outputType=gdal.GDT_Float32, + dstNodata=nan, + outputBounds=(-180.0, -90.0, 180.0, 90.0), + ) del ds -# 10. create a netcdf of the 0.5deg resolution raster -def create_netcdf(input_path, output_path, var, sdate, edate, long_name): +# 10. create a netcdf of the 0.5deg resolution raster +def create_netcdf( + input_path, var, sdate, edate, local_data, remote_data, github_path, pools, layers +): # open the .tif file - csoil = rxr.open_rasterio(input_path, band_as_variable=True, mask_and_scale=True) - - # rename the bands - csoil = csoil.rename({'x': 'lon', 'y': 'lat', 'band_1': var}) + ds = rxr.open_rasterio(input_path, band_as_variable=True, mask_and_scale=True) + ds = ds.rename({"x": "lon", "y": "lat", "band_1": var}) # create time dimension - tb_arr = np.asarray([ - [cf.DatetimeNoLeap(sdate.year, sdate.month, sdate.day)], - [cf.DatetimeNoLeap(edate.year, edate.month, edate.day)] - ]).T - tb_da = xr.DataArray(tb_arr, dims=('time', 'nv')) - csoil = csoil.expand_dims(time=tb_da.mean(dim='nv')) - csoil['time_bounds'] = tb_da - - # dictionaries for formatting each dimension and variable - t_attrs = {'axis': 'T', 'long_name': 'time'} - y_attrs = {'axis': 'Y', 'long_name': 'latitude', 'units': 'degrees_north'} - x_attrs = {'axis': 'X', 'long_name': 'longitude', 'units': 'degrees_east'} - v_attrs = {'long_name': long_name, 'units': 'kg m-2'} - - # set the formats - csoil['time'].attrs = t_attrs - csoil['time_bounds'].attrs['long_name'] = 'time_bounds' - csoil['lat'].attrs = y_attrs - csoil['lon'].attrs = x_attrs - csoil[var].attrs = v_attrs - - # encode time information (necessary for export) - csoil['time'].encoding['units'] = f'days since {sdate.strftime("%Y-%m-%d %H:%M:%S")}' - csoil['time'].encoding['calendar'] = 'noleap' - csoil['time'].encoding['bounds'] = 'time_bounds' - csoil['time_bounds'].encoding['units'] = f'days since {sdate.strftime("%Y-%m-%d %H:%M:%S")}' + ds = hf.add_time_bounds_single(ds, sdate, edate) + ds = hf.set_time_attrs(ds) + ds = hf.set_lat_attrs(ds) + ds = hf.set_lon_attrs(ds) + + # Get variable attribute info via ESGF CMIP variable information + info = hf.get_cmip6_variable_info(var) + + # Set variable attributes + ds = hf.set_var_attrs( + ds, + var=var, + units=info["variable_units"], + standard_name=info["cf_standard_name"], + long_name=info["variable_long_name"], + ) # create the global attributes - generate_stamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(os.path.getmtime(local_data))) - csoil.attrs = { - 'title': f'Harmonized World Soil Database version 2.0 (HWSD v2.0) {long_name}', - 'institution': 'International Soil Reference and Information Centre (ISRIC)', - 'source': 'Harmonized international soil profiles from WISE30sec 2015 with 7 soil layers and expanded soil attributes', - 'history': f""" + generate_stamp = time.strftime( + "%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(local_data)) + ) + + # create correctly formatted citation + data_citation = bb.generate_biblatex_techreport( + cite_key="Nachtergaele2023", + author=[ + "Nachtergaele, Freddy", + "van Velthuizen, Harrij", + "Verelst, Luc", + "Wiberg, Dave", + "Henry, Matieu", + "Chiozza, Federica", + "Yigini, Yusuf", + "Aksoy, Ece", + "Batjes, Niels", + "Boateng, Enoch", + "Fischer, Günther", + "Jones, Arwyn", + "Montanarella, Luca", + "Shi, Xuezheng", + "Tramberend, Sylvia", + ], + title="Harmonized World Soil Database", + institution="Food and Agriculture Organization of the United Nations and International Institute for Applied Systems Analysis, Rome and Laxenburg", + year=2023, + number="version 2.0", + ) + + history = f""" {generate_stamp}: downloaded source from {remote_data} {generate_stamp}: filtered data to soil dominance sequence(s) {pools}; where 1 is the dominant soil type {generate_stamp}: masked invalid negative organic_carbon_pct_wt and bulk_density_g_cm3 with np.nan -{generate_stamp}: calculated cSoilLevels in kg m-2 for each level {layers}: bulk_density_g_cm3 * 10 * (1 - coarse_fragment_pct_vol / 100) * bottom_depth_cm / 10 * organic_carbon_pct_wt / 100) -{generate_stamp}: calculated {var} by getting the weighted mean of all pools in a level and summing {layers} cSoilLevels +{generate_stamp}: calculated cSoilLevels in kg m-2 for each level {layers}: (bulk_density_g_cm3 * 1000) * (1 - coarse_fragment_pct_vol / 100) * (thickness_cm * 0.01) * (organic_carbon_pct_wt / 100) +{generate_stamp}: calculated {var} by getting the weighted mean of all pools in a level and summing {layers} cSoilLevels where levles are 0–20cm (D1), 20–40cm (D2), 40–60cm (D3), 60–80cm (D4), 80–100cm (D5), 100–150cm (D6), 150–200 cm (D7) {generate_stamp}: resampled to 0.5 degree resolution using mean {generate_stamp}: created CF-compliant metadata -{generate_stamp}: exact details on this process can be found at {github_path}""", - 'references': """ -@techreport{Nachtergaele2023, -author = {Nachtergaele, Freddy and van Velthuizen, Harrij and Verelst, Luc and Wiberg, Dave and Henry, Matieu and Chiozza, Federica and Yigini, Yusuf and Aksoy, Ece and Batjes, Niels and Boateng, Enoch and Fischer, Günther and Jones, Arwyn and Montanarella, Luca and Shi, Xuezheng and Tramberend, Sylvia}, -title = {Harmonized World Soil Database}, -institution = {Food and Agriculture Organization of the United Nations and International Institute for Applied Systems Analysis, Rome and Laxenburg} -year = {2023}, -number = {version 2.0}}""", - 'comment': '', - 'Conventions': 'CF-1.11' - } - - # clean up the dataset - csoil['lat'] = csoil['lat'].astype('float32') - csoil['lon'] = csoil['lon'].astype('float32') - csoil = csoil.drop_vars('spatial_ref') - csoil = csoil.reindex(lat=list(reversed(csoil.lat))) +{generate_stamp}: exact details on this process can be found at {github_path} +""" + + ds = hf.set_cf_global_attributes( + ds, + title=f"Harmonized World Soil Database version 2.0 (HWSD v2.0) {var}", + institution="International Soil Reference and Information Centre (ISRIC)", + source="Harmonized international soil profiles from WISE30sec 2015 with 7 soil layers and expanded soil attributes", + history=history, + references=data_citation, + comment="", + conventions="CF 1.12", + ) # export as netcdf - csoil.to_netcdf(output_path, format='NETCDF4', engine='netcdf4') + ds.to_netcdf( + "{variable}_{frequency}_{source_id}_{st_date}-{en_date}.nc".format( + variable=var, + frequency="fx", + source_id="HWSD2", + st_date=sdate.strftime("%Y%m%d"), + en_date=edate.strftime("%Y%m%d"), + ), + encoding={VAR: {"zlib": True}}, + ) + # use all nine steps above to convert the data into a netcdf def main(): + download_data(REMOTE_RAST, REMOTE_DATA) + + client = initialize_client(N_WORKERS, N_THREADS, MEM_LIMIT) + + rast = load_raster(LOCAL_RAST, CHUNKSIZE) + + layer_df = load_layer_table(LOCAL_DATA, "HWSD2_LAYERS") + + dfs = process_layers(layer_df, LAYERS, POOLS, VAR) + + total_df = combine_and_summarize(dfs, VAR) + + rast = apply_mapping(rast, total_df, VAR) + + output_path = save_raster(rast, VAR, LAYERS, POOLS) + + resample_raster( + output_path, + f"hwsd2_{VAR}_{LAYERS[0]}-{LAYERS[-1]}_seq{POOLS[0]}-{POOLS[-1]}_resamp.tif", + 0.5, + 0.5, + "average", + 0, + ) + + create_netcdf( + f"hwsd2_{VAR}_{LAYERS[0]}-{LAYERS[-1]}_seq{POOLS[0]}-{POOLS[-1]}_resamp.tif", + VAR, + SDATE, + EDATE, + LOCAL_DATA, + REMOTE_DATA, + GITHUB_PATH, + POOLS, + LAYERS, + ) - download_data(remote_rast, remote_data) - - client = initialize_client(n_workers, n_threads, mem_limit) - - rast = load_raster(local_rast, chunksize) - - layer_df = load_layer_table(local_data, 'HWSD2_LAYERS') - - dfs = process_layers(layer_df, layers, pools, var) - - total_df = combine_and_summarize(dfs, var) - - rast = apply_mapping(rast, total_df, var) - - output_path = save_raster(rast, var, layers, pools) - - resample_raster(output_path, - f'hwsd2_{var}_{layers[0]}-{layers[-1]}_seq{pools[0]}-{pools[-1]}_resamp.tif', - 0.5, 0.5, 'average', 0) - - create_netcdf(f'hwsd2_{var}_{layers[0]}-{layers[-1]}_seq{pools[0]}-{pools[-1]}_resamp.tif', - f'hwsd2_{var}.nc', var, sdate, edate, long_name) - client.close() if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/biblatex_builder.py b/scripts/biblatex_builder.py index e393220..85ab1e4 100644 --- a/scripts/biblatex_builder.py +++ b/scripts/biblatex_builder.py @@ -102,7 +102,7 @@ def _validate_and_format_authors(authors: list[str]) -> str: def generate_biblatex_techreport( cite_key: str, - author: str, + author: list[str], title: str, institution: str, year: str, @@ -122,8 +122,9 @@ def generate_biblatex_techreport( Returns: str: Formatted BibLaTeX @techreport entry as a multiline string. """ + author_str = _validate_and_format_authors(author) fields = { - "author": author, + "author": author_str, "title": title, "institution": institution, "year": year, diff --git a/scripts/helper_funcs.py b/scripts/helper_funcs.py index 98b6bbf..1b2ef4e 100644 --- a/scripts/helper_funcs.py +++ b/scripts/helper_funcs.py @@ -1,7 +1,7 @@ import datetime import os -import cftime as cf +import cftime import numpy as np import requests import xarray as xr @@ -135,28 +135,31 @@ def set_time_attrs(ds: xr.Dataset) -> xr.Dataset: """ Ensure the xarray dataset's time attributes are formatted according to CF-Conventions. """ - assert "time" in ds da = ds["time"] - # Ensure time is an accepted xarray time dtype + # build ref_date if np.issubdtype(da.dtype, np.datetime64): ref_date = np.datetime_as_string(da.min().values, unit="s") - elif isinstance(da.values[0], cf.datetime): + elif isinstance(da.values[0], cftime.datetime): ref_date = da.values[0].strftime("%Y-%m-%d %H:%M:%S") else: - raise TypeError( - f"Unsupported xarray time format: {type(da.values[0])}. Accepted types are np.datetime64 or cftime.datetime." - ) + raise TypeError(f"Unsupported time dtype {type(da.values[0])}.") + + # set a valid calendar + da.encoding["calendar"] = "standard" + + # set the units string + da.encoding["units"] = f"days since {ref_date}" + + # update (not replace!) the attrs so we keep .attrs["bounds"] + da.attrs.update( + { + "axis": "T", + "standard_name": "time", + "long_name": "time", + } + ) - da.encoding = { - "units": f"days since {ref_date}", - "calendar": da.encoding.get("calendar"), - } - da.attrs = { - "axis": "T", - "standard_name": "time", - "long_name": "time", - } ds["time"] = da return ds @@ -167,12 +170,14 @@ def set_lat_attrs(ds: xr.Dataset) -> xr.Dataset: """ assert "lat" in ds da = ds["lat"] - da.attrs = { - "axis": "Y", - "units": "degrees_north", - "standard_name": "latitude", - "long_name": "latitude", - } + da.attrs.update( + { + "axis": "Y", + "units": "degrees_north", + "standard_name": "latitude", + "long_name": "latitude", + } + ) ds["lat"] = da return ds @@ -183,12 +188,14 @@ def set_lon_attrs(ds: xr.Dataset) -> xr.Dataset: """ assert "lon" in ds da = ds["lon"] - da.attrs = { - "axis": "X", - "units": "degrees_east", - "standard_name": "longitude", - "long_name": "longitude", - } + da.attrs.update( + { + "axis": "X", + "units": "degrees_east", + "standard_name": "longitude", + "long_name": "longitude", + } + ) ds["lon"] = da return ds @@ -266,6 +273,50 @@ def _make_timestamp(t: xr.DataArray, ymd: tuple[int, int, int]) -> np.datetime64 return ds +def add_time_bounds_single( + ds: xr.Dataset, start_date: str, end_date: str +) -> xr.Dataset: + """ + Add a single time coordinate with bounds to an xarray Dataset. + + The 'time' coordinate is set to the midpoint between the start and end dates, + and a 'time_bounds' coordinate is added following CF conventions. + + Args: + ds (xr.Dataset): Dataset to modify. + start_date (str): Start of the time bounds (e.g., '2020-01-01'). + end_date (str): End of the time bounds (e.g., '2020-02-01'). + + Returns: + xr.Dataset: Dataset with a single 'time' coordinate and 'time_bounds'. + """ + + start = np.datetime64(start_date) + end = np.datetime64(end_date) + if end <= start: + raise ValueError("end_date must be after start_date.") + + midpoint = start + (end - start) / 2 + time = xr.DataArray([midpoint], dims="time", name="time") + time_bounds = xr.DataArray( + np.array([[start, end]], dtype="datetime64[ns]"), + dims=("time", "bounds"), + name="time_bounds", + ) + + # expand_dims can take a DataArray as the new coord + ds = ds.expand_dims({"time": time}) + ds = ds.assign_coords(time_bounds=time_bounds) + + # set the CF-required attrs on the bounds var + ds["time_bounds"].attrs["long_name"] = "time_bounds" + + # give the time coord a pointer to its bounds + ds["time"].attrs["bounds"] = "time_bounds" + + return ds + + def set_cf_global_attributes( ds: xr.Dataset, *, # keyword only for the following args diff --git a/scripts/visualize_netcdf.py b/scripts/visualize_netcdf.py new file mode 100644 index 0000000..89c31ed --- /dev/null +++ b/scripts/visualize_netcdf.py @@ -0,0 +1,48 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +import xarray as xr + +# set parameters +wdir = "../HWSD2" +file_path = f"{wdir}/cSoil_fx_HWSD2_19600101-20220101.nc" +tstep = 0 +vmin = 0 # Minimum value for colormap +vmax = 100 # Maximum value for colormap +var = "cSoil" + +# open netcdf and select time step +base_name = Path(file_path).stem +data = xr.open_dataset(file_path) +da = data[var].isel(time=tstep) + +# create and save full map +plt.figure(figsize=(10, 6)) +p = da.plot(vmin=vmin, vmax=vmax) +plt.savefig(f"{wdir}/{base_name}_timestep_{tstep}.png", dpi=300, bbox_inches="tight") +plt.close() + +#### Create zoomed-in map #### +# Define Southeastern US bounding box +lon_min, lon_max = -95, -75 +lat_min, lat_max = 25, 37 + +# Determine proper slicing directions +lat_vals = da["lat"].values +lon_vals = da["lon"].values + +lat_slice = ( + slice(lat_min, lat_max) if lat_vals[0] < lat_vals[-1] else slice(lat_max, lat_min) +) +lon_slice = ( + slice(lon_min, lon_max) if lon_vals[0] < lon_vals[-1] else slice(lon_max, lon_min) +) + +# Clip the data dynamically +da_se = da.sel(lon=lon_slice, lat=lat_slice) + +# Plot the clipped region +plt.figure(figsize=(8, 6)) +p = da_se.plot(vmin=vmin, vmax=vmax) +plt.savefig(f"{wdir}/{base_name}_SE-US_zoom.png", dpi=300, bbox_inches="tight") +plt.close()