Skip to content

feat: add latent encoder/decoder infrastructure for Graph-EFM port#648

Draft
Sir-Sloth-The-Lazy wants to merge 41 commits into
mllam:mainfrom
Sir-Sloth-The-Lazy:feat/latent-encoder-decoder-infra
Draft

feat: add latent encoder/decoder infrastructure for Graph-EFM port#648
Sir-Sloth-The-Lazy wants to merge 41 commits into
mllam:mainfrom
Sir-Sloth-The-Lazy:feat/latent-encoder-decoder-infra

Conversation

@Sir-Sloth-The-Lazy

@Sir-Sloth-The-Lazy Sir-Sloth-The-Lazy commented May 27, 2026

Copy link
Copy Markdown
Contributor

Describe your changes

Adds neural_lam/models/latent/ the encoder and decoder submodules that the probabilistic Graph-EFM model needs. This is infrastructure-only: no model uses these classes yet. They are consumed by the upcoming GraphEFMPredictor (StepPredictor subclass) which will close #62.

New modules in neural_lam/models/latent/:

File What it adds
base_encoder.py BaseLatentEncoder : abstract base; handles isotropic / diagonal Gaussian output
base_decoder.py BaseGraphLatentDecoder : abstract base; residual grid MLP + latent embedder + param map
constant_encoder.py ConstantLatentEncoder : input-independent prior (used when learn_prior=False)
graph_encoder.py GraphLatentEncoder : flat graph: grid → mesh via PropagationNet + InteractionNet stack
graph_decoder.py GraphLatentDecoder : flat graph: grid + latent → grid via g2m / processor / m2g
hi_graph_encoder.py HiGraphLatentEncoder : hierarchical mesh: propagates up to top level, reads out latent dist
hi_graph_decoder.py HiGraphLatentDecoder : hierarchical mesh: up + latent fusion + down pass back to grid

Adaptations from prob_model_lam for the current main architecture:

  • constants.GRID_STATE_DIM (removed) → num_state_vars constructor arg on all decoders
  • from neural_lam.interaction_net import ...from neural_lam.gnn_layers import ...
  • GraphLatentDecoder.processor unified with the other four GNN-seq constructions to use utils.make_gnn_seq, which also handles processor_layers=0 gracefully
  • HiGraph{Encoder,Decoder} raise ValueError for single-level meshes (where the latent would be silently ignored); points users to the flat variants

Also adds to neural_lam/utils.py:

  • IdentityModule pass-through nn.Module for multi-arg pyg.nn.Sequential pipelines
  • make_gnn_seq builds a pyg.nn.Sequential of InteractionNet layers, or IdentityModule when num_gnn_layers=0; lazy-imports gnn_layers to avoid the existing gnn_layers → utils circular dependency

Open question flagged in constant_encoder.py docstring: the static prior returns Normal(mean=1, std=1) faithful to prob_model_lam but the --learn_prior CLI help on that branch describes it as "mean 0". One of the two is wrong; will raise it separately with @joeloskarsson.

Dependencies:

Issue Link

Partially addresses #62 (prerequisite for the GraphEFMPredictor PR).

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form
  • I have requested a reviewer and an assignee

Checklist for reviewers

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • the PR is assigned to the next milestone
  • author has added an entry to the changelog

Sir-Sloth-The-Lazy and others added 30 commits February 21, 2026 17:42
- Update test_datasets.py to use ForecasterModule instead of GraphLAM
- Update test_plotting.py to use ForecasterModule instead of GraphLAM
- Fix interior_mask_bool property shape (1,) -> (N,) for correct loss masking
- Fix all_gather_cat to handle single-device runs without incorrect dim collapse
…r hierarchy

- Replace opaque argparse.Namespace with explicit keyword arguments in
  StepPredictor, BaseGraphModel, BaseHiGraphModel, GraphLAM, HiLAM,
  and HiLAMParallel __init__ methods
