Skip to content

Conversation

@Jubeku
Copy link
Contributor

@Jubeku Jubeku commented Nov 21, 2025

Description

Implement MAE loss function (with weights for physical loss such as mse_channel_location_weighted).

UPDATE:
A generalized lp-norm function is implemented based on which MAE, MSE, RMSE, SSE can be implemented.

Issue Number

Closes #1333
Closes #1536

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@Jubeku Jubeku self-assigned this Nov 21, 2025
@Jubeku Jubeku added the model Related to model training or definition (not generic infra) label Nov 21, 2025
diff_abs = (diff_abs.transpose(1, 0) * weights_points).transpose(1, 0)
loss_chs = diff_abs.mean(0)
loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use the implementation below (didn't run it, so needs testing) then we could support MAE, MSE, RMSE and any othr l_p norm (for p < inf) in one function. Maybe it should be called generalized_lp_norm but I could live with the current naming. MAE, MSE would still exist as functions but they all use lp_norm in the implementation.

def lp_loss( 
    target: torch.Tensor,
    pred: torch.Tensor,
    p_norm: int,
    with_p_root: bool = True,
    with_mean = True,
    weights_channels: torch.Tensor | None = None,
    weights_points: torch.Tensor | None = None,
):

assert type(p_norm) is int, "Only integer p supported for p-norm loss"

mask_nan = ~torch.isnan(target)
pred = pred[0] if pred.shape[0] == 0 else pred.mean(0)

diff_p = torch.pow( torch.abs(torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0)), p_norm)
if weights_points is not None:
     diff_p = (diff_p.transpose(1, 0) * weights_points).transpose(1, 0)
loss_chs = diff_p.mean(0) if with_mean else diff_p.sum(0)
loss_chs = torch.pow( loss_chs, 1.0/p) if with_p_root else loss_chs 
loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure you want to recommend that? most of these functions have accelerated code/intrinsics where the performance difference can be shocking, especially for flowing gradient computations. I would not trust the torch compiler to detect and optimize these patterns.

@tjhunter
Copy link
Collaborator

@Jubeku noting there is a conflict

@Jubeku Jubeku marked this pull request as ready for review January 6, 2026 13:53
@Jubeku
Copy link
Contributor Author

Jubeku commented Jan 6, 2026

This PR makes PR #1551 obsolete which removes the implicit loss function conversion in LossPhysical.

@Jubeku Jubeku requested a review from clessig January 6, 2026 14:14
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor points on the documentation/comments. The code has been tested?

By default, the Lp-norm is normalized by the number of samples (i.e. with_mean=True).
* For example: p=1 corresponds to MAE; p=2 corresponds to MSE.
The samples are weighted by location if weights_points is not None.
The norm can optionally normalised by the pth root.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optionally normalized -> optionally be normalized

The function implements:
loss = Mean_{channels}( weight_channels * Mean_{data_pts}( (target - pred) * weights_points ))
loss = Mean_{channels}(weight_channels * Mean_{data_pts}(|(target - pred)|**p * weights_points))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The p-norm usually involves the p-th root "on the outside". This should also be an optional argument. Please make sure it's at the right position with respect to the mean operations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually present in the implementation but should be mentioned here then as well.

weights_channels : shape = (num_channels,)
weights_points : shape = (num_data_points)
target : tensor of shape ( num_data_points , num_channels )
target : tensor of shape ( ens_dim , num_data_points , num_channels)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pred : ...

Return:
loss : weight loss for gradient computation
loss_chs : losses per channel with location weighting but no channel weighting
loss : (weighted) loss for gradient computation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss: (weighted) scalar loss (e.g. for gradient computation)

------------------------
where wp = weights_points and wc = weights_channels and "x" denotes row/col-wise multiplication.
return lp_loss(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should refer to lp_loss for detailed explanation of arguments.

target : shape ( ens_dim , num_data_points , num_channels)
weights_channels : shape = (num_channels,)
weights_points : shape = (num_data_points)
def sse(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "sse" a standard name? I am not familar with it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better Residual Sum of Squares (RSS)?
https://en.wikipedia.org/wiki/Residual_sum_of_squares

@Jubeku
Copy link
Contributor Author

Jubeku commented Jan 7, 2026

Just some minor points on the documentation/comments. The code has been tested?

I run the integration test and tested manually that the value of the mse loss corresponds to the value using the old implementation (mse_channel_location_weighted):
image

@Jubeku Jubeku merged commit e2032c3 into develop Jan 7, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Related to model training or definition (not generic infra)

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Implicit loss function conversion in LossPhysical MAE loss function

4 participants