20
20
21
21
from timeit import default_timer as timer
22
22
23
+ from ci_tests import SyntheticXPointDataset , test_checkpoint_functionality
24
+
23
25
def expand_xpoints_mask (binary_mask , kernel_size = 9 ):
24
26
"""
25
27
Expands each X-point in a binary mask to include surrounding cells
@@ -668,10 +670,19 @@ def parseCommandLineArgs():
668
670
help = 'create figures of the ground truth X-points and model identified X-points' )
669
671
parser .add_argument ('--plotDir' , type = Path , default = "./plots" ,
670
672
help = 'directory where figures are written' )
673
+
674
+ # CI TEST: Add smoke test flag
675
+ parser .add_argument ('--smoke-test' , action = 'store_true' ,
676
+ help = 'Run a minimal smoke test for CI (overrides other parameters)' )
677
+
671
678
args = parser .parse_args ()
672
679
return args
673
680
674
681
def checkCommandLineArgs (args ):
682
+ # CI TEST: Skip file checks in smoke test mode
683
+ if args .smoke_test :
684
+ return
685
+
675
686
if args .xptCacheDir != None :
676
687
if not args .xptCacheDir .is_dir ():
677
688
print (f"Xpoint cache directory { args .xptCacheDir } does not exist. "
@@ -801,6 +812,32 @@ def load_model_checkpoint(model, optimizer, checkpoint_path):
801
812
802
813
def main ():
803
814
args = parseCommandLineArgs ()
815
+
816
+ # CI TEST: Override parameters for smoke test
817
+ if args .smoke_test :
818
+ print ("=" * 60 )
819
+ print ("RUNNING IN SMOKE TEST MODE FOR CI" )
820
+ print ("=" * 60 )
821
+
822
+ # Override with minimal parameters
823
+ args .epochs = 5
824
+ args .batchSize = 1
825
+ args .trainFrameFirst = 1
826
+ args .trainFrameLast = 11 # 10 frames for training
827
+ args .validationFrameFirst = 11
828
+ args .validationFrameLast = 12 # 1 frame for validation
829
+ args .plot = False # Disable plotting for CI
830
+ args .checkPointFrequency = 2 # Save more frequently
831
+ args .minTrainingLoss = 0 # Don't fail on convergence in smoke test
832
+
833
+ print ("Smoke test parameters:" )
834
+ print (f" - Training frames: { args .trainFrameFirst } to { args .trainFrameLast - 1 } " )
835
+ print (f" - Validation frames: { args .validationFrameFirst } to { args .validationFrameLast - 1 } " )
836
+ print (f" - Epochs: { args .epochs } " )
837
+ print (f" - Batch size: { args .batchSize } " )
838
+ print (f" - Plotting disabled" )
839
+ print ("=" * 60 )
840
+
804
841
checkCommandLineArgs (args )
805
842
printCommandLineArgs (args )
806
843
@@ -809,13 +846,22 @@ def main():
809
846
os .makedirs (outDir , exist_ok = True )
810
847
811
848
t0 = timer ()
812
- train_fnums = range (args .trainFrameFirst , args .trainFrameLast )
813
- val_fnums = range (args .validationFrameFirst , args .validationFrameLast )
849
+
850
+ # CI TEST: Use synthetic data for smoke test
851
+ if args .smoke_test :
852
+ print ("\n Using synthetic data for smoke test..." )
853
+ train_dataset = SyntheticXPointDataset (nframes = 10 , shape = (64 , 64 ), nxpoints = 3 )
854
+ val_dataset = SyntheticXPointDataset (nframes = 1 , shape = (64 , 64 ), nxpoints = 3 , seed = 123 )
855
+ print (f"Created synthetic datasets: { len (train_dataset )} train, { len (val_dataset )} val frames" )
856
+ else :
857
+ # Original data loading
858
+ train_fnums = range (args .trainFrameFirst , args .trainFrameLast )
859
+ val_fnums = range (args .validationFrameFirst , args .validationFrameLast )
814
860
815
- train_dataset = XPointDataset (args .paramFile , train_fnums ,
816
- xptCacheDir = args .xptCacheDir , rotateAndReflect = True )
817
- val_dataset = XPointDataset (args .paramFile , val_fnums ,
818
- xptCacheDir = args .xptCacheDir )
861
+ train_dataset = XPointDataset (args .paramFile , train_fnums ,
862
+ xptCacheDir = args .xptCacheDir , rotateAndReflect = True )
863
+ val_dataset = XPointDataset (args .paramFile , val_fnums ,
864
+ xptCacheDir = args .xptCacheDir )
819
865
820
866
t1 = timer ()
821
867
print ("time (s) to create gkyl data loader: " + str (t1 - t0 ))
@@ -838,7 +884,7 @@ def main():
838
884
train_loss = []
839
885
val_loss = []
840
886
841
- if os .path .exists (latest_checkpoint_path ):
887
+ if os .path .exists (latest_checkpoint_path ) and not args . smoke_test :
842
888
model , optimizer , start_epoch , train_loss , val_loss = load_model_checkpoint (
843
889
model , optimizer , latest_checkpoint_path
844
890
)
@@ -862,6 +908,65 @@ def main():
862
908
plot_training_history (train_loss , val_loss )
863
909
print ("time (s) to train model: " + str (timer ()- t2 ))
864
910
911
+ # CI TEST: Run additional tests if in smoke test mode
912
+ if args .smoke_test :
913
+ print ("\n " + "=" * 60 )
914
+ print ("SMOKE TEST: Running additional CI tests" )
915
+ print ("=" * 60 )
916
+
917
+ # Test 1: Checkpoint save/load
918
+ checkpoint_test_passed = test_checkpoint_functionality (
919
+ model , optimizer , device , val_loader , criterion , None , UNet , optim .Adam
920
+ )
921
+
922
+ # Test 2: Inference test
923
+ print ("Running inference test..." )
924
+ model .eval ()
925
+ with torch .no_grad ():
926
+ # Get one batch
927
+ test_batch = next (iter (val_loader ))
928
+ test_input = test_batch ["all" ].to (device )
929
+ test_output = model (test_input )
930
+
931
+ # Apply sigmoid to get probabilities
932
+ test_probs = torch .sigmoid (test_output )
933
+
934
+ print (f"Input shape: { test_input .shape } " )
935
+ print (f"Output shape: { test_output .shape } " )
936
+ print (f"Output range (logits): [{ test_output .min ():.3f} , { test_output .max ():.3f} ]" )
937
+ print (f"Output range (probs): [{ test_probs .min ():.3f} , { test_probs .max ():.3f} ]" )
938
+ print (f"Predicted X-points: { (test_probs > 0.5 ).sum ().item ()} pixels" )
939
+
940
+ # Test 3: Check if model learned anything
941
+ initial_train_loss = train_loss [0 ] if train_loss else float ('inf' )
942
+ final_train_loss = train_loss [- 1 ] if train_loss else float ('inf' )
943
+
944
+ print (f"\n Training progress:" )
945
+ print (f"Initial loss: { initial_train_loss :.6f} " )
946
+ print (f"Final loss: { final_train_loss :.6f} " )
947
+
948
+ if final_train_loss < initial_train_loss :
949
+ print ("✓ Model showed improvement during training" )
950
+ training_improved = True
951
+ else :
952
+ print ("✗ Model did not improve during training" )
953
+ training_improved = False
954
+
955
+ # Overall smoke test result
956
+ print ("\n " + "=" * 60 )
957
+ print ("SMOKE TEST SUMMARY" )
958
+ print ("=" * 60 )
959
+ print (f"Checkpoint test: { 'PASSED' if checkpoint_test_passed else 'FAILED' } " )
960
+ print (f"Training improvement: { 'YES' if training_improved else 'NO' } " )
961
+ print (f"Overall result: { 'PASSED' if checkpoint_test_passed else 'FAILED' } " )
962
+ print ("=" * 60 )
963
+
964
+ # Return appropriate exit code for CI
965
+ if not checkpoint_test_passed :
966
+ return 1
967
+ else :
968
+ return 0
969
+
865
970
requiredLossDecreaseMagnitude = args .minTrainingLoss
866
971
if np .log10 (abs (train_loss [0 ]/ train_loss [- 1 ])) < requiredLossDecreaseMagnitude :
867
972
print (f"TrainLoss reduced by less than { requiredLossDecreaseMagnitude } orders of magnitude: "
@@ -874,8 +979,12 @@ def main():
874
979
interpFac = 1
875
980
876
981
# Evaluate on combined set for demonstration. Exam this part to see if save to remove
877
- full_fnums = list (train_fnums ) + list (val_fnums )
878
- full_dataset = [train_dataset , val_dataset ]
982
+ if not args .smoke_test :
983
+ train_fnums = range (args .trainFrameFirst , args .trainFrameLast )
984
+ val_fnums = range (args .validationFrameFirst , args .validationFrameLast )
985
+ full_dataset = [train_dataset , val_dataset ]
986
+ else :
987
+ full_dataset = [val_dataset ] # Only use validation data for smoke test
879
988
880
989
t4 = timer ()
881
990
0 commit comments