Skip to content

Commit 32b325b

Browse files
feat: don't collect FillArrays for reactant (#1471)
* feat: don't collect FillArrays for reactant * Update lib/MLDataDevices/src/internal.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix: missing input * fix: reactant device merging --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 48822f4 commit 32b325b

File tree

6 files changed

+118
-58
lines changed

6 files changed

+118
-58
lines changed

lib/MLDataDevices/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLDataDevices"
22
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.11.3"
4+
version = "1.12.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,33 @@
11
module MLDataDevicesFillArraysExt
22

33
using Adapt: Adapt
4-
using FillArrays: AbstractFill
5-
using MLDataDevices: CPUDevice, AbstractDevice, Internal
4+
using FillArrays: AbstractFill, OneElement, Fill, Ones, Zeros
5+
using MLDataDevices: CPUDevice, ReactantDevice, AbstractDevice, Internal
66

77
Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x
8+
function Adapt.adapt_structure(dev::CPUDevice, x::Ones{T}) where {T}
9+
return Ones{Adapt.adapt(dev, T)}(axes(x))
10+
end
11+
function Adapt.adapt_structure(dev::CPUDevice, x::Zeros{T}) where {T}
12+
return Zeros{Adapt.adapt(dev, T)}(axes(x))
13+
end
14+
Adapt.adapt_structure(dev::CPUDevice, x::Fill) = Fill(Adapt.adapt(dev, x.value), axes(x))
15+
function Adapt.adapt_structure(dev::CPUDevice, x::OneElement)
16+
return OneElement(Adapt.adapt(dev, x.val), x.ind, x.axes)
17+
end
18+
19+
Adapt.adapt_structure(dev::ReactantDevice, x::AbstractFill) = Internal.to_rarray(dev, x)
20+
Adapt.adapt_structure(dev::ReactantDevice, x::OneElement) = Internal.to_rarray(dev, x)
21+
822
Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x))
23+
Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x))
24+
25+
Internal.get_device(::AbstractFill{T}) where {T} = Internal.get_device(T)
26+
Internal.get_device(f::Fill) = Internal.get_device(f.value)
27+
Internal.get_device(e::OneElement) = Internal.get_device(e.val)
928

10-
Internal.get_device(::AbstractFill) = CPUDevice()
11-
Internal.get_device_type(::AbstractFill) = CPUDevice
29+
Internal.get_device_type(::AbstractFill{T}) where {T} = Internal.get_device_type(T)
30+
Internal.get_device_type(f::Fill) = Internal.get_device_type(f.value)
31+
Internal.get_device_type(e::OneElement) = Internal.get_device_type(e.val)
1232

1333
end

lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,45 @@ function Reactant.make_tracer(
2727
return prev
2828
end
2929

30+
# Call into Reactant.to_rarray
31+
function device_to_kwargs(dev::ReactantDevice, x)
32+
kwargs = (;)
33+
dev.client === missing || (kwargs = (; kwargs..., client=dev.client))
34+
dev.device === missing || (kwargs = (; kwargs..., device=dev.device))
35+
if dev.sharding !== missing
36+
if dev.sharding isa IdDict
37+
sharding = (
38+
haskey(dev.sharding, x) ? dev.sharding[x] : Reactant.Sharding.NoSharding()
39+
)
40+
@assert sharding isa Reactant.Sharding.AbstractSharding
41+
kwargs = (; kwargs..., sharding)
42+
elseif dev.sharding isa Reactant.Sharding.AbstractSharding
43+
kwargs = (; kwargs..., dev.sharding)
44+
else
45+
throw(ArgumentError("`sharding` must be an `IdDict` or a \
46+
`Reactant.Sharding.AbstractSharding` but got \
47+
$(typeof(dev.sharding))."))
48+
end
49+
end
50+
return kwargs
51+
end
52+
53+
function Internal.to_rarray_internal(dev::ReactantDevice, x)
54+
return Reactant.to_rarray(x; device_to_kwargs(dev, x)...)
55+
end
56+
3057
# Default RNG
3158
MLDataDevices.default_device_rng(::ReactantDevice) = Reactant.TracedRandom.default_rng()
3259

3360
# Query Device from Array
34-
@static if isdefined(Reactant, :ConcreteIFRTArray)
35-
const AllConcreteTypes = Union{
36-
Reactant.ConcreteIFRTNumber,
37-
Reactant.ConcreteIFRTArray,
38-
Reactant.ConcretePJRTNumber,
39-
Reactant.ConcretePJRTArray,
40-
}
41-
elseif isdefined(Reactant, :ConcretePJRTArray)
42-
const AllConcreteTypes = Union{Reactant.ConcretePJRTNumber,Reactant.ConcretePJRTArray}
43-
else
44-
const AllConcreteTypes = Union{ConcreteRNumber,ConcreteRArray}
45-
end
61+
const AllConcreteTypes = Union{
62+
<:Reactant.ConcreteIFRTNumber,
63+
<:Reactant.ConcreteIFRTArray,
64+
<:Reactant.ConcretePJRTNumber,
65+
<:Reactant.ConcretePJRTArray,
66+
}
4667

68+
Internal.get_device(::Type{<:AllConcreteTypes}) = ReactantDevice()
4769
function Internal.get_device(x::AllConcreteTypes)
4870
return ReactantDevice(
4971
Reactant.XLA.client(x),
@@ -53,6 +75,7 @@ function Internal.get_device(x::AllConcreteTypes)
5375
),
5476
)
5577
end
78+
Internal.get_device_type(::Type{<:AllConcreteTypes}) = ReactantDevice
5679
Internal.get_device_type(::AllConcreteTypes) = ReactantDevice
5780

5881
function Internal.get_device(::Union{TracedRArray,TracedRNumber})
@@ -69,26 +92,21 @@ Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothi
6992
Profiler.@annotate "Device Transfer (Reactant)" function Adapt.adapt_storage(
7093
dev::ReactantDevice, x::AbstractArray
7194
)
72-
kwargs = (;)
73-
dev.client === missing || (kwargs = (; kwargs..., client=dev.client))
74-
dev.device === missing || (kwargs = (; kwargs..., device=dev.device))
75-
if dev.sharding !== missing
76-
if dev.sharding isa IdDict
77-
sharding =
78-
haskey(dev.sharding, x) ? dev.sharding[x] : Reactant.Sharding.NoSharding()
79-
@assert sharding isa Reactant.Sharding.AbstractSharding
80-
kwargs = (; kwargs..., sharding)
81-
elseif dev.sharding isa Reactant.Sharding.AbstractSharding
82-
kwargs = (; kwargs..., dev.sharding)
83-
else
84-
throw(ArgumentError("`sharding` must be an `IdDict` or a \
85-
`Reactant.Sharding.AbstractSharding` but got \
86-
$(typeof(dev.sharding))."))
87-
end
88-
end
89-
return ConcreteRArray(x; kwargs...)
95+
return ConcreteRArray(x; device_to_kwargs(dev, x)...)
9096
end
9197

98+
function Adapt.adapt_storage(
99+
::CPUDevice,
100+
T::Type{<:Union{<:Reactant.ConcretePJRTNumber,<:Reactant.ConcreteIFRTNumber}},
101+
)
102+
return Reactant.unwrapped_eltype(T)
103+
end
104+
105+
function Adapt.adapt_storage(
106+
::CPUDevice, x::Union{<:Reactant.ConcretePJRTNumber,<:Reactant.ConcreteIFRTNumber}
107+
)
108+
return Reactant.unwrapped_eltype(x)(x)
109+
end
92110
Adapt.adapt_storage(::CPUDevice, ::Reactant.ReactantRNG) = Random.default_rng()
93111

94112
end

lib/MLDataDevices/src/internal.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,23 @@ combine_devices(::AbstractDevice, dev::ReactantDevice) = dev
149149
function combine_devices(dev1::ReactantDevice, dev2::ReactantDevice)
150150
if dev1 == dev2
151151
# `merge(...)` of `IdDict` constructs a `Dict`
152-
sharding = dev1.sharding
153-
for (k, v) in dev2.sharding
154-
sharding[k] = v
152+
if dev1.sharding isa IdDict
153+
sharding = dev1.sharding
154+
if dev2.sharding isa IdDict
155+
for (k, v) in dev2.sharding
156+
sharding[k] = v
157+
end
158+
end
159+
elseif dev2.sharding isa IdDict
160+
sharding = dev2.sharding
161+
else
162+
sharding = missing
155163
end
156-
return ReactantDevice(dev1.client, dev1.device, sharding)
164+
165+
client = dev1.client === missing ? dev2.client : dev1.client
166+
device = dev1.device === missing ? dev2.device : dev1.device
167+
168+
return ReactantDevice(client, device, sharding)
157169
end
158170
throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2)."))
159171
end
@@ -211,6 +223,9 @@ for op in (:get_device, :get_device_type)
211223
end
212224
end
213225

226+
get_device(::Type{<:Number}) = CPUDevice()
227+
get_device_type(::Type{<:Number}) = CPUDevice
228+
214229
get_device(_) = UnknownDevice()
215230
get_device_type(_) = UnknownDevice
216231

@@ -263,4 +278,11 @@ end
263278

264279
static_length(t::Tuple) = Val(length(t))
265280

281+
function to_rarray(args...; kwargs...)
282+
loaded(ReactantDevice) && return to_rarray_internal(args...; kwargs...)
283+
return error("`to_rarray` is only supported with `Reactant` loaded.")
284+
end
285+
286+
function to_rarray_internal end
287+
266288
end

lib/MLDataDevices/src/public.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,14 @@ struct oneAPIDevice <: AbstractGPUDevice end
1616
end
1717

1818
function Base.:(==)(x::ReactantDevice, y::ReactantDevice)
19-
if x.client !== missing
20-
y.client === missing && return false
21-
x.client.client != y.client.client && return false
22-
else
23-
y.client !== missing && return false
19+
if x.client !== missing && y.client !== missing && x.client.client != y.client.client
20+
return false
2421
end
25-
if x.device !== missing
26-
y.device === missing && return false
27-
x.device.device != y.device.device && return false
28-
else
29-
y.device !== missing && return false
22+
23+
if x.device !== missing && y.device !== missing && x.device.device != y.device.device
24+
return false
3025
end
26+
3127
return true
3228
end
3329

lib/MLDataDevices/test/reactant_tests.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,16 @@ using FillArrays, Zygote # Extensions
4141
rng=MersenneTwister(),
4242
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)),
4343
farray=Fill(1.0f0, (2, 3)),
44+
one_elem2=FillArrays.OneElement(2.0f0, (2, 3), (1:3, 1:4)),
45+
zeros_fa=Zeros{Float32}((2, 3)),
46+
ones_fa=Ones{Float32}((2, 3)),
4447
)
4548

