Skip to content

Commit 8099829

Browse files
committed
Merge pull request #730 from pymc-devs/transform
automatic transform application
2 parents 10360db + d81a974 commit 8099829

23 files changed

+485
-396
lines changed

pymc3/backends/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(self, name, model=None, vars=None):
3131
self.varnames = [str(var) for var in vars]
3232
self.fn = model.fastfn(vars)
3333

34+
3435
## Get variable shapes. Most backends will need this
3536
## information.
3637
var_values = list(zip(self.varnames, self.fn(model.test_point)))

pymc3/blocking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def map(self, dpt):
4747
"""
4848
apt = np.empty(self.ordering.dimensions)
4949
for var, slc, _ in self.ordering.vmap:
50-
apt[slc] = np.ravel(dpt[var])
50+
apt[slc] = dpt[var].ravel()
5151
return apt
5252

5353
def rmap(self, apt):
@@ -61,7 +61,7 @@ def rmap(self, apt):
6161
dpt = self.dpt.copy()
6262

6363
for var, slc, shp in self.ordering.vmap:
64-
dpt[var] = np.reshape(np.atleast_1d(apt)[slc], shp)
64+
dpt[var] = apt[slc].reshape(shp)
6565

6666
return dpt
6767

pymc3/distributions/continuous.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,23 @@
99

1010
from .dist_math import *
1111
from numpy.random import uniform as runiform, normal as rnormal
12+
from .transforms import logtransform, logoddstransform, interval_transform
1213

1314
__all__ = ['Uniform', 'Flat', 'Normal', 'Beta', 'Exponential', 'Laplace',
1415
'T', 'StudentT', 'Cauchy', 'HalfCauchy', 'Gamma', 'Weibull','Bound',
1516
'Tpos', 'Lognormal', 'ChiSquared', 'HalfNormal', 'Wald',
1617
'Pareto', 'InverseGamma']
1718

19+
class PositiveContinuous(Continuous):
20+
"""Base class for positive continuous distributions"""
21+
def __init__(self, transform=logtransform, *args, **kwargs):
22+
super(PositiveContinuous, self).__init__(transform=transform, *args, **kwargs)
23+
24+
class UnitContinuous(Continuous):
25+
"""Base class for continuous distributions on [0,1]"""
26+
def __init__(self, transform=logoddstransform, *args, **kwargs):
27+
super(UnitContinuous, self).__init__(transform=transform, *args, **kwargs)
28+
1829
def get_tau_sd(tau=None, sd=None):
1930
"""
2031
Find precision and standard deviation
@@ -70,13 +81,17 @@ class Uniform(Continuous):
7081
upper : float
7182
Upper limit (defaults to 1)
7283
"""
73-
def __init__(self, lower=0, upper=1, *args, **kwargs):
84+
def __init__(self, lower=0, upper=1, transform='interval', *args, **kwargs):
7485
super(Uniform, self).__init__(*args, **kwargs)
86+
7587
self.lower = lower
7688
self.upper = upper
7789
self.mean = (upper + lower) / 2.
7890
self.median = self.mean
7991

92+
if transform is 'interval':
93+
self.transform = interval_transform(lower, upper)
94+
8095
def logp(self, value):
8196
lower = self.lower
8297
upper = self.upper
@@ -142,7 +157,7 @@ def logp(self, value):
142157
)
143158

144159

145-
class HalfNormal(Continuous):
160+
class HalfNormal(PositiveContinuous):
146161
"""
147162
Half-normal log-likelihood, a normal distribution with mean 0 limited
148163
to the domain :math:`x \in [0, \infty)`.
@@ -206,7 +221,7 @@ def logp(self, value):
206221

207222

208223

209-
class Beta(Continuous):
224+
class Beta(UnitContinuous):
210225
"""
211226
Beta log-likelihood. The conjugate prior for the parameter
212227
:math:`p` of the binomial distribution.
@@ -268,7 +283,7 @@ def logp(self, value):
268283
beta > 0)
269284

270285

271-
class Exponential(Continuous):
286+
class Exponential(PositiveContinuous):
272287
"""
273288
Exponential distribution
274289
@@ -320,7 +335,7 @@ def logp(self, value):
320335
return -log(2 * b) - abs(value - mu) / b
321336

322337

323-
class Lognormal(Continuous):
338+
class Lognormal(PositiveContinuous):
324339
"""
325340
Log-normal log-likelihood.
326341
@@ -409,7 +424,7 @@ def logp(self, value):
409424
StudentT = T
410425

411426

412-
class Pareto(Continuous):
427+
class Pareto(PositiveContinuous):
413428
"""
414429
Pareto log-likelihood. The Pareto is a continuous, positive
415430
probability distribution with two parameters. It is often used
@@ -482,7 +497,7 @@ def logp(self, value):
482497
value - alpha) / beta) ** 2),
483498
beta > 0)
484499

