Skip to content

Perceiver encoder updated for cross-resolution fusion#702

Open
alxmrs wants to merge 4 commits into
mainfrom
u/alxmrs/kr2/2-dual-perceiver
Open

Perceiver encoder updated for cross-resolution fusion#702
alxmrs wants to merge 4 commits into
mainfrom
u/alxmrs/kr2/2-dual-perceiver

Conversation

@alxmrs

@alxmrs alxmrs commented Apr 15, 2026

Copy link
Copy Markdown
Member

Replace the single-stream PerceiverEncoder with a dual-perceiver architecture: separate Perceivers for prognostic and boundary streams, fused via concat --> Linear projection. Because patch_extent is in degrees, both streams produce the same latent grid regardless of their spatial resolution, enabling cross-resolution forward passes (e.g. ¼° prog with 1° boundary).

This is https://github.com/Open-Athena/Ocean_Emulator/pull/681 broken up into a stack of three small PRs. 2/3

@alxmrs alxmrs force-pushed the u/alxmrs/kr2/1-split-tensors branch from 3392629 to 93bca31 Compare April 15, 2026 02:40
@alxmrs alxmrs force-pushed the u/alxmrs/kr2/2-dual-perceiver branch from e0605e6 to 166410d Compare April 15, 2026 02:41
@alxmrs alxmrs requested a review from jder April 15, 2026 02:47
@alxmrs alxmrs marked this pull request as ready for review April 15, 2026 02:47
Base automatically changed from u/alxmrs/kr2/1-split-tensors to main April 18, 2026 00:07
@alxmrs alxmrs force-pushed the u/alxmrs/kr2/2-dual-perceiver branch from 166410d to 16f4e08 Compare April 22, 2026 19:03
@alxmrs

alxmrs commented Apr 22, 2026

Copy link
Copy Markdown
Member Author

@codex may I have your review?

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex Review: Didn't find any major issues. Another round soon, please!

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@alxmrs alxmrs force-pushed the u/alxmrs/kr2/2-dual-perceiver branch from 16f4e08 to bb37bc4 Compare April 22, 2026 22:11
@alxmrs

alxmrs commented Apr 24, 2026

Copy link
Copy Markdown
Member Author

Discussion with @jder - this is not quite the direction that we want. There is a way to accomplish this with a single perceiver: if we treat the prognostic and boundary as 1d vectors and contact then, we could use the perceiver on the 1d sequence. To preserve the patch encoding, we'd need to use our own Fourier positional encoding scheme, which should be easy to implement or reuse. We would also need to encode which part of the sequence was the prognostic and boundary in an encoding scheme.

There are many modeling reasons (large context / latent mixing) why this is preferred over using two perceivers.

@jder

jder commented Apr 24, 2026

Copy link
Copy Markdown
Member

Sounds great to me!

@alxmrs alxmrs force-pushed the u/alxmrs/kr2/2-dual-perceiver branch from bb37bc4 to 56115ed Compare April 24, 2026 21:12
Replace the single-stream PerceiverEncoder with a dual-perceiver
architecture: separate Perceivers for prognostic and boundary streams,
fused via concat→Linear projection. Because patch_extent is in degrees,
both streams produce the same latent grid regardless of their spatial
resolution, enabling cross-resolution forward passes (e.g. ¼° prog
with 1° boundary).

FOMO.forward_once now passes prog and boundary separately to the
encoder instead of concatenating. 3D coordinates are appended to the
prognostic stream only.

Config: EncoderConfig gains a boundary_perceiver field (default depth=2,
num_latents=64). PerceiverConfig.build outputs latent_dim (pooled)
instead of out_channels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@alxmrs alxmrs force-pushed the u/alxmrs/kr2/2-dual-perceiver branch from 56115ed to 429d0a4 Compare April 29, 2026 22:20
alxmrs and others added 3 commits April 29, 2026 16:12
Switches the test helpers and call sites from the old dual-perceiver
encoder kwargs (`prog_latent_dim`, `boundary_latent_dim`,
`boundary_perceiver`) to the new ones (`token_dim`, `latent_dim`,
`max_patch_size`, single `perceiver`).  The test Perceiver helper now
builds a 1-D Perceiver (`input_axis=1, fourier_encode_data=False`) to
match the encoder's concatenated prog+boundary token sequence, and
encoder calls now take a `GridContext` instead of a bare resolution
tuple.  Adds a small cross-resolution test for completeness.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
resolution. Within each patch:

* Prog and boundary pixels are flattened to 1-D token sequences and
linearly projected to a common ``token_dim``.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is equivalent to a 1x1 Conv, see comment in code.

Comment on lines +227 to +228
pooled = self.perceiver(seq) # (B*lat_h*lat_w, latent_dim)
x = self.fusion_proj(pooled) # (B*lat_h*lat_w, out_channels)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why include the linear layer here? Why not just set the perceiver to use the output dim? I believe the last layer of the perceiver is a proj.

# Concat prog + boundary tokens along the sequence dim. The Perceiver
# cross-attends over the unified sequence, so prog and boundary mix
# inside the latent set rather than being fused after the fact.
seq = torch.cat([prog_tokens, boundary_tokens], dim=1)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran out of time to finish this PR today, but @jder, was this more of what you were imagining?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this looks like the right direction to me!

tokens = tokens + self.pos_proj(pos_enc).unsqueeze(0)

stream_idx = torch.tensor(stream_id, device=tokens.device)
tokens = tokens + self.stream_embed(stream_idx)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW given the per-stream projection above in tokens = proj(tokens) I don't think we also need this since there's effectively a learned embedding inside the linear projection.

@alxmrs alxmrs changed the title Dual-perceiver encoder for cross-resolution fusion Perceiver encoder updated for cross-resolution fusion May 4, 2026
@jder

jder commented Jun 15, 2026

Copy link
Copy Markdown
Member

Jesse to review after getting KR1 branch merged + running.

@jder jder self-assigned this Jun 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants