-
Notifications
You must be signed in to change notification settings - Fork 10
Description
While running the Modeling a simple Einstein ring example notebook, I noticed that all calculations were performed on a single CPU. This suggests the code has not been sufficiently vectorized to take advantage of JAX's features. During the PSO fitting sequence in the example there are 200 "particles" which can in principle be evaluated simultaneously. Similarly, the optional MCMC stage has 70 walkers that can be run fully parallelized.
On a CPU this vectorization can achieve noticeable real speedup, or at least acceleration by the number of cores available. However, on GPU it is absolutely critical to have vectorized operations in order the utilize the resource. A quick look through the repo suggests that the current JAXtronomy PSO implementation is actually just imported from lenstronomy. I understand this has nice behavior regarding the JAXtronomy mission statement of being a drop in replacement for lenstronomy. However, I think absorbing this and similar functions into JAXtronomy so that appropriate accelerations/vectorizations can be made is necessary to really say that JAXtronomy is GPU compatible.
The L-BFGS optimizer does seem to be vectorized, at least on my laptop it is able to use a significant fraction of the cores.