Skip to content

Commit cd79949

Browse files
authored
Merge pull request #14 from SCOREC/ci-tests
Adding CI test suite A --smoke-test flag that runs a brief end-to-end training loop. It uses a new SyntheticXPointDataset to generate mock data and makes sure the core training pipeline is functional. This provides a sanity check, confirming that data loading, model forward/backward passes, and the main training loop execute without errors. A pytest suite for more targeted checks. This includes unit tests for components like the DiceLoss function and an integration test to verify that model checkpointing (saving and loading) works correctly. Run with `pytest test_xpoint_ml.py`.
2 parents 624e08f + 9e2e65a commit cd79949

File tree

3 files changed

+406
-9
lines changed

3 files changed

+406
-9
lines changed

XPointMLTest.py

Lines changed: 118 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from timeit import default_timer as timer
2222

23+
from ci_tests import SyntheticXPointDataset, test_checkpoint_functionality
24+
2325
def expand_xpoints_mask(binary_mask, kernel_size=9):
2426
"""
2527
Expands each X-point in a binary mask to include surrounding cells
@@ -668,10 +670,19 @@ def parseCommandLineArgs():
668670
help='create figures of the ground truth X-points and model identified X-points')
669671
parser.add_argument('--plotDir', type=Path, default="./plots",
670672
help='directory where figures are written')
673+
674+
# CI TEST: Add smoke test flag
675+
parser.add_argument('--smoke-test', action='store_true',
676+
help='Run a minimal smoke test for CI (overrides other parameters)')
677+
671678
args = parser.parse_args()
672679
return args
673680

674681
def checkCommandLineArgs(args):
682+
# CI TEST: Skip file checks in smoke test mode
683+
if args.smoke_test:
684+
return
685+
675686
if args.xptCacheDir != None:
676687
if not args.xptCacheDir.is_dir():
677688
print(f"Xpoint cache directory {args.xptCacheDir} does not exist. "
@@ -801,6 +812,32 @@ def load_model_checkpoint(model, optimizer, checkpoint_path):
801812

802813
def main():
803814
args = parseCommandLineArgs()
815+
816+
# CI TEST: Override parameters for smoke test
817+
if args.smoke_test:
818+
print("=" * 60)
819+
print("RUNNING IN SMOKE TEST MODE FOR CI")
820+
print("=" * 60)
821+
822+
# Override with minimal parameters
823+
args.epochs = 5
824+
args.batchSize = 1
825+
args.trainFrameFirst = 1
826+
args.trainFrameLast = 11 # 10 frames for training
827+
args.validationFrameFirst = 11
828+
args.validationFrameLast = 12 # 1 frame for validation
829+
args.plot = False # Disable plotting for CI
830+
args.checkPointFrequency = 2 # Save more frequently
831+
args.minTrainingLoss = 0 # Don't fail on convergence in smoke test
832+
833+
print("Smoke test parameters:")
834+
print(f" - Training frames: {args.trainFrameFirst} to {args.trainFrameLast-1}")
835+
print(f" - Validation frames: {args.validationFrameFirst} to {args.validationFrameLast-1}")
836+
print(f" - Epochs: {args.epochs}")
837+
print(f" - Batch size: {args.batchSize}")
838+
print(f" - Plotting disabled")
839+
print("=" * 60)
840+
804841
checkCommandLineArgs(args)
805842
printCommandLineArgs(args)
806843

@@ -809,13 +846,22 @@ def main():
809846
os.makedirs(outDir, exist_ok=True)
810847

811848
t0 = timer()
812-
train_fnums = range(args.trainFrameFirst, args.trainFrameLast)
813-
val_fnums = range(args.validationFrameFirst, args.validationFrameLast)
849+
850+
# CI TEST: Use synthetic data for smoke test
851+
if args.smoke_test:
852+
print("\nUsing synthetic data for smoke test...")
853+
train_dataset = SyntheticXPointDataset(nframes=10, shape=(64, 64), nxpoints=3)
854+
val_dataset = SyntheticXPointDataset(nframes=1, shape=(64, 64), nxpoints=3, seed=123)
855+
print(f"Created synthetic datasets: {len(train_dataset)} train, {len(val_dataset)} val frames")
856+
else:
857+
# Original data loading
858+
train_fnums = range(args.trainFrameFirst, args.trainFrameLast)
859+
val_fnums = range(args.validationFrameFirst, args.validationFrameLast)
814860

815-
train_dataset = XPointDataset(args.paramFile, train_fnums,
816-
xptCacheDir=args.xptCacheDir, rotateAndReflect=True)
817-
val_dataset = XPointDataset(args.paramFile, val_fnums,
818-
xptCacheDir=args.xptCacheDir)
861+
train_dataset = XPointDataset(args.paramFile, train_fnums,
862+
xptCacheDir=args.xptCacheDir, rotateAndReflect=True)
863+
val_dataset = XPointDataset(args.paramFile, val_fnums,
864+
xptCacheDir=args.xptCacheDir)
819865

820866
t1 = timer()
821867
print("time (s) to create gkyl data loader: " + str(t1-t0))
@@ -838,7 +884,7 @@ def main():
838884
train_loss = []
839885
val_loss = []
840886

841-
if os.path.exists(latest_checkpoint_path):
887+
if os.path.exists(latest_checkpoint_path) and not args.smoke_test:
842888
model, optimizer, start_epoch, train_loss, val_loss = load_model_checkpoint(
843889
model, optimizer, latest_checkpoint_path
844890
)
@@ -862,6 +908,65 @@ def main():
862908
plot_training_history(train_loss, val_loss)
863909
print("time (s) to train model: " + str(timer()-t2))
864910

911+
# CI TEST: Run additional tests if in smoke test mode
912+
if args.smoke_test:
913+
print("\n" + "="*60)
914+
print("SMOKE TEST: Running additional CI tests")
915+
print("="*60)
916+
917+
# Test 1: Checkpoint save/load
918+
checkpoint_test_passed = test_checkpoint_functionality(
919+
model, optimizer, device, val_loader, criterion, None, UNet, optim.Adam
920+
)
921+
922+
# Test 2: Inference test
923+
print("Running inference test...")
924+
model.eval()
925+
with torch.no_grad():
926+
# Get one batch
927+
test_batch = next(iter(val_loader))
928+
test_input = test_batch["all"].to(device)
929+
test_output = model(test_input)
930+
931+
# Apply sigmoid to get probabilities
932+
test_probs = torch.sigmoid(test_output)
933+
934+
print(f"Input shape: {test_input.shape}")
935+
print(f"Output shape: {test_output.shape}")
936+
print(f"Output range (logits): [{test_output.min():.3f}, {test_output.max():.3f}]")
937+
print(f"Output range (probs): [{test_probs.min():.3f}, {test_probs.max():.3f}]")
938+
print(f"Predicted X-points: {(test_probs > 0.5).sum().item()} pixels")
939+
940+
# Test 3: Check if model learned anything
941+
initial_train_loss = train_loss[0] if train_loss else float('inf')
942+
final_train_loss = train_loss[-1] if train_loss else float('inf')
943+
944+
print(f"\nTraining progress:")
945+
print(f"Initial loss: {initial_train_loss:.6f}")
946+
print(f"Final loss: {final_train_loss:.6f}")
947+
948+
if final_train_loss < initial_train_loss:
949+
print("✓ Model showed improvement during training")
950+
training_improved = True
951+
else:
952+
print("✗ Model did not improve during training")
953+
training_improved = False
954+
955+
# Overall smoke test result
956+
print("\n" + "="*60)
957+
print("SMOKE TEST SUMMARY")
958+
print("="*60)
959+
print(f"Checkpoint test: {'PASSED' if checkpoint_test_passed else 'FAILED'}")
960+
print(f"Training improvement: {'YES' if training_improved else 'NO'}")
961+
print(f"Overall result: {'PASSED' if checkpoint_test_passed else 'FAILED'}")
962+
print("="*60)
963+
964+
# Return appropriate exit code for CI
965+
if not checkpoint_test_passed:
966+
return 1
967+
else:
968+
return 0
969+
865970
requiredLossDecreaseMagnitude = args.minTrainingLoss
866971
if np.log10(abs(train_loss[0]/train_loss[-1])) < requiredLossDecreaseMagnitude:
867972
print(f"TrainLoss reduced by less than {requiredLossDecreaseMagnitude} orders of magnitude: "
@@ -874,8 +979,12 @@ def main():
874979
interpFac = 1
875980

876981
# Evaluate on combined set for demonstration. Exam this part to see if save to remove
877-
full_fnums = list(train_fnums) + list(val_fnums)
878-
full_dataset = [train_dataset, val_dataset]
982+
if not args.smoke_test:
983+
train_fnums = range(args.trainFrameFirst, args.trainFrameLast)
984+
val_fnums = range(args.validationFrameFirst, args.validationFrameLast)
985+
full_dataset = [train_dataset, val_dataset]
986+
else:
987+
full_dataset = [val_dataset] # Only use validation data for smoke test
879988

880989
t4 = timer()
881990

ci_tests.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import numpy as np
2+
import torch
3+
from torch.utils.data import Dataset, DataLoader
4+
import torch.optim as optim
5+
import os
6+
7+
class SyntheticXPointDataset(Dataset):
8+
"""
9+
Synthetic dataset for CI testing that doesn't require actual simulation data.
10+
Creates predictable X-point patterns for testing model training pipeline.
11+
"""
12+
def __init__(self, nframes=2, shape=(64, 64), nxpoints=4, seed=42):
13+
"""
14+
nframes: Number of synthetic frames to generate
15+
shape: Spatial dimensions (H, W) of each frame
16+
nxpoints: Approximate number of X-points per frame
17+
seed: Random seed for reproducibility
18+
"""
19+
super().__init__()
20+
self.nframes = nframes
21+
self.shape = shape
22+
self.nxpoints = nxpoints
23+
self.rng = np.random.RandomState(seed)
24+
25+
#pre-generate all frames for consistency
26+
self.data = []
27+
for i in range(nframes):
28+
frame_data = self._generate_frame(i)
29+
self.data.append(frame_data)
30+
31+
def _generate_frame(self, idx):
32+
"""Generate a single synthetic frame with X-points"""
33+
H, W = self.shape
34+
35+
#create synthetic psi field with some structure
36+
x = np.linspace(-np.pi, np.pi, W)
37+
y = np.linspace(-np.pi, np.pi, H)
38+
X, Y = np.meshgrid(x, y)
39+
40+
#create a field with saddle points (X-points)
41+
psi = np.sin(X + 0.1*idx) * np.cos(Y + 0.1*idx) + \
42+
0.5 * np.sin(2*X) * np.cos(2*Y)
43+
44+
# add some noise
45+
psi += 0.1 * self.rng.randn(H, W)
46+
47+
#create synthetic B fields (derivatives of psi)
48+
bx = np.gradient(psi, axis=0)
49+
by = -np.gradient(psi, axis=1)
50+
51+
#create synthetic current (Laplacian of psi)
52+
jz = -(np.gradient(np.gradient(psi, axis=0), axis=0) +
53+
np.gradient(np.gradient(psi, axis=1), axis=1))
54+
55+
# create X-point mask
56+
mask = np.zeros((H, W), dtype=np.float32)
57+
58+
for _ in range(self.nxpoints):
59+
x_loc = self.rng.randint(10, W-10)
60+
y_loc = self.rng.randint(10, H-10)
61+
# Create 9x9 region around X-point
62+
mask[max(0, y_loc-4):min(H, y_loc+5),
63+
max(0, x_loc-4):min(W, x_loc+5)] = 1.0
64+
65+
#Convert to torch tensors
66+
psi_torch = torch.from_numpy(psi.astype(np.float32)).unsqueeze(0)
67+
bx_torch = torch.from_numpy(bx.astype(np.float32)).unsqueeze(0)
68+
by_torch = torch.from_numpy(by.astype(np.float32)).unsqueeze(0)
69+
jz_torch = torch.from_numpy(jz.astype(np.float32)).unsqueeze(0)
70+
all_torch = torch.cat((psi_torch, bx_torch, by_torch, jz_torch))
71+
mask_torch = torch.from_numpy(mask).float().unsqueeze(0)
72+
73+
x_coords = np.linspace(0, 1, W)
74+
y_coords = np.linspace(0, 1, H)
75+
76+
params = {
77+
"axesNorm": 1.0, "plotContours": 1, "colorContours": 'k',
78+
"numContours": 50, "axisEqual": 1, "symBar": 1, "colormap": 'bwr'
79+
}
80+
81+
return {
82+
"fnum": idx, "rotation": 0, "reflectionAxis": -1, "psi": psi_torch,
83+
"all": all_torch, "mask": mask_torch, "x": x_coords, "y": y_coords,
84+
"filenameBase": f"synthetic_frame_{idx}", "params": params
85+
}
86+
87+
def __len__(self):
88+
return self.nframes
89+
90+
def __getitem__(self, idx):
91+
return self.data[idx]
92+
93+
def test_checkpoint_functionality(model, optimizer, device, val_loader, criterion, scaler, UNet, Adam):
94+
"""
95+
Test that model can be saved and loaded correctly.
96+
Returns True if test passes, False otherwise.
97+
98+
"""
99+
# Import locally to prevent circular dependency
100+
from XPointMLTest import validate_one_epoch
101+
102+
print("\n" + "="*60)
103+
print("TESTING CHECKPOINT SAVE/LOAD FUNCTIONALITY")
104+
print("="*60)
105+
106+
#get initial validation loss
107+
model.eval()
108+
initial_loss = validate_one_epoch(model, val_loader, criterion, device)
109+
print(f"Initial validation loss: {initial_loss:.6f}")
110+
111+
#saves checkpoint
112+
test_checkpoint_path = "smoke_test_checkpoint.pt"
113+
checkpoint = {
114+
'model_state_dict': model.state_dict(),
115+
'optimizer_state_dict': optimizer.state_dict(),
116+
'val_loss': initial_loss,
117+
'test_value': 42
118+
}
119+
120+
torch.save(checkpoint, test_checkpoint_path)
121+
print(f"Saved checkpoint to {test_checkpoint_path}")
122+
123+
# create new model and optimizer
124+
model2 = UNet(input_channels=4, base_channels=64).to(device)
125+
optimizer2 = Adam(model2.parameters(), lr=1e-5)
126+
127+
# load checkpoint
128+
loaded_checkpoint = torch.load(test_checkpoint_path, map_location=device, weights_only=False)
129+
model2.load_state_dict(loaded_checkpoint['model_state_dict'])
130+
optimizer2.load_state_dict(loaded_checkpoint['optimizer_state_dict'])
131+
132+
assert loaded_checkpoint['test_value'] == 42, "Test value mismatch!"
133+
print("Checkpoint test value verified")
134+
135+
#get loaded model validation loss
136+
model2.eval()
137+
loaded_loss = validate_one_epoch(model2, val_loader, criterion, device)
138+
print(f"Loaded model validation loss: {loaded_loss:.6f}")
139+
140+
# check if losses match
141+
loss_diff = abs(initial_loss - loaded_loss)
142+
success = loss_diff < 1e-6
143+
if success:
144+
print(f"✓ Checkpoint test PASSED (loss difference: {loss_diff:.2e})")
145+
else:
146+
print(f"✗ Checkpoint test FAILED (loss difference: {loss_diff:.2e})")
147+
148+
if os.path.exists(test_checkpoint_path):
149+
os.remove(test_checkpoint_path)
150+
print(f"Cleaned up {test_checkpoint_path}")
151+
152+
print("="*60 + "\n")
153+
return success

0 commit comments

Comments
 (0)