Skip to content

U/alxmrs/experiments/kr2#681

Draft
alxmrs wants to merge 28 commits into
mainfrom
u/alxmrs/experiments/kr2
Draft

U/alxmrs/experiments/kr2#681
alxmrs wants to merge 28 commits into
mainfrom
u/alxmrs/experiments/kr2

Conversation

@alxmrs

@alxmrs alxmrs commented Apr 13, 2026

Copy link
Copy Markdown
Member

This PR makes four contributions towards implementing O1KR2 (#615):

  1. Updates the encoder to use dual perceivers to separately represent prognostic and boundary forcings. This is needed because at inference, we plan to use coarse forcings for high res prognostic prediction. Latents from both boundary and prognostic perceivers are linearly mixed, then positional encodings are added before the output is passed on (to the processor).
  2. To enable (1) in the model, this PR also updates the data plumbing across the whole training and inference process to separate the prognostic and boundary tensors. These are no longer concatenated as an "input" tensor.
  3. Also to enable (1), we now update the training and inference AR process to update the GridContext input resolution after step 1. In this new scheme, step 0 does either downscaling or upscaling, and step 1+ updates the grid context to make the input resolution the output resolution. This is needed during "mix" schedule multiscale training, since step 1 and after the prognostic is fed back to the previous decoder step at the output resolution. This does not affect "match" schedule multiscale training, because both resolutions are already equal.
  4. In addition to unit tests being updated, we add a new test_fomo_cross_resolution set of integration tests. These vet that single step cross resolution (prog/boundary), single step mix schedule, and two step AR mixed schedule training of the FOMO model all work as expected.

@alxmrs alxmrs force-pushed the u/alxmrs/experiments/kr2 branch from a736b1a to c159eb8 Compare April 13, 2026 19:34
Comment thread src/ocean_emulators/models/modules/encoder.py Outdated
@@ -28,26 +29,34 @@ def patch_from(


class PerceiverEncoder(nn.Module):

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this branch are in service of this encoder update.

@alxmrs

alxmrs commented Apr 14, 2026

Copy link
Copy Markdown
Member Author

@codex may I have your review?

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 2a89013570

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread src/ocean_emulators/models/fomo.py Outdated
Comment on lines +98 to +102
prognostic = self.maybe_add_3d_coordinates(
prognostic, ctx.output_resolution_cpu
)

fts = self.encoder(prognostic, boundary, ctx.output_resolution_cpu)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use input-resolution coordinates for FOMO prognostic path

When src and dst resolutions differ in mix schedule training, TrainData carries prognostic tensors on ctx.input_resolution_cpu (see TorchTrainDataset.ctx), but this code feeds ctx.output_resolution_cpu into the prognostic path. That makes spatial metadata inconsistent with the tensor shape: Concat3dCoordinates can fail with a size mismatch if 3D coordinates are enabled, and even without 3D coordinates the encoder’s positional encoding (pos_scale_enc) is computed from output-grid lat/lon with input-grid patch sizes, producing a token count mismatch at the x + pos_encoding add. As written, cross-resolution mix runs can crash on the first forward pass.

Useful? React with 👍 / 👎.

@alxmrs

alxmrs commented Apr 14, 2026

Copy link
Copy Markdown
Member Author

@codex may I have your review?

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d1c7d2e68d

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread src/ocean_emulators/models/fomo.py
@alxmrs

alxmrs commented Apr 14, 2026

Copy link
Copy Markdown
Member Author

@codex may I have another review?

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex Review: Didn't find any major issues. 🚀

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@alxmrs alxmrs force-pushed the u/alxmrs/experiments/kr2 branch from 944c435 to 022914a Compare April 14, 2026 20:18
# so the encoder uses the correct resolution.
ctx = dataclasses.replace(
ctx, input_resolution_cpu=ctx.output_resolution_cpu
)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This addresses the TODO on the left.

class Stepper:
def __init__(self):
pass
def train_batch(

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Stepper class was really three functions in a trenchcoat. So, I extracted them.

@alxmrs alxmrs force-pushed the u/alxmrs/experiments/kr2 branch from 58f278e to 8f7ffa2 Compare April 22, 2026 22:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant