Skip to content

Add changes for pushforward trick#1997

Open
SavvasMel wants to merge 6 commits intoecmwf:developfrom
SavvasMel:SavvasMel/develop/pushf_trick
Open

Add changes for pushforward trick#1997
SavvasMel wants to merge 6 commits intoecmwf:developfrom
SavvasMel:SavvasMel/develop/pushf_trick

Conversation

@SavvasMel
Copy link
Copy Markdown
Contributor

Description

This PR adds the necessary changes for the pushforward trick.

Issue Number

Closes #1740

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@github-actions github-actions bot added model Related to model training or definition (not generic infra) science Scientific questions labels Mar 9, 2026
@SavvasMel SavvasMel requested a review from clessig March 12, 2026 11:57
# reshard_after_forward=False keeps FE parameters unsharded
# during the multi-step rollout loop.
# Needed for pushforward trick.
fully_shard(module, reshard_after_forward=False, **fsdp_kwargs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sophie-xhonneux : is this maybe related to the problem we are seeing with the EMATeacher where we need to reshard?

@SavvasMel SavvasMel force-pushed the SavvasMel/develop/pushf_trick branch from 056d4be to 8ad0cff Compare March 28, 2026 10:58
@github-actions github-actions bot added data Anything related to the datasets used in the project eval anything related to the model evaluation pipeline infra Issues related to infrastructure labels Mar 28, 2026
@SavvasMel SavvasMel force-pushed the SavvasMel/develop/pushf_trick branch from 2f89054 to 02a0d46 Compare March 28, 2026 11:09
@SavvasMel SavvasMel requested a review from clessig March 28, 2026 11:11
Copy link
Copy Markdown
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially ready to merge, just some minor details.


def __init__(self):
super().__init__()
self.fe_blocks = torch.nn.ModuleList()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.blocks = ... Also do we need this as all? If anything I would expect self.blocks = torch.nn.Identity() but since we implement forward it might not be needed

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually added that for this:

num_params_fe = get_num_parameters(self.forecast_engine.fe_blocks)

See below.

tokens = tokens.reshape(shape).sum(axis=1)

# Allow for pushforward trick
p_fwd = self.cf.get("pushforward_trick", False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect that the push-forward trick config is part of forecast, i.e.

training_config : 

  .... 

  forecast: 
    num_steps: 3
    push_forward: True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

data Anything related to the datasets used in the project eval anything related to the model evaluation pipeline infra Issues related to infrastructure model Related to model training or definition (not generic infra) science Scientific questions

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Run pushforward trick experiments

3 participants