Skip to content

Commit 2a6c06a

Browse files
committed
Update validation and README
1 parent 5912555 commit 2a6c06a

File tree

2 files changed

+61
-79
lines changed

2 files changed

+61
-79
lines changed

examples/weather/temporal_interpolation/README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,19 @@ can use:
120120
```bash
121121
torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml ++training.optimizer_params.lr=0.0001
122122
```
123+
124+
## Validation
125+
126+
To evaluate checkpoints, you can use the `validate.py` script. The script computes a
127+
histogram of squared errors as a function of the interpolation step (+0 h to +6 h),
128+
which can be used to produce a plot similar to Figure 3 of the paper. The validation
129+
uses the same configuration files as training, with validation-specific options passed
130+
through the `validation` configuration group. Refer to the docstring of `error_by_time`
131+
in `validate.py` for the recognized options.
132+
133+
For example, to run the validation of a model trained with `train_interp.yaml` and save
134+
the resulting error histogram to `validation.nc`:
135+
136+
```bash
137+
python validate.py --config-name="train_interp ++validation.output_path=validation.nc
138+
```

examples/weather/temporal_interpolation/validate.py

Lines changed: 45 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323
import torch
2424
import xarray as xr
2525

26-
from train_interp import setup_trainer, Trainer
26+
from train import input_output_from_batch_data, setup_trainer, Trainer
2727

2828

2929
def setup_analysis(
3030
cfg: dict, checkpoint: str | None = None, shuffle: bool = False
3131
) -> Trainer:
32-
"""Setup trainer for validation analysis.
32+
"""
33+
Setup trainer for validation analysis.
3334
3435
Parameters
3536
----------
@@ -64,8 +65,9 @@ def inference_model(
6465
timesteps: int = 6,
6566
denorm: bool = True,
6667
method: Literal["fcinterp", "linear"] = "fcinterp",
67-
) -> Generator[tuple[torch.Tensor, torch.Tensor], None, None]:
68-
"""Run inference on validation data.
68+
) -> Generator[tuple[torch.Tensor, torch.Tensor, int], None, None]:
69+
"""
70+
Run inference on validation data.
6971
7072
Parameters
7173
----------
@@ -80,83 +82,41 @@ def inference_model(
8082
8183
Yields
8284
------
83-
tuple[torch.Tensor, torch.Tensor]
84-
True and predicted values for each batch.
85+
tuple[torch.Tensor, torch.Tensor, int]
86+
True values, predicted values, and timestep index for each batch.
8587
"""
8688
for batch in trainer.valid_datapipe:
8789
y_true_step = []
8890
y_pred_step = []
89-
for step in range(timesteps + 1):
90-
(invar, outvar_true) = input_output_from_batch_data_analysis(batch, step)
91-
invar = tuple(v.detach() for v in invar)
92-
outvar_true = outvar_true.detach()
93-
y_true_step.append(outvar_true)
94-
if method == "fcinterp":
95-
y_pred_step.append(trainer.eval_step(invar))
96-
elif method == "linear":
97-
y_pred_step.append(linear_interp_batch_data(batch, step))
91+
(invar, outvar_true) = input_output_from_batch_data(batch)
92+
invar = tuple(v.detach() for v in invar)
93+
outvar_true = outvar_true.detach()
94+
y_true_step.append(outvar_true)
95+
step = int(round(invar[1].item() * timesteps))
96+
if method == "fcinterp":
97+
y_pred_step.append(trainer.eval_step(invar))
98+
elif method == "linear":
99+
y_pred_step.append(linear_interp_batch_data(batch, step))
98100

99101
y_true = torch.stack(y_true_step, dim=1)
100102
y_pred = torch.stack(y_pred_step, dim=1)
101103
if denorm:
102104
y_true = denormalize(trainer, y_true)
103105
y_pred = denormalize(trainer, y_pred)
104106

105-
yield (y_true, y_pred)
106-
107+
yield (y_true, y_pred, step)
107108

108-
@torch.no_grad()
109-
def input_output_from_batch_data_analysis(
110-
batch: list[dict[str, torch.Tensor]], step: int, time_scale: float = 6 * 3600.0
111-
) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
112-
"""Convert batch data to model inputs and outputs for a specific timestep.
113-
114-
Parameters
115-
----------
116-
batch : list[dict[str, torch.Tensor]]
117-
Batch dictionary from datapipe.
118-
step : int
119-
Timestep index for output.
120-
time_scale : float, optional
121-
Length of the interpolation interval in seconds.
122109

123-
Returns
124-
-------
125-
tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]
126-
Model inputs (atmospheric variables, time) and ground truth output.
110+
def linear_interp_batch_data(
111+
batch: list[dict[str, torch.Tensor]], step: int
112+
) -> torch.Tensor:
127113
"""
128-
batch = batch[0]
129-
130-
# concatenate all input variables to a single tensor
131-
atmos_vars = batch["state_seq-atmos"]
132-
133-
atmos_vars_in = [atmos_vars[:, 0], atmos_vars[:, -1]]
134-
if "cos_zenith-atmos" in batch:
135-
atmos_vars_in = atmos_vars_in + [batch["cos_zenith-atmos"].squeeze(dim=2)]
136-
if "latlon" in batch:
137-
atmos_vars_in = atmos_vars_in + [batch["latlon"]]
138-
if "geopotential" in batch:
139-
atmos_vars_in = atmos_vars_in + [batch["geopotential"]]
140-
if "land_sea_mask" in batch:
141-
atmos_vars_in = atmos_vars_in + [batch["land_sea_mask"]]
142-
atmos_vars_in = torch.cat(atmos_vars_in, dim=1)
143-
144-
atmos_vars_out = atmos_vars[:, step]
145-
146-
time = batch["timestamps-atmos"]
147-
# normalize time coordinate
148-
time = (time[:, step : step + 1] - time[:, :1]).to(dtype=torch.float32) / time_scale
149-
150-
return ((atmos_vars_in, time), atmos_vars_out)
151-
152-
153-
def linear_interp_batch_data(batch: dict[str, torch.Tensor], step: int) -> torch.Tensor:
154-
"""Perform linear interpolation on batch data.
114+
Perform linear interpolation on batch data.
155115
156116
Parameters
157117
----------
158-
batch : dict[str, torch.Tensor]
159-
Batch dictionary from datapipe.
118+
batch : list[dict[str, torch.Tensor]]
119+
Batch data from datapipe (list containing a dictionary).
160120
step : int
161121
Timestep index for interpolation.
162122
@@ -173,7 +133,8 @@ def linear_interp_batch_data(batch: dict[str, torch.Tensor], step: int) -> torch
173133

