-
Notifications
You must be signed in to change notification settings - Fork 81
fix: minor reactant stuff + docs build #1548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @avik-pal, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the integration of Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
ext/LuxReactantExt/training.jl
Outdated
| ) | ||
| else | ||
| cache = TrainingBackendCache( | ||
| backend, False(), dparameters, (; compiled_grad_and_step_function, is_sharded) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| backend, False(), dparameters, (; compiled_grad_and_step_function, is_sharded) | |
| backend, | |
| False(), | |
| dparameters, | |
| (; compiled_grad_and_step_function, is_sharded), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request updates the documentation to use Reactant and Enzyme and refactors the Reactant training extension. The documentation changes look good but introduce a small bug in the quickstart example. The refactoring of the training extension is a nice improvement, but it introduces a critical bug in the apply_gradients implementation that could lead to a crash. My review includes suggestions to fix these issues.
| ## First construct a TrainState | ||
| train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0)) | ||
| train_state = Training.TrainState(model, ps, st, Adam(0.0001f0)) | ||
| ## We can compute the gradients using Training.compute_gradients | ||
| ## TrainState handles compilation internally | ||
| gs, loss, stats, train_state = Lux.Training.compute_gradients( | ||
| AutoZygote(), MSELoss(), | ||
| (x, dev(rand(rng, Float32, 10, 2))), train_state | ||
| AutoEnzyme(), | ||
| MSELoss(), | ||
| (x, dev(rand(rng, Float32, 10, 2))), | ||
| train_state | ||
| ) | ||
| ## Optimization | ||
| train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end) | ||
| # Both these steps can be combined into a single call | ||
| # Both these steps can be combined into a single call (preferred approach) | ||
| gs, loss, stats, train_state = Training.single_train_step!( | ||
| AutoZygote(), MSELoss(), | ||
| (x, dev(rand(rng, Float32, 10, 2))), train_state | ||
| AutoEnzyme(), | ||
| MSELoss(), | ||
| (x, dev(rand(rng, Float32, 10, 2))), | ||
| train_state | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Training module is not imported directly, so calls to Training.TrainState, Training.apply_gradients!, and Training.single_train_step! will fail. You should prefix them with Lux. to match the style in the rest of the file and ensure the example code is runnable.
## First construct a TrainState
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
## We can compute the gradients using Training.compute_gradients
## TrainState handles compilation internally
gs, loss, stats, train_state = Lux.Training.compute_gradients(
AutoEnzyme(),
MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))),
train_state
)
## Optimization
train_state = Lux.Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)
# Both these steps can be combined into a single call (preferred approach)
gs, loss, stats, train_state = Lux.Training.single_train_step!(
AutoEnzyme(),
MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))),
train_state
)84094d2 to
0fe8fa8
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1548 +/- ##
==========================================
- Coverage 82.47% 73.87% -8.61%
==========================================
Files 168 168
Lines 6957 6954 -3
==========================================
- Hits 5738 5137 -601
- Misses 1219 1817 +598 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Benchmark Results (Julia v1.11)Time benchmarks
Memory benchmarks
|
0fe8fa8 to
9c9b0d6
Compare
No description provided.