-
Notifications
You must be signed in to change notification settings - Fork 7
Exponential #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Exponential #94
Changes from all commits
eb913cb
a3dc04d
c4564ee
8dc3ecd
d9fb748
5095cdb
89dfa23
996ecb5
dc78eb0
f220035
c68afad
95ddb06
5d6f4f3
c8e811c
0229417
cbbf813
d08d545
720ada5
d738c22
be111ea
c760a47
d0d14e1
cf98bd4
1536eb4
349800e
28b5bc5
c313009
55794da
e04d94c
04f08ab
bfcd6ca
3c3124f
04a1436
2d736f6
89f02c3
dba342f
ed4228e
329f493
62b448b
1cd993e
cdec4d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,127 @@ | ||||||||
| # Inputs | ||||||||
| # ------ | ||||||||
| function copy_input(::typeof(exponential), A::AbstractMatrix) | ||||||||
| return copy!(similar(A, float(eltype(A))), A) | ||||||||
| end | ||||||||
|
|
||||||||
| copy_input(::typeof(exponential), A::Diagonal) = copy(A) | ||||||||
| copy_input(::typeof(exponential), (τ, A)::Tuple{Number, AbstractMatrix}) = (τ, copy!(similar(A, float(eltype(A))), A)) | ||||||||
| copy_input(::typeof(exponential), (τ, A)::Tuple{Number, Diagonal}) = τ, copy(A) | ||||||||
|
|
||||||||
| function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) | ||||||||
| m, n = size(A) | ||||||||
| m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) | ||||||||
| @check_size(expA, (m, m)) | ||||||||
| @check_scalar(expA, A) | ||||||||
| return nothing | ||||||||
| end | ||||||||
|
|
||||||||
| function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) | ||||||||
| if !ishermitian(A) | ||||||||
| throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix)")) | ||||||||
| end | ||||||||
|
Comment on lines
+20
to
+22
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| m, n = size(A) | ||||||||
| m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) | ||||||||
| @check_size(expA, (m, m)) | ||||||||
| @check_scalar(expA, A) | ||||||||
| return nothing | ||||||||
| end | ||||||||
|
|
||||||||
| function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, ::DiagonalAlgorithm) | ||||||||
| m, n = size(A) | ||||||||
| @assert m == n && isdiag(A) | ||||||||
| @assert expA isa Diagonal | ||||||||
| @check_size(expA, (m, m)) | ||||||||
| @check_scalar(expA, A) | ||||||||
| return nothing | ||||||||
| end | ||||||||
|
|
||||||||
| function check_input(type::typeof(exponential!), τ::Real, A::AbstractMatrix, expA::AbstractMatrix, alg) | ||||||||
| return check_input(type, A, expA, alg) | ||||||||
| end | ||||||||
|
|
||||||||
| function check_input(type::typeof(exponential!), τ, A::AbstractMatrix, expA::AbstractMatrix, alg) | ||||||||
| return check_input(type, complex(A), expA, alg) | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is probably not entirely how we should handle this, as we don't want to copy |
||||||||
| end | ||||||||
|
|
||||||||
|
|
||||||||
| # Outputs | ||||||||
| # ------- | ||||||||
| initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm) = A | ||||||||
| initialize_output(::typeof(exponential!), (τ, A)::Tuple{T, AbstractMatrix}, ::AbstractAlgorithm) where {T <: Real} = A | ||||||||
| initialize_output(::typeof(exponential!), (τ, A)::Tuple{Number, AbstractMatrix}, ::AbstractAlgorithm) = complex(A) | ||||||||
|
|
||||||||
| # Implementation | ||||||||
| # -------------- | ||||||||
| function exponential!(A, expA, alg::MatrixFunctionViaLA) | ||||||||
| check_input(exponential!, A, expA, alg) | ||||||||
| A = LinearAlgebra.exp!(A) | ||||||||
| A === expA || copy!(expA, A) | ||||||||
| return expA | ||||||||
| end | ||||||||
|
|
||||||||
| function exponential!(A, expA, alg::MatrixFunctionViaEigh) | ||||||||
| check_input(exponential!, A, expA, alg) | ||||||||
| D, V = eigh_full!(A, alg.eigh_alg) | ||||||||
| expD = map_diagonal!(x -> exp(x / 2), D, D) | ||||||||
| VexpD = rmul!(V, expD) | ||||||||
| return mul!(expA, VexpD, V') | ||||||||
| end | ||||||||
|
|
||||||||
| function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEig) | ||||||||
| check_input(exponential!, A, expA, alg) | ||||||||
| D, V = eig_full!(A, alg.eig_alg) | ||||||||
| expD = map_diagonal!(exp, D, D) | ||||||||
| iV = inv(V) | ||||||||
| VexpD = rmul!(V, expD) | ||||||||
| if eltype(A) <: Real | ||||||||
| expA .= real.(VexpD * iV) | ||||||||
| else | ||||||||
| mul!(expA, VexpD, iV) | ||||||||
| end | ||||||||
| return expA | ||||||||
| end | ||||||||
|
|
||||||||
| function exponential!((τ, A)::Tuple{Number, AbstractMatrix}, expA::AbstractMatrix, alg::MatrixFunctionViaLA) | ||||||||
| check_input(exponential!, τ, A, expA, alg) | ||||||||
| expA .= A .* τ | ||||||||
| return LinearAlgebra.exp!(expA) | ||||||||
| end | ||||||||
|
|
||||||||
| function exponential!((τ, A)::Tuple{Number, AbstractMatrix}, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) | ||||||||
| check_input(exponential!, τ, A, expA, alg) | ||||||||
| D, V = eigh_full!(A, alg.eigh_alg) | ||||||||
| expD = map_diagonal(x -> exp(x * τ), D) | ||||||||
| VexpD = V * expD | ||||||||
| if eltype(A) <: Real && eltype(τ) <: Real | ||||||||
| return expA .= real.(VexpD * V') | ||||||||
| else | ||||||||
| return mul!(expA, VexpD, V') | ||||||||
| end | ||||||||
| end | ||||||||
|
|
||||||||
| function exponential!((τ, A)::Tuple{Number, AbstractMatrix}, expA, alg::MatrixFunctionViaEig) | ||||||||
| check_input(exponential!, τ, A, expA, alg) | ||||||||
| D, V = eig_full!(A, alg.eig_alg) | ||||||||
| expD = map_diagonal!(x -> exp(x * τ), D, D) | ||||||||
| iV = inv(V) | ||||||||
| VexpD = rmul!(V, expD) | ||||||||
| if eltype(A) <: Real && eltype(τ) <: Real | ||||||||
| expA .= real.(VexpD * iV) | ||||||||
| return expA | ||||||||
| else | ||||||||
| return mul!(expA, VexpD, iV) | ||||||||
| end | ||||||||
| end | ||||||||
|
|
||||||||
| # Diagonal logic | ||||||||
| # -------------- | ||||||||
| function exponential!(A, expA, alg::DiagonalAlgorithm) | ||||||||
| check_input(exponential!, A, expA, alg) | ||||||||
| return map_diagonal!(exp, expA, A) | ||||||||
| end | ||||||||
|
|
||||||||
| function exponential!((τ, A)::Tuple{Number, AbstractMatrix}, expA, alg::DiagonalAlgorithm) | ||||||||
| check_input(exponential!, τ, A, expA, alg) | ||||||||
| return map_diagonal!(x -> exp(x * τ), expA, A) | ||||||||
| end | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # Exponential functions | ||
| # -------------- | ||
|
|
||
| """ | ||
| exponential(A; kwargs...) -> expA | ||
| exponential(A, alg::AbstractAlgorithm) -> expA | ||
| exponential!(A, [expA]; kwargs...) -> expA | ||
| exponential!(A, [expA], alg::AbstractAlgorithm) -> expA | ||
| exponential((τ,A); kwargs...) -> expτA | ||
| exponential((τ,A), alg::AbstractAlgorithm) -> expτA | ||
| exponential!((τ,A), [expA]; kwargs...) -> expτA | ||
| exponential!((τ,A), [expA], alg::AbstractAlgorithm) -> expτA | ||
|
|
||
| Compute the exponential of the square matrix `A` or `τ*A`, | ||
|
|
||
| !!! note | ||
| The bang method `exponential!` optionally accepts the output structure and | ||
| possibly destroys the input matrix `A`. Always use the return value of the function | ||
| as it may not always be possible to use the provided `expA` as output. | ||
| """ | ||
| @functiondef exponential | ||
|
|
||
| # Algorithm selection | ||
| # ------------------- | ||
| default_exponential_algorithm(A; kwargs...) = default_exponential_algorithm(typeof(A); kwargs...) | ||
| function default_exponential_algorithm(T::Type; kwargs...) | ||
| return MatrixFunctionViaLA(; kwargs...) | ||
| end | ||
| function default_exponential_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} | ||
| return DiagonalAlgorithm(; kwargs...) | ||
| end | ||
|
|
||
| for f in (:exponential!,) | ||
| @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} | ||
| return default_exponential_algorithm(A; kwargs...) | ||
| end | ||
| end | ||
|
|
||
| for f in (:exponential!,) | ||
| @eval function default_algorithm(::typeof($f), ::Tuple{A, B}; kwargs...) where {A, B} | ||
| return default_exponential_algorithm(B; kwargs...) | ||
| end | ||
| end |
|
sanderdemeyer marked this conversation as resolved.
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # ================================ | ||
| # EXPONENTIAL ALGORITHMS | ||
| # ================================ | ||
| """ | ||
| MatrixFunctionViaLA() | ||
|
|
||
| Algorithm type to denote finding the exponential of `A` via the implementation of `LinearAlgebra`. | ||
| """ | ||
| @algdef MatrixFunctionViaLA | ||
|
|
||
| """ | ||
| MatrixFunctionViaEigh() | ||
|
|
||
| Algorithm type to denote finding the exponential `A` by computing the hermitian eigendecomposition of `A`. | ||
| The `eigh_alg` specifies which hermitian eigendecomposition implementation to use. | ||
| """ | ||
| struct MatrixFunctionViaEigh{A <: AbstractAlgorithm} <: AbstractAlgorithm | ||
| eigh_alg::A | ||
| end | ||
| function Base.show(io::IO, alg::MatrixFunctionViaEigh) | ||
| print(io, "MatrixFunctionViaEigh(") | ||
| _show_alg(io, alg.eigh_alg) | ||
| return print(io, ")") | ||
| end | ||
|
|
||
| """ | ||
| MatrixFunctionViaEig() | ||
|
|
||
| Algorithm type to denote finding the exponential `A` by computing the eigendecomposition of `A`. | ||
| The `eig_alg` specifies which eigendecomposition implementation to use. | ||
| """ | ||
| struct MatrixFunctionViaEig{A <: AbstractAlgorithm} <: AbstractAlgorithm | ||
| eig_alg::A | ||
| end | ||
| function Base.show(io::IO, alg::MatrixFunctionViaEig) | ||
| print(io, "MatrixFunctionViaEig(") | ||
| _show_alg(io, alg.eig_alg) | ||
| return print(io, ")") | ||
| end |
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using TestExtras | ||
| using StableRNGs | ||
| using MatrixAlgebraKit: diagview | ||
| using LinearAlgebra | ||
| using LinearAlgebra: exp | ||
|
|
||
| BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) | ||
| GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat}) | ||
|
|
||
| @testset "exponential! for T = $T" for T in BLASFloats | ||
| rng = StableRNG(123) | ||
| m = 54 | ||
|
|
||
| A = LinearAlgebra.normalize!(randn(rng, T, m, m)) | ||
| Ac = copy(A) | ||
| expA = LinearAlgebra.exp(A) | ||
|
|
||
| expA2 = @constinferred exponential(A) | ||
| @test expA ≈ expA2 | ||
| @test A == Ac | ||
|
|
||
| algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) | ||
| @testset "algorithm $alg" for alg in algs | ||
| expA2 = @constinferred exponential(A, alg) | ||
| @test expA ≈ expA2 | ||
| @test A == Ac | ||
| end | ||
|
|
||
| @test_throws DomainError exponential(A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) | ||
| end | ||
|
|
||
| @testset "exponential! for T = $T" for T in BLASFloats | ||
| rng = StableRNG(123) | ||
| m = 54 | ||
|
|
||
| A = randn(rng, T, m, m) | ||
| τ = randn(rng, T) | ||
| Ac = copy(A) | ||
|
|
||
| Aτ = A * τ | ||
| expAτ = LinearAlgebra.exp(Aτ) | ||
|
|
||
| expAτ2 = @constinferred exponential((τ, A)) | ||
| @test expAτ ≈ expAτ2 | ||
| @test A == Ac | ||
|
|
||
| algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) | ||
| @testset "algorithm $alg" for alg in algs | ||
| expAτ2 = @constinferred exponential((τ, A), alg) | ||
| @test expAτ ≈ expAτ2 | ||
| @test A == Ac | ||
| end | ||
|
|
||
| @test_throws DomainError exponential((τ, A); alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) | ||
| end | ||
|
|
||
| @testset "exponential! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) | ||
| rng = StableRNG(123) | ||
| m = 54 | ||
|
|
||
| A = Diagonal(randn(rng, T, m)) | ||
| τ = randn(rng, T) | ||
| Ac = copy(A) | ||
|
|
||
| expA = LinearAlgebra.exp(A) | ||
|
|
||
| expA2 = @constinferred exponential(A) | ||
| @test expA ≈ expA2 | ||
| @test A == Ac | ||
| end | ||
|
|
||
| @testset "exponential! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) | ||
| rng = StableRNG(123) | ||
| m = 1 | ||
|
|
||
| A = Diagonal(randn(rng, T, m)) | ||
| τ = randn(rng, T) | ||
| Ac = copy(A) | ||
|
|
||
| Aτ = A * τ | ||
| expAτ = LinearAlgebra.exp(Aτ) | ||
|
|
||
| expAτ2 = @constinferred exponential((τ, A)) | ||
| @test expAτ ≈ expAτ2 | ||
| @test A == Ac | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: do we want to keep this check here, knowing that it will again be checked in the iimplementation of
eigh_full?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, we should remove this test.