- Reorder methods in step_predictor.py: forward/expand_to_batch now
  appear before clamping methods
- Update all instantiation sites (train_model.py, test_training.py,
  test_prediction_model_classes.py) to pass explicit kwargs
- HiLAM helper methods (make_same/up/down_gnns) now use self.hidden_dim
  and self.hidden_layers instead of args parameter

Addresses review comments on PR mllam#208.
- Rename border to boundary in Forecaster
- Pass Forecaster object to ForecasterModule init instead of Predictor
- Remove inline imports in ForecasterModule
- Move loss-related pred_std logic fully into ForecasterModule
- Delete obsolete test_refactored_hierarchy.py
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
- Add predicts_std property to StepPredictor, Forecaster and ARForecaster
  so ForecasterModule can query the forecaster instead of taking output_std
  as a separate constructor argument
- Remove output_std parameter from ForecasterModule; use
  self._forecaster.predicts_std throughout
- Move fallback per_var_std logic out of forecast_for_batch into each
  step method so pred_std is None before fallback, enabling direct None
  checks instead of hparam checks
- Replace len(datastore.boundary_mask) with datastore.num_grid_points in
  StepPredictor to avoid relying on boundary_mask
- Move get_state_feature_weighting and ARForecaster inline imports to
  module-level imports in forecaster_module.py and train_model.py
- Fix statement ordering in StepPredictor.__init__ so register_buffer for
  grid_static_features appears directly after building the tensor
- Replace dict+loop pattern for registering state_mean/state_std buffers
  with two direct register_buffer calls
- Remove all internal Item N checklist references from comments
- Remove TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD env var hack; pass
  weights_only=False explicitly to load_from_checkpoint calls and
  weights_only=True to torch.load in test_graph_creation.py
- Add test_step_predictor_no_static_features to verify models initialise
  and run correctly when the datastore returns None for static features
- Fix graph= -> graph_name= and model.forecaster -> model._forecaster in
  tests to match current API
…r_batch

Makes the forecasting path tolerant to batch-folded execution so that
future ensemble generation can fold (S, B) into (S*B) before calling
ARForecaster, without any changes to ARForecaster or StepPredictor.

Prediction is kept folded through the existing deterministic logging and
aggregation paths so all dim assumptions in training_step, validation_step,
and test_step remain correct. Unfolding to (*leading, T, N, F) is deferred
to ensemble-specific subclasses (e.g. EnsForecasterModule).

Adds test_fold_unfold_equivalence to confirm ARForecaster's rollout is
rank-transparent under a pre-entry fold.
Remove the thin forecast_for_batch wrapper, inlining the batch
unpacking and forecaster call into each step method. Also fix
_forecaster -> forecaster attribute references in checkpoint
loading and tests to match the actual attribute name.
…sue-49

Resolved conflicts:
- Removed ar_model.py (replaced by ForecasterModule/ARForecaster hierarchy)
- Updated test_training.py: use ForecasterModule imports, dynamic device
  allocation, and wandb.init(mode="disabled")
- Updated test_plotting.py: replace ARModel.all_gather_cat with
  ForecasterModule.all_gather_cat, fix model.args.create_gif reference

Ported features from ar_model.py:
- GIF export (create_gif param + plot_examples logic) into ForecasterModule
- Added common_step to ForecasterModule; refactored training/validation/
  test steps to delegate to it, eliminating duplicated forecaster calls
