Skip to content

Commit a7f8cc8

Browse files
authored
Move support of InverseFunctions and ChangesOfVariables to extensions
1 parent d245bcf commit a7f8cc8

File tree

5 files changed

+34
-11
lines changed

5 files changed

+34
-11
lines changed

Project.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,21 @@ version = "0.8.12"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
8-
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
98
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
109
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
11-
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
1210
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1311
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1412
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1513
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1614

15+
[weakdeps]
16+
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
17+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
18+
19+
[extensions]
20+
ChangesOfVariablesExt = "ChangesOfVariables"
21+
InverseFunctionsExt = "InverseFunctions"
22+
1723
[compat]
1824
ArgCheck = "1, 2"
1925
ChangesOfVariables = "0.1"

ext/ChangesOfVariablesExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module ChangesOfVariablesExt
2+
3+
import TransformVariables
4+
import ChangesOfVariables
5+
6+
function ChangesOfVariables.with_logabsdet_jacobian(f::TransformVariables.CallableTransform, x)
7+
return TransformVariables.transform_and_logjac(f.t, x)
8+
end
9+
function ChangesOfVariables.with_logabsdet_jacobian(f::TransformVariables.CallableInverse, x)
10+
return TransformVariables.inverse_and_logjac(f.t, x)
11+
end
12+
13+
end # module

ext/InverseFunctionsExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module InverseFunctionsExt
2+
3+
import TransformVariables
4+
import InverseFunctions
5+
6+
function InverseFunctions.inverse(f::TransformVariables.CallableTransform)
7+
return TransformVariables.inverse(f)
8+
end
9+
function InverseFunctions.inverse(f::TransformVariables.CallableInverse)
10+
return TransformVariables.inverse(f)
11+
end
12+
13+
end # module

src/TransformVariables.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@ using LinearAlgebra: UpperTriangular, logabsdet
88
using Random: AbstractRNG, GLOBAL_RNG
99
using StaticArrays: MMatrix, SMatrix, SArray, SVector, pushfirst
1010

11-
import ChangesOfVariables
12-
import InverseFunctions
13-
1411
include("utilities.jl")
1512
include("generic.jl")
1613
include("scalar.jl")

src/generic.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,6 @@ end
143143

144144
transform(t) = CallableTransform(t)
145145

146-
ChangesOfVariables.with_logabsdet_jacobian(f::CallableTransform, x) = transform_and_logjac(f.t, x)
147-
148146
"""
149147
$(TYPEDEF)
150148
@@ -160,12 +158,8 @@ function Base.show(io::IO, f::CallableInverse)
160158
end
161159

162160
inverse(f::CallableInverse) = CallableTransform(f.t)
163-
InverseFunctions.inverse(f::CallableInverse) = CallableTransform(f.t)
164161

165162
inverse(f::CallableTransform) = CallableInverse(f.t)
166-
InverseFunctions.inverse(f::CallableTransform) = CallableInverse(f.t)
167-
168-
ChangesOfVariables.with_logabsdet_jacobian(f::CallableInverse, x) = inverse_and_logjac(f.t, x)
169163

170164
"""
171165
`$(FUNCTIONNAME)(t, y)`

0 commit comments

Comments
 (0)