@@ -27,23 +27,45 @@ function Reactant.make_tracer(
2727 return prev
2828end
2929
30+ # Call into Reactant.to_rarray
31+ function device_to_kwargs (dev:: ReactantDevice , x)
32+ kwargs = (;)
33+ dev. client === missing || (kwargs = (; kwargs... , client= dev. client))
34+ dev. device === missing || (kwargs = (; kwargs... , device= dev. device))
35+ if dev. sharding != = missing
36+ if dev. sharding isa IdDict
37+ sharding = (
38+ haskey (dev. sharding, x) ? dev. sharding[x] : Reactant. Sharding. NoSharding ()
39+ )
40+ @assert sharding isa Reactant. Sharding. AbstractSharding
41+ kwargs = (; kwargs... , sharding)
42+ elseif dev. sharding isa Reactant. Sharding. AbstractSharding
43+ kwargs = (; kwargs... , dev. sharding)
44+ else
45+ throw (ArgumentError (" `sharding` must be an `IdDict` or a \
46+ `Reactant.Sharding.AbstractSharding` but got \
47+ $(typeof (dev. sharding)) ." ))
48+ end
49+ end
50+ return kwargs
51+ end
52+
53+ function Internal. to_rarray_internal (dev:: ReactantDevice , x)
54+ return Reactant. to_rarray (x; device_to_kwargs (dev, x)... )
55+ end
56+
3057# Default RNG
3158MLDataDevices. default_device_rng (:: ReactantDevice ) = Reactant. TracedRandom. default_rng ()
3259
3360# Query Device from Array
34- @static if isdefined (Reactant, :ConcreteIFRTArray )
35- const AllConcreteTypes = Union{
36- Reactant. ConcreteIFRTNumber,
37- Reactant. ConcreteIFRTArray,
38- Reactant. ConcretePJRTNumber,
39- Reactant. ConcretePJRTArray,
40- }
41- elseif isdefined (Reactant, :ConcretePJRTArray )
42- const AllConcreteTypes = Union{Reactant. ConcretePJRTNumber,Reactant. ConcretePJRTArray}
43- else
44- const AllConcreteTypes = Union{ConcreteRNumber,ConcreteRArray}
45- end
61+ const AllConcreteTypes = Union{
62+ <: Reactant.ConcreteIFRTNumber ,
63+ <: Reactant.ConcreteIFRTArray ,
64+ <: Reactant.ConcretePJRTNumber ,
65+ <: Reactant.ConcretePJRTArray ,
66+ }
4667
68+ Internal. get_device (:: Type{<:AllConcreteTypes} ) = ReactantDevice ()
4769function Internal. get_device (x:: AllConcreteTypes )
4870 return ReactantDevice (
4971 Reactant. XLA. client (x),
@@ -53,6 +75,7 @@ function Internal.get_device(x::AllConcreteTypes)
5375 ),
5476 )
5577end
78+ Internal. get_device_type (:: Type{<:AllConcreteTypes} ) = ReactantDevice
5679Internal. get_device_type (:: AllConcreteTypes ) = ReactantDevice
5780
5881function Internal. get_device (:: Union{TracedRArray,TracedRNumber} )
@@ -69,26 +92,21 @@ Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothi
6992Profiler. @annotate " Device Transfer (Reactant)" function Adapt. adapt_storage (
7093 dev:: ReactantDevice , x:: AbstractArray
7194)
72- kwargs = (;)
73- dev. client === missing || (kwargs = (; kwargs... , client= dev. client))
74- dev. device === missing || (kwargs = (; kwargs... , device= dev. device))
75- if dev. sharding != = missing
76- if dev. sharding isa IdDict
77- sharding =
78- haskey (dev. sharding, x) ? dev. sharding[x] : Reactant. Sharding. NoSharding ()
79- @assert sharding isa Reactant. Sharding. AbstractSharding
80- kwargs = (; kwargs... , sharding)
81- elseif dev. sharding isa Reactant. Sharding. AbstractSharding
82- kwargs = (; kwargs... , dev. sharding)
83- else
84- throw (ArgumentError (" `sharding` must be an `IdDict` or a \
85- `Reactant.Sharding.AbstractSharding` but got \
86- $(typeof (dev. sharding)) ." ))
87- end
88- end
89- return ConcreteRArray (x; kwargs... )
95+ return ConcreteRArray (x; device_to_kwargs (dev, x)... )
9096end
9197
98+ function Adapt. adapt_storage (
99+ :: CPUDevice ,
100+ T:: Type{<:Union{<:Reactant.ConcretePJRTNumber,<:Reactant.ConcreteIFRTNumber}} ,
101+ )
102+ return Reactant. unwrapped_eltype (T)
103+ end
104+
105+ function Adapt. adapt_storage (
106+ :: CPUDevice , x:: Union{<:Reactant.ConcretePJRTNumber,<:Reactant.ConcreteIFRTNumber}
107+ )
108+ return Reactant. unwrapped_eltype (x)(x)
109+ end
92110Adapt. adapt_storage (:: CPUDevice , :: Reactant.ReactantRNG ) = Random. default_rng ()
93111
94112end
0 commit comments