From 33ab47b23099228358ffb287e8564cb8b138fb37 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 8 Jun 2026 11:14:20 +0200 Subject: [PATCH 01/10] Some basic svd forward rules and tests --- .../MatrixAlgebraKitEnzymeExt.jl | 52 ++++++++++ .../MatrixAlgebraKitMooncakeExt.jl | 52 +++++++++- src/MatrixAlgebraKit.jl | 1 + src/pushforwards/svd.jl | 95 +++++++++++++++++++ test/enzyme/svd.jl | 2 +- test/testsuite/enzyme/svd.jl | 41 +++++++- test/testsuite/mooncake/svd.jl | 30 +++++- 7 files changed, 260 insertions(+), 13 deletions(-) create mode 100644 src/pushforwards/svd.jl diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 343bd9681..997b059a6 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -8,6 +8,7 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback! using MatrixAlgebraKit: eig_pushforward!, eigh_pushforward!, eig_vals_pushforward!, eigh_vals_pushforward! using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback! +using MatrixAlgebraKit: svd_pushforward!, svd_vals_pushforward! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! using Enzyme @@ -264,6 +265,30 @@ for f in (:svd_compact!, :svd_full!) !isa(USVᴴ, Const) && make_zero!(USVᴴ.dval) return (nothing, nothing, nothing) end + function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation, + USVᴴ::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + $f(A.val, USVᴴ.val, alg.val) + if !isa(A, Const) && !isa(USVᴴ, Const) + make_zero!(USVᴴ.dval) + svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval) + end + #!isa(A, Const) && make_zero!(A.dval) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return USVᴴ + elseif EnzymeRules.needs_primal(config) + return USVᴴ.val + elseif EnzymeRules.needs_shadow(config) + return USVᴴ.dval + else + return nothing + end + end end end @@ -502,5 +527,32 @@ function EnzymeRules.reverse( !isa(S, Const) && !A_is_arg && make_zero!(S.dval) return (nothing, nothing, nothing) end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(svd_vals!)}, + ::Type{RT}, + A::Annotation{TA}, + S::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA} + A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval + U, S_, Vᴴ = svd_compact!(A.val, alg.val) + copyto!(S.val, diagview(S_)) + if !isa(A, Const) && !isa(S, Const) + ΔS = A_is_arg ? make_zero(S.dval) : S.dval + svd_vals_pushforward!(A.dval, A.val, (U, Diagonal(S.val), Vᴴ), ΔS) + A_is_arg && (S.dval .= ΔS) + end + !isa(A, Const) && !A_is_arg && make_zero!(A.dval) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return S + elseif EnzymeRules.needs_primal(config) + return S.val + elseif EnzymeRules.needs_shadow(config) + return S.dval + else + return nothing + end +end end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 16241385b..ebe080eb5 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -13,6 +13,7 @@ using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pul using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback! +using MatrixAlgebraKit: svd_pushforward!, svd_trunc_pushforward!, svd_vals_pushforward! using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra @@ -538,7 +539,7 @@ for (f!, f) in ( (:svd_compact!, :svd_compact), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) A, dA = arrayify(A_dA) USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) @@ -562,7 +563,18 @@ for (f!, f) in ( end return USVᴴ_dUSVᴴ, svd_adjoint end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual) + A, dA = arrayify(A_dA) + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) + S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) + Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + $f!(A, USVᴴ, Mooncake.primal(alg_dalg)) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + return USVᴴ_dUSVᴴ + end + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) A, dA = arrayify(A_dA) USVᴴ = $f(A, Mooncake.primal(alg_dalg)) @@ -585,10 +597,23 @@ for (f!, f) in ( end return USVᴴ_codual, svd_adjoint end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual) + A, dA = arrayify(A_dA) + USVᴴ = $f(A, Mooncake.primal(alg_dalg)) + dUSVᴴ = Mooncake.zero_tangent(USVᴴ) + USVᴴ_dual = Dual(USVᴴ, dUSVᴴ) + U, S, Vᴴ = Mooncake.primal(USVᴴ_dual) + dU_, dS_, dVᴴ_ = Mooncake.tangent(USVᴴ_dual) + U, dU = arrayify(U, dU_) + S, dS = arrayify(S, dS_) + Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + return USVᴴ_dual + end end end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -604,8 +629,17 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua end return S_dS, svd_vals_adjoint end +function Mooncake.frule!!(::Dual{typeof(svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + S, dS = arrayify(S_dS) + USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + copy!(S, diagview(USVᴴ[2])) + svd_vals_pushforward!(dA, A, USVᴴ, dS) + return S_dS +end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -624,6 +658,16 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co end return S_codual, svd_vals_adjoint end +function Mooncake.frule!!(::Dual{typeof(svd_vals)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + S = diagview(USVᴴ[2]) + S_dual = Dual(S, Mooncake.zero_tangent(S)) + S_, dS = arrayify(S_dual) + svd_vals_pushforward!(dA, A, USVᴴ, dS) + return S_dual +end @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 65de152c4..115b83018 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -132,6 +132,7 @@ include("pullbacks/polar.jl") include("pushforwards/polar.jl") include("pushforwards/eig.jl") include("pushforwards/eigh.jl") +include("pushforwards/svd.jl") include("precompile.jl") diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl new file mode 100644 index 000000000..51b5adc45 --- /dev/null +++ b/src/pushforwards/svd.jl @@ -0,0 +1,95 @@ +function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = default_pullback_rank_atol(A), kwargs...) + U, Smat, Vᴴ = USVᴴ + m, n = size(U, 1), size(Vᴴ, 2) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) + minmn = min(m, n) + S = diagview(Smat) + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ + r = searchsortedlast(S, rank_atol; rev = true) # rank + + vΔS = view(ΔS, 1:r, 1:r) + + vU = view(U, :, 1:r) + vS = view(S, 1:r) + vSmat = view(Smat, 1:r, 1:r) + vVᴴ = view(Vᴴ, 1:r, :) + + # compact region + vV = adjoint(vVᴴ) + UΔAV = vU' * ΔA * vV + copyto!(diagview(vΔS), diag(real.(UΔAV))) + F = one(eltype(S)) ./ (transpose(vS) .- vS) + G = one(eltype(S)) ./ (transpose(vS) .+ vS) + diagview(F) .= zero(eltype(F)) + hUΔAV = F .* (UΔAV + UΔAV') ./ 2 + aUΔAV = G .* (UΔAV - UΔAV') ./ 2 + K̇ = hUΔAV + aUΔAV + Ṁ = hUΔAV - aUΔAV + + # check gauge condition + @assert isantihermitian(K̇) + @assert isantihermitian(Ṁ) + K̇diag = diagview(K̇) + for i in 1:length(K̇diag) + @assert K̇diag[i] ≈ (im / 2) * imag(diagview(UΔAV)[i]) / S[i] + end + + ∂U = vU * K̇ + ∂V = vV * Ṁ + # full component + if size(U, 2) > minmn && size(Vᴴ, 1) > minmn + Uperp = view(U, :, (minmn + 1):m) + Vᴴperp = view(Vᴴ, (minmn + 1):n, :) + + aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp) + + UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2))) + fill!(UÃÃV, 0) + view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV + view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV' + rhs = vcat(adjoint(Uperp * ΔA * Vᴴ), Vᴴperp * ΔA' * U) + superKM = -sylvester(UÃÃV, Smat, rhs) + K̇perp = view(superKM, 1:size(aUAV, 2)) + Ṁperp = view(superKM, (size(aUAV, 2) + 1):(size(aUAV, 1) + size(aUAV, 2))) + ∂U .+= Uperp * K̇perp + ∂V .+= Vᴴperp * Ṁperp + else + ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU * vU') + ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV * vVᴴ) + upper = ImUU * ΔA * vV + lower = ImVV * ΔA' * vU + rhs = vcat(upper, lower) + + Ã = ImUU * A * ImVV + ÃÃ = similar(A, (m + n, m + n)) + fill!(ÃÃ, 0) + view(ÃÃ, (1:m), m .+ (1:n)) .= Ã + view(ÃÃ, m .+ (1:n), 1:m) .= Ã' + + superLN = -sylvester(ÃÃ, vSmat, rhs) + ∂U += view(superLN, 1:size(upper, 1), :) + ∂V += view(superLN, (size(upper, 1) + 1):(size(upper, 1) + size(lower, 1)), :) + end + if !iszerotangent(ΔU) + vΔU = view(ΔU, :, 1:r) + copyto!(vΔU, ∂U) + end + if !iszerotangent(ΔVᴴ) + vΔVᴴ = view(ΔVᴴ, 1:r, :) + adjoint!(vΔVᴴ, ∂V) + end + return (ΔU, ΔS, ΔVᴴ) +end + +function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...) + # TODO +end + +function svd_vals_pushforward!( + ΔA, A, USVᴴ, ΔS, ind = Colon(); + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]) + ) + ΔUSVᴴ = (nothing, diagonal(ΔS), nothing) + return svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol) +end diff --git a/test/enzyme/svd.jl b/test/enzyme/svd.jl index e4aaa7aa1..a60ec6eba 100644 --- a/test/enzyme/svd.jl +++ b/test/enzyme/svd.jl @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) if !is_buildkite TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) AT = Diagonal{T, Vector{T}} - m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + m == n && TestSuite.test_enzyme_svd(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/testsuite/enzyme/svd.jl b/test/testsuite/enzyme/svd.jl index 2131aa8d5..66fea0c88 100644 --- a/test/testsuite/enzyme/svd.jl +++ b/test/testsuite/enzyme/svd.jl @@ -8,48 +8,83 @@ function test_enzyme_svd(T::Type, sz; kwargs...) end end +""" + test_enzyme_svd_compact(T, sz; rng, atol, rtol) + +Test the Enzyme forward- and reverse-mode AD rule for `svd_compact` and its in-place variant. +""" function test_enzyme_svd_compact( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "svd_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "svd_compact: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(svd_compact, A) USVᴴ, ΔUSVᴴ = ad_svd_compact_setup(A) test_reverse(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) test_reverse(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + if eltype(T) <: Real + A = instantiate_matrix(T, sz) + test_forward(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, fdm) + end end end +""" + test_enzyme_svd_full(T, sz; rng, atol, rtol) + +Test the Enzyme forward- and reverse-mode AD rule for `svd_full` and its in-place variant. The +gauge-dependent extra columns of `U` and rows of `Vᴴ` are zeroed out in the cotangent. +""" function test_enzyme_svd_full( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "svd_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "svd_full: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(svd_full, A) USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) test_reverse(svd_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) test_reverse(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + if eltype(T) <: Real + A = instantiate_matrix(T, sz) + test_forward(svd_full, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, fdm) + end end end +""" + test_enzyme_svd_vals(T, sz; rng, atol, rtol) + +Test the Enzyme forward- and reverse-mode AD rule for `svd_vals` and its in-place variant. +""" function test_enzyme_svd_vals( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "svd_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "svd_vals: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(svd_vals, A) S, ΔS = ad_svd_vals_setup(A) test_reverse(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) test_reverse(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) + A = instantiate_matrix(T, sz) + test_forward(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, fdm) end end +""" + test_enzyme_svd_trunc(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rules for `svd_trunc`, `svd_trunc_no_error`, and their +in-place variants, over a range of truncation ranks and a tolerance-based truncation. +""" function test_enzyme_svd_trunc( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), diff --git a/test/testsuite/mooncake/svd.jl b/test/testsuite/mooncake/svd.jl index 5ac79744e..ba6477990 100644 --- a/test/testsuite/mooncake/svd.jl +++ b/test/testsuite/mooncake/svd.jl @@ -16,7 +16,7 @@ end """ test_mooncake_svd_compact(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `svd_compact` and its in-place variant. +Test the Mooncake forward- and reverse-mode AD rule for `svd_compact` and its in-place variant. """ function test_mooncake_svd_compact( T, sz; @@ -36,13 +36,23 @@ function test_mooncake_svd_compact( rng, call_and_zero!, svd_compact!, A, alg; mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false ) + if eltype(T) <: Real # gauge freedom in complex outputs + Mooncake.TestUtils.test_rule( + rng, svd_compact, A, alg; + mode = Mooncake.ForwardMode, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, call_and_zero!, svd_compact!, A, alg; + mode = Mooncake.ForwardMode, atol, rtol, is_primitive = false + ) + end end end """ test_mooncake_svd_full(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `svd_full` and its in-place variant. The +Test the Mooncake forward- and reverse-mode AD rule for `svd_full` and its in-place variant. The gauge-dependent extra columns of `U` and rows of `Vᴴ` are zeroed out in the cotangent. """ function test_mooncake_svd_full( @@ -63,13 +73,23 @@ function test_mooncake_svd_full( rng, call_and_zero!, svd_full!, A, alg; mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false ) + if eltype(T) <: Real # gauge freedom in complex outputs + Mooncake.TestUtils.test_rule( + rng, svd_full, A, alg; + mode = Mooncake.ForwardMode, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, call_and_zero!, svd_full!, A, alg; + mode = Mooncake.ForwardMode, atol, rtol, is_primitive = false + ) + end end end """ test_mooncake_svd_vals(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `svd_vals` and its in-place variant. +Test the Mooncake forward- and reverse-mode AD rule for `svd_vals` and its in-place variant. """ function test_mooncake_svd_vals( T, sz; @@ -83,11 +103,11 @@ function test_mooncake_svd_vals( Mooncake.TestUtils.test_rule( rng, svd_vals, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, svd_vals!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false + output_tangent, atol, rtol, is_primitive = false ) end end From 838ecab85b11dfd59a152c3741575dbebb02453c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 8 Jun 2026 14:06:22 +0200 Subject: [PATCH 02/10] Try to fix Enzyme? --- ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 997b059a6..c2b730a0e 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -278,7 +278,7 @@ for f in (:svd_compact!, :svd_full!) make_zero!(USVᴴ.dval) svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval) end - #!isa(A, Const) && make_zero!(A.dval) + !isa(A, Const) && make_zero!(A.dval) if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) return USVᴴ elseif EnzymeRules.needs_primal(config) @@ -537,13 +537,13 @@ function EnzymeRules.forward( ) where {RT, TA} A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval U, S_, Vᴴ = svd_compact!(A.val, alg.val) - copyto!(S.val, diagview(S_)) if !isa(A, Const) && !isa(S, Const) ΔS = A_is_arg ? make_zero(S.dval) : S.dval - svd_vals_pushforward!(A.dval, A.val, (U, Diagonal(S.val), Vᴴ), ΔS) + svd_vals_pushforward!(A.dval, A.val, (U, Diagonal(diagview(S_)), Vᴴ), ΔS) A_is_arg && (S.dval .= ΔS) end !isa(A, Const) && !A_is_arg && make_zero!(A.dval) + copyto!(S.val, diagview(S_)) if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) return S elseif EnzymeRules.needs_primal(config) From be9afda4bd7122d082a0c268c7718aa8d3327719 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 8 Jun 2026 14:40:41 +0200 Subject: [PATCH 03/10] Use GPU safe rank calculation --- src/pushforwards/svd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl index 51b5adc45..66a14779e 100644 --- a/src/pushforwards/svd.jl +++ b/src/pushforwards/svd.jl @@ -5,7 +5,7 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d minmn = min(m, n) S = diagview(Smat) ΔU, ΔS, ΔVᴴ = ΔUSVᴴ - r = searchsortedlast(S, rank_atol; rev = true) # rank + r = svd_rank(S; rank_atol) vΔS = view(ΔS, 1:r, 1:r) From 883dadb4e1ce04a08e50e0b478acda4aff70dc16 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 8 Jun 2026 15:04:49 +0200 Subject: [PATCH 04/10] Another fix for GPU --- src/pushforwards/svd.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl index 66a14779e..d8c7688bc 100644 --- a/src/pushforwards/svd.jl +++ b/src/pushforwards/svd.jl @@ -30,9 +30,6 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d @assert isantihermitian(K̇) @assert isantihermitian(Ṁ) K̇diag = diagview(K̇) - for i in 1:length(K̇diag) - @assert K̇diag[i] ≈ (im / 2) * imag(diagview(UΔAV)[i]) / S[i] - end ∂U = vU * K̇ ∂V = vV * Ṁ From 1e03d65862968f2e99d0e8c0b0c8682c78d2fd97 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 9 Jun 2026 15:11:38 +0200 Subject: [PATCH 05/10] Some more fixes --- .../MatrixAlgebraKitEnzymeExt.jl | 13 +++++-------- src/pushforwards/svd.jl | 5 ++--- test/testsuite/enzyme/svd.jl | 19 ++++++++----------- test/testsuite/mooncake/svd.jl | 8 ++++---- 4 files changed, 19 insertions(+), 26 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index c2b730a0e..ae51af5dd 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -269,16 +269,13 @@ for f in (:svd_compact!, :svd_full!) config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof($f)}, ::Type{RT}, - A::Annotation, + A::Annotation{TA}, USVᴴ::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT} + ) where {RT, TA} $f(A.val, USVᴴ.val, alg.val) - if !isa(A, Const) && !isa(USVᴴ, Const) - make_zero!(USVᴴ.dval) - svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval) - end - !isa(A, Const) && make_zero!(A.dval) + !isa(A, Const) && !isa(USVᴴ, Const) && svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval) + make_zero!(A.dval) if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) return USVᴴ elseif EnzymeRules.needs_primal(config) @@ -542,7 +539,7 @@ function EnzymeRules.forward( svd_vals_pushforward!(A.dval, A.val, (U, Diagonal(diagview(S_)), Vᴴ), ΔS) A_is_arg && (S.dval .= ΔS) end - !isa(A, Const) && !A_is_arg && make_zero!(A.dval) + !A_is_arg && make_zero!(A.dval) copyto!(S.val, diagview(S_)) if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) return S diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl index d8c7688bc..69a8e8c35 100644 --- a/src/pushforwards/svd.jl +++ b/src/pushforwards/svd.jl @@ -18,9 +18,8 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d vV = adjoint(vVᴴ) UΔAV = vU' * ΔA * vV copyto!(diagview(vΔS), diag(real.(UΔAV))) - F = one(eltype(S)) ./ (transpose(vS) .- vS) - G = one(eltype(S)) ./ (transpose(vS) .+ vS) - diagview(F) .= zero(eltype(F)) + F = inv_safe.(transpose(vS) .- vS) + G = inv_safe.(transpose(vS) .+ vS) hUΔAV = F .* (UΔAV + UΔAV') ./ 2 aUΔAV = G .* (UΔAV - UΔAV') ./ 2 K̇ = hUΔAV + aUΔAV diff --git a/test/testsuite/enzyme/svd.jl b/test/testsuite/enzyme/svd.jl index 66fea0c88..06e6da671 100644 --- a/test/testsuite/enzyme/svd.jl +++ b/test/testsuite/enzyme/svd.jl @@ -23,11 +23,10 @@ function test_enzyme_svd_compact( alg = MatrixAlgebraKit.select_algorithm(svd_compact, A) USVᴴ, ΔUSVᴴ = ad_svd_compact_setup(A) test_reverse(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) - test_reverse(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + test_reverse(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) if eltype(T) <: Real - A = instantiate_matrix(T, sz) test_forward(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, fdm) - test_forward(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm) end end end @@ -48,11 +47,10 @@ function test_enzyme_svd_full( alg = MatrixAlgebraKit.select_algorithm(svd_full, A) USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) test_reverse(svd_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) - test_reverse(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + test_reverse(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) if eltype(T) <: Real - A = instantiate_matrix(T, sz) test_forward(svd_full, RT, (A, TA), (alg, Const); atol, rtol, fdm) - test_forward(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm) end end end @@ -72,10 +70,9 @@ function test_enzyme_svd_vals( alg = MatrixAlgebraKit.select_algorithm(svd_vals, A) S, ΔS = ad_svd_vals_setup(A) test_reverse(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) - test_reverse(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) - A = instantiate_matrix(T, sz) + test_reverse(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) test_forward(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, fdm) - test_forward(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm) end end @@ -99,7 +96,7 @@ function test_enzyme_svd_trunc( trunc = truncrank(r) truncalg = TruncatedAlgorithm(alg, trunc) USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) end @testset "trunctol" begin @@ -107,7 +104,7 @@ function test_enzyme_svd_trunc( trunc = trunctol(atol = maximum(S) / 2) truncalg = TruncatedAlgorithm(alg, trunc) USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) end end diff --git a/test/testsuite/mooncake/svd.jl b/test/testsuite/mooncake/svd.jl index ba6477990..d58a9c0ca 100644 --- a/test/testsuite/mooncake/svd.jl +++ b/test/testsuite/mooncake/svd.jl @@ -33,7 +33,7 @@ function test_mooncake_svd_compact( mode = Mooncake.ReverseMode, output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, call_and_zero!, svd_compact!, A, alg; + rng, call_and_zero!, svd_compact!, copy(A), alg; mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false ) if eltype(T) <: Real # gauge freedom in complex outputs @@ -42,7 +42,7 @@ function test_mooncake_svd_compact( mode = Mooncake.ForwardMode, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, call_and_zero!, svd_compact!, A, alg; + rng, call_and_zero!, svd_compact!, copy(A), alg; mode = Mooncake.ForwardMode, atol, rtol, is_primitive = false ) end @@ -70,7 +70,7 @@ function test_mooncake_svd_full( mode = Mooncake.ReverseMode, output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, call_and_zero!, svd_full!, A, alg; + rng, call_and_zero!, svd_full!, copy(A), alg; mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false ) if eltype(T) <: Real # gauge freedom in complex outputs @@ -79,7 +79,7 @@ function test_mooncake_svd_full( mode = Mooncake.ForwardMode, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, call_and_zero!, svd_full!, A, alg; + rng, call_and_zero!, svd_full!, copy(A), alg; mode = Mooncake.ForwardMode, atol, rtol, is_primitive = false ) end From d8c94278af6a91e1ca49ad70dcfd374cd020072c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 9 Jun 2026 17:48:40 +0200 Subject: [PATCH 06/10] Update src/pushforwards/svd.jl Co-authored-by: Jutho --- src/pushforwards/svd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl index 69a8e8c35..7056a2cf3 100644 --- a/src/pushforwards/svd.jl +++ b/src/pushforwards/svd.jl @@ -17,7 +17,7 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d # compact region vV = adjoint(vVᴴ) UΔAV = vU' * ΔA * vV - copyto!(diagview(vΔS), diag(real.(UΔAV))) + copyto!(diagview(vΔS), real.(diagview(UΔAV))) F = inv_safe.(transpose(vS) .- vS) G = inv_safe.(transpose(vS) .+ vS) hUΔAV = F .* (UΔAV + UΔAV') ./ 2 From ec2a674f609963975ec309d72795d2da3baa72ce Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Jun 2026 11:59:57 +0200 Subject: [PATCH 07/10] A little cleanup --- ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index ae51af5dd..ed420f890 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -274,8 +274,10 @@ for f in (:svd_compact!, :svd_full!) alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TA} $f(A.val, USVᴴ.val, alg.val) - !isa(A, Const) && !isa(USVᴴ, Const) && svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval) - make_zero!(A.dval) + if !isa(A, Const) + !isa(USVᴴ, Const) && svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval) + make_zero!(A.dval) + end if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) return USVᴴ elseif EnzymeRules.needs_primal(config) From 9d0f666c9dcb32c9640e4efaf8d04cdcf471d63b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Jun 2026 15:26:05 +0200 Subject: [PATCH 08/10] Some small fixes again --- ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl | 5 +++++ src/pushforwards/svd.jl | 4 ++-- test/enzyme/svd.jl | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index ed420f890..65ce01a3e 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -275,6 +275,11 @@ for f in (:svd_compact!, :svd_full!) ) where {RT, TA} $f(A.val, USVᴴ.val, alg.val) if !isa(A, Const) + if $(f == svd_compact!) + make_zero!(USVᴴ.dval[2].diag) + else + make_zero!(USVᴴ.dval[2]) + end !isa(USVᴴ, Const) && svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval) make_zero!(A.dval) end diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl index 7056a2cf3..dd09640b1 100644 --- a/src/pushforwards/svd.jl +++ b/src/pushforwards/svd.jl @@ -7,7 +7,7 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d ΔU, ΔS, ΔVᴴ = ΔUSVᴴ r = svd_rank(S; rank_atol) - vΔS = view(ΔS, 1:r, 1:r) + vΔS = view(diagview(ΔS), 1:r) vU = view(U, :, 1:r) vS = view(S, 1:r) @@ -17,7 +17,7 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d # compact region vV = adjoint(vVᴴ) UΔAV = vU' * ΔA * vV - copyto!(diagview(vΔS), real.(diagview(UΔAV))) + copyto!(vΔS, real.(diagview(UΔAV))) F = inv_safe.(transpose(vS) .- vS) G = inv_safe.(transpose(vS) .+ vS) hUΔAV = F .* (UΔAV + UΔAV') ./ 2 diff --git a/test/enzyme/svd.jl b/test/enzyme/svd.jl index a60ec6eba..bef41e5c7 100644 --- a/test/enzyme/svd.jl +++ b/test/enzyme/svd.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite From 8940756981884e4862c7784c898685396427afc8 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Jun 2026 15:43:19 +0200 Subject: [PATCH 09/10] Use sylvester fallback --- ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl | 2 +- src/pushforwards/svd.jl | 4 ++-- test/testsuite/enzyme/svd.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 65ce01a3e..68e2916b2 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -275,7 +275,7 @@ for f in (:svd_compact!, :svd_full!) ) where {RT, TA} $f(A.val, USVᴴ.val, alg.val) if !isa(A, Const) - if $(f == svd_compact!) + if $(f == svd_compact!) make_zero!(USVᴴ.dval[2].diag) else make_zero!(USVᴴ.dval[2]) diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl index dd09640b1..307d8cb76 100644 --- a/src/pushforwards/svd.jl +++ b/src/pushforwards/svd.jl @@ -44,7 +44,7 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV' rhs = vcat(adjoint(Uperp * ΔA * Vᴴ), Vᴴperp * ΔA' * U) - superKM = -sylvester(UÃÃV, Smat, rhs) + superKM = -_sylvester(UÃÃV, Smat, rhs) K̇perp = view(superKM, 1:size(aUAV, 2)) Ṁperp = view(superKM, (size(aUAV, 2) + 1):(size(aUAV, 1) + size(aUAV, 2))) ∂U .+= Uperp * K̇perp @@ -62,7 +62,7 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d view(ÃÃ, (1:m), m .+ (1:n)) .= Ã view(ÃÃ, m .+ (1:n), 1:m) .= Ã' - superLN = -sylvester(ÃÃ, vSmat, rhs) + superLN = -_sylvester(ÃÃ, vSmat, rhs) ∂U += view(superLN, 1:size(upper, 1), :) ∂V += view(superLN, (size(upper, 1) + 1):(size(upper, 1) + size(lower, 1)), :) end diff --git a/test/testsuite/enzyme/svd.jl b/test/testsuite/enzyme/svd.jl index 06e6da671..1861b83b1 100644 --- a/test/testsuite/enzyme/svd.jl +++ b/test/testsuite/enzyme/svd.jl @@ -48,7 +48,7 @@ function test_enzyme_svd_full( USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) test_reverse(svd_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) test_reverse(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) - if eltype(T) <: Real + if eltype(T) <: Real && size(A, 1) == size(A, 2) # finite differences check for free component is very finicky test_forward(svd_full, RT, (A, TA), (alg, Const); atol, rtol, fdm) test_forward(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm) end From aa3a2ee70eb32cf9179c2ce1ae162e66f3195d7a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 12 Jun 2026 17:18:21 +0200 Subject: [PATCH 10/10] Fix diagms --- src/pushforwards/svd.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl index 307d8cb76..9acca17f3 100644 --- a/src/pushforwards/svd.jl +++ b/src/pushforwards/svd.jl @@ -50,8 +50,8 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d ∂U .+= Uperp * K̇perp ∂V .+= Vᴴperp * Ṁperp else - ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU * vU') - ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV * vVᴴ) + ImUU = (LinearAlgebra.diagm(one!(similar(U, m))) - vU * vU') + ImVV = (LinearAlgebra.diagm(one!(similar(Vᴴ, n))) - vV * vVᴴ) upper = ImUU * ΔA * vV lower = ImVV * ΔA' * vU rhs = vcat(upper, lower)