Perceiver encoder updated for cross-resolution fusion#702
Conversation
3392629 to
93bca31
Compare
e0605e6 to
166410d
Compare
166410d to
16f4e08
Compare
|
@codex may I have your review? |
|
Codex Review: Didn't find any major issues. Another round soon, please! ℹ️ About Codex in GitHubCodex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback". |
16f4e08 to
bb37bc4
Compare
|
Discussion with @jder - this is not quite the direction that we want. There is a way to accomplish this with a single perceiver: if we treat the prognostic and boundary as 1d vectors and contact then, we could use the perceiver on the 1d sequence. To preserve the patch encoding, we'd need to use our own Fourier positional encoding scheme, which should be easy to implement or reuse. We would also need to encode which part of the sequence was the prognostic and boundary in an encoding scheme. There are many modeling reasons (large context / latent mixing) why this is preferred over using two perceivers. |
|
Sounds great to me! |
bb37bc4 to
56115ed
Compare
Replace the single-stream PerceiverEncoder with a dual-perceiver architecture: separate Perceivers for prognostic and boundary streams, fused via concat→Linear projection. Because patch_extent is in degrees, both streams produce the same latent grid regardless of their spatial resolution, enabling cross-resolution forward passes (e.g. ¼° prog with 1° boundary). FOMO.forward_once now passes prog and boundary separately to the encoder instead of concatenating. 3D coordinates are appended to the prognostic stream only. Config: EncoderConfig gains a boundary_perceiver field (default depth=2, num_latents=64). PerceiverConfig.build outputs latent_dim (pooled) instead of out_channels. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
56115ed to
429d0a4
Compare
Switches the test helpers and call sites from the old dual-perceiver encoder kwargs (`prog_latent_dim`, `boundary_latent_dim`, `boundary_perceiver`) to the new ones (`token_dim`, `latent_dim`, `max_patch_size`, single `perceiver`). The test Perceiver helper now builds a 1-D Perceiver (`input_axis=1, fourier_encode_data=False`) to match the encoder's concatenated prog+boundary token sequence, and encoder calls now take a `GridContext` instead of a bare resolution tuple. Adds a small cross-resolution test for completeness. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| resolution. Within each patch: | ||
|
|
||
| * Prog and boundary pixels are flattened to 1-D token sequences and | ||
| linearly projected to a common ``token_dim``. |
There was a problem hiding this comment.
This is equivalent to a 1x1 Conv, see comment in code.
| pooled = self.perceiver(seq) # (B*lat_h*lat_w, latent_dim) | ||
| x = self.fusion_proj(pooled) # (B*lat_h*lat_w, out_channels) |
There was a problem hiding this comment.
Why include the linear layer here? Why not just set the perceiver to use the output dim? I believe the last layer of the perceiver is a proj.
| # Concat prog + boundary tokens along the sequence dim. The Perceiver | ||
| # cross-attends over the unified sequence, so prog and boundary mix | ||
| # inside the latent set rather than being fused after the fact. | ||
| seq = torch.cat([prog_tokens, boundary_tokens], dim=1) |
There was a problem hiding this comment.
I ran out of time to finish this PR today, but @jder, was this more of what you were imagining?
There was a problem hiding this comment.
Yes, this looks like the right direction to me!
| tokens = tokens + self.pos_proj(pos_enc).unsqueeze(0) | ||
|
|
||
| stream_idx = torch.tensor(stream_id, device=tokens.device) | ||
| tokens = tokens + self.stream_embed(stream_idx) |
There was a problem hiding this comment.
FWIW given the per-stream projection above in tokens = proj(tokens) I don't think we also need this since there's effectively a learned embedding inside the linear projection.
|
Jesse to review after getting KR1 branch merged + running. |
Replace the single-stream PerceiverEncoder with a dual-perceiver architecture: separate Perceivers for prognostic and boundary streams, fused via concat --> Linear projection. Because patch_extent is in degrees, both streams produce the same latent grid regardless of their spatial resolution, enabling cross-resolution forward passes (e.g. ¼° prog with 1° boundary).
This is https://github.com/Open-Athena/Ocean_Emulator/pull/681 broken up into a stack of three small PRs. 2/3