Conversation
# Conflicts: # heavyball/utils.py
GPU Test Results0/31 passed
Failure details
|
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b2ef73b885
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
heavyball/utils.py
Outdated
| if ( | ||
| group.get("merge_dims", False) | ||
| and p.data.ndim >= 4 | ||
| and p.data.is_contiguous(memory_format=torch.channels_last) | ||
| ): |
There was a problem hiding this comment.
Preserve merge_dims updates for non-contiguous params
This guard now only makes parameters contiguous when they are already channels_last, so other non-contiguous layouts skip the normalization step. With merge_dims=True, merge_group()/dim_merger() can then materialize contiguous temporary tensors, and later updates apply to those temporaries instead of the original parameter storage, which can silently prevent real weights from updating for transposed/custom-layout parameters (including 3D channels-last layouts that do not satisfy this predicate).
Useful? React with 👍 / 👎.
| @zero_guard("mars_old_grad") | ||
| @no_state | ||
| def mars(group, update, grad, param, mars_old_grad): | ||
| utils.mars_correction(update, mars_old_grad, group["mars_gamma"], utils.get_beta1(group)) |
There was a problem hiding this comment.
Apply MARS correction to gradient tensor too
The new MARS transform mutates update but leaves grad unchanged, yet downstream fused update paths still use grad for caution masking (for example via update_by_* into _compilable_update_). In mars=True + caution=True runs this makes the mask compare corrected updates against stale gradients, so valid updates can be masked or scaled incorrectly; previously MARS was applied before gradients were yielded so both tensors stayed consistent.
Useful? React with 👍 / 👎.
No description provided.