Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Expand All @@ -31,6 +33,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
TensorKitAMDGPUExt = "AMDGPU"
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
TensorKitChainRulesCoreExt = "ChainRulesCore"
TensorKitEnzymeExt = "Enzyme"
TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils"
TensorKitFiniteDifferencesExt = "FiniteDifferences"
TensorKitMooncakeExt = "Mooncake"

Expand All @@ -43,19 +47,21 @@ AMDGPU = "2"
CUDA = "6"
ChainRulesCore = "1"
Dictionaries = "0.4"
Enzyme = "0.13.157"
EnzymeTestUtils = "0.2.7"
FiniteDifferences = "0.12"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.7"
MatrixAlgebraKit = "0.6.8"
Mooncake = "0.5.27"
OhMyThreads = "0.8.0"
Printf = "1"
Random = "1"
ScopedValues = "1.3.0"
Strided = "2"
TensorKitSectors = "0.3.7"
TensorOperations = "5.5"
TensorOperations = "5.5.2"
TupleTools = "1.5"
VectorInterface = "0.4.8, 0.5, 0.6"
VectorInterface = "0.4.8, 0.5"
cuTENSOR = "6"
julia = "1.10"
16 changes: 16 additions & 0 deletions ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module TensorKitEnzymeExt

using Enzyme
using TensorKit
import TensorKit as TK
using VectorInterface
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
import TensorOperations as TO
using MatrixAlgebraKit
using TupleTools
using Random: AbstractRNG

include("utility.jl")
include("tensoroperations.jl")

end
274 changes: 274 additions & 0 deletions ext/TensorKitEnzymeExt/tensoroperations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# tensorcontract!
# ---------------
# TODO: it might be beneficial to compare here if it would make sense to simply compute the
# rrule of permute-permute-gemm-permute, rather than using the contractions directly.
# This could possibly out save some permutations being carried out twice, at the cost of having
# to store some more intermediate objects.
# For example, the combination `ΔC, pΔC, false` appears in the pullback for ΔA and ΔB, so effectively
# this permutation is done multiple times.

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.blas_contract!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
pA::Const{<:Index2Tuple},
B::Annotation{<:AbstractTensorMap},
pB::Const{<:Index2Tuple},
pAB::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
allocator::Const
) where {RT}
Ccache = isa(β, Const) ? nothing : copy(C.val)
A_needs_cache = EnzymeRules.overwritten(config)[3] && !(typeof(B) <: Const) && !(typeof(C) <: Const)
Acache = A_needs_cache ? copy(A.val) : nothing
B_needs_cache = EnzymeRules.overwritten(config)[5] && !(typeof(A) <: Const) && !(typeof(C) <: Const)
Bcache = B_needs_cache ? copy(B.val) : nothing
AB = if !isa(α, Const)
AB = TO.tensorcontract(A.val, pA.val, false, B.val, pB.val, false, pAB.val, One(), backend.val, allocator.val)
add!(C.val, AB, α.val, β.val)
AB
else
TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val)
nothing
end
primal = EnzymeRules.needs_primal(config) ? C.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing
cache = (Ccache, Acache, Bcache, AB)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.blas_contract!)},
::Type{RT},
cache,
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
pA::Const{<:Index2Tuple},
B::Annotation{<:AbstractTensorMap},
pB::Const{<:Index2Tuple},
pAB::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
allocator::Const
) where {RT}
cacheC, cacheA, cacheB, AB = cache
Cval = cacheC
Aval = something(cacheA, A.val)
Bval = something(cacheB, B.val)

Δα = pullback_dα(α, C, AB)
Δβ = pullback_dβ(β, C, Cval)

if !isa(A, Const)
blas_contract_pullback_ΔA!(
A.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val
) # this typically returns nothing
end
if !isa(B, Const)
blas_contract_pullback_ΔB!(
B.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val
) # this typically returns nothing
end
!isa(C, Const) && pullback_dC!(C.dval, β.val) # this typically returns nothing
return nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, nothing, nothing
end

