2323import torch
2424import xarray as xr
2525
26- from train_interp import setup_trainer , Trainer
26+ from train import input_output_from_batch_data , setup_trainer , Trainer
2727
2828
2929def setup_analysis (
3030 cfg : dict , checkpoint : str | None = None , shuffle : bool = False
3131) -> Trainer :
32- """Setup trainer for validation analysis.
32+ """
33+ Setup trainer for validation analysis.
3334
3435 Parameters
3536 ----------
@@ -64,8 +65,9 @@ def inference_model(
6465 timesteps : int = 6 ,
6566 denorm : bool = True ,
6667 method : Literal ["fcinterp" , "linear" ] = "fcinterp" ,
67- ) -> Generator [tuple [torch .Tensor , torch .Tensor ], None , None ]:
68- """Run inference on validation data.
68+ ) -> Generator [tuple [torch .Tensor , torch .Tensor , int ], None , None ]:
69+ """
70+ Run inference on validation data.
6971
7072 Parameters
7173 ----------
@@ -80,83 +82,41 @@ def inference_model(
8082
8183 Yields
8284 ------
83- tuple[torch.Tensor, torch.Tensor]
84- True and predicted values for each batch.
85+ tuple[torch.Tensor, torch.Tensor, int ]
86+ True values, predicted values, and timestep index for each batch.
8587 """
8688 for batch in trainer .valid_datapipe :
8789 y_true_step = []
8890 y_pred_step = []
89- for step in range ( timesteps + 1 ):
90- ( invar , outvar_true ) = input_output_from_batch_data_analysis ( batch , step )
91- invar = tuple ( v .detach () for v in invar )
92- outvar_true = outvar_true . detach ( )
93- y_true_step . append ( outvar_true )
94- if method == "fcinterp" :
95- y_pred_step .append (trainer .eval_step (invar ))
96- elif method == "linear" :
97- y_pred_step .append (linear_interp_batch_data (batch , step ))
91+ ( invar , outvar_true ) = input_output_from_batch_data ( batch )
92+ invar = tuple ( v . detach () for v in invar )
93+ outvar_true = outvar_true .detach ()
94+ y_true_step . append ( outvar_true )
95+ step = int ( round ( invar [ 1 ]. item () * timesteps ) )
96+ if method == "fcinterp" :
97+ y_pred_step .append (trainer .eval_step (invar ))
98+ elif method == "linear" :
99+ y_pred_step .append (linear_interp_batch_data (batch , step ))
98100
99101 y_true = torch .stack (y_true_step , dim = 1 )
100102 y_pred = torch .stack (y_pred_step , dim = 1 )
101103 if denorm :
102104 y_true = denormalize (trainer , y_true )
103105 y_pred = denormalize (trainer , y_pred )
104106
105- yield (y_true , y_pred )
106-
107+ yield (y_true , y_pred , step )
107108
108- @torch .no_grad ()
109- def input_output_from_batch_data_analysis (
110- batch : list [dict [str , torch .Tensor ]], step : int , time_scale : float = 6 * 3600.0
111- ) -> tuple [tuple [torch .Tensor , torch .Tensor ], torch .Tensor ]:
112- """Convert batch data to model inputs and outputs for a specific timestep.
113-
114- Parameters
115- ----------
116- batch : list[dict[str, torch.Tensor]]
117- Batch dictionary from datapipe.
118- step : int
119- Timestep index for output.
120- time_scale : float, optional
121- Length of the interpolation interval in seconds.
122109
123- Returns
124- -------
125- tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]
126- Model inputs (atmospheric variables, time) and ground truth output.
110+ def linear_interp_batch_data (
111+ batch : list [dict [str , torch .Tensor ]], step : int
112+ ) -> torch .Tensor :
127113 """
128- batch = batch [0 ]
129-
130- # concatenate all input variables to a single tensor
131- atmos_vars = batch ["state_seq-atmos" ]
132-
133- atmos_vars_in = [atmos_vars [:, 0 ], atmos_vars [:, - 1 ]]
134- if "cos_zenith-atmos" in batch :
135- atmos_vars_in = atmos_vars_in + [batch ["cos_zenith-atmos" ].squeeze (dim = 2 )]
136- if "latlon" in batch :
137- atmos_vars_in = atmos_vars_in + [batch ["latlon" ]]
138- if "geopotential" in batch :
139- atmos_vars_in = atmos_vars_in + [batch ["geopotential" ]]
140- if "land_sea_mask" in batch :
141- atmos_vars_in = atmos_vars_in + [batch ["land_sea_mask" ]]
142- atmos_vars_in = torch .cat (atmos_vars_in , dim = 1 )
143-
144- atmos_vars_out = atmos_vars [:, step ]
145-
146- time = batch ["timestamps-atmos" ]
147- # normalize time coordinate
148- time = (time [:, step : step + 1 ] - time [:, :1 ]).to (dtype = torch .float32 ) / time_scale
149-
150- return ((atmos_vars_in , time ), atmos_vars_out )
151-
152-
153- def linear_interp_batch_data (batch : dict [str , torch .Tensor ], step : int ) -> torch .Tensor :
154- """Perform linear interpolation on batch data.
114+ Perform linear interpolation on batch data.
155115
156116 Parameters
157117 ----------
158- batch : dict[str, torch.Tensor]
159- Batch dictionary from datapipe.
118+ batch : list[ dict[str, torch.Tensor] ]
119+ Batch data from datapipe (list containing a dictionary) .
160120 step : int
161121 Timestep index for interpolation.
162122
@@ -173,7 +133,8 @@ def linear_interp_batch_data(batch: dict[str, torch.Tensor], step: int) -> torch
173133
174134
175135def denormalize (trainer : Trainer , y : torch .Tensor ) -> torch .Tensor :
176- """Denormalize predictions using dataset statistics.
136+ """
137+ Denormalize predictions using dataset statistics.
177138
178139 Parameters
179140 ----------
@@ -205,7 +166,12 @@ def error_by_time(
205166 nbins : int = 10000 ,
206167 n_samples : int = 1000 ,
207168) -> tuple [list [torch .Tensor ], torch .Tensor ]:
208- """Compute error statistics for each interpolation step.
169+ """
170+ Compute error statistics for each interpolation step. The error
171+ is computed as the squared difference of the prediction and truth
172+ and is area-weighted (i.e. multiplied by the cosine of the latitude).
173+ It is calculated on the values normalized to zero mean and unit variance,
174+ so that errors of all variables are comparable.
209175
210176 Parameters
211177 ----------
@@ -229,37 +195,35 @@ def error_by_time(
229195 tuple[list[torch.Tensor], torch.Tensor]
230196 Histogram counts for each timestep and bin edges.
231197 """
232- trainer = setup_analysis (cfg = cfg , checkpoint = checkpoint , shuffle = True )
198+ trainer = setup_analysis (cfg = cfg , checkpoint = checkpoint )
233199
234200 lat = torch .linspace (90 , - 90 , 721 )[:- 1 ].to (device = trainer .model .device )
235201 lat [0 ] = 0.5 * (lat [0 ] + lat [1 ])
236202 cos_lat = torch .cos (lat * (torch .pi / 180 ))[None , None , :, None ]
237203
238204 bins = torch .linspace (0 , max_error , nbins + 1 )
239205
240- def _hist (y_true , y_pred ) :
206+ def _hist (y_true : torch . Tensor , y_pred : torch . Tensor ) -> torch . Tensor :
241207 err = (y_true - y_pred ) ** 2
242208 weights = torch .ones_like (err ) * cos_lat
243209 return torch .histogram (
244210 err .ravel ().cpu (), bins = bins , weight = weights .ravel ().cpu ()
245211 )[0 ]
246212
247- hist_counts = [None ] * (timesteps + 1 )
213+ hist_counts = [
214+ torch .zeros (nbins , dtype = torch .float64 ) for _ in range (timesteps + 1 )
215+ ]
248216
249- for i_sample , (y_true , y_pred ) in enumerate (
217+ for i_sample , (y_true , y_pred , step ) in enumerate (
250218 inference_model (trainer , timesteps = timesteps , denorm = False , method = method )
251219 ):
252220 if i_sample % 100 == 0 :
253221 print (f"{ i_sample } /{ n_samples } " )
254222
255- for step in range (timesteps + 1 ):
256- hist_counts_step = _hist (y_true [:, step , ...], y_pred [:, step , ...])
257- if hist_counts [step ] is None :
258- hist_counts [step ] = hist_counts_step
259- else :
260- hist_counts [step ] += hist_counts_step
223+ hist_counts_step = _hist (y_true [:, - 1 , ...], y_pred [:, - 1 , ...])
224+ hist_counts [step ] += hist_counts_step
261225
262- if i_sample >= n_samples : # len(trainer.valid_datapipe) :
226+ if i_sample + 1 >= n_samples :
263227 break
264228
265229 return (hist_counts , bins )
@@ -268,7 +232,8 @@ def _hist(y_true, y_pred):
268232def save_histogram (
269233 hist_counts : list [torch .Tensor ], bins : torch .Tensor , output_path : str
270234) -> None :
271- """Save histogram data to netCDF4 file.
235+ """
236+ Save histogram data to netCDF4 file.
272237
273238 Parameters
274239 ----------
@@ -310,7 +275,8 @@ def save_histogram(
310275
311276@hydra .main (version_base = None , config_path = "config" )
312277def main (cfg : DictConfig ):
313- """Main entry point for validation and error analysis.
278+ """
279+ Run validation for interpolation error as a function of step.
314280
315281 Parameters
316282 ----------
0 commit comments