From 43316b9dbfda348ac20cf1ebb3ecbf46b3aa1f39 Mon Sep 17 00:00:00 2001 From: marius Date: Mon, 26 Apr 2021 12:40:58 -0700 Subject: [PATCH] fix some ComponentArray grads reverting to Vector --- src/autodiff.jl | 1 + src/dataset.jl | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/autodiff.jl b/src/autodiff.jl index 8da00c58..ef16da3e 100644 --- a/src/autodiff.jl +++ b/src/autodiff.jl @@ -1,6 +1,7 @@ # this does basis promotion, unlike Zygote's default for AbstractArrays Zygote.accum(a::Field, b::Field) = a+b +Zygote.accum(a::FieldTuple, b::FieldTuple) = Zygote.accum.(a,b) # this may create a LazyBinaryOp, unlike Zygote's Zygote.accum(a::FieldOp, b::FieldOp) = a+b diff --git a/src/dataset.jl b/src/dataset.jl index ab122461..28c46c66 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -230,7 +230,7 @@ function load_sim(; @warn "`rfid` will be removed in a future version. Use `fiducial_θ=(r=...,)` instead." fiducial_θ = merge(fiducial_θ,(r=rfid,)) end - Aϕ₀ = get(fiducial_θ, :Aϕ, 1) + Aϕ₀ = T(get(fiducial_θ, :Aϕ, 1)) fiducial_θ = Base.structdiff(fiducial_θ, NamedTuple{(:Aϕ,)}) # remove Aϕ key if present if (Cℓ == nothing) Cℓ = camb(;fiducial_θ..., ℓmax=ℓmax) @@ -241,7 +241,7 @@ function load_sim(; error("ℓmax of `Cℓ` argument should be higher than $ℓmax for this configuration.") end end - r₀ = Cℓ.params.r + r₀ = T(Cℓ.params.r) # noise Cℓs (these are non-debeamed, hence beamFWHM=0 below; the beam comes in via the B operator) if (Cℓn == nothing) @@ -264,7 +264,7 @@ function load_sim(; Cf̃ = Cℓ_to_Cov(pol, proj, (Cℓ.total[k] for k in ks)...) Cn̂ = Cℓ_to_Cov(pol, proj, (Cℓn[k] for k in ks)...) if (Cn == nothing); Cn = Cn̂; end - Cf = ParamDependentOp((;r=r₀, _...)->(Cfs + T(r/r₀)*Cft)) + Cf = ParamDependentOp((;r=r₀, _...)->(Cfs + (T(r)/r₀)*Cft)) Cϕ = ParamDependentOp((;Aϕ=Aϕ₀, _...)->(T(Aϕ) * Cϕ₀)) # data mask