function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(TensorKit.blas_contract!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
pA::Annotation{<:Index2Tuple},
B::Annotation{<:AbstractTensorMap},
pB::Annotation{<:Index2Tuple},
pAB::Annotation{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
allocator::Const
) where {RT}
# ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α
if !isa(C, Const)
if isa(β, Const)
scale!(C.dval, β.val)
else
add!(C.dval, C.val, β.dval, β.val)
end
!isa(α, Const) && TensorKit.blas_contract!(C.dval, A.val, pA.val, B.val, pB.val, pAB.val, α.dval, One(), backend.val, allocator.val)
!isa(A, Const) && TensorKit.blas_contract!(C.dval, A.dval, pA.val, B.val, pB.val, pAB.val, α.val, One(), backend.val, allocator.val)
!isa(B, Const) && TensorKit.blas_contract!(C.dval, A.val, pA.val, B.dval, pB.val, pAB.val, α.val, One(), backend.val, allocator.val)
end
TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return C
elseif EnzymeRules.needs_primal(config)
return C.val
elseif EnzymeRules.needs_shadow(config)
return C.dval
else
return nothing
end
end

function blas_contract_pullback_ΔA!(
ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator
)
ipAB = invperm(linearize(pAB))
pΔC = _repartition(ipAB, TO.numout(pA))
ipA = _repartition(invperm(linearize(pA)), A)

tB = twist(
B,
TupleTools.vcat(
filter(x -> !isdual(space(B, x)), pB[1]),
filter(x -> isdual(space(B, x)), pB[2])
); copy = false
)

project_contract!(
ΔA,
ΔC, pΔC, false,
tB, reverse(pB), true,
ipA, conj(α), backend, allocator
)

return nothing
end

function blas_contract_pullback_ΔB!(
ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator
)
ipAB = invperm(linearize(pAB))
pΔC = _repartition(ipAB, TO.numout(pA))
ipB = _repartition(invperm(linearize(pB)), B)

tA = twist(
A,
TupleTools.vcat(
filter(x -> isdual(space(A, x)), pA[1]),
filter(x -> !isdual(space(A, x)), pA[2])
); copy = false
)

project_contract!(
ΔB,
tA, reverse(pA), true,
ΔC, pΔC, false,
ipB, conj(α), backend, allocator
)

return nothing
end


# tensortrace!
# ------------

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.trace_permute!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Index2Tuple},
q::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
) where {RT}
C_cache = !isa(β, Const) ? copy(C.val) : nothing
A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing
At = if !isa(α, Const)
At = TO.tensortrace(A.val, p.val, q.val, false, One(), backend.val)
add!(C.val, At, α.val, β.val)
At
else
TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val)
nothing
end
primal = EnzymeRules.needs_primal(config) ? C.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing
cache = (C_cache, A_cache, At)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end


function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.trace_permute!)},
::Type{RT},
cache,
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Index2Tuple},
q::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
) where {RT}
C_cache, A_cache, At = cache
Aval = something(A_cache, A.val)
Cval = something(C_cache, C.val)
!isa(A, Const) && !isa(C, Const) && trace_permute_pullback_ΔA!(A.dval, C.dval, Aval, p.val, q.val, α.val, backend.val)
Δαr = pullback_dα(α, C, At)
Δβr = pullback_dβ(β, C, Cval)
!isa(C, Const) && pullback_dC!(C.dval, β.val)
return nothing, nothing, nothing, nothing, Δαr, Δβr, nothing
end

function trace_permute_pullback_ΔA!(
ΔA, ΔC, A, p, q, α, backend
)
ip = invperm((linearize(p)..., q[1]..., q[2]...))
pdA = _repartition(ip, A)
E = one!(TO.tensoralloc_add(scalartype(A), A, q, false))
twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E)))
pE = ((), trivtuple(TO.numind(q)))
pΔC = (trivtuple(TO.numind(p)), ())
TO.tensorproduct!(
ΔA, ΔC, pΔC, false, E, pE, false, pdA, conj(α), One(), backend
)
return nothing
end

function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(TensorKit.trace_permute!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Annotation{<:Index2Tuple},
q::Annotation{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
) where {RT}
# dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC
# dC1 = dβ * C + β * dC
if !isa(C, Const)
if isa(β, Const)
scale!(C.dval, β.val)
else
add!(C.dval, C.val, β.dval, β.val)
end
!isa(α, Const) && TensorKit.trace_permute!(C.dval, A.val, p.val, q.val, α.dval, One(), backend.val)
!isa(A, Const) && TensorKit.trace_permute!(C.dval, A.dval, p.val, q.val, α.val, One(), backend.val)
end
TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return C
elseif EnzymeRules.needs_primal(config)
return C.val
elseif EnzymeRules.needs_shadow(config)
return C.dval
else
return nothing
end
end
Loading
Loading