Skip to content

Commit e504d23

Browse files
authored
Merge pull request #122 from devmotion/dw/callable
Make `CallableTransform` and `CallableInverse` aliases of `Base.Fix1`
2 parents 9a144ca + f5e51cc commit e504d23

File tree

3 files changed

+25
-33
lines changed

3 files changed

+25
-33
lines changed

ext/ChangesOfVariablesExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ import TransformVariables
44
import ChangesOfVariables
55

66
function ChangesOfVariables.with_logabsdet_jacobian(f::TransformVariables.CallableTransform, x)
7-
return TransformVariables.transform_and_logjac(f.t, x)
7+
return TransformVariables.transform_and_logjac(f.x, x)
88
end
99
function ChangesOfVariables.with_logabsdet_jacobian(f::TransformVariables.CallableInverse, x)
10-
return TransformVariables.inverse_and_logjac(f.t, x)
10+
return TransformVariables.inverse_and_logjac(f.x, x)
1111
end
1212

1313
end # module

src/generic.jl

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -125,41 +125,21 @@ transform(t, x) == transform(t)(x)
125125
"""
126126
function transform end
127127

128+
transform(t::AbstractTransform) = Base.Fix1(transform, t)
129+
128130
"""
129131
$(TYPEDEF)
130132
131133
Partial application of `transform(t, x)`, callable with `x`. Use `transform(t)` to
132134
construct.
133135
"""
134-
struct CallableTransform{T}
135-
t::T
136-
end
137-
138-
(f::CallableTransform)(x) = transform(f.t, x)
136+
const CallableTransform{T} = Base.Fix1{typeof(transform),T} where {T<:AbstractTransform}
139137

140138
function Base.show(io::IO, f::CallableTransform)
141-
print(io, "callable for the transformation $(f.t)")
142-
end
143-
144-
transform(t) = CallableTransform(t)
145-
146-
"""
147-
$(TYPEDEF)
148-
149-
Partial application of `inverse(t, y)`, callable with `y`. Use `inverse(t)` to
150-
construct.
151-
"""
152-
struct CallableInverse{T}
153-
t::T
154-
end
155-
156-
function Base.show(io::IO, f::CallableInverse)
157-
print(io, "callable inverse for the transformation $(f.t)")
139+
print(io, "callable for the transformation ", f.x)
158140
end
159141

160-
inverse(f::CallableInverse) = CallableTransform(f.t)
161-
162-
inverse(f::CallableTransform) = CallableInverse(f.t)
142+
inverse(f::CallableTransform) = Base.Fix1(inverse, f.x)
163143

164144
"""
165145
`$(FUNCTIONNAME)(t, y)`
@@ -174,9 +154,21 @@ with transform, so the following holds:
174154
inverse(t)(y) == inverse(t, y) == inverse(transform(t))(y)
175155
```
176156
"""
177-
inverse(t::AbstractTransform) = CallableInverse(t)
157+
inverse(t::AbstractTransform) = Base.Fix1(inverse, t)
158+
159+
"""
160+
$(TYPEDEF)
161+
162+
Partial application of `inverse(t, y)`, callable with `y`. Use `inverse(t)` to
163+
construct.
164+
"""
165+
const CallableInverse{T} = Base.Fix1{typeof(inverse),T} where {T<:AbstractTransform}
166+
167+
function Base.show(io::IO, f::CallableInverse)
168+
print(io, "callable inverse for the transformation ", f.x)
169+
end
178170

179-
(f::CallableInverse)(y) = inverse(f.t, y)
171+
inverse(f::CallableInverse) = Base.Fix1(transform, f.x)
180172

181173
"""
182174
`$(FUNCTIONNAME)(t::AbstractTransform, y)`

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,17 +583,17 @@ end
583583

584584
@testset "ChangesOfVariables" begin
585585
t = as(Real, 1.0, 3.0)
586-
f = TransformVariables.CallableTransform(t)
587-
inv_f = TransformVariables.CallableInverse(t)
586+
f = transform(t)
587+
inv_f = inverse(t)
588588
ChangesOfVariables.test_with_logabsdet_jacobian(f, -4.2, ForwardDiff.derivative)
589589
ChangesOfVariables.test_with_logabsdet_jacobian(inv_f, 1.7, ForwardDiff.derivative)
590590
end
591591

592592

593593
@testset "InverseFunctions" begin
594594
t = as(Real, 1.0, 3.0)
595-
f = TransformVariables.CallableTransform(t)
596-
inv_f = TransformVariables.CallableInverse(t)
595+
f = transform(t)
596+
inv_f = inverse(t)
597597
InverseFunctions.test_inverse(f, -4.2)
598598
InverseFunctions.test_inverse(inv_f, 1.7)
599599
end

0 commit comments

Comments
 (0)