diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index 4c54f3c0c..c34f1abc4 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -14,6 +14,13 @@ function MatrixAlgebraKit.default_eig_algorithm( return QRIteration(; driver, kwargs...) end +function MatrixAlgebraKit.default_exponential_algorithm( + type::Type{T}; kwargs... + ) where {T <: StridedMatrix{<:GSFloat}} + eig_alg = MatrixAlgebraKit.default_eig_algorithm(type; kwargs...) + return MatrixFunctionViaEig(eig_alg) +end + function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) D, Vmat = GenericSchur.eigen!(A) copyto!(Dd, D) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 65de152c4..20ccc1e8b 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -30,6 +30,7 @@ export left_polar, right_polar export left_polar!, right_polar! export left_orth, right_orth, left_null, right_null export left_orth!, right_orth!, left_null!, right_null! +export exponential, exponential!, exponentialr, exponentialr! export Householder, Native_HouseholderQR, Native_HouseholderLQ export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar @@ -40,6 +41,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration export LQViaTransposedQR export PolarViaSVD, PolarNewton +export MatrixFunctionViaLA, MatrixFunctionViaEig, MatrixFunctionViaEigh export DefaultAlgorithm export DiagonalAlgorithm export NativeBlocked @@ -95,9 +97,12 @@ include("common/matrixproperties.jl") include("yalapack.jl") include("algorithms.jl") + include("interface/projections.jl") include("interface/decompositions.jl") include("interface/truncation.jl") +include("interface/matrixfunctions.jl") + include("interface/qr.jl") include("interface/lq.jl") include("interface/svd.jl") @@ -107,6 +112,7 @@ include("interface/gen_eig.jl") include("interface/schur.jl") include("interface/polar.jl") include("interface/orthnull.jl") +include("interface/exponential.jl") include("implementations/projections.jl") include("implementations/truncation.jl") @@ -119,6 +125,7 @@ include("implementations/gen_eig.jl") include("implementations/schur.jl") include("implementations/polar.jl") include("implementations/orthnull.jl") +include("implementations/exponential.jl") include("common/gauge.jl") # needs to be defined after the functions are diff --git a/src/common/view.jl b/src/common/view.jl index 8cd989ea0..61c8da8b8 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -20,6 +20,26 @@ See also [`diagview`](@ref). diagonal(v::AbstractVector) = Diagonal(v) +""" + map_diagonal!(f, dst, src...) + +Map the scalar function `f` over all elements of the diagonal of `src...`, returning +a diagonal result. + +See also [`map_diagonal!`](@ref). +""" +map_diagonal(f, src, srcs...) = diagonal(f.(diagview(src), map(diagview, srcs)...)) + +""" + map_diagonal!(f, dst, src...) + +Map the scalar function `f` over all elements of the diagonal of `src...`, +into the diagonal elements of destination `dst`. + +See also [`map_diagonal`](@ref). +""" +map_diagonal!(f, dst, src, srcs...) = (diagview(dst) .= f.(diagview(src), map(diagview, srcs)...); dst) + # triangularind function lowertriangularind(A::AbstractMatrix) Base.require_one_based_indexing(A) diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl new file mode 100644 index 000000000..2da6feeae --- /dev/null +++ b/src/implementations/exponential.jl @@ -0,0 +1,129 @@ +# Inputs +# ------ +function copy_input(::typeof(exponential), A::AbstractMatrix) + return copy!(similar(A, float(eltype(A))), A) +end + +copy_input(::typeof(exponential), A::Diagonal) = copy(A) +copy_input(::typeof(exponential), (τ, A)::Tuple{Number, AbstractMatrix}) = (τ, copy!(similar(A, float(eltype(A))), A)) +copy_input(::typeof(exponential), (τ, A)::Tuple{Number, Diagonal}) = τ, copy(A) + +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) + m = LinearAlgebra.checksquare(A) + @check_size(expA, (m, m)) + @check_scalar(expA, A) + return nothing +end + +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + m = LinearAlgebra.checksquare(A) + @check_size(expA, (m, m)) + @check_scalar(expA, A) + return nothing +end + +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, ::DiagonalAlgorithm) + m = LinearAlgebra.checksquare(A) + @assert isdiag(A) + @assert expA isa Diagonal + @check_size(expA, (m, m)) + @check_scalar(expA, A) + return nothing +end + +function check_input(::typeof(exponential!), (τ, A)::Tuple{Number, AbstractMatrix}, expA::AbstractMatrix, alg::AbstractAlgorithm) + m = LinearAlgebra.checksquare(A) + @check_size(expA, (m, m)) + @check_scalar(expA, A, (τ isa Real) ? identity : complex) + return nothing +end + +function check_input(::typeof(exponential!), (τ, A)::Tuple{Number, AbstractMatrix}, expA::AbstractMatrix, ::DiagonalAlgorithm) + m = LinearAlgebra.checksquare(A) + @assert isdiag(A) + @assert expA isa Diagonal + @check_size(expA, (m, m)) + @check_scalar(expA, A, (τ isa Real) ? identity : complex) + return nothing +end + +# Outputs +# ------- +initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm) = A +initialize_output(::typeof(exponential!), (τ, A)::Tuple{T, AbstractMatrix}, ::AbstractAlgorithm) where {T <: Real} = A +initialize_output(::typeof(exponential!), (τ, A)::Tuple{Number, AbstractMatrix}, ::AbstractAlgorithm) = complex(A) + +# Implementation +# -------------- +function exponential!(A, expA, alg::MatrixFunctionViaLA) + check_input(exponential!, A, expA, alg) + A = LinearAlgebra.exp!(A) + A === expA || copy!(expA, A) + return expA +end + +function exponential!(A, expA, alg::MatrixFunctionViaEigh) + check_input(exponential!, A, expA, alg) + D, V = eigh_full!(A, alg.eigh_alg) + expD = map_diagonal!(x -> exp(x / 2), D, D) + VexpD = rmul!(V, expD) + return mul!(expA, VexpD, V') +end + +function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEig) + check_input(exponential!, A, expA, alg) + D, V = eig_full!(A, alg.eig_alg) + expD = map_diagonal!(exp, D, D) + iV = inv(V) + VexpD = rmul!(V, expD) + if eltype(A) <: Real + expA .= real.(VexpD * iV) + else + mul!(expA, VexpD, iV) + end + return expA +end + +function exponential!((τ, A)::Tuple{Number, AbstractMatrix}, expA::AbstractMatrix, alg::MatrixFunctionViaLA) + check_input(exponential!, (τ, A), expA, alg) + expA .= A .* τ + return LinearAlgebra.exp!(expA) +end + +function exponential!((τ, A)::Tuple{Number, AbstractMatrix}, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + check_input(exponential!, (τ, A), expA, alg) + D, V = eigh_full!(A, alg.eigh_alg) + expD = map_diagonal(x -> exp(x * τ), D) + VexpD = V * expD + if eltype(A) <: Real && eltype(τ) <: Real + return expA .= real.(VexpD * V') + else + return mul!(expA, VexpD, V') + end +end + +function exponential!((τ, A)::Tuple{Number, AbstractMatrix}, expA, alg::MatrixFunctionViaEig) + check_input(exponential!, (τ, A), expA, alg) + D, V = eig_full!(A, alg.eig_alg) + expD = map_diagonal!(x -> exp(x * τ), D, D) + iV = inv(V) + VexpD = rmul!(V, expD) + if eltype(A) <: Real && eltype(τ) <: Real + expA .= real.(VexpD * iV) + return expA + else + return mul!(expA, VexpD, iV) + end +end + +# Diagonal logic +# -------------- +function exponential!(A, expA, alg::DiagonalAlgorithm) + check_input(exponential!, A, expA, alg) + return map_diagonal!(exp, expA, A) +end + +function exponential!((τ, A)::Tuple{Number, AbstractMatrix}, expA, alg::DiagonalAlgorithm) + check_input(exponential!, (τ, A), expA, alg) + return map_diagonal!(x -> exp(x * τ), expA, A) +end diff --git a/src/interface/exponential.jl b/src/interface/exponential.jl new file mode 100644 index 000000000..cd15d1557 --- /dev/null +++ b/src/interface/exponential.jl @@ -0,0 +1,43 @@ +# Exponential functions +# -------------- + +""" + exponential(A; kwargs...) -> expA + exponential(A, alg::AbstractAlgorithm) -> expA + exponential!(A, [expA]; kwargs...) -> expA + exponential!(A, [expA], alg::AbstractAlgorithm) -> expA + exponential((τ,A); kwargs...) -> expτA + exponential((τ,A), alg::AbstractAlgorithm) -> expτA + exponential!((τ,A), [expA]; kwargs...) -> expτA + exponential!((τ,A), [expA], alg::AbstractAlgorithm) -> expτA + +Compute the exponential of the square matrix `A` or `τ*A`, + +!!! note + The bang method `exponential!` optionally accepts the output structure and + possibly destroys the input matrix `A`. Always use the return value of the function + as it may not always be possible to use the provided `expA` as output. +""" +@functiondef exponential + +# Algorithm selection +# ------------------- +default_exponential_algorithm(A; kwargs...) = default_exponential_algorithm(typeof(A); kwargs...) +function default_exponential_algorithm(T::Type; kwargs...) + return MatrixFunctionViaLA(; kwargs...) +end +function default_exponential_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} + return DiagonalAlgorithm(; kwargs...) +end + +for f in (:exponential!,) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_exponential_algorithm(A; kwargs...) + end +end + +for f in (:exponential!,) + @eval function default_algorithm(::typeof($f), ::Tuple{A, B}; kwargs...) where {A, B} + return default_exponential_algorithm(B; kwargs...) + end +end diff --git a/src/interface/matrixfunctions.jl b/src/interface/matrixfunctions.jl new file mode 100644 index 000000000..ce24652de --- /dev/null +++ b/src/interface/matrixfunctions.jl @@ -0,0 +1,39 @@ +# ================================ +# EXPONENTIAL ALGORITHMS +# ================================ +""" + MatrixFunctionViaLA() + +Algorithm type to denote finding the exponential of `A` via the implementation of `LinearAlgebra`. +""" +@algdef MatrixFunctionViaLA + +""" + MatrixFunctionViaEigh() + +Algorithm type to denote finding the exponential `A` by computing the hermitian eigendecomposition of `A`. +The `eigh_alg` specifies which hermitian eigendecomposition implementation to use. +""" +struct MatrixFunctionViaEigh{A <: AbstractAlgorithm} <: AbstractAlgorithm + eigh_alg::A +end +function Base.show(io::IO, alg::MatrixFunctionViaEigh) + print(io, "MatrixFunctionViaEigh(") + _show_alg(io, alg.eigh_alg) + return print(io, ")") +end + +""" + MatrixFunctionViaEig() + +Algorithm type to denote finding the exponential `A` by computing the eigendecomposition of `A`. +The `eig_alg` specifies which eigendecomposition implementation to use. +""" +struct MatrixFunctionViaEig{A <: AbstractAlgorithm} <: AbstractAlgorithm + eig_alg::A +end +function Base.show(io::IO, alg::MatrixFunctionViaEig) + print(io, "MatrixFunctionViaEig(") + _show_alg(io, alg.eig_alg) + return print(io, ")") +end diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl deleted file mode 100644 index 8b1378917..000000000 --- a/src/matrixfunctions.jl +++ /dev/null @@ -1 +0,0 @@ - diff --git a/test/exponential.jl b/test/exponential.jl new file mode 100644 index 000000000..ad7628c3a --- /dev/null +++ b/test/exponential.jl @@ -0,0 +1,88 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra +using LinearAlgebra: exp + +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat}) + +@testset "exponential! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + + A = LinearAlgebra.normalize!(randn(rng, T, m, m)) + Ac = copy(A) + expA = LinearAlgebra.exp(A) + + expA2 = @constinferred exponential(A) + @test expA ≈ expA2 + @test A == Ac + + algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) + @testset "algorithm $alg" for alg in algs + expA2 = @constinferred exponential(A, alg) + @test expA ≈ expA2 + @test A == Ac + end + + @test_throws DomainError exponential(A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) +end + +@testset "exponential! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T, m, m) + τ = randn(rng, T) + Ac = copy(A) + + Aτ = A * τ + expAτ = LinearAlgebra.exp(Aτ) + + expAτ2 = @constinferred exponential((τ, A)) + @test expAτ ≈ expAτ2 + @test A == Ac + + algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) + @testset "algorithm $alg" for alg in algs + expAτ2 = @constinferred exponential((τ, A), alg) + @test expAτ ≈ expAτ2 + @test A == Ac + end + + @test_throws DomainError exponential((τ, A); alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) +end + +@testset "exponential! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) + rng = StableRNG(123) + m = 54 + + A = Diagonal(randn(rng, T, m)) + τ = randn(rng, T) + Ac = copy(A) + + expA = LinearAlgebra.exp(A) + + expA2 = @constinferred exponential(A) + @test expA ≈ expA2 + @test A == Ac +end + +@testset "exponential! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) + rng = StableRNG(123) + m = 1 + + A = Diagonal(randn(rng, T, m)) + τ = randn(rng, T) + Ac = copy(A) + + Aτ = A * τ + expAτ = LinearAlgebra.exp(Aτ) + + expAτ2 = @constinferred exponential((τ, A)) + @test expAτ ≈ expAτ2 + @test A == Ac +end diff --git a/test/genericlinearalgebra/exponential.jl b/test/genericlinearalgebra/exponential.jl new file mode 100644 index 000000000..5743d3b4d --- /dev/null +++ b/test/genericlinearalgebra/exponential.jl @@ -0,0 +1,46 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra +using GenericLinearAlgebra + +GenericFloats = (BigFloat, Complex{BigFloat}) + +@testset "exponential! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 54 + + A = project_hermitian!(randn(rng, T, m, m)) + D, V = @constinferred eigh_full(A) + algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expA = @constinferred exponential!(copy(A); alg) + expA2 = @constinferred exponential(A; alg) + @test expA2 ≈ expA + + Dexp, Vexp = @constinferred eigh_full(expA) + @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) + end +end + +using GenericSchur +@testset "exponential! for T1 = $T1, T2 = $T2" for T1 in GenericFloats, T2 in GenericFloats + rng = StableRNG(123) + m = 54 + A = project_hermitian!(randn(rng, T1, m, m)) + τ = randn(rng, T2) + + D, V = @constinferred eigh_full(A) + algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expτA = @constinferred exponential!((τ, copy(A)); alg) + expτA2 = @constinferred exponential((τ, A); alg) + @test expτA2 ≈ expτA + + Dexp, Vexp = @constinferred eig_full(expτA) + + @test sort(diagview(Dexp); by = real) ≈ sort(LinearAlgebra.exp.(diagview(D) .* τ); by = real) + end +end diff --git a/test/genericschur/exponential.jl b/test/genericschur/exponential.jl new file mode 100644 index 000000000..e2d26e2bd --- /dev/null +++ b/test/genericschur/exponential.jl @@ -0,0 +1,47 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra +using GenericSchur + +GenericFloats = (BigFloat, Complex{BigFloat}) + +@testset "exponential! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T, m, m) + D, V = @constinferred eig_full(A) + algs = (MatrixFunctionViaEig(GS_QRIteration()),) + expA_LA = @constinferred exponential(A) + @testset "algorithm $alg" for alg in algs + expA = @constinferred exponential!(copy(A)) + expA2 = @constinferred exponential(A; alg = alg) + @test expA ≈ expA_LA + @test expA2 ≈ expA + + Dexp, Vexp = @constinferred eig_full(expA) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D)); by = imag) + end +end + +@testset "exponential! for T1 = $T1, T2 = $T2" for T1 in GenericFloats, T2 in GenericFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T1, m, m) + τ = randn(rng, T2) + + D, V = @constinferred eig_full(A) + algs = (MatrixFunctionViaEig(GS_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expτA = @constinferred exponential!((τ, copy(A))) + expτA2 = @constinferred exponential((τ, A); alg) + @test expτA2 ≈ expτA + + Dexp, Vexp = @constinferred eig_full(expτA) + @test sort(diagview(Dexp); by = x -> (imag(x), real(x))) ≈ sort(LinearAlgebra.exp.(diagview(D) .* τ); by = x -> (imag(x), real(x))) + end +end