Skip to content

Commit 255ad78

Browse files
authored
Use |> for moving data to devices (#1559)
* Use `|>` for moving data to devices It's more visually suggestive of the data being moved. Also, in an assignment expression, the main data piece/function call is closer to the equal sign rather than the device call. * Run JuliaFormatter over examples/
1 parent 15e62d3 commit 255ad78

File tree

11 files changed

+38
-34
lines changed

11 files changed

+38
-34
lines changed

examples/ConvolutionalVAE/main.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,14 @@ function main(;
256256
Random.seed!(rng, seed)
257257

258258
cvae = CVAE(rng; num_latent_dims, image_shape=(image_size..., 1), max_num_filters)
259-
ps, st = xdev(Lux.setup(rng, cvae))
259+
ps, st = Lux.setup(rng, cvae) |> xdev
260260

261261
z = xdev(randn(Float32, num_latent_dims, num_samples))
262262
decode_compiled = @compile decode(cvae, z, ps, Lux.testmode(st))
263-
x = xdev(randn(Float32, image_size..., 1, batchsize))
263+
x = randn(Float32, image_size..., 1, batchsize) |> xdev
264264
cvae_compiled = @compile cvae(x, ps, Lux.testmode(st))
265265

266-
train_dataloader = xdev(loadmnist(batchsize, image_size))
266+
train_dataloader = loadmnist(batchsize, image_size) |> xdev
267267

268268
opt = AdamW(; eta=learning_rate, lambda=weight_decay)
269269

examples/DDIM/main.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ function Base.getindex(ds::FlowersDataset, idxs)
501501
imgs = Array{Float32,4}(undef, ds.image_size..., 3, length(idxs))
502502
tforeach(1:length(idxs)) do i
503503
img = Image(load(ds.image_files[idxs[i]]))
504-
copyto!(view(imgs, :, :, :, i), itemdata(apply(ds.transform, img)))
504+
return copyto!(view(imgs, :, :, :, i), itemdata(apply(ds.transform, img)))
505505
end
506506
return imgs
507507
end

examples/GCN_Cora/main.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ function main(;
9191
rng = Random.default_rng()
9292
Random.seed!(rng, 0)
9393

94-
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
94+
features, targets, adj, (train_idx, val_idx, test_idx) = loadcora() |> xdev
9595

9696
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
97-
ps, st = xdev(Lux.setup(rng, gcn))
97+
ps, st = Lux.setup(rng, gcn) |> xdev
9898
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
9999

100100
train_state = Training.TrainState(gcn, ps, st, opt)

examples/HyperNet/main.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,9 @@ function accuracy(model, ps, st, dataloader, data_idx)
9797
cdev = cpu_device()
9898
st = Lux.testmode(st)
9999
for (x, y) in dataloader
100-
target_class = onecold(cdev(y))
101-
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
100+
ŷ, _ = model((data_idx, x), ps, st)
101+
target_class = y |> cdev |> onecold
102+
predicted_class =|> cdev |> onecold
102103
total_correct += sum(target_class .== predicted_class)
103104
total += length(target_class)
104105
end
@@ -111,10 +112,10 @@ function train()
111112
dev = reactant_device(; force=true)
112113

113114
model = create_model()
114-
dataloaders = dev(load_datasets())
115+
dataloaders = load_datasets() |> dev
115116

116117
Random.seed!(1234)
117-
ps, st = dev(Lux.setup(Random.default_rng(), model))
118+
ps, st = Lux.setup(Random.default_rng(), model) |> dev
118119

119120
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
120121

examples/ImageNet/main.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ function construct_model(;
226226
rng::AbstractRNG, model_name::String, model_args, pretrained::Bool=false
227227
)
228228
model = getproperty(Vision, Symbol(model_name))(model_args...; pretrained)
229-
ps, st = gdev(Lux.setup(rng, model))
229+
ps, st = Lux.setup(rng, model) |> gdev
230230

231231
sensible_println("=> model `$(model_name)` created.")
232232
pretrained && sensible_println("==> using pre-trained model`")
@@ -549,7 +549,7 @@ Comonicon.@main function main(;
549549

550550
ckpt = load_checkpoint(rpath)
551551
if !isnothing(ckpt)
552-
ps, st = gdev((ckpt.ps, ckpt.st))
552+
ps, st = (ckpt.ps, ckpt.st) |> gdev
553553
initial_step = ckpt.step
554554
sensible_println("=> training started from $(initial_step)")
555555
else

examples/NeuralODE/main.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,11 @@ function create_model(
128128
Random.seed!(rng, 0)
129129

130130
ps, st = Lux.setup(rng, model)
131-
ps = dev((use_named_tuple ? ps : ComponentArray(ps)))
132-
st = dev(st)
131+
if !use_named_tuple
132+
ps = ComponentArray(ps)
133+
end
134+
ps = ps |> dev
135+
st = st |> dev
133136

134137
return model, ps, st
135138
end
@@ -239,7 +242,7 @@ model, ps, st = create_model(NeuralODE)
239242

240243
model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)
241244

242-
x = gpu_device()(ones(Float32, 28, 28, 1, 3));
245+
x = ones(Float32, 28, 28, 1, 3) |> gpu_device();
243246

244247
# NeuralODE is not type stable due to the boxing of `st`
245248

examples/OptimizationIntegration/main.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ Base.length(t::TimeWrapper) = length(t.t)
7878

7979
Base.getindex(t::TimeWrapper, i) = TimeWrapper(t.t[i])
8080

81-
dataloader = gdev(DataLoader((ode_data, TimeWrapper(t)); batchsize=8))
81+
dataloader = DataLoader((ode_data, TimeWrapper(t)); batchsize=8) |> gdev
8282
nothing #hide
8383

8484
# ## Training the model
@@ -101,8 +101,8 @@ function train_model(dataloader)
101101
model = Chain(Dense(2, 32, tanh), Dense(32, 32, tanh), Dense(32, 2))
102102
ps, st = Lux.setup(Random.default_rng(), model)
103103

104-
ps_ca = gdev(ComponentArray(ps))
105-
st = gdev(st)
104+
ps_ca = ComponentArray(ps) |> gdev
105+
st = st |> gdev
106106

107107
function callback(state, l)
108108
if state.iter == 1 || state.iter % 25 == 0
@@ -153,7 +153,7 @@ nothing #hide
153153
dudt(u, p, t) = trained_model(u, p)
154154
prob = ODEProblem(dudt, gdev(u0), (tspan[1], tspan[2]), trained_model.ps)
155155
sol = solve(prob, Tsit5(); saveat=t)
156-
pred = cdev(convert(AbstractArray, sol))
156+
pred = convert(AbstractArray, sol) |> cdev
157157

158158
begin
159159
fig = Figure()

examples/PolynomialFitting/main.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ const loss_function = MSELoss()
6969
const cdev = cpu_device()
7070
const xdev = reactant_device()
7171

72-
ps, st = xdev(Lux.setup(rng, model))
72+
ps, st = Lux.setup(rng, model) |> xdev
7373

7474
# ## Training
7575

@@ -104,11 +104,10 @@ forward_pass = @compile Lux.apply(
104104
tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
105105
)
106106

107-
y_pred = cdev(
108-
first(
109-
forward_pass(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
110-
),
111-
)
107+
y_pred =
108+
forward_pass(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)) |>
109+
first |>
110+
cdev
112111
nothing #hide
113112

114113
# Let's plot the results

examples/RealNVP/main.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,13 @@ function main(;
223223
rng = Random.default_rng()
224224
Random.seed!(rng, 0)
225225

226-
dataloader = Iterators.cycle(
227-
xdev(load_moons_dataloader(rng, Float32, n_train_samples; batchsize, noise))
228-
)
226+
dataloader =
227+
load_moons_dataloader(rng, Float32, n_train_samples; batchsize, noise) |>
228+
xdev |>
229+
Iterators.cycle
229230

230231
model = RealNVP(; n_transforms, dist_dims=2, hidden_dims, n_layers)
231-
ps, st = xdev(Lux.setup(rng, model))
232+
ps, st = Lux.setup(rng, model) |> xdev
232233
opt = Adam(lr)
233234

234235
train_state = Training.TrainState(model, ps, st, opt)

examples/SimpleChains/main.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ end
7474

7575
# ## Define the Training Loop
7676
function train(model, dev=cpu_device(); rng=Random.default_rng(), kwargs...)
77-
train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))
78-
ps, st = dev(Lux.setup(rng, model))
77+
train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev
78+
ps, st = Lux.setup(rng, model) |> dev
7979

8080
vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
8181

0 commit comments

Comments
 (0)