Skip to content

Commit 1294a0b

Browse files
committed
More AxisArray constructors
- also in the background @defdim is broken up into smaller macros
1 parent c149b00 commit 1294a0b

File tree

8 files changed

+233
-123
lines changed

8 files changed

+233
-123
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AxisIndices"
22
uuid = "f52c9ee2-1b1c-4fd8-8546-6350938c7f11"
33
authors = ["Tokazama <[email protected]>"]
4-
version = "0.4.4"
4+
version = "0.4.5"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/Arrays/AbstractAxisArray.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,10 @@ Base.vcat(A::AbstractAxisArray{T,N}) where {T,N} = A
325325

326326
Base.cat(A::AbstractAxisArray{T,N}; dims) where {T,N} = A
327327

328+
function Base.convert(::Type{T}, A::AbstractArray) where {T<:AbstractAxisArray}
329+
if A isa T
330+
return A
331+
else
332+
return T(A)
333+
end
334+
end

src/Arrays/AxisArray.jl

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,17 +225,15 @@ julia> size(AxisArray{Int,2}(undef, (["a", "b"], [:one, :two])))
225225
(2, 2)
226226
227227
"""
228-
function AxisArray{T,N}(x::AbstractArray{T2,N}, axis_keys::Tuple, check_length::Bool=true) where {T,T2,N}
229-
return AxisArray{T,N}(map(T, x), axis_keys, check_length)
228+
function AxisArray{T,N}(A::AbstractArray{T2,N}, axis_keys::Tuple, check_length::Bool=true) where {T,T2,N}
229+
return AxisArray{T,N}(copyto!(Array{T}(undef, size(A)), A), axis_keys, check_length)
230230
end
231231

232232
function AxisArray{T,N}(init::ArrayInitializer, args...) where {T,N}
233233
return AxisArray{T,N}(init, args)
234234
end
235235

236-
function AxisArray{T,N}(x::AbstractArray, args...) where {T,N}
237-
return AxisArray{T,N}(x, args)
238-
end
236+
AxisArray{T,N}(x::AbstractArray, args...) where {T,N} = AxisArray{T,N}(x, args)
239237

240238
function AxisArray{T,N}(init::ArrayInitializer, axs::Tuple{Vararg{Any,N}}) where {T,N}
241239
return AxisArray{T,N}(init, map(to_axis, axs))
@@ -246,14 +244,27 @@ function AxisArray{T,N}(init::ArrayInitializer, axs::AbstractAxes{N}) where {T,N
246244
return AxisArray{T,N,typeof(p),typeof(axs)}(p, axs)
247245
end
248246

249-
function AxisArray{T,N}(
250-
x::AbstractArray{T,N},
251-
axs::Tuple{Vararg{Any,N2}},
247+
function AxisArray{T,N}(x::AbstractArray{T,N}, axs::Tuple, check_length::Bool=true) where {T,N}
248+
return AxisArray{T,N,typeof(x)}(x, axs, check_length)
249+
end
250+
251+
###
252+
### AxisArray{T,N,P}
253+
###
254+
AxisArray{T,N,P}(A::AbstractArray, args...) where {T,N,P} = AxisArray{T,N,P}(A, args)
255+
256+
function AxisArray{T,N,P}(
257+
x::AbstractArray,
258+
axs::Tuple,
252259
check_length::Bool=true
253-
) where {T,N,N2}
260+
) where {T,N,P}
261+
262+
return AxisArray{T,N,P}(convert(P, x), axs, check_length)
263+
end
254264

265+
function AxisArray{T,N,P}(x::P, axs::Tuple, check_length::Bool=true) where {T,N,P<:AbstractArray{T,N}}
255266
axs = to_axes((), axs, axes(x), check_length, Staticness(x))
256-
return AxisArray{T,N,typeof(x),typeof(axs)}(x, axs)
267+
return AxisArray{T,N,P,typeof(axs)}(x, axs)
257268
end
258269

259270
###
@@ -301,3 +312,8 @@ function Base.reshape(A::AbstractArray, shp::Tuple{<:AbstractAxis,Vararg{<:Abstr
301312
return AxisArray{eltype(p),ndims(p),typeof(p),typeof(axs)}(p, axs)
302313
end
303314

315+
AxisArray{T,N,P}(A::AxisArray{T,N,P}) where {T,N,P} = A
316+
317+
function AxisArray{T,N,P}(A::AxisArray) where {T,N,P}
318+
return AxisArray{T,N,P}(convert(P, parent(A)), axes(A))
319+
end

src/Arrays/OffsetArray.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,28 @@ end
3737

3838
OffsetArray{T,N}(A, inds::Vararg) where {T,N} = OffsetArray{T,N}(A, inds)
3939

40-
@inline function OffsetArray{T,N}(A::AbstractArray{T,N}, inds::NTuple{M,Any}) where {T,N,M}
40+
function OffsetArray{T,N}(A::AbstractArray{T,N}, inds::Tuple) where {T,N}
41+
return OffsetArray{T,N,typeof(A)}(A, inds)
42+
end
43+
44+
function OffsetArray{T,N}(A::AbstractArray{T2,N}, inds::Tuple) where {T,T2,N}
45+
return OffsetArray{T,N}(copyto!(Array{T}(undef, size(A)), A), inds)
46+
end
47+
48+
49+
function OffsetArray{T,N,P}(A::AbstractArray, inds::NTuple{M,Any}) where {T,N,P<:AbstractArray{T,N},M}
50+
return OffsetArray{T,N,P}(convert(P, A))
51+
end
52+
53+
OffsetArray{T,N,P}(A::OffsetArray{T,N,P}) where {T,N,P} = A
54+
55+
function OffsetArray{T,N,P}(A::OffsetArray) where {T,N,P}
56+
p = convert(P, parent(A))
57+
axs = map(assign_indices, axes(A), axes(p))
58+
return AxisArray{T,N,P,typeof(axs)}(p, axs)
59+
end
60+
61+
function OffsetArray{T,N,P}(A::P, inds::NTuple{M,Any}) where {T,N,P<:AbstractArray{T,N},M}
4162
S = Staticness(A)
4263
if N === M
4364
axs = map((x, y) -> OffsetAxis(as_staticness(S, x), as_staticness(S, y)), inds, axes(A))

src/Interface/names.jl

Lines changed: 117 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,94 @@ end
6161
# So `condition` also has to be pure, which shouldn't be hard because it should basically
6262
# just be comparing symbols
6363

64+
macro def_naxis(name, name_dim)
65+
nname = Symbol(:n, name)
66+
nname_doc = """
67+
$nname(x) -> Int
68+
69+
Returns the size along the dimension corresponding to the $name.
70+
"""
71+
72+
esc(quote
73+
@doc $nname_doc
74+
@inline $nname(x) = Base.size(x, $name_dim(x))
75+
end)
76+
end
77+
78+
79+
macro def_axis_keys(name, name_dims)
80+
name_keys = Symbol(name, :_keys)
81+
name_keys_doc = """
82+
$name_keys(x)
83+
84+
Returns the keys corresponding to the $name axis
85+
"""
86+
esc(quote
87+
@doc $name_keys_doc
88+
@inline $name_keys(x) = keys(axes(x, $name_dims(x)))
89+
end)
90+
end
91+
92+
93+
94+
macro def_axis_indices(name, name_dim)
95+
name_indices = Symbol(name, :_indices)
96+
name_indices_doc = """
97+
$name_indices(x)
98+
99+
Returns the indices corresponding to the $name axis
100+
"""
101+
esc(quote
102+
@doc $name_indices_doc
103+
@inline $name_indices(x) = indices(axes(x, $name_dim(x)))
104+
end)
105+
end
106+
107+
108+
# TODO I'm not sure this is the best name for this one
109+
macro def_axis_type(name, name_dim)
110+
name_type = Symbol(name, :_axis_type)
111+
name_type_doc = """
112+
$name_type(x)
113+
114+
Returns the key type corresponding to the $name axis.
115+
"""
116+
117+
esc(quote
118+
@doc $name_type_doc
119+
@inline $name_type(x) = keytype(axes(x, $name_dim(x)))
120+
end)
121+
end
122+
123+
macro def_selectdim(name, name_dim)
124+
name_selectdim = Symbol(:select_, name, :dim)
125+
name_selectdim_doc = """
126+
$name_selectdim(x, i)
127+
128+
Return a view of all the data of `x` where the index for the $name dimension equals `i`.
129+
"""
130+
131+
esc(quote
132+
@doc $name_selectdim_doc
133+
@inline $name_selectdim(x, i) = selectdim(x, $name_dim(x), i)
134+
135+
end)
136+
end
137+
138+
macro def_eachslice(name, name_dim)
139+
each_name = Symbol(:each_, name)
140+
each_name_doc = """
141+
$each_name(x)
142+
143+
Create a generator that iterates over the $name dimensions `A`, returning views that select
144+
all the data from the other dimensions in `A`.
145+
"""
146+
esc(quote
147+
@doc $each_name_doc
148+
@inline $each_name(x) = eachslice(x, dims=$name_dim(x))
149+
end)
150+
end
151+
64152
"""
65153
@defdim name condition
66154
@@ -92,7 +180,16 @@ julia> @defdim time is_time
92180
`@defdim` should be considered experimental and subject to change
93181
94182
"""
95-
macro defdim(name, condition)
183+
macro defdim(
184+
name,
185+
condition,
186+
def_naxis::Bool=true,
187+
def_axis_keys::Bool=true,
188+
def_axis_indices::Bool=true,
189+
def_axis_type::Bool=true,
190+
def_selectdim::Bool=true,
191+
def_eachslice::Bool=true,
192+
)
96193

97194
dim_noerror_name = Symbol(:dim_noerror_, name)
98195

@@ -103,13 +200,6 @@ macro defdim(name, condition)
103200
Returns the dimension corresponding to $name.
104201
"""
105202