4649
device = reactant_device()
4750
aType = MLDataDevices.functional(ReactantDevice) ? Reactant.ConcreteRArray : Array
48-
rngType =
51+
rngType = (
4952
MLDataDevices.functional(ReactantDevice) ? Reactant.ReactantRNG : Random.AbstractRNG
53+
)
5054

5155
ps_xpu = device(ps)
5256
@test get_device(ps_xpu) isa ReactantDevice
@@ -70,10 +74,10 @@ using FillArrays, Zygote # Extensions
7074

7175
if MLDataDevices.functional(ReactantDevice)
7276
@test ps_xpu.one_elem isa Reactant.RArray
73-
@test ps_xpu.farray isa Reactant.RArray
74-
else
75-
@test ps_xpu.one_elem isa Zygote.OneElement
76-
@test ps_xpu.farray isa Fill
77+
@test ps_xpu.farray isa Fill{<:Reactant.ConcreteRNumber}
78+
@test ps_xpu.one_elem2 isa FillArrays.OneElement{<:Reactant.ConcreteRNumber}
79+
@test ps_xpu.zeros_fa isa Zeros{<:Reactant.ConcreteRNumber}
80+
@test ps_xpu.ones_fa isa Ones{<:Reactant.ConcreteRNumber}
7781
end
7882

7983
ps_cpu = cpu_device()(ps_xpu)
@@ -100,10 +104,10 @@ using FillArrays, Zygote # Extensions
100104

101105
if MLDataDevices.functional(ReactantDevice)
102106
@test ps_cpu.one_elem isa Array
103-
@test ps_cpu.farray isa Array
104-
else
105-
@test ps_cpu.one_elem isa Zygote.OneElement
106-
@test ps_cpu.farray isa Fill
107+
@test ps_cpu.farray isa Fill{Float32}
108+
@test ps_cpu.one_elem2 isa FillArrays.OneElement{Float32}
109+
@test ps_cpu.zeros_fa isa Zeros{Float32}
110+
@test ps_cpu.ones_fa isa Ones{Float32}
107111
end
108112

109113
ps_mixed = (; a=rand(2), b=device(rand(2)))

0 commit comments

Comments
 (0)