diff --git a/XPointMLTest.py b/XPointMLTest.py index 201f862..b0c2b5d 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -20,6 +20,8 @@ from timeit import default_timer as timer +from ci_tests import SyntheticXPointDataset, test_checkpoint_functionality + def expand_xpoints_mask(binary_mask, kernel_size=9): """ Expands each X-point in a binary mask to include surrounding cells @@ -668,10 +670,19 @@ def parseCommandLineArgs(): help='create figures of the ground truth X-points and model identified X-points') parser.add_argument('--plotDir', type=Path, default="./plots", help='directory where figures are written') + + # CI TEST: Add smoke test flag + parser.add_argument('--smoke-test', action='store_true', + help='Run a minimal smoke test for CI (overrides other parameters)') + args = parser.parse_args() return args def checkCommandLineArgs(args): + # CI TEST: Skip file checks in smoke test mode + if args.smoke_test: + return + if args.xptCacheDir != None: if not args.xptCacheDir.is_dir(): print(f"Xpoint cache directory {args.xptCacheDir} does not exist. " @@ -801,6 +812,32 @@ def load_model_checkpoint(model, optimizer, checkpoint_path): def main(): args = parseCommandLineArgs() + + # CI TEST: Override parameters for smoke test + if args.smoke_test: + print("=" * 60) + print("RUNNING IN SMOKE TEST MODE FOR CI") + print("=" * 60) + + # Override with minimal parameters + args.epochs = 5 + args.batchSize = 1 + args.trainFrameFirst = 1 + args.trainFrameLast = 11 # 10 frames for training + args.validationFrameFirst = 11 + args.validationFrameLast = 12 # 1 frame for validation + args.plot = False # Disable plotting for CI + args.checkPointFrequency = 2 # Save more frequently + args.minTrainingLoss = 0 # Don't fail on convergence in smoke test + + print("Smoke test parameters:") + print(f" - Training frames: {args.trainFrameFirst} to {args.trainFrameLast-1}") + print(f" - Validation frames: {args.validationFrameFirst} to {args.validationFrameLast-1}") + print(f" - Epochs: {args.epochs}") + print(f" - Batch size: {args.batchSize}") + print(f" - Plotting disabled") + print("=" * 60) + checkCommandLineArgs(args) printCommandLineArgs(args) @@ -809,13 +846,22 @@ def main(): os.makedirs(outDir, exist_ok=True) t0 = timer() - train_fnums = range(args.trainFrameFirst, args.trainFrameLast) - val_fnums = range(args.validationFrameFirst, args.validationFrameLast) + + # CI TEST: Use synthetic data for smoke test + if args.smoke_test: + print("\nUsing synthetic data for smoke test...") + train_dataset = SyntheticXPointDataset(nframes=10, shape=(64, 64), nxpoints=3) + val_dataset = SyntheticXPointDataset(nframes=1, shape=(64, 64), nxpoints=3, seed=123) + print(f"Created synthetic datasets: {len(train_dataset)} train, {len(val_dataset)} val frames") + else: + # Original data loading + train_fnums = range(args.trainFrameFirst, args.trainFrameLast) + val_fnums = range(args.validationFrameFirst, args.validationFrameLast) - train_dataset = XPointDataset(args.paramFile, train_fnums, - xptCacheDir=args.xptCacheDir, rotateAndReflect=True) - val_dataset = XPointDataset(args.paramFile, val_fnums, - xptCacheDir=args.xptCacheDir) + train_dataset = XPointDataset(args.paramFile, train_fnums, + xptCacheDir=args.xptCacheDir, rotateAndReflect=True) + val_dataset = XPointDataset(args.paramFile, val_fnums, + xptCacheDir=args.xptCacheDir) t1 = timer() print("time (s) to create gkyl data loader: " + str(t1-t0)) @@ -838,7 +884,7 @@ def main(): train_loss = [] val_loss = [] - if os.path.exists(latest_checkpoint_path): + if os.path.exists(latest_checkpoint_path) and not args.smoke_test: model, optimizer, start_epoch, train_loss, val_loss = load_model_checkpoint( model, optimizer, latest_checkpoint_path ) @@ -862,6 +908,65 @@ def main(): plot_training_history(train_loss, val_loss) print("time (s) to train model: " + str(timer()-t2)) + # CI TEST: Run additional tests if in smoke test mode + if args.smoke_test: + print("\n" + "="*60) + print("SMOKE TEST: Running additional CI tests") + print("="*60) + + # Test 1: Checkpoint save/load + checkpoint_test_passed = test_checkpoint_functionality( + model, optimizer, device, val_loader, criterion, None, UNet, optim.Adam + ) + + # Test 2: Inference test + print("Running inference test...") + model.eval() + with torch.no_grad(): + # Get one batch + test_batch = next(iter(val_loader)) + test_input = test_batch["all"].to(device) + test_output = model(test_input) + + # Apply sigmoid to get probabilities + test_probs = torch.sigmoid(test_output) + + print(f"Input shape: {test_input.shape}") + print(f"Output shape: {test_output.shape}") + print(f"Output range (logits): [{test_output.min():.3f}, {test_output.max():.3f}]") + print(f"Output range (probs): [{test_probs.min():.3f}, {test_probs.max():.3f}]") + print(f"Predicted X-points: {(test_probs > 0.5).sum().item()} pixels") + + # Test 3: Check if model learned anything + initial_train_loss = train_loss[0] if train_loss else float('inf') + final_train_loss = train_loss[-1] if train_loss else float('inf') + + print(f"\nTraining progress:") + print(f"Initial loss: {initial_train_loss:.6f}") + print(f"Final loss: {final_train_loss:.6f}") + + if final_train_loss < initial_train_loss: + print("✓ Model showed improvement during training") + training_improved = True + else: + print("✗ Model did not improve during training") + training_improved = False + + # Overall smoke test result + print("\n" + "="*60) + print("SMOKE TEST SUMMARY") + print("="*60) + print(f"Checkpoint test: {'PASSED' if checkpoint_test_passed else 'FAILED'}") + print(f"Training improvement: {'YES' if training_improved else 'NO'}") + print(f"Overall result: {'PASSED' if checkpoint_test_passed else 'FAILED'}") + print("="*60) + + # Return appropriate exit code for CI + if not checkpoint_test_passed: + return 1 + else: + return 0 + requiredLossDecreaseMagnitude = args.minTrainingLoss if np.log10(abs(train_loss[0]/train_loss[-1])) < requiredLossDecreaseMagnitude: print(f"TrainLoss reduced by less than {requiredLossDecreaseMagnitude} orders of magnitude: " @@ -874,8 +979,12 @@ def main(): interpFac = 1 # Evaluate on combined set for demonstration. Exam this part to see if save to remove - full_fnums = list(train_fnums) + list(val_fnums) - full_dataset = [train_dataset, val_dataset] + if not args.smoke_test: + train_fnums = range(args.trainFrameFirst, args.trainFrameLast) + val_fnums = range(args.validationFrameFirst, args.validationFrameLast) + full_dataset = [train_dataset, val_dataset] + else: + full_dataset = [val_dataset] # Only use validation data for smoke test t4 = timer() diff --git a/ci_tests.py b/ci_tests.py new file mode 100644 index 0000000..eb31555 --- /dev/null +++ b/ci_tests.py @@ -0,0 +1,153 @@ +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import os + +class SyntheticXPointDataset(Dataset): + """ + Synthetic dataset for CI testing that doesn't require actual simulation data. + Creates predictable X-point patterns for testing model training pipeline. + """ + def __init__(self, nframes=2, shape=(64, 64), nxpoints=4, seed=42): + """ + nframes: Number of synthetic frames to generate + shape: Spatial dimensions (H, W) of each frame + nxpoints: Approximate number of X-points per frame + seed: Random seed for reproducibility + """ + super().__init__() + self.nframes = nframes + self.shape = shape + self.nxpoints = nxpoints + self.rng = np.random.RandomState(seed) + + #pre-generate all frames for consistency + self.data = [] + for i in range(nframes): + frame_data = self._generate_frame(i) + self.data.append(frame_data) + + def _generate_frame(self, idx): + """Generate a single synthetic frame with X-points""" + H, W = self.shape + + #create synthetic psi field with some structure + x = np.linspace(-np.pi, np.pi, W) + y = np.linspace(-np.pi, np.pi, H) + X, Y = np.meshgrid(x, y) + + #create a field with saddle points (X-points) + psi = np.sin(X + 0.1*idx) * np.cos(Y + 0.1*idx) + \ + 0.5 * np.sin(2*X) * np.cos(2*Y) + + # add some noise + psi += 0.1 * self.rng.randn(H, W) + + #create synthetic B fields (derivatives of psi) + bx = np.gradient(psi, axis=0) + by = -np.gradient(psi, axis=1) + + #create synthetic current (Laplacian of psi) + jz = -(np.gradient(np.gradient(psi, axis=0), axis=0) + + np.gradient(np.gradient(psi, axis=1), axis=1)) + + # create X-point mask + mask = np.zeros((H, W), dtype=np.float32) + + for _ in range(self.nxpoints): + x_loc = self.rng.randint(10, W-10) + y_loc = self.rng.randint(10, H-10) + # Create 9x9 region around X-point + mask[max(0, y_loc-4):min(H, y_loc+5), + max(0, x_loc-4):min(W, x_loc+5)] = 1.0 + + #Convert to torch tensors + psi_torch = torch.from_numpy(psi.astype(np.float32)).unsqueeze(0) + bx_torch = torch.from_numpy(bx.astype(np.float32)).unsqueeze(0) + by_torch = torch.from_numpy(by.astype(np.float32)).unsqueeze(0) + jz_torch = torch.from_numpy(jz.astype(np.float32)).unsqueeze(0) + all_torch = torch.cat((psi_torch, bx_torch, by_torch, jz_torch)) + mask_torch = torch.from_numpy(mask).float().unsqueeze(0) + + x_coords = np.linspace(0, 1, W) + y_coords = np.linspace(0, 1, H) + + params = { + "axesNorm": 1.0, "plotContours": 1, "colorContours": 'k', + "numContours": 50, "axisEqual": 1, "symBar": 1, "colormap": 'bwr' + } + + return { + "fnum": idx, "rotation": 0, "reflectionAxis": -1, "psi": psi_torch, + "all": all_torch, "mask": mask_torch, "x": x_coords, "y": y_coords, + "filenameBase": f"synthetic_frame_{idx}", "params": params + } + + def __len__(self): + return self.nframes + + def __getitem__(self, idx): + return self.data[idx] + +def test_checkpoint_functionality(model, optimizer, device, val_loader, criterion, scaler, UNet, Adam): + """ + Test that model can be saved and loaded correctly. + Returns True if test passes, False otherwise. + + """ + # Import locally to prevent circular dependency + from XPointMLTest import validate_one_epoch + + print("\n" + "="*60) + print("TESTING CHECKPOINT SAVE/LOAD FUNCTIONALITY") + print("="*60) + + #get initial validation loss + model.eval() + initial_loss = validate_one_epoch(model, val_loader, criterion, device) + print(f"Initial validation loss: {initial_loss:.6f}") + + #saves checkpoint + test_checkpoint_path = "smoke_test_checkpoint.pt" + checkpoint = { + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'val_loss': initial_loss, + 'test_value': 42 + } + + torch.save(checkpoint, test_checkpoint_path) + print(f"Saved checkpoint to {test_checkpoint_path}") + + # create new model and optimizer + model2 = UNet(input_channels=4, base_channels=64).to(device) + optimizer2 = Adam(model2.parameters(), lr=1e-5) + + # load checkpoint + loaded_checkpoint = torch.load(test_checkpoint_path, map_location=device, weights_only=False) + model2.load_state_dict(loaded_checkpoint['model_state_dict']) + optimizer2.load_state_dict(loaded_checkpoint['optimizer_state_dict']) + + assert loaded_checkpoint['test_value'] == 42, "Test value mismatch!" + print("Checkpoint test value verified") + + #get loaded model validation loss + model2.eval() + loaded_loss = validate_one_epoch(model2, val_loader, criterion, device) + print(f"Loaded model validation loss: {loaded_loss:.6f}") + + # check if losses match + loss_diff = abs(initial_loss - loaded_loss) + success = loss_diff < 1e-6 + if success: + print(f"✓ Checkpoint test PASSED (loss difference: {loss_diff:.2e})") + else: + print(f"✗ Checkpoint test FAILED (loss difference: {loss_diff:.2e})") + + if os.path.exists(test_checkpoint_path): + os.remove(test_checkpoint_path) + print(f"Cleaned up {test_checkpoint_path}") + + print("="*60 + "\n") + return success \ No newline at end of file diff --git a/test_xpoint_ml.py b/test_xpoint_ml.py new file mode 100644 index 0000000..842a48d --- /dev/null +++ b/test_xpoint_ml.py @@ -0,0 +1,135 @@ +import numpy as np +import torch +from torch.utils.data import DataLoader +import torch.optim as optim +import os +import pytest + +from XPointMLTest import UNet, DiceLoss, expand_xpoints_mask, validate_one_epoch +from ci_tests import SyntheticXPointDataset + +# --- Pytest Fixtures --- +@pytest.fixture +def unet_model(): + return UNet(input_channels=4, base_channels=16) + +@pytest.fixture +def dice_loss(): + return DiceLoss() + +@pytest.fixture +def synthetic_dataset(): + return SyntheticXPointDataset(nframes=2, shape=(32, 32)) + +@pytest.fixture +def synthetic_batch(synthetic_dataset): + return synthetic_dataset[0] + +# --- 1. Unit Tests (Utils & Loss Functions) --- +def test_expand_xpoints_mask(): + mask = np.zeros((20, 20)) + mask[10, 10] = 1 + expanded = expand_xpoints_mask(mask, kernel_size=5) + assert expanded.shape == (20, 20) + assert np.sum(expanded) == 25 + assert expanded[10, 10] == 1 + assert expanded[8, 8] == 1 + assert expanded[7, 7] == 0 + +def test_dice_loss_perfect_match(dice_loss): + target = torch.ones(1, 1, 10, 10) + logits = torch.full((1, 1, 10, 10), 10.0) #large positive logits + loss = dice_loss(logits, target) + #due to smoothing factor, perfect match doesn't give exactly 0 + assert loss < 1e-4, f"Loss should be near 0, got {loss.item()}" + +def test_dice_loss_no_match(dice_loss): + target = torch.zeros(1, 1, 10, 10) + logits = torch.full((1, 1, 10, 10), 10.0) + loss = dice_loss(logits, target) + expected_loss = 1.0 - (1.0 / (100 + 1.0)) + assert torch.isclose(loss, torch.tensor(expected_loss), atol=1e-3) + +# --- 2. Dataset Integrity Test --- +def test_synthetic_dataset_integrity(synthetic_dataset): + assert len(synthetic_dataset) == 2 + item = synthetic_dataset[0] + expected_keys = ["fnum", "all", "mask", "psi", "x", "y"] + assert all(key in item for key in expected_keys) + assert item['all'].shape == (4, 32, 32) + assert item['mask'].shape == (1, 32, 32) + assert item['psi'].shape == (1, 32, 32) + assert item['all'].dtype == torch.float32 + assert item['mask'].dtype == torch.float32 + +# --- 3. Model Forward/Backward Pass Test --- +def test_model_forward_backward_pass(unet_model, synthetic_batch, dice_loss): + model = unet_model + loss_fn = dice_loss + input_tensor = synthetic_batch['all'].unsqueeze(0) + target_tensor = synthetic_batch['mask'].unsqueeze(0) + prediction = model(input_tensor) + assert prediction.shape == target_tensor.shape + loss = loss_fn(prediction, target_tensor) + assert loss.item() > 0 + loss.backward() + has_grads = any(p.grad is not None for p in model.parameters()) + assert has_grads, "No gradients were computed during the backward pass." + grad_sum = sum(p.grad.sum() for p in model.parameters() if p.grad is not None) + assert grad_sum != 0, "Gradients are all zero." + +# --- 4. Standalone checkpoint test for pytest --- +def test_checkpoint_save_load(unet_model, synthetic_dataset): + """ + Standalone pytest version of checkpoint functionality test + """ + device = torch.device("cpu") + model = unet_model.to(device) + optimizer = optim.Adam(model.parameters(), lr=1e-5) + criterion = DiceLoss() + + #create a simple dataloader + val_loader = DataLoader(synthetic_dataset, batch_size=1, shuffle=False) + + #get initial loss + initial_loss = validate_one_epoch(model, val_loader, criterion, device) + + #save checkpoint + test_checkpoint_path = "test_checkpoint_pytest.pt" + checkpoint = { + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'val_loss': initial_loss, + 'test_value': 42 + } + torch.save(checkpoint, test_checkpoint_path) + + #create new model and load + model2 = UNet(input_channels=4, base_channels=16).to(device) + optimizer2 = optim.Adam(model2.parameters(), lr=1e-5) + + loaded_checkpoint = torch.load(test_checkpoint_path, map_location=device, weights_only=False) + model2.load_state_dict(loaded_checkpoint['model_state_dict']) + optimizer2.load_state_dict(loaded_checkpoint['optimizer_state_dict']) + + assert loaded_checkpoint['test_value'] == 42 + + #get loaded model loss + loaded_loss = validate_one_epoch(model2, val_loader, criterion, device) + + #check if losses match + loss_diff = abs(initial_loss - loaded_loss) + assert loss_diff < 1e-6, f"Loss difference too large: {loss_diff}" + + #cleanup + if os.path.exists(test_checkpoint_path): + os.remove(test_checkpoint_path) + +def test_model_inference(unet_model, synthetic_batch): + model = unet_model + input_tensor = synthetic_batch['all'].unsqueeze(0) + with torch.no_grad(): + output = model(input_tensor) + assert output.shape == (1, 1, 32, 32) + assert output.dtype == torch.float32 + assert torch.isfinite(output).all() \ No newline at end of file