485-
class HalfCauchy(Continuous):
500+
class HalfCauchy(PositiveContinuous):
486501
"""
487502
Half-Cauchy log-likelihood. Simply the absolute value of Cauchy.
488503
@@ -510,7 +525,7 @@ def logp(self, value):
510525
value >= 0)
511526

512527

513-
class Gamma(Continuous):
528+
class Gamma(PositiveContinuous):
514529
"""
515530
Gamma log-likelihood.
516531
@@ -575,7 +590,7 @@ def logp(self, value):
575590
beta > 0)
576591

577592

578-
class InverseGamma(Continuous):
593+
class InverseGamma(PositiveContinuous):
579594
"""
580595
Inverse gamma log-likelihood, the reciprocal of the gamma distribution.
581596
@@ -634,7 +649,7 @@ def __init__(self, nu, *args, **kwargs):
634649
super(ChiSquared, self).__init__(alpha=nu/2., beta=0.5, *args, **kwargs)
635650

636651

637-
class Weibull(Continuous):
652+
class Weibull(PositiveContinuous):
638653
"""
639654
Weibull log-likelihood
640655

pymc3/distributions/distribution.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ def dist(cls, *args, **kwargs):
3131
dist.__init__(*args, **kwargs)
3232
return dist
3333

34-
def __init__(self, shape, dtype, testval=None, defaults=[]):
34+
def __init__(self, shape, dtype, testval=None, defaults=[], transform=None):
3535
self.shape = np.atleast_1d(shape)
3636
self.dtype = dtype
3737
self.type = TensorType(self.dtype, self.shape)
3838
self.testval = testval
3939
self.defaults = defaults
40+
self.transform = transform
4041

4142
def default(self):
4243
return self.get_test_val(self.testval, self.defaults)
@@ -80,15 +81,6 @@ def __init__(self, logp, shape=(), dtype='float64',testval=0, *args, **kwargs):
8081
super(DensityDist, self).__init__(shape, dtype, testval, *args, **kwargs)
8182
self.logp = logp
8283

83-
def TransformedVar(*args, **kwargs):
84-
try:
85-
model = Model.get_context()
86-
except TypeError:
87-
raise TypeError("No model on context stack, which is needed to use the Normal('x', 0,1) syntax. Add a 'with model:' block")
88-
89-
return model.TransformedVar(*args, **kwargs)
90-
91-
9284
class MultivariateContinuous(Continuous):
9385

9486
pass

pymc3/distributions/multivariate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .dist_math import *
44

55
import numpy as np
6+
from .transforms import simplextransform
67

