feat: add latent encoder/decoder infrastructure for Graph-EFM port#648
feat: add latent encoder/decoder infrastructure for Graph-EFM port#648Sir-Sloth-The-Lazy wants to merge 41 commits into
Conversation
- Update test_datasets.py to use ForecasterModule instead of GraphLAM - Update test_plotting.py to use ForecasterModule instead of GraphLAM - Fix interior_mask_bool property shape (1,) -> (N,) for correct loss masking - Fix all_gather_cat to handle single-device runs without incorrect dim collapse
…r hierarchy - Replace opaque argparse.Namespace with explicit keyword arguments in StepPredictor, BaseGraphModel, BaseHiGraphModel, GraphLAM, HiLAM, and HiLAMParallel __init__ methods - Reorder methods in step_predictor.py: forward/expand_to_batch now appear before clamping methods - Update all instantiation sites (train_model.py, test_training.py, test_prediction_model_classes.py) to pass explicit kwargs - HiLAM helper methods (make_same/up/down_gnns) now use self.hidden_dim and self.hidden_layers instead of args parameter Addresses review comments on PR mllam#208.
- Rename border to boundary in Forecaster - Pass Forecaster object to ForecasterModule init instead of Predictor - Remove inline imports in ForecasterModule - Move loss-related pred_std logic fully into ForecasterModule - Delete obsolete test_refactored_hierarchy.py
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
- Add predicts_std property to StepPredictor, Forecaster and ARForecaster so ForecasterModule can query the forecaster instead of taking output_std as a separate constructor argument - Remove output_std parameter from ForecasterModule; use self._forecaster.predicts_std throughout - Move fallback per_var_std logic out of forecast_for_batch into each step method so pred_std is None before fallback, enabling direct None checks instead of hparam checks - Replace len(datastore.boundary_mask) with datastore.num_grid_points in StepPredictor to avoid relying on boundary_mask - Move get_state_feature_weighting and ARForecaster inline imports to module-level imports in forecaster_module.py and train_model.py - Fix statement ordering in StepPredictor.__init__ so register_buffer for grid_static_features appears directly after building the tensor - Replace dict+loop pattern for registering state_mean/state_std buffers with two direct register_buffer calls - Remove all internal Item N checklist references from comments - Remove TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD env var hack; pass weights_only=False explicitly to load_from_checkpoint calls and weights_only=True to torch.load in test_graph_creation.py - Add test_step_predictor_no_static_features to verify models initialise and run correctly when the datastore returns None for static features - Fix graph= -> graph_name= and model.forecaster -> model._forecaster in tests to match current API
…r_batch Makes the forecasting path tolerant to batch-folded execution so that future ensemble generation can fold (S, B) into (S*B) before calling ARForecaster, without any changes to ARForecaster or StepPredictor. Prediction is kept folded through the existing deterministic logging and aggregation paths so all dim assumptions in training_step, validation_step, and test_step remain correct. Unfolding to (*leading, T, N, F) is deferred to ensemble-specific subclasses (e.g. EnsForecasterModule). Adds test_fold_unfold_equivalence to confirm ARForecaster's rollout is rank-transparent under a pre-entry fold.
…ecast_for_batch" This reverts commit e659f7a.
Remove the thin forecast_for_batch wrapper, inlining the batch unpacking and forecaster call into each step method. Also fix _forecaster -> forecaster attribute references in checkpoint loading and tests to match the actual attribute name.
…sue-49 Resolved conflicts: - Removed ar_model.py (replaced by ForecasterModule/ARForecaster hierarchy) - Updated test_training.py: use ForecasterModule imports, dynamic device allocation, and wandb.init(mode="disabled") - Updated test_plotting.py: replace ARModel.all_gather_cat with ForecasterModule.all_gather_cat, fix model.args.create_gif reference Ported features from ar_model.py: - GIF export (create_gif param + plot_examples logic) into ForecasterModule - Added common_step to ForecasterModule; refactored training/validation/ test steps to delegate to it, eliminating duplicated forecaster calls
Resolved conflicts between local model class hierarchy refactor and upstream's equivalent refactor (mllam#208). Upstream's version removes the config parameter from model constructors, passing output_clamping_lower/upper explicitly instead, and adds create_gif/args to ForecasterModule. Also includes upstream's new load_forecaster_module_from_checkpoint helper and additional test coverage.
Adds neural_lam/models/latent/ with the encoder and decoder submodules needed by the probabilistic GraphEFM model (issue mllam#62). Ported from the prob_model_lam branch with adaptations for the current main architecture: - constants.GRID_STATE_DIM replaced by a num_state_vars constructor arg - interaction_net imports updated to neural_lam.gnn_layers - GraphLatentDecoder.processor unified with the other four GNN-seq constructions to use utils.make_gnn_seq (handles processor_layers=0) - HiGraph{Encoder,Decoder} guard against single-level meshes where the latent variable would be silently ignored - ConstantLatentEncoder docstring documents the N(1,1) vs N(0,1) discrepancy with the prob_model_lam CLI help (open question upstream) Also adds to neural_lam/utils.py: - IdentityModule: pass-through nn.Module for multi-arg sequential GNNs - make_gnn_seq: builds a pyg.nn.Sequential of InteractionNets, or an IdentityModule when num_gnn_layers=0; lazy-imports gnn_layers to avoid the existing gnn_layers -> utils circular dependency 17 tests in tests/test_latent_modules.py cover output shapes, distribution properties, backprop to every parameter, 2- and 3-level hierarchical graphs, intra_level_layers=0, and the single-level guard.
|
Pinging @joeloskarsson ! |
There was a problem hiding this comment.
Is this (and also forecaster, ar_forecaster and step_predictor) now added as duplicate files here? A bit confused about these additions? These files are placed differently on main?
There was a problem hiding this comment.
This is branched of the latest main , I don't think these are duplicates , however will tell you concertely in this weeks meeting.
There was a problem hiding this comment.
Hmm, if I go to the "files changed" tab then I see these files as added in this PR? This is what confuses me.
There was a problem hiding this comment.
I think there is a discussion to be had if we want to use the encoder/decoder terminology here, since this means something different than the encoder/decoder used in the encode-process-decode framework of the current deterministic models. I think this should be fine, since these are explicitly Latent{Encoder, Decoder}, and we don't have any other classes called encoder/decoder right now. But wanted to make note of this, and open to other ideas.
There was a problem hiding this comment.
I have this noted and will bring this up in future discussions , if ambiguity comes up !
joeloskarsson
left a comment
There was a problem hiding this comment.
Had a look over the encoders now, decoders still TODO :)
| ): | ||
| super().__init__(latent_dim, output_dist) | ||
|
|
||
| self.g2m_gnn = PropagationNet( |
There was a problem hiding this comment.
I think also here the gnn type should be set with the corresponding argparse flag.
There was a problem hiding this comment.
I've switched it to resolve the GNN via get_gnn_class() and added a g2m_gnn_type parameter wired to the existing --g2m_gnn_type flag, matching the deterministic models. Two things I'd like your opinion on before I finalize:
Default value. The original prob_model_lam encoder used PropagationNet for the g2m step, but --g2m_gnn_type defaults to InteractionNet. I've defaulted the new parameter to InteractionNet for consistency with the flag and the rest of the codebase, which does change the encoder's default behavior vs. the original port. Are you happy with that, or would you prefer keeping PropagationNet as the default here?
Shared vs. separate flag. Wiring this to --g2m_gnn_type means the latent encoder's g2m GNN and the step predictor's g2m GNN share a single flag and can't be configured independently. In the original Graph-EFM these differed (predictor g2m = InteractionNet, encoder g2m = PropagationNet). Do you think sharing the one flag is fine, or should the encoder get its own flag (e.g. --latent_g2m_gnn_type) to preserve that independence?
| g2m_edge_index, | ||
| m2m_edge_index, | ||
| hidden_dim, | ||
| processor_layers, |
There was a problem hiding this comment.
This feels like a strange argument name to me. We don't really have a "processor" here, as this is all an encoder.
| Used as a non-learned prior in ``GraphEFM`` when ``learn_prior`` is | ||
| disabled. ``compute_dist_params`` returns a tensor of ones, so the | ||
| resulting Normal is ``Normal(mean=1, std=1)`` for ``output_dist= | ||
| "isotropic"`` and ``Normal(mean=1, std=softplus(1)+eps)`` for | ||
| ``output_dist="diagonal"``. (Note: the ``train_model.py`` CLI help on | ||
| ``prob_model_lam`` describes this prior as "mean 0"; the code itself | ||
| has always produced mean 1. Preserved as-is during the port for | ||
| behavioral parity — open question for upstream.) |
There was a problem hiding this comment.
Should be mean 0, the mean 1 is probably just a bug. Does not really change anything, just a constant offset, but makes more sense to use 0 mean.
| num_gnn_layers is 0. | ||
| """ | ||
| # First-party | ||
| from neural_lam.gnn_layers import InteractionNet |
There was a problem hiding this comment.
Can we make this general enough to not just support Inets?
| if num_gnn_layers == 0: | ||
| return IdentityModule() |
There was a problem hiding this comment.
This feels like not the best place to handle this. If we call this to build a gnn sequence, then we should at least request one layer. Otherwise there should be a check that makes this not being called at all.
| self.g2m_gnn = PropagationNet( | ||
| g2m_edge_index, | ||
| hidden_dim, | ||
| hidden_layers=hidden_layers, | ||
| update_edges=False, | ||
| ) | ||
|
|
||
| self.mesh_up_gnns = nn.ModuleList( | ||
| [ | ||
| PropagationNet( | ||
| edge_index, | ||
| hidden_dim, | ||
| hidden_layers=hidden_layers, | ||
| update_edges=False, | ||
| ) | ||
| for edge_index in mesh_up_edge_index | ||
| ] | ||
| ) |
There was a problem hiding this comment.
Can we also here use argparse options for choosing the gnn types.
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
|
On it ! |
Make GNN types configurable and tidy up the latent modules per PR review: - make_gnn_seq: accept a gnn_type arg (resolved via get_gnn_class) so it is not limited to InteractionNet, and make it strict (raise on num_gnn_layers < 1) instead of silently returning an IdentityModule; callers now own the no-op (identity) case explicitly. - graph/hi encoders and decoders: expose g2m/m2g/mesh_up/mesh_down gnn_type parameters wired to get_gnn_class, with defaults matching prob_model_lam. - graph encoder/decoder: rename processor_layers -> m2m_layers (and the self.processor attribute -> self.m2m_gnns); "processor" was misleading in an encoder/decoder context. - ConstantLatentEncoder: return zeros instead of ones so the static prior is mean 0 (fixes the prob_model_lam mean-1 bug; matches its own CLI help). - tests: update for the renamed arg and strict make_gnn_seq, add coverage for the flat zero-m2m identity path, and assert the constant prior is N(0, 1).
|
I have pushed the latest changes , please have a look whenever you have time |
Port prob_model_lam's GraphEFM single-step half onto the StepPredictor interface, reusing the latent encoder/decoder infra. The predictor owns its conditional prior, variational encoder, and latent decoder, plus the per-step ELBO pieces (compute_step_loss) and sampling helpers; rollout, ELBO assembly, ensemble logic, and logging stay outside it. - forward is source's predict_step (prior rsample -> decode -> sampled next state); no rescaling/clamping - loss_fn and interior_mask are threaded parameters, not predictor state; compute_step_loss takes compute_kl (kl_term=None when off) - per_var_std mirrors ForecasterModule's formula, hence the config arg - one class for flat + hierarchical meshes, resolved from self.hierarchical - not registered in MODELS yet (needs config / no mesh_aggr); config-aware assembly deferred to the ensemble-forecaster PR Adds tests/test_graph_efm_predictor.py covering forward shapes, output_std, compute_step_loss + KL toggle, differentiability, member stochasticity, sample_obs_noise, and the per_var_std formula (flat + hierarchical).
Describe your changes
Adds
neural_lam/models/latent/the encoder and decoder submodules that the probabilistic Graph-EFM model needs. This is infrastructure-only: no model uses these classes yet. They are consumed by the upcomingGraphEFMPredictor(StepPredictorsubclass) which will close #62.New modules in
neural_lam/models/latent/:base_encoder.pyBaseLatentEncoder: abstract base; handles isotropic / diagonal Gaussian outputbase_decoder.pyBaseGraphLatentDecoder: abstract base; residual grid MLP + latent embedder + param mapconstant_encoder.pyConstantLatentEncoder: input-independent prior (used whenlearn_prior=False)graph_encoder.pyGraphLatentEncoder: flat graph: grid → mesh via PropagationNet + InteractionNet stackgraph_decoder.pyGraphLatentDecoder: flat graph: grid + latent → grid via g2m / processor / m2ghi_graph_encoder.pyHiGraphLatentEncoder: hierarchical mesh: propagates up to top level, reads out latent disthi_graph_decoder.pyHiGraphLatentDecoder: hierarchical mesh: up + latent fusion + down pass back to gridAdaptations from
prob_model_lamfor the currentmainarchitecture:constants.GRID_STATE_DIM(removed) →num_state_varsconstructor arg on all decodersfrom neural_lam.interaction_net import ...→from neural_lam.gnn_layers import ...GraphLatentDecoder.processorunified with the other four GNN-seq constructions to useutils.make_gnn_seq, which also handlesprocessor_layers=0gracefullyHiGraph{Encoder,Decoder}raiseValueErrorfor single-level meshes (where the latent would be silently ignored); points users to the flat variantsAlso adds to
neural_lam/utils.py:IdentityModulepass-throughnn.Modulefor multi-argpyg.nn.Sequentialpipelinesmake_gnn_seqbuilds apyg.nn.SequentialofInteractionNetlayers, orIdentityModulewhennum_gnn_layers=0; lazy-importsgnn_layersto avoid the existinggnn_layers → utilscircular dependencyOpen question flagged in
constant_encoder.pydocstring: the static prior returnsNormal(mean=1, std=1)faithful toprob_model_lambut the--learn_priorCLI help on that branch describes it as "mean 0". One of the two is wrong; will raise it separately with @joeloskarsson.Dependencies:
PropagationNetandInteractionNetare already onmain(AddsPropagationNetGNN layer and makes it optionally usable in existing deterministic models #507 merged).Issue Link
Partially addresses #62 (prerequisite for the
GraphEFMPredictorPR).Type of change
Checklist before requesting a review
Checklist for reviewers
Author checklist after completed review
Checklist for assignee