Skip to content

Commit 24bee5c

Browse files
mnabianktangsalisaydemr
authored
CorrDiff integration: Support wandb logging (#316)
* Update blossom-ci.yml (#295) * Change pip install commands with the correct PyPI package name (#298) * add wb logging * formatting * make mode configurable --------- Co-authored-by: Kaustubh Tangsali <[email protected]> Co-authored-by: Abdullah <[email protected]>
1 parent 85c10f9 commit 24bee5c

File tree

4 files changed

+28
-4
lines changed

4 files changed

+28
-4
lines changed

examples/generative/corrdiff/conf/config_train_diffusion.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ workers: 4
6161

6262

6363
## I/O-related options
64+
wandb_mode: offline
65+
# Wights & biases mode [online, ofline, disabled]
6466
desc: ''
6567
# String to include in result dir name
6668
tick: 1

examples/generative/corrdiff/conf/config_train_regression.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ workers: 4
6161

6262

6363
## I/O-related options
64+
wandb_mode: offline
65+
# Wights & biases mode [online, ofline, disabled]
6466
desc: ''
6567
# String to include in result dir name
6668
tick: 1

examples/generative/corrdiff/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def main(cfg: DictConfig) -> None:
6969
workers = getattr(cfg, "workers", 4)
7070

7171
# Parse I/O-related options
72+
wandb_mode = getattr(cfg, "wandb_mode", "disabled")
7273
desc = getattr(cfg, "desc")
7374
tick = getattr(cfg, "tick", 1)
7475
snap = getattr(cfg, "snap", 1)
@@ -80,6 +81,7 @@ def main(cfg: DictConfig) -> None:
8081
# Parse weather data options
8182
c = EasyDict()
8283
c.task = task
84+
c.wandb_mode = wandb_mode
8385
c.train_data_path = getattr(cfg, "train_data_path")
8486
c.crop_size_x = getattr(cfg, "crop_size_x", 448)
8587
c.crop_size_y = getattr(cfg, "crop_size_y", 448)

examples/generative/corrdiff/training/training_loop.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import sys
2121
import time
22+
import wandb as wb
2223

2324
import numpy as np
2425
import psutil
@@ -29,7 +30,11 @@
2930
sys.path.append("../")
3031
from module import Module
3132
from modulus.distributed import DistributedManager
32-
from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper
33+
from modulus.launch.logging import (
34+
PythonLogger,
35+
RankZeroLoggingWrapper,
36+
initialize_wandb,
37+
)
3338
from modulus.utils.generative import (
3439
InfiniteSampler,
3540
construct_class_by_name,
@@ -81,6 +86,7 @@ def training_loop(
8186
gridtype="sinusoidal",
8287
N_grid_channels=4,
8388
normalization="v1",
89+
wandb_mode="disabled",
8490
):
8591
"""CorrDiff training loop"""
8692

@@ -93,6 +99,15 @@ def training_loop(
9399
logger0 = RankZeroLoggingWrapper(logger, dist)
94100
logger.file_logging(file_name=f".logs/training_loop_{dist.rank}.log")
95101

102+
# wandb logger
103+
initialize_wandb(
104+
project="Modulus-Generative",
105+
entity="Modulus",
106+
name="CorrDiff",
107+
group="CorrDiff-DDP-Group",
108+
mode=wandb_mode,
109+
)
110+
96111
# Initialize.
97112
start_time = time.time()
98113

@@ -241,6 +256,7 @@ def training_loop(
241256
while True:
242257
# Accumulate gradients.
243258
optimizer.zero_grad(set_to_none=True)
259+
loss_accum = 0
244260
for round_idx in range(num_accumulation_rounds):
245261
with ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
246262
# Fetch training data: weather
@@ -261,13 +277,17 @@ def training_loop(
261277
augment_pipe=augment_pipe,
262278
)
263279
training_stats.report("Loss/loss", loss)
264-
loss.sum().mul(loss_scaling / batch_gpu_total).backward()
280+
loss = loss.sum().mul(loss_scaling / batch_gpu_total)
281+
loss_accum += loss
282+
loss.backward()
283+
wb.log({"loss": loss_accum}, step=cur_nimg)
265284

266285
# Update weights.
267286
for g in optimizer.param_groups:
268287
g["lr"] = optimizer_kwargs["lr"] * min(
269288
cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1
270289
) # TODO better handling (potential bug)
290+
wb.log({"lr": g["lr"]}, step=cur_nimg)
271291
for param in net.parameters():
272292
if param.grad is not None:
273293
torch.nan_to_num(
@@ -324,8 +344,6 @@ def training_loop(
324344
torch.cuda.reset_peak_memory_stats()
325345
logger0.info(" ".join(fields))
326346

327-
ckpt_dir = run_dir
328-
329347
# Save full dump of the training state.
330348
if (
331349
(state_dump_ticks is not None)

0 commit comments

Comments
 (0)