106-
nname = Symbol(:n, name)
107-
nname_doc = """
108-
$nname(x) -> Int
109-
110-
Returns the size along the dimension corresponding to the $name.
111-
"""
112-
113203
has_name_dim = Symbol(:has_, name, :dim)
114204
has_name_dim_doc = """
115205
$has_name_dim(x) -> Bool
@@ -130,42 +220,6 @@ macro defdim(name, condition)
130220
Returns an `AxisIterator` along the $name axis.
131221
"""
132222

133-
name_indices = Symbol(name, :_indices)
134-
name_indices_doc = """
135-
$name_indices(x)
136-
137-
Returns the indices corresponding to the $name axis
138-
"""
139-
140-
name_keys = Symbol(name, :_keys)
141-
name_keys_doc = """
142-
$name_keys(x)
143-
144-
Returns the keys corresponding to the $name axis
145-
"""
146-
147-
name_type = Symbol(name, :_axis_type)
148-
name_type_doc = """
149-
$name_type(x)
150-
151-
Returns the key type corresponding to the $name axis.
152-
"""
153-
154-
name_selectdim = Symbol(:select_, name, :dim)
155-
name_selectdim_doc = """
156-
$name_selectdim(x, i)
157-
158-
Return a view of all the data of `x` where the index for the $name dimension equals `i`.
159-
"""
160-
161-
each_name = Symbol(:each_, name)
162-
each_name_doc = """
163-
$each_name(x)
164-
165-
Create a generator that iterates over the $name dimensions `A`, returning views that select
166-
all the data from the other dimensions in `A`.
167-
"""
168-
169223
err_msg = "Method $(Symbol(condition)) is not true for any dimensions of "
170224

