Skip to content

Commit 37b7ad5

Browse files
committed
Switch to new Scan API
Co-authored-by: Michal-Novomestsky
1 parent e5eb8c9 commit 37b7ad5

File tree

9 files changed

+159
-95
lines changed

9 files changed

+159
-95
lines changed

pymc/distributions/timeseries.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -442,23 +442,26 @@ def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None
442442
rhos_bcast = pt.broadcast_to(rhos, rhos_bcast_shape)
443443

444444
def step(*args):
445-
*prev_xs, reversed_rhos, sigma, rng = args
445+
*prev_xs, rng, reversed_rhos, sigma = args
446446
if constant_term:
447447
mu = reversed_rhos[-1] + pt.sum(prev_xs * reversed_rhos[:-1], axis=0)
448448
else:
449449
mu = pt.sum(prev_xs * reversed_rhos, axis=0)
450450
next_rng, new_x = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs
451-
return new_x, {rng: next_rng}
451+
return new_x, next_rng
452452

453453
# We transpose inputs as scan iterates over first dimension
454-
innov, innov_updates = pytensor.scan(
454+
innov, noise_next_rng = pytensor.scan(
455455
fn=step,
456-
outputs_info=[{"initial": init_dist.T, "taps": range(-ar_order, 0)}],
457-
non_sequences=[rhos_bcast.T[::-1], sigma.T, noise_rng],
456+
outputs_info=[
457+
{"initial": init_dist.T, "taps": range(-ar_order, 0)},
458+
noise_rng,
459+
],
460+
non_sequences=[rhos_bcast.T[::-1], sigma.T],
458461
n_steps=steps,
459462
strict=True,
463+
return_updates=False,
460464
)
461-
(noise_next_rng,) = tuple(innov_updates.values())
462465
ar = pt.concatenate([init_dist, innov.T], axis=-1)
463466

464467
return AutoRegressiveRV(
@@ -710,24 +713,25 @@ def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None)
710713

711714
# Create OpFromGraph representing random draws from GARCH11 process
712715

713-
def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
716+
def step(prev_y, prev_sigma, rng, omega, alpha_1, beta_1):
714717
new_sigma = pt.sqrt(
715718
omega + alpha_1 * pt.square(prev_y) + beta_1 * pt.square(prev_sigma)
716719
)
717720
next_rng, new_y = Normal.dist(mu=0, sigma=new_sigma, rng=rng).owner.outputs
718-
return (new_y, new_sigma), {rng: next_rng}
721+
return new_y, new_sigma, next_rng
719722

720-
(y_t, _), innov_updates = pytensor.scan(
723+
y_t, _, noise_next_rng = pytensor.scan(
721724
fn=step,
722725
outputs_info=[
723726
init_dist,
724727
pt.broadcast_to(initial_vol.astype("floatX"), init_dist.shape),
728+
noise_rng,
725729
],
726-
non_sequences=[omega, alpha_1, beta_1, noise_rng],
730+
non_sequences=[omega, alpha_1, beta_1],
727731
n_steps=steps,
728732
strict=True,
733+
return_updates=False,
729734
)
730-
(noise_next_rng,) = tuple(innov_updates.values())
731735

732736
garch11 = pt.concatenate([init_dist[None, ...], y_t], axis=0).dimshuffle(
733737
(*range(1, y_t.ndim), 0)
@@ -816,12 +820,13 @@ def garch11_logp(
816820
def volatility_update(x, vol, w, a, b):
817821
return pt.sqrt(w + a * pt.square(x) + b * pt.square(vol))
818822

819-
vol, _ = pytensor.scan(
823+
vol = pytensor.scan(
820824
fn=volatility_update,
821825
sequences=[value_dimswapped[:-1]],
822826
outputs_info=[initial_vol],
823827
non_sequences=[omega, alpha_1, beta_1],
824828
strict=True,
829+
return_updates=False,
825830
)
826831
sigma_t = pt.concatenate([[initial_vol], vol])
827832
# Compute and collapse logp across time dimension
@@ -861,21 +866,21 @@ def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None):
861866

862867
# Create OpFromGraph representing random draws from SDE process
863868
def step(*prev_args):
864-
prev_y, *prev_sde_pars, rng = prev_args
869+
prev_y, rng, *prev_sde_pars = prev_args
865870
f, g = sde_fn(prev_y, *prev_sde_pars)
866871
mu = prev_y + dt * f
867872
sigma = pt.sqrt(dt) * g
868873
next_rng, next_y = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs
869-
return next_y, {rng: next_rng}
874+
return next_y, next_rng
870875

871-
y_t, innov_updates = pytensor.scan(
876+
y_t, noise_next_rng = pytensor.scan(
872877
fn=step,
873-
outputs_info=[init_dist],
874-
non_sequences=[*sde_pars, noise_rng],
878+
outputs_info=[init_dist, noise_rng],
879+
non_sequences=[*sde_pars],
875880
n_steps=steps,
876881
strict=True,
882+
return_updates=False,
877883
)
878-
(noise_next_rng,) = tuple(innov_updates.values())
879884

880885
sde_out = pt.concatenate([init_dist[None, ...], y_t], axis=0).dimshuffle(
881886
(*range(1, y_t.ndim), 0)

pymc/logprob/transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,11 +729,12 @@ def calc_delta_x(value, prior_result):
729729
2 * prior_result * pt.erfcx(prior_result) - 2 / pt.sqrt(np.pi)
730730
)
731731

732-
result, updates = scan(
732+
result = scan(
733733
fn=calc_delta_x,
734734
outputs_info=pt.ones_like(x),
735735
non_sequences=value,
736736
n_steps=10,
737+
return_updates=False,
737738
)
738739
return result[-1]
739740

pymc/pytensorf.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def jacobian1(f, v):
331331
def grad_i(i):
332332
return gradient1(f[i], v)
333333

334-
return pytensor.map(grad_i, idx)[0]
334+
return pytensor.map(grad_i, idx, return_updates=False)
335335

336336

337337
def jacobian(f, vars=None):
@@ -355,8 +355,13 @@ def grad_ii(i, f, x):
355355
return grad(f[i], x)[i]
356356

357357
return pytensor.scan(
358-
grad_ii, sequences=[idx], n_steps=f.shape[0], non_sequences=[f, x], name="jacobian_diag"
359-
)[0]
358+
grad_ii,
359+
sequences=[idx],
360+
n_steps=f.shape[0],
361+
non_sequences=[f, x],
362+
name="jacobian_diag",
363+
return_updates=False,
364+
)
360365

361366

362367
@pytensor.config.change_flags(compute_test_value="ignore")
@@ -381,7 +386,7 @@ def hessian_diag1(f, v):
381386
def hess_ii(i):
382387
return gradient1(g[i], v)[i]
383388

384-
return pytensor.map(hess_ii, idx)[0]
389+
return pytensor.map(hess_ii, idx, return_updates=False)
385390

386391

387392
@pytensor.config.change_flags(compute_test_value="ignore")

pymc/variational/approximations.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,10 @@ def evaluate_over_trace(self, node):
392392
def sample(post, *_):
393393
return graph_replace(node, {self.input: post}, strict=False)
394394

395-
nodes, _ = pytensor.scan(
396-
sample, self.histogram, non_sequences=_known_scan_ignored_inputs(makeiter(node))
395+
nodes = pytensor.scan(
396+
sample,
397+
self.histogram,
398+
non_sequences=_known_scan_ignored_inputs(makeiter(node)),
399+
return_updates=False,
397400
)
398401
return nodes

pymc/variational/opvi.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,8 +1062,11 @@ def symbolic_sample_over_posterior(self, node):
10621062
def sample(post, *_):
10631063
return graph_replace(node, {self.input: post}, strict=False)
10641064

1065-
nodes, _ = pytensor.scan(
1066-
sample, random, non_sequences=_known_scan_ignored_inputs(makeiter(random))
1065+
nodes = pytensor.scan(
1066+
sample,
1067+
random,
1068+
non_sequences=_known_scan_ignored_inputs(makeiter(random)),
1069+
return_updates=False,
10671070
)
10681071
assert self.input not in set(pytensor.graph.graph_inputs(makeiter(nodes)))
10691072
return nodes
@@ -1451,8 +1454,11 @@ def symbolic_sample_over_posterior(self, node, more_replacements=None):
14511454
def sample(*post):
14521455
return graph_replace(node, dict(zip(self.inputs, post)), strict=False)
14531456

1454-
nodes, _ = pytensor.scan(
1455-
sample, self.symbolic_randoms, non_sequences=_known_scan_ignored_inputs(makeiter(node))
1457+
nodes = pytensor.scan(
1458+
sample,
1459+
self.symbolic_randoms,
1460+
non_sequences=_known_scan_ignored_inputs(makeiter(node)),
1461+
return_updates=False,
14561462
)
14571463
assert not (set(self.inputs) & set(pytensor.graph.graph_inputs(makeiter(nodes))))
14581464
return nodes

tests/distributions/test_custom.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -362,32 +362,40 @@ def scan_step(left, right):
362362
return x, x_update
363363

364364
def dist(size):
365-
xs, updates = scan(
366-
fn=scan_step,
367-
sequences=[
368-
pt.as_tensor_variable(np.array([-4, -3])),
369-
pt.as_tensor_variable(np.array([-2, -1])),
370-
],
371-
name="xs",
372-
)
365+
with pytest.warns(DeprecationWarning, match="Scan return signature will change"):
366+
xs, updates = scan(
367+
fn=scan_step,
368+
sequences=[
369+
pt.as_tensor_variable(np.array([-4, -3])),
370+
pt.as_tensor_variable(np.array([-2, -1])),
371+
],
372+
name="xs",
373+
# There's a bug in the ordering of outputs when there's a mapped `None` output
374+
# We have to stick with the deprecated API for now
375+
return_updates=True,
376+
)
373377
return xs
374378

375379
with Model() as model:
376380
CustomDist("x", dist=dist)
377381
assert_support_point_is_expected(model, np.array([-3, -2]))
378382

379383
def test_custom_dist_default_support_point_scan_recurring(self):
380-
def scan_step(xtm1):
381-
x = Normal.dist(xtm1 + 1)
382-
x_update = collect_default_updates([x])
383-
return x, x_update
384+
def scan_step(xtm1, rng):
385+
next_rng, x = Normal.dist(xtm1 + 1, rng=rng).owner.outputs
386+
return x, next_rng
384387

385388
def dist(size):
386-
xs, _ = scan(
389+
rng = pytensor.shared(np.random.default_rng())
390+
xs, _next_rng = scan(
387391
fn=scan_step,
388-
outputs_info=pt.as_tensor_variable(np.array([0])).astype(float),
392+
outputs_info=[
393+
pt.as_tensor_variable(np.array([0])).astype(float),
394+
rng,
395+
],
389396
n_steps=3,
390397
name="xs",
398+
return_updates=False,
391399
)
392400
return xs
393401

@@ -527,16 +535,20 @@ def test_scan(self):
527535
def trw(nu, sigma, steps, size):
528536
if rv_size_is_none(size):
529537
size = ()
538+
rng = pytensor.shared(np.random.default_rng())
530539

531-
def step(xtm1, nu, sigma):
532-
x = StudentT.dist(nu=nu, mu=xtm1, sigma=sigma, shape=size)
533-
return x, collect_default_updates([x])
540+
def step(xtm1, rng, nu, sigma):
541+
next_rng, x = StudentT.dist(
542+
nu=nu, mu=xtm1, sigma=sigma, shape=size, rng=rng
543+
).owner.outputs
544+
return x, next_rng
534545

535-
xs, _ = scan(
546+
xs, _next_rng = scan(
536547
fn=step,
537-
outputs_info=pt.zeros(size),
548+
outputs_info=[pt.zeros(size), rng],
538549
non_sequences=[nu, sigma],
539550
n_steps=steps,
551+
return_updates=False,
540552
)
541553

542554
# Logprob inference cannot be derived yet https://github.com/pymc-devs/pymc/issues/6360
@@ -667,13 +679,17 @@ def step(s):
667679
traffic = s + innov
668680
return traffic, {innov.owner.inputs[0]: innov.owner.outputs[0]}
669681

670-
rv_seq, _ = pytensor.scan(
671-
fn=step,
672-
sequences=[seq],
673-
outputs_info=[None],
674-
n_steps=n_steps,
675-
strict=True,
676-
)
682+
with pytest.warns(DeprecationWarning, match="Scan return signature will change"):
683+
rv_seq, _ = pytensor.scan(
684+
fn=step,
685+
sequences=[seq],
686+
outputs_info=[None],
687+
n_steps=n_steps,
688+
strict=True,
689+
# There's a bug in the ordering of outputs when there's a mapped `None` output
690+
# We have to stick with the deprecated API for now
691+
return_updates=True,
692+
)
677693
return rv_seq
678694

679695
def normal_shifted(mu, size):

0 commit comments

Comments
 (0)