78
from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
89
from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all
@@ -65,8 +66,8 @@ class Dirichlet(Continuous):
6566
Only the first `k-1` elements of `x` are expected. Can be used
6667
as a parent of Multinomial and Categorical nevertheless.
6768
"""
68-
def __init__(self, a, *args, **kwargs):
69-
super(Dirichlet, self).__init__(*args, **kwargs)
69+
def __init__(self, a, transform=simplextransform, *args, **kwargs):
70+
super(Dirichlet, self).__init__(transform=transform, *args, **kwargs)
7071
self.a = a
7172
self.k = a.shape[0]
7273
self.mean = a / sum(a)

pymc3/distributions/transforms.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, dist, transform, *args, **kwargs):
3838
testval = forward(dist.default())
3939

4040
self.dist = dist
41-
self.transform = transform
41+
self.transform_used = transform
4242
v = forward(FreeRV(name='v', distribution=dist))
4343
self.type = v.type
4444

@@ -49,12 +49,42 @@ def __init__(self, dist, transform, *args, **kwargs):
4949

5050

5151
def logp(self, x):
52-
return self.dist.logp(self.transform.backward(x)) + self.transform.jacobian_det(x)
52+
return self.dist.logp(self.transform_used.backward(x)) + self.transform_used.jacobian_det(x)
5353

5454
transform = Transform
5555

56+
5657
logtransform = transform("log", log, exp, idfn)
5758

59+
60+
logistic = t.nnet.sigmoid
61+
def logistic_jacobian(x):
62+
ex = exp(-x)
63+
return log(ex/(ex +1)**2)
64+
65+
def logit(x):
66+
return log(x/(1-x))
67+
logoddstransform = transform("logodds", logit, logistic, logistic_jacobian)
68+
69+
70+
def interval_transform(a, b):
71+
def interval_real(x):
72+
r= log((x-a)/(b-x))
73+
return r
74+
75+
def real_interval(x):
76+
r = (b-a)*exp(x)/(1+exp(x)) + a
77+
return r
78+
79+
def real_interval_jacobian(x):
80+
ex = exp(-x)
81+
jac = log(ex*(b-a)/(ex + 1)**2)
82+
return jac
83+
84+
return transform("interval", interval_real, real_interval, real_interval_jacobian)
85+
86+
87+
5888
simplextransform = transform("simplex",
5989
lambda p: p[:-1],
6090
lambda p: concatenate(

pymc3/examples/ARM12_6.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,14 @@ def run(n=3000):
5050
n = 50
5151
with model:
5252
start = {'groupmean': obs_means.mean(),
53-
'groupsd': obs_means.std(),
54-
'sd': data.groupby('group').lradon.std().mean(),
53+
'groupsd_interval': 0,
54+
'sd_interval': 0,
5555
'means': np.array(obs_means),
5656
'floor_m': 0.,
5757
}
5858

5959
start = find_MAP(start, [groupmean, sd, floor_m])
60-
H = model.fastd2logp()
61-
h = np.diag(H(start))
62-
63-
step = HamiltonianMC(model.vars, h)
60+
step = NUTS(model.vars, scaling=start)
6461

6562
trace = sample(n, step, start)
6663

pymc3/examples/ARM12_6uranium.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,14 @@ def run(n=3000):
5353
n = 50
5454
with model:
5555

56-
start = {'groupmean': obs_means.mean(),
57-
'groupsd': obs_means.std(),
58-
'sd': data.groupby('group').lradon.std().mean(),
56+
start = Point({
57+
'groupmean': obs_means.mean(),
58+
'groupsd_interval': 0,
59+
'sd_interval': 0,
5960
'means': np.array(obs_means),
6061
'u_m': np.array([.72]),
6162
'floor_m': 0.,
62-
}
63+
})
6364

6465
start = find_MAP(start, model.vars[:-1])
6566
H = model.fastd2logp()

pymc3/examples/dirichlet.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
k = 5
88
a = constant(np.array([2, 3., 4, 2, 2]))
99

10-
p, p_m1 = model.TransformedVar(
11-
'p', Dirichlet.dist(a, shape=k),
12-
simplextransform)
10+
p = Dirichlet('p', a, shape=k)
1311

1412
c = Categorical('c', p, observed=np.random.randint(0, k, 5))
1513

pymc3/examples/discrete_find_MAP.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@
4040

4141
with model:
4242
for i in range(n+1):
43-
s = {'p': 0.5, 'surv_sim': i}
43+
s = {'p_logodds': 0.5, 'surv_sim': i}
4444
map_est = mc.find_MAP(start=s, vars=model.vars,
4545
fmin=mc.starting.optimize.fmin_bfgs)
4646
print('surv_sim: %i->%i, p: %f->%f, LogP:%f'%(s['surv_sim'],
4747
map_est['surv_sim'],
48-
s['p'],
49-
map_est['p'],
48+
s['p_logodds'],
49+
map_est['p_logodds'],
5050
model.logp(map_est)))
5151

5252
# Once again because the gradient of `surv_sim` provides no information to the
@@ -58,12 +58,12 @@
5858

5959
with model:
6060
for i in range(n+1):
61-
s = {'p': 0.5, 'surv_sim': i}
61+
s = {'p_logodds': 0.0, 'surv_sim': i}
6262
map_est = mc.find_MAP(start=s, vars=model.vars)
6363
print('surv_sim: %i->%i, p: %f->%f, LogP:%f'%(s['surv_sim'],
6464
map_est['surv_sim'],
65-
s['p'],
66-
map_est['p'],
65+
s['p_logodds'],
66+
map_est['p_logodds'],
6767
model.logp(map_est)))
6868

6969
# For most starting values this converges to the maximum log likelihood of
@@ -82,12 +82,12 @@ def bh(*args, **kwargs):
8282

8383
with model:
8484
for i in range(n+1):
85-
s = {'p': 0.5, 'surv_sim': i}
85+
s = {'p_logodds': 0.0, 'surv_sim': i}
8686
map_est = mc.find_MAP(start=s, vars=model.vars, fmin=bh)
8787
print('surv_sim: %i->%i, p: %f->%f, LogP:%f'%(s['surv_sim'],
8888
map_est['surv_sim'],
89-
s['p'],
90-
map_est['p'],
89+
s['p_logodds'],
90+
map_est['p_logodds'],
9191
model.logp(map_est)))
9292

9393
# By default `basinhopping` uses a gradient minimization technique,
@@ -96,13 +96,13 @@ def bh(*args, **kwargs):
9696

9797
with model:
9898
for i in range(n+1):
99-
s = {'p': 0.5, 'surv_sim': i}
99+
s = {'p_logodds': 0.0, 'surv_sim': i}
100100
map_est = mc.find_MAP(start=s, vars=model.vars,
101101
fmin=bh, minimizer_kwargs={"method": "Powell"})
102102
print('surv_sim: %i->%i, p: %f->%f, LogP:%f'%(s['surv_sim'],
103103
map_est['surv_sim'],
104-
s['p'],
105-
map_est['p'],
104+
s['p_logodds'],
105+
map_est['p_logodds'],
106106
model.logp(map_est)))
107107

108108
# Confident in our MAP estimate we can sample from the posterior, making sure

0 commit comments

Comments
 (0)