Skip to content

Commit a66c0b2

Browse files
authored
Fix ODE integrator default time and compatibility with jax>=0.9.0 (#813)
* fix: ensure 't' keyword argument defaults to 0 in _call_integral and format code in build method * fix: update backend import for compatibility with jax>=0.8.0
1 parent 26778e8 commit a66c0b2

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

brainpy/integrators/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ def state_delays(self, value):
141141
raise ValueError('Cannot set "state_delays" by users.')
142142

143143
def _call_integral(self, *args, **kwargs):
144+
kwargs = dict(kwargs)
145+
t = kwargs.get('t', None)
146+
kwargs['t'] = 0. if t is None else t
147+
144148
if _during_compile:
145149
jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs)
146150
outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs))

brainpy/integrators/ode/explicit_rk.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ def __init__(self,
178178

179179
def build(self):
180180
# step stage
181-
common.step(self.variables, C.DT,
182-
self.A, self.C, self.code_lines, self.parameters)
181+
common.step(self.variables, C.DT, self.A, self.C, self.code_lines, self.parameters)
183182
# variable update
184183
return_args = common.update(self.variables, C.DT, self.B, self.code_lines)
185184
# returns
@@ -189,7 +188,8 @@ def build(self):
189188
code_scope={k: v for k, v in self.code_scope.items()},
190189
code_lines=self.code_lines,
191190
show_code=self.show_code,
192-
func_name=self.func_name)
191+
func_name=self.func_name
192+
)
193193

194194

195195
class Euler(ExplicitRKIntegrator):

brainpy/math/environment.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import brainstate.environ
2626
import jax
2727
from jax import config, numpy as jnp, devices
28-
from jax.lib import xla_bridge
2928

3029
from . import modes
3130
from . import scales
@@ -733,8 +732,13 @@ def clear_buffer_memory(
733732
Clear name cache. Default is True.
734733
735734
"""
735+
if jax.__version_info__ < (0, 8, 0):
736+
from jax.lib.xla_bridge import get_backend
737+
else:
738+
from jax.extend.backend import get_backend
739+
736740
if array:
737-
for buf in xla_bridge.get_backend(platform).live_buffers():
741+
for buf in get_backend(platform).live_buffers():
738742
buf.delete()
739743
if compilation:
740744
jax.clear_caches()

0 commit comments

Comments
 (0)