Resolved conflicts between local model class hierarchy refactor and upstream's
equivalent refactor (mllam#208). Upstream's version removes the config parameter
from model constructors, passing output_clamping_lower/upper explicitly instead,
and adds create_gif/args to ForecasterModule. Also includes upstream's new
load_forecaster_module_from_checkpoint helper and additional test coverage.
Adds neural_lam/models/latent/ with the encoder and decoder submodules
needed by the probabilistic GraphEFM model (issue mllam#62). Ported from the
prob_model_lam branch with adaptations for the current main architecture:

- constants.GRID_STATE_DIM replaced by a num_state_vars constructor arg
- interaction_net imports updated to neural_lam.gnn_layers
- GraphLatentDecoder.processor unified with the other four GNN-seq
  constructions to use utils.make_gnn_seq (handles processor_layers=0)
- HiGraph{Encoder,Decoder} guard against single-level meshes where the
  latent variable would be silently ignored
- ConstantLatentEncoder docstring documents the N(1,1) vs N(0,1)
  discrepancy with the prob_model_lam CLI help (open question upstream)

Also adds to neural_lam/utils.py:
- IdentityModule: pass-through nn.Module for multi-arg sequential GNNs
- make_gnn_seq: builds a pyg.nn.Sequential of InteractionNets, or an
  IdentityModule when num_gnn_layers=0; lazy-imports gnn_layers to
  avoid the existing gnn_layers -> utils circular dependency

17 tests in tests/test_latent_modules.py cover output shapes,
distribution properties, backprop to every parameter, 2- and 3-level
hierarchical graphs, intra_level_layers=0, and the single-level guard.
@Sir-Sloth-The-Lazy

Copy link
Copy Markdown
Contributor Author

Pinging @joeloskarsson !

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this (and also forecaster, ar_forecaster and step_predictor) now added as duplicate files here? A bit confused about these additions? These files are placed differently on main?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is branched of the latest main , I don't think these are duplicates , however will tell you concertely in this weeks meeting.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, if I go to the "files changed" tab then I see these files as added in this PR? This is what confuses me.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a discussion to be had if we want to use the encoder/decoder terminology here, since this means something different than the encoder/decoder used in the encode-process-decode framework of the current deterministic models. I think this should be fine, since these are explicitly Latent{Encoder, Decoder}, and we don't have any other classes called encoder/decoder right now. But wanted to make note of this, and open to other ideas.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have this noted and will bring this up in future discussions , if ambiguity comes up !

@joeloskarsson joeloskarsson added this to the v0.7.0 (proposed) milestone Jun 4, 2026

@joeloskarsson joeloskarsson left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a look over the encoders now, decoders still TODO :)

):
super().__init__(latent_dim, output_dist)

