diff --git a/Project.toml b/Project.toml index 3d0abc9b1..a554735cf 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" @@ -31,6 +33,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" TensorKitAMDGPUExt = "AMDGPU" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" +TensorKitEnzymeExt = "Enzyme" +TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils" TensorKitFiniteDifferencesExt = "FiniteDifferences" TensorKitMooncakeExt = "Mooncake" @@ -43,10 +47,12 @@ AMDGPU = "2" CUDA = "6" ChainRulesCore = "1" Dictionaries = "0.4" +Enzyme = "0.13.157" +EnzymeTestUtils = "0.2.7" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.6.7" +MatrixAlgebraKit = "0.6.8" Mooncake = "0.5.27" OhMyThreads = "0.8.0" Printf = "1" @@ -54,8 +60,8 @@ Random = "1" ScopedValues = "1.3.0" Strided = "2" TensorKitSectors = "0.3.7" -TensorOperations = "5.5" +TensorOperations = "5.5.2" TupleTools = "1.5" -VectorInterface = "0.4.8, 0.5, 0.6" +VectorInterface = "0.4.8, 0.5" cuTENSOR = "6" julia = "1.10" diff --git a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl new file mode 100644 index 000000000..54bf1acb9 --- /dev/null +++ b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl @@ -0,0 +1,16 @@ +module TensorKitEnzymeExt + +using Enzyme +using TensorKit +import TensorKit as TK +using VectorInterface +using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize +import TensorOperations as TO +using MatrixAlgebraKit +using TupleTools +using Random: AbstractRNG + +include("utility.jl") +include("tensoroperations.jl") + +end diff --git a/ext/TensorKitEnzymeExt/tensoroperations.jl b/ext/TensorKitEnzymeExt/tensoroperations.jl new file mode 100644 index 000000000..1f8dd799b --- /dev/null +++ b/ext/TensorKitEnzymeExt/tensoroperations.jl @@ -0,0 +1,274 @@ +# tensorcontract! +# --------------- +# TODO: it might be beneficial to compare here if it would make sense to simply compute the +# rrule of permute-permute-gemm-permute, rather than using the contractions directly. +# This could possibly out save some permutations being carried out twice, at the cost of having +# to store some more intermediate objects. +# For example, the combination `ΔC, pΔC, false` appears in the pullback for ΔA and ΔB, so effectively +# this permutation is done multiple times. + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.blas_contract!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + pA::Const{<:Index2Tuple}, + B::Annotation{<:AbstractTensorMap}, + pB::Const{<:Index2Tuple}, + pAB::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + allocator::Const + ) where {RT} + Ccache = isa(β, Const) ? nothing : copy(C.val) + A_needs_cache = EnzymeRules.overwritten(config)[3] && !(typeof(B) <: Const) && !(typeof(C) <: Const) + Acache = A_needs_cache ? copy(A.val) : nothing + B_needs_cache = EnzymeRules.overwritten(config)[5] && !(typeof(A) <: Const) && !(typeof(C) <: Const) + Bcache = B_needs_cache ? copy(B.val) : nothing + AB = if !isa(α, Const) + AB = TO.tensorcontract(A.val, pA.val, false, B.val, pB.val, false, pAB.val, One(), backend.val, allocator.val) + add!(C.val, AB, α.val, β.val) + AB + else + TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (Ccache, Acache, Bcache, AB) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.blas_contract!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + pA::Const{<:Index2Tuple}, + B::Annotation{<:AbstractTensorMap}, + pB::Const{<:Index2Tuple}, + pAB::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + allocator::Const + ) where {RT} + cacheC, cacheA, cacheB, AB = cache + Cval = cacheC + Aval = something(cacheA, A.val) + Bval = something(cacheB, B.val) + + Δα = pullback_dα(α, C, AB) + Δβ = pullback_dβ(β, C, Cval) + + if !isa(A, Const) + blas_contract_pullback_ΔA!( + A.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val + ) # this typically returns nothing + end + if !isa(B, Const) + blas_contract_pullback_ΔB!( + B.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val + ) # this typically returns nothing + end + !isa(C, Const) && pullback_dC!(C.dval, β.val) # this typically returns nothing + return nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, nothing, nothing +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(TensorKit.blas_contract!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + pA::Annotation{<:Index2Tuple}, + B::Annotation{<:AbstractTensorMap}, + pB::Annotation{<:Index2Tuple}, + pAB::Annotation{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + allocator::Const + ) where {RT} + # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α + if !isa(C, Const) + if isa(β, Const) + scale!(C.dval, β.val) + else + add!(C.dval, C.val, β.dval, β.val) + end + !isa(α, Const) && TensorKit.blas_contract!(C.dval, A.val, pA.val, B.val, pB.val, pAB.val, α.dval, One(), backend.val, allocator.val) + !isa(A, Const) && TensorKit.blas_contract!(C.dval, A.dval, pA.val, B.val, pB.val, pAB.val, α.val, One(), backend.val, allocator.val) + !isa(B, Const) && TensorKit.blas_contract!(C.dval, A.val, pA.val, B.dval, pB.val, pAB.val, α.val, One(), backend.val, allocator.val) + end + TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return C + elseif EnzymeRules.needs_primal(config) + return C.val + elseif EnzymeRules.needs_shadow(config) + return C.dval + else + return nothing + end +end + +function blas_contract_pullback_ΔA!( + ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator + ) + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + ipA = _repartition(invperm(linearize(pA)), A) + + tB = twist( + B, + TupleTools.vcat( + filter(x -> !isdual(space(B, x)), pB[1]), + filter(x -> isdual(space(B, x)), pB[2]) + ); copy = false + ) + + project_contract!( + ΔA, + ΔC, pΔC, false, + tB, reverse(pB), true, + ipA, conj(α), backend, allocator + ) + + return nothing +end + +function blas_contract_pullback_ΔB!( + ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator + ) + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + ipB = _repartition(invperm(linearize(pB)), B) + + tA = twist( + A, + TupleTools.vcat( + filter(x -> isdual(space(A, x)), pA[1]), + filter(x -> !isdual(space(A, x)), pA[2]) + ); copy = false + ) + + project_contract!( + ΔB, + tA, reverse(pA), true, + ΔC, pΔC, false, + ipB, conj(α), backend, allocator + ) + + return nothing +end + + +# tensortrace! +# ------------ + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.trace_permute!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + q::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + ) where {RT} + C_cache = !isa(β, Const) ? copy(C.val) : nothing + A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + At = if !isa(α, Const) + At = TO.tensortrace(A.val, p.val, q.val, false, One(), backend.val) + add!(C.val, At, α.val, β.val) + At + else + TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (C_cache, A_cache, At) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.trace_permute!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + q::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + ) where {RT} + C_cache, A_cache, At = cache + Aval = something(A_cache, A.val) + Cval = something(C_cache, C.val) + !isa(A, Const) && !isa(C, Const) && trace_permute_pullback_ΔA!(A.dval, C.dval, Aval, p.val, q.val, α.val, backend.val) + Δαr = pullback_dα(α, C, At) + Δβr = pullback_dβ(β, C, Cval) + !isa(C, Const) && pullback_dC!(C.dval, β.val) + return nothing, nothing, nothing, nothing, Δαr, Δβr, nothing +end + +function trace_permute_pullback_ΔA!( + ΔA, ΔC, A, p, q, α, backend + ) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = _repartition(ip, A) + E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + pE = ((), trivtuple(TO.numind(q))) + pΔC = (trivtuple(TO.numind(p)), ()) + TO.tensorproduct!( + ΔA, ΔC, pΔC, false, E, pE, false, pdA, conj(α), One(), backend + ) + return nothing +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(TensorKit.trace_permute!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Annotation{<:Index2Tuple}, + q::Annotation{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + ) where {RT} + # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC + # dC1 = dβ * C + β * dC + if !isa(C, Const) + if isa(β, Const) + scale!(C.dval, β.val) + else + add!(C.dval, C.val, β.dval, β.val) + end + !isa(α, Const) && TensorKit.trace_permute!(C.dval, A.val, p.val, q.val, α.dval, One(), backend.val) + !isa(A, Const) && TensorKit.trace_permute!(C.dval, A.dval, p.val, q.val, α.val, One(), backend.val) + end + TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return C + elseif EnzymeRules.needs_primal(config) + return C.val + elseif EnzymeRules.needs_shadow(config) + return C.dval + else + return nothing + end +end diff --git a/ext/TensorKitEnzymeExt/utility.jl b/ext/TensorKitEnzymeExt/utility.jl new file mode 100644 index 000000000..03ade424a --- /dev/null +++ b/ext/TensorKitEnzymeExt/utility.jl @@ -0,0 +1,80 @@ +# Projection +# ---------- +pullback_dα(α::Const, C::Const, A) = nothing +pullback_dα(α::Const, C::Annotation, A) = nothing +pullback_dα(α::Annotation, C::Const, A) = zero(α.val) +pullback_dα(α::Annotation, C::Annotation, A) = project_scalar(α.val, inner(A, C.dval)) + +pullback_dβ(β::Const, C::Const, Ccache) = nothing +pullback_dβ(β::Const, C::Annotation, Ccache) = nothing +pullback_dβ(β::Annotation, C::Const, Ccache) = zero(β.val) +pullback_dβ(β::Annotation, C::Annotation, Ccache) = project_scalar(β.val, inner(Ccache, C.dval)) + +pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β)) + +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) + +# in-place multiplication and accumulation which might project to (real) +# TODO: this could probably be done without allocating +function project_mul!(C, A, B, α) + TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(mul!(zerovector(C, TC), A, B, α))) + else + mul!(C, A, B, α, One()) + end +end +function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator) + TA = TensorKit.promote_permute(A) + TB = TensorKit.promote_permute(B) + TC = TO.promote_contract(TA, TB, scalartype(α)) + + return if scalartype(C) <: Real && !(TC <: Real) + add!(C, real(TO.tensorcontract!(zerovector(C, TC), A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend, allocator))) + else + TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, One(), backend, allocator) + end +end + +# IndexTuple utility +# ------------------ +trivtuple(N) = ntuple(identity, N) + +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) +end + +# Ignore derivatives +# ------------------ + +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.FusionTree}) = true +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.GenericTreeTransformer}) = true +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.VectorSpace}) = true + +@inline EnzymeRules.inactive(::typeof(TensorKit.sectorstructure), ::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.degeneracystructure), ::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.flip), s::HomSpace, i::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.permute), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.braid), s::HomSpace, i::Index2Tuple, ::IndexTuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.compose), s1::HomSpace, s2::HomSpace) = nothing +@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorcontract), c::HomSpace, p::Index2Tuple, α::Bool, b::HomSpace, q::Index2Tuple, β::Bool, pq::Index2Tuple) = nothing diff --git a/ext/TensorKitEnzymeTestUtilsExt.jl b/ext/TensorKitEnzymeTestUtilsExt.jl new file mode 100644 index 000000000..4a1f393b1 --- /dev/null +++ b/ext/TensorKitEnzymeTestUtilsExt.jl @@ -0,0 +1,66 @@ +module TensorKitEnzymeTestUtilsExt + +using TensorKit +using EnzymeTestUtils +using EnzymeTestUtils: Enzyme +import EnzymeTestUtils: to_vec, from_vec, rand_tangent + +function EnzymeTestUtils.to_vec(x::TensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + has_seen = haskey(seen_vecs, x) + is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x)) + if has_seen || is_const + x_vec = Float32[] + else + vec_of_vecs = [b * TensorKit.sqrtdim(c) for (c, b) in blocks(x)] + x_vec, back = to_vec(vec_of_vecs) + seen_vecs[x] = x_vec + end + function TensorMap_from_vec(x_vec_new::AbstractVector, seen_xs::EnzymeTestUtils.AliasDict) + if xor(has_seen, haskey(seen_xs, x)) + throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized.")) + end + has_seen && return seen_xs[x] + is_const && return x + + x_new = similar(x) + xvec_of_vecs = back(x_vec_new) + for (i, (c, b)) in enumerate(blocks(x_new)) + scale!(b, xvec_of_vecs[i], TensorKit.invsqrtdim(c)) + end + if Core.Typeof(x_new) != Core.Typeof(x) + x_new = Core.Typeof(x)(x_new) + end + seen_xs[x] = x_new + return x_new + end + return x_vec, TensorMap_from_vec +end +function EnzymeTestUtils.to_vec(t::TensorKit.AdjointTensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + parent_vec, parent_t = to_vec(parent(t), seen_vecs) + return parent_vec, adjoint ∘ parent_t +end +function EnzymeTestUtils.to_vec(t::TensorKit.DiagonalTensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + parent_vec, parent_t = to_vec(TensorMap(t), seen_vecs) + return parent_vec, TensorKit.DiagonalTensorMap ∘ parent_t +end + +# generate random tangents for testing +function EnzymeTestUtils.rand_tangent(rng, t::TensorMap) + return TensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t)) +end + +function EnzymeTestUtils.rand_tangent(rng, t::TensorKit.AdjointTensorMap) + return adjoint(rand_tangent(rng, parent(t))) +end + +function EnzymeTestUtils.rand_tangent(rng, t::DiagonalTensorMap) + return DiagonalTensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t, 1)) +end + +function EnzymeTestUtils.map_fields_recursive(f::typeof(Base.copyto!), y::TensorKit.SortedVectorDict{K, V}, x::TensorKit.SortedVectorDict{K, V}) where {K, V} + copyto!(y.keys, x.keys) + copyto!(y.values, x.values) + return y +end + +end diff --git a/test/Project.toml b/test/Project.toml index 18af8af80..5252ff1f4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" diff --git a/test/enzyme-tensoroperations/contract.jl b/test/enzyme-tensoroperations/contract.jl new file mode 100644 index 000000000..064bfba74 --- /dev/null +++ b/test/enzyme-tensoroperations/contract.jl @@ -0,0 +1,138 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: One, Zero +using Enzyme, EnzymeTestUtils + +is_ci = get(ENV, "CI", "false") == "true" + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - TensorOperations" begin + @timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + symmetricbraiding && @timedtestset "tensorcontract!" begin + d = 0 + local V1, V2, V3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 0 && break + end + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) + + α = randn(T) + β = randn(T) + V2_conj = prod(conj, V2; init = one(V[1])) + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + + αβs = is_ci ? (((α, Active), (β, Active)),) : Iterators.product(((One(), Const), (α, Const), (α, Active)), ((Zero(), Const), (β, Const), (β, Active))) + for (α_, β_) in αβs + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + α_, β_, + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! reverse α $α_ β $β_", + ) + end + αβs = is_ci ? (((α, Duplicated), (β, Duplicated)),) : Iterators.product(((One(), Const), (α, Const), (α, Duplicated)), ((Zero(), Const), (β, Const), (β, Duplicated))) + for (α_, β_) in αβs + EnzymeTestUtils.test_forward( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + α_, β_, + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! forward α $α_ β $β_", + ) + end + if !(T <: Real) && !is_ci + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! reverse real(α) real(β)", + ) + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (real(A), Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! reverse real(A) real(α) real(β)", + ) + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (real(B), Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! reverse real(B) real(α) real(β)", + ) + EnzymeTestUtils.test_forward( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! forward real(α) real(β)", + ) + EnzymeTestUtils.test_forward( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (real(A), Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! forward real(A) real(α) real(β)", + ) + EnzymeTestUtils.test_forward( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (real(B), Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! forward real(B) real(α) real(β)", + ) + end + end + end +end diff --git a/test/enzyme-tensoroperations/trace.jl b/test/enzyme-tensoroperations/trace.jl new file mode 100644 index 000000000..958af2e11 --- /dev/null +++ b/test/enzyme-tensoroperations/trace.jl @@ -0,0 +1,60 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: One, Zero +using Enzyme, EnzymeTestUtils + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +is_ci = get(ENV, "CI", "false") == "true" +rTαs = is_ci ? (Active,) : (Const, Active) +rTβs = is_ci ? (Active,) : (Const, Active) +fTαs = is_ci ? (Duplicated,) : (Const, Duplicated) +fTβs = is_ci ? (Duplicated,) : (Const, Duplicated) +TCs = is_ci ? (Duplicated,) : (Const, Duplicated) +TAs = is_ci ? (Duplicated,) : (Const, Duplicated) + +@timedtestset "Enzyme - TensorOperations (trace)" begin + @timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + symmetricbraiding && @timedtestset "trace_permute!" begin + k1 = rand(0:2) + k2 = rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + (_p, _q) = randindextuple(k1 + 2 * k2, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + for TC in TCs, TA in TAs + for Tα in rTαs, Tβ in rTβs + EnzymeTestUtils.test_reverse( + TensorKit.trace_permute!, TC, + (copy(C), TC), (A, TA), (p, Const), (q, Const), + (α, Tα), (β, Tβ), (TensorOperations.DefaultBackend(), Const); + atol, rtol, + testset_name = "trace_permute! reverse TC $TC TA $TA Tα $Tα Tβ $Tβ", + ) + end + for Tα in fTαs, Tβ in fTβs + EnzymeTestUtils.test_forward( + TensorKit.trace_permute!, TC, + (copy(C), TC), (A, TA), (p, Const), (q, Const), + (α, Tα), (β, Tβ), (TensorOperations.DefaultBackend(), Const); + atol, rtol, + testset_name = "trace_permute! forward TC $TC TA $TA Tα $Tα Tβ $Tβ", + ) + end + end + end + end +end