-
Notifications
You must be signed in to change notification settings - Fork 49
Implement MAE loss function #1334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| diff_abs = (diff_abs.transpose(1, 0) * weights_points).transpose(1, 0) | ||
| loss_chs = diff_abs.mean(0) | ||
| loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use the implementation below (didn't run it, so needs testing) then we could support MAE, MSE, RMSE and any othr l_p norm (for p < inf) in one function. Maybe it should be called generalized_lp_norm but I could live with the current naming. MAE, MSE would still exist as functions but they all use lp_norm in the implementation.
def lp_loss(
target: torch.Tensor,
pred: torch.Tensor,
p_norm: int,
with_p_root: bool = True,
with_mean = True,
weights_channels: torch.Tensor | None = None,
weights_points: torch.Tensor | None = None,
):
assert type(p_norm) is int, "Only integer p supported for p-norm loss"
mask_nan = ~torch.isnan(target)
pred = pred[0] if pred.shape[0] == 0 else pred.mean(0)
diff_p = torch.pow( torch.abs(torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0)), p_norm)
if weights_points is not None:
diff_p = (diff_p.transpose(1, 0) * weights_points).transpose(1, 0)
loss_chs = diff_p.mean(0) if with_mean else diff_p.sum(0)
loss_chs = torch.pow( loss_chs, 1.0/p) if with_p_root else loss_chs
loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you sure you want to recommend that? most of these functions have accelerated code/intrinsics where the performance difference can be shocking, especially for flowing gradient computations. I would not trust the torch compiler to detect and optimize these patterns.
|
@Jubeku noting there is a conflict |
|
This PR makes PR #1551 obsolete which removes the implicit loss function conversion in LossPhysical. |
clessig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some minor points on the documentation/comments. The code has been tested?
| By default, the Lp-norm is normalized by the number of samples (i.e. with_mean=True). | ||
| * For example: p=1 corresponds to MAE; p=2 corresponds to MSE. | ||
| The samples are weighted by location if weights_points is not None. | ||
| The norm can optionally normalised by the pth root. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optionally normalized -> optionally be normalized
| The function implements: | ||
| loss = Mean_{channels}( weight_channels * Mean_{data_pts}( (target - pred) * weights_points )) | ||
| loss = Mean_{channels}(weight_channels * Mean_{data_pts}(|(target - pred)|**p * weights_points)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The p-norm usually involves the p-th root "on the outside". This should also be an optional argument. Please make sure it's at the right position with respect to the mean operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually present in the implementation but should be mentioned here then as well.
| weights_channels : shape = (num_channels,) | ||
| weights_points : shape = (num_data_points) | ||
| target : tensor of shape ( num_data_points , num_channels ) | ||
| target : tensor of shape ( ens_dim , num_data_points , num_channels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pred : ...
| Return: | ||
| loss : weight loss for gradient computation | ||
| loss_chs : losses per channel with location weighting but no channel weighting | ||
| loss : (weighted) loss for gradient computation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loss: (weighted) scalar loss (e.g. for gradient computation)
| ------------------------ | ||
| where wp = weights_points and wc = weights_channels and "x" denotes row/col-wise multiplication. | ||
| return lp_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should refer to lp_loss for detailed explanation of arguments.
| target : shape ( ens_dim , num_data_points , num_channels) | ||
| weights_channels : shape = (num_channels,) | ||
| weights_points : shape = (num_data_points) | ||
| def sse( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "sse" a standard name? I am not familar with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better Residual Sum of Squares (RSS)?
https://en.wikipedia.org/wiki/Residual_sum_of_squares

Description
Implement MAE loss function (with weights for physical loss such as mse_channel_location_weighted).
UPDATE:
A generalized lp-norm function is implemented based on which MAE, MSE, RMSE, SSE can be implemented.
Issue Number
Closes #1333
Closes #1536
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60