self.g2m_gnn = PropagationNet(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think also here the gnn type should be set with the corresponding argparse flag.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've switched it to resolve the GNN via get_gnn_class() and added a g2m_gnn_type parameter wired to the existing --g2m_gnn_type flag, matching the deterministic models. Two things I'd like your opinion on before I finalize:

Default value. The original prob_model_lam encoder used PropagationNet for the g2m step, but --g2m_gnn_type defaults to InteractionNet. I've defaulted the new parameter to InteractionNet for consistency with the flag and the rest of the codebase, which does change the encoder's default behavior vs. the original port. Are you happy with that, or would you prefer keeping PropagationNet as the default here?

Shared vs. separate flag. Wiring this to --g2m_gnn_type means the latent encoder's g2m GNN and the step predictor's g2m GNN share a single flag and can't be configured independently. In the original Graph-EFM these differed (predictor g2m = InteractionNet, encoder g2m = PropagationNet). Do you think sharing the one flag is fine, or should the encoder get its own flag (e.g. --latent_g2m_gnn_type) to preserve that independence?

g2m_edge_index,
m2m_edge_index,
hidden_dim,
processor_layers,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like a strange argument name to me. We don't really have a "processor" here, as this is all an encoder.

Comment on lines +12 to +19
Used as a non-learned prior in ``GraphEFM`` when ``learn_prior`` is
disabled. ``compute_dist_params`` returns a tensor of ones, so the
resulting Normal is ``Normal(mean=1, std=1)`` for ``output_dist=
"isotropic"`` and ``Normal(mean=1, std=softplus(1)+eps)`` for
``output_dist="diagonal"``. (Note: the ``train_model.py`` CLI help on
``prob_model_lam`` describes this prior as "mean 0"; the code itself
has always produced mean 1. Preserved as-is during the port for
behavioral parity — open question for upstream.)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be mean 0, the mean 1 is probably just a bug. Does not really change anything, just a constant offset, but makes more sense to use 0 mean.

Comment thread neural_lam/models/latent/constant_encoder.py Outdated
Comment thread neural_lam/utils.py Outdated
num_gnn_layers is 0.
"""
# First-party
from neural_lam.gnn_layers import InteractionNet

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this general enough to not just support Inets?

Comment thread neural_lam/utils.py Outdated
Comment on lines +409 to +410
if num_gnn_layers == 0:
return IdentityModule()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like not the best place to handle this. If we call this to build a gnn sequence, then we should at least request one layer. Otherwise there should be a check that makes this not being called at all.

Comment on lines +43 to +60
self.g2m_gnn = PropagationNet(
g2m_edge_index,
hidden_dim,
hidden_layers=hidden_layers,
update_edges=False,
)

self.mesh_up_gnns = nn.ModuleList(
[
PropagationNet(
edge_index,
hidden_dim,
hidden_layers=hidden_layers,
update_edges=False,
)
for edge_index in mesh_up_edge_index
]
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also here use argparse options for choosing the gnn types.

@joeloskarsson joeloskarsson self-assigned this Jun 4, 2026
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
@Sir-Sloth-The-Lazy

Sir-Sloth-The-Lazy commented Jun 5, 2026

Copy link
Copy Markdown
Contributor Author

On it !

Make GNN types configurable and tidy up the latent modules per PR review:

- make_gnn_seq: accept a gnn_type arg (resolved via get_gnn_class) so it is
  not limited to InteractionNet, and make it strict (raise on
  num_gnn_layers < 1) instead of silently returning an IdentityModule;
  callers now own the no-op (identity) case explicitly.
- graph/hi encoders and decoders: expose g2m/m2g/mesh_up/mesh_down gnn_type
  parameters wired to get_gnn_class, with defaults matching prob_model_lam.
- graph encoder/decoder: rename processor_layers -> m2m_layers (and the
  self.processor attribute -> self.m2m_gnns); "processor" was misleading in
  an encoder/decoder context.
- ConstantLatentEncoder: return zeros instead of ones so the static prior is
  mean 0 (fixes the prob_model_lam mean-1 bug; matches its own CLI help).
- tests: update for the renamed arg and strict make_gnn_seq, add coverage for
  the flat zero-m2m identity path, and assert the constant prior is N(0, 1).
@Sir-Sloth-The-Lazy

Copy link
Copy Markdown
Contributor Author

I have pushed the latest changes , please have a look whenever you have time

Port prob_model_lam's GraphEFM single-step half onto the StepPredictor
interface, reusing the latent encoder/decoder infra. The predictor owns
its conditional prior, variational encoder, and latent decoder, plus the
per-step ELBO pieces (compute_step_loss) and sampling helpers; rollout,
ELBO assembly, ensemble logic, and logging stay outside it.

- forward is source's predict_step (prior rsample -> decode -> sampled
  next state); no rescaling/clamping
- loss_fn and interior_mask are threaded parameters, not predictor state;
  compute_step_loss takes compute_kl (kl_term=None when off)
- per_var_std mirrors ForecasterModule's formula, hence the config arg
- one class for flat + hierarchical meshes, resolved from self.hierarchical
- not registered in MODELS yet (needs config / no mesh_aggr); config-aware
  assembly deferred to the ensemble-forecaster PR

Adds tests/test_graph_efm_predictor.py covering forward shapes, output_std,
compute_step_loss + KL toggle, differentiability, member stochasticity,
sample_obs_noise, and the per_var_std formula (flat + hierarchical).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Merge Graph-EFM model from prob_model_lam branch

2 participants