Skip to content

Commit

Permalink
use LuxCore.apply
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Mar 25, 2024
1 parent 392d6d4 commit 8e9320b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/cond_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ struct CondLayer{NN <: LuxCore.AbstractExplicitLayer, AT <: AbstractArray} <:
end

@inline function (m::CondLayer)(z::AbstractArray, ps::Any, st::Any)
m.nn(vcat(z, m.ys), ps, st)
LuxCore.apply(m.nn, vcat(z, m.ys), ps, st)

Check warning on line 8 in src/cond_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/cond_layer.jl#L8

Added line #L8 was not covered by tests
end
44 changes: 22 additions & 22 deletions src/icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function augmented_f(
ż, J = AbstractDifferentiation.value_and_jacobian(
icnf.differentiation_backend,
let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 87 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L87

Added line #L87 was not covered by tests
end,
z,
)
Expand All @@ -107,7 +107,7 @@ function augmented_f(
ż, J = AbstractDifferentiation.value_and_jacobian(
icnf.differentiation_backend,
let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 110 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L110

Added line #L110 was not covered by tests
end,
z,
)
Expand All @@ -128,7 +128,7 @@ function augmented_f(
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1)]
ż, J = Zygote.withjacobian(let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 131 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L131

Added line #L131 was not covered by tests
end, z)
= -tr(only(J))
vcat(ż, l̇)
Expand All @@ -147,7 +147,7 @@ function augmented_f(
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1)]
ż, J = Zygote.withjacobian(let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 150 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L150

Added line #L150 was not covered by tests
end, z)
du[begin:(end - n_aug - 1)] .=
du[(end - n_aug)] = -tr(only(J))
Expand All @@ -166,7 +166,7 @@ function augmented_f(
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
ż, J = jacobian_batched(icnf, let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 169 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L169

Added line #L169 was not covered by tests
end, z)
= -transpose(tr.(eachslice(J; dims = 3)))
vcat(ż, l̇)
Expand All @@ -185,7 +185,7 @@ function augmented_f(
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
ż, J = jacobian_batched(icnf, let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 188 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L188

Added line #L188 was not covered by tests
end, z)
du[begin:(end - n_aug - 1), :] .=
du[(end - n_aug), :] .= -(tr.(eachslice(J; dims = 3)))
Expand Down Expand Up @@ -220,7 +220,7 @@ function augmented_f(
ż, VJ = AbstractDifferentiation.value_and_pullback_function(
icnf.differentiation_backend,
let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 223 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L223

Added line #L223 was not covered by tests
end,
z,
)
Expand Down Expand Up @@ -273,7 +273,7 @@ function augmented_f(
ż, VJ = AbstractDifferentiation.value_and_pullback_function(
icnf.differentiation_backend,
let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 276 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L276

Added line #L276 was not covered by tests
end,
z,
)
Expand Down Expand Up @@ -326,7 +326,7 @@ function augmented_f(
ż_JV = AbstractDifferentiation.value_and_pushforward_function(
icnf.differentiation_backend,
let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 329 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L329

Added line #L329 was not covered by tests
end,
z,
)
Expand Down Expand Up @@ -380,7 +380,7 @@ function augmented_f(
ż_JV = AbstractDifferentiation.value_and_pushforward_function(
icnf.differentiation_backend,
let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 383 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L383

Added line #L383 was not covered by tests
end,
z,
)
Expand Down Expand Up @@ -432,7 +432,7 @@ function augmented_f(
z_aug = z[(end - n_aug_input + 1):end]
end
ż, VJ = Zygote.pullback(let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 435 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L435

Added line #L435 was not covered by tests
end, z)
ϵJ = only(VJ(ϵ))
= -(ϵJ ϵ)
Expand Down Expand Up @@ -481,7 +481,7 @@ function augmented_f(
z_aug = z[(end - n_aug_input + 1):end]
end
ż, VJ = Zygote.pullback(let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 484 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L484

Added line #L484 was not covered by tests
end, z)
ϵJ = only(VJ(ϵ))
du[begin:(end - n_aug - 1)] .=
Expand Down Expand Up @@ -529,9 +529,9 @@ function augmented_f(
n_aug_input = n_augment_input(icnf)
z_aug = z[(end - n_aug_input + 1):end, :]
end
= nn(z, p)
= LuxCore.apply(nn, z, p)

Check warning on line 532 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L532

Added line #L532 was not covered by tests
Jf = VecJac(let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 534 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L534

Added line #L534 was not covered by tests
end, z; autodiff = icnf.autodiff_backend)
ϵJ = reshape(Jf * ϵ, size(z))
= -sum(ϵJ .* ϵ; dims = 1)
Expand Down Expand Up @@ -579,9 +579,9 @@ function augmented_f(
n_aug_input = n_augment_input(icnf)
z_aug = z[(end - n_aug_input + 1):end, :]
end
= nn(z, p)
= LuxCore.apply(nn, z, p)

Check warning on line 582 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L582

Added line #L582 was not covered by tests
Jf = VecJac(let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 584 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L584

Added line #L584 was not covered by tests
end, z; autodiff = icnf.autodiff_backend)
ϵJ = reshape(Jf * ϵ, size(z))
du[begin:(end - n_aug - 1), :] .=
Expand Down Expand Up @@ -629,9 +629,9 @@ function augmented_f(
n_aug_input = n_augment_input(icnf)
z_aug = z[(end - n_aug_input + 1):end, :]
end
= nn(z, p)
= LuxCore.apply(nn, z, p)

Check warning on line 632 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L632

Added line #L632 was not covered by tests
Jf = JacVec(let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 634 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L634

Added line #L634 was not covered by tests
end, z; autodiff = icnf.autodiff_backend)
= reshape(Jf * ϵ, size(z))
= -sum.* Jϵ; dims = 1)
Expand Down Expand Up @@ -679,9 +679,9 @@ function augmented_f(
n_aug_input = n_augment_input(icnf)
z_aug = z[(end - n_aug_input + 1):end, :]
end
= nn(z, p)
= LuxCore.apply(nn, z, p)

Check warning on line 682 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L682

Added line #L682 was not covered by tests
Jf = JacVec(let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 684 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L684

Added line #L684 was not covered by tests
end, z; autodiff = icnf.autodiff_backend)
= reshape(Jf * ϵ, size(z))
du[begin:(end - n_aug - 1), :] .=
Expand Down Expand Up @@ -730,7 +730,7 @@ function augmented_f(
z_aug = z[(end - n_aug_input + 1):end, :]
end
ż, VJ = Zygote.pullback(let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 733 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L733

Added line #L733 was not covered by tests
end, z)
ϵJ = only(VJ(ϵ))
= -sum(ϵJ .* ϵ; dims = 1)
Expand Down Expand Up @@ -779,7 +779,7 @@ function augmented_f(
z_aug = z[(end - n_aug_input + 1):end, :]
end
ż, VJ = Zygote.pullback(let p = p
x -> nn(x, p)
x -> LuxCore.apply(nn, x, p)

Check warning on line 782 in src/icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/icnf.jl#L782

Added line #L782 was not covered by tests
end, z)
ϵJ = only(VJ(ϵ))
du[begin:(end - n_aug - 1), :] .=
Expand Down

0 comments on commit 8e9320b

Please sign in to comment.