Skip to content

Commit 68c1854

Browse files
swbgCharlelieLrt
andauthored
Small CorrDiff fixes (#1062)
* apply small CorrDiff fixes * only remove checkpoints after saving new checkpoint * write CorrDiff output data only once * run pre-commit --------- Co-authored-by: Charlelie Laurent <[email protected]>
1 parent ed8b3ce commit 68c1854

File tree

4 files changed

+28
-28
lines changed

4 files changed

+28
-28
lines changed

examples/weather/corrdiff/conf/base/generation/base_all.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ perf:
4646
# Use Apex GroupNorm (optimized normalization for performance with channelslast layout)
4747
profile_mode: false
4848
# Enable NVTX annotations for performance profiling
49-
io_syncronous: true
49+
io_synchronous: true
5050
# Synchronize I/O operations for writing inference results
5151

examples/weather/corrdiff/generate.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def generate_fn():
355355
has_lead_time=has_lead_time,
356356
)
357357

358-
if cfg.generation.perf.io_syncronous:
358+
if cfg.generation.perf.io_synchronous:
359359
writer_executor = ThreadPoolExecutor(
360360
max_workers=cfg.generation.perf.num_writer_workers
361361
)
@@ -381,8 +381,9 @@ def elapsed_time(self, _):
381381
start = end = DummyEvent()
382382

383383
times = dataset.time()
384-
for index, (image_tar, image_lr, *lead_time_label) in enumerate(
385-
iter(data_loader)
384+
for dataset_index, (image_tar, image_lr, *lead_time_label) in zip(
385+
sampler,
386+
iter(data_loader),
386387
):
387388
time_index += 1
388389
if dist.rank == 0:
@@ -405,7 +406,7 @@ def elapsed_time(self, _):
405406
image_out = generate_fn()
406407
if dist.rank == 0:
407408
batch_size = image_out.shape[0]
408-
if cfg.generation.perf.io_syncronous:
409+
if cfg.generation.perf.io_synchronous:
409410
# write out data in a seperate thread so we don't hold up inferencing
410411
writer_threads.append(
411412
writer_executor.submit(
@@ -417,8 +418,7 @@ def elapsed_time(self, _):
417418
image_tar.cpu(),
418419
image_lr.cpu(),
419420
time_index,
420-
index,
421-
has_lead_time,
421+
dataset_index,
422422
)
423423
)
424424
else:
@@ -430,8 +430,7 @@ def elapsed_time(self, _):
430430
image_tar.cpu(),
431431
image_lr.cpu(),
432432
time_index,
433-
index,
434-
has_lead_time,
433+
dataset_index,
435434
)
436435
end.record()
437436
end.synchronize()
@@ -449,7 +448,7 @@ def elapsed_time(self, _):
449448
)
450449

451450
# make sure all the workers are done writing
452-
if dist.rank == 0 and cfg.generation.perf.io_syncronous:
451+
if dist.rank == 0 and cfg.generation.perf.io_synchronous:
453452
for thread in list(writer_threads):
454453
thread.result()
455454
writer_threads.remove(thread)

examples/weather/corrdiff/helpers/generate_helpers.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def save_images(
5151
image_tar,
5252
image_lr,
5353
time_index,
54-
t_index,
55-
has_lead_time,
54+
dataset_index,
5655
):
5756
"""
5857
Saves inferencing result along with the baseline
@@ -71,7 +70,7 @@ def save_images(
7170
image_tar (torch.Tensor): Ground truth data
7271
image_lr (torch.Tensor): Low resolution input data
7372
time_index (int): Epoch number
74-
t_index (int): index where times are located
73+
dataset_index (int): index where times are located
7574
"""
7675
# weather sub-plot
7776
image_lr2 = image_lr[0].unsqueeze(0)
@@ -95,7 +94,7 @@ def save_images(
9594
image_out2 = image_out2.cpu().numpy()
9695
image_out2 = dataset.denormalize_output(image_out2)
9796

98-
time = times[t_index]
97+
time = times[dataset_index]
9998
writer.write_time(time_index, time)
10099
for channel_idx in range(image_out2.shape[1]):
101100
info = dataset.output_channels()[channel_idx]
@@ -107,10 +106,10 @@ def save_images(
107106
channel_name, time_index, idx, image_out2[0, channel_idx]
108107
)
109108

110-
input_channel_info = dataset.input_channels()
111-
for channel_idx in range(len(input_channel_info)):
112-
info = input_channel_info[channel_idx]
113-
channel_name = info.name + info.level
114-
writer.write_input(channel_name, time_index, image_lr2[0, channel_idx])
115-
if channel_idx == image_lr2.shape[1] - 1:
116-
break
109+
input_channel_info = dataset.input_channels()
110+
for channel_idx in range(len(input_channel_info)):
111+
info = input_channel_info[channel_idx]
112+
channel_name = info.name + info.level
113+
writer.write_input(channel_name, time_index, image_lr2[0, channel_idx])
114+
if channel_idx == image_lr2.shape[1] - 1:
115+
break

examples/weather/corrdiff/train.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -828,13 +828,15 @@ def main(cfg: DictConfig) -> None:
828828
epoch=cur_nimg,
829829
)
830830

831-
# Retain only the recent n checkpoints, if desired
832-
if cfg.training.io.save_n_recent_checkpoints > 0:
833-
for suffix in [".mdlus", ".pt"]:
834-
ckpts = checkpoint_list(checkpoint_dir, suffix=suffix)
835-
while len(ckpts) > cfg.training.io.save_n_recent_checkpoints:
836-
os.remove(os.path.join(checkpoint_dir, ckpts[0]))
837-
ckpts = ckpts[1:]
831+
# Retain only the recent n checkpoints, if desired
832+
if cfg.training.io.save_n_recent_checkpoints > 0:
833+
for suffix in [".mdlus", ".pt"]:
834+
ckpts = checkpoint_list(checkpoint_dir, suffix=suffix)
835+
while (
836+
len(ckpts) > cfg.training.io.save_n_recent_checkpoints
837+
):
838+
os.remove(os.path.join(checkpoint_dir, ckpts[0]))
839+
ckpts = ckpts[1:]
838840

839841
# Done.
840842
logger0.info("Training Completed.")

0 commit comments

Comments
 (0)