174134

175135
def denormalize(trainer: Trainer, y: torch.Tensor) -> torch.Tensor:
176-
"""Denormalize predictions using dataset statistics.
136+
"""
137+
Denormalize predictions using dataset statistics.
177138
178139
Parameters
179140
----------
@@ -205,7 +166,12 @@ def error_by_time(
205166
nbins: int = 10000,
206167
n_samples: int = 1000,
207168
) -> tuple[list[torch.Tensor], torch.Tensor]:
208-
"""Compute error statistics for each interpolation step.
169+
"""
170+
Compute error statistics for each interpolation step. The error
171+
is computed as the squared difference of the prediction and truth
172+
and is area-weighted (i.e. multiplied by the cosine of the latitude).
173+
It is calculated on the values normalized to zero mean and unit variance,
174+
so that errors of all variables are comparable.
209175
210176
Parameters
211177
----------
@@ -229,37 +195,35 @@ def error_by_time(
229195
tuple[list[torch.Tensor], torch.Tensor]
230196
Histogram counts for each timestep and bin edges.
231197
"""
232-
trainer = setup_analysis(cfg=cfg, checkpoint=checkpoint, shuffle=True)
198+
trainer = setup_analysis(cfg=cfg, checkpoint=checkpoint)
233199

234200
lat = torch.linspace(90, -90, 721)[:-1].to(device=trainer.model.device)
235201
lat[0] = 0.5 * (lat[0] + lat[1])
236202
cos_lat = torch.cos(lat * (torch.pi / 180))[None, None, :, None]
237203

238204
bins = torch.linspace(0, max_error, nbins + 1)
239205

240-
def _hist(y_true, y_pred):
206+
def _hist(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
241207
err = (y_true - y_pred) ** 2
242208
weights = torch.ones_like(err) * cos_lat
243209
return torch.histogram(
244210
err.ravel().cpu(), bins=bins, weight=weights.ravel().cpu()
245211
)[0]
246212

247-
hist_counts = [None] * (timesteps + 1)
213+
hist_counts = [
214+
torch.zeros(nbins, dtype=torch.float64) for _ in range(timesteps + 1)
215+
]
248216

249-
for i_sample, (y_true, y_pred) in enumerate(
217+
for i_sample, (y_true, y_pred, step) in enumerate(
250218
inference_model(trainer, timesteps=timesteps, denorm=False, method=method)
251219
):
252220
if i_sample % 100 == 0:
253221
print(f"{i_sample}/{n_samples}")
254222

255-
for step in range(timesteps + 1):
256-
hist_counts_step = _hist(y_true[:, step, ...], y_pred[:, step, ...])
257-
if hist_counts[step] is None:
258-
hist_counts[step] = hist_counts_step
259-
else:
260-
hist_counts[step] += hist_counts_step
223+
hist_counts_step = _hist(y_true[:, -1, ...], y_pred[:, -1, ...])
224+
hist_counts[step] += hist_counts_step
261225

262-
if i_sample >= n_samples: # len(trainer.valid_datapipe):
226+
if i_sample + 1 >= n_samples:
263227
break
264228

265229
return (hist_counts, bins)
@@ -268,7 +232,8 @@ def _hist(y_true, y_pred):
268232
def save_histogram(
269233
hist_counts: list[torch.Tensor], bins: torch.Tensor, output_path: str
270234
) -> None:
271-
"""Save histogram data to netCDF4 file.
235+
"""
236+
Save histogram data to netCDF4 file.
272237
273238
Parameters
274239
----------
@@ -310,7 +275,8 @@ def save_histogram(
310275

311276
@hydra.main(version_base=None, config_path="config")
312277
def main(cfg: DictConfig):
313-
"""Main entry point for validation and error analysis.
278+
"""
279+
Run validation for interpolation error as a function of step.
314280
315281
Parameters
316282
----------

0 commit comments

Comments
 (0)