171225
esc(quote
@@ -186,9 +240,6 @@ macro defdim(name, condition)
186240
end
187241
end
188242

189-
@doc $nname_doc
190-
@inline $nname(x) = Base.size(x, $name_dim(x))
191-
192243
@doc $has_name_dim_doc
193244
@inline $has_name_dim(x) = !($dim_noerror_name(dimnames(x)) === 0)
194245

@@ -198,22 +249,30 @@ macro defdim(name, condition)
198249
@doc $name_axis_itr
199250
@inline $name_axis(x, sz; kwargs...) = AxisIterator(axes(x, $name_dim(x)), sz; kwargs...)
200251

201-
@doc $name_keys_doc
202-
@inline $name_keys(x) = keys($name_axis(x))
252+
if $def_naxis
253+
Interface.@def_naxis($name, $name_dim)
254+
end
203255

204-
@doc $name_indices_doc
205-
@inline $name_indices(x) = values($name_axis(x))
256+
if $def_axis_keys
257+
Interface.@def_axis_keys($name, $name_dim)
258+
end
206259

207-
@doc $name_type_doc
208-
@inline $name_type(x) = keytype($name_axis(x))
260+
if $def_axis_indices
261+
Interface.@def_axis_indices($name, $name_dim)
262+
end
209263

210-
@doc $name_selectdim_doc
211-
@inline $name_selectdim(x, i) = selectdim(x, $name_dim(x), i)
264+
if $def_axis_type
265+
Interface.@def_axis_type($name, $name_dim)
266+
end
212267

213-
@doc $each_name_doc
214-
@inline $each_name(x) = eachslice(x, dims=$name_dim(x))
268+
if $def_selectdim
269+
Interface.@def_selectdim($name, $name_dim)
270+
end
271+
272+
if $def_eachslice
273+
Interface.@def_eachslice($name, $name_dim)
274+
end
215275

216276
nothing
217277
end)
218278
end
219-

test/Arrays/Arrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
include("vectors.jl")
3+
include("AxisArray.jl")
34

45

56
@testset "permuteddimsview" begin

0 commit comments

Comments
 (0)